Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions TPTBox/core/bids_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -1750,15 +1750,15 @@ def filter(
"""
if self._flatten:
assert isinstance(self.candidates, list)
for bids_file in self.candidates.copy():
if not bids_file.do_filter(key, filter_fun, required=required):
self.candidates.remove(bids_file)
# list comprehension is O(n); the old copy()+list.remove() loop was O(n^2)
self.candidates = [f for f in self.candidates if f.do_filter(key, filter_fun, required=required)]
else:
assert isinstance(self.candidates, dict)
for sequences, bids_files in self.candidates.copy().items():
# print(sequences, list(bids_file.do_filter(key, filter_fun, required=required) for bids_file in bids_files))
if not any(bids_file.do_filter(key, filter_fun, required=required) for bids_file in bids_files):
self.candidates.pop(sequences)
self.candidates = {
seq: bids_files
for seq, bids_files in self.candidates.items()
if any(f.do_filter(key, filter_fun, required=required) for f in bids_files)
}

def filter_format(self, filter_fun: list[str] | str | typing.Callable[[str | object], bool]) -> None:
"""Keep only files whose format label satisfies *filter_fun*.
Expand Down Expand Up @@ -1807,15 +1807,15 @@ def filter_non_existence(
"""
if self._flatten:
assert isinstance(self.candidates, list)
for bids_file in self.candidates.copy():
if bids_file.do_filter(key, filter_fun, required=required):
self.candidates.remove(bids_file)
# list comprehension is O(n); the old copy()+list.remove() loop was O(n^2)
self.candidates = [f for f in self.candidates if not f.do_filter(key, filter_fun, required=required)]
else:
assert isinstance(self.candidates, dict)
for sequences, bids_files in self.candidates.copy().items():
# print(sequences, list(bids_file.do_filter(key, filter_fun, required=required) for bids_file in bids_files))
if any(bids_file.do_filter(key, filter_fun, required=required) for bids_file in bids_files):
self.candidates.pop(sequences)
self.candidates = {
seq: bids_files
for seq, bids_files in self.candidates.items()
if not any(f.do_filter(key, filter_fun, required=required) for f in bids_files)
}

def filter_dixon_only_inphase(self) -> None:
"""Remove Dixon files that are fat, water, out-of-phase, or difference images.
Expand Down
9 changes: 9 additions & 0 deletions TPTBox/core/nii_poi_abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,15 @@ def global_to_local(self, x: COORDINATE) -> tuple:
a = self.rotation.T @ (np.array(x) - self.origin) / np.array(self.zoom)
return tuple(round(float(v), 7) for v in a)

def global_to_local_arr(self, coords: np.ndarray) -> np.ndarray:
"""Vectorized :meth:`global_to_local` for an ``(N, 3)`` array of world coordinates.

Equivalent to applying ``global_to_local`` to each row but in a single batched
inverse-affine matmul.
"""
a = (np.asarray(coords, dtype=float) - np.asarray(self.origin)) @ np.asarray(self.rotation) / np.asarray(self.zoom)
return np.round(a, 7)

def local_to_global(self, x: COORDINATE) -> tuple:
"""Convert voxel (local) coordinates to world (RAS/LPS) coordinates.

Expand Down
53 changes: 31 additions & 22 deletions TPTBox/core/nii_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
np_filter_connected_components,
np_get_connected_components_center_of_mass,
np_is_empty,
np_isin,
np_map_labels,
np_map_labels_based_on_majority_label_mask_overlap,
np_point_coordinates,
Expand Down Expand Up @@ -2092,10 +2093,11 @@ def truncate_labels_beyond_reference_(
flip = self.orientation[axis_] != axis # Check orientation for flipping
# Get the array data
np_array = self.get_array()
np_array_cond = self.extract_label(idx).get_seg_array()
# both masks come directly from np_array via np_isin (avoids two extract_label round-trips)
np_array_cond = np_isin(np_array, idx)

# Find the lowest point (smallest index) along the axis where `not_above` exists
threshold = np.where(self.extract_label(not_beyond).get_seg_array() == 1)
threshold = np.where(np_isin(np_array, not_beyond))
if len(threshold[axis_]) == 0:
return self if inplace else self.copy()
flip_up = flip
Expand All @@ -2115,7 +2117,7 @@ def truncate_labels_beyond_reference_(
mask = np.broadcast_to(mask, self.shape)

# Replace values of `idx` with `fill` in the masked region
np_array = np.where((np_array_cond == 1) & mask, fill, np_array)
np_array = np.where(np_array_cond & mask, fill, np_array)

# Update the NIfTI object with the modified array
return self.set_array(np_array, inplace=inplace)
Expand Down Expand Up @@ -2253,7 +2255,8 @@ def map_labels(self, label_map:LABEL_MAP , verbose:logging=True, inplace=False)
If inplace is True, returns the current NIfTI image object with mapped labels. Otherwise, returns a new NIfTI image object with mapped labels.
"""
data_orig = self.get_seg_array()
labels_before = [v for v in np_unique(data_orig) if v > 0]
# the before/after np_unique scans are only used for the verbose log line; skip them otherwise
labels_before = [v for v in np_unique(data_orig) if v > 0] if verbose else None
# enforce keys to be str to support both str and int
label_map_ = {
(v_name2idx[k] if k in v_name2idx else int(k)): (
Expand All @@ -2263,15 +2266,16 @@ def map_labels(self, label_map:LABEL_MAP , verbose:logging=True, inplace=False)
}
log.print("label_map_ =", label_map_, verbose=verbose)
data = np_map_labels(data_orig, label_map_)
labels_after = [v for v in np_unique(data) if v > 0]
log.print(
"N =",
len(label_map_),
"labels reassigned, before labels: ",
labels_before,
" after: ",
labels_after,verbose=verbose
)
if verbose:
labels_after = [v for v in np_unique(data) if v > 0]
log.print(
"N =",
len(label_map_),
"labels reassigned, before labels: ",
labels_before,
" after: ",
labels_after,verbose=verbose
)
nii = data.astype(np.uint16), self.affine, self.header
if inplace:
self.nii = nii
Expand Down Expand Up @@ -2685,19 +2689,22 @@ def extract_label(self,label:int|Enum|Sequence[int]|Sequence[Enum]|None, keep_la
seg_arr = self.get_seg_array()

if isinstance(label, Sequence):
label_int:list[int] = [idx.value if isinstance(idx,Enum) else idx for idx in label]
assert 0 not in label_int, 'Zero label does not make sense. This is the background'
seg_arr = np_extract_label(seg_arr, label_int, to_label=1, inplace=True)
labels:int|list[int] = [idx.value if isinstance(idx,Enum) else idx for idx in label]
assert 0 not in labels, 'Zero label does not make sense. This is the background'
else:
if isinstance(label,Enum):
label = label.value
if isinstance(label,str):
label = int(label)

assert label != 0, 'Zero label does not make sense. This is the background'
seg_arr = np_extract_label(seg_arr, label, to_label=1, inplace=True)
labels = label
if keep_label:
seg_arr = seg_arr * self.get_seg_array()
# keep the original label values where in `labels`, zero everywhere else.
# single get_seg_array() copy + one np_isin mask (faster than extract + a second copy/multiply)
seg_arr[~np_isin(seg_arr, labels)] = 0
else:
seg_arr = np_extract_label(seg_arr, labels, to_label=1, inplace=True)
return self.set_array(seg_arr,inplace=inplace)
def ravel(self,order:Literal["K", "A", "C", "F"] | None="C")->np.ndarray:
"""Return a contiguous flattened array.
Expand All @@ -2719,15 +2726,17 @@ def extract_label_(self, label: int | Enum | Sequence[int] | Sequence[Enum], kee
def remove_labels(self,label:int|Enum|Sequence[int]|Sequence[Enum], inplace=False, verbose:logging=True, removed_to_label=0) -> Self:
"""If this NII is a segmentation you can single out one label."""
assert label != 0, 'Zero label does not make sens. This is the background'
seg_arr = self.get_seg_array()
if not isinstance(label,Sequence):
label = [label] # type: ignore
flat: list[int] = []
for l in label:
if isinstance(l, list):
for g in l:
seg_arr[seg_arr == g] = removed_to_label
flat.extend(g.value if isinstance(g, Enum) else g for g in l)
else:
seg_arr[seg_arr == l] = removed_to_label
flat.append(l.value if isinstance(l, Enum) else l)
# one np_map_labels gather is constant-time in the number of labels (a per-label
# `seg_arr == l` loop costs one full pass per label).
seg_arr = np_map_labels(self.get_seg_array(), dict.fromkeys(flat, removed_to_label))
return self.set_array(seg_arr,inplace=inplace, verbose=verbose)
def remove_labels_(self, label: int | Enum | Sequence[int] | Sequence[Enum], removed_to_label=0, verbose: logging = True) -> Self:
"""In-place variant of `remove_labels`."""
Expand Down
84 changes: 63 additions & 21 deletions TPTBox/core/np_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,40 @@
INTARRAY = Union[UINTARRAY, NDArray[INT]]


def np_isin(arr: np.ndarray, labels, invert: bool = False) -> np.ndarray:
"""Fast ``np.isin`` for non-negative integer label arrays via a boolean lookup table.

For unsigned-integer segmentation masks this is ~3-6x faster than ``np.isin`` when testing
membership in more than one label, because it replaces the general algorithm with a single
``lut[arr]`` gather. Falls back to ``np.isin`` for non-unsigned dtypes, negative labels, or
very large label ranges; uses ``arr == label`` for the single-label case.

Args:
arr (np.ndarray): Input array.
labels: A label or iterable of labels to test membership against.
invert (bool, optional): If True, return the complement (equivalent to
``np.isin(arr, labels, invert=True)``). Defaults to False.

Returns:
np.ndarray: Boolean mask, same shape as ``arr``.
"""
if not isinstance(labels, (list, tuple, np.ndarray)):
labels = [labels]
if len(labels) == 0:
return np.ones(arr.shape, dtype=bool) if invert else np.zeros(arr.shape, dtype=bool)
if len(labels) == 1:
res = arr == labels[0]
return ~res if invert else res
if np.issubdtype(arr.dtype, np.unsignedinteger) and min(int(x) for x in labels) >= 0:
m = max(int(arr.max()), int(max(labels))) + 1
if m < 2**20: # keep the lookup table small (same threshold as np_unique's bincount path)
lut = np.zeros(m, dtype=bool)
lut[np.asarray(labels)] = True
res = lut[arr]
return ~res if invert else res
return np.isin(arr, labels, invert=invert)


def np_extract_label(
arr: np.ndarray,
label: int | list[int],
Expand Down Expand Up @@ -69,7 +103,7 @@ def np_extract_label(

if isinstance(label, list):
assert 0 not in label, "label 0 is not supported in list mode"
arr_msk = np.isin(arr, label)
arr_msk = np_isin(arr, label)
arr[arr_msk] = to_label
arr[~arr_msk] = 0
return arr
Expand Down Expand Up @@ -125,10 +159,12 @@ def np_volume(arr: UINTARRAY, include_zero: bool = False) -> dict[int, int]:
Returns:
dict[int, int]: Mapping from label value to number of voxels with that label.
"""
# np.bincount wins decisively when there are many labels (e.g. connected-component maps);
# cc3d statistics is faster for the few-label case typical of anatomical segmentations.
counts = np.bincount(arr.ravel()) if int(arr.max()) > 256 else cc3dstatistics(arr, use_crop=not include_zero)["voxel_counts"]
if include_zero:
return {idx: i for idx, i in dict(enumerate(cc3dstatistics(arr, use_crop=False)["voxel_counts"])).items() if i > 0}
else:
return {idx: i for idx, i in dict(enumerate(cc3dstatistics(arr)["voxel_counts"])).items() if i > 0 and idx != 0}
return {idx: i for idx, i in enumerate(counts) if i > 0}
return {idx: i for idx, i in enumerate(counts) if i > 0 and idx != 0}
Comment on lines +162 to +167


def np_is_empty(arr: UINTARRAY | INTARRAY) -> bool:
Expand Down Expand Up @@ -253,8 +289,8 @@ def np_center_of_mass(arr: UINTARRAY) -> dict[int, COORDINATE]:
"""
stats = cc3dstatistics(arr, use_crop=False)
# Does not use the other calls for speed reasons
unique = [idx for idx, i in enumerate(stats["voxel_counts"]) if i > 0 and idx != 0]
return {idx: v for idx, v in enumerate(stats["centroids"]) if idx in unique}
vc = stats["voxel_counts"]
return {idx: v for idx, v in enumerate(stats["centroids"]) if idx != 0 and vc[idx] > 0}


def np_bounding_boxes(arr: UINTARRAY) -> dict[int, tuple[slice, slice, slice]]:
Expand All @@ -270,8 +306,8 @@ def np_bounding_boxes(arr: UINTARRAY) -> dict[int, tuple[slice, slice, slice]]:
"""
stats = cc3dstatistics(arr)
# Does not use the other calls for speed reasons
unique = [idx for idx, i in enumerate(stats["voxel_counts"]) if i > 0 and idx != 0]
return {idx: v for idx, v in enumerate(stats["bounding_boxes"]) if idx in unique}
vc = stats["voxel_counts"]
return {idx: v for idx, v in enumerate(stats["bounding_boxes"]) if idx != 0 and vc[idx] > 0}


def np_contacts(arr: UINTARRAY, connectivity: int) -> dict[tuple[int, int], int]:
Expand Down Expand Up @@ -383,14 +419,14 @@ def np_erode_msk_euclid(arr: np.ndarray, n_pixel: int = 3, use_crop=True, labels
if use_crop:
arr_bin = arr.copy()
if labels is not None:
arr_bin[np.isin(arr_bin, labels, invert=True)] = 0
arr_bin[np_isin(arr_bin, labels, invert=True)] = 0
crop = np_bbox_binary(arr_bin, px_dist=1 + n_pixel, raise_error=False)
arrc = arr[crop]
else:
arrc = arr
if labels is not None:
arrc = arrc.copy()
arrc[np.isin(arrc, labels, invert=True)] = 0
arrc[np_isin(arrc, labels, invert=True)] = 0

if mask is not None:
mask = mask.copy()
Expand Down Expand Up @@ -429,14 +465,14 @@ def np_dilate_msk_euclid(arr: np.ndarray, n_pixel: int = 3, use_crop=True, label
if use_crop:
arr_bin = arr.copy()
if labels is not None:
arr_bin[np.isin(arr_bin, labels, invert=True)] = 0
arr_bin[np_isin(arr_bin, labels, invert=True)] = 0
crop = np_bbox_binary(arr_bin, px_dist=1 + n_pixel, raise_error=False)
arrc = arr[crop]
else:
arrc = arr
if labels is not None:
arrc = arrc.copy()
arrc[np.isin(arr_bin, labels, invert=True)] = 0
arrc[np_isin(arr_bin, labels, invert=True)] = 0
Comment on lines 471 to +475
if mask is not None:
mask[mask != 0] = 1
if use_crop:
Expand Down Expand Up @@ -500,7 +536,7 @@ def np_dilate_msk(
if use_crop:
# try:
arr_bin = arr.copy()
arr_bin[np.isin(arr_bin, labels, invert=True)] = 0
arr_bin[np_isin(arr_bin, labels, invert=True)] = 0
crop = np_bbox_binary(arr_bin, px_dist=1 + n_pixel, raise_error=False)
arrc = arr[crop]
else:
Expand All @@ -521,8 +557,7 @@ def np_dilate_msk(
out = arrc
for _ in range(n_pixel):
for i in labels:
data = out.copy()
data[i != data] = 0
data = out == i # boolean mask; _binary_dilation casts to bool anyway, so this is exact and avoids a full copy
if use_crop:
lcrop = np_bbox_binary(data, px_dist=2 + n_pixel, raise_error=False)
data = data[lcrop]
Expand Down Expand Up @@ -575,7 +610,7 @@ def np_erode_msk(
labels: list[int] = _to_labels(arr, label_ref)

if use_crop:
crop = np_bbox_binary(np.isin(arr, labels, invert=False), px_dist=1 + n_pixel, raise_error=False)
crop = np_bbox_binary(np_isin(arr, labels, invert=False), px_dist=1 + n_pixel, raise_error=False)
arrc = arr[crop]
else:
arrc = arr
Expand Down Expand Up @@ -703,9 +738,16 @@ def np_bbox_binary(img: np.ndarray, px_dist: int | Sequence[int] | np.ndarray =
assert len(px_dist) == n, f"dimension mismatch, got img shape {shp} and px_dist {px_dist}"

bbox: list[float] = []
for ax in itertools.combinations(reversed(range(n)), n - 1):
nonzero = np.any(a=img, axis=ax)
bbox.extend(np.where(nonzero)[0][[0, -1]]) # type: ignore
if n == 3:
# 2 full passes instead of 3: two axis extents come from a shared 2D projection (cheap),
# only the third axis needs a second full reduction.
p = np.any(img, axis=2)
for nonzero in (np.any(p, axis=1), np.any(p, axis=0), np.any(img, axis=(0, 1))):
bbox.extend(np.where(nonzero)[0][[0, -1]]) # type: ignore
else:
for ax in itertools.combinations(reversed(range(n)), n - 1):
nonzero = np.any(a=img, axis=ax)
bbox.extend(np.where(nonzero)[0][[0, -1]]) # type: ignore
out: tuple[slice, ...] = tuple(
slice(
max(bbox[i] - px_dist[i // 2], 0),
Expand Down Expand Up @@ -867,7 +909,7 @@ def np_connected_components(
labels: Sequence[int] = _to_labels(arr, label_ref)
if include_zero:
arr[arr == 0] = arr.max() + 1
arr[np.isin(arr, labels, invert=True)] = 0
arr[np_isin(arr, labels, invert=True)] = 0
cc_map, n = _connected_components(arr, connectivity=connectivity, return_N=True)
return cc_map, n

Expand Down Expand Up @@ -952,7 +994,7 @@ def np_filter_connected_components(

arr2 = arr.copy()
labels: Sequence[int] = _to_labels(arr, label_ref)
arr2[np.isin(arr2, labels, invert=True)] = 0 # type:ignore
arr2[np_isin(arr2, labels, invert=True)] = 0 # type:ignore

labels_out, n = _connected_components(arr2, connectivity=connectivity, return_N=True)
largest_k_components_org = largest_k_components
Expand Down
Loading
Loading