对于数据科学或机器学习研究者而言,当解决任何机器学习问题时,可能面临的最大问题之一就是训练数据不平衡的问题。本文将尝试使用图像分类问题来揭示训练数据中不平衡类别的奥秘。
在一个分类问题中,当你想要预测一个或多个类中的样本数量极少时,可能会遇到数据中类不平衡的问题,即部分类的样本数量远远大于其它类中的样本数量。
不平衡课程造成问题主要是由于以下两个原因:
解决这个问题的方法主要有三种,三种各有各自的优缺点:
在本节中,将分析一个图像分类问题(其中存在不平衡类问题),然后使用一种简单有效的技术来解决它。 问题:在kaggle上选择了“驼背鲸识别挑战”任务,期望解决不平衡类别的挑战(理想情况下,所分类的鲸鱼数量少于未分类的鲸类)。 Kagele上任务说明:在这场比赛中,面临的挑战是要建立一个算法来识别图像中的鲸鱼种类。将分析Happy Whale数据库(包含25,000多张图像),这些数据来自研究机构和公共贡献者。通过竞赛,你将有助于为全球海洋哺乳动物种群动态开启丰富的理解领域。
由于这是一个多标签图像分类问题,首先想要检查数据是如何在类中分布的。
上图表明,在4251张训练图像中,每个类只有一张图像的超过了2000张。还有一些类只有2~5张图像。可见这是一个严重的不平衡类问题。我们不能期望深度学习模型每个类别仅使用一张图像进行训练。这也会产生一个问题,即如何在训练和验证样本之间创建一个分界线,理想情况下希望每个类都在训练样本和验证样本中都有表示。 接下来应该做什么? 本文考虑了两个特别的选项:
从图像中可以看到,图像是特定于鲸鱼的尾巴,因此,识别将可能与图像的方向有关。同时注意到数据中有很多图像是特定的黑白或只有R/G/B通道。 根据这些观察结果,使用以下代码对训练样本中不平衡类的图像进行小幅改动并保存:
import osfrom PIL import Imagefrom PIL import ImageFilter
filelist = train['Image'].loc[(train['cnt_freq']<10)].tolist()for count in range(0,2):
for imagefile in filelist:
os.chdir('/home/paperspace/fastai/courses/dl1/data/humpback/train')
im=Image.open(imagefile)
im=im.convert("RGB")
r,g,b=im.split()
r=r.convert("RGB")
g=g.convert("RGB")
b=b.convert("RGB")
im_blur=im.filter(ImageFilter.GaussianBlur)
im_unsharp=im.filter(ImageFilter.UnsharpMask)
os.chdir('/home/paperspace/fastai/courses/dl1/data/humpback/copy')
r.save(str(count)+'r_'+imagefile)
g.save(str(count)+'g_'+imagefile)
b.save(str(count)+'b_'+imagefile)
im_blur.save(str(count)+'bl_'+imagefile)
im_unsharp.save(str(count)+'un_'+imagefile)
以上代码对不平衡类中的每张图像(频率小于10)都进行如下处理:
图像增强:只想确保模型能够获得鲸鱼fluke的详细视图。为此,将缩放合并成图像增强。
学习率设定:从图中可以看到,将学习率定为0.01时效果最好。
使用Resnet50模型(第一层参数不变)进行了很少的迭代训练就能取得很好的效果,这是由于imagenet数据库中也有鲸鱼图像。
epoch trn_loss val_loss accuracy
0 1.827677 0.492113 0.895976
1 0.93804 0.188566 0.964128
2 0.844708 0.175866 0.967555
3 0.571255 0.126632 0.977614
4 0.458565 0.116253 0.979991
5 0.410907 0.113607 0.980544
6 0.42319 0.109893 0.981097
在kaggle排行榜上可以看到模型在测试集上的效果,本文提出的解决方案在本次比赛中排名34,平均精度均值(MAP)为0.41928。
有时候,最简单的方法是最合乎逻辑的(如果你没有更多的数据,只需要复制现有的数据,并有轻微的变化即可),也是最有效的。