367 lines
15 KiB
Python
367 lines
15 KiB
Python
|
#!/usr/bin/env python
|
|||
|
# -*- coding: utf-8 -*-
|
|||
|
|
|||
|
"""
|
|||
|
脚本注册中心服务
|
|||
|
管理脚本实例的全局注册信息,支持多脚本并发运行
|
|||
|
"""
|
|||
|
|
|||
|
import asyncio
|
|||
|
import time
|
|||
|
from datetime import datetime
|
|||
|
from typing import Dict, Any, Optional, List, Callable
|
|||
|
from utils.logger import get_logger
|
|||
|
|
|||
|
logger = get_logger("services.script_registry")
|
|||
|
|
|||
|
|
|||
|
class VWEDGlobalRegistry:
|
|||
|
"""VWED 全局注册中心(支持多脚本独立服务)"""
|
|||
|
|
|||
|
def __init__(self):
|
|||
|
self.api_routes: Dict[str, Dict[str, Any]] = {} # HTTP接口注册: {route_key: route_info}
|
|||
|
self.websocket_routes: Dict[str, Dict[str, Any]] = {} # WebSocket接口注册
|
|||
|
self.tcp_servers: Dict[str, Dict[str, Any]] = {} # TCP服务注册
|
|||
|
self.event_listeners: Dict[str, List[Dict[str, Any]]] = {} # 事件监听器注册
|
|||
|
self.timers: Dict[str, Dict[str, Any]] = {} # 定时任务注册
|
|||
|
self.registered_functions: Dict[str, Dict[str, Any]] = {} # 自定义函数注册
|
|||
|
self.active_scripts: Dict[str, Dict[str, Any]] = {} # 活跃脚本
|
|||
|
self.current_script_id: Optional[str] = None # 当前执行的脚本ID
|
|||
|
|
|||
|
def register_api_route(self, path: str, method: str, handler: Callable,
|
|||
|
script_id: str = None, description: str = "",
|
|||
|
params: Dict = None, parameters: Dict = None, response_schema: Dict = None):
|
|||
|
"""注册API接口(按脚本隔离)"""
|
|||
|
script_id = script_id or self.current_script_id
|
|||
|
route_key = f"{method}:{path}"
|
|||
|
|
|||
|
# 支持简化的参数定义格式 params={"key": "default_value"}
|
|||
|
if params is not None and parameters is None:
|
|||
|
# 从简化的params格式转换为完整的parameters格式
|
|||
|
parameters = self._convert_simple_params_to_schema(params)
|
|||
|
elif parameters is None:
|
|||
|
parameters = {}
|
|||
|
|
|||
|
self.api_routes[route_key] = {
|
|||
|
"path": path,
|
|||
|
"method": method,
|
|||
|
"handler": handler,
|
|||
|
"script_id": script_id,
|
|||
|
"description": description,
|
|||
|
"params": params or {}, # 保存原始的简化参数
|
|||
|
"parameters": parameters, # 完整的参数schema
|
|||
|
"response_schema": response_schema or {},
|
|||
|
"registered_at": datetime.now().isoformat(),
|
|||
|
"call_count": 0,
|
|||
|
"last_called_at": None,
|
|||
|
"average_response_time_ms": None
|
|||
|
}
|
|||
|
|
|||
|
logger.info(f"注册API路由: {method} {path} -> {handler.__name__} (script_id: {script_id})")
|
|||
|
|
|||
|
def register_function(self, name: str, handler: Callable, script_id: str = None,
|
|||
|
description: str = "", params: Dict = None, parameters: List[Dict] = None,
|
|||
|
return_schema: Dict = None, tags: List[str] = None):
|
|||
|
"""注册自定义函数(按脚本隔离)"""
|
|||
|
script_id = script_id or self.current_script_id
|
|||
|
|
|||
|
# 检测是否为异步函数
|
|||
|
is_async = asyncio.iscoroutinefunction(handler)
|
|||
|
|
|||
|
# 支持简化的参数定义格式
|
|||
|
if params is not None and parameters is None:
|
|||
|
parameters = self._convert_simple_params_to_list_schema(params)
|
|||
|
elif parameters is None:
|
|||
|
parameters = []
|
|||
|
|
|||
|
self.registered_functions[name] = {
|
|||
|
"handler": handler,
|
|||
|
"script_id": script_id,
|
|||
|
"description": description,
|
|||
|
"params": params or {}, # 保存原始的简化参数
|
|||
|
"parameters": parameters, # 完整的参数schema
|
|||
|
"return_schema": return_schema or {},
|
|||
|
"is_async": is_async,
|
|||
|
"tags": tags or [],
|
|||
|
"registered_at": datetime.now().isoformat(),
|
|||
|
"call_count": 0,
|
|||
|
"success_count": 0,
|
|||
|
"error_count": 0,
|
|||
|
"last_called_at": None,
|
|||
|
"average_execution_time_ms": None
|
|||
|
}
|
|||
|
|
|||
|
logger.info(f"注册函数: {name} -> {handler.__name__} (script_id: {script_id}, async: {is_async})")
|
|||
|
|
|||
|
def register_event_listener(self, event_name: str, handler: Callable,
|
|||
|
script_id: str = None, priority: int = 1):
|
|||
|
"""注册事件监听器(按脚本隔离)"""
|
|||
|
script_id = script_id or self.current_script_id
|
|||
|
|
|||
|
if event_name not in self.event_listeners:
|
|||
|
self.event_listeners[event_name] = []
|
|||
|
|
|||
|
listener_info = {
|
|||
|
"handler": handler,
|
|||
|
"script_id": script_id,
|
|||
|
"priority": priority,
|
|||
|
"registered_at": datetime.now().isoformat()
|
|||
|
}
|
|||
|
|
|||
|
self.event_listeners[event_name].append(listener_info)
|
|||
|
# 按优先级排序
|
|||
|
self.event_listeners[event_name].sort(key=lambda x: x["priority"])
|
|||
|
|
|||
|
logger.info(f"注册事件监听器: {event_name} -> {handler.__name__} (script_id: {script_id})")
|
|||
|
|
|||
|
def register_timer(self, timer_id: str, interval: int, handler: Callable,
|
|||
|
script_id: str = None, repeat: bool = False, delay: int = 0):
|
|||
|
"""注册定时任务"""
|
|||
|
script_id = script_id or self.current_script_id
|
|||
|
|
|||
|
self.timers[timer_id] = {
|
|||
|
"handler": handler,
|
|||
|
"script_id": script_id,
|
|||
|
"interval": interval,
|
|||
|
"repeat": repeat,
|
|||
|
"delay": delay,
|
|||
|
"next_run": time.time() + delay,
|
|||
|
"last_run": None,
|
|||
|
"run_count": 0,
|
|||
|
"registered_at": datetime.now().isoformat()
|
|||
|
}
|
|||
|
|
|||
|
logger.info(f"注册定时任务: {timer_id} -> {handler.__name__} (script_id: {script_id})")
|
|||
|
|
|||
|
def get_api_route(self, path: str, method: str = "GET") -> Optional[Dict[str, Any]]:
|
|||
|
"""获取API路由"""
|
|||
|
route_key = f"{method}:{path}"
|
|||
|
return self.api_routes.get(route_key)
|
|||
|
|
|||
|
def get_registered_function(self, name: str) -> Optional[Callable]:
|
|||
|
"""获取注册的函数"""
|
|||
|
func_info = self.registered_functions.get(name)
|
|||
|
return func_info["handler"] if func_info else None
|
|||
|
|
|||
|
def get_function_info(self, name: str) -> Optional[Dict[str, Any]]:
|
|||
|
"""获取函数详细信息"""
|
|||
|
return self.registered_functions.get(name)
|
|||
|
|
|||
|
async def emit_event(self, event_name: str, data: Dict[str, Any]):
|
|||
|
"""触发事件"""
|
|||
|
listeners = self.event_listeners.get(event_name, [])
|
|||
|
|
|||
|
for listener in listeners:
|
|||
|
try:
|
|||
|
handler = listener["handler"]
|
|||
|
if asyncio.iscoroutinefunction(handler):
|
|||
|
await handler(data)
|
|||
|
else:
|
|||
|
handler(data)
|
|||
|
except Exception as e:
|
|||
|
logger.error(f"事件处理器执行失败: {event_name} -> {e}")
|
|||
|
|
|||
|
def clear_script_registrations(self, script_id: str):
|
|||
|
"""清理特定脚本的所有注册项"""
|
|||
|
# 清理API路由
|
|||
|
api_routes_to_remove = [k for k, v in self.api_routes.items()
|
|||
|
if v.get("script_id") == script_id]
|
|||
|
for route_key in api_routes_to_remove:
|
|||
|
del self.api_routes[route_key]
|
|||
|
|
|||
|
# 清理函数注册
|
|||
|
functions_to_remove = [k for k, v in self.registered_functions.items()
|
|||
|
if v.get("script_id") == script_id]
|
|||
|
for func_name in functions_to_remove:
|
|||
|
del self.registered_functions[func_name]
|
|||
|
|
|||
|
# 清理事件监听器
|
|||
|
for event_name in self.event_listeners:
|
|||
|
self.event_listeners[event_name] = [
|
|||
|
listener for listener in self.event_listeners[event_name]
|
|||
|
if listener.get("script_id") != script_id
|
|||
|
]
|
|||
|
|
|||
|
# 清理定时任务
|
|||
|
timers_to_remove = [k for k, v in self.timers.items()
|
|||
|
if v.get("script_id") == script_id]
|
|||
|
for timer_id in timers_to_remove:
|
|||
|
del self.timers[timer_id]
|
|||
|
|
|||
|
# 从活跃脚本列表移除
|
|||
|
if script_id in self.active_scripts:
|
|||
|
del self.active_scripts[script_id]
|
|||
|
|
|||
|
logger.info(f"已清理脚本 {script_id} 的所有注册项")
|
|||
|
|
|||
|
def clear_all(self):
|
|||
|
"""清理所有注册项"""
|
|||
|
self.api_routes.clear()
|
|||
|
self.websocket_routes.clear()
|
|||
|
self.tcp_servers.clear()
|
|||
|
self.event_listeners.clear()
|
|||
|
self.timers.clear()
|
|||
|
self.registered_functions.clear()
|
|||
|
self.active_scripts.clear()
|
|||
|
|
|||
|
logger.info("已清理所有注册项")
|
|||
|
|
|||
|
def set_current_script(self, script_id: str):
|
|||
|
"""设置当前执行的脚本ID"""
|
|||
|
self.current_script_id = script_id
|
|||
|
|
|||
|
def add_active_script(self, script_id: str, script_path: str, file_id: int = None):
|
|||
|
"""添加活跃脚本记录"""
|
|||
|
self.active_scripts[script_id] = {
|
|||
|
"script_path": script_path,
|
|||
|
"file_id": file_id,
|
|||
|
"started_at": datetime.now().isoformat(),
|
|||
|
"status": "running",
|
|||
|
"heartbeat": time.time()
|
|||
|
}
|
|||
|
|
|||
|
def update_script_heartbeat(self, script_id: str):
|
|||
|
"""更新脚本心跳"""
|
|||
|
if script_id in self.active_scripts:
|
|||
|
self.active_scripts[script_id]["heartbeat"] = time.time()
|
|||
|
|
|||
|
def get_script_registrations(self, script_id: str) -> Dict[str, int]:
|
|||
|
"""获取特定脚本的注册统计"""
|
|||
|
api_count = len([v for v in self.api_routes.values()
|
|||
|
if v.get("script_id") == script_id])
|
|||
|
function_count = len([v for v in self.registered_functions.values()
|
|||
|
if v.get("script_id") == script_id])
|
|||
|
event_count = sum([
|
|||
|
len([l for l in listeners if l.get("script_id") == script_id])
|
|||
|
for listeners in self.event_listeners.values()
|
|||
|
])
|
|||
|
timer_count = len([v for v in self.timers.values()
|
|||
|
if v.get("script_id") == script_id])
|
|||
|
|
|||
|
return {
|
|||
|
"apis": api_count,
|
|||
|
"functions": function_count,
|
|||
|
"events": event_count,
|
|||
|
"timers": timer_count
|
|||
|
}
|
|||
|
|
|||
|
def get_all_registrations(self) -> Dict[str, Any]:
|
|||
|
"""获取所有注册信息"""
|
|||
|
return {
|
|||
|
"api_routes": {k: {**v, "handler": v["handler"].__name__}
|
|||
|
for k, v in self.api_routes.items()},
|
|||
|
"registered_functions": {k: {**v, "handler": v["handler"].__name__}
|
|||
|
for k, v in self.registered_functions.items()},
|
|||
|
"event_listeners": {k: [{**l, "handler": l["handler"].__name__} for l in v]
|
|||
|
for k, v in self.event_listeners.items()},
|
|||
|
"timers": {k: {**v, "handler": v["handler"].__name__}
|
|||
|
for k, v in self.timers.items()},
|
|||
|
"active_scripts": self.active_scripts
|
|||
|
}
|
|||
|
|
|||
|
def update_api_call_stats(self, route_key: str, response_time_ms: int):
|
|||
|
"""更新API调用统计"""
|
|||
|
if route_key in self.api_routes:
|
|||
|
route_info = self.api_routes[route_key]
|
|||
|
route_info["call_count"] += 1
|
|||
|
route_info["last_called_at"] = datetime.now().isoformat()
|
|||
|
|
|||
|
if route_info["average_response_time_ms"] is None:
|
|||
|
route_info["average_response_time_ms"] = response_time_ms
|
|||
|
else:
|
|||
|
total_time = (route_info["average_response_time_ms"] *
|
|||
|
(route_info["call_count"] - 1) + response_time_ms)
|
|||
|
route_info["average_response_time_ms"] = int(total_time / route_info["call_count"])
|
|||
|
|
|||
|
def update_function_call_stats(self, function_name: str, execution_time_ms: int, success: bool = True):
|
|||
|
"""更新函数调用统计"""
|
|||
|
if function_name in self.registered_functions:
|
|||
|
func_info = self.registered_functions[function_name]
|
|||
|
func_info["call_count"] += 1
|
|||
|
func_info["last_called_at"] = datetime.now().isoformat()
|
|||
|
|
|||
|
if success:
|
|||
|
func_info["success_count"] += 1
|
|||
|
else:
|
|||
|
func_info["error_count"] += 1
|
|||
|
|
|||
|
# 更新平均执行时间
|
|||
|
if func_info["average_execution_time_ms"] is None:
|
|||
|
func_info["average_execution_time_ms"] = execution_time_ms
|
|||
|
else:
|
|||
|
total_time = (func_info["average_execution_time_ms"] *
|
|||
|
(func_info["call_count"] - 1) + execution_time_ms)
|
|||
|
func_info["average_execution_time_ms"] = int(total_time / func_info["call_count"])
|
|||
|
|
|||
|
def _convert_simple_params_to_schema(self, simple_params: Dict[str, Any]) -> Dict[str, Any]:
|
|||
|
"""将简化的参数格式转换为完整的参数schema格式(API接口用)
|
|||
|
|
|||
|
输入格式: {"a": 0, "b": 0, "operation": "add"}
|
|||
|
输出格式: 标准API参数schema
|
|||
|
"""
|
|||
|
properties = {}
|
|||
|
required = []
|
|||
|
|
|||
|
for param_name, default_value in simple_params.items():
|
|||
|
param_type = self._infer_parameter_type(default_value)
|
|||
|
properties[param_name] = {
|
|||
|
"type": param_type,
|
|||
|
"default": default_value
|
|||
|
}
|
|||
|
# 如果有默认值,则不是必需的;如果默认值为None,则可能是必需的
|
|||
|
if default_value is None:
|
|||
|
required.append(param_name)
|
|||
|
|
|||
|
return {
|
|||
|
"type": "object",
|
|||
|
"properties": properties,
|
|||
|
"required": required
|
|||
|
}
|
|||
|
|
|||
|
def _convert_simple_params_to_list_schema(self, simple_params: Dict[str, Any]) -> List[Dict[str, Any]]:
|
|||
|
"""将简化的参数格式转换为列表参数schema格式(函数注册用)
|
|||
|
|
|||
|
输入格式: {"a": 0, "b": 0, "operation": "add"}
|
|||
|
输出格式: [{"name": "a", "type": "integer", "default": 0}, ...]
|
|||
|
"""
|
|||
|
result = []
|
|||
|
|
|||
|
for param_name, default_value in simple_params.items():
|
|||
|
param_type = self._infer_parameter_type(default_value)
|
|||
|
param_info = {
|
|||
|
"name": param_name,
|
|||
|
"type": param_type,
|
|||
|
"default": default_value,
|
|||
|
"required": default_value is None
|
|||
|
}
|
|||
|
result.append(param_info)
|
|||
|
|
|||
|
return result
|
|||
|
|
|||
|
def _infer_parameter_type(self, value: Any) -> str:
|
|||
|
"""根据默认值推断参数类型"""
|
|||
|
if value is None:
|
|||
|
return "any"
|
|||
|
elif isinstance(value, bool):
|
|||
|
return "boolean"
|
|||
|
elif isinstance(value, int):
|
|||
|
return "integer"
|
|||
|
elif isinstance(value, float):
|
|||
|
return "number"
|
|||
|
elif isinstance(value, str):
|
|||
|
return "string"
|
|||
|
elif isinstance(value, list):
|
|||
|
return "array"
|
|||
|
elif isinstance(value, dict):
|
|||
|
return "object"
|
|||
|
else:
|
|||
|
return "string" # 默认为字符串类型
|
|||
|
|
|||
|
|
|||
|
# 全局注册实例
|
|||
|
_global_registry = VWEDGlobalRegistry()
|
|||
|
|
|||
|
|
|||
|
def get_global_registry() -> VWEDGlobalRegistry:
|
|||
|
"""获取全局注册中心实例"""
|
|||
|
return _global_registry
|