PocketFlow百行源码解析

我将逐行解析 pocketflow/__init__.py 文件的内容。这是一个轻量级的 LLM 工作流框架的核心实现。

import asyncio, warnings, copy, time

导入必要的 Python 标准库:

  • asyncio: 用于异步编程
  • warnings: 用于发出警告信息
  • copy: 用于对象的复制
  • time: 用于时间相关操作
class BaseNode:
    def __init__(self): self.params,self.successors={},{}
    def set_params(self,params): self.params=params

定义基础节点类 BaseNode

  • __init__: 初始化两个字典 – params(参数)和successors(后继节点)
  • set_params: 设置节点参数
    def next(self,node,action="default"):
        if action in self.successors: warnings.warn(f"Overwriting successor for action '{action}'")
        self.successors[action]=node; return node

添加后继节点的方法:

  • 如果该动作已有后继节点,发出警告
  • 将新节点添加到后继节点字典中
  • 返回新节点以支持链式调用
    def prep(self,shared): pass
    def exec(self,prep_res): pass
    def post(self,shared,prep_res,exec_res): pass

定义三个核心方法:

  • prep: 准备数据
  • exec: 执行核心逻辑
  • post: 处理结果
    这些方法在基类中是空实现,由子类重写
    def _exec(self,prep_res): return self.exec(prep_res)
    def _run(self,shared): p=self.prep(shared); e=self._exec(p); return self.post(shared,p,e)
    def run(self,shared): 
        if self.successors: warnings.warn("Node won't run successors. Use Flow.")  
        return self._run(shared)

内部运行机制:

  • _exec: 执行包装器
  • _run: 执行完整的 prep->exec->post 流程
  • run: 公开的运行方法,提醒用户单个节点不会运行后继节点
    def __rshift__(self,other): return self.next(other)
    def __sub__(self,action):
        if isinstance(action,str): return _ConditionalTransition(self,action)
        raise TypeError("Action must be a string")

运算符重载:

  • >>: 用于添加默认后继节点
  • -: 用于创建条件转换
class _ConditionalTransition:
    def __init__(self,src,action): self.src,self.action=src,action
    def __rshift__(self,tgt): return self.src.next(tgt,self.action)

条件转换辅助类:

  • 存储源节点和动作
  • 支持使用 >> 添加目标节点
class Node(BaseNode):
    def __init__(self,max_retries=1,wait=0): super().__init__(); self.max_retries,self.wait=max_retries,wait
    def exec_fallback(self,prep_res,exc): raise exc
    def _exec(self,prep_res):
        for self.cur_retry in range(self.max_retries):
            try: return self.exec(prep_res)
            except Exception as e:
                if self.cur_retry==self.max_retries-1: return self.exec_fallback(prep_res,e)
                if self.wait>0: time.sleep(self.wait)

实际节点类:

  • 添加重试机制
  • 支持失败回退
  • 支持重试等待时间
class BatchNode(Node):
    def _exec(self,items): return [super(BatchNode,self)._exec(i) for i in (items or [])]

批处理节点:

  • 对输入列表中的每个项目执行处理
  • 返回结果列表
class Flow(BaseNode):
    def __init__(self,start=None): super().__init__(); self.start_node=start
    def start(self,start): self.start_node=start; return start
    def get_next_node(self,curr,action):
        nxt=curr.successors.get(action or "default")
        if not nxt and curr.successors: warnings.warn(f"Flow ends: '{action}' not found in {list(curr.successors)}")
        return nxt
    def _orch(self,shared,params=None):
        curr,p,last_action =copy.copy(self.start_node),(params or {**self.params}),None
        while curr: curr.set_params(p); last_action=curr._run(shared); curr=copy.copy(self.get_next_node(curr,last_action))
        return last_action
    def _run(self,shared): p=self.prep(shared); o=self._orch(shared); return self.post(shared,p,o)
    def post(self,shared,prep_res,exec_res): return exec_res

工作流类:

  • 管理节点之间的流转
  • 支持参数传递
  • 处理节点执行结果
  • 实现节点链的遍历执行
class BatchFlow(Flow):
    def _run(self,shared):
        pr=self.prep(shared) or []
        for bp in pr: self._orch(shared,{**self.params,**bp})
        return self.post(shared,pr,None)

批处理工作流:

  • 对多组参数执行同一工作流
  • 合并参数字典
class AsyncNode(Node):
    async def prep_async(self,shared): pass
    async def exec_async(self,prep_res): pass
    async def exec_fallback_async(self,prep_res,exc): raise exc
    async def post_async(self,shared,prep_res,exec_res): pass
    async def _exec(self,prep_res): 
        for i in range(self.max_retries):
            try: return await self.exec_async(prep_res)
            except Exception as e:
                if i==self.max_retries-1: return await self.exec_fallback_async(prep_res,e)
                if self.wait>0: await asyncio.sleep(self.wait)
    async def run_async(self,shared): 
        if self.successors: warnings.warn("Node won't run successors. Use AsyncFlow.")  
        return await self._run_async(shared)
    async def _run_async(self,shared): p=await self.prep_async(shared); e=await self._exec(p); return await self.post_async(shared,p,e)
    def _run(self,shared): raise RuntimeError("Use run_async.")

异步节点:

  • 提供异步版本的核心方法
  • 支持异步重试和回退
  • 强制使用异步接口
class AsyncBatchNode(AsyncNode,BatchNode):
    async def _exec(self,items): return [await super(AsyncBatchNode,self)._exec(i) for i in items]

异步批处理节点:

  • 结合异步和批处理功能
  • 顺序处理批量项目
class AsyncParallelBatchNode(AsyncNode,BatchNode):
    async def _exec(self,items): return await asyncio.gather(*(super(AsyncParallelBatchNode,self)._exec(i) for i in items))

并行异步批处理节点:

  • 并行处理批量项目
  • 使用 asyncio.gather 实现并发
class AsyncFlow(Flow,AsyncNode):
    async def _orch_async(self,shared,params=None):
        curr,p,last_action =copy.copy(self.start_node),(params or {**self.params}),None
        while curr: curr.set_params(p); last_action=await curr._run_async(shared) if isinstance(curr,AsyncNode) else curr._run(shared); curr=copy.copy(self.get_next_node(curr,last_action))
        return last_action
    async def _run_async(self,shared): p=await self.prep_async(shared); o=await self._orch_async(shared); return await self.post_async(shared,p,o)
    async def post_async(self,shared,prep_res,exec_res): return exec_res

异步工作流:

  • 支持混合同步和异步节点
  • 提供异步工作流执行
class AsyncBatchFlow(AsyncFlow,BatchFlow):
    async def _run_async(self,shared):
        pr=await self.prep_async(shared) or []
        for bp in pr: await self._orch_async(shared,{**self.params,**bp})
        return await self.post_async(shared,pr,None)

异步批处理工作流:

  • 顺序执行多组参数的异步工作流
class AsyncParallelBatchFlow(AsyncFlow,BatchFlow):
    async def _run_async(self,shared): 
        pr=await self.prep_async(shared) or []
        await asyncio.gather(*(self._orch_async(shared,{**self.params,**bp}) for bp in pr))
        return await self.post_async(shared,pr,None)

并行异步批处理工作流:

  • 并行执行多组参数的异步工作流
  • 最大化并发性能

这个框架通过组合基本的节点和工作流概念,提供了一个灵活的系统来构建 LLM 应用。它支持同步/异步操作、批处理、并行处理等多种模式,同时保持了代码的简洁性和可扩展性。



让我解释一下为什么 exec() 方法不需要 shared 参数:

  • 职责分离原则
  • prep: 负责从共享存储中读取和准备数据
  • exec: 专注于纯计算/处理逻辑
  • post: 负责将结果写回共享存储并决定下一步动作
  • 纯函数设计
  • exec() 被设计成一个”纯函数”,它只处理通过 prep() 准备好的数据
  • 这样的设计使得 exec() 更容易测试、调试和重试
  • 不直接访问共享状态可以避免副作用,提高代码的可预测性
  • 数据流设计
graph LR
    prep["prep(shared)"] --> |prep_res| exec["exec(prep_res)"]
    exec --> |exec_res| post["post(shared, prep_res, exec_res)"]
  • prepshared 读取数据并返回 prep_res
  • exec 接收 prep_res 并返回 exec_res
  • post 接收所有上下文并更新 shared
  • 重试机制考虑
def _exec(self,prep_res):
    for self.cur_retry in range(self.max_retries):
        try: return self.exec(prep_res)
        except Exception as e:
            if self.cur_retry==self.max_retries-1: return self.exec_fallback(prep_res,e)
            if self.wait>0: time.sleep(self.wait)
  • 当执行失败需要重试时,如果 exec() 直接访问 shared,可能会导致状态不一致
  • 通过参数传递的方式,每次重试都使用相同的输入,保证了重试的一致性
  • 并行处理考虑
class AsyncParallelBatchNode(AsyncNode,BatchNode):
    async def _exec(self,items): return await asyncio.gather(*(super(AsyncParallelBatchNode,self)._exec(i) for i in items))
  • 在并行处理场景下,如果 exec() 直接访问共享状态,需要处理复杂的并发控制
  • 通过参数传递的方式,每个并行任务都有自己的独立数据副本,避免了并发问题

这种设计模式在很多框架中都很常见,例如:

  • MapReduce 中的 map 函数就是纯函数
  • React 中的 render 函数也是纯函数
  • 函数式编程中的不可变性原则

总的来说,这是一个经过深思熟虑的设计决策,它使得:

  1. 代码更容易测试和维护
  2. 重试机制更可靠
  3. 并行处理更安全
  4. 数据流更清晰
  5. 副作用更可控

发表评论

人生梦想 - 关注前沿的计算机技术 acejoy.com 🐾 步子哥の博客 🐾 背多分论坛 🐾 知差(chai)网 🐾 DeepracticeX 社区 🐾 老薛主机 🐾