VWED_server/data/session.py
2025-04-30 16:57:46 +08:00

179 lines
4.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
数据库会话管理模块
提供数据库会话获取和管理功能
包含数据库引擎创建、会话获取、连接池管理等功能
"""
import traceback
import logging
from contextlib import contextmanager, asynccontextmanager
from sqlalchemy import create_engine, text
from sqlalchemy.orm import sessionmaker, scoped_session
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.ext.asyncio import async_sessionmaker
from config.database_config import Base, DBConfig
from config.settings import settings
# 获取日志记录器
from utils.logger import get_logger
logger = get_logger("data.session")
# 创建数据库引擎
engine = create_engine(
DBConfig.DATABASE_URL,
**DBConfig.DATABASE_ARGS
)
# 创建会话工厂
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# 创建 scoped_session确保线程安全
db_session = scoped_session(SessionLocal)
# 创建异步数据库引擎
# 将mysql+pymysql替换为mysql+aiomysql以支持异步
async_engine = create_async_engine(
DBConfig.DATABASE_URL.replace('mysql+pymysql', 'mysql+aiomysql'),
**DBConfig.ASYNC_DATABASE_ARGS
)
# 创建异步会话工厂
AsyncSessionLocal = async_sessionmaker(
autocommit=False,
autoflush=False,
expire_on_commit=False,
class_=AsyncSession,
bind=async_engine
)
def get_session():
"""
获取数据库会话
Returns:
Session: 数据库会话对象
"""
return db_session()
def get_db():
"""
获取数据库会话生成器用于FastAPI依赖注入
Yields:
Session: 数据库会话对象
"""
db = get_session()
try:
yield db
finally:
db.close()
@asynccontextmanager
async def get_async_session():
"""
获取异步数据库会话,使用异步上下文管理器
Yields:
AsyncSession: 异步数据库会话对象
"""
session = AsyncSessionLocal()
try:
yield session
await session.commit()
except Exception as e:
await session.rollback()
raise e
finally:
await session.close()
async def get_async_db():
"""
获取异步数据库会话生成器用于FastAPI依赖注入
Yields:
AsyncSession: 异步数据库会话对象
"""
async with AsyncSessionLocal() as session:
yield session
@contextmanager
def session_scope():
"""
会话上下文管理器
提供自动提交和异常回滚功能
Yields:
Session: 数据库会话对象
"""
session = get_session()
try:
yield session
session.commit()
except Exception as e:
session.rollback()
raise e
finally:
session.close()
def init_database():
"""
初始化数据库和表结构
在应用启动时调用
Returns:
bool: 初始化是否成功
"""
logger.info("正在初始化数据库...")
try:
# 尝试连接到MySQL服务器并创建数据库如果不存在
root_engine = create_engine(DBConfig.create_db_url(include_db_name=False))
with root_engine.connect() as conn:
conn.execute(text(f"CREATE DATABASE IF NOT EXISTS {settings.DB_NAME} CHARACTER SET {settings.DB_CHARSET} COLLATE {settings.DB_CHARSET}_unicode_ci"))
logger.info(f"数据库 {settings.DB_NAME} 已创建或已存在")
# 连接到数据库并创建表
Base.metadata.create_all(bind=engine)
logger.info("数据库表初始化完成")
return True
except Exception as e:
logger.error(f"数据库初始化失败: {str(e)}")
logger.error(traceback.format_exc())
return False
def close_database_connections():
"""
关闭数据库连接池
在应用关闭时调用
"""
logger.info("正在关闭数据库连接...")
try:
engine.dispose()
logger.info("数据库连接已关闭")
except Exception as e:
logger.error(f"关闭数据库连接时出错: {str(e)}")
async def close_async_database_connections():
"""
关闭异步数据库连接池
在应用关闭时调用
"""
logger.info("正在关闭异步数据库连接...")
try:
await async_engine.dispose()
logger.info("异步数据库连接已关闭")
except Exception as e:
logger.error(f"关闭异步数据库连接时出错: {str(e)}")