1
resposta

[Projeto] 04 Faça como eu fiz: condicional espacial

import torch
import torch.nn as nn
import cv2
import numpy as np

==========================================

1. Processamento de Máscaras (OpenCV)

==========================================

def create_spatial_mask(image_path, size=(54, 54)):
"""
Extrai as bordas de uma imagem para servir como guia espacial.
"""
# Carrega em escala de cinza
img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
img = cv2.resize(img, size)

# Filtro Scharr: Excelente para bordas em baixa resolução
scharr_x = cv2.Scharr(img, cv2.CV_64F, 1, 0)
scharr_y = cv2.Scharr(img, cv2.CV_64F, 0, 1)

# Combina as bordas e normaliza
edge_mask = cv2.addWeighted(cv2.convertScaleAbs(scharr_x), 0.5, 
                           cv2.convertScaleAbs(scharr_y), 0.5, 0)

# Converte para tensor [1, H, W] no intervalo [0, 1]
mask_tensor = torch.from_numpy(edge_mask).float() / 255.0
return mask_tensor.unsqueeze(0)

==========================================

2. Bloco Residual para Condicionamento Espacial

==========================================

class SpatialResidualBlock(nn.Module):
def init(self, in_ch, out_ch, time_dim):
super().init()
self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
self.time_mlp = nn.Linear(time_dim, out_ch)
self.relu = nn.ReLU()
self.shortcut = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()

def forward(self, x, t_emb):
    h = self.relu(self.conv1(x))
    # Adiciona informação de tempo
    h = h + self.time_mlp(t_emb)[:, :, None, None]
    h = self.conv2(self.relu(h))
    return h + self.shortcut(x)

==========================================

3. U-Net Condicional Espacial

==========================================

class SpatialUNet(nn.Module):
def init(self):
super().init()
# Entrada: 3 canais (RGB) + 1 canal (Máscara Espacial) = 4
self.input_layer = nn.Conv2d(4, 64, kernel_size=3, padding=1)

    self.time_mlp = nn.Sequential(
        nn.Linear(32, 32),
        nn.ReLU()
    )
    
    # Estrutura simplificada para demonstração
    self.down1 = SpatialResidualBlock(64, 128, 32)
    self.up1 = nn.ConvTranspose2d(128, 64, 2, 2)
    self.final_conv = nn.Conv2d(128, 3, 1) # 64 (up) + 64 (skip)

def forward(self, x, t, mask):
    """
    x: Imagem ruidosa [B, 3, 54, 54]
    t: Timestep embedding [B, 32]
    mask: Máscara espacial [B, 1, 54, 54]
    """
    # Concatenação Espacial: Une a imagem ao mapa de bordas
    x_cond = torch.cat([x, mask], dim=1) # Resultado: [B, 4, 54, 54]
    
    t_emb = self.time_mlp(t)
    
    # Encoder
    h1 = self.input_layer(x_cond)
    h2 = self.down1(F.max_pool2d(h1, 2), t_emb)
    
    # Decoder com Skip Connection
    out = self.up1(h2)
    out = torch.cat([out, h1], dim=1) # Skip connection traz a máscara original
    
    return self.final_conv(out)

==========================================

4. Demonstração de Uso

==========================================

Simulação de batch

batch_size = 1
img_noisy = torch.randn((batch_size, 3, 54, 54))
time_input = torch.randn((batch_size, 32))

Máscara gerada previamente pelo OpenCV

mask_input = torch.randn((batch_size, 1, 54, 54))

model = SpatialUNet()
output = model(img_noisy, time_input, mask_input)

print(f"Entrada (RGB+Máscara): 4 canais")
print(f"Saída (Imagem Gerada): {output.shape}")

1 resposta

Olá, Moacir! Como vai?

Parabéns pela resolução da atividade!

Observei que você explorou o processamento de máscaras espaciais com OpenCV para enriquecer a entrada de imagens em PyTorch, utilizou muito bem o bloco residual condicional para integrar embeddings temporais ao fluxo da rede e ainda compreendeu a importância da arquitetura U-Net espacial para combinar informações de bordas com reconstrução de imagens.

Continue postando as suas soluções, com certeza isso ajudará outros estudantes e tem grande relevância para o fórum.

Fico à disposição! E se precisar, conte sempre com o apoio do fórum.

Abraço e bons estudos!

AluraConte com o apoio da comunidade Alura na sua jornada. Abraços e bons estudos!