summaryrefslogblamecommitdiff
path: root/CakeMcmcFit.R
blob: 0a5bf4a78f20301556ea1abd782bcf2a476354d5 (plain) (tree)
1
2
3
4
5
6
7
                                                
                                          

                                                                             
                                                           
 
                                                                                









                                                                         
                                                                          
 












                                                                                                          
                                                                                  

                                                                                                     





































































                                                                                                                                 
         




                                                                            
         



























                                                                                                                             

         




                                              
 




                                              
         

























                                                                                                                                                                                                                                                                                      

 





























                                                                                                      

                 










                                                                                                                      
                   

                                                
                                         



                                                                        
 




                                                                            
 
                                                                   

                                                                


                                                     
 

                                                               
                       


                                           


                                                  
                                           
                 

                                                  



                                                                                                                                 





                                                            








                                                             



















                                                                           
 
# Some of the CAKE R modules are based on mkin, 
# Based on mcmckinfit as modified by Bayer
# Modifications developed by Hybrid Intelligence (formerly Tessella), part of
# Capgemini Engineering, for Syngenta, Copyright (C) 2011-2022 Syngenta
# Tessella Project Reference: 6245, 7247, 8361, 7414, 10091

#    The CAKE R modules are free software: you can redistribute it and/or modify
#    it under the terms of the GNU General Public License as published by
#    the Free Software Foundation, either version 3 of the License, or
#    (at your option) any later version.
# 
#    This program is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#    GNU General Public License for more details.
# 
#    You should have received a copy of the GNU General Public License
#    along with this program.  If not, see <http://www.gnu.org/licenses/>.

# Performs an Markov chain Monte Carlo least squares fit on a given CAKE model.
#
# cake.model: The model to perform the fit on (as generated by CakeModel.R).
# observed: Observation data to fit to.
# parms.ini: Initial values for the parameters being fitted.
# state.ini: Initial state (i.e. initial values for concentration, the dependent variable being modelled).
# lower: Lower bounds to apply to parameters.
# upper: Upper bound to apply to parameters.
# fixed_parms: A vector of names of parameters that are fixed to their initial values.
# fixed_initials: A vector of compartments with fixed initial concentrations.
# quiet: Whether the internal cost functions should execute more quietly than normal (less output).
# niter: The number of MCMC iterations to apply.
# atol: The tolerance to apply to the ODE solver.
# dfopDtMaxIter: The maximum number of iterations to apply to DFOP DT calculation.
# control: ...
# useExtraSolver: Whether to use the extra solver for this fit (only used for the initial first fit).
CakeMcmcFit <- function(cake.model,
                         observed,
                         parms.ini,
                         state.ini,
                         lower,
                         upper,
                         fixed_parms = NULL,
                         fixed_initials,
                         quiet = FALSE,
                         niter = 1000,
                         verbose = TRUE,
                         seed = NULL,
                         atol = 1e-6,
                         dfopDtMaxIter = 10000,
                         control = list(),
                         useExtraSolver = FALSE) {

    fit <- CakeFit("MCMC",
                    cake.model,
                    observed,
                    parms.ini,
                    state.ini,
                    lower,
                    upper,
                    fixed_parms,
                    fixed_initials,
                    quiet,
                    niter = niter,
                    verbose = verbose,
                    seed = seed,
                    atol = atol,
                    dfopDtMaxIter = dfopDtMaxIter,
                    control = control,
                    useExtraSolver = useExtraSolver)

    return(fit)
}

GetMcmcSpecificSetup <- function() {
    return(function(
           cake.model,
           state.ini.optim,
           state.ini.optim.boxnames,
           state.ini.fixed,
           parms.fixed,
           observed,
           mkindiff,
           quiet,
           atol,
           solution,
           ...) {
        seed <- list(...)$seed

        costFunctions <- CakeInternalCostFunctions(cake.model, state.ini.optim, state.ini.optim.boxnames, state.ini.fixed,
                                             parms.fixed, observed, mkindiff = mkindiff, quiet, atol = atol, solution = solution)

        bestIteration <<- -1;
        costWithStatus <- function(P, ...) {
            r <- costFunctions$cost(P)
            if (r$cost == costFunctions$get.best.cost()) {
                bestIteration <<- costFunctions$get.calls();
                cat(' MCMC best so far: c', r$cost, 'it:', bestIteration, '\n')
            }

            arguments <- list(...)
            if (costFunctions$get.calls() <= arguments$maxCallNo) {
                cat("MCMC Call no.", costFunctions$get.calls(), '\n')
            }

            return(r)
        }

        # Set the seed
        if (is.null(seed)) {
            # No seed so create a random one so there is something to report
            seed <- runif(1, 0, 2 ^ 31 - 1)
        }

        seed <- as.integer(seed)
        set.seed(seed)

        return(list(costFunctions = costFunctions, costWithStatus = costWithStatus, maxIter = NULL, tol = NULL, seed = seed))
    })
}

GetMcmcOptimisationRoutine <- function() {
    return(function(costFunctions, costForExtraSolver, useExtraSolver, parms, lower, upper, control, ...) {
        mcmcArgs <- list(...)
        cake.model <- mcmcArgs$cake.model
        costWithStatus <- mcmcArgs$costWithStatus
        observed <- mcmcArgs$observed
        niter <- mcmcArgs$niter
        verbose <- mcmcArgs$verbose

        # Runs a pre-fit with no weights first, followed by a weighted step (extra solver with first, not with second)

        # Run optimiser, no weighting
        fitStepResult <- RunFitStep(costFunctions$cost, costForExtraSolver, useExtraSolver, parms, lower, upper, control)
        fit <- fitStepResult$fit
        fitted_with_extra_solver <- fitStepResult$fitted_with_extra_solver


        # Process extra solver output if it was used
        if (fitted_with_extra_solver) {
            fit <- GetFitValuesAfterExtraSolver(fit, FF)
        }

        # One reweighted estimation
        # Estimate the error variance(sd)     
        tmpres <- fit$residuals
        oldERR <- observed$err
        err <- rep(NA, length(cake.model$map))

        for (i in 1:length(cake.model$map)) {
            box <- names(cake.model$map)[i]
            ind <- which(names(tmpres) == box)
            tmp <- tmpres[ind]
            err[i] <- sd(tmp)
        }

        names(err) <- names(cake.model$map)
        ERR <- err[as.character(observed$name)]
        observed$err <- ERR
        costFunctions$set.error(ERR)

        olderr <- rep(1, length(cake.model$map))
        diffsigma <- sum((err - olderr) ^ 2)

        ## At least do one iteration step to get a weighted LS
        fit <- modFit(f = costFunctions$cost, p = fit$par, lower = lower, upper = upper)

        # Run MCMC optimiser with output from weighted fit
        # Apply iterative re-weighting here (iterations fixed to 1 for now):
        # Do modMCMC as below,  we should also pass in the final priors for subsequent iterations
        # Use modMCMC average as input to next modMCMC run (as done in block 3 to get final parameters)
        # use errors from previous step as inputs to modMCMC cov0 and var0 at each iteration

        fs <- summary(fit)
        cov0 <- if (all(is.na(fs$cov.scaled))) NULL else fs$cov.scaled * 2.4 ^ 2 / length(fit$par)
        var0 <- fit$var_ms_unweighted
        costFunctions$set.calls(0);
        costFunctions$reset.best.cost()
        res <- modMCMC(f = costWithStatus, p = fit$par, maxCallNo = niter, jump = cov0, lower = lower, upper = upper, prior = NULL, var0 = var0, wvar0 = 0.1, niter = niter, outputlength = niter, burninlength = 0, updatecov = niter, ntrydr = 1, drscale = NULL, verbose = verbose)
        return(list(fit = fitStepResult$fit, fitted_with_extra_solver = fitStepResult$fitted_with_extra_solver, res = res))
    })
}

GetMcmcSpecificWrapUp <- function() {
    return(function(fit, ...) {
        args <- list(...)
        res <- args$res
        seed <- args$seed
        costWithStatus <- args$costWithStatus
        observed <- args$observed
        parms.fixed <- args$parms.fixed

        # Replace mean from modFit with mean from modMCMC
        fnm <- function(x) mean(res$pars[, x])
        fit$par <- sapply(dimnames(res$pars)[[2]], fnm)
        fit$bestpar <- res$bestpar
        fit$costfn <- costWithStatus
        parms.all <- c(fit$par, parms.fixed)

        data <- observed
        data$err <- rep(NA, length(data$time))

        fit$seed <- seed
        fit$res <- res

        np <- length(parms.all)
        fit$rank <- np
        fit$df.residual <- length(fit$residuals) - fit$rank

        class(fit) <- c("CakeMcmcFit", "mkinfit", "modFit") # Note different class to other optimisers
        return(list(fit = fit, parms.all = parms.all, data = data))
    })
}

# Summarise a fit
# The MCMC summary is separate from the others due to the difference in the outputs of the modMCMC and modFit.
summary.CakeMcmcFit <- function(object, data = TRUE, distimes = TRUE, halflives = TRUE, ff = TRUE, cov = FALSE, ...) {
    param <- object$par
    pnames <- names(param)
    p <- length(param)
    #covar  <- try(solve(0.5*object$hessian), silent = TRUE)   # unscaled covariance
    mcmc <- object$res
    covar <- cov(mcmc$pars)

    rdf <- object$df.residual

    message <- "ok"
    rownames(covar) <- colnames(covar) <- pnames

    #se     <- sqrt(diag(covar) * resvar)
    fnse <- function(x) sd(mcmc$pars[, x]) #/sqrt(length(mcmc$pars[,x]))
    se <- sapply(dimnames(mcmc$pars)[[2]], fnse)

    tval <- param / se

    if (!all(object$start$lower >= 0)) {
        message <- "Note that the one-sided t-test may not be appropriate if
        parameter values below zero are possible."
        warning(message)
    } else message <- "ok"

    # Filter the values for t-test, only apply t-test to k-values  
    t.names <- grep("k(\\d+)|k_(.*)", names(tval), value = TRUE)
    t.rest <- setdiff(names(tval), t.names)
    t.values <- c(tval)
    t.values[t.rest] <- NA
    t.result <- pt(t.values, rdf, lower.tail = FALSE)

    # Now set the values we're not interested in for the lower 
    # and upper bound we're not interested in to NA
    t.param <- c(param)
    t.param[t.names] <- NA
    # calculate the 90% confidence interval
    alpha <- 0.10
    lci90 <- t.param + qt(alpha / 2, rdf) * se
    uci90 <- t.param + qt(1 - alpha / 2, rdf) * se

    # calculate the 95% confidence interval
    alpha <- 0.05
    lci95 <- t.param + qt(alpha / 2, rdf) * se
    uci95 <- t.param + qt(1 - alpha / 2, rdf) * se

    param <- cbind(param, se, tval, t.result, lci90, uci90, lci95, uci95)
    dimnames(param) <- list(pnames, c("Estimate", "Std. Error",
                                    "t value", "Pr(>t)", "Lower CI (90%)", "Upper CI (90%)", "Lower CI (95%)", "Upper CI (95%)"))

    # Residuals from mean of MCMC fit
    resvar <- object$ssr / rdf
    modVariance <- object$ssr / length(object$data$residual)

    ans <- list(ssr = object$ssr,
               residuals = object$data$residuals,
               residualVariance = resvar,
               sigma = sqrt(resvar),
               modVariance = modVariance,
               df = c(p, rdf), cov.unscaled = covar,
               cov.scaled = covar * resvar,
               info = object$info, niter = object$iterations,
               stopmess = message,
               par = param)

    ans$diffs <- object$diffs
    ans$data <- object$data
    ans$additionalstats <- CakeAdditionalStats(object$data)
    ans$seed <- object$seed

    ans$start <- object$start
    ans$fixed <- object$fixed
    ans$errmin <- object$errmin
    ans$penalties <- object$penalties
    if (distimes) {
        ans$distimes <- object$distimes
        ans$extraDT50 <- object$extraDT50
        ans$ioreRepDT <- object$ioreRepDT
        ans$fomcRepDT <- object$fomcRepDT
    }
    if (halflives) ans$halflives <- object$halflives
    if (ff) ans$ff <- object$ff
    class(ans) <- c("summary.CakeFit", "summary.mkinfit", "summary.modFit")
    return(ans)
}

Contact - Imprint