354 lines
14 KiB
Python
354 lines
14 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
|
||
"""
|
||
异步MQTT客户端服务模块
|
||
提供高性能的异步MQTT连接、发布、订阅等功能
|
||
支持自动重连和异步消息处理
|
||
"""
|
||
|
||
import asyncio
|
||
import json
|
||
import time
|
||
import sys
|
||
from typing import Dict, Any, Optional, Callable, List
|
||
from utils.logger import get_logger
|
||
|
||
# Windows兼容性修复
|
||
if sys.platform == 'win32':
|
||
# 在Windows上使用SelectorEventLoop以支持add_reader/add_writer
|
||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||
|
||
# try:
|
||
# import aiomqtt
|
||
# except ImportError:
|
||
# try:
|
||
# # 尝试导入旧版本的asyncio-mqtt
|
||
#
|
||
# except ImportError:
|
||
# aiomqtt = None
|
||
# 如果都没有安装,回退到paho-mqtt
|
||
import aiomqtt
|
||
|
||
from config.tf_api_config import MQTT_CONFIG
|
||
|
||
logger = get_logger("services.async_mqtt_service")
|
||
|
||
|
||
class AsyncMQTTService:
|
||
"""异步MQTT客户端服务"""
|
||
|
||
def __init__(self, config: Dict[str, Any] = None):
|
||
self.config = config or MQTT_CONFIG
|
||
self.client = None
|
||
self.connected = False
|
||
self.reconnect_attempts = 0
|
||
self.message_handlers: Dict[str, List[Callable]] = {}
|
||
self._running = False
|
||
self._connection_task = None
|
||
self._message_task = None
|
||
self._use_aiomqtt = aiomqtt is not None
|
||
|
||
if not self._use_aiomqtt:
|
||
logger.warning("aiomqtt未安装,将使用paho-mqtt的异步包装")
|
||
|
||
async def connect(self):
|
||
"""连接MQTT服务器"""
|
||
try:
|
||
if self._use_aiomqtt:
|
||
await self._connect_aiomqtt()
|
||
else:
|
||
await self._connect_paho_async()
|
||
return True
|
||
except Exception as e:
|
||
logger.error(f"连接MQTT服务器失败: {e}")
|
||
return False
|
||
|
||
async def _connect_aiomqtt(self):
|
||
"""使用aiomqtt连接"""
|
||
client_config = {
|
||
"hostname": self.config['host'],
|
||
"port": self.config['port'],
|
||
# "keepalive": self.config.get('keepalive', 60),
|
||
}
|
||
|
||
if self.config.get('username'):
|
||
client_config['username'] = self.config['username']
|
||
client_config['password'] = self.config.get('password', '')
|
||
# try:
|
||
self.client = aiomqtt.Client(**client_config)
|
||
# except TypeError as e:
|
||
# # 处理版本兼容性问题,移除不支持的参数
|
||
# logger.warning(f"aiomqtt版本兼容性问题,重试连接: {e}")
|
||
# # 创建基础配置,只包含核心参数
|
||
# basic_config = {
|
||
# "hostname": self.config['host'],
|
||
# "port": self.config['port']
|
||
# }
|
||
# if self.config.get('username'):
|
||
# basic_config['username'] = self.config['username']
|
||
# basic_config['password'] = self.config.get('password', '')
|
||
# self.client = aiomqtt.Client(**basic_config)
|
||
|
||
self._running = True
|
||
self._connection_task = asyncio.create_task(self._maintain_connection())
|
||
|
||
logger.info(f"正在连接MQTT服务器 {self.config['host']}:{self.config['port']}...")
|
||
|
||
# 等待连接建立
|
||
timeout = 10
|
||
start_time = time.time()
|
||
while not self.connected and (time.time() - start_time) < timeout:
|
||
await asyncio.sleep(0.1)
|
||
|
||
if self.connected:
|
||
logger.info("异步MQTT连接建立成功")
|
||
else:
|
||
logger.error("异步MQTT连接建立超时")
|
||
|
||
async def _maintain_connection(self):
|
||
"""维护MQTT连接(仅用于aiomqtt)"""
|
||
while self._running:
|
||
try:
|
||
async with self.client:
|
||
self.connected = True
|
||
self.reconnect_attempts = 0
|
||
logger.info(f"MQTT连接成功: {self.config['host']}:{self.config['port']}")
|
||
|
||
# 启动消息处理任务
|
||
self._message_task = asyncio.create_task(self._handle_messages())
|
||
|
||
# 重新订阅所有topics
|
||
await self._resubscribe_all_topics()
|
||
|
||
# 等待消息处理任务完成或连接断开
|
||
await self._message_task
|
||
|
||
except Exception as e:
|
||
self.connected = False
|
||
if self._message_task:
|
||
self._message_task.cancel()
|
||
|
||
if self._running:
|
||
self.reconnect_attempts += 1
|
||
if self.reconnect_attempts <= self.config.get('max_retries', 3):
|
||
delay = self.config.get('reconnect_delay', 5) * self.reconnect_attempts
|
||
logger.warning(f"MQTT连接断开: {e}, 将在 {delay} 秒后重连(第{self.reconnect_attempts}次尝试)")
|
||
await asyncio.sleep(delay)
|
||
else:
|
||
logger.error("MQTT重连次数已达上限,停止重连")
|
||
break
|
||
else:
|
||
break
|
||
|
||
async def _handle_messages(self):
|
||
"""处理接收到的消息(仅用于aiomqtt)"""
|
||
try:
|
||
async for message in self.client.messages:
|
||
topic = str(message.topic)
|
||
payload = message.payload.decode('utf-8')
|
||
|
||
# 异步处理消息
|
||
await self._process_message(topic, payload)
|
||
|
||
except asyncio.CancelledError:
|
||
logger.debug("消息处理任务被取消")
|
||
except Exception as e:
|
||
logger.error(f"处理MQTT消息失败: {e}", exc_info=True)
|
||
|
||
async def _process_message(self, topic: str, payload: str):
|
||
"""异步处理单个消息"""
|
||
try:
|
||
handlers = self.message_handlers.get(topic, [])
|
||
for handler in handlers:
|
||
try:
|
||
if asyncio.iscoroutinefunction(handler):
|
||
await handler(topic, payload)
|
||
else:
|
||
# 对于同步处理器,在线程池中执行
|
||
loop = asyncio.get_event_loop()
|
||
await loop.run_in_executor(None, handler, topic, payload)
|
||
except Exception as e:
|
||
logger.error(f"MQTT消息处理器执行失败: {e}", exc_info=True)
|
||
except Exception as e:
|
||
logger.error(f"处理MQTT消息异常: {e}", exc_info=True)
|
||
|
||
async def _resubscribe_all_topics(self):
|
||
"""重新订阅所有topics"""
|
||
for topic in self.message_handlers.keys():
|
||
try:
|
||
await self.client.subscribe(topic)
|
||
logger.debug(f"重新订阅MQTT主题: {topic}")
|
||
except Exception as e:
|
||
logger.error(f"重新订阅主题失败 {topic}: {e}")
|
||
|
||
async def disconnect(self):
|
||
"""断开连接"""
|
||
self._running = False
|
||
self.connected = False
|
||
|
||
if self._message_task:
|
||
self._message_task.cancel()
|
||
try:
|
||
await self._message_task
|
||
except asyncio.CancelledError:
|
||
pass
|
||
|
||
if self._connection_task:
|
||
self._connection_task.cancel()
|
||
try:
|
||
await self._connection_task
|
||
except asyncio.CancelledError:
|
||
pass
|
||
|
||
logger.info("异步MQTT连接已断开")
|
||
|
||
async def subscribe(self, topic: str, handler: Callable = None):
|
||
"""订阅主题"""
|
||
try:
|
||
if not self.connected:
|
||
logger.warning(f"MQTT未连接,无法订阅主题: {topic}")
|
||
return False
|
||
|
||
if self._use_aiomqtt and self.client:
|
||
await self.client.subscribe(topic)
|
||
logger.info(f"成功订阅MQTT主题: {topic}")
|
||
|
||
# 添加消息处理器
|
||
if handler:
|
||
if topic not in self.message_handlers:
|
||
self.message_handlers[topic] = []
|
||
if handler not in self.message_handlers[topic]:
|
||
self.message_handlers[topic].append(handler)
|
||
|
||
return True
|
||
else:
|
||
logger.error("MQTT客户端未初始化")
|
||
return False
|
||
|
||
except Exception as e:
|
||
logger.error(f"订阅MQTT主题异常: {e}")
|
||
return False
|
||
|
||
async def unsubscribe(self, topic: str):
|
||
"""取消订阅主题"""
|
||
try:
|
||
if not self.connected:
|
||
logger.debug(f"MQTT未连接,跳过取消订阅主题: {topic}")
|
||
return False
|
||
|
||
if self._use_aiomqtt and self.client:
|
||
await self.client.unsubscribe(topic)
|
||
logger.info(f"取消订阅MQTT主题: {topic}")
|
||
|
||
# 清理消息处理器
|
||
if topic in self.message_handlers:
|
||
del self.message_handlers[topic]
|
||
|
||
return True
|
||
else:
|
||
logger.warning(f"MQTT客户端未初始化,无法取消订阅主题: {topic}")
|
||
return False
|
||
|
||
except Exception as e:
|
||
logger.error(f"取消订阅MQTT主题异常: {e}")
|
||
return False
|
||
|
||
async def publish(self, topic: str, payload: Any, qos: int = 0, retain: bool = False):
|
||
"""发布消息"""
|
||
try:
|
||
if not self.connected:
|
||
logger.warning(f"MQTT未连接,无法发布消息到: {topic}")
|
||
return False
|
||
|
||
# 处理不同类型的payload
|
||
payload_str = self._serialize_payload(payload)
|
||
|
||
if self._use_aiomqtt and self.client:
|
||
await self.client.publish(topic, payload_str, qos=qos, retain=retain)
|
||
logger.debug(f"MQTT消息发布成功: {topic}")
|
||
return True
|
||
else:
|
||
logger.error("MQTT客户端未初始化")
|
||
return False
|
||
|
||
except Exception as e:
|
||
logger.error(f"发布MQTT消息异常: {e}")
|
||
return False
|
||
|
||
def _serialize_payload(self, payload: Any) -> str:
|
||
"""序列化消息载荷"""
|
||
try:
|
||
if isinstance(payload, (str, int, float, bool)):
|
||
return str(payload)
|
||
elif isinstance(payload, dict):
|
||
return json.dumps(payload, ensure_ascii=False, default=self._json_default)
|
||
elif isinstance(payload, list):
|
||
return json.dumps(payload, ensure_ascii=False, default=self._json_default)
|
||
elif hasattr(payload, 'to_dict') and callable(payload.to_dict):
|
||
# 如果对象有to_dict方法,优先使用它
|
||
return json.dumps(payload.to_dict(), ensure_ascii=False, default=self._json_default)
|
||
elif hasattr(payload, 'dict'):
|
||
# 如果是Pydantic模型等,尝试使用dict方法
|
||
if callable(payload.dict):
|
||
return json.dumps(payload.dict(), ensure_ascii=False, default=self._json_default)
|
||
else:
|
||
return json.dumps(payload.dict, ensure_ascii=False, default=self._json_default)
|
||
elif hasattr(payload, '__dict__'):
|
||
# 对于自定义对象,尝试序列化其属性
|
||
return json.dumps(payload.__dict__, ensure_ascii=False, default=self._json_default)
|
||
else:
|
||
# 最后尝试直接转换为字符串
|
||
return str(payload)
|
||
except Exception as e:
|
||
logger.warning(f"序列化载荷失败,使用字符串表示: {e}")
|
||
return str(payload)
|
||
|
||
def _json_default(self, obj):
|
||
"""JSON序列化的默认处理器,处理特殊类型"""
|
||
from enum import Enum
|
||
|
||
if isinstance(obj, Enum):
|
||
return obj.value
|
||
elif hasattr(obj, 'to_dict') and callable(obj.to_dict):
|
||
return obj.to_dict()
|
||
elif hasattr(obj, '__dict__'):
|
||
return obj.__dict__
|
||
else:
|
||
return str(obj)
|
||
|
||
def add_message_handler(self, topic: str, handler: Callable):
|
||
"""为指定主题添加消息处理器"""
|
||
if topic not in self.message_handlers:
|
||
self.message_handlers[topic] = []
|
||
if handler not in self.message_handlers[topic]:
|
||
self.message_handlers[topic].append(handler)
|
||
|
||
def remove_message_handler(self, topic: str, handler: Callable):
|
||
"""移除指定主题的消息处理器"""
|
||
if topic in self.message_handlers:
|
||
if handler in self.message_handlers[topic]:
|
||
self.message_handlers[topic].remove(handler)
|
||
if not self.message_handlers[topic]:
|
||
del self.message_handlers[topic]
|
||
|
||
def is_connected(self) -> bool:
|
||
"""检查连接状态"""
|
||
return self.connected
|
||
|
||
def get_connection_info(self) -> Dict[str, Any]:
|
||
"""获取连接信息"""
|
||
return {
|
||
"host": self.config['host'],
|
||
"port": self.config['port'],
|
||
"connected": self.connected,
|
||
"reconnect_attempts": self.reconnect_attempts,
|
||
"running": self._running,
|
||
"subscribed_topics": list(self.message_handlers.keys()),
|
||
"async_mode": self._use_aiomqtt
|
||
}
|
||
|
||
|
||
def create_async_mqtt_service(config: Dict[str, Any] = None) -> AsyncMQTTService:
|
||
"""创建异步MQTT服务实例"""
|
||
return AsyncMQTTService(config) |