一、MHA定义

2、代码实现

import torch
import torch.nn as nn
 
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.wq = nn.Linear(embed_dim, embed_dim)
        self.wk = nn.Linear(embed_dim, embed_dim)
        self.wv = nn.Linear(embed_dim, embed_dim)
        self.wo = nn.Linear(embed_dim, embed_dim)

    def split(self,hidden):    
        batch_size = hidden.shape[0]
        x = hidden.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        return x

    def forward(self, hidden_states, mask=None):
        batch_size = hidden_states.size(0)
        
        # 线性变换
        q, k, v = self.wq(hidden_states), self.wk(hidden_states), self.wv(hidden_states)
        
        # 多头切分
        q, k ,v  = self.split(q), self.split(k), self.split(v)
        
        # 注意力计算
        scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attention = torch.softmax(scores, dim=-1)
        output = torch.matmul(attention, v)
        
        # 拼接多头
        output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.head_dim)
        
        # 线性变换
        output = self.wo(output)
        
        return output
 
x = torch.ones(3, 6, 32)
atten = MultiHeadAttention(32, 8)
y = atten(x)
print(y.shape)

伪代码:

def attention(Q, K, V, num_heads):
    """
    标准的多头注意力机制
 
    Q, K, V: 上一层的输出
    num_heads: 多头的数量
    """
 
    batch_size, seq_len, d_model = Q.shape
    d_k = d_model // num_heads  # 每个头的维度
 
    # 1. 新建KQV三个矩阵
    Q = linear(Q).view(batch_size, -1, num_heads, d_k)
    K = linear(K).view(batch_size, -1, num_heads, d_k)
    V = linear(V).view(batch_size, -1, num_heads, d_k)
 
    # 2. 分头处理
    Q = Q.transpose(1, 2)  # [batch_size, num_heads, seq_len, d_k]
    K = K.transpose(1, 2)
    V = V.transpose(1, 2)
 
    # 3. 计算注意力分数并归一化
    scores = matmul(Q, K.transpose(-2, -1)) / sqrt(d_k)
    attention_weights = softmax(scores, dim=-1)
 
    # 4. 查询
    output = matmul(attention_weights, V)
 
    # 5. 合并头并输出到下一层
    output = output.transpose(1, 2).contiguous().view(batch_size, -1, d_model)
    output = linear(output)
 
    return output

流程

1、线性层和分头:首先,对Q、K、V进行线性变换,然后将它们分为多个头。

2、缩放点积注意力:对于每个头,计算Q和K的点积,然后除以根号下的头的维度大小(为了控制梯度的大小),接着应用softmax函数来获取注意力权重。

3、应用注意力权重:使用得到的注意力权重和V进行加权求和,得到输出。

4、合并头:将所有头的输出合并回一个大的张量。

5、最后一个线性层:对合并后的输出应用另一个线性变换。

详细过程

在每次的计算注意过程中,一般是有下面的步骤

在多头注意力机制中的“独立处理”步骤,每个头实际上执行的是一个缩放点积注意力(Scaled Dot-Product Attention)操作。这个过程可以分为以下几个步骤:

1、计算点积:首先,对于每个头,计算查询(Q)和键(K)之间的点积。这是为了评估序列中每个位置与其他位置之间的相似度。

2、缩放:然后,将这些点积除以一个缩放因子,通常是sqrt(d_k),其中d_k是每个头的维度。这样做是为了防止点积的结果过大,导致在后续应用softmax时出现梯度消失的问题。

3、应用Softmax:接下来,对缩放后的点积结果应用softmax函数,将其转换为概率分布。这个概率分布表示了每个位置对序列中其他位置的关注程度。

4、加权和:最后,使用这个概率分布对值(V)进行加权求和。这样就得到了每个位置的加权表示,其中包含了序列中其他位置的信息。

5、输出:每个头的输出是一个包含了加权和信息的矩阵,其维度是[batch_size, seq_len, d_k]。

将以上步骤形式化地表示为伪代码:

def scaled_dot_product_attention(Q, K, V):
    """
    计算缩放点积注意力。
 
    Q: 查询矩阵
    K: 键矩阵
    V: 值矩阵
    """
    d_k = Q.size(-1)
    scores = matmul(Q, K.transpose(-2, -1)) / sqrt(d_k)  # 计算点积并缩放
    attention_weights = softmax(scores, dim=-1)         # 应用softmax获取权重
    output = matmul(attention_weights, V)               # 加权求和
    return output

QKV在计算过程中的维度变化

Q(查询),K(键)和V(值)在计算过程中的维度变化如下:

1、初始维度:Q、K、V的输入维度是 [batch_size, seq_len, embed_dim],其中embed_dim 是嵌入维度。

2、线性变换后:应用线性变换后,Q、K、V的维度仍然是 [batch_size, seq_len, embed_dim]。

3、分头处理:为了分头处理,Q、K、V被重塑为 [batch_size, seq_len, num_heads, head_dim],然后转置为 [batch_size, num_heads, seq_len, head_dim],其中 head_dim = embed_dim / num_heads。

4、注意力计算:在计算注意力时,Q和K的最后两个维度进行矩阵乘法,得到的分数维度是 [batch_size, num_heads, seq_len, seq_len],然后与V相乘后,输出维度是 [batch_size, num_heads, seq_len, head_dim]。

5、输出合并:在合并多头输出之前,首先将输出转置为 [batch_size, seq_len, num_heads, head_dim],然后重塑为 [batch_size, seq_len, embed_dim]。

6、最终输出:通过最后一个线性层后,输出维度仍然是 [batch_size, seq_len, embed_dim]。

二、MQA定义

2、代码实现

import torch
import torch.nn as nn
 
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        
        self.wq = nn.Linear(embed_dim, embed_dim)

        # 这是MHA的
        #self.wk = nn.Linear(embed_dim, embed_dim)
        #self.wv = nn.Linear(embed_dim, embed_dim)

        # 这是MQA的
        self.wk = nn.Linear(embed_dim, self.head_dim)
        self.wv = nn.Linear(embed_dim, self.head_dim)

        self.wo = nn.Linear(embed_dim, embed_dim)

    def split(self, hidden, head_num=None):    
        batch_size,seq_len = hidden.size()[:2]
        # 这是q需要拆分多头的
        if head_num ==None:
            x = hidden.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
            return x
        else:
            # 这是MQA: 需要拆分k和v,这里面的head_num =1 的
            # 最终返回维度(batch_size, 1, seq_len, head_dim)
            return hidden.view(batch_size, seq_len, head_num, self.head_dim).transpose(1, 2)
            

    def forward(self, hidden_states, mask=None):
        batch_size = hidden_states.size(0)
        
        # 线性变换
        q, k, v = self.wq(hidden_states), self.wk(hidden_states), self.wv(hidden_states)
        
        # 多头切分
        # 这是MHA的
        #q, k ,v  = self.split(q), self.split(k), self.split(v)
        # 这是MQA的
        q, k ,v  = self.split(q), self.split(k, 1), self.split(v, 1)

        # 注意力计算
        scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))
        print("scores:",scores.shape)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attention = torch.softmax(scores, dim=-1)
        output = torch.matmul(attention, v)

        # 拼接多头
        output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.head_dim)
        
        # 线性变换
        output = self.wo(output)
        
        return output
 
x = torch.ones(3, 15, 512)
atten = MultiHeadAttention(512, 8)
y = atten(x)
print(y.shape)

3、解释

于多头注意力相比,多查询注意力需要修改的地方有三处(代码上是两处):

(1).的在wk和wv的维度映射上有所不同

原因是: self.wq = nn.Linear(embed_dim, embed_dim) 在多头拆分的时候是有多个头,
而 self.wk = nn.Linear(embed_dim, self.head_dim)在多头拆分的时候只有一个头(需要共享的),然后通过matmul的广播机制进行复制,如果一样都是(embed_dim, embed_dim),在多头注意力拆分的时候就会出错。

(2).多头拆分的时候不一样

这是MHA:在拆分的时候q,k,v都是一样的,都拆成同样的头,后续的attention计算可以正常进行
return hidden.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

这是MQA: 需要拆分k和v,这里面的head_num =1 的,q需要和MHA一样拆分,k,v需要新的拆分方式
return hidden.view(batch_size, seq_len, head_num, self.head_dim).transpose(1, 2)

(3).还有就是计算注意力分数采用的是广播机制

代码上MHA和MQA是一样的在attention计算的时候,q*k转置的时候都是用的matmul方法,
但是 MHA用的是相同维度的矩阵计算多头拆分之后k,v是(batch_size, num_heads, seq_len, head_dim)
但是 MQA用的是相同维度的矩阵计算多头拆分之后k,v是(batch_size, 1, seq_len, head_dim)

MHA在计算q*k转置的时候是矩阵计算,不需要广播

MHA在计算q*k转置的时候是先需要广播将k,v原始维度(batch_size, 1, seq_len, head_dim)扩张到(复制)(batch_size, num_heads, seq_len, head_dim),在进行矩阵计算,这两部操作,matmul可以通过广播机制实现。

三、GQA

2、代码实现

import torch
import torch.nn as nn
 
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        
        self.wq = nn.Linear(embed_dim, embed_dim)

        # 这是MHA的
        #self.wk = nn.Linear(embed_dim, embed_dim)
        #self.wv = nn.Linear(embed_dim, embed_dim)

        # 这是MQA的
        #self.wk = nn.Linear(embed_dim, self.head_dim)
        #self.wv = nn.Linear(embed_dim, self.head_dim)

        # 这是GQA的
        self.group_num = 4  # 这是4个组
        self.wk = nn.Linear(embed_dim, self.group_num*self.head_dim)
        self.wv = nn.Linear(embed_dim, self.group_num*self.head_dim)

        self.wo = nn.Linear(embed_dim, embed_dim)

    def split(self, hidden, group_num=None):    
        batch_size,seq_len = hidden.size()[:2]
        # 这是q需要拆分多头的 
        if group_num == None:
            x = hidden.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
            return x
        else:
            # 这是kv需要拆分的多头
            x = hidden.view(batch_size, seq_len, group_num, self.head_dim).transpose(1, 2)
            x = x[:, :, None, :, :].expand(batch_size, group_num, self.num_heads // group_num, seq_len, self.head_dim).reshape(batch_size, self.num_heads, seq_len, self.head_dim)
            return x

    def forward(self, hidden_states, mask=None):
        batch_size = hidden_states.size(0)
        
        # 线性变换
        q, k, v = self.wq(hidden_states), self.wk(hidden_states), self.wv(hidden_states)
        
        # 多头切分
        # 这是MHA的
        #q, k ,v  = self.split(q), self.split(k), self.split(v)
        # 这是MQA的
        #q, k ,v  = self.split(q), self.split(k, 1), self.split(v, 1)
        # 这是GQA的
        q, k ,v  = self.split(q), self.split(k, self.group_num), self.split(v, self.group_num)
        
        # 注意力计算
        scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))
        print("scores:",scores.shape)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attention = torch.softmax(scores, dim=-1)
        output = torch.matmul(attention, v)

        # 拼接多头
        output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.head_dim)
        
        # 线性变换
        output = self.wo(output)
        
        return output
 
x = torch.ones(3, 15, 512)
atten = MultiHeadAttention(512, 8)
y = atten(x)
print(y.shape)

3、解释

MGA是GQA的一种特例,也就是分组为1,
于多头注意力相比和多查询注意力,分组查询注意力需要修改的地方有三处(代码上是两处):

(1).的在wk和wv的维度映射上有所不同

self.group_num = 4 # 这是4个组
self.wk = nn.Linear(embed_dim, self.group_num*self.head_dim)
self.wv = nn.Linear(embed_dim, self.group_num*self.head_dim)

(2).多头拆分的时候不一样

不同于MQA,MQA只有一个分组,所以k,v在拆分后维度是(batch_size, 1, seq_len, head_dim),因为有1,所有matmul可以自动广播,而GQA,是多个组,需要手动复制,

x = hidden.view(batch_size, seq_len, group_num, self.head_dim).transpose(1, 2)

此时x维度为(batch_size, group_num, seq_len, self.head_dim)

x = x[:, :, None, :, :].expand(batch_size, group_num, self.num_heads // group_num, seq_len, self.head_dim).reshape(batch_size, self.num_heads, seq_len, self.head_dim)
return x
x[:, :, None, :, :]是增加了1维度空的,然后按照expand进行维度复制扩展,最后在进行维度转换

参考:
1.https://zhuanlan.zhihu.com/p/659238103

标签: none

已有 2 条评论

  1. 博主真是太厉害了!!!

  2. 不错不错,我喜欢看 https://www.jiwenlaw.com/

添加新评论