Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 57 additions & 16 deletions src/mdio/segy/_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,16 @@
from typing import TYPE_CHECKING

import numpy as np
from segy import SegyFile
from segy.arrays import HeaderArray
from zarr import open_group as zarr_open_group

from mdio.core.config import MDIOSettings
from mdio.segy._raw_trace_wrapper import SegyFileRawTraceWrapper
from mdio.segy.file import SegyFileArguments
from mdio.segy.file import SegyFileWrapper

if TYPE_CHECKING:
from segy import SegyFile
from zarr import Array as zarr_Array

from zarr.core.config import config as zarr_config
Expand Down Expand Up @@ -68,30 +69,70 @@ def header_scan_worker(
return HeaderArray(trace_header) # wrap back so we can use aliases


def trace_worker( # noqa: PLR0913
segy_file: SegyFile,
data_array: zarr_Array,
header_array: zarr_Array | None,
raw_header_array: zarr_Array | None,
region: dict[str, slice],
# Per-worker process state populated once by `trace_worker_init`. Keeping the SEG-Y handle,
# Zarr array handles, and the (compressed, in-memory) grid map here lets us pickle them a single
# time per worker via the pool initializer instead of once per submitted block. The grid map is
# retained as a compressed in-memory Zarr array and sliced lazily per region, so each worker only
# materializes its own block rather than the full dense map.
_worker_state: dict[str, object] = {}


def trace_worker_init( # noqa: PLR0913
segy_file_kwargs: SegyFileArguments,
output_path: str,
storage_options: dict[str, object] | None,
use_consolidated: bool,
data_variable_name: str,
grid_map: zarr_Array,
) -> SummaryStatistics | None:
) -> None:
"""Initialize per-process state for trace ingestion workers.

Used as the `ProcessPoolExecutor` initializer so the SEG-Y file, Zarr output handles, and grid
map are opened/transferred once per worker process rather than re-pickled for every block.

Args:
segy_file_kwargs: Arguments to open the SegyFile instance.
output_path: POSIX path to the output MDIO Zarr store.
storage_options: fsspec storage options for the output store.
use_consolidated: Whether to open the group with consolidated metadata (Zarr V2).
data_variable_name: Name of the data variable in the dataset.
grid_map: Compressed in-memory Zarr array mapping live traces to their positions.
"""
# Setting the zarr config to 1 thread to ensure we honor the `MDIO__IMPORT__CPU_COUNT` environment variable.
# The Zarr 3 engine utilizes multiple threads. This can lead to resource contention and unpredictable memory usage.
zarr_config.set({"threading.max_workers": 1})

zarr_group = zarr_open_group(
output_path,
mode="r+",
storage_options=storage_options,
use_consolidated=use_consolidated,
)

_worker_state["segy_file"] = SegyFile(**segy_file_kwargs)
_worker_state["data_array"] = zarr_group[data_variable_name]
_worker_state["header_array"] = zarr_group.get("headers")
_worker_state["raw_header_array"] = zarr_group.get("raw_headers")
_worker_state["grid_map"] = grid_map


def trace_worker(region: dict[str, slice]) -> SummaryStatistics | None:
"""Writes a subset of traces from a region of the dataset of Zarr file.

Reads its shared inputs (SEG-Y handle, Zarr arrays, grid map) from the per-process state set up
by `trace_worker_init`, so only the lightweight `region` is pickled per block.

Args:
segy_file: The opened SEG-Y file.
data_array: Zarr array for writing trace data.
header_array: Zarr array for writing trace headers (or None if not needed).
raw_header_array: Zarr array for writing raw headers (or None if not needed).
region: Region of the dataset to write to.
grid_map: Zarr array mapping live traces to their positions in the dataset.

Returns:
SummaryStatistics object containing statistics about the written traces.
"""
# Setting the zarr config to 1 thread to ensure we honor the `MDIO__IMPORT__CPU_COUNT` environment variable.
# The Zarr 3 engine utilizes multiple threads. This can lead to resource contention and unpredictable memory usage.
zarr_config.set({"threading.max_workers": 1})
segy_file: SegyFile = _worker_state["segy_file"]
data_array: zarr_Array = _worker_state["data_array"]
header_array: zarr_Array | None = _worker_state["header_array"]
raw_header_array: zarr_Array | None = _worker_state["raw_header_array"]
grid_map: zarr_Array = _worker_state["grid_map"]

region_slices = tuple(region.values())
local_grid_map = grid_map[region_slices[:-1]] # minus last (vertical) axis
Expand Down
39 changes: 18 additions & 21 deletions src/mdio/segy/blocked_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import zarr
from dask.array import Array
from dask.array import map_blocks
from segy import SegyFile
from tqdm.auto import tqdm
from zarr import open_group as zarr_open_group

Expand All @@ -23,6 +22,7 @@
from mdio.core.config import MDIOSettings
from mdio.core.indexing import ChunkIterator
from mdio.segy._workers import trace_worker
from mdio.segy._workers import trace_worker_init
from mdio.segy.creation import SegyPartRecord
from mdio.segy.creation import concat_files
from mdio.segy.creation import serialize_to_segy_stack
Expand Down Expand Up @@ -82,47 +82,44 @@ def to_zarr( # noqa: PLR0913, PLR0915
num_chunks = chunk_iter.num_chunks

zarr_format = zarr.config.get("default_zarr_format")
use_consolidated = zarr_format == ZarrFormat.V2

# Open zarr group once in main process
# Open zarr group once in main process (used for final stats update below).
storage_options = _normalize_storage_options(output_path)
zarr_group = zarr_open_group(
output_path.as_posix(),
mode="r+",
storage_options=storage_options,
use_consolidated=zarr_format == ZarrFormat.V2,
use_consolidated=use_consolidated,
)

# Get array handles from the opened group
data_array = zarr_group[data_variable_name]
header_array = zarr_group.get("headers")
raw_header_array = zarr_group.get("raw_headers")

# For Unix async writes with s3fs/fsspec & multiprocessing, use 'spawn' instead of default
# 'fork' to avoid deadlocks on cloud stores. Slower but necessary. Default on Windows.
num_workers = min(num_chunks, settings.import_cpus)
context = mp.get_context("spawn")

# Use initializer to open segy file once per worker
# Open the SEG-Y file, Zarr output handles, and transfer the compressed grid map once per worker
# via the initializer. The grid map stays a compressed in-memory Zarr array and is sliced lazily
# inside each worker, so we avoid both re-pickling per block and materializing the full dense map.
executor = ProcessPoolExecutor(
max_workers=num_workers,
mp_context=context,
initializer=trace_worker_init,
initargs=(
segy_file_kwargs,
output_path.as_posix(),
storage_options,
use_consolidated,
data_variable_name,
grid_map,
),
)

segy_file = SegyFile(**segy_file_kwargs)

with executor:
futures = []
for region in chunk_iter:
# Pass zarr array handles directly to workers
future = executor.submit(
trace_worker,
segy_file,
data_array,
header_array,
raw_header_array,
region,
grid_map,
)
# Only the lightweight region is pickled per block; shared inputs live in worker state.
future = executor.submit(trace_worker, region)
futures.append(future)

iterable = tqdm(
Expand Down