Skip to content

Commit abc928d

Browse files
committed
Revert to previous implementation without class Item
1 parent 70d3733 commit abc928d

File tree

4 files changed

+33
-51
lines changed

4 files changed

+33
-51
lines changed

examples/fodo.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from schema.DriftElement import DriftElement
1212
from schema.QuadrupoleElement import QuadrupoleElement
1313

14-
from schema.Item import Item
1514
from schema.Line import Line
1615

1716

@@ -45,11 +44,11 @@ def main():
4544
# Create line with all elements
4645
line = Line(
4746
line=[
48-
Item(item=drift1),
49-
Item(item=quad1),
50-
Item(item=drift2),
51-
Item(item=quad2),
52-
Item(item=drift3),
47+
drift1,
48+
quad1,
49+
drift2,
50+
quad2,
51+
drift3,
5352
]
5453
)
5554
# Serialize to YAML

schema/Item.py

Lines changed: 0 additions & 27 deletions
This file was deleted.

schema/Line.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1-
from pydantic import BaseModel, ConfigDict
2-
from typing import List
1+
from pydantic import BaseModel, ConfigDict, Field
2+
from typing import Annotated, List, Literal, Union
33

4-
5-
from schema.Item import Item
4+
from schema.BaseElement import BaseElement
5+
from schema.ThickElement import ThickElement
6+
from schema.DriftElement import DriftElement
7+
from schema.QuadrupoleElement import QuadrupoleElement
68

79

810
class Line(BaseModel):
@@ -12,7 +14,20 @@ class Line(BaseModel):
1214
# not only when an instance of Line is created
1315
model_config = ConfigDict(validate_assignment=True)
1416

15-
line: List[Item]
17+
kind: Literal["Line"] = "Line"
18+
19+
line: List[
20+
Annotated[
21+
Union[
22+
BaseElement,
23+
ThickElement,
24+
DriftElement,
25+
QuadrupoleElement,
26+
"Line",
27+
],
28+
Field(discriminator="kind"),
29+
]
30+
]
1631

1732

1833
# Avoid circular import issues

tests/test_schema.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from schema.DriftElement import DriftElement
1212
from schema.QuadrupoleElement import QuadrupoleElement
1313

14-
from schema.Item import Item
1514
from schema.Line import Line
1615

1716

@@ -105,22 +104,18 @@ def test_QuadrupoleElement():
105104
def test_Line():
106105
# Create first line with one base element
107106
element1 = BaseElement(name="element1")
108-
item1 = Item(item=element1)
109-
line1 = Line(line=[item1])
110-
assert item1.item == element1
111-
assert line1.line == [item1]
107+
line1 = Line(line=[element1])
108+
assert line1.line == [element1]
112109
# Extend first line with one thick element
113110
element2 = ThickElement(name="element2", length=2.0)
114-
item2 = Item(item=element2)
115-
line1.line.extend([item2])
116-
assert line1.line == [item1, item2]
111+
line1.line.extend([element2])
112+
assert line1.line == [element1, element2]
117113
# Create second line with one drift element
118114
element3 = DriftElement(name="element3", length=3.0)
119-
line2 = Line(line=[Item(item=element3)])
115+
line2 = Line(line=[element3])
120116
# Extend first line with second line
121117
line1.line.extend(line2.line)
122-
assert line1.line[:2] == [item1, item2]
123-
assert line1.line[2].item == element3
118+
assert line1.line == [element1, element2, element3]
124119

125120

126121
def test_yaml():
@@ -129,7 +124,7 @@ def test_yaml():
129124
# Create one thick element
130125
element2 = ThickElement(name="element2", length=2.0)
131126
# Create line with both elements
132-
line = Line(line=[Item(item=element1), Item(item=element2)])
127+
line = Line(line=[element1, element2])
133128
# Serialize the Line object to YAML
134129
yaml_data = yaml.dump(line.model_dump(), default_flow_style=False)
135130
print(f"\n{yaml_data}")
@@ -154,7 +149,7 @@ def test_json():
154149
# Create one thick element
155150
element2 = ThickElement(name="element2", length=2.0)
156151
# Create line with both elements
157-
line = Line(line=[Item(item=element1), Item(item=element2)])
152+
line = Line(line=[element1, element2])
158153
# Serialize the Line object to JSON
159154
json_data = json.dumps(line.model_dump(), sort_keys=True, indent=2)
160155
print(f"\n{json_data}")

0 commit comments

Comments
 (0)