327 lines
12 KiB
Python
327 lines
12 KiB
Python
#!/usr/bin/env python
|
|
# -*- coding: utf-8 -*-
|
|
|
|
"""
|
|
脚本WebSocket服务
|
|
提供实时日志推送和脚本状态监控
|
|
"""
|
|
|
|
import json
|
|
import asyncio
|
|
from datetime import datetime
|
|
from typing import Dict, Any, List, Optional, Set
|
|
from fastapi import WebSocket, WebSocketDisconnect
|
|
from utils.logger import get_logger
|
|
|
|
logger = get_logger("services.script_websocket")
|
|
|
|
|
|
class ScriptWebSocketManager:
|
|
"""脚本WebSocket连接管理器"""
|
|
|
|
def __init__(self):
|
|
# 活跃的WebSocket连接: {connection_id: connection_info}
|
|
self.active_connections: Dict[str, Dict[str, Any]] = {}
|
|
|
|
# 脚本订阅: {script_id: set(connection_ids)}
|
|
self.script_subscriptions: Dict[str, Set[str]] = {}
|
|
|
|
# 连接订阅: {connection_id: set(script_ids)}
|
|
self.connection_subscriptions: Dict[str, Set[str]] = {}
|
|
|
|
# 广播队列
|
|
self.broadcast_queue = asyncio.Queue()
|
|
|
|
# 启动广播任务
|
|
self._broadcast_task = None
|
|
|
|
async def connect(self, websocket: WebSocket, client_id: str,
|
|
client_type: str = "web") -> str:
|
|
"""接受WebSocket连接"""
|
|
try:
|
|
await websocket.accept()
|
|
|
|
connection_id = f"{client_type}_{client_id}_{int(datetime.now().timestamp() * 1000)}"
|
|
|
|
self.active_connections[connection_id] = {
|
|
"websocket": websocket,
|
|
"client_id": client_id,
|
|
"client_type": client_type,
|
|
"connected_at": datetime.now(),
|
|
"last_ping": datetime.now(),
|
|
"subscribed_scripts": set()
|
|
}
|
|
|
|
self.connection_subscriptions[connection_id] = set()
|
|
|
|
# 启动广播任务
|
|
if self._broadcast_task is None or self._broadcast_task.done():
|
|
self._broadcast_task = asyncio.create_task(self._broadcast_worker())
|
|
|
|
logger.info(f"WebSocket连接建立: {connection_id} ({client_type})")
|
|
|
|
# 发送欢迎消息
|
|
await self.send_to_connection(connection_id, {
|
|
"type": "welcome",
|
|
"connection_id": connection_id,
|
|
"server_time": datetime.now().isoformat(),
|
|
"message": "连接建立成功"
|
|
})
|
|
|
|
return connection_id
|
|
|
|
except Exception as e:
|
|
logger.error(f"WebSocket连接建立失败: {e}", exc_info=True)
|
|
raise
|
|
|
|
async def disconnect(self, connection_id: str):
|
|
"""断开WebSocket连接"""
|
|
try:
|
|
if connection_id in self.active_connections:
|
|
# 取消所有订阅
|
|
if connection_id in self.connection_subscriptions:
|
|
subscribed_scripts = self.connection_subscriptions[connection_id].copy()
|
|
for script_id in subscribed_scripts:
|
|
await self.unsubscribe_script(connection_id, script_id)
|
|
|
|
# 移除连接
|
|
del self.active_connections[connection_id]
|
|
if connection_id in self.connection_subscriptions:
|
|
del self.connection_subscriptions[connection_id]
|
|
|
|
logger.info(f"WebSocket连接断开: {connection_id}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"WebSocket断开连接处理失败: {e}", exc_info=True)
|
|
|
|
async def subscribe_script(self, connection_id: str, script_id: str) -> bool:
|
|
"""订阅脚本日志"""
|
|
try:
|
|
if connection_id not in self.active_connections:
|
|
return False
|
|
|
|
# 添加到脚本订阅列表
|
|
if script_id not in self.script_subscriptions:
|
|
self.script_subscriptions[script_id] = set()
|
|
self.script_subscriptions[script_id].add(connection_id)
|
|
|
|
# 添加到连接订阅列表
|
|
self.connection_subscriptions[connection_id].add(script_id)
|
|
|
|
# 更新连接信息
|
|
self.active_connections[connection_id]["subscribed_scripts"].add(script_id)
|
|
|
|
logger.info(f"订阅脚本: {connection_id} -> {script_id}")
|
|
|
|
# 发送订阅确认
|
|
await self.send_to_connection(connection_id, {
|
|
"type": "subscription_success",
|
|
"script_id": script_id,
|
|
"message": f"成功订阅脚本 {script_id} 的日志"
|
|
})
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"订阅脚本失败: {e}", exc_info=True)
|
|
return False
|
|
|
|
async def unsubscribe_script(self, connection_id: str, script_id: str) -> bool:
|
|
"""取消订阅脚本日志"""
|
|
try:
|
|
# 从脚本订阅列表移除
|
|
if script_id in self.script_subscriptions:
|
|
self.script_subscriptions[script_id].discard(connection_id)
|
|
if not self.script_subscriptions[script_id]:
|
|
del self.script_subscriptions[script_id]
|
|
|
|
# 从连接订阅列表移除
|
|
if connection_id in self.connection_subscriptions:
|
|
self.connection_subscriptions[connection_id].discard(script_id)
|
|
|
|
# 更新连接信息
|
|
if connection_id in self.active_connections:
|
|
self.active_connections[connection_id]["subscribed_scripts"].discard(script_id)
|
|
|
|
logger.info(f"取消订阅脚本: {connection_id} -> {script_id}")
|
|
|
|
# 发送取消订阅确认
|
|
await self.send_to_connection(connection_id, {
|
|
"type": "unsubscription_success",
|
|
"script_id": script_id,
|
|
"message": f"已取消订阅脚本 {script_id} 的日志"
|
|
})
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"取消订阅脚本失败: {e}", exc_info=True)
|
|
return False
|
|
|
|
async def send_to_connection(self, connection_id: str, message: Dict[str, Any]):
|
|
"""向特定连接发送消息"""
|
|
try:
|
|
if connection_id not in self.active_connections:
|
|
return False
|
|
|
|
connection_info = self.active_connections[connection_id]
|
|
websocket = connection_info["websocket"]
|
|
|
|
message_text = json.dumps(message, ensure_ascii=False, default=str)
|
|
await websocket.send_text(message_text)
|
|
|
|
return True
|
|
|
|
except WebSocketDisconnect:
|
|
logger.warning(f"连接已断开: {connection_id}")
|
|
await self.disconnect(connection_id)
|
|
return False
|
|
except Exception as e:
|
|
logger.error(f"发送消息失败: {connection_id} -> {e}", exc_info=True)
|
|
return False
|
|
|
|
async def broadcast_to_script_subscribers(self, script_id: str, message: Dict[str, Any]):
|
|
"""向脚本订阅者广播消息"""
|
|
try:
|
|
if script_id not in self.script_subscriptions:
|
|
return
|
|
|
|
# 添加脚本ID到消息
|
|
message["script_id"] = script_id
|
|
|
|
# 获取订阅者列表副本
|
|
subscribers = self.script_subscriptions[script_id].copy()
|
|
|
|
# 向广播队列添加任务
|
|
await self.broadcast_queue.put({
|
|
"type": "script_broadcast",
|
|
"script_id": script_id,
|
|
"subscribers": subscribers,
|
|
"message": message
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.error(f"广播消息失败: {script_id} -> {e}", exc_info=True)
|
|
|
|
async def broadcast_script_log(self, script_id: str, log_level: str,
|
|
log_message: str, **kwargs):
|
|
"""广播脚本日志"""
|
|
message = {
|
|
"type": "script_log",
|
|
"timestamp": datetime.now().isoformat(),
|
|
"level": log_level,
|
|
"message": log_message,
|
|
**kwargs
|
|
}
|
|
await self.broadcast_to_script_subscribers(script_id, message)
|
|
|
|
async def broadcast_script_status(self, script_id: str, status: str,
|
|
message: str = "", **kwargs):
|
|
"""广播脚本状态变化"""
|
|
status_message = {
|
|
"type": "script_status",
|
|
"timestamp": datetime.now().isoformat(),
|
|
"status": status,
|
|
"message": message,
|
|
**kwargs
|
|
}
|
|
await self.broadcast_to_script_subscribers(script_id, status_message)
|
|
|
|
async def broadcast_function_execution(self, script_id: str, function_name: str,
|
|
execution_type: str, result: Dict[str, Any]):
|
|
"""广播函数执行结果"""
|
|
message = {
|
|
"type": "function_execution",
|
|
"timestamp": datetime.now().isoformat(),
|
|
"function_name": function_name,
|
|
"execution_type": execution_type,
|
|
"result": result
|
|
}
|
|
await self.broadcast_to_script_subscribers(script_id, message)
|
|
|
|
async def _broadcast_worker(self):
|
|
"""广播工作线程"""
|
|
try:
|
|
while True:
|
|
try:
|
|
# 从队列获取广播任务
|
|
broadcast_task = await asyncio.wait_for(
|
|
self.broadcast_queue.get(), timeout=1.0
|
|
)
|
|
|
|
if broadcast_task["type"] == "script_broadcast":
|
|
await self._process_script_broadcast(broadcast_task)
|
|
|
|
except asyncio.TimeoutError:
|
|
# 检查是否还有活跃连接
|
|
if not self.active_connections:
|
|
break
|
|
except Exception as e:
|
|
logger.error(f"广播工作线程错误: {e}", exc_info=True)
|
|
|
|
except Exception as e:
|
|
logger.error(f"广播工作线程异常退出: {e}", exc_info=True)
|
|
|
|
async def _process_script_broadcast(self, broadcast_task: Dict[str, Any]):
|
|
"""处理脚本广播任务"""
|
|
script_id = broadcast_task["script_id"]
|
|
subscribers = broadcast_task["subscribers"]
|
|
message = broadcast_task["message"]
|
|
|
|
# 向所有订阅者发送消息
|
|
failed_connections = []
|
|
for connection_id in subscribers:
|
|
success = await self.send_to_connection(connection_id, message)
|
|
if not success:
|
|
failed_connections.append(connection_id)
|
|
|
|
# 清理失败的连接
|
|
for connection_id in failed_connections:
|
|
await self.unsubscribe_script(connection_id, script_id)
|
|
|
|
async def get_connection_status(self) -> Dict[str, Any]:
|
|
"""获取连接状态"""
|
|
return {
|
|
"total_connections": len(self.active_connections),
|
|
"total_script_subscriptions": len(self.script_subscriptions),
|
|
"connections": [
|
|
{
|
|
"connection_id": conn_id,
|
|
"client_id": conn_info["client_id"],
|
|
"client_type": conn_info["client_type"],
|
|
"connected_at": conn_info["connected_at"].isoformat(),
|
|
"subscribed_scripts": list(conn_info["subscribed_scripts"])
|
|
}
|
|
for conn_id, conn_info in self.active_connections.items()
|
|
],
|
|
"script_subscriptions": {
|
|
script_id: len(subscribers)
|
|
for script_id, subscribers in self.script_subscriptions.items()
|
|
}
|
|
}
|
|
|
|
async def ping_connections(self):
|
|
"""向所有连接发送心跳检测"""
|
|
ping_message = {
|
|
"type": "ping",
|
|
"timestamp": datetime.now().isoformat()
|
|
}
|
|
|
|
failed_connections = []
|
|
for connection_id in self.active_connections:
|
|
success = await self.send_to_connection(connection_id, ping_message)
|
|
if not success:
|
|
failed_connections.append(connection_id)
|
|
|
|
# 清理失败的连接
|
|
for connection_id in failed_connections:
|
|
await self.disconnect(connection_id)
|
|
|
|
|
|
# 全局WebSocket管理器实例
|
|
_websocket_manager = ScriptWebSocketManager()
|
|
|
|
|
|
def get_websocket_manager() -> ScriptWebSocketManager:
|
|
"""获取WebSocket管理器实例"""
|
|
return _websocket_manager |