Cost complexity pruning for decision trees (CPP)
When making a decision tree with no pruning it can tend to overfit the training data. To reduce this we can “prune” the decision tree. We do this by looking at how useful each of the branches are - then removing the ones that don’t add enough value. This is controlled by a variable
Suppose we are in modelling framework where
to be the value we want to minimise on our sub-trees.
Optimally we would iterate over all potential subtrees of
One method to do this it to calculate the effective cost complexity of a non-terminal node. To give you the intuition behind this let
To define the effective cost complexity we will find the
(It can be derived by equating the contribution of
The process then prunes branches with the lowest
Pseudocode
CPP(decision_tree, alpha, R):
Input:
decision_tree with a set of vertices V
alpha positive constant to determine pruning
R the evaluation function such as Entropy or Gini, this will have to
defined on vertices of V (for the training data that gets
classifed to v)
Output:
a pruned decision tree
1. Set best_tree = decision_tree
2. For each non-leaf set alpha_eff(v) = calculate_effective_alpha(decision_tree, v, R)
3. Set min_eff_alpha = min(alpha_eff(v) for v in V non-leaf)
4. While min_eff_alpha < alpha
4.1. Find v such that alpha_eff(v) = min_eff_alpha
4.2. Prune best_tree at v
4.3. Let P be the set of vertices with v downstream of it.
4.4. For each p in P set alpha_eff(p) = calculate_effective_alpha(best_tree, p, R)
4.5. Set min_eff_alpha = min(alpha_eff(v) for v in best_tree non-leaf)
4.6. Break if best_tree only has 1 vertex.
5. return best_tree.
calculate_effective_alpha(tree, v, R):
Input
tree is a decision tree
v is a vertex in V that is not a leaf node
R the evaluation function such as Entropy or Gini, this will have to
defined on vertices of V (for the training data that gets
classifed to v)
Output
alpha_eff the effective alpha for that vertex
1. Let L be the set of leaf vertices with v unstream of it in tree.
2. Set total_leaf_weight = sum(R(x) for x in L)
3. Return (R(v) - total_leaf_weight)/(|L| - 1)
Run time
Optimally we would iterate over all potential subtrees of
Correctness
This can reduce overfitting, however the parameter