import torch
import torch.nn as nn
import os
import json
from datetime import datetime
from torch.cuda.amp import GradScaler, autocast
from torchmetrics.image.fid import FrechetInceptionDistance
from torchvision.utils import save_image, make_grid
from tqdm import tqdm
==========================================
1. Configuração de Diretórios (Estrutura MLOps)
==========================================
TIMESTAMP = datetime.now().strftime("%Y%m%d_%H%M%S")
BASE_DIR = f"experiments/run_{TIMESTAMP}"
os.makedirs(f"{BASE_DIR}/checkpoints", exist_ok=True)
os.makedirs(f"{BASE_DIR}/samples", exist_ok=True)
os.makedirs(f"{BASE_DIR}/logs", exist_ok=True)
==========================================
2. Setup de Métricas e Performance
==========================================
device = "cuda" if torch.cuda.is_available() else "cpu"
FID requer imagens em 299x299 internamente (o TorchMetrics cuida disso)
fid_metric = FrechetInceptionDistance(feature=2048).to(device)
scaler = GradScaler() # Para Mixed Precision (Float16)
==========================================
3. Função de Salvamento de Artefatos
==========================================
def save_artifacts(model, optimizer, epoch, metrics, config):
# Salvar Checkpoint do Modelo
checkpoint_path = f"{BASE_DIR}/checkpoints/model_epoch_{epoch}.pt"
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'metrics': metrics
}, checkpoint_path)
# Salvar Hiperparâmetros (JSON)
with open(f"{BASE_DIR}/config.json", 'w') as f:
json.dump(config, f, indent=4)
# Salvar Logs de Métricas
with open(f"{BASE_DIR}/logs/metrics.csv", 'a') as f:
f.write(f"{epoch},{metrics['loss']},{metrics['fid']}\n")
==========================================
4. Loop de Treinamento com Avaliação
==========================================
def train_diffusion_mlops(model, dataloader, config):
optimizer = torch.optim.AdamW(model.parameters(), lr=config['lr'])
best_fid = float('inf')
for epoch in range(config['epochs']):
model.train()
epoch_loss = 0
pbar = tqdm(dataloader, desc=f"Epoch {epoch}")
for batch in pbar:
images = batch[0].to(device) # Imagens reais
t = torch.randint(