1
resposta

[Projeto] 06 Faça como eu fiz: modelo UNet

import torch
import torch.nn as nn
from torch.cuda.amp import GradScaler, autocast
from collections import defaultdict

==========================================

1. Bloco Residual com Contexto Temporal

==========================================

class ResidualBlock(nn.Module):
def init(self, in_ch, out_ch, time_emb_dim):
super().init()
self.time_mlp = nn.Linear(time_emb_dim, out_ch)

    self.conv1 = nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1)
    self.conv2 = nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1)
    self.bn1 = nn.BatchNorm2d(out_ch)
    self.bn2 = nn.BatchNorm2d(out_ch)
    self.relu = nn.ReLU()
    
    # Atalho caso as dimensões de entrada/saída sejam diferentes
    self.shortcut = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()

def forward(self, x, t):
    # Primeiro estágio
    h = self.relu(self.bn1(self.conv1(x)))
    
    # Injeção de Tempo: Transformamos t para casar com os canais de h
    time_emb = self.relu(self.time_mlp(t))
    time_emb = time_emb[(...,) + (None,) * 2] # Reshape para [B, C, 1, 1]
    
    h = h + time_emb
    
    # Segundo estágio
    h = self.bn2(self.conv2(h))
    
    # Soma residual (Shortcut + Transformação)
    return self.relu(h + self.shortcut(x))

==========================================

2. Arquitetura U-Net Avançada

==========================================

class AdvancedUNet(nn.Module):
def init(self, in_channels=3, time_dim=32):
super().init()

    # Mapeamento de ativações para visualização (Hooks)
    self.activations = defaultdict(list)

    # Encoder (Downsampling)
    self.enc1 = ResidualBlock(in_channels, 64, time_dim)
    self.pool1 = nn.MaxPool2d(2)
    self.enc2 = ResidualBlock(64, 128, time_dim)
    self.pool2 = nn.MaxPool2d(2)
    
    # Bottleneck
    self.bottleneck = ResidualBlock(128, 256, time_dim)
    
    # Decoder (Upsampling)
    self.up1 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
    self.dec1 = ResidualBlock(256, 128, time_dim) # 128 (up) + 128 (skip)
    
    self.up2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
    self.dec2 = ResidualBlock(128, 64, time_dim) # 64 (up) + 64 (skip)
    
    self.final_conv = nn.Conv2d(64, in_channels, kernel_size=1)

def forward(self, x, t):
    # Encoder
    s1 = self.enc1(x, t)
    p1 = self.pool1(s1)
    s2 = self.enc2(p1, t)
    p2 = self.pool2(s2)
    
    # Bottleneck
    b = self.bottleneck(p2, t)
    
    # Decoder com Skip Connections
    d1 = self.up1(b)
    d1 = torch.cat((d1, s2), dim=1) # Concatena canal
    d1 = self.dec1(d1, t)
    
    d2 = self.up2(d1)
    d2 = torch.cat((d2, s1), dim=1)
    d2 = self.dec2(d2, t)
    
    return self.final_conv(d2)

==========================================

3. Hooks para Visualização de Ativações

==========================================

def register_visual_hooks(model):
def hook_fn(module, input, output):
# Armazena apenas uma amostra do batch para visualização
model.activations[module].append(output.detach().cpu())

# Registra em todos os blocos residuais
for name, layer in model.named_modules():
    if isinstance(layer, ResidualBlock):
        layer.register_forward_hook(hook_fn)

==========================================

4. Configuração de Treinamento (AdamW + AMP)

==========================================

device = "cuda" if torch.cuda.is_available() else "cpu"
model = AdvancedUNet().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
scaler = GradScaler() # Para precisão mista (AMP)

Exemplo de loop de treinamento simplificado

def train_step(images, timesteps, targets):
images, timesteps, targets = images.to(device), timesteps.to(device), targets.to(device)

with autocast(): # Executa em float16 onde possível
    output = model(images, timesteps)
    loss = nn.MSELoss()(output, targets)

optimizer.zero_grad()
scaler.scale(loss).backward() # Escala o loss para evitar underflow
scaler.step(optimizer)
scaler.update()

return loss.item()
1 resposta

Oi, Moacir!

Você elevou bastante o nível aqui. A estrutura ficou mais próxima de uma implementação real, com bloco residual, skip connections, hooks e até configuração de treino com AdamW e AMP.

Isso mostra que você não ficou só na ideia da U-Net, mas já começou a pensar em treinamento e monitoramento do modelo, o que enriquece bastante o exercício.

Bons estudos!

Sucesso

Imagem da comunidade