import torch
import os
import json
from torchvision.utils import save_image, make_grid
from PIL import Image
==========================================
1. Configurações de Ambiente e Caminhos
==========================================
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
CHECKPOINT_PATH = "experiments/run_20260407/checkpoints/best_model.pt"
OUTPUT_DIR = "projeto_final_entrega/resultados"
os.makedirs(OUTPUT_DIR, exist_ok=True)
==========================================
2. Carregamento do Modelo e Metadados
==========================================
def load_final_model(path):
# Inicializa a arquitetura (mesma usada no treino)
model = ConditionalUNet(num_classes=3).to(DEVICE)
# Carrega os pesos e as métricas salvas
checkpoint = torch.load(path, map_location=DEVICE)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
print(f"Modelo carregado! Época: {checkpoint['epoch']} | FID: {checkpoint['metrics']['fid']:.2f}")
return model, checkpoint['metrics']
==========================================
3. Função de Amostragem (Reverse Diffusion)
==========================================
@torch.no_grad()
def generate_final_samples(model, labels, guidance_scale=7.5, steps=250):
"""
Gera imagens a partir do ruído puro usando Classifier-Free Guidance.
"""
n = len(labels)
# 1. Começamos com ruído gaussiano puro
img = torch.randn((n, 3, 64, 64), device=DEVICE)
# 2. Loop Reverso (de T até 0)
for i in reversed(range(steps)):
t = torch.full((n,), i, device=DEVICE, dtype=torch.long)
# Predição Condicionada e Não-Condicionada para CFG
noise_pred_cond = model(img, t, labels)
noise_pred_uncond = model(img, t, torch.zeros_like(labels)) # Rótulo nulo
# Ajuste pelo Guidance Scale
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
# Passo de limpeza (Simplificado para o exemplo)
img = img - (noise_pred * 0.05)
# Normaliza de volta para [0, 1]
return (img.clamp(-1, 1) + 1) / 2
==========================================
4. Execução e Documentação
==========================================
if name == "main":
# Carregar modelo
model, final_metrics = load_final_model(CHECKPOINT_PATH)
# Definir o que queremos gerar (Ex: 2 de cada classe)
target_labels = torch.tensor([0, 0, 1, 1, 2, 2]).to(DEVICE) # X, Y, Z
# Testar diferentes níveis de Guidance Scale
for gs in [1.0, 3.0, 7.5]:
print(f"Gerando grade com Guidance Scale: {gs}")
final_images = generate_final_samples(model, target_labels, guidance_scale=gs)
# Criar e salvar Grid
grid = make_grid(final_images, nrow=2)
save_image(grid, f"{OUTPUT_DIR}/amostras_gs_{gs}.png")
# Gerar arquivo de resultados finais
resumo_final = {
"projeto": "Modelos de Difusão - Pixel Art/Médico",
"data_conclusao": "2026-04-07",
"metricas_finais": final_metrics,
"status": "Aprovado para Produção"
}
with open(f"{OUTPUT_DIR}/relatorio_final.json", 'w') as f:
json.dump(resumo_final, f, indent=4)
print(f"Processo concluído.
