前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >MLIR-Toy-实践-3-Dialect转换

MLIR-Toy-实践-3-Dialect转换

作者头像
hunterzju
发布2022-04-28 17:49:57
2.4K0
发布2022-04-28 17:49:57
举报
文章被收录于专栏:编译器开发

上篇文章为Toy添加了一个新Op(toy.or)表示逻辑或。本文介绍如何将OrOp降低到其他方言对应的Op,主要用到了RewritePatternConversionPattern相关的内容。

RewritePattern与ConversionPattern

MLIR是一种图类型的IR表示,而RewritePattern提供了一个图模式匹配的接口,可以更方便进行图优化。比如ToyTutorial-ch3中使用的优化pattern:将两个嵌套的transport转换为一个返回输入数据的节点。

Image

RewritePattern的实现有两种方式,一种是采用c++实现,需要定义一个转换结构体继承mlir::OpRewritePattern<TransposeOp>,并重写matchAndRewrite()方法,该方法中实现了IR结构的修改逻辑。比如上文中提到的Transpose逻辑优化,在transpose嵌套transpose操作时,两次转置操作抵消,直接返回输入参数。定义该Pattern后,创建一个标准化pass(在toy.cpp中实现),并将Pattern注册到该Pass中(在ToyCombine.cpp中实现)。

代码语言:javascript
复制
/// ToyCombine.cpp 定义pattern:transpose(transpose(x)) -> x
struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
  // ....
  mlir::LogicalResult
  matchAndRewrite(TransposeOp op,
                  mlir::PatternRewriter &rewriter) const override {
    // Look through the input of the current transpose.
    mlir::Value transposeInput = op.getOperand();
    TransposeOp transposeInputOp = transposeInput.getDefiningOp<TransposeOp>();

    // Input defined by another transpose? If not, no match.
    if (!transposeInputOp)
      return failure();

    // Otherwise, we have a redundant transpose. Use the rewriter.
    rewriter.replaceOp(op, {transposeInputOp.getOperand()});
    return success();
  }
};

// toy.cpp 创建pass
  mlir::PassManager pm(module.getContext());
  pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());

// ToyCombine.cpp 注册pattarn
void TransposeOp::getCanonicalizationPatterns(
    RewritePatternSet &results, MLIRContext *context) {
  results.add<SimplifyRedundantTranspose>(context);
}

// 在ops.td中声明允许标准化操作
let hasCanonicalizer = 1;

RewritePattern的另一种实现方式是采用DRR描述Pattern,然后通过TableGen来生成c代码。在ToyTutorial-ch3中采用DRR方式实现了Reshape操作的优化。DRR定义规则如下:

代码语言:javascript
复制
class Pattern<
    dag sourcePattern, list<dag> resultPatterns,
    list<dag> additionalConstraints = [],
    dag benefitsAdded = (addBenefit 0)>;

比如,定义对Reshape操作的优化如下,其中sourcePattern是(ReshapeOp(ReshapeOp arg)),resultPatterns是(ReshapeOp arg),约束Constraints和优先级benefits都省略没有定义:

代码语言:javascript
复制
// Reshape(Reshape(x)) = Reshape(x)
def ReshapeReshapeOptPattern : Pat<(ReshapeOp(ReshapeOp $arg)),
                                   (ReshapeOp $arg)>;

ConversionPattern是一种特殊的RewritePattern,用于实现Dialect之间的转换。在Dialect转换过程中,可能会对Operation中的操作数做修改,因而ConversionPatternRewritePattern一个主要区别是matchAndRewrite()接口函数中多了一个operands参数,用于对Operation中的操作数修改。

代码语言:javascript
复制
struct MyConversionPattern : public ConversionPattern {
  /// The `matchAndRewrite` hooks on ConversionPatterns take an additional
  /// `operands` parameter, containing the remapped operands of the original
  /// operation.
  virtual LogicalResult
  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                  ConversionPatternRewriter &rewriter) const;
};

ConversionPattern实现Dialect转换

上一篇文章中,我们给Toy Dialect添加了一个逻辑或操作OrOp,下文结合Conversion Pattern的使用记录下将Toy Dialect中的OrOp转换到其他dialect的过程。Dialect转换需要指定Conversion Target(目标方言)和Rewrite Patterns(匹配Operation)。

首先指定Conversion Target,这里将MLIR Dialect转换到Affine, MemRef and Standard 三种Dialect,为后续转换到可运行的LLVM Dialect做准备。具体实现在LowerToAffineLoops.cpp中,指定了合法和非法的Dialect以及Operation:

代码语言:javascript
复制
  ConversionTarget target(getContext());

  // We define the specific operations, or dialects, that are legal targets for
  // this lowering. In our case, we are lowering to a combination of the
  // `Affine`, `MemRef` and `Standard` dialects.
  target.addLegalDialect<AffineDialect, memref::MemRefDialect,
                         StandardOpsDialect>();

  // We also define the Toy dialect as Illegal so that the conversion will fail
  // if any of these operations are *not* converted. Given that we actually want
  // a partial lowering, we explicitly mark the Toy operations that don't want
  // to lower, `toy.print`, as `legal`.
  target.addIllegalDialect<toy::ToyDialect>();
  target.addLegalOp<toy::PrintOp>();

接下来指定转换匹配的Pattern,具体实现如上一节描述,先定义一个转换Pattern类,该类继承了ConversionPattern;然后重载其中的matchAndRewrite()方法来指定转换操作;接下来将这些Pattern添加到转换context中;最后执行转换。

代码语言:javascript
复制
  // 转换pattern定义
  struct TransposeOpLowering : public ConversionPattern {
    // ...
    LogicalResult
    matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                    ConversionPatternRewriter &rewriter)
    // ...
  }
  // 添加pattern到context
  RewritePatternSet patterns(&getContext());
  patterns.add<AddOpLowering, ConstantOpLowering, MulOpLowering,
               ReturnOpLowering, TransposeOpLowering>(&getContext());
  
  // 应用转换过程
  applyPartialConversion(getFunction(), target, std::move(patterns))

OrOp转换

将新添加的toy.or进行转换,需要实现一个转换Pattern,并将其添加到转换Context中。参考已经实现的AddMul操作,其都是将Toy Dialect先通过Affine Dialect将循环展开,然后转换到Standard Dialect中的对应Op。这里有一个问题是,在Standard Dialect中Add和Mul都有对应的浮点和整型操作,但是Or仅支持整型操作(这是符合运算逻辑的,对于整型逻辑或才有意义),但是输入数据是浮点型F64。因此,OrOp需要做一个浮点转整型的操作。同时由于后续操作都是在浮点上操作的,因此还需要将OrOp的结果操作数从整型转回浮点。

代码语言:javascript
复制
// 对OrOp添加Float转Int逻辑,并映射到Standard::OrOp
template <typename BinaryOp, typename LoweredBinaryOp>
struct BinaryOpLowering : public ConversionPattern {
  BinaryOpLowering(MLIRContext *ctx)
      : ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {}

  LogicalResult
  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                  ConversionPatternRewriter &rewriter) const final {
    auto loc = op->getLoc();
    lowerOpToLoops(
        op, operands, rewriter,
        [loc, op](OpBuilder &builder, ValueRange memRefOperands,
              ValueRange loopIvs) {
          // ...

          // 对toy.or添加float转int逻辑,利用standard中的FPToUIOp
          auto opname = op->getName();
          if (opname.getStringRef().str() == "toy.or") {
            auto castLhs = builder.create<FPToUIOp>(loc, builder.getI64Type(), loadedLhs);
            auto castRhs = builder.create<FPToUIOp>(loc, builder.getI64Type(), loadedRhs);
            
            return builder.create<LoweredBinaryOp>(loc, castLhs, castRhs);
          }
          // ....
        });
    return success();
  }
};
// OrOp pattern定义
using OrOpLowering = BinaryOpLowering<toy::OrOp, OrOp>;

//
static void lowerOpToLoops(Operation *op, ArrayRef<Value> operands,
                           PatternRewriter &rewriter,
                           LoopIterationFn processIteration) {
  // ...
  buildAffineLoopNest(
      rewriter, loc, lowerBounds, tensorType.getShape(), steps,
      [&](OpBuilder &nestedBuilder, Location loc, ValueRange ivs) {
        Value valueToStore = processIteration(nestedBuilder, operands, ivs);
        // 将"toy.or" Op的结果从Int转换为Float
        if(nestedBuilder.getI64Type() == valueToStore.getType()) {
          valueToStore = nestedBuilder.create<UIToFPOp>(loc, nestedBuilder.getF64Type(), valueToStore);
        }
        nestedBuilder.create<AffineStoreOp>(loc, valueToStore, alloc, ivs);
      });
  // ...
}

总结

MLIR中基于Pattern对IR图进行操作,提供了一个便捷且标准化的接口,带来了很大便利性,但是也增加了学习成本。Conversion Pattern提供了一套在Dialect间进行转换的通路,别且多个Dialect可以共存,有点类似于插件的感觉。由于目前了解有限,总感觉各个Dialect之间的抽象关系不是很明确,而且暂时没找到一个文档解释对各个Dialect的抽象层级进行一个比较系统的解释,可能是个有待改进的点吧。

另外,关于OrOp类型转换的问题,感觉处理得有点野路子,欢迎有想法的朋友指正。

本文使用 Zhihu On VSCode 创作并发布

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2021-12-12,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • RewritePattern与ConversionPattern
  • ConversionPattern实现Dialect转换
  • OrOp转换
  • 总结
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档