378 lines
11 KiB
Python
378 lines
11 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
|
||
"""
|
||
数据库连接配置模块
|
||
包含数据库连接参数和SQLAlchemy配置,以及Redis缓存配置
|
||
"""
|
||
|
||
import os
|
||
import json
|
||
from sqlalchemy import create_engine
|
||
from sqlalchemy.ext.declarative import declarative_base
|
||
from sqlalchemy.orm import sessionmaker, scoped_session
|
||
import traceback
|
||
import sys
|
||
|
||
class ConfigDict:
|
||
"""配置字典类,支持通过点号访问配置项"""
|
||
def __init__(self, **kwargs):
|
||
for key, value in kwargs.items():
|
||
if isinstance(value, dict):
|
||
setattr(self, key, ConfigDict(**value))
|
||
else:
|
||
setattr(self, key, value)
|
||
|
||
def get(self, key, default=None):
|
||
return getattr(self, key, default)
|
||
|
||
def to_dict(self):
|
||
result = {}
|
||
for key, value in self.__dict__.items():
|
||
if isinstance(value, ConfigDict):
|
||
result[key] = value.to_dict()
|
||
else:
|
||
result[key] = value
|
||
return result
|
||
|
||
# 数据库连接配置
|
||
DB_CONFIG = ConfigDict(
|
||
default=dict(
|
||
dialect='mysql',
|
||
driver='pymysql',
|
||
username='root',
|
||
password='root',
|
||
host='localhost',
|
||
port=3306,
|
||
database='tianfeng_task',
|
||
charset='utf8mb4'
|
||
),
|
||
test=dict(
|
||
dialect='sqlite',
|
||
database=':memory:'
|
||
)
|
||
)
|
||
|
||
# Redis缓存配置
|
||
REDIS_CONFIG = ConfigDict(
|
||
default=dict(
|
||
host='localhost',
|
||
port=6379,
|
||
db=0,
|
||
password=None,
|
||
prefix='tianfeng:',
|
||
socket_timeout=5,
|
||
socket_connect_timeout=5,
|
||
decode_responses=True
|
||
),
|
||
test=dict(
|
||
host='localhost',
|
||
port=6379,
|
||
db=1,
|
||
password=None,
|
||
prefix='tianfeng_test:',
|
||
decode_responses=True
|
||
)
|
||
)
|
||
|
||
# 当前环境,可通过环境变量设置
|
||
ENV = os.environ.get('TIANFENG_ENV', 'default')
|
||
|
||
# 根据环境获取数据库配置
|
||
db_conf = getattr(DB_CONFIG, ENV)
|
||
|
||
# 构建数据库连接URL
|
||
if db_conf.dialect == 'sqlite':
|
||
DATABASE_URL = f"{db_conf.dialect}:///{db_conf.database}"
|
||
else:
|
||
DATABASE_URL = (
|
||
f"{db_conf.dialect}+{db_conf.driver}://"
|
||
f"{db_conf.username}:{db_conf.password}@"
|
||
f"{db_conf.host}:{db_conf.port}/{db_conf.database}?"
|
||
f"charset={db_conf.charset}"
|
||
)
|
||
|
||
# 创建数据库引擎
|
||
engine = create_engine(
|
||
DATABASE_URL,
|
||
pool_size=20,
|
||
max_overflow=0,
|
||
pool_recycle=3600,
|
||
pool_pre_ping=True,
|
||
echo=False # 设置为True可以显示SQL语句,用于调试
|
||
)
|
||
|
||
# 创建会话工厂
|
||
SessionFactory = sessionmaker(bind=engine)
|
||
|
||
# 创建线程安全的会话
|
||
db_session = scoped_session(SessionFactory)
|
||
|
||
# 创建基类
|
||
Base = declarative_base()
|
||
Base.query = db_session.query_property()
|
||
|
||
# 数据库配置类
|
||
class DBConfig:
|
||
"""数据库配置类,提供数据库相关的配置和方法"""
|
||
config = DB_CONFIG
|
||
env = ENV
|
||
url = DATABASE_URL
|
||
engine = engine
|
||
session = db_session
|
||
base = Base
|
||
|
||
@classmethod
|
||
def get_config(cls):
|
||
"""获取当前环境的数据库配置"""
|
||
return getattr(cls.config, cls.env)
|
||
|
||
@classmethod
|
||
def get_session(cls):
|
||
"""获取数据库会话"""
|
||
return cls.session
|
||
|
||
@classmethod
|
||
def init_db(cls):
|
||
"""
|
||
初始化数据库
|
||
创建所有表
|
||
"""
|
||
# 测试数据库连接
|
||
try:
|
||
print(f"尝试连接数据库: {cls.url}")
|
||
connection = cls.engine.connect()
|
||
print("数据库连接成功!")
|
||
connection.close()
|
||
except Exception as e:
|
||
print(f"数据库连接失败: {str(e)}")
|
||
print("详细错误信息:")
|
||
traceback.print_exc(file=sys.stdout)
|
||
raise
|
||
|
||
# 导入所有模型,确保它们已注册到Base
|
||
import data.models
|
||
|
||
# 首先尝试创建数据库(如果不存在)
|
||
if cls.get_config().dialect != 'sqlite':
|
||
from sqlalchemy import text
|
||
# 创建一个不指定数据库的连接
|
||
db_conf = cls.get_config()
|
||
try:
|
||
print(f"尝试创建数据库 {db_conf.database} (如果不存在)")
|
||
temp_url = (
|
||
f"{db_conf.dialect}+{db_conf.driver}://"
|
||
f"{db_conf.username}:{db_conf.password}@"
|
||
f"{db_conf.host}:{db_conf.port}/"
|
||
f"?charset={db_conf.charset}"
|
||
)
|
||
print(f"临时连接URL: {temp_url}")
|
||
temp_engine = create_engine(temp_url)
|
||
with temp_engine.connect() as conn:
|
||
conn.execute(text(f"CREATE DATABASE IF NOT EXISTS {db_conf.database} CHARACTER SET {db_conf.charset} COLLATE {db_conf.charset}_unicode_ci;"))
|
||
conn.commit()
|
||
temp_engine.dispose()
|
||
print(f"数据库 {db_conf.database} 创建或已存在")
|
||
except Exception as e:
|
||
print(f"创建数据库失败: {str(e)}")
|
||
print("详细错误信息:")
|
||
traceback.print_exc(file=sys.stdout)
|
||
raise
|
||
|
||
# 创建所有表
|
||
try:
|
||
print("开始创建所有表...")
|
||
cls.base.metadata.create_all(bind=cls.engine)
|
||
print("所有表创建成功")
|
||
except Exception as e:
|
||
print(f"创建表失败: {str(e)}")
|
||
print("详细错误信息:")
|
||
traceback.print_exc(file=sys.stdout)
|
||
raise
|
||
|
||
@classmethod
|
||
def shutdown_session(cls, exception=None):
|
||
"""
|
||
关闭会话
|
||
在应用程序关闭时调用
|
||
"""
|
||
cls.session.remove()
|
||
|
||
# 缓存配置类
|
||
class CacheConfig:
|
||
"""缓存配置类,提供Redis缓存相关的配置和方法"""
|
||
config = REDIS_CONFIG
|
||
env = ENV
|
||
_redis_client = None
|
||
|
||
@classmethod
|
||
def get_config(cls):
|
||
"""获取当前环境的Redis配置"""
|
||
return getattr(cls.config, cls.env)
|
||
|
||
@classmethod
|
||
def get_redis_client(cls):
|
||
"""获取Redis客户端实例"""
|
||
if cls._redis_client is None:
|
||
try:
|
||
import redis
|
||
redis_conf = cls.get_config()
|
||
cls._redis_client = redis.Redis(
|
||
host=redis_conf.host,
|
||
port=redis_conf.port,
|
||
db=redis_conf.db,
|
||
password=redis_conf.password,
|
||
socket_timeout=getattr(redis_conf, 'socket_timeout', 5),
|
||
socket_connect_timeout=getattr(redis_conf, 'socket_connect_timeout', 5),
|
||
decode_responses=getattr(redis_conf, 'decode_responses', True)
|
||
)
|
||
except ImportError:
|
||
raise ImportError("Redis package is not installed. Please install it with 'pip install redis'")
|
||
except Exception as e:
|
||
print(f"Error connecting to Redis: {e}")
|
||
return None
|
||
return cls._redis_client
|
||
|
||
@classmethod
|
||
def get_key(cls, key):
|
||
"""获取带前缀的缓存键"""
|
||
prefix = getattr(cls.get_config(), 'prefix', 'tianfeng:')
|
||
return f"{prefix}{key}"
|
||
|
||
@classmethod
|
||
def set(cls, key, value, expire=None):
|
||
"""
|
||
设置缓存
|
||
|
||
Args:
|
||
key (str): 缓存键
|
||
value (any): 缓存值,非字符串类型会被JSON序列化
|
||
expire (int, optional): 过期时间(秒)
|
||
|
||
Returns:
|
||
bool: 是否设置成功
|
||
"""
|
||
redis_client = cls.get_redis_client()
|
||
if not redis_client:
|
||
return False
|
||
|
||
if not isinstance(value, (str, int, float, bool)):
|
||
value = json.dumps(value)
|
||
|
||
full_key = cls.get_key(key)
|
||
if expire:
|
||
return redis_client.setex(full_key, expire, value)
|
||
else:
|
||
return redis_client.set(full_key, value)
|
||
|
||
@classmethod
|
||
def get(cls, key, default=None):
|
||
"""
|
||
获取缓存
|
||
|
||
Args:
|
||
key (str): 缓存键
|
||
default (any, optional): 默认值
|
||
|
||
Returns:
|
||
any: 缓存值或默认值
|
||
"""
|
||
redis_client = cls.get_redis_client()
|
||
if not redis_client:
|
||
return default
|
||
|
||
full_key = cls.get_key(key)
|
||
value = redis_client.get(full_key)
|
||
|
||
if value is None:
|
||
return default
|
||
|
||
# 尝试解析JSON
|
||
try:
|
||
if value.startswith('{') or value.startswith('['):
|
||
return json.loads(value)
|
||
except (json.JSONDecodeError, AttributeError):
|
||
pass
|
||
|
||
return value
|
||
|
||
@classmethod
|
||
def delete(cls, key):
|
||
"""
|
||
删除缓存
|
||
|
||
Args:
|
||
key (str): 缓存键
|
||
|
||
Returns:
|
||
bool: 是否删除成功
|
||
"""
|
||
redis_client = cls.get_redis_client()
|
||
if not redis_client:
|
||
return False
|
||
|
||
full_key = cls.get_key(key)
|
||
return redis_client.delete(full_key) > 0
|
||
|
||
@classmethod
|
||
def exists(cls, key):
|
||
"""
|
||
检查缓存是否存在
|
||
|
||
Args:
|
||
key (str): 缓存键
|
||
|
||
Returns:
|
||
bool: 是否存在
|
||
"""
|
||
redis_client = cls.get_redis_client()
|
||
if not redis_client:
|
||
return False
|
||
|
||
full_key = cls.get_key(key)
|
||
return redis_client.exists(full_key) > 0
|
||
|
||
@classmethod
|
||
def ttl(cls, key):
|
||
"""
|
||
获取缓存剩余过期时间
|
||
|
||
Args:
|
||
key (str): 缓存键
|
||
|
||
Returns:
|
||
int: 剩余秒数,-1表示永不过期,-2表示不存在
|
||
"""
|
||
redis_client = cls.get_redis_client()
|
||
if not redis_client:
|
||
return -2
|
||
|
||
full_key = cls.get_key(key)
|
||
return redis_client.ttl(full_key)
|
||
|
||
@classmethod
|
||
def clear_all(cls):
|
||
"""
|
||
清除当前环境下的所有缓存
|
||
|
||
Returns:
|
||
bool: 是否清除成功
|
||
"""
|
||
redis_client = cls.get_redis_client()
|
||
if not redis_client:
|
||
return False
|
||
|
||
prefix = getattr(cls.get_config(), 'prefix', 'tianfeng:')
|
||
keys = redis_client.keys(f"{prefix}*")
|
||
if keys:
|
||
return redis_client.delete(*keys) > 0
|
||
return True
|
||
|
||
# 兼容旧代码的函数
|
||
def init_db():
|
||
"""初始化数据库(兼容旧代码)"""
|
||
DBConfig.init_db()
|
||
|
||
def shutdown_session(exception=None):
|
||
"""关闭会话(兼容旧代码)"""
|
||
DBConfig.shutdown_session(exception) |