4
respostas

Problema da função plot_boundary para problemas binários

A função plot_boundary descrita na aula (e mostrada abaixo) faz referência a problemas "binários" (com dois labels no target, ou classe).

O fato é que ao descomentar as linhas relativas aos problemas binários (e, obviamente, comentar a parte para multi-class) na execução da célula dá o seguinte erro: NameError: name 'Variable' is not defined como se, a referência Variable não existisse ou algo do tipo.

É um problema da função mesmo ou de como eu preparei meus dados?

import numpy as np 

def plot_boundary(X, y, model):
  x_min, x_max = X[:, 0].min()-0.1, X[:, 0].max()+0.1
  y_min, y_max = X[:, 1].min()-0.1, X[:, 1].max()+0.1

  spacing = min(x_max - x_min, y_max - y_min) / 100

  XX, YY = np.meshgrid(np.arange(x_min, x_max, spacing),
                       np.arange(y_min, y_max, spacing))

  data = np.hstack((XX.ravel().reshape(-1,1), 
                    YY.ravel().reshape(-1,1)))

  # For binary problems
  # db_prob = model(Variable(torch.Tensor(data)).cuda() )
  # clf = np.where(db_prob.cpu().data < 0.5,0,1)

  # For multi-class problems
  db_prob = model(torch.Tensor(data).to(device) )
  clf = np.argmax(db_prob.cpu().data.numpy(), axis=-1)

  Z = clf.reshape(XX.shape)

  plt.contourf(XX, YY, Z, cmap=plt.cm.brg, alpha=0.5)
  plt.scatter(X[:,0], X[:,1], c=y, edgecolors='k', s=25, cmap=plt.cm.brg)

Alguma ajuda?

4 respostas

Olá!

Isso depende da versão do PyTorch, as mais recentes não tem mais o objeto tipo Variable. Pode tentar sem ele e ver se funciona?

Bom, eu retirei a menção à Variable e não deu certo. O erro mudou: RuntimeError: mat1 dim 1 must match mat2 dim 0

Olá, joia?

Testei aqui, só removendo o Variable e funcionou, então é possível que o problema esteja em outro ponto do código. Lembre-se que para problemas binários, a saída da sua rede tem que ser de tamanho 1, então a rede ficaria dessa forma:

net = nn.Sequential(
    nn.Linear(input_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, 1), ##########
    nn.Sigmoid() ##########
)

Assumindo aí que a saída da rede é um valor entre 0 e 1 (garantido pela ativação sigmóide) e que vamos interpretar valores < 0.5 como uma classe e >0.5 como outra classe.

Confere se o restante do código está correto.

E uma coisa que ajuda muito a mapear o problema é postar o erro completo, incluindo a linha onde o problema aconteceu.

Uma boa semana pra vc!

O código é esse:

import numpy as np 

def plot_boundary(X, y, model):
  x_min, x_max = X[:, 0].min()-0.1, X[:, 0].max()+0.1
  y_min, y_max = X[:, 1].min()-0.1, X[:, 1].max()+0.1

  spacing = min(x_max - x_min, y_max - y_min) / 100

  XX, YY = np.meshgrid(np.arange(x_min, x_max, spacing),
                       np.arange(y_min, y_max, spacing))

  data = np.hstack((XX.ravel().reshape(-1,1), 
                    YY.ravel().reshape(-1,1)))

  # For binary problems
  #db_prob = model(Variable(torch.Tensor(data)).cuda() )
  #db_prob = model(torch.Tensor(data).to(device) )
  db_prob = model(torch.Tensor(data).cuda() )
  clf = np.where(db_prob.cpu().data < 0.5,0,1)

  # For multi-class problems
  #db_prob = model(torch.Tensor(data).to(device) )
  #clf = np.argmax(db_prob.cpu().data.numpy(), axis=-1)

  Z = clf.reshape(XX.shape)

  plt.contourf(XX, YY, Z, cmap=plt.cm.brg, alpha=0.5)
  plt.scatter(X[:,0], X[:,1], c=y, edgecolors='k', s=25, cmap=plt.cm.brg)

Quer mergulhar em tecnologia e aprendizagem?

Receba a newsletter que o nosso CEO escreve pessoalmente, com insights do mercado de trabalho, ciência e desenvolvimento de software