summaryrefslogblamecommitdiff
path: root/CakeMcmcFit.R
blob: be57cef0bdc6a4066239a43f72db9447261b7236 (plain) (tree)
1
2
3
4
5
6
                                                
                                          

                                                                                    
 
                                                                                









                                                                         
                                                                          
 












                                                                                                          
                                                                                  














                                                                                                     
                                                


                                                
 
                                        

                                                                 

                                       
                                                   




                                            


                                           
        


                                                                    
                                                            
         
        

                              

                                     
                                    

     






                                                               
    




                                                                








                                                                                                      






                                                               

                 
 


                                                                      
                               
     
    
                            

                  
                                                               
                              
                      
     

                                                                                                                                                 

                                   
                                                                                                                                               



         



                                                                                                                                                 
                                            
     
    










                                                                             
     















                                                                                       
                                                                                                                                                                                                                                     
  

                                                                      
                                                               



                                                                    
                                                               


                                           

                                 
   




                                                     
   

                                        
                                                   
                                                                 
                                                               


                                                                                                                                                                


                                                                                                 
     
    

                                                                  
  

                                          
   


                                                                            

                                    
                                                                            
 
                                                               
                                                   
                                                                                                                                         
 







                                                                         
                                       
 

                                       
   
                    







                                                       
                                                                                                                     























                                                                                  








                                                                   
                       






                                           
                 





                                                                                                                                 




                                                           









                                                             


                           
                                                         



                           

                                   





                                     
                                                 
                            


                                                                         
# Some of the CAKE R modules are based on mkin, 
# Based on mcmckinfit as modified by Bayer
# Modifications developed by Tessella for Syngenta: Copyright (C) 2011-2020 Syngenta
# Tessella Project Reference: 6245, 7247, 8361, 7414, 10091

#    The CAKE R modules are free software: you can redistribute it and/or modify
#    it under the terms of the GNU General Public License as published by
#    the Free Software Foundation, either version 3 of the License, or
#    (at your option) any later version.
# 
#    This program is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#    GNU General Public License for more details.
# 
#    You should have received a copy of the GNU General Public License
#    along with this program.  If not, see <http://www.gnu.org/licenses/>.

# Performs an Markov chain Monte Carlo least squares fit on a given CAKE model.
#
# cake.model: The model to perform the fit on (as generated by CakeModel.R).
# observed: Observation data to fit to.
# parms.ini: Initial values for the parameters being fitted.
# state.ini: Initial state (i.e. initial values for concentration, the dependent variable being modelled).
# lower: Lower bounds to apply to parameters.
# upper: Upper bound to apply to parameters.
# fixed_parms: A vector of names of parameters that are fixed to their initial values.
# fixed_initials: A vector of compartments with fixed initial concentrations.
# quiet: Whether the internal cost functions should execute more quietly than normal (less output).
# niter: The number of MCMC iterations to apply.
# atol: The tolerance to apply to the ODE solver.
# dfopDtMaxIter: The maximum number of iterations to apply to DFOP DT calculation.
# control: ...
# useExtraSolver: Whether to use the extra solver for this fit (only used for the initial first fit).
CakeMcmcFit <- function (cake.model, 
                         observed, 
                         parms.ini, 
                         state.ini, 
                         lower, 
                         upper, 
                         fixed_parms = NULL, 
                         fixed_initials, 
                         quiet = FALSE, 
                         niter = 1000, 
                         verbose = TRUE, 
                         seed=NULL, 
                         atol=1e-6, 
                         dfopDtMaxIter = 10000, 
                         control=list(),
                         useExtraSolver = FALSE,
                         ...) 
{
    NAind <-which(is.na(observed$value))
    mod_vars <- names(cake.model$diffs)
    observed <- subset(observed, name %in% names(cake.model$map))
    ERR <- rep(1,nrow(observed))
    observed <- cbind(observed,err=ERR)
    obs_vars <- unique(as.character(observed$name))
    
    if (is.null(names(parms.ini)))  {
        names(parms.ini) <- cake.model$parms
    }
    
    mkindiff <- function(t, state, parms) {
        time <- t
        diffs <- vector()
        
        for (box in mod_vars) {
            diffname <- paste("d", box, sep = "_")
            diffs[diffname] <- with(as.list(c(time, state, parms)), 
                eval(parse(text = cake.model$diffs[[box]])))
        }
        
        return(list(c(diffs)))
    }
    
    if (is.null(names(state.ini))) { 
        names(state.ini) <- mod_vars
    }
    
    parms.fixed <- parms.ini[fixed_parms]
    optim_parms <- setdiff(names(parms.ini), fixed_parms)
    parms.optim <- parms.ini[optim_parms]
    state.ini.fixed <- state.ini[fixed_initials]
    optim_initials <- setdiff(names(state.ini), fixed_initials)
    state.ini.optim <- state.ini[optim_initials]
    state.ini.optim.boxnames <- names(state.ini.optim)
    
    if (length(state.ini.optim) > 0) {
        names(state.ini.optim) <- paste(names(state.ini.optim), 
            "0", sep = "_")
    }
   
    costFunctions <- CakeInternalCostFunctions(cake.model, state.ini.optim, state.ini.optim.boxnames, 
                        state.ini.fixed, parms.fixed, observed, mkindiff, quiet, atol=atol)
    bestIteration <- -1;
    costWithStatus <- function(P, ...){
        r <- costFunctions$cost(P)
        if(r$cost == costFunctions$get.best.cost()) {
            bestIteration<<-costFunctions$get.calls();
             cat(' MCMC best so far: c', r$cost, 'it:', bestIteration, '\n')
        }
        
        arguments <- list(...)
        if (costFunctions$get.calls() <= arguments$maxCallNo)
        {
          cat("MCMC Call no.", costFunctions$get.calls(), '\n')
        }

        return(r)
    }

    # Set the seed
    if ( is.null(seed) ) {
      # No seed so create a random one so there is something to report
      seed <- runif(1,0,2^31-1)
    }
    
    seed <- as.integer(seed)
    set.seed(seed)
    
    ## ############# Get Initial Paramtervalues   #############
    ## Start with no weighting
    if(useExtraSolver)
    {
        a <- try(fit <- modFit(costFunctions$cost, c(state.ini.optim, parms.optim), lower = lower, upper = upper, method='Port',control=control))
        
        if(class(a) == "try-error")
        {
            fit <- modFit(costFunctions$cost, c(state.ini.optim, parms.optim), lower = lower, upper = upper, method='L-BFGS-B',control=control)
        }
    }
    else
    {
        # modFit parameter transformations can explode if you put in parameters that are equal to a bound, so we move them away by a tiny amount.
        all.optim <- ShiftAwayFromBoundaries(c(state.ini.optim, parms.optim), lower, upper)
        
        fit <- modFit(costFunctions$cost, all.optim, lower = lower, 
                          upper = upper,...)
    }
    
    ## ############## One reweighted estimation ############################ 
    ## Estimate the error variance(sd)     
    tmpres <- fit$residuals
    oldERR <- observed$err
    err <- rep(NA,length(cake.model$map))
    
    for(i in 1:length(cake.model$map)) {
        box <- names(cake.model$map)[i]
        ind <- which(names(tmpres)==box)
        tmp <- tmpres[ind]
        err[i] <- sd(tmp)
    }
    
    names(err) <- names(cake.model$map)
    ERR <- err[as.character(observed$name)]
    observed$err <-ERR
    costFunctions$set.error(ERR)
    
    olderr <- rep(1,length(cake.model$map))
    diffsigma <- sum((err-olderr)^2)
    ## At least do one iteration step to get a weighted LS
    fit <- modFit(costFunctions$cost, fit$par, lower = lower, upper = upper, ...)
    ## Use this as the Input for MCMC algorithm
    ## ##########################
    fs <- summary(fit)
    cov0 <- if(all(is.na(fs$cov.scaled))) NULL else fs$cov.scaled*2.4^2/length(fit$par)
    var0 <- fit$var_ms_unweighted
    costFunctions$set.calls(0); costFunctions$reset.best.cost()
    res <- modMCMC(costWithStatus, maxCallNo=niter, fit$par,...,jump=cov0,lower=lower,upper=upper,prior=NULL,var0=var0,wvar0=0.1,niter=niter,outputlength=niter,burninlength=0,updatecov=niter,ntrydr=1,drscale=NULL,verbose=verbose)
  
    # Construct the fit object to return
    fit$start <- data.frame(initial = c(state.ini.optim, parms.optim))
    fit$start$type <- c(rep("state", length(state.ini.optim)), 
        rep("deparm", length(parms.optim)))
    fit$start$lower <- lower
    fit$start$upper <- upper
    fit$fixed <- data.frame(value = c(state.ini.fixed, parms.fixed))
    fit$fixed$type <- c(rep("state", length(state.ini.fixed)), 
        rep("deparm", length(parms.fixed)))

    fit$mkindiff <- mkindiff
    fit$map <- cake.model$map
    fit$diffs <- cake.model$diffs
   
    # Replace mean from modFit with mean from modMCMC
    fnm <- function(x) mean(res$pars[,x])
    fit$par <- sapply(dimnames(res$pars)[[2]],fnm)
    fit$bestpar <- res$bestpar
    fit$costfn <- costWithStatus
   
    # Disappearence times
    parms.all <- c(fit$par, parms.fixed)
    obs_vars <- unique(as.character(observed$name))
    fit$distimes <- data.frame(DT50 = rep(NA, length(obs_vars)), 
        DT90 = rep(NA, length(obs_vars)), row.names = obs_vars)
    fit$extraDT50<- data.frame(k1 = rep(NA, length(names(cake.model$map))), k2 = rep(NA, length(names(cake.model$map))), row.names = names(cake.model$map))     
    
    for (compartment.name in names(cake.model$map)) {
        type <- names(cake.model$map[[compartment.name]])[1]
        fit$distimes[compartment.name, ] <- CakeDT(type,compartment.name,parms.all,dfopDtMaxIter)
        fit$extraDT50[compartment.name, ] <- CakeExtraDT(type, compartment.name, parms.all)
    }
    
    fit$ioreRepDT <- CakeIORERepresentativeDT("Parent", parms.all)
    fit$fomcRepDT <- CakeFOMCBackCalculatedDT(parms.all)
  
    # Ensure initial state is at time 0
    obstimes <- unique(c(0,observed$time))
   
    # Solve the system
    out_predicted <- CakeOdeSolve(fit, obstimes, solution = "deSolve", atol)
    out_transformed <- PostProcessOdeOutput(out_predicted, cake.model, atol)
   
    fit$predicted <- out_transformed
    fit$penalties <- CakePenaltiesLong(parms.all, out_transformed, observed)

    predicted_long <- wide_to_long(out_transformed,time="time")
    obs_vars <- unique(as.character(observed$name))
    fit$errmin <- CakeChi2(cake.model, observed, predicted_long, obs_vars, parms.optim, state.ini.optim, state.ini, parms.ini, fit$fixed)

    data<-observed
    data$err<-rep(NA,length(data$time))
    data<-merge(data, predicted_long, by=c("time","name"))
    names(data)<-c("time", "variable", "observed","err-var", "predicted")
    data$residual<-data$observed-data$predicted
    data$variable <- ordered(data$variable, levels = obs_vars)
    fit$data <- data[order(data$variable, data$time), ]
    fit$atol <- atol
    fit$topology <- cake.model$topology

    sq <- data$residual * data$residual
    fit$ssr <- sum(sq)
   
    fit$seed <- seed
   
    fit$res <- res
    class(fit) <- c("CakeMcmcFit", "mkinfit", "modFit")
    return(fit)
}


# Summarise a fit
summary.CakeMcmcFit <- function(object, data = TRUE, distimes = TRUE, halflives = TRUE, ff = TRUE, cov = FALSE,...) {
  param  <- object$par
  pnames <- names(param)
  p      <- length(param)
  #covar  <- try(solve(0.5*object$hessian), silent = TRUE)   # unscaled covariance
  mcmc <- object$res
  covar <- cov(mcmc$pars)

  rdf    <- object$df.residual
  
    message <- "ok"
    rownames(covar) <- colnames(covar) <-pnames
    
    #se     <- sqrt(diag(covar) * resvar)
    fnse <- function(x) sd(mcmc$pars[,x])	#/sqrt(length(mcmc$pars[,x]))
    se <- sapply(dimnames(mcmc$pars)[[2]],fnse)
    
    tval      <- param / se

  if (!all(object$start$lower >=0)) {
    message <- "Note that the one-sided t-test may not be appropriate if
      parameter values below zero are possible."
    warning(message)
  } else message <- "ok"

    # Filter the values for t-test, only apply t-test to k-values  
    t.names  <- grep("k(\\d+)|k_(.*)", names(tval), value = TRUE)
    t.rest   <- setdiff(names(tval), t.names)
    t.values <- c(tval)
    t.values[t.rest] <- NA
    t.result <- pt(t.values, rdf, lower.tail = FALSE)
    
    # Now set the values we're not interested in for the lower 
    # and upper bound we're not interested in to NA
    t.param <- c(param)
    t.param[t.names] <- NA
    # calculate the 90% confidence interval
    alpha <- 0.10
    lci90 <- t.param + qt(alpha/2,rdf)*se
    uci90 <- t.param + qt(1-alpha/2,rdf)*se
    
    # calculate the 95% confidence interval
    alpha <- 0.05
    lci95 <- t.param + qt(alpha/2,rdf)*se
    uci95 <- t.param + qt(1-alpha/2,rdf)*se

    param <- cbind(param, se, tval, t.result, lci90, uci90, lci95, uci95)
    dimnames(param) <- list(pnames, c("Estimate", "Std. Error",
                                    "t value", "Pr(>t)", "Lower CI (90%)", "Upper CI (90%)", "Lower CI (95%)", "Upper CI (95%)"))
       
   # Residuals from mean of MCMC fit
   resvar <- object$ssr/ rdf
   modVariance <- object$ssr / length(object$data$residual)
  
   ans <- list(ssr = object$ssr,
               residuals = object$data$residuals,
               residualVariance = resvar,
               sigma = sqrt(resvar),
               modVariance = modVariance,
               df = c(p, rdf), cov.unscaled = covar,
               cov.scaled = covar * resvar,
               info = object$info, niter = object$iterations,
               stopmess = message,
               par = param)
    
  ans$diffs <- object$diffs
  ans$data <- object$data
  ans$additionalstats <- CakeAdditionalStats(object$data)
  ans$seed <- object$seed

  ans$start <- object$start
  ans$fixed <- object$fixed
  ans$errmin <- object$errmin
  ans$penalties <- object$penalties
  if(distimes){ 
    ans$distimes <- object$distimes
    ans$extraDT50 <- object$extraDT50
    ans$ioreRepDT <- object$ioreRepDT
    ans$fomcRepDT <- object$fomcRepDT
  }
  if(halflives) ans$halflives <- object$halflives
  if(ff) ans$ff <- object$ff
  class(ans) <- c("summary.CakeFit","summary.mkinfit", "summary.modFit") 
  return(ans)  
}

Contact - Imprint