diff --git a/src/shapefile.py b/src/shapefile.py index ac64591..d66308e 100644 --- a/src/shapefile.py +++ b/src/shapefile.py @@ -3943,9 +3943,9 @@ def record( else: # Blank fields for empty record record = ["" for _ in range(fieldCount)] - self.__dbfRecord(record) + self._record(record) - def __dbfRecord(self, record: list[RecordValue]) -> None: + def _record(self, record: list[RecordValue]) -> None: """Writes the dbf records.""" f = self.file if self.recNum == 0: diff --git a/tests/hypothesis_tests.py b/tests/hypothesis_tests.py index 00b7505..7a7ebe3 100644 --- a/tests/hypothesis_tests.py +++ b/tests/hypothesis_tests.py @@ -1,7 +1,9 @@ from __future__ import annotations +import datetime import io import itertools +import string import pytest from hypothesis import HealthCheck, given, settings @@ -16,6 +18,9 @@ one_of, tuples, sampled_from, + text, + characters, + dates, ) import shapefile as shp @@ -518,3 +523,119 @@ def test_shx_reader_writer_roundtrip(codes_and_shapes)-> None: assert r.offsets == offsets_B assert r.shape_lengths_B == sizes_B + + +DBF_FIELD_TYPES = { + "C": {}, + "N": {"max_decimal" : 20, "max_length": 22}, # max length=23 to avoid error due to precision limit, e.g.: + "F": {"max_decimal" : 20, "max_length": 22}, # hypothesis.errors.InvalidArgument: max_value=100000000000000000000000 + # cannot be exactly represented as a float of + # width 64 - use max_value=1e+23 instead. + "L": {"max_length": 1}, + "D": {"min_length": 8, "max_length": 8}, +} + +@composite +def dbf_field(draw): + field_type, bounds_dict = draw(sampled_from(list(DBF_FIELD_TYPES.items()))) + + name = draw( + text( + alphabet=characters(codec="ascii"), + min_size=1, + max_size=10, + ) + ) + + max_length = bounds_dict.get("max_length", 254) + min_length = bounds_dict.get("min_length", 1) + max_decimal = bounds_dict.get("max_decimal", 0) + size = draw(integers(min_value=min_length, max_value=max_length)) + decimal = draw(integers(min_value=0, max_value=max(0,min(size - 3, max_decimal)))) + + + return {"name": name, "field_type": field_type, "size": size, "decimal": decimal} + +ascii_printable = string.ascii_letters + string.digits + string.punctuation #+ " " + +def record_value_for_field(name: str, field_type: str, size: int, decimal: int = 0): + + if field_type == "C": + return text( + alphabet=ascii_printable, + min_size=0, + max_size=size, + ) + if field_type in {"N", "F"}: + + int_digits = size if decimal == 0 else size - decimal - 1 + min_int = -(10 ** (int_digits - 1) - 1) + max_int = 10 ** int_digits - 1 + + if decimal == 0: + return integers(min_value=min_int, max_value=max_int) + + # Max finite float: 2**1023 * (2 - 2**(-52)) + return floats( + min_value=min_int - 1, + max_value=max_int + 1, + exclude_min=True, + exclude_max=True, + ) + if field_type == "L": + return sampled_from([True, False, None]) + if field_type == "D": + return one_of(dates(), dates().map(lambda d: d.strftime("%Y%m%d"))) + + raise ValueError(f"Unsupported: {field_type=}") + + +@composite +def dbf_fields_and_records( + draw, + max_fields=10, # In DbfWriter.__init__, max_num_fields: int = 2046, + max_records=20, + ): + + fields = draw(lists(dbf_field(), min_size=1, max_size=max_fields)) + + record_strategy = tuples(*(record_value_for_field(**field) for field in fields)) + + records = draw(lists(record_strategy, min_size=0, max_size=max_records)) + + return fields, records + + + +@pytest.mark.hypothesis +@given(fields_and_records=dbf_fields_and_records()) +def test_dbf_reader_writer_roundtrip(fields_and_records)-> None: + fields, records = fields_and_records + stream = io.BytesIO() + with shp.DbfWriter(dbf=stream) as dbf_w: + for field in fields: + dbf_w.field(**field) + for record in records: + dbf_w.record(*record) + stream.seek(0) + with shp.DbfReader(dbf=stream) as r: + actual_fields = iter(r.fields) + next(actual_fields) # skip deletion flag + for f_r, f_w in itertools.zip_longest(actual_fields, fields): + actual_field_dict = f_r._asdict() + for k in ("field_type", "size", "decimal"): + assert actual_field_dict[k] == f_w[k], f"{k=}, {actual_field_dict[k]=}, {f_w[k]=}" + for exp_rec, actual_rec in itertools.zip_longest(records, r.records()): + for expected, actual, field in itertools.zip_longest(exp_rec, actual_rec, fields): + field_type = field["field_type"] + decimal = field["decimal"] + if field_type == "D": + if isinstance(expected, datetime.date): + expected = expected.strftime("%Y%m%d") + if isinstance(actual, datetime.date): + actual = actual.strftime("%Y%m%d") + elif field_type in ("N", "F"): + expected = float(format(expected, f".{decimal}f")) + # elif field_type == "C": + # expected = expected.strip() + assert actual == expected, f"{actual=}, {expected=}, {field_type=}, {type(actual)=}, {type(expected)=}"