193 lines
5.1 KiB
Python
193 lines
5.1 KiB
Python
#!/usr/bin/env python
|
|
# -*- coding: utf-8 -*-
|
|
|
|
"""
|
|
数据库会话管理模块
|
|
提供数据库会话获取和管理功能
|
|
"""
|
|
|
|
from contextlib import contextmanager
|
|
from config.database import DBConfig, CacheConfig
|
|
from config.component_config import (
|
|
COMPONENT_CATEGORIES,
|
|
get_component_types,
|
|
get_system_components,
|
|
CACHE_EXPIRE_TIME,
|
|
CACHE_KEYS
|
|
)
|
|
|
|
def get_session():
|
|
"""
|
|
获取数据库会话
|
|
|
|
Returns:
|
|
Session: 数据库会话对象
|
|
"""
|
|
return DBConfig.get_session()
|
|
|
|
def get_db():
|
|
"""
|
|
获取数据库会话生成器
|
|
|
|
Yields:
|
|
Session: 数据库会话对象
|
|
"""
|
|
db = get_session()
|
|
try:
|
|
yield db
|
|
finally:
|
|
db.close()
|
|
|
|
@contextmanager
|
|
def session_scope():
|
|
"""
|
|
会话上下文管理器
|
|
提供自动提交和异常回滚功能
|
|
|
|
Yields:
|
|
Session: 数据库会话对象
|
|
"""
|
|
session = get_session()
|
|
try:
|
|
yield session
|
|
session.commit()
|
|
except Exception as e:
|
|
session.rollback()
|
|
raise e
|
|
finally:
|
|
session.close()
|
|
|
|
def initialize_database():
|
|
"""
|
|
初始化数据库
|
|
创建所有表并初始化基础数据
|
|
"""
|
|
# 创建所有表
|
|
DBConfig.init_db()
|
|
|
|
# 初始化基础数据
|
|
_initialize_component_categories()
|
|
_initialize_component_types()
|
|
_initialize_system_components()
|
|
|
|
print("数据库初始化完成")
|
|
|
|
def _initialize_component_categories():
|
|
"""初始化组件类别"""
|
|
from data.models.component import ComponentCategory
|
|
|
|
# 检查是否已存在数据
|
|
if ComponentCategory.query.count() > 0:
|
|
return
|
|
|
|
# 从配置文件获取组件类别
|
|
categories = COMPONENT_CATEGORIES
|
|
|
|
# 批量创建
|
|
with session_scope() as session:
|
|
for category_data in categories:
|
|
category = ComponentCategory(**category_data)
|
|
session.add(category)
|
|
|
|
# 缓存组件类别
|
|
_cache_component_categories()
|
|
|
|
def _initialize_component_types():
|
|
"""初始化组件类型"""
|
|
from data.models.component import ComponentType, ComponentCategory
|
|
|
|
# 检查是否已存在数据
|
|
if ComponentType.query.count() > 0:
|
|
return
|
|
|
|
# 获取组件类别
|
|
categories = {category.code: category.id for category in ComponentCategory.query.all()}
|
|
|
|
# 从配置文件获取组件类型
|
|
types = get_component_types(categories)
|
|
|
|
# 批量创建
|
|
with session_scope() as session:
|
|
for type_data in types:
|
|
component_type = ComponentType(**type_data)
|
|
session.add(component_type)
|
|
|
|
# 缓存组件类型
|
|
_cache_component_types()
|
|
|
|
def _initialize_system_components():
|
|
"""初始化系统组件"""
|
|
from data.models.component import Component, ComponentType
|
|
|
|
# 检查是否已存在数据
|
|
if Component.query.count() > 0:
|
|
return
|
|
|
|
# 获取组件类型
|
|
types = {component_type.code: component_type.id for component_type in ComponentType.query.all()}
|
|
|
|
# 从配置文件获取系统组件
|
|
components = get_system_components(types)
|
|
|
|
# 批量创建
|
|
with session_scope() as session:
|
|
for component_data in components:
|
|
component = Component(**component_data)
|
|
session.add(component)
|
|
|
|
# 缓存系统组件
|
|
_cache_system_components()
|
|
|
|
def _cache_component_categories():
|
|
"""缓存组件类别"""
|
|
from data.models.component import ComponentCategory
|
|
|
|
categories = ComponentCategory.query.all()
|
|
categories_dict = {category.id: {
|
|
"id": category.id,
|
|
"name": category.name,
|
|
"code": category.code,
|
|
"description": category.description,
|
|
"sort_order": category.sort_order
|
|
} for category in categories}
|
|
|
|
# 使用Redis缓存
|
|
CacheConfig.set(CACHE_KEYS["COMPONENT_CATEGORIES"], categories_dict, expire=CACHE_EXPIRE_TIME)
|
|
|
|
def _cache_component_types():
|
|
"""缓存组件类型"""
|
|
from data.models.component import ComponentType
|
|
|
|
types = ComponentType.query.all()
|
|
types_dict = {type.id: {
|
|
"id": type.id,
|
|
"name": type.name,
|
|
"code": type.code,
|
|
"category_id": type.category_id,
|
|
"description": type.description,
|
|
"icon": type.icon,
|
|
"sort_order": type.sort_order
|
|
} for type in types}
|
|
|
|
# 使用Redis缓存
|
|
CacheConfig.set(CACHE_KEYS["COMPONENT_TYPES"], types_dict, expire=CACHE_EXPIRE_TIME)
|
|
|
|
def _cache_system_components():
|
|
"""缓存系统组件"""
|
|
from data.models.component import Component
|
|
|
|
components = Component.query.filter(Component.is_system == True).all()
|
|
components_dict = {component.id: {
|
|
"id": component.id,
|
|
"name": component.name,
|
|
"code": component.code,
|
|
"type_id": component.type_id,
|
|
"description": component.description,
|
|
"config_schema": component.config_schema,
|
|
"input_schema": component.input_schema,
|
|
"output_schema": component.output_schema
|
|
} for component in components}
|
|
|
|
# 使用Redis缓存
|
|
CacheConfig.set(CACHE_KEYS["SYSTEM_COMPONENTS"], components_dict, expire=CACHE_EXPIRE_TIME)
|