Ray DAG 源码解读
什么是DAG
Ray DAG(Directed Acyclic Graph)是 Ray 计算框架中用于表示任务依赖关系的有向无环图结构。它定义了一组计算任务(Tasks 或 Actors)及其依赖关系,以 DAG 形式组织任务的执行,简称计算图。
DAG的作用有:
- 延迟计算,与remote不同,bind方法仅会构建计算图,而不是立即执行。可以打包复杂的调用关系,然后直接执行图,DAG可以重复执行;
- 是实现workflow的核心组件;
- 是实现Accelerate DAG的核心组件。
DAG的使用方法
使用方法参考官方手册,DAG节点可以包含Class,ClassMethod,Function,Input, Output等类型的节点。
更直观的方式是将DAG的可视化,当前支持生成图片或者生成ascii格式的DAG示意图,如下所示:
1 | import ray |

DAGNode
DAG是由DAGNode组成,每个DAGNode代表一个Task或者Actor,图的边代表数据的依赖关系,由被依赖者指向依赖者。整个图从上到下分别是InputNode,各个任务Node以及OutputNode。
DAGNode
DAGNode是对操作和其参数的一个封装,其包含实际执行的逻辑,例如,针对FunctionNode而言,就是函数对象,以及该函数的所有入参。DAGNode还会记录当前节点的上下游信息,并且能够按依赖顺序执行每一个Node。
PyObjScanner
DAGNode中有多个函数均使用到了PyObjScanner,这是利用pickle库来扫描给定对象中所有的DAGNodeBase类型的对象。比如,在获取节点的上游节点(所有依赖的节点)collect_upstream_nodes
以及扫描所有children节点get_all_child_nodes
时,就是通过扫描所有入参,来确定在执行当前Node之前,有哪些Node需要先完成计算。
apply_recursive
DAGNode执行的核心是这个递归执行函数,该函数从叶子节点开始,根据依赖关系递归的执行。apply_recursive
提供了一个Cachingfn
的内部类,该类能够缓存当前节点执行结果(也有可能是一个future对象)而避免节点被重复计算。当Cachingfn
对象创建后,会将原来的执行函数替换成自己。
如果
apply_recursive
不是第一次被调用,则直接返回cache中缓存的结果。如果是第一次被调用,首先根据传入的执行函数fn来生成
Cachingfn
对象,然后将fn函数替换成Cachingfn
对象,该对象提供__call__
方法,该方法是一个warpper,调用原始回调函数fn,并且记录缓存,并返回结果。
以FunctionNode为例,apply_recursive
的回调函数是将FunctionNode中记录的函数构造成remote对象,然后直接调用remote方法。所以该回调函数的参数需要是一个数据对象,例如:标准的数据类型,或者是一个object_ref。但是DAGNode的arg是另外一个DAGNode,并不能直接被回调函数所执行,所以,在执行当前节点之前,需要对入参进行替换。
apply_and_replace_all_child_nodes
该函数就提供了替换DAGNode入参类型的功能。该函数接受一个回调函数作为参数,该回调函数即对为Node的apply_recursive
替换也就是将当前的DAGNode,通过apply_recursive
替换成当前节点执行完成后的结果。而这个apply_recursive
函数会递归的处理子节点的所有子节点。
最终,会递归到InputNode,获取到InputNode的值,然后交给InputNode的后续节点计算,然后再将计算结果交个在后续的节点,直到执行完毕,执行完成后会返回一个object_ref,可以使用get方法获取实际的计算结果。
PR点:get_toplevel_child_nodes没有引用了,可以删除
FunctionNode
FunctionNode是DAGNode的子类,代表一个函数节点,该节点通过remote函数的.bind
方法生成,除了基类的属性之外,FunctionNode还需要保存函数体。
execute_impl
即调用该函数的remote方法:
1 | def _execute_impl(self, *args, **kwargs): |
ClassNode和ClassMethodNode
ClassNode是DAGNode的子类,代表一个类,该节点通过remote类的.bind
方法生成,该Node中保存了Class本身。ClassNode的execute_impl
函数仅执行了remote方法,将对象在集群中实例化。
当使用该Node对应类的函数时,需要对这些函数在进行一次bind
。ClassNode通过__getattr__
获得其中Class的函数,该函数也提供了.bind
方法,调用可以生成ClassMethodNode对象。在对ClassMethod进行bind是,还会记录actor的句柄。
1 | class ClassNode: |
ClassMethodNode也是DAGNode的子类,代表一个类的成员函数,分为class_method_call以及class_method_output两种类型,class_method_output类型是一个特殊的Node,该node仅作为存储ClassMethodNode的结果使用。ClassMethodNode的bind函数返回的就是一个class_method_output类型的tuple。class_method_output类型的Node执行即为返回其中实际存储的CLassMethod的返回结果。简单的说,就是将一个函数返回的tuple做了一个封装。
1 | def _execute_impl(self, *args, **kwargs): |
PR点:其实这个output类型的Node仅作为数据的存储,和ClassMethodNode混在一起比较难理解,不如搞一个新的Node类型
InputNode和InputAttributeNode
InputNode是DAGNode的子类,代表一组输入,InputAttributeNode也是DAGNode的子类,代表一组输入中的一个元素,可以通过下标或者.
操作符来获取。
1 | with ray.dag.InputNode as inp: |
上述代码是使用整个input作为fun函数的入参,这个参数就是InputNode;
1 | with ray.dag.InputNode as inp: |
这段代码中,fun1和fun2使用了input中的部分数据,这里的参数就是InputAttributeNode。
InputNode的execute_impl
比较简单,就是返回args,InputAttributeNode内部有InputNode的引用,通过__get_attr__
以及__get_item__
函数来获取InputNode中的arg。
MultiOutputNode
MultiOutputNode是DAGNode的子类,可以储存多个输出。比如,我们想执行一个计算图,但是希望获取超过1个的结果,这时就需要构建一个MultiOutputNode。
MultiOutputNode的execute_impl
就是将入参转成一个list返回。
CollectiveOutputNode
CollectiveOutputNode是ClassMethodNode的子类,这个Node不能被执行,所以仅用于计算图编译中。该类使用other_args_to_resolve传入一个_CollectiveOperation。
_CollectiveOperation目前仅支持allReduce,该类提供了nccl_group的初始化功能,将all_reduce的节点放到一个group中。其execute
方法就是执行group的all_reduce方法
1 | type_hint = self._type_hint |
1 | import torch |