From 2d611d74241ce97c122bc647f7149c85f195f508 Mon Sep 17 00:00:00 2001 From: jl33ai Date: Fri, 29 May 2026 09:27:34 -0700 Subject: [PATCH] add synthetic data source for running without trodes lets you run the pipeline end to end without an acquisition rig. useful for smoke testing a fresh install, working on a laptop, and CI. - realtime_decoder/synthetic.py: SyntheticDataReceiver and SyntheticClient that match the surface of TrodesDataReceiver and TrodesClient. drop in replacements. non-blocking __next__ semantics so the polling loops in encoder/decoder/ripple work unchanged. - runscript.py picks the data source via config['datasource']. defaults to 'trodes' so every existing config keeps working. - config/demo_synthetic.yml is a 5 rank demo wired to the synthetic source. - README adds a 'running without acquisition hardware' section. spikes are poisson, marks gaussian, position walks a triangle wave along a single segment. not biologically realistic, the point is to exercise the data path and message plumbing. to try it: mpiexec -np 5 python -u runscript.py config/demo_synthetic.yml --- README.md | 17 ++ config/demo_synthetic.yml | 175 +++++++++++++++++++ realtime_decoder/synthetic.py | 305 ++++++++++++++++++++++++++++++++++ runscript.py | 42 +++-- 4 files changed, 528 insertions(+), 11 deletions(-) create mode 100644 config/demo_synthetic.yml create mode 100644 realtime_decoder/synthetic.py diff --git a/README.md b/README.md index db0fc9a..803ee71 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,23 @@ mpiexec -np -bind-to hwthread python -u runscript.py 0 + and elapsed > self._p['run_duration_s'] + and not self._stopped + ): + # one-time log; the supervisor's termination is wired + # through SyntheticClient.receive() below. + self._stopped = True + + if self.datatype == Datatypes.LFP: + return self._next_lfp(elapsed) + elif self.datatype == Datatypes.SPIKES: + return self._next_spike(elapsed) + else: + return self._next_position(elapsed) + + # ------------------------------------------------------------------ + # Per-datatype generators + # ------------------------------------------------------------------ + + def _next_lfp(self, elapsed): + target_idx = int(elapsed * self._fs) + if target_idx < self._next_sample_idx: + return None + idx = self._next_sample_idx + self._next_sample_idx += 1 + # white-ish noise sized to (num_channels,), scaled the same way + # TrodesDataReceiver does (raw * voltage_scaling_factor) + n = max(1, len(self.ntrode_ids)) + raw = self._rng.standard_normal(n) * 200.0 # ~uV range pre-scale + data = raw * self._p['voltage_scaling_factor'] + local_ts = idx # LFP uses spike-clock timestamps in real Trodes; + # at fs_lfp=1500, fs_spike=30000 the ratio is 20, but downstream + # only cares about monotonicity within a stream, so use idx. + system_ts = time.time_ns() + return LFPPoint( + local_ts, + list(self.ntrode_ids), + data, + system_ts, + time.time_ns(), + ) + + def _next_spike(self, elapsed): + if not self.ntrode_ids: + return None + spike_sample_now = int(elapsed * self._spike_clock) + # Find any ntrode whose next-spike sample has arrived. + for ntid in self.ntrode_ids: + if self._next_spike_sample[ntid] <= spike_sample_now: + ts = self._next_spike_sample[ntid] + self._schedule_next_spike(ntid, sample_now=spike_sample_now) + # mark vector: gaussian around _amp, all channels positive + samples = ( + self._rng.standard_normal(self._mark_dim) * 8.0 + self._amp + ) / self._p['voltage_scaling_factor'] + # SpikePoint.data is later multiplied by voltage_scaling_factor + # in real Trodes; the encoder reads `max(mark_vec)` so we just + # need the post-scaling magnitudes to clear `encoder.spk_amp`. + return SpikePoint( + ts, + ntid, + samples * self._p['voltage_scaling_factor'], + time.time_ns(), + time.time_ns(), + ) + return None + + def _next_position(self, elapsed): + target_idx = int(elapsed * self._fs) + if target_idx < self._next_sample_idx: + return None + idx = self._next_sample_idx + self._next_sample_idx += 1 + + # triangle-wave walk along a single linear segment between 0 and + # track_length_cm + L = self._p['track_length_cm'] + v = self._p['walk_speed_cm_s'] + t = elapsed + period = 2.0 * L / max(v, 1e-6) + phase = (t % period) / period # 0..1 + pos_cm = L * (1.0 - abs(2.0 * phase - 1.0)) + # x/y/x2/y2 in "pixel" units — kinematics.scale_factor converts back + sf = self.config['kinematics']['scale_factor'] + x = pos_cm / sf + y = 100.0 # constant + x2 = x + 5.0 + y2 = y + return CameraModulePoint( + idx, + segment=0, + position=pos_cm, + x=x, + y=y, + x2=x2, + y2=y2, + t_recv_data=time.time_ns(), + ) + + # ------------------------------------------------------------------ + # internals + # ------------------------------------------------------------------ + + def _schedule_next_spike(self, ntid, *, sample_now): + rate = max(self._p['spike_rate_hz'], 1e-6) + # exponential inter-arrival in seconds → samples + gap_s = self._rng.exponential(1.0 / rate) + gap_samples = max(1, int(gap_s * self._spike_clock)) + self._next_spike_sample[ntid] = sample_now + gap_samples + + +class SyntheticClient(object): + """Drop-in synthetic replacement for ``trodesnet.TrodesClient``. + + Exposes the same surface used by the supervisor and stim decider: + * ``set_startup_callback`` / ``set_termination_callback`` + * ``receive`` (called from the supervisor main loop) + * ``send_statescript_shortcut_message`` (called from stim_decider) + + ``receive`` fires the startup callback once after ``startup_delay_s`` + of wall clock has elapsed, and fires termination once ``run_duration_s`` + has elapsed. + """ + + def __init__(self, config): + self._startup_callback = utils.nop + self._termination_callback = utils.nop + self._p = _params(config) + self._t0_wall = time.time() + self._started = False + self._terminated = False + # log-only buffer of "shortcut messages" the stim decider would + # have sent to ECU; useful for asserting in tests later. + self.sent_shortcuts = [] + + def receive(self): + elapsed = time.time() - self._t0_wall + if not self._started and elapsed >= self._p['startup_delay_s']: + self._started = True + self._startup_callback() + if ( + self._started + and not self._terminated + and self._p['run_duration_s'] > 0 + and elapsed >= self._p['run_duration_s'] + self._p['startup_delay_s'] + ): + self._terminated = True + self._termination_callback() + + def send_statescript_shortcut_message(self, val): + self.sent_shortcuts.append((time.time_ns(), int(val))) + + def set_startup_callback(self, callback): + self._startup_callback = callback + + def set_termination_callback(self, callback): + self._termination_callback = callback diff --git a/runscript.py b/runscript.py index 3e920c8..5b93548 100644 --- a/runscript.py +++ b/runscript.py @@ -11,12 +11,30 @@ from mpi4py import MPI from realtime_decoder import ( - datatypes, position, trodesnet, stimulation, + datatypes, position, trodesnet, synthetic, stimulation, main_process, ripple_process, encoder_process, decoder_process, gui_process, base, messages, merge_rec ) + +def _data_source_factory(config): + """Pick the (receiver_class, client_class) pair for the configured + data source. + + `datasource: trodes` (default) uses the live Trodes streams. + `datasource: synthetic` uses the in-process generator from + `realtime_decoder.synthetic` — install-and-run with no hardware. + """ + ds = config.get('datasource', 'trodes') + if ds == 'trodes': + return trodesnet.TrodesDataReceiver, trodesnet.TrodesClient + if ds == 'synthetic': + return synthetic.SyntheticDataReceiver, synthetic.SyntheticClient + raise ValueError( + f"Unknown datasource {ds!r}; expected 'trodes' or 'synthetic'" + ) + # from line_profiler import LineProfiler class GuiProcessStub(base.RealtimeProcess, base.MessageHandler): @@ -169,21 +187,23 @@ def setup(config_path, numprocs): regloop = True ################################################# + DataReceiver, Client = _data_source_factory(config) + if rank in config['rank']['supervisor']: - trodes_client = trodesnet.TrodesClient(config) + net_client = Client(config) stim_decider = stimulation.TwoArmTrodesStimDecider( - comm, rank, config, trodes_client + comm, rank, config, net_client ) process = main_process.MainProcess( - comm, rank, config, stim_decider, trodes_client + comm, rank, config, stim_decider, net_client ) - trodes_client.set_startup_callback(process.startup) - trodes_client.set_termination_callback(process.trigger_termination) + net_client.set_startup_callback(process.startup) + net_client.set_termination_callback(process.trigger_termination) elif rank in config['rank']['ripples']: - lfp_interface = trodesnet.TrodesDataReceiver( + lfp_interface = DataReceiver( comm, rank, config, datatypes.Datatypes.LFP ) - pos_interface = trodesnet.TrodesDataReceiver( + pos_interface = DataReceiver( comm, rank, config, datatypes.Datatypes.LINEAR_POSITION ) process = ripple_process.RippleProcess( @@ -196,10 +216,10 @@ def setup(config_path, numprocs): # prof.print_stats() # regloop = False elif rank in config['rank']['encoders']: - spikes_interface = trodesnet.TrodesDataReceiver( + spikes_interface = DataReceiver( comm, rank, config, datatypes.Datatypes.SPIKES ) - pos_interface = trodesnet.TrodesDataReceiver( + pos_interface = DataReceiver( comm, rank, config, datatypes.Datatypes.LINEAR_POSITION ) pos_mapper = position.TrodesPositionMapper( @@ -211,7 +231,7 @@ def setup(config_path, numprocs): pos_mapper ) elif rank in config['rank']['decoders']: - pos_interface = trodesnet.TrodesDataReceiver( + pos_interface = DataReceiver( comm, rank, config, datatypes.Datatypes.LINEAR_POSITION ) pos_mapper = position.TrodesPositionMapper(