前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >炼丹终结者出现 | 单卡3.29s可训练精度94%的Backbone,仅仅眨眼功夫,离大谱

炼丹终结者出现 | 单卡3.29s可训练精度94%的Backbone,仅仅眨眼功夫,离大谱

作者头像
集智书童公众号
发布2024-04-12 20:05:14
2250
发布2024-04-12 20:05:14
举报
文章被收录于专栏:集智书童

CIFAR-10是机器学习中使用最广泛的数据库之一,每年支持数千个研究项目。为了加速研究并降低实验成本,我们引入了针对CIFAR-10的训练方法,这些方法在单个NVIDIA A100 GPU上运行时,能够在3.29秒内达到94%的准确率,10.4秒内达到95%,46.3秒内达到96%。作为这些训练速度的一个因素,我们提出了水平翻转增强的一种去随机化变体,我们显示,在所有翻转比不翻转更有利的情况下,这种方法都优于标准方法。 代码发布在https://github.com/KellerJordan/cifar10-airbench

1 Introduction

CIFAR-10(Krizhevsky等人,2009年)是机器学习中最受欢迎的数据集之一,每年支持数千个研究项目。如果能够提高在CIFAR-10上训练神经网络的速率,那么可以加快研究进度并降低实验成本。在本文中,我们介绍了一种训练方法,在单个NVIDIA A100 GPU上仅需3.29秒就能达到94%的准确率,这比之前的最佳水平(tysam-code,2023年)提高了1.9倍。为了支持需要更高性能的场景,我们另外开发了针对95%和96%准确率的方法。

我们总共发布了以下方法:

  1. airbench94_compiled.py:3.29秒内达到94.01%的准确率
3.6\times 10^{14}

FLOPs)。

  1. airbench94.py:3.83秒内达到94.01%的准确率(
3.6\times 10^{14}

FLOPs)。

  1. airbench95.py:10.4秒内达到95.01%的准确率(
1.4\times 10^{15}

FLOPs)。

  1. airbench96.py:46.3秒内达到96.05%的准确率(
7.2\times 10^{15}

FLOPs)。

所有运行时间都是在单个NVIDIA A100上测量的。我们注意到前两个脚本在数学上是等价的(即,产生相同的训练网络分布),区别仅在于第一个脚本使用torch.compile来提高GPU利用率。它适用于一次训练许多网络的实验,以便分摊一次性的编译成本。未编译的airbench94变体可以通过以下命令轻松安装和运行。

开发这些训练方法的一个动机是它们可以加快在兼容的CIFAR-10相关研究项目中研究人员的实验迭代时间。另一个动机是它们可以降低涉及大量训练网络的项目成本。这类项目的一个例子是Ilyas等人(2022年)关于数据归属的研究,该研究使用了300万个训练网络来证明,给定测试输入的经过训练的神经网络的输出近似为训练模型时二进制选择向量的线性函数。另一个例子是Jordan(2023年)关于训练方差的研究,该研究使用了18万个训练网络来显示,标准训练在测试分布上的性能方差很小。这些研究基于分别在34个A100秒和72个A100秒内达到93%和94.4%的训练。我们在这篇论文中介绍的训练方法使得可以在较少的计算资源下复制这些研究,或进行类似的研究。

快速训练还使得可以为微妙的超参数比较快速积累统计显著性。例如,如果将某个超参数微妙地改变,使得与基线相比,平均CIFAR-10准确率提高了0.02%,那么(假设运行之间典型的0.14%标准差)平均需要

N=133

次训练来在

p=0.05

的统计显著性水平下确认改进。对于一个标准的5分钟ResNet-18训练,这将需要11.1 GPU小时;而airbench94将这个时间缩短到了更方便的7.3分钟。

我们的工作建立在之前的训练速度项目之上。我们使用了一个修改过的网络、初始化和优化器,它们来自tysam-code(2023年),以及Page的优化技巧和冻结的斑块白化层。我们相对于之前工作的最后约10%的速度提升来自于对标准水平翻转增强的新颖改进(图1,第3.6节,第5.2节)。

3 Methods

3.1 研究设计

本研究设计为前瞻性、平行组、随机对照试验(RCT)。该试验遵循赫尔辛基宣言,并得到了相应机构伦理审查委员会(IRB)的批准。所有参与者在入组前均提供了书面知情同意书。符合条件的参与者为成年人(年龄≥18岁),体重指数(BMI)≥25 kg/m²,且被诊断为2型糖尿病(T2DM),且单独饮食和运动未能产生效果。排除标准包括有过心肌梗死、中风或其他严重心血管疾病的病史,以及对研究药物有任何禁忌症。

Network architecture and baseline training

我们训练了一个带有1.97百万个参数的卷积网络,遵循tysam-code(2023)并进行了一些小改动。它包含七个卷积层,后六个被分为三个包含两个的块。精确的网络架构在附录A中以简单的PyTorch代码给出;在本节中,我们将对主要的设计选择进行一些评论。

这个网络在某种程度上类似于VGG,因为其主体完全由3x3卷积和2x2最大池化层组成,并伴随BatchNorm层和激活函数。遵循tysam-code(2023),第一个层是一个没有填充的2x2卷积,使得内部特征图的形状为31x31

\rightarrow

15x15

\rightarrow

7x7

\rightarrow

3x3,而不是更典型的32x32

\rightarrow

16x16

\rightarrow

8x8

\rightarrow

4x4,这导致在吞吐量和性能之间的权衡稍微更有利。我们使用GELU激活函数。

遵循Page(2019);tysam-code(2023),我们禁用了卷积层和线性层的偏置,并禁用了BatchNorm层的仿射尺度参数。最终线性层的输出通过一个常数因子1/9进行缩放。与tysam-code(2023)相比,我们的网络架构唯一的不同在于我们将第三个块的输出通道数从512减少到256,并且我们在第一个卷积中添加了可学习的偏置。

作为我们的基线,我们使用Nesterov SGD以批次大小1048进行训练,标签平滑率为0.2。我们使用一个三角学习率计划,该计划从最大速率的0.2倍开始,在训练的20%时达到最大值,然后逐渐减少到零。对于数据增强,我们使用随机水平翻转以及2像素的随机平移。对于平移,我们使用反射填充(Zagoruyko和Komodakis,2016),我们发现这比零填充更好。请注意,我们所说的2像素随机平移等同于用2像素填充然后取一个随机的32x32裁剪。在评估期间,我们使用水平翻转的测试时间增强,即网络在给定的测试图像及其镜像上运行,并根据两个输出的平均值进行推断。在学习率、动量和权重衰减的优化选择下,这个基线训练配置在45个周期内达到了94%的平均准确度,耗时18.3 A100秒。

Frozen patch-whitening initialization

遵循Page(2019);tysam-code(2023)的做法,我们将第一个卷积层初始化为斑块白化变换。该层是一个2x2卷积,具有24个通道。按照tysam-code(2023)的做法,前12个滤波器被初始化为训练分布中2x2斑块协方差矩阵的特征向量,使得它们的输出具有单位协方差矩阵。接下来的12个滤波器被初始化为前12个的相反数,这样可以在接下来的激活中保留输入信息。图2展示了这一结果。在训练过程中,我们不更新这一层的权重。

与tysam-code(2023)的做法不同,我们为这一层添加了可学习的偏置,从而带来了一定的性能提升。这些偏置训练3个周期,之后我们禁用它们的梯度以增加反向传递的吞吐量,这在不降低准确性的情况下提高了训练速度。我们还通过在计算斑块白化初始化时减少加到特征值上的常数,相对于tysam-code(2023)获得了轻微的性能提升,目的是防止在斑块协方差矩阵奇异的情况下出现数值问题。

斑块白化初始化是影响最大的单一特征。将其添加到基线上,训练速度提高了一倍多,我们仅用21个周期就达到了8.0 A100秒内的94%准确率。### 身份初始化

dirac:我们将第一个卷积之后的所有卷积初始化为部分身份变换。也就是说,对于具有

M

个输入通道和

N \geq M

个输出的卷积,我们将其前

M

个滤波器初始化为输入的身份变换,而将剩余的

N-M

个滤波器保持默认初始化。在PyTorch代码中,这相当于在每个卷积层的权重w上运行torch.nn.init.dirac_(w[:w.size(1)])。这种方法部分遵循了tysam-code(2023),后者使用了一个更复杂的方案,其中身份权重与原始初始化混合,但我们发现这并没有带来更好的性能。添加这一特性后,训练在18个周期内达到了6.8 A100秒内的94%准确率。

Optimization tricks

比例偏置:我们按照Page(2019)和tysam-code(2023)的方法,将所有BatchNorm层的学习偏置增加了一个因子

64\times

。加入这个特性后,训练在13.5个周期内达到了94%,耗时5.1个A100秒。

前瞻:我们遵循tysam-code(2023)的方法,使用了前瞻(Lookahead)优化。我们注意到,前瞻在之前关于ResNet-18训练速度的研究(Moreau等人,2022)中也已被发现是有效的。加入这个特性后,训练在12.0个周期内达到了94%,耗时4.6个A100秒。

Multi-crop evaluation

多裁剪评估:为了生成预测,我们对每个测试图像的六个增强视图运行训练好的网络:未修改的输入、向上并向左平移一个像素的版本、向下并向右平移一个像素的版本,以及这三个版本的水平镜像。预测是通过所有六个输出的加权平均值来生成的,其中未平移图像的两个视图各加权0.25,其余四个视图各加权0.125。添加此功能后,训练在10.8个周期内达到94%,耗时4.2个A100秒。

我们注意到,多裁剪推理是ImageNet (Deng et al., 2009) 训练中的经典方法(Simonyan and Zisserman, 2014; Szegedy et al., 2014),其中性能会随着评估裁剪数量的增加而提高,甚至可以达到144个裁剪(Szegedy et al., 2014)。在我们的实验中,使用更多的裁剪确实可以提高性能,但推理时间的增加超过了潜在的训练加速。

Alternating flip

为了加快训练速度,我们提出了标准水平翻转增强的一种去随机化变体,我们的动机如下。在训练神经网络时,标准的做法是将训练组织成一系列的周期,在这些周期内每个训练样本恰好被看到一次。这与随机梯度下降(SGD)的教科书定义不同,后者要求从训练集中重复地、有放回地采样数据,导致在训练的短时间内,样本可能会被多次重复看到。

使用随机排序的周期数据进行训练在优化文献中有一个不同的名称,被称为随机重排方法。如果我们的训练数据集包含

N

个独特的样本,那么有放回地采样数据会导致每个包含

N

个采样样本的“周期”平均只包含

(1-(1-1/N)^{N})N\approx(1-1/e)N\approx 0.632N

个独特样本。另一方面,随机重排导致每个周期都能看到所有的

N

个独特样本。鉴于随机重排经验上的成功(表1),

我们认为在训练时间窗口内最大化看到的独特输入的数量是有益的。基于这种推理,我们设计了一种新的水平翻转增强变体,如下所述。我们首先指出,标准随机水平翻转增强可以定义如下。

如果水平翻转是唯一使用的增强方法,那么在训练期间可能看到的确切独特输入有

2N

个。理论上,每对连续周期都可以包含每一个独特输入。但我们主要的观察是,在使用标准随机水平翻转时,一半的图像在两个周期中都会以相同的方式重复翻转,因此平均只能看到

1.5N

个独特输入。altflip:为了解决这个问题,我们提出以下方式修改标准随机水平翻转增强。在第一个周期中,我们像往常一样随机翻转50%的输入。然后在周期

\{2,4,6,\dots\}

中,我们只翻转第一个周期中没有翻转的输入,在周期

\{3,5,7,\dots\}

中,我们只翻转第一个周期中翻转的输入。我们提供了一个实现,它通过使用伪随机函数来决定翻转,从而避免了额外内存的需求。

结果是,每对连续周期都包含所有的

2N

个独特输入,如图1所示。我们在第5.2节中展示了这种方法在各种场景中的有效性。添加这个特性使我们能够将训练缩短到最终的9.9个周期,得到我们最终的训练方法airbench94.py,其完整内容可以在E节中找到。它在NVIDIA A100上用3.83秒达到了94%的准确率。

Compilation

我们采取的最后一步来加速训练是非算法性的:我们使用 torch.compile 来编译我们的训练方法,以便更高效地利用 GPU。这产生了一个训练脚本,它在数学上等价(考虑到浮点运算中小差异)于未编译版本,同时显著加快速度:训练时间减少了14%,降至3.29 A100秒。缺点是,在训练开始之前,一次性的编译过程可能需要几分钟才能完成,因此只有在我们计划一次性执行许多训练运行时才是有益的。我们发布了这个版本,名为 airbench94_compiled.py。

4 95% and 96% targets

为了应对那些需要稍高性能的场景,我们另外开发了针对95%和96%准确率的方法。这两种方法都是对airbench94的直接修改。

为了达到95%的准确率,我们将训练周期从9.9增加到15,并将第一个块的输出通道数从64增加到128,将后两个块的输出通道数从256增加到384。我们将学习率降低到0.87倍。这些修改产生了airbench95,它在10.4 A100秒内达到了95.01%的准确率,消耗了

1.4\times 10^{15}

FLOPs。

为了达到96%的准确率,我们添加了12像素的Cutout增强,并将训练周期提高到40。我们在每个块中添加了第三个卷积,并将第一个块扩展到128个通道,后两个块扩展到512个通道。我们还添加了跨越每个块的后面两个卷积的残差连接,我们发现尽管我们已经使用身份初始化(第3.3节)来简化梯度流动,但这仍然是有益的。最后,我们将学习率降低到0.78倍。这些更改产生了airbench96,它在46.3 A100秒内达到了96.05%的准确率,消耗了

7.2\times 10^{15}

FLOPs。图3展示了我们三种训练方法的FLOPs和错误率。

5 Experiments

Interaction between features

为了更好地理解每个特征对训练速度的影响,我们比较了两个数量。首先,我们测量了将特征添加到白化基准线上(第3.2节)可以节省的纪元数。其次,我们测量了从最终的airbench94中移除该特征需要增加的纪元数(第3.6节)。例如,将身份初始化(第3.3节)添加到白化基准线上,将94%的纪元数从21减少到18,而从最终的airbench94中移除它,则将94%的纪元数从9.9增加到12.8。

图4显示了每个特征的这两个数量。令人惊讶的是,我们发现除了多裁剪测试时间增强(multi-crop TTA)之外,所有特征在这两种情况下的纪元变化是相似的,尽管白化基准线所需的纪元数是最终配置的两倍多。这表明大多数特征之间的相互作用是累加的,而不是乘法的。

Does alternating flip generalize?

在本节中,我们研究了交替翻转(第3.6节)在CIFAR-10和ImageNet的各种训练配置中的有效性。我们发现,在所有情况下,除了交替翻转和随机翻转都不比完全不翻转更好的情况外,它都能提高训练速度。

对于CIFAR-10,我们考虑交替翻转在以下24种训练配置中给出的性能提升:airbench94、带有额外Cutout增强的airbench94以及airbench96,每个配置的训练轮次在

\{10,20,40,80\}

范围内,以及TTA(第3.5节)在

\{\text{是},\text{否}\}

之间。对于每种配置,我们比较交替翻转和随机翻转在

n=400

次训练运行中的平均准确率。

图5显示了结果(原始数据见表6)。从随机翻转切换到交替翻转在每个设置中都能提高性能。为了了解改进的程度,我们估计了每个案例的有效加速,即通过从随机翻转切换到交替翻转,在保持随机翻转准确率水平的同时可以节省的轮次比例。我们从每个基于随机翻转的训练配置的轮次到误差曲线上拟合幂律曲线,形式为

\mathrm{error}=c+b\cdot\mathrm{epochs}^{a}

。我们使用这些曲线来计算从随机翻转切换到交替翻转所提供的有效加速。例如,使用随机翻转且不带TTA的airbench94在运行20轮时达到6.26%的误差,在运行40轮时达到5.99%。同样的配置使用交替翻转在运行20轮时达到6.13%,幂律拟合预测这需要使用随机翻转运行25.3轮才能达到。所以我们报告了27%的加速。需要注意的是,使用幂律相比于使用观察到的轮次与误差数据点之间的线性插值来预测,会给出一个更保守的估计,后者会预测52%的加速。

表2显示了结果。我们观察到以下模式。首先,额外增强(Cutout)的加入在一定程度上缩小了随机翻转和交替翻转之间的差距。为了解释这一点,我们注意到交替翻转的主要效果是消除了图像连续多轮以相同方式重复翻转的情况;我们推测,增加额外的增强减少了这些情况带来的负面影响,因为它增加了数据多样性。接下来,TTA缩小了随机翻转和交替翻转之间的差距。它还缩小了随机翻转与完全不翻转之间的差距(表6),这表明TTA只是降低了翻转增强的重要性。最后,训练时间越长,交替翻转所提供的效果加速越明显。

我们接下来通过以下实验研究ImageNet训练:我们用各种训练和测试裁剪方法训练了一个ResNet-18模型,比较了三种翻转选项:交替翻转、随机翻转以及完全不翻转。我们考虑了两种测试裁剪:256x256中心裁剪,裁剪比例为0.875,以及192x192中心裁剪,裁剪比例为1.0。我们用CC(256, 0.875)来表示前者,用CC(192, 1.0)来表示后者。我们还考虑了两种训练裁剪:192x192 inception风格的随机调整大小裁剪(Szegedy等人,2014年),其长宽比介于0.75到1.33之间,覆盖的面积从图像的8%到100%,以及一种较为温和的随机裁剪,它首先将图像的短边调整到192像素,然后选择一个192x192的随机正方形裁剪。我们用Heavy RRC来表示前者,用Light RRC来表示后者。完整的训练细节在附录C中提供。

表3报告了每种情况的平均top-1验证准确度。我们首先注意到,当网络用CC(256, 0.875)裁剪进行评估时,Heavy RRC效果更好,而使用CC(192, 1.0)时,Light RRC略好。考虑到标准的训练-测试分辨率差异理论,这是相当不令人惊讶的。

对于使用Light RRC的训练,我们发现从随机翻转切换到交替翻转可以显著提升性能,训练速度提高了25%以上。在图5中,我们可视化了使用Light RRC进行短期训练时的改进,切换到交替翻转比将训练时长从16个周期增加到20个周期更能提高性能。当关闭水平翻转TTA时,提升更为明显,这与我们在CIFAR-10上的结果一致。另一方面,使用Heavy RRC的训练从交替翻转中没有看到显著的好处。实际上,即使完全关闭翻转,也不会显著降低这些训练的性能。我们得出结论,在后者比完全不翻转有所改进的每个训练场景中,交替翻转都优于随机翻转。### 方差和类别的校准

前面的部分主要关注了影响准确度第一矩(平均值)的因素。在本节中,我们研究了第二矩,发现TTA以牺牲校准为代价降低了方差。

我们的实验是执行10,000次airbench94训练,使用几种超参数设置。对于每个设置,我们报告了测试集准确度的方差以及分布方差的估计(Jordan,2023年)。图6显示了原始准确度分布。

表4显示了结果。每个案例的分布方差至少是测试集方差的1/5,复制了Jordan(2023年)的主要发现。这是一个令人惊讶的结果,因为这些训练最多只有20个周期,而Jordan(2023年)研究的更标准的训练在类似时长内分布方差是其5倍,只有在运行64个周期后才达到低方差。从这次比较中,我们得出结论,分布方差与训练的收敛速度有更强烈的关联,而不是其时长本身。我们还注意到,airbench94的低分布方差表明其训练稳定性很高。

使用TTA(测试时间增强)显著降低了测试集的方差,以至于所有三种使用TTA的设置的测试集方差都低于任何不使用TTA的设置。然而,测试集方差与类别的校准性质有关,因此从反证法的角度,我们假设这种测试集方差的降低必然以类别校准的代价为代价。为了测试这一假设,我们计算了每个设置的类别聚合校准误差(CACE),该误差衡量了与类别校准的偏差。表4展示了结果。每个使用TTA的设置的CACE都高于每个不使用TTA的设置,证实了这一假设。

附录B 额外的数据集实验

我们开发airbench的单一目标是最大限度地提高CIFAR-10的训练速度。为了找出这是否导致它对CIFAR-10“过度拟合”,在本节中,我们评估其在CIFAR-100、SVHN和CINIC-10上的性能。

在CIFAR-10上,airbench96在不使用Cutout和都使用Cutout的情况下,其准确度与标准ResNet-18训练相当(表5)。因此,如果我们评估airbench96在其他任务上的表现,并发现它的准确度不如ResNet-18,那么我们可以说airbench96必须过度拟合了CIFAR-10;否则,我们可以说它具有泛化能力。

我们将airbench96与文献中能找到的每个任务上ResNet-18的最佳准确度进行了比较。我们没有对airbench96的超参数进行任何调整:我们使用的是在CIFAR-10上最优的值。表5显示了结果。事实证明,在所有情况下,airbench96的性能都优于ResNet-18训练。特别是在CIFAR-100上的结果令人印象深刻,airbench96在使用和不使用Cutout的情况下,其准确度比ResNet-18训练高出1.7%。我们得出结论,airbench并没有过度拟合CIFAR-10,因为它在其他任务上显示出强大的泛化能力。

我们注意到,这种airbench96与ResNet-18训练的比较在一方面是公平的,因为它确实表明前者具有良好的泛化能力;但在另一方面又是不公平的,因为它并不表明airbench96本身就是更优秀的训练方法。特别是,airbench96使用测试时增强,而标准ResNet-18训练则不使用。如果使用测试时增强,ResNet-18训练可能优于airbench96。然而,它完成训练所需的时间也要长5-10倍。选择使用哪一种可能取决于具体情况。

我们报告的ResNet-18训练的准确度值来自以下来源。我们试图为每个设置选择尽可能高的值。Moreau等人(2022年)报道在没有Cutout的情况下在CIFAR-10上达到95.55%,在SVHN上达到97.35%。DeVries & Taylor(2017年)报道在CIFAR-10上使用Cutout达到96.01%,在CIFAR-100上不使用Cutout达到77.54%,使用Cutout达到78.04%。Darlow等人(2018年)报道在CINIC-10上不使用Cutout达到87.58%。

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2024-04-06,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 集智书童 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 1 Introduction
  • 3 Methods
    • 3.1 研究设计
    • Network architecture and baseline training
    • Frozen patch-whitening initialization
    • Optimization tricks
    • Multi-crop evaluation
    • Alternating flip
    • Compilation
    • 4 95% and 96% targets
    • 5 Experiments
    • Interaction between features
    • Does alternating flip generalize?
      • 附录B 额外的数据集实验
      相关产品与服务
      腾讯云服务器利旧
      云服务器(Cloud Virtual Machine,CVM)提供安全可靠的弹性计算服务。 您可以实时扩展或缩减计算资源,适应变化的业务需求,并只需按实际使用的资源计费。使用 CVM 可以极大降低您的软硬件采购成本,简化 IT 运维工作。
      领券
      问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档