137 lines
4.3 KiB
Python
137 lines
4.3 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
|
||
"""
|
||
块处理器基类模块
|
||
提供所有块处理器的基类和注册功能
|
||
"""
|
||
|
||
from abc import ABC, abstractmethod
|
||
from typing import Dict, List, Any, Optional, Type, Callable
|
||
import json
|
||
import logging
|
||
|
||
# 获取日志记录器
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# 块处理器注册表
|
||
_block_handlers = {}
|
||
|
||
class BlockHandler(ABC):
|
||
"""
|
||
块处理器抽象基类
|
||
所有具体的块处理器都应该继承这个类
|
||
"""
|
||
|
||
@abstractmethod
|
||
async def execute(
|
||
self,
|
||
block: Dict[str, Any],
|
||
input_params: Dict[str, Any],
|
||
context: Any # TaskContext类型,避免循环导入
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
执行块
|
||
|
||
Args:
|
||
block: 块定义
|
||
input_params: 解析后的输入参数
|
||
context: 任务上下文
|
||
|
||
Returns:
|
||
Dict[str, Any]: 执行结果,必须包含success字段
|
||
"""
|
||
pass
|
||
|
||
async def _record_task_log(
|
||
self,
|
||
block: Dict[str, Any],
|
||
result: Dict[str, Any],
|
||
context: Any, # TaskContext类型,避免循环导入
|
||
log_type: str = "block_execution",
|
||
parent_log_id: str = None,
|
||
iteration_index: int = None
|
||
) -> str:
|
||
"""
|
||
记录任务日志
|
||
|
||
Args:
|
||
block: 块定义
|
||
result: 执行结果
|
||
context: 任务上下文
|
||
log_type: 日志类型 (iteration_start, iteration_end, block_execution, branch_execution)
|
||
parent_log_id: 父日志ID (用于建立层级关系)
|
||
iteration_index: 迭代索引 (循环中的第几次)
|
||
|
||
Returns:
|
||
str: 日志ID
|
||
"""
|
||
from sqlalchemy import insert
|
||
from data.models.tasklog import VWEDTaskLog
|
||
from data.session import get_async_session
|
||
import uuid
|
||
from datetime import datetime
|
||
try:
|
||
# 创建任务日志记录
|
||
task_log_id = str(uuid.uuid4())
|
||
|
||
# 自定义JSON序列化器,处理datetime对象和bytes对象
|
||
def json_serializer(obj):
|
||
if isinstance(obj, datetime):
|
||
return obj.isoformat()
|
||
if isinstance(obj, bytes):
|
||
return obj.decode('utf-8', errors='backslashreplace')
|
||
raise TypeError(f"Object of type {type(obj)} is not JSON serializable")
|
||
|
||
# 使用传入的参数,如果没有则从context获取
|
||
final_parent_log_id = parent_log_id if parent_log_id is not None else context.parent_log_id
|
||
final_iteration_index = iteration_index if iteration_index is not None else context.current_iteration_index
|
||
|
||
async with get_async_session() as session:
|
||
stmt = insert(VWEDTaskLog).values(
|
||
id=task_log_id,
|
||
level=1 if result.get("success", False) else 3, # 1: 信息, 3: 错误
|
||
message=json.dumps(result, ensure_ascii=False, default=json_serializer),
|
||
task_block_id=context.current_block_name or block.get("name", "unknown"), # 使用实际保存的block_name
|
||
task_id=context.task_def_id,
|
||
task_record_id=context.task_record_id,
|
||
parent_log_id=final_parent_log_id,
|
||
iteration_index=final_iteration_index,
|
||
block_record_id=context.block_record_id,
|
||
log_type=log_type
|
||
)
|
||
await session.execute(stmt)
|
||
await session.commit()
|
||
|
||
return task_log_id
|
||
except Exception as e:
|
||
logger.error(f"记录任务日志失败: {str(e)}")
|
||
return None
|
||
|
||
|
||
# 注册装饰器
|
||
def register_handler(block_type: str):
|
||
"""
|
||
注册块处理器的装饰器
|
||
|
||
Args:
|
||
block_type: 块类型
|
||
"""
|
||
def decorator(cls):
|
||
_block_handlers[block_type] = cls()
|
||
return cls
|
||
return decorator
|
||
|
||
|
||
# 获取块处理器
|
||
def get_block_handler(block_type: str) -> Optional[BlockHandler]:
|
||
"""
|
||
获取块处理器
|
||
|
||
Args:
|
||
block_type: 块类型
|
||
|
||
Returns:
|
||
Optional[BlockHandler]: 对应的块处理器,如果不存在则返回None
|
||
"""
|
||
return _block_handlers.get(block_type) |