Add optional StochasticAD extension for differentiating jump processes#596
Add optional StochasticAD extension for differentiating jump processes#596RomanSahakyan03 wants to merge 1 commit into
Conversation
Adds JumpProcessesStochasticADExt, an optional package extension (loaded only when StochasticAD and Distributions are present) that lets StochasticAD's derivative_estimate compute gradients of expectations over variable-rate jump processes.
Differentiating through the solver directly --
derivative_estimate(g -> solve(jprob(g), Tsit5())) -- does not compose: the adaptive OrdinaryDiffEq internals require Number-interface methods that StochasticTriple intentionally omits, and VR_Direct's ContinuousCallback rootfind is a boolean predicate on a triple. The extension instead provides a triple-generic, fixed-grid PDMP simulator (generic RK4 + per-channel Bernoulli with first-fire priority + multiplicative-select update) that composes by design.
- src/JumpProcesses.jl: fixedgrid_simulate / fixedgrid_jump_observable stubs
and exports (no StochasticAD code in src/)
- ext/JumpProcessesStochasticADExt.jl: the simulator and a JumpProblem adapter
- Project.toml: StochasticAD and Distributions as weakdeps/extension + compat
- test/stochasticad_tests.jl: analytic-gradient and MCWF tests (registered)
Validated against an analytic benchmark (0.7 sigma) and the standalone Stage-2
prototype gradient (agreement within ~1 sigma on both partials).
|
I think I'll let @ChrisRackauckas take a first pass at this as I don't know the ODE AD internals that well. I do have one concern though, which is that I'm not sure StochasticAD is really being maintained or developed anymore. Maybe Chris can comment on that. |
|
In either case though thanks for the contribution! |
|
More generally, if you are interested in getting sensitivity calculations into JumpProcesses and want to talk about that sometime please let me know -- I've been wanting to build that out and had it on my TODO. |
Thanks @isaacsas, that makes sense. I'd definitely be interested in talking about sensitivity calculations for JumpProcesses. I'll wait for @ChrisRackauckas's first pass here as well, since this started from his suggestion and I want to make sure I'm aligned with the intended direction. Happy to treat this PR as a prototype and discuss what a better long-term path should look like. |
| # that `ForwardDiff.Dual` implements but `StochasticAD.StochasticTriple` omits | ||
| # (so triples can't even enter the ODE state), and (2) `VR_Direct` locates jump |
There was a problem hiding this comment.
implements but
StochasticAD.StochasticTripleomits (so triples can't even enter the ODE state)
That doesn't seem accurate? It's a direct extension so the theory works out. What happens if you try it? Make an MWE. IIRC that already works with the package as-is.
There was a problem hiding this comment.
Thanks, I reduced this further. You were right that my original wording was too broad, especially because one(γ) in StochasticAD returns Float64, so my first example accidentally created a Vector{Float64} state.
But even with a scalar triple state, Tsit5() currently fails before any JumpProcesses code is involved:
using OrdinaryDiffEq
using StochasticAD
function loss_scalar_state_triple(γ)
prob = ODEProblem(
(u, p, t) -> -u,
γ,
(0.0, 1.0),
nothing,
)
sol = solve(prob, Tsit5())
return sol[end]
end
derivative_estimate(loss_scalar_state_triple, 2.0)This errors during OrdinaryDiffEqCore.__init with:
ERROR: MethodError: no method matching StochasticAD.StochasticTriple{StochasticAD.Tag{…}, Float64, StochasticAD.PrunedFIsBackend.PrunedFIs{…}}(::Float64)
The type `StochasticAD.StochasticTriple{StochasticAD.Tag{typeof(loss_param_triple), Float64}, Float64, StochasticAD.PrunedFIsBackend.PrunedFIs{Float64}}` exists, but no method is defined for this combination of argument types when trying to construct it.
Closest candidates are:
StochasticAD.StochasticTriple{T, V, FIs}(::V, ::V, ::FIs) where {T, V, FIs<:StochasticAD.AbstractFIs{V}}
@ StochasticAD C:\Users\Roman\.julia\packages\StochasticAD\Gef5s\src\stochastic_triple.jl:21
(::Type{T})(::T) where T<:Number
@ Core boot.jl:965
(::Type{T})(::Static.StaticInteger) where T<:Real
@ Static C:\Users\Roman\.julia\packages\Static\d7YOk\src\Static.jl:442
...and the stacktrace goes through:
oneunit(x::StochasticAD.StochasticTriple{...})So I’ll revise the comment to avoid saying that triples cannot enter the ODE state or that OrdinaryDiffEq assumes the full Number interface in general. The more precise issue is that this Tsit5 initialization path currently calls oneunit/scalar construction in a way that StochasticTriple does not support. Separately, the JumpProcesses variable-rate callback/rootfinding path has its own issue with predicates on triples.
| # times with a `ContinuousCallback` rootfind, i.e. a boolean predicate on a | ||
| # triple, which StochasticAD forbids by design. Both were established |
There was a problem hiding this comment.
That is required for a general variable-rate jump. So is this scoped only for fixed rate? If it's scoped for fixed rate, it can probably be simplified to just use tstops, which then would be a lot faster.
There was a problem hiding this comment.
You're right. Exact general variable-rate jumps need the integrated-hazard rootfind, so I should not describe this as exact VariableRateJump support. I’ll fix that wording.
The scope is not fixed-rate only. The target is state-dependent rates, e.g. my MCWF case where the rate depends on |c₂(t)|². Since StochasticAD cannot differentiate through the rootfind predicate, this uses a fixed-grid / tau-leap-style approximation instead.
For constant-rate jumps, I agree that tstops is the better path. I confirmed that pre-sampling N ~ Poisson(λT) and placing jumps with tstops differentiates correctly, so I’m going to start working for the constant rate one.
|
It would be much easier to start with the jump-only case. Start with the direct method on purely constant rate jumps and SSA stepper and get that differentiating. |
| Validity: this is an O(dt²)-per-step (τ-leap-style) approximation to the exact | ||
| continuous-time PDMP, accurate when `rateₖ · dt ≪ 1`. |
There was a problem hiding this comment.
I'm not sure this accuracy actually holds? This is differentiating as a tau-leap method, of which none are order 2. So this is likely only O(dt)
There was a problem hiding this comment.
And if it's only O(dt), then just euler the other terms? In which case, this is just a tau-leap effectively done in first reaction version.
Following your suggestion, I tried the constant-rate jump-only case with It does not differentiate as-is because I see two possible first steps:
|
|
You cannot presample because the rate is determined by the previous set of values. You have to iterate that. The stochastic triples are made for handling such iteration.
Yes |
Adds JumpProcessesStochasticADExt, an optional package extension (loaded only when StochasticAD and Distributions are present) that lets StochasticAD's derivative_estimate compute gradients of expectations over variable-rate jump processes.
Differentiating through the solver directly --
derivative_estimate(g -> solve(jprob(g), Tsit5())) -- does not compose: the adaptive OrdinaryDiffEq internals require Number-interface methods that StochasticTriple intentionally omits, and VR_Direct's ContinuousCallback rootfind is a boolean predicate on a triple. The extension instead provides a triple-generic, fixed-grid PDMP simulator (generic RK4 + per-channel Bernoulli with first-fire priority + multiplicative-select update) that composes by design.