Skip to content

Commit 1d62eb5

Browse files
committed
Add BeamLine.from/to_file Functions
Hide more internals of serialization (more user friendly) and prepare for increased complexity of the root structure of PALS.
1 parent ada018d commit 1d62eb5

File tree

5 files changed

+121
-66
lines changed

5 files changed

+121
-66
lines changed

examples/fodo.py

Lines changed: 12 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
import json
2-
import yaml
3-
41
from pals import MagneticMultipoleParameters
52
from pals import Drift
63
from pals import Quadrupole
@@ -45,34 +42,24 @@ def main():
4542
drift3,
4643
],
4744
)
45+
4846
# Serialize to YAML
49-
yaml_data = yaml.dump(line.model_dump(), default_flow_style=False)
50-
print("Dumping YAML data...")
51-
print(f"{yaml_data}")
52-
# Write YAML data to file
53-
yaml_file = "examples_fodo.yaml"
54-
with open(yaml_file, "w") as file:
55-
file.write(yaml_data)
47+
yaml_file = "examples_fodo.pals.yaml"
48+
line.to_file(yaml_file)
49+
5650
# Read YAML data from file
57-
with open(yaml_file, "r") as file:
58-
yaml_data = yaml.safe_load(file)
59-
# Parse YAML data
60-
loaded_line = BeamLine(**yaml_data)
51+
loaded_line = BeamLine.from_file(yaml_file)
52+
6153
# Validate loaded data
6254
assert line == loaded_line
55+
6356
# Serialize to JSON
64-
json_data = json.dumps(line.model_dump(), sort_keys=True, indent=2)
65-
print("Dumping JSON data...")
66-
print(f"{json_data}")
67-
# Write JSON data to file
68-
json_file = "examples_fodo.json"
69-
with open(json_file, "w") as file:
70-
file.write(json_data)
57+
json_file = "examples_fodo.pals.json"
58+
line.to_file(json_file)
59+
7160
# Read JSON data from file
72-
with open(json_file, "r") as file:
73-
json_data = json.loads(file.read())
74-
# Parse JSON data
75-
loaded_line = BeamLine(**json_data)
61+
loaded_line = BeamLine.from_file(json_file)
62+
7663
# Validate loaded data
7764
assert line == loaded_line
7865

src/pals/functions.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
"""Public, free-standing functions for PALS."""
2+
3+
import os
4+
5+
6+
def inspect_file_extensions(filename: str):
7+
"""Attempt to strip two levels of file extensions to determine the schema.
8+
9+
filename examples: fodo.pals.yaml, fodo.pals.json, ...
10+
"""
11+
file_noext, extension = os.path.splitext(filename)
12+
file_noext_noext, extension_inner = os.path.splitext(file_noext)
13+
14+
if extension_inner != ".pals":
15+
raise RuntimeError(
16+
f"inspect_file_extensions: No support for file {filename} with extension {extension}. "
17+
f"PALS files must end in .pals.json or .pals.yaml or similar."
18+
)
19+
20+
return {
21+
"file_noext": file_noext,
22+
"extension": extension,
23+
"file_noext_noext": file_noext_noext,
24+
"extension_inner": extension_inner,
25+
}
26+
27+
28+
def load_file_to_dict(filename: str) -> dict:
29+
# Attempt to strip two levels of file extensions to determine the schema.
30+
# Examples: fodo.pals.yaml, fodo.pals.json, ...
31+
file_noext, extension, file_noext_noext, extension_inner = inspect_file_extensions(
32+
filename
33+
).values()
34+
35+
# examples: fodo.pals.yaml, fodo.pals.json
36+
with open(filename, "r") as file:
37+
if extension == ".json":
38+
import json
39+
40+
pals_data = json.loads(file.read())
41+
42+
elif extension == ".yaml":
43+
import yaml
44+
45+
pals_data = yaml.safe_load(file)
46+
47+
# TODO: toml, xml
48+
49+
else:
50+
raise RuntimeError(
51+
f"load_file_to_dict: No support for PALS file {filename} with extension {extension} yet."
52+
)
53+
54+
return pals_data
55+
56+
57+
def store_dict_to_file(filename: str, pals_dict: dict):
58+
file_noext, extension, file_noext_noext, extension_inner = inspect_file_extensions(
59+
filename
60+
).values()
61+
62+
# examples: fodo.pals.yaml, fodo.pals.json
63+
if extension == ".json":
64+
import json
65+
66+
json_data = json.dumps(pals_dict, sort_keys=True, indent=2)
67+
with open(filename, "w") as file:
68+
file.write(json_data)
69+
70+
elif extension == ".yaml":
71+
import yaml
72+
73+
yaml_data = yaml.dump(pals_dict, default_flow_style=False)
74+
with open(filename, "w") as file:
75+
file.write(yaml_data)
76+
77+
# TODO: toml, xml
78+
79+
else:
80+
raise RuntimeError(
81+
f"store_dict_to_file: No support for PALS file {filename} with extension {extension} yet."
82+
)

src/pals/kinds/BeamLine.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from .all_elements import get_all_elements_as_annotation
55
from .mixin import BaseElement
6+
from ..functions import load_file_to_dict, store_dict_to_file
67

78

89
class BeamLine(BaseElement):
@@ -25,3 +26,14 @@ def model_dump(self, *args, **kwargs):
2526
from pals.kinds.mixin.all_element_mixin import dump_element_list
2627

2728
return dump_element_list(self, "line", *args, **kwargs)
29+
30+
@staticmethod
31+
def from_file(filename: str) -> "BeamLine":
32+
"""Load a BeamLine from a text file"""
33+
pals_dict = load_file_to_dict(filename)
34+
return BeamLine(**pals_dict)
35+
36+
def to_file(self, filename: str):
37+
"""Save a BeamLine to a text file"""
38+
pals_dict = self.model_dump()
39+
store_dict_to_file(filename, pals_dict)

tests/test_elements.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def test_Quadrupole():
105105
assert element.ElectricMultipoleP.En2 == element_electric_multipole_En2
106106
assert element.ElectricMultipoleP.Es2 == element_electric_multipole_Es2
107107
assert element.ElectricMultipoleP.tilt2 == element_electric_multipole_tilt2
108-
# Serialize the BeamLine object to YAML
108+
# Serialize the element to YAML
109109
yaml_data = yaml.dump(element.model_dump(), default_flow_style=False)
110110
print(f"\n{yaml_data}")
111111

tests/test_serialization.py

Lines changed: 14 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
import json
21
import os
3-
import yaml
42

53
import pals
64

@@ -13,17 +11,10 @@ def test_yaml():
1311
# Create line with both elements
1412
line = pals.BeamLine(name="line", line=[element1, element2])
1513
# Serialize the BeamLine object to YAML
16-
yaml_data = yaml.dump(line.model_dump(), default_flow_style=False)
17-
print(f"\n{yaml_data}")
18-
# Write the YAML data to a test file
19-
test_file = "line.yaml"
20-
with open(test_file, "w") as file:
21-
file.write(yaml_data)
14+
test_file = "line.pals.yaml"
15+
line.to_file(test_file)
2216
# Read the YAML data from the test file
23-
with open(test_file, "r") as file:
24-
yaml_data = yaml.safe_load(file)
25-
# Parse the YAML data back into a BeamLine object
26-
loaded_line = pals.BeamLine(**yaml_data)
17+
loaded_line = pals.BeamLine.from_file(test_file)
2718
# Remove the test file
2819
os.remove(test_file)
2920
# Validate loaded BeamLine object
@@ -38,17 +29,10 @@ def test_json():
3829
# Create line with both elements
3930
line = pals.BeamLine(name="line", line=[element1, element2])
4031
# Serialize the BeamLine object to JSON
41-
json_data = json.dumps(line.model_dump(), sort_keys=True, indent=2)
42-
print(f"\n{json_data}")
43-
# Write the JSON data to a test file
44-
test_file = "line.json"
45-
with open(test_file, "w") as file:
46-
file.write(json_data)
32+
test_file = "line.pals.json"
33+
line.to_file(test_file)
4734
# Read the JSON data from the test file
48-
with open(test_file, "r") as file:
49-
json_data = json.loads(file.read())
50-
# Parse the JSON data back into a BeamLine object
51-
loaded_line = pals.BeamLine(**json_data)
35+
loaded_line = pals.BeamLine.from_file(test_file)
5236
# Remove the test file
5337
os.remove(test_file)
5438
# Validate loaded BeamLine object
@@ -224,21 +208,16 @@ def test_comprehensive_lattice():
224208
],
225209
)
226210

227-
# Test serialization to YAML
228-
yaml_data = yaml.dump(lattice.model_dump(), default_flow_style=False)
229-
print(f"\nComprehensive lattice YAML:\n{yaml_data}")
230-
231211
# Write to temporary file
232-
yaml_file = "comprehensive_lattice.yaml"
233-
with open(yaml_file, "w") as file:
234-
file.write(yaml_data)
212+
yaml_file = "comprehensive_lattice.pals.yaml"
213+
lattice.to_file(yaml_file)
235214

236215
# Read back from file
237216
with open(yaml_file, "r") as file:
238-
loaded_yaml_data = yaml.safe_load(file)
217+
print(f"\nComprehensive lattice YAML:\n{file.read()}")
239218

240219
# Deserialize back to Python object using Pydantic model logic
241-
loaded_lattice = pals.BeamLine(**loaded_yaml_data)
220+
loaded_lattice = pals.BeamLine.from_file(yaml_file)
242221

243222
# Verify the loaded lattice has the correct structure and parameter groups
244223
assert len(loaded_lattice.line) == 31 # Should have 31 elements
@@ -284,21 +263,16 @@ def test_comprehensive_lattice():
284263
assert unionele_loaded.elements[1].kind == "Drift"
285264
assert unionele_loaded.elements[1].length == 0.1
286265

287-
# Test serialization to JSON
288-
json_data = json.dumps(lattice.model_dump(), sort_keys=True, indent=2)
289-
print(f"\nComprehensive lattice JSON:\n{json_data}")
290-
291266
# Write to temporary file
292-
json_file = "comprehensive_lattice.json"
293-
with open(json_file, "w") as file:
294-
file.write(json_data)
267+
json_file = "comprehensive_lattice.pals.json"
268+
lattice.to_file(json_file)
295269

296270
# Read back from file
297271
with open(json_file, "r") as file:
298-
loaded_json_data = json.loads(file.read())
272+
print(f"\nComprehensive lattice JSON:\n{file.read()}")
299273

300274
# Deserialize back to Python object using Pydantic model logic
301-
loaded_lattice_json = pals.BeamLine(**loaded_json_data)
275+
loaded_lattice_json = pals.BeamLine.from_file(json_file)
302276

303277
# Verify the loaded lattice has the correct structure and parameter groups
304278
assert len(loaded_lattice_json.line) == 31 # Should have 31 elements

0 commit comments

Comments
 (0)