137 lines
4.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
块处理器基类模块
提供所有块处理器的基类和注册功能
"""
from abc import ABC, abstractmethod
from typing import Dict, List, Any, Optional, Type, Callable
import json
import logging
# 获取日志记录器
logger = logging.getLogger(__name__)
# 块处理器注册表
_block_handlers = {}
class BlockHandler(ABC):
"""
块处理器抽象基类
所有具体的块处理器都应该继承这个类
"""
@abstractmethod
async def execute(
self,
block: Dict[str, Any],
input_params: Dict[str, Any],
context: Any # TaskContext类型避免循环导入
) -> Dict[str, Any]:
"""
执行块
Args:
block: 块定义
input_params: 解析后的输入参数
context: 任务上下文
Returns:
Dict[str, Any]: 执行结果必须包含success字段
"""
pass
async def _record_task_log(
self,
block: Dict[str, Any],
result: Dict[str, Any],
context: Any, # TaskContext类型避免循环导入
log_type: str = "block_execution",
parent_log_id: str = None,
iteration_index: int = None
) -> str:
"""
记录任务日志
Args:
block: 块定义
result: 执行结果
context: 任务上下文
log_type: 日志类型 (iteration_start, iteration_end, block_execution, branch_execution)
parent_log_id: 父日志ID (用于建立层级关系)
iteration_index: 迭代索引 (循环中的第几次)
Returns:
str: 日志ID
"""
from sqlalchemy import insert
from data.models.tasklog import VWEDTaskLog
from data.session import get_async_session
import uuid
from datetime import datetime
try:
# 创建任务日志记录
task_log_id = str(uuid.uuid4())
# 自定义JSON序列化器处理datetime对象和bytes对象
def json_serializer(obj):
if isinstance(obj, datetime):
return obj.isoformat()
if isinstance(obj, bytes):
return obj.decode('utf-8', errors='backslashreplace')
raise TypeError(f"Object of type {type(obj)} is not JSON serializable")
# 使用传入的参数,如果没有则从context获取
final_parent_log_id = parent_log_id if parent_log_id is not None else context.parent_log_id
final_iteration_index = iteration_index if iteration_index is not None else context.current_iteration_index
async with get_async_session() as session:
stmt = insert(VWEDTaskLog).values(
id=task_log_id,
level=1 if result.get("success", False) else 3, # 1: 信息, 3: 错误
message=json.dumps(result, ensure_ascii=False, default=json_serializer),
task_block_id=context.current_block_name or block.get("name", "unknown"), # 使用实际保存的block_name
task_id=context.task_def_id,
task_record_id=context.task_record_id,
parent_log_id=final_parent_log_id,
iteration_index=final_iteration_index,
block_record_id=context.block_record_id,
log_type=log_type
)
await session.execute(stmt)
await session.commit()
return task_log_id
except Exception as e:
logger.error(f"记录任务日志失败: {str(e)}")
return None
# 注册装饰器
def register_handler(block_type: str):
"""
注册块处理器的装饰器
Args:
block_type: 块类型
"""
def decorator(cls):
_block_handlers[block_type] = cls()
return cls
return decorator
# 获取块处理器
def get_block_handler(block_type: str) -> Optional[BlockHandler]:
"""
获取块处理器
Args:
block_type: 块类型
Returns:
Optional[BlockHandler]: 对应的块处理器如果不存在则返回None
"""
return _block_handlers.get(block_type)