2025-03-17 18:31:20 +08:00

193 lines
5.1 KiB
Python

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
数据库会话管理模块
提供数据库会话获取和管理功能
"""
from contextlib import contextmanager
from config.database import DBConfig, CacheConfig
from config.component_config import (
COMPONENT_CATEGORIES,
get_component_types,
get_system_components,
CACHE_EXPIRE_TIME,
CACHE_KEYS
)
def get_session():
"""
获取数据库会话
Returns:
Session: 数据库会话对象
"""
return DBConfig.get_session()
def get_db():
"""
获取数据库会话生成器
Yields:
Session: 数据库会话对象
"""
db = get_session()
try:
yield db
finally:
db.close()
@contextmanager
def session_scope():
"""
会话上下文管理器
提供自动提交和异常回滚功能
Yields:
Session: 数据库会话对象
"""
session = get_session()
try:
yield session
session.commit()
except Exception as e:
session.rollback()
raise e
finally:
session.close()
def initialize_database():
"""
初始化数据库
创建所有表并初始化基础数据
"""
# 创建所有表
DBConfig.init_db()
# 初始化基础数据
_initialize_component_categories()
_initialize_component_types()
_initialize_system_components()
print("数据库初始化完成")
def _initialize_component_categories():
"""初始化组件类别"""
from data.models.component import ComponentCategory
# 检查是否已存在数据
if ComponentCategory.query.count() > 0:
return
# 从配置文件获取组件类别
categories = COMPONENT_CATEGORIES
# 批量创建
with session_scope() as session:
for category_data in categories:
category = ComponentCategory(**category_data)
session.add(category)
# 缓存组件类别
_cache_component_categories()
def _initialize_component_types():
"""初始化组件类型"""
from data.models.component import ComponentType, ComponentCategory
# 检查是否已存在数据
if ComponentType.query.count() > 0:
return
# 获取组件类别
categories = {category.code: category.id for category in ComponentCategory.query.all()}
# 从配置文件获取组件类型
types = get_component_types(categories)
# 批量创建
with session_scope() as session:
for type_data in types:
component_type = ComponentType(**type_data)
session.add(component_type)
# 缓存组件类型
_cache_component_types()
def _initialize_system_components():
"""初始化系统组件"""
from data.models.component import Component, ComponentType
# 检查是否已存在数据
if Component.query.count() > 0:
return
# 获取组件类型
types = {component_type.code: component_type.id for component_type in ComponentType.query.all()}
# 从配置文件获取系统组件
components = get_system_components(types)
# 批量创建
with session_scope() as session:
for component_data in components:
component = Component(**component_data)
session.add(component)
# 缓存系统组件
_cache_system_components()
def _cache_component_categories():
"""缓存组件类别"""
from data.models.component import ComponentCategory
categories = ComponentCategory.query.all()
categories_dict = {category.id: {
"id": category.id,
"name": category.name,
"code": category.code,
"description": category.description,
"sort_order": category.sort_order
} for category in categories}
# 使用Redis缓存
CacheConfig.set(CACHE_KEYS["COMPONENT_CATEGORIES"], categories_dict, expire=CACHE_EXPIRE_TIME)
def _cache_component_types():
"""缓存组件类型"""
from data.models.component import ComponentType
types = ComponentType.query.all()
types_dict = {type.id: {
"id": type.id,
"name": type.name,
"code": type.code,
"category_id": type.category_id,
"description": type.description,
"icon": type.icon,
"sort_order": type.sort_order
} for type in types}
# 使用Redis缓存
CacheConfig.set(CACHE_KEYS["COMPONENT_TYPES"], types_dict, expire=CACHE_EXPIRE_TIME)
def _cache_system_components():
"""缓存系统组件"""
from data.models.component import Component
components = Component.query.filter(Component.is_system == True).all()
components_dict = {component.id: {
"id": component.id,
"name": component.name,
"code": component.code,
"type_id": component.type_id,
"description": component.description,
"config_schema": component.config_schema,
"input_schema": component.input_schema,
"output_schema": component.output_schema
} for component in components}
# 使用Redis缓存
CacheConfig.set(CACHE_KEYS["SYSTEM_COMPONENTS"], components_dict, expire=CACHE_EXPIRE_TIME)