diff --git a/README.md b/README.md index db0fc9a..654473d 100644 --- a/README.md +++ b/README.md @@ -24,10 +24,61 @@ mpiexec -np -bind-to hwthread python -u runscript.py Dict[str, Any]: + """Load a YAML config, resolving ``_extends`` chains and validating. + + ``_extends`` may be a single path or a list of paths. Each parent is + loaded recursively (parents may themselves ``_extends``) and merged + in order, with the current file's keys taking precedence. + + Relative paths in ``_extends`` resolve relative to the file that + declares them. + """ + cfg = _load_with_inheritance(path, _seen=set()) + validate(cfg) + return cfg + + +def _load_with_inheritance(path: str, *, _seen: set) -> Dict[str, Any]: + abspath = os.path.abspath(path) + if abspath in _seen: + raise ConfigError( + f"Circular `_extends` chain detected involving {abspath}" + ) + _seen = _seen | {abspath} + + with open(abspath, 'r') as f: + raw = yaml.safe_load(f) or {} + + extends = raw.pop('_extends', None) + if extends is None: + return raw + + if isinstance(extends, str): + parents: List[str] = [extends] + elif isinstance(extends, list): + parents = list(extends) + else: + raise ConfigError( + f"`_extends` in {abspath} must be a string or list, got {type(extends).__name__}" + ) + + here = os.path.dirname(abspath) + merged: Dict[str, Any] = {} + for parent in parents: + parent_path = parent if os.path.isabs(parent) else os.path.join(here, parent) + merged = deep_merge(merged, _load_with_inheritance(parent_path, _seen=_seen)) + return deep_merge(merged, raw) + + +def deep_merge(base: Dict[str, Any], override: Dict[str, Any]) -> Dict[str, Any]: + """Recursively merge ``override`` onto ``base``, preferring ``override``. + + Nested dicts merge key-by-key. Lists and scalars are replaced, not + appended — this matches operator intuition ("override X" means + "replace X," not "extend X"). + """ + out = dict(base) + for k, v in override.items(): + if k in out and isinstance(out[k], dict) and isinstance(v, dict): + out[k] = deep_merge(out[k], v) + else: + out[k] = v + return out + + +# --------------------------------------------------------------------------- +# validation +# --------------------------------------------------------------------------- + + +# Required top-level keys and the rough shape we expect. Kept as plain +# code rather than a third-party schema lib so this module has no new +# install-time dependencies; the checks below are cheap and the error +# messages are deliberately operator-friendly. +_REQUIRED_TOP: Tuple[str, ...] = ( + 'rank', + 'algorithm', + 'sampling_rate', + 'files', + 'encoder', + 'decoder', + 'ripples', + 'kinematics', +) +_KNOWN_ALGORITHMS = ('clusterless_decoder', 'clusterless_classifier') +_KNOWN_DATASOURCES = ('trodes', 'synthetic') +_REQUIRED_RANK_ROLES = ('supervisor', 'decoders', 'encoders', 'ripples', 'gui') +_REQUIRED_SAMPLING = ('spikes', 'lfp', 'position') +_REQUIRED_FILES = ('output_dir', 'prefix') +_REQUIRED_ENCODER = ('mark_dim', 'bufsize', 'spk_amp', 'position') +_REQUIRED_ENCODER_POSITION = ('lower', 'upper', 'num_bins', 'arm_ids', 'arm_coords') +_REQUIRED_DECODER = ('bufsize', 'time_bin', 'cred_int_bufsize') + + +def validate(cfg: Dict[str, Any]) -> None: + """Raise ``ConfigError`` with a clear message if ``cfg`` is malformed.""" + errors: List[str] = [] + + for k in _REQUIRED_TOP: + if k not in cfg: + errors.append(f"missing required top-level key '{k}'") + + if cfg.get('algorithm') and cfg['algorithm'] not in _KNOWN_ALGORITHMS: + errors.append( + f"algorithm={cfg['algorithm']!r} is not one of {_KNOWN_ALGORITHMS}" + ) + + ds = cfg.get('datasource', 'trodes') + if ds not in _KNOWN_DATASOURCES: + errors.append( + f"datasource={ds!r} is not one of {_KNOWN_DATASOURCES}" + ) + + rank = cfg.get('rank', {}) + if isinstance(rank, dict): + for role in _REQUIRED_RANK_ROLES: + v = rank.get(role) + if v is None: + errors.append(f"rank.{role} is missing") + elif not isinstance(v, list) or not v: + errors.append(f"rank.{role} must be a non-empty list of ints, got {v!r}") + for role in ('supervisor', 'gui'): + if isinstance(rank.get(role), list) and len(rank[role]) != 1: + errors.append( + f"rank.{role} must contain exactly one rank, got {rank[role]!r}" + ) + else: + errors.append(f"rank must be a mapping, got {type(rank).__name__}") + + sr = cfg.get('sampling_rate', {}) + if isinstance(sr, dict): + for k in _REQUIRED_SAMPLING: + if k not in sr: + errors.append(f"sampling_rate.{k} is missing") + elif not isinstance(sr[k], (int, float)) or sr[k] <= 0: + errors.append(f"sampling_rate.{k} must be a positive number, got {sr[k]!r}") + + files = cfg.get('files', {}) + if isinstance(files, dict): + for k in _REQUIRED_FILES: + if not files.get(k): + errors.append(f"files.{k} is missing or empty") + + enc = cfg.get('encoder', {}) + if isinstance(enc, dict): + for k in _REQUIRED_ENCODER: + if k not in enc: + errors.append(f"encoder.{k} is missing") + pos = enc.get('position', {}) + if isinstance(pos, dict): + for k in _REQUIRED_ENCODER_POSITION: + if k not in pos: + errors.append(f"encoder.position.{k} is missing") + + dec = cfg.get('decoder', {}) + if isinstance(dec, dict): + for k in _REQUIRED_DECODER: + if k not in dec: + errors.append(f"decoder.{k} is missing") + tb = dec.get('time_bin', {}) + if isinstance(tb, dict): + for k in ('samples', 'delay_samples'): + if k not in tb: + errors.append(f"decoder.time_bin.{k} is missing") + + # Cross-field: each decoder rank must be a key in decoder_assignment. + dec_ranks = (cfg.get('rank') or {}).get('decoders') or [] + assignment = cfg.get('decoder_assignment') or {} + if isinstance(assignment, dict): + for r in dec_ranks: + if r not in assignment: + errors.append( + f"decoder_assignment is missing an entry for rank {r}" + ) + + # Cross-field: encoder.mark_dim must match across encoder and any + # synthetic-source override. + syn = cfg.get('synthetic') or {} + if isinstance(enc, dict) and isinstance(syn, dict): + if 'mark_dim' in syn and 'mark_dim' in enc and syn['mark_dim'] != enc['mark_dim']: + errors.append( + f"synthetic.mark_dim ({syn['mark_dim']}) " + f"!= encoder.mark_dim ({enc['mark_dim']})" + ) + + if errors: + bullet = '\n - '.join(errors) + raise ConfigError(f"config validation failed:\n - {bullet}") diff --git a/realtime_decoder/synthetic.py b/realtime_decoder/synthetic.py new file mode 100644 index 0000000..f3c8e30 --- /dev/null +++ b/realtime_decoder/synthetic.py @@ -0,0 +1,305 @@ +"""Synthetic data source for the realtime_decoder. + +Lets you install the package and run the full MPI pipeline end-to-end +without any acquisition hardware (Trodes, SpikeGLX, Open Ephys, etc.). + +This is intended for: + * smoke-testing a fresh install + * developing/debugging the decoder loop on a laptop + * regression testing in CI + +It is NOT intended to be biologically realistic. The synthetic spikes are +Poisson with a fixed rate, marks are gaussian, and the synthetic position +walks back and forth on a simple linear track. The goal is to exercise the +data path and message plumbing, not to validate decoding accuracy. + +Wiring is parallel to ``trodesnet.py``: a ``SyntheticDataReceiver`` that +mirrors ``TrodesDataReceiver`` and a ``SyntheticClient`` that mirrors +``TrodesClient``. The dispatch happens in ``runscript.py`` based on +``config['datasource']``. + +Config block expected under the top-level ``synthetic`` key (all optional): + + synthetic: + spike_rate_hz: 20 # per ntrode, Poisson + mark_dim: 8 # must match encoder.mark_dim + mark_amplitude_uv: 80 # mean spike amplitude + track_length_cm: 200 # walk distance + walk_speed_cm_s: 15 # synthetic animal speed + startup_delay_s: 1.0 # delay before firing the startup callback + run_duration_s: 60 # auto-terminate after this long + voltage_scaling_factor: 0.195 +""" + +import time + +import numpy as np + +from realtime_decoder import utils +from realtime_decoder.base import DataSourceReceiver +from realtime_decoder.datatypes import ( + Datatypes, + LFPPoint, + SpikePoint, + CameraModulePoint, +) + + +_DEFAULTS = { + 'spike_rate_hz': 20.0, + 'mark_dim': 4, + 'mark_amplitude_uv': 80.0, + 'track_length_cm': 200.0, + 'walk_speed_cm_s': 15.0, + 'startup_delay_s': 1.0, + 'run_duration_s': 60.0, + 'voltage_scaling_factor': 1.0, +} + + +def _params(config): + """Read the ``synthetic`` config block, applying defaults.""" + p = dict(_DEFAULTS) + p.update(config.get('synthetic') or {}) + return p + + +class SyntheticDataReceiver(DataSourceReceiver): + """Drop-in synthetic replacement for ``trodesnet.TrodesDataReceiver``. + + Generates LFP / spike / position samples on demand at clock-driven + rates. ``__next__`` returns None when no sample is due yet, matching + the non-blocking semantics of the Trodes receiver — the polling main + loops do not need to know they are reading synthetic data. + """ + + def __init__(self, comm, rank, config, datatype): + if datatype not in ( + Datatypes.LFP, + Datatypes.SPIKES, + Datatypes.LINEAR_POSITION, + ): + raise TypeError(f"Invalid datatype {datatype}") + super().__init__(comm, rank, config, datatype) + + self._p = _params(config) + self._started = False + self._stopped = False + + self.ntrode_ids = [] + + # Per-stream pacing: we advance a deterministic virtual clock + # (sample index) from t=0 at activate(), and emit samples as + # wall-clock catches up. This gives roughly the same delivery + # cadence as a live acquisition system at the configured rates. + self._t0_wall = None + self._next_sample_idx = 0 + if datatype == Datatypes.LFP: + self._fs = config['sampling_rate']['lfp'] + elif datatype == Datatypes.SPIKES: + self._fs = config['sampling_rate']['spikes'] + else: # LINEAR_POSITION + self._fs = config['sampling_rate']['position'] + + self._spike_clock = config['sampling_rate']['spikes'] + + # Spike-stream specific: independent Poisson process per ntrode. + # ``_next_spike_sample[ntid]`` stores the spike-clock sample index + # at which that ntrode's next spike will fire. + self._rng = np.random.default_rng(seed=rank * 1009 + int(datatype)) + self._next_spike_sample = {} + self._mark_dim = self._p['mark_dim'] + self._amp = self._p['mark_amplitude_uv'] + + # ------------------------------------------------------------------ + # DataSourceReceiver contract + # ------------------------------------------------------------------ + + def register_datatype_channel(self, channel): + ntrode_id = int(channel) + if self.datatype in (Datatypes.LFP, Datatypes.SPIKES): + if ntrode_id not in self.ntrode_ids: + self.ntrode_ids.append(ntrode_id) + # position has no channels + + def activate(self): + self._t0_wall = time.time() + self._next_sample_idx = 0 + if self.datatype == Datatypes.SPIKES: + for ntid in self.ntrode_ids: + self._schedule_next_spike(ntid, sample_now=0) + self._started = True + self.class_log.debug( + f"Synthetic {self.datatype.name} datastream activated " + f"({len(self.ntrode_ids)} ntrodes)" + ) + + def deactivate(self): + self._started = False + + def stop_iterator(self): + raise StopIteration() + + def __next__(self): + if not self._started: + return None + + elapsed = time.time() - self._t0_wall + if ( + self._p['run_duration_s'] > 0 + and elapsed > self._p['run_duration_s'] + and not self._stopped + ): + # one-time log; the supervisor's termination is wired + # through SyntheticClient.receive() below. + self._stopped = True + + if self.datatype == Datatypes.LFP: + return self._next_lfp(elapsed) + elif self.datatype == Datatypes.SPIKES: + return self._next_spike(elapsed) + else: + return self._next_position(elapsed) + + # ------------------------------------------------------------------ + # Per-datatype generators + # ------------------------------------------------------------------ + + def _next_lfp(self, elapsed): + target_idx = int(elapsed * self._fs) + if target_idx < self._next_sample_idx: + return None + idx = self._next_sample_idx + self._next_sample_idx += 1 + # white-ish noise sized to (num_channels,), scaled the same way + # TrodesDataReceiver does (raw * voltage_scaling_factor) + n = max(1, len(self.ntrode_ids)) + raw = self._rng.standard_normal(n) * 200.0 # ~uV range pre-scale + data = raw * self._p['voltage_scaling_factor'] + local_ts = idx # LFP uses spike-clock timestamps in real Trodes; + # at fs_lfp=1500, fs_spike=30000 the ratio is 20, but downstream + # only cares about monotonicity within a stream, so use idx. + system_ts = time.time_ns() + return LFPPoint( + local_ts, + list(self.ntrode_ids), + data, + system_ts, + time.time_ns(), + ) + + def _next_spike(self, elapsed): + if not self.ntrode_ids: + return None + spike_sample_now = int(elapsed * self._spike_clock) + # Find any ntrode whose next-spike sample has arrived. + for ntid in self.ntrode_ids: + if self._next_spike_sample[ntid] <= spike_sample_now: + ts = self._next_spike_sample[ntid] + self._schedule_next_spike(ntid, sample_now=spike_sample_now) + # mark vector: gaussian around _amp, all channels positive + samples = ( + self._rng.standard_normal(self._mark_dim) * 8.0 + self._amp + ) / self._p['voltage_scaling_factor'] + # SpikePoint.data is later multiplied by voltage_scaling_factor + # in real Trodes; the encoder reads `max(mark_vec)` so we just + # need the post-scaling magnitudes to clear `encoder.spk_amp`. + return SpikePoint( + ts, + ntid, + samples * self._p['voltage_scaling_factor'], + time.time_ns(), + time.time_ns(), + ) + return None + + def _next_position(self, elapsed): + target_idx = int(elapsed * self._fs) + if target_idx < self._next_sample_idx: + return None + idx = self._next_sample_idx + self._next_sample_idx += 1 + + # triangle-wave walk along a single linear segment between 0 and + # track_length_cm + L = self._p['track_length_cm'] + v = self._p['walk_speed_cm_s'] + t = elapsed + period = 2.0 * L / max(v, 1e-6) + phase = (t % period) / period # 0..1 + pos_cm = L * (1.0 - abs(2.0 * phase - 1.0)) + # x/y/x2/y2 in "pixel" units — kinematics.scale_factor converts back + sf = self.config['kinematics']['scale_factor'] + x = pos_cm / sf + y = 100.0 # constant + x2 = x + 5.0 + y2 = y + return CameraModulePoint( + idx, + segment=0, + position=pos_cm, + x=x, + y=y, + x2=x2, + y2=y2, + t_recv_data=time.time_ns(), + ) + + # ------------------------------------------------------------------ + # internals + # ------------------------------------------------------------------ + + def _schedule_next_spike(self, ntid, *, sample_now): + rate = max(self._p['spike_rate_hz'], 1e-6) + # exponential inter-arrival in seconds → samples + gap_s = self._rng.exponential(1.0 / rate) + gap_samples = max(1, int(gap_s * self._spike_clock)) + self._next_spike_sample[ntid] = sample_now + gap_samples + + +class SyntheticClient(object): + """Drop-in synthetic replacement for ``trodesnet.TrodesClient``. + + Exposes the same surface used by the supervisor and stim decider: + * ``set_startup_callback`` / ``set_termination_callback`` + * ``receive`` (called from the supervisor main loop) + * ``send_statescript_shortcut_message`` (called from stim_decider) + + ``receive`` fires the startup callback once after ``startup_delay_s`` + of wall clock has elapsed, and fires termination once ``run_duration_s`` + has elapsed. + """ + + def __init__(self, config): + self._startup_callback = utils.nop + self._termination_callback = utils.nop + self._p = _params(config) + self._t0_wall = time.time() + self._started = False + self._terminated = False + # log-only buffer of "shortcut messages" the stim decider would + # have sent to ECU; useful for asserting in tests later. + self.sent_shortcuts = [] + + def receive(self): + elapsed = time.time() - self._t0_wall + if not self._started and elapsed >= self._p['startup_delay_s']: + self._started = True + self._startup_callback() + if ( + self._started + and not self._terminated + and self._p['run_duration_s'] > 0 + and elapsed >= self._p['run_duration_s'] + self._p['startup_delay_s'] + ): + self._terminated = True + self._termination_callback() + + def send_statescript_shortcut_message(self, val): + self.sent_shortcuts.append((time.time_ns(), int(val))) + + def set_startup_callback(self, callback): + self._startup_callback = callback + + def set_termination_callback(self, callback): + self._termination_callback = callback diff --git a/runscript.py b/runscript.py index 3e920c8..2f374f5 100644 --- a/runscript.py +++ b/runscript.py @@ -1,5 +1,6 @@ import os import argparse +import sys import time import datetime import logging @@ -11,12 +12,30 @@ from mpi4py import MPI from realtime_decoder import ( - datatypes, position, trodesnet, stimulation, + datatypes, position, trodesnet, synthetic, stimulation, main_process, ripple_process, encoder_process, decoder_process, gui_process, base, messages, - merge_rec + merge_rec, config_loader ) + +def _data_source_factory(config): + """Pick the (receiver_class, client_class) pair for the configured + data source. + + `datasource: trodes` (default) uses the live Trodes streams. + `datasource: synthetic` uses the in-process generator from + `realtime_decoder.synthetic` — install-and-run with no hardware. + """ + ds = config.get('datasource', 'trodes') + if ds == 'trodes': + return trodesnet.TrodesDataReceiver, trodesnet.TrodesClient + if ds == 'synthetic': + return synthetic.SyntheticDataReceiver, synthetic.SyntheticClient + raise ValueError( + f"Unknown datasource {ds!r}; expected 'trodes' or 'synthetic'" + ) + # from line_profiler import LineProfiler class GuiProcessStub(base.RealtimeProcess, base.MessageHandler): @@ -101,8 +120,17 @@ def setup(config_path, numprocs): num_digits = len(str(comm.Get_size())) - with open(config_path, 'r') as f: - config = yaml.safe_load(f) + # Load via the resolver: handles `_extends` inheritance and runs + # validation up front so missing required keys produce one clear + # error before the MPI run starts, instead of an IndexError / + # KeyError deep inside a worker. + try: + config = config_loader.load_config(config_path) + except config_loader.ConfigError as exc: + if rank == 0: + print(f"[config] {exc}", file=sys.stderr, flush=True) + comm.Barrier() + sys.exit(2) os.makedirs(os.path.dirname(config['files']['output_dir']), exist_ok=True) prefix = config['files']['prefix'] @@ -169,21 +197,23 @@ def setup(config_path, numprocs): regloop = True ################################################# + DataReceiver, Client = _data_source_factory(config) + if rank in config['rank']['supervisor']: - trodes_client = trodesnet.TrodesClient(config) + net_client = Client(config) stim_decider = stimulation.TwoArmTrodesStimDecider( - comm, rank, config, trodes_client + comm, rank, config, net_client ) process = main_process.MainProcess( - comm, rank, config, stim_decider, trodes_client + comm, rank, config, stim_decider, net_client ) - trodes_client.set_startup_callback(process.startup) - trodes_client.set_termination_callback(process.trigger_termination) + net_client.set_startup_callback(process.startup) + net_client.set_termination_callback(process.trigger_termination) elif rank in config['rank']['ripples']: - lfp_interface = trodesnet.TrodesDataReceiver( + lfp_interface = DataReceiver( comm, rank, config, datatypes.Datatypes.LFP ) - pos_interface = trodesnet.TrodesDataReceiver( + pos_interface = DataReceiver( comm, rank, config, datatypes.Datatypes.LINEAR_POSITION ) process = ripple_process.RippleProcess( @@ -196,10 +226,10 @@ def setup(config_path, numprocs): # prof.print_stats() # regloop = False elif rank in config['rank']['encoders']: - spikes_interface = trodesnet.TrodesDataReceiver( + spikes_interface = DataReceiver( comm, rank, config, datatypes.Datatypes.SPIKES ) - pos_interface = trodesnet.TrodesDataReceiver( + pos_interface = DataReceiver( comm, rank, config, datatypes.Datatypes.LINEAR_POSITION ) pos_mapper = position.TrodesPositionMapper( @@ -211,7 +241,7 @@ def setup(config_path, numprocs): pos_mapper ) elif rank in config['rank']['decoders']: - pos_interface = trodesnet.TrodesDataReceiver( + pos_interface = DataReceiver( comm, rank, config, datatypes.Datatypes.LINEAR_POSITION ) pos_mapper = position.TrodesPositionMapper(