Ray DAG 源码解读

什么是DAG

Ray DAG(Directed Acyclic Graph)是 Ray 计算框架中用于表示任务依赖关系的有向无环图结构。它定义了一组计算任务(Tasks 或 Actors)及其依赖关系,以 DAG 形式组织任务的执行,简称计算图。

DAG的作用有:

  1. 延迟计算,与remote不同,bind方法仅会构建计算图,而不是立即执行。可以打包复杂的调用关系,然后直接执行图,DAG可以重复执行;
  2. 是实现workflow的核心组件;
  3. 是实现Accelerate DAG的核心组件。

DAG的使用方法

使用方法参考官方手册,DAG节点可以包含Class,ClassMethod,Function,Input, Output等类型的节点。

更直观的方式是将DAG的可视化,当前支持生成图片或者生成ascii格式的DAG示意图,如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import ray
import time
import random
from ray.dag.input_node import InputNode
from ray.dag.output_node import MultiOutputNode
ray.init(address = "10.218.163.33:6274")
import os
@ray.remote
class Actor:
def __init__(self, init_value, fail_after=None, sys_exit=False):
self.i = init_value
self.fail_after = fail_after
self.sys_exit = sys_exit

self.count = 0

def echo(self, x):
self.count += 1
return x

def sleep(self, x):
time.sleep(x)
return x

@ray.method(num_returns=2)
def return_two(self, x):
return x, x + 1

def read_input(self, x):
return x


a = Actor.remote(0)
b = Actor.remote(0)
single_fetch = True
with InputNode() as i:
o1, o2 = a.return_two.bind(i)
o3 = b.echo.bind(o1)
o4 = b.echo.bind(o2)
dag = MultiOutputNode([o3, o4])

compiled_dag = dag.experimental_compile()
for _ in range(3):
refs = compiled_dag.execute(1)
if single_fetch:
for i, ref in enumerate(refs):
res = ray.get(ref)
assert res == i + 1
else:
res = ray.get(refs)
assert res == [1, 2]
compiled_dag.visualize()
compiled_dag.teardown()
image-20250210140344597

DAGNode

DAG是由DAGNode组成,每个DAGNode代表一个Task或者Actor,图的边代表数据的依赖关系,由被依赖者指向依赖者。整个图从上到下分别是InputNode,各个任务Node以及OutputNode。

PlantUML diagram

DAGNode

DAGNode是对操作和其参数的一个封装,其包含实际执行的逻辑,例如,针对FunctionNode而言,就是函数对象,以及该函数的所有入参。DAGNode还会记录当前节点的上下游信息,并且能够按依赖顺序执行每一个Node。

PyObjScanner

DAGNode中有多个函数均使用到了PyObjScanner,这是利用pickle库来扫描给定对象中所有的DAGNodeBase类型的对象。比如,在获取节点的上游节点(所有依赖的节点)collect_upstream_nodes以及扫描所有children节点get_all_child_nodes时,就是通过扫描所有入参,来确定在执行当前Node之前,有哪些Node需要先完成计算。

apply_recursive

DAGNode执行的核心是这个递归执行函数,该函数从叶子节点开始,根据依赖关系递归的执行。apply_recursive提供了一个Cachingfn的内部类,该类能够缓存当前节点执行结果(也有可能是一个future对象)而避免节点被重复计算。当Cachingfn对象创建后,会将原来的执行函数替换成自己。

  • 如果apply_recursive不是第一次被调用,则直接返回cache中缓存的结果。

  • 如果是第一次被调用,首先根据传入的执行函数fn来生成Cachingfn对象,然后将fn函数替换成Cachingfn对象,该对象提供__call__方法,该方法是一个warpper,调用原始回调函数fn,并且记录缓存,并返回结果。

以FunctionNode为例,apply_recursive的回调函数是将FunctionNode中记录的函数构造成remote对象,然后直接调用remote方法。所以该回调函数的参数需要是一个数据对象,例如:标准的数据类型,或者是一个object_ref。但是DAGNode的arg是另外一个DAGNode,并不能直接被回调函数所执行,所以,在执行当前节点之前,需要对入参进行替换。

apply_and_replace_all_child_nodes

该函数就提供了替换DAGNode入参类型的功能。该函数接受一个回调函数作为参数,该回调函数即对为Node的apply_recursive替换也就是将当前的DAGNode,通过apply_recursive替换成当前节点执行完成后的结果。而这个apply_recursive函数会递归的处理子节点的所有子节点。

最终,会递归到InputNode,获取到InputNode的值,然后交给InputNode的后续节点计算,然后再将计算结果交个在后续的节点,直到执行完毕,执行完成后会返回一个object_ref,可以使用get方法获取实际的计算结果。

PR点:get_toplevel_child_nodes没有引用了,可以删除

FunctionNode

FunctionNode是DAGNode的子类,代表一个函数节点,该节点通过remote函数的.bind方法生成,除了基类的属性之外,FunctionNode还需要保存函数体。

execute_impl即调用该函数的remote方法:

1
2
3
4
5
6
7
8
9
10
11
12
13
def _execute_impl(self, *args, **kwargs):
"""Executor of FunctionNode by ray.remote().

Args and kwargs are to match base class signature, but not in the
implementation. All args and kwargs should be resolved and replaced
with value in bound_args and bound_kwargs via bottom-up recursion when
current node is executed.
"""
return (
ray.remote(self._body)
.options(**self._bound_options)
.remote(*self._bound_args, **self._bound_kwargs)
)

ClassNode和ClassMethodNode

ClassNode是DAGNode的子类,代表一个类,该节点通过remote类的.bind方法生成,该Node中保存了Class本身。ClassNode的execute_impl函数仅执行了remote方法,将对象在集群中实例化。

当使用该Node对应类的函数时,需要对这些函数在进行一次bind。ClassNode通过__getattr__获得其中Class的函数,该函数也提供了.bind方法,调用可以生成ClassMethodNode对象。在对ClassMethod进行bind是,还会记录actor的句柄。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class ClassNode:
def __getattr__(self, method_name: str):
# User trying to call .bind() without a bind class method
if method_name == "bind" and "bind" not in dir(self._body):
raise AttributeError(f".bind() cannot be used again on {type(self)} ")
# Raise an error if the method is invalid.
getattr(self._body, method_name)
call_node = _UnboundClassMethodNode(self, method_name, {})
return call_node

class _UnboundClassMethodNode:
def bind(self, *args, **kwargs):
other_args_to_resolve = {
PARENT_CLASS_NODE_KEY: self._actor,
PREV_CLASS_METHOD_CALL_KEY: self._actor._last_call,
}

node = ClassMethodNode(
self._method_name,
args,
kwargs,
self._options,
other_args_to_resolve=other_args_to_resolve,
)
self._actor._last_call = node
return node

ClassMethodNode也是DAGNode的子类,代表一个类的成员函数,分为class_method_call以及class_method_output两种类型,class_method_output类型是一个特殊的Node,该node仅作为存储ClassMethodNode的结果使用。ClassMethodNode的bind函数返回的就是一个class_method_output类型的tuple。class_method_output类型的Node执行即为返回其中实际存储的CLassMethod的返回结果。简单的说,就是将一个函数返回的tuple做了一个封装。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def _execute_impl(self, *args, **kwargs):
"""Executor of ClassMethodNode by ray.remote()

Args and kwargs are to match base class signature, but not in the
implementation. All args and kwargs should be resolved and replaced
with value in bound_args and bound_kwargs via bottom-up recursion when
current node is executed.
"""
if self.is_class_method_call:
method_body = getattr(self._parent_class_node, self._method_name)
# Execute with bound args.
return method_body.options(**self._bound_options).remote(
*self._bound_args,
**self._bound_kwargs,
)
else:
assert self._class_method_output is not None
return self._bound_args[0][self._class_method_output.output_idx]

PR点:其实这个output类型的Node仅作为数据的存储,和ClassMethodNode混在一起比较难理解,不如搞一个新的Node类型

InputNode和InputAttributeNode

InputNode是DAGNode的子类,代表一组输入,InputAttributeNode也是DAGNode的子类,代表一组输入中的一个元素,可以通过下标或者.操作符来获取。

1
2
with ray.dag.InputNode as inp:
x = fun.bind(inp)

上述代码是使用整个input作为fun函数的入参,这个参数就是InputNode;

1
2
3
with ray.dag.InputNode as inp:
x = fun1.bind(inp[0])
y = fun2.bind(inp.x)

这段代码中,fun1和fun2使用了input中的部分数据,这里的参数就是InputAttributeNode。

InputNode的execute_impl比较简单,就是返回args,InputAttributeNode内部有InputNode的引用,通过__get_attr__以及__get_item__函数来获取InputNode中的arg。

MultiOutputNode

MultiOutputNode是DAGNode的子类,可以储存多个输出。比如,我们想执行一个计算图,但是希望获取超过1个的结果,这时就需要构建一个MultiOutputNode。

MultiOutputNode的execute_impl就是将入参转成一个list返回。

CollectiveOutputNode

CollectiveOutputNode是ClassMethodNode的子类,这个Node不能被执行,所以仅用于计算图编译中。该类使用other_args_to_resolve传入一个_CollectiveOperation。

_CollectiveOperation目前仅支持allReduce,该类提供了nccl_group的初始化功能,将all_reduce的节点放到一个group中。其execute方法就是执行group的all_reduce方法

1
2
3
4
5
6
7
8
9
type_hint = self._type_hint
if type_hint.communicator_id is not None:
return type_hint.communicator_id
if communicator_id is None:
communicator_id = _init_communicator(
self._actor_handles, type_hint.get_custom_communicator()
)
type_hint.set_communicator_id(communicator_id)
return communicator_id
1
2
3
4
5
6
7
8
import torch

if not isinstance(send_buf, torch.Tensor):
raise ValueError("Expected a torch tensor")
communicator = self.get_communicator()
recv_buf = torch.empty_like(send_buf)
communicator.allreduce(send_buf, recv_buf, self._op)
return recv_buf