Skip to content

Add optional StochasticAD extension for differentiating jump processes#596

Open
RomanSahakyan03 wants to merge 1 commit into
SciML:masterfrom
RomanSahakyan03:stochasticad-integration
Open

Add optional StochasticAD extension for differentiating jump processes#596
RomanSahakyan03 wants to merge 1 commit into
SciML:masterfrom
RomanSahakyan03:stochasticad-integration

Conversation

@RomanSahakyan03

@RomanSahakyan03 RomanSahakyan03 commented Jun 3, 2026

Copy link
Copy Markdown

Adds JumpProcessesStochasticADExt, an optional package extension (loaded only when StochasticAD and Distributions are present) that lets StochasticAD's derivative_estimate compute gradients of expectations over variable-rate jump processes.

Differentiating through the solver directly --
derivative_estimate(g -> solve(jprob(g), Tsit5())) -- does not compose: the adaptive OrdinaryDiffEq internals require Number-interface methods that StochasticTriple intentionally omits, and VR_Direct's ContinuousCallback rootfind is a boolean predicate on a triple. The extension instead provides a triple-generic, fixed-grid PDMP simulator (generic RK4 + per-channel Bernoulli with first-fire priority + multiplicative-select update) that composes by design.

  • src/JumpProcesses.jl: fixedgrid_simulate / fixedgrid_jump_observable stubs and exports (no StochasticAD code in src/)
  • ext/JumpProcessesStochasticADExt.jl: the simulator and a JumpProblem adapter
  • Project.toml: StochasticAD and Distributions as weakdeps/extension + compat
  • test/stochasticad_tests.jl: analytic-gradient and MCWF tests (registered)

  Adds JumpProcessesStochasticADExt, an optional package extension (loaded only when StochasticAD and Distributions are present) that lets StochasticAD's derivative_estimate compute gradients of expectations over variable-rate jump processes.

  Differentiating through the solver directly --
  derivative_estimate(g -> solve(jprob(g), Tsit5())) -- does not compose: the adaptive OrdinaryDiffEq internals require Number-interface methods that StochasticTriple intentionally omits, and VR_Direct's ContinuousCallback rootfind is a boolean predicate on a triple. The extension instead provides a triple-generic, fixed-grid PDMP simulator (generic RK4 + per-channel Bernoulli with first-fire priority + multiplicative-select update) that composes by design.

  - src/JumpProcesses.jl: fixedgrid_simulate / fixedgrid_jump_observable stubs
    and exports (no StochasticAD code in src/)
  - ext/JumpProcessesStochasticADExt.jl: the simulator and a JumpProblem adapter
  - Project.toml: StochasticAD and Distributions as weakdeps/extension + compat
  - test/stochasticad_tests.jl: analytic-gradient and MCWF tests (registered)

  Validated against an analytic benchmark (0.7 sigma) and the standalone Stage-2
  prototype gradient (agreement within ~1 sigma on both partials).
@isaacsas

isaacsas commented Jun 3, 2026

Copy link
Copy Markdown
Member

I think I'll let @ChrisRackauckas take a first pass at this as I don't know the ODE AD internals that well. I do have one concern though, which is that I'm not sure StochasticAD is really being maintained or developed anymore. Maybe Chris can comment on that.

@isaacsas

isaacsas commented Jun 3, 2026

Copy link
Copy Markdown
Member

In either case though thanks for the contribution!

@isaacsas

isaacsas commented Jun 3, 2026

Copy link
Copy Markdown
Member

More generally, if you are interested in getting sensitivity calculations into JumpProcesses and want to talk about that sometime please let me know -- I've been wanting to build that out and had it on my TODO.

@RomanSahakyan03

RomanSahakyan03 commented Jun 3, 2026

Copy link
Copy Markdown
Author

More generally, if you are interested in getting sensitivity calculations into JumpProcesses and want to talk about that sometime please let me know -- I've been wanting to build that out and had it on my TODO.

Thanks @isaacsas, that makes sense.

I'd definitely be interested in talking about sensitivity calculations for JumpProcesses. I'll wait for @ChrisRackauckas's first pass here as well, since this started from his suggestion and I want to make sure I'm aligned with the intended direction.

Happy to treat this PR as a prototype and discuss what a better long-term path should look like.

Comment on lines +8 to +9
# that `ForwardDiff.Dual` implements but `StochasticAD.StochasticTriple` omits
# (so triples can't even enter the ODE state), and (2) `VR_Direct` locates jump

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

implements but StochasticAD.StochasticTriple omits (so triples can't even enter the ODE state)

That doesn't seem accurate? It's a direct extension so the theory works out. What happens if you try it? Make an MWE. IIRC that already works with the package as-is.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I reduced this further. You were right that my original wording was too broad, especially because one(γ) in StochasticAD returns Float64, so my first example accidentally created a Vector{Float64} state.

But even with a scalar triple state, Tsit5() currently fails before any JumpProcesses code is involved:

using OrdinaryDiffEq
using StochasticAD

function loss_scalar_state_triple(γ)
    prob = ODEProblem(
        (u, p, t) -> -u,
        γ,
        (0.0, 1.0),
        nothing,
    )

    sol = solve(prob, Tsit5())
    return sol[end]
end

derivative_estimate(loss_scalar_state_triple, 2.0)

This errors during OrdinaryDiffEqCore.__init with:

ERROR: MethodError: no method matching StochasticAD.StochasticTriple{StochasticAD.Tag{…}, Float64, StochasticAD.PrunedFIsBackend.PrunedFIs{…}}(::Float64)
The type `StochasticAD.StochasticTriple{StochasticAD.Tag{typeof(loss_param_triple), Float64}, Float64, StochasticAD.PrunedFIsBackend.PrunedFIs{Float64}}` exists, but no method is defined for this combination of argument types when trying to construct it.

Closest candidates are:
  StochasticAD.StochasticTriple{T, V, FIs}(::V, ::V, ::FIs) where {T, V, FIs<:StochasticAD.AbstractFIs{V}}
   @ StochasticAD C:\Users\Roman\.julia\packages\StochasticAD\Gef5s\src\stochastic_triple.jl:21
  (::Type{T})(::T) where T<:Number
   @ Core boot.jl:965
  (::Type{T})(::Static.StaticInteger) where T<:Real
   @ Static C:\Users\Roman\.julia\packages\Static\d7YOk\src\Static.jl:442
  ...

and the stacktrace goes through:

oneunit(x::StochasticAD.StochasticTriple{...})

So I’ll revise the comment to avoid saying that triples cannot enter the ODE state or that OrdinaryDiffEq assumes the full Number interface in general. The more precise issue is that this Tsit5 initialization path currently calls oneunit/scalar construction in a way that StochasticTriple does not support. Separately, the JumpProcesses variable-rate callback/rootfinding path has its own issue with predicates on triples.

Comment on lines +10 to +11
# times with a `ContinuousCallback` rootfind, i.e. a boolean predicate on a
# triple, which StochasticAD forbids by design. Both were established

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is required for a general variable-rate jump. So is this scoped only for fixed rate? If it's scoped for fixed rate, it can probably be simplified to just use tstops, which then would be a lot faster.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right. Exact general variable-rate jumps need the integrated-hazard rootfind, so I should not describe this as exact VariableRateJump support. I’ll fix that wording.

The scope is not fixed-rate only. The target is state-dependent rates, e.g. my MCWF case where the rate depends on |c₂(t)|². Since StochasticAD cannot differentiate through the rootfind predicate, this uses a fixed-grid / tau-leap-style approximation instead.

For constant-rate jumps, I agree that tstops is the better path. I confirmed that pre-sampling N ~ Poisson(λT) and placing jumps with tstops differentiates correctly, so I’m going to start working for the constant rate one.

@ChrisRackauckas

Copy link
Copy Markdown
Member

It would be much easier to start with the jump-only case. Start with the direct method on purely constant rate jumps and SSA stepper and get that differentiating.

Comment on lines +48 to +49
Validity: this is an O(dt²)-per-step (τ-leap-style) approximation to the exact
continuous-time PDMP, accurate when `rateₖ · dt ≪ 1`.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this accuracy actually holds? This is differentiating as a tau-leap method, of which none are order 2. So this is likely only O(dt)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And if it's only O(dt), then just euler the other terms? In which case, this is just a tau-leap effectively done in first reaction version.

@RomanSahakyan03

Copy link
Copy Markdown
Author

It would be much easier to start with the jump-only case. Start with the direct method on purely constant rate jumps and SSA stepper and get that differentiating.

Following your suggestion, I tried the constant-rate jump-only case with Direct + SSAStepper on a pure-birth example.

It does not differentiate as-is because DirectJumpAggregation stores propensities in a Vector{Float64} around fill_cur_rates in direct.jl:109, so a StochasticTriple rate hits Float64(::StochasticTriple).

I see two possible first steps:

  1. Make the direct aggregator generic over the rate type, so triples can pass through the existing SSA path.

  2. For purely constant-rate jumps, skip the SSA internals and pre-sample N ~ Poisson(λT), then place jumps with tstops. This differentiates correctly in the pure-birth case and recovers d/dλ E[N] = T.

@ChrisRackauckas

Copy link
Copy Markdown
Member

You cannot presample because the rate is determined by the previous set of values. You have to iterate that. The stochastic triples are made for handling such iteration.

Make the direct aggregator generic over the rate type, so triples can pass through the existing SSA path.

Yes

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants