前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >使用度量学习进行特征嵌入:交叉熵和监督对比损失的效果对比

使用度量学习进行特征嵌入:交叉熵和监督对比损失的效果对比

作者头像
deephub
发布2021-02-12 16:34:19
1.5K0
发布2021-02-12 16:34:19
举报
文章被收录于专栏:DeepHub IMBA

分类是机器学习中最简单,最常见的任务之一。例如,在计算机视觉中,您希望能够微调普通卷积神经网络(CNN)的最后一层,以将样本正确分类为某些类别(类)。但是,有几种根本不同的方法可以实现这一目标。

Metric learning(度量学习)是其中之一,今天我想与大家分享如何正确使用它。为了使事情变得实用,我们将研究监督式对比学习(SupCon),它是对比学习的一部分,而后者又是度量学习的一部分,但稍后会介绍更多。

通常如何进行分类

在进行度量学习之前,首先了解通常如何解决分类任务。卷积神经网络是当今实用计算机视觉最重要的思想之一,它由两部分组成:编码器和头部(在这种情况下为分类器)。

首先-拍摄图像并计算一组特征,这些特征可以捕获该图像的重要信息。这是通过卷积和池化操作完成的(这就是为什么它被称为卷积神经网络)。之后,将这些特征解压缩到单个向量中,并使用常规的全连接神经网络执行分类。在实践中,您采用在大型数据集(例如ImageNet)上预先训练的某种模型(例如ResNet,DenseNet,EfficientNet等),并根据您的任务(仅最后一层或整个模型)进行微调)。

然而,这里有几点需要注意。首先,通常只关心网络FC部分的输出。也就是说,你取它的输出,并把它们提供给损失函数,以保持模型学习。换句话说,您并不真正关心网络中间发生了什么(例如,来自编码器的特性)。其次,通常你用一些基本的损失函数来训练这些东西,比如交叉熵。

为了更好地理解这个2步过程(encoder + FC),你可以这样想:encoder将图像映射到一些高维空间(例如,在ResNet18的情况下,我们讨论的是512维,而对于Resnet101 - 2048)。在此之后,FC的目标是在这些代表样本的点之间画一条线,以便将它们映射到类。这两种东西是同时训练的。因此,你试图优化特征,同时“在高维空间中画线”。

这种方法有什么问题吗?嗯,没什么,真的。它实际上运行得很好。但这并不意味着没有别的办法。

度量学习 Metric learning

现代机器学习中最有趣的想法之一(至少对我来说是这样)叫做度量学习(或深度度量学习)。简单地说:如果我们不去关注FC层的输出,而是更仔细地研究编码器生成的特性会怎样?如果我们设法用一些损耗函数来优化这些特性,而不是使用网络输出进行优化,会怎么样呢?这就是度量学习的意义所在:用编码器生成好的特性(嵌入)。

“好”是什么意思呢?好吧,如果你想一下,在计算机视觉的例子中,你想对相似的图像有相似的特征,而对截然不同的图像有截然不同的特征。

监督对比学习 Supervised Contrastive Learning

好的,假设在度量学习中,我们关心的只是“好”特征。但是监督式对比学习有什么意义呢?老实说,这种特定方法没有什么特别之处。这是最近的一篇论文,提出了一些不错的技巧,以及一个有趣的2步方法

  1. 训练一个好的编码器,该编码器能够为图像生成良好的特征。
  2. 冻结编码器,添加FC层,然后进行训练。

您可能想知道常规分类器训练有什么区别。不同之处在于,在常规培训中,您需要同时训练编码器和FC。另一方面,在这里,您首先训练一个不错的编码器,然后将其冻结(不再训练),然后仅训练FC。这种逻辑背后的想法是,如果我们设法首先为图像生成真正好的特征,则应该很容易优化FC(正如我们前面提到的,其目标是优化分离样本的行)。

训练过程的细节

让我们深入了解SupCon实施的细节。

在查看训练循环之前,您应该了解的一件事是要训练哪种模型。这非常简单:编码器(例如ResNet,DenseNet,EffNet等),但没有常规的FC层进行分类。

这里不是分类头,而是投影头。投影头是一个由2个FC层组成的序列,它将编码器的特征映射到一个较低的维度空间(通常是128维度,你甚至可以在上面的图片中看到这个值)。使用投影头的原因是,与来自编码器的几千个特征相比,使用128个精心选择的特征更容易让模型学习。

  1. 构造一批N个图像。与其他度量学习方法不同,您不需要太关心这些样本的选择。能拿多少就拿多少,剩下的由损失来处理。
  2. 将这些图像以成对的方式转发给网络,其中一对图像被构造为[augmentation(image_i), augmentation(image_i)],得到embeddings。并进行标准化。
  3. 以某个图像做为锚点。在批处理中找到同一个类的所有图像。把它们作为正样本。找到所有不同类的图像。把他们当作负样本。
  4. 将SupCon损失应用于第二步归一化嵌入,使正样本彼此靠近,同时使负样本更远离。
  5. 第一阶段训练完成后,删除投影头,并在编码器顶部添加FC(就像在常规分类训练中一样)。开始第二阶段训练的冻结编码器,并微调FC的训练。

这里要记住几件事。首先,在训练完成后,去掉投影头,使用投影头之前的特征是会获得更好的效果。作者解释说,由于我们降低了嵌入的大小,导致信息丢失。其次,增强的选择很重要。作者提出了裁剪和色彩抖动的组合。Supcon一次处理批处理中的所有图像(因此,无需构造对或三元组)。而且批处理中的图像越多,模型学习起来就越容易(因为SupCon具有隐式的正负硬挖掘质量)。第四,你可以在第4步停止。这意味着可以通过嵌入来进行分类,而不需要任何FC层。为了做到这一点,计算所有训练样本的嵌入。然后,在验证时,对每个样本计算一个嵌入,将其与每个训练嵌入进行比较(例如余弦距离),采用其类别。

PyTorch实现

实际上,在PyTorch中有一个SupCon的半官方实现。不幸的是,它包含了非常恼人的隐藏bug。最严重的一个问题是:repo的创造者使用了他自己的resnet实现,由于其中的一些bug,批量大小比普通的torchvision模型低两倍。最重要的是,repo没有验证或可视化,所以你不知道什么时候停止训练。在我的repo中,我修复了所有这些问题,并为稳定的训练增加了更多的技巧。

更准确地说,在我的实现包含了以下功能:

  • 使用albumentations进行扩增
  • Yaml配置
  • t-SNE可视化
  • 使用AMI、NMI、mAP、precision_at_1等PyTorch度量学习进行2步验证(用于投影头前后的特性)。
  • 指数移动平均更稳定的训练,随机移动平均更好的泛化和整体性能。
  • 自动混合精度训练,以便能够训练更大的批大小(大约是2的倍数)。
  • 标签平滑损失,LRFinder为第二阶段的训练(FC)。
  • 支持timm模型和jettify优化器
  • 固定种子,使训练具有确定性。
  • 保存基于验证的权重,日志-定期。txt文件,以及TensorBoard日志。

例子是使用Cifar10和Cifar100数据集来进行测试的,但是添加自己的数据集非常简单。为了运行整个数据处理管道,请执行以下操作:

代码语言:javascript
复制
 python train.py --config_name configs/train/train_supcon_resnet18_cifar10_stage1.yml
 python swa.py --config_name configs/train/swa_supcon_resnet18_cifar100_stage1.yml
 python train.py --config_name configs/train/train_supcon_resnet18_cifar10_stage2.yml
 python swa.py --config_name configs/train/swa_supcon_resnet18_cifar100_stage2.yml

之后,你可以检查可视化t-SNE结果。例如,对于Cifar10和Cifar100,大概是下面这样:

Cifar10 t-SNE, SupCon 损失

Cifar10 t-SNE, Cross Entropy 损失

Cifar100 t-SNE, SupCon 损失

Cifar10 t-SNE, Cross Entropy 损失

总结

度量学习是一个非常强大的东西。但是要达到常规CE / LabelSmoothing可以提供的准确性水平非常困难。此外,在训练期间它在计算上也可能是昂贵的并且不稳定的。我在各种任务(分类,超出分布的预测,对新类的泛化等)上测试了SupCon和其他度量指标损失,使用诸如SupCon之类的优势尚不确定。

那有什么意义?我个人认为有两件事。第一,SupCon(和其他度量学习方法)仍然可以提供比CE更结构化的集群,因为它直接优化了该属性。第二,多一个你可以尝试的技能/工具仍然是非常有益的。因此,通过更好的扩展集或不同的数据集(可能使用更细粒度的类),SupCon 可能会产生更好的结果,而不仅仅是与常规分类训练相当。

本文代码:https://github.com/ivanpanshin/SupCon-Framework

作者:Ivan Panshin

原文地址:https://towardsdatascience.com/how-to-use-metric-learning-embedding-is-all-you-need-f26e01597375

deephub翻译组

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2021-02-04,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 DeepHub IMBA 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 通常如何进行分类
  • 度量学习 Metric learning
  • 监督对比学习 Supervised Contrastive Learning
  • 训练过程的细节
  • PyTorch实现
  • 总结
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档