#!/usr/bin/env python # -*- coding: utf-8 -*- """ 数据库连接配置模块 包含数据库连接参数和SQLAlchemy配置,以及Redis缓存配置 """ import os import json from sqlalchemy import create_engine from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker, scoped_session import traceback import sys class ConfigDict: """配置字典类,支持通过点号访问配置项""" def __init__(self, **kwargs): for key, value in kwargs.items(): if isinstance(value, dict): setattr(self, key, ConfigDict(**value)) else: setattr(self, key, value) def get(self, key, default=None): return getattr(self, key, default) def to_dict(self): result = {} for key, value in self.__dict__.items(): if isinstance(value, ConfigDict): result[key] = value.to_dict() else: result[key] = value return result # 数据库连接配置 DB_CONFIG = ConfigDict( default=dict( dialect='mysql', driver='pymysql', username='root', password='root', host='localhost', port=3306, database='tianfeng_task', charset='utf8mb4' ), test=dict( dialect='sqlite', database=':memory:' ) ) # Redis缓存配置 REDIS_CONFIG = ConfigDict( default=dict( host='localhost', port=6379, db=0, password=None, prefix='tianfeng:', socket_timeout=5, socket_connect_timeout=5, decode_responses=True ), test=dict( host='localhost', port=6379, db=1, password=None, prefix='tianfeng_test:', decode_responses=True ) ) # 当前环境,可通过环境变量设置 ENV = os.environ.get('TIANFENG_ENV', 'default') # 根据环境获取数据库配置 db_conf = getattr(DB_CONFIG, ENV) # 构建数据库连接URL if db_conf.dialect == 'sqlite': DATABASE_URL = f"{db_conf.dialect}:///{db_conf.database}" else: DATABASE_URL = ( f"{db_conf.dialect}+{db_conf.driver}://" f"{db_conf.username}:{db_conf.password}@" f"{db_conf.host}:{db_conf.port}/{db_conf.database}?" f"charset={db_conf.charset}" ) # 创建数据库引擎 engine = create_engine( DATABASE_URL, pool_size=20, max_overflow=0, pool_recycle=3600, pool_pre_ping=True, echo=False # 设置为True可以显示SQL语句,用于调试 ) # 创建会话工厂 SessionFactory = sessionmaker(bind=engine) # 创建线程安全的会话 db_session = scoped_session(SessionFactory) # 创建基类 Base = declarative_base() Base.query = db_session.query_property() # 数据库配置类 class DBConfig: """数据库配置类,提供数据库相关的配置和方法""" config = DB_CONFIG env = ENV url = DATABASE_URL engine = engine session = db_session base = Base @classmethod def get_config(cls): """获取当前环境的数据库配置""" return getattr(cls.config, cls.env) @classmethod def get_session(cls): """获取数据库会话""" return cls.session @classmethod def init_db(cls): """ 初始化数据库 创建所有表 """ # 测试数据库连接 try: print(f"尝试连接数据库: {cls.url}") connection = cls.engine.connect() print("数据库连接成功!") connection.close() except Exception as e: print(f"数据库连接失败: {str(e)}") print("详细错误信息:") traceback.print_exc(file=sys.stdout) raise # 导入所有模型,确保它们已注册到Base import data.models # 首先尝试创建数据库(如果不存在) if cls.get_config().dialect != 'sqlite': from sqlalchemy import text # 创建一个不指定数据库的连接 db_conf = cls.get_config() try: print(f"尝试创建数据库 {db_conf.database} (如果不存在)") temp_url = ( f"{db_conf.dialect}+{db_conf.driver}://" f"{db_conf.username}:{db_conf.password}@" f"{db_conf.host}:{db_conf.port}/" f"?charset={db_conf.charset}" ) print(f"临时连接URL: {temp_url}") temp_engine = create_engine(temp_url) with temp_engine.connect() as conn: conn.execute(text(f"CREATE DATABASE IF NOT EXISTS {db_conf.database} CHARACTER SET {db_conf.charset} COLLATE {db_conf.charset}_unicode_ci;")) conn.commit() temp_engine.dispose() print(f"数据库 {db_conf.database} 创建或已存在") except Exception as e: print(f"创建数据库失败: {str(e)}") print("详细错误信息:") traceback.print_exc(file=sys.stdout) raise # 创建所有表 try: print("开始创建所有表...") cls.base.metadata.create_all(bind=cls.engine) print("所有表创建成功") except Exception as e: print(f"创建表失败: {str(e)}") print("详细错误信息:") traceback.print_exc(file=sys.stdout) raise @classmethod def shutdown_session(cls, exception=None): """ 关闭会话 在应用程序关闭时调用 """ cls.session.remove() # 缓存配置类 class CacheConfig: """缓存配置类,提供Redis缓存相关的配置和方法""" config = REDIS_CONFIG env = ENV _redis_client = None @classmethod def get_config(cls): """获取当前环境的Redis配置""" return getattr(cls.config, cls.env) @classmethod def get_redis_client(cls): """获取Redis客户端实例""" if cls._redis_client is None: try: import redis redis_conf = cls.get_config() cls._redis_client = redis.Redis( host=redis_conf.host, port=redis_conf.port, db=redis_conf.db, password=redis_conf.password, socket_timeout=getattr(redis_conf, 'socket_timeout', 5), socket_connect_timeout=getattr(redis_conf, 'socket_connect_timeout', 5), decode_responses=getattr(redis_conf, 'decode_responses', True) ) except ImportError: raise ImportError("Redis package is not installed. Please install it with 'pip install redis'") except Exception as e: print(f"Error connecting to Redis: {e}") return None return cls._redis_client @classmethod def get_key(cls, key): """获取带前缀的缓存键""" prefix = getattr(cls.get_config(), 'prefix', 'tianfeng:') return f"{prefix}{key}" @classmethod def set(cls, key, value, expire=None): """ 设置缓存 Args: key (str): 缓存键 value (any): 缓存值,非字符串类型会被JSON序列化 expire (int, optional): 过期时间(秒) Returns: bool: 是否设置成功 """ redis_client = cls.get_redis_client() if not redis_client: return False if not isinstance(value, (str, int, float, bool)): value = json.dumps(value) full_key = cls.get_key(key) if expire: return redis_client.setex(full_key, expire, value) else: return redis_client.set(full_key, value) @classmethod def get(cls, key, default=None): """ 获取缓存 Args: key (str): 缓存键 default (any, optional): 默认值 Returns: any: 缓存值或默认值 """ redis_client = cls.get_redis_client() if not redis_client: return default full_key = cls.get_key(key) value = redis_client.get(full_key) if value is None: return default # 尝试解析JSON try: if value.startswith('{') or value.startswith('['): return json.loads(value) except (json.JSONDecodeError, AttributeError): pass return value @classmethod def delete(cls, key): """ 删除缓存 Args: key (str): 缓存键 Returns: bool: 是否删除成功 """ redis_client = cls.get_redis_client() if not redis_client: return False full_key = cls.get_key(key) return redis_client.delete(full_key) > 0 @classmethod def exists(cls, key): """ 检查缓存是否存在 Args: key (str): 缓存键 Returns: bool: 是否存在 """ redis_client = cls.get_redis_client() if not redis_client: return False full_key = cls.get_key(key) return redis_client.exists(full_key) > 0 @classmethod def ttl(cls, key): """ 获取缓存剩余过期时间 Args: key (str): 缓存键 Returns: int: 剩余秒数,-1表示永不过期,-2表示不存在 """ redis_client = cls.get_redis_client() if not redis_client: return -2 full_key = cls.get_key(key) return redis_client.ttl(full_key) @classmethod def clear_all(cls): """ 清除当前环境下的所有缓存 Returns: bool: 是否清除成功 """ redis_client = cls.get_redis_client() if not redis_client: return False prefix = getattr(cls.get_config(), 'prefix', 'tianfeng:') keys = redis_client.keys(f"{prefix}*") if keys: return redis_client.delete(*keys) > 0 return True # 兼容旧代码的函数 def init_db(): """初始化数据库(兼容旧代码)""" DBConfig.init_db() def shutdown_session(exception=None): """关闭会话(兼容旧代码)""" DBConfig.shutdown_session(exception)