Undestanding Spark's joins
May 01, 2020
Reason for this topis is following code:
df = df.repartition(20000)
another_df = another_df.repartition(20000)
df = (
df
join.(another_df, "key", "inner")
)
It caused out of memory exception and encouraged me to dig deeper into what happened.
To make my points clear, here are details of my setup.
Spark configuration relevant to my cases:
spark.sql.autoBroadcastJoinThreshold -1 // Prevent automatic broadcast
spark.sql.shuffle.partitions 200 // Default partitions during shuffle.
Three datasets:
large_df.columns
Out: ['category', 'no', 'value']
large_df = large_df.repartition(10, "no")
large_df.rdd.getNumPartitions()
Out: 10
another_large_df.columns
Out: ['category_2', 'no', 'value_2']
large_df = large_df.repartition(10, "no")
large_df.rdd.getNumPartitions()
Out: 10
mapping.columns
Out: ['category', 'category_name']
mapping = mapping.coalesce(1)
mapping.rdd.getNumPartitions()
Out: 1
Initial observation
join_different_partitions_keys = large_df.join(another_large_df, "no", "inner")
join_different_partitions_keys.rdd.getNumPartitions()
Out: 200
Original dataframes large_df
and another_large_df
had 10 partitions each, but output is 200. The reason for it is shuffling. Both dataframes are having 10 partitions but Spark does not know what data is inside each partition of another_large_df
in terms of no
. More importantly, does not which no
s are there. To avoid comparing every pair combination, Spark performs shuffle, which moves rows between partitions to have a notion of what no
s are in each of them. It shuffles dataframe to 200 partitions instead of 10. 200 is coming from spark.sql.shuffle.partitions
.
In my original situation trying fit 20000 partitions into 200 resulted in out of memory on an executor. I realized it when I noticed 200 tasks in Spark UI while executing join.
More observations
Same observation when we have longer partition key:
another_large_df = another_large_df.repartition(10, "no", "category_2")
join_longer_partitions_keys = large_df.join(another_large_df, "no", "inner")
join_longer_partitions_keys.rdd.getNumPartitions()
Out: 200
Similar observation in case of mismatch between partition key and join key:
another_large_df = another_large_df.withColumnRenamed("category_2", "category").repartition(10, "no")
join_longer_join_key = large_df.join(another_large_df, ["no", "category"], "inner")
join_longer_join_key.rdd.getNumPartitions()
Out: 200
In case of matching keys we are receiving expected result:
another_large_df = another_large_df.withColumnRenamed("category", "category_2").repartition(10, "no")
optimal_join = large_df.join(another_large_df, "no", "inner")
optimal_join.rdd.getNumPartitions()
Out: 10
Broadcast observations
Similar issue we experience in case of small data:
not_broadcast_join = large_df.join(mapping, "category", "inner")
not_broadcast_join.rdd.getNumPartitions()
Out: 200
Broadcast solves the issue:
broadcast_join = large_df.join(F.broadcast(mapping), "category", "inner")
broadcast_join.rdd.getNumPartitions()
Out: 10
Conclusion
Here are some things to remember
- use
broadcast
in case of joining with small dataframes, Spark automatically does it if it is belowspark.sql.autoBroadcastJoinThreshold
- know set
spark.sql.shuffle.partitions
- keep the same partition and join keys
Entire code can be found here