import torch
import torch.nn as nn
import torch.nn.functional as F
==========================================
1. Codificação de Tempo (Time Embedding)
==========================================
class SinusoidalPositionEmbeddings(nn.Module):
"""
Transforma o valor escalar do timestep t em um vetor de alta dimensão.
Isso ajuda a rede a entender se o ruído é leve (t pequeno) ou pesado (t grande).
"""
def init(self, dim):
super().init()
self.dim = dim
def forward(self, time):
device = time.device
half_dim = self.dim // 2
embeddings = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
embeddings = time[:, None] * embeddings[None, :]
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
return embeddings
==========================================
2. Blocos de Construção da U-Net
==========================================
class Block(nn.Module):
def init(self, in_ch, out_ch, time_emb_dim):
super().init()
# MLP para integrar a informação do tempo nas convoluções
self.time_mlp = nn.Linear(time_emb_dim, out_ch)
self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
self.relu = nn.ReLU()
def forward(self, x, t):
# Primeira convolução
h = self.relu(self.conv1(x))
# Adiciona o contexto temporal (t)
time_emb = self.relu(self.time_mlp(t))
# Ajusta a dimensão do tempo para somar com a imagem [B, C, 1, 1]
time_emb = time_emb[(..., ) + (None, ) * 2]
h = h + time_emb
# Segunda convolução
return self.relu(self.conv2(h))
==========================================
3. Arquitetura Mini U-Net Completa
==========================================
class SimpleUNet(nn.Module):
def init(self):
super().init()
image_channels = 3
down_channels = (64, 128, 256)
up_channels = (256, 128, 64)
out_dim = 3
time_emb_dim = 32
# Codificador de Tempo
self.time_mlp = nn.Sequential(
SinusoidalPositionEmbeddings(time_emb_dim),
nn.Linear(time_emb_dim, time_emb_dim),
nn.ReLU()
)
# Encoder (Downsampling)
self.downs = nn.ModuleList([Block(image_channels, down_channels[0], time_emb_dim)])
for i in range(len(down_channels)-1):
self.downs.append(Block(down_channels[i], down_channels[i+1], time_emb_dim))
self.pool = nn.MaxPool2d(2)
# Decoder (Upsampling)
self.ups = nn.ModuleList([])
for i in range(len(up_channels)-1):
self.ups.append(nn.ConvTranspose2d(up_channels[i], up_channels[i+1], 2, 2))
self.ups.append(Block(up_channels[i], up_channels[i+1], time_emb_dim)) # Para processar concatenação
# Camada Final (Gera o ruído previsto)
self.final_conv = nn.Conv2d(up_channels[-1], out_dim, 1)
def forward(self, x, t):
# 1. Embedding de tempo
t = self.time_mlp(t)
# 2. Caminho de Descida (Encoder) + Guardar Skip Connections
residuals = []
for down in self.downs:
x = down(x, t)
residuals.append(x)
x = self.pool(x)
# 3. Caminho de Subida (Decoder) + Skip Connections
for i in range(0, len(self.ups), 2):
x = self.ups[i](x) # Transpose Conv
residual = residuals.pop()
# Concatena o que veio do Encoder (Skip Connection)
x = torch.cat((x, residual), dim=1)
x = self.ups[i+1](x, t) # Block
return self.final_conv(x)
==========================================
4. Exemplo de Uso
==========================================
Criar modelo
model = SimpleUNet()
Simular uma entrada (Batch de 1, 3 Canais, 64x64 pixels)
test_input = torch.randn((1, 3, 64, 64))
Simular um timestep (ex: passo 150 de 300)
test_t = torch.tensor([150])
Prever o ruído
predicted_noise = model(test_input, test_t)
print(f"Formato da entrada: {test_input.shape}")
print(f"Formato da saída (ruído previsto): {predicted_noise.shape}")