The basic intuition behind a decision tree is to map out all possible decision paths in the form of a tree. It can be used for classification and regression (Decision Tree Regression ). In this post, let's try to understand the classifier.
Suppose that we have a dataset like in the figure below (ref)
There are many algorithms which can help us make a tree like above, in Machine Learning, we usually use:
- ID3 (Iterative Dichotomiser): uses information gain / entropy.
- CART (Classification And Regression Tree): uses Gini impurity.
- Splitting: It is a process of dividing a node into two or more sub-nodes.
- Pruning: When we remove sub-nodes of a decision node, this process is called pruning.
- Parent node and Child Node: A node, which is divided into sub-nodes is called parent node of sub-nodes where as sub-nodes are the child of parent node.
Some points:(ref)
- Most of the time, they lead to similar trees. (ref)
- Gini impurity is slightly faster. (ref)
- Gini impurity tends to isolate the most frequent class in its own branch of the tree, while entropy tends to produce slightly more balanced trees.
Some highlight advantages of Decision Tree Classifier: (ref)
- Can be used for regression or classification.
- Can be displayed graphically.
- Highly interpretable.
- Can be specified as a series of rules, and more closely approximate human decision-making than other models.
- Prediction is fast.
- Features don't need scaling.
- Automatically learns feature interactions.
- Tends to ignore irrelevant features.
- Non-parametric (will outperform linear models if relationship between features and response is highly non-linear).
Its disadvantages:
- Performance is (generally) not competitive with the best supervised learning methods.
- Can easily overfit the training data (tuning is required).
- Small variations in the data can result in a completely different tree (high variance).
- Recursive binary splitting makes "locally optimal" decisions that may not result in a globally optimal tree.
- Doesn't work well with unbalanced or small datasets.
If the number of features are too large, we'll have a very large tree! Even, it easily leads to an overfitting problem (check Undefitting & Overfitting). How to avoid them?
- Pruning: removing the branches that make use of features having low importance.
- Set a minimum number of training input to use on each leaf. If it doesn't satisfy, we remove this leaf. In scikit-learn, use
min_samples_split
.
- Set the maximum depth of the tree. In scikit-learn, use
max_depth
.
- When explainability between variable is prioritised over accuracy. Otherwise, we tend to use Random Forest .
- When the data is more non-parametric in nature.
- When we want a simple model.
- When entire dataset and features can be used
- When we have limited computational power
- When we are not worried about accuracy on future datasets.
- When we are not worried about accuracy on future datasets.
Load the library,
1from sklearn.tree import DecisionTreeClassifier
Create a decision tree (other parameters):
1# The Gini impurity (default)
2clf = DecisionTreeClassifier() # criterion='gini'
3# The information gain (ID3)
4clf = DecisionTreeClassifier(criterion='entropy')
An example,
1from sklearn import tree
2X = [[0, 0], [1, 1]]
3Y = [0, 1]
4clf = tree.DecisionTreeClassifier()
5clf = clf.fit(X, Y)
6# predict
7clf.predict([[2., 2.]])
8# probability of each class
9clf.predict_proba([[2., 2.]])
1# output
2array([1])
3array([[0., 1.]])
Plot the tree (You may need to install Graphviz first. Don't forget to add its installed folder to
$path
),1from IPython.display import Image
2import pydotplus
3dot_data = tree.export_graphviz(clf, out_file=None,
4 rounded=True,
5 filled=True)
6graph = pydotplus.graph_from_dot_data(dot_data)
7Image(graph.create_png())
Save the tree (follows the codes in "plot the tree")
1graph.write_pdf("tree.pdf") # to pdf
2graph.write_png("thi.png") # to png
- Scikit-learn. Decision Tree CLassifier official doc.
- Saed Sayad. Decision Tree - Classification.
- Brian Ambielli. Information Entropy and Information Gain.
- Brian Ambielli. Gini Impurity (With Examples).
- Aurélien Géron. Hands-on Machine Learning with Scikit-Learn and TensorFlow, chapter 6.