Ensemble Estimators for Apache Spark

License Build Status codecov Maven Central

Library of Meta-Estimators à la scikit-learn for Ensemble Learning for Apache Spark ML

Setup

Download the dependency from Maven Central

SBT

libraryDependencies += "com.github.pierrenodet" % "spark-ensemble_2.11" % "0.5.0"

Maven

<dependency>
  <groupId>com.github.pierrenodet</groupId>
  <artifactId>spark-ensemble_2.11</artifactId>
  <version>0.5.0</version>
</dependency>

What's inside

This Spark ML library contains the following algorithms for ensemble learning :

How to use

Loading Features

val raw = spark.read.option("header", "true").option("inferSchema", "true").csv("src/test/resources/data/iris/train.csv")

val vectorAssembler = new VectorAssembler()
.setInputCols(raw.columns.filter(x => !x.equals("class"))).
setOutputCol("features")

val stringIndexer = new StringIndexer()
.setInputCol("class")
.setOutputCol("label")
    
val data = stringIndexer.fit(raw).transform(vectorAssembler.transform(raw))

Base Learner Settings

val baseClassifier = new DecisionTreeClassifier()
.setMaxDepth(20)
.setMaxBin(30)

Meta Estimator Settings

val baggingClassifier = new BaggingClassifier()
.setBaseLearner(baseClassifier)
.setMaxIter(10)
.setParallelism(4)

Train and Test

val Array(train, test) = data.randomSplit(Array(0.7, 0.3))

val model = baggingClassifier.fit(train)

model.models.map(_.asInstanceOf[DecisionTreeClassificationModel])

val predicted = model.transform(test)
predicted.show()

val re = new MulticlassClassificationEvaluator()
println(re.evaluate(predicted))

Cross Validation

val paramGrid = new ParamGridBuilder()
        .addGrid(baggingClassifier.sampleRatioFeatures, Array(0.7,1))
        .addGrid(baggingClassifier.replacementFeatures, Array(x = false))
        .addGrid(baggingClassifier.replacement, Array(x = true))
        .addGrid(baggingClassifier.sampleRatio, Array(0.7, 1))
        .addGrid(baseClassifier.maxDepth, Array(1,10))
        .addGrid(baseClassifier.maxBins, Array(30,40))
        .build()

val cv = new CrossValidator()
        .setEstimator(br)
        .setEvaluator(new MulticlassClassificationEvaluator())
        .setEstimatorParamMaps(paramGrid)
        .setNumFolds(5)
        .setParallelism(4)

val cvModel = cv.fit(data)

cvModel.bestModel.asInstanceOf[BaggingClassificationModel]

Contributing

Feel free to open an issue or make a pull request to contribute to the repository.

Authors

See also the list of contributors who participated in this project.

License

This project is licensed under the Apache License Version 2.0 - see the LICENSE file for details