import torch
import torch.nn as nn
import torch.nn.functional as F
==========================================
1. Sinusoidal Time Embedding
==========================================
class SinusoidalPositionEmbeddings(nn.Module):
def init(self, dim):
super().init()
self.dim = dim
def forward(self, time):
device = time.device
half_dim = self.dim // 2
embeddings = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
embeddings = time[:, None] * embeddings[None, :]
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
return embeddings
==========================================
2. Bloco Residual Condicional
==========================================
class ConditionalResidualBlock(nn.Module):
"""
Este bloco recebe a imagem e um vetor 'context' que é a
soma do embedding de tempo + embedding de rótulo.
"""
def init(self, in_ch, out_ch, context_dim):
super().init()
self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
# MLP para processar o contexto (Tempo + Classe)
self.context_mlp = nn.Linear(context_dim, out_ch)
self.relu = nn.ReLU()
self.bn = nn.BatchNorm2d(out_ch)
self.shortcut = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()
def forward(self, x, context):
h = self.relu(self.bn(self.conv1(x)))
# Injeção do Contexto (Time + Label)
# Transformamos [B, context_dim] -> [B, out_ch, 1, 1]
c = self.relu(self.context_mlp(context))
h = h + c[(...,) + (None,) * 2]
h = self.bn(self.conv2(h))
return self.relu(h + self.shortcut(x))
==========================================
3. Arquitetura Conditional U-Net
==========================================
class ConditionalUNet(nn.Module):
def init(self, num_classes=3, img_ch=3):
super().init()
time_dim = 32
label_dim = 32
context_dim = time_dim + label_dim # 64
# 1. Embeddings
self.time_mlp = SinusoidalPositionEmbeddings(time_dim)
self.label_emb = nn.Embedding(num_classes, label_dim)
# 2. Encoder
self.enc1 = ConditionalResidualBlock(img_ch, 64, context_dim)
self.pool = nn.MaxPool2d(2)
# 3. Bottleneck
self.bottleneck = ConditionalResidualBlock(64, 128, context_dim)
# 4. Decoder
self.up = nn.ConvTranspose2d(128, 64, 2, 2)
self.dec1 = ConditionalResidualBlock(128, 64, context_dim) # 64(up) + 64(skip)
self.final_conv = nn.Conv2d(64, img_ch, 1)
def forward(self, x, t, y):
"""
x: Imagem ruidosa [B, 3, H, W]
t: Timestep [B]
y: Classe/Rótulo [B] (ex: 0, 1 ou 2)
"""
# Criar contexto combinando tempo e rótulo
t_emb = self.time_mlp(t)
y_emb = self.label_emb(y)
context = torch.cat([t_emb, y_emb], dim=-1) # [B, 64]
# Encoder
s1 = self.enc1(x, context)
p1 = self.pool(s1)
# Bottleneck
b = self.bottleneck(p1, context)
# Decoder com Skip Connection
d1 = self.up(b)
d1 = torch.cat([d1, s1], dim=1)
d1 = self.dec1(d1, context)
return self.final_conv(d1)
==========================================
4. Exemplo de Execução
==========================================
device = "cuda" if torch.cuda.is_available() else "cpu"
model = ConditionalUNet().to(device)
Simulação: Batch de 2 imagens 64x64
x = torch.randn((2, 3, 64, 64)).to(device)
t = torch.tensor([100, 250]).to(device) # Timesteps diferentes
y = torch.tensor([0, 2]).to(device) # Pedindo Classe 0 e Classe 2
Predição do ruído baseada no contexto
noise_pred = model(x, t, y)
print(f"Predição do ruído realizada: {noise_pred.shape}")