Wednesday, May 15, 2013


The default plot function from glmnet is quite basic and plain. The following function tries to improve the plots
#this function is to achieve the same purpose of plot.glmnet()
#with the following enhancement:

# 1) show variable number instead of column number
# 2) show the value of lambda that gives smallest cross validation error with a solid line
# 3) show the largest lambda such that the error is within 1 se of the minimal cross validation error with a dashed line
 
#input variables: a glmnet object from glmnet(), a cv. glmnet object from cv.glmnet(), and a possible plot title
 
plot_glmnet_beta <- function( glmnetObj, cvObj, title1=''){
 
   stopifnot( any(class(glmnetObj) %in% 'glmnet'))
   b1 <- as.matrix(glmnetObj$beta)
   l1 <- log(glmnetObj$lambda)
   stopifnot(length(l1) == ncol(b1))
 
   min1 <- min(l1)
   max1 <- max(l1)
   min1 <- min1 - (max1 - min1)/10
 
   label.data <- data.frame(var = row.names(b1), value=b1[, ncol(b1 )], stringsAsFactors=F)
 
   p1 <- ggplot(toLong(t(b1), l1, c('coefficient', 'logLambda', 'predictor')), aes(x=logLambda, y=coefficient  ))+geom_line(aes(group=predictor, color=predictor)) + theme_bw() + scale_x_continuous(limits = c(min1, max1)) + theme(legend.position='none') + ggtitle(title1)
   p2 <- p1 + geom_text(data=label.data, aes(x=-Inf, y=value,label= var), size=2, hjust=0 ) + geom_vline(xintercept = log(cvObj$lambda.min))
   print(p2)
}
 
 
Created by Pretty R at inside-R.org

No comments: