Solucionado (ver solução)
Solucionado
(ver solução)
1
resposta

[Dúvida] É possível salvar os parâmetros de uma rede?

Olá, gostaria de saber se é possível salvar os dados de output ou loss durante o treinamento de uma rede para que quando for treinada novamente não comece do "zero". Caso sim, como poderia fazer?

Acretido que poderia salvar durante o loop e antes de começar o treinamento carregar essas informações da loss na primeira interação

 for epoch in range(args['num_epochs']):
        epoch_loss = []
        for batch in train_loader:
            dado, rotulo = batch
            dado = dado.to(args['device'])
            rotulo = rotulo.to(args['device'])

            # Forward
            output = my_net(dado)
            loss = criterion(output, rotulo)
            epoch_loss.append(loss.cpu().data)

            # Backward
            loss.backward()
            optimizer.step()

        epoch_loss = np.asarray(epoch_loss)
        # Salvar os dados da variável epoch_loss em banco 
        print(f'Epoca: {epoch+1} | Loss: {epoch_loss.mean():.4f} + / - {epoch_loss.std():.4f}')
1 resposta
solução!

Olá Raphael, tudo bem?

Sim, é possível salvar os parâmetros de uma rede neural, assim como os dados de loss, durante o treinamento. Essa prática de retomar o treinamento de onde parou, chamamos de checkpointing.

Você pode salvar tanto os pesos do modelo quanto o estado do otimizador com torch.save(), e depois carregá-los com torch.load().

Você pode fazer assim:

# Salvando o estado do modelo e do otimizador
torch.save({
    'epoch': epoch,
    'model_state_dict': net.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': epoch_loss.mean(),
}, 'checkpoint.pth')

Para carregar o estado salvo:

checkpoint = torch.load('checkpoint.pth', weights_only=False)
net.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch'] + 1

Se quiser retomar apenas quando o checkpoint existir, você pode incluir uma verificação:

import os

if os.path.exists('checkpoint.pth'):
    checkpoint = torch.load('checkpoint.pth', weights_only=False)
    net.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    print(f'Retomando do epoch {start_epoch} com loss anterior de {checkpoint["loss"]:.4f}')
else:
    start_epoch = 0
    print('Começando do zero.')

Com isso, você consegue treinar por etapas, interromper e continuar o treinamento com segurança.

Espero ter ajudado.

Qualquer dúvida, compartilhe no fórum.

Abraços e bons estudos!

Caso este post tenha lhe ajudado, por favor, marcar como solucionado ✓. Bons Estudos!