Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/shapefile.py
Original file line number Diff line number Diff line change
Expand Up @@ -3943,9 +3943,9 @@ def record(
else:
# Blank fields for empty record
record = ["" for _ in range(fieldCount)]
self.__dbfRecord(record)
self._record(record)

def __dbfRecord(self, record: list[RecordValue]) -> None:
def _record(self, record: list[RecordValue]) -> None:
"""Writes the dbf records."""
f = self.file
if self.recNum == 0:
Expand Down
121 changes: 121 additions & 0 deletions tests/hypothesis_tests.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations

import datetime
import io
import itertools
import string

import pytest
from hypothesis import HealthCheck, given, settings
Expand All @@ -16,6 +18,9 @@
one_of,
tuples,
sampled_from,
text,
characters,
dates,
)

import shapefile as shp
Expand Down Expand Up @@ -518,3 +523,119 @@ def test_shx_reader_writer_roundtrip(codes_and_shapes)-> None:
assert r.offsets == offsets_B
assert r.shape_lengths_B == sizes_B



DBF_FIELD_TYPES = {
"C": {},
"N": {"max_decimal" : 20, "max_length": 22}, # max length=23 to avoid error due to precision limit, e.g.:
"F": {"max_decimal" : 20, "max_length": 22}, # hypothesis.errors.InvalidArgument: max_value=100000000000000000000000
# cannot be exactly represented as a float of
# width 64 - use max_value=1e+23 instead.
"L": {"max_length": 1},
"D": {"min_length": 8, "max_length": 8},
}

@composite
def dbf_field(draw):
field_type, bounds_dict = draw(sampled_from(list(DBF_FIELD_TYPES.items())))

name = draw(
text(
alphabet=characters(codec="ascii"),
min_size=1,
max_size=10,
)
)

max_length = bounds_dict.get("max_length", 254)
min_length = bounds_dict.get("min_length", 1)
max_decimal = bounds_dict.get("max_decimal", 0)
size = draw(integers(min_value=min_length, max_value=max_length))
decimal = draw(integers(min_value=0, max_value=max(0,min(size - 3, max_decimal))))


return {"name": name, "field_type": field_type, "size": size, "decimal": decimal}

ascii_printable = string.ascii_letters + string.digits + string.punctuation #+ " "

def record_value_for_field(name: str, field_type: str, size: int, decimal: int = 0):

if field_type == "C":
return text(
alphabet=ascii_printable,
min_size=0,
max_size=size,
)
if field_type in {"N", "F"}:

int_digits = size if decimal == 0 else size - decimal - 1
min_int = -(10 ** (int_digits - 1) - 1)
max_int = 10 ** int_digits - 1

if decimal == 0:
return integers(min_value=min_int, max_value=max_int)

# Max finite float: 2**1023 * (2 - 2**(-52))
return floats(
min_value=min_int - 1,
max_value=max_int + 1,
exclude_min=True,
exclude_max=True,
)
if field_type == "L":
return sampled_from([True, False, None])
if field_type == "D":
return one_of(dates(), dates().map(lambda d: d.strftime("%Y%m%d")))

raise ValueError(f"Unsupported: {field_type=}")


@composite
def dbf_fields_and_records(
draw,
max_fields=10, # In DbfWriter.__init__, max_num_fields: int = 2046,
max_records=20,
):

fields = draw(lists(dbf_field(), min_size=1, max_size=max_fields))

record_strategy = tuples(*(record_value_for_field(**field) for field in fields))

records = draw(lists(record_strategy, min_size=0, max_size=max_records))

return fields, records



@pytest.mark.hypothesis
@given(fields_and_records=dbf_fields_and_records())
def test_dbf_reader_writer_roundtrip(fields_and_records)-> None:
fields, records = fields_and_records
stream = io.BytesIO()
with shp.DbfWriter(dbf=stream) as dbf_w:
for field in fields:
dbf_w.field(**field)
for record in records:
dbf_w.record(*record)
stream.seek(0)
with shp.DbfReader(dbf=stream) as r:
actual_fields = iter(r.fields)
next(actual_fields) # skip deletion flag
for f_r, f_w in itertools.zip_longest(actual_fields, fields):
actual_field_dict = f_r._asdict()
for k in ("field_type", "size", "decimal"):
assert actual_field_dict[k] == f_w[k], f"{k=}, {actual_field_dict[k]=}, {f_w[k]=}"
for exp_rec, actual_rec in itertools.zip_longest(records, r.records()):
for expected, actual, field in itertools.zip_longest(exp_rec, actual_rec, fields):
field_type = field["field_type"]
decimal = field["decimal"]
if field_type == "D":
if isinstance(expected, datetime.date):
expected = expected.strftime("%Y%m%d")
if isinstance(actual, datetime.date):
actual = actual.strftime("%Y%m%d")
elif field_type in ("N", "F"):
expected = float(format(expected, f".{decimal}f"))
# elif field_type == "C":
# expected = expected.strip()
assert actual == expected, f"{actual=}, {expected=}, {field_type=}, {type(actual)=}, {type(expected)=}"
Loading