feat: 自定义流式调用接口
This commit is contained in:
@@ -0,0 +1,47 @@
|
||||
import json
|
||||
import random
|
||||
import time
|
||||
import uuid
|
||||
from typing import Annotated, List
|
||||
from operator import add
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from typing_extensions import TypedDict
|
||||
from fastapi import FastAPI
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from langgraph.constants import START, END
|
||||
from langgraph.func import task, entrypoint
|
||||
from langgraph.graph import StateGraph
|
||||
from langgraph.types import interrupt
|
||||
from langserve import add_routes
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
from app.controller.add_graph_proxy_route import create_graph
|
||||
from app.utils.PER_REQ_CONFIG_MODIFIER import PER_REQ_CONFIG_MODIFIER, next_thread_id
|
||||
|
||||
|
||||
def add_custom_stream_route(app: FastAPI):
|
||||
@app.post('/custom_stream')
|
||||
async def custom_stream():
|
||||
# 生成唯一的线程 ID,用于追踪和持久化工作流执行状态
|
||||
|
||||
config = {"configurable": {"thread_id": next_thread_id()}}
|
||||
|
||||
# 创建工作流图实例
|
||||
graph = create_graph()
|
||||
|
||||
async def generator_function():
|
||||
"""
|
||||
异步生成器函数,实现流式响应
|
||||
使用 astream 方法流式执行工作流,实时返回每个节点的执行结果
|
||||
"""
|
||||
async for chunk in graph.astream(
|
||||
input={}, # 空输入,工作流从 START 节点自动开始
|
||||
config=config, # 配置线程 ID,支持状态持久化和恢复
|
||||
stream_mode=['messages', 'updates'] # 同时流式消息和状态更新
|
||||
):
|
||||
# 将每个执行块格式化为 SSE (Server-Sent Events) 格式
|
||||
yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 返回流式响应,客户端可以实时接收工作流执行进度
|
||||
return StreamingResponse(generator_function(), media_type="text/event-stream")
|
||||
@@ -1,112 +0,0 @@
|
||||
import json
|
||||
import random
|
||||
import time
|
||||
import uuid
|
||||
from typing import TypedDict, Annotated, List
|
||||
from operator import add
|
||||
from fastapi import FastAPI
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from langgraph.constants import START, END
|
||||
from langgraph.func import task, entrypoint
|
||||
from langgraph.graph import StateGraph
|
||||
from langgraph.types import interrupt
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
|
||||
def add_test_route(app: FastAPI):
|
||||
"""
|
||||
将自定义工作流注册为 FastAPI 的流式接口
|
||||
"""
|
||||
|
||||
@app.post('/run_workflow')
|
||||
async def run_workflow():
|
||||
# 生成唯一的线程 ID,用于追踪和持久化工作流执行状态
|
||||
thread_id = str(uuid.uuid4())
|
||||
config = {"configurable": {"thread_id": thread_id}}
|
||||
|
||||
# 创建工作流图实例
|
||||
graph = create_graph()
|
||||
|
||||
async def generator_function():
|
||||
"""
|
||||
异步生成器函数,实现流式响应
|
||||
使用 astream 方法流式执行工作流,实时返回每个节点的执行结果
|
||||
"""
|
||||
async for chunk in graph.astream(
|
||||
input={}, # 空输入,工作流从 START 节点自动开始
|
||||
config=config, # 配置线程 ID,支持状态持久化和恢复
|
||||
stream_mode=['messages', 'updates'] # 同时流式消息和状态更新
|
||||
):
|
||||
# 将每个执行块格式化为 SSE (Server-Sent Events) 格式
|
||||
yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 返回流式响应,客户端可以实时接收工作流执行进度
|
||||
return StreamingResponse(generator_function(), media_type="text/event-stream")
|
||||
|
||||
|
||||
def create_graph(checkpointer=InMemorySaver()):
|
||||
"""
|
||||
创建 LangGraph 工作流图
|
||||
|
||||
Args:
|
||||
checkpointer: 状态检查点保存器,默认为内存保存器
|
||||
支持分布式部署时可替换为 Redis、PostgreSQL 等持久化存储
|
||||
|
||||
Returns:
|
||||
编译后的工作流图实例
|
||||
"""
|
||||
|
||||
# 定义工作流状态 schema
|
||||
class StateSchema(TypedDict):
|
||||
# name_list: 使用 Annotated 类型注解,指定 add 为归约函数
|
||||
# 每次节点返回的 name_list 会自动累加到全局状态中
|
||||
name_list: Annotated[List[str], add]
|
||||
|
||||
builder = StateGraph(StateSchema)
|
||||
|
||||
def node_1(state):
|
||||
"""
|
||||
工作流第一个节点:生成 0-100 的随机数
|
||||
返回格式必须匹配 StateSchema,name_list 会被自动累加到状态中
|
||||
"""
|
||||
random_int = random.randint(0, 100)
|
||||
print(["🧠节点执行", "node_1", random_int])
|
||||
return {"name_list": [f"node_1:{random_int}"]}
|
||||
|
||||
def node_2(state):
|
||||
"""
|
||||
工作流第二个节点:生成 100-200 的随机数
|
||||
state 参数包含当前累积的状态(可通过 state['name_list'] 访问)
|
||||
"""
|
||||
random_int = random.randint(100, 200)
|
||||
print(["🧠节点执行", "node_2", random_int])
|
||||
return {"name_list": [f"node_2:{random_int}"]}
|
||||
|
||||
def node_3(state):
|
||||
"""
|
||||
工作流第三个节点:生成 300-400 的随机数
|
||||
节点执行完成后,所有 name_list 会通过 add 归约函数合并
|
||||
"""
|
||||
random_int = random.randint(300, 400)
|
||||
print(["🧠节点执行", "node_3", random_int])
|
||||
return {"name_list": [f"node_3:{random_int}"]}
|
||||
|
||||
# 注册三个节点到工作流
|
||||
builder.add_node(node_1)
|
||||
builder.add_node(node_2)
|
||||
builder.add_node(node_3)
|
||||
|
||||
# 定义节点执行顺序:START -> node_1 -> node_2 -> node_3 -> END
|
||||
builder.add_edge(START, 'node_1')
|
||||
builder.add_edge('node_1', 'node_2')
|
||||
builder.add_edge('node_2', 'node_3')
|
||||
builder.add_edge('node_3', END)
|
||||
|
||||
# 编译工作流图,注入检查点保存器
|
||||
# checkpointer 支持:
|
||||
# 1. 断点续传:工作流中断后可从最近检查点恢复
|
||||
# 2. 人机交互:在 interrupt 处暂停等待用户输入
|
||||
# 3. 状态持久化:跨请求保持工作流状态
|
||||
graph = builder.compile(checkpointer=checkpointer)
|
||||
|
||||
return graph
|
||||
Reference in New Issue
Block a user