241 lines
6.4 KiB
Python
241 lines
6.4 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
|
||
"""
|
||
后台任务管理器
|
||
统一管理系统中的所有 asyncio.create_task 创建的后台任务
|
||
防止内存泄漏和异常被静默吞噬
|
||
"""
|
||
|
||
import asyncio
|
||
import weakref
|
||
from typing import Set, Optional, Callable, Any
|
||
from utils.logger import get_logger
|
||
|
||
logger = get_logger("utils.background_task_manager")
|
||
|
||
|
||
class BackgroundTaskManager:
|
||
"""
|
||
后台任务管理器
|
||
|
||
功能:
|
||
1. 统一管理所有后台任务
|
||
2. 自动记录任务异常
|
||
3. 支持批量取消任务
|
||
4. 使用WeakSet自动清理已完成任务
|
||
"""
|
||
|
||
def __init__(self):
|
||
"""初始化任务管理器"""
|
||
# 使用普通set存储任务引用,通过done_callback清理
|
||
self._tasks: Set[asyncio.Task] = set()
|
||
self._task_counter = 0
|
||
|
||
def create_task(
|
||
self,
|
||
coro,
|
||
name: Optional[str] = None,
|
||
context: Optional[str] = None
|
||
) -> asyncio.Task:
|
||
"""
|
||
创建并管理后台任务
|
||
|
||
Args:
|
||
coro: 协程对象
|
||
name: 任务名称(可选)
|
||
context: 任务上下文信息(用于日志)
|
||
|
||
Returns:
|
||
asyncio.Task: 创建的任务对象
|
||
"""
|
||
self._task_counter += 1
|
||
|
||
# 如果没有提供名称,使用自动编号
|
||
if name is None:
|
||
name = f"background_task_{self._task_counter}"
|
||
|
||
# 创建任务
|
||
task = asyncio.create_task(coro, name=name)
|
||
|
||
# 添加完成回调
|
||
task.add_done_callback(lambda t: self._task_done_callback(t, context))
|
||
|
||
# 添加到任务集合
|
||
self._tasks.add(task)
|
||
|
||
logger.debug(f"创建后台任务: {name}, 当前任务数: {len(self._tasks)}")
|
||
|
||
return task
|
||
|
||
def _task_done_callback(self, task: asyncio.Task, context: Optional[str] = None):
|
||
"""
|
||
任务完成回调
|
||
|
||
Args:
|
||
task: 完成的任务
|
||
context: 任务上下文信息
|
||
"""
|
||
task_name = task.get_name()
|
||
|
||
# 从任务集合中移除
|
||
self._tasks.discard(task)
|
||
|
||
try:
|
||
# 尝试获取任务结果,如果有异常会抛出
|
||
exception = task.exception()
|
||
|
||
if exception:
|
||
context_info = f" (context: {context})" if context else ""
|
||
logger.error(
|
||
f"后台任务异常: {task_name}{context_info}, "
|
||
f"异常类型: {type(exception).__name__}, "
|
||
f"异常信息: {str(exception)}"
|
||
)
|
||
else:
|
||
logger.debug(f"后台任务完成: {task_name}")
|
||
|
||
except asyncio.CancelledError:
|
||
logger.debug(f"后台任务被取消: {task_name}")
|
||
except Exception as e:
|
||
logger.error(f"获取任务结果失败: {task_name}, 错误: {str(e)}")
|
||
|
||
async def cancel_all(self, timeout: float = 5.0) -> int:
|
||
"""
|
||
取消所有运行中的任务
|
||
|
||
Args:
|
||
timeout: 等待任务取消的超时时间(秒)
|
||
|
||
Returns:
|
||
int: 被取消的任务数量
|
||
"""
|
||
if not self._tasks:
|
||
logger.debug("没有需要取消的后台任务")
|
||
return 0
|
||
|
||
# 获取所有未完成的任务
|
||
tasks_to_cancel = [t for t in self._tasks if not t.done()]
|
||
|
||
if not tasks_to_cancel:
|
||
logger.debug("所有后台任务已完成")
|
||
return 0
|
||
|
||
logger.info(f"开始取消 {len(tasks_to_cancel)} 个后台任务")
|
||
|
||
# 取消所有任务
|
||
for task in tasks_to_cancel:
|
||
task.cancel()
|
||
|
||
# 等待所有任务完成(带超时)
|
||
try:
|
||
await asyncio.wait_for(
|
||
asyncio.gather(*tasks_to_cancel, return_exceptions=True),
|
||
timeout=timeout
|
||
)
|
||
except asyncio.TimeoutError:
|
||
logger.warning(f"取消任务超时 ({timeout}秒),部分任务可能仍在运行")
|
||
|
||
cancelled_count = len(tasks_to_cancel)
|
||
logger.info(f"已取消 {cancelled_count} 个后台任务")
|
||
|
||
return cancelled_count
|
||
|
||
def get_running_tasks_count(self) -> int:
|
||
"""
|
||
获取运行中的任务数量
|
||
|
||
Returns:
|
||
int: 运行中的任务数量
|
||
"""
|
||
return len([t for t in self._tasks if not t.done()])
|
||
|
||
def get_total_tasks_count(self) -> int:
|
||
"""
|
||
获取总任务数量(包括已完成的)
|
||
|
||
Returns:
|
||
int: 总任务数量
|
||
"""
|
||
return len(self._tasks)
|
||
|
||
def get_stats(self) -> dict:
|
||
"""
|
||
获取任务统计信息
|
||
|
||
Returns:
|
||
dict: 统计信息
|
||
"""
|
||
total = len(self._tasks)
|
||
running = len([t for t in self._tasks if not t.done()])
|
||
completed = total - running
|
||
|
||
return {
|
||
"total": total,
|
||
"running": running,
|
||
"completed": completed,
|
||
"created_count": self._task_counter
|
||
}
|
||
|
||
async def cleanup_done_tasks(self):
|
||
"""
|
||
清理已完成的任务(可选的定期清理方法)
|
||
注意:通过 done_callback 已经自动清理,此方法作为备用
|
||
"""
|
||
before_count = len(self._tasks)
|
||
self._tasks = {t for t in self._tasks if not t.done()}
|
||
after_count = len(self._tasks)
|
||
|
||
cleaned = before_count - after_count
|
||
if cleaned > 0:
|
||
logger.debug(f"清理了 {cleaned} 个已完成的任务")
|
||
|
||
return cleaned
|
||
|
||
|
||
# 全局单例
|
||
background_task_manager = BackgroundTaskManager()
|
||
|
||
|
||
# 便捷函数
|
||
def create_background_task(
|
||
coro,
|
||
name: Optional[str] = None,
|
||
context: Optional[str] = None
|
||
) -> asyncio.Task:
|
||
"""
|
||
创建后台任务的便捷函数
|
||
|
||
Args:
|
||
coro: 协程对象
|
||
name: 任务名称
|
||
context: 任务上下文
|
||
|
||
Returns:
|
||
asyncio.Task: 创建的任务
|
||
"""
|
||
return background_task_manager.create_task(coro, name=name, context=context)
|
||
|
||
|
||
async def cancel_all_background_tasks(timeout: float = 5.0) -> int:
|
||
"""
|
||
取消所有后台任务的便捷函数
|
||
|
||
Args:
|
||
timeout: 超时时间(秒)
|
||
|
||
Returns:
|
||
int: 被取消的任务数量
|
||
"""
|
||
return await background_task_manager.cancel_all(timeout=timeout)
|
||
|
||
|
||
def get_background_tasks_stats() -> dict:
|
||
"""
|
||
获取后台任务统计信息的便捷函数
|
||
|
||
Returns:
|
||
dict: 统计信息
|
||
"""
|
||
return background_task_manager.get_stats()
|