我将逐行解析 pocketflow/__init__.py
文件的内容。这是一个轻量级的 LLM 工作流框架的核心实现。
import asyncio, warnings, copy, time
导入必要的 Python 标准库:
asyncio
: 用于异步编程warnings
: 用于发出警告信息copy
: 用于对象的复制time
: 用于时间相关操作
class BaseNode:
def __init__(self): self.params,self.successors={},{}
def set_params(self,params): self.params=params
定义基础节点类 BaseNode
:
__init__
: 初始化两个字典 –params
(参数)和successors
(后继节点)set_params
: 设置节点参数
def next(self,node,action="default"):
if action in self.successors: warnings.warn(f"Overwriting successor for action '{action}'")
self.successors[action]=node; return node
添加后继节点的方法:
- 如果该动作已有后继节点,发出警告
- 将新节点添加到后继节点字典中
- 返回新节点以支持链式调用
def prep(self,shared): pass
def exec(self,prep_res): pass
def post(self,shared,prep_res,exec_res): pass
定义三个核心方法:
prep
: 准备数据exec
: 执行核心逻辑post
: 处理结果
这些方法在基类中是空实现,由子类重写
def _exec(self,prep_res): return self.exec(prep_res)
def _run(self,shared): p=self.prep(shared); e=self._exec(p); return self.post(shared,p,e)
def run(self,shared):
if self.successors: warnings.warn("Node won't run successors. Use Flow.")
return self._run(shared)
内部运行机制:
_exec
: 执行包装器_run
: 执行完整的 prep->exec->post 流程run
: 公开的运行方法,提醒用户单个节点不会运行后继节点
def __rshift__(self,other): return self.next(other)
def __sub__(self,action):
if isinstance(action,str): return _ConditionalTransition(self,action)
raise TypeError("Action must be a string")
运算符重载:
>>
: 用于添加默认后继节点-
: 用于创建条件转换
class _ConditionalTransition:
def __init__(self,src,action): self.src,self.action=src,action
def __rshift__(self,tgt): return self.src.next(tgt,self.action)
条件转换辅助类:
- 存储源节点和动作
- 支持使用
>>
添加目标节点
class Node(BaseNode):
def __init__(self,max_retries=1,wait=0): super().__init__(); self.max_retries,self.wait=max_retries,wait
def exec_fallback(self,prep_res,exc): raise exc
def _exec(self,prep_res):
for self.cur_retry in range(self.max_retries):
try: return self.exec(prep_res)
except Exception as e:
if self.cur_retry==self.max_retries-1: return self.exec_fallback(prep_res,e)
if self.wait>0: time.sleep(self.wait)
实际节点类:
- 添加重试机制
- 支持失败回退
- 支持重试等待时间
class BatchNode(Node):
def _exec(self,items): return [super(BatchNode,self)._exec(i) for i in (items or [])]
批处理节点:
- 对输入列表中的每个项目执行处理
- 返回结果列表
class Flow(BaseNode):
def __init__(self,start=None): super().__init__(); self.start_node=start
def start(self,start): self.start_node=start; return start
def get_next_node(self,curr,action):
nxt=curr.successors.get(action or "default")
if not nxt and curr.successors: warnings.warn(f"Flow ends: '{action}' not found in {list(curr.successors)}")
return nxt
def _orch(self,shared,params=None):
curr,p,last_action =copy.copy(self.start_node),(params or {**self.params}),None
while curr: curr.set_params(p); last_action=curr._run(shared); curr=copy.copy(self.get_next_node(curr,last_action))
return last_action
def _run(self,shared): p=self.prep(shared); o=self._orch(shared); return self.post(shared,p,o)
def post(self,shared,prep_res,exec_res): return exec_res
工作流类:
- 管理节点之间的流转
- 支持参数传递
- 处理节点执行结果
- 实现节点链的遍历执行
class BatchFlow(Flow):
def _run(self,shared):
pr=self.prep(shared) or []
for bp in pr: self._orch(shared,{**self.params,**bp})
return self.post(shared,pr,None)
批处理工作流:
- 对多组参数执行同一工作流
- 合并参数字典
class AsyncNode(Node):
async def prep_async(self,shared): pass
async def exec_async(self,prep_res): pass
async def exec_fallback_async(self,prep_res,exc): raise exc
async def post_async(self,shared,prep_res,exec_res): pass
async def _exec(self,prep_res):
for i in range(self.max_retries):
try: return await self.exec_async(prep_res)
except Exception as e:
if i==self.max_retries-1: return await self.exec_fallback_async(prep_res,e)
if self.wait>0: await asyncio.sleep(self.wait)
async def run_async(self,shared):
if self.successors: warnings.warn("Node won't run successors. Use AsyncFlow.")
return await self._run_async(shared)
async def _run_async(self,shared): p=await self.prep_async(shared); e=await self._exec(p); return await self.post_async(shared,p,e)
def _run(self,shared): raise RuntimeError("Use run_async.")
异步节点:
- 提供异步版本的核心方法
- 支持异步重试和回退
- 强制使用异步接口
class AsyncBatchNode(AsyncNode,BatchNode):
async def _exec(self,items): return [await super(AsyncBatchNode,self)._exec(i) for i in items]
异步批处理节点:
- 结合异步和批处理功能
- 顺序处理批量项目
class AsyncParallelBatchNode(AsyncNode,BatchNode):
async def _exec(self,items): return await asyncio.gather(*(super(AsyncParallelBatchNode,self)._exec(i) for i in items))
并行异步批处理节点:
- 并行处理批量项目
- 使用
asyncio.gather
实现并发
class AsyncFlow(Flow,AsyncNode):
async def _orch_async(self,shared,params=None):
curr,p,last_action =copy.copy(self.start_node),(params or {**self.params}),None
while curr: curr.set_params(p); last_action=await curr._run_async(shared) if isinstance(curr,AsyncNode) else curr._run(shared); curr=copy.copy(self.get_next_node(curr,last_action))
return last_action
async def _run_async(self,shared): p=await self.prep_async(shared); o=await self._orch_async(shared); return await self.post_async(shared,p,o)
async def post_async(self,shared,prep_res,exec_res): return exec_res
异步工作流:
- 支持混合同步和异步节点
- 提供异步工作流执行
class AsyncBatchFlow(AsyncFlow,BatchFlow):
async def _run_async(self,shared):
pr=await self.prep_async(shared) or []
for bp in pr: await self._orch_async(shared,{**self.params,**bp})
return await self.post_async(shared,pr,None)
异步批处理工作流:
- 顺序执行多组参数的异步工作流
class AsyncParallelBatchFlow(AsyncFlow,BatchFlow):
async def _run_async(self,shared):
pr=await self.prep_async(shared) or []
await asyncio.gather(*(self._orch_async(shared,{**self.params,**bp}) for bp in pr))
return await self.post_async(shared,pr,None)
并行异步批处理工作流:
- 并行执行多组参数的异步工作流
- 最大化并发性能
这个框架通过组合基本的节点和工作流概念,提供了一个灵活的系统来构建 LLM 应用。它支持同步/异步操作、批处理、并行处理等多种模式,同时保持了代码的简洁性和可扩展性。
让我解释一下为什么 exec()
方法不需要 shared
参数:
- 职责分离原则
prep
: 负责从共享存储中读取和准备数据exec
: 专注于纯计算/处理逻辑post
: 负责将结果写回共享存储并决定下一步动作
- 纯函数设计
exec()
被设计成一个”纯函数”,它只处理通过prep()
准备好的数据- 这样的设计使得
exec()
更容易测试、调试和重试 - 不直接访问共享状态可以避免副作用,提高代码的可预测性
- 数据流设计
graph LR
prep["prep(shared)"] --> |prep_res| exec["exec(prep_res)"]
exec --> |exec_res| post["post(shared, prep_res, exec_res)"]
prep
从shared
读取数据并返回prep_res
exec
接收prep_res
并返回exec_res
post
接收所有上下文并更新shared
- 重试机制考虑
def _exec(self,prep_res):
for self.cur_retry in range(self.max_retries):
try: return self.exec(prep_res)
except Exception as e:
if self.cur_retry==self.max_retries-1: return self.exec_fallback(prep_res,e)
if self.wait>0: time.sleep(self.wait)
- 当执行失败需要重试时,如果
exec()
直接访问shared
,可能会导致状态不一致 - 通过参数传递的方式,每次重试都使用相同的输入,保证了重试的一致性
- 并行处理考虑
class AsyncParallelBatchNode(AsyncNode,BatchNode):
async def _exec(self,items): return await asyncio.gather(*(super(AsyncParallelBatchNode,self)._exec(i) for i in items))
- 在并行处理场景下,如果
exec()
直接访问共享状态,需要处理复杂的并发控制 - 通过参数传递的方式,每个并行任务都有自己的独立数据副本,避免了并发问题
这种设计模式在很多框架中都很常见,例如:
- MapReduce 中的 map 函数就是纯函数
- React 中的 render 函数也是纯函数
- 函数式编程中的不可变性原则
总的来说,这是一个经过深思熟虑的设计决策,它使得:
- 代码更容易测试和维护
- 重试机制更可靠
- 并行处理更安全
- 数据流更清晰
- 副作用更可控