datarekha
Data Engineering Hard Asked at DatabricksAsked at AmazonAsked at MetaAsked at UberAsked at AirbnbAsked at LinkedIn

What is data skew in Spark and how do you fix it with salting?

The short answer

Data skew occurs when one or more keys concentrate disproportionately more rows than others, causing a few tasks to process gigabytes while the rest finish in seconds — one slow task stalls the entire stage. Salting appends a random suffix to skewed keys before a join or aggregation, spreading the hot key across multiple partitions, then removes the salt after aggregation.

How to think about it

Data skew is one of the most common causes of Spark job hangs. A stage with 200 tasks can sit at “199/200” for 30 minutes because one partition holds 80% of the data.

Diagnosing skew

In the Spark UI, open the stage details and look at the task duration distribution. A skewed stage shows one or a few tasks with dramatically longer durations and much larger input sizes than the median.

How salting fixes a skewed join

Salting artificially spreads a hot key across multiple partitions by appending a random integer suffix. After the join, you aggregate away the salt.

from pyspark.sql import functions as F

SALT_FACTOR = 50  # tune based on skew severity

# 1. Salt the large (skewed) table
large_df = large_df.withColumn(
    "salt", (F.rand() * SALT_FACTOR).cast("int")
).withColumn(
    "salted_key", F.concat(F.col("join_key"), F.lit("_"), F.col("salt"))
)

# 2. Explode the small table to match all salt values
small_df = small_df.withColumn(
    "salt", F.explode(F.array([F.lit(i) for i in range(SALT_FACTOR)]))
).withColumn(
    "salted_key", F.concat(F.col("join_key"), F.lit("_"), F.col("salt"))
)

# 3. Join on the salted key — hot key is now split across 50 partitions
joined = large_df.join(small_df, "salted_key")

# 4. Drop the salt columns
result = joined.drop("salt", "salted_key")

The small table grows by SALT_FACTOR, so salting only makes sense when the small side fits in memory after expansion. If the small side is tiny, a broadcast join is simpler.

Salting for aggregations

# Phase 1: partial aggregation with salt
df_salted = df.withColumn("salt", (F.rand() * 50).cast("int")) \
              .withColumn("salted_key", F.concat(col("key"), lit("_"), col("salt")))

partial = df_salted.groupBy("salted_key").agg(F.sum("value").alias("partial_sum"))

# Phase 2: final aggregation — strip salt, aggregate partial sums
partial.withColumn("key", F.split(col("salted_key"), "_")[0]) \
       .groupBy("key").agg(F.sum("partial_sum"))

AQE skew join optimization

In Spark 3+, Adaptive Query Execution can detect and split skewed partitions automatically at runtime with spark.sql.adaptive.skewJoin.enabled = true. It is the first thing to try — manual salting is still needed when AQE cannot detect the skew or for aggregations.

Keep practising

All Data Engineering questions

Explore further

Skip to content