A Matemática da Intuição: Q, K e V
Query (Q): "O que eu estou procurando?"
Key (K): "O que eu ofereço de conteúdo?"
Value (V): "Se eu for relevante, qual informação eu entrego?"Código Fonte: Self-Attention e Multi-Head Attention (MHA)
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class MultiHeadAttention(nn.Module):
def init(self, d_model, num_heads):
super(MultiHeadAttention, self).init()
assert d_model % num_heads == 0, "d_model precisa ser divisível por num_heads"
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads # Dimensão de cada cabeça
# Projeções Lineares para Q, K e V
self.w_q = nn.Linear(d_model, d_model)
self.w_k = nn.Linear(d_model, d_model)
self.w_v = nn.Linear(d_model, d_model)
# Projeção final de saída
self.w_o = nn.Linear(d_model, d_model)
def forward(self, q, k, v, mask=None):
batch_size = q.size(0)
# 1. Projeção Linear e divisão em 'n' cabeças
# (batch, seq_len, d_model) -> (batch, num_heads, seq_len, d_k)
q = self.w_q(q).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
k = self.w_k(k).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
v = self.w_v(v).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
# 2. Scaled Dot-Product Attention
# scores = (Q * K^T) / sqrt(d_k)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
# 3. Aplicação da Máscara Causal (impede olhar para o futuro)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
# 4. Softmax para obter pesos (0 a 1) e multiplicação pelo Value
attention_weights = F.softmax(scores, dim=-1)
x = torch.matmul(attention_weights, v)
# 5. Concatenação das cabeças e Projeção Final
x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
return self.w_o(x), attention_weights
--- Exemplo de Uso ---
d_model = 128
heads = 8
mha = MultiHeadAttention(d_model, heads)
Simulando entrada: batch de 1, 10 tokens, embedding de 128
x = torch.randn(1, 10, d_model)
Máscara Causal (Triangular Inferior)
mask = torch.tril(torch.ones(10, 10)).unsqueeze(0).unsqueeze(0)
output, weights = mha(x, x, x, mask=mask)
print(f"Saída do MHA: {output.shape}") # [1, 10, 128]
