Loading [MathJax]/jax/input/TeX/config.js
前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >专栏 >Tensorflow Lite Model Maker --- 图像分类篇+源码

Tensorflow Lite Model Maker --- 图像分类篇+源码

原创
作者头像
XianxinMao
修改于 2021-10-11 02:32:18
修改于 2021-10-11 02:32:18
1.3K00
代码可运行
举报
文章被收录于专栏:深度学习框架深度学习框架
运行总次数:0
代码可运行

TFLite_tutorials

The TensorFlow Lite Model Maker library simplifies the process of adapting and converting a TensorFlow neural-network model to particular input data when deploying this model for on-device ML applications. 解读: 此处我们想要得到的是 .tflite 格式的模型,用于在移动端或者嵌入式设备上进行部署

下表罗列的是 TFLite Model Maker 目前支持的几个任务类型

Supported Tasks

Task Utility

Image Classification: tutorial, api

Classify images into predefined categories.

Object Detection: tutorial, api

Detect objects in real time.

Text Classification: tutorial, api

Classify text into predefined categories.

BERT Question Answer: tutorial, api

Find the answer in a certain context for a given question with BERT.

Audio Classification: tutorial, api

Classify audio into predefined categories.

Recommendation: demo, api

Recommend items based on the context information for on-device scenario.

If your tasks are not supported, please first use TensorFlow to retrain a TensorFlow model with transfer learning (following guides like images, text, audio) or train it from scratch, and then convert it to TensorFlow Lite model. 解读: 如果你要训练的模型不符合上述的任务类型,那么可以先训练 Tensorflow Model 然后再转换成 TFLite

想用使用 Tensorflow Lite Model Maker 我们需要先安装:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
pip install tflite-model-maker

本质完成的是分类任务 更换不同的模型,看最终的准确率,以及 TFLite 的大小、推断速度、内存占用、CPU占用等

下面的代码片段是用于下载数据集的

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
image_path = tf.keras.utils.get_file(
    'flower_photos.tgz',
    'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
    extract=True)
image_path = os.path.join(os.path.dirname(image_path), 'flower_photos')

数据集结构如下所示: flower_photos |__ daisy |__ 100080576_f52e8ee070_n.jpg |__ 14167534527_781ceb1b7a_n.jpg |__ ... |__ dandelion |__ 10043234166_e6dd915111_n.jpg |__ 1426682852_e62169221f_m.jpg |__ ... |__ roses |__ 102501987_3cdb8e5394_n.jpg |__ 14982802401_a3dfb22afb.jpg |__ ... |__ sunflowers |__ 12471791574_bb1be83df4.jpg |__ 15122112402_cafa41934f.jpg |__ ... |__ tulips |__ 13976522214_ccec508fe7.jpg |__ 14487943607_651e8062a1_m.jpg |__ ...

加载数据集并切分

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
data = DataLoader.from_folder(image_path)
train_data, rest_data = data.split(0.8)
validation_data, test_data = rest_data.split(0.5)
代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
assert tf.__version__.startswith('2')

判断是否为 '2' 开头

模型训练结果 train_acc = 0.9698, val_acc = 0.9375, test_acc = 0.9210 总体来说符合模型的泛化规律

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
import os
import time
​
import numpy as np
import tensorflow as tf
from tflite_model_maker import model_spec
from tflite_model_maker import image_classifier
from tflite_model_maker.config import ExportFormat
from tflite_model_maker.config import QuantizationConfig
from tflite_model_maker.image_classifier import DataLoader
import matplotlib.pyplot as plt
​
assert tf.__version__.startswith('2')
​
image_path = tf.keras.utils.get_file(
    'flower_photos.tgz',
    'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
    extract=True)
image_path = os.path.join(os.path.dirname(image_path), 'flower_photos')
​
data = DataLoader.from_folder(image_path)
# data = data.gen_dataset(batch_size=1)
train_data, rest_data = data.split(0.8)
# for batch in data.take(1):
#     print(batch)
#     break
​
validation_data, test_data = rest_data.split(0.5)
​
model = image_classifier.create(train_data, validation_data=validation_data,
                                model_spec=model_spec.get('efficientnet_lite0'), epochs=20)
​
loss, accuracy = model.evaluate(test_data)
​
model.export(export_dir='./testTFlite', export_format=(ExportFormat.TFLITE, ExportFormat.LABEL))
​
start = time.time()
print(model.evaluate_tflite('./testTFlite/model.tflite', test_data))
end = time.time()
print('elapsed time: ', end - start)

从上面的输出日志来看,模型经过量化后,准确率并未有多少损失,量化后的模型大小为 4.0MB(efficientnet_lite0) 从下图来看,是单 cpu 在做推断,test_data 的图片有 367 张,总耗时 273.43s

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
config = QuantizationConfig.for_float16()
model.export(export_dir='./testTFlite', tflite_filename='model_fp16.tflite', quantization_config=config, export_format=(ExportFormat.TFLITE, ExportFormat.LABEL))

如果导出的模型是 fp16 的话,模型大小为 6.8MB(efficientnet_lite0),推断速度是 5.54 s,快了很多

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
model = image_classifier.create(train_data, validation_data=validation_data,
                                model_spec=model_spec.get('mobilenet_v2'), epochs=20)

将模型切换为 mobilenet_v2,导出的 fp16 模型大小为 4.6MB,推断速度是 4.36 s

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
inception_v3_spec = image_classifier.ModelSpec(
    uri='https://tfhub.dev/google/imagenet/inception_v3/feature_vector/1')
inception_v3_spec.input_image_shape = [299, 299]
model = image_classifier.create(train_data, validation_data=validation_data,
                                model_spec=inception_v3_spec, epochs=20)

将模型切换为 inception_v3,导出的 fp16 模型大小为 43.8MB(inception_v3),推断速度是 25.31 s

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
Common Dataset used for tasks.class DataLoader(object):
  """This class provides generic utilities for loading customized domain data that will be used later in model retraining.
​
  For different ML problems or tasks, such as image classification, text
  classification etc., a subclass is provided to handle task-specific data
  loading requirements.
  """
​
  def __init__(self, dataset, size):
    """Init function for class `DataLoader`.
​
    In most cases, one should use helper functions like `from_folder` to create
    an instance of this class.Args:
      dataset: A tf.data.Dataset object that contains a potentially large set of
        elements, where each element is a pair of (input_data, target). The
        `input_data` means the raw input data, like an image, a text etc., while
        the `target` means some ground truth of the raw input data, such as the
        classification label of the image etc.
      size: The size of the dataset. tf.data.Dataset donesn't support a function
        to get the length directly since it's lazy-loaded and may be infinite.
    """
    self._dataset = dataset
    self._size = size
​
  def gen_dataset(self,
                  batch_size=1,
                  is_training=False,
                  shuffle=False,
                  input_pipeline_context=None,
                  preprocess=None,
                  drop_remainder=False):
    """Generate a shared and batched tf.data.Dataset for training/evaluation.
代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
Image dataloader
​
class ImageClassifierDataLoader(dataloader.ClassificationDataLoader):
  """DataLoader for image classifier."""
​
  @classmethod
  def from_folder(cls, filename, shuffle=True):
    """Image analysis for image classification load images with labels.
​
    Assume the image data of the same label are in the same subdirectory.Args:
      filename: Name of the file.
      shuffle: boolean, if shuffle, random shuffle data.Returns:
      ImageDataset containing images and labels and other related info.
    """
   @classmethod
   def from_tfds(cls, name):
     """Loads data from tensorflow_datasets."""
代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
ImageNet preprocessing
​
class Preprocessor(object):
  """Preprocessing for image classification."""
​
  def __init__(self,
               input_shape,
               num_classes,
               mean_rgb,
               stddev_rgb,
               use_augmentation=False):
    self.input_shape = input_shape
    self.num_classes = num_classes
    self.mean_rgb = mean_rgb
    self.stddev_rgb = stddev_rgb
    self.use_augmentation = use_augmentation
​
  def __call__(self, image, label, is_training=True):
    if self.use_augmentation:
      return self._preprocess_with_augmentation(image, label, is_training)
    return self._preprocess_without_augmentation(image, label)
​
  def _preprocess_with_augmentation(self, image, label, is_training):
    """Image preprocessing method with data augmentation."""
    image_size = self.input_shape[0]
    if is_training:
      image = preprocess_for_train(image, image_size)
    else:
      image = preprocess_for_eval(image, image_size)
​
    image -= tf.constant(self.mean_rgb, shape=[1, 1, 3], dtype=image.dtype)
    image /= tf.constant(self.stddev_rgb, shape=[1, 1, 3], dtype=image.dtype)
​
    label = tf.one_hot(label, depth=self.num_classes)
    return image, label
​
  # TODO(yuqili): Changes to preprocess to support batch input.
  def _preprocess_without_augmentation(self, image, label):
    """Image preprocessing method without data augmentation."""
    image = tf.cast(image, tf.float32)
​
    image -= tf.constant(self.mean_rgb, shape=[1, 1, 3], dtype=image.dtype)
    image /= tf.constant(self.stddev_rgb, shape=[1, 1, 3], dtype=image.dtype)
​
    image = tf.compat.v1.image.resize(image, self.input_shape)
    label = tf.one_hot(label, depth=self.num_classes)
    return image, label
代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
class ImageClassifier(classification_model.ClassificationModel):
  """ImageClassifier class for inference and exporting to tflite."""
​
  def __init__(self,
               model_spec,
               index_to_label,
               shuffle=True,
               hparams=hub_lib.get_default_hparams(),
               use_augmentation=False,
               representative_data=None):
    """Init function for ImageClassifier class.Args:
      model_spec: Specification for the model.
      index_to_label: A list that map from index to label class name.
      shuffle: Whether the data should be shuffled.
      hparams: A namedtuple of hyperparameters. This function expects
        .dropout_rate: The fraction of the input units to drop, used in dropout
          layer.
        .do_fine_tuning: If true, the Hub module is trained together with the
          classification layer on top.
      use_augmentation: Use data augmentation for preprocessing.
      representative_data:  Representative dataset for full integer
        quantization. Used when converting the keras model to the TFLite model
        with full interger quantization.
    """
    super(ImageClassifier, self).__init__(model_spec, index_to_label, shuffle,
                                          hparams.do_fine_tuning)
    num_classes = len(index_to_label)
    self._hparams = hparams
    self.preprocess = image_preprocessing.Preprocessor(
        self.model_spec.input_image_shape,
        num_classes,
        self.model_spec.mean_rgb,
        self.model_spec.stddev_rgb,
        use_augmentation=use_augmentation)
    self.history = None  # Training history that returns from `keras_model.fit`.
    self.representative_data = representative_data
​
  def _get_tflite_input_tensors(self, input_tensors):
    """Gets the input tensors for the TFLite model."""
    return input_tensors
​
  def create_model(self, hparams=None, with_loss_and_metrics=False):
    """Creates the classifier model for retraining."""
    hparams = self._get_hparams_or_default(hparams)
​
    module_layer = hub_loader.HubKerasLayerV1V2(
        self.model_spec.uri, trainable=hparams.do_fine_tuning)
    self.model = hub_lib.build_model(module_layer, hparams,
                                     self.model_spec.input_image_shape,
                                     self.num_classes)
    if with_loss_and_metrics:
      # Adds loss and metrics in the keras model.
      self.model.compile(
          loss=tf.keras.losses.CategoricalCrossentropy(label_smoothing=0.1),
          metrics=['accuracy'])
代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
Custom classification model that is already retained by data
​
class ClassificationModel(custom_model.CustomModel):
  """"The abstract base class that represents a Tensorflow classification model."""
​
  DEFAULT_EXPORT_FORMAT = (ExportFormat.TFLITE, ExportFormat.LABEL)
  ALLOWED_EXPORT_FORMAT = (ExportFormat.TFLITE, ExportFormat.LABEL,
                           ExportFormat.SAVED_MODEL, ExportFormat.TFJS)
​
  def __init__(self, model_spec, index_to_label, shuffle, train_whole_model):
    """Initialize a instance with data, deploy mode and other related parameters.Args:
      model_spec: Specification for the model.
      index_to_label: A list that map from index to label class name.
      shuffle: Whether the data should be shuffled.
      train_whole_model: If true, the Hub module is trained together with the
        classification layer on top. Otherwise, only train the top
        classification layer.
    """
    super(ClassificationModel, self).__init__(model_spec, shuffle)
    self.index_to_label = index_to_label
    self.num_classes = len(index_to_label)
    self.train_whole_model = train_whole_model
​
  def evaluate(self, data, batch_size=32):
    """Evaluates the model.Args:
      data: Data to be evaluated.
      batch_size: Number of samples per evaluation step.Returns:
      The loss value and accuracy.
    """
    ds = data.gen_dataset(
        batch_size, is_training=False, preprocess=self.preprocess)
    return self.model.evaluate(ds)
​
  def predict_top_k(self, data, k=1, batch_size=32):
    """Predicts the top-k predictions.
代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
class CustomModel(abc.ABC):
  """"The abstract base class that represents a Tensorflow classification model."""
​
  DEFAULT_EXPORT_FORMAT = (ExportFormat.TFLITE)
  ALLOWED_EXPORT_FORMAT = (ExportFormat.TFLITE, ExportFormat.SAVED_MODEL,
                           ExportFormat.TFJS)
​
  def __init__(self, model_spec, shuffle):
    """Initialize a instance with data, deploy mode and other related parameters.Args:
      model_spec: Specification for the model.
      shuffle: Whether the training data should be shuffled.
    """
    self.model_spec = model_spec
    self.shuffle = shuffle
    self.model = None
    # TODO(yuqili): remove this method once preprocess for image classifier is
    # also moved to DataLoader part.
    self.preprocess = None
​
  @abc.abstractmethod
  def train(self, train_data, validation_data=None, **kwargs):
    return
​
  def summary(self):
    self.model.summary()
​
  @abc.abstractmethod
  def evaluate(self, data, **kwargs):
    return
代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
def export_tflite(model,
                  tflite_filepath,
                  quantization_config=None,
                  convert_from_saved_model_tf2=False,
                  preprocess=None,
                  supported_ops=(tf.lite.OpsSet.TFLITE_BUILTINS,)):
  """Converts the retrained model to tflite format and saves it.Args:
    model: model to be converted to tflite.
    tflite_filepath: File path to save tflite model.
    quantization_config: Configuration for post-training quantization.
    convert_from_saved_model_tf2: Convert to TFLite from saved_model in TF 2.x.
    preprocess: A preprocess function to apply on the dataset.
        # TODO(wangtz): Remove when preprocess is split off from CustomModel.
    supported_ops: A list of supported ops in the converted TFLite file.
  """
  if tflite_filepath is None:
    raise ValueError(
        "TFLite filepath couldn't be None when exporting to tflite.")if compat.get_tf_behavior() == 1:
    lite = tf.compat.v1.lite
  else:
    lite = tf.lite
​
  convert_from_saved_model = (
      compat.get_tf_behavior() == 1 or convert_from_saved_model_tf2)
  with _create_temp_dir(convert_from_saved_model) as temp_dir_name:
    if temp_dir_name:
      save_path = os.path.join(temp_dir_name, 'saved_model')
      model.save(save_path, include_optimizer=False, save_format='tf')
      converter = lite.TFLiteConverter.from_saved_model(save_path)
    else:
      converter = lite.TFLiteConverter.from_keras_model(model)if quantization_config:
      converter = quantization_config.get_converter_with_quantization(
          converter, preprocess=preprocess)
​
    converter.target_spec.supported_ops = supported_ops
    tflite_model = converter.convert()with tf.io.gfile.GFile(tflite_filepath, 'wb') as f:
    f.write(tflite_model)
​
​
def get_lite_runner(tflite_filepath, model_spec=None):
  """Gets `LiteRunner` from file path to TFLite model and `model_spec`."""
  # Gets the functions to handle the input & output indexes if exists.
  reorder_input_details_fn = None
  if hasattr(model_spec, 'reorder_input_details'):
    reorder_input_details_fn = model_spec.reorder_input_details
​
  reorder_output_details_fn = None
  if hasattr(model_spec, 'reorder_output_details'):
    reorder_output_details_fn = model_spec.reorder_output_details
​
  lite_runner = LiteRunner(tflite_filepath, reorder_input_details_fn,
                           reorder_output_details_fn)
  return lite_runner

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
暂无评论
推荐阅读
编辑精选文章
换一批
TensorFlow Lite for Android 初探(附demo)
TensorFlow Lite 是用于移动设备和嵌入式设备的轻量级解决方案。TensorFlow Lite 支持 Android、iOS 甚至树莓派等多种平台。
OpenCV学堂
2019/11/13
1.2K0
TensorFlow巨浪中的巨人:大数据领域的引领者 TensorFlow实战【上进小菜猪大数据系列】
大数据时代的到来带来了海量数据的处理和分析需求。在这个背景下,TensorFlow作为一种强大的深度学习框架,展现了其在大数据领域中的巨大潜力。本文将深入探索TensorFlow在大数据处理和分析中的应用,介绍其在数据预处理、模型构建、分布式训练和性能优化等方面的优势和特点。
上进小菜猪
2023/10/16
3040
TensorFlow巨浪中的巨人:大数据领域的引领者 TensorFlow实战【上进小菜猪大数据系列】
图像分类-flower_photos 实验研究
flower_photos 数据量比较小,所以 simple_cnn 可以在 trainset 上拟合到 0.99,意思就是数据复杂度 < 模型复杂度
XianxinMao
2021/08/22
6120
tensorflow版本的tansformer训练IWSLT数据集
代码来源:https://github.com/Kyubyong/transformer
西西嘛呦
2020/08/26
2K0
精度、延迟两不误,移动端性能新SOTA,谷歌TF开源轻量级EfficientNet
今天,谷歌在 GitHub 与 TFHub 上同时发布了 EfficientNet-Lite,该模型运行在 TensorFlow Lite 上,且专门针对移动设备 CPU、GPU 以及 EdgeTPU 做了优化。EfficientNet-Lite 为边缘设备带来了 EfficientNet 上强大的性能,并且提供五个不同版本,让用户能够根据自己的应用场景灵活地在低延迟与高精度之间选择。
机器之心
2020/03/25
5700
Text classification with TensorFlow Hub: Movie reviews
This notebook classifies movie reviews as positive or negative using the text of the review. This is an example of binary—or two-class—classification, an important and widely applicable kind of machine learning problem.
XianxinMao
2021/07/31
2840
TensorFlow Lite for Android 初探(附demo)一. TensorFlow Lite二. tflite 格式三. 常用的 Java API四. TensorFlow Lite
我们知道大多数的 AI 是在云端运算的,但是在移动端使用 AI 具有无网络延迟、响应更加及时、数据隐私等特性。
fengzhizi715
2018/12/07
3.2K0
TensorFlow-Slim图像分类库
本文介绍了如何使用深度学习模型进行图像分类,并探讨了在训练和评估模型时出现的问题及解决方案。
chaibubble
2018/01/02
2.6K0
TensorFlow-Slim图像分类库
Load and preprocess images
This tutorial shows how to load and preprocess an image dataset in three ways. First, you will use high-level Keras preprocessing utilities and layers to read a directory of images on disk. Next, you will write your own input pipeline from scratch using tf.data. Finally, you will download a dataset from the large catalog available in TensorFlow Datasets.
XianxinMao
2021/07/29
7160
【他山之石】Pytorch/Tensorflow-gpu训练并行加速trick(含代码)
“他山之石,可以攻玉”,站在巨人的肩膀才能看得更高,走得更远。在科研的道路上,更需借助东风才能更快前行。为此,我们特别搜集整理了一些实用的代码链接,数据集,软件,编程技巧等,开辟“他山之石”专栏,助你乘风破浪,一路奋勇向前,敬请关注。
马上科普尚尚
2021/01/28
1.6K0
【他山之石】Pytorch/Tensorflow-gpu训练并行加速trick(含代码)
图像分类任务中,Tensorflow 与 Keras 到底哪个更厉害?
转载来源:AI 研习社编译的技术博客 原标题:Tensorflow Vs Keras? — Comparison by building a model for image classificatio
崔庆才
2019/09/04
9640
图像分类任务中,Tensorflow 与 Keras 到底哪个更厉害?
迁移学习之快速搭建【卷积神经网络】
卷积神经网络 概念认识:https://cloud.tencent.com/developer/article/1822928
一颗小树x
2021/05/14
2K0
迁移学习之快速搭建【卷积神经网络】
使用自己的数据集训练MobileNet、ResNet实现图像分类(TensorFlow)| CSDN博文精选
之前写了一篇博客《使用自己的数据集训练GoogLenet InceptionNet V1 V2 V3模型(TensorFlow)》https://panjinquan.blog.csdn.net/article/details/81560537,本博客就是此博客的框架基础上,完成对MobileNet的图像分类模型的训练,其相关项目的代码也会统一更新到一个Github中,强烈建议先看这篇博客《使用自己的数据集训练GoogLenet InceptionNet V1 V2 V3模型(TensorFlow)》后,再来看这篇博客。
AI科技大本营
2019/12/23
7K0
使用NVIDIA TAO工具包优化Arm Ethos-U NPUs的AI模型
本文翻译自:《Optimizing AI models for Arm Ethos-U NPUs using the NVIDIA TAO Toolkit》
GPUS Lady
2023/10/28
4680
使用NVIDIA TAO工具包优化Arm Ethos-U NPUs的AI模型
AIoT应用创新大赛-基于TFML的迁移学习实践
NXP eIQ平台提供了嵌入式平台集成化的机器学习应用部署能力,支持BYOD(Bring Your Own Data)和BYOM(Bring You Own Model)的两种建模应用的工作流。
flavorfan
2022/02/23
2.2K0
AIoT应用创新大赛-基于TFML的迁移学习实践
keras.Model
Model groups layers into an object with training and inference features.
狼啸风云
2022/06/08
1.2K0
[源码解析] TensorFlow 分布式之 ParameterServerStrategy V2
对于 ParameterServerStrategy V2,我们将从几个方面来研究:如何与集群建立连接,如何生成变量,如何获取数据,如何运行。其中,变量和作用域我们在前文已经研究过,运行在 MirroredStrategy 里面也介绍,所以本文主要看看如何使用,如何初始化。在下一篇之中会重点看看如何分发计算。
罗西的思考
2022/05/15
1.3K0
TensorFlow 2.0到底怎么样?简单的图像分类任务探一探
从历史角度看,TensorFlow 是机器学习框架的「工业车床」:具有复杂性和陡峭学习曲线的强大工具。如果你之前用过 TensorFlow 1.x,你就会知道复杂与难用是在说什么。
机器之心
2019/04/29
1K0
TensorFlow 2.0到底怎么样?简单的图像分类任务探一探
《高效迁移学习:Keras与EfficientNet花卉分类项目全解析》
想象一下:如果一个已经会弹钢琴的人学习吉他,会比完全不懂音乐的人快得多。因为TA已经掌握了乐理知识、节奏感和手指灵活性,这些都可以迁移到新乐器的学习中。这正是迁移学习(Transfer Learning)的核心思想——将已掌握的知识迁移到新任务中。
机器学习司猫白
2025/03/12
1280
《高效迁移学习:Keras与EfficientNet花卉分类项目全解析》
基于Tensorflow2 Lite在Android手机上实现图像分类
Tensorflow2之后,训练保存的模型也有所变化,基于Keras接口搭建的网络模型默认保存的模型是h5格式的,而之前的模型格式是pb。Tensorflow2的h5格式的模型转换成tflite格式模型非常方便。本教程就是介绍如何使用Tensorflow2的Keras接口训练分类模型并使用Tensorflow Lite部署到Android设备上。
夜雨飘零
2020/07/22
3.4K0
基于Tensorflow2 Lite在Android手机上实现图像分类
推荐阅读
相关推荐
TensorFlow Lite for Android 初探(附demo)
更多 >
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档
本文部分代码块支持一键运行,欢迎体验
本文部分代码块支持一键运行,欢迎体验