115 lines
4.5 KiB
Python
115 lines
4.5 KiB
Python
# core/component.py
|
|
"""
|
|
组件基类和工厂
|
|
"""
|
|
from abc import ABC, abstractmethod
|
|
from typing import Dict, Any, Type, Callable, List, Optional
|
|
import importlib
|
|
import inspect
|
|
import pkgutil
|
|
from .context import TaskContext
|
|
from .exceptions import ComponentError, ParameterError
|
|
from utils.logger import get_logger
|
|
|
|
# 获取日志记录器
|
|
logger = get_logger(__name__)
|
|
|
|
class Component(ABC):
|
|
"""组件基类"""
|
|
|
|
def __init__(self, block_id: str, params: Dict[str, Any]):
|
|
self.block_id = block_id
|
|
self.params = params
|
|
self.context = TaskContext.get_instance()
|
|
|
|
@abstractmethod
|
|
def execute(self) -> Dict[str, Any]:
|
|
"""执行组件逻辑"""
|
|
pass
|
|
|
|
def resolve_param(self, param_name: str, default=None) -> Any:
|
|
"""解析参数值,处理变量引用"""
|
|
if param_name not in self.params:
|
|
return default
|
|
|
|
value = self.params[param_name]
|
|
|
|
# 处理变量引用 ${xxx}
|
|
if isinstance(value, str) and value.startswith("${") and value.endswith("}"):
|
|
var_path = value[2:-1].split(".")
|
|
current = self.context.variables
|
|
|
|
for part in var_path:
|
|
if isinstance(current, dict) and part in current:
|
|
current = current[part]
|
|
else:
|
|
return default # 变量不存在,返回默认值
|
|
|
|
return current
|
|
|
|
# 处理块引用 blocks.bX.xxx
|
|
if isinstance(value, str) and value.startswith("blocks."):
|
|
parts = value.split(".")
|
|
if len(parts) >= 3 and parts[0] == "blocks" and parts[1] in self.context.blocks:
|
|
block_data = self.context.blocks[parts[1]]
|
|
if len(parts) == 3 and parts[2] in block_data:
|
|
return block_data[parts[2]]
|
|
|
|
return value
|
|
|
|
def store_result(self, result: Dict[str, Any]) -> None:
|
|
"""存储执行结果"""
|
|
self.context.set_block_result(self.block_id, result)
|
|
|
|
def validate_required_params(self, required_params: List[str]) -> None:
|
|
"""验证必要参数"""
|
|
for param in required_params:
|
|
if param not in self.params or self.params[param] is None:
|
|
raise ParameterError(f"缺少必要参数: {param}")
|
|
|
|
class ComponentFactory:
|
|
"""组件工厂,负责创建组件实例"""
|
|
|
|
_components: Dict[str, Type[Component]] = {}
|
|
|
|
@classmethod
|
|
def register(cls, component_type: str) -> Callable:
|
|
"""注册组件类的装饰器"""
|
|
def decorator(component_class: Type[Component]) -> Type[Component]:
|
|
cls._components[component_type] = component_class
|
|
logger.info(f"注册组件: {component_type} -> {component_class.__name__}")
|
|
return component_class
|
|
return decorator
|
|
|
|
@classmethod
|
|
def create(cls, block_id: str, component_type: str, params: Dict[str, Any]) -> Component:
|
|
"""创建组件实例"""
|
|
if component_type not in cls._components:
|
|
raise ComponentError(f"未知的组件类型: {component_type}")
|
|
|
|
component_class = cls._components[component_type]
|
|
return component_class(block_id, params)
|
|
|
|
@classmethod
|
|
def get_component_types(cls) -> List[str]:
|
|
"""获取所有已注册的组件类型"""
|
|
return list(cls._components.keys())
|
|
|
|
@classmethod
|
|
def auto_discover(cls, package_name: str) -> None:
|
|
"""自动发现并注册组件"""
|
|
logger.info(f"自动发现组件: {package_name}")
|
|
package = importlib.import_module(package_name)
|
|
for _, module_name, is_pkg in pkgutil.iter_modules(package.__path__, package.__name__ + '.'):
|
|
if not is_pkg:
|
|
try:
|
|
module = importlib.import_module(module_name)
|
|
for name, obj in inspect.getmembers(module):
|
|
if (inspect.isclass(obj) and issubclass(obj, Component) and
|
|
obj != Component and not inspect.isabstract(obj)):
|
|
# 从类名推断组件类型
|
|
component_type = obj.__name__.replace('Component', '').lower()
|
|
cls._components[component_type] = obj
|
|
logger.info(f"自动注册组件: {component_type} -> {obj.__name__}")
|
|
except Exception as e:
|
|
logger.error(f"加载组件模块失败: {module_name}, 错误: {str(e)}") |