0
respostas

[Projeto] implementamos duas funções: train() para o fluxo de treinamento e test() para o fluxo de validação

============================================================

Função única para treino e validação

============================================================

def forward(loader, net, epoch, mode):

# --------------------------------------------------------
# 1) Define o modo da rede
# --------------------------------------------------------
# train() ativa comportamento de treino
# eval() ativa comportamento de validação/teste
if mode == "train":
    net.train()
else:
    net.eval()

# --------------------------------------------------------
# 2) Lista para armazenar as losses do epoch
# --------------------------------------------------------
epoch_loss = []

# --------------------------------------------------------
# 3) Percorre todos os batches do DataLoader
# --------------------------------------------------------
for batch in loader:
    dado, rotulo = batch

    # ----------------------------------------------------
    # 4) Envia os dados para o dispositivo (CPU ou GPU)
    # ----------------------------------------------------
    dado = dado.to(args['device'])
    rotulo = rotulo.to(args['device'])

    # ----------------------------------------------------
    # 5) Forward pass
    # ----------------------------------------------------
    pred = net(dado)

    # ----------------------------------------------------
    # 6) Calcula a loss
    # ----------------------------------------------------
    loss = criterion(pred, rotulo)

    # Salva a loss do batch para calcular média no final
    epoch_loss.append(loss.detach().cpu().item())

    # ----------------------------------------------------
    # 7) Se estiver em modo treino, faz backpropagation
    # ----------------------------------------------------
    if mode == "train":
        # Zera gradientes acumulados da iteração anterior
        optimizer.zero_grad()

        # Calcula gradientes
        loss.backward()

        # Atualiza os pesos
        optimizer.step()

# --------------------------------------------------------
# 8) Calcula estatísticas da época
# --------------------------------------------------------
epoch_loss = np.asarray(epoch_loss)

print(
    "Modo: %s | Época %d | Loss: %.4f +- %.4f"
    % (mode, epoch, epoch_loss.mean(), epoch_loss.std())
)