#!/usr/bin/env python # -*- coding: utf-8 -*- """ 脚本WebSocket服务 提供实时日志推送和脚本状态监控 """ import json import asyncio from datetime import datetime from typing import Dict, Any, List, Optional, Set from fastapi import WebSocket, WebSocketDisconnect from utils.logger import get_logger logger = get_logger("services.script_websocket") class ScriptWebSocketManager: """脚本WebSocket连接管理器""" def __init__(self): # 活跃的WebSocket连接: {connection_id: connection_info} self.active_connections: Dict[str, Dict[str, Any]] = {} # 脚本订阅: {script_id: set(connection_ids)} self.script_subscriptions: Dict[str, Set[str]] = {} # 连接订阅: {connection_id: set(script_ids)} self.connection_subscriptions: Dict[str, Set[str]] = {} # 广播队列 self.broadcast_queue = asyncio.Queue() # 启动广播任务 self._broadcast_task = None async def connect(self, websocket: WebSocket, client_id: str, client_type: str = "web") -> str: """接受WebSocket连接""" try: await websocket.accept() connection_id = f"{client_type}_{client_id}_{int(datetime.now().timestamp() * 1000)}" self.active_connections[connection_id] = { "websocket": websocket, "client_id": client_id, "client_type": client_type, "connected_at": datetime.now(), "last_ping": datetime.now(), "subscribed_scripts": set() } self.connection_subscriptions[connection_id] = set() # 启动广播任务 if self._broadcast_task is None or self._broadcast_task.done(): self._broadcast_task = asyncio.create_task(self._broadcast_worker()) logger.info(f"WebSocket连接建立: {connection_id} ({client_type})") # 发送欢迎消息 await self.send_to_connection(connection_id, { "type": "welcome", "connection_id": connection_id, "server_time": datetime.now().isoformat(), "message": "连接建立成功" }) return connection_id except Exception as e: logger.error(f"WebSocket连接建立失败: {e}", exc_info=True) raise async def disconnect(self, connection_id: str): """断开WebSocket连接""" try: if connection_id in self.active_connections: # 取消所有订阅 if connection_id in self.connection_subscriptions: subscribed_scripts = self.connection_subscriptions[connection_id].copy() for script_id in subscribed_scripts: await self.unsubscribe_script(connection_id, script_id) # 移除连接 del self.active_connections[connection_id] if connection_id in self.connection_subscriptions: del self.connection_subscriptions[connection_id] logger.info(f"WebSocket连接断开: {connection_id}") except Exception as e: logger.error(f"WebSocket断开连接处理失败: {e}", exc_info=True) async def subscribe_script(self, connection_id: str, script_id: str) -> bool: """订阅脚本日志""" try: if connection_id not in self.active_connections: return False # 添加到脚本订阅列表 if script_id not in self.script_subscriptions: self.script_subscriptions[script_id] = set() self.script_subscriptions[script_id].add(connection_id) # 添加到连接订阅列表 self.connection_subscriptions[connection_id].add(script_id) # 更新连接信息 self.active_connections[connection_id]["subscribed_scripts"].add(script_id) logger.info(f"订阅脚本: {connection_id} -> {script_id}") # 发送订阅确认 await self.send_to_connection(connection_id, { "type": "subscription_success", "script_id": script_id, "message": f"成功订阅脚本 {script_id} 的日志" }) return True except Exception as e: logger.error(f"订阅脚本失败: {e}", exc_info=True) return False async def unsubscribe_script(self, connection_id: str, script_id: str) -> bool: """取消订阅脚本日志""" try: # 从脚本订阅列表移除 if script_id in self.script_subscriptions: self.script_subscriptions[script_id].discard(connection_id) if not self.script_subscriptions[script_id]: del self.script_subscriptions[script_id] # 从连接订阅列表移除 if connection_id in self.connection_subscriptions: self.connection_subscriptions[connection_id].discard(script_id) # 更新连接信息 if connection_id in self.active_connections: self.active_connections[connection_id]["subscribed_scripts"].discard(script_id) logger.info(f"取消订阅脚本: {connection_id} -> {script_id}") # 发送取消订阅确认 await self.send_to_connection(connection_id, { "type": "unsubscription_success", "script_id": script_id, "message": f"已取消订阅脚本 {script_id} 的日志" }) return True except Exception as e: logger.error(f"取消订阅脚本失败: {e}", exc_info=True) return False async def send_to_connection(self, connection_id: str, message: Dict[str, Any]): """向特定连接发送消息""" try: if connection_id not in self.active_connections: return False connection_info = self.active_connections[connection_id] websocket = connection_info["websocket"] message_text = json.dumps(message, ensure_ascii=False, default=str) await websocket.send_text(message_text) return True except WebSocketDisconnect: logger.warning(f"连接已断开: {connection_id}") await self.disconnect(connection_id) return False except Exception as e: logger.error(f"发送消息失败: {connection_id} -> {e}", exc_info=True) return False async def broadcast_to_script_subscribers(self, script_id: str, message: Dict[str, Any]): """向脚本订阅者广播消息""" try: if script_id not in self.script_subscriptions: return # 添加脚本ID到消息 message["script_id"] = script_id # 获取订阅者列表副本 subscribers = self.script_subscriptions[script_id].copy() # 向广播队列添加任务 await self.broadcast_queue.put({ "type": "script_broadcast", "script_id": script_id, "subscribers": subscribers, "message": message }) except Exception as e: logger.error(f"广播消息失败: {script_id} -> {e}", exc_info=True) async def broadcast_script_log(self, script_id: str, log_level: str, log_message: str, **kwargs): """广播脚本日志""" message = { "type": "script_log", "timestamp": datetime.now().isoformat(), "level": log_level, "message": log_message, **kwargs } await self.broadcast_to_script_subscribers(script_id, message) async def broadcast_script_status(self, script_id: str, status: str, message: str = "", **kwargs): """广播脚本状态变化""" status_message = { "type": "script_status", "timestamp": datetime.now().isoformat(), "status": status, "message": message, **kwargs } await self.broadcast_to_script_subscribers(script_id, status_message) async def broadcast_function_execution(self, script_id: str, function_name: str, execution_type: str, result: Dict[str, Any]): """广播函数执行结果""" message = { "type": "function_execution", "timestamp": datetime.now().isoformat(), "function_name": function_name, "execution_type": execution_type, "result": result } await self.broadcast_to_script_subscribers(script_id, message) async def _broadcast_worker(self): """广播工作线程""" try: while True: try: # 从队列获取广播任务 broadcast_task = await asyncio.wait_for( self.broadcast_queue.get(), timeout=1.0 ) if broadcast_task["type"] == "script_broadcast": await self._process_script_broadcast(broadcast_task) except asyncio.TimeoutError: # 检查是否还有活跃连接 if not self.active_connections: break except Exception as e: logger.error(f"广播工作线程错误: {e}", exc_info=True) except Exception as e: logger.error(f"广播工作线程异常退出: {e}", exc_info=True) async def _process_script_broadcast(self, broadcast_task: Dict[str, Any]): """处理脚本广播任务""" script_id = broadcast_task["script_id"] subscribers = broadcast_task["subscribers"] message = broadcast_task["message"] # 向所有订阅者发送消息 failed_connections = [] for connection_id in subscribers: success = await self.send_to_connection(connection_id, message) if not success: failed_connections.append(connection_id) # 清理失败的连接 for connection_id in failed_connections: await self.unsubscribe_script(connection_id, script_id) async def get_connection_status(self) -> Dict[str, Any]: """获取连接状态""" return { "total_connections": len(self.active_connections), "total_script_subscriptions": len(self.script_subscriptions), "connections": [ { "connection_id": conn_id, "client_id": conn_info["client_id"], "client_type": conn_info["client_type"], "connected_at": conn_info["connected_at"].isoformat(), "subscribed_scripts": list(conn_info["subscribed_scripts"]) } for conn_id, conn_info in self.active_connections.items() ], "script_subscriptions": { script_id: len(subscribers) for script_id, subscribers in self.script_subscriptions.items() } } async def ping_connections(self): """向所有连接发送心跳检测""" ping_message = { "type": "ping", "timestamp": datetime.now().isoformat() } failed_connections = [] for connection_id in self.active_connections: success = await self.send_to_connection(connection_id, ping_message) if not success: failed_connections.append(connection_id) # 清理失败的连接 for connection_id in failed_connections: await self.disconnect(connection_id) # 全局WebSocket管理器实例 _websocket_manager = ScriptWebSocketManager() def get_websocket_manager() -> ScriptWebSocketManager: """获取WebSocket管理器实例""" return _websocket_manager