313 lines
10 KiB
Python
313 lines
10 KiB
Python
|
#!/usr/bin/env python
|
|||
|
# -*- coding: utf-8 -*-
|
|||
|
|
|||
|
"""
|
|||
|
WebSocket服务演示和测试
|
|||
|
启动一个完整的WebSocket服务器,模拟客户端连接,测试内置函数功能
|
|||
|
"""
|
|||
|
|
|||
|
import asyncio
|
|||
|
import json
|
|||
|
import threading
|
|||
|
import time
|
|||
|
import sys
|
|||
|
import os
|
|||
|
from datetime import datetime
|
|||
|
from typing import Dict, Set
|
|||
|
|
|||
|
# 添加项目根目录到Python路径
|
|||
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|||
|
parent_dir = os.path.dirname(current_dir)
|
|||
|
if parent_dir not in sys.path:
|
|||
|
sys.path.insert(0, parent_dir)
|
|||
|
|
|||
|
# 导入FastAPI和相关模块
|
|||
|
try:
|
|||
|
import uvicorn
|
|||
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
|||
|
from fastapi.middleware.cors import CORSMiddleware
|
|||
|
import websockets
|
|||
|
except ImportError as e:
|
|||
|
print(f"需要安装依赖包: {e}")
|
|||
|
print("请运行: pip install fastapi uvicorn websockets")
|
|||
|
sys.exit(1)
|
|||
|
|
|||
|
# 导入WebSocket模块
|
|||
|
try:
|
|||
|
from services.online_script.script_vwed_objects import create_vwed_object
|
|||
|
from utils.logger import get_logger
|
|||
|
except ImportError as e:
|
|||
|
print(f"导入项目模块失败: {e}")
|
|||
|
print("请确保从项目根目录运行此脚本")
|
|||
|
sys.exit(1)
|
|||
|
|
|||
|
logger = get_logger("tests.websocket_service_demo")
|
|||
|
|
|||
|
|
|||
|
class WebSocketTestServer:
|
|||
|
"""WebSocket测试服务器"""
|
|||
|
|
|||
|
def __init__(self, port: int = 8899):
|
|||
|
self.port = port
|
|||
|
self.app = FastAPI(title="WebSocket测试服务器")
|
|||
|
self.connections: Dict[str, WebSocket] = {}
|
|||
|
self.client_ips: Set[str] = set()
|
|||
|
|
|||
|
# 添加CORS支持
|
|||
|
self.app.add_middleware(
|
|||
|
CORSMiddleware,
|
|||
|
allow_origins=["*"],
|
|||
|
allow_credentials=True,
|
|||
|
allow_methods=["*"],
|
|||
|
allow_headers=["*"],
|
|||
|
)
|
|||
|
|
|||
|
self.setup_routes()
|
|||
|
|
|||
|
def setup_routes(self):
|
|||
|
"""设置路由"""
|
|||
|
|
|||
|
@self.app.websocket("/ws/{client_id}")
|
|||
|
async def websocket_endpoint(websocket: WebSocket, client_id: str):
|
|||
|
await websocket.accept()
|
|||
|
|
|||
|
# 获取客户端IP
|
|||
|
client_ip = websocket.client.host if hasattr(websocket, 'client') else "unknown"
|
|||
|
self.connections[client_id] = websocket
|
|||
|
self.client_ips.add(client_ip)
|
|||
|
|
|||
|
logger.info(f"客户端连接: {client_id} (IP: {client_ip})")
|
|||
|
|
|||
|
try:
|
|||
|
# 发送欢迎消息
|
|||
|
await websocket.send_text(json.dumps({
|
|||
|
"type": "welcome",
|
|||
|
"message": f"欢迎 {client_id}",
|
|||
|
"timestamp": datetime.now().isoformat()
|
|||
|
}))
|
|||
|
|
|||
|
# 保持连接并监听消息
|
|||
|
while True:
|
|||
|
try:
|
|||
|
data = await websocket.receive_text()
|
|||
|
logger.info(f"收到来自 {client_id} 的消息: {data}")
|
|||
|
|
|||
|
# 回显消息
|
|||
|
await websocket.send_text(json.dumps({
|
|||
|
"type": "echo",
|
|||
|
"original_message": data,
|
|||
|
"timestamp": datetime.now().isoformat()
|
|||
|
}))
|
|||
|
|
|||
|
except WebSocketDisconnect:
|
|||
|
break
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
logger.error(f"WebSocket处理错误: {e}")
|
|||
|
finally:
|
|||
|
# 清理连接
|
|||
|
if client_id in self.connections:
|
|||
|
del self.connections[client_id]
|
|||
|
logger.info(f"客户端断开连接: {client_id}")
|
|||
|
|
|||
|
@self.app.get("/status")
|
|||
|
async def get_status():
|
|||
|
"""获取服务器状态"""
|
|||
|
return {
|
|||
|
"active_connections": len(self.connections),
|
|||
|
"client_ids": list(self.connections.keys()),
|
|||
|
"client_ips": list(self.client_ips),
|
|||
|
"timestamp": datetime.now().isoformat()
|
|||
|
}
|
|||
|
|
|||
|
def start_server(self):
|
|||
|
"""启动服务器"""
|
|||
|
logger.info(f"启动WebSocket测试服务器,端口: {self.port}")
|
|||
|
|
|||
|
# 在单独线程中运行服务器
|
|||
|
def run_server():
|
|||
|
uvicorn.run(self.app, host="0.0.0.0", port=self.port, log_level="warning")
|
|||
|
|
|||
|
server_thread = threading.Thread(target=run_server, daemon=True)
|
|||
|
server_thread.start()
|
|||
|
|
|||
|
# 等待服务器启动
|
|||
|
time.sleep(2)
|
|||
|
logger.info("WebSocket测试服务器已启动")
|
|||
|
|
|||
|
return server_thread
|
|||
|
|
|||
|
|
|||
|
class WebSocketTestClient:
|
|||
|
"""WebSocket测试客户端"""
|
|||
|
|
|||
|
def __init__(self, client_id: str, server_url: str = "ws://localhost:8899"):
|
|||
|
self.client_id = client_id
|
|||
|
self.server_url = f"{server_url}/ws/{client_id}"
|
|||
|
self.websocket = None
|
|||
|
self.messages = []
|
|||
|
self.connected = False
|
|||
|
|
|||
|
async def connect(self):
|
|||
|
"""连接到服务器"""
|
|||
|
try:
|
|||
|
self.websocket = await websockets.connect(self.server_url)
|
|||
|
self.connected = True
|
|||
|
logger.info(f"客户端 {self.client_id} 已连接到服务器")
|
|||
|
|
|||
|
# 启动消息监听
|
|||
|
asyncio.create_task(self.listen_messages())
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
logger.error(f"客户端 {self.client_id} 连接失败: {e}")
|
|||
|
raise
|
|||
|
|
|||
|
async def listen_messages(self):
|
|||
|
"""监听服务器消息"""
|
|||
|
try:
|
|||
|
async for message in self.websocket:
|
|||
|
data = json.loads(message)
|
|||
|
self.messages.append(data)
|
|||
|
logger.info(f"客户端 {self.client_id} 收到消息: {data}")
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
logger.error(f"客户端 {self.client_id} 消息监听错误: {e}")
|
|||
|
finally:
|
|||
|
self.connected = False
|
|||
|
|
|||
|
async def send_message(self, message: str):
|
|||
|
"""发送消息到服务器"""
|
|||
|
if self.websocket and self.connected:
|
|||
|
await self.websocket.send(message)
|
|||
|
logger.info(f"客户端 {self.client_id} 发送消息: {message}")
|
|||
|
|
|||
|
async def disconnect(self):
|
|||
|
"""断开连接"""
|
|||
|
if self.websocket:
|
|||
|
await self.websocket.close()
|
|||
|
self.connected = False
|
|||
|
logger.info(f"客户端 {self.client_id} 已断开连接")
|
|||
|
|
|||
|
|
|||
|
async def test_websocket_builtin_functions():
|
|||
|
"""测试WebSocket内置函数"""
|
|||
|
logger.info("=== 测试WebSocket内置函数 ===")
|
|||
|
|
|||
|
try:
|
|||
|
# 创建VWED对象
|
|||
|
vwed = create_vwed_object("test_script_websocket")
|
|||
|
|
|||
|
# 等待客户端连接稳定
|
|||
|
await asyncio.sleep(1)
|
|||
|
|
|||
|
# 测试获取客户端信息
|
|||
|
logger.info("--- 测试获取客户端信息 ---")
|
|||
|
try:
|
|||
|
client_ips = vwed.websocket.get_websocket_client_ip()
|
|||
|
client_names = vwed.websocket.get_websocket_client_name()
|
|||
|
|
|||
|
logger.info(f"获取到的客户端IP: {client_ips}")
|
|||
|
logger.info(f"获取到的客户端名称: {client_names}")
|
|||
|
except Exception as e:
|
|||
|
logger.warning(f"获取客户端信息失败 (可能是因为使用了不同的连接管理器): {e}")
|
|||
|
|
|||
|
# 测试发送消息功能
|
|||
|
logger.info("--- 测试发送消息功能 ---")
|
|||
|
|
|||
|
# 注意:由于我们使用的是独立的测试服务器,这些测试可能会失败
|
|||
|
# 但可以展示如何使用这些函数
|
|||
|
|
|||
|
try:
|
|||
|
# 测试根据IP发送消息
|
|||
|
vwed.websocket.send_msg_to_wsc_by_client_ip("来自VWED脚本的消息", "127.0.0.1")
|
|||
|
logger.info("尝试根据IP发送消息")
|
|||
|
except Exception as e:
|
|||
|
logger.info(f"根据IP发送消息测试 (预期可能失败): {e}")
|
|||
|
|
|||
|
try:
|
|||
|
# 测试根据客户端名称发送消息
|
|||
|
vwed.websocket.send_msg_to_wsc_by_client_name("来自VWED脚本的消息", "test_client_1")
|
|||
|
logger.info("尝试根据客户端名称发送消息")
|
|||
|
except Exception as e:
|
|||
|
logger.info(f"根据客户端名称发送消息测试 (预期可能失败): {e}")
|
|||
|
|
|||
|
logger.info("WebSocket内置函数测试完成")
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
logger.error(f"WebSocket内置函数测试失败: {e}")
|
|||
|
|
|||
|
|
|||
|
async def run_websocket_demo():
|
|||
|
"""运行WebSocket演示"""
|
|||
|
logger.info("开始WebSocket服务演示")
|
|||
|
|
|||
|
# 启动测试服务器
|
|||
|
server = WebSocketTestServer()
|
|||
|
server_thread = server.start_server()
|
|||
|
|
|||
|
try:
|
|||
|
# 创建测试客户端
|
|||
|
clients = []
|
|||
|
client_tasks = []
|
|||
|
|
|||
|
# 创建多个客户端
|
|||
|
for i in range(3):
|
|||
|
client_id = f"test_client_{i+1}"
|
|||
|
client = WebSocketTestClient(client_id)
|
|||
|
clients.append(client)
|
|||
|
|
|||
|
# 连接客户端
|
|||
|
await client.connect()
|
|||
|
await asyncio.sleep(0.5) # 错开连接时间
|
|||
|
|
|||
|
# 客户端发送测试消息
|
|||
|
for i, client in enumerate(clients):
|
|||
|
await client.send_message(f"测试消息来自客户端 {i+1}")
|
|||
|
await asyncio.sleep(0.2)
|
|||
|
|
|||
|
# 等待消息处理
|
|||
|
await asyncio.sleep(2)
|
|||
|
|
|||
|
# 测试WebSocket内置函数
|
|||
|
await test_websocket_builtin_functions()
|
|||
|
|
|||
|
# 展示客户端收到的消息
|
|||
|
logger.info("=== 客户端消息汇总 ===")
|
|||
|
for client in clients:
|
|||
|
logger.info(f"客户端 {client.client_id} 收到 {len(client.messages)} 条消息")
|
|||
|
for msg in client.messages:
|
|||
|
logger.info(f" - {msg}")
|
|||
|
|
|||
|
# 断开客户端连接
|
|||
|
for client in clients:
|
|||
|
await client.disconnect()
|
|||
|
|
|||
|
logger.info("✅ WebSocket服务演示完成")
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
logger.error(f"❌ WebSocket服务演示失败: {e}")
|
|||
|
raise
|
|||
|
|
|||
|
finally:
|
|||
|
logger.info("清理测试环境")
|
|||
|
|
|||
|
|
|||
|
def main():
|
|||
|
"""主函数"""
|
|||
|
logger.info("WebSocket服务演示和测试开始")
|
|||
|
|
|||
|
try:
|
|||
|
# 运行异步演示
|
|||
|
asyncio.run(run_websocket_demo())
|
|||
|
|
|||
|
except KeyboardInterrupt:
|
|||
|
logger.info("用户中断演示")
|
|||
|
except Exception as e:
|
|||
|
logger.error(f"演示过程中发生错误: {e}")
|
|||
|
raise
|
|||
|
|
|||
|
logger.info("WebSocket服务演示和测试结束")
|
|||
|
|
|||
|
|
|||
|
if __name__ == "__main__":
|
|||
|
main()
|