1
resposta

[Projeto] 06 Faça como eu fiz: calculando a Self-Attention e implementando a Multi-Head Attention

  1. A Matemática da Intuição: Q, K e V
    Query (Q): "O que eu estou procurando?"
    Key (K): "O que eu ofereço de conteúdo?"
    Value (V): "Se eu for relevante, qual informação eu entrego?"

  2. Código Fonte: Self-Attention e Multi-Head Attention (MHA)

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class MultiHeadAttention(nn.Module):
def init(self, d_model, num_heads):
super(MultiHeadAttention, self).init()
assert d_model % num_heads == 0, "d_model precisa ser divisível por num_heads"

    self.d_model = d_model
    self.num_heads = num_heads
    self.d_k = d_model // num_heads # Dimensão de cada cabeça
    
    # Projeções Lineares para Q, K e V
    self.w_q = nn.Linear(d_model, d_model)
    self.w_k = nn.Linear(d_model, d_model)
    self.w_v = nn.Linear(d_model, d_model)
    
    # Projeção final de saída
    self.w_o = nn.Linear(d_model, d_model)

def forward(self, q, k, v, mask=None):
    batch_size = q.size(0)
    
    # 1. Projeção Linear e divisão em 'n' cabeças
    # (batch, seq_len, d_model) -> (batch, num_heads, seq_len, d_k)
    q = self.w_q(q).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
    k = self.w_k(k).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
    v = self.w_v(v).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
    
    # 2. Scaled Dot-Product Attention
    # scores = (Q * K^T) / sqrt(d_k)
    scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
    
    # 3. Aplicação da Máscara Causal (impede olhar para o futuro)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    
    # 4. Softmax para obter pesos (0 a 1) e multiplicação pelo Value
    attention_weights = F.softmax(scores, dim=-1)
    x = torch.matmul(attention_weights, v)
    
    # 5. Concatenação das cabeças e Projeção Final
    x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
    return self.w_o(x), attention_weights

--- Exemplo de Uso ---

d_model = 128
heads = 8
mha = MultiHeadAttention(d_model, heads)

Simulando entrada: batch de 1, 10 tokens, embedding de 128

x = torch.randn(1, 10, d_model)

Máscara Causal (Triangular Inferior)

mask = torch.tril(torch.ones(10, 10)).unsqueeze(0).unsqueeze(0)

output, weights = mha(x, x, x, mask=mask)
print(f"Saída do MHA: {output.shape}") # [1, 10, 128]

1 resposta

Olá! Como vai?

Parabéns pela resolução das atividades!

E para compartilhar códigos de maneira ainda mais organizada aqui no fórum, você pode utilizar a opção abaixo:

Opção inserir bloco de código, da caixa de perguntas do fórum da alura

Após clicar, irá aparecer uma estrutura da seguinte maneira:

Opção de bloco de código sendo illustrada
O resultado será o seguinte:

Copie o seu código aqui

Fico à disposição! E se precisar, conte sempre com o apoio do fórum.

Abraço e bons estudos!

AluraConte com o apoio da comunidade Alura na sua jornada. Abraços e bons estudos!