The default plot function from glmnet is quite basic and plain. The following function tries to improve the plots
#this function is to achieve the same purpose of plot.glmnet() #with the following enhancement: # 1) show variable number instead of column number # 2) show the value of lambda that gives smallest cross validation error with a solid line # 3) show the largest lambda such that the error is within 1 se of the minimal cross validation error with a dashed line #input variables: a glmnet object from glmnet(), a cv. glmnet object from cv.glmnet(), and a possible plot title plot_glmnet_beta <- function( glmnetObj, cvObj, title1=''){ stopifnot( any(class(glmnetObj) %in% 'glmnet')) b1 <- as.matrix(glmnetObj$beta)
l1 <- log(glmnetObj$lambda) stopifnot(length(l1) == ncol(b1))
min1 <- min(l1) max1 <- max(l1) min1 <- min1 - (max1 - min1)/10 label.data <- data.frame(var = row.names(b1), value=b1[, ncol(b1 )], stringsAsFactors=F)
p1 <- ggplot(toLong(t(b1), l1, c('coefficient', 'logLambda', 'predictor')), aes(x=logLambda, y=coefficient ))+geom_line(aes(group=predictor, color=predictor)) + theme_bw() + scale_x_continuous(limits = c(min1, max1)) + theme(legend.position='none') + ggtitle(title1) p2 <- p1 + geom_text(data=label.data, aes(x=-Inf, y=value,label= var), size=2, hjust=0 ) + geom_vline(xintercept = log(cvObj$lambda.min)) print(p2) }
No comments:
Post a Comment