Classifiying documents using Naive Bayes on Apache Spark / MLlib
2014/06/11 4 Comments
In recent years, Apache Spark has gained in popularity as a faster alternative to Hadoop and it reached a major milestone last month by releasing the production ready version 1.0.0. It claims to be up to a 100 times faster by leveraging the distributed memory of the cluster and by not being tied to the multi stage execution of Map/Reduce. Like Hadoop, it offers a similar ecosystem with a database (Shark SQL), a machine learning library (MLlib), a graph library (GraphX) and many other tools built on top of Spark. Finally Spark integrates well with Scala and one can manipulate distributed collections just like regular Scala collections and Spark will take care of distributing the processing to the different workers.
In this post, we describe how we used Spark / MLlib to classify HTML documents using the popular Reuters 21578 collection of documents that appeared on Reuters newswire in 1987 as a training set.
The Reuters collection can be obtained from http://archive.ics.uci.edu/ml/machine-learning-databases/reuters21578-mld/.
The collection has 123 categories (see this page for more details). To simplify, we will only keep the ones that have between 100 and 1000 documents.
We’re going to do the following steps:
- parse XML documents (extract topic and content)
- tokenize and stem the documents
- create a dictionary out of all the words in the collection of documents and compute IDF (Inverse Document Frequency for each term)
- vectorize documents using TF-IDF scores
- train the Naive Bayes classifier
- classify HTML documents
In order to parse the XML documents we use a simple SAX parser:
object ReutersParser { def PopularCategories = Seq("money", "fx", "crude", "grain", "trade", "interest", "wheat", "ship", "corn", "oil", "dlr", "gas", "oilseed", "supply", "sugar", "gnp", "coffee", "veg", "gold", "nat", "soybean", "bop", "livestock", "cpi") def parseAll(xmlFiles: Iterable[String]) = xmlFiles flatMap parse def parse(xmlFile: String) = { val docs = mutable.ArrayBuffer.empty[Document] val xml = new XMLEventReader(Source.fromFile(xmlFile, "latin1")) var currentDoc: Document = null var inTopics = false var inLabel = false var inBody = false for (event <- xml) { event match { case EvElemStart(_, "REUTERS", attrs, _) => currentDoc = Document(attrs.get("NEWID").get.head.text) case EvElemEnd(_, "REUTERS") => if (currentDoc.labels.nonEmpty) { docs += currentDoc } case EvElemStart(_, "TOPICS", _, _) => inTopics = true case EvElemEnd(_, "TOPICS") => inTopics = false case EvElemStart(_, "D", _, _) => inLabel = true case EvElemEnd(_, "D") => inLabel = false case EvElemStart(_, "BODY", _, _) => inBody = true case EvElemEnd(_, "BODY") => inBody = false case EvText(text) => if (text.trim.nonEmpty) { if (inTopics && inLabel && PopularCategories.contains(text)) { currentDoc = currentDoc.copy(labels = currentDoc.labels + text) } else if (inBody) { currentDoc = currentDoc.copy(body = currentDoc.body + text.trim) } } case _ => } } docs } } case class Document(docId: String, body: String = "", labels: Set[String] = Set.empty)
Then we tokenize the documents using an english Stemmer (e.g., “meets” and “meeting” become “meet”):
def tokenize(content: String): Seq[String] = { val tReader = new StringReader(content) val analyzer = new EnglishAnalyzer(LuceneVersion) val tStream = analyzer.tokenStream("contents", tReader) val term = tStream.addAttribute(classOf[CharTermAttribute]) tStream.reset() val result = mutable.ArrayBuffer.empty[String] while(tStream.incrementToken()) { val termValue = term.toString if (!(termValue matches ".*[\\d\\.].*")) { result += term.toString } } result }
Now let’s create the dictionary, vectorize the documents and train the Naive Bayes classifier on Spark.
We initialize the Spark Context to run locally using 4 workers:
val sc = new SparkContext("local[4]", "naivebayes")
Then we convert the collection of documents to a Resilient Distributed DataSet (RDD) using sc.parallelize():
val termDocsRdd = sc.parallelize[TermDoc](termDocs.toSeq)
Then in order to vectorize the documents, we create a dictionary that contains all the words contained in all the documents. This is simply achieved using a simple transformation:
val terms = termDocsRdd.flatMap(_.terms).distinct().collect().sortBy(identity)
Spark will take care of distributing the work to the different workers and collect() will collect the data from the different workers.
Based on the dictionary, we compute the IDF score for each term. There are different formulas to calculate IDF scores. It’s usually:
idf(term, docs) = log[(number of documents) / (number of documents containing term)]
However in the implementation of Naive Bayes in MLlib, it’s using log, so we can get rid of it in the formula.
idf(term, docs) = (number of documents) / (number of documents containing term)
We also exclude words that are present in less than 3 documents (arbitrary) to remove too specific terms:
val idfs = (termDocsRdd.flatMap(termDoc => termDoc.terms.map((termDoc.doc, _))).distinct().groupBy(_._2) collect { case (term, docs) if docs.size > 3 => term -> (numDocs.toDouble / docs.size.toDouble) }).collect.toMap
We then vectorize each document by computing the TF-IDF score for each term they contain:
(filteredTerms.groupBy(identity).map { case (term, instances) => (indexOf(term), (instances.size.toDouble / filteredTerms.size.toDouble) * idfs(term)) }).toSeq.sortBy(_._1) // sort by termId
and convert them into a collection of LabeledPoints. Each LabeledPoint represents a training document associated to a label id (a double number) and a sparse vector:
val tfidfs = termDocsRdd flatMap { termDoc => val termPairs = termDict.tfIdfs(termDoc.terms, idfs) termDoc.labels.headOption.map { label => val labelId = labelDict.indexOf(label).toDouble val vector = Vectors.sparse(termDict.count, termPairs) LabeledPoint(labelId, vector) } }
We then train the Naive Bayes classifier:
val model = NaiveBayes.train(tfidfs)
We provide a REPL console that is expecting as input a URL. We then extract the text content using Goose developed by Gravity Labs and run the Naive Bayes classifier to predict the label:
// extract content from HTML val config = new Configuration config.setEnableImageFetching(false) val goose = new Goose(config) val content = goose.extractContent(url).cleanedArticleText // tokenize content and stem it val tokens = Tokenizer.tokenize(content) // compute TFIDF vector val tfIdfs = naiveBayesAndDictionaries.termDictionary.tfIdfs(tokens, naiveBayesAndDictionaries.idfs) val vector = naiveBayesAndDictionaries.termDictionary.vectorize(tfIdfs) // classify document val labelId = model.predict(vector) // convert label from double println("Label: " + naiveBayesAndDictionaries.labelDictionary.valueOf(labelId.toInt))
Running the example
Install GravityLabs Goose (we forked their project to update their dependencies to use Scala 2.10) in your local Maven repository:
$ git clone http://github.com/chimpler/goose $ cd goose $ mvn install
Build and run the Naive Bayes classifier:
$ git clone http://github.com/chimpler/blog-spark-naive-bayes-reuters $ cd blog-spark-naive-bayes-reuters $ ./download_reuters.sh $ sbt run
You will be prompted to enter some URLs. For instance, you can use the followings ones:
- http://www.coinflation.com/coins/1942-1945-Silver-War-Nickel-Value.html
- http://www.businessweek.com/news/2014-06-10/china-using-dubai-style-fake-islands-to-reshape-south-china-sea
- http://en.wikipedia.org/wiki/Soybean
- http://en.wikipedia.org/wiki/Whole_wheat_bread
- http://en.wiktionary.rg/wiki/cow
- http://www.businessweek.com/articles/2014-02-13/to-stop-the-coffee-apocalypse-starbucks-buys-a-farm
You can also see the jobs that were run on a web interface at: http://localhost:4040/.
Conclusion
In this post, we described in a simple example how we can use Spark to classify documents using Naive Bayes. There are many other aspects of Spark that are also interesting: ability to broadcast variables to workers, cache results, ingest data streams, ….
Even though MLlib is still very young and offer much less algorithm implementations than Mahout, it is faster and their team is working on adding more algorithms. On the other hand Mahout is moving to Spark to offer better performance. So it should be interesting to see in a few months how things are evolving.
Pingback: Classifiying documents using Naive Bayes on Apa...
Pingback: Classifiying paperwork utilizing Naive Bayes on Apache Spark / MLlib | Ragnarok Connection
Thanks for the detailed tutorial, this works great but I have to implement it using Java 8 and am stuck in creating TFIDF vectors. Can you please provide any direction on implementing the same logic using Java 8 (including lambda expressions).
I run the command follow your blog, but exception like below always come out :
Exception in thread “main” java.lang.NoClassDefFoundError: scala/reflect/ClassManifest
at com.gravity.goose.network.HtmlFetcher$.(HtmlFetcher.scala:66)
at com.gravity.goose.network.HtmlFetcher$.(HtmlFetcher.scala)
at com.gravity.goose.Configuration.(Configuration.scala:118)
Could you help me?Thank you.