1
resposta

Kernel Morrendo ao usar o plt.imshow(dados[0])

Pessoal,

Estou tentando executar o arquivo da aula 4 - Carregamento de Dados.ipynb e estou tendo um problema para executar as células que plotam os dados que vem do dataset do torchvision.

O código funciona perfeitamente bem até o momento em que ele precisa executar os seguintes comandos:

for i in range(3):
    dado, rotulo = train_set[i]

    plt.figure()
    plt.imshow(dado[0])
    plt.title('Rotulo: '+ str(rotulo))

O problema é que quando chega na linha do plt.imshow(dados[0]) o kernel morre toda vez, sem me dizer pq ele fez isso.

Mais alguém já passou por isso ou sabe um possível motivo para isso acontecer?

Obrigado desde já!

1 resposta

Oii Rafael, como você está?

Peço desculpas pela demora em dar um retorno.

Isso pode ser gerado por problemas de performance do computador, pois por padrão, para treinamento de imagens é necessário configurações um pouco mais robustas do que as "padrões". Bom espaço de memória e uma boa placa de vídeo são diferenciais.

Deixo abaixo o código completo até o ponto que você mencionou e recomendo que tente executar novamente pelo Google Colab, que possui internamente todo o suporte para trabalhar com processamento de imagens:


import torch
from torch import nn, optim

from torchvision import datasets
from torchvision import transforms

from torch.utils.data import DataLoader

import matplotlib.pyplot as plt
import numpy as np
import time

args = {
    'batch_size': 5,
    'num_workers': 4,
    'num_classes': 10,
    'lr': 1e-4,
    'weight_decay': 5e-4,
    'num_epochs': 30
}

if torch.cuda.is_available():
    args['device'] = torch.device('cuda')
else:
    args['device'] = torch.device('cpu')

print(args['device'])


train_set = datasets.MNIST('./',
                           train=True,
                           transform=transforms.ToTensor(),
                           download=True)

test_set = datasets.MNIST('./',
                          train=False,
                          transform=transforms.ToTensor(),
                          download=False)

print('Amostras de treino: ' + str(len(train_set)) +
      '\nAmostras de Teste:' + str(len(test_set)))


print(type(train_set))
print(type(train_set[0]))


for i in range(3):
    dado, rotulo = train_set[i]

    plt.figure()
    plt.imshow(dado[0])
    plt.title('Rotulo: ' + str(rotulo))

Qualquer dúvida fico à disposição.

Grande abraço e bons estudos!