Spark RDD Transformations
A transformation takes one RDD and produces a new RDD. Transformations do not run immediately — Spark records them in a plan and executes that plan only when you call an action. This design is called lazy evaluation and is covered fully in the next topic.
The Assembly Line Mental Model
Raw Material (original RDD)
|
v
[ Station 1: filter ] removes bad parts
|
v
[ Station 2: map ] reshapes each part
|
v
[ Station 3: flatMap ] explodes one part into many
|
v
Final Product (new RDD)
Each station is a transformation. No station runs until the factory manager calls "start" — that call is an action.
map — Transform Each Element
map() applies a function to every record and returns one new record per input record.
prices = sc.parallelize([10, 20, 30, 40]) # Double every price doubled = prices.map(lambda x: x * 2) print(doubled.collect()) # Output: [20, 40, 60, 80] # Add a label to each price labeled = prices.map(lambda x: (x, "USD")) print(labeled.collect()) # Output: [(10, 'USD'), (20, 'USD'), (30, 'USD'), (40, 'USD')]
filter — Keep Only Matching Records
filter() keeps records where a condition returns True and drops the rest.
scores = sc.parallelize([45, 72, 88, 33, 95, 61]) # Keep only passing scores (above 60) passing = scores.filter(lambda x: x > 60) print(passing.collect()) # Output: [72, 88, 95, 61]
flatMap — One Input, Many Outputs
flatMap() applies a function that returns a list for each record, then flattens all those lists into one RDD. This is useful for splitting sentences into words.
sentences = sc.parallelize(["hello world", "spark is fast"])
# Split each sentence into words
words = sentences.flatMap(lambda line: line.split(" "))
print(words.collect())
# Output: ['hello', 'world', 'spark', 'is', 'fast']
# Compare with map (no flattening):
not_flat = sentences.map(lambda line: line.split(" "))
print(not_flat.collect())
# Output: [['hello', 'world'], ['spark', 'is', 'fast']]
distinct — Remove Duplicates
items = sc.parallelize(["apple", "banana", "apple", "cherry", "banana"]) unique = items.distinct() print(unique.collect()) # Output: ['apple', 'banana', 'cherry'] (order may vary)
union — Combine Two RDDs
rdd1 = sc.parallelize([1, 2, 3]) rdd2 = sc.parallelize([4, 5, 6]) combined = rdd1.union(rdd2) print(combined.collect()) # Output: [1, 2, 3, 4, 5, 6]
reduceByKey — Aggregate by Key
This transformation works on key-value pair RDDs. It groups records by key and applies a function to combine values within each group.
sales = sc.parallelize([
("North", 100), ("South", 200),
("North", 150), ("South", 50), ("East", 300)
])
# Sum sales per region
totals = sales.reduceByKey(lambda a, b: a + b)
print(totals.collect())
# Output: [('North', 250), ('South', 250), ('East', 300)]
Diagram — how reduceByKey works: Input pairs: North:100 South:200 North:150 South:50 East:300 Group by key: North --> [100, 150] --> sum --> North: 250 South --> [200, 50] --> sum --> South: 250 East --> [300] --> sum --> East: 300
groupByKey — Group Values by Key
groupByKey() collects all values for each key into a list. Use reduceByKey() when possible — it is faster because it reduces data before shuffling across the network.
data = sc.parallelize([("A", 1), ("B", 2), ("A", 3), ("B", 4)])
grouped = data.groupByKey().mapValues(list)
print(grouped.collect())
# Output: [('A', [1, 3]), ('B', [2, 4])]
sortBy and sortByKey
numbers = sc.parallelize([5, 1, 4, 2, 3]) sorted_rdd = numbers.sortBy(lambda x: x) print(sorted_rdd.collect()) # Output: [1, 2, 3, 4, 5] # Descending order desc = numbers.sortBy(lambda x: x, ascending=False) print(desc.collect()) # Output: [5, 4, 3, 2, 1]
Common Transformations at a Glance
| Transformation | Input | Output | Use Case |
|---|---|---|---|
| map(f) | 1 record | 1 record | Transform each element |
| flatMap(f) | 1 record | 0 or many records | Explode lists, split text |
| filter(f) | 1 record | 0 or 1 record | Remove unwanted rows |
| distinct() | RDD | RDD | Remove duplicates |
| union(rdd) | 2 RDDs | 1 RDD | Merge datasets |
| reduceByKey(f) | Pair RDD | Pair RDD | Aggregate per key |
| groupByKey() | Pair RDD | Pair RDD | Group values per key |
| sortBy(f) | RDD | RDD | Sort records |
