diff --git a/app/controller/add_file_route.py b/app/controller/add_file_route.py new file mode 100644 index 0000000..92fa2ed --- /dev/null +++ b/app/controller/add_file_route.py @@ -0,0 +1,102 @@ +import asyncio +import os.path +from typing import List + +from fastapi import FastAPI, UploadFile, Form + +from app.config.env import env +from app.model.FileModel import FileSaveService +from app.utils.mysql_utils import AsyncSessionDep +from app.utils.path_join import path_join + + +def add_file_route(app: FastAPI): + async def _save_file( + session: AsyncSessionDep, + file: UploadFile, + body: dict, + ): + filename = file.filename + return await FileSaveService.saveFile( + session=session, + file=file, + filename=filename, + id=body.get('id', None), + file_record=body, + ) + + async def _save_file_list( + session: AsyncSessionDep, + file_list: List[UploadFile], + body: dict, + ): + async_task_list = [asyncio.create_task( + FileSaveService.saveFile( + session=session, + file=item, + filename=item.filename, + file_record=body, + ) + ) for item in file_list] + result_list = await asyncio.gather(*async_task_list) + return {"result": result_list} + + async def _delete_file(row_dict: dict): + file_public_path = row_dict.get('path') + path_list: List[str] = file_public_path.split('/') + original_name = path_list.pop() + file_id = path_list.pop() + save_file_path = path_join(env.file_save_path, file_id, original_name) + + # 删除文件 + try: + os.remove(save_file_path) + except FileNotFoundError: + print("文件不存在:" + save_file_path) + + # 删除文件夹 + save_dir_path = path_join(env.file_save_path, file_id) + try: + os.rmdir(save_dir_path) + except FileNotFoundError: + print("文件夹不存在:" + save_dir_path) + + return {"result": True} + + # 上传文件接口,文件会持久化保存 + @app.post('/save_file') + async def save_file( + file: UploadFile, + session: AsyncSessionDep, + head_id: str = Form(default=None, description="父对象id"), + attr1: str = Form(default=None, description="扩展属性1"), + attr2: str = Form(default=None, description="扩展属性2"), + attr3: str = Form(default=None, description="扩展属性3") + ): + return await _save_file(session, file, { + "head_id": head_id, + "attr1": attr1, + "attr2": attr2, + "attr3": attr3 + }) + + # 上传文件接口,文件不会持久化保存,仅用于临时保存,只是为了验证 + @app.post('/upload_file') + async def save_file( + file: UploadFile, + session: AsyncSessionDep, + head_id: str = Form(default=None, description="父对象id"), + attr1: str = Form(default=None, description="扩展属性1"), + attr2: str = Form(default=None, description="扩展属性2"), + attr3: str = Form(default=None, description="扩展属性3") + ): + result = await _save_file(session, file, { + "head_id": head_id, + "attr1": attr1, + "attr2": attr2, + "attr3": attr3 + }) + if 'result' in result: + print(f"upload_file:自动删除文件「{result['result']['path']}」") + await _delete_file(result['result']) + return result diff --git a/app/model/FileModel.py b/app/model/FileModel.py new file mode 100644 index 0000000..cb6ef4e --- /dev/null +++ b/app/model/FileModel.py @@ -0,0 +1,114 @@ +import datetime +from pathlib import Path +from typing import Optional, Callable, Awaitable + +import aiofiles +from anyio._core._fileio import ReadableBuffer +from fastapi import UploadFile +from sqlmodel import Field + +from app.config.env import env +from app.model.BasicModel import BasicModel +from app.utils.model_utils import to_obj +from app.utils.mysql_utils import AsyncSessionDep +from app.utils.next_id import next_id +from app.utils.path_join import path_join + + +class FileModel(BasicModel, table=True): + __tablename__ = "pl_upload" + + name: str | None = Field(default=None, description='文件名称') + path: str | None = Field(default=None, description='文件路径') + head_id: str | None = Field(default=None, description='父对象id') + attr1: str | None = Field(default=None, description='扩展属性1') + attr2: str | None = Field(default=None, description='扩展属性2') + attr3: str | None = Field(default=None, description='扩展属性3') + content: str | None = Field(default=None, description='扩展属性内容文本') + + +# 文件保存服务 +class FileSaveService: + # 将文件保存到服务本地目录 + # 并且往附件表中插入对应的文件记录 + @staticmethod + async def saveFile( + session: AsyncSessionDep, + file: UploadFile, + filename: str, + id: str = None, + file_record: dict = None, + ): + file_cls = await async_save_file( + filename=filename, + session=session, + aget_file_blob=file.read, + id=id, + file_record=file_record, + ) + return {"result": file_cls.model_dump()} + + +async def async_save_file( + # 保存的文件名 + filename: str, + # 数据库会话管理对象 + session: AsyncSessionDep, + # 异步获取文件buffer的方法 + aget_file_blob: Callable[[], Awaitable[ReadableBuffer]], + # 文件id,没有则自动生成 + id: str | None = None, + # 其他额外的要保存附件记录中的字段信息 + file_record: dict | None = None, +): + """ + 保存文件工具函数 + """ + + if not id: + id = await next_id() + + datetime_string = datetime.datetime.now().strftime("%Y%m%d%H%M%S") + file_id = f"{datetime_string}_{id}" + + # 文件的保存路径,用来将文件写入到磁盘 + # save_path=/www/wwwroot/web/web/upload_file/时间+随机ID + save_path = path_join(env.server_file_save_path, file_id) + print("save_path", save_path) + + # 文件的访问路径,用来生成文件的访问路径,然后存到附件表中的path字段中 + # public_path=/web/upload_file/时间+随机ID + public_path = path_join(env.server_file_public_path, file_id) + print("public_path", public_path) + + # parents=True 表示创建所有不存在的父目录 + # exist_ok=True 表示如果目录已存在不抛出异常 + Path(save_path).mkdir(parents=True, exist_ok=True) + + file_save_path = path_join(save_path, filename) # 文件的保存路径 + file_public_path = path_join(public_path, filename) # 文件的访问路径 + + # with open(file_save_path, 'wb') as f: + # f.write(await file.read()) + file_blob = await aget_file_blob() + async with aiofiles.open(file_save_path, 'wb') as f: + await f.write(file_blob) + await f.flush() # 确保数据写入磁盘 + + file_dict = { + "id": id, + "name": filename, + "path": file_public_path, + **file_record, + } + new_file_obj = to_obj(FileModel, file_dict) + session.add(new_file_obj) + await session.commit() + return new_file_obj + + +def get_file_record_save_path(file_record_path: str): + """ + 获取文件的保存路径 + """ + return env.file_save_path + file_record_path[len(env.file_public_path):] diff --git a/app/routes.py b/app/routes.py index 4fb207d..76e4105 100644 --- a/app/routes.py +++ b/app/routes.py @@ -1,9 +1,11 @@ from app.controller.add_complex_search_route import add_complex_search_route from app.controller.add_custom_stream_route import add_custom_stream_route from app.controller.add_docs_route import add_docs_route +from app.controller.add_file_route import add_file_route from app.controller.add_graph_proxy_route import add_graph_proxy_route from app.controller.add_test_route import add_test_route from app.model.LlmDemoMdel import add_llm_demo_route +from app.utils.next_id import add_next_id_route # /*@formatter:off*/ routes = [ @@ -13,5 +15,7 @@ routes = [ add_custom_stream_route, # 自定义流式接口 add_llm_demo_route, # LlmDemo 测试用户模块 add_complex_search_route, # 多条件组合查询案例 + add_next_id_route, # 生成ID接口 + add_file_route, # 文件上传接口 ] # /*@formatter:on*/