Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 107 additions & 12 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

import itertools
import os
import random
import time
import uuid
import warnings
from abc import ABC, abstractmethod
Expand All @@ -31,6 +33,7 @@
from pydantic import Field

import pyiceberg.expressions.parser as parser
from pyiceberg.exceptions import CommitFailedException
from pyiceberg.expressions import AlwaysFalse, AlwaysTrue, And, BooleanExpression, EqualTo, IsNull, Or, Reference
from pyiceberg.expressions.visitors import (
ResidualEvaluator,
Expand Down Expand Up @@ -205,6 +208,22 @@ class TableProperties:
MIN_SNAPSHOTS_TO_KEEP = "history.expire.min-snapshots-to-keep"
MIN_SNAPSHOTS_TO_KEEP_DEFAULT = 1

COMMIT_NUM_RETRIES = "commit.retry.num-retries"
COMMIT_NUM_RETRIES_DEFAULT = 4

COMMIT_MIN_RETRY_WAIT_MS = "commit.retry.min-wait-ms"
COMMIT_MIN_RETRY_WAIT_MS_DEFAULT = 100

COMMIT_MAX_RETRY_WAIT_MS = "commit.retry.max-wait-ms"
COMMIT_MAX_RETRY_WAIT_MS_DEFAULT = 60000

COMMIT_TOTAL_RETRY_TIME_MS = "commit.retry.total-timeout-ms"
COMMIT_TOTAL_RETRY_TIME_MS_DEFAULT = 1800000 # 30 minutes

WRITE_DELETE_ISOLATION_LEVEL = "write.delete.isolation-level"
WRITE_UPDATE_ISOLATION_LEVEL = "write.update.isolation-level"
WRITE_ISOLATION_LEVEL_DEFAULT = "serializable"


class Transaction:
_table: Table
Expand All @@ -223,6 +242,7 @@ def __init__(self, table: Table, autocommit: bool = False):
self._autocommit = autocommit
self._updates = ()
self._requirements = ()
self._snapshot_producers: list[Any] = []

@property
def table_metadata(self) -> TableMetadata:
Expand Down Expand Up @@ -265,6 +285,10 @@ def _stage(

return self

def _register_snapshot_producer(self, producer: Any) -> None:
"""Register a snapshot producer for retry support."""
self._snapshot_producers.append(producer)

def _apply(
self,
updates: tuple[TableUpdate, ...],
Expand Down Expand Up @@ -589,7 +613,12 @@ def dynamic_partition_overwrite(
delete_filter = self._build_partition_predicate(
partition_records=partitions_to_overwrite, spec=self.table_metadata.spec(), schema=self.table_metadata.schema()
)
self.delete(delete_filter=delete_filter, snapshot_properties=snapshot_properties, branch=branch)
self.delete(
delete_filter=delete_filter,
snapshot_properties=snapshot_properties,
branch=branch,
_isolation_level_property=TableProperties.WRITE_UPDATE_ISOLATION_LEVEL,
)

with self._append_snapshot_producer(snapshot_properties, branch=branch) as append_files:
append_files.commit_uuid = append_snapshot_commit_uuid
Expand Down Expand Up @@ -682,6 +711,7 @@ def overwrite(
case_sensitive=case_sensitive,
snapshot_properties=snapshot_properties,
branch=branch,
_isolation_level_property=TableProperties.WRITE_UPDATE_ISOLATION_LEVEL,
)

with self._append_snapshot_producer(snapshot_properties, branch=branch) as append_files:
Expand All @@ -699,6 +729,7 @@ def delete(
snapshot_properties: dict[str, str] = EMPTY_DICT,
case_sensitive: bool = True,
branch: str | None = MAIN_BRANCH,
_isolation_level_property: str | None = None,
) -> None:
"""
Shorthand for deleting record from a table.
Expand Down Expand Up @@ -726,6 +757,8 @@ def delete(
delete_filter = _parse_row_filter(delete_filter)

with self.update_snapshot(snapshot_properties=snapshot_properties, branch=branch).delete() as delete_snapshot:
if _isolation_level_property is not None:
delete_snapshot._isolation_level_property = _isolation_level_property
delete_snapshot.delete_by_predicate(delete_filter, case_sensitive)

# Check if there are any files that require an actual rewrite of a data file
Expand Down Expand Up @@ -781,7 +814,10 @@ def delete(
with self.update_snapshot(
snapshot_properties=snapshot_properties, branch=branch
).overwrite() as overwrite_snapshot:
if _isolation_level_property is not None:
overwrite_snapshot._isolation_level_property = _isolation_level_property
overwrite_snapshot.commit_uuid = commit_uuid
overwrite_snapshot.delete_by_predicate(delete_filter, case_sensitive)
for original_data_file, replaced_data_files in replaced_files:
overwrite_snapshot.delete_data_file(original_data_file)
for replaced_data_file in replaced_data_files:
Expand Down Expand Up @@ -1036,17 +1072,78 @@ def commit_transaction(self) -> Table:
The table with the updates applied.
"""
if len(self._updates) > 0:
self._requirements += (AssertTableUUID(uuid=self.table_metadata.table_uuid),)
self._table._do_commit( # pylint: disable=W0212
updates=self._updates,
requirements=self._requirements,
from pyiceberg.utils.properties import property_as_int

properties = self._table.metadata.properties
num_retries_val = property_as_int(
properties, TableProperties.COMMIT_NUM_RETRIES, TableProperties.COMMIT_NUM_RETRIES_DEFAULT
)
num_retries = num_retries_val if num_retries_val is not None else TableProperties.COMMIT_NUM_RETRIES_DEFAULT
min_wait_val = property_as_int(
properties, TableProperties.COMMIT_MIN_RETRY_WAIT_MS, TableProperties.COMMIT_MIN_RETRY_WAIT_MS_DEFAULT
)
min_wait_ms = min_wait_val if min_wait_val is not None else TableProperties.COMMIT_MIN_RETRY_WAIT_MS_DEFAULT
max_wait_val = property_as_int(
properties, TableProperties.COMMIT_MAX_RETRY_WAIT_MS, TableProperties.COMMIT_MAX_RETRY_WAIT_MS_DEFAULT
)
max_wait_ms = max_wait_val if max_wait_val is not None else TableProperties.COMMIT_MAX_RETRY_WAIT_MS_DEFAULT
total_timeout_val = property_as_int(
properties, TableProperties.COMMIT_TOTAL_RETRY_TIME_MS, TableProperties.COMMIT_TOTAL_RETRY_TIME_MS_DEFAULT
)
total_timeout_ms = (
total_timeout_val if total_timeout_val is not None else TableProperties.COMMIT_TOTAL_RETRY_TIME_MS_DEFAULT
)
start_time = time.monotonic()
self._requirements += (AssertTableUUID(uuid=self.table_metadata.table_uuid),)

try:
for attempt in range(num_retries + 1):
try:
self._table._do_commit( # pylint: disable=W0212
updates=self._updates,
requirements=self._requirements,
)
self._cleanup_uncommitted_manifests()
break
except CommitFailedException:
elapsed_ms = (time.monotonic() - start_time) * 1000
if attempt == num_retries or not self._snapshot_producers or elapsed_ms >= total_timeout_ms:
raise

wait = min(min_wait_ms * (2**attempt), max_wait_ms)
jitter = random.uniform(0, 0.25 * wait)
time.sleep((wait + jitter) / 1000.0)

self._table.refresh()
self._rebuild_snapshot_updates()
except Exception:
for producer in self._snapshot_producers:
producer._clean_all_uncommitted()
raise

self._updates = ()
self._requirements = ()

return self._table

def _cleanup_uncommitted_manifests(self) -> None:
"""Clean up manifests from failed retry attempts after a successful commit."""
for producer in self._snapshot_producers:
producer._cleanup_uncommitted()

def _rebuild_snapshot_updates(self) -> None:
"""Rebuild snapshot updates for retry by re-executing registered producers."""
from pyiceberg.table.update import AddSnapshotUpdate, AssertRefSnapshotId, SetSnapshotRefUpdate

self._updates = tuple(u for u in self._updates if not isinstance(u, (AddSnapshotUpdate, SetSnapshotRefUpdate)))
self._requirements = tuple(r for r in self._requirements if not isinstance(r, AssertRefSnapshotId))

for producer in self._snapshot_producers:
producer._refresh_for_retry()
producer._validate_concurrency()
updates, requirements = producer._commit()
self._stage(updates, requirements)


class CreateTableTransaction(Transaction):
"""A transaction that involves the creation of a new table."""
Expand Down Expand Up @@ -2072,13 +2169,11 @@ def _build_residual_evaluator(self, spec_id: int) -> Callable[[DataFile], Residu
# The lambda created here is run in multiple threads.
# So we avoid creating _EvaluatorExpression methods bound to a single
# shared instance across multiple threads.
return lambda datafile: (
residual_evaluator_of(
spec=spec,
expr=self.row_filter,
case_sensitive=self.case_sensitive,
schema=self.table_metadata.schema(),
)
return lambda datafile: residual_evaluator_of(
spec=spec,
expr=self.row_filter,
case_sensitive=self.case_sensitive,
schema=self.table_metadata.schema(),
)

@staticmethod
Expand Down
7 changes: 7 additions & 0 deletions pyiceberg/table/snapshots.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,13 @@ def __repr__(self) -> str:
return f"Operation.{self.name}"


class IsolationLevel(str, Enum):
"""Transaction isolation level for concurrent write validation."""

SERIALIZABLE = "serializable"
SNAPSHOT = "snapshot"


class UpdateMetrics:
added_file_size: int
removed_file_size: int
Expand Down
Loading