diff options
author | Johannes Ranke <jranke@uni-bremen.de> | 2019-06-05 17:03:19 +0200 |
---|---|---|
committer | Johannes Ranke <jranke@uni-bremen.de> | 2019-06-05 17:03:19 +0200 |
commit | 5243ac0ebe3223e5803b4ebee1cb619008638785 (patch) | |
tree | 6be4e8e47f8735f4bbce23d8abf328a700e028d8 /R | |
parent | b061a00ba4f9410a4b77e58a96fb7baa1247252b (diff) | |
parent | 4b323789265213bd65165873e7efe5e45a579275 (diff) |
Merge branch 'algorithm'
Diffstat (limited to 'R')
-rw-r--r-- | R/mkinfit.R | 207 |
1 files changed, 157 insertions, 50 deletions
diff --git a/R/mkinfit.R b/R/mkinfit.R index bc8b9d11..2af4e493 100644 --- a/R/mkinfit.R +++ b/R/mkinfit.R @@ -34,6 +34,8 @@ mkinfit <- function(mkinmod, observed, quiet = FALSE,
atol = 1e-8, rtol = 1e-10, n.outtimes = 100,
error_model = c("const", "obs", "tc"),
+ error_model_algorithm = c("d_3", "direct", "twostep", "threestep", "fourstep", "IRLS"),
+ reweight.tol = 1e-8, reweight.max.iter = 10,
trace_parms = FALSE,
...)
{
@@ -91,7 +93,7 @@ mkinfit <- function(mkinmod, observed, if (length(wrongpar.names) > 0) {
warning("Initial parameter(s) ", paste(wrongpar.names, collapse = ", "),
" not used in the model")
- parms.ini <- parms.ini[setdiff(names(parms.ini), wrongpar.names)]
+ parms.ini <- parms.ini[setdiff(names(parms.ini), wrongpar.names)]
}
# Warn that the sum of formation fractions may exceed one if they are not
@@ -244,6 +246,7 @@ mkinfit <- function(mkinmod, observed, # Get the error model
err_mod <- match.arg(error_model)
+ error_model_algorithm = match.arg(error_model_algorithm)
errparm_names <- switch(err_mod,
"const" = "sigma",
"obs" = paste0("sigma_", obs_vars),
@@ -276,34 +279,48 @@ mkinfit <- function(mkinmod, observed, length.out = n.outtimes))))
# Define log-likelihood function for optimisation, including (back)transformations
- nlogLik <- function(P, trans = TRUE, OLS = FALSE, local = FALSE, update_data = TRUE, ...)
+ nlogLik <- function(P, trans = TRUE, OLS = FALSE, fixed_degparms = FALSE, fixed_errparms = FALSE, update_data = TRUE, ...)
{
assign("calls", calls + 1, inherits = TRUE) # Increase the model solution counter
- P.orig <- P
# Trace parameter values if requested and if we are actually optimising
if(trace_parms & update_data) cat(P, "\n")
- # If we do a local optimisation of the error model, the initials
- # for the state variabels and the parameters are given as 'local'
- if (local[1] != FALSE) {
- P <- local
+ if (is.numeric(fixed_degparms)) {
+ degparms <- fixed_degparms
+ errparms <- P # This version of errparms is local to the function
+ degparms_fixed = TRUE
+ } else {
+ degparms_fixed = FALSE
+ }
+
+ if (is.numeric(fixed_errparms)) {
+ degparms <- P
+ errparms <- fixed_errparms # Local to the function
+ errparms_fixed = TRUE
+ } else {
+ errparms_fixed = FALSE
+ }
+
+ if (OLS) {
+ degparms <- P
+ }
+
+ if (!OLS & !degparms_fixed & !errparms_fixed) {
+ degparms <- P[1:(length(P) - length(errparms))]
+ errparms <- P[(length(degparms) + 1):length(P)]
}
# Initial states for t0
if(length(state.ini.optim) > 0) {
- odeini <- c(P[1:length(state.ini.optim)], state.ini.fixed)
+ odeini <- c(degparms[1:length(state.ini.optim)], state.ini.fixed)
names(odeini) <- c(state.ini.optim.boxnames, state.ini.fixed.boxnames)
} else {
odeini <- state.ini.fixed
names(odeini) <- state.ini.fixed.boxnames
}
- if (OLS | identical(P, local)) {
- odeparms.optim <- P[(length(state.ini.optim) + 1):length(P)]
- } else {
- odeparms.optim <- P[(length(state.ini.optim) + 1):(length(P) - length(errparms))]
- }
+ odeparms.optim <- degparms[(length(state.ini.optim) + 1):length(degparms)]
if (trans == TRUE) {
odeparms <- c(odeparms.optim, transparms.fixed)
@@ -322,23 +339,20 @@ mkinfit <- function(mkinmod, observed, method.ode = method.ode,
atol = atol, rtol = rtol, ...)
- # Get back the original parameter vector to get the error model parameters
- P <- P.orig
-
out_long <- mkin_wide_to_long(out, time = "time")
if (err_mod == "const") {
- observed$std <- P["sigma"]
+ observed$std <- errparms["sigma"]
}
if (err_mod == "obs") {
std_names <- paste0("sigma_", observed$name)
- observed$std <- P[std_names]
+ observed$std <- errparms[std_names]
}
if (err_mod == "tc") {
tmp <- merge(observed, out_long, by = c("time", "name"))
tmp$name <- ordered(tmp$name, levels = obs_vars)
tmp <- tmp[order(tmp$name, tmp$time), ]
- observed$std <- sqrt(P["sigma_low"]^2 + tmp$value.y^2 * P["rsd_high"]^2)
+ observed$std <- sqrt(errparms["sigma_low"]^2 + tmp$value.y^2 * errparms["rsd_high"]^2)
}
data_log_lik <- merge(observed[c("name", "time", "value", "std")], out_long,
@@ -351,7 +365,8 @@ mkinfit <- function(mkinmod, observed, sum(dnorm(x = value.observed, mean = value.predicted, sd = std, log = TRUE)))
}
- # We update the current likelihood and data during the optimisation, not during hessian calculations
+ # We update the current likelihood and data during the optimisation, not
+ # during hessian calculations
if (update_data) {
assign("out_predicted", out_long, inherits = TRUE)
@@ -421,42 +436,129 @@ mkinfit <- function(mkinmod, observed, # Show parameter names if tracing is requested
if(trace_parms) cat(names_optim, "\n")
+ # browser()
+
# Do the fit and take the time until the hessians are calculated
fit_time <- system.time({
- # In the inital run, we assume a constant standard deviation and do
- # not optimise it, as this is unnecessary and leads to instability of the
- # fit (most obvious when fitting the IORE model)
- if (!quiet) message("Ordinary least squares optimisation")
- parms.start <- c(state.ini.optim, transparms.optim)
- fit.ols <- nlminb(parms.start, nlogLik, control = control,
- lower = lower[names(parms.start)],
- upper = upper[names(parms.start)], OLS = TRUE, ...)
+ degparms <- c(state.ini.optim, transparms.optim)
if (err_mod == "const") {
+ if (!quiet) message("Ordinary least squares optimisation")
+ fit <- nlminb(degparms, nlogLik, control = control,
+ lower = lower[names(degparms)],
+ upper = upper[names(degparms)], OLS = TRUE, ...)
+ degparms <- fit$par
+
# Get the maximum likelihood estimate for sigma at the optimum parameter values
data_errmod$residual <- data_errmod$value.observed - data_errmod$value.predicted
- sigma_mle = sqrt(sum(data_errmod$residual^2)/nrow(data_errmod))
+ sigma_mle <- sqrt(sum(data_errmod$residual^2)/nrow(data_errmod))
- errparms = c(sigma = sigma_mle)
- fit <- fit.ols
- fit$logLik <- - nlogLik(c(fit$par, sigma = sigma_mle), OLS = FALSE)
- } else {
- # After the OLS stop we fit the error model with fixed degradation model
- # parameters
- if (!quiet) message("Optimising the error model")
- fit.err <- nlminb(errparms, nlogLik, control = control,
- lower = lower[names(errparms)],
- upper = upper[names(errparms)],
- local = fit.ols$par, ...)
- errparms.tmp <- fit.err$par
- if (!quiet) message("Optimising the complete model")
- parms.start <- c(fit.ols$par, errparms.tmp)
- fit <- nlminb(parms.start, nlogLik,
- lower = lower[names(parms.start)],
- upper = upper[names(parms.start)],
- control = control, ...)
+ errparms <- c(sigma = sigma_mle)
+ nlogLik.current <- nlogLik(c(degparms, errparms), OLS = FALSE)
fit$logLik <- - nlogLik.current
+ } else {
+ if (error_model_algorithm == "d_3") {
+ if (!quiet) message("Directly optimising the complete model")
+ parms.start <- c(degparms, errparms)
+ fit_direct <- nlminb(parms.start, nlogLik,
+ lower = lower[names(parms.start)],
+ upper = upper[names(parms.start)],
+ control = control, ...)
+ fit_direct$logLik <- - nlogLik.current
+ nlogLik.current <- Inf # reset to avoid conflict with the OLS step
+ }
+ if (error_model_algorithm != "direct") {
+ if (!quiet) message("Ordinary least squares optimisation")
+ fit <- nlminb(degparms, nlogLik, control = control,
+ lower = lower[names(degparms)],
+ upper = upper[names(degparms)], OLS = TRUE, ...)
+ degparms <- fit$par
+ # Get the maximum likelihood estimate for sigma at the optimum parameter values
+ data_errmod$residual <- data_errmod$value.observed - data_errmod$value.predicted
+ sigma_mle <- sqrt(sum(data_errmod$residual^2)/nrow(data_errmod))
+
+ nlogLik.current <- nlogLik(c(degparms, errparms), OLS = FALSE)
+ fit$logLik <- - nlogLik.current
+ }
+ if (error_model_algorithm %in% c("threestep", "fourstep", "d_3")) {
+ if (!quiet) message("Optimising the error model")
+ fit <- nlminb(errparms, nlogLik, control = control,
+ lower = lower[names(errparms)],
+ upper = upper[names(errparms)],
+ fixed_degparms = degparms, ...)
+ errparms <- fit$par
+ }
+ if (error_model_algorithm == "fourstep") {
+ if (!quiet) message("Optimising the degradation model")
+ fit <- nlminb(degparms, nlogLik, control = control,
+ lower = lower[names(degparms)],
+ upper = upper[names(degparms)],
+ fixed_errparms = errparms, ...)
+ degparms <- fit$par
+ }
+ if (error_model_algorithm %in% c("direct", "twostep", "threestep",
+ "fourstep", "d_3")) {
+ if (!quiet) message("Optimising the complete model")
+ parms.start <- c(degparms, errparms)
+ fit <- nlminb(parms.start, nlogLik,
+ lower = lower[names(parms.start)],
+ upper = upper[names(parms.start)],
+ control = control, ...)
+ fit$logLik <- - nlogLik.current
+
+ d_3_messages = c(
+ same = "Direct fitting and three-step fitting yield approximately the same likelihood",
+ threestep = "Three-step fitting yielded a higher likelihood than direct fitting",
+ direct = "Direct fitting yielded a higher likelihood than three-step fitting")
+ if (error_model_algorithm == "d_3") {
+ rel_diff <- abs((fit_direct$logLik - fit$logLik))/-mean(c(fit_direct$logLik, fit$logLik))
+ if (rel_diff < 0.0001) {
+ if (!quiet) message(d_3_messages["same"])
+ fit$d_3_message <- d_3_messages["same"]
+ } else {
+ if (fit$logLik > fit_direct$logLik) {
+ if (!quiet) message(d_3_messages["threestep"])
+ fit$d_3_message <- d_3_messages["threestep"]
+ } else {
+ if (!quiet) message(d_3_messages["direct"])
+ fit <- fit_direct
+ fit$d_3_message <- d_3_messages["direct"]
+ }
+ }
+ }
+ }
+ if (err_mod != "const" & error_model_algorithm == "IRLS") {
+ reweight.diff <- 1
+ n.iter <- 0
+ errparms_last <- errparms
+
+ while (reweight.diff > reweight.tol &
+ n.iter < reweight.max.iter) {
+
+ if (!quiet) message("Optimising the error model")
+ fit <- nlminb(errparms, nlogLik, control = control,
+ lower = lower[names(errparms)],
+ upper = upper[names(errparms)],
+ fixed_degparms = degparms, ...)
+ errparms <- fit$par
+
+ if (!quiet) message("Optimising the degradation model")
+ fit <- nlminb(degparms, nlogLik, control = control,
+ lower = lower[names(degparms)],
+ upper = upper[names(degparms)],
+ fixed_errparms = errparms, ...)
+ degparms <- fit$par
+
+ reweight.diff <- dist(rbind(errparms, errparms_last))
+ errparms_last <- errparms
+
+ fit$par <- c(fit$par, errparms)
+ nlogLik.current <- nlogLik(c(degparms, errparms), OLS = FALSE)
+ fit$logLik <- - nlogLik.current
+ }
+ }
}
+ fit$error_model_algorithm <- error_model_algorithm
# We include the error model in the parameter uncertainty analysis, also
# for constant variance, to get a confidence interval for it
@@ -629,6 +731,7 @@ summary.mkinfit <- function(object, data = TRUE, distimes = TRUE, alpha = 0.05, solution_type = object$solution_type,
warning = object$warning,
use_of_ff = object$mkinmod$use_of_ff,
+ error_model_algorithm = object$error_model_algorithm,
df = c(p, rdf),
cov.unscaled = covar,
err_mod = object$err_mod,
@@ -667,8 +770,9 @@ summary.mkinfit <- function(object, data = TRUE, distimes = TRUE, alpha = 0.05, ep <- endpoints(object)
if (length(ep$ff) != 0)
ans$ff <- ep$ff
- if(distimes) ans$distimes <- ep$distimes
- if(length(ep$SFORB) != 0) ans$SFORB <- ep$SFORB
+ if (distimes) ans$distimes <- ep$distimes
+ if (length(ep$SFORB) != 0) ans$SFORB <- ep$SFORB
+ if (!is.null(object$d_3_message)) ans$d_3_message <- object$d_3_message
class(ans) <- c("summary.mkinfit", "summary.modFit")
return(ans)
}
@@ -698,12 +802,15 @@ print.summary.mkinfit <- function(x, digits = max(3, getOption("digits") - 3), . cat("\nFitted using", x$calls, "model solutions performed in", x$time[["elapsed"]], "s\n")
- cat("\nError model:\n")
+ cat("\nError model: ")
cat(switch(x$err_mod,
const = "Constant variance",
obs = "Variance unique to each observed variable",
tc = "Two-component variance function"), "\n")
+ cat("\nError model algorithm:", x$error_model_algorithm, "\n")
+ if (!is.null(x$d_3_message)) cat(x$d_3_message, "\n")
+
cat("\nStarting values for parameters to be optimised:\n")
print(x$start)
|