============================================================
Função única para treino e validação
============================================================
def forward(loader, net, epoch, mode):
# --------------------------------------------------------
# 1) Define o modo da rede
# --------------------------------------------------------
# train() ativa comportamento de treino
# eval() ativa comportamento de validação/teste
if mode == "train":
net.train()
else:
net.eval()
# --------------------------------------------------------
# 2) Lista para armazenar as losses do epoch
# --------------------------------------------------------
epoch_loss = []
# --------------------------------------------------------
# 3) Percorre todos os batches do DataLoader
# --------------------------------------------------------
for batch in loader:
dado, rotulo = batch
# ----------------------------------------------------
# 4) Envia os dados para o dispositivo (CPU ou GPU)
# ----------------------------------------------------
dado = dado.to(args['device'])
rotulo = rotulo.to(args['device'])
# ----------------------------------------------------
# 5) Forward pass
# ----------------------------------------------------
pred = net(dado)
# ----------------------------------------------------
# 6) Calcula a loss
# ----------------------------------------------------
loss = criterion(pred, rotulo)
# Salva a loss do batch para calcular média no final
epoch_loss.append(loss.detach().cpu().item())
# ----------------------------------------------------
# 7) Se estiver em modo treino, faz backpropagation
# ----------------------------------------------------
if mode == "train":
# Zera gradientes acumulados da iteração anterior
optimizer.zero_grad()
# Calcula gradientes
loss.backward()
# Atualiza os pesos
optimizer.step()
# --------------------------------------------------------
# 8) Calcula estatísticas da época
# --------------------------------------------------------
epoch_loss = np.asarray(epoch_loss)
print(
"Modo: %s | Época %d | Loss: %.4f +- %.4f"
% (mode, epoch, epoch_loss.mean(), epoch_loss.std())
)