430 lines
18 KiB
Python
430 lines
18 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
|
||
"""
|
||
脚本注册中心服务
|
||
管理脚本实例的全局注册信息,支持多脚本并发运行
|
||
"""
|
||
|
||
import asyncio
|
||
import time
|
||
from datetime import datetime
|
||
from typing import Dict, Any, Optional, List, Callable
|
||
from utils.logger import get_logger
|
||
|
||
logger = get_logger("services.script_registry")
|
||
|
||
|
||
class VWEDGlobalRegistry:
|
||
"""VWED 全局注册中心(支持多脚本独立服务)"""
|
||
|
||
def __init__(self):
|
||
self.api_routes: Dict[str, Dict[str, Any]] = {} # HTTP接口注册: {route_key: route_info}
|
||
self.websocket_routes: Dict[str, Dict[str, Any]] = {} # WebSocket接口注册
|
||
self.tcp_servers: Dict[str, Dict[str, Any]] = {} # TCP服务注册
|
||
self.event_listeners: Dict[str, List[Dict[str, Any]]] = {} # 事件监听器注册
|
||
self.timers: Dict[str, Dict[str, Any]] = {} # 定时任务注册
|
||
self.registered_functions: Dict[str, Dict[str, Any]] = {} # 自定义函数注册
|
||
self.device_handlers: Dict[str, Dict[str, Any]] = {} # 设备处理器注册
|
||
self.active_scripts: Dict[str, Dict[str, Any]] = {} # 活跃脚本
|
||
self.current_script_id: Optional[str] = None # 当前执行的脚本ID
|
||
|
||
def register_api_route(self, path: str, method: str, handler: Callable,
|
||
script_id: str = None, description: str = "",
|
||
params: Dict = None, parameters: Dict = None, response_schema: Dict = None):
|
||
"""注册API接口(按脚本隔离)"""
|
||
script_id = script_id or self.current_script_id
|
||
route_key = f"{method}:{path}"
|
||
|
||
# 支持简化的参数定义格式 params={"key": "default_value"}
|
||
if params is not None and parameters is None:
|
||
# 从简化的params格式转换为完整的parameters格式
|
||
parameters = self._convert_simple_params_to_schema(params)
|
||
elif parameters is None:
|
||
parameters = {}
|
||
|
||
self.api_routes[route_key] = {
|
||
"path": path,
|
||
"method": method,
|
||
"handler": handler,
|
||
"script_id": script_id,
|
||
"description": description,
|
||
"params": params or {}, # 保存原始的简化参数
|
||
"parameters": parameters, # 完整的参数schema
|
||
"response_schema": response_schema or {},
|
||
"registered_at": datetime.now().isoformat(),
|
||
"call_count": 0,
|
||
"last_called_at": None,
|
||
"average_response_time_ms": None
|
||
}
|
||
|
||
logger.info(f"注册API路由: {method} {path} -> {handler.__name__} (script_id: {script_id})")
|
||
|
||
def register_function(self, name: str, handler: Callable, script_id: str = None,
|
||
description: str = "", params: Dict = None, parameters: List[Dict] = None,
|
||
return_schema: Dict = None, tags: List[str] = None):
|
||
"""注册自定义函数(按脚本隔离)"""
|
||
script_id = script_id or self.current_script_id
|
||
|
||
# 检测是否为异步函数
|
||
is_async = asyncio.iscoroutinefunction(handler)
|
||
|
||
# 支持简化的参数定义格式
|
||
if params is not None and parameters is None:
|
||
parameters = self._convert_simple_params_to_list_schema(params)
|
||
elif parameters is None:
|
||
parameters = []
|
||
|
||
self.registered_functions[name] = {
|
||
"handler": handler,
|
||
"script_id": script_id,
|
||
"description": description,
|
||
"params": params or {}, # 保存原始的简化参数
|
||
"parameters": parameters, # 完整的参数schema
|
||
"return_schema": return_schema or {},
|
||
"is_async": is_async,
|
||
"tags": tags or [],
|
||
"registered_at": datetime.now().isoformat(),
|
||
"call_count": 0,
|
||
"success_count": 0,
|
||
"error_count": 0,
|
||
"last_called_at": None,
|
||
"average_execution_time_ms": None
|
||
}
|
||
|
||
logger.info(f"注册函数: {name} -> {handler.__name__} (script_id: {script_id}, async: {is_async})")
|
||
|
||
def register_event_listener(self, event_name: str, handler: Callable,
|
||
script_id: str = None, priority: int = 1):
|
||
"""注册事件监听器(按脚本隔离)"""
|
||
script_id = script_id or self.current_script_id
|
||
|
||
if event_name not in self.event_listeners:
|
||
self.event_listeners[event_name] = []
|
||
|
||
listener_info = {
|
||
"handler": handler,
|
||
"script_id": script_id,
|
||
"priority": priority,
|
||
"registered_at": datetime.now().isoformat()
|
||
}
|
||
|
||
self.event_listeners[event_name].append(listener_info)
|
||
# 按优先级排序
|
||
self.event_listeners[event_name].sort(key=lambda x: x["priority"])
|
||
|
||
logger.info(f"注册事件监听器: {event_name} -> {handler.__name__} (script_id: {script_id})")
|
||
|
||
def register_timer(self, timer_id: str, interval: int, handler: Callable,
|
||
script_id: str = None, repeat: bool = False, delay: int = 0):
|
||
"""注册定时任务"""
|
||
script_id = script_id or self.current_script_id
|
||
|
||
self.timers[timer_id] = {
|
||
"handler": handler,
|
||
"script_id": script_id,
|
||
"interval": interval,
|
||
"repeat": repeat,
|
||
"delay": delay,
|
||
"next_run": time.time() + delay,
|
||
"last_run": None,
|
||
"run_count": 0,
|
||
"registered_at": datetime.now().isoformat()
|
||
}
|
||
|
||
logger.info(f"注册定时任务: {timer_id} -> {handler.__name__} (script_id: {script_id})")
|
||
|
||
def register_device_handler(self, device_id: str, device_type: str,
|
||
listen_topics: List[str], forward_topics: List[str],
|
||
handler: Callable, script_id: str = None, description: str = "",
|
||
device_brand: str = None, protocol_key: str = None,
|
||
auto_encode: bool = True, **kwargs):
|
||
"""注册设备处理器(按脚本隔离)"""
|
||
script_id = script_id or self.current_script_id
|
||
|
||
# 检测是否为异步函数
|
||
is_async = asyncio.iscoroutinefunction(handler)
|
||
|
||
self.device_handlers[device_id] = {
|
||
"device_id": device_id,
|
||
"device_type": device_type,
|
||
"listen_topics": listen_topics,
|
||
"forward_topics": forward_topics,
|
||
"handler": handler,
|
||
"script_id": script_id,
|
||
"description": description,
|
||
"device_brand": device_brand,
|
||
"protocol_key": protocol_key,
|
||
"auto_encode": auto_encode,
|
||
"is_async": is_async,
|
||
"registered_at": datetime.now().isoformat(),
|
||
"status": "running",
|
||
"message_count": 0,
|
||
"success_count": 0,
|
||
"error_count": 0,
|
||
"last_processed": None,
|
||
"average_processing_time_ms": None,
|
||
**kwargs
|
||
}
|
||
|
||
brand_info = f" 品牌: {device_brand}" if device_brand else ""
|
||
protocol_info = f" 协议: {protocol_key}" if protocol_key else ""
|
||
logger.info(f"注册设备处理器: {device_id} ({device_type}{brand_info}{protocol_info}) -> {handler.__name__} (script_id: {script_id}, async: {is_async})")
|
||
logger.info(f"监听topics: {listen_topics}, 转发topics: {forward_topics}")
|
||
|
||
def get_device_handler(self, device_id: str) -> Optional[Dict[str, Any]]:
|
||
"""获取设备处理器信息"""
|
||
return self.device_handlers.get(device_id)
|
||
|
||
def get_api_route(self, path: str, method: str = "GET") -> Optional[Dict[str, Any]]:
|
||
"""获取API路由"""
|
||
route_key = f"{method}:{path}"
|
||
return self.api_routes.get(route_key)
|
||
|
||
def get_registered_function(self, name: str) -> Optional[Callable]:
|
||
"""获取注册的函数"""
|
||
func_info = self.registered_functions.get(name)
|
||
return func_info["handler"] if func_info else None
|
||
|
||
def get_function_info(self, name: str) -> Optional[Dict[str, Any]]:
|
||
"""获取函数详细信息"""
|
||
return self.registered_functions.get(name)
|
||
|
||
# def get_device_handler(self, device_id: str) -> Optional[Dict[str, Any]]:
|
||
# """获取设备处理器信息"""
|
||
# return self.device_handlers.get(device_id)
|
||
|
||
def get_all_device_handlers(self) -> Dict[str, Dict[str, Any]]:
|
||
"""获取所有设备处理器"""
|
||
return self.device_handlers.copy()
|
||
|
||
async def emit_event(self, event_name: str, data: Dict[str, Any]):
|
||
"""触发事件"""
|
||
listeners = self.event_listeners.get(event_name, [])
|
||
|
||
for listener in listeners:
|
||
try:
|
||
handler = listener["handler"]
|
||
if asyncio.iscoroutinefunction(handler):
|
||
await handler(data)
|
||
else:
|
||
handler(data)
|
||
except Exception as e:
|
||
logger.error(f"事件处理器执行失败: {event_name} -> {e}")
|
||
|
||
def clear_script_registrations(self, script_id: str):
|
||
"""清理特定脚本的所有注册项"""
|
||
# 清理API路由
|
||
api_routes_to_remove = [k for k, v in self.api_routes.items()
|
||
if v.get("script_id") == script_id]
|
||
for route_key in api_routes_to_remove:
|
||
del self.api_routes[route_key]
|
||
|
||
# 清理函数注册
|
||
functions_to_remove = [k for k, v in self.registered_functions.items()
|
||
if v.get("script_id") == script_id]
|
||
for func_name in functions_to_remove:
|
||
del self.registered_functions[func_name]
|
||
|
||
# 清理事件监听器
|
||
for event_name in self.event_listeners:
|
||
self.event_listeners[event_name] = [
|
||
listener for listener in self.event_listeners[event_name]
|
||
if listener.get("script_id") != script_id
|
||
]
|
||
|
||
# 清理定时任务
|
||
timers_to_remove = [k for k, v in self.timers.items()
|
||
if v.get("script_id") == script_id]
|
||
for timer_id in timers_to_remove:
|
||
del self.timers[timer_id]
|
||
|
||
# 清理设备处理器
|
||
devices_to_remove = [k for k, v in self.device_handlers.items()
|
||
if v.get("script_id") == script_id]
|
||
for device_id in devices_to_remove:
|
||
del self.device_handlers[device_id]
|
||
|
||
# 从活跃脚本列表移除
|
||
if script_id in self.active_scripts:
|
||
del self.active_scripts[script_id]
|
||
|
||
logger.info(f"已清理脚本 {script_id} 的所有注册项")
|
||
|
||
def clear_all(self):
|
||
"""清理所有注册项"""
|
||
self.api_routes.clear()
|
||
self.websocket_routes.clear()
|
||
self.tcp_servers.clear()
|
||
self.event_listeners.clear()
|
||
self.timers.clear()
|
||
self.registered_functions.clear()
|
||
self.device_handlers.clear()
|
||
self.active_scripts.clear()
|
||
|
||
logger.info("已清理所有注册项")
|
||
|
||
def set_current_script(self, script_id: str):
|
||
"""设置当前执行的脚本ID"""
|
||
self.current_script_id = script_id
|
||
|
||
def add_active_script(self, script_id: str, script_path: str, file_id: int = None):
|
||
"""添加活跃脚本记录"""
|
||
self.active_scripts[script_id] = {
|
||
"script_path": script_path,
|
||
"file_id": file_id,
|
||
"started_at": datetime.now().isoformat(),
|
||
"status": "running",
|
||
"heartbeat": time.time()
|
||
}
|
||
|
||
def update_script_heartbeat(self, script_id: str):
|
||
"""更新脚本心跳"""
|
||
if script_id in self.active_scripts:
|
||
self.active_scripts[script_id]["heartbeat"] = time.time()
|
||
|
||
def get_script_registrations(self, script_id: str) -> Dict[str, int]:
|
||
"""获取特定脚本的注册统计"""
|
||
api_count = len([v for v in self.api_routes.values()
|
||
if v.get("script_id") == script_id])
|
||
function_count = len([v for v in self.registered_functions.values()
|
||
if v.get("script_id") == script_id])
|
||
event_count = sum([
|
||
len([l for l in listeners if l.get("script_id") == script_id])
|
||
for listeners in self.event_listeners.values()
|
||
])
|
||
timer_count = len([v for v in self.timers.values()
|
||
if v.get("script_id") == script_id])
|
||
device_count = len([v for v in self.device_handlers.values()
|
||
if v.get("script_id") == script_id])
|
||
|
||
return {
|
||
"apis": api_count,
|
||
"functions": function_count,
|
||
"events": event_count,
|
||
"timers": timer_count,
|
||
"devices": device_count
|
||
}
|
||
|
||
def get_all_registrations(self) -> Dict[str, Any]:
|
||
"""获取所有注册信息"""
|
||
return {
|
||
"api_routes": {k: {**v, "handler": v["handler"].__name__}
|
||
for k, v in self.api_routes.items()},
|
||
"registered_functions": {k: {**v, "handler": v["handler"].__name__}
|
||
for k, v in self.registered_functions.items()},
|
||
"event_listeners": {k: [{**l, "handler": l["handler"].__name__} for l in v]
|
||
for k, v in self.event_listeners.items()},
|
||
"timers": {k: {**v, "handler": v["handler"].__name__}
|
||
for k, v in self.timers.items()},
|
||
"device_handlers": {k: {**v, "handler": v["handler"].__name__}
|
||
for k, v in self.device_handlers.items()},
|
||
"active_scripts": self.active_scripts
|
||
}
|
||
|
||
def update_api_call_stats(self, route_key: str, response_time_ms: int):
|
||
"""更新API调用统计"""
|
||
if route_key in self.api_routes:
|
||
route_info = self.api_routes[route_key]
|
||
route_info["call_count"] += 1
|
||
route_info["last_called_at"] = datetime.now().isoformat()
|
||
|
||
if route_info["average_response_time_ms"] is None:
|
||
route_info["average_response_time_ms"] = response_time_ms
|
||
else:
|
||
total_time = (route_info["average_response_time_ms"] *
|
||
(route_info["call_count"] - 1) + response_time_ms)
|
||
route_info["average_response_time_ms"] = int(total_time / route_info["call_count"])
|
||
|
||
def update_function_call_stats(self, function_name: str, execution_time_ms: int, success: bool = True):
|
||
"""更新函数调用统计"""
|
||
if function_name in self.registered_functions:
|
||
func_info = self.registered_functions[function_name]
|
||
func_info["call_count"] += 1
|
||
func_info["last_called_at"] = datetime.now().isoformat()
|
||
|
||
if success:
|
||
func_info["success_count"] += 1
|
||
else:
|
||
func_info["error_count"] += 1
|
||
|
||
# 更新平均执行时间
|
||
if func_info["average_execution_time_ms"] is None:
|
||
func_info["average_execution_time_ms"] = execution_time_ms
|
||
else:
|
||
total_time = (func_info["average_execution_time_ms"] *
|
||
(func_info["call_count"] - 1) + execution_time_ms)
|
||
func_info["average_execution_time_ms"] = int(total_time / func_info["call_count"])
|
||
|
||
def _convert_simple_params_to_schema(self, simple_params: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""将简化的参数格式转换为完整的参数schema格式(API接口用)
|
||
|
||
输入格式: {"a": 0, "b": 0, "operation": "add"}
|
||
输出格式: 标准API参数schema
|
||
"""
|
||
properties = {}
|
||
required = []
|
||
|
||
for param_name, default_value in simple_params.items():
|
||
param_type = self._infer_parameter_type(default_value)
|
||
properties[param_name] = {
|
||
"type": param_type,
|
||
"default": default_value
|
||
}
|
||
# 如果有默认值,则不是必需的;如果默认值为None,则可能是必需的
|
||
if default_value is None:
|
||
required.append(param_name)
|
||
|
||
return {
|
||
"type": "object",
|
||
"properties": properties,
|
||
"required": required
|
||
}
|
||
|
||
def _convert_simple_params_to_list_schema(self, simple_params: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||
"""将简化的参数格式转换为列表参数schema格式(函数注册用)
|
||
|
||
输入格式: {"a": 0, "b": 0, "operation": "add"}
|
||
输出格式: [{"name": "a", "type": "integer", "default": 0}, ...]
|
||
"""
|
||
result = []
|
||
|
||
for param_name, default_value in simple_params.items():
|
||
param_type = self._infer_parameter_type(default_value)
|
||
param_info = {
|
||
"name": param_name,
|
||
"type": param_type,
|
||
"default": default_value,
|
||
"required": default_value is None
|
||
}
|
||
result.append(param_info)
|
||
|
||
return result
|
||
|
||
def _infer_parameter_type(self, value: Any) -> str:
|
||
"""根据默认值推断参数类型"""
|
||
if value is None:
|
||
return "any"
|
||
elif isinstance(value, bool):
|
||
return "boolean"
|
||
elif isinstance(value, int):
|
||
return "integer"
|
||
elif isinstance(value, float):
|
||
return "number"
|
||
elif isinstance(value, str):
|
||
return "string"
|
||
elif isinstance(value, list):
|
||
return "array"
|
||
elif isinstance(value, dict):
|
||
return "object"
|
||
else:
|
||
return "string" # 默认为字符串类型
|
||
|
||
|
||
# 全局注册实例
|
||
_global_registry = VWEDGlobalRegistry()
|
||
|
||
|
||
def get_global_registry() -> VWEDGlobalRegistry:
|
||
"""获取全局注册中心实例"""
|
||
return _global_registry |