Spark Caching and Persistence
Caching tells Spark to store a DataFrame or RDD in memory after computing it for the first time. The next time you use that data, Spark reads from memory instead of recomputing everything from scratch. This dramatically speeds up workloads that use the same dataset multiple times.
Why Caching Matters — The Repeated Computation Problem
WITHOUT caching: Action 1: df.count() Read CSV --> filter --> join --> count (takes 60 seconds) Action 2: df.show() Read CSV --> filter --> join --> show (takes 60 seconds again!) WITH caching: df.cache() Action 1: df.count() Read CSV --> filter --> join --> [STORE IN MEMORY] --> count (60 sec) Action 2: df.show() [READ FROM MEMORY] --> show (1 second!)
cache() vs persist()
cache() is shorthand for persist(StorageLevel.MEMORY_AND_DISK). persist() gives you control over exactly where Spark stores the data.
from pyspark import StorageLevel # Default cache (memory, spills to disk if memory is full) df.cache() # Same as above, explicit version df.persist(StorageLevel.MEMORY_AND_DISK) # Memory only (faster, but data is lost if memory is full) df.persist(StorageLevel.MEMORY_ONLY) # Disk only (slow, but works even with tiny memory) df.persist(StorageLevel.DISK_ONLY) # Memory with serialization (less memory used, slightly slower to deserialize) df.persist(StorageLevel.MEMORY_ONLY_SER)
Storage Level Comparison
| Storage Level | Memory | Disk | Serialized | Best For |
|---|---|---|---|---|
| MEMORY_ONLY | Yes | No | No | Data fits comfortably in RAM |
| MEMORY_AND_DISK | Yes | Spill | No | Most general use (default cache) |
| DISK_ONLY | No | Yes | Yes | Very large datasets, low RAM |
| MEMORY_ONLY_SER | Yes | No | Yes | Reduce memory footprint |
| MEMORY_AND_DISK_SER | Yes | Spill | Yes | Balanced memory and speed |
Important: Caching Is Lazy
Like transformations, cache() does not store data immediately. Spark caches the data the first time an action forces computation of that DataFrame. Subsequent actions read from cache.
df = spark.read.parquet("large_dataset.parquet")
df = df.filter("amount > 1000").cache()
# At this point, nothing is cached yet
df.count() # <-- FIRST action: computes AND caches
df.show() # <-- SECOND action: reads from cache (fast!)
df.describe() # <-- reads from cache (fast!)
Verifying What Is Cached
The Spark UI Storage tab shows all cached DataFrames and RDDs, including how much memory each one uses and whether any partitions spilled to disk.
# Check cache status in code print(df.is_cached) # True or False # Spark UI: http://localhost:4040 --> Storage tab # Shows: # +-------------------+---------+-------------------+-----------+ # | RDD Name | Storage | Partitions Cached | Size | # +-------------------+---------+-------------------+-----------+ # | filtered_sales | Memory | 20 / 20 | 2 GB | # +-------------------+---------+-------------------+-----------+
Releasing Cache — unpersist()
Always release cached data when you no longer need it. Leaving unused data cached wastes memory and can slow other jobs running on the same cluster.
# Release cache df.unpersist() # Verify it was released print(df.is_cached) # False
When to Cache and When Not To
Cache these:
- DataFrames you use in multiple actions in the same job
- Training datasets in machine learning loops (same data read many times)
- Expensive joins or aggregations used more than once
- Lookup tables used repeatedly across many queries
Do not cache these:
- DataFrames used only once — caching adds overhead with no benefit
- Datasets too large to fit in cluster memory — they spill to disk and can slow things down
- Continuously streaming data — streaming has its own state management
Machine Learning Caching Example
# Load training data — used in every training iteration
train_data = spark.read.parquet("training_data.parquet")
train_data.cache()
train_data.count() # trigger the cache load
# Now train for 50 iterations — reads from cache each time
for iteration in range(50):
model = train_model(train_data, iteration)
print(f"Iteration {iteration}: loss = {model.loss}")
# Clean up
train_data.unpersist()
