aboutsummaryrefslogtreecommitdiff
path: root/custom_lsoda_call_edited.patch
blob: a79cbbcd9bceb350705654bd02471630394e08c9 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
--- a/R/mkinfit.R
+++ b/R/mkinfit.R
 
+  # Get native symbol before iterations info for speed
+  call_lsoda <- getNativeSymbolInfo("call_lsoda", PACKAGE = "deSolve")
+  if (solution_type == "deSolve" & use_compiled[1] != FALSE) {
+    mkinmod$diffs_address <- getNativeSymbolInfo("diffs",
+      PACKAGE = mkinmod$dll_info[["name"]])$address
+    mkinmod$initpar_address <- getNativeSymbolInfo("initpar",
+      PACKAGE = mkinmod$dll_info[["name"]])$address
+  }
+

@@ -610,7 +619,8 @@ mkinfit <- function(mkinmod, observed,
                          solution_type = solution_type,
                          use_compiled = use_compiled,
                          method.ode = method.ode,
-                         atol = atol, rtol = rtol, ...)
+                         atol = atol, rtol = rtol,
+                         call_lsoda = call_lsoda, ...)
 
--- a/R/mkinpredict.R
+++ b/R/mkinpredict.R

@@ -116,9 +114,10 @@ mkinpredict.mkinmod <- function(x,
   outtimes = seq(0, 120, by = 0.1),
   solution_type = "deSolve",
   use_compiled = "auto",
-  method.ode = "lsoda", atol = 1e-8, rtol = 1e-10, maxsteps = 20000,
+  method.ode = "lsoda", atol = 1e-8, rtol = 1e-10, maxsteps = 20000L,
   map_output = TRUE,
   na_stop = TRUE,
+  call_lsoda = NULL,
   ...)
 {
 
@@ -173,20 +172,80 @@ mkinpredict.mkinmod <- function(x,
   if (solution_type == "deSolve") {
     if (!is.null(x$cf) & use_compiled[1] != FALSE) {
 
-      out <- deSolve::ode(
-        y = odeini,
-        times = outtimes,
-        func = "diffs",
-        initfunc = "initpar",
-        dllname = if (is.null(x$dll_info)) inline::getDynLib(x$cf)[["name"]]
-          else x$dll_info[["name"]],
-        parms = odeparms[x$parms], # Order matters when using compiled models
-        method = method.ode,
-        atol = atol,
-        rtol = rtol,
-        maxsteps = maxsteps,
-        ...
+      # Prepare call to "call_lsoda"
+      # Simplified code from deSolve::lsoda() adapted to our use case
+      if (is.null(call_lsoda)) {
+        call_lsoda <- getNativeSymbolInfo("call_lsoda", PACKAGE = "deSolve")
+      }
+      if (is.null(x$diffs_address)) {
+        x$diffs_address <- getNativeSymbolInfo("diffs",
+          PACKAGE = x$dll_info[["name"]])$address
+        x$initpar_address <- getNativeSymbolInfo("initpar",
+          PACKAGE = x$dll_info[["name"]])$address
+      }
+      rwork <- vector("double", 20)
+      rwork[] <- 0.
+      rwork[6] <- max(abs(diff(outtimes)))
+
+      iwork <- vector("integer", 20)
+      iwork[] <- 0
+      iwork[6] <- maxsteps
+
+      n <- length(odeini)
+      lmat <- n^2 + 2     # from deSolve::lsoda(), for Jacobian type full, internal (2)
+      # hard-coded default values of lsoda():
+      maxordn <- 12L
+      maxords <- 5L
+      lrn <- 20 + n * (maxordn + 1) + 3 * n  # length in case non-stiff method
+      lrs <- 20 + n * (maxords + 1) + 3 * n + lmat        # length in case stiff method
+      lrw <- max(lrn, lrs)                       # actual length: max of both
+      liw <- 20 + n
+
+      on.exit(.C("unlock_solver", PACKAGE = "deSolve"))
+
+      out_raw <- .Call(call_lsoda,
+        as.double(odeini),   # y
+        as.double(outtimes), # times
+        x$diffs_address,     # derivfunc
+        as.double(odeparms[x$parms]), # parms
+        rtol, atol,
+        NULL, NULL, # rho, tcrit
+        NULL, # jacfunc
+        x$initpar_address, # initfunc
+        NULL, # eventfunc
+        0L,   # verbose
+        1L,   # iTask
+        as.double(rwork), # rWork
+        as.integer(iwork), # iWork
+        2L,     # jT full Jacobian calculated internally
+        0L,    # nOut
+        as.integer(lrw), # lRw
+        as.integer(liw), # lIw
+        1L, # Solver
+        NULL, # rootfunc
+        0L, as.double(0), 0L, # nRoot, Rpar, Ipar
+        0L, # Type
+        list(fmat = 0L, tmat = 0L, imat = 0L, ModelForc = NULL), # flist
+        list(), # elist
+        list(islag = 0L) # elag
       )
+
+      out <- t(out_raw)
+      colnames(out) <- c("time", mod_vars)
     } else {
       mkindiff <- function(t, state, parms) {
 

Contact - Imprint