diff --git a/app/controller/add_graph_proxy_route.py b/app/controller/add_graph_proxy_route.py new file mode 100644 index 0000000..9d1d20c --- /dev/null +++ b/app/controller/add_graph_proxy_route.py @@ -0,0 +1,100 @@ +import json +import random +import time +import uuid +from typing import Annotated, List +from operator import add + +from langchain_core.runnables import RunnableConfig +from typing_extensions import TypedDict +from fastapi import FastAPI +from langgraph.checkpoint.memory import InMemorySaver +from langgraph.constants import START, END +from langgraph.func import task, entrypoint +from langgraph.graph import StateGraph +from langgraph.types import interrupt +from langserve import add_routes +from starlette.responses import StreamingResponse + +from app.utils.PER_REQ_CONFIG_MODIFIER import PER_REQ_CONFIG_MODIFIER + + +def add_graph_proxy_route(app: FastAPI): + """ + 将自定义工作流注册为 FastAPI 的流式接口 + """ + + add_routes( + app=app, + runnable=create_graph(), + path="/graph_proxy", + per_req_config_modifier=PER_REQ_CONFIG_MODIFIER, + ) + + +def create_graph(checkpointer=InMemorySaver()): + """ + 创建 LangGraph 工作流图 + + Args: + checkpointer: 状态检查点保存器,默认为内存保存器 + 支持分布式部署时可替换为 Redis、PostgreSQL 等持久化存储 + + Returns: + 编译后的工作流图实例 + """ + + # 定义工作流状态 schema + class StateSchema(TypedDict): + # name_list: 使用 Annotated 类型注解,指定 add 为归约函数 + # 每次节点返回的 name_list 会自动累加到全局状态中 + name_list: Annotated[List[str], add] + + builder = StateGraph(StateSchema) + + def node_1(state, config: RunnableConfig): + """ + 工作流第一个节点:生成 0-100 的随机数 + 返回格式必须匹配 StateSchema,name_list 会被自动累加到状态中 + """ + random_int = str(random.randint(0, 100)) + "-->>id:" + config.get('configurable').get('thread_id') + print(["🧠节点执行", "node_1", random_int]) + return {"name_list": [f"node_1:{random_int}"]} + + def node_2(state): + """ + 工作流第二个节点:生成 100-200 的随机数 + state 参数包含当前累积的状态(可通过 state['name_list'] 访问) + """ + random_int = random.randint(100, 200) + print(["🧠节点执行", "node_2", random_int]) + return {"name_list": [f"node_2:{random_int}"]} + + def node_3(state): + """ + 工作流第三个节点:生成 300-400 的随机数 + 节点执行完成后,所有 name_list 会通过 add 归约函数合并 + """ + random_int = random.randint(300, 400) + print(["🧠节点执行", "node_3", random_int]) + return {"name_list": [f"node_3:{random_int}"]} + + # 注册三个节点到工作流 + builder.add_node(node_1) + builder.add_node(node_2) + builder.add_node(node_3) + + # 定义节点执行顺序:START -> node_1 -> node_2 -> node_3 -> END + builder.add_edge(START, 'node_1') + builder.add_edge('node_1', 'node_2') + builder.add_edge('node_2', 'node_3') + builder.add_edge('node_3', END) + + # 编译工作流图,注入检查点保存器 + # checkpointer 支持: + # 1. 断点续传:工作流中断后可从最近检查点恢复 + # 2. 人机交互:在 interrupt 处暂停等待用户输入 + # 3. 状态持久化:跨请求保持工作流状态 + graph = builder.compile(checkpointer=checkpointer) + + return graph diff --git a/app/utils/PER_REQ_CONFIG_MODIFIER.py b/app/utils/PER_REQ_CONFIG_MODIFIER.py new file mode 100644 index 0000000..6f944db --- /dev/null +++ b/app/utils/PER_REQ_CONFIG_MODIFIER.py @@ -0,0 +1,8 @@ +async def PER_REQ_CONFIG_MODIFIER(config, request): + body = await request.json() + return { + "configurable": { + "thread_id": + body.get("config", {}).get("configurable", {}).get("thread_id", "default") + } + } diff --git a/main.py b/main.py index 2f1c2de..d601590 100644 --- a/main.py +++ b/main.py @@ -16,6 +16,7 @@ from pydantic import BaseModel, Field from starlette.staticfiles import StaticFiles from app.config.env import env +from app.controller.add_graph_proxy_route import add_graph_proxy_route from app.controller.add_test_route import add_test_route from app.utils.get_local_ips import get_local_ips from app.utils.llm_utils import create_llm @@ -121,7 +122,7 @@ model = init_chat_model( add_routes(app=app, runnable=model, path="/doubao") -add_test_route(app) +add_graph_proxy_route(app) if __name__ == "__main__": import uvicorn