Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 52 additions & 22 deletions realtime_decoder/decoder_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import glob
import numpy as np

from mpi4py import MPI

from realtime_decoder import (
base, utils, position, messages, transitions,
binary_record, taskstate
Expand All @@ -18,11 +20,17 @@ class DecoderMPISendInterface(base.StandardMPISendInterface):
def __init__(self, comm, rank, config):
super().__init__(comm, rank, config)

# NOTE: each send path used to call msg.tobytes() per tick, which
# allocates a fresh bytes object every time. Sending [msg, MPI.BYTE]
# hands MPI the numpy buffer directly — zero-copy, no GC pressure.
# Wire format is unchanged (raw bytes), so receivers built around
# `bytearray(... .itemsize)` + `np.frombuffer` continue to work.

def send_posterior(self, dest, msg):
"""Send a message containing posterior data"""

self.comm.Send(
buf=msg.tobytes(),
buf=[msg, MPI.BYTE],
dest=dest,
tag=messages.MPIMessageTag.POSTERIOR
)
Expand All @@ -32,7 +40,7 @@ def send_velocity_position(self, dest, msg):
velocity data"""

self.comm.Send(
msg.tobytes(),
[msg, MPI.BYTE],
dest=dest,
tag=messages.MPIMessageTag.VEL_POS
)
Expand All @@ -42,7 +50,7 @@ def send_dropped_spikes(self, dest, msg):
spikes"""

self.comm.Send(
msg.tobytes(),
[msg, MPI.BYTE],
dest=dest,
tag=messages.MPIMessageTag.DROPPED_SPIKES
)
Expand Down Expand Up @@ -461,6 +469,22 @@ def __init__(
self._init_timings()
self._set_up_trodes()

# Pre-allocated scratch buffers reused on every LFP tick (~167 Hz at
# 30 kHz spike clock with a 180-sample bin). Allocating these per
# tick generates ~500+ short-lived numpy objects/sec, which is the
# kind of churn that triggers gen-2 GC pauses long enough to show up
# as the 100 ms tail-latency spikes the lab has seen. Reusing them
# in-place keeps the steady-state allocation rate near zero.
self._enc_cred_intervals = np.zeros(
self.p['cred_int_bufsize'], dtype=int
)
self._enc_argmaxes = np.zeros(
self.p['cred_int_bufsize'], dtype=int
)
self._spike_mask = np.zeros(
self._spike_buf.shape[0], dtype=bool
)

def next_iter(self):
"""Run one iteration processing any available neural data"""

Expand Down Expand Up @@ -808,32 +832,35 @@ def _process_lfp_timestamp(self, timestamp):
"""Process a new LFP timestamp by triggering an updated
estimate of the posterior"""

# these are default values. if there are relevant spikes
# in the time bin of interest, these will be populated
# accordingly
enc_cred_intervals = np.zeros(self.p['cred_int_bufsize'], dtype=int)
enc_argmaxes = np.zeros(self.p['cred_int_bufsize'], dtype=int)
# Reuse persistent scratch buffers (allocated once in __init__)
# rather than allocating per-tick. Zeroing in place is ~free; the
# numpy/gc overhead of `np.zeros(N)` at ~167 Hz is not.
self._enc_cred_intervals.fill(0)
self._enc_argmaxes.fill(0)

lb = int(timestamp - self.p['tbin_delay_samples'] - self.p['tbin_samples'])
ub = int(timestamp - self.p['tbin_delay_samples'])
# Compute the bin-membership mask in place into a preallocated
# bool array. `np.logical_and(out=...)` avoids the temporary the
# default form allocates.
spikes_in_bin_mask = np.logical_and(
self._spike_buf[:, 0] >= lb,
self._spike_buf[:, 0] < ub
self._spike_buf[:, 0] < ub,
out=self._spike_mask,
)

if np.sum(spikes_in_bin_mask) > 0:
if np.any(spikes_in_bin_mask):

# these spikes are being used. mark them with a 1
self._spike_buf[spikes_in_bin_mask, 4] = 1

spikes_before = np.atleast_2d(
self._spike_buf[spikes_in_bin_mask]
)
# Boolean indexing a 2D array with a 1D bool mask already
# returns 2D — np.atleast_2d here was a no-op that added an
# extra array wrapper per tick. Drop it.
spikes_before = self._spike_buf[spikes_in_bin_mask]

unique_inds = self._get_unique(spikes_before[:, 0]) #NOTE(DS): to get rid of duplicated spikes
spikes_after = np.atleast_2d(
spikes_before[unique_inds]
)
spikes_after = spikes_before[unique_inds]

num_before = len(spikes_before)
num_after = len(spikes_after)
Expand All @@ -848,10 +875,13 @@ def _process_lfp_timestamp(self, timestamp):
# main process will check for non-nan elements
order = np.argsort(spikes_after[:, 0])
ordered_spikes = spikes_after[order]
cred_int_max = self.p['cred_int_max']
cred_int_bufsize = self.p['cred_int_bufsize']
for ii, data in enumerate(ordered_spikes):
if data[3] <= self.p['cred_int_max']:
enc_cred_intervals[ii % self.p['cred_int_bufsize']] = data[1]
enc_argmaxes[ii % self.p['cred_int_bufsize']] = np.argmax(data[5:])
if data[3] <= cred_int_max:
slot = ii % cred_int_bufsize
self._enc_cred_intervals[slot] = data[1]
self._enc_argmaxes[slot] = np.argmax(data[5:])

# Note: the decoder can automatically handle the no-spike case
spikes_in_bin_count = num_after
Expand All @@ -867,7 +897,7 @@ def _process_lfp_timestamp(self, timestamp):
spikes_in_bin_count = 0
t0 = time.time_ns()
posterior, likelihood = self._decoder.compute_posterior(
np.atleast_2d(self._spike_buf[spikes_in_bin_mask])
self._spike_buf[spikes_in_bin_mask]
)
t1 = time.time_ns()
self._time_posterior(lb, ub, t0, t1)
Expand All @@ -890,8 +920,8 @@ def _process_lfp_timestamp(self, timestamp):
self._posterior_msg[0]['velocity'] = self._current_vel
self._posterior_msg[0]['cred_int_post'] = cred_int_post
self._posterior_msg[0]['cred_int_lk'] = cred_int_lk
self._posterior_msg[0]['enc_cred_intervals'] = enc_cred_intervals
self._posterior_msg[0]['enc_argmaxes'] = enc_argmaxes
self._posterior_msg[0]['enc_cred_intervals'] = self._enc_cred_intervals
self._posterior_msg[0]['enc_argmaxes'] = self._enc_argmaxes
self._posterior_msg[0]['spike_count'] = spikes_in_bin_count
self.send_interface.send_posterior(
self._config['rank']['supervisor'][0], self._posterior_msg
Expand Down