2025-03-18 18:34:03 +08:00

378 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
数据库连接配置模块
包含数据库连接参数和SQLAlchemy配置以及Redis缓存配置
"""
import os
import json
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, scoped_session
import traceback
import sys
class ConfigDict:
"""配置字典类,支持通过点号访问配置项"""
def __init__(self, **kwargs):
for key, value in kwargs.items():
if isinstance(value, dict):
setattr(self, key, ConfigDict(**value))
else:
setattr(self, key, value)
def get(self, key, default=None):
return getattr(self, key, default)
def to_dict(self):
result = {}
for key, value in self.__dict__.items():
if isinstance(value, ConfigDict):
result[key] = value.to_dict()
else:
result[key] = value
return result
# 数据库连接配置
DB_CONFIG = ConfigDict(
default=dict(
dialect='mysql',
driver='pymysql',
username='root',
password='root',
host='localhost',
port=3306,
database='tianfeng_task',
charset='utf8mb4'
),
test=dict(
dialect='sqlite',
database=':memory:'
)
)
# Redis缓存配置
REDIS_CONFIG = ConfigDict(
default=dict(
host='localhost',
port=6379,
db=0,
password=None,
prefix='tianfeng:',
socket_timeout=5,
socket_connect_timeout=5,
decode_responses=True
),
test=dict(
host='localhost',
port=6379,
db=1,
password=None,
prefix='tianfeng_test:',
decode_responses=True
)
)
# 当前环境,可通过环境变量设置
ENV = os.environ.get('TIANFENG_ENV', 'default')
# 根据环境获取数据库配置
db_conf = getattr(DB_CONFIG, ENV)
# 构建数据库连接URL
if db_conf.dialect == 'sqlite':
DATABASE_URL = f"{db_conf.dialect}:///{db_conf.database}"
else:
DATABASE_URL = (
f"{db_conf.dialect}+{db_conf.driver}://"
f"{db_conf.username}:{db_conf.password}@"
f"{db_conf.host}:{db_conf.port}/{db_conf.database}?"
f"charset={db_conf.charset}"
)
# 创建数据库引擎
engine = create_engine(
DATABASE_URL,
pool_size=20,
max_overflow=0,
pool_recycle=3600,
pool_pre_ping=True,
echo=False # 设置为True可以显示SQL语句用于调试
)
# 创建会话工厂
SessionFactory = sessionmaker(bind=engine)
# 创建线程安全的会话
db_session = scoped_session(SessionFactory)
# 创建基类
Base = declarative_base()
Base.query = db_session.query_property()
# 数据库配置类
class DBConfig:
"""数据库配置类,提供数据库相关的配置和方法"""
config = DB_CONFIG
env = ENV
url = DATABASE_URL
engine = engine
session = db_session
base = Base
@classmethod
def get_config(cls):
"""获取当前环境的数据库配置"""
return getattr(cls.config, cls.env)
@classmethod
def get_session(cls):
"""获取数据库会话"""
return cls.session
@classmethod
def init_db(cls):
"""
初始化数据库
创建所有表
"""
# 测试数据库连接
try:
print(f"尝试连接数据库: {cls.url}")
connection = cls.engine.connect()
print("数据库连接成功!")
connection.close()
except Exception as e:
print(f"数据库连接失败: {str(e)}")
print("详细错误信息:")
traceback.print_exc(file=sys.stdout)
raise
# 导入所有模型确保它们已注册到Base
import data.models
# 首先尝试创建数据库(如果不存在)
if cls.get_config().dialect != 'sqlite':
from sqlalchemy import text
# 创建一个不指定数据库的连接
db_conf = cls.get_config()
try:
print(f"尝试创建数据库 {db_conf.database} (如果不存在)")
temp_url = (
f"{db_conf.dialect}+{db_conf.driver}://"
f"{db_conf.username}:{db_conf.password}@"
f"{db_conf.host}:{db_conf.port}/"
f"?charset={db_conf.charset}"
)
print(f"临时连接URL: {temp_url}")
temp_engine = create_engine(temp_url)
with temp_engine.connect() as conn:
conn.execute(text(f"CREATE DATABASE IF NOT EXISTS {db_conf.database} CHARACTER SET {db_conf.charset} COLLATE {db_conf.charset}_unicode_ci;"))
conn.commit()
temp_engine.dispose()
print(f"数据库 {db_conf.database} 创建或已存在")
except Exception as e:
print(f"创建数据库失败: {str(e)}")
print("详细错误信息:")
traceback.print_exc(file=sys.stdout)
raise
# 创建所有表
try:
print("开始创建所有表...")
cls.base.metadata.create_all(bind=cls.engine)
print("所有表创建成功")
except Exception as e:
print(f"创建表失败: {str(e)}")
print("详细错误信息:")
traceback.print_exc(file=sys.stdout)
raise
@classmethod
def shutdown_session(cls, exception=None):
"""
关闭会话
在应用程序关闭时调用
"""
cls.session.remove()
# 缓存配置类
class CacheConfig:
"""缓存配置类提供Redis缓存相关的配置和方法"""
config = REDIS_CONFIG
env = ENV
_redis_client = None
@classmethod
def get_config(cls):
"""获取当前环境的Redis配置"""
return getattr(cls.config, cls.env)
@classmethod
def get_redis_client(cls):
"""获取Redis客户端实例"""
if cls._redis_client is None:
try:
import redis
redis_conf = cls.get_config()
cls._redis_client = redis.Redis(
host=redis_conf.host,
port=redis_conf.port,
db=redis_conf.db,
password=redis_conf.password,
socket_timeout=getattr(redis_conf, 'socket_timeout', 5),
socket_connect_timeout=getattr(redis_conf, 'socket_connect_timeout', 5),
decode_responses=getattr(redis_conf, 'decode_responses', True)
)
except ImportError:
raise ImportError("Redis package is not installed. Please install it with 'pip install redis'")
except Exception as e:
print(f"Error connecting to Redis: {e}")
return None
return cls._redis_client
@classmethod
def get_key(cls, key):
"""获取带前缀的缓存键"""
prefix = getattr(cls.get_config(), 'prefix', 'tianfeng:')
return f"{prefix}{key}"
@classmethod
def set(cls, key, value, expire=None):
"""
设置缓存
Args:
key (str): 缓存键
value (any): 缓存值非字符串类型会被JSON序列化
expire (int, optional): 过期时间(秒)
Returns:
bool: 是否设置成功
"""
redis_client = cls.get_redis_client()
if not redis_client:
return False
if not isinstance(value, (str, int, float, bool)):
value = json.dumps(value)
full_key = cls.get_key(key)
if expire:
return redis_client.setex(full_key, expire, value)
else:
return redis_client.set(full_key, value)
@classmethod
def get(cls, key, default=None):
"""
获取缓存
Args:
key (str): 缓存键
default (any, optional): 默认值
Returns:
any: 缓存值或默认值
"""
redis_client = cls.get_redis_client()
if not redis_client:
return default
full_key = cls.get_key(key)
value = redis_client.get(full_key)
if value is None:
return default
# 尝试解析JSON
try:
if value.startswith('{') or value.startswith('['):
return json.loads(value)
except (json.JSONDecodeError, AttributeError):
pass
return value
@classmethod
def delete(cls, key):
"""
删除缓存
Args:
key (str): 缓存键
Returns:
bool: 是否删除成功
"""
redis_client = cls.get_redis_client()
if not redis_client:
return False
full_key = cls.get_key(key)
return redis_client.delete(full_key) > 0
@classmethod
def exists(cls, key):
"""
检查缓存是否存在
Args:
key (str): 缓存键
Returns:
bool: 是否存在
"""
redis_client = cls.get_redis_client()
if not redis_client:
return False
full_key = cls.get_key(key)
return redis_client.exists(full_key) > 0
@classmethod
def ttl(cls, key):
"""
获取缓存剩余过期时间
Args:
key (str): 缓存键
Returns:
int: 剩余秒数,-1表示永不过期-2表示不存在
"""
redis_client = cls.get_redis_client()
if not redis_client:
return -2
full_key = cls.get_key(key)
return redis_client.ttl(full_key)
@classmethod
def clear_all(cls):
"""
清除当前环境下的所有缓存
Returns:
bool: 是否清除成功
"""
redis_client = cls.get_redis_client()
if not redis_client:
return False
prefix = getattr(cls.get_config(), 'prefix', 'tianfeng:')
keys = redis_client.keys(f"{prefix}*")
if keys:
return redis_client.delete(*keys) > 0
return True
# 兼容旧代码的函数
def init_db():
"""初始化数据库(兼容旧代码)"""
DBConfig.init_db()
def shutdown_session(exception=None):
"""关闭会话(兼容旧代码)"""
DBConfig.shutdown_session(exception)