feat: 将自定义工作流部署为模型服务
This commit is contained in:
@@ -0,0 +1,100 @@
|
|||||||
|
import json
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from typing import Annotated, List
|
||||||
|
from operator import add
|
||||||
|
|
||||||
|
from langchain_core.runnables import RunnableConfig
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from langgraph.checkpoint.memory import InMemorySaver
|
||||||
|
from langgraph.constants import START, END
|
||||||
|
from langgraph.func import task, entrypoint
|
||||||
|
from langgraph.graph import StateGraph
|
||||||
|
from langgraph.types import interrupt
|
||||||
|
from langserve import add_routes
|
||||||
|
from starlette.responses import StreamingResponse
|
||||||
|
|
||||||
|
from app.utils.PER_REQ_CONFIG_MODIFIER import PER_REQ_CONFIG_MODIFIER
|
||||||
|
|
||||||
|
|
||||||
|
def add_graph_proxy_route(app: FastAPI):
|
||||||
|
"""
|
||||||
|
将自定义工作流注册为 FastAPI 的流式接口
|
||||||
|
"""
|
||||||
|
|
||||||
|
add_routes(
|
||||||
|
app=app,
|
||||||
|
runnable=create_graph(),
|
||||||
|
path="/graph_proxy",
|
||||||
|
per_req_config_modifier=PER_REQ_CONFIG_MODIFIER,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_graph(checkpointer=InMemorySaver()):
|
||||||
|
"""
|
||||||
|
创建 LangGraph 工作流图
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checkpointer: 状态检查点保存器,默认为内存保存器
|
||||||
|
支持分布式部署时可替换为 Redis、PostgreSQL 等持久化存储
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
编译后的工作流图实例
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 定义工作流状态 schema
|
||||||
|
class StateSchema(TypedDict):
|
||||||
|
# name_list: 使用 Annotated 类型注解,指定 add 为归约函数
|
||||||
|
# 每次节点返回的 name_list 会自动累加到全局状态中
|
||||||
|
name_list: Annotated[List[str], add]
|
||||||
|
|
||||||
|
builder = StateGraph(StateSchema)
|
||||||
|
|
||||||
|
def node_1(state, config: RunnableConfig):
|
||||||
|
"""
|
||||||
|
工作流第一个节点:生成 0-100 的随机数
|
||||||
|
返回格式必须匹配 StateSchema,name_list 会被自动累加到状态中
|
||||||
|
"""
|
||||||
|
random_int = str(random.randint(0, 100)) + "-->>id:" + config.get('configurable').get('thread_id')
|
||||||
|
print(["🧠节点执行", "node_1", random_int])
|
||||||
|
return {"name_list": [f"node_1:{random_int}"]}
|
||||||
|
|
||||||
|
def node_2(state):
|
||||||
|
"""
|
||||||
|
工作流第二个节点:生成 100-200 的随机数
|
||||||
|
state 参数包含当前累积的状态(可通过 state['name_list'] 访问)
|
||||||
|
"""
|
||||||
|
random_int = random.randint(100, 200)
|
||||||
|
print(["🧠节点执行", "node_2", random_int])
|
||||||
|
return {"name_list": [f"node_2:{random_int}"]}
|
||||||
|
|
||||||
|
def node_3(state):
|
||||||
|
"""
|
||||||
|
工作流第三个节点:生成 300-400 的随机数
|
||||||
|
节点执行完成后,所有 name_list 会通过 add 归约函数合并
|
||||||
|
"""
|
||||||
|
random_int = random.randint(300, 400)
|
||||||
|
print(["🧠节点执行", "node_3", random_int])
|
||||||
|
return {"name_list": [f"node_3:{random_int}"]}
|
||||||
|
|
||||||
|
# 注册三个节点到工作流
|
||||||
|
builder.add_node(node_1)
|
||||||
|
builder.add_node(node_2)
|
||||||
|
builder.add_node(node_3)
|
||||||
|
|
||||||
|
# 定义节点执行顺序:START -> node_1 -> node_2 -> node_3 -> END
|
||||||
|
builder.add_edge(START, 'node_1')
|
||||||
|
builder.add_edge('node_1', 'node_2')
|
||||||
|
builder.add_edge('node_2', 'node_3')
|
||||||
|
builder.add_edge('node_3', END)
|
||||||
|
|
||||||
|
# 编译工作流图,注入检查点保存器
|
||||||
|
# checkpointer 支持:
|
||||||
|
# 1. 断点续传:工作流中断后可从最近检查点恢复
|
||||||
|
# 2. 人机交互:在 interrupt 处暂停等待用户输入
|
||||||
|
# 3. 状态持久化:跨请求保持工作流状态
|
||||||
|
graph = builder.compile(checkpointer=checkpointer)
|
||||||
|
|
||||||
|
return graph
|
||||||
@@ -0,0 +1,8 @@
|
|||||||
|
async def PER_REQ_CONFIG_MODIFIER(config, request):
|
||||||
|
body = await request.json()
|
||||||
|
return {
|
||||||
|
"configurable": {
|
||||||
|
"thread_id":
|
||||||
|
body.get("config", {}).get("configurable", {}).get("thread_id", "default")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -16,6 +16,7 @@ from pydantic import BaseModel, Field
|
|||||||
from starlette.staticfiles import StaticFiles
|
from starlette.staticfiles import StaticFiles
|
||||||
|
|
||||||
from app.config.env import env
|
from app.config.env import env
|
||||||
|
from app.controller.add_graph_proxy_route import add_graph_proxy_route
|
||||||
from app.controller.add_test_route import add_test_route
|
from app.controller.add_test_route import add_test_route
|
||||||
from app.utils.get_local_ips import get_local_ips
|
from app.utils.get_local_ips import get_local_ips
|
||||||
from app.utils.llm_utils import create_llm
|
from app.utils.llm_utils import create_llm
|
||||||
@@ -121,7 +122,7 @@ model = init_chat_model(
|
|||||||
|
|
||||||
add_routes(app=app, runnable=model, path="/doubao")
|
add_routes(app=app, runnable=model, path="/doubao")
|
||||||
|
|
||||||
add_test_route(app)
|
add_graph_proxy_route(app)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|||||||
Reference in New Issue
Block a user