VWED_server/data/session.py

179 lines
4.4 KiB
Python
Raw Normal View History

2025-04-30 16:57:46 +08:00
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
数据库会话管理模块
提供数据库会话获取和管理功能
包含数据库引擎创建会话获取连接池管理等功能
"""
import traceback
import logging
from contextlib import contextmanager, asynccontextmanager
from sqlalchemy import create_engine, text
from sqlalchemy.orm import sessionmaker, scoped_session
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.ext.asyncio import async_sessionmaker
from config.database_config import Base, DBConfig
from config.settings import settings
# 获取日志记录器
from utils.logger import get_logger
logger = get_logger("data.session")
# 创建数据库引擎
engine = create_engine(
DBConfig.DATABASE_URL,
**DBConfig.DATABASE_ARGS
)
# 创建会话工厂
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# 创建 scoped_session确保线程安全
db_session = scoped_session(SessionLocal)
# 创建异步数据库引擎
# 将mysql+pymysql替换为mysql+aiomysql以支持异步
async_engine = create_async_engine(
DBConfig.DATABASE_URL.replace('mysql+pymysql', 'mysql+aiomysql'),
**DBConfig.ASYNC_DATABASE_ARGS
)
# 创建异步会话工厂
AsyncSessionLocal = async_sessionmaker(
autocommit=False,
autoflush=False,
expire_on_commit=False,
class_=AsyncSession,
bind=async_engine
)
def get_session():
"""
获取数据库会话
Returns:
Session: 数据库会话对象
"""
return db_session()
def get_db():
"""
获取数据库会话生成器用于FastAPI依赖注入
Yields:
Session: 数据库会话对象
"""
db = get_session()
try:
yield db
finally:
db.close()
@asynccontextmanager
async def get_async_session():
"""
获取异步数据库会话使用异步上下文管理器
Yields:
AsyncSession: 异步数据库会话对象
"""
session = AsyncSessionLocal()
try:
yield session
await session.commit()
except Exception as e:
await session.rollback()
raise e
finally:
await session.close()
async def get_async_db():
"""
获取异步数据库会话生成器用于FastAPI依赖注入
Yields:
AsyncSession: 异步数据库会话对象
"""
async with AsyncSessionLocal() as session:
yield session
@contextmanager
def session_scope():
"""
会话上下文管理器
提供自动提交和异常回滚功能
Yields:
Session: 数据库会话对象
"""
session = get_session()
try:
yield session
session.commit()
except Exception as e:
session.rollback()
raise e
finally:
session.close()
def init_database():
"""
初始化数据库和表结构
在应用启动时调用
Returns:
bool: 初始化是否成功
"""
logger.info("正在初始化数据库...")
try:
# 尝试连接到MySQL服务器并创建数据库如果不存在
root_engine = create_engine(DBConfig.create_db_url(include_db_name=False))
with root_engine.connect() as conn:
conn.execute(text(f"CREATE DATABASE IF NOT EXISTS {settings.DB_NAME} CHARACTER SET {settings.DB_CHARSET} COLLATE {settings.DB_CHARSET}_unicode_ci"))
logger.info(f"数据库 {settings.DB_NAME} 已创建或已存在")
# 连接到数据库并创建表
Base.metadata.create_all(bind=engine)
logger.info("数据库表初始化完成")
return True
except Exception as e:
logger.error(f"数据库初始化失败: {str(e)}")
logger.error(traceback.format_exc())
return False
def close_database_connections():
"""
关闭数据库连接池
在应用关闭时调用
"""
logger.info("正在关闭数据库连接...")
try:
engine.dispose()
logger.info("数据库连接已关闭")
except Exception as e:
logger.error(f"关闭数据库连接时出错: {str(e)}")
async def close_async_database_connections():
"""
关闭异步数据库连接池
在应用关闭时调用
"""
logger.info("正在关闭异步数据库连接...")
try:
await async_engine.dispose()
logger.info("异步数据库连接已关闭")
except Exception as e:
logger.error(f"关闭异步数据库连接时出错: {str(e)}")