我试图在Android上运行一个tflite模型来检测对象。同样的,
(a)培训:
!python3 object_detection/model_main.py \
--pipeline_config_path=/content/drive/My\ Drive/Detecto\ Tutorial/models/research/object_detection/samples/configs/ssd_mobilenet_v2_coco.config \
--model_dir=training/
(修改配置文件以指向提到我的特定TFrecords的位置)
(b)导出推理图
!python /content/drive/'My Drive'/'Detecto Tutorial'/models/research/object_detection/export_inference_graph.py \
--input_type=image_tensor \
--pipeline_config_path=/content/drive/My\ Drive/Detecto\ Tutorial/models/research/object_detection/samples/configs/ssd_mobilenet_v2_coco.config \
--output_directory={output_directory} \
--trained_checkpoint_prefix={last_model_path}
(c)创建tflite就绪图
!python /content/drive/'My Drive'/'Detecto Tutorial'/models/research/object_detection/export_tflite_ssd_graph.py \
--pipeline_config_path=/content/drive/My\ Drive/Detecto\ Tutorial/models/research/object_detection/samples/configs/ssd_mobilenet_v2_coco.config \
--output_directory={output_directory} \
--trained_checkpoint_prefix={last_model_path} \
--add_postprocessing_op=true
上面的tflite模型是独立验证的,运行良好(在Android之外)。
现在需要用元数据填充tflite模型,以便可以在下面提供的示例Android代码中处理tflite模型(否则我会收到错误:不是一个有效的Zip文件,并且在Android上运行时没有关联文件)。
作为同一链接的一部分提供的示例.TFlite使用元数据填充,运行良好。
当我尝试使用以下链接时:示例
populator = _metadata.MetadataPopulator.with_model_file('/content/drive/My Drive/Detecto Tutorial/models/research/fine_tuned_model/detect3.tflite')
populator.load_metadata_buffer(metadata_buf)
populator.load_associated_files(['/content/drive/My Drive/Detecto Tutorial/models/research/fine_tuned_model/labelmap.txt'])
populator.populate()
要添加元数据(代码的其余部分实际上与将元描述更改为对象检测(而不是图像分类和指定labelmap.txt的位置)相同,它会给出以下错误:
<ipython-input-6-173fc798ea6e> in <module>()
1 populator = _metadata.MetadataPopulator.with_model_file('/content/drive/My Drive/Detecto Tutorial/models/research/fine_tuned_model/detect3.tflite')
----> 2 populator.load_metadata_buffer(metadata_buf)
3 populator.load_associated_files(['/content/drive/My Drive/Detecto Tutorial/models/research/fine_tuned_model/labelmap.txt'])
4 populator.populate()
1 frames
/usr/local/lib/python3.6/dist-packages/tensorflow_lite_support/metadata/metadata.py in _validate_metadata(self, metadata_buf)
540 "The number of output tensors ({0}) should match the number of "
541 "output tensor metadata ({1})".format(num_output_tensors,
--> 542 num_output_meta))
543
544
ValueError: The number of output tensors (4) should match the number of output tensor metadata (1)
这4个输出张量是步骤2中提到的output_arrays中提到的那些张量(可能有人会在那里纠正我)。我不知道如何相应地更新输出张量元数据。
最近使用自定义模型(然后在Android上应用)使用对象检测的人能帮忙吗?或者帮助理解如何将张量元数据更新为4而不是1。
发布于 2020-10-22 19:36:41
更新于2021年6月10日:
请参阅关于元数据写入器库的最新教程 on tensorflow.org。
更新
元数据写入库已经发布。它目前支持图像分类器和目标检测器,更多支持的任务正在进行中。
下面是一个为对象检测器模型编写元数据的示例:
pip install tflite_support_nightly
from tflite_support.metadata_writers import object_detector
from tflite_support.metadata_writers import writer_utils
from tflite_support import metadata
ObjectDetectorWriter = object_detector.MetadataWriter
_MODEL_PATH = "ssd_mobilenet_v1_1_default_1.tflite"
_LABEL_FILE = "labelmap.txt"
_SAVE_TO_PATH = "ssd_mobilenet_v1_1_default_1_metadata.tflite"
writer = ObjectDetectorWriter.create_for_inference(
writer_utils.load_file(_MODEL_PATH), [127.5], [127.5], [_LABEL_FILE])
writer_utils.save_file(writer.populate(), _SAVE_TO_PATH)
# Verify the populated metadata and associated files.
displayer = metadata.MetadataDisplayer.with_model_file(_SAVE_TO_PATH)
print("Metadata populated:")
print(displayer.get_metadata_json())
print("Associated file(s) populated:")
print(displayer.get_packed_associated_file_list())
这里有一个代码片段,您可以使用它来填充对象检测模型的元数据,它与TFLite安卓应用程序兼容。
model_meta = _metadata_fb.ModelMetadataT()
model_meta.name = "SSD_Detector"
model_meta.description = (
"Identify which of a known set of objects might be present and provide "
"information about their positions within the given image or a video "
"stream.")
# Creates input info.
input_meta = _metadata_fb.TensorMetadataT()
input_meta.name = "image"
input_meta.content = _metadata_fb.ContentT()
input_meta.content.contentProperties = _metadata_fb.ImagePropertiesT()
input_meta.content.contentProperties.colorSpace = (
_metadata_fb.ColorSpaceType.RGB)
input_meta.content.contentPropertiesType = (
_metadata_fb.ContentProperties.ImageProperties)
input_normalization = _metadata_fb.ProcessUnitT()
input_normalization.optionsType = (
_metadata_fb.ProcessUnitOptions.NormalizationOptions)
input_normalization.options = _metadata_fb.NormalizationOptionsT()
input_normalization.options.mean = [127.5]
input_normalization.options.std = [127.5]
input_meta.processUnits = [input_normalization]
input_stats = _metadata_fb.StatsT()
input_stats.max = [255]
input_stats.min = [0]
input_meta.stats = input_stats
# Creates outputs info.
output_location_meta = _metadata_fb.TensorMetadataT()
output_location_meta.name = "location"
output_location_meta.description = "The locations of the detected boxes."
output_location_meta.content = _metadata_fb.ContentT()
output_location_meta.content.contentPropertiesType = (
_metadata_fb.ContentProperties.BoundingBoxProperties)
output_location_meta.content.contentProperties = (
_metadata_fb.BoundingBoxPropertiesT())
output_location_meta.content.contentProperties.index = [1, 0, 3, 2]
output_location_meta.content.contentProperties.type = (
_metadata_fb.BoundingBoxType.BOUNDARIES)
output_location_meta.content.contentProperties.coordinateType = (
_metadata_fb.CoordinateType.RATIO)
output_location_meta.content.range = _metadata_fb.ValueRangeT()
output_location_meta.content.range.min = 2
output_location_meta.content.range.max = 2
output_class_meta = _metadata_fb.TensorMetadataT()
output_class_meta.name = "category"
output_class_meta.description = "The categories of the detected boxes."
output_class_meta.content = _metadata_fb.ContentT()
output_class_meta.content.contentPropertiesType = (
_metadata_fb.ContentProperties.FeatureProperties)
output_class_meta.content.contentProperties = (
_metadata_fb.FeaturePropertiesT())
output_class_meta.content.range = _metadata_fb.ValueRangeT()
output_class_meta.content.range.min = 2
output_class_meta.content.range.max = 2
label_file = _metadata_fb.AssociatedFileT()
label_file.name = os.path.basename("label.txt")
label_file.description = "Label of objects that this model can recognize."
label_file.type = _metadata_fb.AssociatedFileType.TENSOR_VALUE_LABELS
output_class_meta.associatedFiles = [label_file]
output_score_meta = _metadata_fb.TensorMetadataT()
output_score_meta.name = "score"
output_score_meta.description = "The scores of the detected boxes."
output_score_meta.content = _metadata_fb.ContentT()
output_score_meta.content.contentPropertiesType = (
_metadata_fb.ContentProperties.FeatureProperties)
output_score_meta.content.contentProperties = (
_metadata_fb.FeaturePropertiesT())
output_score_meta.content.range = _metadata_fb.ValueRangeT()
output_score_meta.content.range.min = 2
output_score_meta.content.range.max = 2
output_number_meta = _metadata_fb.TensorMetadataT()
output_number_meta.name = "number of detections"
output_number_meta.description = "The number of the detected boxes."
output_number_meta.content = _metadata_fb.ContentT()
output_number_meta.content.contentPropertiesType = (
_metadata_fb.ContentProperties.FeatureProperties)
output_number_meta.content.contentProperties = (
_metadata_fb.FeaturePropertiesT())
# Creates subgraph info.
group = _metadata_fb.TensorGroupT()
group.name = "detection result"
group.tensorNames = [
output_location_meta.name, output_class_meta.name,
output_score_meta.name
]
subgraph = _metadata_fb.SubGraphMetadataT()
subgraph.inputTensorMetadata = [input_meta]
subgraph.outputTensorMetadata = [
output_location_meta, output_class_meta, output_score_meta,
output_number_meta
]
subgraph.outputTensorGroups = [group]
model_meta.subgraphMetadata = [subgraph]
b = flatbuffers.Builder(0)
b.Finish(
model_meta.Pack(b),
_metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
self.metadata_buf = b.Output()
https://stackoverflow.com/questions/64097085
复制相似问题