前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >MLIR_对自定义IR Dialect编写bufferization pass

MLIR_对自定义IR Dialect编写bufferization pass

作者头像
BBuf
发布2024-07-01 14:14:36
490
发布2024-07-01 14:14:36
举报
文章被收录于专栏:GiantPandaCVGiantPandaCV

最近在整理先前实习做的一些工作,主要是对AI compiler做基于mlir的重构,以下是之前写的compiler frontend的一个比较基础的pass,针对自定义的IR Dialect做bufferization。

一、bufferization概念

Bufferization 是MLIR中一个重要的过程,它主要负责将具有tensor(张量)语义的操作转换为具有memref(内存引用)语义的操作。

  • Tensor在MLIR中代表抽象值类型的数据序列,它们并不直接对应于内存中的位置。
  • MemRef(Memory Reference)则代表对内存区域的具体引用,提供了更低级别的缓冲区访问能力。
  • Bufferization将tensor的语义转换为memref的语义,memref提供了更直接、更具体的内存访问方式,减少了编译器需要处理的抽象层次。

二、实现

以下是在XPU上自定义TIR的一个conv2d mlir的示意 pass的功能就是实现将func和op的tensor type转为memref type(TIR->MTIR),实现共包含两个pass,六个pattern!

代码语言:javascript
复制
module {
  func.func @XPUFunc(%arg0: tensor<1x8x8x256xf32>) -> tensor<1x4x4x256xf32> attributes {input_names = ["data0"], input_num = 1 : i64, output_names = ["conv0_fix"]} {
    %0 = "tir.const"() {value = dense_resource<__elided__> : tensor<256x2x2x256xi8>} : () -> tensor<256x2x2x256xi8>
    %1 = "tir.const"() {value = dense_resource<__elided__> : tensor<256xi8>} : () -> tensor<256xi8>
    %2 = "tir.float2fix"(%arg0) {bit_width = 8 : i32, fix_point = 0 : i32, if_signed = true, op_name = "data0_fix", round_mode = "XPU_ROUND"} : (tensor<1x8x8x256xf32>) -> tensor<1x8x8x256xi8>
    %3 = "tir.conv2d-fix"(%2, %0, %1) {dilation = [1 : i32, 1 : i32], group = 1 : i32, hsigmoid_in = -128 : i32, kernel = [2 : i32, 2 : i32], nonlinear = "NONE", op_name = "conv0", pad = [0 : i32, 0 : i32, 0 : i32, 0 : i32], pad_mode = "FLOOR", shift_hsigmoid = -128 : i32, shift_hswish = -128 : i32, stride = [2 : i32, 2 : i32]} : (tensor<1x8x8x256xi8>, tensor<256x2x2x256xi8>, tensor<256xi8>) -> tensor<1x4x4x256xi8>
    %4 = "tir.fix2float"(%3) {bit_width = 8 : i32, fix_point = 0 : i32, if_signed = true, op_name = "conv0_fix", round_mode = "XPU_ROUND"} : (tensor<1x4x4x256xi8>) -> tensor<1x4x4x256xf32>
    return %4 : tensor<1x4x4x256xf32>
  }
}

ODS自定义OP .td写法示例

代码语言:javascript
复制
include "Tir_op_base.td"
def Tir_ConstOp :
    Tir_Op<"const", [ConstantLike, Pure, FirstAttrDerivedResultType]> {
  let summary = "Represent a constant tensor with values";
  let description = [{
    The constant operator providing initialized values for tensors.

    The initial values come either in `DenseElementsAttr` `value`, or from an
    external binary file specified in `path`.
  }];
  let arguments = (ins
    OptionalAttr<ElementsAttr>:$value
  );
  let results = (outs Tir_Tensor:$output);
  let hasFolder = 1;
}
...
...
2.1global_bufferize pass

实现分为两步pass,第一步为global_bufferize pass,即将func的argument和return的tensor type转为memref。代码和注释如下所示

代码语言:javascript
复制
/// @brief Early bufferization on global input/output and constants
class GlobalBufferize : public impl::GlobalBufferizeBase<GlobalBufferize> {
public:
  void runOnOperation() override { // 重写基类的runOnOperation函数
    auto *ctx = &getContext(); 
    //获取上下文,FuncOp的成员函数,用于后续创建新的Op、添加转换规则

    ConversionTarget target(*ctx);
    //ConversionTarget 用于指定在转换过程中哪些Op是合法的,哪些是需要动态检查的。
    target.addDynamicallyLegalOp<tir::ConstOp>([](Operation *op) {
      auto ttype = op->getResult(0).getType().cast<RankedTensorType>();
      return ttype.getRank() == 0;
    }); //tir.ConstOp返回维度数(秩)是0的时候也就是标量,才合法转换 //不然就转为memex.const
    target.addLegalOp<memex::ConstOp>(); //静态合法,不需要转换
    target.addLegalOp<tir::UpLoadOp>();
    target.addLegalOp<tir::DownLoadOp>();
    target.addDynamicallyLegalOp<mlir::func::ReturnOp>( 
        [](Operation *op) { return op->getNumOperands() == 0; });
    //ReturnOp返回数为0时合法。
    //因为后续用到了upload和download将func里面的argu2进行结果copy,所以不需要return结果了
    mlir::func::FuncOp func = getOperation(); //获取funcOp
    updateFuncOp(func); //更新Op的操作
    RewritePatternSet convertPatterns(ctx); //存Pattern的集合
    convertPatterns.insert<ConstOpConverter, ReturnOpConverter>(ctx); 
    //将ConstOp、ReturnOp的ConvertPattern加入set
    (void)applyPartialConversion(func, target, std::move(convertPatterns));
    //根据target中定义的规则进行convertpatternset中的转换
  }
};

} //
//创建返回pass对象
std::unique_ptr<mlir::Pass> tir::createGlobalBufferizePass() {
  return std::make_unique<GlobalBufferize>();
}

以上是globalbufferize pass的主要部分,在定义的target合法规则检查上应用了两个转换pattern和updateFuncOp。下面看updateFuncOp

代码语言:javascript
复制
static inline MemRefType tensorToMemRef(RankedTensorType type) {
  return MemRefType::get(type.getShape(), type.getElementType());
}
static void updateFuncOp(mlir::func::FuncOp func) {
    mlir::OpBuilder builder(func.getBody());
    //OpBuilder用于在Func Body内生成Op
    auto funcType = func.getFunctionType(); 
    //获取FuncOp的inputs、results类型信息
    llvm::SmallVector<Type, 4> argTypes; //存更新后的函数参数类型
    for (auto type : llvm::enumerate(funcType.getInputs())) {
    //遍历FuncOp的输入参数
        auto tensorType = type.value().dyn_cast<RankedTensorType>();
        if (tensorType) {
            auto argType = tensorToMemRef(tensorType); //将tensor转为memref
            auto arg = func.getArgument(type.index());
            arg.setType(argType);
            //以上三步将funcOp inputs的对应type由Tensor type转为MemRef type
            auto load = builder.create<tir::UpLoadOp>(func.getLoc(), tensorType, arg);
            //创建tir.upload op,将该Op的input和result(args)为tensor type
            arg.replaceAllUsesExcept(load->getResult(0), load);
            //loadOp input替换为memref,result还是tensor
            argTypes.emplace_back(argType);
        } else {
            argTypes.emplace_back(type.value());
        }
    }
    for (auto type : funcType.getResults()) {
        auto tensorType = type.cast<RankedTensorType>();
        auto argType = tensorToMemRef(tensorType);
        argTypes.emplace_back(argType);
        func.front().addArguments(argType, builder.getUnknownLoc());
    }
    //将funcOp的type根据argTypes vector进行替换
    func.setType(FunctionType::get(func.getContext(), argTypes, llvm::None));
}

总结:updateFuncOp 函数的作用是将输入参数和输出结果从 RankedTensorType 转换为 MemRefType,另外还创建了tir.uploadOp(memref->tensor)来获取对应input的memref类型输入转为tensor。 再来看两个convertpattern,对于ConstOpConvert,实现上是用自定义memtx.const(tensor->memtef)+tir.upload(memref->tensor)替换了原来的tir.const(tensor->tensor)

代码语言:javascript
复制
struct ConstOpConverter : public OpConversionPattern<ConstOp> {
  using OpConversionPattern<ConstOp>::OpConversionPattern;

  LogicalResult
  matchAndRewrite(ConstOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    auto tensorType = op.getOutput().getType().cast<RankedTensorType>();
    auto memRefType = tensorToMemRef(tensorType);
    auto mconst =
        rewriter.create<memtx::ConstOp>(op.getLoc(), memRefType, *op.getValue())
            .getResult();
    rewriter.replaceOpWithNewOp<tir::UpLoadOp>(op, tensorType, mconst);
    return success();
  }
};

struct ReturnOpConverter : public OpConversionPattern<mlir::func::ReturnOp> {
  using OpConversionPattern<mlir::func::ReturnOp>::OpConversionPattern;

  LogicalResult
  matchAndRewrite(mlir::func::ReturnOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    auto func = op->getParentOfType<mlir::func::FuncOp>();
    unsigned retArgIndex = func.getNumArguments() - op.getNumOperands();
    for (auto opr : llvm::enumerate(adaptor.getOperands())) {
      auto outputArg = func.getArgument(retArgIndex + opr.index());
      rewriter.create<tir::DownLoadOp>(op.getLoc(), opr.value(), outputArg);
    }
    rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(op);
    return success();
  }
};

对于ReturnOpConverter,用tir.download替换returnOp,将输出结果从tensor转为memref global_bufferize pass后的结果如下,可以看到func的arg转为了memref,新增了tir.upload和download作为func arg输入memref->tensor的Op,memtx.const+tir.upload用于memref和tensor转换

代码语言:javascript
复制
module {
  func.func @XPUFunc(%arg0: memref<1x8x8x256xf32>, %arg1: memref<1x4x4x256xf32>) attributes {input_names = ["data0"], input_num = 1 : i64, output_names = ["conv0_fix"]} {
    %0 = "tir.upload"(%arg0) : (memref<1x8x8x256xf32>) -> tensor<1x8x8x256xf32>
    %1 = "memtx.const"() {value = dense_resource<__elided__> : tensor<256x2x2x256xi8>} : () -> memref<256x2x2x256xi8>
    %2 = "tir.upload"(%1) : (memref<256x2x2x256xi8>) -> tensor<256x2x2x256xi8>
    %3 = "memtx.const"() {value = dense_resource<__elided__> : tensor<256xi8>} : () -> memref<256xi8>
    %4 = "tir.upload"(%3) : (memref<256xi8>) -> tensor<256xi8>
    %5 = "tir.float2fix"(%0) {bit_width = 8 : i32, fix_point = 0 : i32, if_signed = true, op_name = "data0_fix", round_mode = "XPU_ROUND"} : (tensor<1x8x8x256xf32>) -> tensor<1x8x8x256xi8>
    %6 = "tir.conv2d-fix"(%5, %2, %4) {dilation = [1 : i32, 1 : i32], group = 1 : i32, hsigmoid_in = -128 : i32, kernel = [2 : i32, 2 : i32], nonlinear = "NONE", op_name = "conv0", pad = [0 : i32, 0 : i32, 0 : i32, 0 : i32], pad_mode = "FLOOR", shift_hsigmoid = -128 : i32, shift_hswish = -128 : i32, stride = [2 : i32, 2 : i32]} : (tensor<1x8x8x256xi8>, tensor<256x2x2x256xi8>, tensor<256xi8>) -> tensor<1x4x4x256xi8>
    %7 = "tir.fix2float"(%6) {bit_width = 8 : i32, fix_point = 0 : i32, if_signed = true, op_name = "conv0_fix", round_mode = "XPU_ROUND"} : (tensor<1x4x4x256xi8>) -> tensor<1x4x4x256xf32>
    "tir.download"(%7, %arg1) : (tensor<1x4x4x256xf32>, memref<1x4x4x256xf32>) -> ()
    return
  }
}

下面是新增的ODS自定义Op

代码语言:javascript
复制
include "tir_op_base.td"
include "mlir/Interfaces/SideEffectInterfaces.td"

def Tir_MemRef : StridedMemRefOf<[Tir_ElementType]>;
def Tir_UpLoadOp : Tir_Op<"upload", [NoMemoryEffect]> {
  let arguments = (ins Tir_MemRef:$mem);

  let results = (outs Tir_Tensor:$output);
}

def Tir_DownLoadOp : Tir_Op<"download"> {
  let arguments = (ins Tir_Tensor:$tensor, Tir_MemRef:$mem);
}
2.2tir2mtir_convert pass

直接上结果,我们的目的是将IR 做bufferization即不能出现出memref类型外的tensor类型,在前一个pass global_bufferize后,我们得到了IR所示的结果,在此基础上继续写第二个pass->tir2mtir_convert。

代码语言:javascript
复制
module {
  func.func @XPUFunc(%arg0: memref<1x8x8x256xf32>, %arg1: memref<1x4x4x256xf32>) attributes {input_names = ["data0"], input_num = 1 : i64, output_names = ["conv0_fix"]} {
    %alloc = memref.alloc() : memref<1x8x8x256xf32>
    "memtx.copy"(%arg0, %alloc) : (memref<1x8x8x256xf32>, memref<1x8x8x256xf32>) -> ()
    %0 = "memtx.const"() {value = dense_resource<__elided__> : tensor<256x2x2x256xi8>} : () -> memref<256x2x2x256xi8>
    %alloc_0 = memref.alloc() : memref<256x2x2x256xi8>
    "memtx.copy"(%0, %alloc_0) : (memref<256x2x2x256xi8>, memref<256x2x2x256xi8>) -> ()
    %1 = "memtx.const"() {value = dense_resource<__elided__> : tensor<256xi8>} : () -> memref<256xi8>
    %alloc_1 = memref.alloc() : memref<256xi8>
    "memtx.copy"(%1, %alloc_1) : (memref<256xi8>, memref<256xi8>) -> ()
    %alloc_2 = memref.alloc() : memref<1x8x8x256xi8>
    "mtir.float2fix"(%alloc, %alloc_2) {bit_width = 8 : i32, fix_point = 0 : i32, if_signed = true, op_name = "data0_fix", round_mode = "XPU_ROUND"} : (memref<1x8x8x256xf32>, memref<1x8x8x256xi8>) -> ()
    %alloc_3 = memref.alloc() : memref<1x4x4x256xi8>
    "mtir.conv2d-fix"(%alloc_2, %alloc_0, %alloc_1, %alloc_3) {dilation = [1 : i32, 1 : i32], group = 1 : i32, hsigmoid_in = -128 : i32, kernel = [2 : i32, 2 : i32], nonlinear = "NONE", op_name = "conv0", pad = [0 : i32, 0 : i32, 0 : i32, 0 : i32], pad_mode = "FLOOR", shift_hsigmoid = -128 : i32, shift_hswish = -128 : i32, stride = [2 : i32, 2 : i32]} : (memref<1x8x8x256xi8>, memref<256x2x2x256xi8>, memref<256xi8>, memref<1x4x4x256xi8>) -> ()
    %alloc_4 = memref.alloc() : memref<1x4x4x256xf32>
    "mtir.fix2float"(%alloc_3, %alloc_4) {bit_width = 8 : i32, fix_point = 0 : i32, if_signed = true, op_name = "conv0_fix", round_mode = "XPU_ROUND"} : (memref<1x4x4x256xi8>, memref<1x4x4x256xf32>) -> ()
    "memtx.copy"(%alloc_4, %arg1) : (memref<1x4x4x256xf32>, memref<1x4x4x256xf32>) -> ()
    return
  }
}

pass如下

代码语言:javascript
复制
struct ConvertTirToMTirPass
    : public impl::ConvertTirToMTirBase<ConvertTirToMTirPass> {
  void runOnOperation() override {
    mlir::func::FuncOp f = getOperation();
    auto &context = getContext();
    ConversionTarget target(context);
    mlir::bufferization::BufferizeTypeConverter typeConverter;
      
    // 设置TirToMTir的legality 和 patterns
    setupTirToMTirLegality(typeConverter, target);
    RewritePatternSet patterns(&context);
    populateTirToMTirPatterns(typeConverter, patterns);

    // 使用在target上定义的合法性pattern做conversion转换
    if (failed(applyFullConversion(f, target, std::move(patterns)))) {
      signalPassFailure();
    }
    // 设置finalize的legality和patterns
    RewritePatternSet finalizePatterns(&context);
    ConversionTarget finalizeTarget(context);
    finalizeTarget.markUnknownOpDynamicallyLegal(
        [&](Operation *op) { return typeConverter.isLegal(op); });
    populateEliminateBufferizeMaterializationsPatterns(typeConverter,
                                                       finalizePatterns);
    // 使用在target上定义的合法性pattern做conversion转换
    if (failed(applyFullConversion(f, finalizeTarget,
                                   std::move(finalizePatterns)))) {
      signalPassFailure();
    }
  }
};

} // end anonymous namespace

std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
mxir::createConvertTirToMTirPass() {
  return std::make_unique<ConvertTirToMTirPass>();
}

下面来看具体的Legality和pattern

代码语言:javascript
复制
//添加和标记合法和非法的方言,在convert的时候应用
void xcompiler::mxir::setupTirToMTirLegality(
    mlir::bufferization::BufferizeTypeConverter &typeConverter,
    ConversionTarget &target) {
  target.addLegalDialect<memref::MemRefDialect>();
  target.addLegalDialect<mtir::MTIRDialect>();
  target.addLegalDialect<memtx::MemTxDialect>();
  target.addLegalDialect<AffineDialect, arith::ArithDialect>();
  target.addLegalOp<mlir::func::ReturnOp, mlir::func::FuncOp>();
  target.addIllegalDialect<tir::TirDialect>();
  //virtual buffer
  mlir::bufferization::populateBufferizeMaterializationLegality(target);
}

void xcompiler::mxir::populateTirToMTirPatterns(
    mlir::bufferization::BufferizeTypeConverter &typeConverter,
    RewritePatternSet &patterns) {
  auto *context = patterns.getContext();
  typeConverter.addConversion(
      [](RankedTensorType type) -> Optional<Type> { return llvm::None; });
  //不支持tensorType
  typeConverter.addArgumentMaterialization(
      [](OpBuilder &builder, TensorType type, ValueRange inputs,
         Location loc) -> Optional<Value> {
        if (type.getRank() == 0) { //标量直接返回第一个输入
          return inputs[0]; 
        }
        return llvm::None;
      });
  //主要应用了四个pattern
  patterns.add<ConstOpConverter, UpLoadOpConverter, DownLoadOpConverter,
               TirOpConverter>(typeConverter, context);
}

四个pattern

代码语言:javascript
复制
//alloc op 
//为给定的op创建一个内存分配操作memref::AllocOp
static memref::AllocOp createAllocForOp(Operation *op, MemRefType type,
                                        OpBuilder &builder) {
  auto alloc = builder.create<memref::AllocOp>(op->getLoc(), type);
  if (auto attr = op->getAttrOfType<IntegerAttr>("id")) {
    auto baseName = op->getName().stripDialect().str();
    //分配name alloc_0 alloc_1 ...
    std::string bufferName =
        baseName + "." + std::to_string(attr.getInt()) + ".out";
    alloc->setAttr("name", builder.getStringAttr(bufferName));
  }
  return alloc;
}
//这个pattern是将 memtx::ConstOp 操作转换为 arith::ConstantOp 操作
struct ConstOpConverter : public OpConversionPattern<memtx::ConstOp> {
  using OpConversionPattern<memtx::ConstOp>::OpConversionPattern;

  LogicalResult
  matchAndRewrite(memtx::ConstOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const final {
    auto type = op.getType();
    auto denseAttr = op.getValue().cast<mlir::DenseElementsAttr>();
    rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, denseAttr);
    return success();
  }
};
//创建一个新的memref::AllocOp,然后使用memtx::CopyOp将
//原始op memref中的数据复制到新分配的memref中,并最终将原始op替换为新分配的memref
struct UpLoadOpConverter : public OpConversionPattern<UpLoadOp> {
  using OpConversionPattern<UpLoadOp>::OpConversionPattern;

  LogicalResult
  matchAndRewrite(UpLoadOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const final {
    auto type = op.getMem().getType().dyn_cast<MemRefType>();
    auto typeAlloc = MemRefType::get(type.getShape(), type.getElementType());
    auto alloc = rewriter.create<memref::AllocOp>(op.getLoc(), typeAlloc);
    auto a = rewriter.create<memtx::CopyOp>(op.getLoc(), op.getMem(),
                                            alloc.getMemref());
    rewriter.replaceOp(op, alloc.getMemref());
    return success();
  }
};
//将DownLoadOp转换为memtx::CopyOp
struct DownLoadOpConverter : public OpConversionPattern<DownLoadOp> {
  using OpConversionPattern<DownLoadOp>::OpConversionPattern;

  LogicalResult
  matchAndRewrite(DownLoadOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const final {
    rewriter.replaceOpWithNewOp<memtx::CopyOp>(op, adaptor.getTensor(),
                                               adaptor.getMem());
    return success();
  }
};
//将tir.op转为mtir.op
//如tir.conv2d-fix->mtir.conv2d-fix
class TirOpConverter : public OpInterfaceConversionPattern<TirOpInterface> {
public:
  using OpInterfaceConversionPattern<
      TirOpInterface>::OpInterfaceConversionPattern;

  LogicalResult
  matchAndRewrite(TirOpInterface op, ArrayRef<Value> operands,
                  ConversionPatternRewriter &rewriter) const final {
    Location loc = op.getLoc();
    auto tensorType = op->getResult(0).getType().cast<RankedTensorType>();
    auto memrefType =
        getTypeConverter()->convertType(tensorType).cast<MemRefType>();

    SmallVector<Value, 4> bufferOprs(operands.begin(), operands.end());

    Value output;
    auto baseName = op->getName().stripDialect().str();
    
    if (!output) {
      output = createAllocForOp(op, memrefType, rewriter).getMemref();
    }
    bufferOprs.push_back(output);

    auto opName = "mtir." + baseName;
    auto *ctx = getContext();
    //根据op的type和attr创建新的op,并使用rewriter执行op的替换和插入
    if (RegisteredOperationName::lookup(opName, ctx)) {
      rewriter.insert(Operation::create(loc, OperationName(opName, ctx), {},
                                        bufferOprs, op->getAttrDictionary()));
    } else {
      llvm::errs() << "Op not supported in tir to txir conversion";
    }

    rewriter.replaceOp(op, output);
    return success();
  }
};

MTIR ODS自定义Op .td写法示例

代码语言:javascript
复制
def MTIR_Conv2dFixOp :
    MTIR_Op<"conv2d-fix", []> {
  let summary = "2D Convolution Fix Operator";
  let description = [{
    Performs a 2D convolution-fix over the given tensor input, using the weight
    tensor.
  }];

  let arguments = (ins
    MTIR_MemRef:$input,
    MTIR_MemRef:$weight,
    Optional<MTIR_MemRef>:$bias,
    MTIR_MemRef:$output,

    I32ArrayAttr:$kernel,
    I32ArrayAttr:$stride,
    OptionalAttr<I32ArrayAttr>:$dilation,
    OptionalAttr<StrAttr>:$pad_mode,
    OptionalAttr<I32ArrayAttr>:$pad,
    OptionalAttr<StrAttr>:$nonlinear,
    OptionalAttr<I32Attr>:$hsigmoid_in,
    OptionalAttr<I32Attr>:$shift_hsigmoid,
    OptionalAttr<I32Attr>:$shift_hswish,
    OptionalAttr<I32Attr>:$group
  );
}
...
...

总结:通过上面两步pass即得到了自定义TIR->MTIR的bufferization化

代码语言:javascript
复制
module {
  func.func @XPUFunc(%arg0: memref<1x8x8x256xf32>, %arg1: memref<1x4x4x256xf32>) attributes {input_names = ["data0"], input_num = 1 : i64, output_names = ["conv0_fix"]} {
    %alloc = memref.alloc() : memref<1x8x8x256xf32>
    "memtx.copy"(%arg0, %alloc) : (memref<1x8x8x256xf32>, memref<1x8x8x256xf32>) -> ()
    %0 = "memtx.const"() {value = dense_resource<__elided__> : tensor<256x2x2x256xi8>} : () -> memref<256x2x2x256xi8>
    %alloc_0 = memref.alloc() : memref<256x2x2x256xi8>
    "memtx.copy"(%0, %alloc_0) : (memref<256x2x2x256xi8>, memref<256x2x2x256xi8>) -> ()
    %1 = "memtx.const"() {value = dense_resource<__elided__> : tensor<256xi8>} : () -> memref<256xi8>
    %alloc_1 = memref.alloc() : memref<256xi8>
    "memtx.copy"(%1, %alloc_1) : (memref<256xi8>, memref<256xi8>) -> ()
    %alloc_2 = memref.alloc() : memref<1x8x8x256xi8>
    "mtir.float2fix"(%alloc, %alloc_2) {bit_width = 8 : i32, fix_point = 0 : i32, if_signed = true, op_name = "data0_fix", round_mode = "XPU_ROUND"} : (memref<1x8x8x256xf32>, memref<1x8x8x256xi8>) -> ()
    %alloc_3 = memref.alloc() : memref<1x4x4x256xi8>
    "mtir.conv2d-fix"(%alloc_2, %alloc_0, %alloc_1, %alloc_3) {dilation = [1 : i32, 1 : i32], group = 1 : i32, hsigmoid_in = -128 : i32, kernel = [2 : i32, 2 : i32], nonlinear = "NONE", op_name = "conv0", pad = [0 : i32, 0 : i32, 0 : i32, 0 : i32], pad_mode = "FLOOR", shift_hsigmoid = -128 : i32, shift_hswish = -128 : i32, stride = [2 : i32, 2 : i32]} : (memref<1x8x8x256xi8>, memref<256x2x2x256xi8>, memref<256xi8>, memref<1x4x4x256xi8>) -> ()
    %alloc_4 = memref.alloc() : memref<1x4x4x256xf32>
    "mtir.fix2float"(%alloc_3, %alloc_4) {bit_width = 8 : i32, fix_point = 0 : i32, if_signed = true, op_name = "conv0_fix", round_mode = "XPU_ROUND"} : (memref<1x4x4x256xi8>, memref<1x4x4x256xf32>) -> ()
    "memtx.copy"(%alloc_4, %arg1) : (memref<1x4x4x256xf32>, memref<1x4x4x256xf32>) -> ()
    return
  }
}
本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2024-06-25,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 GiantPandaCV 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 一、bufferization概念
  • 二、实现
    • 2.1global_bufferize pass
      • 2.2tir2mtir_convert pass
      领券
      问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档