datarekha
Machine Learning Medium Asked at GoogleAsked at Amazon

What is k-means++ and why is it better than random initialisation?

The short answer

K-means++ initialises centroids by probabilistically spacing them apart: the first centroid is chosen uniformly at random, and each subsequent centroid is chosen with probability proportional to its squared distance from the nearest already-chosen centroid. This reduces the chance of bad starts, cuts the number of iterations to convergence, and provides an O(log k) approximation guarantee on the final inertia.

How to think about it

Initialisation quality determines whether k-means converges to a good local minimum or a poor one. K-means++ turns a probabilistic trick into a theoretical guarantee.

The problem with random initialisation

Choose k centroids uniformly at random from the data and you may pick two centroids that land in the same true cluster, leaving another cluster uncovered. Recovery requires many extra iterations — or convergence to a wrong partition. With k restarts this is partly mitigated, but each restart is a full O(nk) run.

K-means++ initialisation

  1. Choose the first centroid c_1 uniformly at random from the data.
  2. For each remaining centroid c_i (i = 2, …, k):
    • Compute D(x) = min distance from each point x to the nearest already-chosen centroid.
    • Sample the next centroid proportional to D(x)².
  3. Proceed with standard k-means assignment and update steps.

By weighting by D(x)², points that are far from all current centroids are more likely to become the next centroid. This naturally spreads initial centroids across distinct clusters.

Why the guarantee matters

Arthur and Vassilvitskii (2007) proved that k-means++ achieves expected inertia within O(log k) of the globally optimal inertia — a bound that random initialisation cannot provide. In practice it also converges in fewer iterations, so the initialisation overhead is quickly recovered.

from sklearn.cluster import KMeans

# init="k-means++" is the default in scikit-learn >= 0.24
km = KMeans(n_clusters=5, init="k-means++", n_init=10, random_state=42)
km.fit(X)

# Compare to random to see how much it helps on your data
km_random = KMeans(n_clusters=5, init="random", n_init=10, random_state=42)
km_random.fit(X)

print(f"k-means++ inertia: {km.inertia_:.1f}")
print(f"random init inertia: {km_random.inertia_:.1f}")

Keep practising

All Machine Learning questions

Explore further

Skip to content