导语:模型优化有很多方法,比如模型压缩、模型剪纸、转tensorrt等。本文讨论mxnet模型tesorrt优化,主要解决动态batch、Op不支持问题。
cuda 10.2.89
cudnn 8.0.3.33
mxnet-cu102 1.8.0.post0
onnx 1.8.1
onnx-simplifier 0.3.6
onnxconverter-common 1.7.0
onnxmltools 1.7.0
onnxoptimizer 0.2.6
onnxruntime 1.8.0
onnxruntime-gpu-tensorrt 1.4.0
tensorrt 7.1.3.4
from mxnet.contrib import onnx as onnx_mxnet
symbol_model = "./face_recg_2.3-symbol.json"
params_model = "./face_recg_2.3-0000.params"
onnx_path = "face_recg.onnx"
input_shape = (1, 3, 224, 192)
onnx_mxnet.export_model(symbol_model, params_model, [input_shape], np.float32, onnx_path, verbose=True)
mxnet没有dynamic_axes可以配置,此时转好的模型batch固定为1。ONNX中对shape的处理,可以为text,所以我们可以直接修改onnx模型去支持动态batch。
onnx_file = "face_recg.onnx"
new_onnx_file = "face_recg_new.onnx"
input_maps = {}
output_maps = {}
init_maps = {}
keys = []
model_onnx = onnx.load(onnx_file)
graph = model_onnx.graph
for inp in model_onnx.graph.input:
input_maps[inp.name] = inp
keys.append(inp.name)
for inp in model_onnx.graph.output:
output_maps[inp.name] = inp
name = 'data'
graph.input.remove(input_maps[name])
new_nv = onnx.helper.make_tensor_value_info(name, onnx.TensorProto.FLOAT, ["batch", 3, 224, 192])
graph.input.extend([new_nv])
name = 'fc5'
graph.output.remove(output_maps[name])
new_nv = onnx.helper.make_tensor_value_info(name, onnx.TensorProto.FLOAT, ["batch", 1024])
graph.output.extend([new_nv])
save_onnx(model_onnx, new_onnx_file)
修改前模型结构:
修改后模型结构:
报错如下:
[TensorRT] ERROR: relu0_1: slope tensor must be unidirectional broadcastable to input tensor
PRelu的参数为一维的,我们在onnxruntime中进行inference的时候可能无法正常进行broadcast(graph optimization阶段也无法进行),所以我们的思路很直接:直接修改slope参数的shape信息,若原来是1x64的向量,则将其shape信息改成(64, 1, 1)。修改后可以正常进行broadcast。核心代码如下,我们需要构建input maps(节点信息)和initializer maps(权值信息),然后遍历所有node去找出prelu的节点进行修改。
修改代码:
for init in model_onnx.graph.initializer:
init_maps[init.name] = init
for key in keys:
if "relu" in key:
inp = input_maps[key]
dim_value = inp.type.tensor_type.shape.dim[0].dim_value
new_shape = [dim_value, 1, 1]
graph.input.remove(inp)
new_inp = onnx.helper.make_tensor_value_info(inp.name, onnx.TensorProto.FLOAT, new_shape)
graph.input.extend([new_inp])
init = init_maps[key]
new_init = onnx.helper.make_tensor(init.name, onnx.TensorProto.FLOAT, new_shape, init.float_data)
graph.initializer.remove(init)
graph.initializer.extend([new_init])
修改前模型结构:
修改后模型结构:
from onnxsim import simplify
model_onnx = onnx.load(onnx_path)
onnx_model_simp, check = simplify(model_onnx, input_shapes = {"data": [16, 3, 224, 192]}, dynamic_input_shape = True)
assert check, "Simplified ONNX model could not be validated"
onnx.save(model_onnx_simp, onnx_simp_file)
import pycuda.driver as cuda
import pycuda.autoinit
import tensorrt as trt
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
onnx_simp_file = "face_recg_simp.onnx"
engine_file = "face_recg.trt"
max_batch_size = 16
max_ws = 10 * 1 << 30
dynamic_inputs = {
"data": {
"min_shape": (1, 3, 224, 192),
"opt_shape": (32, 3, 224, 192),
"max_shape": (64, 3, 224, 192),
}
}
fp16 = True
print("building tensorrt engine")
builder = trt.Builder(TRT_LOGGER)
builder.fp16_mode = fp16
builder.max_batch_size = max_batch_size
config = builder.create_builder_config()
config.max_workspace_size = max_ws
if fp16:
config.flags |= 1 << int(trt.BuilderFlag.FP16)
explicit_batch = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
network = builder.create_network(explicit_batch)
if dynamic_inputs and len(dynamic_inputs) > 0:
profile = builder.create_optimization_profile();
for k, v in dynamic_inputs.items():
profile.set_shape(k, v['min_shape'], v['opt_shape'], v['max_shape'])
config.add_optimization_profile(profile)
with trt.OnnxParser(network, TRT_LOGGER) as parser:
with open(onnx_simp_file, 'rb') as model:
parsed = parser.parse(model.read())
print("network.num_layers", network.num_layers)
#last_layer = network.get_layer(network.num_layers - 1)
#network.mark_output(last_layer.get_output(0))
engine = builder.build_engine(network, config=config)
with open(engine_file, 'wb') as f:
f.write(bytearray(engine.serialize()))
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。