import torch
import torch.nn as nn
import cv2
import numpy as np
==========================================
1. Processamento de Máscaras (OpenCV)
==========================================
def create_spatial_mask(image_path, size=(54, 54)):
"""
Extrai as bordas de uma imagem para servir como guia espacial.
"""
# Carrega em escala de cinza
img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
img = cv2.resize(img, size)
# Filtro Scharr: Excelente para bordas em baixa resolução
scharr_x = cv2.Scharr(img, cv2.CV_64F, 1, 0)
scharr_y = cv2.Scharr(img, cv2.CV_64F, 0, 1)
# Combina as bordas e normaliza
edge_mask = cv2.addWeighted(cv2.convertScaleAbs(scharr_x), 0.5,
cv2.convertScaleAbs(scharr_y), 0.5, 0)
# Converte para tensor [1, H, W] no intervalo [0, 1]
mask_tensor = torch.from_numpy(edge_mask).float() / 255.0
return mask_tensor.unsqueeze(0)
==========================================
2. Bloco Residual para Condicionamento Espacial
==========================================
class SpatialResidualBlock(nn.Module):
def init(self, in_ch, out_ch, time_dim):
super().init()
self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
self.time_mlp = nn.Linear(time_dim, out_ch)
self.relu = nn.ReLU()
self.shortcut = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()
def forward(self, x, t_emb):
h = self.relu(self.conv1(x))
# Adiciona informação de tempo
h = h + self.time_mlp(t_emb)[:, :, None, None]
h = self.conv2(self.relu(h))
return h + self.shortcut(x)
==========================================
3. U-Net Condicional Espacial
==========================================
class SpatialUNet(nn.Module):
def init(self):
super().init()
# Entrada: 3 canais (RGB) + 1 canal (Máscara Espacial) = 4
self.input_layer = nn.Conv2d(4, 64, kernel_size=3, padding=1)
self.time_mlp = nn.Sequential(
nn.Linear(32, 32),
nn.ReLU()
)
# Estrutura simplificada para demonstração
self.down1 = SpatialResidualBlock(64, 128, 32)
self.up1 = nn.ConvTranspose2d(128, 64, 2, 2)
self.final_conv = nn.Conv2d(128, 3, 1) # 64 (up) + 64 (skip)
def forward(self, x, t, mask):
"""
x: Imagem ruidosa [B, 3, 54, 54]
t: Timestep embedding [B, 32]
mask: Máscara espacial [B, 1, 54, 54]
"""
# Concatenação Espacial: Une a imagem ao mapa de bordas
x_cond = torch.cat([x, mask], dim=1) # Resultado: [B, 4, 54, 54]
t_emb = self.time_mlp(t)
# Encoder
h1 = self.input_layer(x_cond)
h2 = self.down1(F.max_pool2d(h1, 2), t_emb)
# Decoder com Skip Connection
out = self.up1(h2)
out = torch.cat([out, h1], dim=1) # Skip connection traz a máscara original
return self.final_conv(out)
==========================================
4. Demonstração de Uso
==========================================
Simulação de batch
batch_size = 1
img_noisy = torch.randn((batch_size, 3, 54, 54))
time_input = torch.randn((batch_size, 32))
Máscara gerada previamente pelo OpenCV
mask_input = torch.randn((batch_size, 1, 54, 54))
model = SpatialUNet()
output = model(img_noisy, time_input, mask_input)
print(f"Entrada (RGB+Máscara): 4 canais")
print(f"Saída (Imagem Gerada): {output.shape}")