In part 1/2 of “Decision Tree Learning Algorithm,” we delved into a simple example to understand how a decision tree learns rules. We used the misclassification rate as a means to evaluate the accuracy of our predictions. Additionally, we discussed that the misclassification rate is just one way to measure impurity. In the second part, we will explore how to find the best split when we have two feature columns. This will lead us to a more general algorithm for finding the best split when working with multiple features. Finally, we will conclude this section with a comprehensive Decision Tree Learning Algorithm.
Finding the best split for two columns
In the last article, we began with a simple example that involved just one feature. Now, let’s explore what happens when we introduce another feature, ‘debt.’ This feature represents the amount of debt that clients have.
We’ve previously attempted splitting the data based on the ‘assets’ feature. Now, let’s investigate the results of attempting a split based on the ‘debt’ feature. To do this, we should first sort the data by the ‘debt’ column.
df_example.sort_values('debt')
ASSETS
DEBT
STATUS
6
9000
500
ok
1
2000
1000
default
2
0
1000
default
3
5000
1000
ok
4
5000
1000
ok
5
4000
1000
ok
7
3000
2000
default
0
8000
3000
default
Dataframe sorted by debt column
Possible thresholds:
T = 500 (DEBT<=500 becomes LEFT, DEBT>500 becomes RIGHT)
T = 1000 (DEBT<=1,000 becomes LEFT, DEBT>1,000 becomes RIGHT)
T = 2000 (DEBT<=2,000 becomes LEFT, DEBT>2,000 becomes RIGHT)
So we’ve seen how it works for only one feature. Let’s generalize a bit and put both threshold variables in a dictionary. If there are more features, we can put them here as well.
thresholds = {
'assets': [0, 2000, 3000, 4000, 5000, 8000],
'debt': [500, 1000, 2000]
}
for feature, Ts in thresholds.items():
print('#####################')
print(feature)
for T in Ts:
print(T)
df_left = df_example[df_example[feature] <= T]
df_right = df_example[df_example[feature] > T]
display(df_left)
print(df_left.status.value_counts(normalize=True))
display(df_right)
print(df_right.status.value_counts(normalize=True))
print()
print('#####################')
##################### assets 0
ASSETS
DEBT
STATUS
2
0
1000
default
LEFT for T=0
default 1.0
Name: status, dtype: float64
ASSETS
DEBT
STATUS
0
8000
3000
default
1
2000
1000
default
3
5000
1000
ok
4
5000
1000
ok
5
4000
1000
ok
6
9000
500
ok
7
3000
2000
default
RIGHT for T=0
ok 0.571429
default 0.428571
Name: status, dtype: float64
2000
ASSETS
DEBT
STATUS
1
2000
1000
default
2
0
1000
default
LEFT for T=2000
default 1.0
Name: status, dtype: float64
ASSETS
DEBT
STATUS
0
8000
3000
default
3
5000
1000
ok
4
5000
1000
ok
5
4000
1000
ok
6
9000
500
ok
7
3000
2000
default
RIGHT for T=2000
ok 0.666667
default 0.333333
Name: status, dtype: float64
3000
ASSETS
DEBT
STATUS
1
2000
1000
default
2
0
1000
default
7
3000
2000
default
LEFT for T=3000
default 1.0
Name: status, dtype: float64
ASSETS
DEBT
STATUS
0
8000
3000
default
3
5000
1000
ok
4
5000
1000
ok
5
4000
1000
ok
6
9000
500
ok
RIGHT for T=3000
ok 0.8
default 0.2
Name: status, dtype: float64
4000
ASSETS
DEBT
STATUS
1
2000
1000
default
2
0
1000
default
5
4000
1000
ok
7
3000
2000
default
LEFT for T=4000
default 0.75
ok 0.25
Name: status, dtype: float64
ASSETS
DEBT
STATUS
0
8000
3000
default
3
5000
1000
ok
4
5000
1000
ok
6
9000
500
ok
RIGHT for T=4000
ok 0.75
default 0.25
Name: status, dtype: float64
5000
ASSETS
DEBT
STATUS
1
2000
1000
default
2
0
1000
default
3
5000
1000
ok
4
5000
1000
ok
5
4000
1000
ok
7
3000
2000
default
LEFT for T=5000
default 0.5
ok 0.5
Name: status, dtype: float64
ASSETS
DEBT
STATUS
0
8000
3000
default
6
9000
500
ok
RIGHT for T=5000
default 0.5
ok 0.5
Name: status, dtype: float64
8000
ASSETS
DEBT
STATUS
0
8000
3000
default
1
2000
1000
default
2
0
1000
default
3
5000
1000
ok
4
5000
1000
ok
5
4000
1000
ok
7
3000
2000
default
LEFT for T=8000
default 0.571429
ok 0.428571
Name: status, dtype: float64
default 0.571429
ok 0.428571
Name: status, dtype: float64
1000
ASSETS
DEBT
STATUS
1
2000
1000
default
2
0
1000
default
3
5000
1000
ok
4
5000
1000
ok
5
4000
1000
ok
6
9000
500
ok
LEFT for T=1000
ok 0.666667
default 0.333333
Name: status, dtype: float64
ASSETS
DEBT
STATUS
0
8000
3000
default
7
3000
2000
default
RIGHT for T=1000
default 1.0
Name: status, dtype: float64
2000
ASSETS
DEBT
STATUS
1
2000
1000
default
2
0
1000
default
3
5000
1000
ok
4
5000
1000
ok
5
4000
1000
ok
6
9000
500
ok
7
3000
2000
default
LEFT for T=2000
ok 0.571429
default 0.428571
Name: status, dtype: float64
ASSETS
DEBT
STATUS
0
8000
3000
default
RIGHT for T=2000
default 1.0
Name: status, dtype: float64
We calculate the table for assets already, let’s do this also for debt.
T
Decision LEFT
Impurity LEFT
Decision RIGHT
Impurity RIGHT
AVG
0
DEFAULT
0%
OK
43%
21%
2000
DEFAULT
0%
OK
33%
16%
——————
——————
——————
——————
——————
——————
3000
DEFAULT
0%
OK
20%
10%
——————
——————
——————
——————
——————
——————
4000
DEFAULT
25%
OK
25%
25%
5000
DEFAULT
50%
OK
50%
50%
8000
DEFAULT
43%
OK
0%
21%
Average Impurity for Assets column
T
Decision LEFT
Impurity LEFT
Decision RIGHT
Impurity RIGHT
AVG
500
OK
0%
DEFAULT
43%
21%
1000
OK
33%
DEFAULT
0%
16%
2000
OK
43%
DEFAULT
0%
21%
Average Impurity for Debt column
We can observe that the “debt” feature is not as useful for making the split as the “assets” feature. The best split overall remains the same as before: “ASSETS > 3000.” If we were working with three features, we would have another group in the table with another set of rows and another variable to consider.
Finding the best split algorithm for multiple features
FOR each feature in FEATURES:
FIND all thresholds for the feature
FOR each threshold in thresholds:
SPLIT the dataset using "feature > threshold" condition
COMPUTE the impurity of this split
SELECT the condition with the LOWEST IMPURITY
We may need to establish some stopping criteria. This is because we apply recursive splitting, but how do we determine when to stop?
Stopping Criteria
To determine when to halt the recursive splitting process, it is essential to define stopping criteria. These criteria are vital for preventing overfitting and ensuring that the tree-building process concludes when specific conditions are satisfied. In the next article, when we discuss parameter tuning, we will delve into common stopping criteria for decision trees. Selecting the right stopping criteria is crucial as it depends on the dataset and the problem we are addressing. By adhering to these criteria, we can achieve a balance between model complexity and generalization. These criteria can be integrated into the decision tree algorithm to determine when to cease recursive splitting.
The most common stopping criteria are:
The group is already pure: If a node is pure, meaning all samples in that node belong to the same class, further splitting is unnecessary as it won’t provide any additional information.
The tree reached depth limit (controlled by max_depth): Limiting the depth of the tree helps control overfitting and ensures that the tree doesn’t become too complex.
The group is too small for further splitting (controlled by min_samples_leaf): Restricting splitting when the number of samples in a node falls below a certain threshold helps prevent overfitting and results in smaller, more interpretable trees.
When we use these criteria, we enforce simplicity in our model, thereby preventing overfitting. Let’s summarize the decision tree learning algorithm.
Decision Tree Learning Algorithm
Find the Best Split: For each feature evaluate all splits based on all possible thresholds and select this one that has the lowest impurity.
Stop if Max Depth is Reached: Stop the splitting process if the maximum allowable depth of the tree is reached.
If LEFT is Sufficiently Large and Not Pure Repeat for LEFT: If the left subset of data is both sufficiently large and not pure (contains more than one class), repeat the splitting process for the left subset.
If RIGHT is Sufficiently Large and Not PureRepeat for RIGHT: Similarly, if the right subset of data is both sufficiently large and not pure, repeat the splitting process for the right subset.
For more information on decision trees, you can visit the scikit-learn website. In the ‘Classification criteria’ section, you’ll find various methods for measuring impurity, including ‘Misclassification,’ as well as others like ‘Entropy’ and ‘Gini.’ It’s important to note that decision trees can also be employed to address regression problems.