随着深度学习的发展,越来越多的应用使用到深度学习技术,而在有些应用中,我们可能无法获取到足够多的训练数据,这时就需要使用一些半监督或无监督方法来完成我们到目标。Google Research从2019-2020期间提出了多个半监督学习方法,一次次打破了各个任务的半监督SOTA,本文就来浅析Google Research于2020年发表的FixMatch方法。
"FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence" by Kihyuk Sohn, David Berthelot, Chun-Liang Li, Zizhao Zhang, Nicholas Carlini, Ekin D. Cubuk, Alex Kurakin, Han Zhang, and Colin Raffel.
(附上基于Tensorflow和Pytorch版本到实现,方便习惯使用不同框架到同学学习)
https://github.com/google-research/fixmatch
https://github.com/kekmodel/FixMatch-pytorch
假设我们要识别一只宝可梦,但是我们的标注数据有限,但是有很多未标注的宝可梦图像。
通常我们只使用标记数据进行训练,但是如果数据过少模型的效果往往不够理想。
而FixMatch的核心思想就是利用未标注的数据来参与训练,虽然这些数据没有被标注,但是直觉上讲,如果我们对这些数据进行一定对扰动,模型应该能够输出同一种结果,来监督模型的训练过程,如此一来可以大大提高模型的泛化能力。
FixMatch算法伪代码如下。
FixMatch 借鉴了 UDA 和 ReMixMatch 的思想,应用了多种数据增强方法,对未标注数据,先使用弱增强方法通过模型生成伪标签,再通过强增强方法处理未标注数据进行预测。弱增强方法包括翻转和水平移动。强增强方法使用RandAugment以及CTAugment,最后使用CutOut进行增强。
特征提取网络使用了Wide Residual Networks结构,具体为Wide ResNet-28-2 with 1.5M parameters,在很多任务中其能有效的提取图像的特征。
为了得到该模型的最优超参数,Google工程师凭借其超强的算力对超参做了大量的消融实验,比如学习率,衰减率、学习率衰减函数、标签样本与无标签样本比例、动量、优化器选择、伪标签中用的阈值,包括sharpen中的τ。
论文最后给出了FixMatch在CIFAR数据集上的实验结果,一个字,强!
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。