Oii Rafael, como você está?
Peço desculpas pela demora em dar um retorno.
Isso pode ser gerado por problemas de performance do computador, pois por padrão, para treinamento de imagens é necessário configurações um pouco mais robustas do que as "padrões". Bom espaço de memória e uma boa placa de vídeo são diferenciais.
Deixo abaixo o código completo até o ponto que você mencionou e recomendo que tente executar novamente pelo Google Colab, que possui internamente todo o suporte para trabalhar com processamento de imagens:
import torch
from torch import nn, optim
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import time
args = {
'batch_size': 5,
'num_workers': 4,
'num_classes': 10,
'lr': 1e-4,
'weight_decay': 5e-4,
'num_epochs': 30
}
if torch.cuda.is_available():
args['device'] = torch.device('cuda')
else:
args['device'] = torch.device('cpu')
print(args['device'])
train_set = datasets.MNIST('./',
train=True,
transform=transforms.ToTensor(),
download=True)
test_set = datasets.MNIST('./',
train=False,
transform=transforms.ToTensor(),
download=False)
print('Amostras de treino: ' + str(len(train_set)) +
'\nAmostras de Teste:' + str(len(test_set)))
print(type(train_set))
print(type(train_set[0]))
for i in range(3):
dado, rotulo = train_set[i]
plt.figure()
plt.imshow(dado[0])
plt.title('Rotulo: ' + str(rotulo))
Qualquer dúvida fico à disposição.
Grande abraço e bons estudos!