Loading [MathJax]/jax/output/CommonHTML/config.js
前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >专栏 >Tensorflow Lite Model Maker --- 图像分类篇+源码

Tensorflow Lite Model Maker --- 图像分类篇+源码

原创
作者头像
XianxinMao
修改于 2021-10-11 02:32:18
修改于 2021-10-11 02:32:18
1.3K00
代码可运行
举报
文章被收录于专栏:深度学习框架深度学习框架
运行总次数:0
代码可运行

TFLite_tutorials

The TensorFlow Lite Model Maker library simplifies the process of adapting and converting a TensorFlow neural-network model to particular input data when deploying this model for on-device ML applications. 解读: 此处我们想要得到的是 .tflite 格式的模型,用于在移动端或者嵌入式设备上进行部署

下表罗列的是 TFLite Model Maker 目前支持的几个任务类型

Supported Tasks

Task Utility

Image Classification: tutorial, api

Classify images into predefined categories.

Object Detection: tutorial, api

Detect objects in real time.

Text Classification: tutorial, api

Classify text into predefined categories.

BERT Question Answer: tutorial, api

Find the answer in a certain context for a given question with BERT.

Audio Classification: tutorial, api

Classify audio into predefined categories.

Recommendation: demo, api

Recommend items based on the context information for on-device scenario.

If your tasks are not supported, please first use TensorFlow to retrain a TensorFlow model with transfer learning (following guides like images, text, audio) or train it from scratch, and then convert it to TensorFlow Lite model. 解读: 如果你要训练的模型不符合上述的任务类型,那么可以先训练 Tensorflow Model 然后再转换成 TFLite

想用使用 Tensorflow Lite Model Maker 我们需要先安装:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
pip install tflite-model-maker

本质完成的是分类任务 更换不同的模型,看最终的准确率,以及 TFLite 的大小、推断速度、内存占用、CPU占用等

下面的代码片段是用于下载数据集的

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
image_path = tf.keras.utils.get_file(
    'flower_photos.tgz',
    'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
    extract=True)
image_path = os.path.join(os.path.dirname(image_path), 'flower_photos')

数据集结构如下所示: flower_photos |__ daisy |__ 100080576_f52e8ee070_n.jpg |__ 14167534527_781ceb1b7a_n.jpg |__ ... |__ dandelion |__ 10043234166_e6dd915111_n.jpg |__ 1426682852_e62169221f_m.jpg |__ ... |__ roses |__ 102501987_3cdb8e5394_n.jpg |__ 14982802401_a3dfb22afb.jpg |__ ... |__ sunflowers |__ 12471791574_bb1be83df4.jpg |__ 15122112402_cafa41934f.jpg |__ ... |__ tulips |__ 13976522214_ccec508fe7.jpg |__ 14487943607_651e8062a1_m.jpg |__ ...

加载数据集并切分

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
data = DataLoader.from_folder(image_path)
train_data, rest_data = data.split(0.8)
validation_data, test_data = rest_data.split(0.5)
代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
assert tf.__version__.startswith('2')

判断是否为 '2' 开头

模型训练结果 train_acc = 0.9698, val_acc = 0.9375, test_acc = 0.9210 总体来说符合模型的泛化规律

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
import os
import time
​
import numpy as np
import tensorflow as tf
from tflite_model_maker import model_spec
from tflite_model_maker import image_classifier
from tflite_model_maker.config import ExportFormat
from tflite_model_maker.config import QuantizationConfig
from tflite_model_maker.image_classifier import DataLoader
import matplotlib.pyplot as plt
​
assert tf.__version__.startswith('2')
​
image_path = tf.keras.utils.get_file(
    'flower_photos.tgz',
    'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
    extract=True)
image_path = os.path.join(os.path.dirname(image_path), 'flower_photos')
​
data = DataLoader.from_folder(image_path)
# data = data.gen_dataset(batch_size=1)
train_data, rest_data = data.split(0.8)
# for batch in data.take(1):
#     print(batch)
#     break
​
validation_data, test_data = rest_data.split(0.5)
​
model = image_classifier.create(train_data, validation_data=validation_data,
                                model_spec=model_spec.get('efficientnet_lite0'), epochs=20)
​
loss, accuracy = model.evaluate(test_data)
​
model.export(export_dir='./testTFlite', export_format=(ExportFormat.TFLITE, ExportFormat.LABEL))
​
start = time.time()
print(model.evaluate_tflite('./testTFlite/model.tflite', test_data))
end = time.time()
print('elapsed time: ', end - start)

从上面的输出日志来看,模型经过量化后,准确率并未有多少损失,量化后的模型大小为 4.0MB(efficientnet_lite0) 从下图来看,是单 cpu 在做推断,test_data 的图片有 367 张,总耗时 273.43s

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
config = QuantizationConfig.for_float16()
model.export(export_dir='./testTFlite', tflite_filename='model_fp16.tflite', quantization_config=config, export_format=(ExportFormat.TFLITE, ExportFormat.LABEL))

如果导出的模型是 fp16 的话,模型大小为 6.8MB(efficientnet_lite0),推断速度是 5.54 s,快了很多

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
model = image_classifier.create(train_data, validation_data=validation_data,
                                model_spec=model_spec.get('mobilenet_v2'), epochs=20)

将模型切换为 mobilenet_v2,导出的 fp16 模型大小为 4.6MB,推断速度是 4.36 s

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
inception_v3_spec = image_classifier.ModelSpec(
    uri='https://tfhub.dev/google/imagenet/inception_v3/feature_vector/1')
inception_v3_spec.input_image_shape = [299, 299]
model = image_classifier.create(train_data, validation_data=validation_data,
                                model_spec=inception_v3_spec, epochs=20)

将模型切换为 inception_v3,导出的 fp16 模型大小为 43.8MB(inception_v3),推断速度是 25.31 s

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
Common Dataset used for tasks.class DataLoader(object):
  """This class provides generic utilities for loading customized domain data that will be used later in model retraining.
​
  For different ML problems or tasks, such as image classification, text
  classification etc., a subclass is provided to handle task-specific data
  loading requirements.
  """
​
  def __init__(self, dataset, size):
    """Init function for class `DataLoader`.
​
    In most cases, one should use helper functions like `from_folder` to create
    an instance of this class.Args:
      dataset: A tf.data.Dataset object that contains a potentially large set of
        elements, where each element is a pair of (input_data, target). The
        `input_data` means the raw input data, like an image, a text etc., while
        the `target` means some ground truth of the raw input data, such as the
        classification label of the image etc.
      size: The size of the dataset. tf.data.Dataset donesn't support a function
        to get the length directly since it's lazy-loaded and may be infinite.
    """
    self._dataset = dataset
    self._size = size
​
  def gen_dataset(self,
                  batch_size=1,
                  is_training=False,
                  shuffle=False,
                  input_pipeline_context=None,
                  preprocess=None,
                  drop_remainder=False):
    """Generate a shared and batched tf.data.Dataset for training/evaluation.
代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
Image dataloader
​
class ImageClassifierDataLoader(dataloader.ClassificationDataLoader):
  """DataLoader for image classifier."""
​
  @classmethod
  def from_folder(cls, filename, shuffle=True):
    """Image analysis for image classification load images with labels.
​
    Assume the image data of the same label are in the same subdirectory.Args:
      filename: Name of the file.
      shuffle: boolean, if shuffle, random shuffle data.Returns:
      ImageDataset containing images and labels and other related info.
    """
   @classmethod
   def from_tfds(cls, name):
     """Loads data from tensorflow_datasets."""
代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
ImageNet preprocessing
​
class Preprocessor(object):
  """Preprocessing for image classification."""
​
  def __init__(self,
               input_shape,
               num_classes,
               mean_rgb,
               stddev_rgb,
               use_augmentation=False):
    self.input_shape = input_shape
    self.num_classes = num_classes
    self.mean_rgb = mean_rgb
    self.stddev_rgb = stddev_rgb
    self.use_augmentation = use_augmentation
​
  def __call__(self, image, label, is_training=True):
    if self.use_augmentation:
      return self._preprocess_with_augmentation(image, label, is_training)
    return self._preprocess_without_augmentation(image, label)
​
  def _preprocess_with_augmentation(self, image, label, is_training):
    """Image preprocessing method with data augmentation."""
    image_size = self.input_shape[0]
    if is_training:
      image = preprocess_for_train(image, image_size)
    else:
      image = preprocess_for_eval(image, image_size)
​
    image -= tf.constant(self.mean_rgb, shape=[1, 1, 3], dtype=image.dtype)
    image /= tf.constant(self.stddev_rgb, shape=[1, 1, 3], dtype=image.dtype)
​
    label = tf.one_hot(label, depth=self.num_classes)
    return image, label
​
  # TODO(yuqili): Changes to preprocess to support batch input.
  def _preprocess_without_augmentation(self, image, label):
    """Image preprocessing method without data augmentation."""
    image = tf.cast(image, tf.float32)
​
    image -= tf.constant(self.mean_rgb, shape=[1, 1, 3], dtype=image.dtype)
    image /= tf.constant(self.stddev_rgb, shape=[1, 1, 3], dtype=image.dtype)
​
    image = tf.compat.v1.image.resize(image, self.input_shape)
    label = tf.one_hot(label, depth=self.num_classes)
    return image, label
代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
class ImageClassifier(classification_model.ClassificationModel):
  """ImageClassifier class for inference and exporting to tflite."""
​
  def __init__(self,
               model_spec,
               index_to_label,
               shuffle=True,
               hparams=hub_lib.get_default_hparams(),
               use_augmentation=False,
               representative_data=None):
    """Init function for ImageClassifier class.Args:
      model_spec: Specification for the model.
      index_to_label: A list that map from index to label class name.
      shuffle: Whether the data should be shuffled.
      hparams: A namedtuple of hyperparameters. This function expects
        .dropout_rate: The fraction of the input units to drop, used in dropout
          layer.
        .do_fine_tuning: If true, the Hub module is trained together with the
          classification layer on top.
      use_augmentation: Use data augmentation for preprocessing.
      representative_data:  Representative dataset for full integer
        quantization. Used when converting the keras model to the TFLite model
        with full interger quantization.
    """
    super(ImageClassifier, self).__init__(model_spec, index_to_label, shuffle,
                                          hparams.do_fine_tuning)
    num_classes = len(index_to_label)
    self._hparams = hparams
    self.preprocess = image_preprocessing.Preprocessor(
        self.model_spec.input_image_shape,
        num_classes,
        self.model_spec.mean_rgb,
        self.model_spec.stddev_rgb,
        use_augmentation=use_augmentation)
    self.history = None  # Training history that returns from `keras_model.fit`.
    self.representative_data = representative_data
​
  def _get_tflite_input_tensors(self, input_tensors):
    """Gets the input tensors for the TFLite model."""
    return input_tensors
​
  def create_model(self, hparams=None, with_loss_and_metrics=False):
    """Creates the classifier model for retraining."""
    hparams = self._get_hparams_or_default(hparams)
​
    module_layer = hub_loader.HubKerasLayerV1V2(
        self.model_spec.uri, trainable=hparams.do_fine_tuning)
    self.model = hub_lib.build_model(module_layer, hparams,
                                     self.model_spec.input_image_shape,
                                     self.num_classes)
    if with_loss_and_metrics:
      # Adds loss and metrics in the keras model.
      self.model.compile(
          loss=tf.keras.losses.CategoricalCrossentropy(label_smoothing=0.1),
          metrics=['accuracy'])
代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
Custom classification model that is already retained by data
​
class ClassificationModel(custom_model.CustomModel):
  """"The abstract base class that represents a Tensorflow classification model."""
​
  DEFAULT_EXPORT_FORMAT = (ExportFormat.TFLITE, ExportFormat.LABEL)
  ALLOWED_EXPORT_FORMAT = (ExportFormat.TFLITE, ExportFormat.LABEL,
                           ExportFormat.SAVED_MODEL, ExportFormat.TFJS)
​
  def __init__(self, model_spec, index_to_label, shuffle, train_whole_model):
    """Initialize a instance with data, deploy mode and other related parameters.Args:
      model_spec: Specification for the model.
      index_to_label: A list that map from index to label class name.
      shuffle: Whether the data should be shuffled.
      train_whole_model: If true, the Hub module is trained together with the
        classification layer on top. Otherwise, only train the top
        classification layer.
    """
    super(ClassificationModel, self).__init__(model_spec, shuffle)
    self.index_to_label = index_to_label
    self.num_classes = len(index_to_label)
    self.train_whole_model = train_whole_model
​
  def evaluate(self, data, batch_size=32):
    """Evaluates the model.Args:
      data: Data to be evaluated.
      batch_size: Number of samples per evaluation step.Returns:
      The loss value and accuracy.
    """
    ds = data.gen_dataset(
        batch_size, is_training=False, preprocess=self.preprocess)
    return self.model.evaluate(ds)
​
  def predict_top_k(self, data, k=1, batch_size=32):
    """Predicts the top-k predictions.
代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
class CustomModel(abc.ABC):
  """"The abstract base class that represents a Tensorflow classification model."""
​
  DEFAULT_EXPORT_FORMAT = (ExportFormat.TFLITE)
  ALLOWED_EXPORT_FORMAT = (ExportFormat.TFLITE, ExportFormat.SAVED_MODEL,
                           ExportFormat.TFJS)
​
  def __init__(self, model_spec, shuffle):
    """Initialize a instance with data, deploy mode and other related parameters.Args:
      model_spec: Specification for the model.
      shuffle: Whether the training data should be shuffled.
    """
    self.model_spec = model_spec
    self.shuffle = shuffle
    self.model = None
    # TODO(yuqili): remove this method once preprocess for image classifier is
    # also moved to DataLoader part.
    self.preprocess = None
​
  @abc.abstractmethod
  def train(self, train_data, validation_data=None, **kwargs):
    return
​
  def summary(self):
    self.model.summary()
​
  @abc.abstractmethod
  def evaluate(self, data, **kwargs):
    return
代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
def export_tflite(model,
                  tflite_filepath,
                  quantization_config=None,
                  convert_from_saved_model_tf2=False,
                  preprocess=None,
                  supported_ops=(tf.lite.OpsSet.TFLITE_BUILTINS,)):
  """Converts the retrained model to tflite format and saves it.Args:
    model: model to be converted to tflite.
    tflite_filepath: File path to save tflite model.
    quantization_config: Configuration for post-training quantization.
    convert_from_saved_model_tf2: Convert to TFLite from saved_model in TF 2.x.
    preprocess: A preprocess function to apply on the dataset.
        # TODO(wangtz): Remove when preprocess is split off from CustomModel.
    supported_ops: A list of supported ops in the converted TFLite file.
  """
  if tflite_filepath is None:
    raise ValueError(
        "TFLite filepath couldn't be None when exporting to tflite.")if compat.get_tf_behavior() == 1:
    lite = tf.compat.v1.lite
  else:
    lite = tf.lite
​
  convert_from_saved_model = (
      compat.get_tf_behavior() == 1 or convert_from_saved_model_tf2)
  with _create_temp_dir(convert_from_saved_model) as temp_dir_name:
    if temp_dir_name:
      save_path = os.path.join(temp_dir_name, 'saved_model')
      model.save(save_path, include_optimizer=False, save_format='tf')
      converter = lite.TFLiteConverter.from_saved_model(save_path)
    else:
      converter = lite.TFLiteConverter.from_keras_model(model)if quantization_config:
      converter = quantization_config.get_converter_with_quantization(
          converter, preprocess=preprocess)
​
    converter.target_spec.supported_ops = supported_ops
    tflite_model = converter.convert()with tf.io.gfile.GFile(tflite_filepath, 'wb') as f:
    f.write(tflite_model)
​
​
def get_lite_runner(tflite_filepath, model_spec=None):
  """Gets `LiteRunner` from file path to TFLite model and `model_spec`."""
  # Gets the functions to handle the input & output indexes if exists.
  reorder_input_details_fn = None
  if hasattr(model_spec, 'reorder_input_details'):
    reorder_input_details_fn = model_spec.reorder_input_details
​
  reorder_output_details_fn = None
  if hasattr(model_spec, 'reorder_output_details'):
    reorder_output_details_fn = model_spec.reorder_output_details
​
  lite_runner = LiteRunner(tflite_filepath, reorder_input_details_fn,
                           reorder_output_details_fn)
  return lite_runner

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
暂无评论
推荐阅读
编辑精选文章
换一批
Step By Step To Learn Kubernetes Series
flannel借助etcd内的路由表实现k8s集群节点上的每个Pod能相互的通信。
公众号: 云原生生态圈
2021/11/15
4170
CentOS7.7部署k8s(3 master + 3 node + 1 client)
VMware创建7个vm,规格2cpu 2G mem 200G disk,一个NAT网卡
后端云
2020/04/22
2.3K0
CentOS7.7部署k8s(3 master + 3 node + 1 client)
Kubernetes简介和安装
Production-Grade Container Orchestration Automated container deployment, scaling, and management
IT小马哥
2021/06/03
1.4K0
Kubernetes简介和安装
腾讯云CVM上用kubeadm安装Kubernetes集群(版本1.14.0)
kubeadm是Kubernetes官方提供的用于快速安装 Kubernetes 集群的工具,通过将集群的各个组件进行容器化安装管理,通过kubeadm的方式安装集群比二进制的方式安装要方便
马凌鑫
2019/04/02
4.1K0
Kubernetes部署并初始化群集
docker的Cgroup Driver和kubelet的Cgroup Driver不一致,修改如下:
行 者
2023/10/19
3380
Kubernetes | 集群安装 - ClusterInstallation
CentOS 7.x 系统自带的 3.10.x 内核存在一些 Bugs,导致运行的 Docker、Kubernetes 不稳定,例如: rpm -Uvhhttp://www.elrepo.org/elrepo-release-7.0-3.el7.elrepo.noarch.rpm
Zkeq
2023/04/29
1K0
Kubernetes | 集群安装 - ClusterInstallation
CentOS7.7部署k8s(1 master + 2 node)
VMware创建三个vm,规格2cpu 4G mem 200G disk,一个NAT网卡
后端云
2020/04/22
1.4K0
CentOS7.7部署k8s(1 master + 2 node)
安装Kubernetes集群
之前我们在windows机器上用Minikube安装了一个单节点Kubernetes集群,这个只能当做了解k8s的练手,本篇文章我们安装一个拥有一个Master,两个Worker节点的k8s集群,作为熟悉Kubernetes的测试集群。
云原生
2021/05/31
1.1K0
安装Kubernetes集群
外包精通--教你5分钟搞定k8s安装(CentOS)笔记、思路
我们知道k8s的主机角色分为master、worknode,创建k8s集群首先需要初始化k8s的master节点。
Godev
2023/06/25
2.3K0
kubeadm搭建单master节点1.20版本kubernetes集群
由于是云服务器,selinux、firewalld、swap都会默认关闭,iptables规则也会清空,所以仅需要配置下主机名、hosts文件以及配置下kubernetes的转发规则就好,如下:
唐旭
2021/11/02
1.6K1
Kubernetes 使用kubeadm创建集群
注意,安装docker时,需要指Kubenetes支持的版本(参见如下),如果安装的docker版本过高导致,会提示以下问题
授客
2021/09/26
3.5K0
kubernetes系列教程(二)kubeadm离线部署1.14.1集群
本章是kubernetes系列教程第二篇,要深入学习kubernetes,首先需要有一个k8s环境,然而,受制硬件环境,网络环境等因素,要搭建一个环境有一定的困难,让很多初学者望而却步,本章主要介绍通过kubeadm安装工具部署kubernetes集群,考虑到国内网络限制,已将安装镜像通过跳板机下载到本地,方便大家离线安装。
HappyLau谈云计算
2019/08/03
14.2K2
kubernetes系列教程(二)kubeadm离线部署1.14.1集群
使用kubeadm搭建高可用k8s v1.16.3集群
本文通过kubeadm搭建一个高可用的k8s集群,kubeadm可以帮助我们快速的搭建k8s集群,高可用主要体现在对master节点组件及etcd存储的高可用,文中使用到的服务器ip及角色对应如下:
仙人技术
2020/04/29
2.2K1
使用kubeadm搭建高可用k8s v1.16.3集群
部署 Kubernetes + KubeVirt 以及 KubeVirt的基本使用
KubeVirt目的是让虚拟机运行在容器中,下面就用下KubeVirt的几个基本操作:
后端云
2022/06/09
4.2K0
部署 Kubernetes + KubeVirt 以及 KubeVirt的基本使用
kubernetes部署:基于kubeadm的国内镜像源安装
Kubernetes 1.8开始要求关闭系统的Swap,如果不关闭,默认配置下kubelet将无法启动,关闭系统的Swap方法如下:
机械视角
2019/10/23
16.4K0
kubernetes部署:基于kubeadm的国内镜像源安装
使用kubeadm快速启用一个集群
使用kubeadm快速启用一个集群 ================= [图片] CentOS 配置YUM源 ============= cat <<EOF > /etc/yum.repos.d/kubernetes.repo [kubernetes] name=kubernetes baseurl=https://mirrors.ustc.edu.cn/kubernetes/yum/repos/kubernetes-el7-$basearch enabled=1 EOF setenforce 0 yum
小陈运维
2022/05/06
3490
CentOS 7 上安装配置 Kubernetes 集群
安装和配置 Kubernetes 集群的过程是比较繁琐的,这里阐述在 Mac 上利用 virtualbox 配置 CentOS 7 上的 Kubernetes 集群的过程。
星哥玩云
2022/07/28
5660
kubeadm部署K8S集群并使用containerd做容器运行时
去年12月份,当Kubernetes社区宣布1.20版本之后会逐步弃用dockershim,当时也有很多自媒体在宣传Kubernetes弃用Docker。其实,我觉得这是一种误导,也许仅仅是为了蹭热度。
没有故事的陈师傅
2021/04/08
3K0
Kubernetes(k8s)-安装k8s(containerd版)
我们上一章介绍了Docker基本情况,目前在规模较大的容器集群基本都是Kubernetes,但是Kubernetes涉及的东西和概念确实是太多了,而且随着版本迭代功能在还增加,笔者有些功能也确实没用过,所以只能按照我自己的理解来讲解。
运维小路
2024/12/23
6570
Kubernetes(k8s)-安装k8s(containerd版)
kubeadm部署高可用kubernetes
(2)启用 ELRepo 仓库 ELRepo 仓库是基于社区的用于企业级 Linux 仓库,提供对 RedHat Enterprise (RHEL) 和 其他基于 RHEL的 Linux 发行版(CentOS、Scientific、Fedora 等)的支持。 ELRepo 聚焦于和硬件相关的软件包,包括文件系统驱动、显卡驱动、网络驱动、声卡驱动和摄像头驱动等。
全栈程序员站长
2022/09/15
1K0
推荐阅读
相关推荐
Step By Step To Learn Kubernetes Series
更多 >
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档
本文部分代码块支持一键运行,欢迎体验
本文部分代码块支持一键运行,欢迎体验