127 lines
4.0 KiB
Python
Raw Permalink Normal View History

2025-03-17 14:58:05 +08:00
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
基础模型模块
包含所有模型共用的字段和方法
"""
import datetime
from sqlalchemy import Column, Integer, DateTime, Boolean
from config.database import Base
class BaseModel(Base):
"""
基础模型类
包含所有模型共用的字段和方法
"""
__abstract__ = True # 声明为抽象类,不会创建表
id = Column(Integer, primary_key=True, autoincrement=True, comment='主键ID')
created_at = Column(DateTime, default=datetime.datetime.now, comment='创建时间')
updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, comment='更新时间')
is_deleted = Column(Boolean, default=False, comment='是否删除(软删除标记)')
def to_dict(self):
"""
将模型转换为字典
用于API响应
"""
result = {}
for column in self.__table__.columns:
value = getattr(self, column.name)
# 处理日期时间类型
if isinstance(value, datetime.datetime):
value = value.strftime('%Y-%m-%d %H:%M:%S')
result[column.name] = value
return result
def to_json(self, fields=None, exclude=None, timestamp_format='ms'):
"""
将模型转换为JSON友好的字典
Args:
fields (list, optional): 需要包含的字段列表为None则包含所有字段
exclude (list, optional): 需要排除的字段列表
timestamp_format (str, optional): 时间戳格式可选值'ms'毫秒时间戳, 'iso'ISO格式, 'str'字符串格式
Returns:
dict: 包含指定字段的字典
"""
result = {}
# 获取所有列名
columns = [column.name for column in self.__table__.columns]
# 如果指定了fields则只包含这些字段
if fields:
columns = [col for col in columns if col in fields]
# 排除指定的字段
if exclude:
columns = [col for col in columns if col not in exclude]
# 获取字段值
for column in columns:
value = getattr(self, column)
# 处理日期时间类型
if isinstance(value, datetime.datetime):
if timestamp_format == 'ms':
# 转换为毫秒时间戳
value = int(value.timestamp() * 1000)
elif timestamp_format == 'iso':
# 转换为ISO格式
value = value.isoformat()
else:
# 转换为字符串格式
value = value.strftime('%Y-%m-%d %H:%M:%S')
# 处理枚举类型
elif hasattr(value, 'name') and hasattr(value, 'value'):
# 如果是枚举类型,返回其值
value = value.value
result[column] = value
return result
@classmethod
def get_by_id(cls, id):
"""
根据ID获取记录
"""
return cls.query.filter(cls.id == id, cls.is_deleted == False).first()
@classmethod
def get_all(cls, page=1, per_page=20):
"""
获取所有记录分页
"""
return cls.query.filter(cls.is_deleted == False).paginate(page=page, per_page=per_page)
def save(self):
"""
保存记录
"""
from config.database import db_session
db_session.add(self)
db_session.commit()
return self
def delete(self):
"""
删除记录软删除
"""
self.is_deleted = True
self.save()
return self
def hard_delete(self):
"""
硬删除记录
"""
from config.database import db_session
db_session.delete(self)
db_session.commit()
return self