Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 29 additions & 2 deletions probeflow/gui/dialogs/fft_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,10 @@ def __init__(
self._mains_artists: list = []
self._mains_preview_active = False
self._mains_fast_axis = "x"
# User-placed streak pairs (positive |q| in nm⁻¹) and drag state.
self._mains_custom_q: list = []
self._mains_drag_idx = None
self._mains_tab_index = None
# Inverse FFT / Fourier reconstruction (Reconstruct tab).
self._fft_selection_overlay = None
self._reconstruct_tab_index = -1
Expand Down Expand Up @@ -798,7 +802,8 @@ def _build_fft_column(self) -> QVBoxLayout:

# Append the Mains tab last so existing tab indices (e.g. _grid_tab_index)
# are unaffected.
self._tab_widget.addTab(self._build_mains_tab(), "⚡ Mains")
self._mains_tab_index = self._tab_widget.addTab(
self._build_mains_tab(), "⚡ Mains")
self._reconstruct_tab_index = self._tab_widget.addTab(
self._build_reconstruct_tab(), "Inverse FFT")

Expand Down Expand Up @@ -1408,6 +1413,9 @@ def _on_press(self, event):
and event.button == 1 and self._fft_selection_overlay is not None):
if self._fft_selection_overlay.on_press(event):
return
# On the Mains tab, grabbing a custom streak line beats panning.
if self._mains_handle_press(event):
return
if (
event.inaxes is self._ax_fft
and event.button == 1
Expand All @@ -1422,10 +1430,13 @@ def _on_press(self, event):

def _on_release(self, event):
self._pan_anchor = None
self._mains_handle_release(event)
if self._fft_selection_overlay is not None:
self._fft_selection_overlay.on_release(event)

def _on_motion(self, event):
if self._mains_handle_motion(event):
return
if (self._fft_selection_overlay is not None
and self._fft_selection_overlay.is_dragging()):
self._fft_selection_overlay.on_motion(event)
Expand Down Expand Up @@ -1624,10 +1635,26 @@ def _on_fft_hist_range_released(self, lo_phys: float, hi_phys: float) -> None:
return
self._fft_drs.set_manual(lo_phys, hi_phys)

# Auto-contrast presets for the (log-scaled) FFT display. A single
# idempotent reset to 0–100 % meant repeated Auto clicks visibly did
# nothing; cycling through progressively tighter percentile windows
# gives each click an effect and returns to the full range.
_FFT_AUTO_PRESETS = (
(0.0, 100.0, "full range"),
(1.0, 99.5, "1–99.5 %"),
(5.0, 98.0, "5–98 %"),
)

def _reset_intensity(self) -> None:
if not self._fft_histogram_is_adjustable():
return
self._fft_drs.reset(0.0, 100.0)
idx = (getattr(self, "_fft_auto_idx", -1) + 1) % len(self._FFT_AUTO_PRESETS)
self._fft_auto_idx = idx
lo, hi, label = self._FFT_AUTO_PRESETS[idx]
self._fft_drs.reset(lo, hi)
status = getattr(self, "_mains_status_lbl", None)
if self._mains_tab_active() and status is not None:
status.setText(f"FFT auto contrast: {label}.")

def _update_info_panel(self):
Ny, Nx = self._arr.shape
Expand Down
207 changes: 190 additions & 17 deletions probeflow/gui/dialogs/fft_viewer_mains_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,11 @@ def _build_mains_tab(self) -> QWidget:
self._mains_radius_spin.setValue(3)
self._mains_radius_spin.setSuffix(" px")
self._mains_radius_spin.setToolTip(_tip(
"Width of each mains streak notch in FFT pixels."))
"Width of each mains streak notch in FFT pixels. The shaded band "
"around each overlay line shows this width on the FFT."))
# The overlay band must track the radius live, or the control reads
# as doing nothing.
self._mains_radius_spin.valueChanged.connect(self._on_mains_changed)
self._mains_radius_spin.setMaximumWidth(_FIELD_W)

self._mains_min_q_spin = QDoubleSpinBox()
Expand All @@ -162,12 +166,41 @@ def _build_mains_tab(self) -> QWidget:
sl.addWidget(self._mains_min_q_spin, 0, 3)
sl.setColumnStretch(4, 1)

self._mains_fill_cb = QCheckBox("Fill notches to background")
self._mains_fill_cb.setToolTip(_tip(
"Instead of zeroing the notched FFT bins (which leaves black "
"streak-shaped gaps and removes the genuine noise floor too), "
"bring them down to the local background magnitude — only the "
"excess pickup energy is removed and the FFT shows background "
"where the streak was."))
sl.addWidget(self._mains_fill_cb, 1, 0, 1, 5)

# ── custom streak pairs (pickup at non-mains frequencies) ───────────────
custom_row = QHBoxLayout()
self._mains_add_streak_btn = QPushButton("Add streak pair")
self._mains_add_streak_btn.setToolTip(_tip(
"Add a symmetric ±q pair of notch lines for a vertical streak "
"that is not at a mains harmonic. Drag either line on the FFT "
"to put it on the streak; the partner mirrors automatically."))
self._mains_add_streak_btn.clicked.connect(self._on_mains_add_streak)
self._mains_remove_streak_btn = QPushButton("Remove pair")
self._mains_remove_streak_btn.setEnabled(False)
self._mains_remove_streak_btn.setToolTip(_tip(
"Remove the most recently added/dragged custom streak pair."))
self._mains_remove_streak_btn.clicked.connect(self._on_mains_remove_streak)
for b in (self._mains_add_streak_btn, self._mains_remove_streak_btn):
b.setMaximumWidth(130)
custom_row.addWidget(self._mains_add_streak_btn)
custom_row.addWidget(self._mains_remove_streak_btn)
custom_row.addStretch(1)
sl.addLayout(custom_row, 2, 0, 1, 5)

self._mains_residual_cb = QCheckBox("Show residual (removed signal)")
self._mains_residual_cb.setToolTip(_tip(
"Preview the residual (original − filtered) instead of the filtered "
"image. Check it looks like noise/stripes, not real features, "
"before applying."))
sl.addWidget(self._mains_residual_cb, 1, 0, 1, 5)
sl.addWidget(self._mains_residual_cb, 3, 0, 1, 5)

btn_row = QHBoxLayout()
self._mains_preview_btn = QPushButton("Preview")
Expand All @@ -192,7 +225,7 @@ def _build_mains_tab(self) -> QWidget:
btn_row.addWidget(self._mains_clear_btn)
btn_row.addStretch(1)
btn_row.addWidget(self._mains_apply_btn)
sl.addLayout(btn_row, 2, 0, 1, 5)
sl.addLayout(btn_row, 4, 0, 1, 5)
lay.addWidget(sgrp)

lay.addStretch(1)
Expand Down Expand Up @@ -249,9 +282,24 @@ def _axis_segments_with_radial_floor(
segments.append((max(lo, gap), hi))
return [(a, b) for a, b in segments if b > a]

def _mains_q_per_px(self) -> float:
"""FFT-pixel spacing of the fast axis in nm⁻¹ (0.0 when unknown)."""
axis = self._qx if self._mains_fast_axis == "x" else self._qy
if axis is None or len(axis) < 2:
return 0.0
return float(abs(axis[1] - axis[0]))

def _mains_custom_streaks(self) -> list[float]:
"""Positive-|q| positions (nm⁻¹) of the user-placed streak pairs."""
if not hasattr(self, "_mains_custom_q"):
self._mains_custom_q = []
return self._mains_custom_q

def _draw_mains_overlay(self) -> None:
"""Vertical (or horizontal) lines at the predicted mains q positions.
"""Lines at the predicted mains q positions plus user streak pairs.

Each line carries a translucent band of half-width = notch radius so
the removal width is visible and tracks the radius control.
Rebuilt on every FFT redraw (the axes are cleared by ``ax.cla()``).
"""
self._mains_artists = []
Expand All @@ -260,23 +308,43 @@ def _draw_mains_overlay(self) -> None:
if not self._mains_overlay_cb.isChecked() or self._qx is None or self._qy is None:
return
min_q = self._mains_min_q_nm_inv()
for p in self._mains_predictions():
q = p["q_nm_inv"]
half_w = float(self._mains_radius_spin.value()) * self._mains_q_per_px()

def _draw_pair(q: float, color: str, style: str) -> None:
for qq in (q, -q):
if self._mains_fast_axis == "x":
for y0, y1 in self._axis_segments_with_radial_floor(qq, self._qy, min_q):
art, = self._ax_fft.plot(
[qq, qq], [y0, y1], color="#f9e2af", lw=0.9,
ls="--", alpha=0.85, zorder=7,
[qq, qq], [y0, y1], color=color, lw=0.9,
ls=style, alpha=0.85, zorder=7,
)
self._mains_artists.append(art)
if half_w > 0:
band = self._ax_fft.fill_betweenx(
[y0, y1], qq - half_w, qq + half_w,
color=color, alpha=0.16, lw=0, zorder=6,
)
self._mains_artists.append(band)
else:
for x0, x1 in self._axis_segments_with_radial_floor(qq, self._qx, min_q):
art, = self._ax_fft.plot(
[x0, x1], [qq, qq], color="#f9e2af", lw=0.9,
ls="--", alpha=0.85, zorder=7,
[x0, x1], [qq, qq], color=color, lw=0.9,
ls=style, alpha=0.85, zorder=7,
)
self._mains_artists.append(art)
if half_w > 0:
band = self._ax_fft.fill_between(
[x0, x1], qq - half_w, qq + half_w,
color=color, alpha=0.16, lw=0, zorder=6,
)
self._mains_artists.append(band)

for p in self._mains_predictions():
_draw_pair(p["q_nm_inv"], "#f9e2af", "--")
# User streak pairs: solid cyan so they read as hand-placed, not
# predicted; draggable on the FFT while the Mains tab is active.
for q in self._mains_custom_streaks():
_draw_pair(q, "#89dceb", "-")

def _on_mains_changed(self) -> None:
"""Fast path: refresh the overlay + status when a control changes."""
Expand All @@ -297,14 +365,20 @@ def _on_mains_changed(self) -> None:
def _update_mains_status(self) -> None:
if not getattr(self, "_mains_status_lbl", None):
return
n_custom = len(self._mains_custom_streaks())
custom_note = (
f" +{n_custom} custom streak pair(s)." if n_custom else ""
)
if self._mains_speed_m_per_s() is None:
self._mains_status_lbl.setText(
"Scan speed unavailable; enter nm/s to show the mains overlay.")
"Scan speed unavailable; enter nm/s to show the mains overlay."
+ custom_note)
return
preds = self._mains_predictions()
if not preds:
self._mains_status_lbl.setText(
"No mains harmonics fall within this FFT (check speed/frequency).")
"No mains harmonics fall within this FFT (check speed/frequency)."
+ custom_note)
return
src = "ROI" if self._fft_source == "active_roi" else "whole image"
if self._mains_harmonics() is None and len(preds) > 4:
Expand All @@ -319,6 +393,95 @@ def _update_mains_status(self) -> None:
floor = f" |q|≥{min_q:.2f} nm⁻¹." if min_q > 0 else ""
self._mains_status_lbl.setText(f"FFT source: {src}. " + " · ".join(parts) + floor)

# ── Custom streak pairs ───────────────────────────────────────────────────

def _mains_extra_streaks_px(self) -> list[int]:
dq = self._mains_q_per_px()
if dq <= 0:
return []
seen: set[int] = set()
out: list[int] = []
for q in self._mains_custom_streaks():
px = int(round(abs(float(q)) / dq))
if px > 0 and px not in seen:
seen.add(px)
out.append(px)
return out

def _on_mains_add_streak(self) -> None:
axis = self._qx if self._mains_fast_axis == "x" else self._qy
if axis is None or len(axis) < 2:
self._mains_status_lbl.setText("Open an FFT first.")
return
q_max = float(np.nanmax(np.abs(axis)))
# Stagger new pairs so consecutive adds don't stack invisibly.
n = len(self._mains_custom_streaks())
q_new = q_max * (0.35 + 0.12 * (n % 5))
self._mains_custom_streaks().append(q_new)
self._mains_remove_streak_btn.setEnabled(True)
self._on_mains_changed()
self._mains_status_lbl.setText(
f"Streak pair added at ±{q_new:.2f} nm⁻¹ — drag either line onto "
"the streak."
)

def _on_mains_remove_streak(self) -> None:
streaks = self._mains_custom_streaks()
if streaks:
streaks.pop()
self._mains_remove_streak_btn.setEnabled(bool(streaks))
self._on_mains_changed()

def _mains_tab_active(self) -> bool:
idx = getattr(self, "_mains_tab_index", None)
tabs = getattr(self, "_tab_widget", None)
return idx is not None and tabs is not None and tabs.currentIndex() == idx

def _mains_handle_press(self, event) -> bool:
"""Begin dragging the custom streak pair nearest the click, if close."""
self._mains_drag_idx = getattr(self, "_mains_drag_idx", None)
if (not self._mains_tab_active() or event.inaxes is not self._ax_fft
or event.button != 1 or not self._mains_custom_streaks()):
return False
coord = event.xdata if self._mains_fast_axis == "x" else event.ydata
if coord is None:
return False
axis = self._qx if self._mains_fast_axis == "x" else self._qy
tol = max(3.0 * self._mains_q_per_px(),
0.02 * float(np.nanmax(np.abs(axis))))
best_idx, best_dist = None, tol
for i, q in enumerate(self._mains_custom_streaks()):
dist = abs(abs(float(coord)) - q)
if dist <= best_dist:
best_idx, best_dist = i, dist
if best_idx is None:
return False
self._mains_drag_idx = best_idx
return True

def _mains_handle_motion(self, event) -> bool:
idx = getattr(self, "_mains_drag_idx", None)
if idx is None:
return False
if event.inaxes is not self._ax_fft:
return True # keep the drag captured while skimming the edge
coord = event.xdata if self._mains_fast_axis == "x" else event.ydata
if coord is None:
return True
dq = self._mains_q_per_px()
streaks = self._mains_custom_streaks()
if 0 <= idx < len(streaks):
streaks[idx] = max(abs(float(coord)), dq)
self._on_mains_changed()
return True

def _mains_handle_release(self, _event) -> bool:
if getattr(self, "_mains_drag_idx", None) is None:
return False
self._mains_drag_idx = None
self._update_mains_status()
return True

def _mains_op_params(self) -> dict:
params = {
"scan_speed_m_per_s": self._mains_speed_m_per_s(),
Expand All @@ -331,6 +494,12 @@ def _mains_op_params(self) -> dict:
"snap_window_px": 2,
"notch_shape": "streak",
"min_q_nm_inv": self._mains_min_q_nm_inv(),
"extra_streaks_px": self._mains_extra_streaks_px(),
"notch_fill": (
"background"
if getattr(self, "_mains_fill_cb", None) is not None
and self._mains_fill_cb.isChecked() else "zero"
),
"fft_source": self._fft_source,
}
if self._fft_source == "active_roi" and self._roi_id is not None:
Expand All @@ -339,8 +508,9 @@ def _mains_op_params(self) -> dict:

def _on_mains_preview(self) -> None:
v = self._mains_speed_m_per_s()
if not v:
self._mains_status_lbl.setText("Enter a scan speed (nm/s) first.")
if not v and not self._mains_extra_streaks_px():
self._mains_status_lbl.setText(
"Enter a scan speed (nm/s) or add a custom streak pair first.")
return
arr = self._get_image_fn() if self._get_image_fn is not None else self._full_arr
if arr is None:
Expand All @@ -354,7 +524,9 @@ def _on_mains_preview(self) -> None:
mains_frequency_hz=p["mains_frequency_hz"], harmonics=p["harmonics"],
notch_radius_px=p["notch_radius_px"], fast_axis=p["fast_axis"],
snap_window_px=p["snap_window_px"], notch_shape=p["notch_shape"],
min_q_nm_inv=p["min_q_nm_inv"])
min_q_nm_inv=p["min_q_nm_inv"],
extra_streaks_px=p["extra_streaks_px"],
notch_fill=p["notch_fill"])
except Exception as exc:
self._mains_status_lbl.setText(f"Preview failed: {exc}")
return
Expand All @@ -377,8 +549,9 @@ def _on_mains_apply(self) -> None:
if self._apply_correction_fn is None:
self._mains_status_lbl.setText("Apply is unavailable in this context.")
return
if self._mains_speed_m_per_s() is None:
self._mains_status_lbl.setText("Enter a scan speed (nm/s) first.")
if self._mains_speed_m_per_s() is None and not self._mains_extra_streaks_px():
self._mains_status_lbl.setText(
"Enter a scan speed (nm/s) or add a custom streak pair first.")
return
if self._mains_preview_active:
self._hide_fft_preview()
Expand Down
Loading
Loading