394 lines
16 KiB
Python
394 lines
16 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
|
||
"""
|
||
脚本WebSocket服务
|
||
提供实时日志推送和脚本状态监控
|
||
"""
|
||
|
||
import json
|
||
import asyncio
|
||
from datetime import datetime
|
||
from typing import Dict, Any, List, Optional, Set
|
||
from fastapi import WebSocket, WebSocketDisconnect
|
||
from websockets.exceptions import ConnectionClosedError
|
||
from utils.logger import get_logger
|
||
|
||
logger = get_logger("services.script_websocket")
|
||
|
||
|
||
class ScriptWebSocketManager:
|
||
"""脚本WebSocket连接管理器"""
|
||
|
||
def __init__(self):
|
||
# 活跃的WebSocket连接: {connection_id: connection_info}
|
||
self.active_connections: Dict[str, Dict[str, Any]] = {}
|
||
|
||
# 脚本订阅: {script_id: set(connection_ids)}
|
||
self.script_subscriptions: Dict[str, Set[str]] = {}
|
||
|
||
# 连接订阅: {connection_id: set(script_ids)}
|
||
self.connection_subscriptions: Dict[str, Set[str]] = {}
|
||
|
||
# 广播队列
|
||
self.broadcast_queue = asyncio.Queue()
|
||
|
||
# 启动广播任务
|
||
self._broadcast_task = None
|
||
|
||
async def connect(self, websocket: WebSocket, client_id: str,
|
||
client_type: str = "web") -> str:
|
||
"""接受WebSocket连接"""
|
||
try:
|
||
await websocket.accept()
|
||
|
||
connection_id = f"{client_type}_{client_id}_{int(datetime.now().timestamp() * 1000)}"
|
||
|
||
self.active_connections[connection_id] = {
|
||
"websocket": websocket,
|
||
"client_id": client_id,
|
||
"client_type": client_type,
|
||
"connected_at": datetime.now(),
|
||
"last_ping": datetime.now(),
|
||
"subscribed_scripts": set()
|
||
}
|
||
|
||
self.connection_subscriptions[connection_id] = set()
|
||
|
||
# 启动广播任务
|
||
if self._broadcast_task is None or self._broadcast_task.done():
|
||
self._broadcast_task = asyncio.create_task(self._broadcast_worker())
|
||
|
||
logger.info(f"WebSocket连接建立: {connection_id} ({client_type})")
|
||
|
||
# 发送欢迎消息
|
||
await self.send_to_connection(connection_id, {
|
||
"type": "welcome",
|
||
"connection_id": connection_id,
|
||
"server_time": datetime.now().isoformat(),
|
||
"message": "连接建立成功"
|
||
})
|
||
|
||
return connection_id
|
||
|
||
except Exception as e:
|
||
logger.error(f"WebSocket连接建立失败: {e}", exc_info=True)
|
||
raise
|
||
|
||
async def disconnect(self, connection_id: str):
|
||
"""断开WebSocket连接"""
|
||
try:
|
||
if connection_id not in self.active_connections:
|
||
return
|
||
|
||
# 获取连接信息
|
||
connection_info = self.active_connections[connection_id]
|
||
|
||
# 取消所有订阅(不发送确认消息,避免在已断开连接上发送)
|
||
if connection_id in self.connection_subscriptions:
|
||
subscribed_scripts = self.connection_subscriptions[connection_id].copy()
|
||
for script_id in subscribed_scripts:
|
||
# 静默取消订阅,不发送确认消息
|
||
await self._silent_unsubscribe_script(connection_id, script_id)
|
||
|
||
# 尝试关闭WebSocket连接(如果还未关闭)
|
||
try:
|
||
websocket = connection_info["websocket"]
|
||
if hasattr(websocket, 'client_state') and hasattr(websocket, 'application_state'):
|
||
from starlette.websockets import WebSocketState
|
||
if (websocket.client_state != WebSocketState.DISCONNECTED and
|
||
websocket.application_state != WebSocketState.DISCONNECTED):
|
||
await websocket.close()
|
||
except Exception:
|
||
# 忽略关闭时的错误,连接可能已经断开
|
||
pass
|
||
|
||
# 移除连接
|
||
del self.active_connections[connection_id]
|
||
if connection_id in self.connection_subscriptions:
|
||
del self.connection_subscriptions[connection_id]
|
||
|
||
logger.debug(f"WebSocket连接已清理: {connection_id}")
|
||
|
||
except Exception as e:
|
||
logger.warning(f"WebSocket断开连接处理失败: {connection_id} -> {e}")
|
||
|
||
async def _silent_unsubscribe_script(self, connection_id: str, script_id: str):
|
||
"""静默取消订阅脚本日志(不发送确认消息)"""
|
||
try:
|
||
# 从脚本订阅列表移除
|
||
if script_id in self.script_subscriptions:
|
||
self.script_subscriptions[script_id].discard(connection_id)
|
||
if not self.script_subscriptions[script_id]:
|
||
del self.script_subscriptions[script_id]
|
||
|
||
# 从连接订阅列表移除
|
||
if connection_id in self.connection_subscriptions:
|
||
self.connection_subscriptions[connection_id].discard(script_id)
|
||
|
||
# 更新连接信息
|
||
if connection_id in self.active_connections:
|
||
self.active_connections[connection_id]["subscribed_scripts"].discard(script_id)
|
||
|
||
logger.debug(f"静默取消订阅: {connection_id} -> {script_id}")
|
||
|
||
except Exception as e:
|
||
logger.warning(f"静默取消订阅失败: {connection_id} -> {script_id}: {e}")
|
||
|
||
async def subscribe_script(self, connection_id: str, script_id: str) -> bool:
|
||
"""订阅脚本日志"""
|
||
try:
|
||
if connection_id not in self.active_connections:
|
||
return False
|
||
|
||
# 添加到脚本订阅列表
|
||
if script_id not in self.script_subscriptions:
|
||
self.script_subscriptions[script_id] = set()
|
||
self.script_subscriptions[script_id].add(connection_id)
|
||
|
||
# 添加到连接订阅列表
|
||
self.connection_subscriptions[connection_id].add(script_id)
|
||
|
||
# 更新连接信息
|
||
self.active_connections[connection_id]["subscribed_scripts"].add(script_id)
|
||
|
||
logger.info(f"订阅脚本: {connection_id} -> {script_id}")
|
||
|
||
# 发送订阅确认
|
||
await self.send_to_connection(connection_id, {
|
||
"type": "subscription_success",
|
||
"script_id": script_id,
|
||
"message": f"成功订阅脚本 {script_id} 的日志"
|
||
})
|
||
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.error(f"订阅脚本失败: {e}", exc_info=True)
|
||
return False
|
||
|
||
async def unsubscribe_script(self, connection_id: str, script_id: str) -> bool:
|
||
"""取消订阅脚本日志"""
|
||
try:
|
||
# 从脚本订阅列表移除
|
||
if script_id in self.script_subscriptions:
|
||
self.script_subscriptions[script_id].discard(connection_id)
|
||
if not self.script_subscriptions[script_id]:
|
||
del self.script_subscriptions[script_id]
|
||
|
||
# 从连接订阅列表移除
|
||
if connection_id in self.connection_subscriptions:
|
||
self.connection_subscriptions[connection_id].discard(script_id)
|
||
|
||
# 更新连接信息
|
||
if connection_id in self.active_connections:
|
||
self.active_connections[connection_id]["subscribed_scripts"].discard(script_id)
|
||
|
||
logger.info(f"取消订阅脚本: {connection_id} -> {script_id}")
|
||
|
||
# 发送取消订阅确认
|
||
await self.send_to_connection(connection_id, {
|
||
"type": "unsubscription_success",
|
||
"script_id": script_id,
|
||
"message": f"已取消订阅脚本 {script_id} 的日志"
|
||
})
|
||
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.error(f"取消订阅脚本失败: {e}", exc_info=True)
|
||
return False
|
||
|
||
async def send_to_connection(self, connection_id: str, message: Dict[str, Any]):
|
||
"""向特定连接发送消息"""
|
||
try:
|
||
if connection_id not in self.active_connections:
|
||
return False
|
||
|
||
connection_info = self.active_connections[connection_id]
|
||
websocket = connection_info["websocket"]
|
||
|
||
# 检查WebSocket状态
|
||
if hasattr(websocket, 'client_state') and hasattr(websocket, 'application_state'):
|
||
from starlette.websockets import WebSocketState
|
||
# 检查WebSocket是否已关闭
|
||
if (websocket.client_state == WebSocketState.DISCONNECTED or
|
||
websocket.application_state == WebSocketState.DISCONNECTED):
|
||
logger.debug(f"WebSocket连接已断开,跳过发送: {connection_id}")
|
||
await self.disconnect(connection_id)
|
||
return False
|
||
|
||
message_text = json.dumps(message, ensure_ascii=False, default=str)
|
||
await websocket.send_text(message_text)
|
||
|
||
return True
|
||
|
||
except WebSocketDisconnect:
|
||
logger.debug(f"WebSocket连接已断开: {connection_id}")
|
||
await self.disconnect(connection_id)
|
||
return False
|
||
except ConnectionClosedError:
|
||
logger.debug(f"WebSocket连接已关闭: {connection_id}")
|
||
await self.disconnect(connection_id)
|
||
return False
|
||
except RuntimeError as e:
|
||
# 捕获 "Unexpected ASGI message" 等运行时错误
|
||
if "websocket.send" in str(e) and ("websocket.close" in str(e) or "response already completed" in str(e)):
|
||
logger.debug(f"WebSocket连接已关闭,无法发送消息: {connection_id}")
|
||
await self.disconnect(connection_id)
|
||
return False
|
||
else:
|
||
logger.warning(f"WebSocket发送消息时出现运行时错误: {connection_id} -> {e}")
|
||
await self.disconnect(connection_id)
|
||
return False
|
||
except Exception as e:
|
||
logger.warning(f"发送消息失败: {connection_id} -> {e}")
|
||
# 对于其他异常,也尝试断开连接以避免后续错误
|
||
await self.disconnect(connection_id)
|
||
return False
|
||
|
||
async def broadcast_to_script_subscribers(self, script_id: str, message: Dict[str, Any]):
|
||
"""向脚本订阅者广播消息"""
|
||
try:
|
||
if script_id not in self.script_subscriptions:
|
||
return
|
||
|
||
# 添加脚本ID到消息
|
||
message["script_id"] = script_id
|
||
|
||
# 获取订阅者列表副本
|
||
subscribers = self.script_subscriptions[script_id].copy()
|
||
|
||
# 向广播队列添加任务
|
||
await self.broadcast_queue.put({
|
||
"type": "script_broadcast",
|
||
"script_id": script_id,
|
||
"subscribers": subscribers,
|
||
"message": message
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"广播消息失败: {script_id} -> {e}", exc_info=True)
|
||
|
||
async def broadcast_script_log(self, script_id: str, log_level: str,
|
||
log_message: str, **kwargs):
|
||
"""广播脚本日志"""
|
||
message = {
|
||
"type": "script_log",
|
||
"timestamp": datetime.now().isoformat(),
|
||
"level": log_level,
|
||
"message": log_message,
|
||
**kwargs
|
||
}
|
||
await self.broadcast_to_script_subscribers(script_id, message)
|
||
|
||
async def broadcast_script_status(self, script_id: str, status: str,
|
||
message: str = "", **kwargs):
|
||
"""广播脚本状态变化"""
|
||
status_message = {
|
||
"type": "script_status",
|
||
"timestamp": datetime.now().isoformat(),
|
||
"status": status,
|
||
"message": message,
|
||
**kwargs
|
||
}
|
||
await self.broadcast_to_script_subscribers(script_id, status_message)
|
||
|
||
async def broadcast_function_execution(self, script_id: str, function_name: str,
|
||
execution_type: str, result: Dict[str, Any]):
|
||
"""广播函数执行结果"""
|
||
message = {
|
||
"type": "function_execution",
|
||
"timestamp": datetime.now().isoformat(),
|
||
"function_name": function_name,
|
||
"execution_type": execution_type,
|
||
"result": result
|
||
}
|
||
await self.broadcast_to_script_subscribers(script_id, message)
|
||
|
||
async def _broadcast_worker(self):
|
||
"""广播工作线程"""
|
||
try:
|
||
while True:
|
||
try:
|
||
# 从队列获取广播任务
|
||
broadcast_task = await asyncio.wait_for(
|
||
self.broadcast_queue.get(), timeout=1.0
|
||
)
|
||
|
||
if broadcast_task["type"] == "script_broadcast":
|
||
await self._process_script_broadcast(broadcast_task)
|
||
|
||
except asyncio.TimeoutError:
|
||
# 检查是否还有活跃连接
|
||
if not self.active_connections:
|
||
break
|
||
except Exception as e:
|
||
logger.error(f"广播工作线程错误: {e}", exc_info=True)
|
||
|
||
except Exception as e:
|
||
logger.error(f"广播工作线程异常退出: {e}", exc_info=True)
|
||
|
||
async def _process_script_broadcast(self, broadcast_task: Dict[str, Any]):
|
||
"""处理脚本广播任务"""
|
||
script_id = broadcast_task["script_id"]
|
||
subscribers = broadcast_task["subscribers"]
|
||
message = broadcast_task["message"]
|
||
|
||
# 向所有订阅者发送消息
|
||
failed_connections = []
|
||
for connection_id in subscribers:
|
||
success = await self.send_to_connection(connection_id, message)
|
||
if not success:
|
||
failed_connections.append(connection_id)
|
||
|
||
# 静默清理失败的连接(避免发送确认消息到已断开的连接)
|
||
for connection_id in failed_connections:
|
||
await self._silent_unsubscribe_script(connection_id, script_id)
|
||
|
||
async def get_connection_status(self) -> Dict[str, Any]:
|
||
"""获取连接状态"""
|
||
return {
|
||
"total_connections": len(self.active_connections),
|
||
"total_script_subscriptions": len(self.script_subscriptions),
|
||
"connections": [
|
||
{
|
||
"connection_id": conn_id,
|
||
"client_id": conn_info["client_id"],
|
||
"client_type": conn_info["client_type"],
|
||
"connected_at": conn_info["connected_at"].isoformat(),
|
||
"subscribed_scripts": list(conn_info["subscribed_scripts"])
|
||
}
|
||
for conn_id, conn_info in self.active_connections.items()
|
||
],
|
||
"script_subscriptions": {
|
||
script_id: len(subscribers)
|
||
for script_id, subscribers in self.script_subscriptions.items()
|
||
}
|
||
}
|
||
|
||
async def ping_connections(self):
|
||
"""向所有连接发送心跳检测"""
|
||
ping_message = {
|
||
"type": "ping",
|
||
"timestamp": datetime.now().isoformat()
|
||
}
|
||
|
||
failed_connections = []
|
||
for connection_id in self.active_connections:
|
||
success = await self.send_to_connection(connection_id, ping_message)
|
||
if not success:
|
||
failed_connections.append(connection_id)
|
||
|
||
# 清理失败的连接
|
||
for connection_id in failed_connections:
|
||
await self.disconnect(connection_id)
|
||
|
||
|
||
# 全局WebSocket管理器实例
|
||
_websocket_manager = ScriptWebSocketManager()
|
||
|
||
|
||
def get_websocket_manager() -> ScriptWebSocketManager:
|
||
"""获取WebSocket管理器实例"""
|
||
return _websocket_manager |