313 lines
10 KiB
Python
313 lines
10 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
|
||
"""
|
||
WebSocket服务演示和测试
|
||
启动一个完整的WebSocket服务器,模拟客户端连接,测试内置函数功能
|
||
"""
|
||
|
||
import asyncio
|
||
import json
|
||
import threading
|
||
import time
|
||
import sys
|
||
import os
|
||
from datetime import datetime
|
||
from typing import Dict, Set
|
||
|
||
# 添加项目根目录到Python路径
|
||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||
parent_dir = os.path.dirname(current_dir)
|
||
if parent_dir not in sys.path:
|
||
sys.path.insert(0, parent_dir)
|
||
|
||
# 导入FastAPI和相关模块
|
||
try:
|
||
import uvicorn
|
||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
import websockets
|
||
except ImportError as e:
|
||
print(f"需要安装依赖包: {e}")
|
||
print("请运行: pip install fastapi uvicorn websockets")
|
||
sys.exit(1)
|
||
|
||
# 导入WebSocket模块
|
||
try:
|
||
from services.online_script.script_vwed_objects import create_vwed_object
|
||
from utils.logger import get_logger
|
||
except ImportError as e:
|
||
print(f"导入项目模块失败: {e}")
|
||
print("请确保从项目根目录运行此脚本")
|
||
sys.exit(1)
|
||
|
||
logger = get_logger("tests.websocket_service_demo")
|
||
|
||
|
||
class WebSocketTestServer:
|
||
"""WebSocket测试服务器"""
|
||
|
||
def __init__(self, port: int = 8899):
|
||
self.port = port
|
||
self.app = FastAPI(title="WebSocket测试服务器")
|
||
self.connections: Dict[str, WebSocket] = {}
|
||
self.client_ips: Set[str] = set()
|
||
|
||
# 添加CORS支持
|
||
self.app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["*"],
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
self.setup_routes()
|
||
|
||
def setup_routes(self):
|
||
"""设置路由"""
|
||
|
||
@self.app.websocket("/ws/{client_id}")
|
||
async def websocket_endpoint(websocket: WebSocket, client_id: str):
|
||
await websocket.accept()
|
||
|
||
# 获取客户端IP
|
||
client_ip = websocket.client.host if hasattr(websocket, 'client') else "unknown"
|
||
self.connections[client_id] = websocket
|
||
self.client_ips.add(client_ip)
|
||
|
||
logger.info(f"客户端连接: {client_id} (IP: {client_ip})")
|
||
|
||
try:
|
||
# 发送欢迎消息
|
||
await websocket.send_text(json.dumps({
|
||
"type": "welcome",
|
||
"message": f"欢迎 {client_id}",
|
||
"timestamp": datetime.now().isoformat()
|
||
}))
|
||
|
||
# 保持连接并监听消息
|
||
while True:
|
||
try:
|
||
data = await websocket.receive_text()
|
||
logger.info(f"收到来自 {client_id} 的消息: {data}")
|
||
|
||
# 回显消息
|
||
await websocket.send_text(json.dumps({
|
||
"type": "echo",
|
||
"original_message": data,
|
||
"timestamp": datetime.now().isoformat()
|
||
}))
|
||
|
||
except WebSocketDisconnect:
|
||
break
|
||
|
||
except Exception as e:
|
||
logger.error(f"WebSocket处理错误: {e}")
|
||
finally:
|
||
# 清理连接
|
||
if client_id in self.connections:
|
||
del self.connections[client_id]
|
||
logger.info(f"客户端断开连接: {client_id}")
|
||
|
||
@self.app.get("/status")
|
||
async def get_status():
|
||
"""获取服务器状态"""
|
||
return {
|
||
"active_connections": len(self.connections),
|
||
"client_ids": list(self.connections.keys()),
|
||
"client_ips": list(self.client_ips),
|
||
"timestamp": datetime.now().isoformat()
|
||
}
|
||
|
||
def start_server(self):
|
||
"""启动服务器"""
|
||
logger.info(f"启动WebSocket测试服务器,端口: {self.port}")
|
||
|
||
# 在单独线程中运行服务器
|
||
def run_server():
|
||
uvicorn.run(self.app, host="0.0.0.0", port=self.port, log_level="warning")
|
||
|
||
server_thread = threading.Thread(target=run_server, daemon=True)
|
||
server_thread.start()
|
||
|
||
# 等待服务器启动
|
||
time.sleep(2)
|
||
logger.info("WebSocket测试服务器已启动")
|
||
|
||
return server_thread
|
||
|
||
|
||
class WebSocketTestClient:
|
||
"""WebSocket测试客户端"""
|
||
|
||
def __init__(self, client_id: str, server_url: str = "ws://localhost:8899"):
|
||
self.client_id = client_id
|
||
self.server_url = f"{server_url}/ws/{client_id}"
|
||
self.websocket = None
|
||
self.messages = []
|
||
self.connected = False
|
||
|
||
async def connect(self):
|
||
"""连接到服务器"""
|
||
try:
|
||
self.websocket = await websockets.connect(self.server_url)
|
||
self.connected = True
|
||
logger.info(f"客户端 {self.client_id} 已连接到服务器")
|
||
|
||
# 启动消息监听
|
||
asyncio.create_task(self.listen_messages())
|
||
|
||
except Exception as e:
|
||
logger.error(f"客户端 {self.client_id} 连接失败: {e}")
|
||
raise
|
||
|
||
async def listen_messages(self):
|
||
"""监听服务器消息"""
|
||
try:
|
||
async for message in self.websocket:
|
||
data = json.loads(message)
|
||
self.messages.append(data)
|
||
logger.info(f"客户端 {self.client_id} 收到消息: {data}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"客户端 {self.client_id} 消息监听错误: {e}")
|
||
finally:
|
||
self.connected = False
|
||
|
||
async def send_message(self, message: str):
|
||
"""发送消息到服务器"""
|
||
if self.websocket and self.connected:
|
||
await self.websocket.send(message)
|
||
logger.info(f"客户端 {self.client_id} 发送消息: {message}")
|
||
|
||
async def disconnect(self):
|
||
"""断开连接"""
|
||
if self.websocket:
|
||
await self.websocket.close()
|
||
self.connected = False
|
||
logger.info(f"客户端 {self.client_id} 已断开连接")
|
||
|
||
|
||
async def test_websocket_builtin_functions():
|
||
"""测试WebSocket内置函数"""
|
||
logger.info("=== 测试WebSocket内置函数 ===")
|
||
|
||
try:
|
||
# 创建VWED对象
|
||
vwed = create_vwed_object("test_script_websocket")
|
||
|
||
# 等待客户端连接稳定
|
||
await asyncio.sleep(1)
|
||
|
||
# 测试获取客户端信息
|
||
logger.info("--- 测试获取客户端信息 ---")
|
||
try:
|
||
client_ips = vwed.websocket.get_websocket_client_ip()
|
||
client_names = vwed.websocket.get_websocket_client_name()
|
||
|
||
logger.info(f"获取到的客户端IP: {client_ips}")
|
||
logger.info(f"获取到的客户端名称: {client_names}")
|
||
except Exception as e:
|
||
logger.warning(f"获取客户端信息失败 (可能是因为使用了不同的连接管理器): {e}")
|
||
|
||
# 测试发送消息功能
|
||
logger.info("--- 测试发送消息功能 ---")
|
||
|
||
# 注意:由于我们使用的是独立的测试服务器,这些测试可能会失败
|
||
# 但可以展示如何使用这些函数
|
||
|
||
try:
|
||
# 测试根据IP发送消息
|
||
vwed.websocket.send_msg_to_wsc_by_client_ip("来自VWED脚本的消息", "127.0.0.1")
|
||
logger.info("尝试根据IP发送消息")
|
||
except Exception as e:
|
||
logger.info(f"根据IP发送消息测试 (预期可能失败): {e}")
|
||
|
||
try:
|
||
# 测试根据客户端名称发送消息
|
||
vwed.websocket.send_msg_to_wsc_by_client_name("来自VWED脚本的消息", "test_client_1")
|
||
logger.info("尝试根据客户端名称发送消息")
|
||
except Exception as e:
|
||
logger.info(f"根据客户端名称发送消息测试 (预期可能失败): {e}")
|
||
|
||
logger.info("WebSocket内置函数测试完成")
|
||
|
||
except Exception as e:
|
||
logger.error(f"WebSocket内置函数测试失败: {e}")
|
||
|
||
|
||
async def run_websocket_demo():
|
||
"""运行WebSocket演示"""
|
||
logger.info("开始WebSocket服务演示")
|
||
|
||
# 启动测试服务器
|
||
server = WebSocketTestServer()
|
||
server_thread = server.start_server()
|
||
|
||
try:
|
||
# 创建测试客户端
|
||
clients = []
|
||
client_tasks = []
|
||
|
||
# 创建多个客户端
|
||
for i in range(3):
|
||
client_id = f"test_client_{i+1}"
|
||
client = WebSocketTestClient(client_id)
|
||
clients.append(client)
|
||
|
||
# 连接客户端
|
||
await client.connect()
|
||
await asyncio.sleep(0.5) # 错开连接时间
|
||
|
||
# 客户端发送测试消息
|
||
for i, client in enumerate(clients):
|
||
await client.send_message(f"测试消息来自客户端 {i+1}")
|
||
await asyncio.sleep(0.2)
|
||
|
||
# 等待消息处理
|
||
await asyncio.sleep(2)
|
||
|
||
# 测试WebSocket内置函数
|
||
await test_websocket_builtin_functions()
|
||
|
||
# 展示客户端收到的消息
|
||
logger.info("=== 客户端消息汇总 ===")
|
||
for client in clients:
|
||
logger.info(f"客户端 {client.client_id} 收到 {len(client.messages)} 条消息")
|
||
for msg in client.messages:
|
||
logger.info(f" - {msg}")
|
||
|
||
# 断开客户端连接
|
||
for client in clients:
|
||
await client.disconnect()
|
||
|
||
logger.info("✅ WebSocket服务演示完成")
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ WebSocket服务演示失败: {e}")
|
||
raise
|
||
|
||
finally:
|
||
logger.info("清理测试环境")
|
||
|
||
|
||
def main():
|
||
"""主函数"""
|
||
logger.info("WebSocket服务演示和测试开始")
|
||
|
||
try:
|
||
# 运行异步演示
|
||
asyncio.run(run_websocket_demo())
|
||
|
||
except KeyboardInterrupt:
|
||
logger.info("用户中断演示")
|
||
except Exception as e:
|
||
logger.error(f"演示过程中发生错误: {e}")
|
||
raise
|
||
|
||
logger.info("WebSocket服务演示和测试结束")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main() |