前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >einsum is all you needed

einsum is all you needed

作者头像
lyhue1991
发布2023-02-23 13:22:36
1.9K0
发布2023-02-23 13:22:36
举报

如果问pytorch中最强大的一个数学函数是什么?

我会说是torch.einsum:爱因斯坦求和函数。

它几乎是一个"万能函数":能实现超过一万种功能的函数。

不仅如此,和其它pytorch中的函数一样,torch.einsum是支持求导和反向传播的,并且计算效率非常高。

einsum 提供了一套既简洁又优雅的规则,可实现包括但不限于:内积,外积,矩阵乘法,转置和张量收缩(tensor contraction)等张量操作,熟练掌握 einsum 可以很方便的实现复杂的张量操作,而且不容易出错。

尤其是在一些包括batch维度的高阶张量的相关计算中,若使用普通的矩阵乘法、求和、转置等算子来实现很容易出现维度匹配等问题,但换成einsum则会特别简单。

套用一句深度学习paper标题当中非常时髦的话术,torch.einsum is all you needed 😋!

公众号后台回复关键词:einsum,获取本文源代码链接。

一,einsum规则原理

顾名思义,einsum这个函数的思想起源于家喻户晓的小爱同学:爱因斯坦~。

很久很久以前,小爱同学在捣鼓广义相对论。广义相对论表述各种物理量用的都是张量。

比如描述时空有一个四维时空度规张量,描述电磁场有一个电磁张量,描述运动的有能量动量张量。

在理论物理学家中,小爱同学的数学基础不算特别好,在捣鼓这些张量的时候,他遇到了一个比较头疼的问题:公式太长太复杂了。

有没有什么办法让这些张量运算公式稍微显得对人类友好一些呢,能不能减少一些那种扭曲的

\sum

求和符号呢?

小爱发现,求和导致维度收缩,因此求和符号操作的指标总是只出现在公式的一边。

例如在我们熟悉的矩阵乘法中

C_{ij} = \sum_{k} A_{ik} B_{kj}

k这个下标被求和了,求和导致了这个维度的消失,所以它只出现在右边而不出现在左边。

这种只出现在张量公式的一边的下标被称之为哑指标,反之为自由指标。

小爱同学脑瓜子滴溜一转,反正这种只出现在一边的哑指标一定是被求和求掉的,干脆把对应的

\sum

求和符号省略得了。

这就是爱因斯坦求和约定:

只出现在公式一边的指标叫做哑指标,针对哑指标的求和符号可以省略。

公式立刻清爽了很多。

C_{ij} = A_{ik} B_{kj}

这个公式表达的含义如下:

C这个张量的第i行第j列由

A

这个张量的第i行第k列和

B

这个张量的第k行第j列相乘,这样得到的是一个三维张量

D

, 其元素为

D_{ikj}

,然后对

D

在维度k上求和得到。

公式展现形式中除了省去了求和符号,还省去了乘法符号(代数通识)。

借鉴爱因斯坦求和约定表达张量运算的清爽整洁,numpy、tensorflow和 torch等库中都引入了 einsum这个函数。

上述矩阵乘法可以被einsum这个函数表述成

代码语言:javascript
复制
C = torch.einsum("ik,kj->ij",A,B)

这个函数的规则原理非常简洁,3句话说明白。

  • 1,用元素计算公式来表达张量运算。
  • 2,只出现在元素计算公式箭头左边的指标叫做哑指标。
  • 3,省略元素计算公式中对哑指标的求和符号。
代码语言:javascript
复制
import torch 

A = torch.tensor([[1,2],[3,4.0]])
B = torch.tensor([[5,6],[7,8.0]])

C1 = A@B
print(C1)

C2 = torch.einsum("ik,kj->ij",[A,B])
print(C2)
代码语言:javascript
复制
tensor([[19., 22.],
        [43., 50.]])
tensor([[19., 22.],
        [43., 50.]])

二,einsum基础范例

einsum这个函数的精髓实际上是第一条:

用元素计算公式来表达张量运算。

而绝大部分张量运算都可以用元素计算公式很方便地来表达,这也是它为什么会那么神通广大。

例1,张量转置

代码语言:javascript
复制
#例1,张量转置
A = torch.randn(3,4,5)

#B = torch.permute(A,[0,2,1])
B = torch.einsum("ijk->ikj",A) 

print("before:",A.shape)
print("after:",B.shape)
代码语言:javascript
复制
before: torch.Size([3, 4, 5])
after: torch.Size([3, 5, 4])

例2,取对角元

代码语言:javascript
复制
#例2,取对角元
A = torch.randn(5,5)
#B = torch.diagonal(A)
B = torch.einsum("ii->i",A)
print("before:",A.shape)
print("after:",B.shape)
代码语言:javascript
复制
before: torch.Size([5, 5])
after: torch.Size([5])

例3,求和降维

代码语言:javascript
复制
#例3,求和降维
A = torch.randn(4,5)
#B = torch.sum(A,1)
B = torch.einsum("ij->i",A)
print("before:",A.shape)
print("after:",B.shape)
代码语言:javascript
复制
before: torch.Size([4, 5])
after: torch.Size([4])

例4,哈达玛积

代码语言:javascript
复制
#例4,哈达玛积
A = torch.randn(5,5)
B = torch.randn(5,5)
#C=A*B
C = torch.einsum("ij,ij->ij",A,B)
print("before:",A.shape, B.shape)
print("after:",C.shape)
代码语言:javascript
复制
before: torch.Size([5, 5]) torch.Size([5, 5])
after: torch.Size([5, 5])

例5,向量内积

代码语言:javascript
复制
#例5,向量内积
A = torch.randn(10)
B = torch.randn(10)
#C=torch.dot(A,B)
C = torch.einsum("i,i->",A,B)
print("before:",A.shape, B.shape)
print("after:",C.shape)
代码语言:javascript
复制
before: torch.Size([10]) torch.Size([10])
after: torch.Size([])

例6,向量外积

代码语言:javascript
复制
#例6,向量外积
A = torch.randn(10)
B = torch.randn(5)
#C = torch.outer(A,B)
C = torch.einsum("i,j->ij",A,B)
print("before:",A.shape, B.shape)
print("after:",C.shape)
代码语言:javascript
复制
before: torch.Size([10]) torch.Size([5])
after: torch.Size([10, 5])

例7,矩阵乘法

代码语言:javascript
复制
#例7,矩阵乘法
A = torch.randn(5,4)
B = torch.randn(4,6)
#C = torch.matmul(A,B)
C = torch.einsum("ik,kj->ij",A,B)
print("before:",A.shape, B.shape)
print("after:",C.shape)
代码语言:javascript
复制
before: torch.Size([5, 4]) torch.Size([4, 6])
after: torch.Size([5, 6])

例8,张量缩并

代码语言:javascript
复制
#例8,张量缩并
A = torch.randn(3,4,5)
B = torch.randn(4,3,6)
#C = torch.tensordot(A,B,dims=[(0,1),(1,0)])
C = torch.einsum("ijk,jih->kh",A,B)
print("before:",A.shape, B.shape)
print("after:",C.shape)
代码语言:javascript
复制
before: torch.Size([3, 4, 5]) torch.Size([4, 3, 6])
after: torch.Size([5, 6])

三,einsum高级范例

einsum可用于超过两个张量的计算。

例9,bilinear注意力机制

例如:双线性变换。这是向量内积的一种扩展,一种常用的注意力机制实现方式)

不考虑batch维度时,双线性变换的公式如下:

A = qWk^T

考虑batch维度时,无法用矩阵乘法表示,可以用元素计算公式表达如下:

A_{ij} = \sum_{k}\sum_{l}Q_{ik}W_{jkl}K_{il} = Q_{ik}W_{jkl}K_{il}
代码语言:javascript
复制
#例9,bilinear注意力机制

#====不考虑batch维度====
q = torch.randn(10) #query_features
k = torch.randn(10) #key_features
W = torch.randn(5,10,10) #out_features,query_features,key_features
b = torch.randn(5) #out_features

#a = q@W@k.t()+b  
a = torch.bilinear(q,k,W,b)
print("a.shape:",a.shape)


#=====考虑batch维度====
Q = torch.randn(8,10)    #batch_size,query_features
K = torch.randn(8,10)    #batch_size,key_features
W = torch.randn(5,10,10) #out_features,query_features,key_features
b = torch.randn(5)       #out_features

#A = torch.bilinear(Q,K,W,b)
A = torch.einsum('bq,oqk,bk->bo',Q,W,K) + b
print("A.shape:",A.shape)

代码语言:javascript
复制
a.shape: torch.Size([5])
A.shape: torch.Size([8, 5])

例10,scaled-dot-product注意力机制

我们也可以用einsum来实现更常见的scaled-dot-product 形式的 Attention.

不考虑batch维度时,scaled-dot-product形式的Attention用矩阵乘法公式表示如下:

a = softmax(\frac{q k^T}{d_k})

考虑batch维度时,无法用矩阵乘法表示,可以用元素计算公式表达如下:

A_{ij} = softmax(\frac{Q_{in}K_{ijn}}{d_k})
代码语言:javascript
复制
#例10,scaled-dot-product注意力机制

#====不考虑batch维度====
q = torch.randn(10)  #query_features
k = torch.randn(6,10) #key_size, key_features

d_k = k.shape[-1]
a = torch.softmax(q@k.t()/d_k,-1) 

print("a.shape=",a.shape )

#====考虑batch维度====
Q = torch.randn(8,10)  #batch_size,query_features
K = torch.randn(8,6,10) #batch_size,key_size,key_features

d_k = K.shape[-1]
A = torch.softmax(torch.einsum("in,ijn->ij",Q,K)/d_k,-1) 

print("A.shape=",A.shape )

代码语言:javascript
复制
a.shape= torch.Size([6])
A.shape= torch.Size([8, 6])

以上。

万水千山总是情,点个在看行不行?😋

公众号后台回复关键词:einsum,获取本文源代码链接。

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2022-07-17,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 算法美食屋 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 一,einsum规则原理
  • 二,einsum基础范例
    • 例1,张量转置
      • 例2,取对角元
        • 例3,求和降维
          • 例4,哈达玛积
            • 例5,向量内积
              • 例6,向量外积
                • 例7,矩阵乘法
                  • 例8,张量缩并
                  • 三,einsum高级范例
                    • 例9,bilinear注意力机制
                      • 例10,scaled-dot-product注意力机制
                      相关产品与服务
                      批量计算
                      批量计算(BatchCompute,Batch)是为有大数据计算业务的企业、科研单位等提供高性价比且易用的计算服务。批量计算 Batch 可以根据用户提供的批处理规模,智能地管理作业和调动其所需的最佳资源。有了 Batch 的帮助,您可以将精力集中在如何分析和处理数据结果上。
                      领券
                      问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档