import torch
import torch.nn as nn
import torch.nn.functional as F
==========================================
1. Feed-Forward Network (FFN)
==========================================
class FeedForward(nn.Module):
def init(self, d_model, dropout=0.1):
super().init()
# Expansão de 4x na dimensão oculta
self.net = nn.Sequential(
nn.Linear(d_model, 4 * d_model),
nn.GELU(), # Ativação suave para LLMs
nn.Linear(4 * d_model, d_model),
nn.Dropout(dropout),
)
def forward(self, x):
return self.net(x)
==========================================
2. Bloco Transformer Completo
==========================================
class TransformerBlock(nn.Module):
def init(self, d_model, n_head):
super().init()
self.ln1 = nn.LayerNorm(d_model)
self.attn = MultiHeadAttention(d_model, n_head) # Implementado na aula anterior
self.ln2 = nn.LayerNorm(d_model)
self.ffn = FeedForward(d_model)
def forward(self, x, mask=None):
# Somas Residuais com Pre-Normalization
x = x + self.attn(self.ln1(x), mask=mask)
x = x + self.ffn(self.ln2(x))
return x
==========================================
3. Modelo de Linguagem (GPT-style)
==========================================
class LanguageModel(nn.Module):
def init(self, vocab_size, d_model, n_head, n_layer, max_seq_len):
super().init()
self.token_embedding = nn.Embedding(vocab_size, d_model)
self.position_embedding = nn.Embedding(max_seq_len, d_model)
# Empilhamento de Blocos
self.blocks = nn.ModuleList([TransformerBlock(d_model, n_head) for _ in range(n_layer)])
self.ln_f = nn.LayerNorm(d_model) # Camada final de norma
self.head = nn.Linear(d_model, vocab_size)
# Weight Tying (Compartilhamento de pesos)
# Reduz parâmetros e ajuda na generalização
self.token_embedding.weight = self.head.weight
def forward(self, idx, targets=None):
b, t = idx.shape
pos = torch.arange(0, t, device=idx.device)
x = self.token_embedding(idx) + self.position_embedding(pos)
# Máscara Causal para Auto-regressão
mask = torch.tril(torch.ones(t, t, device=idx.device)).view(1, 1, t, t)
for block in self.blocks:
x = block(x, mask=mask)
logits = self.head(self.ln_f(x))
return logits