A função plot_boundary
descrita na aula (e mostrada abaixo) faz referência a problemas "binários" (com dois labels no target, ou classe).
O fato é que ao descomentar as linhas relativas aos problemas binários (e, obviamente, comentar a parte para multi-class) na execução da célula dá o seguinte erro: NameError: name 'Variable' is not defined
como se, a referência Variable
não existisse ou algo do tipo.
É um problema da função mesmo ou de como eu preparei meus dados?
import numpy as np
def plot_boundary(X, y, model):
x_min, x_max = X[:, 0].min()-0.1, X[:, 0].max()+0.1
y_min, y_max = X[:, 1].min()-0.1, X[:, 1].max()+0.1
spacing = min(x_max - x_min, y_max - y_min) / 100
XX, YY = np.meshgrid(np.arange(x_min, x_max, spacing),
np.arange(y_min, y_max, spacing))
data = np.hstack((XX.ravel().reshape(-1,1),
YY.ravel().reshape(-1,1)))
# For binary problems
# db_prob = model(Variable(torch.Tensor(data)).cuda() )
# clf = np.where(db_prob.cpu().data < 0.5,0,1)
# For multi-class problems
db_prob = model(torch.Tensor(data).to(device) )
clf = np.argmax(db_prob.cpu().data.numpy(), axis=-1)
Z = clf.reshape(XX.shape)
plt.contourf(XX, YY, Z, cmap=plt.cm.brg, alpha=0.5)
plt.scatter(X[:,0], X[:,1], c=y, edgecolors='k', s=25, cmap=plt.cm.brg)
Alguma ajuda?