Decision Tree Classifier

Anh-Thi Dinh

What's the idea of Decision Tree Classifier?

The basic intuition behind a decision tree is to map out all possible decision paths in the form of a tree. It can be used for classification and regression (Decision Tree Regression ). In this post, let's try to understand the classifier.
Suppose that we have a dataset like in the figure below (ref)
An example of dataset .
A decision tree we want.
There are many algorithms which can help us make a tree like above, in Machine Learning, we usually use:
  • ID3 (Iterative Dichotomiser): uses information gain / entropy.
  • CART (Classification And Regression Tree): uses Gini impurity.

Some basic concepts

  • Splitting: It is a process of dividing a node into two or more sub-nodes.
  • Pruning: When we remove sub-nodes of a decision node, this process is called pruning.
  • Parent node and Child Node: A node, which is divided into sub-nodes is called parent node of sub-nodes where as sub-nodes are the child of parent node.

ID3 algorithm

CART algorithm

Gini Impurity or Entropy?

Some points:(ref)
  • Most of the time, they lead to similar trees. (ref)
  • Gini impurity is slightly faster. (ref)
  • Gini impurity tends to isolate the most frequent class in its own branch of the tree, while entropy tends to produce slightly more balanced trees.

Good / Bad of Decision Tree?

Some highlight advantages of Decision Tree Classifier: (ref)
  1. Can be used for regression or classification.
  1. Can be displayed graphically.
  1. Highly interpretable.
  1. Can be specified as a series of rules, and more closely approximate human decision-making than other models.
  1. Prediction is fast.
  1. Features don't need scaling.
  1. Automatically learns feature interactions.
  1. Tends to ignore irrelevant features.
  1. Non-parametric (will outperform linear models if relationship between features and response is highly non-linear).
Its disadvantages:
  1. Performance is (generally) not competitive with the best supervised learning methods.
  1. Can easily overfit the training data (tuning is required).
  1. Small variations in the data can result in a completely different tree (high variance).
  1. Recursive binary splitting makes "locally optimal" decisions that may not result in a globally optimal tree.
  1. Doesn't work well with unbalanced or small datasets.

When to stop?

If the number of features are too large, we'll have a very large tree! Even, it easily leads to an overfitting problem (check Undefitting & Overfitting). How to avoid them?
  1. Pruning: removing the branches that make use of features having low importance.
  1. Set a minimum number of training input to use on each leaf. If it doesn't satisfy, we remove this leaf. In scikit-learn, use min_samples_split.
  1. Set the maximum depth of the tree. In scikit-learn, use max_depth.

When we need to use Decision Tree?

  • When explainability between variable is prioritised over accuracy. Otherwise, we tend to use Random Forest .
  • When the data is more non-parametric in nature.
  • When we want a simple model.
  • When entire dataset and features can be used
  • When we have limited computational power
  • When we are not worried about accuracy on future datasets.
  • When we are not worried about accuracy on future datasets.

Using Decision Tree Classifier with Scikit-learn

Load and create

Load the library,
1from sklearn.tree import DecisionTreeClassifier
Create a decision tree (other parameters):
1# The Gini impurity (default)
2clf = DecisionTreeClassifier() # criterion='gini'
3# The information gain (ID3)
4clf = DecisionTreeClassifier(criterion='entropy')
An example,
1from sklearn import tree
2X = [[0, 0], [1, 1]]
3Y = [0, 1]
4clf = tree.DecisionTreeClassifier()
5clf =, Y)
6# predict
7clf.predict([[2., 2.]])
8# probability of each class
9clf.predict_proba([[2., 2.]])
1# output
3array([[0., 1.]])

Plot and Save plots

Plot the tree (You may need to install Graphviz first. Don't forget to add its installed folder to $path),
1from IPython.display import Image
2import pydotplus
3dot_data = tree.export_graphviz(clf, out_file=None,
4                                rounded=True,
5                                filled=True)
6graph = pydotplus.graph_from_dot_data(dot_data)
Save the tree (follows the codes in "plot the tree")
1graph.write_pdf("tree.pdf")   # to pdf
2graph.write_png("thi.png")    # to png


Loading comments...