Facilitating the Spread of Knowledge and Innovation in Professional Software Development

Write for InfoQ


Choose your language

InfoQ Homepage Articles Health Informatics and Survival Prediction of Cancer with Apache Spark Machine Learning Library

Health Informatics and Survival Prediction of Cancer with Apache Spark Machine Learning Library


Apache Spark is an open-source cluster based computing engine for large-scale data processing. Spark applications can be deployed in a stand-alone server or in a Hadoop cluster. There is comprehensive literature covering Spark framework, e.g. see the official Apache Spark documentation and Apache Spark review articles such as Big Data Processing with Apache Spark, and presentation on Unified Big Data Processing with Apache Spark.

Our interest in this article is the Spark MLlib library. MLlib provides implementations of various machine learning and statistical algorithms including summary statistics, correlation, clustering, classification and regression. This article will focus on classification algorithms and the corresponding performance metrics.

I will discuss an example where survival prediction of colorectal cancer is formulated as a multi-class classification problem. Then, we will show how to solve that problem using the MLlib Java API.

To start off the discussion in the article, I will introduce an example where survival prediction of colorectal cancer is formulated as a multi-class classification problem. Then I will discuss four different machine learning algorithms to solve that problem and study the related performance metrics. The Test Setup section outlines the Java code for the algorithms and covers the task of how to create the training and test data sets.

Then we will look at the Java code details and finally discuss the findings for each of the algorithms. Here, we also look at the binary classification problem, which is a special case of multi-class classification, and indicate that all algorithms yield improved results in that special case.

Problem Statement

The Surveillance, Epidemiology, and End Results (SEER) Program of the National Cancer Institute publishes cancer incidence and survival statistics from various population-based cancer registries in the United States. We will consider SEER colorectal cancer statistics and try to predict cancer survival months from a set of clinically important input variables. Note that we do not aim to obtain scientifically acceptable results. Rather, this article is a tutorial for the Spark MLlib library, in particular multi-class classification problem and performance metrics API, utilizing real health informatics data.

Multi-class Classification Problem

For a given colorectal cancer patient, try to predict which of the following periods a patient's survival belongs to:

  • "survival <= 36 months"
  • "36 months < survival <= 72 months"
  • "72 months < survival"

We will use machine learning algorithms to solve that problem. Each of the survival periods will be represented by a "label" (each label uniquely identifies a distinct "class") and each of the input variables will be represented by a "feature". We will train the algorithm with a training data set that consists of instances where each instance has known features, i.e. values for input variables, and the corresponding known label. Once the model is trained, we will measure its performance with a test data set. For each instance in test data, features and the corresponding label are known. By supplying the features to the model we will obtain the predicted label and compare it to the actual known label.




Patient survival is greater than or equal to 0 and less than 36 months after diagnosis.


Patient survival is greater than or equal to 36 and less than 72 months after diagnosis.


Patient survival is greater than or equal to 72 months after diagnosis.

Table 1. Label Descriptions

Features are explained in the following table.





Number of regional lymph nodes found to contain metastases in pathology test.

Regional Nodes Positive


Size of the primary tumor.

CS Tumor Size


Age group the patient belongs to e.g. age 0, ages 1-4, ages 5-9 etc.

Age Recode


Microscopic composition of cells and/or tissue of the primary tumor, represented by ICD-O-3 histologic type codes.

Histologic Type ICD-O-3


Stage (level of progression) of cancer in tumor determined according to American Joint Committee on Cancer (AJCC) staging system 6-th edition.

Derived AJCC Stage Group


Stage of cancer in tumor determined according to a simplified classification system by SEER Cancer Statistics Review.

Summary stage 2000 (1998+)


Stage of cancer in tumor determined according to a classification system by SEER called derived 'SEER Summary Stage 2000'.

Derived SS2000


Distant metastasis, i.e. spread of cancer from primary site to other organs determined according to CS cancer schema.

CS Mets at Dx

Table 2. Feature Descriptions

We used the colorectal cancer data file from SEER 1973-2012 (November 2014 Submission) database. This is a fixed width ASCII file where each row corresponds to the record of a unique patient. Format of the file is explained here. Individual data fields in each record are identified by their column positions. Each "Item Name" in Feature table above can be looked up to that document to get more information about the corresponding feature.

Each feature is categorical and has been measured at the time of diagnosis. The label, patient survival duration, corresponds to item name "Survival months – presumed alive" in the data file. We converted the survival months into one of the 3 labels according to the interval it belonged to. For example, the value 0048 i.e. 48 months is converted to label 1, because it falls in range "36 months < survival <= 72 months".

Machine Learning Algorithms

Naïve Bayes

Naive Bayes is a group of probabilistic techniques for classification, with many real-life applications including medical diagnostics, e.g. see "Inductive and Bayesian learning in medical diagnosis". From MLlib library we will employ a particular type of Bayesian technique called Multinomial Naive Bayes.

Multinomial Logistic Regression

Multinomial Logistic Regression is a linear regression technique, i.e. it utilizes linear predictor functions for modeling the relationship between the outcome and the input variables. A particular application area for logistic regression is medical sciences e.g. cancer survival and trauma & injury severity scoring.

MLlib supports various logistic regression algorithms, of which we have used Limited-memory BFGS. That algorithm is based on Broyden–Fletcher–Goldfarb–Shanno (BFGS) algorithm, with techniques to reduce use of computer memory.

Decision Tree and Random Forest

Decision trees are used in many types of machine learning problems including multi-class classification. MLlib supports both basic decision tree algorithm and ensembles of trees, which are composed of multiple tree models. MLlib provides two ensemble algorithms, Gradient-Boosted Trees and Random Forests. At the time of writing this article, MLlib Gradient-Boosted Tress did not support multi-class classification and therefore we only focused on Decision Tree and Random Forest.

Performance Metrics

A detailed discussion of MLlib performance metrics are given here. For the example application we use the following metrics.

Individual Label Statistics

Confusion Matrix

In multi-class classification, Confusion Matrix is a square matrix with as many rows and columns as the number of labels. Each entry of the matrix is a nonnegative integer and each row represents a label. In a given row, the total of numbers across the columns is equal to the number of instances of the particular label in the data set. Each column represents a 'predicted' label. In a given column, the total of numbers across the rows is equal to the number of times the particular label was predicted by the algorithm.

As an example consider the confusion matrix below. (This is the confusion matrix for the trained Random Forest model in the later section.)


Predicted label 0

"survival <= 36"

Predicted label 1

"36 < survival <= 72"

Predicted label 2

"72 < survival"

Actual label 0

"survival <= 36"




Actual label 1

"36 < survival <= 72"




Actual label 2

"72 < survival"




Table 3. Confusion Matrix

Consider the row for label 0 "survival <= 36" (highlighted). Total number of instances for the label is 433 (= 358 + 64 + 11). The algorithm has predicted 358 instances of that label correctly. However, it predicted 64 instances as label 1, i.e. "36 < survival <= 72", incorrectly, and 11 instances as label 2, i.e. "72 < survival", incorrectly.

Similarly, consider the column for predicted label 0 (highlighted). The algorithm has made 435 (= 358 + 58 + 19) predictions for label 0 "survival <= 36". Of those, 358 are correct and 77 (= 58 + 19) are wrong. Of the wrong predictions, 58 correspond to label 1 "36 < survival <= 72" and 19 correspond to label 2 "72 < survival".

Precision and Recall

Precision of a label is the # times the label is correctly predicted divided by # times any label is predicted as that particular label. The confusion matrix can be used to calculate the precision. In the above example consider the column for predicted label 0 "survival <= 36". The precision for that label is 0.82 (= 358/435).

Recall of a label is the # times the label is correctly predicted divided by actual instances of the label. The confusion matrix can also be used to calculate the recall. In the above example consider the row for label 0 "survival <= 36". The recall for that label is 0.83 (= 358/433).

Both precision and recall are numbers between 0 and 1. If they both are close to 1, we consider the prediction of the label is successful although closeness is a relative term depending on context. In the ideal case where for each label prediction and recall are both 1, the confusion matrix will have zeros in all non- diagonal entries.

Overall Statistics

The concepts of precision and recall can be extended to overall statistics.

Weighted Precision is calculated by weighting precision of each label by #instances of the corresponding label and them summing them up and finally dividing the sum by total #instances of the labels. Weighted Recall is calculated similarly.

Test Setup

For creating the data set, we filtered the SEER 1973-2012 (November 2014 Submission) colorectal cancer data file as follows.

  • Records that supplied 'unknown', 'not examined' or white space as the corresponding value for the considered features were eliminated.
  • To avoid skewing of results, we retained only those records with an observation period of at least 4 years. The observation period for any patient in the data file is end of 2012. Therefore, records where year of diagnosis is 2009 or later were eliminated.
  • Records for patients who expired for reasons other than colorectal cancer were eliminated to avoid skewing of survival period.

After the initial data filtering, we created a new data file by extracting the particular features as explained in Features table. All features are categorical, i.e. each feature takes discrete values belonging to a known set. If the value set for a feature had too many categorical values, we compacted the value set into a smaller one by grouping closely related values into a single value. The MLlib Decision Tree and Random Forest algorithms require that a categorical variable assume values in the form of 0, 1, 2, …. Therefore, we transformed value of each feature to one of the values in the set {0, 1, 2, …}.

The final data file has LIBSVM format, a commonly used format to represent data to be processed by machine learning algorithms. We used the MLlib utility API to parse SVM data files while running the programs.

While training the models, we used the k-fold cross validation technique where data is split into 10 (k = 10 in our case) equally sized groups so that 9 of the groups are allocated for training and the remaining group is allocated for validation. The training and validation steps are repeated 10 times so that each of the data groups is used for validation once. MLlib also has an API for splitting data to be used in k-fold cross validation.

The tests were performed in a single server with a single-node Hadoop installation, version 2.7.1.

The Spark text programs were executed in a separate JVM. The data file resided in the Hadoop node.

For each of the classification algorithms we had a separate test run. Each test run processed the same data set. Each test program consisted of the following steps:

  • Initialize Spark configuration & context.
  • Load the data file from Hadoop and parse it.
  • Randomly split data for 10-fold cross validation. Then, repeat the next two steps 10 times for each of the training and validation sets.
  • Train the predictive model with the training data based on the particular algorithm.
  • Once model is created, obtain the performance metrics for both the training data and test data. To do that, in each case, process the particular data set and predict the label for each record. At the end, compare the predicted labels with actual labels and evaluate the performance using the metrics we discussed above.
  • Finally, pick the best model out of 10 based on performance.

MLlib API Code Review

The programs for Naïve Bayes, Multinomial Logistic Regression, Decision Tree and Random Forest are very similar and have the following common code.

Each program starts with creating a new Spark configuration and then a Spark context. Next, the data file from Hadoop is loaded and parsed. Then, data is split into training and test sets.

Those steps are shown below.

// Set application name, e.g. for Naive Bayes:
String appName = "NaiveBayesClassifier";

// For Multinomial Logistic Regression:
// String appName = "MultinomialLogisticClassifier";
// For Decision Tree:
// String appName = "DecisionTreeClassifier";

// For Random Forest:
// String appName = "RandomForestClassifier";
// Initialize Spark configuration & context
SparkConf conf = new SparkConf()
   .setMaster("local[1]").set("spark.executor.memory", "1g");
SparkContext sc = new SparkContext(conf);
// Load data file from Hadoop and parse.
String path = "hdfs://localhost:9000/user/konur/COLRECT_SVM.txt";
JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD();
// Obtain 10 sets of training and test data. 
Tuple2<RDD<LabeledPoint>,RDD<LabeledPoint>>[] myTuple = 
   MLUtils.kFold(data.rdd(), 10, 12345, data.classTag());

// Train/validate the algorithm once for each set.
for(int i = 0; i < myTuple.length; i++){
   JavaRDD<LabeledPoint> trainingData = (new  

   JavaRDD<LabeledPoint> testData = new 

   // Directory location to write the performance stats for this run
   String debugPathName = ...;


We finalize each program by stopping the Spark context.


The implementation of kRun() is very similar for different algorithms. For Naive Bayes:

private static final void kRun(JavaRDD<LabeledPoint> trainingData, 
   JavaRDD<LabeledPoint> testData, Path dbg){

   // Train Naive Bayes model
   final NaiveBayesModel model = NaiveBayes.train(trainingData.rdd());

   // Obtain performance metrics and write into debug file. The dbg
   // variable is a path in file system to write results.
   debug(trainingData,testData,"Training Data","Test Data",model,dbg);

For Multinomial Logistic Regression we pass the number of labels to the algorithm via setNumClasses(). For Naive Bayes there is no need to define number of labels.

private static final void kRun(JavaRDD<LabeledPoint> trainingData, 
   JavaRDD<LabeledPoint> testData, Path dbg){
   // Train Logistic Regression Model
   final LogisticRegressionModel model = new LogisticRegressionWithLBFGS()

   // Obtain performance metrics and write into debug file
   debug(trainingData,testData,"Training Data","Test Data",model,dbg);

For Decision Tree and Random Forest, the kRun() method is slightly different. In both cases, the method starts as follows.

private static final void kRun(JavaRDD<LabeledPoint> trainingData, 
   JavaRDD<LabeledPoint> testData, Path dbg){
// Set the number of classes.
Integer numClasses = 3;
   // Empty categoricalFeaturesInfo indicates all features are continuous.
   // In our case all features are categorical. For better performance, 
   // we will enter #categorical values for each feature.
   HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer,  
   // Regional Nodes Positive: 6 categorical values.
   categoricalFeaturesInfo.put(0, 6);
   // CS Tumor Size: 9 categorical values.
   categoricalFeaturesInfo.put(1, 9);
   // Age Group: 16 categorical values.
   categoricalFeaturesInfo.put(2, 16);
   // Histologic Type ICD-O-3 code: 27 categorical values.
   categoricalFeaturesInfo.put(3, 27);
   // Derived AJCC Stage Group: 9 categorical values.
   categoricalFeaturesInfo.put(4, 9);
   // Summary Stage: 4 categorical values.
   categoricalFeaturesInfo.put(5, 4);
   // Derived Summary Stage: 6 categorical values.
   categoricalFeaturesInfo.put(6, 6);
   // CS Mets at Diagnosis: 12 categorical values.
   categoricalFeaturesInfo.put(7, 12);
   String impurity = "gini";

   // Set depth of the tree
   Integer maxDepth = 30;
   // This must be at least the maximum number of categories for any 
   // categorical feature. 
   Integer maxBins = 300;

Notes regarding above code fragment:

  • Number of classes is defined explicitly.
  • A decision tree would accept both categorical and continuous features. The API would always accept a feature as continuous even if it is categorical. However, MLlib documentation states that better performance can be achieved by indicating to the algorithm which features are actually categorical. For that reason, we define a HashMap<Integer, Integer> object named categoricalFeaturesInfo, and populate it with the number of categorical values for every categorical feature.
  • "Impurity" is a measure used during decision making while creating the decision tree. MLlib supports two distinct impurity measures called Gini and Entropy. We did not observe any notable difference in performance between the two. The sample code above uses Gini as the impurity measure.

At this point, we can train the Decision Tree model.

   final DecisionTreeModel model = DecisionTree.trainClassifier(trainingData,
      numClasses, categoricalFeaturesInfo, impurity, maxDepth, maxBins);

For Random Forest, there are a few more steps needed.

// Random seed for bootstrapping and choosing feature subsets.
   Integer seed = 12345;
   // Total number of trees in the forest
   Integer numTrees = 15;
   // This parameter determines number of features to consider while 
   // partitioning the feature space at a tree node. The API supports all, 
   // sqrt, log2 and onethird and auto. With all, we observed the best 
   // performance.
   String featureSubsetStrategy = "all"; 

Now, we can train the Random Forest model.

   final RandomForestModel model = RandomForest.trainClassifier(trainingData,
      numClasses, categoricalFeaturesInfo, numTrees, featureSubsetStrategy,
      impurity, maxDepth, maxBins, seed);

For both Decision Tree and Random Forest, kRun() finishes same as before.

  // Obtain performance metrics and write into debug file. The dbg
   // variable is a path in file system to write results.
   debug(trainingData,testData,"Training Data","Test Data",model,dbg);

Now, having the models are trained, we can discuss the debug() method which is common to all algorithms.

The debug() program will execute the following code twice, once for trainingData and once for testData, via a utility method. In the following excerpts, data object is a pointer to either trainingData or testData.

First, define a Function that takes a LabeledPoint p, i.e. a particular row in data file, predicts the label, denoted by prediction, and returns a Tuple2 object that consists of the prediction (the predicted label) and the actual label. Next, the Function is passed to the map() method of the data object to obtain the predicted and the actual labels for all rows in the data file. The resulting data store is called predictionAndLabels.

JavaRDD<Tuple2<Object, Object>> predictionAndLabels = data
   .map(new Function<LabeledPoint, Tuple2<Object, Object>>() {
   public Tuple2<Object, Object> call(LabeledPoint p) {
         Double prediction = model.predict(p.features());
         return new Tuple2<Object, Object>(prediction, p.label());

Then, process predictionAndLabels to get performance metrics. The metrics object below will calculate the confusion matrix, weighted precision and weighted recall.

MulticlassMetrics metrics = new MulticlassMetrics(predictionAndLabels.rdd());

Matrix confusion = metrics.confusionMatrix();
byte[] dt = ("\nConfusion matrix: \n" + confusion + "\n\n").getBytes();
out.write(dt, 0, dt.length);
dt = ("\nWeighted precision = " + metrics.weightedPrecision() + "\n").getBytes();
out.write(dt, 0, dt.length);
dt = ("Weighted recall = " + metrics.weightedRecall() + "\n").getBytes();
out.write(dt, 0, dt.length);

Next, we will use the metrics object to calculate precision and recall for each of the individual labels.

for (int i = 0; i < metrics.labels().length; i++) {
   dt = ("\nClass " + metrics.labels()[i] + " precision = " +
      metrics.precision(metrics.labels()[i]) + "\n").getBytes();
   out.write(dt, 0, dt.length);
   dt = ("Class " + metrics.labels()[i] + " recall = " + 
      metrics.recall(metrics.labels()[i]) + "\n").getBytes();
   out.write(dt, 0, dt.length);

Discussion of Results

For each algorithm, we selected the best model according to overall statistics (weighted recall and precision) of test data out of 10 cross-validation runs. Results on test data are summarized below where algorithms are sorted in increasing value of the overall statistics from left to right.

Multinomial Logistic Regression

Naïve Bayes

Decision Tree

Random Forest

Confusion matrix:

301 50 82

38 104 594

8 79 685

Weighted precision = 0.56

Weighted recall = 0.56

Label_0 precision = 0.87

Label_0 recall = 0.70

Label_1 precision = 0.45

Label_1 recall = 0.14

Label_2 precision = 0.50

Label_2 recall = 0.89

Confusion matrix:

252 74 67

48 209 458

17 144 606

Weighted precision = 0.57

Weighted recall = 0.57

Label_0 precision = 0.79

Label_0 recall = 0.64

Label_1 precision = 0.49

Label_1 recall = 0.29

Label_2 precision = 0.54

Label_2 recall = 0.79

Confusion matrix:

372 57 27

67 254 390

29 186 511

Weighted precision = 0.59

Weighted recall = 0.60

Label_0 precision = 0.79

Label_0 recall = 0.82

Label 1 precision = 0.51

Label_1 recall = 0.36

Label_2 precision = 0.55

Label_2 recall = 0.70

Confusion matrix:

358 64 11

58 282 396

19 200 553

Weighted precision = 0.61

Weighted recall = 0.61

Label_0 precision = 0.82

Label_0 recall = 0.83

Label_1 precision = 0.52

Label_1 recall = 0.38

Label_2 precision = 0.58

Label_2 recall = 0.72

Table 4. Results (Performance improves from left to right.)

Random Forest produced the best overall statistics and Decision Tree performed similarly. Multinomial Logistic Regression performed worst in the group.

For any model, predicting label 1 "36 months < survival <= 72 months" is far less successful than predicting the two opposite labels, i.e. "survival <= 36 months" and "72 months < survival" (labels 0 and 2, respectively). Because of that disparity, we would like to explore survival prediction with only 2 labels to see if we could get better results. This is discussed next.

Binary Classification

With only the 2 labels below, we consider a binary classification problem, a special case of multi-class classification.




Patient survival is greater than or equal to 0 and less than 66 months after diagnosis.


Patient survival is greater than or equal to 66 months after diagnosis.

Table 5. Labels

For the binary classification problem, a new data set is created according to the two new labels. We kept the same feature set as before. For Random Forest, Decision Tree and Multinomial Logistic Regression the number of classes is set to 2. Each of the algorithms is run using 10-fold cross validation to solve the problem.

We observed that for each algorithm the overall statistics (weighted recall and precision) are better for binary classification than those of multi-class classification with 3 labels. Decision Tree and Random Forest performed best with identical results, which are shown below (for simplicity, confusion matrices are not shown).

Weighted precision = 0.76

Weighted recall = 0.73

Label_0 precision = 0.85

Label_0 recall = 0.67

Label_1 precision = 0.63

Label_1 recall = 0.83

With respect to weighted recall, Decision Tree and Random Forest correctly predict the survival period in 73% of the cases, which is a significant improvement over multi-class classification with 3 labels. We conclude that the available inputs did not provide sufficient data to accurately solve a classification problem with more than 2 labels. For example, none of the inputs included patient treatment information. With such information provided as input, the algorithms must obtain more accurate results.


In this article, we discussed Naive Bayes, Multinomial Logistic Regression, Decision Tree and Random Forest algorithms from the Apache Spark Machine Learning Library (MLlib). To demonstrate the corresponding MLlib Java API we utilized an example where survival prediction of colorectal cancer is formulated as a multi-class classification problem. Then, those algorithms were used to solve the problem.

For each of those algorithms, we showed how the MLlib Java API can be used to construct the input data set, initialize the algorithm and evaluate the results.

Below table sorts algorithms from left to right in increasing time of execution for training and prediction.

Naive Bayes

Multinomial Logistic Regression

Decision Tree

Random Forest

11 seconds

21 seconds

32 seconds

230 seconds

Table 6. Execution Times

The results above should not be conclusive because our test environment employed a single node Hadoop installation. In a multi-node cluster Spark MLlib will partition data across nodes and parallel process it. The observation above, based on a single node, should not be generalized without proper testing in a multi-node cluster.

About the Author

Konur Unyelioglu is a software architect with iCare, an Enterprise Cloud Electronic Health Record (EHR) company. He has experience in designing and implementing IT solutions for diverse industries including health care, automotive, telecommunications, retail and transportation. His current focus areas include health IT, enterprise Java and cloud technologies.

Rate this Article


Hello stranger!

You need to Register an InfoQ account or or login to post comments. But there's so much more behind being registered.

Get the most out of the InfoQ experience.

Allowed html: a,b,br,blockquote,i,li,pre,u,ul,p

Community comments

Allowed html: a,b,br,blockquote,i,li,pre,u,ul,p

Allowed html: a,b,br,blockquote,i,li,pre,u,ul,p