Implementing Gradient Boosting Tree from scratch in Python
Introduction
This blog covers the theoretical framework behind Gradient Boosting Tree and its implementation in Python without using any libraries that directly implement it.
Motivation
See the previous blog on Random Forest.
What is Gradient Boosting Tree?
Similar to Random Forest, Gradient Boosting Tree (GBT) also consists of a collection of decision trees. However, unlike Random Forest, which trains those trees independently, GBT trains its trees sequentially. The next tree uses the residuals of the previous tree as its target variable instead of the original labels in classification tasks. This means that classification tasks for GBT are treated as regression problems since residuals are used as the target variable.
When constructing the decision tree, the loss is calculated using residuals instead of labels. Below, we provide an example of how this works for classification tasks. For a visual demonstration of the algorithm, refer to StatQuest. Note that the video omits details on how the decision tree is constructed. In the example below, MSE loss is used.
To find the optimal split point, each value of all features is considered. The loss of each split is calculated using the weighted MSE of the sub-groups. The predicted value of each sub-group is the mean of all residuals in that group.
Implementation
The gbt
Function
The core of the program is the gbt
function:
This function calls the build_tree
function a specified number of times (treeLength
) and updates a progress bar accordingly. The trees returned from build_tree
are stored in a list named decision_trees
.
The build_tree
Function
The steps to build a tree are as follows:
- Find optimal split using MSE loss (lines 47–55).
- Initialize the node and fill in the decision boundary (lines 58–62).
- Construct left and right sub-trees under certain conditions (lines 66–69).
- Find the output of the leaf node and update predictions and residuals (lines 72–78).
Key Differences from Random Forest
- The loss function for finding the optimal split uses MSE since GBT treats classification as a regression problem.
- Updates at leaf nodes include recalculating predictions and residuals.
Loss Function
For GBT classification tasks, the Mean Squared Error (MSE) is used to find the optimal split:
An example:
- Residuals: 0.1, 0, -0.1
- Predicted value: 0
- MSE: \(\frac{(0.1 - 0)^2 + (-0.1 - 0)^2 + (0 - 0)^2}{3} = \frac{1}{150}\)
Node Initialization and Split
After finding the optimal split point, the node is initialized, and the split point is stored.
Conditions for Extending Sub-Trees:
- Both left and right sub-groups must have positive lengths.
- The weighted MSE of the split must be less than the MSE of the current node.
Output Calculation for Leaf Nodes
At leaf nodes, the output is computed as:
Predictions are updated using:
Residuals are updated using:
Making Predictions
To make a prediction for (x), the predict
function passes (x), all trained trees, initial predictions, and the learning rate. Each tree’s output is summed and passed to a sigmoid function to generate probabilities. Predictions are made using a threshold of 0.5.
Cross-Validation
The eval
function performs n-fold cross-validation on the dataset:
Results
Using 10-fold cross-validation, an accuracy of 86% was achieved on the validation set:
Parameter Tuning
GBT requires setting three hyperparameters:
- Max Depth
- Number of Trees
- Learning Rate
To find the best combination, all possible values are tested using the sklearn library:
Results showed that 200 trees with a max depth of 5 and a learning rate of 0.1 achieved the best validation accuracy (85.94%), slightly lower than the custom implementation.
Implementation
For full implementation details, see the GitHub repository.