diff --git a/src/mdio/segy/_workers.py b/src/mdio/segy/_workers.py index ae105222..a6721052 100644 --- a/src/mdio/segy/_workers.py +++ b/src/mdio/segy/_workers.py @@ -6,7 +6,9 @@ from typing import TYPE_CHECKING import numpy as np +from segy import SegyFile from segy.arrays import HeaderArray +from zarr import open_group as zarr_open_group from mdio.core.config import MDIOSettings from mdio.segy._raw_trace_wrapper import SegyFileRawTraceWrapper @@ -14,7 +16,6 @@ from mdio.segy.file import SegyFileWrapper if TYPE_CHECKING: - from segy import SegyFile from zarr import Array as zarr_Array from zarr.core.config import config as zarr_config @@ -68,30 +69,70 @@ def header_scan_worker( return HeaderArray(trace_header) # wrap back so we can use aliases -def trace_worker( # noqa: PLR0913 - segy_file: SegyFile, - data_array: zarr_Array, - header_array: zarr_Array | None, - raw_header_array: zarr_Array | None, - region: dict[str, slice], +# Per-worker process state populated once by `trace_worker_init`. Keeping the SEG-Y handle, +# Zarr array handles, and the (compressed, in-memory) grid map here lets us pickle them a single +# time per worker via the pool initializer instead of once per submitted block. The grid map is +# retained as a compressed in-memory Zarr array and sliced lazily per region, so each worker only +# materializes its own block rather than the full dense map. +_worker_state: dict[str, object] = {} + + +def trace_worker_init( # noqa: PLR0913 + segy_file_kwargs: SegyFileArguments, + output_path: str, + storage_options: dict[str, object] | None, + use_consolidated: bool, + data_variable_name: str, grid_map: zarr_Array, -) -> SummaryStatistics | None: +) -> None: + """Initialize per-process state for trace ingestion workers. + + Used as the `ProcessPoolExecutor` initializer so the SEG-Y file, Zarr output handles, and grid + map are opened/transferred once per worker process rather than re-pickled for every block. + + Args: + segy_file_kwargs: Arguments to open the SegyFile instance. + output_path: POSIX path to the output MDIO Zarr store. + storage_options: fsspec storage options for the output store. + use_consolidated: Whether to open the group with consolidated metadata (Zarr V2). + data_variable_name: Name of the data variable in the dataset. + grid_map: Compressed in-memory Zarr array mapping live traces to their positions. + """ + # Setting the zarr config to 1 thread to ensure we honor the `MDIO__IMPORT__CPU_COUNT` environment variable. + # The Zarr 3 engine utilizes multiple threads. This can lead to resource contention and unpredictable memory usage. + zarr_config.set({"threading.max_workers": 1}) + + zarr_group = zarr_open_group( + output_path, + mode="r+", + storage_options=storage_options, + use_consolidated=use_consolidated, + ) + + _worker_state["segy_file"] = SegyFile(**segy_file_kwargs) + _worker_state["data_array"] = zarr_group[data_variable_name] + _worker_state["header_array"] = zarr_group.get("headers") + _worker_state["raw_header_array"] = zarr_group.get("raw_headers") + _worker_state["grid_map"] = grid_map + + +def trace_worker(region: dict[str, slice]) -> SummaryStatistics | None: """Writes a subset of traces from a region of the dataset of Zarr file. + Reads its shared inputs (SEG-Y handle, Zarr arrays, grid map) from the per-process state set up + by `trace_worker_init`, so only the lightweight `region` is pickled per block. + Args: - segy_file: The opened SEG-Y file. - data_array: Zarr array for writing trace data. - header_array: Zarr array for writing trace headers (or None if not needed). - raw_header_array: Zarr array for writing raw headers (or None if not needed). region: Region of the dataset to write to. - grid_map: Zarr array mapping live traces to their positions in the dataset. Returns: SummaryStatistics object containing statistics about the written traces. """ - # Setting the zarr config to 1 thread to ensure we honor the `MDIO__IMPORT__CPU_COUNT` environment variable. - # The Zarr 3 engine utilizes multiple threads. This can lead to resource contention and unpredictable memory usage. - zarr_config.set({"threading.max_workers": 1}) + segy_file: SegyFile = _worker_state["segy_file"] + data_array: zarr_Array = _worker_state["data_array"] + header_array: zarr_Array | None = _worker_state["header_array"] + raw_header_array: zarr_Array | None = _worker_state["raw_header_array"] + grid_map: zarr_Array = _worker_state["grid_map"] region_slices = tuple(region.values()) local_grid_map = grid_map[region_slices[:-1]] # minus last (vertical) axis diff --git a/src/mdio/segy/blocked_io.py b/src/mdio/segy/blocked_io.py index 891b1f43..2eddef4c 100644 --- a/src/mdio/segy/blocked_io.py +++ b/src/mdio/segy/blocked_io.py @@ -12,7 +12,6 @@ import zarr from dask.array import Array from dask.array import map_blocks -from segy import SegyFile from tqdm.auto import tqdm from zarr import open_group as zarr_open_group @@ -23,6 +22,7 @@ from mdio.core.config import MDIOSettings from mdio.core.indexing import ChunkIterator from mdio.segy._workers import trace_worker +from mdio.segy._workers import trace_worker_init from mdio.segy.creation import SegyPartRecord from mdio.segy.creation import concat_files from mdio.segy.creation import serialize_to_segy_stack @@ -82,47 +82,44 @@ def to_zarr( # noqa: PLR0913, PLR0915 num_chunks = chunk_iter.num_chunks zarr_format = zarr.config.get("default_zarr_format") + use_consolidated = zarr_format == ZarrFormat.V2 - # Open zarr group once in main process + # Open zarr group once in main process (used for final stats update below). storage_options = _normalize_storage_options(output_path) zarr_group = zarr_open_group( output_path.as_posix(), mode="r+", storage_options=storage_options, - use_consolidated=zarr_format == ZarrFormat.V2, + use_consolidated=use_consolidated, ) - # Get array handles from the opened group - data_array = zarr_group[data_variable_name] - header_array = zarr_group.get("headers") - raw_header_array = zarr_group.get("raw_headers") - # For Unix async writes with s3fs/fsspec & multiprocessing, use 'spawn' instead of default # 'fork' to avoid deadlocks on cloud stores. Slower but necessary. Default on Windows. num_workers = min(num_chunks, settings.import_cpus) context = mp.get_context("spawn") - # Use initializer to open segy file once per worker + # Open the SEG-Y file, Zarr output handles, and transfer the compressed grid map once per worker + # via the initializer. The grid map stays a compressed in-memory Zarr array and is sliced lazily + # inside each worker, so we avoid both re-pickling per block and materializing the full dense map. executor = ProcessPoolExecutor( max_workers=num_workers, mp_context=context, + initializer=trace_worker_init, + initargs=( + segy_file_kwargs, + output_path.as_posix(), + storage_options, + use_consolidated, + data_variable_name, + grid_map, + ), ) - segy_file = SegyFile(**segy_file_kwargs) - with executor: futures = [] for region in chunk_iter: - # Pass zarr array handles directly to workers - future = executor.submit( - trace_worker, - segy_file, - data_array, - header_array, - raw_header_array, - region, - grid_map, - ) + # Only the lightweight region is pickled per block; shared inputs live in worker state. + future = executor.submit(trace_worker, region) futures.append(future) iterable = tqdm(