Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ classifiers = [
]

[project.optional-dependencies]
test = ["pytest"]
test = ["pytest", "numpy"]

[project.urls]
Documentation = "https://pals-project.readthedocs.io"
Expand Down
67 changes: 65 additions & 2 deletions src/pals/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,25 @@ def load_file_to_dict(filename: str) -> dict:
return pals_data


def _numpy_to_native(obj):
"""Convert a numpy scalar/array to its Python-native equivalent.

Returns ``None`` when the object is not a numpy type or when numpy is not
installed; callers use that to decide whether to fall back to the default
serializer behavior. numpy is an optional dependency.
"""
try:
import numpy as np
except ImportError:
return None

if isinstance(obj, np.ndarray):
return obj.tolist()
if isinstance(obj, np.generic):
return obj.item()
return None


def store_dict_to_file(filename: str, pals_dict: dict):
file_noext, extension, file_noext_noext, extension_inner = inspect_file_extensions(
filename
Expand All @@ -63,14 +82,58 @@ def store_dict_to_file(filename: str, pals_dict: dict):
if extension == ".json":
import json

json_data = json.dumps(pals_dict, sort_keys=False, indent=2)
def _json_default(obj):
native = _numpy_to_native(obj)
if native is not None:
return native
raise TypeError(
f"Object of type {type(obj).__name__} is not JSON serializable"
)

json_data = json.dumps(
pals_dict, sort_keys=False, indent=2, default=_json_default
)
with open(filename, "w") as file:
file.write(json_data)

elif extension == ".yaml":
import yaml

yaml_data = yaml.dump(pals_dict, default_flow_style=False, sort_keys=False)
# Subclass the safe dumper so numpy representers are scoped to PALS
# serialization and do not leak into the global pyyaml state used by
# other code in the same process.
class _PALSDumper(yaml.SafeDumper):
pass

try:
import numpy as np
except ImportError:
np = None

if np is not None:

def _represent_numpy_scalar(dumper, value):
native = value.item()
if isinstance(native, bool):
return dumper.represent_bool(native)
if isinstance(native, int):
return dumper.represent_int(native)
if isinstance(native, float):
return dumper.represent_float(native)
return dumper.represent_data(native)

def _represent_numpy_array(dumper, value):
return dumper.represent_list(value.tolist())

_PALSDumper.add_multi_representer(np.generic, _represent_numpy_scalar)
_PALSDumper.add_representer(np.ndarray, _represent_numpy_array)

yaml_data = yaml.dump(
pals_dict,
Dumper=_PALSDumper,
default_flow_style=False,
sort_keys=False,
)
with open(filename, "w") as file:
file.write(yaml_data)

Expand Down
73 changes: 73 additions & 0 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pytest

import pals


Expand Down Expand Up @@ -322,3 +324,74 @@ def test_comprehensive_lattice(tmp_path):
assert unionele_loaded_json.elements[1].name == "union_drift"
assert unionele_loaded_json.elements[1].kind == "Drift"
assert unionele_loaded_json.elements[1].length == 0.1


def _build_numpy_lattice(np):
"""Build a small lattice using numpy-typed scalar values throughout."""
quad = pals.Quadrupole(
name="q_np",
length=np.float64(0.061),
MagneticMultipoleP=pals.MagneticMultipoleParameters(
Bn1=np.float64(-26.0), Bs1=np.float32(0.5), Kn0=np.int64(-1)
),
)
oct_ = pals.Octupole(
name="o_np",
length=np.float64(0.25),
ElectricMultipoleP=pals.ElectricMultipoleParameters(
En3=np.float64(0.75), Es3=np.float32(0.125)
),
)
return pals.BeamLine(name="line_np", line=[quad, oct_])


def test_yaml_roundtrip_with_numpy(tmp_path):
"""Round-trip numpy-typed values through YAML (regression for issue #67).

Writing YAML with numpy-typed values must not produce !!python/object tags.
Round-tripping must yield Python-native floats with the correct numeric values.
"""
np = pytest.importorskip("numpy")

line = _build_numpy_lattice(np)
yaml_file = tmp_path / "numpy_roundtrip.pals.yaml"
line.to_file(yaml_file)

with open(yaml_file, "r") as f:
text = f.read()

# The bug symptom: YAML contains opaque numpy object tags.
assert "!!python/object" not in text, (
f"YAML output still contains unsafe numpy object tags:\n{text}"
)
assert "numpy" not in text, f"YAML output still references numpy:\n{text}"

loaded = pals.BeamLine.from_file(yaml_file)
loaded_quad = loaded.line[0]
assert loaded_quad.MagneticMultipoleP.Bn1 == -26.0
assert type(loaded_quad.MagneticMultipoleP.Bn1) is float
assert loaded_quad.MagneticMultipoleP.Bs1 == 0.5
assert loaded_quad.MagneticMultipoleP.Kn0 == -1

loaded_oct = loaded.line[1]
assert loaded_oct.ElectricMultipoleP.En3 == 0.75
assert type(loaded_oct.ElectricMultipoleP.En3) is float


def test_json_roundtrip_with_numpy(tmp_path):
"""Round-trip numpy-typed values through JSON (companion for issue #67).

JSON also needs to handle numpy values cleanly (defense-in-depth).
"""
np = pytest.importorskip("numpy")

line = _build_numpy_lattice(np)
json_file = tmp_path / "numpy_roundtrip.pals.json"
line.to_file(json_file)

loaded = pals.BeamLine.from_file(json_file)
loaded_quad = loaded.line[0]
assert loaded_quad.MagneticMultipoleP.Bn1 == -26.0
assert type(loaded_quad.MagneticMultipoleP.Bn1) is float
loaded_oct = loaded.line[1]
assert loaded_oct.ElectricMultipoleP.En3 == 0.75
Loading