179 lines
4.4 KiB
Python
179 lines
4.4 KiB
Python
|
#!/usr/bin/env python
|
|||
|
# -*- coding: utf-8 -*-
|
|||
|
|
|||
|
"""
|
|||
|
数据库会话管理模块
|
|||
|
提供数据库会话获取和管理功能
|
|||
|
包含数据库引擎创建、会话获取、连接池管理等功能
|
|||
|
"""
|
|||
|
|
|||
|
import traceback
|
|||
|
import logging
|
|||
|
from contextlib import contextmanager, asynccontextmanager
|
|||
|
from sqlalchemy import create_engine, text
|
|||
|
from sqlalchemy.orm import sessionmaker, scoped_session
|
|||
|
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
|
|||
|
from sqlalchemy.ext.asyncio import async_sessionmaker
|
|||
|
|
|||
|
from config.database_config import Base, DBConfig
|
|||
|
from config.settings import settings
|
|||
|
|
|||
|
# 获取日志记录器
|
|||
|
from utils.logger import get_logger
|
|||
|
logger = get_logger("data.session")
|
|||
|
|
|||
|
# 创建数据库引擎
|
|||
|
engine = create_engine(
|
|||
|
DBConfig.DATABASE_URL,
|
|||
|
**DBConfig.DATABASE_ARGS
|
|||
|
)
|
|||
|
|
|||
|
# 创建会话工厂
|
|||
|
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
|||
|
|
|||
|
# 创建 scoped_session,确保线程安全
|
|||
|
db_session = scoped_session(SessionLocal)
|
|||
|
|
|||
|
# 创建异步数据库引擎
|
|||
|
# 将mysql+pymysql替换为mysql+aiomysql以支持异步
|
|||
|
async_engine = create_async_engine(
|
|||
|
DBConfig.DATABASE_URL.replace('mysql+pymysql', 'mysql+aiomysql'),
|
|||
|
**DBConfig.ASYNC_DATABASE_ARGS
|
|||
|
)
|
|||
|
|
|||
|
# 创建异步会话工厂
|
|||
|
AsyncSessionLocal = async_sessionmaker(
|
|||
|
autocommit=False,
|
|||
|
autoflush=False,
|
|||
|
expire_on_commit=False,
|
|||
|
class_=AsyncSession,
|
|||
|
bind=async_engine
|
|||
|
)
|
|||
|
|
|||
|
|
|||
|
def get_session():
|
|||
|
"""
|
|||
|
获取数据库会话
|
|||
|
|
|||
|
Returns:
|
|||
|
Session: 数据库会话对象
|
|||
|
"""
|
|||
|
return db_session()
|
|||
|
|
|||
|
|
|||
|
def get_db():
|
|||
|
"""
|
|||
|
获取数据库会话生成器,用于FastAPI依赖注入
|
|||
|
|
|||
|
Yields:
|
|||
|
Session: 数据库会话对象
|
|||
|
"""
|
|||
|
db = get_session()
|
|||
|
try:
|
|||
|
yield db
|
|||
|
finally:
|
|||
|
db.close()
|
|||
|
|
|||
|
|
|||
|
@asynccontextmanager
|
|||
|
async def get_async_session():
|
|||
|
"""
|
|||
|
获取异步数据库会话,使用异步上下文管理器
|
|||
|
|
|||
|
Yields:
|
|||
|
AsyncSession: 异步数据库会话对象
|
|||
|
"""
|
|||
|
session = AsyncSessionLocal()
|
|||
|
try:
|
|||
|
yield session
|
|||
|
await session.commit()
|
|||
|
except Exception as e:
|
|||
|
await session.rollback()
|
|||
|
raise e
|
|||
|
finally:
|
|||
|
await session.close()
|
|||
|
|
|||
|
|
|||
|
async def get_async_db():
|
|||
|
"""
|
|||
|
获取异步数据库会话生成器,用于FastAPI依赖注入
|
|||
|
|
|||
|
Yields:
|
|||
|
AsyncSession: 异步数据库会话对象
|
|||
|
"""
|
|||
|
async with AsyncSessionLocal() as session:
|
|||
|
yield session
|
|||
|
|
|||
|
|
|||
|
@contextmanager
|
|||
|
def session_scope():
|
|||
|
"""
|
|||
|
会话上下文管理器
|
|||
|
提供自动提交和异常回滚功能
|
|||
|
|
|||
|
Yields:
|
|||
|
Session: 数据库会话对象
|
|||
|
"""
|
|||
|
session = get_session()
|
|||
|
try:
|
|||
|
yield session
|
|||
|
session.commit()
|
|||
|
except Exception as e:
|
|||
|
session.rollback()
|
|||
|
raise e
|
|||
|
finally:
|
|||
|
session.close()
|
|||
|
|
|||
|
|
|||
|
def init_database():
|
|||
|
"""
|
|||
|
初始化数据库和表结构
|
|||
|
在应用启动时调用
|
|||
|
|
|||
|
Returns:
|
|||
|
bool: 初始化是否成功
|
|||
|
"""
|
|||
|
logger.info("正在初始化数据库...")
|
|||
|
try:
|
|||
|
# 尝试连接到MySQL服务器并创建数据库(如果不存在)
|
|||
|
root_engine = create_engine(DBConfig.create_db_url(include_db_name=False))
|
|||
|
with root_engine.connect() as conn:
|
|||
|
conn.execute(text(f"CREATE DATABASE IF NOT EXISTS {settings.DB_NAME} CHARACTER SET {settings.DB_CHARSET} COLLATE {settings.DB_CHARSET}_unicode_ci"))
|
|||
|
logger.info(f"数据库 {settings.DB_NAME} 已创建或已存在")
|
|||
|
|
|||
|
# 连接到数据库并创建表
|
|||
|
Base.metadata.create_all(bind=engine)
|
|||
|
logger.info("数据库表初始化完成")
|
|||
|
return True
|
|||
|
except Exception as e:
|
|||
|
logger.error(f"数据库初始化失败: {str(e)}")
|
|||
|
logger.error(traceback.format_exc())
|
|||
|
return False
|
|||
|
|
|||
|
|
|||
|
def close_database_connections():
|
|||
|
"""
|
|||
|
关闭数据库连接池
|
|||
|
在应用关闭时调用
|
|||
|
"""
|
|||
|
logger.info("正在关闭数据库连接...")
|
|||
|
try:
|
|||
|
engine.dispose()
|
|||
|
logger.info("数据库连接已关闭")
|
|||
|
except Exception as e:
|
|||
|
logger.error(f"关闭数据库连接时出错: {str(e)}")
|
|||
|
|
|||
|
|
|||
|
async def close_async_database_connections():
|
|||
|
"""
|
|||
|
关闭异步数据库连接池
|
|||
|
在应用关闭时调用
|
|||
|
"""
|
|||
|
logger.info("正在关闭异步数据库连接...")
|
|||
|
try:
|
|||
|
await async_engine.dispose()
|
|||
|
logger.info("异步数据库连接已关闭")
|
|||
|
except Exception as e:
|
|||
|
logger.error(f"关闭异步数据库连接时出错: {str(e)}")
|
|||
|
|