From b2f04bc0f4732e6448c404c93cdb67b099664c90 Mon Sep 17 00:00:00 2001 From: martsforever Date: Wed, 1 Apr 2026 21:55:31 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20LlmDemoModel=E7=A4=BA=E4=BE=8B=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E5=A2=9E=E5=88=A0=E6=94=B9=E6=9F=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/model/LlmDemoMdel.py | 30 ++++++---- app/model/LlmUserModel.py | 118 -------------------------------------- app/routes.py | 1 - app/utils/model_utils.py | 102 ++++++++++++++++++++++++++++++++ 4 files changed, 120 insertions(+), 131 deletions(-) delete mode 100644 app/model/LlmUserModel.py create mode 100644 app/utils/model_utils.py diff --git a/app/model/LlmDemoMdel.py b/app/model/LlmDemoMdel.py index dc9d583..8874c0d 100644 --- a/app/model/LlmDemoMdel.py +++ b/app/model/LlmDemoMdel.py @@ -4,31 +4,37 @@ from decimal import Decimal from typing import List from fastapi import FastAPI, APIRouter, HTTPException -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from sqlmodel import SQLModel, Field, select +from app.utils.model_utils import to_camel, format_datetime_to_string, format_date_to_string, FormattedDatetime, FormattedDate, current_datetime from app.utils.mysql_utils import AsyncSessionDep -# 定义北京时区(UTC+8) -beijing_timezone = timezone(timedelta(hours=8)) - -# 定义获取当前北京时区时间的匿名函数,用于默认值生成 -current_datetime = lambda: datetime.now(beijing_timezone) - class LlmDemoModel(SQLModel, table=True): + # Pydantic V2 的模型配置 + model_config = ConfigDict( + alias_generator=to_camel, # 使用 to_camel 函数生成别名 + populate_by_name=True, # 允许通过原始字段名(snake_case)赋值 + extra='ignore', # 忽略模型中未定义的额外字段,避免验证失败 + json_encoders={ + datetime: format_datetime_to_string, # 为 datetime 类型指定自定义的 JSON 编码器 + date: format_date_to_string # 为 date 类型指定自定义的 JSON 编码器 + } + ) + __tablename__ = "llm_demo" id: str = Field(default=None, primary_key=True, description="唯一标识,编号") - created_at: datetime = Field(default_factory=current_datetime, description="创建时间") - updated_at: datetime = Field(default_factory=current_datetime, description="更新时间") + created_at: FormattedDatetime = Field(default_factory=current_datetime, description="创建时间") + updated_at: FormattedDatetime = Field(default_factory=current_datetime, description="更新时间") created_by: str | None = Field(default=None, description="创建人id") updated_by: str | None = Field(default=None, description="更新人id") full_name: str = Field(default=None, description="用户名称") - datetime_start: datetime = Field(default=None, description="开通会员时间") - datetime_end: datetime = Field(default=None, description="会员截止到期时间") - birthday: date = Field(default=None, description="生日") + datetime_start: FormattedDatetime = Field(default=None, description="开通会员时间") + datetime_end: FormattedDatetime = Field(default=None, description="会员截止到期时间") + birthday: FormattedDate = Field(default=None, description="生日") amount: Decimal = Field(default=0, description="金额") diff --git a/app/model/LlmUserModel.py b/app/model/LlmUserModel.py deleted file mode 100644 index 22b9b7c..0000000 --- a/app/model/LlmUserModel.py +++ /dev/null @@ -1,118 +0,0 @@ -import uuid -from datetime import datetime, timezone, timedelta, date -from typing import List - -from fastapi import FastAPI, APIRouter, HTTPException -from pydantic import BaseModel -from sqlmodel import SQLModel, Field, select - -from app.utils.mysql_utils import AsyncSessionDep - -# 定义北京时区(UTC+8) -beijing_timezone = timezone(timedelta(hours=8)) - -# 定义获取当前北京时区时间的匿名函数,用于默认值生成 -current_datetime = lambda: datetime.now(beijing_timezone) - - -class LlmUserModel(SQLModel, table=True): - __tablename__ = "llm_user" - - # 唯一标识字段,主键,默认为None(通常由系统生成),描述为“唯一标识,编号” - id: str = Field(default=None, primary_key=True, description="唯一标识,编号") - # 创建时间字段,默认值为当前北京时区时间,描述为“创建时间” - created_at: datetime = Field(default_factory=current_datetime, description="创建时间") - # 更新时间字段,默认值为当前北京时区时间,描述为“更新时间” - updated_at: datetime = Field(default_factory=current_datetime, description="更新时间") - # 创建人ID字段,默认为None,描述为“创建人id” - created_by: str | None = Field(default=None, description="创建人id") - # 更新人ID字段,默认为None,描述为“更新人id” - updated_by: str | None = Field(default=None, description="更新人id") - - full_name: str = Field(default=None, description="用户名称") - username: str = Field(default=None, description="用户名") - password: str = Field(default=None, description="用户密码") - member_start: datetime = Field(default=None, description="开通会员时间") - member_end: datetime = Field(default=None, description="会员截止到期时间") - - -def add_llm_user_route(app: FastAPI): - route_prefix = "/llm_user" - router = APIRouter(prefix=route_prefix, tags=[route_prefix]) - - @router.post('/list') - async def _list(body: ModelQuerySchema, session: AsyncSessionDep): - """ - 获取所有用户列表 - """ - - # 构建SQL查询执行对象 - query = select(LlmUserModel) - # 计算偏移量(跳过前N条),并查询比一页多1条的记录(用于判断是否有下一页) - query = query.offset(body.page * body.page_size).limit(body.page_size + 1) - result = await session.execute(query) - query_cls_list: List[LlmUserModel] = result.scalars().all() - has_next = len(query_cls_list) == body.page_size + 1 - # 若有下一页,移除多查询的那一条记录 - if has_next: - query_cls_list.pop() - return { - "list": query_cls_list, - "has_next": has_next - } - - @router.post('/insert') - async def _insert(body: LlmUserModel, session: AsyncSessionDep): - """ - 插入用户 - """ - if not body.id: - body.id = str(uuid.uuid4()) - session.add(body) - await session.commit() - return body - - @router.post('/update') - async def _update(body: LlmUserModel, session: AsyncSessionDep): - """ - 更新用户 - """ - query = select(LlmUserModel).where(LlmUserModel.id == body.id) - result = await session.execute(query) - query_cls = result.scalars().first() - if not query_cls: - raise HTTPException(status_code=500, detail=f"Update row with id:{body.id} not found") - - update_data_exclude_true = body.model_dump(exclude_unset=True, exclude={'id'}) - update_data_exclude_false = body.model_dump(exclude={'id'}) - - print(f"update_data_exclude_true: {update_data_exclude_true}") - print(f"update_data_exclude_false: {update_data_exclude_false}") - - update_data = body.model_dump(exclude_unset=True, exclude={'id'}) - for field, value in update_data.items(): - setattr(query_cls, field, value) - session.add(query_cls) - await session.commit() - return query_cls - - @router.post('/delete') - async def _delete(body: LlmUserModel, session: AsyncSessionDep): - """ - 更新用户 - """ - query = select(LlmUserModel).where(LlmUserModel.id == body.id) - result = await session.execute(query) - query_cls = result.scalars().first() - if not query_cls: - return {"affect_row_count": 0} - await session.delete(query_cls) - await session.commit() - return {"affect_row_count": 1} - - app.include_router(router) - - -class ModelQuerySchema(BaseModel): - page: int = Field(default=0, description="页码") - page_size: int = Field(default=5, description="每页数量") diff --git a/app/routes.py b/app/routes.py index 0920fce..eeaa384 100644 --- a/app/routes.py +++ b/app/routes.py @@ -3,7 +3,6 @@ from app.controller.add_docs_route import add_docs_route from app.controller.add_graph_proxy_route import add_graph_proxy_route from app.controller.add_test_route import add_test_route from app.model.LlmDemoMdel import add_llm_demo_route -from app.model.LlmUserModel import add_llm_user_route # /*@formatter:off*/ routes = [ diff --git a/app/utils/model_utils.py b/app/utils/model_utils.py new file mode 100644 index 0000000..15341eb --- /dev/null +++ b/app/utils/model_utils.py @@ -0,0 +1,102 @@ +import uuid +from datetime import datetime, timezone, timedelta, date +from typing import List, Annotated, Any + +from fastapi import FastAPI, APIRouter, HTTPException +from pydantic import BaseModel, ConfigDict, BeforeValidator +from sqlmodel import SQLModel, Field, select + +from app.utils.mysql_utils import AsyncSessionDep + +# 定义北京时区(UTC+8) +beijing_timezone = timezone(timedelta(hours=8)) + +# 定义获取当前北京时区时间的匿名函数,用于默认值生成 +current_datetime = lambda: datetime.now(beijing_timezone) + +# /*---------------------------------------datetime-------------------------------------------*/ + +# 日期时间格式 +DATETIME_FORMAT = "%Y-%m-%d %H:%M:%S" + + +# 辅助函数:将 datetime 格式化为字符串 +def format_datetime_to_string(dt: datetime) -> str: + if isinstance(dt, str): + return dt + if dt is None: + return None + # 确保 datetime 对象是带时区的,如果不是则假设为北京时间 + if dt.tzinfo is None: + dt = dt.replace(tzinfo=beijing_timezone) + return dt.astimezone(beijing_timezone).strftime(DATETIME_FORMAT) + + +# 辅助函数:将字符串解析为 datetime +def parse_datetime_from_string(dt_str: Any) -> datetime | None: + if dt_str is None or dt_str == "": + return None + if isinstance(dt_str, datetime): # 如果已经是datetime对象,直接返回 + return dt_str + try: + # 尝试解析,并明确设置为北京时区 + return datetime.strptime(str(dt_str), DATETIME_FORMAT).replace(tzinfo=beijing_timezone) + except ValueError: + # 如果解析失败,Pydantic 会处理验证错误 + raise ValueError(f"Invalid datetime format. Expected '{DATETIME_FORMAT}'") + + +# 定义一个 Annotated 类型,用于在 Pydantic 字段中应用解析器 +# 当从输入数据(如JSON字符串)转换为 datetime 对象时,会先经过这个解析器 +# 这里使用 BeforeValidator,因为它在 Pydantic 自己的验证之前运行 +# 对于 SQLModel (基于Pydantic),这在从数据库加载数据或从请求体解析数据时都适用 +FormattedDatetime = Annotated[ + datetime, + BeforeValidator(parse_datetime_from_string) +] + +# /*---------------------------------------date-------------------------------------------*/ + +# 日期格式 +DATE_FORMAT = "%Y-%m-%d" + + +# 辅助函数:将 date 格式化为字符串 +def format_date_to_string(d: date) -> str: + if isinstance(d, str): + return d + if d is None: + return None + return d.strftime(DATE_FORMAT) + + +# 辅助函数:将字符串解析为 date +def parse_date_from_string(d_str: Any) -> date | None: + if d_str is None or d_str == "": + return None + if isinstance(d_str, date): # 如果已经是 date 对象,直接返回 + return d_str + try: + # 尝试解析 + return datetime.strptime(str(d_str), DATE_FORMAT).date() + except ValueError: + # 如果解析失败,Pydantic 会处理验证错误 + raise ValueError(f"Invalid date format. Expected '{DATE_FORMAT}'") + + +# 定义一个 Annotated 类型,用于在 Pydantic 字段中应用解析器 +# 当从输入数据(如 JSON 字符串)转换为 date 对象时,会先经过这个解析器 +FormattedDate = Annotated[ + date, + BeforeValidator(parse_date_from_string) +] + + +# /*---------------------------------------other-------------------------------------------*/ + +def to_camel(snake_str: str) -> str: + """Converts a snake_case string to camelCase.""" + components = snake_str.split('_') + # We capitalize the first letter of each component except the first one + # and join them to form the camelCase string. + return components[0] + ''.join(x.title() for x in components[1:])