Identify the Successful Completion of Weight-Lifting Exercises
Human Activity Recognition (HAR) data has become ubiquitous with the advent of devices like the Jawbone Up, Nike FuelBand, Fitbit and even smart-phones. Although users of these devices tend to quantify how much they participate in an activity, they rarely consider “how well” they perform the activity.
Results
Using RStudio, we import HAR data for weight-lifting exercises provided from a Groupware@LES study, to see if we can predict how well exercises were performed by users of HAR devices. After some initial data wrangling, we analyze the data, visualize the results, and develop a machine learning model. This deep-dive analysis shows how to perform the typical steps for this type of machine learning activity.
Conclusion
RStudio is an excellent platform for this type of analytical project and helps analysts tell a story with data. In combination with RMarkdown, RStudio readily facilitates analysis, sharing, and reproducibility. Analysts may weave together narrative text and code to seamlessly produce elegantly formatted in various output formats (i.e. html, pdf, MS Word, etc.).
Implementation
With devices like the Jawbone Up, Nike FuelBand and Fitbit, it is feasible to inexpensively collect and analyze a large quantity of HAR data. With the Groupware@LES HAR data, we analyze six study participants that were asked to perform dumbbell exercises with correct form (Class A) and with the most common mistakes: throwing the elbows to the front (Class B), lifting the dumbbell only halfway (Class C), lowering the dumbbell only halfway (Class D) and throwing the hips to the front (Class E). Using this HAR class data, we leverage machine learning algorithms to develop a model to distinguish between participants that have exercised correctly versus those that haven’t, and what their mistakes may have been. We will then use this model to assess how accurately we can classify the exercises performed by other individuals solely based on their HAR data.
Technical Notes: This analysis is statically generated with RStudio 3.3.2 using RMarkdown 1.3 for reproducibility. The complete codebase is available on my github page. Many code blocks are cached for performant iterative analysis. When updating code blocks, pay special attention to subsequent cached code blocks that should be updated as a result. To learn more, have a peek at the RStudio/RMarkdown documentation and cheatsheets.
Data Processing
Getting the Data
Download the training and test data, if not already in the project data folder.
setwd("/home/rstudio/code")
trnFile = "data/pml-har-trn.csv"
tstFile = "data/pml-har-tst.csv"
if (!file.exists(trnFile)) {
trnfileUrl <- "https://d396qusza40orc.cloudfront.net/predmachlearn/pml-training.csv"
download.file(trnfileUrl, destfile = trnFile, method = "curl")
}
if (!file.exists(tstFile)) {
tstfileUrl <- "https://d396qusza40orc.cloudfront.net/predmachlearn/pml-testing.csv"
download.file(tstfileUrl, destfile = tstFile, method = "curl")
}
Data Wrangling and Preparation
The goal of data preparation for machine learning, is to identify the variables that explain as much of the class variance possible. In other words, we need to find the fewest quantity of variables that can explain everything that is going on. To improve model performance and simplify our analysis we discard those variables not related to exercise performance.
# Load libraries
library(caret)
## Loading required package: lattice
## Loading required package: ggplot2
library(knitr)
library(corrplot)
library(doMC)
## Loading required package: foreach
## Loading required package: iterators
## Loading required package: parallel
library(ggplot2)
library(pROC)
## Type 'citation("pROC")' for a citation.
##
## Attaching package: 'pROC'
## The following objects are masked from 'package:stats':
##
## cov, smooth, var
# Define number of parallel processes
registerDoMC(cores = 8)
# Load the training file
trn.dat = read.csv(trnFile, na.strings=c("NA","NaN", "", " "), stringsAsFactors=FALSE); rm("trnFile")
# Quick exploratory analysis
dim.init <- dim(trn.dat); dim.init
## [1] 19622 160
str(trn.dat)
## 'data.frame': 19622 obs. of 160 variables:
## $ X : int 1 2 3 4 5 6 7 8 9 10 ...
## $ user_name : chr "carlitos" "carlitos" "carlitos" "carlitos" ...
## $ raw_timestamp_part_1 : int 1323084231 1323084231 1323084231 1323084232 1323084232 1323084232 1323084232 1323084232 1323084232 1323084232 ...
## $ raw_timestamp_part_2 : int 788290 808298 820366 120339 196328 304277 368296 440390 484323 484434 ...
## $ cvtd_timestamp : chr "05/12/2011 11:23" "05/12/2011 11:23" "05/12/2011 11:23" "05/12/2011 11:23" ...
## $ new_window : chr "no" "no" "no" "no" ...
## $ num_window : int 11 11 11 12 12 12 12 12 12 12 ...
## $ roll_belt : num 1.41 1.41 1.42 1.48 1.48 1.45 1.42 1.42 1.43 1.45 ...
## $ pitch_belt : num 8.07 8.07 8.07 8.05 8.07 8.06 8.09 8.13 8.16 8.17 ...
## $ yaw_belt : num -94.4 -94.4 -94.4 -94.4 -94.4 -94.4 -94.4 -94.4 -94.4 -94.4 ...
## $ total_accel_belt : int 3 3 3 3 3 3 3 3 3 3 ...
## $ kurtosis_roll_belt : chr NA NA NA NA ...
## $ kurtosis_picth_belt : chr NA NA NA NA ...
## $ kurtosis_yaw_belt : chr NA NA NA NA ...
## $ skewness_roll_belt : chr NA NA NA NA ...
## $ skewness_roll_belt.1 : chr NA NA NA NA ...
## $ skewness_yaw_belt : chr NA NA NA NA ...
## $ max_roll_belt : num NA NA NA NA NA NA NA NA NA NA ...
## $ max_picth_belt : int NA NA NA NA NA NA NA NA NA NA ...
## $ max_yaw_belt : chr NA NA NA NA ...
## $ min_roll_belt : num NA NA NA NA NA NA NA NA NA NA ...
## $ min_pitch_belt : int NA NA NA NA NA NA NA NA NA NA ...
## $ min_yaw_belt : chr NA NA NA NA ...
## $ amplitude_roll_belt : num NA NA NA NA NA NA NA NA NA NA ...
## $ amplitude_pitch_belt : int NA NA NA NA NA NA NA NA NA NA ...
## $ amplitude_yaw_belt : chr NA NA NA NA ...
## $ var_total_accel_belt : num NA NA NA NA NA NA NA NA NA NA ...
## $ avg_roll_belt : num NA NA NA NA NA NA NA NA NA NA ...
## $ stddev_roll_belt : num NA NA NA NA NA NA NA NA NA NA ...
## $ var_roll_belt : num NA NA NA NA NA NA NA NA NA NA ...
## $ avg_pitch_belt : num NA NA NA NA NA NA NA NA NA NA ...
## $ stddev_pitch_belt : num NA NA NA NA NA NA NA NA NA NA ...
## $ var_pitch_belt : num NA NA NA NA NA NA NA NA NA NA ...
## $ avg_yaw_belt : num NA NA NA NA NA NA NA NA NA NA ...
## $ stddev_yaw_belt : num NA NA NA NA NA NA NA NA NA NA ...
## $ var_yaw_belt : num NA NA NA NA NA NA NA NA NA NA ...
## $ gyros_belt_x : num 0 0.02 0 0.02 0.02 0.02 0.02 0.02 0.02 0.03 ...
## $ gyros_belt_y : num 0 0 0 0 0.02 0 0 0 0 0 ...
## $ gyros_belt_z : num -0.02 -0.02 -0.02 -0.03 -0.02 -0.02 -0.02 -0.02 -0.02 0 ...
## $ accel_belt_x : int -21 -22 -20 -22 -21 -21 -22 -22 -20 -21 ...
## $ accel_belt_y : int 4 4 5 3 2 4 3 4 2 4 ...
## $ accel_belt_z : int 22 22 23 21 24 21 21 21 24 22 ...
## $ magnet_belt_x : int -3 -7 -2 -6 -6 0 -4 -2 1 -3 ...
## $ magnet_belt_y : int 599 608 600 604 600 603 599 603 602 609 ...
## $ magnet_belt_z : int -313 -311 -305 -310 -302 -312 -311 -313 -312 -308 ...
## $ roll_arm : num -128 -128 -128 -128 -128 -128 -128 -128 -128 -128 ...
## $ pitch_arm : num 22.5 22.5 22.5 22.1 22.1 22 21.9 21.8 21.7 21.6 ...
## $ yaw_arm : num -161 -161 -161 -161 -161 -161 -161 -161 -161 -161 ...
## $ total_accel_arm : int 34 34 34 34 34 34 34 34 34 34 ...
## $ var_accel_arm : num NA NA NA NA NA NA NA NA NA NA ...
## $ avg_roll_arm : num NA NA NA NA NA NA NA NA NA NA ...
## $ stddev_roll_arm : num NA NA NA NA NA NA NA NA NA NA ...
## $ var_roll_arm : num NA NA NA NA NA NA NA NA NA NA ...
## $ avg_pitch_arm : num NA NA NA NA NA NA NA NA NA NA ...
## $ stddev_pitch_arm : num NA NA NA NA NA NA NA NA NA NA ...
## $ var_pitch_arm : num NA NA NA NA NA NA NA NA NA NA ...
## $ avg_yaw_arm : num NA NA NA NA NA NA NA NA NA NA ...
## $ stddev_yaw_arm : num NA NA NA NA NA NA NA NA NA NA ...
## $ var_yaw_arm : num NA NA NA NA NA NA NA NA NA NA ...
## $ gyros_arm_x : num 0 0.02 0.02 0.02 0 0.02 0 0.02 0.02 0.02 ...
## $ gyros_arm_y : num 0 -0.02 -0.02 -0.03 -0.03 -0.03 -0.03 -0.02 -0.03 -0.03 ...
## $ gyros_arm_z : num -0.02 -0.02 -0.02 0.02 0 0 0 0 -0.02 -0.02 ...
## $ accel_arm_x : int -288 -290 -289 -289 -289 -289 -289 -289 -288 -288 ...
## $ accel_arm_y : int 109 110 110 111 111 111 111 111 109 110 ...
## $ accel_arm_z : int -123 -125 -126 -123 -123 -122 -125 -124 -122 -124 ...
## $ magnet_arm_x : int -368 -369 -368 -372 -374 -369 -373 -372 -369 -376 ...
## $ magnet_arm_y : int 337 337 344 344 337 342 336 338 341 334 ...
## $ magnet_arm_z : int 516 513 513 512 506 513 509 510 518 516 ...
## $ kurtosis_roll_arm : chr NA NA NA NA ...
## $ kurtosis_picth_arm : chr NA NA NA NA ...
## $ kurtosis_yaw_arm : chr NA NA NA NA ...
## $ skewness_roll_arm : chr NA NA NA NA ...
## $ skewness_pitch_arm : chr NA NA NA NA ...
## $ skewness_yaw_arm : chr NA NA NA NA ...
## $ max_roll_arm : num NA NA NA NA NA NA NA NA NA NA ...
## $ max_picth_arm : num NA NA NA NA NA NA NA NA NA NA ...
## $ max_yaw_arm : int NA NA NA NA NA NA NA NA NA NA ...
## $ min_roll_arm : num NA NA NA NA NA NA NA NA NA NA ...
## $ min_pitch_arm : num NA NA NA NA NA NA NA NA NA NA ...
## $ min_yaw_arm : int NA NA NA NA NA NA NA NA NA NA ...
## $ amplitude_roll_arm : num NA NA NA NA NA NA NA NA NA NA ...
## $ amplitude_pitch_arm : num NA NA NA NA NA NA NA NA NA NA ...
## $ amplitude_yaw_arm : int NA NA NA NA NA NA NA NA NA NA ...
## $ roll_dumbbell : num 13.1 13.1 12.9 13.4 13.4 ...
## $ pitch_dumbbell : num -70.5 -70.6 -70.3 -70.4 -70.4 ...
## $ yaw_dumbbell : num -84.9 -84.7 -85.1 -84.9 -84.9 ...
## $ kurtosis_roll_dumbbell : chr NA NA NA NA ...
## $ kurtosis_picth_dumbbell : chr NA NA NA NA ...
## $ kurtosis_yaw_dumbbell : chr NA NA NA NA ...
## $ skewness_roll_dumbbell : chr NA NA NA NA ...
## $ skewness_pitch_dumbbell : chr NA NA NA NA ...
## $ skewness_yaw_dumbbell : chr NA NA NA NA ...
## $ max_roll_dumbbell : num NA NA NA NA NA NA NA NA NA NA ...
## $ max_picth_dumbbell : num NA NA NA NA NA NA NA NA NA NA ...
## $ max_yaw_dumbbell : chr NA NA NA NA ...
## $ min_roll_dumbbell : num NA NA NA NA NA NA NA NA NA NA ...
## $ min_pitch_dumbbell : num NA NA NA NA NA NA NA NA NA NA ...
## $ min_yaw_dumbbell : chr NA NA NA NA ...
## $ amplitude_roll_dumbbell : num NA NA NA NA NA NA NA NA NA NA ...
## [list output truncated]
## Removing unnecessary covariates
# Columns not related to exercise performance.
col.nrel <- c("X", "user_name", "raw_timestamp_part_1", "raw_timestamp_part_2",
"cvtd_timestamp", "new_window", "num_window")
# Columns with near zero variance
nsv <- nearZeroVar(trn.dat, saveMetrics = TRUE)
col.nzvs <- rownames(nsv[nsv$nzv == TRUE, ])
# Update the training set with the remaining fields
col.drops <- c(col.nrel, col.nzvs)
trn.dat.all <- trn.dat[,!(names(trn.dat) %in% col.drops)]
# Identify fields with large quantity of missing values.
col.names = colnames(trn.dat.all)
cnt.min = sum(complete.cases(trn.dat.all))
trn.cnt.df <- data.frame(Field=character(), FieldCnt=integer(), stringsAsFactors=FALSE)
# Create data-frame with field names and counts.
for (fld in 1:length(col.names)) {
trn.cnt.df[fld,] <- c(col.names[fld], sum(!is.na(trn.dat.all[[col.names[fld]]])))
}
# Filter out the fields with high quantity of missing values.
col.keep <- as.vector(subset(trn.cnt.df, FieldCnt < cnt.min)$Field)
# Update the training set with the remaining fields
trn.dat.all <- trn.dat.all[col.keep]
## Cleanup and default model settings
# Set the classe variable as a factor.
trn.dat.all$classe <- as.factor(trn.dat.all$classe)
dim.finl <- dim(trn.dat.all); dim.finl
## [1] 19622 53
# house-cleaning
rm("cnt.min", "col.drops", "col.keep", "col.names", "col.nrel", "col.nzvs", "fld", "trn.cnt.df", "nsv","trn.dat")
# Cutoff parameters
corrRt = 0.80 # Correlation Cutoff
dSplit = 0.60 # Training Cutoff
## Set model parameters for model tuning and training
cntFld = 10 # Number of cross-validation folds
cntRpt = 8 # Increase the Repeat count for improved cross-validation accuracy
cntTun = 5 # Parameter for tuning accuracy vs resource consumption
A | B | C | D | E |
---|---|---|---|---|
5580 | 3797 | 3422 | 3216 | 3607 |
Our exploratory analysis shows that the Groupware@LES file is just shy of twenty thousand records and that some of the variables have near-zero-variance values or have missing values. After discarding them, our variable count drops from 160 down to 53. The class variables in the provided data is relatively balanced, which simplifies this analysis somewhat.
kable(t(table(trn.dat.all$classe)), caption = "Class Variable Frequency:")
Handling Multicollinearity
Multicollinearity occurs when model variables are correlated to your response variable as well as each other. This overinflates the standard errors which makes some variables seem statistically insignificant when they should be significant. By identifying the Variable Inflation Factor (VIF), we can quantify the severity of this impact. We implemented a stepwise VIF function to identify multicollinear variables within a specified threshold. A value greater than 10 is considered to be highly collinear. The threshold may be set to the specific tolerance required for your analysis.
We can reduce the number of predictors to a smaller set of uncorrelated components using Partial Least Squares Regression (PLS) or Principal Components Analysis (PCA) or by simply removing them from the model. Removing predictors may introduce some bias (difference between expected prediction and the truth), but may also reduce the prediction variance (increase accuracy). The goal is to optimize the trade-off between bias and variance. These options reduce our variable count and speeds our model calculation.
The goal of PCA is to explain the maximum amount of variance with the fewest number of principal components. PCA creates these components by transforming the variables into a smaller sub-space of variables (dimensionality reduction) that are uncorrelated with each other. A challenge with the components is that it can be difficult to explain what is driving model performance. Using the Caret R package, we apply a Box-Cox Transformation to correct for skewness, center and scale each variable and then apply PCA in one call to pre-process the variables.
We are going to approach this three different ways to get a sense of effort and impact:
- remove the multicollinear predictors,
- apply PCA to the multicollinear predictors, and
- apply PCA to all the predictors.
# compile stepwise VIF selection function for reducing collinearity among explanatory variables
vif_func<-function(in_frame, thresh=10, trace=T, ...){
require(fmsb)
if(class(in_frame) != 'data.frame') in_frame<-data.frame(in_frame)
#get initial vif value for all comparisons of variables
vif_init<-NULL
var_names <- names(in_frame)
for(val in var_names){
regressors <- var_names[-which(var_names == val)]
form <- paste(regressors, collapse = '+')
form_in <- formula(paste(val, '~', form))
vif_init<-rbind(vif_init, c(val, VIF(lm(form_in, data = in_frame, ...))))
}
vif_max<-max(as.numeric(vif_init[,2]))
if(vif_max < thresh){
if(trace==T){ #print output of each iteration
prmatrix(vif_init,collab=c('var','vif'), rowlab=rep('', nrow(vif_init)), quote=F)
cat('\n')
cat(paste('All variables have VIF < ', thresh,', max VIF ',round(vif_max,2), sep=''),'\n\n')
}
return(var_names)
}
else {
in_dat<-in_frame
#backwards selection of explanatory variables, stops when all VIF values are below 'thresh'
while(vif_max >= thresh) {
vif_vals<-NULL
var_names <- names(in_dat)
for(val in var_names){
regressors <- var_names[-which(var_names == val)]
form <- paste(regressors, collapse = '+')
form_in <- formula(paste(val, '~', form))
vif_add<-VIF(lm(form_in, data = in_dat, ...))
vif_vals<-rbind(vif_vals,c(val,vif_add))
}
max_row<-which(vif_vals[,2] == max(as.numeric(vif_vals[,2])))[1]
vif_max<-as.numeric(vif_vals[max_row,2])
if(vif_max<thresh) break
if(trace==T){ #print output of each iteration
prmatrix(vif_vals,collab=c('var','vif'),rowlab=rep('',nrow(vif_vals)),quote=F)
cat('\n')
cat('removed: ',vif_vals[max_row,1],vif_max,'\n\n')
flush.console()
}
in_dat<-in_dat[,!names(in_dat) %in% vif_vals[max_row,1]]
}
return(names(in_dat))
}
}
Select the tabs below to switch between the plots, which highlight multicollinearity across variables in each approach. Review the summary below each plot to see the change in variable count and understand the impact on multicollinearity overall.
Machine Learning
Tuning Preparation
We are using a 60% data split to define our training set. To provide a meaningful estimation of performance, the remaining 40% will be divided equally to define our validation and test sets. The individual models are tuned with the training set, and assessed against the validation set. The test set is for estimating the ensemble model performance.
With the Caret R package, we use repeated k-fold cross-validation with 10 folds across 8 repetitions to improve our accuracy estimates. We preprocess the data with “centered and scaled” to better expose the underlying structure and relationships to the predictors. We pre-calculate a vector of seeds for reproducibility across multiple parallel model runs.
# Select the PCA dataset for training.
trn.dat <- trn.dat.pca
# Partition the data with training, validation, and test sets.
set.seed(1732)
inTrain <- createDataPartition(y=trn.dat$classe, p=dSplit, list=FALSE)
trn.trn <- trn.dat[ inTrain,]
trn.vld <- trn.dat[-inTrain,]
inValid <- createDataPartition(y=trn.vld$classe, p=0.50, list=FALSE)
trn.tst <- trn.vld[-inValid,]
trn.vld <- trn.vld[ inValid,]
# Setup seeds for running fully reproducible model in parallel mode
ivect = cntRpt*cntFld # quantity of integer vectors
seeds <- vector(mode = "list", length = ivect+1) # length equals the integer vectors plus one
for(i in 1:ivect) seeds[[i]] <- sample.int(n=1000, cntTun*3) # seeds for tuning attempts (*3 for Radial SVM)
seeds[[length(seeds)]] <- sample.int(1000, 1) # seed for the final model
# Create cross-validation folds to use for model tuning
myMetr <- "Kappa" # ROC, Kappa, Accuracy
myFlds <- createMultiFolds(y=trn.trn$classe, k = cntFld, times = cntRpt)
myCtrl <- trainControl(method = "repeatedCV", seeds = seeds, allowParallel = TRUE,
#savePredictions=TRUE, classProbs=TRUE, summaryFunction=multiClassSummary,
number = cntFld, repeats = cntRpt)
# Standardize the training subset.
myPrep = c("center", "scale") # pre-process with center and scaling
# house-cleaning
rm("i", "cntFld", "cntRpt", "inTrain", "corrMatx", "vif_func")
Grid Searching
There are two ways to tune an algorithm using the Caret R package. The first method is to allow the system to automatically do it, by setting the tuneLength
parameter to indicate the number of different values to try for each algorithm parameter. This makes a crude guess on what values to try, but can often get you within a reasonable range. This is the approach we’ve taken on our first tuning pass. The second approach involves manually setting the tuning grid to search. Our second tuning pass uses this approach. We start with the best performing parameters from the first tuning pass and create a more focused tuning grid to further refine our parameters.
We have selected three models to tune:
- Support Vector Machine (SVM) with a Radial Basis Function (RBF) kernel,
- RandomForest, and
- Stochastic Gradient Boost Machine.
To improve SVM performance we doubled its grid-search area during its first pass. Also, since the GBM algorithm is computationally expensive, we did significant manual tuning during its second pass to attain reasonable performance. For good measure, we created an Ensemble (Stack) model using the three separate algorithms to improve overall performance.
Technical Notes: Since this process can be very resource intensive, we use OpenBlas, which is an optimized library to replace the standard BLAS library used by R. In addition, we use the “doMC” library to specify the number cores to correspond with the maximum number of R jobs to run in parallel. Depending on your tuning parameters and training set size, this process can easily consume all of your computer memory and may take hours to run. Running machine learning on an H2O.ai cluster helps to alleviate these challenges. For more details, select the “Distributed Random Forest” tab below.
Tuning Model Accuracy
After tuning these models we can identify the parameters that generate the best performance per model and compare how well the models perform against each other.
Fitting the Models
Now that we have the tuned parameters we select and fit the final models. This is really fast compared to the tuning steps.
# Retrive the training errors for the SVM model
err.svr <- tune.svr2$finalModel@error; err.svr
## [1] 0.003226902
# Retrieve the OOB error for the Random Forest model based on number of trees used
ntr <- tune.rft2$finalModel$ntree; ntr
## [1] 500
err.rft <- tune.rft2$finalModel$err.rate[ntr,1]; err.rft
## OOB
## 0.0303159
# Identify the lowest error rate by number of trees
treerr.df <- as.data.frame(tune.rft2$finalModel$err.rate[,1])
treerr.df$trees <- row(treerr.df)
colnames(treerr.df) <- c("err","trees")
# Reduce the number of trees for the Random Forest based on error rate
ntr = as.vector(head(treerr.df[with(treerr.df, order(err)), ], 1)$trees); ntr
## [1] 374
err.rft <- tune.rft2$finalModel$err.rate[ntr, 1]; err.rft
## OOB
## 0.03014606
Separate Model Analysis & Validation
During this process we run the tuned models against our validation set to assess performance. The model performance statistics indicate reasonable performance and that we have not overfit the model.
# Assess the separate model accuracy on the validation set.
yhat.svr <- predict(fit.svr, trn.vld)
## Loading required package: kernlab
##
## Attaching package: 'kernlab'
## The following object is masked from 'package:ggplot2':
##
## alpha
yhat.rft <- predict(fit.rft, trn.vld)
yhat.gbm <- predict(fit.gbm, trn.vld)
## Loading required package: gbm
## Loading required package: survival
##
## Attaching package: 'survival'
## The following object is masked from 'package:caret':
##
## cluster
## Loading required package: splines
## Loaded gbm 2.1.1
## Loading required package: plyr
cm.svr <- confusionMatrix(yhat.svr, trn.vld$classe)
cm.rft <- confusionMatrix(yhat.rft, trn.vld$classe)
cm.gbm <- confusionMatrix(yhat.gbm, trn.vld$classe)
Ensemble Model Analysis & Comparison
Although the separate models perform well individually, we built this ensemble model to highlight the process and the potential benefit to predictive performance. Different algorithms will find some variables more predictive than others, and will rank them differently. This shows up in the variables of importance (see below). To get additional lift, it is important that the models being stacked have varying variables of importance, because the ensemble model will capitalize on those differences to boost predictive performance.
After fitting the separate models to our validation set, we create “modelled predictors” so that we can train the ensemble model. Then we assess the model performance against our test set. This “stacked” model should have improved predictive performance over the separate models.
set.seed(1732)
# Dataframe with predictions of the separate models, combined with the classe variable (from the validation set)
# in preparation for training the ensemble model against the "validation" data.
pred.vld <- data.frame(yhat.svr, yhat.rft, yhat.gbm, classe=trn.vld$classe)
# Train the ensemble model on the "stacked" predictors
# Note: since it was the fastest of the trio, the SVM algorithm used to build the ensemble
fit.stack <- train(classe ~ ., data=pred.vld, method="svmRadial",
metric = myMetr, trControl = myCtrl)
# Compare the ensemble model training performance to the separate models.
rbind(getTrainPerf(fit.stack),
getTrainPerf(fit.svr),
getTrainPerf(fit.rft),
getTrainPerf(fit.gbm))
## TrainAccuracy TrainKappa method
## 1 0.9862357 0.9825861 svmRadial
## 2 0.9824536 0.9778050 svmRadial
## 3 0.9674126 0.9587721 rf
## 4 0.9650669 0.9558121 gbm
# To test the ensemble performance we create "stacked" predictors based
# on the test set for ensemble prediction assessment.
yhat.svr <- predict(fit.svr, trn.tst)
yhat.rft <- predict(fit.rft, trn.tst)
yhat.gbm <- predict(fit.gbm, trn.tst)
# Create "stacked" dataframe of modelled predicters and the classe variable (from the test set)
# in preparation for training the ensemble model against the "test" data.
pred.tst <- data.frame(yhat.svr, yhat.rft, yhat.gbm, classe=trn.tst$classe)
# Predict the model classe variable using the "stacked" predictors
yhat.stk <- predict(fit.stack, pred.tst)
# Compare the predicted class against the actual class variable.
cm.stk <- confusionMatrix(yhat.stk, trn.tst$classe)
str(cm.stk$byClass)
## num [1:5, 1:11] 0.988 0.989 0.972 0.972 0.993 ...
## - attr(*, "dimnames")=List of 2
## ..$ : chr [1:5] "Class: A" "Class: B" "Class: C" "Class: D" ...
## ..$ : chr [1:11] "Sensitivity" "Specificity" "Pos Pred Value" "Neg Pred Value" ...
Conclusion
Our statistics, including accuracy, precision and recall, indicate that the ensemble model is well generalized and not overfit to the training data. The ensemble model’s estimated training accuracy (98.62%) shows a 0.38% predictive improvement over the SVM model (98.25%).
Let’s see how the ensemble performs compared to the separate models against the test data:
- Ensemble: 98.39%
- SVM-Radial: 98.34%
- Random Forest: 97.43%
- Gradient Boost: 97.12%
In addition to the improved estimated mean performance, the ensemble model has estimated per-class performance that is equal-to or better than the separate models:
# Assess how the ensemble gets higher accuracy
# Compare combined predictor accuracy to the producer’s accuracy (aka sensitivity or recall) for the classes
pred.perf <- rbind(confusionMatrix(yhat.svr, trn.tst$classe)$byClass[, 7],
confusionMatrix(yhat.rft, trn.tst$classe)$byClass[, 7],
confusionMatrix(yhat.gbm, trn.tst$classe)$byClass[, 7],
confusionMatrix(yhat.stk, trn.tst$classe)$byClass[, 7])
row.names(pred.perf) <- c("SVM.Radial", "Random.Forest", "Gradient.Boost", "Ensemble.Stack")
# Class Summary
kable(pred.perf, caption = "Per-class Weigthed Average of Precision and Recall (F-1):")
Class: A | Class: B | Class: C | Class: D | Class: E | |
---|---|---|---|---|---|
SVM.Radial | 0.9901257 | 0.9810581 | 0.9700948 | 0.9772549 | 0.9937543 |
Random.Forest | 0.9839429 | 0.9648308 | 0.9561467 | 0.9693637 | 0.9909281 |
Gradient.Boost | 0.9829443 | 0.9606815 | 0.9486623 | 0.9702194 | 0.9867411 |
Ensemble.Stack | 0.9910153 | 0.9810581 | 0.9700948 | 0.9788567 | 0.9937543 |
Technical Notes: Receiver Operating Characteristic (ROC) curve analysis was not designed for multi-classification problems, however using the pROC library, we calculate the mean area under the curve (AUC) on the test data of 99.37%, which cannot be plotted. To learn more about ROC curve analysis, have a look at this BioMed Central (BMC) research paper. “Xavier Robin, Natacha Turck, Alexandre Hainard, Natalia Tiberti, Frédérique Lisacek, Jean-Charles Sanchez and Markus Müller (2011). pROC: an open-source package for R and S+ to analyze and compare ROC curves. BMC Bioinformatics, 12, p. 77. DOI: 10.1186/1471-2105-12-77”.
The code and content in this post is licensed by Roberto Rivera to the public
under a Creative Commons Attribution 4.0 License