#!/usr/bin/env python # -*- coding: utf-8 -*- """ 数据库会话管理模块 提供数据库会话获取和管理功能 包含数据库引擎创建、会话获取、连接池管理等功能 """ import traceback import logging from contextlib import contextmanager, asynccontextmanager from sqlalchemy import create_engine, text from sqlalchemy.orm import sessionmaker, scoped_session from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession from sqlalchemy.ext.asyncio import async_sessionmaker from config.database_config import Base, DBConfig from config.settings import settings # 获取日志记录器 from utils.logger import get_logger logger = get_logger("data.session") # 创建数据库引擎 engine = create_engine( DBConfig.DATABASE_URL, **DBConfig.DATABASE_ARGS ) # 创建会话工厂 SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) # 创建 scoped_session,确保线程安全 db_session = scoped_session(SessionLocal) # 创建异步数据库引擎 # 将mysql+pymysql替换为mysql+aiomysql以支持异步 async_engine = create_async_engine( DBConfig.DATABASE_URL.replace('mysql+pymysql', 'mysql+aiomysql'), **DBConfig.ASYNC_DATABASE_ARGS ) # 创建异步会话工厂 AsyncSessionLocal = async_sessionmaker( autocommit=False, autoflush=False, expire_on_commit=False, class_=AsyncSession, bind=async_engine ) def get_session(): """ 获取数据库会话 Returns: Session: 数据库会话对象 """ return db_session() def get_db(): """ 获取数据库会话生成器,用于FastAPI依赖注入 Yields: Session: 数据库会话对象 """ db = get_session() try: yield db finally: db.close() @asynccontextmanager async def get_async_session(): """ 获取异步数据库会话,使用异步上下文管理器 Yields: AsyncSession: 异步数据库会话对象 """ session = AsyncSessionLocal() try: yield session await session.commit() except Exception as e: await session.rollback() raise e finally: await session.close() async def get_async_db(): """ 获取异步数据库会话生成器,用于FastAPI依赖注入 Yields: AsyncSession: 异步数据库会话对象 """ async with AsyncSessionLocal() as session: yield session @contextmanager def session_scope(): """ 会话上下文管理器 提供自动提交和异常回滚功能 Yields: Session: 数据库会话对象 """ session = get_session() try: yield session session.commit() except Exception as e: session.rollback() raise e finally: session.close() def init_database(): """ 初始化数据库和表结构 在应用启动时调用 Returns: bool: 初始化是否成功 """ logger.info("正在初始化数据库...") try: # 尝试连接到MySQL服务器并创建数据库(如果不存在) root_engine = create_engine(DBConfig.create_db_url(include_db_name=False)) with root_engine.connect() as conn: conn.execute(text(f"CREATE DATABASE IF NOT EXISTS {settings.DB_NAME} CHARACTER SET {settings.DB_CHARSET} COLLATE {settings.DB_CHARSET}_unicode_ci")) logger.info(f"数据库 {settings.DB_NAME} 已创建或已存在") # 连接到数据库并创建表 Base.metadata.create_all(bind=engine) logger.info("数据库表初始化完成") return True except Exception as e: logger.error(f"数据库初始化失败: {str(e)}") logger.error(traceback.format_exc()) return False def close_database_connections(): """ 关闭数据库连接池 在应用关闭时调用 """ logger.info("正在关闭数据库连接...") try: engine.dispose() logger.info("数据库连接已关闭") except Exception as e: logger.error(f"关闭数据库连接时出错: {str(e)}") async def close_async_database_connections(): """ 关闭异步数据库连接池 在应用关闭时调用 """ logger.info("正在关闭异步数据库连接...") try: await async_engine.dispose() logger.info("异步数据库连接已关闭") except Exception as e: logger.error(f"关闭异步数据库连接时出错: {str(e)}")