summaryrefslogblamecommitdiff
path: root/CakeIrlsFit.R
blob: 6a50621d2ce328074e58f485bf8431af5cd1d83a (plain) (tree)
1
2
3
4
5
6
7
8
     
 
                                              

                                                                             
                                                           
 
                                                                                









                                                                         
                                                                          
 
 












                                                                                                          
                                                                                  

                                                               
















































                                                                                                                                 
 






                                                               

             
 

























                                                                                                                           
         
 
























                                                                                                                               

         


                                                                                                                         
 





                                                                                                                                     
         
 


                                                
 

                                                                                                                                   
 
















                                                                   
#$Id$
#
# Some of the CAKE R modules are based on mkin
# Modifications developed by Hybrid Intelligence (formerly Tessella), part of
# Capgemini Engineering, for Syngenta, Copyright (C) 2011-2022 Syngenta
# Tessella Project Reference: 6245, 7247, 8361, 7414, 10091
 
#    The CAKE R modules are free software: you can redistribute it and/or modify
#    it under the terms of the GNU General Public License as published by
#    the Free Software Foundation, either version 3 of the License, or
#    (at your option) any later version.
# 
#    This program is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#    GNU General Public License for more details.
# 
#    You should have received a copy of the GNU General Public License
#    along with this program.  If not, see <http://www.gnu.org/licenses/>.
#

# Performs an iteratively-reweighted least squares fit on a given CAKE model.
# Remark: this function was originally based on the "mkinfit" function, version 0.1.
#
# cake.model: The model to perform the fit on (as generated by CakeModel.R).
# observed: Observation data to fit to.
# parms.ini: Initial values for the parameters being fitted.
# state.ini: Initial state (i.e. initial values for concentration, the dependent variable being modelled).
# lower: Lower bounds to apply to parameters.
# upper: Upper bound to apply to parameters.
# fixed_parms: A vector of names of parameters that are fixed to their initial values.
# fixed_initials: A vector of compartments with fixed initial concentrations.
# quiet: Whether the internal cost functions should execute more quietly than normal (less output).
# atol: The tolerance to apply to the ODE solver.
# dfopDtMaxIter: The maximum number of iterations to apply to DFOP DT calculation.
# control: ...
# useExtraSolver: Whether to use the extra solver for this fit.
CakeIrlsFit <- function(cake.model,
                         observed,
                         parms.ini,
                         state.ini,
                         lower = 0,
                         upper = Inf,
                         fixed_parms = NULL,
                         fixed_initials = names(cake.model$diffs)[-1],
                         quiet = FALSE,
                         atol = 1e-6,
                         dfopDtMaxIter = 10000,
                         control = list(),
                         useExtraSolver = FALSE) {

    fit <- CakeFit("IRLS",
                    cake.model,
                    observed,
                    parms.ini,
                    state.ini,
                    lower,
                    upper,
                    fixed_parms,
                    fixed_initials,
                    quiet,
                    atol = atol,
                    dfopDtMaxIter = dfopDtMaxIter,
                    control = control,
                    useExtraSolver = useExtraSolver)

    return(fit)
}

GetIrlsSpecificSetup <- function() {
    return(function(
           cake.model,
           state.ini.optim,
           state.ini.optim.boxnames,
           state.ini.fixed,
           parms.fixed,
           observed,
           mkindiff,
           quiet,
           atol,
           solution,
           ...) {
        control <- list(...)$control
        # Get the CAKE cost functions
        costFunctions <- CakeInternalCostFunctions(cake.model, state.ini.optim, state.ini.optim.boxnames, state.ini.fixed,
                                             parms.fixed, observed, mkindiff = mkindiff, quiet, atol = atol, solution = solution)

        if (length(control) == 0) {
            irls.control <- list(maxIter = 10, tol = 1e-05)
            control <- list(irls.control = irls.control)
        } else {
            if (is.null(control$irls.control)) {
                irls.control <- list(maxIter = 10, tol = 1e-05)
                control <- list(irls.control = irls.control)
            }
        }

        irls.control <- control$irls.control
        maxIter <- irls.control$maxIter
        tol <- irls.control$tol

        return(list(costFunctions = costFunctions, control = control, maxIter = maxIter, tol = tol, costWithStatus = NULL))
    })
}

GetIrlsOptimisationRoutine <- function() {
    return(function(costFunctions, costForExtraSolver, useExtraSolver, parms, lower, upper, control, ...) {
        irlsArgs <- list(...)
        cake.model <- irlsArgs$cake.model
        tol <- irlsArgs$tol
        fitted_with_extra_solver <- irlsArgs$fitted_with_extra_solver
        observed <- irlsArgs$observed
        maxIter <- irlsArgs$maxIter

        pnames = names(parms)

        fitStepResult <- RunFitStep(costFunctions$cost, costForExtraSolver, useExtraSolver, parms, lower, upper, control)
        fit <- fitStepResult$fit
        fitted_with_extra_solver <- fitStepResult$fitted_with_extra_solver

        if (length(cake.model$map) == 1) {
            ## there is only one parent then don't do any re-weighting
            maxIter <- 0
        }

        niter <- 1
        ## ensure one IRLS iteration
        diffsigma <- 100
        olderr <- rep(1, length(cake.model$map))
        while (diffsigma > tol & niter <= maxIter) {
            # Read info from FF into fit if extra solver was used and set observed$err
            if (fitted_with_extra_solver) {
                fit <- GetFitValuesAfterExtraSolver(fit, FF)
            }
            err <- sqrt(fit$var_ms_unweighted)
            ERR <- err[as.character(observed$name)]
            costFunctions$set.error(ERR)

            diffsigma <- sum((err - olderr) ^ 2)
            cat("IRLS iteration at", niter, "; Diff in error variance ", diffsigma, "\n")
            olderr <- err

            fitStepResult <- RunFitStep(costFunctions$cost, costForExtraSolver, useExtraSolver, fit$par, lower, upper, control)

            fit <- fitStepResult$fit
            fitted_with_extra_solver <- fitStepResult$fitted_with_extra_solver

            niter <- niter + 1

            # If not converged, reweight and fit
        }

        # Read info from FF into fit if extra solver was used, set fit$residuals and get correct Hessian as solnp doesn't
        if (fitted_with_extra_solver) {
            fit <- GetFitValuesAfterExtraSolver(fit, FF)

            np <- length(parms)
            fit$rank <- np
            fit$df.residual <- length(fit$residuals) - fit$rank

            # solnp can return an incorrect Hessian, so we use another fitting method at the optimised point to determine the Hessian
            fit <- modFit(costFunctions$cost, fit$par, lower = lower, upper = upper, method = 'L-BFGS-B', control = list())
        }

        err1 <- sqrt(fit$var_ms_unweighted)
        ERR <- err1[as.character(observed$name)]
        observed$err <- ERR

        return(list(fit = fit, fitted_with_extra_solver = fitStepResult$fitted_with_extra_solver, observed = observed, res = NULL))
    })
}

GetIrlsSpecificWrapUp <- function() {
    return(function(fit, ...) {
        args <- list(...)
        parms.fixed <- args$parms.fixed
        state.ini <- args$state.ini
        observed <- args$observed

        parms.all <- c(fit$par, parms.fixed, state.ini)

        data <- observed

        class(fit) <- c("CakeFit", "mkinfit", "modFit")

        return(list(fit = fit, parms.all = parms.all, data = data))
    })
}

Contact - Imprint