summaryrefslogblamecommitdiff
path: root/CakeCost.R
blob: 110275ed2c18da0017339340d90a14b6115bc4cd (plain) (tree)
1
2
3
4
5
6
7
8
9


                                                                                
                                               


                                                                      

                                                                             
                                                           
 
                                               










                                                                               
                                                           


                                                                           



                                                                               




































































































                                                                                                           
                       















































                                                                                               
    













                                                                                                                





                                          
                                                     



                                             







                                                       



                                           










































                                                                                                       

                                                                                          

                                                                                       


                       



















































































                                                                                                           
                                                                                                                     










                                                                     
 
## -----------------------------------------------------------------------------
## The model cost and residuals
## -----------------------------------------------------------------------------
# Some of the CAKE R modules are based on mkin.
# Call to approx is only performed if there are multiple non NA values
# which should prevent most of the crashes - Rob Nelson (Tessella)
#
# Modifications developed by Hybrid Intelligence (formerly Tessella), part of
# Capgemini Engineering, for Syngenta, Copyright (C) 2011-2022 Syngenta
# Tessella Project Reference: 6245, 7247, 8361, 7414, 10091
#
# The CAKE R modules are free software: you can
# redistribute them and/or modify them under the
# terms of the GNU General Public License as published by the Free Software
# Foundation, either version 3 of the License, or (at your option) any later
# version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
# details.
#
# You should have received a copy of the GNU General Public License along with
# this program. If not, see <http://www.gnu.org/licenses/> 

CakeCost <- function (model, obs, x = "time", y = NULL, err = NULL,
                     weight = "none", scaleVar = FALSE, cost = NULL, ...) {
  ## Sometimes a fit is encountered for which the model is unable to calculate 
  ## values on the full range of observed values. In this case, we will return
  ## an infinite cost to ensure this value is not selected.
  modelCalculatedFully <- all(unlist(obs[x]) %in% unlist(model[x]))

  ## convert vector to matrix
  if (is.vector(obs)) {
    cn <- names(obs)
    obs   <- matrix(data = obs, nrow = 1)
    colnames(obs) <- cn
  }
  if (is.vector(model)) {
    cn <- names(model)
    model <- matrix(data=model, nrow = 1)
    colnames(model) <- cn
  }

  ##=============================================================================
  ## Observations
  ##=============================================================================

  ## The position of independent variable(s)
  ix <- 0
  if (! is.null(x))  {   # mapping required...
    ## For now multiple independent variables are not supported...
    if (length(x) > 1)
      stop ("multiple independent variables in 'obs' are not yet supported")

    if (! is.character(x))
      stop ("'x' should be the *name* of the column with the independent variable in 'obs' or NULL")
    ix  <- which(colnames(obs) %in% x)
    if (length(ix) != length(x))
      stop(paste("Independent variable column not found in observations", x))
  } else ix <- NULL

  ## The position of weighing values
  ierr <- 0
  if (! is.null(err)) {
    if (! is.character(err))
      stop ("'err' should be the *name* of the column with the error estimates in obs or NULL")
    ierr <- which(colnames(obs) == err)    # only one
    if (length(ierr) == 0)
      stop(paste("Column with error estimates not found in observations", err))
  }

  ## The dependent variables
  type <- 1           # data input type: type 2 is table format, type 1 is long format...

  if (!is.null(y)) {   # it is in table format; first column are names of observed data...

    Names    <- as.character(unique(obs[, 1]))  # Names of data sets, all data should be model variables...
    Ndat     <- length(Names)                   # Number of data sets
    ilist    <- 1:Ndat
    if (! is.character(y))
      stop ("'y' should be the *name* of the column with the values of the dependent variable in obs")
    iy  <- which(colnames(obs) == y)
    if (length(iy) == 0)
      stop(paste("Column with value of dependent variable not found in observations", y))
    type <- 2

  } else  {             # it is a matrix, variable names are column names
    Ndat     <- NCOL(obs)-1
    Names    <- colnames(obs)
    ilist    <- (1:NCOL(obs))        # column positions of the (dependent) observed variables
    exclude  <- ix                   # exclude columns that are not
    if (ierr > 0)
      exclude <- c(ix, ierr)          # exclude columns that are not
    if (length(exclude) > 0)
      ilist <- ilist[-exclude]
  }

  #================================
  # The model results
  #================================

  ModNames <- colnames(model)  # Names of model variables
  if (length(ix) > 1) {
    ixMod <- NULL

    for ( i in 1:length(ix)) {
      ix2 <- which(colnames(model) == x[i])
      if (length(ix2) == 0)
        stop(paste("Cannot calculate cost: independent variable not found in model output", x[i]))
      ixMod <- c(ixMod, ix2)
    }

  xMod     <- model[,ixMod]    # Independent variable, model
  } else if (length(ix) == 1) {
   ixMod    <- which(colnames(model) == x)
   if (length(ixMod) == 0)
     stop(paste("Cannot calculate cost: independent variable not found in model output", x))
   xMod     <- model[,ixMod]    # Independent variable, model
  }
  Residual <- NULL
  CostVar  <- NULL

  #================================
  # Compare model and data...
  #================================
  xDat <- 0
  iDat <- 1:nrow(obs)

  for (i in ilist) {   # for each observed variable ...
    ii <- which(ModNames == Names[i])
    if (length(ii) == 0) stop(paste("observed variable not found in model output", Names[i]))
    yMod <- model[, ii]
    if (type == 2)  {  # table format
      iDat <- which (obs[,1] == Names[i])
      if (length(ix) > 0) xDat <- obs[iDat, ix]
      obsdat <- obs[iDat, iy]
    } else {
      if (length(ix) > 0) xDat <- obs[, 1]
      obsdat <- obs[,i]
    }
    ii <- which(is.na(obsdat))
    if (length(ii) > 0) {
      xDat   <- xDat[-ii]
      obsdat <- obsdat[-ii]
    }

    # CAKE version - Added tests for multiple non-NA values 
    if (length(ix) > 0 && length(unique(xMod[!is.na(xMod)]))>1 && length(yMod[!is.na(yMod)])>1)
    {
      ModVar <- approx(xMod, yMod, xout = xDat)$y
    }
    else {
      cat("CakeCost Warning: Only one valid point - using mean (yMod was", yMod, ")\n")
      ModVar <- mean(yMod[!is.na(yMod)])
      obsdat <- mean(obsdat)
    }
    iex <- which(!is.na(ModVar))
    ModVar <- ModVar[iex]
    obsdat <- obsdat[iex]
    xDat   <- xDat[iex]
    if (ierr > 0) {
      Err <- obs[iDat, ierr]
      Err <- Err[iex]
    } else {
      if (weight == "std")
        Err <- sd(obsdat)
      else if (weight == "mean")
        Err <- mean(abs(obsdat))
      else if (weight == "none")
        Err <- 1
      else
       stop("error: do not recognize 'weight'; should be one of 'none', 'std', 'mean'")
    }
    if (any(is.na(Err)))
      stop(paste("error: cannot estimate weighing for observed variable: ", Names[i]))
    if (min(Err) <= 0)
      stop(paste("error: weighing for observed variable is 0 or negative:", Names[i]))
    if (scaleVar)
      Scale <- 1/length(obsdat)
    else Scale <- 1
    
    if(!modelCalculatedFully){ # In this case, the model is unable to predict on the full range, set cost to Inf
      xDat <- 0
      obsdat <- 0
      ModVar <- Inf
      Res <- Inf
      res <- Inf
      weight_for_residual <- Inf
    } else{
      Res <- (ModVar- obsdat)
      res <- Res / Err
      weight_for_residual <- 1 / Err
    }
    
    resScaled <- res * Scale
    Residual <- rbind(Residual,
                      data.frame(
                        name   = Names[i],
                        x      = xDat,
                        obs    = obsdat,
                        mod    = ModVar,
                        weight = weight_for_residual,
                        res.unweighted = Res,
                        res    = res))

    CostVar <- rbind(CostVar,
                  data.frame(
                    name           = Names[i],
                    scale          = Scale,
                    N              = length(Res),
                    SSR.unweighted = sum(Res^2),
                    SSR.unscaled   = sum(res^2),
                    SSR            = sum(resScaled^2)))
                    
  }  # end loop over all observed variables

  ## SSR
  Cost  <- sum(CostVar$SSR * CostVar$scale)
  Lprob <- -sum(log(pmax(0, dnorm(Residual$mod, Residual$obs, Err)))) # avoid log of negative values
  #Lprob <- -sum(log(pmax(.Machine$double.xmin, dnorm(Residual$mod, Residual$obs, Err)))) #avoid log(0)

  if (! is.null(cost)) {
    Cost     <- Cost + cost$model
    CostVar  <- rbind(CostVar, cost$var)
    Residual <- rbind(Residual, cost$residuals)
    Lprob    <- Lprob + cost$minlogp
  }
  out <- list(model = Cost, cost = Cost, minlogp = Lprob, var = CostVar, residuals = Residual)
  class(out) <- "modCost"
  return(out)
}

## -----------------------------------------------------------------------------
## S3 methods of modCost
## -----------------------------------------------------------------------------

plot.modCost<- function(x, legpos="topleft", ...) {
  nvar <- nrow(x$var)

  dots <- list(...)

  dots$xlab <- if(is.null(dots$xlab)) "x" else dots$xlab
  dots$ylab <- if(is.null(dots$ylab)) "weighted residuals" else dots$ylab
  DotsPch   <- if(is.null(dots$pch)) (16:24) else dots$pch
  dots$pch  <- if(is.null(dots$pch)) (16:24)[x$residuals$name] else dots$pch[x$residuals$name]
  DotsCol   <- if(is.null(dots$col)) (1:nvar) else dots$col
  dots$col  <- if(is.null(dots$col)) (1:nvar)[x$residuals$name] else dots$col[x$residuals$name]

  do.call("plot", c(alist(x$residuals$x, x$residuals$res), dots))

#  plot(x$residuals$x, x$residuals$res, xlab="x", ylab="weighted residuals",
#     pch=c(16:24)[x$residuals$name],col=c(1:nvar)[x$residuals$name],...)

  if (! is.na(legpos))
    legend(legpos, legend = x$var$name, col = DotsCol, pch = DotsPch)
}

## -----------------------------------------------------------------------------
## Internal cost function for optimisers
## -----------------------------------------------------------------------------
# Cost function. The returned structure must have $model
# We need to preserve state between calls so make a closure
CakeInternalCostFunctions <- function(mkinmod, state.ini.optim, state.ini.optim.boxnames, 
                                    state.ini.fixed, parms.fixed, observed, mkindiff,  
                                    quiet, atol=1e-6, solution="deSolve", err="err"){
    cost.old <- 1e+100
    calls <- 0
    out_predicted <- NA
    
    get.predicted <- function(){ out_predicted }
    
    get.best.cost <- function(){ cost.old }
    reset.best.cost <- function() { cost.old<<-1e+100 }
    
    get.calls <- function(){ calls }
    set.calls <- function(newcalls){ calls <<- newcalls }
    
    set.error<-function(err) { observed$err <<- err }
    
    # The called cost function
    cost <- function(P) {
        assign("calls", calls + 1, inherits = TRUE)
        print(P)
        
        if (length(state.ini.optim) > 0) {
            odeini <- c(P[1:length(state.ini.optim)], state.ini.fixed)
            names(odeini) <- c(state.ini.optim.boxnames, names(state.ini.fixed))
        } else {
          odeini <- state.ini.fixed
        }
        
        odeparms <- c(P[(length(state.ini.optim) + 1):length(P)], parms.fixed)
        
        # Ensure initial state is at time 0
        outtimes = unique(c(0,observed$time))
        
        odeini <- AdjustOdeInitialValues(odeini, mkinmod, odeparms)
        
        if (solution == "analytical") {
          evalparse <- function(string)
          {
            eval(parse(text=string), as.list(c(odeparms, odeini)))
          }
          
          parent.type = names(mkinmod$map[[1]])[1]  
          parent.name = names(mkinmod$diffs)[[1]]
          o <- switch(parent.type,
                      SFO = SFO.solution(outtimes, 
                                         evalparse(parent.name),
                                         evalparse(paste("k", parent.name, sep="_"))),
                      FOMC = FOMC.solution(outtimes,
                                           evalparse(parent.name),
                                           evalparse("alpha"), evalparse("beta")),
                      DFOP = DFOP.solution(outtimes,
                                           evalparse(parent.name),
                                           evalparse(paste("k1", parent.name, sep="_")), 
                                           evalparse(paste("k2", parent.name, sep="_")),
                                           evalparse(paste("g", parent.name, sep="_"))),
                      HS = HS.solution(outtimes,
                                       evalparse(parent.name),
                                       evalparse("k1"), evalparse("k2"),
                                       evalparse("tb")),
                      IORE = IORE.solution(outtimes,
                                           evalparse(parent.name),
                                           evalparse(paste("k", parent.name, sep="_")),
                                           evalparse("N")))
          
          out <- cbind(outtimes, o)
          dimnames(out) <- list(outtimes, c("time", sub("_free", "", parent.name)))
        }
        if (solution == "deSolve")  
        {
          out <- ode(y = odeini, times = outtimes, func = mkindiff, parms = odeparms, atol = atol)
        }
        
        out_transformed <- PostProcessOdeOutput(out, mkinmod, atol)
        
        assign("out_predicted", out_transformed, inherits = TRUE)
        mC <- CakeCost(out_transformed, observed, y = "value",  err = err, scaleVar = FALSE)
        mC$penalties <- CakePenalties(odeparms, out_transformed, observed)
        mC$model <- mC$cost + mC$penalties
        
        if (mC$model < cost.old) {
            if (!quiet) {
                cat("Model cost at call ", calls, ": m", mC$cost, 'p:', mC$penalties, 'o:', mC$model, "\n")
            }
          
            assign("cost.old", mC$model, inherits = TRUE)
        }
        
        # HACK to make nls.lm respect the penalty, as it just uses residuals and ignores the cost
        if(mC$penalties > 0){
            mC$residuals$res <- mC$residuals$res + (sign(mC$residuals$res) * mC$penalties / length(mC$residuals$res))
        }
        
        return(mC)
    }
    
    list(cost=cost, 
        get.predicted=get.predicted,
        get.calls=get.calls, set.calls=set.calls,
        get.best.cost=get.best.cost, reset.best.cost=reset.best.cost,
        set.error=set.error
    )
}

Contact - Imprint