import torch
import torch.nn as nn
from torch.cuda.amp import GradScaler, autocast
from collections import defaultdict
==========================================
1. Bloco Residual com Contexto Temporal
==========================================
class ResidualBlock(nn.Module):
def init(self, in_ch, out_ch, time_emb_dim):
super().init()
self.time_mlp = nn.Linear(time_emb_dim, out_ch)
self.conv1 = nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(out_ch)
self.bn2 = nn.BatchNorm2d(out_ch)
self.relu = nn.ReLU()
# Atalho caso as dimensões de entrada/saída sejam diferentes
self.shortcut = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()
def forward(self, x, t):
# Primeiro estágio
h = self.relu(self.bn1(self.conv1(x)))
# Injeção de Tempo: Transformamos t para casar com os canais de h
time_emb = self.relu(self.time_mlp(t))
time_emb = time_emb[(...,) + (None,) * 2] # Reshape para [B, C, 1, 1]
h = h + time_emb
# Segundo estágio
h = self.bn2(self.conv2(h))
# Soma residual (Shortcut + Transformação)
return self.relu(h + self.shortcut(x))
==========================================
2. Arquitetura U-Net Avançada
==========================================
class AdvancedUNet(nn.Module):
def init(self, in_channels=3, time_dim=32):
super().init()
# Mapeamento de ativações para visualização (Hooks)
self.activations = defaultdict(list)
# Encoder (Downsampling)
self.enc1 = ResidualBlock(in_channels, 64, time_dim)
self.pool1 = nn.MaxPool2d(2)
self.enc2 = ResidualBlock(64, 128, time_dim)
self.pool2 = nn.MaxPool2d(2)
# Bottleneck
self.bottleneck = ResidualBlock(128, 256, time_dim)
# Decoder (Upsampling)
self.up1 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
self.dec1 = ResidualBlock(256, 128, time_dim) # 128 (up) + 128 (skip)
self.up2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
self.dec2 = ResidualBlock(128, 64, time_dim) # 64 (up) + 64 (skip)
self.final_conv = nn.Conv2d(64, in_channels, kernel_size=1)
def forward(self, x, t):
# Encoder
s1 = self.enc1(x, t)
p1 = self.pool1(s1)
s2 = self.enc2(p1, t)
p2 = self.pool2(s2)
# Bottleneck
b = self.bottleneck(p2, t)
# Decoder com Skip Connections
d1 = self.up1(b)
d1 = torch.cat((d1, s2), dim=1) # Concatena canal
d1 = self.dec1(d1, t)
d2 = self.up2(d1)
d2 = torch.cat((d2, s1), dim=1)
d2 = self.dec2(d2, t)
return self.final_conv(d2)
==========================================
3. Hooks para Visualização de Ativações
==========================================
def register_visual_hooks(model):
def hook_fn(module, input, output):
# Armazena apenas uma amostra do batch para visualização
model.activations[module].append(output.detach().cpu())
# Registra em todos os blocos residuais
for name, layer in model.named_modules():
if isinstance(layer, ResidualBlock):
layer.register_forward_hook(hook_fn)
==========================================
4. Configuração de Treinamento (AdamW + AMP)
==========================================
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AdvancedUNet().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
scaler = GradScaler() # Para precisão mista (AMP)
Exemplo de loop de treinamento simplificado
def train_step(images, timesteps, targets):
images, timesteps, targets = images.to(device), timesteps.to(device), targets.to(device)
with autocast(): # Executa em float16 onde possível
output = model(images, timesteps)
loss = nn.MSELoss()(output, targets)
optimizer.zero_grad()
scaler.scale(loss).backward() # Escala o loss para evitar underflow
scaler.step(optimizer)
scaler.update()
return loss.item()