关心训练加速,以及这个领域里pytorch做的怎么样?

  • 有哪些组件(mlir, lazy tensor core, torch script, jit, fx)

  • 这些组件间关系是什么(是替代品?是上下游关系?)

基本概念

  • 计算图中间表示 (Intermediate Representation):简称为 IR,程序编译过程中,源代码与目标代码之间翻译的中介,IR的设计对编译器来说非常关键,好的IR要考虑从源代码到目标代码编译的完备性、编译优化的易用性和性能
image-20220802112036905
  • 张量计算图:节点代表张量,边代表函数关系,DAG图。叶节点(源),输出节点(汇),计算图是动态的,通过可微分张量计算过程自动生成

  • LLVM:

  • JIT:内生于PyTorch框架的一种DSL和Compiler栈的集合,目标是为了PyTorch的使用者也可以拥有便携、高性能执行模型推理的方法。pytorch本身是一个eager模式设计的深度学习框架,易于debug和观察,但是不利于性能的解耦和优化。PyTorch JIT模式就是一种将原本eager模式的表达转变并固定为一个计算图,便于进行优化和序列化。(向TF靠拢)

  • torchscript:Pytorch中的IR,JIT模式的具体形式,是一种从PyTorch代码创建可序列化和可优化模型的方法。任何TorchScript程序都可以从Python进程中保存,并加载到没有Python依赖的进程中。

  • Eager 模式:Python + Python runtime。这种模式是更 Pythonic 的编程模式,可以让用户很方便的使用 python 的语法来使用并调试框架。

    • 易用性、表达性和可调试性是 PyTorch 的核心原则。易用性的关键驱动因素之一是 PyTorch 执行默认是Eager,即逐个操作执行保留了程序的命令性。但是,Eager Execution 不提供基于编译器的优化,例如,当计算可以表示为图形时的优化。
  • Script 模式:TorchScript + PyTorch JIT。这种模式会对 eager 模式的模型创建一个中间表示(intermediate representation,IR),这个 IR 经过内部优化的,并且可以使用 PyTorch JIT 编译器去运行模型,不再依赖 python runtime,也可以使用 C++ 加载运行。

模型分析

之前在UP-mb15上运行使用trace、script、summary的记录:http://go.d5.sensetime.com/sunye/202205,给出分析模型的对比表:

获取方式 str summary trace script
描述 中序遍历模型结构,从根节点到叶节点,每个节点是一个nn.Module 输入一个tensor,对运行过程中实际走过的nn.Module执行hook函数 需要输入来让模型 forward 一遍,以通过该输入的流转路径,获得,从而得到原本模型 forward 的 IR。不适合适合解决带有if和for-loop的情况 将一个Python代码通过语法分析编译为TorchScript代码(IR),继而导出相应的可优化可部署模型。只支持python代码的子集
关键函数(调用栈) print(model) nn.Moule._str_ model.modules() summary(model, input_size=input_size, device=‘cuda’) model.apply(register_hook) graph(model, *input_to_model, verbose)
torch.jit.trace(model, *args)
torch._C._create_function_from_trace(…)
torch.jit.script(model)
数量 1616 Module 1557 Module 7517 node
输出 模型树形结构,包含模型参数,如shape, kernel_size, stride, padding, bias 线性结构,记录实际运行过程每个module的属性:input_shape, output_shape, trainable, nb_params 部分计算图(及其参数),onnx 完整计算图(及其参数),可部署模型
文件大小 93KB 165KB 仅计算图结构:1.4M
计算图参数+权重:4.2G
4.2G
耗时 0.14s 7.55s (较慢) —(运行失败)

trace与script调用栈

下面以缩进的方式梳理一下trace与script的调用流程,从上到下为顺序执行,缩进为进一步调用。仅保留最关键的部分

trace

把训练后得到的 eager 模型以及模型需要的输入数据作为接口输入,然后 tracer 会把数据在 eager 模型里运行一次。trace只记录走过的tensor和对tensor的操作,不会记录任何控制流信息,如if条件句和循环。因为没有记录控制流的另外的路,也没办法对其进行优化。好处是trace深度嵌入python语言,复用了所有python的语法,在计算流中记录数据流。

torch.jit.trace(model, *args) 	# 将model作为参数func传入
	if isinstance(func, torch.nn.Module): trace_module(func.(__self__), {"forward": example_inputs},... )	# 如果func是nn.Module或者forward函数,直接追踪forward
	torch._C._create_function_from_trace(name, func, example_inputs, ...)	# C++入口,func作为参数traced_fn传入
    	auto state = std::make_shared<TracingState>();	# state这个数据结构会在forward过程中存储trace到的计算过程
        auto out_stack = traced_fn(inputs);	 # 开始forward,在计算发生时,会把计算记录到state中
        	createOperatorFromC10_withTracingHandledHere(op, input_size, output_size)	# 具体记录 operation 的过程
            	tracer_state = jit::tracer::getTracingState();	# 获取 tracer_state
                auto symbol = Symbol::fromQualString(op.schema().name());	# 得到计算类型(比如相加)
                # 创建一个计算节点,然后创建计算输入,最后把计算节点insert 到 graph 中,完成一次对计算的记录
                node = graph->create(symbol, 0);
                Value* none = graph->insertNode(graph->createNone())->output();
                node->addInput(none);
                graph->insertNode(node);

script

script会去理解所有的code,真正像一个编译器一样去进行lexer、parser、Semantic analusis的分析「也就是词法分析、语法分析、句法分析,形成AST树,最后再将AST树线性化」。

  • script相当于一个嵌入在Python/Pytorch的DSL,其语法只是pytorch语法的子集,这意味着存在一些op和语法script不支持,编译时会遇到问题。
  • 会省略常量节点,并需要类型转换,如果没有类型提供则默认是 Tensor 类型。
  • script的编译优化方式更像是CPU上的传统编译优化,重点对于图进行硬件无关优化,并对IF、loop这样的statement进行优化。
torch.jit.script(model)		# 将model作为参数obj传入
	ast = get_jit_def(obj, obj.__name__)	# 得到ast语法树
    	dedent_src = dedent(source)		# dedent_src 为包含了要script函数的字符串
    	py_ast = ast.parse(dedent_src)	# 调用python ast包将字符串解析为Python的ast
        ctx = SourceContext(source, filename, file_lineno, leading_whitespace_len, True)	# ctx中包含了函数所有原信息
    	fn_def = py_ast.body[0]		# Python 解析出的 AST
        return build_def(ctx, fn_def, type_line, def_name, self_name=self_name)	# 将python 的ast 转化为torchjit 使用的ast格式
    		return Def(Ident(r, def_name), decl, build_stmts(ctx, body))	# ctx 包含 source code 所有信息, body是 Python ast 解析结果, 那么build_stmts是关键步骤
    			build_stmts(ctx, body)	# 这里将body作为参数stmts传入
        			return [build_stmt(ctx, s) for s in stmts]
        				build_stmt = StmtBuilder()	# `build_stmt` 是一个StmtBuilder()的instance
            			def __call__(self, ctx, node)	# 根据解析出的ast的类型返回相应的build方法,从截图可以看到`a+2`是一个`Assign`类型,因此会调用build_Assign
    fn = torch._C._jit_script_compile(qualified_name, ast, _rcb, get_default_args(obj))	# c++ 入口,根据ast得到ir,将ast作为参数def传入,跟着def找调用流程
 		auto cu = get_python_cu();
		auto defined_functions = cu->define(QualifiedName(name.prefix()),..., {def}, {pythonResolver(rcb)}, ...)	# 实现在torch/csrc/jit/frontend/ir_emitter.cpp
        to_ir(def, _resolver, self, method);	# 跟随 def,找到了一个转化为 IR 的关键的struct to_ir ,其输入中有 def,也就是 ast,_resolver 是 Python 中传过来的解析名字的函数,我们可以在内部找到关键部分
        	method.setSchema(emitDef(def, self, graph->block()));
            	auto stmts_list = def.statements();
            	emitStatements(stmts_list.begin(), stmts_list.end()); 	# emitDef 中会调用emitStatements
                	# 根据stmt.kind(),会进入而各种emit里面,其中一定可以找到graph->insertNode(graph->create(.....)); 类似的操作,对应我们建立IR graph
                    switch (stmt.kind()) {case TK_IF: emitIf(If(stmt));...}		

ast.parse样例

results = ast.parse("""def test(a):
     a = a + 2
     return a + 1""")

image-20220802153344606

  • ast.body 是一个 list,其长度等于解析的 string 中包含的函数的个数

    • 元素0为FunctionDef,具体为一个Add,left 是Name类型,id a,right是Num,也就是2,这个Binop即解析的a = a + 2
    • 元素1为Return,其中 value 是一个Binop
  • 因为dedent返回的一定是一个 single top-level function, 因此直接取用第 0 个元素,即 py_ast.body[0] 就可以了。

trace与script混用

  • script存在一些op和语法不支持,不会记录常量节点
  • trace不会记录任何控制流信息,如if条件句和循环

为了避免上述问题,将两者同时使用:

import torch
import torch.nn as nn
import torch.nn.functional as F

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        # torch.jit.trace produces a ScriptModule's conv1 and conv2
        self.conv1 = torch.jit.trace(nn.Conv2d(1, 20, 5), torch.rand(1, 1, 16, 16))
        self.conv2 = torch.jit.trace(nn.Conv2d(20, 20, 5), torch.rand(1, 20, 16, 16))

    def forward(self, input):
        input = F.relu(self.conv1(input))
        input = F.relu(self.conv2(input))
        return input

scripted_module = torch.jit.script(MyModule())

官方示例:

class MyRNNLoop(torch.nn.Module):
    def __init__(self):
        super(MyRNNLoop, self).__init__()
        self.cell = torch.jit.trace(MyCell(scripted_gate), (x, h))

    def forward(self, xs):
        h, y = torch.zeros(3, 4), torch.zeros(3, 4)
        for i in range(xs.size(0)):
            y, h = self.cell(xs[i], h)
        return y, h

rnn_loop = torch.jit.script(MyRNNLoop())
print(rnn_loop.code)

Graph Executor

进行IR的优化和pass变换,作用:

  1. 优化Graph结构
  2. 维护DCG和反向图
  3. 产生类汇编代码并压栈

LazyTensor

LazyTensor 是 PyTorch 中一个全新的跟踪系统。它包括其他跟踪系统 (jit.trace) 不提供的安全保证,因为如果输入的属性发生更改,它会回溯并重新编译,否则会使用缓存的计算。比jit好用得多!Lazy Tensor 跟踪前向和后向传递,并删除了 jit 脚本和跟踪图中存在的许多硬件供应商难以支持的 Python 特性。解决Script和Trace在可用性方面的限制。LazyTensor有些像是一种被动的,adaptive的trace机制

使用LT能够做什么?

MLIR

https://github.com/llvm/torch-mlir

https://github.com/llvm/torch-mlir/blob/main/docs/ltc_examples.md

MLIR发布会:https://www.youtube.com/watch?v=qzljG6DKgic

image-20220802115306515

image-20220802115455711

image-20220802115602193

image-20220802115707050

image-20220802115731197