Classifiying documents using Naive Bayes on Apache Spark / MLlib

Apache SparkIn recent years, Apache Spark has gained in popularity as a faster alternative to Hadoop and it reached a major milestone last month by releasing the production ready version 1.0.0. It claims to be up to a 100 times faster by leveraging the distributed memory of the cluster and by not being tied to the multi stage execution of Map/Reduce. Like Hadoop, it offers a similar ecosystem with a database (Shark SQL), a machine learning library (MLlib), a graph library (GraphX) and many other tools built on top of Spark. Finally Spark integrates well with Scala and one can manipulate distributed collections just like regular Scala collections and Spark will take care of distributing the processing to the different workers.

In this post, we describe how we used Spark / MLlib to classify HTML documents using the popular Reuters 21578 collection of documents that appeared on Reuters newswire in 1987 as a training set.

The Reuters collection can be obtained from

The collection has 123 categories (see this page for more details). To simplify, we will only keep the ones that have between 100 and 1000 documents.

We’re going to do the following steps:

  1. parse XML documents (extract topic and content)
  2. tokenize and stem the documents
  3. create a dictionary out of all the words in the collection of documents and compute IDF (Inverse Document Frequency for each term)
  4. vectorize documents using TF-IDF scores
  5. train the Naive Bayes classifier
  6. classify HTML documents

In order to parse the XML documents we use a simple SAX parser:

object ReutersParser {
  def PopularCategories = Seq("money", "fx", "crude", "grain", "trade", "interest",
                              "wheat", "ship", "corn", "oil", "dlr", "gas", "oilseed",
                              "supply", "sugar", "gnp", "coffee", "veg", "gold", "nat",
                              "soybean", "bop", "livestock", "cpi")

  def parseAll(xmlFiles: Iterable[String]) = xmlFiles flatMap parse

  def parse(xmlFile: String) = {
    val docs = mutable.ArrayBuffer.empty[Document]
    val xml = new XMLEventReader(Source.fromFile(xmlFile, "latin1"))
    var currentDoc: Document = null
    var inTopics = false
    var inLabel = false
    var inBody = false
    for (event <- xml) {
      event match {
        case EvElemStart(_, "REUTERS", attrs, _) =>
          currentDoc = Document(attrs.get("NEWID").get.head.text)

        case EvElemEnd(_, "REUTERS") =>
          if (currentDoc.labels.nonEmpty) {
            docs += currentDoc

        case EvElemStart(_, "TOPICS", _, _) => inTopics = true

        case EvElemEnd(_, "TOPICS") => inTopics = false

        case EvElemStart(_, "D", _, _) => inLabel = true

        case EvElemEnd(_, "D") => inLabel = false

        case EvElemStart(_, "BODY", _, _) => inBody = true

        case EvElemEnd(_, "BODY") => inBody = false

        case EvText(text) =>
          if (text.trim.nonEmpty) {
            if (inTopics && inLabel && PopularCategories.contains(text)) {
              currentDoc = currentDoc.copy(labels = currentDoc.labels + text)
            } else if (inBody) {
              currentDoc = currentDoc.copy(body = currentDoc.body + text.trim)

        case _ =>

case class Document(docId: String, body: String = "", labels: Set[String] = Set.empty)

Then we tokenize the documents using an english Stemmer (e.g., “meets” and “meeting” become “meet”):

  def tokenize(content: String): Seq[String] = {
    val tReader = new StringReader(content)
    val analyzer = new EnglishAnalyzer(LuceneVersion)
    val tStream = analyzer.tokenStream("contents", tReader)
    val term = tStream.addAttribute(classOf[CharTermAttribute])

    val result = mutable.ArrayBuffer.empty[String]
    while(tStream.incrementToken()) {
      val termValue = term.toString
      if (!(termValue matches ".*[\\d\\.].*")) {
        result += term.toString

Now let’s create the dictionary, vectorize the documents and train the Naive Bayes classifier on Spark.

We initialize the Spark Context to run locally using 4 workers:

  val sc = new SparkContext("local[4]", "naivebayes")

Then we convert the collection of documents to a Resilient Distributed DataSet (RDD) using sc.parallelize():

    val termDocsRdd = sc.parallelize[TermDoc](termDocs.toSeq)

Then in order to vectorize the documents, we create a dictionary that contains all the words contained in all the documents. This is simply achieved using a simple transformation:

    val terms = termDocsRdd.flatMap(_.terms).distinct().collect().sortBy(identity)

Spark will take care of distributing the work to the different workers and collect() will collect the data from the different workers.

Based on the dictionary, we compute the IDF score for each term. There are different formulas to calculate IDF scores. It’s usually:

idf(term, docs) = log[(number of documents) / (number of documents containing term)]

However in the implementation of Naive Bayes in MLlib, it’s using log, so we can get rid of it in the formula.

idf(term, docs) = (number of documents) / (number of documents containing term)

We also exclude words that are present in less than 3 documents (arbitrary) to remove too specific terms:

    val idfs = (termDocsRdd.flatMap(termDoc =>, _))).distinct().groupBy(_._2) collect {
      case (term, docs) if docs.size > 3 =>
        term -> (numDocs.toDouble / docs.size.toDouble)

We then vectorize each document by computing the TF-IDF score for each term they contain:

    (filteredTerms.groupBy(identity).map {
      case (term, instances) =>
        (indexOf(term), (instances.size.toDouble / filteredTerms.size.toDouble) * idfs(term))
    }).toSeq.sortBy(_._1) // sort by termId

and convert them into a collection of LabeledPoints. Each LabeledPoint represents a training document associated to a label id (a double number) and a sparse vector:

    val tfidfs = termDocsRdd flatMap {
      termDoc =>
        val termPairs = termDict.tfIdfs(termDoc.terms, idfs) {
          label =>
            val labelId = labelDict.indexOf(label).toDouble
            val vector = Vectors.sparse(termDict.count, termPairs)
            LabeledPoint(labelId, vector)

We then train the Naive Bayes classifier:

    val model = NaiveBayes.train(tfidfs)

We provide a REPL console that is expecting as input a URL. We then extract the text content using Goose developed by Gravity Labs and run the Naive Bayes classifier to predict the label:

    // extract content from HTML
    val config = new Configuration
    val goose = new Goose(config)
    val content = goose.extractContent(url).cleanedArticleText
    // tokenize content and stem it
    val tokens = Tokenizer.tokenize(content)
    // compute TFIDF vector
    val tfIdfs = naiveBayesAndDictionaries.termDictionary.tfIdfs(tokens, naiveBayesAndDictionaries.idfs)
    val vector = naiveBayesAndDictionaries.termDictionary.vectorize(tfIdfs)
    // classify document
    val labelId = model.predict(vector)
    // convert label from double
    println("Label: " + naiveBayesAndDictionaries.labelDictionary.valueOf(labelId.toInt))

Running the example

Install GravityLabs Goose (we forked their project to update their dependencies to use Scala 2.10) in your local Maven repository:

$ git clone
$ cd goose
$ mvn install

Build and run the Naive Bayes classifier:

$ git clone
$ cd blog-spark-naive-bayes-reuters
$ ./
$ sbt run

You will be prompted to enter some URLs. For instance, you can use the followings ones:

You can also see the jobs that were run on a web interface at: http://localhost:4040/.


In this post, we described in a simple example how we can use Spark to classify documents using Naive Bayes. There are many other aspects of Spark that are also interesting: ability to broadcast variables to workers, cache results, ingest data streams, ….

Even though MLlib is still very young and offer much less algorithm implementations than Mahout, it is faster and their team is working on adding more algorithms. On the other hand Mahout is moving to Spark to offer better performance. So it should be interesting to see in a few months how things are evolving.


About chimpler

4 Responses to Classifiying documents using Naive Bayes on Apache Spark / MLlib

  1. Pingback: Classifiying documents using Naive Bayes on Apa...

  2. Pingback: Classifiying paperwork utilizing Naive Bayes on Apache Spark / MLlib | Ragnarok Connection

  3. Jatin says:

    Thanks for the detailed tutorial, this works great but I have to implement it using Java 8 and am stuck in creating TFIDF vectors. Can you please provide any direction on implementing the same logic using Java 8 (including lambda expressions).

  4. qingwufu says:

    I run the command follow your blog, but exception like below always come out :
    Exception in thread “main” java.lang.NoClassDefFoundError: scala/reflect/ClassManifest
    at com.gravity.goose.Configuration.(Configuration.scala:118)
    Could you help me?Thank you.

Leave a Reply

Fill in your details below or click an icon to log in: Logo

You are commenting using your account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s

%d bloggers like this: