Decision Tree Regression

Last modified 3 years ago / Edit on Github
Danger icon
The last modifications of this post were around 3 years ago, some information may be outdated!

What's the idea of Decision Tree Regression?

The basic intuition behind a decision tree is to map out all possible decision paths in the form of a tree. It can be used for classification (note) and regression. In this post, let's try to understand the regression.

DT Regression is similar to DT Classification, however we use Mean Square Error (MSE, default) or Mean Absolute Error (MAE) instead of cross-entropy or Gini impurity to determine splits.

MSE=1ni=1n(yiyˉi)2,MAE=1ni=1nyiyˉi.\begin{aligned} \text{MSE} &= \frac{1}{n} \sum_{i=1}^{n} (y_i - \bar{y}_i)^2, \\ \text{MAE} &= \frac{1}{n}\sum_{i=1}^n \vert y_i - \bar{y}_i \vert. \end{aligned}

Suppose that we have a dataset SS like in the figure below,

Example of dataset
An example of dataset SS.

Example of dataset
A decision tree we want.

Some basic concepts

Concepts with a tree.

  • Splitting: It is a process of dividing a node into two or more sub-nodes.
  • Pruning: When we remove sub-nodes of a decision node, this process is called pruning.
  • Parent node and Child Node: A node, which is divided into sub-nodes is called parent node of sub-nodes where as sub-nodes are the child of parent node.
Success icon

Other aspects of decision tree algorithm, check this note.

Warning icon

Looking for an example like in the post of decision tree classifier? Check this! Below are a short algorithm,

  1. Calculate the Standard Deviation (SDSD) of the current node (let's say SS, parent node) by using MSE or MAE,

    SD(S)=1ni=1n(yiyˉi)2,or SD(S)=1ni=1nyiyˉi,\begin{aligned}SD(S) &= \frac{1}{n} \sum_{i=1}^{n} (y_i - \bar{y}_i)^2, \\\text{or } SD(S) &= \frac{1}{n}\sum_{i=1}^n \vert y_i - \bar{y}_i \vert,\end{aligned}

    where yiy_i\in the target values (Hours Played in the above example), yˉ=Σyn\bar{y}=\frac{\Sigma y}{n} is the mean value and nn is the number of examples in this node.

  2. Check the stopping conditions (we don't need to make any split at this node) to stop the split and this node becomes a leaf node. Otherwise, go to step 3.

    • The minimum number of samples required to split an internal node, use min_samples_split in scikit-learn.
    • The maximum depth of the tree, use max_depth in scikit-learn.
    • A node will be split if this split induces a decrease of the impurity greater than or equal to this value, use min_impurity_decrease in scikit-learn.
    • Its coefficient of variation (SD(S)yˉ\frac{SD(S)}{\bar{y}}) is less than a certain threshold.
  3. Calculate the Standard Deviation Reduction (SDR) after splitting node SS on each attribute (for example, consider attribute OO). The attribute w.r.t. the biggest SDR will be chosen!

    SDR(S,O)Standard Deviation Reduction=SD(S)SD before splitjP(OjS)×SD(S,Oj)weighted SD after split\underbrace{SDR(S,O)}_{\text{Standard Deviation Reduction}}= \underbrace{SD(S)}_{\text{SD before split}}- \underbrace{\sum_j P(O_j | S) \times SD(S,O_j)}_{\text{weighted SD after split}}

    where jj \in number of different properties in OO and P(Oj)P(O_j) is the propability of property OjO_j in OO. Note that, SD(S,Oj)SD(S,O_j) means the SD of node OjO_j which is also a child of node SS.

  4. After splitting, we have new child nodes. Each of them becomes a new parent node in the next step. Go back to step 1.

Using Decision Tree Regression with Scikit-learn

Load and create

Load the library,

from sklearn.tree import DecisionTreeRegressor

Create a decision tree (other parameters):

# mean squared error (default)
reg = DecisionTreeRegressor() # criterion='mse'
# mean absolute error
reg = DecisionTreeRegressor(criterion='mae')

An example,

from sklearn import tree
X = [[0, 0], [2, 2]]
y = [0.5, 2.5]
reg = tree.DecisionTreeRegressor()
reg =, y) # train

Plot and save plots

Plot the tree (You may need to install Graphviz first. Don't forget to add its installed folder to $path),

from IPython.display import Image
import pydotplus
dot_data = tree.export_graphviz(reg, out_file=None,
graph = pydotplus.graph_from_dot_data(dot_data)

An example.

Save the tree (follows the codes in "plot the tree")

graph.write_pdf("tree.pdf")   # to pdf
graph.write_png("thi.png") # to png


💬 Comments

Support Thi Support Thi