127 lines
4.0 KiB
Python
127 lines
4.0 KiB
Python
|
#!/usr/bin/env python
|
|||
|
# -*- coding: utf-8 -*-
|
|||
|
|
|||
|
"""
|
|||
|
基础模型模块
|
|||
|
包含所有模型共用的字段和方法
|
|||
|
"""
|
|||
|
|
|||
|
import datetime
|
|||
|
from sqlalchemy import Column, Integer, DateTime, Boolean
|
|||
|
from config.database import Base
|
|||
|
|
|||
|
class BaseModel(Base):
|
|||
|
"""
|
|||
|
基础模型类
|
|||
|
包含所有模型共用的字段和方法
|
|||
|
"""
|
|||
|
__abstract__ = True # 声明为抽象类,不会创建表
|
|||
|
|
|||
|
id = Column(Integer, primary_key=True, autoincrement=True, comment='主键ID')
|
|||
|
created_at = Column(DateTime, default=datetime.datetime.now, comment='创建时间')
|
|||
|
updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, comment='更新时间')
|
|||
|
is_deleted = Column(Boolean, default=False, comment='是否删除(软删除标记)')
|
|||
|
|
|||
|
def to_dict(self):
|
|||
|
"""
|
|||
|
将模型转换为字典
|
|||
|
用于API响应
|
|||
|
"""
|
|||
|
result = {}
|
|||
|
for column in self.__table__.columns:
|
|||
|
value = getattr(self, column.name)
|
|||
|
# 处理日期时间类型
|
|||
|
if isinstance(value, datetime.datetime):
|
|||
|
value = value.strftime('%Y-%m-%d %H:%M:%S')
|
|||
|
result[column.name] = value
|
|||
|
return result
|
|||
|
|
|||
|
def to_json(self, fields=None, exclude=None, timestamp_format='ms'):
|
|||
|
"""
|
|||
|
将模型转换为JSON友好的字典
|
|||
|
|
|||
|
Args:
|
|||
|
fields (list, optional): 需要包含的字段列表,为None则包含所有字段
|
|||
|
exclude (list, optional): 需要排除的字段列表
|
|||
|
timestamp_format (str, optional): 时间戳格式,可选值:'ms'(毫秒时间戳), 'iso'(ISO格式), 'str'(字符串格式)
|
|||
|
|
|||
|
Returns:
|
|||
|
dict: 包含指定字段的字典
|
|||
|
"""
|
|||
|
result = {}
|
|||
|
|
|||
|
# 获取所有列名
|
|||
|
columns = [column.name for column in self.__table__.columns]
|
|||
|
|
|||
|
# 如果指定了fields,则只包含这些字段
|
|||
|
if fields:
|
|||
|
columns = [col for col in columns if col in fields]
|
|||
|
|
|||
|
# 排除指定的字段
|
|||
|
if exclude:
|
|||
|
columns = [col for col in columns if col not in exclude]
|
|||
|
|
|||
|
# 获取字段值
|
|||
|
for column in columns:
|
|||
|
value = getattr(self, column)
|
|||
|
|
|||
|
# 处理日期时间类型
|
|||
|
if isinstance(value, datetime.datetime):
|
|||
|
if timestamp_format == 'ms':
|
|||
|
# 转换为毫秒时间戳
|
|||
|
value = int(value.timestamp() * 1000)
|
|||
|
elif timestamp_format == 'iso':
|
|||
|
# 转换为ISO格式
|
|||
|
value = value.isoformat()
|
|||
|
else:
|
|||
|
# 转换为字符串格式
|
|||
|
value = value.strftime('%Y-%m-%d %H:%M:%S')
|
|||
|
|
|||
|
# 处理枚举类型
|
|||
|
elif hasattr(value, 'name') and hasattr(value, 'value'):
|
|||
|
# 如果是枚举类型,返回其值
|
|||
|
value = value.value
|
|||
|
|
|||
|
result[column] = value
|
|||
|
|
|||
|
return result
|
|||
|
|
|||
|
@classmethod
|
|||
|
def get_by_id(cls, id):
|
|||
|
"""
|
|||
|
根据ID获取记录
|
|||
|
"""
|
|||
|
return cls.query.filter(cls.id == id, cls.is_deleted == False).first()
|
|||
|
|
|||
|
@classmethod
|
|||
|
def get_all(cls, page=1, per_page=20):
|
|||
|
"""
|
|||
|
获取所有记录(分页)
|
|||
|
"""
|
|||
|
return cls.query.filter(cls.is_deleted == False).paginate(page=page, per_page=per_page)
|
|||
|
|
|||
|
def save(self):
|
|||
|
"""
|
|||
|
保存记录
|
|||
|
"""
|
|||
|
from config.database import db_session
|
|||
|
db_session.add(self)
|
|||
|
db_session.commit()
|
|||
|
return self
|
|||
|
|
|||
|
def delete(self):
|
|||
|
"""
|
|||
|
删除记录(软删除)
|
|||
|
"""
|
|||
|
self.is_deleted = True
|
|||
|
self.save()
|
|||
|
return self
|
|||
|
|
|||
|
def hard_delete(self):
|
|||
|
"""
|
|||
|
硬删除记录
|
|||
|
"""
|
|||
|
from config.database import db_session
|
|||
|
db_session.delete(self)
|
|||
|
db_session.commit()
|
|||
|
return self
|