Prezados,
estou tendo problemas ao utilizar o DecisionTreeClassifier.
Executo meu código, porém ao realizar o fit deste modelo, recebo a seguinte exceção:
>>> classifier_model = tree.DecisionTreeClassifier()
>>> classifier_model.fit(train, train_rates)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "venv/lib/python3.4/site-packages/sklearn/tree/tree.py", line 790, in fit
X_idx_sorted=X_idx_sorted)
File "venv/lib/python3.4/site-packages/sklearn/tree/tree.py", line 140, in fit
check_classification_targets(y)
File "venv/lib/python3.4/site-packages/sklearn/utils/multiclass.py", line 172, in check_classification_targets
raise ValueError("Unknown label type: %r" % y_type)
ValueError: Unknown label type: 'continuous'
Alguém já passou por isso? Não entendi o que aconteceu e não consigo resolver. Segue o código que foi executado:
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn import tree
from sklearn.metrics import accuracy_score
movies = pd.read_csv('data/movies_multilinear_regression.csv')
total_columns = len(movies.columns)
movies_independent_variable = movies[movies.columns[2:total_columns - 1]]
movies_result_variable = movies[movies.columns[-1:]]
train, test, train_rates, test_rates = train_test_split(
movies_independent_variable, movies_result_variable)
train_columns_count = len(train.columns)
train = np.array(train).reshape(len(train), train_columns_count)
test = np.array(test).reshape(len(test), train_columns_count)
train_rates = train_rates.values.ravel()
test_rates = test_rates.values.ravel()
regression_model = tree.DecisionTreeRegressor(max_depth=5)
regression_model.fit(train, train_rates)
regression_model.score(train, train_rates)
regression_model.score(test, test_rates)
classifier_model = tree.DecisionTreeClassifier()
classifier_model.fit(train, train_rates) #linha com erro
classifier_model.score(train, train_rates)
classifier_model.score(test, test_rates)
Estou usando as seguintes versões de bibliotecas:
python 3.4.3
cycler==0.10.0
matplotlib==2.1.2
numpy==1.14.0
pandas==0.22.0
pyparsing==2.2.0
python-dateutil==2.6.1
pytz==2017.3
scikit-learn==0.19.1
scipy==1.0.0
six==1.11.0