179 lines
4.4 KiB
Python
179 lines
4.4 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
|
||
"""
|
||
数据库会话管理模块
|
||
提供数据库会话获取和管理功能
|
||
包含数据库引擎创建、会话获取、连接池管理等功能
|
||
"""
|
||
|
||
import traceback
|
||
import logging
|
||
from contextlib import contextmanager, asynccontextmanager
|
||
from sqlalchemy import create_engine, text
|
||
from sqlalchemy.orm import sessionmaker, scoped_session
|
||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
|
||
from sqlalchemy.ext.asyncio import async_sessionmaker
|
||
|
||
from config.database_config import Base, DBConfig
|
||
from config.settings import settings
|
||
|
||
# 获取日志记录器
|
||
from utils.logger import get_logger
|
||
logger = get_logger("data.session")
|
||
|
||
# 创建数据库引擎
|
||
engine = create_engine(
|
||
DBConfig.DATABASE_URL,
|
||
**DBConfig.DATABASE_ARGS
|
||
)
|
||
|
||
# 创建会话工厂
|
||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||
|
||
# 创建 scoped_session,确保线程安全
|
||
db_session = scoped_session(SessionLocal)
|
||
|
||
# 创建异步数据库引擎
|
||
# 将mysql+pymysql替换为mysql+aiomysql以支持异步
|
||
async_engine = create_async_engine(
|
||
DBConfig.DATABASE_URL.replace('mysql+pymysql', 'mysql+aiomysql'),
|
||
**DBConfig.ASYNC_DATABASE_ARGS
|
||
)
|
||
|
||
# 创建异步会话工厂
|
||
AsyncSessionLocal = async_sessionmaker(
|
||
autocommit=False,
|
||
autoflush=False,
|
||
expire_on_commit=False,
|
||
class_=AsyncSession,
|
||
bind=async_engine
|
||
)
|
||
|
||
|
||
def get_session():
|
||
"""
|
||
获取数据库会话
|
||
|
||
Returns:
|
||
Session: 数据库会话对象
|
||
"""
|
||
return db_session()
|
||
|
||
|
||
def get_db():
|
||
"""
|
||
获取数据库会话生成器,用于FastAPI依赖注入
|
||
|
||
Yields:
|
||
Session: 数据库会话对象
|
||
"""
|
||
db = get_session()
|
||
try:
|
||
yield db
|
||
finally:
|
||
db.close()
|
||
|
||
|
||
@asynccontextmanager
|
||
async def get_async_session():
|
||
"""
|
||
获取异步数据库会话,使用异步上下文管理器
|
||
|
||
Yields:
|
||
AsyncSession: 异步数据库会话对象
|
||
"""
|
||
session = AsyncSessionLocal()
|
||
try:
|
||
yield session
|
||
await session.commit()
|
||
except Exception as e:
|
||
await session.rollback()
|
||
raise e
|
||
finally:
|
||
await session.close()
|
||
|
||
|
||
async def get_async_db():
|
||
"""
|
||
获取异步数据库会话生成器,用于FastAPI依赖注入
|
||
|
||
Yields:
|
||
AsyncSession: 异步数据库会话对象
|
||
"""
|
||
async with AsyncSessionLocal() as session:
|
||
yield session
|
||
|
||
|
||
@contextmanager
|
||
def session_scope():
|
||
"""
|
||
会话上下文管理器
|
||
提供自动提交和异常回滚功能
|
||
|
||
Yields:
|
||
Session: 数据库会话对象
|
||
"""
|
||
session = get_session()
|
||
try:
|
||
yield session
|
||
session.commit()
|
||
except Exception as e:
|
||
session.rollback()
|
||
raise e
|
||
finally:
|
||
session.close()
|
||
|
||
|
||
def init_database():
|
||
"""
|
||
初始化数据库和表结构
|
||
在应用启动时调用
|
||
|
||
Returns:
|
||
bool: 初始化是否成功
|
||
"""
|
||
logger.info("正在初始化数据库...")
|
||
try:
|
||
# 尝试连接到MySQL服务器并创建数据库(如果不存在)
|
||
root_engine = create_engine(DBConfig.create_db_url(include_db_name=False))
|
||
with root_engine.connect() as conn:
|
||
conn.execute(text(f"CREATE DATABASE IF NOT EXISTS {settings.DB_NAME} CHARACTER SET {settings.DB_CHARSET} COLLATE {settings.DB_CHARSET}_unicode_ci"))
|
||
logger.info(f"数据库 {settings.DB_NAME} 已创建或已存在")
|
||
|
||
# 连接到数据库并创建表
|
||
Base.metadata.create_all(bind=engine)
|
||
logger.info("数据库表初始化完成")
|
||
return True
|
||
except Exception as e:
|
||
logger.error(f"数据库初始化失败: {str(e)}")
|
||
logger.error(traceback.format_exc())
|
||
return False
|
||
|
||
|
||
def close_database_connections():
|
||
"""
|
||
关闭数据库连接池
|
||
在应用关闭时调用
|
||
"""
|
||
logger.info("正在关闭数据库连接...")
|
||
try:
|
||
engine.dispose()
|
||
logger.info("数据库连接已关闭")
|
||
except Exception as e:
|
||
logger.error(f"关闭数据库连接时出错: {str(e)}")
|
||
|
||
|
||
async def close_async_database_connections():
|
||
"""
|
||
关闭异步数据库连接池
|
||
在应用关闭时调用
|
||
"""
|
||
logger.info("正在关闭异步数据库连接...")
|
||
try:
|
||
await async_engine.dispose()
|
||
logger.info("异步数据库连接已关闭")
|
||
except Exception as e:
|
||
logger.error(f"关闭异步数据库连接时出错: {str(e)}")
|
||
|