datarekha
Machine Learning Medium Asked at AmazonAsked at MicrosoftAsked at Apple

Walk me through exactly how a decision tree chooses a split at each node.

The short answer

At each node the algorithm iterates over every feature and every candidate threshold, scores each candidate split by the weighted impurity of the two child nodes, and selects the pair that gives the largest impurity reduction. It then recurses on each child until a stopping criterion is met.

How to think about it

Step-by-step split selection

  1. Candidate thresholds — for a continuous feature with N distinct values, sort them and consider the N-1 midpoints as candidate thresholds. For a categorical feature (binary-encoded), consider each category as a binary split.

  2. Score each candidate — compute the weighted average impurity of the two child nodes:

Gain(split) = Impurity(parent) - [n_L/n * Impurity(L) + n_R/n * Impurity(R)]
  1. Select the best — keep the (feature, threshold) pair with the maximum gain.

  2. Recurse — apply the same procedure to left and right children until a stopping rule fires (max depth reached, node too small, or gain below a minimum).

from sklearn.tree import DecisionTreeClassifier, export_text
import numpy as np

X = np.array([[2.5], [1.0], [3.5], [0.5], [4.0]])
y = np.array([0, 0, 1, 0, 1])

tree = DecisionTreeClassifier(max_depth=2, criterion="gini")
tree.fit(X, y)
print(export_text(tree, feature_names=["x"]))

Computational cost — naive greedy splitting is O(n · d · log n) per level for n samples and d features, which is why sklearn sorts features once and scans thresholds in order.

Key insight: the split is always axis-aligned (one feature at a time). This makes trees easy to interpret but unable to natively capture diagonal decision boundaries without many levels.

Learn it properly Decision trees

Keep practising

All Machine Learning questions

Explore further

Skip to content