2025-03-17 14:58:05 +08:00

115 lines
4.5 KiB
Python

# core/component.py
"""
组件基类和工厂
"""
from abc import ABC, abstractmethod
from typing import Dict, Any, Type, Callable, List, Optional
import importlib
import inspect
import pkgutil
from .context import TaskContext
from .exceptions import ComponentError, ParameterError
from utils.logger import get_logger
# 获取日志记录器
logger = get_logger(__name__)
class Component(ABC):
"""组件基类"""
def __init__(self, block_id: str, params: Dict[str, Any]):
self.block_id = block_id
self.params = params
self.context = TaskContext.get_instance()
@abstractmethod
def execute(self) -> Dict[str, Any]:
"""执行组件逻辑"""
pass
def resolve_param(self, param_name: str, default=None) -> Any:
"""解析参数值,处理变量引用"""
if param_name not in self.params:
return default
value = self.params[param_name]
# 处理变量引用 ${xxx}
if isinstance(value, str) and value.startswith("${") and value.endswith("}"):
var_path = value[2:-1].split(".")
current = self.context.variables
for part in var_path:
if isinstance(current, dict) and part in current:
current = current[part]
else:
return default # 变量不存在,返回默认值
return current
# 处理块引用 blocks.bX.xxx
if isinstance(value, str) and value.startswith("blocks."):
parts = value.split(".")
if len(parts) >= 3 and parts[0] == "blocks" and parts[1] in self.context.blocks:
block_data = self.context.blocks[parts[1]]
if len(parts) == 3 and parts[2] in block_data:
return block_data[parts[2]]
return value
def store_result(self, result: Dict[str, Any]) -> None:
"""存储执行结果"""
self.context.set_block_result(self.block_id, result)
def validate_required_params(self, required_params: List[str]) -> None:
"""验证必要参数"""
for param in required_params:
if param not in self.params or self.params[param] is None:
raise ParameterError(f"缺少必要参数: {param}")
class ComponentFactory:
"""组件工厂,负责创建组件实例"""
_components: Dict[str, Type[Component]] = {}
@classmethod
def register(cls, component_type: str) -> Callable:
"""注册组件类的装饰器"""
def decorator(component_class: Type[Component]) -> Type[Component]:
cls._components[component_type] = component_class
logger.info(f"注册组件: {component_type} -> {component_class.__name__}")
return component_class
return decorator
@classmethod
def create(cls, block_id: str, component_type: str, params: Dict[str, Any]) -> Component:
"""创建组件实例"""
if component_type not in cls._components:
raise ComponentError(f"未知的组件类型: {component_type}")
component_class = cls._components[component_type]
return component_class(block_id, params)
@classmethod
def get_component_types(cls) -> List[str]:
"""获取所有已注册的组件类型"""
return list(cls._components.keys())
@classmethod
def auto_discover(cls, package_name: str) -> None:
"""自动发现并注册组件"""
logger.info(f"自动发现组件: {package_name}")
package = importlib.import_module(package_name)
for _, module_name, is_pkg in pkgutil.iter_modules(package.__path__, package.__name__ + '.'):
if not is_pkg:
try:
module = importlib.import_module(module_name)
for name, obj in inspect.getmembers(module):
if (inspect.isclass(obj) and issubclass(obj, Component) and
obj != Component and not inspect.isabstract(obj)):
# 从类名推断组件类型
component_type = obj.__name__.replace('Component', '').lower()
cls._components[component_type] = obj
logger.info(f"自动注册组件: {component_type} -> {obj.__name__}")
except Exception as e:
logger.error(f"加载组件模块失败: {module_name}, 错误: {str(e)}")