Erro na hora do plot3d
X = torch.tensor([0,1, 2])
y = perceptron(X)
plot3d(perceptron)
plt.plot([X[0]], [X[1]], [X[2]], marker='^', markersize = 20)
E a função:
def plot3d(perceptron):
w1, w2, w3 = perceptron.weight.data.numpy()[0]
b = perceptron.bias.data.numpy()
X1 = np.linspace(-1, 1, 10)
X2 = np.linspace(-1, 1, 10)
X1, X2 = np.meshgrid(X1, X2)
X3 = (b - w1 * X1 - w2 * X2) / w3
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(projection = '3d')
ax.view_init(azim=180)
ax.plot_surface(X1, X2, X3, cmap = 'plasma')
Se alguém souber como soluciona