diff --git a/dataframely/_storage/parquet.py b/dataframely/_storage/parquet.py index 655c595..805efbd 100644 --- a/dataframely/_storage/parquet.py +++ b/dataframely/_storage/parquet.py @@ -1,6 +1,7 @@ # Copyright (c) QuantCo 2025-2026 # SPDX-License-Identifier: BSD-3-Clause +import inspect from collections.abc import Iterable from typing import Any @@ -53,13 +54,13 @@ def write_frame( def scan_frame(self, **kwargs: Any) -> tuple[pl.LazyFrame, SerializedSchema | None]: source = kwargs.pop("source") lf = pl.scan_parquet(source, **kwargs) - metadata = _read_serialized_schema(source) + metadata = _read_serialized_schema(source, **_metadata_read_options(kwargs)) return lf, metadata def read_frame(self, **kwargs: Any) -> tuple[pl.DataFrame, SerializedSchema | None]: source = kwargs.pop("source") df = pl.read_parquet(source, **kwargs) - metadata = _read_serialized_schema(source) + metadata = _read_serialized_schema(source, **_metadata_read_options(kwargs)) return df, metadata # ------------------------------ Collections --------------------------------------- @@ -147,6 +148,7 @@ def _collection_from_parquet( # between lazy and eager reads data = {} collection_types = [] + metadata_options = _metadata_read_options(kwargs) fs: AbstractFileSystem = url_to_fs(path)[0] for key in members: @@ -159,12 +161,16 @@ def _collection_from_parquet( else pl.read_parquet(scan_path, **kwargs).lazy() ) if is_file: - collection_types.append(_read_serialized_collection(source_path)) + collection_types.append( + _read_serialized_collection(source_path, **metadata_options) + ) else: prefix = get_file_prefix(fs) for file in fs.glob(fs.sep.join([source_path, "**", "*.parquet"])): collection_types.append( - _read_serialized_collection(f"{prefix}{file}") + _read_serialized_collection( + f"{prefix}{file}", **metadata_options + ) ) return data, collection_types @@ -234,7 +240,7 @@ def scan_failure_info( file = kwargs.pop("file") # Meta data - metadata = pl.read_parquet_metadata(file) + metadata = pl.read_parquet_metadata(file, **_metadata_read_options(kwargs)) serialized_schema = assert_failure_info_metadata( metadata.get(SCHEMA_METADATA_KEY) ) @@ -245,11 +251,22 @@ def scan_failure_info( return lf, serialized_rules, serialized_schema -def _read_serialized_collection(path: str) -> SerializedCollection | None: - meta = pl.read_parquet_metadata(path) +_METADATA_READ_PARAMS = frozenset( + inspect.signature(pl.read_parquet_metadata).parameters +) - {"source"} + + +def _metadata_read_options(kwargs: dict[str, Any]) -> dict[str, Any]: + return {k: v for k, v in kwargs.items() if k in _METADATA_READ_PARAMS} + + +def _read_serialized_collection( + path: str, **read_options: Any +) -> SerializedCollection | None: + meta = pl.read_parquet_metadata(path, **read_options) return meta.get(COLLECTION_METADATA_KEY) -def _read_serialized_schema(path: str) -> SerializedSchema | None: - meta = pl.read_parquet_metadata(path) +def _read_serialized_schema(path: str, **read_options: Any) -> SerializedSchema | None: + meta = pl.read_parquet_metadata(path, **read_options) return meta.get(SCHEMA_METADATA_KEY) diff --git a/dataframely/collection/collection.py b/dataframely/collection/collection.py index 43aab21..6476f20 100644 --- a/dataframely/collection/collection.py +++ b/dataframely/collection/collection.py @@ -1273,17 +1273,20 @@ def _validate_input_keys(cls, data: Mapping[str, FrameType], /) -> None: def read_parquet_metadata_collection( source: str | Path | IO[bytes] | bytes, + **kwargs: Any, ) -> type[Collection] | None: """Read a dataframely Collection type from the metadata of a parquet file. Args: source: Path to a parquet file or a file-like object that contains the metadata. + kwargs: Additional keyword arguments passed directly to + :meth:`polars.read_parquet_metadata`. Returns: The collection that was serialized to the metadata. `None` if no collection metadata is found or the deserialization fails. """ - metadata = pl.read_parquet_metadata(source) + metadata = pl.read_parquet_metadata(source, **kwargs) if (schema_metadata := metadata.get(COLLECTION_METADATA_KEY)) is not None: return deserialize_collection(schema_metadata, strict=False) return None diff --git a/dataframely/schema.py b/dataframely/schema.py index c7fc92c..d624ae6 100644 --- a/dataframely/schema.py +++ b/dataframely/schema.py @@ -1416,17 +1416,20 @@ def _rules_match(lhs: dict[str, Rule], rhs: dict[str, Rule]) -> bool: def read_parquet_metadata_schema( source: str | Path | IO[bytes] | bytes, + **kwargs: Any, ) -> type[Schema] | None: """Read a dataframely schema from the metadata of a parquet file. Args: source: Path to a parquet file or a file-like object that contains the metadata. + kwargs: Additional keyword arguments passed directly to + :meth:`polars.read_parquet_metadata`. Returns: The schema that was serialized to the metadata. `None` if no schema metadata is found or the deserialization fails. """ - metadata = pl.read_parquet_metadata(source) + metadata = pl.read_parquet_metadata(source, **kwargs) if (schema_metadata := metadata.get(SCHEMA_METADATA_KEY)) is not None: return deserialize_schema(schema_metadata, strict=False) diff --git a/tests/collection/test_storage.py b/tests/collection/test_storage.py index 9554ca5..58823a7 100644 --- a/tests/collection/test_storage.py +++ b/tests/collection/test_storage.py @@ -1,6 +1,7 @@ # Copyright (c) QuantCo 2025-2026 # SPDX-License-Identifier: BSD-3-Clause +import uuid from typing import Any import polars as pl @@ -498,6 +499,29 @@ def test_read_invalid_parquet_metadata_collection( assert collection is None +@pytest.mark.s3 +def test_read_parquet_metadata_collection_uses_storage_options( + s3_isolated: tuple[str, dict[str, str]], +) -> None: + """`read_parquet_metadata_collection` must forward `storage_options` to the read.""" + # Arrange + bucket, storage_options = s3_isolated + path = f"{bucket}/{uuid.uuid4()}/df.parquet" + pl.DataFrame({"a": [1, 2, 3]}).write_parquet( + path, + metadata={COLLECTION_METADATA_KEY: MyCollection.serialize()}, + storage_options=storage_options, + ) + + # Act + collection = dy.read_parquet_metadata_collection( + path, storage_options=storage_options + ) + + # Assert + assert collection is not None + + @pytest.mark.parametrize( "any_tmp_path", ["tmp_path", pytest.param("s3_tmp_path", marks=pytest.mark.s3)], diff --git a/tests/conftest.py b/tests/conftest.py index 75828f5..4f84c67 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -44,6 +44,49 @@ def s3_tmp_path(s3_server: str, s3_bucket: str, monkeypatch: pytest.MonkeyPatch) return f"{s3_bucket}/{str(uuid.uuid4())}" +@pytest.fixture() +def s3_isolated( + s3_server: str, monkeypatch: pytest.MonkeyPatch +) -> Iterator[tuple[str, dict[str, str]]]: + """A freshly-named bucket that is only reachable via the returned ``storage_options``. + + Polars caches object stores per bucket, and these caches live Rust-side and are not + cleared by ``monkeypatch.delenv``. A bucket configured once from ``AWS_*`` env vars + (e.g. by :func:`s3_tmp_path`) therefore stays reachable without ``storage_options``, + which would let a read silently succeed even if ``storage_options`` was dropped. A + unique bucket has no such cached store, so reaching it requires forwarding + ``storage_options`` to every read. + """ + for var in ( + "AWS_ENDPOINT_URL", + "AWS_ACCESS_KEY_ID", + "AWS_SECRET_ACCESS_KEY", + "AWS_ALLOW_HTTP", + "AWS_S3_ALLOW_UNSAFE_RENAME", + "AWS_DEFAULT_REGION", + "AWS_REGION", + ): + monkeypatch.delenv(var, raising=False) + bucket = f"isolated-{uuid.uuid4()}" + client = boto3.client( + "s3", endpoint_url=s3_server, aws_access_key_id="", aws_secret_access_key="" + ) + client.create_bucket(Bucket=bucket) + yield ( + f"s3://{bucket}", + { + "aws_access_key_id": "testing", + "aws_secret_access_key": "testing", + "aws_endpoint_url": s3_server, + "aws_region": "us-east-1", + "aws_allow_http": "true", + }, + ) + for obj in client.list_objects_v2(Bucket=bucket).get("Contents", []): + client.delete_object(Bucket=bucket, Key=obj["Key"]) + client.delete_bucket(Bucket=bucket) + + @pytest.fixture() def any_tmp_path(request: pytest.FixtureRequest) -> str: return str(request.getfixturevalue(request.param)) diff --git a/tests/failure_info/test_storage.py b/tests/failure_info/test_storage.py index 10b3e7d..3823342 100644 --- a/tests/failure_info/test_storage.py +++ b/tests/failure_info/test_storage.py @@ -1,6 +1,8 @@ # Copyright (c) QuantCo 2025-2026 # SPDX-License-Identifier: BSD-3-Clause +import uuid + import polars as pl import pytest from fsspec import AbstractFileSystem, url_to_fs @@ -165,6 +167,35 @@ def test_write_parquet_custom_metadata( assert pl.read_parquet_metadata(p)["custom"] == "test" +@pytest.mark.s3 +@pytest.mark.parametrize("lazy", [True, False]) +def test_read_parquet_uses_storage_options_for_metadata( + s3_isolated: tuple[str, dict[str, str]], + lazy: bool, +) -> None: + """`storage_options` must reach the rule/schema metadata read, not just the data + read.""" + # Arrange + bucket, storage_options = s3_isolated + df = pl.DataFrame({"a": [4, 5, 6, 6, 7, 8], "b": [1, 2, 3, 4, 5, 6]}) + _, failure = MySchema.filter(df) + path = f"{bucket}/{uuid.uuid4()}/failure.parquet" + failure.write_parquet(path, storage_options=storage_options) + + # Act + # Reading failure info always reads the rule/schema metadata, so a successful read + # proves the metadata read used the forwarded `storage_options`. + if lazy: + read = dy.FailureInfo.scan_parquet(path, storage_options=storage_options) + else: + read = dy.FailureInfo.read_parquet(path, storage_options=storage_options) + + # Assert + assert_frame_equal(failure._lf, read._lf) + assert failure._rule_columns == read._rule_columns + assert MySchema.matches(read.schema) + + def test_write_parquet_fails_without_mkdir(tmp_path: str) -> None: # Arrange df = pl.DataFrame( diff --git a/tests/schema/test_read_write_parquet.py b/tests/schema/test_read_write_parquet.py index 4b83031..22bae98 100644 --- a/tests/schema/test_read_write_parquet.py +++ b/tests/schema/test_read_write_parquet.py @@ -1,11 +1,13 @@ # Copyright (c) QuantCo 2025-2026 # SPDX-License-Identifier: BSD-3-Clause +import uuid from pathlib import Path import polars as pl import pytest from fsspec import url_to_fs +from polars.testing import assert_frame_equal import dataframely as dy from dataframely._storage.parquet import SCHEMA_METADATA_KEY @@ -59,3 +61,56 @@ def test_write_parquet_fails_without_mkdir(tmp_path: str) -> None: # Act / Assert with pytest.raises(FileNotFoundError): MySchema.write_parquet(df, file=p) + + +# --------------------------------- STORAGE OPTIONS ---------------------------------- # + + +@pytest.mark.s3 +@pytest.mark.parametrize("lazy", [True, False]) +def test_read_parquet_uses_storage_options_for_metadata( + s3_isolated: tuple[str, dict[str, str]], + lazy: bool, +) -> None: + """`storage_options` must reach the embedded schema metadata read, not just the + data read.""" + # Arrange + bucket, storage_options = s3_isolated + df = MySchema.validate(pl.DataFrame({"a": [1, 2, 3]}), cast=True) + path = f"{bucket}/{uuid.uuid4()}/df.parquet" + MySchema.write_parquet(df, file=path, storage_options=storage_options) + + # Act + # `validation="forbid"` only returns if the metadata schema is read and matches, so + # a passing read proves the metadata read used the forwarded `storage_options`. + if lazy: + out: pl.DataFrame = MySchema.scan_parquet( + path, validation="forbid", storage_options=storage_options + ).collect() + else: + out = MySchema.read_parquet( + path, validation="forbid", storage_options=storage_options + ) + + # Assert + assert_frame_equal(df, out) + + +@pytest.mark.s3 +def test_read_parquet_metadata_schema_uses_storage_options( + s3_isolated: tuple[str, dict[str, str]], +) -> None: + """`read_parquet_metadata_schema` must forward `storage_options` to the read.""" + # Arrange + bucket, storage_options = s3_isolated + path = f"{bucket}/{uuid.uuid4()}/df.parquet" + MySchema.write_parquet( + MySchema.create_empty(), file=path, storage_options=storage_options + ) + + # Act + schema = dy.read_parquet_metadata_schema(path, storage_options=storage_options) + + # Assert + assert schema is not None + assert schema.matches(MySchema)