#!/usr/bin/env python # -*- coding: utf-8 -*- """ 数据库会话管理模块 提供数据库会话获取和管理功能 """ from contextlib import contextmanager from config.database import DBConfig, CacheConfig from config.component_config import ( COMPONENT_CATEGORIES, get_component_types, get_system_components, CACHE_EXPIRE_TIME, CACHE_KEYS ) def get_session(): """ 获取数据库会话 Returns: Session: 数据库会话对象 """ return DBConfig.get_session() def get_db(): """ 获取数据库会话生成器 Yields: Session: 数据库会话对象 """ db = get_session() try: yield db finally: db.close() @contextmanager def session_scope(): """ 会话上下文管理器 提供自动提交和异常回滚功能 Yields: Session: 数据库会话对象 """ session = get_session() try: yield session session.commit() except Exception as e: session.rollback() raise e finally: session.close() def initialize_database(): """ 初始化数据库 创建所有表并初始化基础数据 """ # 创建所有表 DBConfig.init_db() # 初始化基础数据 _initialize_component_categories() _initialize_component_types() _initialize_system_components() print("数据库初始化完成") def _initialize_component_categories(): """初始化组件类别""" from data.models.component import ComponentCategory # 检查是否已存在数据 if ComponentCategory.query.count() > 0: return # 从配置文件获取组件类别 categories = COMPONENT_CATEGORIES # 批量创建 with session_scope() as session: for category_data in categories: category = ComponentCategory(**category_data) session.add(category) # 缓存组件类别 _cache_component_categories() def _initialize_component_types(): """初始化组件类型""" from data.models.component import ComponentType, ComponentCategory # 检查是否已存在数据 if ComponentType.query.count() > 0: return # 获取组件类别 categories = {category.code: category.id for category in ComponentCategory.query.all()} # 从配置文件获取组件类型 types = get_component_types(categories) # 批量创建 with session_scope() as session: for type_data in types: component_type = ComponentType(**type_data) session.add(component_type) # 缓存组件类型 _cache_component_types() def _initialize_system_components(): """初始化系统组件""" from data.models.component import Component, ComponentType # 检查是否已存在数据 if Component.query.count() > 0: return # 获取组件类型 types = {component_type.code: component_type.id for component_type in ComponentType.query.all()} # 从配置文件获取系统组件 components = get_system_components(types) # 批量创建 with session_scope() as session: for component_data in components: component = Component(**component_data) session.add(component) # 缓存系统组件 _cache_system_components() def _cache_component_categories(): """缓存组件类别""" from data.models.component import ComponentCategory categories = ComponentCategory.query.all() categories_dict = {category.id: { "id": category.id, "name": category.name, "code": category.code, "description": category.description, "sort_order": category.sort_order } for category in categories} # 使用Redis缓存 CacheConfig.set(CACHE_KEYS["COMPONENT_CATEGORIES"], categories_dict, expire=CACHE_EXPIRE_TIME) def _cache_component_types(): """缓存组件类型""" from data.models.component import ComponentType types = ComponentType.query.all() types_dict = {type.id: { "id": type.id, "name": type.name, "code": type.code, "category_id": type.category_id, "description": type.description, "icon": type.icon, "sort_order": type.sort_order } for type in types} # 使用Redis缓存 CacheConfig.set(CACHE_KEYS["COMPONENT_TYPES"], types_dict, expire=CACHE_EXPIRE_TIME) def _cache_system_components(): """缓存系统组件""" from data.models.component import Component components = Component.query.filter(Component.is_system == True).all() components_dict = {component.id: { "id": component.id, "name": component.name, "code": component.code, "type_id": component.type_id, "description": component.description, "config_schema": component.config_schema, "input_schema": component.input_schema, "output_schema": component.output_schema } for component in components} # 使用Redis缓存 CacheConfig.set(CACHE_KEYS["SYSTEM_COMPONENTS"], components_dict, expire=CACHE_EXPIRE_TIME)