Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 26 additions & 9 deletions dataframely/_storage/parquet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) QuantCo 2025-2026
# SPDX-License-Identifier: BSD-3-Clause

import inspect
from collections.abc import Iterable
from typing import Any

Expand Down Expand Up @@ -53,13 +54,13 @@ def write_frame(
def scan_frame(self, **kwargs: Any) -> tuple[pl.LazyFrame, SerializedSchema | None]:
source = kwargs.pop("source")
lf = pl.scan_parquet(source, **kwargs)
metadata = _read_serialized_schema(source)
metadata = _read_serialized_schema(source, **_metadata_read_options(kwargs))
return lf, metadata

def read_frame(self, **kwargs: Any) -> tuple[pl.DataFrame, SerializedSchema | None]:
source = kwargs.pop("source")
df = pl.read_parquet(source, **kwargs)
metadata = _read_serialized_schema(source)
metadata = _read_serialized_schema(source, **_metadata_read_options(kwargs))
return df, metadata

# ------------------------------ Collections ---------------------------------------
Expand Down Expand Up @@ -147,6 +148,7 @@ def _collection_from_parquet(
# between lazy and eager reads
data = {}
collection_types = []
metadata_options = _metadata_read_options(kwargs)

fs: AbstractFileSystem = url_to_fs(path)[0]
for key in members:
Expand All @@ -159,12 +161,16 @@ def _collection_from_parquet(
else pl.read_parquet(scan_path, **kwargs).lazy()
)
if is_file:
collection_types.append(_read_serialized_collection(source_path))
collection_types.append(
_read_serialized_collection(source_path, **metadata_options)
)
else:
prefix = get_file_prefix(fs)
for file in fs.glob(fs.sep.join([source_path, "**", "*.parquet"])):
collection_types.append(
_read_serialized_collection(f"{prefix}{file}")
_read_serialized_collection(
f"{prefix}{file}", **metadata_options
)
)

return data, collection_types
Expand Down Expand Up @@ -234,7 +240,7 @@ def scan_failure_info(
file = kwargs.pop("file")

# Meta data
metadata = pl.read_parquet_metadata(file)
metadata = pl.read_parquet_metadata(file, **_metadata_read_options(kwargs))
serialized_schema = assert_failure_info_metadata(
metadata.get(SCHEMA_METADATA_KEY)
)
Expand All @@ -245,11 +251,22 @@ def scan_failure_info(
return lf, serialized_rules, serialized_schema


def _read_serialized_collection(path: str) -> SerializedCollection | None:
meta = pl.read_parquet_metadata(path)
_METADATA_READ_PARAMS = frozenset(
inspect.signature(pl.read_parquet_metadata).parameters
) - {"source"}


def _metadata_read_options(kwargs: dict[str, Any]) -> dict[str, Any]:
return {k: v for k, v in kwargs.items() if k in _METADATA_READ_PARAMS}


def _read_serialized_collection(
path: str, **read_options: Any
) -> SerializedCollection | None:
meta = pl.read_parquet_metadata(path, **read_options)
return meta.get(COLLECTION_METADATA_KEY)


def _read_serialized_schema(path: str) -> SerializedSchema | None:
meta = pl.read_parquet_metadata(path)
def _read_serialized_schema(path: str, **read_options: Any) -> SerializedSchema | None:
meta = pl.read_parquet_metadata(path, **read_options)
return meta.get(SCHEMA_METADATA_KEY)
5 changes: 4 additions & 1 deletion dataframely/collection/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1273,17 +1273,20 @@ def _validate_input_keys(cls, data: Mapping[str, FrameType], /) -> None:

def read_parquet_metadata_collection(
source: str | Path | IO[bytes] | bytes,
**kwargs: Any,
) -> type[Collection] | None:
"""Read a dataframely Collection type from the metadata of a parquet file.

Args:
source: Path to a parquet file or a file-like object that contains the metadata.
kwargs: Additional keyword arguments passed directly to
:meth:`polars.read_parquet_metadata`.

Returns:
The collection that was serialized to the metadata. `None` if no collection
metadata is found or the deserialization fails.
"""
metadata = pl.read_parquet_metadata(source)
metadata = pl.read_parquet_metadata(source, **kwargs)
if (schema_metadata := metadata.get(COLLECTION_METADATA_KEY)) is not None:
return deserialize_collection(schema_metadata, strict=False)
return None
Expand Down
5 changes: 4 additions & 1 deletion dataframely/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -1416,17 +1416,20 @@ def _rules_match(lhs: dict[str, Rule], rhs: dict[str, Rule]) -> bool:

def read_parquet_metadata_schema(
source: str | Path | IO[bytes] | bytes,
**kwargs: Any,
) -> type[Schema] | None:
"""Read a dataframely schema from the metadata of a parquet file.

Args:
source: Path to a parquet file or a file-like object that contains the metadata.
kwargs: Additional keyword arguments passed directly to
:meth:`polars.read_parquet_metadata`.

Returns:
The schema that was serialized to the metadata. `None` if no schema metadata
is found or the deserialization fails.
"""
metadata = pl.read_parquet_metadata(source)
metadata = pl.read_parquet_metadata(source, **kwargs)

if (schema_metadata := metadata.get(SCHEMA_METADATA_KEY)) is not None:
return deserialize_schema(schema_metadata, strict=False)
Expand Down
24 changes: 24 additions & 0 deletions tests/collection/test_storage.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) QuantCo 2025-2026
# SPDX-License-Identifier: BSD-3-Clause

import uuid
from typing import Any

import polars as pl
Expand Down Expand Up @@ -498,6 +499,29 @@ def test_read_invalid_parquet_metadata_collection(
assert collection is None


@pytest.mark.s3
def test_read_parquet_metadata_collection_uses_storage_options(
s3_isolated: tuple[str, dict[str, str]],
) -> None:
"""`read_parquet_metadata_collection` must forward `storage_options` to the read."""
# Arrange
bucket, storage_options = s3_isolated
path = f"{bucket}/{uuid.uuid4()}/df.parquet"
pl.DataFrame({"a": [1, 2, 3]}).write_parquet(
path,
metadata={COLLECTION_METADATA_KEY: MyCollection.serialize()},
storage_options=storage_options,
)

# Act
collection = dy.read_parquet_metadata_collection(
path, storage_options=storage_options
)

# Assert
assert collection is not None


@pytest.mark.parametrize(
"any_tmp_path",
["tmp_path", pytest.param("s3_tmp_path", marks=pytest.mark.s3)],
Expand Down
43 changes: 43 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,49 @@ def s3_tmp_path(s3_server: str, s3_bucket: str, monkeypatch: pytest.MonkeyPatch)
return f"{s3_bucket}/{str(uuid.uuid4())}"


@pytest.fixture()
def s3_isolated(
s3_server: str, monkeypatch: pytest.MonkeyPatch
) -> Iterator[tuple[str, dict[str, str]]]:
"""A freshly-named bucket that is only reachable via the returned ``storage_options``.

Polars caches object stores per bucket, and these caches live Rust-side and are not
cleared by ``monkeypatch.delenv``. A bucket configured once from ``AWS_*`` env vars
(e.g. by :func:`s3_tmp_path`) therefore stays reachable without ``storage_options``,
which would let a read silently succeed even if ``storage_options`` was dropped. A
unique bucket has no such cached store, so reaching it requires forwarding
``storage_options`` to every read.
"""
for var in (
"AWS_ENDPOINT_URL",
"AWS_ACCESS_KEY_ID",
"AWS_SECRET_ACCESS_KEY",
"AWS_ALLOW_HTTP",
"AWS_S3_ALLOW_UNSAFE_RENAME",
"AWS_DEFAULT_REGION",
"AWS_REGION",
):
monkeypatch.delenv(var, raising=False)
Comment on lines +60 to +69

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this has the desired effect. Polars likely only reads these variables at startup and does not re-read them for every read. Could we check that?

I don't have a good alternative idea but I think all of those tests are moot if I'm right.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! So yes, the tests were indeed moot as Polars caches an object store per bucket which isn't cleared by this monkeypatch. Other tests use the same bucket so it stayed reachable even without storage_options. This is per bucket though so I fixed this by creating a fresh bucket that is never env configured hence no cache and so requires storage_options forwarding. Ran the tests without the fix which then failed as expected. Let me know what you think.

bucket = f"isolated-{uuid.uuid4()}"
client = boto3.client(
"s3", endpoint_url=s3_server, aws_access_key_id="", aws_secret_access_key=""
)
client.create_bucket(Bucket=bucket)
yield (
f"s3://{bucket}",
{
"aws_access_key_id": "testing",
"aws_secret_access_key": "testing",
"aws_endpoint_url": s3_server,
"aws_region": "us-east-1",
"aws_allow_http": "true",
},
)
for obj in client.list_objects_v2(Bucket=bucket).get("Contents", []):
client.delete_object(Bucket=bucket, Key=obj["Key"])
client.delete_bucket(Bucket=bucket)


@pytest.fixture()
def any_tmp_path(request: pytest.FixtureRequest) -> str:
return str(request.getfixturevalue(request.param))
31 changes: 31 additions & 0 deletions tests/failure_info/test_storage.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) QuantCo 2025-2026
# SPDX-License-Identifier: BSD-3-Clause

import uuid

import polars as pl
import pytest
from fsspec import AbstractFileSystem, url_to_fs
Comment thread
mattijsdp marked this conversation as resolved.
Expand Down Expand Up @@ -165,6 +167,35 @@ def test_write_parquet_custom_metadata(
assert pl.read_parquet_metadata(p)["custom"] == "test"


@pytest.mark.s3
@pytest.mark.parametrize("lazy", [True, False])
def test_read_parquet_uses_storage_options_for_metadata(
s3_isolated: tuple[str, dict[str, str]],
lazy: bool,
) -> None:
"""`storage_options` must reach the rule/schema metadata read, not just the data
read."""
# Arrange
bucket, storage_options = s3_isolated
df = pl.DataFrame({"a": [4, 5, 6, 6, 7, 8], "b": [1, 2, 3, 4, 5, 6]})
_, failure = MySchema.filter(df)
path = f"{bucket}/{uuid.uuid4()}/failure.parquet"
failure.write_parquet(path, storage_options=storage_options)

# Act
# Reading failure info always reads the rule/schema metadata, so a successful read
# proves the metadata read used the forwarded `storage_options`.
if lazy:
read = dy.FailureInfo.scan_parquet(path, storage_options=storage_options)
else:
read = dy.FailureInfo.read_parquet(path, storage_options=storage_options)

# Assert
assert_frame_equal(failure._lf, read._lf)
Comment thread
mattijsdp marked this conversation as resolved.
assert failure._rule_columns == read._rule_columns
assert MySchema.matches(read.schema)


def test_write_parquet_fails_without_mkdir(tmp_path: str) -> None:
# Arrange
df = pl.DataFrame(
Expand Down
55 changes: 55 additions & 0 deletions tests/schema/test_read_write_parquet.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# Copyright (c) QuantCo 2025-2026
# SPDX-License-Identifier: BSD-3-Clause

import uuid
from pathlib import Path

import polars as pl
import pytest
from fsspec import url_to_fs
from polars.testing import assert_frame_equal

import dataframely as dy
from dataframely._storage.parquet import SCHEMA_METADATA_KEY
Expand Down Expand Up @@ -59,3 +61,56 @@ def test_write_parquet_fails_without_mkdir(tmp_path: str) -> None:
# Act / Assert
with pytest.raises(FileNotFoundError):
MySchema.write_parquet(df, file=p)


# --------------------------------- STORAGE OPTIONS ---------------------------------- #


@pytest.mark.s3
@pytest.mark.parametrize("lazy", [True, False])
def test_read_parquet_uses_storage_options_for_metadata(
s3_isolated: tuple[str, dict[str, str]],
lazy: bool,
) -> None:
"""`storage_options` must reach the embedded schema metadata read, not just the
data read."""
# Arrange
bucket, storage_options = s3_isolated
df = MySchema.validate(pl.DataFrame({"a": [1, 2, 3]}), cast=True)
path = f"{bucket}/{uuid.uuid4()}/df.parquet"
MySchema.write_parquet(df, file=path, storage_options=storage_options)

# Act
# `validation="forbid"` only returns if the metadata schema is read and matches, so
# a passing read proves the metadata read used the forwarded `storage_options`.
if lazy:
out: pl.DataFrame = MySchema.scan_parquet(
path, validation="forbid", storage_options=storage_options
).collect()
else:
out = MySchema.read_parquet(
path, validation="forbid", storage_options=storage_options
)

# Assert
assert_frame_equal(df, out)


@pytest.mark.s3
def test_read_parquet_metadata_schema_uses_storage_options(
s3_isolated: tuple[str, dict[str, str]],
) -> None:
"""`read_parquet_metadata_schema` must forward `storage_options` to the read."""
# Arrange
bucket, storage_options = s3_isolated
path = f"{bucket}/{uuid.uuid4()}/df.parquet"
MySchema.write_parquet(
MySchema.create_empty(), file=path, storage_options=storage_options
)

# Act
schema = dy.read_parquet_metadata_schema(path, storage_options=storage_options)

# Assert
assert schema is not None
assert schema.matches(MySchema)
Loading