feat: 自定义流式调用接口

This commit is contained in:
martsforever
2026-03-31 17:13:17 +08:00
parent e1f57b6f09
commit 9ad6c69a35
3 changed files with 49 additions and 120 deletions
+47
View File
@@ -0,0 +1,47 @@
import json
import random
import time
import uuid
from typing import Annotated, List
from operator import add
from langchain_core.runnables import RunnableConfig
from typing_extensions import TypedDict
from fastapi import FastAPI
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.constants import START, END
from langgraph.func import task, entrypoint
from langgraph.graph import StateGraph
from langgraph.types import interrupt
from langserve import add_routes
from starlette.responses import StreamingResponse
from app.controller.add_graph_proxy_route import create_graph
from app.utils.PER_REQ_CONFIG_MODIFIER import PER_REQ_CONFIG_MODIFIER, next_thread_id
def add_custom_stream_route(app: FastAPI):
@app.post('/custom_stream')
async def custom_stream():
# 生成唯一的线程 ID,用于追踪和持久化工作流执行状态
config = {"configurable": {"thread_id": next_thread_id()}}
# 创建工作流图实例
graph = create_graph()
async def generator_function():
"""
异步生成器函数,实现流式响应
使用 astream 方法流式执行工作流,实时返回每个节点的执行结果
"""
async for chunk in graph.astream(
input={}, # 空输入,工作流从 START 节点自动开始
config=config, # 配置线程 ID,支持状态持久化和恢复
stream_mode=['messages', 'updates'] # 同时流式消息和状态更新
):
# 将每个执行块格式化为 SSE (Server-Sent Events) 格式
yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"
# 返回流式响应,客户端可以实时接收工作流执行进度
return StreamingResponse(generator_function(), media_type="text/event-stream")
-112
View File
@@ -1,112 +0,0 @@
import json
import random
import time
import uuid
from typing import TypedDict, Annotated, List
from operator import add
from fastapi import FastAPI
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.constants import START, END
from langgraph.func import task, entrypoint
from langgraph.graph import StateGraph
from langgraph.types import interrupt
from starlette.responses import StreamingResponse
def add_test_route(app: FastAPI):
"""
将自定义工作流注册为 FastAPI 的流式接口
"""
@app.post('/run_workflow')
async def run_workflow():
# 生成唯一的线程 ID,用于追踪和持久化工作流执行状态
thread_id = str(uuid.uuid4())
config = {"configurable": {"thread_id": thread_id}}
# 创建工作流图实例
graph = create_graph()
async def generator_function():
"""
异步生成器函数,实现流式响应
使用 astream 方法流式执行工作流,实时返回每个节点的执行结果
"""
async for chunk in graph.astream(
input={}, # 空输入,工作流从 START 节点自动开始
config=config, # 配置线程 ID,支持状态持久化和恢复
stream_mode=['messages', 'updates'] # 同时流式消息和状态更新
):
# 将每个执行块格式化为 SSE (Server-Sent Events) 格式
yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"
# 返回流式响应,客户端可以实时接收工作流执行进度
return StreamingResponse(generator_function(), media_type="text/event-stream")
def create_graph(checkpointer=InMemorySaver()):
"""
创建 LangGraph 工作流图
Args:
checkpointer: 状态检查点保存器,默认为内存保存器
支持分布式部署时可替换为 Redis、PostgreSQL 等持久化存储
Returns:
编译后的工作流图实例
"""
# 定义工作流状态 schema
class StateSchema(TypedDict):
# name_list: 使用 Annotated 类型注解,指定 add 为归约函数
# 每次节点返回的 name_list 会自动累加到全局状态中
name_list: Annotated[List[str], add]
builder = StateGraph(StateSchema)
def node_1(state):
"""
工作流第一个节点:生成 0-100 的随机数
返回格式必须匹配 StateSchemaname_list 会被自动累加到状态中
"""
random_int = random.randint(0, 100)
print(["🧠节点执行", "node_1", random_int])
return {"name_list": [f"node_1:{random_int}"]}
def node_2(state):
"""
工作流第二个节点:生成 100-200 的随机数
state 参数包含当前累积的状态(可通过 state['name_list'] 访问)
"""
random_int = random.randint(100, 200)
print(["🧠节点执行", "node_2", random_int])
return {"name_list": [f"node_2:{random_int}"]}
def node_3(state):
"""
工作流第三个节点:生成 300-400 的随机数
节点执行完成后,所有 name_list 会通过 add 归约函数合并
"""
random_int = random.randint(300, 400)
print(["🧠节点执行", "node_3", random_int])
return {"name_list": [f"node_3:{random_int}"]}
# 注册三个节点到工作流
builder.add_node(node_1)
builder.add_node(node_2)
builder.add_node(node_3)
# 定义节点执行顺序:START -> node_1 -> node_2 -> node_3 -> END
builder.add_edge(START, 'node_1')
builder.add_edge('node_1', 'node_2')
builder.add_edge('node_2', 'node_3')
builder.add_edge('node_3', END)
# 编译工作流图,注入检查点保存器
# checkpointer 支持:
# 1. 断点续传:工作流中断后可从最近检查点恢复
# 2. 人机交互:在 interrupt 处暂停等待用户输入
# 3. 状态持久化:跨请求保持工作流状态
graph = builder.compile(checkpointer=checkpointer)
return graph
+2 -8
View File
@@ -1,23 +1,16 @@
import datetime
import logging
from contextlib import asynccontextmanager
from typing import List, Union
from fastapi import FastAPI
from fastapi.openapi.docs import get_swagger_ui_oauth2_redirect_html, get_redoc_html, get_swagger_ui_html
from fastapi.responses import RedirectResponse
from langchain.chat_models import init_chat_model
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableGenerator
from langchain_openai import ChatOpenAI
from langserve import add_routes
from pydantic import BaseModel, Field
from starlette.staticfiles import StaticFiles
from app.config.env import env
from app.controller.add_custom_stream_route import add_custom_stream_route
from app.controller.add_graph_proxy_route import add_graph_proxy_route
from app.controller.add_test_route import add_test_route
from app.utils.get_local_ips import get_local_ips
from app.utils.llm_utils import create_llm
@@ -123,6 +116,7 @@ model = init_chat_model(
add_routes(app=app, runnable=model, path="/doubao")
add_graph_proxy_route(app)
add_custom_stream_route(app)
if __name__ == "__main__":
import uvicorn