Safer Scala with WartRemover Part 1: Seq.apply and head
This is the first of a series of posts about using WartRemover to improve the quality of Scala codebases. WartRemover is a static analysis tool for Scala.
I find it useful because it actively disables parts of Scala to improve safety by default. You get the benefits of a stricter language like Rust, while still being able to use SuppressWarnings
as an escape hatch.
I also like that the warnings often present as a mini programming puzzle. Will I be clever enough to fix it, perhaps even without AI help?
Here’s a very short reference on setting up WartRemover for your Scala project:
// plugins.sbt
addSbtPlugin("org.wartremover" % "sbt-wartremover" % "3.3.3")
// any other plugins...
// build.sbt
// Set Warts as warnings (compile can still succeed)
wartremoverWarnings += WartRemover.warts.Any
// Or set Warts as errors (compile fails)
wartremoverErrors += Wart.SeqApply
// Or if you want to include all warts with the exception of a few:
wartremoverWarnings ++= Warts.allBut(
Wart.ImplicitParameter,
...
)
Check out the WartRemover Setup Guide for more details.
Two closely-related Warts are SeqIndex
and Head
. These Warts ban the use of Seq.apply
(i.e. indexing into a sequence with parentheses) and Seq.head
, since both can throw exceptions at runtime.
Basic Example
Here’s an example of real code from my Sparklet project that was caught by this Wart:
private[execution] def materialize(ops: Vector[Operation]): Stage[Any, Any] = {
return ops.head match { // WartRemover warning raised: .head is disabled
case MapOp(f) => Stage.map(f.asInstanceOf[Any => Any])
case FilterOp(p) => Stage.filter(p.asInstanceOf[Any => Boolean])
case FlatMapOp(f) => Stage.flatMap(f.asInstanceOf[Any => IterableOnce[Any]])
...
The fix is simple for this case: check that the sequence is non-empty, and then use headOption.get
, which is safe because of the prior check.
private[execution] def materialize(ops: Vector[Operation]): Stage[Any, Any] = {
require(ops.nonEmpty, "Cannot materialize empty operation vector")
if (ops.length == 1) {
return ops.headOption.get match {
case MapOp(f) => Stage.map(f.asInstanceOf[Any => Any])
case FilterOp(p) => Stage.filter(p.asInstanceOf[Any => Boolean])
case FlatMapOp(f) => Stage.flatMap(f.asInstanceOf[Any => IterableOnce[Any]])
...
Similarly, below is an example of replacing Seq.apply
with safe alternatives:
val shuffleInputSources = meta.kind match {
case WideOpKind.Join | WideOpKind.CoGroup =>
Seq(
// Indexing with meta.sides(0) and meta.sides(1) uses Seq.apply;
// this is unsafe if sides has less than 2 elements
ShuffleInput(upstreamIds(0), Some(meta.sides(0)), meta.numPartitions),
ShuffleInput(upstreamIds(1), Some(meta.sides(1)), meta.numPartitions)
)
case _ =>
// Single-input operations
require(upstreamIds.length == 1, s"${meta.kind} requires exactly 1 upstream stage")
Seq(ShuffleInput(upstreamIds.head, None, meta.numPartitions)) // using head
}
We can avoid the apply by first checking the input sequence with require
, and then using take
:
val shuffleInputSources = meta.kind match {
case WideOpKind.Join | WideOpKind.CoGroup =>
require(upstreamIds.length == 2, s"${meta.kind} requires exactly 2 upstream stages")
require(meta.sides.length == 2, s"${meta.kind} requires exactly 2 sides")
val List(upstream1, upstream2) = upstreamIds.take(2).toList
val List(side1, side2) = meta.sides.take(2).toList
Seq(
ShuffleInput(upstream1, Some(side1), meta.numPartitions),
ShuffleInput(upstream2, Some(side2), meta.numPartitions)
)
case _ =>
// Single-input operations
require(upstreamIds.nonEmpty, s"${meta.kind} requires exactly 1 upstream stage")
Seq(ShuffleInput(upstreamIds.headOption.get, None, meta.numPartitions))
}
Of course, the require
can still throw the exception for an emtpy sequence, but this fails fast and is easier to debug than a NoSuchElementException
.
More Complex Example
The value is perhaps best demonstrated in a more complex case, where having WartRemover restrict head
and apply
forces a larger refactor, improving safety significantly.
Consider this Spark pipeline:
import org.apache.spark.sql.{DataFrame, Column}
import org.apache.spark.sql.functions._
case class JoinSpec(leftCol: String, rightCol: String, joinType: String)
def buildDataPipeline(tables: Vector[DataFrame],
joinSpecs: Vector[JoinSpec],
groupByCols: Vector[String]): DataFrame = {
// Unsafe: could throw if empty
val baseDf = tables.head
val joinDfs = tables.tail
// Build joins - unsafe indexing
val joinedDf = joinSpecs.zipWithIndex.foldLeft(baseDf) { case (df, (spec, idx)) =>
df.join(joinDfs(idx), col(spec.leftCol) === col(spec.rightCol), spec.joinType) // Unsafe indexing
}
// Group by - unsafe head access
val primaryGroupCol = groupByCols.head // Unsafe
val restGroupCols = groupByCols.tail.map(col)
joinedDf
.groupBy(col(primaryGroupCol), restGroupCols: _*)
.agg(count("*").as("record_count"), sum(col(groupByCols(0))).as("total")) // Unsafe indexing
}
We can improve this by removing the unsafe head
and apply
indexing cols, and by using a functional for-comprehension with Either
to handle runtime errors cleanly:
// Safe, concise implementation
def buildDataPipeline(
tables: Vector[DataFrame],
joinSpecs: Vector[JoinSpec],
groupByCols: Vector[String]
): Either[String, DataFrame] = {
for {
baseDf <- tables.headOption.toRight("Pipeline requires at least one DataFrame")
_ <- Either.cond(tables.tail.length == joinSpecs.length, (), "Mismatch: tables and join specs")
// Safe zip-based joining (no indexing needed)
joinedDf = tables.tail.zip(joinSpecs).foldLeft(baseDf) { case (df, (joinDf, spec)) =>
df.join(joinDf, col(spec.leftCol) === col(spec.rightCol), spec.joinType)
}
// Safe pattern matching for grouping
result <- groupByCols match {
case Vector() =>
Right(joinedDf.agg(count("*").as("record_count")))
case single +: rest =>
val cols = (single +: rest).map(col)
Right(joinedDf.groupBy(cols: _*).agg(count("*").as("count")))
}
} yield result
}
// Clean usage
buildDataPipeline(
Vector(usersDf, ordersDf),
Vector(JoinSpec("user_id", "user_id", "inner")),
Vector("department", "status")
) match {
case Left(error) => raise new RuntimeException(s"Pipeline failed due to invalid job spec: $error")
case Right(df) => df.show()
}
I plan to write about functional for-comprehensions in a future post, as I find them an interesting topic as well!