From 912be01b18d336c0d1ccdad547734c1454ddd63e Mon Sep 17 00:00:00 2001 From: James Parrott <80779630+JamesParrott@users.noreply.github.com> Date: Mon, 22 Jun 2026 15:25:22 +0100 Subject: [PATCH 1/2] Refactor field descriptor array encoding into Field.encode_field_descriptor --- README.md | 6 ++- src/shapefile.py | 115 +++++++++++++++++++++++++++++++++++++---------- 2 files changed, 97 insertions(+), 24 deletions(-) diff --git a/README.md b/README.md index e433426..17092ad 100644 --- a/README.md +++ b/README.md @@ -1347,7 +1347,11 @@ applies to both reading and writing. True >>> r.close() - +Valid values for encodingErrors are those supported by both `bytes.decode` and `str.encode` in the Python +version being used, e.g. 'strict', 'ignore' or 'replace' +in [CPython 3.9 - 3.14](https://docs.python.org/3/library/stdtypes.html#bytes.decode) +('xmlcharrefreplace' and 'backslashreplace' are only supported by +[`str.encode`](https://docs.python.org/3/library/stdtypes.html#str.encode)). ## Reading Large Shapefiles diff --git a/src/shapefile.py b/src/shapefile.py index d66308e..59ebf8a 100644 --- a/src/shapefile.py +++ b/src/shapefile.py @@ -231,6 +231,44 @@ class FieldType: FIELD_TYPE_ALIASES[c.encode("ascii").upper()] = c +class PossibleDataLoss(Warning): + pass + + +def _largest_valid_truncated_encoding( + s: str, + max_bytes: int, + strict: bool, + encoding: str = "utf8", + encodingErrors: str = "strict", +) -> tuple[bytes, str]: + N = len(s) + for i in reversed(range(0, N + 1)): + trimmed = s[:i] + encoded = trimmed.encode(encoding, encodingErrors) + if len(encoded) <= max_bytes: + if i <= N - 1: + msg = ( + f"Dropped {N - i} code points (e.g. characters)! " + f"{s} was truncated to {trimmed} (discarding: {s[i:]}), " + f"in order to encode it under {max_bytes} bytes for the field or field name. " + f"Used: {encoding=} and {encodingErrors=}. " + ) + if strict: + raise ValueError(f"Data loss. {strict=}.\n{msg}") + else: + warnings.warn( + msg, + category=PossibleDataLoss, + ) + return encoded, trimmed + raise ValueError( + f"Maximum truncation not sufficient to encode below {max_bytes=}. " + f"Could not encode first code point (e.g. character): {s[0]} " + f"to a short enough byte string, using {encoding=}, {encodingErrors=}" + ) + + # Use functional syntax to have an attribute named type, a Python keyword class Field(NamedTuple): name: str @@ -245,6 +283,9 @@ def from_unchecked( field_type: str | bytes | FieldTypeT = "C", size: int = 50, decimal: int = 0, + strict: bool = False, + encoding: str = "utf8", + encodingErrors: str = "strict", ) -> Field: try: type_ = FIELD_TYPE_ALIASES[field_type] @@ -262,10 +303,32 @@ def from_unchecked( # A doctest in README.md previously passed in a string ('40') for size, # so explictly convert name to str, and size and decimal to ints. - return cls( + inst = cls( name=str(name), field_type=type_, size=int(size), decimal=int(decimal) ) + inst.encode_field_descriptor(strict, encoding, encodingErrors) + return inst + + @functools.cache + def encode_field_descriptor( + self, + strict: bool, + encoding: str, + encodingErrors: str, + ) -> bytes: + encoded_name = self.name.encode(encoding, encodingErrors) + encoded_name = encoded_name.replace(b" ", b"_") + encoded_name = encoded_name[:10].ljust(10, b"\x00") + encoded_field_type = self.field_type.encode("ascii") + return pack( + "<10sxc4xBB14x", # Packing the name as "<10sx" adds the null terminator. + encoded_name, + encoded_field_type, + self.size, + self.decimal, + ) + def __repr__(self) -> str: return f'Field(name="{self.name}", field_type=FieldType.{self.field_type}, size={self.size}, decimal={self.decimal})' @@ -2598,15 +2661,16 @@ def _dbfHeader(self) -> None: numFields = (self.__dbfHdrLength - 33) // 32 for __field in range(numFields): encoded_field_tuple: tuple[bytes, bytes, int, int] = unpack( - "<11sc4xBB14x", self.file.read(32) + # Historically the name is a 10 char, null-terminated byte string. + # For clarity we now unpack it as <10sx, + # (instead of <11s, and then having to remove the null + # terminator that never needed to be unpacked in the first place). + "<10sxc4xBB14x", + self.file.read(32), ) encoded_name, encoded_type_char, size, decimal = encoded_field_tuple - if b"\x00" in encoded_name: - idx = encoded_name.index(b"\x00") - else: - idx = len(encoded_name) - 1 - encoded_name = encoded_name[:idx] + encoded_name, __, ___ = encoded_name.partition(b"\x00") name = encoded_name.decode(self.encoding, self.encodingErrors) name = name.lstrip() @@ -2624,11 +2688,11 @@ def _dbfHeader(self) -> None: # store all field positions for easy lookups # note: fieldLookup gives the index position of a field inside Reader.fields - self.__fieldLookup = {f[0]: i for i, f in enumerate(self.fields)} + self.__fieldLookup = {f.name: i for i, f in enumerate(self.fields)} # by default, read all fields except the deletion flag, hence "[1:]" # note: recLookup gives the index position of a field inside a _Record list - fieldnames = [f[0] for f in self.fields[1:]] + fieldnames = [f.name for f in self.fields[1:]] __fieldTuples, recLookup, recStruct = self._record_fields(fieldnames) self.__fullRecStruct = recStruct self.__fullRecLookup = recLookup @@ -3831,6 +3895,7 @@ def __init__( encoding: str = "utf-8", encodingErrors: str = "strict", max_num_fields: int = 2046, + strict: bool = False, # Keep kwargs even though unused, to preserve PyShp 2.4 API **kwargs: Any, ): @@ -3838,9 +3903,10 @@ def __init__( self.encoding = encoding self.encodingErrors = encodingErrors + self.max_num_fields = max_num_fields + self.strict = strict self.fields: list[Field] = [] - self.max_num_fields = max_num_fields self.recNum = 0 def field( @@ -3856,7 +3922,15 @@ def field( raise dbfFileException( f".dbf Shapefile Writer reached maximum number of fields: {self.max_num_fields}." ) - field_ = Field.from_unchecked(name, field_type, size, decimal) + field_ = Field.from_unchecked( + name=name, + field_type=field_type, + size=size, + decimal=decimal, + encoding=self.encoding, + encodingErrors=self.encodingErrors, + strict=self.strict, + ) self.fields.append(field_) def _header(self) -> None: @@ -3892,21 +3966,16 @@ def _header(self) -> None: recordLength, ) f.write(header) + # Field descriptors for field in fields: - encoded_name = field.name.encode(self.encoding, self.encodingErrors) - encoded_name = encoded_name.replace(b" ", b"_") - encoded_name = encoded_name[:10].ljust(11).replace(b" ", b"\x00") - encodedFieldType = field.field_type.encode("ascii") - fld = pack( - "<11sc4xBB14x", - encoded_name, - encodedFieldType, - field.size, - field.decimal, + f.write( + field.encode_field_descriptor( + self.strict, self.encoding, self.encodingErrors + ) ) - f.write(fld) - # Terminator + + # Terminator (0x0d from dbf spec https://en.wikipedia.org/wiki/.dbf#File_header) f.write(b"\r") def record( From 2772856f5924b2f3313b4c6b0486a4f4ce3b524b Mon Sep 17 00:00:00 2001 From: James Parrott <80779630+JamesParrott@users.noreply.github.com> Date: Mon, 22 Jun 2026 17:25:36 +0100 Subject: [PATCH 2/2] Make Field roundtrippable. Allow opting out of whitespace underscore replacement and stripping. Rename mangled class methods. --- src/shapefile.py | 143 ++++++++++++++++++++++++-------------- tests/hypothesis_tests.py | 26 ++++++- 2 files changed, 113 insertions(+), 56 deletions(-) diff --git a/src/shapefile.py b/src/shapefile.py index 59ebf8a..84f468d 100644 --- a/src/shapefile.py +++ b/src/shapefile.py @@ -269,13 +269,42 @@ def _largest_valid_truncated_encoding( ) -# Use functional syntax to have an attribute named type, a Python keyword class Field(NamedTuple): name: str field_type: FieldTypeT size: int decimal: int + @classmethod + @functools.cache + def get_struct(cls) -> Struct: + # En/decoding the name as "<10sx" embeds the null terminator. + return Struct("<10sxc4xBB14x") + + @classmethod + def from_byte_stream( + cls, + b_io: ReadableBinStream, + strict: bool = False, + encoding: str = "utf8", + encodingErrors: str = "strict", + strip_leading_whitespace: bool = True, + ) -> Field: + encoded_field_tuple: tuple[bytes, bytes, int, int] + encoded_field_tuple = cls.get_struct().unpack(b_io.read(32)) + encoded_name, encoded_type_char, size, decimal = encoded_field_tuple + + encoded_name, __, ___ = encoded_name.partition(b"\x00") + name = encoded_name.decode(encoding, encodingErrors) + if strip_leading_whitespace: + name = name.lstrip() + + field_type = FIELD_TYPE_ALIASES[encoded_type_char] + + return cls.from_unchecked( + name, field_type, size, decimal, strict, encoding, encodingErrors + ) + @classmethod def from_unchecked( cls, @@ -287,6 +316,17 @@ def from_unchecked( encoding: str = "utf8", encodingErrors: str = "strict", ) -> Field: + + if "\x00" in name: + msg = ( + "Field names should contain null characters " + "as null bytes are used to pad them in the header. " + f"Got: {name=} " + ) + if strict: + raise dbfFileException(msg) + warnings.warn(msg, category=PossibleDataLoss) + try: type_ = FIELD_TYPE_ALIASES[field_type] except KeyError: @@ -307,22 +347,25 @@ def from_unchecked( name=str(name), field_type=type_, size=int(size), decimal=int(decimal) ) - inst.encode_field_descriptor(strict, encoding, encodingErrors) + inst.encode_field_descriptor( + strict=True, encoding=encoding, encodingErrors=encodingErrors + ) return inst @functools.cache def encode_field_descriptor( self, - strict: bool, - encoding: str, - encodingErrors: str, + strict: bool = False, + encoding: str = "utf8", + encodingErrors: str = "strict", + replace_ascii_spaces_with_underscores: bool = True, ) -> bytes: encoded_name = self.name.encode(encoding, encodingErrors) - encoded_name = encoded_name.replace(b" ", b"_") + if replace_ascii_spaces_with_underscores: + encoded_name = encoded_name.replace(b" ", b"_") encoded_name = encoded_name[:10].ljust(10, b"\x00") encoded_field_type = self.field_type.encode("ascii") - return pack( - "<10sxc4xBB14x", # Packing the name as "<10sx" adds the null terminator. + return self.get_struct().pack( encoded_name, encoded_field_type, self.size, @@ -938,7 +981,7 @@ def __init__( self._errors: dict[str, int] = {} # add oid - self.__oid: int = -1 if oid is None else oid + self._oid: int = -1 if oid is None else oid if self.shapeType != NULL and self.shapeType not in Point_shapeTypes: self.bbox: BBox = bbox or self._bbox_from_points() @@ -978,7 +1021,7 @@ def __init__( @property def oid(self) -> int: """The index position of the shape in the original shapefile""" - return self.__oid + return self._oid @property def shapeTypeName(self) -> str: @@ -998,8 +1041,8 @@ def points_3D(self) -> list[Point3D]: def __repr__(self) -> str: class_name = self.__class__.__name__ if class_name == "Shape": - return f"Shape #{self.__oid}: {self.shapeTypeName}" - return f"{class_name} #{self.__oid}" + return f"Shape #{self._oid}: {self.shapeTypeName}" + return f"{class_name} #{self._oid}" def _bbox_from_points(self) -> BBox: xs: list[float] = [] @@ -2129,11 +2172,11 @@ def __init__( :param values: A sequence of values :param oid: The object id, an int (optional) """ - self.__field_positions = field_positions + self._field_positions = field_positions if oid is not None: - self.__oid = oid + self._oid = oid else: - self.__oid = -1 + self._oid = -1 list.__init__(self, values) def __getattr__(self, item: str) -> RecordValue: @@ -2150,7 +2193,7 @@ def __getattr__(self, item: str) -> RecordValue: try: if item == "__setstate__": # Prevent infinite loop from copy.deepcopy() raise AttributeError("_Record does not implement __setstate__") - index = self.__field_positions[item] + index = self._field_positions[item] return list.__getitem__(self, index) except KeyError: raise AttributeError(f"{item} is not a field name") @@ -2170,7 +2213,7 @@ def __setattr__(self, key: str, value: RecordValue) -> None: if key.startswith("_"): # Prevent infinite loop when setting mangled attribute return list.__setattr__(self, key, value) try: - index = self.__field_positions[key] + index = self._field_positions[key] return list.__setitem__(self, index, value) except KeyError: raise AttributeError(f"{key} is not a field name") @@ -2198,7 +2241,7 @@ def __getitem__( return list.__getitem__(self, cast(Union[SupportsIndex, slice], item)) except TypeError: try: - index = self.__field_positions[cast(str, item)] + index = self._field_positions[cast(str, item)] except KeyError: index = None if index is not None: @@ -2232,7 +2275,7 @@ def __setitem__( list.__setitem__(self, *cast(ValidKVTuple, (key, value))) return except TypeError: - index = self.__field_positions.get(cast(str, key)) + index = self._field_positions.get(cast(str, key)) if index is not None: list.__setitem__(self, index, cast(RecordValue, value)) return @@ -2242,14 +2285,14 @@ def __setitem__( @property def oid(self) -> int: """The index position of the record in the original shapefile""" - return self.__oid + return self._oid def as_dict(self, date_strings: bool = False) -> dict[str, RecordValue]: """ Returns this Record as a dictionary using the field names as keys :return: dict """ - dct = {f: self[i] for f, i in self.__field_positions.items()} + dct = {f: self[i] for f, i in self._field_positions.items()} if date_strings: for k, v in dct.items(): if isinstance(v, date): @@ -2257,7 +2300,7 @@ def as_dict(self, date_strings: bool = False) -> dict[str, RecordValue]: return dct def __repr__(self) -> str: - return f"Record #{self.__oid}: {list(self)}" + return f"Record #{self._oid}: {list(self)}" def __dir__(self) -> list[str]: """ @@ -2270,13 +2313,13 @@ def __dir__(self) -> list[str]: dir(type(self)) ) # default list methods and attributes of this class fnames = list( - self.__field_positions.keys() + self._field_positions.keys() ) # plus field names (random order if Python version < 3.6) return default + fnames def __eq__(self, other: Any) -> bool: if isinstance(other, _Record): - if self.__field_positions != other.__field_positions: + if self._field_positions != other._field_positions: return False return list.__eq__(self, other) @@ -2631,14 +2674,16 @@ def __init__( *, encoding: str = "utf-8", encodingErrors: str = "strict", + strict: bool = False, ): super().__init__(file=dbf) self.encoding = encoding self.encodingErrors = encodingErrors + self.strict = strict self.fields: list[Field] = [] - self.__fieldLookup: dict[str, int] = {} + self._fieldLookup: dict[str, int] = {} self._dbfHeader() @@ -2653,34 +2698,26 @@ def _dbfHeader(self) -> None: # read relevant header parts self.file.seek(0) - self.numRecords, self.__dbfHdrLength, self._record_length = cast( + self.numRecords, self._dbfHdrLength, self._record_length = cast( tuple[int, int, int], unpack(" None: # store all field positions for easy lookups # note: fieldLookup gives the index position of a field inside Reader.fields - self.__fieldLookup = {f.name: i for i, f in enumerate(self.fields)} + self._fieldLookup = {f.name: i for i, f in enumerate(self.fields)} # by default, read all fields except the deletion flag, hence "[1:]" # note: recLookup gives the index position of a field inside a _Record list fieldnames = [f.name for f in self.fields[1:]] __fieldTuples, recLookup, recStruct = self._record_fields(fieldnames) - self.__fullRecStruct = recStruct - self.__fullRecLookup = recLookup + self._fullRecStruct = recStruct + self._fullRecLookup = recLookup def _record_fmt(self, fields: Container[str] | None = None) -> tuple[str, int]: """Calculates the format and size of a .dbf record. Optional 'fields' arg @@ -2739,7 +2776,7 @@ def _record_fields( recStruct = Struct(fmt) # make sure the given fieldnames exist for name in unique_fields: - if name not in self.__fieldLookup or name == "DeletionFlag": + if name not in self._fieldLookup or name == "DeletionFlag": raise ValueError(f'"{name}" is not a valid field name') # fetch relevant field info tuples fieldTuples = [] @@ -2752,8 +2789,8 @@ def _record_fields( else: # use all the dbf fields fieldTuples = self.fields[1:] # sans deletion flag - recStruct = self.__fullRecStruct - recLookup = self.__fullRecLookup + recStruct = self._fullRecStruct + recLookup = self._fullRecLookup return fieldTuples, recLookup, recStruct def _record( @@ -2867,7 +2904,7 @@ def record(self, i: int = 0, fields: list[str] | None = None) -> _Record | None: i = ensure_within_bounds(i, self.numRecords) recSize = self._record_length self.file.seek(0) - self.file.seek(self.__dbfHdrLength + (i * recSize)) + self.file.seek(self._dbfHdrLength + (i * recSize)) fieldTuples, recLookup, recStruct = self._record_fields(fields) return self._record( oid=i, fieldTuples=fieldTuples, recLookup=recLookup, recStruct=recStruct @@ -2939,7 +2976,7 @@ def iterRecords( elif stop < 0: stop = range(self.numRecords)[stop] recSize = self._record_length - self.file.seek(self.__dbfHdrLength + (start * recSize)) + self.file.seek(self._dbfHdrLength + (start * recSize)) fieldTuples, recLookup, recStruct = self._record_fields(fields) for i in range(start, stop): r = self._record( diff --git a/tests/hypothesis_tests.py b/tests/hypothesis_tests.py index c0f89cd..6e848ef 100644 --- a/tests/hypothesis_tests.py +++ b/tests/hypothesis_tests.py @@ -536,12 +536,15 @@ def test_shx_reader_writer_roundtrip(codes_and_shapes)-> None: } @composite -def dbf_field(draw): +def dbf_fields(draw): field_type, bounds_dict = draw(sampled_from(list(DBF_FIELD_TYPES.items()))) name = draw( text( - alphabet=characters(codec="ascii"), + alphabet=characters( + codec="ascii", + exclude_characters=["\x00"], + ), min_size=1, max_size=10, ) @@ -556,6 +559,23 @@ def dbf_field(draw): return {"name": name, "field_type": field_type, "size": size, "decimal": decimal} + +@pytest.mark.hypothesis +@given(field_kwargs=dbf_fields()) +def test_dbf_Field_roundtrips( + field_kwargs: dict, +) -> None: + expected = shp.Field.from_unchecked(**field_kwargs) + stream = io.BytesIO() + encoded = expected.encode_field_descriptor(replace_ascii_spaces_with_underscores=False) + stream.write(encoded) + stream.seek(0) + actual = shp.Field.from_byte_stream(stream, strip_leading_whitespace=False) + assert isinstance(actual, shp.Field) + assert actual.name == expected.name + assert actual[1:] == expected[1:] + + ascii_printable = string.ascii_letters + string.digits + string.punctuation + " " def record_value_for_field(name: str, field_type: str, size: int, decimal: int = 0): @@ -596,7 +616,7 @@ def _dbf_fields_and_record_strategy( max_records=20, ): - fields = draw(lists(dbf_field(), min_size=1, max_size=max_fields)) + fields = draw(lists(dbf_fields(), min_size=1, max_size=max_fields)) record_strategy = tuples(*(record_value_for_field(**field) for field in fields))