VWED_server/services/online_script/script_registry_service.py

430 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
脚本注册中心服务
管理脚本实例的全局注册信息,支持多脚本并发运行
"""
import asyncio
import time
from datetime import datetime
from typing import Dict, Any, Optional, List, Callable
from utils.logger import get_logger
logger = get_logger("services.script_registry")
class VWEDGlobalRegistry:
"""VWED 全局注册中心(支持多脚本独立服务)"""
def __init__(self):
self.api_routes: Dict[str, Dict[str, Any]] = {} # HTTP接口注册: {route_key: route_info}
self.websocket_routes: Dict[str, Dict[str, Any]] = {} # WebSocket接口注册
self.tcp_servers: Dict[str, Dict[str, Any]] = {} # TCP服务注册
self.event_listeners: Dict[str, List[Dict[str, Any]]] = {} # 事件监听器注册
self.timers: Dict[str, Dict[str, Any]] = {} # 定时任务注册
self.registered_functions: Dict[str, Dict[str, Any]] = {} # 自定义函数注册
self.device_handlers: Dict[str, Dict[str, Any]] = {} # 设备处理器注册
self.active_scripts: Dict[str, Dict[str, Any]] = {} # 活跃脚本
self.current_script_id: Optional[str] = None # 当前执行的脚本ID
def register_api_route(self, path: str, method: str, handler: Callable,
script_id: str = None, description: str = "",
params: Dict = None, parameters: Dict = None, response_schema: Dict = None):
"""注册API接口按脚本隔离"""
script_id = script_id or self.current_script_id
route_key = f"{method}:{path}"
# 支持简化的参数定义格式 params={"key": "default_value"}
if params is not None and parameters is None:
# 从简化的params格式转换为完整的parameters格式
parameters = self._convert_simple_params_to_schema(params)
elif parameters is None:
parameters = {}
self.api_routes[route_key] = {
"path": path,
"method": method,
"handler": handler,
"script_id": script_id,
"description": description,
"params": params or {}, # 保存原始的简化参数
"parameters": parameters, # 完整的参数schema
"response_schema": response_schema or {},
"registered_at": datetime.now().isoformat(),
"call_count": 0,
"last_called_at": None,
"average_response_time_ms": None
}
logger.info(f"注册API路由: {method} {path} -> {handler.__name__} (script_id: {script_id})")
def register_function(self, name: str, handler: Callable, script_id: str = None,
description: str = "", params: Dict = None, parameters: List[Dict] = None,
return_schema: Dict = None, tags: List[str] = None):
"""注册自定义函数(按脚本隔离)"""
script_id = script_id or self.current_script_id
# 检测是否为异步函数
is_async = asyncio.iscoroutinefunction(handler)
# 支持简化的参数定义格式
if params is not None and parameters is None:
parameters = self._convert_simple_params_to_list_schema(params)
elif parameters is None:
parameters = []
self.registered_functions[name] = {
"handler": handler,
"script_id": script_id,
"description": description,
"params": params or {}, # 保存原始的简化参数
"parameters": parameters, # 完整的参数schema
"return_schema": return_schema or {},
"is_async": is_async,
"tags": tags or [],
"registered_at": datetime.now().isoformat(),
"call_count": 0,
"success_count": 0,
"error_count": 0,
"last_called_at": None,
"average_execution_time_ms": None
}
logger.info(f"注册函数: {name} -> {handler.__name__} (script_id: {script_id}, async: {is_async})")
def register_event_listener(self, event_name: str, handler: Callable,
script_id: str = None, priority: int = 1):
"""注册事件监听器(按脚本隔离)"""
script_id = script_id or self.current_script_id
if event_name not in self.event_listeners:
self.event_listeners[event_name] = []
listener_info = {
"handler": handler,
"script_id": script_id,
"priority": priority,
"registered_at": datetime.now().isoformat()
}
self.event_listeners[event_name].append(listener_info)
# 按优先级排序
self.event_listeners[event_name].sort(key=lambda x: x["priority"])
logger.info(f"注册事件监听器: {event_name} -> {handler.__name__} (script_id: {script_id})")
def register_timer(self, timer_id: str, interval: int, handler: Callable,
script_id: str = None, repeat: bool = False, delay: int = 0):
"""注册定时任务"""
script_id = script_id or self.current_script_id
self.timers[timer_id] = {
"handler": handler,
"script_id": script_id,
"interval": interval,
"repeat": repeat,
"delay": delay,
"next_run": time.time() + delay,
"last_run": None,
"run_count": 0,
"registered_at": datetime.now().isoformat()
}
logger.info(f"注册定时任务: {timer_id} -> {handler.__name__} (script_id: {script_id})")
def register_device_handler(self, device_id: str, device_type: str,
listen_topics: List[str], forward_topics: List[str],
handler: Callable, script_id: str = None, description: str = "",
device_brand: str = None, protocol_key: str = None,
auto_encode: bool = True, **kwargs):
"""注册设备处理器(按脚本隔离)"""
script_id = script_id or self.current_script_id
# 检测是否为异步函数
is_async = asyncio.iscoroutinefunction(handler)
self.device_handlers[device_id] = {
"device_id": device_id,
"device_type": device_type,
"listen_topics": listen_topics,
"forward_topics": forward_topics,
"handler": handler,
"script_id": script_id,
"description": description,
"device_brand": device_brand,
"protocol_key": protocol_key,
"auto_encode": auto_encode,
"is_async": is_async,
"registered_at": datetime.now().isoformat(),
"status": "running",
"message_count": 0,
"success_count": 0,
"error_count": 0,
"last_processed": None,
"average_processing_time_ms": None,
**kwargs
}
brand_info = f" 品牌: {device_brand}" if device_brand else ""
protocol_info = f" 协议: {protocol_key}" if protocol_key else ""
logger.info(f"注册设备处理器: {device_id} ({device_type}{brand_info}{protocol_info}) -> {handler.__name__} (script_id: {script_id}, async: {is_async})")
logger.info(f"监听topics: {listen_topics}, 转发topics: {forward_topics}")
def get_device_handler(self, device_id: str) -> Optional[Dict[str, Any]]:
"""获取设备处理器信息"""
return self.device_handlers.get(device_id)
def get_api_route(self, path: str, method: str = "GET") -> Optional[Dict[str, Any]]:
"""获取API路由"""
route_key = f"{method}:{path}"
return self.api_routes.get(route_key)
def get_registered_function(self, name: str) -> Optional[Callable]:
"""获取注册的函数"""
func_info = self.registered_functions.get(name)
return func_info["handler"] if func_info else None
def get_function_info(self, name: str) -> Optional[Dict[str, Any]]:
"""获取函数详细信息"""
return self.registered_functions.get(name)
# def get_device_handler(self, device_id: str) -> Optional[Dict[str, Any]]:
# """获取设备处理器信息"""
# return self.device_handlers.get(device_id)
def get_all_device_handlers(self) -> Dict[str, Dict[str, Any]]:
"""获取所有设备处理器"""
return self.device_handlers.copy()
async def emit_event(self, event_name: str, data: Dict[str, Any]):
"""触发事件"""
listeners = self.event_listeners.get(event_name, [])
for listener in listeners:
try:
handler = listener["handler"]
if asyncio.iscoroutinefunction(handler):
await handler(data)
else:
handler(data)
except Exception as e:
logger.error(f"事件处理器执行失败: {event_name} -> {e}")
def clear_script_registrations(self, script_id: str):
"""清理特定脚本的所有注册项"""
# 清理API路由
api_routes_to_remove = [k for k, v in self.api_routes.items()
if v.get("script_id") == script_id]
for route_key in api_routes_to_remove:
del self.api_routes[route_key]
# 清理函数注册
functions_to_remove = [k for k, v in self.registered_functions.items()
if v.get("script_id") == script_id]
for func_name in functions_to_remove:
del self.registered_functions[func_name]
# 清理事件监听器
for event_name in self.event_listeners:
self.event_listeners[event_name] = [
listener for listener in self.event_listeners[event_name]
if listener.get("script_id") != script_id
]
# 清理定时任务
timers_to_remove = [k for k, v in self.timers.items()
if v.get("script_id") == script_id]
for timer_id in timers_to_remove:
del self.timers[timer_id]
# 清理设备处理器
devices_to_remove = [k for k, v in self.device_handlers.items()
if v.get("script_id") == script_id]
for device_id in devices_to_remove:
del self.device_handlers[device_id]
# 从活跃脚本列表移除
if script_id in self.active_scripts:
del self.active_scripts[script_id]
logger.info(f"已清理脚本 {script_id} 的所有注册项")
def clear_all(self):
"""清理所有注册项"""
self.api_routes.clear()
self.websocket_routes.clear()
self.tcp_servers.clear()
self.event_listeners.clear()
self.timers.clear()
self.registered_functions.clear()
self.device_handlers.clear()
self.active_scripts.clear()
logger.info("已清理所有注册项")
def set_current_script(self, script_id: str):
"""设置当前执行的脚本ID"""
self.current_script_id = script_id
def add_active_script(self, script_id: str, script_path: str, file_id: int = None):
"""添加活跃脚本记录"""
self.active_scripts[script_id] = {
"script_path": script_path,
"file_id": file_id,
"started_at": datetime.now().isoformat(),
"status": "running",
"heartbeat": time.time()
}
def update_script_heartbeat(self, script_id: str):
"""更新脚本心跳"""
if script_id in self.active_scripts:
self.active_scripts[script_id]["heartbeat"] = time.time()
def get_script_registrations(self, script_id: str) -> Dict[str, int]:
"""获取特定脚本的注册统计"""
api_count = len([v for v in self.api_routes.values()
if v.get("script_id") == script_id])
function_count = len([v for v in self.registered_functions.values()
if v.get("script_id") == script_id])
event_count = sum([
len([l for l in listeners if l.get("script_id") == script_id])
for listeners in self.event_listeners.values()
])
timer_count = len([v for v in self.timers.values()
if v.get("script_id") == script_id])
device_count = len([v for v in self.device_handlers.values()
if v.get("script_id") == script_id])
return {
"apis": api_count,
"functions": function_count,
"events": event_count,
"timers": timer_count,
"devices": device_count
}
def get_all_registrations(self) -> Dict[str, Any]:
"""获取所有注册信息"""
return {
"api_routes": {k: {**v, "handler": v["handler"].__name__}
for k, v in self.api_routes.items()},
"registered_functions": {k: {**v, "handler": v["handler"].__name__}
for k, v in self.registered_functions.items()},
"event_listeners": {k: [{**l, "handler": l["handler"].__name__} for l in v]
for k, v in self.event_listeners.items()},
"timers": {k: {**v, "handler": v["handler"].__name__}
for k, v in self.timers.items()},
"device_handlers": {k: {**v, "handler": v["handler"].__name__}
for k, v in self.device_handlers.items()},
"active_scripts": self.active_scripts
}
def update_api_call_stats(self, route_key: str, response_time_ms: int):
"""更新API调用统计"""
if route_key in self.api_routes:
route_info = self.api_routes[route_key]
route_info["call_count"] += 1
route_info["last_called_at"] = datetime.now().isoformat()
if route_info["average_response_time_ms"] is None:
route_info["average_response_time_ms"] = response_time_ms
else:
total_time = (route_info["average_response_time_ms"] *
(route_info["call_count"] - 1) + response_time_ms)
route_info["average_response_time_ms"] = int(total_time / route_info["call_count"])
def update_function_call_stats(self, function_name: str, execution_time_ms: int, success: bool = True):
"""更新函数调用统计"""
if function_name in self.registered_functions:
func_info = self.registered_functions[function_name]
func_info["call_count"] += 1
func_info["last_called_at"] = datetime.now().isoformat()
if success:
func_info["success_count"] += 1
else:
func_info["error_count"] += 1
# 更新平均执行时间
if func_info["average_execution_time_ms"] is None:
func_info["average_execution_time_ms"] = execution_time_ms
else:
total_time = (func_info["average_execution_time_ms"] *
(func_info["call_count"] - 1) + execution_time_ms)
func_info["average_execution_time_ms"] = int(total_time / func_info["call_count"])
def _convert_simple_params_to_schema(self, simple_params: Dict[str, Any]) -> Dict[str, Any]:
"""将简化的参数格式转换为完整的参数schema格式API接口用
输入格式: {"a": 0, "b": 0, "operation": "add"}
输出格式: 标准API参数schema
"""
properties = {}
required = []
for param_name, default_value in simple_params.items():
param_type = self._infer_parameter_type(default_value)
properties[param_name] = {
"type": param_type,
"default": default_value
}
# 如果有默认值则不是必需的如果默认值为None则可能是必需的
if default_value is None:
required.append(param_name)
return {
"type": "object",
"properties": properties,
"required": required
}
def _convert_simple_params_to_list_schema(self, simple_params: Dict[str, Any]) -> List[Dict[str, Any]]:
"""将简化的参数格式转换为列表参数schema格式函数注册用
输入格式: {"a": 0, "b": 0, "operation": "add"}
输出格式: [{"name": "a", "type": "integer", "default": 0}, ...]
"""
result = []
for param_name, default_value in simple_params.items():
param_type = self._infer_parameter_type(default_value)
param_info = {
"name": param_name,
"type": param_type,
"default": default_value,
"required": default_value is None
}
result.append(param_info)
return result
def _infer_parameter_type(self, value: Any) -> str:
"""根据默认值推断参数类型"""
if value is None:
return "any"
elif isinstance(value, bool):
return "boolean"
elif isinstance(value, int):
return "integer"
elif isinstance(value, float):
return "number"
elif isinstance(value, str):
return "string"
elif isinstance(value, list):
return "array"
elif isinstance(value, dict):
return "object"
else:
return "string" # 默认为字符串类型
# 全局注册实例
_global_registry = VWEDGlobalRegistry()
def get_global_registry() -> VWEDGlobalRegistry:
"""获取全局注册中心实例"""
return _global_registry