Skip to content

dpohanlon/ssa

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

16 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation


GPU accelerated stochastic simulation in JAX

Usage

The public simulation interface is exposed from ssa. A simulation is defined by:

  • a stoichiometry matrix S with shape (n_species, n_reactions);
  • a consumption matrix nu with shape (n_species, n_reactions);
  • a propensity function propensity(x, params);
  • a parameter PyTree params;
  • an initial state x0 with shape (n_species, batch_size).

S[:, r] is the net state change caused by reaction r.

nu[:, r] gives the number of molecules consumed by reaction r; it is used to cap tau-leap reaction counts so that species counts do not become negative.

For example, a one-species birth-death process,

∅ → X      rate k_birth
X → ∅      rate k_death X

has two reactions and one species:

import jax.numpy as jnp
from jax import random

from ssa import make_problem, simulate, batched_state_from_single

def propensity(x, params):
    k_birth = params["k_birth"]
    k_death = params["k_death"]

    birth = k_birth * jnp.ones_like(x[0])
    death = k_death * x[0].astype(jnp.float32)

    return jnp.stack([birth, death], axis=0)


S = jnp.array([
    [1, -1],
], dtype=jnp.int32)

nu = jnp.array([
    [0, 1],
], dtype=jnp.int32)

params = {
    "k_birth": jnp.float32(5.0),
    "k_death": jnp.float32(0.2),
}

problem = make_problem(S=S, nu=nu, propensity=propensity, params=params)

The state is batched. A single initial condition can be replicated across many independent trajectories with batched_state_from_single:

x0 = batched_state_from_single(
    x0_single=jnp.array([0], dtype=jnp.int32),
    batch_size=1024,
)

key = random.PRNGKey(0)

t, X = simulate(
    problem,
    x0,
    dt=1e-2,
    T=10.0,
    K=8,
    key=key,
)

The returned array X has shape:

(n_saved_times, n_species, batch_size)

Only chunk endpoints are stored. K is the number of tau-leap steps per saved point, so saved times are spaced by K * dt.

A slightly larger example with two species and three reactions,

∅ → X        rate k_x
X → Y        rate k_xy X
Y → ∅        rate k_y Y

can be specified as:

def propensity(x, params):
    X = x[0].astype(jnp.float32)
    Y = x[1].astype(jnp.float32)

    return jnp.stack(
        [
            params["k_x"] * jnp.ones_like(X),
            params["k_xy"] * X,
            params["k_y"] * Y,
        ],
        axis=0,
    )


S = jnp.array(
    [
        [1, -1, 0],
        [0, 1, -1],
    ],
    dtype=jnp.int32,
)

nu = jnp.array(
    [
        [0, 1, 0],
        [0, 0, 1],
    ],
    dtype=jnp.int32,
)

params = {
    "k_x": jnp.float32(10.0),
    "k_xy": jnp.float32(0.3),
    "k_y": jnp.float32(0.1),
}

problem = make_problem(S, nu, propensity, params)

Here the second reaction consumes one molecule of X and produces one molecule of Y, so its stoichiometry column is [-1, +1] and its consumption column is [1, 0].

Built-in example models follow the same pattern. For instance, the repressilator model constructs an SSAProblem and then uses the same simulation call:

import jax.numpy as jnp
from jax import random

from ssa import batched_state_from_single, simulate
from ssa.models.repressilator import build_problem


params = {
    "km0": jnp.float32(0.2),
    "km": jnp.float32(20.0),
    "K": jnp.float32(40.0),
    "n": jnp.float32(2.0),
    "gm": jnp.float32(1.0),
    "kp": jnp.float32(5.0),
    "gp": jnp.float32(0.2),
    "Omega": jnp.float32(20.0),
}

problem = build_problem(params)

x0 = batched_state_from_single(
    jnp.array([2, 1, 3, 10, 4, 7], dtype=jnp.int32),
    batch_size=256,
)

t, X = simulate(
    problem,
    x0,
    dt=1e-2,
    T=100.0,
    K=4,
    key=random.PRNGKey(1),
)

Examples

Several examples are provided, including simulation of a repressilator-style synthetic oscillator, and generic parametrically defined gene-regulatory networks. In the single-cell case the GPU acceleration is particularly useful, as the expression of multiple networks (i.e., cells) at the same time points can be evaluated in parallel.

It first constructs a directed GRN with transcription factors, housekeeping genes, and target genes, assigning each gene basal expression and degradation rates and assigning each regulatory edge an activation or repression sign, strength, Hill coefficient, and threshold. Each gene is represented by two reactions: mRNA production and degradation. Production propensities are modulated by Hill-type regulatory terms from upstream transcription factors, while degradation is linear in the current molecule count. The branching version generates cells along a latent bifurcating pseudotime trajectory. Branch-specific transcription factors are activated according to each cell’s branch and pseudotime, producing structured, branch-dependent expression programs.

Also included is a comparison between pure JAX and Pallas(/Triton) specific code, that achieves around 1.5x performance for mass-action models.

Config: B=65536, K=32, B_TILE=256, dt=0.001, T=0.5, n_steps=500, n_chunks=16
jax   : 0.1330s   246.40 M tau-steps/s
pallas: 0.0961s   340.86 M tau-steps/s

Implementation

This uses the tau-leaping variant of the Gillespie stochastic simulation algorithm. In the exact SSA, the time to the next reaction is sampled from an exponential distribution with rate equal to the total propensity, after which a reaction channel is selected with probability proportional to its propensity. In tau-leaping, one instead chooses a time step $\tau$ small enough that the propensities are approximately constant over the interval $[t,t+\tau]$. The number of firings of each reaction channel during this interval is then sampled from a Poisson distribution with mean equal to that reaction’s propensity multiplied by $\tau$, and the system state is updated using the corresponding stoichiometric changes.

This can substantially accelerate simulation when high-propensity reactions fire many times over intervals in which the propensities vary only slowly, since the algorithm does not need to resolve each individual reaction event. The choice of $\tau$ is therefore problem-dependent, and is often estimated adaptively to balance accuracy and computational efficiency. If $\tau$ is too large, the approximation may become inaccurate, and in discrete population systems it can also produce invalid negative copy numbers unless additional safeguards are used.

An advantage of this formulation is that it admits a straightforward vectorised implementation. Reaction counts can be sampled for all reaction channels at once, and the species updates can be applied through the stoichiometry matrix. Consequently, multiple independent systems with a common $\tau$ and a common number of time steps can be batched and executed in parallel on a GPU. This makes the state update more regular than in the exact SSA, since all trajectories perform the same broad sequence of operations at each step: propensity evaluation, Poisson sampling, and stoichiometric update. A shared $\tau$ also reduces trajectory-specific control flow and branch divergence, although this advantage may be weakened by adaptive step-size selection, rejection steps, or safeguards against negative copy numbers.

Another important consideration is that the Poisson variates can be generated in large batched operations once the propensities for a time step have been computed, rather than being interleaved with trajectory-specific event selection and state updates. This is particularly useful when sampling many independent trajectories from the same system, for example to estimate uncertainty or approximate distributions over possible dynamical outcomes.

About

GPU accelerated stochastic simulation in JAX

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages