From 2d611d74241ce97c122bc647f7149c85f195f508 Mon Sep 17 00:00:00 2001 From: jl33ai Date: Fri, 29 May 2026 09:27:34 -0700 Subject: [PATCH 1/3] add synthetic data source for running without trodes lets you run the pipeline end to end without an acquisition rig. useful for smoke testing a fresh install, working on a laptop, and CI. - realtime_decoder/synthetic.py: SyntheticDataReceiver and SyntheticClient that match the surface of TrodesDataReceiver and TrodesClient. drop in replacements. non-blocking __next__ semantics so the polling loops in encoder/decoder/ripple work unchanged. - runscript.py picks the data source via config['datasource']. defaults to 'trodes' so every existing config keeps working. - config/demo_synthetic.yml is a 5 rank demo wired to the synthetic source. - README adds a 'running without acquisition hardware' section. spikes are poisson, marks gaussian, position walks a triangle wave along a single segment. not biologically realistic, the point is to exercise the data path and message plumbing. to try it: mpiexec -np 5 python -u runscript.py config/demo_synthetic.yml --- README.md | 17 ++ config/demo_synthetic.yml | 175 +++++++++++++++++++ realtime_decoder/synthetic.py | 305 ++++++++++++++++++++++++++++++++++ runscript.py | 42 +++-- 4 files changed, 528 insertions(+), 11 deletions(-) create mode 100644 config/demo_synthetic.yml create mode 100644 realtime_decoder/synthetic.py diff --git a/README.md b/README.md index db0fc9a..803ee71 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,23 @@ mpiexec -np -bind-to hwthread python -u runscript.py 0 + and elapsed > self._p['run_duration_s'] + and not self._stopped + ): + # one-time log; the supervisor's termination is wired + # through SyntheticClient.receive() below. + self._stopped = True + + if self.datatype == Datatypes.LFP: + return self._next_lfp(elapsed) + elif self.datatype == Datatypes.SPIKES: + return self._next_spike(elapsed) + else: + return self._next_position(elapsed) + + # ------------------------------------------------------------------ + # Per-datatype generators + # ------------------------------------------------------------------ + + def _next_lfp(self, elapsed): + target_idx = int(elapsed * self._fs) + if target_idx < self._next_sample_idx: + return None + idx = self._next_sample_idx + self._next_sample_idx += 1 + # white-ish noise sized to (num_channels,), scaled the same way + # TrodesDataReceiver does (raw * voltage_scaling_factor) + n = max(1, len(self.ntrode_ids)) + raw = self._rng.standard_normal(n) * 200.0 # ~uV range pre-scale + data = raw * self._p['voltage_scaling_factor'] + local_ts = idx # LFP uses spike-clock timestamps in real Trodes; + # at fs_lfp=1500, fs_spike=30000 the ratio is 20, but downstream + # only cares about monotonicity within a stream, so use idx. + system_ts = time.time_ns() + return LFPPoint( + local_ts, + list(self.ntrode_ids), + data, + system_ts, + time.time_ns(), + ) + + def _next_spike(self, elapsed): + if not self.ntrode_ids: + return None + spike_sample_now = int(elapsed * self._spike_clock) + # Find any ntrode whose next-spike sample has arrived. + for ntid in self.ntrode_ids: + if self._next_spike_sample[ntid] <= spike_sample_now: + ts = self._next_spike_sample[ntid] + self._schedule_next_spike(ntid, sample_now=spike_sample_now) + # mark vector: gaussian around _amp, all channels positive + samples = ( + self._rng.standard_normal(self._mark_dim) * 8.0 + self._amp + ) / self._p['voltage_scaling_factor'] + # SpikePoint.data is later multiplied by voltage_scaling_factor + # in real Trodes; the encoder reads `max(mark_vec)` so we just + # need the post-scaling magnitudes to clear `encoder.spk_amp`. + return SpikePoint( + ts, + ntid, + samples * self._p['voltage_scaling_factor'], + time.time_ns(), + time.time_ns(), + ) + return None + + def _next_position(self, elapsed): + target_idx = int(elapsed * self._fs) + if target_idx < self._next_sample_idx: + return None + idx = self._next_sample_idx + self._next_sample_idx += 1 + + # triangle-wave walk along a single linear segment between 0 and + # track_length_cm + L = self._p['track_length_cm'] + v = self._p['walk_speed_cm_s'] + t = elapsed + period = 2.0 * L / max(v, 1e-6) + phase = (t % period) / period # 0..1 + pos_cm = L * (1.0 - abs(2.0 * phase - 1.0)) + # x/y/x2/y2 in "pixel" units — kinematics.scale_factor converts back + sf = self.config['kinematics']['scale_factor'] + x = pos_cm / sf + y = 100.0 # constant + x2 = x + 5.0 + y2 = y + return CameraModulePoint( + idx, + segment=0, + position=pos_cm, + x=x, + y=y, + x2=x2, + y2=y2, + t_recv_data=time.time_ns(), + ) + + # ------------------------------------------------------------------ + # internals + # ------------------------------------------------------------------ + + def _schedule_next_spike(self, ntid, *, sample_now): + rate = max(self._p['spike_rate_hz'], 1e-6) + # exponential inter-arrival in seconds → samples + gap_s = self._rng.exponential(1.0 / rate) + gap_samples = max(1, int(gap_s * self._spike_clock)) + self._next_spike_sample[ntid] = sample_now + gap_samples + + +class SyntheticClient(object): + """Drop-in synthetic replacement for ``trodesnet.TrodesClient``. + + Exposes the same surface used by the supervisor and stim decider: + * ``set_startup_callback`` / ``set_termination_callback`` + * ``receive`` (called from the supervisor main loop) + * ``send_statescript_shortcut_message`` (called from stim_decider) + + ``receive`` fires the startup callback once after ``startup_delay_s`` + of wall clock has elapsed, and fires termination once ``run_duration_s`` + has elapsed. + """ + + def __init__(self, config): + self._startup_callback = utils.nop + self._termination_callback = utils.nop + self._p = _params(config) + self._t0_wall = time.time() + self._started = False + self._terminated = False + # log-only buffer of "shortcut messages" the stim decider would + # have sent to ECU; useful for asserting in tests later. + self.sent_shortcuts = [] + + def receive(self): + elapsed = time.time() - self._t0_wall + if not self._started and elapsed >= self._p['startup_delay_s']: + self._started = True + self._startup_callback() + if ( + self._started + and not self._terminated + and self._p['run_duration_s'] > 0 + and elapsed >= self._p['run_duration_s'] + self._p['startup_delay_s'] + ): + self._terminated = True + self._termination_callback() + + def send_statescript_shortcut_message(self, val): + self.sent_shortcuts.append((time.time_ns(), int(val))) + + def set_startup_callback(self, callback): + self._startup_callback = callback + + def set_termination_callback(self, callback): + self._termination_callback = callback diff --git a/runscript.py b/runscript.py index 3e920c8..5b93548 100644 --- a/runscript.py +++ b/runscript.py @@ -11,12 +11,30 @@ from mpi4py import MPI from realtime_decoder import ( - datatypes, position, trodesnet, stimulation, + datatypes, position, trodesnet, synthetic, stimulation, main_process, ripple_process, encoder_process, decoder_process, gui_process, base, messages, merge_rec ) + +def _data_source_factory(config): + """Pick the (receiver_class, client_class) pair for the configured + data source. + + `datasource: trodes` (default) uses the live Trodes streams. + `datasource: synthetic` uses the in-process generator from + `realtime_decoder.synthetic` — install-and-run with no hardware. + """ + ds = config.get('datasource', 'trodes') + if ds == 'trodes': + return trodesnet.TrodesDataReceiver, trodesnet.TrodesClient + if ds == 'synthetic': + return synthetic.SyntheticDataReceiver, synthetic.SyntheticClient + raise ValueError( + f"Unknown datasource {ds!r}; expected 'trodes' or 'synthetic'" + ) + # from line_profiler import LineProfiler class GuiProcessStub(base.RealtimeProcess, base.MessageHandler): @@ -169,21 +187,23 @@ def setup(config_path, numprocs): regloop = True ################################################# + DataReceiver, Client = _data_source_factory(config) + if rank in config['rank']['supervisor']: - trodes_client = trodesnet.TrodesClient(config) + net_client = Client(config) stim_decider = stimulation.TwoArmTrodesStimDecider( - comm, rank, config, trodes_client + comm, rank, config, net_client ) process = main_process.MainProcess( - comm, rank, config, stim_decider, trodes_client + comm, rank, config, stim_decider, net_client ) - trodes_client.set_startup_callback(process.startup) - trodes_client.set_termination_callback(process.trigger_termination) + net_client.set_startup_callback(process.startup) + net_client.set_termination_callback(process.trigger_termination) elif rank in config['rank']['ripples']: - lfp_interface = trodesnet.TrodesDataReceiver( + lfp_interface = DataReceiver( comm, rank, config, datatypes.Datatypes.LFP ) - pos_interface = trodesnet.TrodesDataReceiver( + pos_interface = DataReceiver( comm, rank, config, datatypes.Datatypes.LINEAR_POSITION ) process = ripple_process.RippleProcess( @@ -196,10 +216,10 @@ def setup(config_path, numprocs): # prof.print_stats() # regloop = False elif rank in config['rank']['encoders']: - spikes_interface = trodesnet.TrodesDataReceiver( + spikes_interface = DataReceiver( comm, rank, config, datatypes.Datatypes.SPIKES ) - pos_interface = trodesnet.TrodesDataReceiver( + pos_interface = DataReceiver( comm, rank, config, datatypes.Datatypes.LINEAR_POSITION ) pos_mapper = position.TrodesPositionMapper( @@ -211,7 +231,7 @@ def setup(config_path, numprocs): pos_mapper ) elif rank in config['rank']['decoders']: - pos_interface = trodesnet.TrodesDataReceiver( + pos_interface = DataReceiver( comm, rank, config, datatypes.Datatypes.LINEAR_POSITION ) pos_mapper = position.TrodesPositionMapper( From 2bc4351225923b73d4f950e527d097acaced9efc Mon Sep 17 00:00:00 2001 From: jl33ai Date: Fri, 29 May 2026 09:27:43 -0700 Subject: [PATCH 2/3] add config defaults and startup validation most of the per-animal yamls duplicate the same sampling rates, ripple filter, gui colors, kinematics smoothing filter, mua, etc. drift is real and 'what is the canonical value of X' is currently unanswerable. - configs can declare `_extends: defaults.yml` and only override what is actually different. nested dicts deep merge key by key, lists and scalars replace. - config/defaults.yml holds the values stable across the SC*/fred/ginny configs. - runscript routes config loading through a new config_loader module that runs a small validator on the resolved dict. catches missing rank.supervisor, unknown algorithm, missing encoder.mark_dim, decoder ranks with no assignment, encoder.mark_dim != synthetic.mark_dim, etc. errors print one readable message and exit 2 before MPI workers spawn, instead of surfacing as IndexError deep in a rank. - config/demo_synthetic.yml now uses _extends, drops about 80 lines. - loader uses stdlib pyyaml instead of oyaml since python 3.7+ dicts preserve order. one fewer dep. backward compatible. verified all 16 existing per-animal configs still load and validate unchanged. follow up not in this PR: the SC*/fred/ginny configs can each be rewritten as _extends + overrides which would shrink them by roughly half. left as a separate diff so per-config behavior changes are auditable on their own. --- README.md | 34 +++++ config/defaults.yml | 101 +++++++++++++ config/demo_synthetic.yml | 98 ++++--------- realtime_decoder/config_loader.py | 229 ++++++++++++++++++++++++++++++ runscript.py | 16 ++- 5 files changed, 403 insertions(+), 75 deletions(-) create mode 100644 config/defaults.yml create mode 100644 realtime_decoder/config_loader.py diff --git a/README.md b/README.md index 803ee71..654473d 100644 --- a/README.md +++ b/README.md @@ -45,6 +45,40 @@ auto-terminates after `synthetic.run_duration_s` seconds. See Please see the example configuration file in the `example_config` folder. Options are described in more detail below. +## Defaults and inheritance + +Configs can extend a shared base file by declaring `_extends:` at the top: + +```yaml +_extends: defaults.yml + +# only the keys that actually differ from defaults.yml go here +rank: + supervisor: [0] + ... +``` + +`config/defaults.yml` ships with values shared across the SC* / fred / ginny +configs (sampling rates, ripple filter, GUI, MUA, kinematics smoothing +filter, display intervals, process monitor). Per-animal configs only need +to specify what's actually different — typically `rank`, `trode_selection`, +`decoder_assignment`, `files`, `encoder.position`, `kinematics.scale_factor`, +and `stimulation`. Existing configs without `_extends` continue to work +unchanged. + +Relative paths in `_extends` resolve next to the file that declares them. +You may pass a single path or a list (parents merged in order). + +## Validation + +At startup the loader checks the resolved config against a minimal schema +(required ranks, known algorithm, known datasource, encoder dimensions +match the synthetic source, every decoder rank has an assignment, etc.). +Missing or malformed keys produce a single readable error before the MPI +processes spawn workers, instead of an `IndexError` deep inside a rank. + + + ## `rank` Describes which MPI rank should be assigned to each process type. diff --git a/config/defaults.yml b/config/defaults.yml new file mode 100644 index 0000000..7b04ce0 --- /dev/null +++ b/config/defaults.yml @@ -0,0 +1,101 @@ +--- +# Shared defaults for realtime_decoder configs. +# +# Per-animal / per-experiment YAMLs should declare: +# _extends: defaults.yml +# at the top, then override only what's actually different. Anything not +# overridden inherits from here. Lists and scalars are replaced as a +# whole; nested dicts merge key-by-key. +# +# This file is intentionally conservative: it only contains values that +# have been stable across the SC* / fred / ginny configs in this repo, +# plus a few sane defaults for new fields. Per-animal kinematics, +# stimulation, trode selection, and file paths must still live in the +# child config. + +algorithm: "clusterless_decoder" +datasource: "trodes" +num_setup_messages: 100 +preloaded_model: false +frozen_model: false + +sampling_rate: + spikes: 30000 + lfp: 1500 + position: 30 + +ripples: + max_ripple_samples: 450 + vel_thresh: 10 + freeze_stats: false + timings_bufsize: 1000000 + filter: + type: 'iir' + order: 2 + crit_freqs: [150, 250] + kwargs: + btype: 'bandpass' + ftype: 'butter' + smoothing_filter: + num_taps: 15 + band_edges: [50, 55] + desired: [1, 0] + threshold: + standard: 3.5 + conditioning: 3.75 + content: 4 + end: 0 + +decoder: + cred_int_bufsize: 10 + time_bin: + samples: 180 # 6 ms at 30 kHz spike clock + delay_samples: 180 + +clusterless_decoder: + state_labels: ['state'] + transmat_bias: 1 + +gui: + colormap: 'rocket' + send_interval: 0 + refresh_rate: 25 + trace_length: 2 + state_colors: ['#4c72b0', '#dd8452', '#55a868'] + num_xticks: 5 + +mua: + threshold: + trigger: 4 + end: 0 + freeze_stats: false + moving_avg_window: 5 + +cred_interval: + val: 0.5 + max_num: 5 + +kinematics: + smooth_x: true + smooth_y: true + smooth_speed: false + smoothing_filter: [0.31, 0.29, 0.25, 0.15] + +display: + stim_decider: + position: 150 + decoding_bins: 2000 + ripples: + lfp: 100000 + encoder: + encoding_spikes: 5000 + total_spikes: 50000 + occupancy: 5000 + position: 5000 + decoder: + total_spikes: 50000 + occupancy: 100 + +process_monitor: + interval: 15 + timeout: 3 diff --git a/config/demo_synthetic.yml b/config/demo_synthetic.yml index 666a7a3..fe62728 100644 --- a/config/demo_synthetic.yml +++ b/config/demo_synthetic.yml @@ -1,12 +1,16 @@ --- # Demo config: runs the full MPI pipeline against the in-process -# synthetic data source (realtime_decoder.synthetic). No Trodes, no -# acquisition hardware, no driver setup required. +# synthetic data source. No Trodes / acquisition hardware required. # # Run: # mpiexec -np 5 python -u runscript.py config/demo_synthetic.yml # -# Ranks: 0=supervisor, 1=decoder, 2=gui, 3=ripples, 4=encoder +# Everything not set here inherits from defaults.yml via `_extends`. + +_extends: defaults.yml + +datasource: "synthetic" + rank: supervisor: [0] ripples: [3] @@ -14,57 +18,35 @@ rank: encoders: [4] gui: [2] rank_settings: - enable_rec: [0,1,3,4] + enable_rec: [0, 1, 3, 4] trode_selection: ripples: [1] decoding: [1] decoder_assignment: 1: [1] -algorithm: "clusterless_decoder" -datasource: "synthetic" -num_setup_messages: 100 -preloaded_model: false -frozen_model: false + files: output_dir: '/tmp/realtime_decoder_demo' prefix: 'demo' rec_postfix: 'bin_rec' timing_postfix: 'timing' -# --- synthetic-source parameters (all optional, defaults shown) ----------- + +# --- synthetic-source parameters (all optional, defaults documented in +# realtime_decoder/synthetic.py) synthetic: spike_rate_hz: 30 - mark_dim: 4 - mark_amplitude_uv: 120 # well above encoder.spk_amp below - track_length_cm: 40 # fits into the 0..41 bins + mark_dim: 4 # must equal encoder.mark_dim below + mark_amplitude_uv: 120 + track_length_cm: 40 walk_speed_cm_s: 20 - startup_delay_s: 1.0 # supervisor waits this long, then fires play - run_duration_s: 30 # auto-terminate after this many seconds + startup_delay_s: 1.0 + run_duration_s: 30 voltage_scaling_factor: 0.195 -sampling_rate: - spikes: 30000 - lfp: 1500 - position: 30 + +# Lower-volume timings so the demo doesn't waste memory ripples: - max_ripple_samples: 450 - vel_thresh: 10 - freeze_stats: false timings_bufsize: 100000 - filter: - type: 'iir' - order: 2 - crit_freqs: [150, 250] - kwargs: - btype: 'bandpass' - ftype: 'butter' - smoothing_filter: - num_taps: 15 - band_edges: [50, 55] - desired: [1, 0] - threshold: - standard: 3.5 - conditioning: 3.75 - content: 4 - end: 0 + encoder: spk_amp: 60 use_channel_dist_from_max_amp: 2 @@ -78,43 +60,25 @@ encoder: upper: 41 num_bins: 41 arm_ids: [0] - arm_coords: [[0,40]] + arm_coords: [[0, 40]] mark_kernel: mean: 0 std: 20 use_filter: false n_std: 1 n_marks_min: 10 + decoder: decoder_to_message: 1 bufsize: 2000 timings_bufsize: 10000 - cred_int_bufsize: 10 starting_arm1_bin: 10 starting_arm2_bin: 30 num_pos_points: 30 - time_bin: - samples: 180 - delay_samples: 180 -clusterless_decoder: - state_labels: ['state'] - transmat_bias: 1 -gui: - colormap: 'rocket' - send_interval: 0 - refresh_rate: 25 - trace_length: 2 - state_colors: ['#4c72b0','#dd8452', '#55a868'] - num_xticks: 5 -mua: - threshold: - trigger: 4 - end: 0 - freeze_stats: false - moving_avg_window: 5 + stimulation: instructive: false - shortcut_msg_on: false # no ECU in demo mode + shortcut_msg_on: false automatic_threshold_update: false num_each_arm_per_minute: 1.1 num_pos_points: 30 @@ -147,18 +111,12 @@ stimulation: well_angle_range: 6 within_angle_range: 6 well_loc: [[100, 100], [200, 200]] + kinematics: - smooth_x: true - smooth_y: true - smooth_speed: false - smoothing_filter: [0.31, 0.29, 0.25, 0.15] scale_factor: 0.2644 -cred_interval: - val: 0.5 - max_num: 5 + display: stim_decider: - position: 150 decoding_bins: 200 ripples: lfp: 10000 @@ -169,7 +127,3 @@ display: position: 500 decoder: total_spikes: 5000 - occupancy: 100 -process_monitor: - interval: 15 - timeout: 3 diff --git a/realtime_decoder/config_loader.py b/realtime_decoder/config_loader.py new file mode 100644 index 0000000..9f26aff --- /dev/null +++ b/realtime_decoder/config_loader.py @@ -0,0 +1,229 @@ +"""Config loader: YAML defaults inheritance + startup validation. + +The historical pattern in this repo is one ~200-line YAML per animal, +near-duplicated across the colony. That breeds drift: a parameter +correctly tuned in `SC79_nTrode16.yml` quietly differs from the same +parameter in `SC80_nTrode16.yml`, and there is no single source of +truth for "what's the standard value of X." + +This module provides two small affordances: + +1. Optional ``_extends`` key that loads a parent YAML and deep-merges it + under the current file. Per-animal files become *overrides* on top + of a shared ``defaults.yml`` instead of full standalone configs. + +2. Startup validation. Today, common operator mistakes (missing + ``rank.supervisor``, unknown ``algorithm``, missing ``encoder.mark_dim``) + surface as ``IndexError``/``KeyError``/``NotImplementedError`` deep + inside a worker process — easy to lose in MPI log noise. ``validate`` + raises a single clear ``ConfigError`` *before* the MPI run starts. + +The loader is backward compatible: existing configs without ``_extends`` +load identically to ``yaml.safe_load`` (modulo validation). +""" + +from __future__ import annotations + +import os +from typing import Any, Dict, List, Tuple + +# PyYAML — stdlib-only dependency. (Python 3.7+ preserves insertion order +# in plain dicts, so no need for oyaml just to read configs.) +import yaml + + +class ConfigError(ValueError): + """Raised when a config fails validation or cannot be loaded.""" + + +# --------------------------------------------------------------------------- +# loading +# --------------------------------------------------------------------------- + + +def load_config(path: str) -> Dict[str, Any]: + """Load a YAML config, resolving ``_extends`` chains and validating. + + ``_extends`` may be a single path or a list of paths. Each parent is + loaded recursively (parents may themselves ``_extends``) and merged + in order, with the current file's keys taking precedence. + + Relative paths in ``_extends`` resolve relative to the file that + declares them. + """ + cfg = _load_with_inheritance(path, _seen=set()) + validate(cfg) + return cfg + + +def _load_with_inheritance(path: str, *, _seen: set) -> Dict[str, Any]: + abspath = os.path.abspath(path) + if abspath in _seen: + raise ConfigError( + f"Circular `_extends` chain detected involving {abspath}" + ) + _seen = _seen | {abspath} + + with open(abspath, 'r') as f: + raw = yaml.safe_load(f) or {} + + extends = raw.pop('_extends', None) + if extends is None: + return raw + + if isinstance(extends, str): + parents: List[str] = [extends] + elif isinstance(extends, list): + parents = list(extends) + else: + raise ConfigError( + f"`_extends` in {abspath} must be a string or list, got {type(extends).__name__}" + ) + + here = os.path.dirname(abspath) + merged: Dict[str, Any] = {} + for parent in parents: + parent_path = parent if os.path.isabs(parent) else os.path.join(here, parent) + merged = deep_merge(merged, _load_with_inheritance(parent_path, _seen=_seen)) + return deep_merge(merged, raw) + + +def deep_merge(base: Dict[str, Any], override: Dict[str, Any]) -> Dict[str, Any]: + """Recursively merge ``override`` onto ``base``, preferring ``override``. + + Nested dicts merge key-by-key. Lists and scalars are replaced, not + appended — this matches operator intuition ("override X" means + "replace X," not "extend X"). + """ + out = dict(base) + for k, v in override.items(): + if k in out and isinstance(out[k], dict) and isinstance(v, dict): + out[k] = deep_merge(out[k], v) + else: + out[k] = v + return out + + +# --------------------------------------------------------------------------- +# validation +# --------------------------------------------------------------------------- + + +# Required top-level keys and the rough shape we expect. Kept as plain +# code rather than a third-party schema lib so this module has no new +# install-time dependencies; the checks below are cheap and the error +# messages are deliberately operator-friendly. +_REQUIRED_TOP: Tuple[str, ...] = ( + 'rank', + 'algorithm', + 'sampling_rate', + 'files', + 'encoder', + 'decoder', + 'ripples', + 'kinematics', +) +_KNOWN_ALGORITHMS = ('clusterless_decoder', 'clusterless_classifier') +_KNOWN_DATASOURCES = ('trodes', 'synthetic') +_REQUIRED_RANK_ROLES = ('supervisor', 'decoders', 'encoders', 'ripples', 'gui') +_REQUIRED_SAMPLING = ('spikes', 'lfp', 'position') +_REQUIRED_FILES = ('output_dir', 'prefix') +_REQUIRED_ENCODER = ('mark_dim', 'bufsize', 'spk_amp', 'position') +_REQUIRED_ENCODER_POSITION = ('lower', 'upper', 'num_bins', 'arm_ids', 'arm_coords') +_REQUIRED_DECODER = ('bufsize', 'time_bin', 'cred_int_bufsize') + + +def validate(cfg: Dict[str, Any]) -> None: + """Raise ``ConfigError`` with a clear message if ``cfg`` is malformed.""" + errors: List[str] = [] + + for k in _REQUIRED_TOP: + if k not in cfg: + errors.append(f"missing required top-level key '{k}'") + + if cfg.get('algorithm') and cfg['algorithm'] not in _KNOWN_ALGORITHMS: + errors.append( + f"algorithm={cfg['algorithm']!r} is not one of {_KNOWN_ALGORITHMS}" + ) + + ds = cfg.get('datasource', 'trodes') + if ds not in _KNOWN_DATASOURCES: + errors.append( + f"datasource={ds!r} is not one of {_KNOWN_DATASOURCES}" + ) + + rank = cfg.get('rank', {}) + if isinstance(rank, dict): + for role in _REQUIRED_RANK_ROLES: + v = rank.get(role) + if v is None: + errors.append(f"rank.{role} is missing") + elif not isinstance(v, list) or not v: + errors.append(f"rank.{role} must be a non-empty list of ints, got {v!r}") + for role in ('supervisor', 'gui'): + if isinstance(rank.get(role), list) and len(rank[role]) != 1: + errors.append( + f"rank.{role} must contain exactly one rank, got {rank[role]!r}" + ) + else: + errors.append(f"rank must be a mapping, got {type(rank).__name__}") + + sr = cfg.get('sampling_rate', {}) + if isinstance(sr, dict): + for k in _REQUIRED_SAMPLING: + if k not in sr: + errors.append(f"sampling_rate.{k} is missing") + elif not isinstance(sr[k], (int, float)) or sr[k] <= 0: + errors.append(f"sampling_rate.{k} must be a positive number, got {sr[k]!r}") + + files = cfg.get('files', {}) + if isinstance(files, dict): + for k in _REQUIRED_FILES: + if not files.get(k): + errors.append(f"files.{k} is missing or empty") + + enc = cfg.get('encoder', {}) + if isinstance(enc, dict): + for k in _REQUIRED_ENCODER: + if k not in enc: + errors.append(f"encoder.{k} is missing") + pos = enc.get('position', {}) + if isinstance(pos, dict): + for k in _REQUIRED_ENCODER_POSITION: + if k not in pos: + errors.append(f"encoder.position.{k} is missing") + + dec = cfg.get('decoder', {}) + if isinstance(dec, dict): + for k in _REQUIRED_DECODER: + if k not in dec: + errors.append(f"decoder.{k} is missing") + tb = dec.get('time_bin', {}) + if isinstance(tb, dict): + for k in ('samples', 'delay_samples'): + if k not in tb: + errors.append(f"decoder.time_bin.{k} is missing") + + # Cross-field: each decoder rank must be a key in decoder_assignment. + dec_ranks = (cfg.get('rank') or {}).get('decoders') or [] + assignment = cfg.get('decoder_assignment') or {} + if isinstance(assignment, dict): + for r in dec_ranks: + if r not in assignment: + errors.append( + f"decoder_assignment is missing an entry for rank {r}" + ) + + # Cross-field: encoder.mark_dim must match across encoder and any + # synthetic-source override. + syn = cfg.get('synthetic') or {} + if isinstance(enc, dict) and isinstance(syn, dict): + if 'mark_dim' in syn and 'mark_dim' in enc and syn['mark_dim'] != enc['mark_dim']: + errors.append( + f"synthetic.mark_dim ({syn['mark_dim']}) " + f"!= encoder.mark_dim ({enc['mark_dim']})" + ) + + if errors: + bullet = '\n - '.join(errors) + raise ConfigError(f"config validation failed:\n - {bullet}") diff --git a/runscript.py b/runscript.py index 5b93548..2f374f5 100644 --- a/runscript.py +++ b/runscript.py @@ -1,5 +1,6 @@ import os import argparse +import sys import time import datetime import logging @@ -14,7 +15,7 @@ datatypes, position, trodesnet, synthetic, stimulation, main_process, ripple_process, encoder_process, decoder_process, gui_process, base, messages, - merge_rec + merge_rec, config_loader ) @@ -119,8 +120,17 @@ def setup(config_path, numprocs): num_digits = len(str(comm.Get_size())) - with open(config_path, 'r') as f: - config = yaml.safe_load(f) + # Load via the resolver: handles `_extends` inheritance and runs + # validation up front so missing required keys produce one clear + # error before the MPI run starts, instead of an IndexError / + # KeyError deep inside a worker. + try: + config = config_loader.load_config(config_path) + except config_loader.ConfigError as exc: + if rank == 0: + print(f"[config] {exc}", file=sys.stderr, flush=True) + comm.Barrier() + sys.exit(2) os.makedirs(os.path.dirname(config['files']['output_dir']), exist_ok=True) prefix = config['files']['prefix'] From ff809887bad9fc556acee15074f0c5af94c4a5a0 Mon Sep 17 00:00:00 2001 From: jl33ai Date: Fri, 29 May 2026 14:38:29 -0700 Subject: [PATCH 3/3] add unit tests for deterministic logic 36 tests covering the pure-python pieces of the codebase. these run on a plain `pip install pytest` without MPI or any acquisition hardware. coverage: - utils: normalize_to_probability (incl. nan handling), estimate_new_stats (vs numpy), apply_no_anim_boundary (1d + 2d) - position: PositionBinStruct edges/centers/get_bin, TrodesPositionMapper (basic + clamp-above-1 edge case), KinematicsEstimator (first-sample, unsmoothed speed, FIR smoothing) - transitions: sungod_transition_matrix shape, row sums equal 1 or 0, gap rows/cols zeroed, no NaNs in output - config_loader: deep_merge (nested, list-replace, scalar-overrides-dict), load_config with and without _extends, circular extends detection, validate catches missing required keys / unknown algorithm / decoder rank without assignment / mark_dim mismatch with synthetic - synthetic: receiver rejects bad datatype, lfp/spike/position emission shapes, returns None before activate, SyntheticClient fires startup callback after delay, records would-be ECU shortcut messages tests/conftest.py installs a minimal mpi4py stub if the real mpi4py is not importable. several modules (base, position, synthetic) import mpi4py at module top, and mpi4py needs a system MPI build to install. the stub lets the deterministic tests run anywhere; if real mpi4py is present it's used instead. setup.py: `extras_require={'test': ['pytest']}` and exclude tests from the installed package. also added pyyaml to install_requires since the config_loader uses it. README: short Tests section. to run: pip install -e .[test] pytest -q --- README.md | 11 +++ setup.py | 8 +- tests/__init__.py | 0 tests/conftest.py | 67 ++++++++++++++++ tests/test_config_loader.py | 151 ++++++++++++++++++++++++++++++++++++ tests/test_position.py | 114 +++++++++++++++++++++++++++ tests/test_synthetic.py | 127 ++++++++++++++++++++++++++++++ tests/test_transitions.py | 45 +++++++++++ tests/test_utils.py | 62 +++++++++++++++ 9 files changed, 583 insertions(+), 2 deletions(-) create mode 100644 tests/__init__.py create mode 100644 tests/conftest.py create mode 100644 tests/test_config_loader.py create mode 100644 tests/test_position.py create mode 100644 tests/test_synthetic.py create mode 100644 tests/test_transitions.py create mode 100644 tests/test_utils.py diff --git a/README.md b/README.md index 654473d..b04bab9 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,17 @@ mpiexec -np -bind-to hwthread python -u runscript.py arm 0 (bins 0..3), segment 1 -> arm 1 (bins 5..8) + mapper = position.TrodesPositionMapper( + arm_ids=[0, 1], + arm_coords=[[0, 3], [5, 8]], + ) + # arm 0 has 4 bins; normalized edges [0, .25, .5, .75, 1] + assert mapper.map_position(_camera_point(segment=0, position_on_segment=0.0)) == 0 + assert mapper.map_position(_camera_point(segment=0, position_on_segment=0.5)) == 2 + # exact upper edge clamps to the last bin per the inclusive-upper rule + assert mapper.map_position(_camera_point(segment=0, position_on_segment=1.0)) == 3 + # arm 1 starts at bin 5 + assert mapper.map_position(_camera_point(segment=1, position_on_segment=0.0)) == 5 + assert mapper.map_position(_camera_point(segment=1, position_on_segment=1.0)) == 8 + + +def test_position_mapper_above_one_clamps(): + # numerical noise that pushes segment position slightly above 1.0 + # should still land in the last bin rather than crashing. + mapper = position.TrodesPositionMapper( + arm_ids=[0], + arm_coords=[[0, 4]], + ) + assert mapper.map_position(_camera_point(segment=0, position_on_segment=1.0001)) == 4 + + +# --------------------------------------------------------------------------- +# KinematicsEstimator +# --------------------------------------------------------------------------- + + +def test_kinematics_first_sample_returns_zero_speed(): + est = position.KinematicsEstimator( + scale_factor=1.0, dt=1.0, + xfilter=[1.0], yfilter=[1.0], speedfilter=[1.0], + ) + x, y, s = est.compute_kinematics(10.0, 20.0) + assert (x, y, s) == (10.0, 20.0, 0) + + +def test_kinematics_speed_unsmoothed_matches_euclid_distance(): + est = position.KinematicsEstimator( + scale_factor=1.0, dt=0.5, + xfilter=[1.0], yfilter=[1.0], speedfilter=[1.0], + ) + est.compute_kinematics(0.0, 0.0) # prime + x, y, s = est.compute_kinematics(3.0, 4.0) + # 5 units over 0.5s -> 10 units/sec + assert (x, y) == (3.0, 4.0) + assert np.isclose(s, 10.0) + + +def test_kinematics_smoothing_applies_fir(): + # 3-tap moving average: smoothed value of the third sample should + # equal the average of the last three inputs (* scale). + est = position.KinematicsEstimator( + scale_factor=1.0, dt=1.0, + xfilter=[1/3, 1/3, 1/3], + yfilter=[1/3, 1/3, 1/3], + speedfilter=[1.0], + ) + est.compute_kinematics(0.0, 0.0) # prime, returned raw + est.compute_kinematics(6.0, 0.0, smooth_x=True) # buf=[6,0,0] + x, _, _ = est.compute_kinematics(9.0, 0.0, smooth_x=True) # buf=[9,6,0] + assert np.isclose(x, 5.0) # (9+6+0)/3 diff --git a/tests/test_synthetic.py b/tests/test_synthetic.py new file mode 100644 index 0000000..c228468 --- /dev/null +++ b/tests/test_synthetic.py @@ -0,0 +1,127 @@ +"""Unit tests for realtime_decoder.synthetic. + +These tests don't spin up MPI; they construct receivers directly with a +fake comm and verify the per-datatype generators produce the right +shapes and types. +""" + +import time + +import numpy as np +import pytest + +from realtime_decoder import synthetic, datatypes + + +@pytest.fixture +def base_config(): + return { + 'sampling_rate': {'spikes': 30000, 'lfp': 1500, 'position': 30}, + 'kinematics': {'scale_factor': 0.25}, + 'synthetic': { + 'spike_rate_hz': 1000.0, # high so we see spikes fast + 'mark_dim': 4, + 'mark_amplitude_uv': 120.0, + 'track_length_cm': 40.0, + 'walk_speed_cm_s': 20.0, + 'startup_delay_s': 0.0, + 'run_duration_s': 5.0, + 'voltage_scaling_factor': 0.195, + }, + } + + +class _FakeComm: + """Minimal stub: receivers only need .Get_rank if anything; not used in __next__.""" + pass + + +def test_receiver_rejects_unknown_datatype(base_config): + with pytest.raises(TypeError): + synthetic.SyntheticDataReceiver(_FakeComm(), 0, base_config, datatype=999) + + +def test_lfp_receiver_emits_correct_shape(base_config): + rx = synthetic.SyntheticDataReceiver(_FakeComm(), 0, base_config, datatypes.Datatypes.LFP) + rx.register_datatype_channel(1) + rx.register_datatype_channel(2) + rx.activate() + # poll until we get a sample (LFP at 1500hz, so first sample ~immediate) + deadline = time.time() + 1.0 + sample = None + while time.time() < deadline: + sample = rx.__next__() + if sample is not None: + break + assert sample is not None, "no LFP sample within 1s" + assert isinstance(sample, datatypes.LFPPoint) + assert sample.data.shape == (2,) + + +def test_lfp_receiver_returns_none_before_activate(base_config): + rx = synthetic.SyntheticDataReceiver(_FakeComm(), 0, base_config, datatypes.Datatypes.LFP) + rx.register_datatype_channel(1) + assert rx.__next__() is None + + +def test_spike_receiver_emits_marks_above_amp_threshold(base_config): + rx = synthetic.SyntheticDataReceiver(_FakeComm(), 0, base_config, datatypes.Datatypes.SPIKES) + rx.register_datatype_channel(7) + rx.activate() + deadline = time.time() + 1.0 + spike = None + while time.time() < deadline: + spike = rx.__next__() + if spike is not None: + break + assert spike is not None, "no spike within 1s at 1khz rate" + assert isinstance(spike, datatypes.SpikePoint) + assert spike.elec_grp_id == 7 + assert spike.data.shape == (base_config['synthetic']['mark_dim'],) + # marks should clear a reasonable amplitude threshold post-scaling + assert float(np.max(spike.data)) > 50.0 + + +def test_position_receiver_walks_within_bounds(base_config): + base_config['synthetic']['walk_speed_cm_s'] = 200 # fast walk so we cover range quickly + rx = synthetic.SyntheticDataReceiver(_FakeComm(), 0, base_config, datatypes.Datatypes.LINEAR_POSITION) + rx.activate() + L = base_config['synthetic']['track_length_cm'] + deadline = time.time() + 2.0 + positions = [] + while time.time() < deadline and len(positions) < 30: + p = rx.__next__() + if p is not None: + positions.append(p.position) + assert isinstance(p, datatypes.CameraModulePoint) + assert positions, "no position samples emitted" + assert min(positions) >= 0.0 + assert max(positions) <= L + 1e-6 + + +def test_synthetic_client_fires_startup_callback(base_config): + base_config['synthetic']['startup_delay_s'] = 0.05 + base_config['synthetic']['run_duration_s'] = 0 # disable auto-term + client = synthetic.SyntheticClient(base_config) + + calls = {'startup': 0, 'term': 0} + client.set_startup_callback(lambda: calls.__setitem__('startup', calls['startup'] + 1)) + client.set_termination_callback(lambda: calls.__setitem__('term', calls['term'] + 1)) + + # before delay elapses, no callback + client.receive() + assert calls['startup'] == 0 + + time.sleep(0.1) + client.receive() + assert calls['startup'] == 1 + # subsequent calls should not refire startup + client.receive() + assert calls['startup'] == 1 + + +def test_synthetic_client_records_shortcut_messages(base_config): + client = synthetic.SyntheticClient(base_config) + client.send_statescript_shortcut_message(22) + client.send_statescript_shortcut_message(14) + assert [v for _, v in client.sent_shortcuts] == [22, 14] diff --git a/tests/test_transitions.py b/tests/test_transitions.py new file mode 100644 index 0000000..77affc5 --- /dev/null +++ b/tests/test_transitions.py @@ -0,0 +1,45 @@ +"""Unit tests for the transition model builders.""" + +import numpy as np + +from realtime_decoder import transitions + + +def test_sungod_transition_matrix_shape_and_row_sums(): + pos_bins = np.arange(10) + arm_coords = [[0, 3], [6, 9]] + bias = 1 + T = transitions.sungod_transition_matrix(pos_bins, arm_coords, bias) + + # square, sized to the number of position bins + assert T.shape == (len(pos_bins), len(pos_bins)) + + # rows that are not entirely zero should sum to 1 (within float + # tolerance). The gap rows between arms are masked to NaN by + # apply_no_anim_boundary and then zeroed by the function, so they + # legitimately sum to 0. + row_sums = T.sum(axis=1) + for s in row_sums: + assert np.isclose(s, 0.0) or np.isclose(s, 1.0) + + # at least the in-arm rows must sum to 1 + in_arm_rows = [r for arm in arm_coords for r in range(arm[0], arm[1] + 1)] + for r in in_arm_rows: + assert np.isclose(row_sums[r], 1.0), f"row {r} sums to {row_sums[r]}" + + +def test_sungod_transition_matrix_gap_rows_are_zero(): + pos_bins = np.arange(10) + arm_coords = [[0, 3], [6, 9]] + T = transitions.sungod_transition_matrix(pos_bins, arm_coords, bias=1) + # bins 4 and 5 are gaps; the corresponding rows and columns should + # be all zero so transition mass cannot flow through them. + assert np.all(T[4:6, :] == 0) + assert np.all(T[:, 4:6] == 0) + + +def test_sungod_transition_matrix_no_nans(): + pos_bins = np.arange(8) + arm_coords = [[0, 7]] + T = transitions.sungod_transition_matrix(pos_bins, arm_coords, bias=1) + assert not np.any(np.isnan(T)) diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..6206c97 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,62 @@ +"""Unit tests for realtime_decoder.utils helpers.""" + +import numpy as np +import pytest + +from realtime_decoder import utils + + +def test_normalize_to_probability_simple(): + dist = np.array([1.0, 2.0, 3.0, 4.0]) + out = utils.normalize_to_probability(dist) + assert np.isclose(out.sum(), 1.0) + np.testing.assert_allclose(out, dist / dist.sum()) + + +def test_normalize_to_probability_ignores_nan(): + # np.nansum is used internally, so NaN entries should not bias + # the normalization of the other entries. + dist = np.array([1.0, np.nan, 3.0]) + out = utils.normalize_to_probability(dist) + # the two finite entries should sum to 1 between them (NaN propagates + # to its own bin but the divisor was sum(1+3)=4) + finite = out[np.isfinite(out)] + assert np.isclose(finite.sum(), 1.0) + + +def test_estimate_new_stats_matches_numpy(): + # Welford's online stats should converge to numpy's batch stats. + rng = np.random.default_rng(0) + values = rng.standard_normal(500) + mean = 0.0 + M2 = 0.0 + count = 0 + for v in values: + mean, M2, count = utils.estimate_new_stats(v, mean, M2, count) + # variance from M2; compare to numpy population variance (ddof=0) + var = M2 / count + assert np.isclose(mean, values.mean(), atol=1e-12) + assert np.isclose(var, values.var(), atol=1e-12) + + +def test_apply_no_anim_boundary_2d_fills_gaps(): + # arm_coords [[0,3],[6,9]] => bins 4 and 5 are "no animal" gaps. + x_bins = np.arange(10) + arm_coords = [[0, 3], [6, 9]] + image = np.ones((10, 10)) + out = utils.apply_no_anim_boundary(x_bins, arm_coords, image, fill=0) + # gap rows and columns should be zeroed + assert np.all(out[4:6, :] == 0) + assert np.all(out[:, 4:6] == 0) + # non-gap interior should still be 1 + assert out[1, 1] == 1 + assert out[7, 7] == 1 + + +def test_apply_no_anim_boundary_1d_fills_gaps(): + x_bins = np.arange(10) + arm_coords = [[0, 3], [6, 9]] + image = np.ones(10) + out = utils.apply_no_anim_boundary(x_bins, arm_coords, image, fill=-1) + assert np.all(out[4:6] == -1) + assert out[0] == 1 and out[9] == 1