410 lines
15 KiB
Python
410 lines
15 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
|
||
"""
|
||
任务持久化模块
|
||
负责任务队列持久化和恢复
|
||
"""
|
||
|
||
import asyncio
|
||
import logging
|
||
import json
|
||
import os
|
||
import time
|
||
from typing import Dict, List, Any, Optional, Set, Tuple
|
||
from datetime import datetime, timedelta
|
||
import aiofiles
|
||
from sqlalchemy import select, update, insert
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from config.settings import settings
|
||
from data.models.taskrecord import VWEDTaskRecord
|
||
from data.session import get_async_session
|
||
from utils.logger import get_logger
|
||
from data.enum.task_record_enum import TaskStatus
|
||
|
||
# 获取日志记录器
|
||
logger = get_logger("services.enhanced_scheduler.task_persistence")
|
||
|
||
class TaskPersistenceManager:
|
||
"""
|
||
任务持久化管理器
|
||
负责任务队列的持久化和恢复
|
||
"""
|
||
|
||
def __init__(self,
|
||
backup_interval: int = 300,
|
||
backup_dir: str = None,
|
||
max_backups: int = 5):
|
||
"""
|
||
初始化任务持久化管理器
|
||
|
||
Args:
|
||
backup_interval: 备份间隔(秒)
|
||
backup_dir: 备份目录
|
||
max_backups: 最大备份数量
|
||
"""
|
||
self.backup_interval = backup_interval
|
||
self.backup_dir = backup_dir or os.path.join(settings.DATA_DIR, "task_backups")
|
||
self.max_backups = max_backups
|
||
|
||
# 确保备份目录存在
|
||
os.makedirs(self.backup_dir, exist_ok=True)
|
||
|
||
self.is_running = False
|
||
self.backup_task = None
|
||
self.last_backup_time = datetime.now()
|
||
|
||
# 待持久化的任务队列 {task_id: task_info}
|
||
self.pending_tasks = {}
|
||
|
||
logger.info(f"初始化任务持久化管理器: 间隔={backup_interval}秒, 目录={self.backup_dir}, 最大备份数={max_backups}")
|
||
|
||
async def start(self) -> None:
|
||
"""
|
||
启动任务持久化管理器
|
||
"""
|
||
if self.is_running:
|
||
logger.warning("任务持久化管理器已经在运行中")
|
||
return
|
||
|
||
self.is_running = True
|
||
|
||
# 启动时清理过期备份文件
|
||
await self.clean_expired_backups(max_age_hours=24)
|
||
|
||
# 启动备份任务
|
||
self.backup_task = asyncio.create_task(self._backup_worker())
|
||
|
||
logger.info("任务持久化管理器启动成功")
|
||
|
||
async def stop(self) -> None:
|
||
"""
|
||
停止任务持久化管理器
|
||
"""
|
||
if not self.is_running:
|
||
logger.warning("任务持久化管理器未在运行")
|
||
return
|
||
|
||
self.is_running = False
|
||
|
||
# 取消备份任务
|
||
if self.backup_task:
|
||
self.backup_task.cancel()
|
||
try:
|
||
await self.backup_task
|
||
except asyncio.CancelledError:
|
||
pass
|
||
self.backup_task = None
|
||
|
||
# 执行最后一次备份
|
||
await self._backup_pending_tasks()
|
||
|
||
logger.info("任务持久化管理器已停止")
|
||
|
||
async def add_task(self, task_id: str, priority: int, task_info: Dict[str, Any]) -> None:
|
||
"""
|
||
添加任务到持久化队列
|
||
|
||
Args:
|
||
task_id: 任务ID
|
||
priority: 任务优先级
|
||
task_info: 任务信息
|
||
"""
|
||
self.pending_tasks[task_id] = {
|
||
"id": task_id,
|
||
"priority": priority,
|
||
"info": task_info,
|
||
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||
}
|
||
|
||
async def remove_task(self, task_id: str) -> bool:
|
||
"""
|
||
从持久化队列中移除任务
|
||
|
||
Args:
|
||
task_id: 任务ID
|
||
|
||
Returns:
|
||
bool: 是否成功移除
|
||
"""
|
||
if task_id in self.pending_tasks:
|
||
self.pending_tasks.pop(task_id)
|
||
return True
|
||
return False
|
||
|
||
async def load_pending_tasks(self) -> Dict[str, Dict[str, Any]]:
|
||
"""
|
||
加载待处理任务
|
||
恢复系统崩溃前未完成的任务
|
||
|
||
Returns:
|
||
Dict[str, Dict[str, Any]]: 待处理任务字典 {task_id: task_info}
|
||
"""
|
||
# 从数据库加载待处理任务
|
||
result = {}
|
||
|
||
try:
|
||
async with get_async_session() as session:
|
||
# 查询所有未完成的任务(状态不为完成、失败、取消)
|
||
query = select(VWEDTaskRecord).where(
|
||
VWEDTaskRecord.status.notin_([TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELED]) # 不在完成、失败、取消状态
|
||
)
|
||
|
||
db_result = await session.execute(query)
|
||
task_records = db_result.scalars().all()
|
||
|
||
for task in task_records:
|
||
# 尝试解析输入参数
|
||
try:
|
||
input_params = json.loads(task.input_params) if task.input_params else {}
|
||
except Exception:
|
||
input_params = {}
|
||
|
||
# 构建任务信息
|
||
task_info = {
|
||
"id": task.id,
|
||
"def_id": task.def_id,
|
||
"def_label": task.def_label,
|
||
"def_version": task.def_version,
|
||
"status": task.status,
|
||
"created_on": task.created_at.strftime("%Y-%m-%d %H:%M:%S") if task.created_at else None,
|
||
"priority": task.priority or 1,
|
||
"periodic_task": task.periodic_task,
|
||
"input_params": input_params
|
||
}
|
||
|
||
# 添加到结果
|
||
result[task.id] = {
|
||
"id": task.id,
|
||
"priority": task.priority or 1,
|
||
"info": task_info,
|
||
"timestamp": task.created_at.strftime("%Y-%m-%d %H:%M:%S") if task.created_at else datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||
"from_database": True # 标记数据来源为数据库
|
||
}
|
||
|
||
# 如果有备份文件,尝试合并最新的备份信息
|
||
if result:
|
||
await self._merge_latest_backup(result)
|
||
logger.info(f"已加载 {len(result)} 个待处理任务")
|
||
|
||
return result
|
||
|
||
except Exception as e:
|
||
logger.error(f"加载待处理任务异常: {str(e)}")
|
||
return {}
|
||
|
||
async def _merge_latest_backup(self, result: Dict[str, Dict[str, Any]]) -> None:
|
||
"""
|
||
合并最新的备份信息
|
||
|
||
Args:
|
||
result: 现有任务字典,用于合并
|
||
"""
|
||
try:
|
||
backup_files = []
|
||
|
||
# 获取所有备份文件
|
||
for file in os.listdir(self.backup_dir):
|
||
if file.startswith("tasks_backup_") and file.endswith(".json"):
|
||
backup_files.append(os.path.join(self.backup_dir, file))
|
||
|
||
if not backup_files:
|
||
return
|
||
|
||
# 按修改时间排序
|
||
backup_files.sort(key=lambda x: os.path.getmtime(x), reverse=True)
|
||
|
||
# 加载最新的备份
|
||
latest_backup = backup_files[0]
|
||
|
||
# 获取备份文件的创建时间(用于判断备份是否过期)
|
||
backup_time = datetime.fromtimestamp(os.path.getmtime(latest_backup))
|
||
current_time = datetime.now()
|
||
time_difference = (current_time - backup_time).total_seconds()
|
||
|
||
# 如果备份文件超过12小时,不再使用它恢复任务
|
||
if time_difference > 12 * 3600:
|
||
logger.warning(f"备份文件 {latest_backup} 已过期 ({time_difference/3600:.1f}小时),不使用它恢复任务")
|
||
return
|
||
|
||
async with aiofiles.open(latest_backup, "r", encoding="utf-8") as f:
|
||
content = await f.read()
|
||
backup_tasks = json.loads(content)
|
||
|
||
# 合并任务
|
||
for task_id, task_info in backup_tasks.items():
|
||
# 如果来自数据库的任务中已经有这个ID,不覆盖
|
||
if task_id in result and result[task_id].get("from_database", False):
|
||
continue
|
||
|
||
# 检查任务状态
|
||
status = task_info.get("info", {}).get("status")
|
||
|
||
# 只加载状态为1001(执行中), 1002(队列中)
|
||
if status in [TaskStatus.RUNNING, TaskStatus.QUEUED]:
|
||
result[task_id] = task_info
|
||
result[task_id]["from_backup"] = True # 标记数据来源为备份
|
||
else:
|
||
logger.debug(f"从备份中跳过状态为 {status} 的任务 {task_id}")
|
||
|
||
logger.info(f"从备份 {latest_backup} 加载任务,合并后共 {len(result)} 个任务")
|
||
|
||
except Exception as e:
|
||
logger.error(f"合并备份信息异常: {str(e)}")
|
||
|
||
async def _load_latest_backup(self, result: Dict[str, Dict[str, Any]]) -> None:
|
||
"""
|
||
[已弃用] 加载最新的备份
|
||
请使用 _merge_latest_backup 代替
|
||
|
||
Args:
|
||
result: 现有任务字典,用于合并
|
||
"""
|
||
# 调用新方法
|
||
await self._merge_latest_backup(result)
|
||
|
||
async def clean_expired_backups(self, max_age_hours: int = 24) -> int:
|
||
"""
|
||
清理过期的备份文件
|
||
|
||
Args:
|
||
max_age_hours: 最大保留时间(小时)
|
||
|
||
Returns:
|
||
int: 清理的文件数量
|
||
"""
|
||
try:
|
||
backup_files = []
|
||
current_time = datetime.now()
|
||
cleaned_count = 0
|
||
|
||
# 获取所有备份文件
|
||
for file in os.listdir(self.backup_dir):
|
||
if file.startswith("tasks_backup_") and file.endswith(".json"):
|
||
file_path = os.path.join(self.backup_dir, file)
|
||
backup_files.append(file_path)
|
||
|
||
# 按修改时间排序
|
||
backup_files.sort(key=lambda x: os.path.getmtime(x))
|
||
|
||
# 保留最新的max_backups个备份,其他的按时间清理
|
||
files_to_preserve = backup_files[-self.max_backups:] if len(backup_files) > self.max_backups else backup_files
|
||
|
||
for file_path in backup_files:
|
||
# 如果是需要保留的文件,跳过
|
||
if file_path in files_to_preserve:
|
||
continue
|
||
|
||
# 获取文件创建时间
|
||
file_time = datetime.fromtimestamp(os.path.getmtime(file_path))
|
||
age_hours = (current_time - file_time).total_seconds() / 3600
|
||
|
||
# 如果文件超过最大保留时间,删除
|
||
if age_hours > max_age_hours:
|
||
os.remove(file_path)
|
||
cleaned_count += 1
|
||
logger.debug(f"清理过期备份文件: {file_path}, 创建时间: {file_time}, 保留时间: {age_hours:.1f}小时")
|
||
|
||
if cleaned_count > 0:
|
||
logger.info(f"清理 {cleaned_count} 个过期备份文件,最大保留时间: {max_age_hours}小时")
|
||
|
||
return cleaned_count
|
||
|
||
except Exception as e:
|
||
logger.error(f"清理过期备份文件异常: {str(e)}")
|
||
return 0
|
||
|
||
async def _backup_worker(self) -> None:
|
||
"""
|
||
备份工作线程
|
||
定期备份任务队列
|
||
"""
|
||
logger.info("备份工作线程启动")
|
||
|
||
while self.is_running:
|
||
try:
|
||
# 检查是否需要备份
|
||
now = datetime.now()
|
||
if (now - self.last_backup_time).total_seconds() > self.backup_interval:
|
||
await self._backup_pending_tasks()
|
||
self.last_backup_time = now
|
||
|
||
# 休眠一段时间
|
||
await asyncio.sleep(self.backup_interval / 4)
|
||
|
||
except asyncio.CancelledError:
|
||
# 取消异常,退出循环
|
||
logger.info("备份工作线程被取消")
|
||
break
|
||
except Exception as e:
|
||
logger.error(f"备份工作线程异常: {str(e)}")
|
||
# 出现异常时短暂休眠,避免频繁错误
|
||
await asyncio.sleep(5.0)
|
||
|
||
logger.info("备份工作线程结束")
|
||
|
||
async def _backup_pending_tasks(self) -> None:
|
||
"""
|
||
备份待处理任务
|
||
"""
|
||
if not self.pending_tasks:
|
||
return
|
||
|
||
try:
|
||
# 生成备份文件名
|
||
backup_file = os.path.join(
|
||
self.backup_dir,
|
||
f"tasks_backup_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
|
||
)
|
||
|
||
# 备份任务
|
||
async with aiofiles.open(backup_file, "w", encoding="utf-8") as f:
|
||
await f.write(json.dumps(self.pending_tasks, ensure_ascii=False, indent=2))
|
||
|
||
logger.info(f"备份 {len(self.pending_tasks)} 个任务到 {backup_file}")
|
||
|
||
# 清理旧备份
|
||
await self._cleanup_old_backups()
|
||
|
||
except Exception as e:
|
||
logger.error(f"备份任务异常: {str(e)}")
|
||
|
||
async def _cleanup_old_backups(self) -> None:
|
||
"""
|
||
清理旧备份
|
||
保留最新的max_backups个备份
|
||
"""
|
||
try:
|
||
backup_files = []
|
||
|
||
# 获取所有备份文件
|
||
for file in os.listdir(self.backup_dir):
|
||
if file.startswith("tasks_backup_") and file.endswith(".json"):
|
||
backup_files.append(os.path.join(self.backup_dir, file))
|
||
|
||
# 按修改时间排序
|
||
backup_files.sort(key=lambda x: os.path.getmtime(x), reverse=True)
|
||
|
||
# 删除多余的备份
|
||
for file in backup_files[self.max_backups:]:
|
||
os.remove(file)
|
||
logger.debug(f"删除旧备份: {file}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"清理旧备份异常: {str(e)}")
|
||
|
||
def get_backup_status(self) -> Dict[str, Any]:
|
||
"""
|
||
获取备份状态信息
|
||
|
||
Returns:
|
||
Dict[str, Any]: 备份状态
|
||
"""
|
||
return {
|
||
"pending_tasks": len(self.pending_tasks),
|
||
"backup_interval": self.backup_interval,
|
||
"backup_dir": self.backup_dir,
|
||
"max_backups": self.max_backups,
|
||
"last_backup": self.last_backup_time.strftime("%Y-%m-%d %H:%M:%S")
|
||
} |