238 lines
9.1 KiB
Python
238 lines
9.1 KiB
Python
|
#!/usr/bin/env python
|
|||
|
# -*- coding: utf-8 -*-
|
|||
|
|
|||
|
"""
|
|||
|
WebSocket模块
|
|||
|
提供WebSocket客户端通信相关功能
|
|||
|
"""
|
|||
|
|
|||
|
from typing import List
|
|||
|
from utils.logger import get_logger
|
|||
|
|
|||
|
logger = get_logger("services.online_script.websocket_module")
|
|||
|
|
|||
|
|
|||
|
class VWEDWebSocketModule:
|
|||
|
"""WebSocket模块"""
|
|||
|
|
|||
|
def __init__(self, script_id: str):
|
|||
|
self.script_id = script_id
|
|||
|
self._manager = None
|
|||
|
|
|||
|
def _get_manager(self):
|
|||
|
"""获取WebSocket连接管理器(延迟导入)"""
|
|||
|
if self._manager is None:
|
|||
|
try:
|
|||
|
from routes.websocket_api import manager
|
|||
|
self._manager = manager
|
|||
|
except ImportError as e:
|
|||
|
logger.error(f"导入WebSocket管理器失败: {e}")
|
|||
|
raise Exception(f"WebSocket管理器不可用: {e}")
|
|||
|
return self._manager
|
|||
|
|
|||
|
def send_msg_to_wsc_by_client_ip(self, msg: str, ip: str) -> None:
|
|||
|
"""
|
|||
|
根据IP向客户端发送消息
|
|||
|
|
|||
|
Args:
|
|||
|
msg: 要发送的消息
|
|||
|
ip: 接收消息的客户端IP
|
|||
|
|
|||
|
Raises:
|
|||
|
Exception: 发送失败时抛出异常
|
|||
|
"""
|
|||
|
try:
|
|||
|
manager = self._get_manager()
|
|||
|
|
|||
|
# 从连接管理器中查找对应IP的WebSocket连接
|
|||
|
found_connections = []
|
|||
|
|
|||
|
# 遍历所有活跃连接,查找匹配IP的连接
|
|||
|
for _, connections in manager.active_connections.items():
|
|||
|
for websocket in connections:
|
|||
|
try:
|
|||
|
# 获取WebSocket客户端IP
|
|||
|
client_host = websocket.client.host if hasattr(websocket, 'client') else None
|
|||
|
if client_host == ip:
|
|||
|
found_connections.append(websocket)
|
|||
|
except Exception:
|
|||
|
continue
|
|||
|
|
|||
|
# 同样检查库位状态连接
|
|||
|
for _, connections in manager.storage_location_connections.items():
|
|||
|
for websocket in connections:
|
|||
|
try:
|
|||
|
# 获取WebSocket客户端IP
|
|||
|
client_host = websocket.client.host if hasattr(websocket, 'client') else None
|
|||
|
if client_host == ip:
|
|||
|
found_connections.append(websocket)
|
|||
|
except Exception:
|
|||
|
continue
|
|||
|
|
|||
|
if not found_connections:
|
|||
|
raise Exception(f"未找到IP为 {ip} 的WebSocket客户端连接")
|
|||
|
|
|||
|
# 向找到的所有连接发送消息
|
|||
|
import asyncio
|
|||
|
|
|||
|
async def send_async():
|
|||
|
failed_count = 0
|
|||
|
for websocket in found_connections:
|
|||
|
try:
|
|||
|
await manager.send_personal_message(msg, websocket)
|
|||
|
except Exception as e:
|
|||
|
failed_count += 1
|
|||
|
logger.error(f"向IP {ip} 发送消息失败: {e}")
|
|||
|
|
|||
|
if failed_count == len(found_connections):
|
|||
|
raise Exception(f"向IP {ip} 的所有连接发送消息都失败")
|
|||
|
|
|||
|
# 检查是否在事件循环中
|
|||
|
try:
|
|||
|
asyncio.get_running_loop()
|
|||
|
# 如果在事件循环中,创建任务
|
|||
|
asyncio.create_task(send_async())
|
|||
|
except RuntimeError:
|
|||
|
# 不在事件循环中,创建新的事件循环
|
|||
|
asyncio.run(send_async())
|
|||
|
|
|||
|
logger.info(f"成功向IP {ip} 发送消息: {msg}")
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
logger.error(f"根据IP发送WebSocket消息失败: {e}")
|
|||
|
raise e
|
|||
|
|
|||
|
def send_msg_to_wsc_by_client_name(self, msg: str, client_name: str) -> None:
|
|||
|
"""
|
|||
|
根据客户端名称向客户端发送消息
|
|||
|
|
|||
|
Args:
|
|||
|
msg: 要发送的消息
|
|||
|
client_name: 接收消息的客户端名称
|
|||
|
|
|||
|
Raises:
|
|||
|
Exception: 发送失败时抛出异常
|
|||
|
"""
|
|||
|
try:
|
|||
|
manager = self._get_manager()
|
|||
|
|
|||
|
# 从连接管理器中查找对应名称的WebSocket连接
|
|||
|
found_connections = []
|
|||
|
|
|||
|
# 遍历所有活跃连接,查找匹配名称的连接
|
|||
|
# 注意:当前的ConnectionManager没有存储客户端名称信息
|
|||
|
# 这里使用task_record_id作为客户端标识符
|
|||
|
for task_record_id, connections in manager.active_connections.items():
|
|||
|
if task_record_id == client_name:
|
|||
|
found_connections.extend(connections)
|
|||
|
|
|||
|
# 同样检查库位状态连接,使用scene_id作为客户端标识符
|
|||
|
for scene_id, connections in manager.storage_location_connections.items():
|
|||
|
if scene_id == client_name:
|
|||
|
found_connections.extend(connections)
|
|||
|
|
|||
|
if not found_connections:
|
|||
|
raise Exception(f"未找到名称为 {client_name} 的WebSocket客户端连接")
|
|||
|
|
|||
|
# 向找到的所有连接发送消息
|
|||
|
import asyncio
|
|||
|
|
|||
|
async def send_async():
|
|||
|
failed_count = 0
|
|||
|
for websocket in found_connections:
|
|||
|
try:
|
|||
|
await manager.send_personal_message(msg, websocket)
|
|||
|
except Exception as e:
|
|||
|
failed_count += 1
|
|||
|
logger.error(f"向客户端 {client_name} 发送消息失败: {e}")
|
|||
|
|
|||
|
if failed_count == len(found_connections):
|
|||
|
raise Exception(f"向客户端 {client_name} 的所有连接发送消息都失败")
|
|||
|
|
|||
|
# 检查是否在事件循环中
|
|||
|
try:
|
|||
|
asyncio.get_running_loop()
|
|||
|
# 如果在事件循环中,创建任务
|
|||
|
asyncio.create_task(send_async())
|
|||
|
except RuntimeError:
|
|||
|
# 不在事件循环中,创建新的事件循环
|
|||
|
asyncio.run(send_async())
|
|||
|
|
|||
|
logger.info(f"成功向客户端 {client_name} 发送消息: {msg}")
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
logger.error(f"根据客户端名称发送WebSocket消息失败: {e}")
|
|||
|
raise e
|
|||
|
|
|||
|
def get_websocket_client_ip(self) -> List[str]:
|
|||
|
"""
|
|||
|
获取所有WebSocket客户端的IP
|
|||
|
|
|||
|
Returns:
|
|||
|
包含所有客户端IP的字符串列表
|
|||
|
|
|||
|
Raises:
|
|||
|
Exception: 获取失败时抛出异常
|
|||
|
"""
|
|||
|
try:
|
|||
|
manager = self._get_manager()
|
|||
|
client_ips = []
|
|||
|
|
|||
|
# 遍历所有活跃连接,收集客户端IP
|
|||
|
for _, connections in manager.active_connections.items():
|
|||
|
for websocket in connections:
|
|||
|
try:
|
|||
|
# 获取WebSocket客户端IP
|
|||
|
client_host = websocket.client.host if hasattr(websocket, 'client') else None
|
|||
|
if client_host and client_host not in client_ips:
|
|||
|
client_ips.append(client_host)
|
|||
|
except Exception:
|
|||
|
continue
|
|||
|
|
|||
|
# 同样检查库位状态连接
|
|||
|
for _, connections in manager.storage_location_connections.items():
|
|||
|
for websocket in connections:
|
|||
|
try:
|
|||
|
# 获取WebSocket客户端IP
|
|||
|
client_host = websocket.client.host if hasattr(websocket, 'client') else None
|
|||
|
if client_host and client_host not in client_ips:
|
|||
|
client_ips.append(client_host)
|
|||
|
except Exception:
|
|||
|
continue
|
|||
|
|
|||
|
logger.info(f"获取到 {len(client_ips)} 个WebSocket客户端IP")
|
|||
|
return client_ips
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
logger.error(f"获取WebSocket客户端IP失败: {e}")
|
|||
|
raise e
|
|||
|
|
|||
|
def get_websocket_client_name(self) -> List[str]:
|
|||
|
"""
|
|||
|
获取所有WebSocket客户端名称
|
|||
|
|
|||
|
Returns:
|
|||
|
包含所有客户端名称的字符串列表
|
|||
|
|
|||
|
Raises:
|
|||
|
Exception: 获取失败时抛出异常
|
|||
|
"""
|
|||
|
try:
|
|||
|
manager = self._get_manager()
|
|||
|
client_names = []
|
|||
|
|
|||
|
# 收集任务记录连接的客户端名称(使用task_record_id作为名称)
|
|||
|
client_names.extend(list(manager.active_connections.keys()))
|
|||
|
|
|||
|
# 收集库位状态连接的客户端名称(使用scene_id作为名称)
|
|||
|
client_names.extend(list(manager.storage_location_connections.keys()))
|
|||
|
|
|||
|
# 去重
|
|||
|
client_names = list(set(client_names))
|
|||
|
|
|||
|
logger.info(f"获取到 {len(client_names)} 个WebSocket客户端名称")
|
|||
|
return client_names
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
logger.error(f"获取WebSocket客户端名称失败: {e}")
|
|||
|
raise e
|