317 lines
9.5 KiB
Python
Raw Permalink Normal View History

2025-04-30 16:57:46 +08:00
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
任务上下文模块
提供任务执行过程中的上下文管理
"""
import json
import logging
from typing import Dict, List, Any, Optional
from datetime import datetime
from utils.logger import get_logger
# 获取日志记录器
logger = get_logger("services.execution.task_context")
class TaskContext:
"""
任务上下文类
管理任务执行过程中的数据变量和状态
"""
def __init__(
self,
task_record_id: str,
task_def_id: str,
input_params: Dict[str, Any],
variables: Dict[str, Any] = None,
token: str = None
):
"""
初始化任务上下文
Args:
task_record_id: 任务记录ID
task_def_id: 任务定义ID
input_params: 任务输入参数
variables: 任务变量默认为空字典
"""
self.task_record_id = task_record_id
self.task_def_id = task_def_id
self.input_params = input_params or {}
self.variables = variables or {}
self.variable_sources = {} # 记录每个变量的来源块
self.execution_path = [] # 执行路径
self.outputs = {} # 任务输出
self.block_outputs = {} # 各块的输出结果
self.error = None # 错误信息
self.start_time = datetime.now() # 开始时间
self.is_canceled = False # 是否被取消
self.current_block_id = None # 当前正在执行的块ID
self.current_block_name = None # 当前正在执行的块名称
self.skip_to_component_id = None # 需要跳转到的块ID
self._variables_need_sync = False # 变量是否需要同步到数据库
self.token = token # 任务令牌
def set_current_block(self, block_id: str, block_name: str):
"""
设置当前正在执行的块
Args:
block_id: 块ID
block_name: 块名称
"""
self.current_block_id = block_id
self.current_block_name = block_name
def get_variable(self, name: str, default: Any = None) -> Any:
"""
获取变量值
Args:
name: 变量名
default: 默认值如果变量不存在则返回此值
Returns:
Any: 变量值
"""
# 检查是否是任务输入参数
if name.startswith("taskInputs."):
param_name = name[len("taskInputs."):]
# 查找参数信息
if isinstance(self.input_params, list):
# 若input_params是列表格式需要查找参数信息并从defaultValue字段获取值
for param in self.input_params:
if param.get('name') == param_name:
return param.get('defaultValue', default)
return default
else:
# 旧的处理方式,直接从字典中获取值
return self.input_params.get(param_name, default)
# 检查是否是块输出
if name.startswith("outputs."):
output_path = name[len("outputs."):].split(".")
current = self.block_outputs
for key in output_path:
if key in current:
current = current[key]
else:
return default
return current
# 普通变量
return self.variables.get(name, default)
def set_variable(self, name: str, value: Any) -> None:
"""
设置变量值
Args:
name: 变量名
value: 变量值
"""
# 不允许修改任务输入参数
if name.startswith("taskInputs."):
logger.warning(f"尝试修改任务输入参数 {name},操作被忽略")
return
# 设置变量
self.variables[name] = value
# 标记需要同步变量
self._variables_need_sync = True
# 记录变量来源
if self.current_block_id and self.current_block_name:
self.variable_sources[name] = {
"block_id": self.current_block_id,
"block_name": self.current_block_name,
"timestamp": datetime.now().isoformat()
}
def get_variable_source(self, name: str) -> Optional[Dict[str, str]]:
"""
获取变量的来源信息
Args:
name: 变量名
Returns:
Optional[Dict[str, str]]: 变量来源信息如果不存在则返回None
"""
return self.variable_sources.get(name)
def get_block_variables(self, block_id: str = None, block_name: str = None) -> Dict[str, Any]:
"""
获取由指定块设置的所有变量
Args:
block_id: 块ID
block_name: 块名称
Returns:
Dict[str, Any]: 变量字典
"""
result = {}
# 如果既不指定block_id也不指定block_name返回空字典
if not block_id and not block_name:
return result
# 遍历所有变量来源
for var_name, source in self.variable_sources.items():
# 按块ID匹配
if block_id and source.get("block_id") == block_id:
result[var_name] = self.variables.get(var_name)
# 按块名称匹配
elif block_name and source.get("block_name") == block_name:
result[var_name] = self.variables.get(var_name)
return result
def set_block_output(self, block_id: str, output: Dict[str, Any]) -> None:
"""
设置块输出
Args:
block_id: 块ID
output: 输出值
"""
self.block_outputs[block_id] = output
def get_block_output(self, block_id: str) -> Any:
"""
获取块输出
Args:
block_id: 块ID
Returns:
Any: 块的输出值如果不存在则返回None
"""
return self.block_outputs.get(block_id)
def add_execution_path(self, block_id: str) -> None:
"""
添加执行路径
Args:
block_id: 块ID
"""
self.execution_path.append({
"blockId": block_id,
"timestamp": datetime.now().isoformat()
})
def set_error(self, error: str, block_id: str = None) -> None:
"""
设置错误信息
Args:
error: 错误信息
block_id: 发生错误的块ID
"""
self.error = {
"message": error,
"blockId": block_id,
"timestamp": datetime.now().isoformat()
}
def set_output(self, key: str, value: Any) -> None:
"""
设置任务输出
Args:
key: 输出键名
value: 输出值
"""
self.outputs[key] = value
def mark_canceled(self) -> None:
"""标记任务为已取消状态"""
self.is_canceled = True
def is_task_canceled(self) -> bool:
"""
检查任务是否被取消
Returns:
bool: 是否被取消
"""
return self.is_canceled
def get_execution_time(self) -> int:
"""
获取任务已执行时间毫秒
Returns:
int: 执行时间毫秒
"""
delta = datetime.now() - self.start_time
return int(delta.total_seconds() * 1000)
def set_skip_to(self, component_id: str) -> None:
"""
设置任务需要跳转到的块ID
Args:
component_id: 目标块ID
"""
self.skip_to_component_id = component_id
logger.info(f"设置任务跳转至组件: {component_id}")
def get_skip_to(self) -> Optional[str]:
"""
获取任务需要跳转到的块ID
Returns:
Optional[str]: 如果设置了跳转目标则返回目标块ID否则返回None
"""
return self.skip_to_component_id
def clear_skip_to(self) -> None:
"""
清除跳转标记
"""
self.skip_to_component_id = None
def to_dict(self) -> Dict[str, Any]:
"""
将上下文转换为字典
Returns:
Dict[str, Any]: 上下文字典
"""
return {
"taskRecordId": self.task_record_id,
"taskDefId": self.task_def_id,
"inputParams": self.input_params,
"variables": self.variables,
"variableSources": self.variable_sources,
"executionPath": self.execution_path,
"outputs": self.outputs,
"blockOutputs": self.block_outputs,
"error": self.error,
"startTime": self.start_time.isoformat(),
"executionTime": self.get_execution_time(),
"isCanceled": self.is_canceled,
"skipToComponentId": self.skip_to_component_id
}
def need_sync_variables(self) -> bool:
"""
检查变量是否需要同步到数据库
Returns:
bool: 是否需要同步
"""
return self._variables_need_sync
def mark_variables_synced(self) -> None:
"""
标记变量已同步到数据库
"""
self._variables_need_sync = False