前往小程序,Get更优阅读体验!
立即前往
发布
社区首页 >专栏 >矩阵乘积 MatMul 的反向传播

矩阵乘积 MatMul 的反向传播

原创
作者头像
王白石
修改2024-10-03 18:46:24
修改2024-10-03 18:46:24
2250
举报

有公式 \mathbf{y} = \mathbf{x}W ,其中 \mathbf{x} 是 D * M 矩阵,W 是 M * N 权重矩阵;另有损失函数 L 是对 \mathbf{y} 的函数,假设 Ly 的偏导已知(反向传播时是这样的),求 L 关于矩阵 \mathbf{x} 的偏导

答案见下式,非常简洁;求一个标量对于矩阵的偏导,这个问题一度困惑了我很长一段时间;在学微积分的时候,求的一直都是 y 对标量 x 的导数或者偏导(多个自变量),对矩阵的偏导该如何算,不知啊;看了普林斯顿的微积分读本,托马斯微积分也看了,都没提到

\frac{\partial L}{\partial \mathbf{x}}=\frac{\partial L}{\partial \mathbf{y}}W^T

这里的关键在于如何理解 \frac{\partial L}{\partial \mathbf{x}} ,其实就是一种记法,也就是分别计算 Lx 中所有项的偏导,然后写成矩阵形式;为了表述方便,我们令上式右边为 A , 那么对于 \mathbf{x} 中的第 ij 项(第 i 行第 j 列), 则必有\frac{\partial L}{\partial x_{ij}} = A_{ij} ,我们只要能证明这一点就可以了

根据链式法则(可参考附录), 要计算 \frac{\partial L}{\partial x_{ij}} ,我们先计算 Ly 的偏导(已知项),然后乘以 yx 的偏导;注意并不需要考虑 y 中的所有项,因为按照矩阵乘法定义,x_{ij} 只参与了 yi(y_{i1}, y_{i2},...y_{in}) 的计算,其中 y_{ik} = \sum\limits_{l=1}^Mx_{il}W_{lk}

\begin{split} \frac{\partial L}{\partial x_{ij}}&=\sum_{k=1}^N\frac{\partial L}{\partial y_{ik}}\frac{\partial y_{ik}}{\partial x_{ij}}\\ &=\sum_{k=1}^N\frac{\partial L}{\partial y_{ik}}W_{jk} \text{$\qquad (\frac{\partial y_{ik}}{\partial x_{ij}}=W_{jk})$}\\ &=\sum_{k=1}^N\frac{\partial L}{\partial y_{ik}}W^T_{kj} \text { $\qquad(W_{jk}=W^T_{kj}$)} \end{split}

也就是 Lx_{ij} 的偏导等于 Lyi 行的偏导(可视为向量)与 W^Tj 列(向量)的点积,根据矩阵乘法定义(矩阵 AB的第 ij 项等于A的第 i 行与 B 的第 j 列的点积),可得上述答案

现在我们来计算 L 关于权重矩阵 W 的偏导

同样按照链式法则,我们先计算 Ly 的偏导(已知项),然后乘以 yw 的偏导;按照矩阵乘法 w_{ij} 参与了 yj 列所有项的计算,其中 y_{kj} = \sum\limits_{l=1}^Mx_{kl}W_{lj}

\begin{split} \frac{\partial L}{\partial w_{ij}}&=\sum_{k=1}^D\frac{\partial L}{\partial y_{kj}}\frac{\partial y_{kj}}{\partial w_{ij}}\\ &=\sum_{k=1}^D\frac{\partial L}{\partial y_{kj}}x_{ki} \text{$\qquad (\frac{\partial y_{kj}}{\partial w_{ij}}=x_{ki})$}\\ &=\sum_{k=1}^Dx^T_{ik}\frac{\partial L}{\partial y_{kj}} \end{split}

也就是 LW_{ij} 的偏导等于 x^Ti 行与Lyj 列项的偏导的点积,按照矩阵乘法定义可得

\frac{\partial L}{\partial W} = x^T\frac{\partial L}{\partial y}

附录:

链式法则 如果函数 w = f(x, y) 有连续的偏导数 f_xf_y 并且 x = x(t) , y = y(t) 可微,那么有

\frac{dw}{dt}=\frac{\partial f}{\partial x}\frac{dx}{dt}+\frac{\partial f}{\partial y}\frac{dy}{dt}

参考 托马斯微积分第 11 版,14.4 节 链式法则 Chain Rule

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档