Map-reduce 操作对于高效的任务分解和并行处理至关重要。这种方法涉及将一个任务分解成更小的子任务,每个子任务并行处理,并聚合所有已完成子任务的结果。
考虑这个例子:给定用户提出的一个通用主题,生成一个相关主题列表,为每个主题生成一个笑话,并从结果列表中选择最好的笑话。在这个设计模式中,第一个节点可能会生成一个对象列表(例如,相关主题),我们希望将另一个节点(例如,生成笑话)应用于所有这些对象(例如,主题)。然而,出现了两个主要挑战。
(1) 在我们布局图时,对象(例如,主题)的数量可能是未知的(意味着边的数量可能未知);(2) 输入到下游节点的状态应该是不同的(每个生成的对象一个)。
LangGraph 通过其 Send API 解决了这些挑战。通过使用条件边,Send 可以将不同的状态(例如,主题)分发到多个节点实例(例如,笑话生成)。重要的是,发送的状态可以与核心图的状态不同,允许灵活和动态的工作流管理。
1. 导入必要的库
首先,我们需要导入一些必要的库,包括operator
、typing
、pydantic
以及LangGraph和LangChain的相关模块。
import operator
from typing import Annotated
from typing_extensions import TypedDict
from langgraph.types import Send
from langgraph.graph import END, StateGraph, START
from pydantic import BaseModel, Field
from langchain_openai import ChatOpenAI
2. 定义模型和提示
接下来,我们定义一些用于生成主题、笑话和选择最佳笑话的提示模板,以及相应的数据模型。
# 提示模板
subjects_prompt = """Generate a comma separated list of between 2 and 5 examples related to: {topic}."""
joke_prompt = """Generate a joke about {subject}"""
best_joke_prompt = """Below are a bunch of jokes about {topic}. Select the best one! Return the ID of the best one.
{jokes}"""
class Subjects(BaseModel):
"""list of subjects related to the topic"""
subjects: list[str] = Field(description="list of subjects related to the topic")
class Joke(BaseModel):
joke: str
class BestJoke(BaseModel):
id: int = Field(description="Index of the best joke, starting with 0", ge=0)
3. 初始化模型
我们使用LangChain的ChatOpenAI
类来初始化我们的模型。
model = ChatOpenAI(
temperature=0,
model="GLM-4-Plus",
openai_api_key="your api key",
openai_api_base="https://open.bigmodel.cn/api/paas/v4/"
)
4. 定义状态图组件
我们需要定义状态图的总体状态和节点状态。
class OverallState(TypedDict):
topic: str
subjects: list
jokes: Annotated[list, operator.add]
best_selected_joke: str
class JokeState(TypedDict):
subject: str
5. 定义节点函数
接下来,我们定义三个主要的节点函数:生成主题、生成笑话和选择最佳笑话。
def generate_topics(state: OverallState):
prompt = subjects_prompt.format(topic=state["topic"])
response = model.with_structured_output(Subjects).invoke(prompt)
return {"subjects": response.subjects}
def generate_joke(state: JokeState):
prompt = joke_prompt.format(subject=state["subject"])
response = model.with_structured_output(Joke).invoke(prompt)
return {"jokes": [response.joke]}
def best_joke(state: OverallState):
jokes = "\n\n".join(state["jokes"])
prompt = best_joke_prompt.format(topic=state["topic"], jokes=jokes)
response = model.with_structured_output(BestJoke).invoke(prompt)
return {"best_selected_joke": state["jokes"][response.id]}
6. 定义映射逻辑
我们需要定义一个函数来将生成的主题映射到生成笑话的节点。
def continue_to_jokes(state: OverallState):
return [Send("generate_joke", {"subject": s}) for s in state["subjects"]]
7. 构建状态图
将所有组件组合在一起,构建我们的状态图。
graph = StateGraph(OverallState)
graph.add_node("generate_topics", generate_topics)
graph.add_node("generate_joke", generate_joke)
graph.add_node("best_joke", best_joke)
graph.add_edge(START, "generate_topics")
graph.add_conditional_edges("generate_topics", continue_to_jokes, ["generate_joke"])
graph.add_edge("generate_joke", "best_joke")
graph.add_edge("best_joke", END)
app = graph.compile()
8. 可视化状态图
使用IPython的Image
类来可视化状态图。
from IPython.display import Image
Image(app.get_graph().draw_mermaid_png())
9. 调用状态图生成笑话
最后,我们调用状态图来生成和选择最佳笑话。
for s in app.stream({"topic": "animals"}):
print(s)
输出结果如下:
{'generate_topics': {'subjects': ['dog', 'cat', 'elephant', 'giraffe', 'lion']}}
{'generate_joke': {'jokes': ["Why do elephants never use computers? Because they're afraid of the mouse!"]}}
{'generate_joke': {'jokes': ["Why did the giraffe become a vegetarian? Because it didn't want to stick its neck out for meat!"]}}
{'generate_joke': {'jokes': ["Why don't cats play poker in the jungle? Too many cheetahs!"]}}
{'generate_joke': {'jokes': ["Why don't lions play cards in the jungle? Too many cheetahs!"]}}
{'generate_joke': {'jokes': ["Why did the dog sit in the shade? Because he didn't want to be a hot dog!"]}}
{'best_joke': {'best_selected_joke': "Why did the dog sit in the shade? Because he didn't want to be a hot dog!"}}
通过以上步骤,我们成功构建了一个能够生成和选择最佳笑话的程序。希望这个教程对你有所帮助!
注意事项
官网的代码中没有对输出格式作注解,导致运行失败。
class Subjects(BaseModel):
subjects: list[str]
def generate_joke(state: JokeState):
prompt = joke_prompt.format(subject=state["subject"])
response = model.with_structured_output(Joke).invoke(prompt)
return {"jokes": [response.joke]}
只需要添加更详细的注解即可运行成功:
class Subjects(BaseModel):
subjects: list[str] = Field(description="list of subjects related to the topic, the subject of list is the string format")
参考链接:https://langchain-ai.github.io/langgraph/how-tos/map-reduce/#define-the-graph
如果有任何问题,欢迎在评论区提问。