Vicente Rodríguez

April 23, 2019

Decision trees in Python

A decision tree is a supervised learning algorithm used to resolve classification or regression problems. Decision trees map non-linear relationships quite well. In this tutorial we will use a decision tree to resolve a classification problem.

We will use scikit-learn to create the decision tree, you can install it with pip but I recommend using Google Colab since it already has everything we need installed.

Here is the notebook with all the code of this tutorial, you can upload this notebook to Google Colab to run the code.


A common problem with decision trees is overfitting, when you create a decision tree model with the default parameters, the model will learn every single aspect from the dataset, for example if we have registers of people with diabetes the model will only learn from these people, however the model will have problems if we use registers of new people. To avoid this behavior we can put constraints to the model as a result the model will learn to identify the disease for all the people and not only for the people in the dataset.

Decision tree nodes

decision tree

Decision tree parameters

These are the most common parameters that we tune in a decision tree model:

you can see the complete parameter list in the scikit-learn documentation.

Information Gain and Gini Index

As we have seen on scikit-learn we can use two algorithms to perform splits.

Information gain

This algorithm compute the entropy with the following formula:

-(p log2(p)) -(q log2(q))

If we have data about 100 people where 57 have diabetes and 43 don’t:

p is the probability of success or the number of positive cases: 57/100 (0.57).

q is the probability of failure or the number of negative cases: 43/100 (0.43).

Then the formula will be:

-(0.57 log2(0.57)) -(0.43 log2(0.43))

The entropy is 0.98

If the result is 1, the final node will have data from both classes, 50% of each class, if the result is 0, the final node will only have data from one class. Therefore, the decision tree will split the data when the result is 0 or near to 0.

For example if we have the following data:

Sex: Male, Female

Diabetes Pedigree Function: Positive, Negative.

The decision tree will compute the entropy of these features to identify the best split.

Entropy (sex)


We have 57 men, 35 have diabetes (p), 22 don’t (q)

-(0.61 log2(0.61)) -(0.39 log2(0.39))

The entropy is: 0.96


We have 43 women, 22 have diabetes, 21 don’t:

-(0.51 log2(0.51)) -(0.49 log2(0.49))

The entropy is: 0.99

Now we multiply each entropy by the percentage of men and women, finally we sum the result.


(57/100) * 0.96 = 0.5472


(43/100) * 0.99 = 0.4257

Sex entropy: 0.97

Entropy (Diabetes Pedigree Function)


57 positive diabetes Pedigree Function patients where 29 have diabetes and 28 don't.

-(0.51 log2(0.51)) -(0.49 log2(0.49))

The entropy is: 0.99


43 negative diabetes Pedigree Function patients where 28 have diabetes and 15 don't.

-(0.65 log2(0.65)) -(0.35 log2(0.35))

The entropy is: 0.93

We compute the same operations:


(57/100) * 0.99 = 0.5643


(43/100) * 0.93 = 0.3999

Diabetes Pedigree Function entropy: 0.9642

We can see that the entropy for Diabetes Pedigree Function is the lowest, therefore the decision tree will split the data on Diabetes Pedigree Function.

Gini index

Gini index uses a different formula:

 p^2 + q^2

p and q are the same probabilities than in the entropy formula.

Gini (sex)


0.61^2 + 0.39^2

The result is: 0.5242


0.51^2 + 0.49^2

The result is 0.5002

We compute the same operations:


(57/100) * 0.5242 = 0.298794


(43/100) * 0.5002 = 0.215086

The result for sex split is: 0.51388

Gini (Diabetes Pedigree Function)


0.51^2 + 0.49^2

The result is: 0.5002


0.65^2 + 0.35^2

The result is: 0.545

We sum the results:


(57/100) * 0.5002 = 0.285114


(43/100) * 0.545 = 0.23435

The final result for Diabetes Pedigree Function split is: 0.519464

In this case the result for sex is the lowest, therefore the decision tree will split the data on sex.

The dataset

We will use a dataset that contains registers of people with diabetes you can download the dataset here. We are not going to dive too much into the data exploration step, however this is a key point when you are building a machine learning model, almost all the time the dataset is more important than the model we choose.

This dataset contains information about 768 people, therefore we have 768 rows and 9 columns:

The outcome column indicates if the person has diabetes (1) or not (0). In this case we have a supervised learning problem it means we need labels that indicate the class of each register, we can use the outcome column to obtain these labels, the remaining columns are the data that the decision tree model will use to learn.

Python code

I will use the pandas library to load and handle the dataset, this library is very usefully when you have text data.

from sklearn.model_selection import train_test_split

from sklearn.tree import DecisionTreeClassifier

import numpy as np

import pandas as pd

We load the dataset with pandas:

dataset = pd.read_csv("diabetes.csv")

We can print the dataset like if we were using excel:



We split the dataset into two variables X and y:

features = dataset.drop(["Outcome"], axis=1)

X = np.array(features)

y = np.array(dataset["Outcome"])

X keeps the data that the decision tree will use to learn and y the labels.

Training and validation set

When we are creating a machine learning model it’s important to have multiple data sets:

We use the training set to train the model, the validation set is used to check how the model works with different parameters and finally the test set is used to measure how the model performs, we use this last set only once at the very end when we know that the model is good enough. Each set must have different registers to see how the model performs with data that has never seen.

Sometimes we don’t have enough data to split the dataset into three different sets, therefore we use the training set to train the model and the validation set to test different parameters and measure how the model performs.

We can split the dataset with the following scikit-learn function:

X_train, X_val, y_train, y_val = train_test_split(X, y, random_state=0, test_size=0.20)

We used the 20% of the dataset to build the validation set.


We can create a decision tree model with the following code:

tree = DecisionTreeClassifier(), y_train)

We train the model with the method fit.

We can see the depth of the tree:


In this case the depth is 15, we have a high probability of overfitting.


We will use the validation set to measure the model’s performance

validation_prediction = tree.predict(X_val)

training_prediction = tree.predict(X_train)

Now validation_prediction and training_prediction variables have the classes predicted by the model.

training_prediction[0] has the prediction for the first person that appears in the training dataset (X_train) whereas validation_prediction[0] has the prediction for the first person that appears in the validation dataset (X_val)

print('Accuracy training set: ', accuracy_score(y_true=y_train, y_pred=training_prediction))

print('Accuracy validation set: ', accuracy_score(y_true=y_val, y_pred=validation_prediction))

Accuracy training set:  1.0

Accuracy validation set:  0.7922077922077922

The training set has an accuracy of 100% and the validation set has an accuracy of 79% this means we have an overfitting problem, the model is very good predicting the registers of the training set but is not that good with the validation set.

Printing the decision tree

We can print the decision tree to obtain a visual representation.

We need to install a library called graphviz, on linux you can use the following command:

apt-get install graphviz

You need yo install the library on python as well:

!pip install graphviz

Now we can use it

import graphviz 

from sklearn.tree import export_graphviz

feature_names = features.columns

dot_data = export_graphviz(tree, out_file=None, 



                         filled=True, rounded=True,  


graph = graphviz.Source(dot_data)


feature_names contains the names of each column, the function uses this variable to print the names of the columns in each node

You can see the tree in this link.

The second model

We need a better model to avoid overfitting:

tree = DecisionTreeClassifier(min_samples_leaf=10, max_depth=8, min_samples_split=50)

This time we will use different parameters to put constrains to the model, we want a tree with a maximum depth of 8, each leaf node must have at least 10 examples and a decision node must have at least 50 examples to perform a split.

We train this new model:, y_train)

validation_prediction = tree.predict(X_val)

training_prediction = tree.predict(X_train)

We measure the model’s performance:

print('Accuracy training set: ', accuracy_score(y_true=y_train, y_pred=training_prediction))

print('Accuracy validation set: ', accuracy_score(y_true=y_val, y_pred=validation_prediction))

Accuracy training set:  0.7964169381107492

Accuracy validation set:  0.8116883116883117

Now we have an equal accuracy, sometimes we can not reach an accuracy of 100% in both sets. In this particular problem we could try different parameters and explore the dataset to improve it.

Important things about decision trees