GPU accelerated stochastic simulation in JAX
The public simulation interface is exposed from ssa. A simulation is defined by:
- a stoichiometry matrix
Swith shape(n_species, n_reactions); - a consumption matrix
nuwith shape(n_species, n_reactions); - a propensity function
propensity(x, params); - a parameter PyTree
params; - an initial state
x0with shape(n_species, batch_size).
S[:, r] is the net state change caused by reaction r.
nu[:, r] gives the number of molecules consumed by reaction r; it is used to cap tau-leap reaction counts so that species counts do not become negative.
For example, a one-species birth-death process,
∅ → X rate k_birth
X → ∅ rate k_death X
has two reactions and one species:
import jax.numpy as jnp
from jax import random
from ssa import make_problem, simulate, batched_state_from_single
def propensity(x, params):
k_birth = params["k_birth"]
k_death = params["k_death"]
birth = k_birth * jnp.ones_like(x[0])
death = k_death * x[0].astype(jnp.float32)
return jnp.stack([birth, death], axis=0)
S = jnp.array([
[1, -1],
], dtype=jnp.int32)
nu = jnp.array([
[0, 1],
], dtype=jnp.int32)
params = {
"k_birth": jnp.float32(5.0),
"k_death": jnp.float32(0.2),
}
problem = make_problem(S=S, nu=nu, propensity=propensity, params=params)The state is batched. A single initial condition can be replicated across many independent trajectories with batched_state_from_single:
x0 = batched_state_from_single(
x0_single=jnp.array([0], dtype=jnp.int32),
batch_size=1024,
)
key = random.PRNGKey(0)
t, X = simulate(
problem,
x0,
dt=1e-2,
T=10.0,
K=8,
key=key,
)The returned array X has shape:
(n_saved_times, n_species, batch_size)
Only chunk endpoints are stored. K is the number of tau-leap steps per saved point, so saved times are spaced by K * dt.
A slightly larger example with two species and three reactions,
∅ → X rate k_x
X → Y rate k_xy X
Y → ∅ rate k_y Y
can be specified as:
def propensity(x, params):
X = x[0].astype(jnp.float32)
Y = x[1].astype(jnp.float32)
return jnp.stack(
[
params["k_x"] * jnp.ones_like(X),
params["k_xy"] * X,
params["k_y"] * Y,
],
axis=0,
)
S = jnp.array(
[
[1, -1, 0],
[0, 1, -1],
],
dtype=jnp.int32,
)
nu = jnp.array(
[
[0, 1, 0],
[0, 0, 1],
],
dtype=jnp.int32,
)
params = {
"k_x": jnp.float32(10.0),
"k_xy": jnp.float32(0.3),
"k_y": jnp.float32(0.1),
}
problem = make_problem(S, nu, propensity, params)Here the second reaction consumes one molecule of X and produces one molecule of Y, so its stoichiometry column is [-1, +1] and its consumption column is [1, 0].
Built-in example models follow the same pattern. For instance, the repressilator model constructs an SSAProblem and then uses the same simulation call:
import jax.numpy as jnp
from jax import random
from ssa import batched_state_from_single, simulate
from ssa.models.repressilator import build_problem
params = {
"km0": jnp.float32(0.2),
"km": jnp.float32(20.0),
"K": jnp.float32(40.0),
"n": jnp.float32(2.0),
"gm": jnp.float32(1.0),
"kp": jnp.float32(5.0),
"gp": jnp.float32(0.2),
"Omega": jnp.float32(20.0),
}
problem = build_problem(params)
x0 = batched_state_from_single(
jnp.array([2, 1, 3, 10, 4, 7], dtype=jnp.int32),
batch_size=256,
)
t, X = simulate(
problem,
x0,
dt=1e-2,
T=100.0,
K=4,
key=random.PRNGKey(1),
)Several examples are provided, including simulation of a repressilator-style synthetic oscillator, and generic parametrically defined gene-regulatory networks. In the single-cell case the GPU acceleration is particularly useful, as the expression of multiple networks (i.e., cells) at the same time points can be evaluated in parallel.
It first constructs a directed GRN with transcription factors, housekeeping genes, and target genes, assigning each gene basal expression and degradation rates and assigning each regulatory edge an activation or repression sign, strength, Hill coefficient, and threshold. Each gene is represented by two reactions: mRNA production and degradation. Production propensities are modulated by Hill-type regulatory terms from upstream transcription factors, while degradation is linear in the current molecule count. The branching version generates cells along a latent bifurcating pseudotime trajectory. Branch-specific transcription factors are activated according to each cell’s branch and pseudotime, producing structured, branch-dependent expression programs.
Also included is a comparison between pure JAX and Pallas(/Triton) specific code, that achieves around 1.5x performance for mass-action models.
Config: B=65536, K=32, B_TILE=256, dt=0.001, T=0.5, n_steps=500, n_chunks=16
jax : 0.1330s 246.40 M tau-steps/s
pallas: 0.0961s 340.86 M tau-steps/sThis uses the tau-leaping variant of the Gillespie stochastic simulation algorithm. In the exact SSA, the time to the next reaction is sampled from an exponential distribution with rate equal to the total propensity, after which a reaction channel is selected with probability proportional to its propensity. In tau-leaping, one instead chooses a time step
This can substantially accelerate simulation when high-propensity reactions fire many times over intervals in which the propensities vary only slowly, since the algorithm does not need to resolve each individual reaction event. The choice of
An advantage of this formulation is that it admits a straightforward vectorised implementation. Reaction counts can be sampled for all reaction channels at once, and the species updates can be applied through the stoichiometry matrix. Consequently, multiple independent systems with a common
Another important consideration is that the Poisson variates can be generated in large batched operations once the propensities for a time step have been computed, rather than being interleaved with trajectory-specific event selection and state updates. This is particularly useful when sampling many independent trajectories from the same system, for example to estimate uncertainty or approximate distributions over possible dynamical outcomes.