前往小程序,Get更优阅读体验!
立即前往
发布
社区首页 >专栏 >PyTorch中手机相册图像的分类

PyTorch中手机相册图像的分类

作者头像
代码医生工作室
发布2019-11-29 13:32:35
发布2019-11-29 13:32:35
1.7K00
代码可运行
举报
文章被收录于专栏:相约机器人相约机器人
运行总次数:0
代码可运行

作者 | n0obcoder

来源 | Medium

编辑 | 代码医生团队

这个小型项目听起来像是一个基于深度神经网络的图像分类器的良好实际应用。建立自己的手机相册分类器可能会是一个有趣的体验。

步骤1:建立数据集

需要列出所有希望图像分类器从中输出结果的类别。

由于这是一个手机相册图像分类项目,因此在浏览手机相册时,会选择经常遇到的类。

以下是选择的类

  • 汽车
  • Memes
  • 山脉
  • 自拍
  • 树木
  • 截图

一旦获得所有所需类别的列表,就必须为这些类别收集图像。

有几种不同的收集图像数据的方式

  • 手动收集-可以使用手机相册中的现有图像,也可以单击列为目标类的事物图片。
  • 网络爬取-可以通过多种方式从网络爬取图像。一个python脚本,可用于下载特定类的图像。

下载图像后,必须将它们分为不同的类目录。因此有6个目录,其中包含各个类的图像。

使用了上述两种数据收集方法。可以在stackoverflow等网站上轻松找到该脚本。但是由于无法在互联网上找到截图的精美图像,因此不得不从手机中收集它们。

https://stackoverflow.com/questions/20716842/python-download-images-from-google-image-search

建立了一个手机相册图像数据集,因此可以免费使用它,而不必为自己制作一个:

https://www.kaggle.com/n0obcoder/mobile-gallery-image-classification-data?source=post_page-----8ba2d32ce2bf----------------------

可以下载上述数据集并将其提取到根目录中,以便python脚本或jupyter笔记本文件与数据集目录位于同一目录中。确保数据集目录的路径遵循以下所示的路径。

代码语言:javascript
代码运行次数:0
复制
import matplotlib.pyplot as plt%matplotlib inlineimport os, glob, sys def q(text = ''): # Just an exit function    print(text)    sys.exit() # Input data files are available in the "/kaggle/input/" directory. data_dir = 'mobile-gallery-image-classification-data/mobile_gallery_image_classification/train' for class_dir in glob.glob(data_dir + os.sep + '*'):    print(class_dir) # Any results you write to the current directory ('/kaggle/working') are saved as output.

代码语言:javascript
代码运行次数:0
复制
mobile-gallery-image-classification-data/mobile_gallery_image_classification/train\Carsmobile-gallery-image-classification-data/mobile_gallery_image_classification/train\Memesmobile-gallery-image-classification-data/mobile_gallery_image_classification/train\Mountainsmobile-gallery-image-classification-data/mobile_gallery_image_classification/train\Selfiesmobile-gallery-image-classification-data/mobile_gallery_image_classification/train\Treesmobile-gallery-image-classification-data/mobile_gallery_image_classification/train\Whatsapp_Screenshots

以下是来自手机相册图像数据集的一些示例。

这些是从Mobile Image Gallery数据集中的训练数据中获取的样本图像中的几个。它们各自的类别:Memes(左上),汽车(右上),树木(右下)和山脉(左下)

步骤2:数据预处理和制作DataLoader

数据集准备好之后,要做的下一步就是进行一些数据预处理。通过数据预处理,执行一些简单的图像处理操作,例如调整大小,在水平轴上随机翻转图像,将图像(具有介于0到255之间的整数值的像素)转换为张量(具有浮点数范围的像素值)从0.0到1.0),最后但并非最不重要的一点是,通过使用ImageNet统计信息对张量进行归一化(均值= [0.485,0.456,0.406],std = [0.229,0.224,0.225])。请注意,正在处理BGR(彩色)图像,而不是灰度(黑白)图像。

接下来,利用数据路径和要应用于图像数据的变换/预处理来创建数据集对象。

通过定义拆分百分比,将数据集随机分为训练和验证数据集。

代码语言:javascript
代码运行次数:0
复制
# Let's start by loading in the image dataimport torchfrom torchvision import datasets, transforms # Defining the transforms that we want to apply to the data.# Resizing the image to (224,224),# Randomly flipping the image horizontally(with the default probability of 0.5),# Converting the image to Tensore (converting the pixel values btween 0 and 1),# Normalizing the 3-channel data using the 'Imagenet' statsdata_transforms = transforms.Compose([        transforms.Resize((224, 224)),        transforms.RandomHorizontalFlip(),        transforms.ToTensor(),        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])        ]) print('data_transforms: ', data_transforms) dataset = datasets.ImageFolder(data_dir, transform = data_transforms)print('dataset: ', dataset) # We need to split the dataset between train and val datasetstrain_percentage = 0.8train_size = int(len(dataset)*train_percentage)val_size = len(dataset) - train_sizetrain_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size]) print('\nnumber of examples in train_dataset: ', len(train_dataset))print('number of examples in val_dataset  : ', len(val_dataset))
代码语言:javascript
代码运行次数:0
复制
data_transforms:  Compose(    Resize(size=(224, 224), interpolation=PIL.Image.BILINEAR)    RandomHorizontalFlip(p=0.5)    ToTensor()    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))dataset:  Dataset ImageFolder    Number of datapoints: 1266    Root Location: mobile-gallery-image-classification-data/mobile_gallery_image_classification/train    Transforms (if any): Compose(                             Resize(size=(224, 224), interpolation=PIL.Image.BILINEAR)                             RandomHorizontalFlip(p=0.5)                             ToTensor()                             Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])                         )    Target Transforms (if any): None number of examples in train_dataset:  1012number of examples in val_dataset  :  254

最后,通过定义批处理大小来创建训练和验证数据加载器对象。通过使用矩阵乘法,可以使训练和验证过程变得异常快。

代码语言:javascript
代码运行次数:0
复制
# Defining dataloaders which would return data in batchesbatch_size = 64train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = batch_size, shuffle = True)val_loader   = torch.utils.data.DataLoader(val_dataset,   batch_size = batch_size, shuffle = False) print('number of batches in train_loader with a batch_size of {}: {}'.format(batch_size,  len(train_loader)))print('number of batches in val_loader with a batch_size of {}: {}'.format(batch_size,  len(val_loader)))
代码语言:javascript
代码运行次数:0
复制
number of batches in train_loader with a batch_size of 64: 16number of batches in val_loader with a batch_size of 64: 4

步骤3:定义合适的模型并进行必要的调整

首先,将使用基于卷积神经网络的体系结构,因为在处理图像或与此相关的任何类型的具有空间关系的数据时,没有什么能比CNN更好。由于已经有许多基于CNN的久经考验的架构,因此不需要尝试一种新的架构。

不会自己编写基于CNN的模型的体系结构,而是将使用许多现有模型中的一种。这样做的主要原因有两个:

  1. 这些体系结构已在各种数据集上进行了成功的尝试和测试,并显示了出色的结果。
  2. 这些体系结构已在庞大的公共数据集中进行了训练,并且其预训练权重已公开提供。

有很多可用预训练模式torchvision,像AlexNet,RESNET,VGG,InceptionNet,DenseNet等可供选择。选择名为Resnet34的模型,因为它既不太深也不太浅。如果打算在没有GPU卡的机器上使用此功能,则在训练时它不应占用大量计算能力。

https://pytorch.org/docs/stable/torchvision/index.html#torchvision

但是必须注意一件事。这种基于CNN的架构的最后线性层中神经元的数量代表了数据集中存在的所有类别。最初,Resnet34用于在ImageNet数据集上进行训练,该数据集具有1000个类。但是希望该模型仅输出数据集中具有的类数的预测(本例中为6)。因此仅用具有6个神经元的新线性层替换该模型中的最后一个线性层,输出6个类的预测。

代码语言:javascript
代码运行次数:0
复制
import torchimport torch.nn as nnfrom torchvision import models # Defining the modelmodel = models.resnet34(pretrained = True) # The original architecture of resnet34 has 1000 neurons(corresponding to 1000 classes on which it was originally trained on) in the final layer.# So we need to change the final layer according to the number of classes that we have in our datasetprint('model.fc before: ', model.fc)model_fc_in_features = model.fc.in_featuresmodel.fc = nn.Linear(model_fc_in_features, len(dataset.classes))print('model.fc after : ', model.fc)
代码语言:javascript
代码运行次数:0
复制
model.fc before:  Linear(in_features=512, out_features=1000, bias=True)model.fc after :  Linear(in_features=512, out_features=6, bias=True)

步骤4:通过冻结和取消冻结各层来进行学习

值得注意的是,由于使用的是预先训练的模型,因此它是过滤器,或者内核已经学会了识别某些功能。让尝试更详细地解释这些过滤器已学会识别的功能到底是什么。

初始卷积层中的滤镜学习简单,基本的特征,例如边缘,颜色和纹理;中间层的人可能会学习圆形和多边形等形状;而较深层的滤镜则可以学习更多复杂的图案,例如脸部或花瓣等。通过查看下面的图片,这将变得更加清晰。

低级功能包括边缘,颜色,纹理。中级特征包括简单的形状和几何形状。高级功能包括复杂的形状和对象,例如面孔,花朵等。

显然,可以利用存在于初始层和中间层中的滤镜,因为需要它们来识别输入图像中的边缘,颜色,纹理和简单形状。可能不希望保留的是最后几个卷积和线性层中存在的滤波器。因此训练模型应该仅在最后几层(卷积层或线性层)上以较小的学习率在自定义数据集上微调模型。

代码语言:javascript
代码运行次数:0
复制
# Now let's have a look at the requires_grad attributes for all the parameters for name, param in model.named_parameters():print('name: {} has requires_grad: {}'.format(name, param.requires_grad))
  • name: conv1.weight has requires_grad: Truename: bn1.weight has requires_grad: Truename: bn1.bias has requires_grad: Truename: layer1.0.conv1.weight has requires_grad: Truename: layer1.0.bn1.weight has requires_grad: Truename: layer1.0.bn1.bias has requires_grad: Truename: layer1.0.conv2.weight has requires_grad: Truename: layer1.0.bn2.weight has requires_grad: Truename: layer1.0.bn2.bias has requires_grad: Truename: layer1.1.conv1.weight has requires_grad: Truename: layer1.1.bn1.weight has requires_grad: Truename: layer1.1.bn1.bias has requires_grad: Truename: layer1.1.conv2.weight has requires_grad: Truename: layer1.1.bn2.weight has requires_grad: Truename: layer1.1.bn2.bias has requires_grad: Truename: layer1.2.conv1.weight has requires_grad: Truename: layer1.2.bn1.weight has requires_grad: Truename: layer1.2.bn1.bias has requires_grad: Truename: layer1.2.conv2.weight has requires_grad: Truename: layer1.2.bn2.weight has requires_grad: Truename: layer1.2.bn2.bias has requires_grad: Truename: layer2.0.conv1.weight has requires_grad: Truename: layer2.0.bn1.weight has requires_grad: Truename: layer2.0.bn1.bias has requires_grad: Truename: layer2.0.conv2.weight has requires_grad: Truename: layer2.0.bn2.weight has requires_grad: Truename: layer2.0.bn2.bias has requires_grad: Truename: layer2.0.downsample.0.weight has requires_grad: Truename: layer2.0.downsample.1.weight has requires_grad: Truename: layer2.0.downsample.1.bias has requires_grad: Truename: layer2.1.conv1.weight has requires_grad: Truename: layer2.1.bn1.weight has requires_grad: Truename: layer2.1.bn1.bias has requires_grad: Truename: layer2.1.conv2.weight has requires_grad: Truename: layer2.1.bn2.weight has requires_grad: Truename: layer2.1.bn2.bias has requires_grad: Truename: layer2.2.conv1.weight has requires_grad: Truename: layer2.2.bn1.weight has requires_grad: Truename: layer2.2.bn1.bias has requires_grad: Truename: layer2.2.conv2.weight has requires_grad: Truename: layer2.2.bn2.weight has requires_grad: Truename: layer2.2.bn2.bias has requires_grad: Truename: layer2.3.conv1.weight has requires_grad: Truename: layer2.3.bn1.weight has requires_grad: Truename: layer2.3.bn1.bias has requires_grad: Truename: layer2.3.conv2.weight has requires_grad: Truename: layer2.3.bn2.weight has requires_grad: Truename: layer2.3.bn2.bias has requires_grad: Truename: layer3.0.conv1.weight has requires_grad: Truename: layer3.0.bn1.weight has requires_grad: Truename: layer3.0.bn1.bias has requires_grad: Truename: layer3.0.conv2.weight has requires_grad: Truename: layer3.0.bn2.weight has requires_grad: Truename: layer3.0.bn2.bias has requires_grad: Truename: layer3.0.downsample.0.weight has requires_grad: Truename: layer3.0.downsample.1.weight has requires_grad: Truename: layer3.0.downsample.1.bias has requires_grad: Truename: layer3.1.conv1.weight has requires_grad: Truename: layer3.1.bn1.weight has requires_grad: Truename: layer3.1.bn1.bias has requires_grad: Truename: layer3.1.conv2.weight has requires_grad: Truename: layer3.1.bn2.weight has requires_grad: Truename: layer3.1.bn2.bias has requires_grad: Truename: layer3.2.conv1.weight has requires_grad: Truename: layer3.2.bn1.weight has requires_grad: Truename: layer3.2.bn1.bias has requires_grad: Truename: layer3.2.conv2.weight has requires_grad: Truename: layer3.2.bn2.weight has requires_grad: Truename: layer3.2.bn2.bias has requires_grad: Truename: layer3.3.conv1.weight has requires_grad: Truename: layer3.3.bn1.weight has requires_grad: Truename: layer3.3.bn1.bias has requires_grad: Truename: layer3.3.conv2.weight has requires_grad: Truename: layer3.3.bn2.weight has requires_grad: Truename: layer3.3.bn2.bias has requires_grad: Truename: layer3.4.conv1.weight has requires_grad: Truename: layer3.4.bn1.weight has requires_grad: Truename: layer3.4.bn1.bias has requires_grad: Truename: layer3.4.conv2.weight has requires_grad: Truename: layer3.4.bn2.weight has requires_grad: Truename: layer3.4.bn2.bias has requires_grad: Truename: layer3.5.conv1.weight has requires_grad: Truename: layer3.5.bn1.weight has requires_grad: Truename: layer3.5.bn1.bias has requires_grad: Truename: layer3.5.conv2.weight has requires_grad: Truename: layer3.5.bn2.weight has requires_grad: Truename: layer3.5.bn2.bias has requires_grad: Truename: layer4.0.conv1.weight has requires_grad: Truename: layer4.0.bn1.weight has requires_grad: Truename: layer4.0.bn1.bias has requires_grad: Truename: layer4.0.conv2.weight has requires_grad: Truename: layer4.0.bn2.weight has requires_grad: Truename: layer4.0.bn2.bias has requires_grad: Truename: layer4.0.downsample.0.weight has requires_grad: Truename: layer4.0.downsample.1.weight has requires_grad: Truename: layer4.0.downsample.1.bias has requires_grad: Truename: layer4.1.conv1.weight has requires_grad: Truename: layer4.1.bn1.weight has requires_grad: Truename: layer4.1.bn1.bias has requires_grad: Truename: layer4.1.conv2.weight has requires_grad: Truename: layer4.1.bn2.weight has requires_grad: Truename: layer4.1.bn2.bias has requires_grad: Truename: layer4.2.conv1.weight has requires_grad: Truename: layer4.2.bn1.weight has requires_grad: Truename: layer4.2.bn1.bias has requires_grad: Truename: layer4.2.conv2.weight has requires_grad: Truename: layer4.2.bn2.weight has requires_grad: Truename: layer4.2.bn2.bias has requires_grad: Truename: fc.weight has requires_grad: Truename: fc.bias has requires_grad: True

看到所有参数在开始时都是可训练的(requires_grad = True表示该参数是可学习的)

看看这些图层的名称是什么,以便可以冻结它们的最后两个

代码语言:javascript
代码运行次数:0
复制
for name, module in model.named_children():print('name: ', name)
代码语言:javascript
代码运行次数:0
复制
name:  conv1name:  bn1name:  reluname:  maxpoolname:  layer1name:  layer2name:  layer3name:  layer4name:  avgpoolname:  fc

因此,冻结了网络的所有层,但只有少数层。这将确保仅对网络末端的未冻结层进行微调,而其他层保持不变。

代码语言:javascript
代码运行次数:0
复制
#  We would freeze all but the last few layers (layer4 and fc) for name, param in model.named_parameters():    if ('layer4' in name) or ('fc' in name):        param.requires_grad = True    else:        param.requires_grad = False

为所有参数打印出require_grad并确保已进行所需的更改

代码语言:javascript
代码运行次数:0
复制
for name, param in model.named_parameters():print('name: {} has requires_grad: {}'.format(name, param.requires_grad))
代码语言:javascript
代码运行次数:0
复制
/*
* 提示:该行代码过长,系统自动注释不进行高亮。一键复制会移除系统注释 
* name: conv1.weight has requires_grad: Falsename: bn1.weight has requires_grad: Falsename: bn1.bias has requires_grad: Falsename: layer1.0.conv1.weight has requires_grad: Falsename: layer1.0.bn1.weight has requires_grad: Falsename: layer1.0.bn1.bias has requires_grad: Falsename: layer1.0.conv2.weight has requires_grad: Falsename: layer1.0.bn2.weight has requires_grad: Falsename: layer1.0.bn2.bias has requires_grad: Falsename: layer1.1.conv1.weight has requires_grad: Falsename: layer1.1.bn1.weight has requires_grad: Falsename: layer1.1.bn1.bias has requires_grad: Falsename: layer1.1.conv2.weight has requires_grad: Falsename: layer1.1.bn2.weight has requires_grad: Falsename: layer1.1.bn2.bias has requires_grad: Falsename: layer1.2.conv1.weight has requires_grad: Falsename: layer1.2.bn1.weight has requires_grad: Falsename: layer1.2.bn1.bias has requires_grad: Falsename: layer1.2.conv2.weight has requires_grad: Falsename: layer1.2.bn2.weight has requires_grad: Falsename: layer1.2.bn2.bias has requires_grad: Falsename: layer2.0.conv1.weight has requires_grad: Falsename: layer2.0.bn1.weight has requires_grad: Falsename: layer2.0.bn1.bias has requires_grad: Falsename: layer2.0.conv2.weight has requires_grad: Falsename: layer2.0.bn2.weight has requires_grad: Falsename: layer2.0.bn2.bias has requires_grad: Falsename: layer2.0.downsample.0.weight has requires_grad: Falsename: layer2.0.downsample.1.weight has requires_grad: Falsename: layer2.0.downsample.1.bias has requires_grad: Falsename: layer2.1.conv1.weight has requires_grad: Falsename: layer2.1.bn1.weight has requires_grad: Falsename: layer2.1.bn1.bias has requires_grad: Falsename: layer2.1.conv2.weight has requires_grad: Falsename: layer2.1.bn2.weight has requires_grad: Falsename: layer2.1.bn2.bias has requires_grad: Falsename: layer2.2.conv1.weight has requires_grad: Falsename: layer2.2.bn1.weight has requires_grad: Falsename: layer2.2.bn1.bias has requires_grad: Falsename: layer2.2.conv2.weight has requires_grad: Falsename: layer2.2.bn2.weight has requires_grad: Falsename: layer2.2.bn2.bias has requires_grad: Falsename: layer2.3.conv1.weight has requires_grad: Falsename: layer2.3.bn1.weight has requires_grad: Falsename: layer2.3.bn1.bias has requires_grad: Falsename: layer2.3.conv2.weight has requires_grad: Falsename: layer2.3.bn2.weight has requires_grad: Falsename: layer2.3.bn2.bias has requires_grad: Falsename: layer3.0.conv1.weight has requires_grad: Falsename: layer3.0.bn1.weight has requires_grad: Falsename: layer3.0.bn1.bias has requires_grad: Falsename: layer3.0.conv2.weight has requires_grad: Falsename: layer3.0.bn2.weight has requires_grad: Falsename: layer3.0.bn2.bias has requires_grad: Falsename: layer3.0.downsample.0.weight has requires_grad: Falsename: layer3.0.downsample.1.weight has requires_grad: Falsename: layer3.0.downsample.1.bias has requires_grad: Falsename: layer3.1.conv1.weight has requires_grad: Falsename: layer3.1.bn1.weight has requires_grad: Falsename: layer3.1.bn1.bias has requires_grad: Falsename: layer3.1.conv2.weight has requires_grad: Falsename: layer3.1.bn2.weight has requires_grad: Falsename: layer3.1.bn2.bias has requires_grad: Falsename: layer3.2.conv1.weight has requires_grad: Falsename: layer3.2.bn1.weight has requires_grad: Falsename: layer3.2.bn1.bias has requires_grad: Falsename: layer3.2.conv2.weight has requires_grad: Falsename: layer3.2.bn2.weight has requires_grad: Falsename: layer3.2.bn2.bias has requires_grad: Falsename: layer3.3.conv1.weight has requires_grad: Falsename: layer3.3.bn1.weight has requires_grad: Falsename: layer3.3.bn1.bias has requires_grad: Falsename: layer3.3.conv2.weight has requires_grad: Falsename: layer3.3.bn2.weight has requires_grad: Falsename: layer3.3.bn2.bias has requires_grad: Falsename: layer3.4.conv1.weight has requires_grad: Falsename: layer3.4.bn1.weight has requires_grad: Falsename: layer3.4.bn1.bias has requires_grad: Falsename: layer3.4.conv2.weight has requires_grad: Falsename: layer3.4.bn2.weight has requires_grad: Falsename: layer3.4.bn2.bias has requires_grad: Falsename: layer3.5.conv1.weight has requires_grad: Falsename: layer3.5.bn1.weight has requires_grad: Falsename: layer3.5.bn1.bias has requires_grad: Falsename: layer3.5.conv2.weight has requires_grad: Falsename: layer3.5.bn2.weight has requires_grad: Falsename: layer3.5.bn2.bias has requires_grad: Falsename: layer4.0.conv1.weight has requires_grad: Truename: layer4.0.bn1.weight has requires_grad: Truename: layer4.0.bn1.bias has requires_grad: Truename: layer4.0.conv2.weight has requires_grad: Truename: layer4.0.bn2.weight has requires_grad: Truename: layer4.0.bn2.bias has requires_grad: Truename: layer4.0.downsample.0.weight has requires_grad: Truename: layer4.0.downsample.1.weight has requires_grad: Truename: layer4.0.downsample.1.bias has requires_grad: Truename: layer4.1.conv1.weight has requires_grad: Truename: layer4.1.bn1.weight has requires_grad: Truename: layer4.1.bn1.bias has requires_grad: Truename: layer4.1.conv2.weight has requires_grad: Truename: layer4.1.bn2.weight has requires_grad: Truename: layer4.1.bn2.bias has requires_grad: Truename: layer4.2.conv1.weight has requires_grad: Truename: layer4.2.bn1.weight has requires_grad: Truename: layer4.2.bn1.bias has requires_grad: Truename: layer4.2.conv2.weight has requires_grad: Truename: layer4.2.bn2.weight has requires_grad: Truename: layer4.2.bn2.bias has requires_grad: Truename: fc.weight has requires_grad: Truename: fc.bias has requires_grad: True
*/

是的,已经进行了更改!(请参见,存在于“ layer4”和“ fc”中的参数具有require_grad = True,其余所有其他参数具有require_grad = False)

步骤5:损失函数和优化器

已经准备好将数据输入模型中,并且模型将返回预测。但是,如何知道预测是否正确?

这就是损失函数发挥作用的地方!

只有当它们具有标量表示形式时,才能比较它们。损失函数提供了一个标量,可以进行比较。

但是,如何在训练模型时确保损失不断减少,使每次迭代的预测越来越好?

交叉熵损失是全世界用来解决多分类问题的标准损失函数。Adam优化器是最受欢迎的优化器选择之一。

代码语言:javascript
代码运行次数:0
复制
import torch.optim as optimfrom torch.optim import lr_scheduler # Now we define the Loss Functionloss_fn = nn.CrossEntropyLoss() # Define the optimizerlr = 0.00001optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr = lr)

步骤6:训练和验证

在完成了数据处理的所有工作,选择了合适的模型,冻结了一些层,选择了损失函数和优化器之后,终于准备好开始训练神经网络了。

释放Resnet34,并从训练数据中吸收所有能量!

代码语言:javascript
代码运行次数:0
复制
loader = {'train': train_loader, 'val': val_loader} epochs = 5log_interval = 2 # Let's train the model for 5 epochs !train_losses, val_losses, batch_train_losses, batch_val_losses = trainer(loader, model, loss_fn, optimizer, epochs = epochs, log_interval = log_interval) # Ploting the epoch lossesplt.plot(train_losses)plt.plot(val_losses)plt.legend(['train losses', 'val_losses'])plt.title('Loss vs Epoch') plt.figure()plt.plot(batch_train_losses)plt.title('batch_train_losses') plt.figure()plt.plot(batch_val_losses)plt.title('batch_val_losses')  # Saving the model(architecture and weights)torch.save(model, 'stage1.pth')
  • Training started...epoch >>> 1/5___TRAINING___batch_loss at batch_idx 01/16: 1.8830947875976562batch_loss at batch_idx 03/16: 1.7403203248977661batch_loss at batch_idx 05/16: 1.7298262119293213batch_loss at batch_idx 07/16: 1.5611944198608398/opt/conda/lib/python3.6/site-packages/PIL/TiffImagePlugin.py:780: UserWarning: Corrupt EXIF data. Expecting to read 4 bytes but only got 0. warnings.warn(str(msg))batch_loss at batch_idx 09/16: 1.5026812553405762batch_loss at batch_idx 11/16: 1.3258850574493408batch_loss at batch_idx 13/16: 1.2495317459106445batch_loss at batch_idx 15/16: 1.305968165397644>>> train loss at epoch 1/5: 1.5462412377120007___VALIDATION___batch_loss at batch_idx 01/16: 1.1894911527633667batch_loss at batch_idx 03/16: 1.152886152267456>>> val loss at epoch 1/5: 1.1689875689078504=========================epoch >>> 2/5___TRAINING___batch_loss at batch_idx 01/16: 1.0216645002365112batch_loss at batch_idx 03/16: 1.0510553121566772batch_loss at batch_idx 05/16: 0.9224186539649963batch_loss at batch_idx 07/16: 0.8934431672096252batch_loss at batch_idx 09/16: 0.8022943735122681batch_loss at batch_idx 11/16: 0.8075667023658752batch_loss at batch_idx 13/16: 0.7715064287185669batch_loss at batch_idx 15/16: 0.8646692633628845>>> train loss at epoch 2/5: 0.8703589484154471___VALIDATION___batch_loss at batch_idx 01/16: 0.6990357637405396batch_loss at batch_idx 03/16: 0.6862280368804932>>> val loss at epoch 2/5: 0.7019397386415737=========================epoch >>> 3/5___TRAINING___batch_loss at batch_idx 01/16: 0.6628543138504028batch_loss at batch_idx 03/16: 0.5495110750198364batch_loss at batch_idx 05/16: 0.4737720787525177batch_loss at batch_idx 07/16: 0.5540937781333923batch_loss at batch_idx 09/16: 0.5418666005134583batch_loss at batch_idx 11/16: 0.483386754989624batch_loss at batch_idx 13/16: 0.4547680914402008batch_loss at batch_idx 15/16: 0.4926633834838867>>> train loss at epoch 3/5: 0.5295033304116471___VALIDATION___batch_loss at batch_idx 01/16: 0.4764443635940552batch_loss at batch_idx 03/16: 0.45040857791900635>>> val loss at epoch 3/5: 0.4676538405455942=========================epoch >>> 4/5___TRAINING___batch_loss at batch_idx 01/16: 0.3784201443195343batch_loss at batch_idx 03/16: 0.35650306940078735batch_loss at batch_idx 05/16: 0.40147092938423157batch_loss at batch_idx 07/16: 0.3184959292411804batch_loss at batch_idx 09/16: 0.32096436619758606batch_loss at batch_idx 11/16: 0.3386695683002472batch_loss at batch_idx 13/16: 0.32678791880607605batch_loss at batch_idx 15/16: 0.4038775861263275>>> train loss at epoch 4/5: 0.3578496524703361___VALIDATION___batch_loss at batch_idx 01/16: 0.3420213460922241batch_loss at batch_idx 03/16: 0.33064761757850647>>> val loss at epoch 4/5: 0.3362177138722788=========================epoch >>> 5/5___TRAINING___batch_loss at batch_idx 01/16: 0.3416098356246948batch_loss at batch_idx 03/16: 0.25589874386787415batch_loss at batch_idx 05/16: 0.2259582132101059batch_loss at batch_idx 07/16: 0.27338215708732605batch_loss at batch_idx 09/16: 0.2513640224933624batch_loss at batch_idx 11/16: 0.25437164306640625batch_loss at batch_idx 13/16: 0.22256909310817719batch_loss at batch_idx 15/16: 0.2105967402458191>>> train loss at epoch 5/5: 0.2632993612836001___VALIDATION___batch_loss at batch_idx 01/16: 0.24774116277694702batch_loss at batch_idx 03/16: 0.24713973701000214>>> val loss at epoch 5/5: 0.258364795347837=========================

训练5个时期后的损失图

想和大家分享一个技巧。网络训练了几个时期,然后冻结了除最后一个线性层以外的所有层。

可能想知道为什么需要此步骤。

还记得已经丢弃了预训练模型中的最后一个线性层,并添加了一个新神经元层,该神经元层数等于自定义数据集中的类数吗?当这样做时,最后线性层的权重被随机初始化,一旦所有卷积层都经过训练(需要从输入图像中提取不同的特征),就需要适当地对其进行训练。

代码语言:javascript
代码运行次数:0
复制
# We will now freeze the 'layer4' and train just the 'fc' layer of the model for 2 more epochs for name, param in model.named_parameters():    if 'layer4' in name:        param.requires_grad = False # layer4 parameters would not get trained now         # Define the new learning rate and the new optimizer which would contain only the parameters with requires_grad = Truelr = 0.0003optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr = lr)        epochs = 3log_interval = 2 # Let's train the model for 3 more epochs !train_losses, val_losses, batch_train_losses, batch_val_losses = trainer(loader, model, loss_fn, optimizer, epochs = epochs, log_interval = log_interval) # Ploting the epoch lossesplt.plot(train_losses)plt.plot(val_losses)plt.legend(['train losses', 'val_losses'])plt.title('Loss vs Epoch') plt.figure()plt.plot(batch_train_losses)plt.title('batch_train_losses') plt.figure()plt.plot(batch_val_losses)plt.title('batch_val_losses')  # Saving the model(architecture and weights)torch.save(model, 'stage2.pth')
代码语言:javascript
代码运行次数:0
复制
Training started...epoch >>> 1/3___TRAINING___batch_loss at batch_idx 01/16: 0.20289725065231323batch_loss at batch_idx 03/16: 0.2349197268486023batch_loss at batch_idx 05/16: 0.2194989025592804batch_loss at batch_idx 07/16: 0.20219461619853973batch_loss at batch_idx 09/16: 0.27012479305267334batch_loss at batch_idx 11/16: 0.20639048516750336batch_loss at batch_idx 13/16: 0.1523684412240982batch_loss at batch_idx 15/16: 0.14577656984329224>>> train loss at epoch 1/3: 0.2009116342887577___VALIDATION___batch_loss at batch_idx 01/16: 0.20299889147281647batch_loss at batch_idx 03/16: 0.19083364307880402>>> val loss at epoch 1/3: 0.20429044950196124=========================epoch >>> 2/3___TRAINING___batch_loss at batch_idx 01/16: 0.14590243995189667batch_loss at batch_idx 03/16: 0.10861243307590485batch_loss at batch_idx 05/16: 0.14622969925403595batch_loss at batch_idx 07/16: 0.1130327433347702batch_loss at batch_idx 09/16: 0.1342758983373642batch_loss at batch_idx 11/16: 0.13757610321044922batch_loss at batch_idx 13/16: 0.15501776337623596batch_loss at batch_idx 15/16: 0.11977922171354294>>> train loss at epoch 2/3: 0.14645167593310474___VALIDATION___batch_loss at batch_idx 01/16: 0.16367006301879883batch_loss at batch_idx 03/16: 0.16462600231170654>>> val loss at epoch 2/3: 0.17527046447663797=========================epoch >>> 3/3___TRAINING___batch_loss at batch_idx 01/16: 0.1762229949235916batch_loss at batch_idx 03/16: 0.10568083077669144batch_loss at batch_idx 05/16: 0.14333905279636383batch_loss at batch_idx 07/16: 0.08794888854026794batch_loss at batch_idx 09/16: 0.1599852591753006batch_loss at batch_idx 11/16: 0.15842339396476746batch_loss at batch_idx 13/16: 0.08625025302171707batch_loss at batch_idx 15/16: 0.12491285800933838>>> train loss at epoch 3/3: 0.13451774695174026___VALIDATION___batch_loss at batch_idx 01/16: 0.15565256774425507batch_loss at batch_idx 03/16: 0.13937778770923615>>> val loss at epoch 3/3: 0.15460531577819914=========================

以下是冻结第4层后的损失图。

冻结第4层后的损失图。

该技巧在实践中效果很好。有训练网络的标准方法,但是没有硬性规定。因此可能想尝试训练过程。通过在下面的评论部分中分享方法,按照其他训练步骤,是否取得了不错的成绩。

步骤7:现在是测试时间!

已经在手机相册的自定义数据集上训练了神经网络,现在应该将任何给定图像分类为训练过的数据集中存在的6类之一。

现在,需要做的就是读取测试图像,对它进行相同的预处理,就像在训练网络时对图像所做的一样,并希望看到一些不错的预测从网络中返回。

代码语言:javascript
代码运行次数:0
复制
import cv2import torch.nn.functional as F # Making a 'predict' function which would take the 'model' and the path of the 'test image' as inputs, and predict the class that the test image belongs to.def predict(model, test_img_path):    img = cv2.imread(test_img_path)     # Visualizing the test image    plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))        img = transforms.Compose([transforms.ToPILImage()])(img)        img = data_transforms(img)    img = img.view(1, img.shape[0], img.shape[2], img.shape[2]) # Expanding dimension        model.eval()    with torch.no_grad():        logits = model(img)            probs = F.softmax(logits, dim = 1)    max_prob, ind = torch.max(probs, 1)        print('This Neural Network thinks that the given image belongs to >>> {} <<< class with confidence of {}%'.format(dataset.classes[ind], round(max_prob.item()*100, 2)))    test_data_dir = 'mobile_gallery_image_classification_data/mobile_gallery_image_classification/test' test_img_list = []for class_dir in glob.glob(test_data_dir + os.sep + '*'):    test_img_list.append(class_dir) # Loading the trained model(architecture as well as the weights) for making inferencesmodel = torch.load('stage2.pth') # Select the test image index(choose a number from 0 to 6)test_img_index = 3predict(model, test_img_list[test_img_index])   This Neural Network thinks that the given image belongs to >>> Memes <<< class with confidence of 95.21%

测试图片

该神经网络认为给定图像属于Memes类,正确率为95.21%

刚刚制作了一个手机相册图像分类器:这只是使用图像分类器的一个想法。可以使用图像分类器来构建各种创意应用程序。

强烈建议使用这个公共的Kaggle内核并使用代码。

https://www.kaggle.com/n0obcoder/mobile-gallery-image-classification-using-pytorch

下载包含jupyter笔记本文件的GitHub存储库。

https://github.com/n0obcoder/Mobile-Gallery-Image-Classification-in-PyTorch

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2019-11-25,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 相约机器人 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档