掘金 人工智能 05月06日
从0开始LLM-注意力机制-3
index_new5.html
../../../zaker_core/zaker_tpl_static/wap/tpl_guoji1.html

 

本文深入探讨了因果注意力机制,它是自注意力机制的一种特殊形式,通过遮蔽未来 Token 的信息,确保模型在处理序列时仅依赖于过去和当前的信息。文章详细介绍了两种实现因果注意力的方法:一种是在 softmax 后将对角线以上的权重归零并重新标准化,另一种是在 softmax 前用负无穷大遮蔽注意力得分。此外,还讨论了 Dropout 技术在防止过拟合中的应用,并展示了如何将这些技术集成到一个紧凑的 `CausalAttention` 类中,以便在大型语言模型中使用。

🎭 因果注意力机制是自注意力的一种变体,通过遮蔽(mask)的方式,限制模型在处理序列中的每个 Token 时,只能关注到该 Token 及其之前的 Token,避免了访问未来信息。

♾️ 两种实现因果注意力遮蔽的关键方法:一种是在计算 softmax 后的注意力权重矩阵中,将对角线以上的值归零,并重新标准化;另一种更高效的方法是在应用 softmax 函数之前,将注意力得分矩阵对角线以上的值替换为负无穷大。

🚫 Dropout技术通过在训练过程中随机忽略部分神经元,减少模型对特定神经元的依赖,防止过拟合。在 Transformer 架构中,Dropout 常被应用于注意力得分计算之后,或注意力权重应用于值向量之后。

💻 文中提供了一个紧凑的 `CausalAttention` 类实现,集成了因果注意力和 Dropout 技术,并使用 `register_buffer` 确保在大型语言模型中使用时,遮蔽张量能自动移动到合适的设备上。

使用因果注意力机制隐藏后续词

修改标准的自注意力机制,创建一个因果注意力(Causal Attention)机制。因果注意力,也称为遮蔽注意力(masked attention),是自注意力的一种特殊形式。它限制模型在处理任何给定 Token 时,只考虑序列中之前和当前的输入。这与标准的自注意力机制形成对比,后者允许一次访问整个输入序列。

因此,在计算注意力得分时,因果注意力机制确保模型只考虑序列中当前 Token 或之前出现的 Token。

在因果注意力中,遮蔽掉对角线以上的注意力权重,以便在计算上下文向量时,大语言模型无法访问后续的Token。例如,在第二行中,对于单词“journey”,我们只保留“Your”(之前的单词)和“journey”(当前位置)的注意力权重。

遮蔽了对角线以上的注意力权重,并标准化未遮蔽的注意力权重,使得每一行的注意力权重之和为 1。

应用因果注意力遮蔽

因果注意力机制中获取遮蔽的注意力权重矩阵的一种方式是对注意力得分应用 softmax 函数,将对角线以上的元素归零并标准化结果矩阵。

    使用 softmax 函数计算注意力权重,如前几节所做的那样:
queries = sa_v2.W_query(inputs)  #Akeys = sa_v2.W_key(inputs) attn_scores = queries @ keys.Tattn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=1)print(attn_weights)#结果tensor([[0.1921, 0.1646, 0.1652, 0.1550, 0.1721, 0.1510],        [0.2041, 0.1659, 0.1662, 0.1496, 0.1665, 0.1477],        [0.2036, 0.1659, 0.1662, 0.1498, 0.1664, 0.1480],        [0.1869, 0.1667, 0.1668, 0.1571, 0.1661, 0.1564],        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.1585],        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],       grad_fn=<SoftmaxBackward>)

2. 使用 PyTorch 的 tril 函数,创建一个遮蔽,使得对角线以上的值为零:

context_length = attn_scores.shape[0]mask_simple = torch.tril(torch.ones(context_length, context_length))print(mask_simple)#结果tensor([[1., 0., 0., 0., 0., 0.],        [1., 1., 0., 0., 0., 0.],        [1., 1., 1., 0., 0., 0.],        [1., 1., 1., 1., 0., 0.],        [1., 1., 1., 1., 1., 0.],        [1., 1., 1., 1., 1., 1.]])

将这个遮蔽与注意力权重相乘,将对角线以上的值归零:

masked_simple = attn_weights*mask_simpleprint(masked_simple)#结果tensor([[0.1921, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],        [0.2041, 0.1659, 0.0000, 0.0000, 0.0000, 0.0000],        [0.2036, 0.1659, 0.1662, 0.0000, 0.0000, 0.0000],        [0.1869, 0.1667, 0.1668, 0.1571, 0.0000, 0.0000],        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.0000],        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],       grad_fn=<MulBackward0>)

3.重新标准化注意力权重,使每一行的和再次为 1。

row_sums = masked_simple.sum(dim=1, keepdim=True)masked_simple_norm = masked_simple / row_sumsprint(masked_simple_norm)#结果tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],       grad_fn=<DivBackward0>)

信息泄露

应用遮蔽然后重新标准化注意力权重时,可能会出现后续的 Token(我们打算遮蔽的)的信息仍影响当前的 Token 的情况,这是因为它们的值是 softmax 函数计算的一部分。然而,关键点在于,当我们在遮蔽后重新标准化注意力权重时,我们实际上是在一个更小的子集上重新计算 softmax 函数(因为遮蔽位置不会对 softmax 值有任何贡献)。

softmax的数学优雅之处在于,尽管在最初的计算中分母包含了所有位置,但在遮蔽和重新归一化之后,被遮蔽的位置的影响被消除了————它们不会以任何有意义的方式对 softmax 得分产生影响。

简而言之,经过遮蔽和重新标准化后,注意力权重的分布就好像一开始只在未遮蔽位置上计算一样。这确保了后续(或其他遮蔽)Token 的信息不会像我们想象的那样泄露。

在因果注意力中获得掩蔽注意力权重矩阵的更高效方法,是在应用 softmax 函数之前,用负无穷大值遮蔽注意力得分。

softmax 函数将其输入转换为概率分布。当一行中存在负无穷大(-∞)值时,softmax 函数将其概率视为零。(数学上,这是因为 e^-∞ 趋近于 0。)

可以通过创建一个对角线以上是 1 的遮蔽,然后将这些 1 替换为负无穷大(-inf)值来实现这种更高效的遮蔽技巧:

mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)masked = attn_scores.masked_fill(mask.bool(), -torch.inf)print(masked)#结果tensor([[0.2899,   -inf,   -inf,   -inf,   -inf,   -inf],        [0.4656, 0.1723,   -inf,   -inf,   -inf,   -inf],        [0.4594, 0.1703, 0.1731,   -inf,   -inf,   -inf],        [0.2642, 0.1024, 0.1036, 0.0186,   -inf,   -inf],        [0.2183, 0.0874, 0.0882, 0.0177, 0.0786,   -inf],        [0.3408, 0.1270, 0.1290, 0.0198, 0.1290, 0.0078]],       grad_fn=<MaskedFillBackward0>)

只需对这些遮蔽结果应用 softmax 函数即可完成:

attn_weights = torch.softmax(masked / keys.shape[-1]**0.5, dim=1)print(attn_weights)#结果tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],       grad_fn=<SoftmaxBackward>)

 通过 Dropout 遮蔽额外的注意力权重

在深度学习中,Dropout 是一种技术,即在训练过程中随机忽略选定的隐藏层单元,有效地将它们“丢弃”。这种方法有助于防止过拟合,确保模型不会过度依赖任何特定的隐藏层单元组。需要强调的是,Dropout 仅在训练期间使用,在之后不可以使用。

在包括 GPT 在内的 Transformer 架构中,注意力机制中的 Dropout 通常应用于两个特定区域:计算注意力得分之后,或将注意力权重应用于值向量之后。

计算注意力权重后应用 Dropout 遮蔽,如图所示,这是实践中更常见的变体。利用因果注意力遮蔽(左上角),应用额外的 Dropout 遮蔽(右上角)来归零额外的注意力权重,以减少训练期间的过拟合。

使用了 50% 的 Dropout 率,这意味着遮蔽掉一半的注意力权重。(在训练 GPT 模型时,将使用较低的 Dropout 率,例如 0.1 或 0.2。)

torch.manual_seed(123)dropout = torch.nn.Dropout(0.5) #Aexample = torch.ones(6, 6) #Bprint(dropout(example))#结果tensor([[2., 2., 2., 2., 2., 2.],        [0., 2., 0., 0., 0., 0.],        [0., 0., 2., 0., 2., 0.],        [2., 2., 0., 0., 0., 2.],        [2., 0., 0., 0., 0., 2.],        [0., 2., 0., 0., 0., 0.]])

当对注意力权重矩阵应用 50% 的 Dropout 率时,矩阵中一半的元素被随机设为零。为了补偿活跃元素的减少,矩阵中剩余元素的值被放大了 1/0.5 = 2 倍。这种放大对于保持注意力权重的整体平衡至关重要,它能确保在训练和推理阶段,注意力机制的平均影响保持一致。

将 Dropout 应用于注意力权重矩阵本身:

torch.manual_seed(123)print(dropout(attn_weights))#结果tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],        [0.0000, 0.8966, 0.0000, 0.0000, 0.0000, 0.0000],        [0.0000, 0.0000, 0.6206, 0.0000, 0.0000, 0.0000],        [0.5517, 0.4921, 0.0000, 0.0000, 0.0000, 0.0000],        [0.4350, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],        [0.0000, 0.3327, 0.0000, 0.0000, 0.0000, 0.0000]],       grad_fn=<MulBackward0>)

实现一个紧凑的 causal attention 类

将因果注意力和 Dropout 技术集成到之前的 SelfAttention Python 类中。这个类随后将作为在即将到来的章节中开发多头注意力( multi-head attention )的模板,这是将实现的最后的attention类。

但在开始之前,还有一件事要确保,那就是代码能够处理由多个输入组成的批次,以便 CausalAttention 类支持实现的数据加载器生成的批次输出。

为简化模拟这种批次输入:

batch = torch.stack((inputs, inputs), dim=0)print(batch.shape) #A #结果torch.Size([2, 6, 3])

一个紧凑的因果注意力类

class CausalAttention(nn.Module):    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):        super().__init__()        self.d_out = d_out        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)        self.dropout = nn.Dropout(dropout) # New        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1)) # New    def forward(self, x):        b, num_tokens, d_in = x.shape # New batch dimension b        keys = self.W_key(x)        queries = self.W_query(x)        values = self.W_value(x)        attn_scores = queries @ keys.transpose(1, 2) # Changed transpose        attn_scores.masked_fill_(  # New, _ ops are in-place            self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)         attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)        attn_weights = self.dropout(attn_weights) # New        context_vec = attn_weights @ values        return context_vec

在 ‘init’ 方法中添加了一个 self.register_buffer() 调用。在 PyTorch 中使用 register_buffer 并不是所有情况下都必须的,但在这里有几个优点。例如,当我们在大型语言模型中使用 CausalAttention 类时,缓冲区会随着模型自动移动到适当的设备(CPU或GPU)上,这在后续章节中训练大语言模型时会很有用。这意味着我们不需要手动确保这些张量与模型参数在同一设备上,从而避免设备不匹配错误。

使用 CausalAttention 类:

torch.manual_seed(123)context_length = batch.shape[1]ca = CausalAttention(d_in, d_out, context_length, 0.0)context_vecs = ca(batch)print("context_vecs.shape:", context_vecs.shape)#结果context_vecs.shape: torch.Size([2, 6, 2])

Fish AI Reader

Fish AI Reader

AI辅助创作,多种专业模板,深度分析,高质量内容生成。从观点提取到深度思考,FishAI为您提供全方位的创作支持。新版本引入自定义参数,让您的创作更加个性化和精准。

FishAI

FishAI

鱼阅,AI 时代的下一个智能信息助手,助你摆脱信息焦虑

联系邮箱 441953276@qq.com

相关标签

因果注意力 自注意力机制 Dropout Transformer PyTorch
相关文章