220 lines
7.4 KiB
Python
220 lines
7.4 KiB
Python
# app.py
|
||
from fastapi import FastAPI, Request, HTTPException
|
||
from fastapi.responses import JSONResponse
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
import logging
|
||
import time
|
||
import traceback
|
||
from utils.logger import setup_logger
|
||
from config.component_config import register_all_components
|
||
from config.settings import (
|
||
AppConfig, ServerConfig, ApiConfig, CorsConfig
|
||
)
|
||
from api.task_api import router as task_router
|
||
from api.workflow_api import router as workflow_router
|
||
from api.component_api import router as component_router
|
||
from api.task_param_api import router as task_params_router
|
||
from core.exceptions import TianfengTaskError
|
||
from config.api_config import ApiResponseCode, ApiResponseMessage
|
||
from config.component_config import ComponentCategoryConfig
|
||
from api.task_instance_api import router as task_instance_router
|
||
# 导入数据库相关模块
|
||
from config.database import DBConfig, CacheConfig, db_session
|
||
import data.models # 导入所有模型以确保它们被注册
|
||
from data.models.component import ComponentCategory, ComponentType, Component, ComponentCategoryEnum
|
||
# 导入Lifespan
|
||
from contextlib import asynccontextmanager
|
||
|
||
# 设置日志
|
||
logger = setup_logger()
|
||
|
||
# 定义Lifespan上下文管理器
|
||
@asynccontextmanager
|
||
async def lifespan(app: FastAPI):
|
||
"""
|
||
应用生命周期管理
|
||
在应用启动时执行初始化操作,在应用关闭时执行清理操作
|
||
"""
|
||
# 应用启动时执行
|
||
logger.info("应用启动")
|
||
|
||
yield # 应用运行期间
|
||
|
||
# 应用关闭时执行
|
||
logger.info("应用关闭")
|
||
DBConfig.shutdown_session()
|
||
|
||
# 创建FastAPI应用,使用Lifespan
|
||
app = FastAPI(
|
||
title=ApiConfig.TITLE,
|
||
description=ApiConfig.DESCRIPTION,
|
||
version=ApiConfig.VERSION,
|
||
lifespan=lifespan
|
||
)
|
||
|
||
# 添加CORS中间件
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=CorsConfig.ALLOW_ORIGINS,
|
||
allow_credentials=CorsConfig.ALLOW_CREDENTIALS,
|
||
allow_methods=CorsConfig.ALLOW_METHODS,
|
||
allow_headers=CorsConfig.ALLOW_HEADERS,
|
||
)
|
||
|
||
# 初始化数据库
|
||
def init_database():
|
||
"""初始化数据库,创建所有表"""
|
||
try:
|
||
logger.info("开始初始化数据库...")
|
||
|
||
# 初始化数据库表
|
||
logger.info("开始创建数据库表...")
|
||
try:
|
||
DBConfig.init_db()
|
||
logger.info("数据库表创建成功")
|
||
except Exception as table_err:
|
||
logger.error(f"数据库表创建失败: {str(table_err)}")
|
||
# 打印详细错误信息和堆栈跟踪
|
||
logger.error(traceback.format_exc())
|
||
raise
|
||
|
||
# 初始化基础数据
|
||
logger.info("开始初始化基础数据...")
|
||
try:
|
||
init_base_data(db_session)
|
||
logger.info("基础数据初始化成功")
|
||
except Exception as data_err:
|
||
logger.error(f"基础数据初始化失败: {str(data_err)}")
|
||
# 打印详细错误信息和堆栈跟踪
|
||
logger.error(traceback.format_exc())
|
||
raise
|
||
|
||
return True
|
||
except Exception as e:
|
||
logger.error(f"数据库初始化失败: {str(e)}")
|
||
logger.error(traceback.format_exc())
|
||
# 继续执行程序,但记录错误
|
||
return False
|
||
|
||
def init_base_data(db_session):
|
||
"""初始化基础数据"""
|
||
try:
|
||
# 导入需要的模型
|
||
logger.info("检查基础数据...")
|
||
|
||
# 检查是否已存在组件分类
|
||
existing_categories = db_session.query(ComponentCategory).filter(ComponentCategory.is_deleted == False).all()
|
||
if existing_categories:
|
||
logger.info(f"已存在 {len(existing_categories)} 个组件分类,跳过初始化")
|
||
return
|
||
|
||
logger.info("开始创建组件分类...")
|
||
# 创建组件分类
|
||
categories = []
|
||
for category_enum in ComponentCategoryEnum:
|
||
logger.info(f"创建组件分类: {category_enum.value}")
|
||
category = ComponentCategory(
|
||
name=ComponentCategoryConfig.get_category_name(category_enum),
|
||
code=category_enum,
|
||
description=ComponentCategoryConfig.get_category_description(category_enum),
|
||
icon=f"icon-{category_enum.value}",
|
||
order=ComponentCategoryConfig.get_category_order(category_enum)
|
||
)
|
||
categories.append(category)
|
||
db_session.add(category)
|
||
|
||
logger.info("提交组件分类到数据库...")
|
||
db_session.commit()
|
||
logger.info(f"创建了 {len(categories)} 个组件分类")
|
||
|
||
# 这里可以添加更多基础数据的初始化,如组件类型、系统组件等
|
||
|
||
except Exception as e:
|
||
logger.error(f"基础数据初始化失败: {str(e)}")
|
||
logger.error(traceback.format_exc())
|
||
db_session.rollback()
|
||
raise
|
||
|
||
# 初始化数据库
|
||
init_result = init_database()
|
||
if not init_result:
|
||
logger.warning("数据库初始化失败,但程序将继续执行。请检查日志获取详细错误信息。")
|
||
|
||
# 注册所有组件
|
||
try:
|
||
register_all_components()
|
||
except Exception as e:
|
||
logger.error(f"组件注册失败: {str(e)}")
|
||
logger.error(traceback.format_exc())
|
||
|
||
# 注册API路由
|
||
app.include_router(task_router, prefix=ApiConfig.PREFIX)
|
||
app.include_router(workflow_router, prefix=ApiConfig.PREFIX)
|
||
app.include_router(component_router, prefix=ApiConfig.PREFIX)
|
||
app.include_router(task_instance_router, prefix=ApiConfig.PREFIX)
|
||
app.include_router(task_params_router, prefix=ApiConfig.PREFIX)
|
||
|
||
|
||
# 请求中间件
|
||
@app.middleware("http")
|
||
async def add_process_time_header(request: Request, call_next):
|
||
start_time = time.time()
|
||
response = await call_next(request)
|
||
process_time = time.time() - start_time
|
||
response.headers["X-Process-Time"] = str(process_time)
|
||
return response
|
||
|
||
# 数据库会话中间件
|
||
@app.middleware("http")
|
||
async def db_session_middleware(request: Request, call_next):
|
||
try:
|
||
response = await call_next(request)
|
||
return response
|
||
finally:
|
||
db_session.remove()
|
||
|
||
# 全局异常处理
|
||
@app.exception_handler(TianfengTaskError)
|
||
async def tianfeng_task_error_handler(request: Request, exc: TianfengTaskError):
|
||
"""处理天风任务模块异常"""
|
||
return JSONResponse(
|
||
status_code=ApiResponseCode.BAD_REQUEST,
|
||
content={
|
||
"code": ApiResponseCode.BAD_REQUEST,
|
||
"message": str(exc),
|
||
"data": None
|
||
}
|
||
)
|
||
|
||
@app.exception_handler(Exception)
|
||
async def global_exception_handler(request: Request, exc: Exception):
|
||
"""处理通用异常"""
|
||
logger.exception("未处理的异常")
|
||
return JSONResponse(
|
||
status_code=ApiResponseCode.SERVER_ERROR,
|
||
content={
|
||
"code": ApiResponseCode.SERVER_ERROR,
|
||
"message": f"{ApiResponseMessage.SERVER_ERROR}: {str(exc)}",
|
||
"data": None
|
||
}
|
||
)
|
||
|
||
# 健康检查接口
|
||
@app.get(f"{ApiConfig.PREFIX}/health")
|
||
async def health_check():
|
||
"""健康检查"""
|
||
return {
|
||
"code": ApiResponseCode.SUCCESS,
|
||
"message": "服务正常",
|
||
"data": {
|
||
"app_name": AppConfig.NAME,
|
||
"app_version": AppConfig.VERSION,
|
||
"status": "healthy"
|
||
}
|
||
}
|
||
|
||
# 主函数
|
||
if __name__ == "__main__":
|
||
import uvicorn
|
||
logger.info(f"启动天风任务模块服务,调试模式: {AppConfig.DEBUG}")
|
||
uvicorn.run("app:app", host=ServerConfig.HOST, port=ServerConfig.PORT, reload=ServerConfig.RELOAD) |