VWED_server/services/script_websocket_service.py

327 lines
12 KiB
Python

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
脚本WebSocket服务
提供实时日志推送和脚本状态监控
"""
import json
import asyncio
from datetime import datetime
from typing import Dict, Any, List, Optional, Set
from fastapi import WebSocket, WebSocketDisconnect
from utils.logger import get_logger
logger = get_logger("services.script_websocket")
class ScriptWebSocketManager:
"""脚本WebSocket连接管理器"""
def __init__(self):
# 活跃的WebSocket连接: {connection_id: connection_info}
self.active_connections: Dict[str, Dict[str, Any]] = {}
# 脚本订阅: {script_id: set(connection_ids)}
self.script_subscriptions: Dict[str, Set[str]] = {}
# 连接订阅: {connection_id: set(script_ids)}
self.connection_subscriptions: Dict[str, Set[str]] = {}
# 广播队列
self.broadcast_queue = asyncio.Queue()
# 启动广播任务
self._broadcast_task = None
async def connect(self, websocket: WebSocket, client_id: str,
client_type: str = "web") -> str:
"""接受WebSocket连接"""
try:
await websocket.accept()
connection_id = f"{client_type}_{client_id}_{int(datetime.now().timestamp() * 1000)}"
self.active_connections[connection_id] = {
"websocket": websocket,
"client_id": client_id,
"client_type": client_type,
"connected_at": datetime.now(),
"last_ping": datetime.now(),
"subscribed_scripts": set()
}
self.connection_subscriptions[connection_id] = set()
# 启动广播任务
if self._broadcast_task is None or self._broadcast_task.done():
self._broadcast_task = asyncio.create_task(self._broadcast_worker())
logger.info(f"WebSocket连接建立: {connection_id} ({client_type})")
# 发送欢迎消息
await self.send_to_connection(connection_id, {
"type": "welcome",
"connection_id": connection_id,
"server_time": datetime.now().isoformat(),
"message": "连接建立成功"
})
return connection_id
except Exception as e:
logger.error(f"WebSocket连接建立失败: {e}", exc_info=True)
raise
async def disconnect(self, connection_id: str):
"""断开WebSocket连接"""
try:
if connection_id in self.active_connections:
# 取消所有订阅
if connection_id in self.connection_subscriptions:
subscribed_scripts = self.connection_subscriptions[connection_id].copy()
for script_id in subscribed_scripts:
await self.unsubscribe_script(connection_id, script_id)
# 移除连接
del self.active_connections[connection_id]
if connection_id in self.connection_subscriptions:
del self.connection_subscriptions[connection_id]
logger.info(f"WebSocket连接断开: {connection_id}")
except Exception as e:
logger.error(f"WebSocket断开连接处理失败: {e}", exc_info=True)
async def subscribe_script(self, connection_id: str, script_id: str) -> bool:
"""订阅脚本日志"""
try:
if connection_id not in self.active_connections:
return False
# 添加到脚本订阅列表
if script_id not in self.script_subscriptions:
self.script_subscriptions[script_id] = set()
self.script_subscriptions[script_id].add(connection_id)
# 添加到连接订阅列表
self.connection_subscriptions[connection_id].add(script_id)
# 更新连接信息
self.active_connections[connection_id]["subscribed_scripts"].add(script_id)
logger.info(f"订阅脚本: {connection_id} -> {script_id}")
# 发送订阅确认
await self.send_to_connection(connection_id, {
"type": "subscription_success",
"script_id": script_id,
"message": f"成功订阅脚本 {script_id} 的日志"
})
return True
except Exception as e:
logger.error(f"订阅脚本失败: {e}", exc_info=True)
return False
async def unsubscribe_script(self, connection_id: str, script_id: str) -> bool:
"""取消订阅脚本日志"""
try:
# 从脚本订阅列表移除
if script_id in self.script_subscriptions:
self.script_subscriptions[script_id].discard(connection_id)
if not self.script_subscriptions[script_id]:
del self.script_subscriptions[script_id]
# 从连接订阅列表移除
if connection_id in self.connection_subscriptions:
self.connection_subscriptions[connection_id].discard(script_id)
# 更新连接信息
if connection_id in self.active_connections:
self.active_connections[connection_id]["subscribed_scripts"].discard(script_id)
logger.info(f"取消订阅脚本: {connection_id} -> {script_id}")
# 发送取消订阅确认
await self.send_to_connection(connection_id, {
"type": "unsubscription_success",
"script_id": script_id,
"message": f"已取消订阅脚本 {script_id} 的日志"
})
return True
except Exception as e:
logger.error(f"取消订阅脚本失败: {e}", exc_info=True)
return False
async def send_to_connection(self, connection_id: str, message: Dict[str, Any]):
"""向特定连接发送消息"""
try:
if connection_id not in self.active_connections:
return False
connection_info = self.active_connections[connection_id]
websocket = connection_info["websocket"]
message_text = json.dumps(message, ensure_ascii=False, default=str)
await websocket.send_text(message_text)
return True
except WebSocketDisconnect:
logger.warning(f"连接已断开: {connection_id}")
await self.disconnect(connection_id)
return False
except Exception as e:
logger.error(f"发送消息失败: {connection_id} -> {e}", exc_info=True)
return False
async def broadcast_to_script_subscribers(self, script_id: str, message: Dict[str, Any]):
"""向脚本订阅者广播消息"""
try:
if script_id not in self.script_subscriptions:
return
# 添加脚本ID到消息
message["script_id"] = script_id
# 获取订阅者列表副本
subscribers = self.script_subscriptions[script_id].copy()
# 向广播队列添加任务
await self.broadcast_queue.put({
"type": "script_broadcast",
"script_id": script_id,
"subscribers": subscribers,
"message": message
})
except Exception as e:
logger.error(f"广播消息失败: {script_id} -> {e}", exc_info=True)
async def broadcast_script_log(self, script_id: str, log_level: str,
log_message: str, **kwargs):
"""广播脚本日志"""
message = {
"type": "script_log",
"timestamp": datetime.now().isoformat(),
"level": log_level,
"message": log_message,
**kwargs
}
await self.broadcast_to_script_subscribers(script_id, message)
async def broadcast_script_status(self, script_id: str, status: str,
message: str = "", **kwargs):
"""广播脚本状态变化"""
status_message = {
"type": "script_status",
"timestamp": datetime.now().isoformat(),
"status": status,
"message": message,
**kwargs
}
await self.broadcast_to_script_subscribers(script_id, status_message)
async def broadcast_function_execution(self, script_id: str, function_name: str,
execution_type: str, result: Dict[str, Any]):
"""广播函数执行结果"""
message = {
"type": "function_execution",
"timestamp": datetime.now().isoformat(),
"function_name": function_name,
"execution_type": execution_type,
"result": result
}
await self.broadcast_to_script_subscribers(script_id, message)
async def _broadcast_worker(self):
"""广播工作线程"""
try:
while True:
try:
# 从队列获取广播任务
broadcast_task = await asyncio.wait_for(
self.broadcast_queue.get(), timeout=1.0
)
if broadcast_task["type"] == "script_broadcast":
await self._process_script_broadcast(broadcast_task)
except asyncio.TimeoutError:
# 检查是否还有活跃连接
if not self.active_connections:
break
except Exception as e:
logger.error(f"广播工作线程错误: {e}", exc_info=True)
except Exception as e:
logger.error(f"广播工作线程异常退出: {e}", exc_info=True)
async def _process_script_broadcast(self, broadcast_task: Dict[str, Any]):
"""处理脚本广播任务"""
script_id = broadcast_task["script_id"]
subscribers = broadcast_task["subscribers"]
message = broadcast_task["message"]
# 向所有订阅者发送消息
failed_connections = []
for connection_id in subscribers:
success = await self.send_to_connection(connection_id, message)
if not success:
failed_connections.append(connection_id)
# 清理失败的连接
for connection_id in failed_connections:
await self.unsubscribe_script(connection_id, script_id)
async def get_connection_status(self) -> Dict[str, Any]:
"""获取连接状态"""
return {
"total_connections": len(self.active_connections),
"total_script_subscriptions": len(self.script_subscriptions),
"connections": [
{
"connection_id": conn_id,
"client_id": conn_info["client_id"],
"client_type": conn_info["client_type"],
"connected_at": conn_info["connected_at"].isoformat(),
"subscribed_scripts": list(conn_info["subscribed_scripts"])
}
for conn_id, conn_info in self.active_connections.items()
],
"script_subscriptions": {
script_id: len(subscribers)
for script_id, subscribers in self.script_subscriptions.items()
}
}
async def ping_connections(self):
"""向所有连接发送心跳检测"""
ping_message = {
"type": "ping",
"timestamp": datetime.now().isoformat()
}
failed_connections = []
for connection_id in self.active_connections:
success = await self.send_to_connection(connection_id, ping_message)
if not success:
failed_connections.append(connection_id)
# 清理失败的连接
for connection_id in failed_connections:
await self.disconnect(connection_id)
# 全局WebSocket管理器实例
_websocket_manager = ScriptWebSocketManager()
def get_websocket_manager() -> ScriptWebSocketManager:
"""获取WebSocket管理器实例"""
return _websocket_manager