111 lines
3.1 KiB
Python
111 lines
3.1 KiB
Python
#!/usr/bin/env python
|
|
# -*- coding: utf-8 -*-
|
|
|
|
"""
|
|
数据库管理路由
|
|
提供数据库状态检查和手动初始化接口
|
|
"""
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, status
|
|
from pydantic import BaseModel
|
|
from sqlalchemy import text
|
|
from typing import Dict, Any, List
|
|
import traceback
|
|
import os
|
|
|
|
from config.database_config import Base
|
|
from data.session import get_session, init_database, engine
|
|
from config.settings import settings
|
|
|
|
router = APIRouter(
|
|
prefix="/db",
|
|
tags=["database"],
|
|
responses={404: {"description": "未找到"}},
|
|
dependencies=[] if settings.DEBUG else [Depends(lambda: HTTPException(status_code=403, detail="此功能仅在开发环境可用"))]
|
|
)
|
|
|
|
|
|
class DatabaseStatus(BaseModel):
|
|
status: str
|
|
connection: bool
|
|
tables: List[str]
|
|
db_name: str
|
|
app_env: str
|
|
message: str = None
|
|
|
|
|
|
@router.get("/status", response_model=DatabaseStatus)
|
|
async def get_database_status():
|
|
"""
|
|
获取数据库连接状态和可用表列表
|
|
"""
|
|
try:
|
|
# 获取数据库会话
|
|
session = get_session()
|
|
|
|
# 检查连接
|
|
session.execute(text("SELECT 1"))
|
|
|
|
# 获取所有表名
|
|
table_names = []
|
|
for table in Base.metadata.sorted_tables:
|
|
table_names.append(table.name)
|
|
|
|
return {
|
|
"status": "ok",
|
|
"connection": True,
|
|
"tables": table_names,
|
|
"db_name": settings.DB_NAME,
|
|
"app_env": os.getenv("APP_ENV", "development")
|
|
}
|
|
except Exception as e:
|
|
return {
|
|
"status": "error",
|
|
"connection": False,
|
|
"tables": [],
|
|
"db_name": settings.DB_NAME,
|
|
"app_env": os.getenv("APP_ENV", "development"),
|
|
"message": str(e)
|
|
}
|
|
finally:
|
|
if 'session' in locals():
|
|
session.close()
|
|
|
|
|
|
@router.post("/init", response_model=Dict[str, Any])
|
|
async def reinit_database():
|
|
"""
|
|
手动初始化数据库表结构
|
|
使用场景:
|
|
1. 首次部署时创建表结构
|
|
2. 模型更新后重建表结构
|
|
"""
|
|
if not settings.DEBUG:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="此功能仅在开发环境可用"
|
|
)
|
|
|
|
try:
|
|
# 调用session中的初始化函数
|
|
result = init_database()
|
|
|
|
if result:
|
|
return {
|
|
"status": "success",
|
|
"message": "数据库表初始化成功",
|
|
"tables": [table.name for table in Base.metadata.sorted_tables],
|
|
"db_name": settings.DB_NAME,
|
|
"app_env": os.getenv("APP_ENV", "development")
|
|
}
|
|
else:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="数据库初始化失败,请查看日志获取详细信息"
|
|
)
|
|
except Exception as e:
|
|
error_detail = traceback.format_exc()
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"数据库初始化失败: {str(e)}\n{error_detail}"
|
|
) |