1
resposta

[Projeto] 06 Faça como eu fiz: rede neural condicional

import torch
import torch.nn as nn
import torch.nn.functional as F

==========================================

1. Sinusoidal Time Embedding

==========================================

class SinusoidalPositionEmbeddings(nn.Module):
def init(self, dim):
super().init()
self.dim = dim

def forward(self, time):
    device = time.device
    half_dim = self.dim // 2
    embeddings = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
    embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
    embeddings = time[:, None] * embeddings[None, :]
    embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
    return embeddings

==========================================

2. Bloco Residual Condicional

==========================================

class ConditionalResidualBlock(nn.Module):
"""
Este bloco recebe a imagem e um vetor 'context' que é a
soma do embedding de tempo + embedding de rótulo.
"""
def init(self, in_ch, out_ch, context_dim):
super().init()
self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)

    # MLP para processar o contexto (Tempo + Classe)
    self.context_mlp = nn.Linear(context_dim, out_ch)
    
    self.relu = nn.ReLU()
    self.bn = nn.BatchNorm2d(out_ch)
    self.shortcut = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()

def forward(self, x, context):
    h = self.relu(self.bn(self.conv1(x)))
    
    # Injeção do Contexto (Time + Label)
    # Transformamos [B, context_dim] -> [B, out_ch, 1, 1]
    c = self.relu(self.context_mlp(context))
    h = h + c[(...,) + (None,) * 2]
    
    h = self.bn(self.conv2(h))
    return self.relu(h + self.shortcut(x))

==========================================

3. Arquitetura Conditional U-Net

==========================================

class ConditionalUNet(nn.Module):
def init(self, num_classes=3, img_ch=3):
super().init()

    time_dim = 32
    label_dim = 32
    context_dim = time_dim + label_dim # 64
    
    # 1. Embeddings
    self.time_mlp = SinusoidalPositionEmbeddings(time_dim)
    self.label_emb = nn.Embedding(num_classes, label_dim)
    
    # 2. Encoder
    self.enc1 = ConditionalResidualBlock(img_ch, 64, context_dim)
    self.pool = nn.MaxPool2d(2)
    
    # 3. Bottleneck
    self.bottleneck = ConditionalResidualBlock(64, 128, context_dim)
    
    # 4. Decoder
    self.up = nn.ConvTranspose2d(128, 64, 2, 2)
    self.dec1 = ConditionalResidualBlock(128, 64, context_dim) # 64(up) + 64(skip)
    
    self.final_conv = nn.Conv2d(64, img_ch, 1)

def forward(self, x, t, y):
    """
    x: Imagem ruidosa [B, 3, H, W]
    t: Timestep [B]
    y: Classe/Rótulo [B] (ex: 0, 1 ou 2)
    """
    # Criar contexto combinando tempo e rótulo
    t_emb = self.time_mlp(t)
    y_emb = self.label_emb(y)
    context = torch.cat([t_emb, y_emb], dim=-1) # [B, 64]
    
    # Encoder
    s1 = self.enc1(x, context)
    p1 = self.pool(s1)
    
    # Bottleneck
    b = self.bottleneck(p1, context)
    
    # Decoder com Skip Connection
    d1 = self.up(b)
    d1 = torch.cat([d1, s1], dim=1)
    d1 = self.dec1(d1, context)
    
    return self.final_conv(d1)

==========================================

4. Exemplo de Execução

==========================================

device = "cuda" if torch.cuda.is_available() else "cpu"
model = ConditionalUNet().to(device)

Simulação: Batch de 2 imagens 64x64

x = torch.randn((2, 3, 64, 64)).to(device)
t = torch.tensor([100, 250]).to(device) # Timesteps diferentes
y = torch.tensor([0, 2]).to(device) # Pedindo Classe 0 e Classe 2

Predição do ruído baseada no contexto

noise_pred = model(x, t, y)

print(f"Predição do ruído realizada: {noise_pred.shape}")

1 resposta

Olá, Moacir! Como vai?

Parabéns pela resolução da atividade!

Observei que você explorou o embedding senoidal de tempo para representar timesteps em PyTorch, utilizou muito bem o bloco residual condicional para integrar contexto de tempo e rótulo ao processamento e ainda compreendeu a importância da arquitetura U-Net condicional para realizar predições de ruído de forma estruturada e eficiente.

Continue postando as suas soluções, com certeza isso ajudará outros estudantes e tem grande relevância para o fórum.

Fico à disposição! E se precisar, conte sempre com o apoio do fórum.

Abraço e bons estudos!

AluraConte com o apoio da comunidade Alura na sua jornada. Abraços e bons estudos!