MHA、MQA、GQA的区别
一、MHA定义
2、代码实现
import torch
import torch.nn as nn
class MultiHeadAttention(nn.Module):
def __init__(self, embed_dim, num_heads):
super(MultiHeadAttention, self).__init__()
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.wq = nn.Linear(embed_dim, embed_dim)
self.wk = nn.Linear(embed_dim, embed_dim)
self.wv = nn.Linear(embed_dim, embed_dim)
self.wo = nn.Linear(embed_dim, embed_dim)
def split(self,hidden):
batch_size = hidden.shape[0]
x = hidden.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
return x
def forward(self, hidden_states, mask=None):
batch_size = hidden_states.size(0)
# 线性变换
q, k, v = self.wq(hidden_states), self.wk(hidden_states), self.wv(hidden_states)
# 多头切分
q, k ,v = self.split(q), self.split(k), self.split(v)
# 注意力计算
scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attention = torch.softmax(scores, dim=-1)
output = torch.matmul(attention, v)
# 拼接多头
output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.head_dim)
# 线性变换
output = self.wo(output)
return output
x = torch.ones(3, 6, 32)
atten = MultiHeadAttention(32, 8)
y = atten(x)
print(y.shape)
伪代码:
def attention(Q, K, V, num_heads):
"""
标准的多头注意力机制
Q, K, V: 上一层的输出
num_heads: 多头的数量
"""
batch_size, seq_len, d_model = Q.shape
d_k = d_model // num_heads # 每个头的维度
# 1. 新建KQV三个矩阵
Q = linear(Q).view(batch_size, -1, num_heads, d_k)
K = linear(K).view(batch_size, -1, num_heads, d_k)
V = linear(V).view(batch_size, -1, num_heads, d_k)
# 2. 分头处理
Q = Q.transpose(1, 2) # [batch_size, num_heads, seq_len, d_k]
K = K.transpose(1, 2)
V = V.transpose(1, 2)
# 3. 计算注意力分数并归一化
scores = matmul(Q, K.transpose(-2, -1)) / sqrt(d_k)
attention_weights = softmax(scores, dim=-1)
# 4. 查询
output = matmul(attention_weights, V)
# 5. 合并头并输出到下一层
output = output.transpose(1, 2).contiguous().view(batch_size, -1, d_model)
output = linear(output)
return output
流程
1、线性层和分头:首先,对Q、K、V进行线性变换,然后将它们分为多个头。
2、缩放点积注意力:对于每个头,计算Q和K的点积,然后除以根号下的头的维度大小(为了控制梯度的大小),接着应用softmax函数来获取注意力权重。
3、应用注意力权重:使用得到的注意力权重和V进行加权求和,得到输出。
4、合并头:将所有头的输出合并回一个大的张量。
5、最后一个线性层:对合并后的输出应用另一个线性变换。
详细过程
在每次的计算注意过程中,一般是有下面的步骤
在多头注意力机制中的“独立处理”步骤,每个头实际上执行的是一个缩放点积注意力(Scaled Dot-Product Attention)操作。这个过程可以分为以下几个步骤:
1、计算点积:首先,对于每个头,计算查询(Q)和键(K)之间的点积。这是为了评估序列中每个位置与其他位置之间的相似度。
2、缩放:然后,将这些点积除以一个缩放因子,通常是sqrt(d_k),其中d_k是每个头的维度。这样做是为了防止点积的结果过大,导致在后续应用softmax时出现梯度消失的问题。
3、应用Softmax:接下来,对缩放后的点积结果应用softmax函数,将其转换为概率分布。这个概率分布表示了每个位置对序列中其他位置的关注程度。
4、加权和:最后,使用这个概率分布对值(V)进行加权求和。这样就得到了每个位置的加权表示,其中包含了序列中其他位置的信息。
5、输出:每个头的输出是一个包含了加权和信息的矩阵,其维度是[batch_size, seq_len, d_k]。
将以上步骤形式化地表示为伪代码:
def scaled_dot_product_attention(Q, K, V):
"""
计算缩放点积注意力。
Q: 查询矩阵
K: 键矩阵
V: 值矩阵
"""
d_k = Q.size(-1)
scores = matmul(Q, K.transpose(-2, -1)) / sqrt(d_k) # 计算点积并缩放
attention_weights = softmax(scores, dim=-1) # 应用softmax获取权重
output = matmul(attention_weights, V) # 加权求和
return output
QKV在计算过程中的维度变化
Q(查询),K(键)和V(值)在计算过程中的维度变化如下:
1、初始维度:Q、K、V的输入维度是 [batch_size, seq_len, embed_dim],其中embed_dim 是嵌入维度。
2、线性变换后:应用线性变换后,Q、K、V的维度仍然是 [batch_size, seq_len, embed_dim]。
3、分头处理:为了分头处理,Q、K、V被重塑为 [batch_size, seq_len, num_heads, head_dim],然后转置为 [batch_size, num_heads, seq_len, head_dim],其中 head_dim = embed_dim / num_heads。
4、注意力计算:在计算注意力时,Q和K的最后两个维度进行矩阵乘法,得到的分数维度是 [batch_size, num_heads, seq_len, seq_len],然后与V相乘后,输出维度是 [batch_size, num_heads, seq_len, head_dim]。
5、输出合并:在合并多头输出之前,首先将输出转置为 [batch_size, seq_len, num_heads, head_dim],然后重塑为 [batch_size, seq_len, embed_dim]。
6、最终输出:通过最后一个线性层后,输出维度仍然是 [batch_size, seq_len, embed_dim]。
二、MQA定义
2、代码实现
import torch
import torch.nn as nn
class MultiHeadAttention(nn.Module):
def __init__(self, embed_dim, num_heads):
super(MultiHeadAttention, self).__init__()
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.wq = nn.Linear(embed_dim, embed_dim)
# 这是MHA的
#self.wk = nn.Linear(embed_dim, embed_dim)
#self.wv = nn.Linear(embed_dim, embed_dim)
# 这是MQA的
self.wk = nn.Linear(embed_dim, self.head_dim)
self.wv = nn.Linear(embed_dim, self.head_dim)
self.wo = nn.Linear(embed_dim, embed_dim)
def split(self, hidden, head_num=None):
batch_size,seq_len = hidden.size()[:2]
# 这是q需要拆分多头的
if head_num ==None:
x = hidden.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
return x
else:
# 这是MQA: 需要拆分k和v,这里面的head_num =1 的
# 最终返回维度(batch_size, 1, seq_len, head_dim)
return hidden.view(batch_size, seq_len, head_num, self.head_dim).transpose(1, 2)
def forward(self, hidden_states, mask=None):
batch_size = hidden_states.size(0)
# 线性变换
q, k, v = self.wq(hidden_states), self.wk(hidden_states), self.wv(hidden_states)
# 多头切分
# 这是MHA的
#q, k ,v = self.split(q), self.split(k), self.split(v)
# 这是MQA的
q, k ,v = self.split(q), self.split(k, 1), self.split(v, 1)
# 注意力计算
scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))
print("scores:",scores.shape)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attention = torch.softmax(scores, dim=-1)
output = torch.matmul(attention, v)
# 拼接多头
output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.head_dim)
# 线性变换
output = self.wo(output)
return output
x = torch.ones(3, 15, 512)
atten = MultiHeadAttention(512, 8)
y = atten(x)
print(y.shape)
3、解释
于多头注意力相比,多查询注意力需要修改的地方有三处(代码上是两处):
(1).的在wk和wv的维度映射上有所不同
原因是: self.wq = nn.Linear(embed_dim, embed_dim) 在多头拆分的时候是有多个头,
而 self.wk = nn.Linear(embed_dim, self.head_dim)在多头拆分的时候只有一个头(需要共享的),然后通过matmul的广播机制进行复制,如果一样都是(embed_dim, embed_dim),在多头注意力拆分的时候就会出错。
(2).多头拆分的时候不一样
这是MHA:在拆分的时候q,k,v都是一样的,都拆成同样的头,后续的attention计算可以正常进行
return hidden.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
这是MQA: 需要拆分k和v,这里面的head_num =1 的,q需要和MHA一样拆分,k,v需要新的拆分方式
return hidden.view(batch_size, seq_len, head_num, self.head_dim).transpose(1, 2)
(3).还有就是计算注意力分数采用的是广播机制
代码上MHA和MQA是一样的在attention计算的时候,q*k转置的时候都是用的matmul方法,
但是 MHA用的是相同维度的矩阵计算多头拆分之后k,v是(batch_size, num_heads, seq_len, head_dim)
但是 MQA用的是相同维度的矩阵计算多头拆分之后k,v是(batch_size, 1, seq_len, head_dim)
MHA在计算q*k转置的时候是矩阵计算,不需要广播
MHA在计算q*k转置的时候是先需要广播将k,v原始维度(batch_size, 1, seq_len, head_dim)扩张到(复制)(batch_size, num_heads, seq_len, head_dim),在进行矩阵计算,这两部操作,matmul可以通过广播机制实现。
三、GQA
2、代码实现
import torch
import torch.nn as nn
class MultiHeadAttention(nn.Module):
def __init__(self, embed_dim, num_heads):
super(MultiHeadAttention, self).__init__()
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.wq = nn.Linear(embed_dim, embed_dim)
# 这是MHA的
#self.wk = nn.Linear(embed_dim, embed_dim)
#self.wv = nn.Linear(embed_dim, embed_dim)
# 这是MQA的
#self.wk = nn.Linear(embed_dim, self.head_dim)
#self.wv = nn.Linear(embed_dim, self.head_dim)
# 这是GQA的
self.group_num = 4 # 这是4个组
self.wk = nn.Linear(embed_dim, self.group_num*self.head_dim)
self.wv = nn.Linear(embed_dim, self.group_num*self.head_dim)
self.wo = nn.Linear(embed_dim, embed_dim)
def split(self, hidden, group_num=None):
batch_size,seq_len = hidden.size()[:2]
# 这是q需要拆分多头的
if group_num == None:
x = hidden.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
return x
else:
# 这是kv需要拆分的多头
x = hidden.view(batch_size, seq_len, group_num, self.head_dim).transpose(1, 2)
x = x[:, :, None, :, :].expand(batch_size, group_num, self.num_heads // group_num, seq_len, self.head_dim).reshape(batch_size, self.num_heads, seq_len, self.head_dim)
return x
def forward(self, hidden_states, mask=None):
batch_size = hidden_states.size(0)
# 线性变换
q, k, v = self.wq(hidden_states), self.wk(hidden_states), self.wv(hidden_states)
# 多头切分
# 这是MHA的
#q, k ,v = self.split(q), self.split(k), self.split(v)
# 这是MQA的
#q, k ,v = self.split(q), self.split(k, 1), self.split(v, 1)
# 这是GQA的
q, k ,v = self.split(q), self.split(k, self.group_num), self.split(v, self.group_num)
# 注意力计算
scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))
print("scores:",scores.shape)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attention = torch.softmax(scores, dim=-1)
output = torch.matmul(attention, v)
# 拼接多头
output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.head_dim)
# 线性变换
output = self.wo(output)
return output
x = torch.ones(3, 15, 512)
atten = MultiHeadAttention(512, 8)
y = atten(x)
print(y.shape)
3、解释
MGA是GQA的一种特例,也就是分组为1,
于多头注意力相比和多查询注意力,分组查询注意力需要修改的地方有三处(代码上是两处):
(1).的在wk和wv的维度映射上有所不同
self.group_num = 4 # 这是4个组
self.wk = nn.Linear(embed_dim, self.group_num*self.head_dim)
self.wv = nn.Linear(embed_dim, self.group_num*self.head_dim)
(2).多头拆分的时候不一样
不同于MQA,MQA只有一个分组,所以k,v在拆分后维度是(batch_size, 1, seq_len, head_dim),因为有1,所有matmul可以自动广播,而GQA,是多个组,需要手动复制,
x = hidden.view(batch_size, seq_len, group_num, self.head_dim).transpose(1, 2)
此时x维度为(batch_size, group_num, seq_len, self.head_dim)
x = x[:, :, None, :, :].expand(batch_size, group_num, self.num_heads // group_num, seq_len, self.head_dim).reshape(batch_size, self.num_heads, seq_len, self.head_dim)
return x
x[:, :, None, :, :]是增加了1维度空的,然后按照expand进行维度复制扩展,最后在进行维度转换
博主真是太厉害了!!!
不错不错,我喜欢看 https://www.jiwenlaw.com/