April 23, 2019

Decision trees in Python

A decision tree is a supervised learning algorithm used to resolve classification or regression problems. Decision trees map non-linear relationships quite well. In this tutorial we will use a decision tree to resolve a classification problem.

We will use scikit-learn to create the decision tree, you can install it with pip but I recommend using Google Colab since it already has everything we need installed.

Here is the notebook with all the code of this tutorial, you can upload this notebook to Google Colab to run the code.

Overfitting

A common problem with decision trees is overfitting, when you create a decision tree model with the default parameters, the model will learn every single aspect from the dataset, for example if we have registers of people with diabetes the model will only learn from these people, however the model will have problems if we use registers of new people. To avoid this behavior we can put constraints to the model as a result the model will learn to identify the disease for all the people and not only for the people in the dataset.

Decision tree nodes

decision tree

  • Main or Root Node: This node represents the entire population. This node is a decision node as well.
  • Decision Node: This node divides the population into two or more nodes.
  • Leaf Node: These are the last nodes that don't split further.
  • Depth: The depth indicates how many branches a decision tree has, for example the decision tree above has a depth of 4.

Decision tree parameters

These are the most common parameters that we tune in a decision tree model:

  • max_depth: The maximum depth that the decision tree nodes can reach, this can help avoid overfitting but it could also have the contraries effect.

  • min_samples_leaf: The minimum amount of examples that a leaf node must have.

  • min_samples_split: The minimum amount of examples that a decision node must have to perform a split, if the amount is not enough the node remains as a leaf node.

  • criterion: The function that the decision tree will use to perform the splits, we can use Gini or Information gain.

you can see the complete parameter list in the scikit-learn documentation.

Information Gain and Gini Index

As we have seen on scikit-learn we can use two algorithms to perform splits.

Information gain

This algorithm compute the entropy with the following formula:

-(p log2(p)) -(q log2(q))

If we have data about 100 people where 57 have diabetes and 43 don’t:

p is the probability of success or the number of positive cases: 57/100 (0.57).

q is the probability of failure or the number of negative cases: 43/100 (0.43).

Then the formula will be:

-(0.57 log2(0.57)) -(0.43 log2(0.43))

The entropy is 0.98

If the result is 1, the final node will have data from both classes, 50% of each class, if the result is 0, the final node will only have data from one class. Therefore, the decision tree will split the data when the result is 0 or near to 0.

For example if we have the following data:

Sex: Male, Female

  • Male: 35 have diabetes, 22 don’t.
  • Female: 22 have diabetes, 21 don’t.

Diabetes Pedigree Function: Positive, Negative.

  • Positive: 29 have diabetes, 28 don’t.
  • Negative: 28 have diabetes, 15 don’t.

The decision tree will compute the entropy of these features to identify the best split.

Entropy (sex)

Male:

We have 57 men, 35 have diabetes (p), 22 don’t (q)

  • 35/57 = 0.61
  • 22/57 = 0.39
-(0.61 log2(0.61)) -(0.39 log2(0.39))

The entropy is: 0.96

Female:

We have 43 women, 22 have diabetes, 21 don’t:

  • 22/43 = 0.51
  • 21/43 = 0.49
-(0.51 log2(0.51)) -(0.49 log2(0.49))

The entropy is: 0.99

Now we multiply each entropy by the percentage of men and women, finally we sum the result.

Men:

(57/100) * 0.96 = 0.5472

Women:

(43/100) * 0.99 = 0.4257

Sex entropy: 0.97

Entropy (Diabetes Pedigree Function)

Positive:

57 positive diabetes Pedigree Function patients where 29 have diabetes and 28 don't.

  • 29/57 = 0.51
  • 28/57 = 0.49
-(0.51 log2(0.51)) -(0.49 log2(0.49))

The entropy is: 0.99

Negative:

43 negative diabetes Pedigree Function patients where 28 have diabetes and 15 don't.

  • 28/43 = 0.65
  • 15/43 = 0.35
-(0.65 log2(0.65)) -(0.35 log2(0.35))

The entropy is: 0.93

We compute the same operations:

Positive:

(57/100) * 0.99 = 0.5643

Negative:

(43/100) * 0.93 = 0.3999

Diabetes Pedigree Function entropy: 0.9642

We can see that the entropy for Diabetes Pedigree Function is the lowest, therefore the decision tree will split the data on Diabetes Pedigree Function.

Gini index

Gini index uses a different formula:

 p^2 + q^2

p and q are the same probabilities than in the entropy formula.

Gini (sex)

Male:

  • 35/57 = 0.61
  • 22/57 = 0.39
0.61^2 + 0.39^2

The result is: 0.5242

Female:

  • 22/43 = 0.51
  • 21/43 = 0.49
0.51^2 + 0.49^2

The result is 0.5002

We compute the same operations:

Men:

(57/100) * 0.5242 = 0.298794

Women:

(43/100) * 0.5002 = 0.215086

The result for sex split is: 0.51388

Gini (Diabetes Pedigree Function)

Positive:

  • 29/57 = 0.51
  • 28/57 = 0.49
0.51^2 + 0.49^2

The result is: 0.5002

Negative:

  • 28/43 = 0.65
  • 15/43 = 0.35
0.65^2 + 0.35^2

The result is: 0.545

We sum the results:

Positive:

(57/100) * 0.5002 = 0.285114

Negative:

(43/100) * 0.545 = 0.23435

The final result for Diabetes Pedigree Function split is: 0.519464

In this case the result for sex is the lowest, therefore the decision tree will split the data on sex.

The dataset

We will use a dataset that contains registers of people with diabetes you can download the dataset here. We are not going to dive too much into the data exploration step, however this is a key point when you are building a machine learning model, almost all the time the dataset is more important than the model we choose.

This dataset contains information about 768 people, therefore we have 768 rows and 9 columns:

  • Pregnancies
  • Glucose
  • BloodPressure
  • SkinThickness
  • Insulin
  • BMI
  • DiabetesPedigreeFunction
  • Age
  • Outcome

The outcome column indicates if the person has diabetes (1) or not (0). In this case we have a supervised learning problem it means we need labels that indicate the class of each register, we can use the outcome column to obtain these labels, the remaining columns are the data that the decision tree model will use to learn.

Python code

I will use the pandas library to load and handle the dataset, this library is very usefully when you have text data.

from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
import numpy as np
import pandas as pd

We load the dataset with pandas:

dataset = pd.read_csv("diabetes.csv")

We can print the dataset like if we were using excel:

dataset.head()
dataset.shape

We split the dataset into two variables X and y:

features = dataset.drop(["Outcome"], axis=1)
X = np.array(features)
y = np.array(dataset["Outcome"])

X keeps the data that the decision tree will use to learn and y the labels.

Training and validation set

When we are creating a machine learning model it’s important to have multiple data sets:

  • Training set
  • Validation set
  • Test set

We use the training set to train the model, the validation set is used to check how the model works with different parameters and finally the test set is used to measure how the model performs, we use this last set only once at the very end when we know that the model is good enough. Each set must have different registers to see how the model performs with data that has never seen.

Sometimes we don’t have enough data to split the dataset into three different sets, therefore we use the training set to train the model and the validation set to test different parameters and measure how the model performs.

We can split the dataset with the following scikit-learn function:

X_train, X_val, y_train, y_val = train_test_split(X, y, random_state=0, test_size=0.20)

We used the 20% of the dataset to build the validation set.

Model

We can create a decision tree model with the following code:

tree = DecisionTreeClassifier()
tree.fit(X_train, y_train)

We train the model with the method fit.

We can see the depth of the tree:

tree.tree_.max_depth

In this case the depth is 15, we have a high probability of overfitting.

Accuracy

We will use the validation set to measure the model’s performance

validation_prediction = tree.predict(X_val)
training_prediction = tree.predict(X_train)

Now validation_prediction and training_prediction variables have the classes predicted by the model.

training_prediction0 has the prediction for the first person that appears in the training dataset (X_train) whereas validation_prediction0 has the prediction for the first person that appears in the validation dataset (X_val)

print('Accuracy training set: ', accuracy_score(y_true=y_train, y_pred=training_prediction))
print('Accuracy validation set: ', accuracy_score(y_true=y_val, y_pred=validation_prediction))
Accuracy training set:  1.0
Accuracy validation set:  0.7922077922077922

The training set has an accuracy of 100% and the validation set has an accuracy of 79% this means we have an overfitting problem, the model is very good predicting the registers of the training set but is not that good with the validation set.

Printing the decision tree

We can print the decision tree to obtain a visual representation.

We need to install a library called graphviz, on linux you can use the following command:

apt-get install graphviz

You need yo install the library on python as well:

!pip install graphviz

Now we can use it

import graphviz 
from sklearn.tree import export_graphviz

feature_names = features.columns

dot_data = export_graphviz(tree, out_file=None, 
                         feature_names=feature_names,  
                         class_names=True,  
                         filled=True, rounded=True,  
                         special_characters=True)  
graph = graphviz.Source(dot_data)

graph

feature_names contains the names of each column, the function uses this variable to print the names of the columns in each node

You can see the tree in this link.

The second model

We need a better model to avoid overfitting:

tree = DecisionTreeClassifier(min_samples_leaf=10, max_depth=8, min_samples_split=50)

This time we will use different parameters to put constrains to the model, we want a tree with a maximum depth of 8, each leaf node must have at least 10 examples and a decision node must have at least 50 examples to perform a split.

We train this new model:

tree.fit(X_train, y_train)
validation_prediction = tree.predict(X_val)
training_prediction = tree.predict(X_train)

We measure the model’s performance:

print('Accuracy training set: ', accuracy_score(y_true=y_train, y_pred=training_prediction))
print('Accuracy validation set: ', accuracy_score(y_true=y_val, y_pred=validation_prediction))
Accuracy training set:  0.7964169381107492
Accuracy validation set:  0.8116883116883117

Now we have an equal accuracy, sometimes we can not reach an accuracy of 100% in both sets. In this particular problem we could try different parameters and explore the dataset to improve it.

Important things about decision trees

  • When the classes are not well separated decision trees are susceptible to overfitting.
  • Decision trees work poorly with high dimensional datasets (datasets with a lot of features).
  • You can use decision trees when you have a small number of features
  • Decision trees are insensitive to outliers.
  • Changes in the input data can affect the information gain algorithm and cause changes in the tree.
  • Decision trees are not good when we have a regression problem, continues variables have an infinite number of values and decision trees only have a small amount of leaves nodes that output the result.
  • The top nodes are basically the most important features in the dataset, knowing this you can select the most relevant features of the dataset.
  • Some nodes can be duplicated, leading to complex trees.
  • Every feature in the tree must interact with every feature of the previous nodes, this works poorly when we have features that have no or weak interactions.

Categories