245 lines
8.0 KiB
Python
245 lines
8.0 KiB
Python
# app.py
|
||
from fastapi import FastAPI, Request, HTTPException
|
||
from fastapi.responses import JSONResponse
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from fastapi.exceptions import RequestValidationError
|
||
import logging
|
||
import time
|
||
import traceback
|
||
from utils.logger import get_logger
|
||
from contextlib import asynccontextmanager
|
||
import uvicorn
|
||
import os
|
||
# 导入配置
|
||
from config.settings import settings
|
||
from config.error_messages import VALIDATION_ERROR_MESSAGES, HTTP_ERROR_MESSAGES
|
||
# 导入数据库相关
|
||
from data.session import init_database, close_database_connections, close_async_database_connections
|
||
from data.cache import redis_client
|
||
# 引入路由
|
||
from routes.database import router as db_router
|
||
from routes.template_api import router as template_router
|
||
from routes.task_api import router as task_router
|
||
from routes.common_api import router as common_router, format_response
|
||
from routes.task_edit_api import router as task_edit_router
|
||
from routes.script_api import router as script_router
|
||
from routes.task_record_api import router as task_record_router
|
||
# 设置日志
|
||
logger = get_logger("app")
|
||
|
||
@asynccontextmanager
|
||
async def lifespan(app: FastAPI):
|
||
"""
|
||
应用程序生命周期管理
|
||
启动时初始化数据库和任务调度器,关闭时清理资源
|
||
"""
|
||
# 启动前的初始化操作
|
||
# 初始化数据库
|
||
init_database()
|
||
# 初始化Redis连接
|
||
if redis_client.get_client() is None:
|
||
logger.warning("Redis连接失败,部分功能可能无法正常使用")
|
||
|
||
# 启动增强版任务调度器
|
||
from services.enhanced_scheduler import scheduler
|
||
await scheduler.start(worker_count=settings.TASK_SCHEDULER_MIN_WORKER_COUNT)
|
||
logger.info(f"增强版任务调度器已启动,最小工作线程数: {settings.TASK_SCHEDULER_MIN_WORKER_COUNT},最大工作线程数: {settings.TASK_SCHEDULER_MAX_WORKER_COUNT}")
|
||
|
||
yield
|
||
|
||
# 应用程序关闭前的清理操作
|
||
logger.info("应用程序关闭中...")
|
||
|
||
# 停止增强版任务调度器
|
||
from services.enhanced_scheduler import scheduler
|
||
await scheduler.stop()
|
||
logger.info("增强版任务调度器已停止")
|
||
|
||
await close_async_database_connections() # 关闭异步数据库连接
|
||
close_database_connections() # 关闭同步数据库连接
|
||
|
||
# 创建FastAPI应用
|
||
app = FastAPI(
|
||
title=settings.APP_NAME,
|
||
description=settings.APP_DESCRIPTION,
|
||
version=settings.APP_VERSION,
|
||
lifespan=lifespan,
|
||
debug=settings.DEBUG
|
||
)
|
||
|
||
# 添加CORS中间件
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=settings.CORS_ORIGINS,
|
||
allow_credentials=settings.CORS_ALLOW_CREDENTIALS,
|
||
allow_methods=settings.CORS_ALLOW_METHODS,
|
||
allow_headers=settings.CORS_ALLOW_HEADERS,
|
||
)
|
||
|
||
# 请求日志中间件
|
||
@app.middleware("http")
|
||
async def log_requests(request: Request, call_next):
|
||
"""记录请求日志的中间件"""
|
||
start_time = time.time()
|
||
|
||
# 获取请求信息
|
||
method = request.method
|
||
url = request.url.path
|
||
client_host = request.client.host if request.client else "unknown"
|
||
|
||
# 记录请求
|
||
logger.info(f"请求开始: {method} {url} 来自 {client_host}")
|
||
|
||
try:
|
||
# 处理请求
|
||
response = await call_next(request)
|
||
|
||
# 计算处理时间
|
||
process_time = time.time() - start_time
|
||
logger.info(f"请求完成: {method} {url} 状态码: {response.status_code} 耗时: {process_time:.4f}秒")
|
||
|
||
return response
|
||
except Exception as e:
|
||
# 记录异常
|
||
process_time = time.time() - start_time
|
||
logger.error(f"请求异常: {method} {url} 耗时: {process_time:.4f}秒")
|
||
logger.error(f"异常详情: {str(e)}")
|
||
logger.error(traceback.format_exc())
|
||
|
||
# 返回通用错误响应
|
||
return JSONResponse(
|
||
status_code=500,
|
||
content=format_response(
|
||
code=500,
|
||
message="服务器内部错误,请联系管理员",
|
||
data=None
|
||
)
|
||
)
|
||
|
||
# 全局验证错误处理器
|
||
@app.exception_handler(RequestValidationError)
|
||
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
||
"""
|
||
处理验证错误,将错误消息转换为中文,并提供更友好的错误提示
|
||
包括显示具体缺失的字段名称
|
||
"""
|
||
errors = exc.errors()
|
||
error_details = []
|
||
missing_fields = []
|
||
|
||
for error in errors:
|
||
error_type = error.get("type", "")
|
||
loc = error.get("loc", [])
|
||
|
||
# 获取完整的字段路径,排除body/query等
|
||
if len(loc) > 1 and loc[0] in ["body", "query", "path", "header"]:
|
||
field_path = ".".join(str(item) for item in loc[1:])
|
||
else:
|
||
field_path = ".".join(str(item) for item in loc)
|
||
|
||
# 获取中文错误消息
|
||
message = VALIDATION_ERROR_MESSAGES.get(error_type, error.get("msg", "验证错误"))
|
||
|
||
# 替换消息中的参数
|
||
context = error.get("ctx", {})
|
||
for key, value in context.items():
|
||
message = message.replace(f"{{{key}}}", str(value))
|
||
|
||
# 收集缺失字段
|
||
if error_type == "missing" or error_type == "value_error.missing":
|
||
missing_fields.append(field_path)
|
||
|
||
error_details.append({
|
||
"field": field_path,
|
||
"message": message,
|
||
"type": error_type
|
||
})
|
||
|
||
# 构建友好的错误响应
|
||
if missing_fields:
|
||
missing_fields_str = ", ".join(missing_fields)
|
||
error_message = f"缺少必填字段: {missing_fields_str}"
|
||
elif error_details:
|
||
# 提取第一个错误的字段和消息
|
||
first_error = error_details[0]
|
||
error_message = f"参数 '{first_error['field']}' 验证失败: {first_error['message']}"
|
||
else:
|
||
error_message = "参数验证失败"
|
||
|
||
return JSONResponse(
|
||
status_code=400,
|
||
content={
|
||
"code": 400,
|
||
"message": error_message,
|
||
"data": error_details if len(error_details) > 1 else None
|
||
}
|
||
)
|
||
|
||
# HTTP错误处理器
|
||
@app.exception_handler(HTTPException)
|
||
async def http_exception_handler(request: Request, exc: HTTPException):
|
||
"""处理HTTP异常,转换为统一的响应格式"""
|
||
status_code = exc.status_code
|
||
# 获取错误消息,优先使用自定义消息,否则使用配置中的错误消息
|
||
message = exc.detail
|
||
if isinstance(message, str) and message == "Not Found":
|
||
message = HTTP_ERROR_MESSAGES.get(status_code, message)
|
||
|
||
return JSONResponse(
|
||
status_code=status_code,
|
||
content=format_response(
|
||
code=status_code,
|
||
message=message,
|
||
data=None
|
||
)
|
||
)
|
||
|
||
# 全局异常处理器
|
||
@app.exception_handler(Exception)
|
||
async def global_exception_handler(request: Request, exc: Exception):
|
||
"""处理所有未捕获的异常"""
|
||
logger.error(f"未捕获异常: {str(exc)}")
|
||
logger.error(traceback.format_exc())
|
||
|
||
return JSONResponse(
|
||
status_code=500,
|
||
content=format_response(
|
||
code=500,
|
||
message="服务器内部错误,请联系管理员",
|
||
data=None if not settings.DEBUG else str(exc)
|
||
)
|
||
)
|
||
|
||
# 注册路由
|
||
app.include_router(common_router)
|
||
app.include_router(db_router)
|
||
app.include_router(template_router)
|
||
app.include_router(task_router)
|
||
app.include_router(task_edit_router)
|
||
app.include_router(script_router)
|
||
app.include_router(task_record_router)
|
||
# 根路由
|
||
@app.get("/")
|
||
async def root():
|
||
"""API根路由,显示系统基本信息"""
|
||
return {
|
||
"app_name": settings.APP_NAME,
|
||
"version": settings.APP_VERSION,
|
||
"description": settings.APP_DESCRIPTION,
|
||
"status": "running"
|
||
}
|
||
|
||
# 主函数
|
||
if __name__ == "__main__":
|
||
# 从环境变量中获取端口,默认为8000
|
||
port = int(os.environ.get("PORT", settings.SERVER_PORT))
|
||
|
||
# 启动服务器
|
||
uvicorn.run(
|
||
"app:app",
|
||
host="0.0.0.0",
|
||
port=port,
|
||
reload=settings.DEBUG,
|
||
workers=settings.SERVER_WORKERS
|
||
)
|