feat: LlmDemoModel示例数据增删改查
This commit is contained in:
+18
-12
@@ -4,31 +4,37 @@ from decimal import Decimal
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from fastapi import FastAPI, APIRouter, HTTPException
|
from fastapi import FastAPI, APIRouter, HTTPException
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, ConfigDict
|
||||||
from sqlmodel import SQLModel, Field, select
|
from sqlmodel import SQLModel, Field, select
|
||||||
|
|
||||||
|
from app.utils.model_utils import to_camel, format_datetime_to_string, format_date_to_string, FormattedDatetime, FormattedDate, current_datetime
|
||||||
from app.utils.mysql_utils import AsyncSessionDep
|
from app.utils.mysql_utils import AsyncSessionDep
|
||||||
|
|
||||||
# 定义北京时区(UTC+8)
|
|
||||||
beijing_timezone = timezone(timedelta(hours=8))
|
|
||||||
|
|
||||||
# 定义获取当前北京时区时间的匿名函数,用于默认值生成
|
|
||||||
current_datetime = lambda: datetime.now(beijing_timezone)
|
|
||||||
|
|
||||||
|
|
||||||
class LlmDemoModel(SQLModel, table=True):
|
class LlmDemoModel(SQLModel, table=True):
|
||||||
|
# Pydantic V2 的模型配置
|
||||||
|
model_config = ConfigDict(
|
||||||
|
alias_generator=to_camel, # 使用 to_camel 函数生成别名
|
||||||
|
populate_by_name=True, # 允许通过原始字段名(snake_case)赋值
|
||||||
|
extra='ignore', # 忽略模型中未定义的额外字段,避免验证失败
|
||||||
|
json_encoders={
|
||||||
|
datetime: format_datetime_to_string, # 为 datetime 类型指定自定义的 JSON 编码器
|
||||||
|
date: format_date_to_string # 为 date 类型指定自定义的 JSON 编码器
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
__tablename__ = "llm_demo"
|
__tablename__ = "llm_demo"
|
||||||
|
|
||||||
id: str = Field(default=None, primary_key=True, description="唯一标识,编号")
|
id: str = Field(default=None, primary_key=True, description="唯一标识,编号")
|
||||||
created_at: datetime = Field(default_factory=current_datetime, description="创建时间")
|
created_at: FormattedDatetime = Field(default_factory=current_datetime, description="创建时间")
|
||||||
updated_at: datetime = Field(default_factory=current_datetime, description="更新时间")
|
updated_at: FormattedDatetime = Field(default_factory=current_datetime, description="更新时间")
|
||||||
created_by: str | None = Field(default=None, description="创建人id")
|
created_by: str | None = Field(default=None, description="创建人id")
|
||||||
updated_by: str | None = Field(default=None, description="更新人id")
|
updated_by: str | None = Field(default=None, description="更新人id")
|
||||||
|
|
||||||
full_name: str = Field(default=None, description="用户名称")
|
full_name: str = Field(default=None, description="用户名称")
|
||||||
datetime_start: datetime = Field(default=None, description="开通会员时间")
|
datetime_start: FormattedDatetime = Field(default=None, description="开通会员时间")
|
||||||
datetime_end: datetime = Field(default=None, description="会员截止到期时间")
|
datetime_end: FormattedDatetime = Field(default=None, description="会员截止到期时间")
|
||||||
birthday: date = Field(default=None, description="生日")
|
birthday: FormattedDate = Field(default=None, description="生日")
|
||||||
amount: Decimal = Field(default=0, description="金额")
|
amount: Decimal = Field(default=0, description="金额")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,118 +0,0 @@
|
|||||||
import uuid
|
|
||||||
from datetime import datetime, timezone, timedelta, date
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from fastapi import FastAPI, APIRouter, HTTPException
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from sqlmodel import SQLModel, Field, select
|
|
||||||
|
|
||||||
from app.utils.mysql_utils import AsyncSessionDep
|
|
||||||
|
|
||||||
# 定义北京时区(UTC+8)
|
|
||||||
beijing_timezone = timezone(timedelta(hours=8))
|
|
||||||
|
|
||||||
# 定义获取当前北京时区时间的匿名函数,用于默认值生成
|
|
||||||
current_datetime = lambda: datetime.now(beijing_timezone)
|
|
||||||
|
|
||||||
|
|
||||||
class LlmUserModel(SQLModel, table=True):
|
|
||||||
__tablename__ = "llm_user"
|
|
||||||
|
|
||||||
# 唯一标识字段,主键,默认为None(通常由系统生成),描述为“唯一标识,编号”
|
|
||||||
id: str = Field(default=None, primary_key=True, description="唯一标识,编号")
|
|
||||||
# 创建时间字段,默认值为当前北京时区时间,描述为“创建时间”
|
|
||||||
created_at: datetime = Field(default_factory=current_datetime, description="创建时间")
|
|
||||||
# 更新时间字段,默认值为当前北京时区时间,描述为“更新时间”
|
|
||||||
updated_at: datetime = Field(default_factory=current_datetime, description="更新时间")
|
|
||||||
# 创建人ID字段,默认为None,描述为“创建人id”
|
|
||||||
created_by: str | None = Field(default=None, description="创建人id")
|
|
||||||
# 更新人ID字段,默认为None,描述为“更新人id”
|
|
||||||
updated_by: str | None = Field(default=None, description="更新人id")
|
|
||||||
|
|
||||||
full_name: str = Field(default=None, description="用户名称")
|
|
||||||
username: str = Field(default=None, description="用户名")
|
|
||||||
password: str = Field(default=None, description="用户密码")
|
|
||||||
member_start: datetime = Field(default=None, description="开通会员时间")
|
|
||||||
member_end: datetime = Field(default=None, description="会员截止到期时间")
|
|
||||||
|
|
||||||
|
|
||||||
def add_llm_user_route(app: FastAPI):
|
|
||||||
route_prefix = "/llm_user"
|
|
||||||
router = APIRouter(prefix=route_prefix, tags=[route_prefix])
|
|
||||||
|
|
||||||
@router.post('/list')
|
|
||||||
async def _list(body: ModelQuerySchema, session: AsyncSessionDep):
|
|
||||||
"""
|
|
||||||
获取所有用户列表
|
|
||||||
"""
|
|
||||||
|
|
||||||
# 构建SQL查询执行对象
|
|
||||||
query = select(LlmUserModel)
|
|
||||||
# 计算偏移量(跳过前N条),并查询比一页多1条的记录(用于判断是否有下一页)
|
|
||||||
query = query.offset(body.page * body.page_size).limit(body.page_size + 1)
|
|
||||||
result = await session.execute(query)
|
|
||||||
query_cls_list: List[LlmUserModel] = result.scalars().all()
|
|
||||||
has_next = len(query_cls_list) == body.page_size + 1
|
|
||||||
# 若有下一页,移除多查询的那一条记录
|
|
||||||
if has_next:
|
|
||||||
query_cls_list.pop()
|
|
||||||
return {
|
|
||||||
"list": query_cls_list,
|
|
||||||
"has_next": has_next
|
|
||||||
}
|
|
||||||
|
|
||||||
@router.post('/insert')
|
|
||||||
async def _insert(body: LlmUserModel, session: AsyncSessionDep):
|
|
||||||
"""
|
|
||||||
插入用户
|
|
||||||
"""
|
|
||||||
if not body.id:
|
|
||||||
body.id = str(uuid.uuid4())
|
|
||||||
session.add(body)
|
|
||||||
await session.commit()
|
|
||||||
return body
|
|
||||||
|
|
||||||
@router.post('/update')
|
|
||||||
async def _update(body: LlmUserModel, session: AsyncSessionDep):
|
|
||||||
"""
|
|
||||||
更新用户
|
|
||||||
"""
|
|
||||||
query = select(LlmUserModel).where(LlmUserModel.id == body.id)
|
|
||||||
result = await session.execute(query)
|
|
||||||
query_cls = result.scalars().first()
|
|
||||||
if not query_cls:
|
|
||||||
raise HTTPException(status_code=500, detail=f"Update row with id:{body.id} not found")
|
|
||||||
|
|
||||||
update_data_exclude_true = body.model_dump(exclude_unset=True, exclude={'id'})
|
|
||||||
update_data_exclude_false = body.model_dump(exclude={'id'})
|
|
||||||
|
|
||||||
print(f"update_data_exclude_true: {update_data_exclude_true}")
|
|
||||||
print(f"update_data_exclude_false: {update_data_exclude_false}")
|
|
||||||
|
|
||||||
update_data = body.model_dump(exclude_unset=True, exclude={'id'})
|
|
||||||
for field, value in update_data.items():
|
|
||||||
setattr(query_cls, field, value)
|
|
||||||
session.add(query_cls)
|
|
||||||
await session.commit()
|
|
||||||
return query_cls
|
|
||||||
|
|
||||||
@router.post('/delete')
|
|
||||||
async def _delete(body: LlmUserModel, session: AsyncSessionDep):
|
|
||||||
"""
|
|
||||||
更新用户
|
|
||||||
"""
|
|
||||||
query = select(LlmUserModel).where(LlmUserModel.id == body.id)
|
|
||||||
result = await session.execute(query)
|
|
||||||
query_cls = result.scalars().first()
|
|
||||||
if not query_cls:
|
|
||||||
return {"affect_row_count": 0}
|
|
||||||
await session.delete(query_cls)
|
|
||||||
await session.commit()
|
|
||||||
return {"affect_row_count": 1}
|
|
||||||
|
|
||||||
app.include_router(router)
|
|
||||||
|
|
||||||
|
|
||||||
class ModelQuerySchema(BaseModel):
|
|
||||||
page: int = Field(default=0, description="页码")
|
|
||||||
page_size: int = Field(default=5, description="每页数量")
|
|
||||||
@@ -3,7 +3,6 @@ from app.controller.add_docs_route import add_docs_route
|
|||||||
from app.controller.add_graph_proxy_route import add_graph_proxy_route
|
from app.controller.add_graph_proxy_route import add_graph_proxy_route
|
||||||
from app.controller.add_test_route import add_test_route
|
from app.controller.add_test_route import add_test_route
|
||||||
from app.model.LlmDemoMdel import add_llm_demo_route
|
from app.model.LlmDemoMdel import add_llm_demo_route
|
||||||
from app.model.LlmUserModel import add_llm_user_route
|
|
||||||
|
|
||||||
# /*@formatter:off*/
|
# /*@formatter:off*/
|
||||||
routes = [
|
routes = [
|
||||||
|
|||||||
@@ -0,0 +1,102 @@
|
|||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone, timedelta, date
|
||||||
|
from typing import List, Annotated, Any
|
||||||
|
|
||||||
|
from fastapi import FastAPI, APIRouter, HTTPException
|
||||||
|
from pydantic import BaseModel, ConfigDict, BeforeValidator
|
||||||
|
from sqlmodel import SQLModel, Field, select
|
||||||
|
|
||||||
|
from app.utils.mysql_utils import AsyncSessionDep
|
||||||
|
|
||||||
|
# 定义北京时区(UTC+8)
|
||||||
|
beijing_timezone = timezone(timedelta(hours=8))
|
||||||
|
|
||||||
|
# 定义获取当前北京时区时间的匿名函数,用于默认值生成
|
||||||
|
current_datetime = lambda: datetime.now(beijing_timezone)
|
||||||
|
|
||||||
|
# /*---------------------------------------datetime-------------------------------------------*/
|
||||||
|
|
||||||
|
# 日期时间格式
|
||||||
|
DATETIME_FORMAT = "%Y-%m-%d %H:%M:%S"
|
||||||
|
|
||||||
|
|
||||||
|
# 辅助函数:将 datetime 格式化为字符串
|
||||||
|
def format_datetime_to_string(dt: datetime) -> str:
|
||||||
|
if isinstance(dt, str):
|
||||||
|
return dt
|
||||||
|
if dt is None:
|
||||||
|
return None
|
||||||
|
# 确保 datetime 对象是带时区的,如果不是则假设为北京时间
|
||||||
|
if dt.tzinfo is None:
|
||||||
|
dt = dt.replace(tzinfo=beijing_timezone)
|
||||||
|
return dt.astimezone(beijing_timezone).strftime(DATETIME_FORMAT)
|
||||||
|
|
||||||
|
|
||||||
|
# 辅助函数:将字符串解析为 datetime
|
||||||
|
def parse_datetime_from_string(dt_str: Any) -> datetime | None:
|
||||||
|
if dt_str is None or dt_str == "":
|
||||||
|
return None
|
||||||
|
if isinstance(dt_str, datetime): # 如果已经是datetime对象,直接返回
|
||||||
|
return dt_str
|
||||||
|
try:
|
||||||
|
# 尝试解析,并明确设置为北京时区
|
||||||
|
return datetime.strptime(str(dt_str), DATETIME_FORMAT).replace(tzinfo=beijing_timezone)
|
||||||
|
except ValueError:
|
||||||
|
# 如果解析失败,Pydantic 会处理验证错误
|
||||||
|
raise ValueError(f"Invalid datetime format. Expected '{DATETIME_FORMAT}'")
|
||||||
|
|
||||||
|
|
||||||
|
# 定义一个 Annotated 类型,用于在 Pydantic 字段中应用解析器
|
||||||
|
# 当从输入数据(如JSON字符串)转换为 datetime 对象时,会先经过这个解析器
|
||||||
|
# 这里使用 BeforeValidator,因为它在 Pydantic 自己的验证之前运行
|
||||||
|
# 对于 SQLModel (基于Pydantic),这在从数据库加载数据或从请求体解析数据时都适用
|
||||||
|
FormattedDatetime = Annotated[
|
||||||
|
datetime,
|
||||||
|
BeforeValidator(parse_datetime_from_string)
|
||||||
|
]
|
||||||
|
|
||||||
|
# /*---------------------------------------date-------------------------------------------*/
|
||||||
|
|
||||||
|
# 日期格式
|
||||||
|
DATE_FORMAT = "%Y-%m-%d"
|
||||||
|
|
||||||
|
|
||||||
|
# 辅助函数:将 date 格式化为字符串
|
||||||
|
def format_date_to_string(d: date) -> str:
|
||||||
|
if isinstance(d, str):
|
||||||
|
return d
|
||||||
|
if d is None:
|
||||||
|
return None
|
||||||
|
return d.strftime(DATE_FORMAT)
|
||||||
|
|
||||||
|
|
||||||
|
# 辅助函数:将字符串解析为 date
|
||||||
|
def parse_date_from_string(d_str: Any) -> date | None:
|
||||||
|
if d_str is None or d_str == "":
|
||||||
|
return None
|
||||||
|
if isinstance(d_str, date): # 如果已经是 date 对象,直接返回
|
||||||
|
return d_str
|
||||||
|
try:
|
||||||
|
# 尝试解析
|
||||||
|
return datetime.strptime(str(d_str), DATE_FORMAT).date()
|
||||||
|
except ValueError:
|
||||||
|
# 如果解析失败,Pydantic 会处理验证错误
|
||||||
|
raise ValueError(f"Invalid date format. Expected '{DATE_FORMAT}'")
|
||||||
|
|
||||||
|
|
||||||
|
# 定义一个 Annotated 类型,用于在 Pydantic 字段中应用解析器
|
||||||
|
# 当从输入数据(如 JSON 字符串)转换为 date 对象时,会先经过这个解析器
|
||||||
|
FormattedDate = Annotated[
|
||||||
|
date,
|
||||||
|
BeforeValidator(parse_date_from_string)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# /*---------------------------------------other-------------------------------------------*/
|
||||||
|
|
||||||
|
def to_camel(snake_str: str) -> str:
|
||||||
|
"""Converts a snake_case string to camelCase."""
|
||||||
|
components = snake_str.split('_')
|
||||||
|
# We capitalize the first letter of each component except the first one
|
||||||
|
# and join them to form the camelCase string.
|
||||||
|
return components[0] + ''.join(x.title() for x in components[1:])
|
||||||
Reference in New Issue
Block a user