tianfeng_task_modules/services/task_param_service.py
2025-03-18 18:34:03 +08:00

304 lines
12 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
任务参数服务模块
负责任务输入参数的管理
"""
from typing import Dict, Any, List, Optional, Tuple
import json
from config.task_config import TaskInputParamConfig, TaskInputParamType
from config.database import db_session
from sqlalchemy import and_, or_, desc
class TaskParamService:
"""任务参数服务类"""
def __init__(self):
# 导入数据模型
from data.models.task_input_param import TaskInputParam
self.TaskInputParam = TaskInputParam
def get_param_types(self) -> List[Dict[str, Any]]:
"""获取所有支持的参数类型"""
param_types = []
for param_type in TaskInputParamType:
param_types.append({
"key": param_type.value,
"name": param_type.value,
"description": f"{param_type.value}类型参数"
})
return param_types
def get_task_input_params(self, task_id: str, instance_id: str = None) -> Tuple[List[Dict[str, Any]], Optional[str]]:
"""
获取任务输入参数
Args:
task_id: 任务ID
instance_id: 任务实例ID如果为None则尝试获取最新的实例
Returns:
Tuple[List[Dict[str, Any]], Optional[str]]: (任务输入参数列表, 实际使用的实例ID)
"""
# 如果没有提供实例ID尝试查找最新的实例
found_instance_id = instance_id
if not found_instance_id:
from data.models.task_instance import TaskInstance
with db_session() as session:
instance = session.query(TaskInstance).filter(
TaskInstance.task_id == task_id,
TaskInstance.is_deleted == False
).order_by(TaskInstance.updated_at.desc()).first()
if instance:
found_instance_id = instance.instance_id
else:
# 如果没有找到实例返回空列表和None
return [], None
# 从数据库中查询并获取任务的输入参数
with db_session() as session:
# 查询任务输入参数,过滤掉已删除的
db_params = session.query(self.TaskInputParam).filter(
and_(
self.TaskInputParam.instance_id == found_instance_id,
self.TaskInputParam.is_deleted == False
)
).order_by(self.TaskInputParam.sort_order).all()
# 将数据库查询结果转换为字典
params = [param.to_dict() for param in db_params]
return params, found_instance_id
def update_task_input_params(self, task_id: str, params: List[Dict[str, Any]], instance_id: str = None) -> Tuple[int, str, bool]:
"""
更新任务输入参数
Args:
task_id: 任务ID
params: 任务输入参数列表
instance_id: 任务实例ID如果为None则尝试获取最新的实例
Returns:
Tuple[int, str, bool]: (更新的参数数量, 实际使用的实例ID, 是否有数据变动)
"""
# 如果没有提供实例ID尝试查找最新的实例
found_instance_id = instance_id
if not found_instance_id:
from data.models.task_instance import TaskInstance
with db_session() as session:
instance = session.query(TaskInstance).filter(
TaskInstance.task_id == task_id,
TaskInstance.is_deleted == False
).order_by(TaskInstance.updated_at.desc()).first()
if instance:
found_instance_id = instance.instance_id
else:
# 创建新的任务实例
from data.models.task import Task
task = session.query(Task).filter(Task.task_id == task_id).first()
if not task:
raise ValueError(f"任务不存在: {task_id}")
instance = TaskInstance(
task_id=task_id,
name=task.name,
variables={},
priority=1,
input_params={},
block_outputs={},
context_params={}
)
session.add(instance)
session.commit()
found_instance_id = instance.instance_id
# 新建实例,肯定有数据变动
return 0, found_instance_id, True
# 过滤出非系统参数
custom_params = []
system_param_keys = [param["key"] for param in TaskInputParamConfig.get_system_params()]
for param in params:
# 检查是否为系统参数
if param["param_name"] in system_param_keys:
continue
# 确保param存在default_value键
if "default_value" not in param:
param["default_value"] = None
custom_params.append(param)
# 检测是否有数据变动的标志
has_changes = False
# 更新数据库
with db_session() as session:
# 查询当前任务实例的所有自定义参数
existing_params = session.query(self.TaskInputParam).filter(
and_(
self.TaskInputParam.instance_id == found_instance_id,
self.TaskInputParam.is_system == False,
self.TaskInputParam.is_deleted == False
)
).all()
# 创建一个映射,用于快速找到现有参数
existing_param_map = {param.param_name: param for param in existing_params}
# 检查参数数量是否变化
if len(existing_param_map) != len(custom_params):
has_changes = True
# 处理每个自定义参数
for i, param_data in enumerate(custom_params):
param_name = param_data["param_name"]
# 检查是否存在现有参数
if param_name in existing_param_map:
# 获取现有参数
param = existing_param_map[param_name]
# 检查参数是否有变化
if (param.label != param_data.get("label", "") or
param.param_type != param_data.get("param_type", "") or
param.required != param_data.get("required", False) or
param.description != param_data.get("description", "") or
param.sort_order != i):
has_changes = True
# 检查default_value是否有变化需要特殊处理
old_value = param.default_value
new_value = param_data.get("default_value")
# 尝试比较值考虑到None, 空字符串, 空列表等特殊情况)
if ((old_value is None and new_value not in [None, "", [], {}]) or
(new_value is None and old_value not in [None, "", [], {}]) or
(str(old_value) != str(new_value))):
has_changes = True
# 更新现有参数
param.label = param_data.get("label", param_name)
param.param_type = param_data.get("param_type", param.param_type)
param.required = param_data.get("required", False)
param.default_value = new_value
param.description = param_data.get("description", "")
param.sort_order = i
# 从映射中移除,以便后面知道哪些需要删除
del existing_param_map[param_name]
else:
# 创建新参数 - 有新增肯定是有变化的
has_changes = True
new_param = self.TaskInputParam(
instance_id=found_instance_id,
task_id=task_id, # 冗余存储任务ID便于查询
param_name=param_name,
label=param_data.get("label", param_name),
param_type=param_data.get("param_type", TaskInputParamType.STRING.value),
required=param_data.get("required", False),
default_value=param_data.get("default_value"),
description=param_data.get("description", ""),
is_system=False,
is_readonly=False,
sort_order=i
)
session.add(new_param)
# 如果有需要删除的参数,标记变化
if existing_param_map:
has_changes = True
# 标记需要删除的参数
for param in existing_param_map.values():
param.is_deleted = True
# 只有在有变化时才提交事务
if has_changes:
session.commit()
return len(custom_params), found_instance_id, has_changes
def delete_task_input_param(self, task_id: str, param_id: str, instance_id: str = None) -> bool:
"""
删除任务输入参数
Args:
task_id: 任务ID
param_id: 参数ID
instance_id: 任务实例ID如果为None则尝试获取最新的实例
Returns:
bool: 是否成功删除
"""
# 如果没有提供实例ID尝试查找最新的实例
if not instance_id:
from data.models.task_instance import TaskInstance
with db_session() as session:
instance = session.query(TaskInstance).filter(
TaskInstance.task_id == task_id,
TaskInstance.is_deleted == False
).order_by(TaskInstance.updated_at.desc()).first()
if instance:
instance_id = instance.instance_id
else:
return False # 如果找不到实例,返回删除失败
with db_session() as session:
# 查询参数
param = session.query(self.TaskInputParam).filter(
and_(
self.TaskInputParam.instance_id == instance_id,
self.TaskInputParam.param_id == param_id,
self.TaskInputParam.is_deleted == False
)
).first()
if not param:
return False
# 系统参数不允许删除
if param.is_system:
return False
# 标记为已删除
param.is_deleted = True
session.commit()
return True
def get_default_input_params(self) -> List[Dict[str, Any]]:
"""
获取默认的任务输入参数
Returns:
List[Dict[str, Any]]: 默认的任务输入参数列表
"""
# 返回一些常用的默认参数作为示例
return [
{
"param_name": "robotId",
"label": "机器人ID",
"param_type": TaskInputParamType.STRING,
"required": True,
"default_value": "",
"description": "执行任务的机器人ID"
},
{
"param_name": "targetPosition",
"label": "目标位置",
"param_type": TaskInputParamType.STRING,
"required": False,
"default_value": "",
"description": "任务执行的目标位置"
},
{
"param_name": "timeout",
"label": "超时时间",
"param_type": TaskInputParamType.INTEGER,
"required": False,
"default_value": 3600,
"description": "任务执行的超时时间(秒)"
}
]