作者:安平博,Xilinx高级工程师;来源:AI加速微信公众号
什么是pass?
Pass是TVM中基于relay IR进行的优化,目的是去除冗余算子,进行硬件友好的算子转换,最终能够提高硬件运行效率。由tensorflow等深度学习框架生成的图机构中,含有很多可以优化的算子,比如expand_dim,len等,其实在编译阶段完全可以优化掉,从而能够减少硬件的计算,以及避免出现硬件不支持的算子。
TVM中在include/tvm/ir/transform.h中对pass进行了抽象,主要包括PassContext,PassInfo,Pass,以及Sequential。其中PassContext包含了pass执行依赖的一些参数,比如优化level,analysis report等。PassInfo是一个用于记录pass信息的类,包括pass的opt-level,名称等。和PassContext的区别是PassContext是pass执行所需要获取的条件。Pass就是执行pass的主体,主要就是pass的函数。比如RemoveUnusedFunctions就是执行pass的一个主体函数,目的就是去除冗余算子。Sequential是一个container,装载所有pass。
一些pass
01. RemoveUnusedFunctions
位于src/relay/backend/vm/removed_unused_funcs.cc中,顾名思义就是去除relay IR中的冗余函数。通过从main函数开始遍历,如果一个函数体没有引用其它函数,而同时又没有被其它函数调用,即从relay图上看是一个孤立算子,那么就从IRModule中删除。
void VisitExpr_(const FunctionNode* func_node) final { auto func = GetRef<Function>(func_node); if (visiting_.find(func) == visiting_.end()) { visiting_.insert(func); for (auto param : func_node->params) { ExprVisitor::VisitExpr(param); } ExprVisitor::VisitExpr(func_node-> body); } }
02. ToBasicBlockNormalForm
函数在文件src/relay/trnaforms/to_basic_block_normal_from.cc中。通过遍历IRModule中的每个function,将每个function转换为基本块形式。转换函数是ToBasicBlockNormalFormAux。这个函数包括两个步骤:一是找到基本块(basic block)的边界,TVM中对边界进行了一步抽象,判断每个expr是否属于同一个scope,如果scope相同那么就可以将这些表达式放在一个基本块中;第二步根据每个表达式所属的scope将表达式归属到一个基本块中。
Expr ToBasicBlockNormalFormAux(const Expr& e) { // calculate all the dependency between nodes. support::Arena arena; DependencyGraph dg = DependencyGraph::Create(&arena, e); /* The scope of the whole expr is global. * The scope of any subexpr, is the lowest common ancestor of all incoming edge. * We also record the set of expressions whose scope is lifted. */ std::pairscopes = CalcScope(dg); return Fill::ToBasicBlockNormalForm(e, dg, &scopes.first, &scopes.second); }
DependencyGraph是一个表达式相互依赖的图结构,通过遍历图中每个节点,找到每个节点的scope。CalcScope在文件src/relay/transforms/to_a_normal_from.cc中。这个函数中重点关注以下代码:
… s = LCA(s, expr_scope.at(iit->value)); … if (n->new_scope) { auto child_scope = std::make_shared<ScopeNode>(s); expr_scope.insert({n, child_scope}); } else { expr_scope.insert({n, s}); }
LCA是获得当前节点的父节点的scope的LCA(least common ancestor),然后将这个scope作为这个节点的scope。了解基本块原理的都知道,寻找基本块首先要找到首指令的位置,然后一个首指令到下一个首指令之间的指令就属于一个基本块。而首指令就是那些具有条件和无条件跳转的指令。在TVM中通过new_scope来标记这些节点,比如Ifnode,FunctionNode,LetNode在建立dependency图的时候,这些节点就被标记为new_scope。这样就建立了dependency节点到scope节点的对应map。同时scope节点也被建立起树结构。
接下来就是建立Fill类,这个类中包含了dependency图以及scope的信息,通过其函数ToBasicBlockNormalForm实现基本块转换。它的基本逻辑通过VisitExpr函数遍历dependency节点,将具有相同scope的节点压入到同一个let_list中。Let_list文档中是这样解释的:
/*! * \file let_list.h * \brief LetList record let binding and insert let expression implicitly. * using it, one can treat AST as value instead of expression, * and pass them around freely without fear of AST explosion (or effect duplication). * for example, if one write 'b = a + a; c = b + b; d = c + c', the AST will contain 8 'a'. * if one instead write 'b = ll.Push(a + a); c = ll.Push(b + b); d = ll.Get(c + c);', * the AST will contain 2 'a', as b and c are now variables.
Let_list使得抽象语法树简洁化,不会因为变量的复制导致树的爆炸。具有相同的scope的expr被约束到相同的let_list中,用一个var来表达,这样就将表达式转化为var的形式。一个var也就对应了一个基本块。
03. Legalize
Legalize是实现等价函数的转换。主要代码在src/relay/transforms/legalize.cc中。主函数是:
Expr Legalize(const Expr& expr, const std::string& legalize_map_attr_name) { auto rewriter = Legalizer(legalize_map_attr_name); return PostOrderRewrite(expr, &rewriter); }
在legalize.cc文件中定义了一个继承了ExprRewriter的类,在这个类中实现了对function的替换。我们追踪一下调用的过程。PostOrderRewrite在文件src/relay/ir/expr_functor.cc中。首先建立一个PostOrderRewriter类,然后访问每个节点。在访问节点过程中调用了ExpandDataFlow函数,看一下这个函数的描述:
* * ExpandDataflow manually manages a stack and performs DFS to determine the processing * order of nodes in an input graph. * * If it finds a dataflow node (Call, Tuple, TupleGetItem), it checks if the arguments to that node * need to be processed via fcheck_visited. If so, the function pushes those arguments to the stack * and continues iteratively to process the top of the stack. When it finds a node that doesn't * match the dataflow types, or a node who's inputs have all been processed, it visits the current * leaf via fvisit_leaf. * * This function should be used internally to other classes to implement mixed-mode traversals. The * expectation is that fvisit_leaf will perform recursive analysis within mixed-mode traversal if it * hits a non-dataflow node. * * fcheck_visited and fvisit_leaf are templated to encourage compiler inlining. */
主要目的是有区别的去处理graph中的节点,如果fcheck_visited已经确定该节点处理过或者不需要处理,就跳过,通过fvisit_leaf继续访问下一个节点。而在VisitLeaf函数中就调用了legalizer类中的rewrite_函数实现了legalize功能。在Rewrite_中,通过映射表legalize_map_attr_name实现函数的等价转换。
04. SimplifyInference
实现对batch normalization, layer normalization, instance normalization, group normalization, L2 normalization算子的分解,这样做的目的是可以在之后的优化中,将这些算子融合到其它算子上,减少计算量。代码在src/relay/transforms/simplify_inference.cc中。文件中定义了一个InferenceSimplifier类来处理这个问题。看一下这几个normalization的公式:
1 BN:
2 LN:获得均值和方差是基于同一层不同神经元的数据。归一化公式相同。
3 GN: 将每个输入样本沿着通道进行分组,在每个组内进行归一化。
4 IN:对每个通道的数据进行归一化。
来看一下bacth normalization的处理代码:
Expr BatchNormToInferUnpack(const Attrs attrs, Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr moving_var, Type tdata) { auto ttype = tdata.as<TensorTypeNode>(); CHECK(ttype); const auto param = attrs.as< BatchNormAttrs>(); Expr epsilon = MakeConstantScalar(ttype->dtype, static_cast<float>(param->epsilon)); Expr var_add_eps = Add(moving_var, epsilon); Expr sqrt_var = Sqrt(var_add_eps); Expr scale = Divide(MakeConstantScalar(ttype->dtype, 1.0f), sqrt_var); if (param->scale) { scale = Multiply(scale, gamma); } Expr neg_mean = Negative(moving_mean); Expr shift = Multiply(neg_mean, scale); if (param->center) { shift = Add(shift, beta); } auto ndim = ttype->shape.size(); int axis = (param->axis < 0) ? param->axis + ndim : param->axis; scale = ExpandBiasToMatchAxis(scale, ndim, {axis}); shift = ExpandBiasToMatchAxis(shift, ndim, {axis}); Expr out = Multiply(data, scale); out = Add(out, shift); return out; }
可以看到就是将batch norm算子分解成最基本的加减乘除算子。
05. EliminateCommonSubexpr
顾名思义,这个pass的目的是消除公共子表达式。公共子表达式类似这种:
a=b+c
d=b+c
两个表达式具有相同的op,同时又有相同的args,而且args的顺序也一样。那么就可以用一个表达式替换。
这个pass的实现在文件src/relay/transforms/eliminate_common_subexpr.cc中。TVM定义了类CommonSubexprEliminator来处理。重载函数Rewrite_实现了对expr的遍历和重写操作。
Expr Rewrite_(const CallNode* call, const Expr& post) final { … if (new_call->args.size() == 0 || op == nullptr || op_stateful.get(GetRef< Op>(op), false)) { return new_expr; } if (fskip_ != nullptr && fskip_(new_expr)) { return new_expr; } auto it = expr_map_.find(new_call->op); if (it != expr_map_.end()) { for (const Expr& candidate_expr : it->second) { if (const CallNode* candidate = candidate_expr.as< CallNode>()) { bool is_equivalent = true; if (!attrs_equal(new_call->attrs, candidate->attrs)) { continue; } for (size_t i = 0; i < new_call->args.size(); i++) { if (!new_call->args[i].same_as(candidate->args[i]) && !IsEqualScalar(new_call->args[i], candidate->args[i])) { is_equivalent = false; break; } } if (!is_equivalent) continue; return GetRef<Call>(candidate); } } } expr_map_[new_call->op].push_back(new_expr); return new_expr; }
使用一个expr_map_映射记录已经遍历过的具有相同op的expr,之后每次遇到相同的op都会对已经记录的expr进行匹配,匹配包括attrs以及args,如果二者都一样的话,证明就是公共子表达式。
没有看过的pass
以上是实现相对简单的pass,TVM中还实现了其它很多pass,就没有一一去读代码了。以后看需要再去读吧。现在做一些罗列:
1 SimplifyExpr
简化一些表达式,具体如何进行简化需要读代码了。
2 CombineParallelConv2D
合并多分支并行的conv2d运算,理解是对多个batch的conv2d进行合并。
3 CombineParalleleDense
将多个batch的dense操作合并为一个batch_matmul操作。
4 CombineParallelBatchMatmul
对多个并行的batch_mamul再进行合并。
这几个combine操作可能是针对GPU器件的一个多数据并行性的优化。
5 FoldConstant
典型的一个常量合并优化。
6 FoldScaleAxis
包含了ForwardFoldScaleAxis和backwardFoldScaleAxis,主要是将scale参数合并到conv/dense操作的权重参数中。
7 CanonicalizeCast
官方解释是: Canonicalize cast expressions to make operator fusion more efficient。理解是对一些cast操作规范化,就是让复杂的cast操作可以更简洁。
8 CanonicalizeOps
规范化一些算子,比如bias_add能够被表示为expand_dims和broadcast_add操作。
…