From 9533e6fba719475865c8eade49ba320eb161bc83 Mon Sep 17 00:00:00 2001 From: Edoardo Zoni Date: Mon, 8 Jun 2026 11:46:56 -0700 Subject: [PATCH] Use tmp_path fixture in test_serialization.py --- tests/test_serialization.py | 24 +++++++----------------- 1 file changed, 7 insertions(+), 17 deletions(-) diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 7e05b67..5f6d037 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -1,9 +1,7 @@ -import os - import pals -def test_yaml(): +def test_yaml(tmp_path): # Create one base element element1 = pals.Marker(name="element1") # Create one thick element @@ -11,17 +9,15 @@ def test_yaml(): # Create line with both elements line = pals.BeamLine(name="line", line=[element1, element2]) # Serialize the BeamLine object to YAML - test_file = "line.pals.yaml" + test_file = tmp_path / "line.pals.yaml" line.to_file(test_file) # Read the YAML data from the test file loaded_line = pals.BeamLine.from_file(test_file) - # Remove the test file - os.remove(test_file) # Validate loaded BeamLine object assert line == loaded_line -def test_json(): +def test_json(tmp_path): # Create one base element element1 = pals.Marker(name="element1") # Create one thick element @@ -29,17 +25,15 @@ def test_json(): # Create line with both elements line = pals.BeamLine(name="line", line=[element1, element2]) # Serialize the BeamLine object to JSON - test_file = "line.pals.json" + test_file = tmp_path / "line.pals.json" line.to_file(test_file) # Read the JSON data from the test file loaded_line = pals.BeamLine.from_file(test_file) - # Remove the test file - os.remove(test_file) # Validate loaded BeamLine object assert line == loaded_line -def test_comprehensive_lattice(): +def test_comprehensive_lattice(tmp_path): """Test a comprehensive lattice using every PALS element at least once""" # Create elements in alphabetical order for easy maintenance @@ -214,7 +208,7 @@ def test_comprehensive_lattice(): ) # Write to temporary file - yaml_file = "comprehensive_lattice.pals.yaml" + yaml_file = tmp_path / "comprehensive_lattice.pals.yaml" lattice.to_file(yaml_file) # Read back from file @@ -272,7 +266,7 @@ def test_comprehensive_lattice(): assert unionele_loaded.elements[1].length == 0.1 # Write to temporary file - json_file = "comprehensive_lattice.pals.json" + json_file = tmp_path / "comprehensive_lattice.pals.json" lattice.to_file(json_file) # Read back from file @@ -328,7 +322,3 @@ def test_comprehensive_lattice(): assert unionele_loaded_json.elements[1].name == "union_drift" assert unionele_loaded_json.elements[1].kind == "Drift" assert unionele_loaded_json.elements[1].length == 0.1 - - # Clean up temporary files - os.remove(yaml_file) - os.remove(json_file)