【pytorch】手把手实现自注意力机制 |
您所在的位置:网站首页 › 注意力机制pytorch › 【pytorch】手把手实现自注意力机制 |
背景: 不仅在NLP领域,自注意力机制也在CV领域有着广泛的应用。所以,如何很好地实现自注意力机制成为比较关键的问题。下面我们来对于该机制进行简单实现。 先总结一下思路: 1. 我们的输入是一个(B,N,C)形状的矩阵,其中B代表Batch Size,N代表Time Step,C代表每个Time Step的维度。 2. 我们想做的是,根据输入得到多头的qkv。qkv分别代表query,key,value。我们想用query来查询key而得到一个关联度矩阵A。 3. 由于是多头注意力,我们得到了多个关联度矩阵,我们要将多个关联度矩阵合并为一个。 4. 最后的关联度矩阵和value矩阵相乘,等到最后的输出。 最后的代码如下: import torch,math import torch.nn as nn class MultiHead_SelfAttention(nn.Module): def __init__(self, dim, num_head): ''' Args: dim: dimension for each time step num_head:num head for multi-head self-attention ''' super().__init__() self.dim=dim self.num_head=num_head self.qkv=nn.Linear(dim, dim*3) # extend the dimension for later spliting def forward(self, x): B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_head, C//self.num_head).permute(2, 0, 3, 1, 4) q, k, v= qkv[0], qkv[1], qkv[2] att = [email protected](-1, -2)/ math.sqrt(C) att = att.softmax(dim=1) # 将多个注意力矩阵合并为一个 x = (att@v).transpose(1, 2) x=x.reshape(B, N, C) return x if __name__=='__main__': B = 10 N = 20 C = 32 num_head=8 x = torch.rand((B, N, C)) MHSA=Multihead_SelfAttention(C, num_head) print(MHSA(x).shape) |
今日新闻 |
推荐新闻 |
CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3 |