1
resposta

[Projeto] 08 Faça como eu fiz: treine modelo Diffusão com todos os artefatos para MLOps

import torch
import torch.nn as nn
import os
import json
from datetime import datetime
from torch.cuda.amp import GradScaler, autocast
from torchmetrics.image.fid import FrechetInceptionDistance
from torchvision.utils import save_image, make_grid
from tqdm import tqdm

==========================================

1. Configuração de Diretórios (Estrutura MLOps)

==========================================

TIMESTAMP = datetime.now().strftime("%Y%m%d_%H%M%S")
BASE_DIR = f"experiments/run_{TIMESTAMP}"
os.makedirs(f"{BASE_DIR}/checkpoints", exist_ok=True)
os.makedirs(f"{BASE_DIR}/samples", exist_ok=True)
os.makedirs(f"{BASE_DIR}/logs", exist_ok=True)

==========================================

2. Setup de Métricas e Performance

==========================================

device = "cuda" if torch.cuda.is_available() else "cpu"

FID requer imagens em 299x299 internamente (o TorchMetrics cuida disso)

fid_metric = FrechetInceptionDistance(feature=2048).to(device)
scaler = GradScaler() # Para Mixed Precision (Float16)

==========================================

3. Função de Salvamento de Artefatos

==========================================

def save_artifacts(model, optimizer, epoch, metrics, config):
# Salvar Checkpoint do Modelo
checkpoint_path = f"{BASE_DIR}/checkpoints/model_epoch_{epoch}.pt"
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'metrics': metrics
}, checkpoint_path)

# Salvar Hiperparâmetros (JSON)
with open(f"{BASE_DIR}/config.json", 'w') as f:
    json.dump(config, f, indent=4)
    
# Salvar Logs de Métricas
with open(f"{BASE_DIR}/logs/metrics.csv", 'a') as f:
    f.write(f"{epoch},{metrics['loss']},{metrics['fid']}\n")

==========================================

4. Loop de Treinamento com Avaliação

==========================================

def train_diffusion_mlops(model, dataloader, config):
optimizer = torch.optim.AdamW(model.parameters(), lr=config['lr'])
best_fid = float('inf')

for epoch in range(config['epochs']):
    model.train()
    epoch_loss = 0
    pbar = tqdm(dataloader, desc=f"Epoch {epoch}")
    
    for batch in pbar:
        images = batch[0].to(device) # Imagens reais
        t = torch.randint(
1 resposta

Olá, Moacir! Como vai?

Parabéns pela resolução da atividade!

Observei que você explorou o setup de diretórios para organizar experimentos em PyTorch, utilizou muito bem o registro de métricas com FID e Mixed Precision para otimizar desempenho e ainda compreendeu a importância da função de salvamento de artefatos para garantir reprodutibilidade e rastreabilidade em treinamentos de modelos.

Continue postando as suas soluções, com certeza isso ajudará outros estudantes e tem grande relevância para o fórum.

Fico à disposição! E se precisar, conte sempre com o apoio do fórum.

Abraço e bons estudos!

AluraConte com o apoio da comunidade Alura na sua jornada. Abraços e bons estudos!