Intro
When developing, we often need to iterate over a collection and pick out only the elements of specific types. This seems like a simple task, but what's the optimal way to write the code? In this post, I'll share my journey, starting from inefficient code and progressing to a final solution that is type-safe, efficient, and reusable, using Scala 3 macros.
The Start of the Problem: Inefficient Type Filtering
First, let's define a simple shape hierarchy that we'll use for our examples.
sealed trait Shapeobject Shape {
case class Circle(radius: Double) extends Shapecase class Square(side: Double) extends Shapesealed trait Polygon extends Shapeobject Polygon {
case class Triangle(a: Double, b: Double, c: Double) extends Polygon
}
}
// A type unrelated to the Shape hierarchy
sealed trait NotShape
Imagine we have an Iterable[Shape]
and we want to extract all the elements of type Shape.Circle
, Shape.Square
, and Shape.Polygon.Triangle
. The most intuitive code I first came up with was this:
val shapes: Iterable[Shape] = getShapes() // A function that returns a list of shapes
val circles = shapes.collect { case c: Shape.Circle => c }
val squares = shapes.collect { case s: Shape.Square => s }
val triangles = shapes.collect { case t: Shape.Polygon.Triangle => t }
The biggest problem with this code is clear. Even though we could complete the task in a single pass since each shape type is mutually exclusive, this code iterates through the shapes
collection three separate times. As the amount of data grows, this inefficiency becomes a direct cause of performance degradation.
First Improvement: One-Pass Iteration with foldLeft
To improve performance, I wrote code that uses foldLeft
to perform the collection in a single pass.
val (circles, squares, triangles) =
shapes.foldLeft(
(List.empty[Shape.Circle], List.empty[Shape.Square], List.empty[Shape.Polygon.Triangle])
) { case ((cs, ss, ts), shape) =>
shape match {
case c: Shape.Circle => (c :: cs, ss, ts)
case s: Shape.Square => (cs, s :: ss, ts)
case t: Shape.Polygon.Triangle => (cs, ss, t :: ts)
case _ => (cs, ss, ts)
}
}
This code certainly has the advantage of getting the desired result in a single pass. However, I wasn't very happy with it for two clear reasons:
- Verbose Code: As more types are added, the initial tuple for
foldLeft
and thecase
statements get longer and longer. - Not Intuitive:
foldLeft
is typically used for aggregating values, but here it's being used simply for 'collecting' by type, which obscures its intent.
I wanted to create a method that was both intuitive and efficient, something like .collectTypes[Circle, Square, Triangle]
.
Second Improvement: Extension Methods with implicit class
To add new functionality to Iterable[A]
, I decided to use an implicit class
.
What is an implicit class?
An
implicit class
is a powerful Scala feature that allows you to add new methods to existing classes without modifying their source code. When the compiler sees a method call on an object that doesn't exist, it looks for animplicit class
in scope that can "wrap" the object to provide that method. This is also known as the 'Extension Method' pattern.
Using an implicit class
, I created a collectTypes
method that could handle two types.
implicit class IterableOps[A](iterable: Iterable[A]) extends AnyVal {
def collectTypes[T1 <: A: ClassTag, T2 <: A: ClassTag]: (List[T1], List[T2]) = {
val b1 = List.newBuilder[T1]
val b2 = List.newBuilder[T2]
iterable.foreach {
case t1: T1 => b1 += t1
case t2: T2 => b2 += t2
case _ =>
}
(b1.result(), b2.result())
}
}
This code worked wonderfully. However, it had a clear limitation: the number of types you could filter for was fixed. To handle three types, I'd have to create another overloaded method with three parameters. What if I needed to filter for 4, 5, or even 20 types? It was at this point that I decided I needed to use a technique that generates code at compile time: Macros.
Third Improvement: First Steps with Scala 3 Macros
What is a Scala Macro?
A macro is "code that writes code." It runs at compile time to manipulate a program's Abstract Syntax Tree (AST) and generate new code. This allows you to eliminate repetitive boilerplate, achieve high levels of abstraction with no runtime penalty, and implement powerful, type-based logic at compile time.
My goal was to create the following interface:
implicit class IterableOps[A](iterable: Iterable[A]) extends AnyVal {
inline def collectTypes[Ts <: Tuple]: Tuple.Map[Ts, List] =
macros.CollectTypesMacro.collectTypes[Ts](iterable)
}
// Example usage
val (circles, squares, triangles) = shapes.collectTypes[(Shape.Circle, Shape.Square, Shape.Polygon.Triangle)]
While there's a minor inconvenience of passing the type parameters as a Tuple
(e.g., collectTypes[(...)]
), it allows us to handle N types just by adding two parentheses ()
.
The Code the Macro Will Write
The core purpose of this macro is to have the compiler replace the collectTypes
call with code that is even more efficient than our manual foldLeft
version. When the compiler sees shapes.collectTypes[(...)]
, it generates the following code right in its place.
The process that I want is:
- Create a
Builder
for each type. - Iterate through the collection only once, using a
match
statement to add elements to the appropriateBuilder
. - Return a tuple containing the results from all
Builder
s.
// At compile time, the call is expanded into the following code:
{
// 1. Create a Builder for each type
val b1: scala.collection.mutable.Builder[Shape.Circle, List[Shape.Circle]] =
scala.List.newBuilder[Shape.Circle]
val b2: scala.collection.mutable.Builder[Shape.Square, List[Shape.Square]] =
scala.List.newBuilder[Shape.Square]
val b3: scala.collection.mutable.Builder[Shape.Polygon.Triangle, List[Shape.Polygon.Triangle]] =
scala.List.newBuilder[Shape.Polygon.Triangle]
// 2. Iterate once and use a match statement to add elements to the correct Builder
shapes.foreach { item =>
item match {
case x: Shape.Circle => b1.addOne(x)
case x: Shape.Square => b2.addOne(x)
case x: Shape.Polygon.Triangle => b3.addOne(x)
case _ => ()
}
}
// 3. Return a tuple of the results from all Builders
(b1.result(), b2.result(), b3.result())
}
This generated code is exactly what I want: clean, efficient, and hidden behind a simple, reusable method call.
Final Solution: A Macro with Compile-Time Safety
After getting this far, I noticed a critical flaw. What would happen if I tried to filter for a type that has nothing to do with Iterable[Shape]
, like NotShape
?
val (circles, notShapes) = shapes.collectTypes[(Shape.Circle, NotShape)]
The current macro doesn't check if the types passed to collectTypes
are actually subtypes of the collection's element type (Shape
). In this case, the code would compile successfully but would produce a bug at runtime, as the NotShape
case in the match
statement could never be reached, always returning an empty list. Catching issues at compile time is one of the main motivations for using macros, after all.
So, I added subtype validation logic to the macro. First, when calling the macro from IterableOps
, we pass along the collection's element type A
.
// IterableOps.scala
implicit class IterableOps[A](iterable: Iterable[A]) extends AnyVal {
inline def collectTypes[Ts <: Tuple]: Tuple.Map[Ts, List] =
// Pass type A to the macro for validation
macros.CollectTypesMacro.collectTypes[A, Ts](iterable)
}
Then, the macro implementation (collectTypesImpl
) receives this type information and performs the check.
// CollectTypesMacro.scala
def collectTypesImpl[A: Type, Ts <: Tuple: Type](...) = {
// ...
// --- Added validation logic ---
val iterableElemTpe: TypeRepr = TypeRepr.of[A]
types.foreach { t =>
// Check the subtype relationship using the '<:<' operator
if (!(t <:< iterableElemTpe)) {
report.errorAndAbort(
s"Type ${t.show} is not a subtype of the Iterable element type ${iterableElemTpe.show}"
)
}
}
// ------------------------------
// ... rest of the macro implementation ...
}
Now, the macro knows about type A
and confirms that every type passed in the Tuple
is a subtype of A
at compile time. If the relationship doesn't hold, we get a helpful compile error.
// When compiling code like shapes.collectTypes[(Shape.Circle, NotShape)]
[error] -- Error: .../Main.scala:39:35
[error] 39 | shapes.collectTypes[(Shape.Circle, NotShape)]
...
[error] |Type com.example.NotShape is not a subtype of the Iterable element type com.example.Shape.
This prevents runtime bugs at the source, creating a truly robust and safe solution.
Conclusion and Summary
- Starting from a simple collection filtering problem, my journey went through the following stages:
- Problem Identification: Multiple
collect
calls are inefficient. - First Improvement:
foldLeft
is efficient but verbose and unintuitive. - Second Improvement:
implicit class
is user-friendly but lacks scalability for N types. - Third Improvement: A Scala macro can generate efficient code for N types, but it initially lacked compile-time safety checks.
- Final Solution: By adding subtype validation, the Scala 3 macro became the ultimate solution, achieving efficiency, type safety, and reusability.
- Problem Identification: Multiple
- Macros are more than just a tool for reducing code; they are a powerful way to build the exact abstractions you want without any runtime cost. By combining Scala's strong type system with macros, I could solve complex requirements with code that is concise, safe, and efficient.