Decision Tree Learning Algorithm – Part 1/2
Before we dive deeper into parameter tuning, which is the topic of Part 8 of this chapter (Decision Trees and Ensemble Learning), let’s take a step back for a moment and explore how a decision tree can generate rules, as we’ve seen in the last article.
In this article, we simplify the problem by using a much smaller dataset with just one feature. Additionally, we employ a very small tree, known as a “decision stump.” This approach offers a clearer understanding of how a decision tree learns from data.
So, let’s begin with a decision stump, which is essentially a decision tree with only one level. This means we have a condition node with the condition inside. Conditions are in the form of “feature > T”, where T represents a threshold. The condition “feature > T” is referred to as a “split”. This split divides the data into records that satisfy the condition (TRUE) and those that do not (FALSE). At the top of the tree, we have the condition node, and at the bottom, there are the leaves of the tree. These leaves are called decision nodes. Decision nodes are where the tree doesn’t go any deeper, and this is where decisions are made. In this context, “decision making” means that we choose the most frequently occurring status label within each group and use it as the final decision for that group.
FEATURE > T
/ \
FALSE TRUE
DEFAULT OK
To illustrate how the learning algorithm works, let’s use a simple dataset. This dataset contains 8 records with one value for ‘assets’ and one for ‘status,’ which is our target variable.
data = [
[8000, 'default'],
[2000, 'default'],
[ 0, 'default'],
[5000, 'ok'],
[5000, 'ok'],
[4000, 'ok'],
[9000, 'ok'],
[3000, 'default'],
]
df_example = pd.DataFrame(data, columns=['assets', 'status'])
df_example
| assets | status | |
|---|---|---|
| 0 | 8000 | default |
| 1 | 2000 | default |
| 2 | 0 | default |
| 3 | 5000 | ok |
| 4 | 5000 | ok |
| 5 | 4000 | ok |
| 6 | 9000 | ok |
| 7 | 3000 | default |
Finding the best split for one column
What we want to do now is, we want to use this numerical column ‘assets’ in our training. We want to train a decision tree using this column. So we want to train a model based on the condition ASSETS > T The question is: What is the best T? So what is the best split? We want to split the dataset into two parts. One part where the condition holds true, and the other part where the condition does not hold true.
ASSETS > T
/ \
FALSE TRUE
DF_LEFT DF_RIGHT
DF_RIGHT is where this condition is true, so for all records where ‘assets’ is greater than T. DF_LEFT is where this condition is false, so for all records where ‘assets’ is not greater than T.
This is called splitting, so we’re splitting the data set into two parts where it’s true and where it’s false.
What we can do now is sorting the data by the ‘assets’ column.
df_example.sort_values('assets')
| Assets | Status | |
|---|---|---|
| 2 | 0 | default |
| 1 | 2000 | default |
| 7 | 3000 | default |
| 5 | 4000 | ok |
| 3 | 5000 | ok |
| 4 | 5000 | ok |
| 0 | 8000 | default |
| 6 | 9000 | ok |
We are considering different threshold values “T” for the condition “ASSETS > T.” The possibilities for the threshold “T” are:
- T = 2000 (ASSETS<=2,000 becomes LEFT, ASSETS>2,000 becomes RIGHT)
- T = 3000 (ASSETS<=3,000 becomes LEFT, ASSETS>3,000 becomes RIGHT)
- T = 4000 (ASSETS<=4,000 becomes LEFT, ASSETS>4,000 becomes RIGHT)
- T = 5000 (ASSETS<=5,000 becomes LEFT, ASSETS>5,000 becomes RIGHT)
- T = 8000 (ASSETS<=8,000 becomes LEFT, ASSETS>8,000 becomes RIGHT)
It’s important to note that values like T = 9000 and T = 0 don’t make sense for splitting because there would be nothing on the right or left side, respectively.
Ts = [0, 2000, 3000, 4000, 5000, 8000]
We want to evaluate each of these threshold values “T” by splitting our dataset into the left (LEFT) and right (RIGHT) subsets and determine which split provides the best results.
from IPython.display import display
for T in Ts:
print(T)
df_left = df_example[df_example.assets <= T]
df_right = df_example[df_example.assets > T]
display(df_left)
display(df_right)
print()
0
| Assets | status | |
|---|---|---|
| 2 | 0 | default |
| Assets | Status | |
|---|---|---|
| 0 | 8000 | default |
| 1 | 2000 | default |
| 3 | 5000 | ok |
| 4 | 5000 | ok |
| 5 | 4000 | ok |
| 6 | 9000 | ok |
| 7 | 3000 | default |
2000
| Assets | Status | |
|---|---|---|
| 1 | 2000 | default |
| 2 | 0 | default |
| Assets | status | |
|---|---|---|
| 0 | 8000 | default |
| 3 | 5000 | ok |
| 4 | 5000 | ok |
| 5 | 4000 | ok |
| 6 | 9000 | ok |
| 7 | 3000 | default |
3000
| Assets | Status | |
|---|---|---|
| 1 | 2000 | default |
| 2 | 0 | default |
| 7 | 3000 | default |
| assets | status | |
|---|---|---|
| 0 | 8000 | default |
| 3 | 5000 | ok |
| 4 | 5000 | ok |
| 5 | 4000 | ok |
| 6 | 9000 | ok |
4000
| assets | status | |
|---|---|---|
| 1 | 2000 | default |
| 2 | 0 | default |
| 5 | 4000 | ok |
| 7 | 3000 | default |
| assets | status | |
|---|---|---|
| 0 | 8000 | default |
| 3 | 5000 | ok |
| 4 | 5000 | ok |
| 6 | 9000 | ok |
5000
| assets | status | |
|---|---|---|
| 1 | 2000 | default |
| 2 | 0 | default |
| 3 | 5000 | ok |
| 4 | 5000 | ok |
| 5 | 4000 | ok |
| 7 | 3000 | default |
| assets | status | |
|---|---|---|
| 0 | 8000 | default |
| 6 | 9000 | ok |
8000
| assets | status | |
|---|---|---|
| 0 | 8000 | default |
| 1 | 2000 | default |
| 2 | 0 | default |
| 3 | 5000 | ok |
| 4 | 5000 | ok |
| 5 | 4000 | ok |
| 7 | 3000 | default |
| ASSETS | STATUS | |
|---|---|---|
| 6 | 9000 | ok |
Now that we have generated many splits using different threshold values, we need to determine which split is the best. To evaluate this, we can use various split evaluation criteria. Let’s take the example where T=4000. In this case, on the LEFT side, we have 3 cases labeled as DEFAULT and 1 case labeled as OK. On the RIGHT side, we have 3 cases labeled as OK and 1 case labeled as DEFAULT. As mentioned earlier, we select the most frequently occurring status within each group and use it as the final decision. In the case of the LEFT group, the most frequent status is “default,” and for the RIGHT group, it’s “ok.”
T = 4000
df_left = df_example[df_example.assets <= T]
df_right = df_example[df_example.assets > T]
display(df_left)
display(df_right)
| Assets | Status | |
|---|---|---|
| 1 | 2000 | default |
| 2 | 0 | default |
| 5 | 4000 | ok |
| 7 | 3000 | default |
| ASSETS | STATUS | |
|---|---|---|
| 0 | 8000 | default |
| 3 | 5000 | ok |
| 4 | 5000 | ok |
| 6 | 9000 | ok |
ASSETS > 4000
/ \
FALSE TRUE
DEFAULT OK
Misclassification Rate
To evaluate the accuracy of our predictions, we can calculate the misclassification rate. The misclassification rate measures the fraction of errors when we predict that everyone in the LEFT group is a DEFAULT, and similarly for the RIGHT group. It gives us insight into how well our model is performing.
For the LEFT group (3xDEFAULT, 1xOK), where we predict everything as DEFAULT, the misclassification rate is 1/4, which equals 25%.
For the RIGHT group (3xOK, 1xDEFAULT), where we predict everything as OK, the misclassification rate is also 1/4, or 25%.
This is how we evaluate the quality of our split. For T=4000, the average misclassification rate is 25%. We can also use a weighted average, which is particularly useful when there are significantly more examples on one side than on the other. You can print the normalized value counts as shown in the next snippet.
T = 4000
df_left = df_example[df_example.assets <= T]
df_right = df_example[df_example.assets > T]
display(df_left)
print(df_left.status.value_counts(normalize=True))
display(df_right)
print(df_left.status.value_counts(normalize=True))
| ASSETS | Status | |
|---|---|---|
| 1 | 2000 | default |
| 2 | 0 | default |
| 5 | 4000 | ok |
| 7 | 3000 | default |
default 0.75
ok 0.25
Name: status, dtype: float64
| ASSETS | STATUS | |
|---|---|---|
| 0 | 8000 | default |
| 3 | 5000 | ok |
| 4 | 5000 | ok |
| 6 | 9000 | ok |
default 0.75
ok 0.25
Name: status, dtype: float64
Let’s print out the average misclassification rate for each split. You can use the code implemented in the snippet below to calculate and display this information.
for T in Ts:
print(T)
df_left = df_example[df_example.assets <= T]
df_right = df_example[df_example.assets > T]
display(df_left)
print(df_left.status.value_counts(normalize=True))
display(df_right)
print(df_right.status.value_counts(normalize=True))
print()
0
| ASSETS | STATUS | |
|---|---|---|
| 2 | 0 | default |
default 1.0
Name: status, dtype: float64
| ASSETS | STATUS | |
|---|---|---|
| 0 | 8000 | default |
| 1 | 2000 | default |
| 3 | 5000 | ok |
| 4 | 5000 | ok |
| 5 | 4000 | ok |
| 6 | 9000 | ok |
| 7 | 3000 | default |
ok 0.571429
default 0.428571
Name: status, dtype: float64
2000
| ASSETS | STATUS | |
|---|---|---|
| 1 | 2000 | default |
| 2 | 0 | default |
default 1.0
Name: status, dtype: float64
| ASSETS | STATUS | |
|---|---|---|
| 0 | 8000 | default |
| 3 | 5000 | ok |
| 4 | 5000 | ok |
| 5 | 4000 | ok |
| 6 | 9000 | ok |
| 7 | 3000 | default |
ok 0.666667
default 0.333333
Name: status, dtype: float64
3000
| ASSETS | STATUS | |
|---|---|---|
| 1 | 2000 | default |
| 2 | 0 | default |
| 7 | 3000 | default |
default 1.0
Name: status, dtype: float64
| ASSETS | STATUS | |
|---|---|---|
| 0 | 8000 | default |
| 3 | 5000 | ok |
| 4 | 5000 | ok |
| 5 | 4000 | ok |
| 6 | 9000 | ok |
ok 0.8
default 0.2
Name: status, dtype: float64
4000
| ASSETS | STATUS | |
|---|---|---|
| 1 | 2000 | default |
| 2 | 0 | default |
| 5 | 4000 | ok |
| 7 | 3000 | default |
default 0.75
ok 0.25
Name: status, dtype: float64
| ASSETS | STATUS | |
|---|---|---|
| 0 | 8000 | default |
| 3 | 5000 | ok |
| 4 | 5000 | ok |
| 6 | 9000 | ok |
ok 0.75
default 0.25
Name: status, dtype: float64
5000
| ASSETS | STATUS | |
|---|---|---|
| 1 | 2000 | default |
| 2 | 0 | default |
| 3 | 5000 | ok |
| 4 | 5000 | ok |
| 5 | 4000 | ok |
| 7 | 3000 | default |
default 0.5
ok 0.5
Name: status, dtype: float64
| ASSETS | STATUS | |
|---|---|---|
| 0 | 8000 | default |
| 6 | 9000 | ok |
default 0.5
ok 0.5
Name: status, dtype: float64
8000
| ASSETS | STATUS | |
|---|---|---|
| 0 | 8000 | default |
| 1 | 2000 | default |
| 2 | 0 | default |
| 3 | 5000 | ok |
| 4 | 5000 | ok |
| 5 | 4000 | ok |
| 7 | 3000 | default |
default 0.571429
ok 0.428571
Name: status, dtype: float64
| ASSETS | STATUS | |
|---|---|---|
| 6 | 9000 | ok |
ok 1.0
Name: status, dtype: float64
Impurity
The misclassification rate is just one way of measuring impurity. We aim for our leaves to be as pure as possible, meaning that the groups resulting from the split should ideally contain only observations of one class, making them pure. The misclassification rate informs us of how impure these groups are by quantifying the error rate in classifying observations.
There are also alternative methods to measure impurity. Scikit-Learn offers more advanced criteria for splitting, such as entropy and Gini impurity. These criteria provide different ways to evaluate the quality of splits, with the goal of creating decision trees with more homogeneous leaves.
| T | Decision LEFT | Impurity LEFT | Decision RIGHT | Impurity RIGHT | AVG |
|---|---|---|---|---|---|
| 0 | DEFAULT | 0% | OK | 43% | 21% |
| 2000 | DEFAULT | 0% | OK | 33% | 16% |
| —————— | —————— | —————— | —————— | —————— | —————— |
| 3000 | DEFAULT | 0% | OK | 20% | 10% |
| —————— | —————— | —————— | —————— | —————— | —————— |
| 4000 | DEFAULT | 25% | OK | 25% | 25% |
| 5000 | DEFAULT | 50% | OK | 50% | 50% |
| 8000 | DEFAULT | 43% | OK | 0% | 21% |
We observe that the split with T = 3000 yields the lowest average misclassification rate, resulting in the best impurity of 10%. This is how we determine the best split when dealing with just one column. To summarize our approach:
- We sort the dataset.
- We identify all possible thresholds.
- For each of the thresholds, we split the dataset.
- For each split, we calculate the impurity on the left and the right.
- We then compute the average impurity.
- Among all the splits considered, we select the one with the lowest average impurity.
The decision tree we’ve learned based on this split is as follows:
ASSETS > 3000
/ \
FALSE TRUE
DEFAULT OK