import torch
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms
from PIL import Image
==========================================
1. Agendamento de Ruído (Beta Schedule)
==========================================
def linear_beta_schedule(timesteps, start=0.0001, end=0.02):
"""
Define uma progressão linear para os valores de beta (ruído).
"""
return torch.linspace(start, end, timesteps)
Configurações globais
T = 300 # Total de passos de difusão
betas = linear_beta_schedule(timesteps=T)
==========================================
2. Cálculos dos Parâmetros de Difusão
==========================================
Alfa é a "fidelidade" que sobra após o ruído (1 - beta)
alphas = 1. - betas
Alfa barra (cumprod) é o produto acumulado de todos os alfas até o tempo t
alphas_cumprod = torch.cumprod(alphas, axis=0)
Parâmetros necessários para a fórmula: x_t = sqrt(alpha_barra)*x_0 + sqrt(1-alpha_barra)*ruído
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)
def get_index_from_list(vals, t, x_shape):
"""
Extrai o valor correto do parâmetro para o lote (batch) atual de imagens.
"""
batch_size = t.shape[0]
out = vals.gather(-1, t.cpu())
# Reformata para permitir multiplicação direta com a imagem (batch, canal, h, w)
return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)
==========================================
3. O Forward Process (A Adição de Ruído)
==========================================
def forward_diffusion_sample(x_0, t, device="cpu"):
"""
Aplica o ruído à imagem original x_0 no tempo t de forma direta.
"""
noise = torch.randn_like(x_0) # Ruído gaussiano puro
# Busca os coeficientes para o tempo t
sqrt_alpha_bar_t = get_index_from_list(sqrt_alphas_cumprod, t, x_0.shape)
sqrt_one_minus_alpha_bar_t = get_index_from_list(sqrt_one_minus_alphas_cumprod, t, x_0.shape)
# Fórmula mestre da difusão:
x_t = sqrt_alpha_bar_t.to(device) * x_0.to(device) \
+ sqrt_one_minus_alpha_bar_t.to(device) * noise.to(device)
return x_t, noise.to(device)
==========================================
4. Demonstração Visual
==========================================
def show_diffusion_steps(img_path):
# Carregamento e pré-processamento (Pixel Art ou Foto)
transform = transforms.Compose([
transforms.Resize((64, 64)),
transforms.ToTensor(),
transforms.Lambda(lambda t: (t * 2) - 1) # Normaliza para [-1, 1]
])
# Se não tiver uma imagem, criamos um tensor aleatório simulando uma imagem colorida
try:
img = Image.open(img_path).convert("RGB")
img_tensor = transform(img).unsqueeze(0) # Adiciona dimensão de batch
except:
print("Imagem não encontrada, gerando dados sintéticos...")
img_tensor = torch.randn((1, 3, 64, 64))
# Definir timestamps para visualizar (Início, Meio, Quase Ruído, Ruído Total)
steps = [0, 50, 150, 299]
plt.figure(figsize=(15, 4))
for i, t_val in enumerate(steps):
t = torch.tensor([t_val])
img_noisy, _ = forward_diffusion_sample(img_tensor, t)
# Converte de volta para [0, 1] para plotar
img_plot = (img_noisy.squeeze(0).permute(1, 2, 0) + 1) / 2
img_plot = img_plot.clamp(0, 1).numpy()
plt.subplot(1, len(steps), i + 1)
plt.imshow(img_plot)
plt.title(f"Passo t = {t_val}")
plt.axis('off')
plt.show()