Loading [MathJax]/jax/output/CommonHTML/config.js
前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >专栏 >多任务学习——【ICML 2018】GradNorm

多任务学习——【ICML 2018】GradNorm

作者头像
小爷毛毛_卓寿杰
发布于 2021-09-10 07:02:34
发布于 2021-09-10 07:02:34
1.9K00
代码可运行
举报
文章被收录于专栏:Soul Joy HubSoul Joy Hub
运行总次数:0
代码可运行

论文链接:https://arxiv.org/pdf/1711.02257.pdf

之前讲过了多任务学习,如简单的shared bottom,都存在一个问题:多个任务的loss如何融合?简单的方式,就是将多个任务的loss直接相加:

但实际情况是,不同任务loss梯度的量级不同,造成有的task在梯度反向传播中占主导地位,模型过分学习该任务而忽视其它任务。此外,不同任务收敛速度不一致的,可能导致有些任务还处于欠拟合,可有些任务已经过拟合了。当然,我们可以人工的设置超参数,如:

由于各任务在训练过程中自己的梯度量级和收敛速度也是动态变化的,所以很显然这样定值的w做并没有很好的解决问题。作者提出了一种可以动态调整loss的w的算法——GradNorm

从上图可知,GradNorm 是以平衡的梯度作为目标,优化Grad Loss,从而动态调整各个任务的w。

那下面我们就来看看Grad Loss是怎么样的:

要注意的是,上式中,减号后面的项,是基于当轮各任务的梯度所计算出来的常量。其中:

G调节着梯度的量级。r调节着任务收敛速度:收敛速度越快,ri​就越小,从而 Gw(i)​(t)应该被优化的变小。

算法步骤如下:

实现如下(引自 GitHub):

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
            # switch for each weighting algorithm:
            # --> grad norm
            if args.mode == 'grad_norm':
                
                # get layer of shared weights
                W = model.get_last_shared_layer()

                # get the gradient norms for each of the tasks
                # G^{(i)}_w(t) 
                norms = []
                for i in range(len(task_loss)):
                    # get the gradient of this task loss with respect to the shared parameters
                    gygw = torch.autograd.grad(task_loss[i], W.parameters(), retain_graph=True)
                    # compute the norm
                    norms.append(torch.norm(torch.mul(model.weights[i], gygw[0])))
                norms = torch.stack(norms)
                #print('G_w(t): {}'.format(norms))


                # compute the inverse training rate r_i(t) 
                # \curl{L}_i 
                if torch.cuda.is_available():
                    loss_ratio = task_loss.data.cpu().numpy() / initial_task_loss
                else:
                    loss_ratio = task_loss.data.numpy() / initial_task_loss
                # r_i(t)
                inverse_train_rate = loss_ratio / np.mean(loss_ratio)
                #print('r_i(t): {}'.format(inverse_train_rate))


                # compute the mean norm \tilde{G}_w(t) 
                if torch.cuda.is_available():
                    mean_norm = np.mean(norms.data.cpu().numpy())
                else:
                    mean_norm = np.mean(norms.data.numpy())
                #print('tilde G_w(t): {}'.format(mean_norm))


                # compute the GradNorm loss 
                # this term has to remain constant
                constant_term = torch.tensor(mean_norm * (inverse_train_rate ** args.alpha), requires_grad=False)
                if torch.cuda.is_available():
                    constant_term = constant_term.cuda()
                #print('Constant term: {}'.format(constant_term))
                # this is the GradNorm loss itself
                grad_norm_loss = torch.tensor(torch.sum(torch.abs(norms - constant_term)))
                #print('GradNorm loss {}'.format(grad_norm_loss))

                # compute the gradient for the weights
                model.weights.grad = torch.autograd.grad(grad_norm_loss, model.weights)[0]
本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2021/07/25 ,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
暂无评论
推荐阅读
编辑精选文章
换一批
【动手实现】Metabalance缓解多任务梯度不平衡
此次我们一起来复现Meta AI(Facebook)提出的MetaBalance方法,该方法主要用于平衡多任务模型中不同任务的梯度。在多任务学习中,不同的任务构建的损失函数在梯度回传时其大小可能存在较大差异,MetaBalance对不同的梯度进行相应的缩放使得不同任务的梯度相近。复现的代码如下,需要注意的是,这部分的梯度缩放只针对共享层,对于任务独立的tower部分不影响。
秋枫学习笔记
2022/09/19
5270
pytorch 学习笔记(一)
pytorch是一个动态的建图的工具。不像Tensorflow那样,先建图,然后通过feed和run重复执行建好的图。相对来说,pytorch具有更好的灵活性。 编写一个深度网络需要关注的地方是:
ke1th
2018/01/02
1.7K0
01.神经网络和深度学习 W2.神经网络基础(作业:逻辑回归 图片识别)
文章目录 编程题 1 1. numpy 基本函数 1.1 编写 sigmoid 函数 1.2 编写 sigmoid 函数的导数 1.3 reshape操作 1.4 标准化 1.5 广播机制 2.
Michael阿明
2021/02/19
4460
《自然语言处理实战入门》第二章:NLP 前置技术(深度学习) ---- pytorch
An open source machine learning framework that accelerates the path from research prototyping to production deployment.
流川疯
2021/12/06
4960
YOLOv8优化策略:Adam该换了!斯坦福最新Sophia优化器,比Adam快2倍 | 2023.5月斯坦福最新成果
斯坦福2023.5月发表的最新研究成果,他们提出了「一种叫Sophia的优化器,相比Adam,它在LLM上能够快2倍,可以大幅降低训练成本」。
AI小怪兽
2023/11/03
2.2K0
机器学习|从0开发大模型之模型预训练
继续写《从0开发大模型》系列文章,本文主要介绍预训练过程。 预训练是目的是让模型学习知识,需要将预处理的数据(《机器学习|从0开发大模型之数据预处理》)中生成的 pretrain_data.bin 文件的上下文全部学习到,那预训练怎么做呢?
用户1904552
2025/02/27
2010
机器学习|从0开发大模型之模型预训练
DGL & RDKit | 基于Attentive FP可视化训练模型原子权重
DGL开发人员提供了用于可视化训练模型原子权重的代码。使用Attentive FP构建模型后,可以可视化给定分子的原子权重,意味着每个原子对目标值的贡献量。
DrugAI
2021/02/01
1.2K0
人工智能创新挑战赛:助力精准气象和海洋预测Baseline3:TCNN+RNN模型、SA-ConvLSTM模型
本次任务我们将学习来自TOP选手“swg-lhl”的冠军建模方案,该方案中采用的模型是TCNN+RNN。
汀丶人工智能
2023/06/06
8350
人工智能创新挑战赛:助力精准气象和海洋预测Baseline3:TCNN+RNN模型、SA-ConvLSTM模型
NumPyML 源码解析(三)
The losses.py module implements several common loss functions, including:
ApacheCN_飞龙
2024/02/17
2440
大厂技术实现 | 多目标优化及应用(含代码实现)@推荐与计算广告系列
推荐,搜索,计算广告是互联网公司最普及最容易商业变现的方向,也是算法发挥作用最大的一些方向,前沿算法的突破和应用可以极大程度驱动业务增长,这个系列咱们就聊聊这些业务方向的技术和企业实践。本期主题为多目标学习优化落地(附『实现代码』和『微信数据集』)
ShowMeAI
2021/10/21
2.2K1
大厂技术实现 | 多目标优化及应用(含代码实现)@推荐与计算广告系列
从零开始学习 PyTorch:多层全连接神经网络
本文引自博文视点新书《深度学习入门之PyTorch》第3 章——多层全连接神经网络 内容提要:深度学习如今已经成为科技领域最炙手可热的技术,在《深度学习入门之PyTorch》中,我们将帮助你入门深度学习。《深度学习入门之PyTorch》将从机器学习和深度学习的基础理论入手,从零开始学习 PyTorch,了解 PyTorch 基础,以及如何用 PyTorch 框架搭建模型。通过阅读《深度学习入门之PyTorch》,你将学到机器学习中的线性回归和 Logistic 回归、深度学习的优化方法、多层全连接神经
AI科技大本营
2018/04/27
5.6K0
从零开始学习 PyTorch:多层全连接神经网络
【RL Base】强化学习:信赖域策略优化(TRPO)算法
在强化学习(RL)领域,如何稳定地优化策略是一个核心挑战。2015 年,由 John Schulman 等人提出的信赖域策略优化(Trust Region Policy Optimization, TRPO)算法为这一问题提供了优雅的解决方案。TRPO 通过限制策略更新的幅度,避免了策略更新过大导致的不稳定问题,是强化学习中经典的策略优化方法之一。
不去幼儿园
2024/12/03
3210
【RL Base】强化学习:信赖域策略优化(TRPO)算法
生成对抗网络(GAN)系列:WGAN与金融时序(附代码)
过拟合是我们试图将机器学习技术应用于时间序列时遇到的问题之一。出现这个问题是因为我们使用我们所知道的唯一时间序列路径来训练我们的模型:已实现的历史。
量化投资与机器学习微信公众号
2020/06/29
4.3K1
生成对抗网络(GAN)系列:WGAN与金融时序(附代码)
PyTorch中神经网络的对抗性攻击和防御
深度学习和神经网络的兴起为现代社会带来了各种机会和应用,例如对象检测和文本转语音。然而,尽管看似准确性很高,但神经网络(以及几乎所有机器学习模型)实际上都可能受到数据(即对抗性示例)的困扰,而这些数据是从原始训练样本中进行的非常轻微的操纵。实际上,过去的研究表明,只要您知道更改数据的“正确”方法,就可以迫使您的网络在数据上表现不佳,而这些数据在肉眼看来似乎并没有什么不同!这些对数据进行有意操纵以降低模型精度的方法称为对抗性攻击,而攻击与防御之战是机器学习领域中持续流行的研究主题。
代码医生工作室
2020/09/14
2.2K0
点赞收藏:PyTorch常用代码段整理合集
PyTorch 将被安装在 anaconda3/lib/python3.7/site-packages/torch/目录下。
机器之心
2019/05/10
1.7K0
PyTorch常用代码段合集
本文是PyTorch常用代码段合集,涵盖基本配置、张量处理、模型定义与操作、数据处理、模型训练与测试等5个方面,还给出了多个值得注意的Tips,内容非常全面。
小白学视觉
2022/02/14
1.2K0
PyTorch常用代码段合集
深度学习实战:2.AlexNet实现花图像分类
AlexNet在我之前的博客中已经做过详解,详情见:https://blog.csdn.net/muye_IT/article/details/123602605?spm=1001.2014.3001
AiCharm
2023/05/15
9710
深度学习实战:2.AlexNet实现花图像分类
PyTorch深度学习模型训练加速指南2021
简要介绍在PyTorch中加速深度学习模型训练的一些最小改动、影响最大的方法。我既喜欢效率又喜欢ML,所以我想我也可以把它写下来。
AI算法与图像处理
2021/01/20
1.4K0
PyTorch深度学习模型训练加速指南2021
深度学习
x_norm = np.linalg.norm(x,ord=2,axis=1,keepdims=True)
freesan44
2021/10/12
4700
PyTorch 2.2 中文官方教程(三)
介绍 || 张量 || 自动微分 || 构建模型 || TensorBoard 支持 || 训练模型 || 模型理解
ApacheCN_飞龙
2024/02/05
4610
PyTorch 2.2 中文官方教程(三)
推荐阅读
相关推荐
【动手实现】Metabalance缓解多任务梯度不平衡
更多 >
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档
本文部分代码块支持一键运行,欢迎体验
本文部分代码块支持一键运行,欢迎体验