354 lines
14 KiB
Python
354 lines
14 KiB
Python
|
#!/usr/bin/env python
|
|||
|
# -*- coding: utf-8 -*-
|
|||
|
|
|||
|
"""
|
|||
|
异步MQTT客户端服务模块
|
|||
|
提供高性能的异步MQTT连接、发布、订阅等功能
|
|||
|
支持自动重连和异步消息处理
|
|||
|
"""
|
|||
|
|
|||
|
import asyncio
|
|||
|
import json
|
|||
|
import time
|
|||
|
import sys
|
|||
|
from typing import Dict, Any, Optional, Callable, List
|
|||
|
from utils.logger import get_logger
|
|||
|
|
|||
|
# Windows兼容性修复
|
|||
|
if sys.platform == 'win32':
|
|||
|
# 在Windows上使用SelectorEventLoop以支持add_reader/add_writer
|
|||
|
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
|||
|
|
|||
|
# try:
|
|||
|
# import aiomqtt
|
|||
|
# except ImportError:
|
|||
|
# try:
|
|||
|
# # 尝试导入旧版本的asyncio-mqtt
|
|||
|
#
|
|||
|
# except ImportError:
|
|||
|
# aiomqtt = None
|
|||
|
# 如果都没有安装,回退到paho-mqtt
|
|||
|
import aiomqtt
|
|||
|
|
|||
|
from config.tf_api_config import MQTT_CONFIG
|
|||
|
|
|||
|
logger = get_logger("services.async_mqtt_service")
|
|||
|
|
|||
|
|
|||
|
class AsyncMQTTService:
|
|||
|
"""异步MQTT客户端服务"""
|
|||
|
|
|||
|
def __init__(self, config: Dict[str, Any] = None):
|
|||
|
self.config = config or MQTT_CONFIG
|
|||
|
self.client = None
|
|||
|
self.connected = False
|
|||
|
self.reconnect_attempts = 0
|
|||
|
self.message_handlers: Dict[str, List[Callable]] = {}
|
|||
|
self._running = False
|
|||
|
self._connection_task = None
|
|||
|
self._message_task = None
|
|||
|
self._use_aiomqtt = aiomqtt is not None
|
|||
|
|
|||
|
if not self._use_aiomqtt:
|
|||
|
logger.warning("aiomqtt未安装,将使用paho-mqtt的异步包装")
|
|||
|
|
|||
|
async def connect(self):
|
|||
|
"""连接MQTT服务器"""
|
|||
|
try:
|
|||
|
if self._use_aiomqtt:
|
|||
|
await self._connect_aiomqtt()
|
|||
|
else:
|
|||
|
await self._connect_paho_async()
|
|||
|
return True
|
|||
|
except Exception as e:
|
|||
|
logger.error(f"连接MQTT服务器失败: {e}")
|
|||
|
return False
|
|||
|
|
|||
|
async def _connect_aiomqtt(self):
|
|||
|
"""使用aiomqtt连接"""
|
|||
|
client_config = {
|
|||
|
"hostname": self.config['host'],
|
|||
|
"port": self.config['port'],
|
|||
|
# "keepalive": self.config.get('keepalive', 60),
|
|||
|
}
|
|||
|
|
|||
|
if self.config.get('username'):
|
|||
|
client_config['username'] = self.config['username']
|
|||
|
client_config['password'] = self.config.get('password', '')
|
|||
|
# try:
|
|||
|
self.client = aiomqtt.Client(**client_config)
|
|||
|
# except TypeError as e:
|
|||
|
# # 处理版本兼容性问题,移除不支持的参数
|
|||
|
# logger.warning(f"aiomqtt版本兼容性问题,重试连接: {e}")
|
|||
|
# # 创建基础配置,只包含核心参数
|
|||
|
# basic_config = {
|
|||
|
# "hostname": self.config['host'],
|
|||
|
# "port": self.config['port']
|
|||
|
# }
|
|||
|
# if self.config.get('username'):
|
|||
|
# basic_config['username'] = self.config['username']
|
|||
|
# basic_config['password'] = self.config.get('password', '')
|
|||
|
# self.client = aiomqtt.Client(**basic_config)
|
|||
|
|
|||
|
self._running = True
|
|||
|
self._connection_task = asyncio.create_task(self._maintain_connection())
|
|||
|
|
|||
|
logger.info(f"正在连接MQTT服务器 {self.config['host']}:{self.config['port']}...")
|
|||
|
|
|||
|
# 等待连接建立
|
|||
|
timeout = 10
|
|||
|
start_time = time.time()
|
|||
|
while not self.connected and (time.time() - start_time) < timeout:
|
|||
|
await asyncio.sleep(0.1)
|
|||
|
|
|||
|
if self.connected:
|
|||
|
logger.info("异步MQTT连接建立成功")
|
|||
|
else:
|
|||
|
logger.error("异步MQTT连接建立超时")
|
|||
|
|
|||
|
async def _maintain_connection(self):
|
|||
|
"""维护MQTT连接(仅用于aiomqtt)"""
|
|||
|
while self._running:
|
|||
|
try:
|
|||
|
async with self.client:
|
|||
|
self.connected = True
|
|||
|
self.reconnect_attempts = 0
|
|||
|
logger.info(f"MQTT连接成功: {self.config['host']}:{self.config['port']}")
|
|||
|
|
|||
|
# 启动消息处理任务
|
|||
|
self._message_task = asyncio.create_task(self._handle_messages())
|
|||
|
|
|||
|
# 重新订阅所有topics
|
|||
|
await self._resubscribe_all_topics()
|
|||
|
|
|||
|
# 等待消息处理任务完成或连接断开
|
|||
|
await self._message_task
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
self.connected = False
|
|||
|
if self._message_task:
|
|||
|
self._message_task.cancel()
|
|||
|
|
|||
|
if self._running:
|
|||
|
self.reconnect_attempts += 1
|
|||
|
if self.reconnect_attempts <= self.config.get('max_retries', 3):
|
|||
|
delay = self.config.get('reconnect_delay', 5) * self.reconnect_attempts
|
|||
|
logger.warning(f"MQTT连接断开: {e}, 将在 {delay} 秒后重连(第{self.reconnect_attempts}次尝试)")
|
|||
|
await asyncio.sleep(delay)
|
|||
|
else:
|
|||
|
logger.error("MQTT重连次数已达上限,停止重连")
|
|||
|
break
|
|||
|
else:
|
|||
|
break
|
|||
|
|
|||
|
async def _handle_messages(self):
|
|||
|
"""处理接收到的消息(仅用于aiomqtt)"""
|
|||
|
try:
|
|||
|
async for message in self.client.messages:
|
|||
|
topic = str(message.topic)
|
|||
|
payload = message.payload.decode('utf-8')
|
|||
|
|
|||
|
# 异步处理消息
|
|||
|
await self._process_message(topic, payload)
|
|||
|
|
|||
|
except asyncio.CancelledError:
|
|||
|
logger.debug("消息处理任务被取消")
|
|||
|
except Exception as e:
|
|||
|
logger.error(f"处理MQTT消息失败: {e}", exc_info=True)
|
|||
|
|
|||
|
async def _process_message(self, topic: str, payload: str):
|
|||
|
"""异步处理单个消息"""
|
|||
|
try:
|
|||
|
handlers = self.message_handlers.get(topic, [])
|
|||
|
for handler in handlers:
|
|||
|
try:
|
|||
|
if asyncio.iscoroutinefunction(handler):
|
|||
|
await handler(topic, payload)
|
|||
|
else:
|
|||
|
# 对于同步处理器,在线程池中执行
|
|||
|
loop = asyncio.get_event_loop()
|
|||
|
await loop.run_in_executor(None, handler, topic, payload)
|
|||
|
except Exception as e:
|
|||
|
logger.error(f"MQTT消息处理器执行失败: {e}", exc_info=True)
|
|||
|
except Exception as e:
|
|||
|
logger.error(f"处理MQTT消息异常: {e}", exc_info=True)
|
|||
|
|
|||
|
async def _resubscribe_all_topics(self):
|
|||
|
"""重新订阅所有topics"""
|
|||
|
for topic in self.message_handlers.keys():
|
|||
|
try:
|
|||
|
await self.client.subscribe(topic)
|
|||
|
logger.debug(f"重新订阅MQTT主题: {topic}")
|
|||
|
except Exception as e:
|
|||
|
logger.error(f"重新订阅主题失败 {topic}: {e}")
|
|||
|
|
|||
|
async def disconnect(self):
|
|||
|
"""断开连接"""
|
|||
|
self._running = False
|
|||
|
self.connected = False
|
|||
|
|
|||
|
if self._message_task:
|
|||
|
self._message_task.cancel()
|
|||
|
try:
|
|||
|
await self._message_task
|
|||
|
except asyncio.CancelledError:
|
|||
|
pass
|
|||
|
|
|||
|
if self._connection_task:
|
|||
|
self._connection_task.cancel()
|
|||
|
try:
|
|||
|
await self._connection_task
|
|||
|
except asyncio.CancelledError:
|
|||
|
pass
|
|||
|
|
|||
|
logger.info("异步MQTT连接已断开")
|
|||
|
|
|||
|
async def subscribe(self, topic: str, handler: Callable = None):
|
|||
|
"""订阅主题"""
|
|||
|
try:
|
|||
|
if not self.connected:
|
|||
|
logger.warning(f"MQTT未连接,无法订阅主题: {topic}")
|
|||
|
return False
|
|||
|
|
|||
|
if self._use_aiomqtt and self.client:
|
|||
|
await self.client.subscribe(topic)
|
|||
|
logger.info(f"成功订阅MQTT主题: {topic}")
|
|||
|
|
|||
|
# 添加消息处理器
|
|||
|
if handler:
|
|||
|
if topic not in self.message_handlers:
|
|||
|
self.message_handlers[topic] = []
|
|||
|
if handler not in self.message_handlers[topic]:
|
|||
|
self.message_handlers[topic].append(handler)
|
|||
|
|
|||
|
return True
|
|||
|
else:
|
|||
|
logger.error("MQTT客户端未初始化")
|
|||
|
return False
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
logger.error(f"订阅MQTT主题异常: {e}")
|
|||
|
return False
|
|||
|
|
|||
|
async def unsubscribe(self, topic: str):
|
|||
|
"""取消订阅主题"""
|
|||
|
try:
|
|||
|
if not self.connected:
|
|||
|
logger.debug(f"MQTT未连接,跳过取消订阅主题: {topic}")
|
|||
|
return False
|
|||
|
|
|||
|
if self._use_aiomqtt and self.client:
|
|||
|
await self.client.unsubscribe(topic)
|
|||
|
logger.info(f"取消订阅MQTT主题: {topic}")
|
|||
|
|
|||
|
# 清理消息处理器
|
|||
|
if topic in self.message_handlers:
|
|||
|
del self.message_handlers[topic]
|
|||
|
|
|||
|
return True
|
|||
|
else:
|
|||
|
logger.warning(f"MQTT客户端未初始化,无法取消订阅主题: {topic}")
|
|||
|
return False
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
logger.error(f"取消订阅MQTT主题异常: {e}")
|
|||
|
return False
|
|||
|
|
|||
|
async def publish(self, topic: str, payload: Any, qos: int = 0, retain: bool = False):
|
|||
|
"""发布消息"""
|
|||
|
try:
|
|||
|
if not self.connected:
|
|||
|
logger.warning(f"MQTT未连接,无法发布消息到: {topic}")
|
|||
|
return False
|
|||
|
|
|||
|
# 处理不同类型的payload
|
|||
|
payload_str = self._serialize_payload(payload)
|
|||
|
|
|||
|
if self._use_aiomqtt and self.client:
|
|||
|
await self.client.publish(topic, payload_str, qos=qos, retain=retain)
|
|||
|
logger.debug(f"MQTT消息发布成功: {topic}")
|
|||
|
return True
|
|||
|
else:
|
|||
|
logger.error("MQTT客户端未初始化")
|
|||
|
return False
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
logger.error(f"发布MQTT消息异常: {e}")
|
|||
|
return False
|
|||
|
|
|||
|
def _serialize_payload(self, payload: Any) -> str:
|
|||
|
"""序列化消息载荷"""
|
|||
|
try:
|
|||
|
if isinstance(payload, (str, int, float, bool)):
|
|||
|
return str(payload)
|
|||
|
elif isinstance(payload, dict):
|
|||
|
return json.dumps(payload, ensure_ascii=False, default=self._json_default)
|
|||
|
elif isinstance(payload, list):
|
|||
|
return json.dumps(payload, ensure_ascii=False, default=self._json_default)
|
|||
|
elif hasattr(payload, 'to_dict') and callable(payload.to_dict):
|
|||
|
# 如果对象有to_dict方法,优先使用它
|
|||
|
return json.dumps(payload.to_dict(), ensure_ascii=False, default=self._json_default)
|
|||
|
elif hasattr(payload, 'dict'):
|
|||
|
# 如果是Pydantic模型等,尝试使用dict方法
|
|||
|
if callable(payload.dict):
|
|||
|
return json.dumps(payload.dict(), ensure_ascii=False, default=self._json_default)
|
|||
|
else:
|
|||
|
return json.dumps(payload.dict, ensure_ascii=False, default=self._json_default)
|
|||
|
elif hasattr(payload, '__dict__'):
|
|||
|
# 对于自定义对象,尝试序列化其属性
|
|||
|
return json.dumps(payload.__dict__, ensure_ascii=False, default=self._json_default)
|
|||
|
else:
|
|||
|
# 最后尝试直接转换为字符串
|
|||
|
return str(payload)
|
|||
|
except Exception as e:
|
|||
|
logger.warning(f"序列化载荷失败,使用字符串表示: {e}")
|
|||
|
return str(payload)
|
|||
|
|
|||
|
def _json_default(self, obj):
|
|||
|
"""JSON序列化的默认处理器,处理特殊类型"""
|
|||
|
from enum import Enum
|
|||
|
|
|||
|
if isinstance(obj, Enum):
|
|||
|
return obj.value
|
|||
|
elif hasattr(obj, 'to_dict') and callable(obj.to_dict):
|
|||
|
return obj.to_dict()
|
|||
|
elif hasattr(obj, '__dict__'):
|
|||
|
return obj.__dict__
|
|||
|
else:
|
|||
|
return str(obj)
|
|||
|
|
|||
|
def add_message_handler(self, topic: str, handler: Callable):
|
|||
|
"""为指定主题添加消息处理器"""
|
|||
|
if topic not in self.message_handlers:
|
|||
|
self.message_handlers[topic] = []
|
|||
|
if handler not in self.message_handlers[topic]:
|
|||
|
self.message_handlers[topic].append(handler)
|
|||
|
|
|||
|
def remove_message_handler(self, topic: str, handler: Callable):
|
|||
|
"""移除指定主题的消息处理器"""
|
|||
|
if topic in self.message_handlers:
|
|||
|
if handler in self.message_handlers[topic]:
|
|||
|
self.message_handlers[topic].remove(handler)
|
|||
|
if not self.message_handlers[topic]:
|
|||
|
del self.message_handlers[topic]
|
|||
|
|
|||
|
def is_connected(self) -> bool:
|
|||
|
"""检查连接状态"""
|
|||
|
return self.connected
|
|||
|
|
|||
|
def get_connection_info(self) -> Dict[str, Any]:
|
|||
|
"""获取连接信息"""
|
|||
|
return {
|
|||
|
"host": self.config['host'],
|
|||
|
"port": self.config['port'],
|
|||
|
"connected": self.connected,
|
|||
|
"reconnect_attempts": self.reconnect_attempts,
|
|||
|
"running": self._running,
|
|||
|
"subscribed_topics": list(self.message_handlers.keys()),
|
|||
|
"async_mode": self._use_aiomqtt
|
|||
|
}
|
|||
|
|
|||
|
|
|||
|
def create_async_mqtt_service(config: Dict[str, Any] = None) -> AsyncMQTTService:
|
|||
|
"""创建异步MQTT服务实例"""
|
|||
|
return AsyncMQTTService(config)
|