VWED_server/utils/background_task_manager.py

241 lines
6.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
后台任务管理器
统一管理系统中的所有 asyncio.create_task 创建的后台任务
防止内存泄漏和异常被静默吞噬
"""
import asyncio
import weakref
from typing import Set, Optional, Callable, Any
from utils.logger import get_logger
logger = get_logger("utils.background_task_manager")
class BackgroundTaskManager:
"""
后台任务管理器
功能:
1. 统一管理所有后台任务
2. 自动记录任务异常
3. 支持批量取消任务
4. 使用WeakSet自动清理已完成任务
"""
def __init__(self):
"""初始化任务管理器"""
# 使用普通set存储任务引用通过done_callback清理
self._tasks: Set[asyncio.Task] = set()
self._task_counter = 0
def create_task(
self,
coro,
name: Optional[str] = None,
context: Optional[str] = None
) -> asyncio.Task:
"""
创建并管理后台任务
Args:
coro: 协程对象
name: 任务名称(可选)
context: 任务上下文信息(用于日志)
Returns:
asyncio.Task: 创建的任务对象
"""
self._task_counter += 1
# 如果没有提供名称,使用自动编号
if name is None:
name = f"background_task_{self._task_counter}"
# 创建任务
task = asyncio.create_task(coro, name=name)
# 添加完成回调
task.add_done_callback(lambda t: self._task_done_callback(t, context))
# 添加到任务集合
self._tasks.add(task)
logger.debug(f"创建后台任务: {name}, 当前任务数: {len(self._tasks)}")
return task
def _task_done_callback(self, task: asyncio.Task, context: Optional[str] = None):
"""
任务完成回调
Args:
task: 完成的任务
context: 任务上下文信息
"""
task_name = task.get_name()
# 从任务集合中移除
self._tasks.discard(task)
try:
# 尝试获取任务结果,如果有异常会抛出
exception = task.exception()
if exception:
context_info = f" (context: {context})" if context else ""
logger.error(
f"后台任务异常: {task_name}{context_info}, "
f"异常类型: {type(exception).__name__}, "
f"异常信息: {str(exception)}"
)
else:
logger.debug(f"后台任务完成: {task_name}")
except asyncio.CancelledError:
logger.debug(f"后台任务被取消: {task_name}")
except Exception as e:
logger.error(f"获取任务结果失败: {task_name}, 错误: {str(e)}")
async def cancel_all(self, timeout: float = 5.0) -> int:
"""
取消所有运行中的任务
Args:
timeout: 等待任务取消的超时时间(秒)
Returns:
int: 被取消的任务数量
"""
if not self._tasks:
logger.debug("没有需要取消的后台任务")
return 0
# 获取所有未完成的任务
tasks_to_cancel = [t for t in self._tasks if not t.done()]
if not tasks_to_cancel:
logger.debug("所有后台任务已完成")
return 0
logger.info(f"开始取消 {len(tasks_to_cancel)} 个后台任务")
# 取消所有任务
for task in tasks_to_cancel:
task.cancel()
# 等待所有任务完成(带超时)
try:
await asyncio.wait_for(
asyncio.gather(*tasks_to_cancel, return_exceptions=True),
timeout=timeout
)
except asyncio.TimeoutError:
logger.warning(f"取消任务超时 ({timeout}秒),部分任务可能仍在运行")
cancelled_count = len(tasks_to_cancel)
logger.info(f"已取消 {cancelled_count} 个后台任务")
return cancelled_count
def get_running_tasks_count(self) -> int:
"""
获取运行中的任务数量
Returns:
int: 运行中的任务数量
"""
return len([t for t in self._tasks if not t.done()])
def get_total_tasks_count(self) -> int:
"""
获取总任务数量(包括已完成的)
Returns:
int: 总任务数量
"""
return len(self._tasks)
def get_stats(self) -> dict:
"""
获取任务统计信息
Returns:
dict: 统计信息
"""
total = len(self._tasks)
running = len([t for t in self._tasks if not t.done()])
completed = total - running
return {
"total": total,
"running": running,
"completed": completed,
"created_count": self._task_counter
}
async def cleanup_done_tasks(self):
"""
清理已完成的任务(可选的定期清理方法)
注意:通过 done_callback 已经自动清理,此方法作为备用
"""
before_count = len(self._tasks)
self._tasks = {t for t in self._tasks if not t.done()}
after_count = len(self._tasks)
cleaned = before_count - after_count
if cleaned > 0:
logger.debug(f"清理了 {cleaned} 个已完成的任务")
return cleaned
# 全局单例
background_task_manager = BackgroundTaskManager()
# 便捷函数
def create_background_task(
coro,
name: Optional[str] = None,
context: Optional[str] = None
) -> asyncio.Task:
"""
创建后台任务的便捷函数
Args:
coro: 协程对象
name: 任务名称
context: 任务上下文
Returns:
asyncio.Task: 创建的任务
"""
return background_task_manager.create_task(coro, name=name, context=context)
async def cancel_all_background_tasks(timeout: float = 5.0) -> int:
"""
取消所有后台任务的便捷函数
Args:
timeout: 超时时间(秒)
Returns:
int: 被取消的任务数量
"""
return await background_task_manager.cancel_all(timeout=timeout)
def get_background_tasks_stats() -> dict:
"""
获取后台任务统计信息的便捷函数
Returns:
dict: 统计信息
"""
return background_task_manager.get_stats()