diff --git a/realtime_decoder/decoder_process.py b/realtime_decoder/decoder_process.py index 2823924..c5ba64d 100644 --- a/realtime_decoder/decoder_process.py +++ b/realtime_decoder/decoder_process.py @@ -3,6 +3,8 @@ import glob import numpy as np +from mpi4py import MPI + from realtime_decoder import ( base, utils, position, messages, transitions, binary_record, taskstate @@ -18,11 +20,17 @@ class DecoderMPISendInterface(base.StandardMPISendInterface): def __init__(self, comm, rank, config): super().__init__(comm, rank, config) + # NOTE: each send path used to call msg.tobytes() per tick, which + # allocates a fresh bytes object every time. Sending [msg, MPI.BYTE] + # hands MPI the numpy buffer directly — zero-copy, no GC pressure. + # Wire format is unchanged (raw bytes), so receivers built around + # `bytearray(... .itemsize)` + `np.frombuffer` continue to work. + def send_posterior(self, dest, msg): """Send a message containing posterior data""" self.comm.Send( - buf=msg.tobytes(), + buf=[msg, MPI.BYTE], dest=dest, tag=messages.MPIMessageTag.POSTERIOR ) @@ -32,7 +40,7 @@ def send_velocity_position(self, dest, msg): velocity data""" self.comm.Send( - msg.tobytes(), + [msg, MPI.BYTE], dest=dest, tag=messages.MPIMessageTag.VEL_POS ) @@ -42,7 +50,7 @@ def send_dropped_spikes(self, dest, msg): spikes""" self.comm.Send( - msg.tobytes(), + [msg, MPI.BYTE], dest=dest, tag=messages.MPIMessageTag.DROPPED_SPIKES ) @@ -461,6 +469,22 @@ def __init__( self._init_timings() self._set_up_trodes() + # Pre-allocated scratch buffers reused on every LFP tick (~167 Hz at + # 30 kHz spike clock with a 180-sample bin). Allocating these per + # tick generates ~500+ short-lived numpy objects/sec, which is the + # kind of churn that triggers gen-2 GC pauses long enough to show up + # as the 100 ms tail-latency spikes the lab has seen. Reusing them + # in-place keeps the steady-state allocation rate near zero. + self._enc_cred_intervals = np.zeros( + self.p['cred_int_bufsize'], dtype=int + ) + self._enc_argmaxes = np.zeros( + self.p['cred_int_bufsize'], dtype=int + ) + self._spike_mask = np.zeros( + self._spike_buf.shape[0], dtype=bool + ) + def next_iter(self): """Run one iteration processing any available neural data""" @@ -808,32 +832,35 @@ def _process_lfp_timestamp(self, timestamp): """Process a new LFP timestamp by triggering an updated estimate of the posterior""" - # these are default values. if there are relevant spikes - # in the time bin of interest, these will be populated - # accordingly - enc_cred_intervals = np.zeros(self.p['cred_int_bufsize'], dtype=int) - enc_argmaxes = np.zeros(self.p['cred_int_bufsize'], dtype=int) + # Reuse persistent scratch buffers (allocated once in __init__) + # rather than allocating per-tick. Zeroing in place is ~free; the + # numpy/gc overhead of `np.zeros(N)` at ~167 Hz is not. + self._enc_cred_intervals.fill(0) + self._enc_argmaxes.fill(0) lb = int(timestamp - self.p['tbin_delay_samples'] - self.p['tbin_samples']) ub = int(timestamp - self.p['tbin_delay_samples']) + # Compute the bin-membership mask in place into a preallocated + # bool array. `np.logical_and(out=...)` avoids the temporary the + # default form allocates. spikes_in_bin_mask = np.logical_and( self._spike_buf[:, 0] >= lb, - self._spike_buf[:, 0] < ub + self._spike_buf[:, 0] < ub, + out=self._spike_mask, ) - if np.sum(spikes_in_bin_mask) > 0: + if np.any(spikes_in_bin_mask): # these spikes are being used. mark them with a 1 self._spike_buf[spikes_in_bin_mask, 4] = 1 - spikes_before = np.atleast_2d( - self._spike_buf[spikes_in_bin_mask] - ) + # Boolean indexing a 2D array with a 1D bool mask already + # returns 2D — np.atleast_2d here was a no-op that added an + # extra array wrapper per tick. Drop it. + spikes_before = self._spike_buf[spikes_in_bin_mask] unique_inds = self._get_unique(spikes_before[:, 0]) #NOTE(DS): to get rid of duplicated spikes - spikes_after = np.atleast_2d( - spikes_before[unique_inds] - ) + spikes_after = spikes_before[unique_inds] num_before = len(spikes_before) num_after = len(spikes_after) @@ -848,10 +875,13 @@ def _process_lfp_timestamp(self, timestamp): # main process will check for non-nan elements order = np.argsort(spikes_after[:, 0]) ordered_spikes = spikes_after[order] + cred_int_max = self.p['cred_int_max'] + cred_int_bufsize = self.p['cred_int_bufsize'] for ii, data in enumerate(ordered_spikes): - if data[3] <= self.p['cred_int_max']: - enc_cred_intervals[ii % self.p['cred_int_bufsize']] = data[1] - enc_argmaxes[ii % self.p['cred_int_bufsize']] = np.argmax(data[5:]) + if data[3] <= cred_int_max: + slot = ii % cred_int_bufsize + self._enc_cred_intervals[slot] = data[1] + self._enc_argmaxes[slot] = np.argmax(data[5:]) # Note: the decoder can automatically handle the no-spike case spikes_in_bin_count = num_after @@ -867,7 +897,7 @@ def _process_lfp_timestamp(self, timestamp): spikes_in_bin_count = 0 t0 = time.time_ns() posterior, likelihood = self._decoder.compute_posterior( - np.atleast_2d(self._spike_buf[spikes_in_bin_mask]) + self._spike_buf[spikes_in_bin_mask] ) t1 = time.time_ns() self._time_posterior(lb, ub, t0, t1) @@ -890,8 +920,8 @@ def _process_lfp_timestamp(self, timestamp): self._posterior_msg[0]['velocity'] = self._current_vel self._posterior_msg[0]['cred_int_post'] = cred_int_post self._posterior_msg[0]['cred_int_lk'] = cred_int_lk - self._posterior_msg[0]['enc_cred_intervals'] = enc_cred_intervals - self._posterior_msg[0]['enc_argmaxes'] = enc_argmaxes + self._posterior_msg[0]['enc_cred_intervals'] = self._enc_cred_intervals + self._posterior_msg[0]['enc_argmaxes'] = self._enc_argmaxes self._posterior_msg[0]['spike_count'] = spikes_in_bin_count self.send_interface.send_posterior( self._config['rank']['supervisor'][0], self._posterior_msg