seahrh / spark-util

Utility for common use cases and bug workarounds in Apache Spark 2




Apache Spark utility for common use cases and bug workarounds.


Getting started

Add the following to your build.sbt

libraryDependencies += "com.sgcharts" %% "spark-util" % "0.4.1"

Handling the imbalanced class problem with SMOTE

There is a number of ways to deal with the imbalanced class problem but they have drawbacks:

  • Under-sample the majority class - losing data
  • Over-sample the minority class - risk overfitting

Alternatively, SMOTE over-samples the minority class by creating “synthetic” examples rather than by over-sampling with replacement.

Synthetic examples

Synthetic examples are generated in the following way:

  • Take the difference between the feature vector (sample) under consideration and its nearest neighbour
  • Multiply this difference by a random number between 0 and 1, and add it to the sample

For discrete attributes, the synthetic example randomly picks either the sample or the neighbour, and copies that value.

By forcing the decision region of the minority class to become more general, SMOTE reduces overfitting.

Based on Chawla, N. V., Bowyer, K. W., Hall, L. O., & Kegelmeyer, W. P. (2002). SMOTE: synthetic minority over-sampling technique. Journal of artificial intelligence research, 16, 321-357.

SMOTE in Spark

Nearest neighbours are approximated with the Locality Sensitive Hashing (LSH) model introduced in Spark ML 2.1.

Collecting data is deliberately avoided, so that the driver does not require additional memory.


import com.sgcharts.sparkutil.Smote

val df: DataFrame = Smote(
    sample = minorityClassDf,
    discreteStringAttributes = Seq("name"),
    discreteLongAttributes = Seq("house_id", "house_zip"),
    continuousAttributes = Seq("age", "rent_amount")

Dataset union bug

Union of Dataset is bugged (SPARK-21109). Internally, union resolves by column position (not by name).

Solution: Convert Dataset to DataFrame, then reorder the column positions so that they are the same as the first operand.

You can also apply union on more than two DataFrames in one call (varargs). This is unlike the Spark API which takes in only two at a time.

import com.sgcharts.sparkutil.union

val ds1: Dataset[MyCaseClass] = ???
val ds2: Dataset[MyCaseClass] = ???
val ds3: Dataset[MyCaseClass] = ???
import spark.implicits._
val res: Dataset[MyCaseClass] = union(ds1.toDF, ds2.toDF, ds3.toDF).as[MyCaseClass]

Saving partitioned Hive table bug

Hive partitions written by the DataFrameWriter#saveAsTable API, are not registered in the Hive metastore (SPARK-14927). Hence the partitions are not accessible in Hive.

As of Spark 2.3.1, the DataFrameWriter API is still bugged out, with the insertInto and saveAsTable giving different problems on Hive. For example, saveAsTable will always overwrite existing partitions, effectively allowing only one partition to exist at any one time.

Solution: Instead of using the saveAsTable API, register the partition explicitly.

TablePartition contains a Dataset to be written to a single partition. Optionally, set the number of files to write per partition (default=1).

Usage: saving a single partition

import com.sgcharts.sparkutil.ParquetTablePartition

val ds: Dataset[MyTableSchema] = ???

val part = ParquetTablePartition[MyTableSchema](

// Overwrite partition

// Or append data to an existing partition

Count by key

CountAccumulator extends org.apache.spark.util.AccumulatorV2. It can count any key that implements Ordering. The accumulator returns a SortedMap of the keys and their counts.

Unlike HashMap, SortedMap uses compareTo instead of equals to determine whether two keys are the same. For example, consider the BigDecimal class whose compareTo method is inconsistent with equals. If only two keys BigDecimal("1.0") and BigDecimal("1.00") exist, the resulting SortedMap will contain only one entry because the two keys are equal when compared using the compareTo method.

On creation, CountAccumulator automatically registers the accumulator with SparkContext.


import org.apache.spark.SparkContext
import com.sgcharts.sparkutil.CountAccumulator

val sc: SparkContext = ???
val a = CountAccumulator[String](sc, Option("my accumulator name")) // Counting String keys
val df: DataFrame = ???
df.foreach(x => a.add(x))
val result: SortedMap[String, Long] = a.value

See CountAccumulatorSpec for more examples.

Based on hammerlab's spark-util.


Spark uses log4j (not logback).

Writes to console stderr (default in spark/conf)


import com.sgcharts.sparkutil.Log4jLogging

object MySparkApp extends Log4jLogging {"Hello World!")