Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions datashare-python/datashare_python/cli/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@
_TASK_ID_HELP = "task ID"
_WATCH_HELP = "watch a task until it's complete"

TaskArgs = str
StrTaskArgs = str

task_app = AsyncTyper(name="task")


@task_app.async_command(help=_START_HELP)
async def start(
name: Annotated[str, typer.Argument(help=_NAME_HELP)],
args: Annotated[TaskArgs, typer.Argument(help=_ARGS_HELP)] = None,
args: Annotated[StrTaskArgs, typer.Argument(help=_ARGS_HELP)] = None,
group: Annotated[
str | None,
typer.Option("--group", "-g", help=_GROUP_HELP),
Expand Down
1 change: 1 addition & 0 deletions datashare-python/datashare_python/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
DEFAULT_NAMESPACE = "datashare-default"

METADATA_JSON = "metadata.json"
MANIFEST_JSON = "manifest.json"

TIKA_METADATA_RESOURCENAME = "tika_metadata_resourcename"

Expand Down
2 changes: 1 addition & 1 deletion datashare-python/datashare_python/logging_.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def _encode_value(value: Any) -> str:
return "true" if value else "false"
if isinstance(value, numbers.Number):
return str(value)
return json.dumps(value).decode()
return json.dumps(value)


def _json_formatter(datefmt: str) -> BaseJsonFormatter:
Expand Down
131 changes: 125 additions & 6 deletions datashare-python/datashare_python/objects.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
import logging
import os
from abc import ABC
from asyncio import Lock
from collections.abc import Awaitable, Callable
from dataclasses import InitVar, dataclass, field
from datetime import UTC, datetime
from enum import StrEnum, unique
from io import BytesIO
from pathlib import Path
from typing import Annotated, Any, ClassVar, Literal, Self, TypeVar, cast
from typing import Annotated, Any, ClassVar, Generic, Literal, Self, TypeVar, cast

import langcodes
from icij_common.registrable import Registrable
from lru import LRU
from pydantic_core import PydanticCustomError, ValidationError, core_schema
from pydantic_core.core_schema import PlainValidatorFunctionSchema
from pydantic_core.core_schema import (
PlainValidatorFunctionSchema,
)
from pydantic_extra_types.language_code import LanguageName
from temporalio import workflow

Expand All @@ -34,12 +38,17 @@
from icij_common.pydantic_utils import (
icij_config,
lowercamel_case_config,
make_enum_discriminator,
merge_configs,
no_enum_values_config,
tagged_union,
)
from pydantic import (
AfterValidator,
AliasChoices,
BeforeValidator,
ConfigDict,
Discriminator,
Field,
GetCoreSchemaHandler,
TypeAdapter,
Expand Down Expand Up @@ -256,13 +265,123 @@ def _is_absolute_path(v: bytes | BytesIO | Path) -> Any:
return v


@dataclass(frozen=True)
class DocArtifact:
class ArtifactType(StrEnum):
STRUCTURE = "structure"
ASR_TRANSCRIPTION = "transcription"


class ManifestEntryStatus(StrEnum):
COMPLETE = "complete"


class TaskArgs(DatashareModel, ABC):
def as_manifest_task_input(self) -> dict[str, Any]:
# This is a base implementation, if the input is too large to be dumped,
# override this and pop large keys
as_manifest = self.model_dump(by_alias=True)
return as_manifest


A = TypeVar("A", bound=TaskArgs)


class ManifestEntry(DatashareModel, Generic[A], ABC):
status: ManifestEntryStatus
label: str | None = None
input: Annotated[
dict[str, Any] | None,
Field(
validation_alias=AliasChoices("taskInput", "input"),
serialization_alias="taskInput",
),
]

@classmethod
def complete(cls, args: A, label: str | None = None, **kwargs) -> Self:
return cls(
input=args.as_manifest_task_input(),
label=label,
status=ManifestEntryStatus.COMPLETE,
**kwargs,
)


class PaginationType(StrEnum):
FILESYSTEM = "filesystem"
BYTE_RANGES = "byteRanges"


class BasePagination(DatashareModel, Registrable, ABC):
registry_key: ClassVar[str] = Field(frozen=True, default="type")

total: int
type: ClassVar[PaginationType] = Field(frozen=True)


def _validate_pages_range(v: Any) -> None:
if not isinstance(v, list):
msg = f"expected a list, got {type(v)}"
raise TypeError(msg)
previous_end = None
for page_i, (start, end) in enumerate(v):
if not start <= end:
msg = "end of page must be >= start"
raise ValueError(msg)
if previous_end is not None and previous_end != start:
msg = (
f"start of page {page_i} doesn't match end of previous "
f"page {previous_end}"
)
raise ValueError(msg)
return v


PagesRange = Annotated[list[tuple[int, int]], AfterValidator(_validate_pages_range)]


@BasePagination.register(PaginationType.FILESYSTEM)
class FilesystemPagination(BasePagination):
type: ClassVar[PaginationType] = Field(
default=PaginationType.FILESYSTEM, frozen=True
)


@BasePagination.register(PaginationType.BYTE_RANGES)
class ByteRangesPagination(BasePagination):
type: ClassVar[PaginationType] = Field(
default=PaginationType.BYTE_RANGES, frozen=True
)
byte_ranges: PagesRange

@model_validator(mode="after")
def byte_ranges_length_should_match_total(self) -> Self:
if len(self.byte_ranges) != self.total:
msg = (
f"byte_ranges must match total. Found {len(self.byte_ranges)} for"
f" byte_ranges and {self.total} for total."
)
raise ValueError(msg)
return self


pagination_discriminator = make_enum_discriminator("type", PaginationType)
Pagination = Annotated[
tagged_union(BasePagination.__subclasses__(), lambda x: x.type),

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

crucial: the classvar + custom serializer route dumps type correctly, but the union still can't parse
it back. The tag extractor lambda x: x.type returns the FieldInfo (since type is a ClassVar set to
Field(...)), not the enum, so every Pagination payload fails with union_tag_invalid. Pulling the
default off the FieldInfo fixes both directions (verified round-trip on both subclasses):

Suggested change
tagged_union(BasePagination.__subclasses__(), lambda x: x.type),
tagged_union(BasePagination.__subclasses__(), lambda x: x.type.default),

A round-trip test through the Pagination adapter would lock this down.

Discriminator(pagination_discriminator),
]


class DocArtifact(BaseModel, ABC):
# This object is not used for serde, just as a container, it's OK to allow
# arbitrary types (to allow storing BytesIO)
model_config = ConfigDict(arbitrary_types_allowed=True)

project: str
doc_id: str
artifact: Annotated[bytes | BytesIO | Path, AfterValidator(_is_absolute_path)]
filename: str
metadata_key: str
filename: ClassVar[str] # Override this
type: ClassVar[ArtifactType] # Override this
manifest_entry: ManifestEntry


@unique
Expand Down
77 changes: 59 additions & 18 deletions datashare-python/datashare_python/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
SyncProgressRateHandler,
)

from .constants import METADATA_JSON
from .constants import MANIFEST_JSON, METADATA_JSON
from .objects import DocArtifact, DocumentLocation, FilesystemDocument
from .types_ import RawAsyncProgressHandler

Expand Down Expand Up @@ -338,34 +338,75 @@ def _metadata_path(doc_id: str, *, project: str) -> Path:
return metadata_path


def _read_artifact_metadata(root: Path, artifact: DocArtifact) -> dict:
m_path = root / _metadata_path(artifact.doc_id, project=artifact.project)
def _manifest_path(doc_id: str, *, project: str) -> Path:
manifest_path = artifacts_dir(doc_id, project=project) / MANIFEST_JSON
return manifest_path


def _read_artifact_manifest(root: Path, artifact: DocArtifact) -> dict:
m_path = root / _manifest_path(artifact.doc_id, project=artifact.project)
if not m_path.exists():
m_path = root / _metadata_path(artifact.doc_id, project=artifact.project)
if not m_path.exists():
msg = f"couldn't find manifest nor metadata for {artifact.doc_id}"
raise FileNotFoundError(msg)
return json.loads(m_path.read_text())


def write_artifact(root: Path, artifact: DocArtifact) -> Path:
# TODO: WARNING many writers could write at the time, to avoid inconsistent
# states we should handle this somehow
artif_dir = root / artifacts_dir(artifact.doc_id, project=artifact.project)
artif_dir.mkdir(exist_ok=True, parents=True)
# TODO: if transcriptions are too large we could also serialize them
# as jsonl
artifact_path: Path = artif_dir / artifact.filename
match artifact.artifact:
artifact_path = artif_dir / artifact.filename
# Read the metadata first (things could go wrong here in case someone is reading
# at the same time). We read in a backward compatible wat and write to that same
# location. We don't take responsibility for migrating the data, the DS back will
# do it
manifest_path, manifest = _read_manifest_backward_compatible(root, artifact)
is_legacy = manifest_path.name == "metadata.json"
# Pop the status key from the manifest before writing
manifest_entry = manifest.get(artifact.type)
if manifest_entry is not None and not is_legacy:
manifest[artifact.type].pop("status", None)
manifest_path.write_text(json.dumps(manifest))
Comment thread
ClemDoum marked this conversation as resolved.
# Write the artifact
_write_artifact_bytes(artifact_path, artifact.artifact)
# Update the manifest entry with details and new states
if is_legacy:
manifest_entry = str(artifact_path.relative_to(artif_dir))
else:
manifest_entry = artifact.manifest_entry.model_dump(mode="json", by_alias=True)
manifest[artifact.type] = manifest_entry
manifest_path.write_text(json.dumps(manifest))
return artifact_path.relative_to(artif_dir)


def _read_manifest_backward_compatible(
root: Path, artifact: DocArtifact
) -> tuple[Path, dict[str, Any]]:
manifest_path = root / _manifest_path(artifact.doc_id, project=artifact.project)
if manifest_path.exists():
return manifest_path, _read_artifact_manifest(root, artifact)
meta_path = root / _metadata_path(artifact.doc_id, project=artifact.project)
if meta_path.exists():
return meta_path, _read_artifact_manifest(root, artifact)
return manifest_path, dict()


def _write_artifact_bytes(path: Path, artifact: bytes | BytesIO | Path) -> None:
match artifact:
case bytes():
artifact_path.write_bytes(artifact.artifact)
path.write_bytes(artifact)
case BytesIO():
with artifact_path.open("wb") as f:
f.write(artifact.artifact.read())
with path.open("wb") as f:
f.write(artifact.read())
case Path():
artifact_path.unlink(missing_ok=True)
shutil.move(artifact.artifact, artifact_path)
path.unlink(missing_ok=True)
shutil.move(artifact, path)
case _:
msg = f"unsupported artifact type: {artifact.artifact.__class__.__name__}"
msg = f"unsupported artifact type: {artifact.__class__.__name__}"
raise ValueError(msg)
meta_path = root / _metadata_path(artifact.doc_id, project=artifact.project)
meta = _read_artifact_metadata(root, artifact) if meta_path.exists() else dict()
meta[artifact.metadata_key] = artifact.filename
meta_path.write_text(json.dumps(meta))
return artifact_path.relative_to(artif_dir)


def debuggable_name(
Expand Down
2 changes: 1 addition & 1 deletion datashare-python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ dependencies = [
"alive-progress~=3.2",
"aiohttp~=3.11",
"hatchling~=1.27",
"icij-common[elasticsearch]~=0.8.2",
"icij-common[elasticsearch]~=0.8.3",
"langcodes~=3.5",
"python-json-logger~=4.0",
"pyyaml~=6.0",
Expand Down
20 changes: 20 additions & 0 deletions datashare-python/tests/test_objects.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import re
from datetime import datetime
from pathlib import Path
Expand All @@ -6,6 +7,8 @@
from datashare_python.conftest import TEST_PROJECT
from datashare_python.constants import TIKA_METADATA_RESOURCENAME
from datashare_python.objects import (
BasePagination,
ByteRangesPagination,
DatashareLanguage,
Document,
DocumentLocation,
Expand Down Expand Up @@ -94,3 +97,20 @@ def test_invalid_datashare_language_should_raise(
# When/Then
with pytest.raises(ValidationError, match=expected_msg):
type_adapter.validate_python(language)


def test_pagination_serde() -> None:
# Given
pagination = ByteRangesPagination(total=3, byte_ranges=[(0, 1), (1, 2), (2, 3)])
ta = TypeAdapter(BasePagination)
# When
serialized = pagination.model_dump_json(by_alias=True)
deserialized = ta.validate_json(serialized)
# Then
expected_serialized = {
"type": "byteRanges",
"total": 3,
"byteRanges": [[0, 1], [1, 2], [2, 3]],
}
assert json.loads(serialized) == expected_serialized
assert deserialized == pagination
Loading
Loading