From 9ad6c69a358c21d8bcc34edf0dc802a7f7bcf591 Mon Sep 17 00:00:00 2001 From: martsforever Date: Tue, 31 Mar 2026 17:13:17 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E8=87=AA=E5=AE=9A=E4=B9=89=E6=B5=81?= =?UTF-8?q?=E5=BC=8F=E8=B0=83=E7=94=A8=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/controller/add_custom_stream_route.py | 47 +++++++++ app/controller/add_test_route.py | 112 ---------------------- main.py | 10 +- 3 files changed, 49 insertions(+), 120 deletions(-) create mode 100644 app/controller/add_custom_stream_route.py delete mode 100644 app/controller/add_test_route.py diff --git a/app/controller/add_custom_stream_route.py b/app/controller/add_custom_stream_route.py new file mode 100644 index 0000000..825aeee --- /dev/null +++ b/app/controller/add_custom_stream_route.py @@ -0,0 +1,47 @@ +import json +import random +import time +import uuid +from typing import Annotated, List +from operator import add + +from langchain_core.runnables import RunnableConfig +from typing_extensions import TypedDict +from fastapi import FastAPI +from langgraph.checkpoint.memory import InMemorySaver +from langgraph.constants import START, END +from langgraph.func import task, entrypoint +from langgraph.graph import StateGraph +from langgraph.types import interrupt +from langserve import add_routes +from starlette.responses import StreamingResponse + +from app.controller.add_graph_proxy_route import create_graph +from app.utils.PER_REQ_CONFIG_MODIFIER import PER_REQ_CONFIG_MODIFIER, next_thread_id + + +def add_custom_stream_route(app: FastAPI): + @app.post('/custom_stream') + async def custom_stream(): + # 生成唯一的线程 ID,用于追踪和持久化工作流执行状态 + + config = {"configurable": {"thread_id": next_thread_id()}} + + # 创建工作流图实例 + graph = create_graph() + + async def generator_function(): + """ + 异步生成器函数,实现流式响应 + 使用 astream 方法流式执行工作流,实时返回每个节点的执行结果 + """ + async for chunk in graph.astream( + input={}, # 空输入,工作流从 START 节点自动开始 + config=config, # 配置线程 ID,支持状态持久化和恢复 + stream_mode=['messages', 'updates'] # 同时流式消息和状态更新 + ): + # 将每个执行块格式化为 SSE (Server-Sent Events) 格式 + yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n" + + # 返回流式响应,客户端可以实时接收工作流执行进度 + return StreamingResponse(generator_function(), media_type="text/event-stream") diff --git a/app/controller/add_test_route.py b/app/controller/add_test_route.py deleted file mode 100644 index 2a1a4bb..0000000 --- a/app/controller/add_test_route.py +++ /dev/null @@ -1,112 +0,0 @@ -import json -import random -import time -import uuid -from typing import TypedDict, Annotated, List -from operator import add -from fastapi import FastAPI -from langgraph.checkpoint.memory import InMemorySaver -from langgraph.constants import START, END -from langgraph.func import task, entrypoint -from langgraph.graph import StateGraph -from langgraph.types import interrupt -from starlette.responses import StreamingResponse - - -def add_test_route(app: FastAPI): - """ - 将自定义工作流注册为 FastAPI 的流式接口 - """ - - @app.post('/run_workflow') - async def run_workflow(): - # 生成唯一的线程 ID,用于追踪和持久化工作流执行状态 - thread_id = str(uuid.uuid4()) - config = {"configurable": {"thread_id": thread_id}} - - # 创建工作流图实例 - graph = create_graph() - - async def generator_function(): - """ - 异步生成器函数,实现流式响应 - 使用 astream 方法流式执行工作流,实时返回每个节点的执行结果 - """ - async for chunk in graph.astream( - input={}, # 空输入,工作流从 START 节点自动开始 - config=config, # 配置线程 ID,支持状态持久化和恢复 - stream_mode=['messages', 'updates'] # 同时流式消息和状态更新 - ): - # 将每个执行块格式化为 SSE (Server-Sent Events) 格式 - yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n" - - # 返回流式响应,客户端可以实时接收工作流执行进度 - return StreamingResponse(generator_function(), media_type="text/event-stream") - - -def create_graph(checkpointer=InMemorySaver()): - """ - 创建 LangGraph 工作流图 - - Args: - checkpointer: 状态检查点保存器,默认为内存保存器 - 支持分布式部署时可替换为 Redis、PostgreSQL 等持久化存储 - - Returns: - 编译后的工作流图实例 - """ - - # 定义工作流状态 schema - class StateSchema(TypedDict): - # name_list: 使用 Annotated 类型注解,指定 add 为归约函数 - # 每次节点返回的 name_list 会自动累加到全局状态中 - name_list: Annotated[List[str], add] - - builder = StateGraph(StateSchema) - - def node_1(state): - """ - 工作流第一个节点:生成 0-100 的随机数 - 返回格式必须匹配 StateSchema,name_list 会被自动累加到状态中 - """ - random_int = random.randint(0, 100) - print(["🧠节点执行", "node_1", random_int]) - return {"name_list": [f"node_1:{random_int}"]} - - def node_2(state): - """ - 工作流第二个节点:生成 100-200 的随机数 - state 参数包含当前累积的状态(可通过 state['name_list'] 访问) - """ - random_int = random.randint(100, 200) - print(["🧠节点执行", "node_2", random_int]) - return {"name_list": [f"node_2:{random_int}"]} - - def node_3(state): - """ - 工作流第三个节点:生成 300-400 的随机数 - 节点执行完成后,所有 name_list 会通过 add 归约函数合并 - """ - random_int = random.randint(300, 400) - print(["🧠节点执行", "node_3", random_int]) - return {"name_list": [f"node_3:{random_int}"]} - - # 注册三个节点到工作流 - builder.add_node(node_1) - builder.add_node(node_2) - builder.add_node(node_3) - - # 定义节点执行顺序:START -> node_1 -> node_2 -> node_3 -> END - builder.add_edge(START, 'node_1') - builder.add_edge('node_1', 'node_2') - builder.add_edge('node_2', 'node_3') - builder.add_edge('node_3', END) - - # 编译工作流图,注入检查点保存器 - # checkpointer 支持: - # 1. 断点续传:工作流中断后可从最近检查点恢复 - # 2. 人机交互:在 interrupt 处暂停等待用户输入 - # 3. 状态持久化:跨请求保持工作流状态 - graph = builder.compile(checkpointer=checkpointer) - - return graph diff --git a/main.py b/main.py index d601590..a6498a0 100644 --- a/main.py +++ b/main.py @@ -1,23 +1,16 @@ import datetime -import logging from contextlib import asynccontextmanager -from typing import List, Union from fastapi import FastAPI from fastapi.openapi.docs import get_swagger_ui_oauth2_redirect_html, get_redoc_html, get_swagger_ui_html from fastapi.responses import RedirectResponse from langchain.chat_models import init_chat_model -from langchain_core.messages import HumanMessage, AIMessage, SystemMessage -from langchain_core.prompts import ChatPromptTemplate -from langchain_core.runnables import RunnableGenerator -from langchain_openai import ChatOpenAI from langserve import add_routes -from pydantic import BaseModel, Field from starlette.staticfiles import StaticFiles from app.config.env import env +from app.controller.add_custom_stream_route import add_custom_stream_route from app.controller.add_graph_proxy_route import add_graph_proxy_route -from app.controller.add_test_route import add_test_route from app.utils.get_local_ips import get_local_ips from app.utils.llm_utils import create_llm @@ -123,6 +116,7 @@ model = init_chat_model( add_routes(app=app, runnable=model, path="/doubao") add_graph_proxy_route(app) +add_custom_stream_route(app) if __name__ == "__main__": import uvicorn