Skip to content

Commit cd03e2f

Browse files
authored
Flight split based on condition (#457)
1 parent e7870cb commit cd03e2f

File tree

3 files changed

+105
-7
lines changed

3 files changed

+105
-7
lines changed

src/traffic/core/flight.py

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1456,14 +1456,27 @@ def sliding_windows(
14561456
yield from after.sliding_windows(duration_, step_)
14571457

14581458
@overload
1459-
def split(self, value: int, unit: str) -> FlightIterator: ...
1459+
def split(
1460+
self,
1461+
value: int,
1462+
unit: str,
1463+
condition: None | Callable[["Flight", "Flight"], bool] = None,
1464+
) -> FlightIterator: ...
14601465

14611466
@overload
1462-
def split(self, value: str, unit: None = None) -> FlightIterator: ...
1467+
def split(
1468+
self,
1469+
value: str,
1470+
unit: None = None,
1471+
condition: None | Callable[["Flight", "Flight"], bool] = None,
1472+
) -> FlightIterator: ...
14631473

14641474
@flight_iterator
14651475
def split(
1466-
self, value: Union[int, str] = 10, unit: Optional[str] = None
1476+
self,
1477+
value: Union[int, str] = 10,
1478+
unit: Optional[str] = None,
1479+
condition: None | Callable[["Flight", "Flight"], bool] = None,
14671480
) -> Iterator["Flight"]:
14681481
"""Iterates on legs of a Flight based on the distribution of timestamps.
14691482
@@ -1476,13 +1489,51 @@ def split(
14761489
``np.timedelta64``);
14771490
- in the pandas style: ``Flight.split('10T')`` (see ``pd.Timedelta``)
14781491
1492+
If the `condition` parameter is set, the flight is split between two
1493+
segments only if `condition(f1, f2)` is verified.
1494+
1495+
Example:
1496+
1497+
.. code:: python
1498+
1499+
def no_split_below_5000ft(f1, f2):
1500+
first = f1.data.iloc[-1].altitude >= 5000
1501+
second = f2.data.iloc[0].altitude >= 5000
1502+
return first or second
1503+
1504+
# would yield many segments
1505+
belevingsvlucht.query('altitude > 2000').split('1 min')
1506+
1507+
# yields only one segment
1508+
belevingsvlucht.query('altitude > 2000').split(
1509+
'1 min', condition = no_split_below_5000ft
1510+
)
1511+
14791512
"""
14801513
if isinstance(value, int) and unit is None:
14811514
# default value is 10 m
14821515
unit = "m"
14831516

1484-
for data in _split(self.data, value, unit):
1485-
yield self.__class__(data)
1517+
if condition is None:
1518+
for data in _split(self.data, value, unit):
1519+
yield self.__class__(data)
1520+
1521+
else:
1522+
previous = None
1523+
for data in _split(self.data, value, unit):
1524+
if previous is None:
1525+
previous = self.__class__(data)
1526+
else:
1527+
latest = self.__class__(data)
1528+
if condition(previous, latest):
1529+
yield previous
1530+
previous = latest
1531+
else:
1532+
previous = self.__class__(
1533+
pd.concat([previous.data, data])
1534+
)
1535+
if previous is not None:
1536+
yield previous
14861537

14871538
def max_split(
14881539
self,

src/traffic/core/iterator.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,24 @@ def min(self, key: str = "duration") -> Optional["Flight"]:
232232
"""
233233
return min(self, key=lambda x: getattr(x, key), default=None)
234234

235+
def map(
236+
self, fun: Callable[["Flight"], Optional["Flight"]]
237+
) -> "FlightIterator":
238+
"""Applies a function on each segment of an Iterator.
239+
240+
For instance:
241+
242+
>>> flight.split("10min").map(lambda f: f.resample("2s")).all()
243+
244+
"""
245+
246+
def aux(self: FlightIterator) -> Iterator["Flight"]:
247+
for segment in self:
248+
if (result := fun(segment)) is not None:
249+
yield result
250+
251+
return flight_iterator(aux)(self)
252+
235253
def __call__(
236254
self,
237255
fun: Callable[..., "LazyTraffic"],
@@ -272,8 +290,6 @@ def flight_iterator(
272290
fun.__annotations__["return"] == Iterator["Flight"]
273291
or eval(fun.__annotations__["return"]) == Iterator["Flight"]
274292
):
275-
print(eval(fun.__annotations__["return"]))
276-
print(Iterator["Flight"])
277293
raise TypeError(msg)
278294

279295
@functools.wraps(fun, updated=("__dict__", "__annotations__"))

tests/test_flight.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -913,3 +913,34 @@ def test_label() -> None:
913913
ils = set(ils for ils in labelled.ILS_unique if ils is not None)
914914
assert ils == {"14", "28"}
915915
assert labelled.duration_min > pd.Timedelta("2 min 30 s")
916+
917+
918+
def test_split_condition() -> None:
919+
def no_split_below_5000(f1: Flight, f2: Flight) -> bool:
920+
return ( # type: ignore
921+
f1.data.iloc[-1].altitude >= 5000
922+
or f2.data.iloc[0].altitude >= 5000
923+
)
924+
925+
f_max = (
926+
belevingsvlucht.query("altitude > 2000") # type: ignore
927+
.split(
928+
"1 min",
929+
condition=no_split_below_5000,
930+
)
931+
.max()
932+
)
933+
934+
assert f_max is not None
935+
assert f_max.start - belevingsvlucht.start < pd.Timedelta("5 min")
936+
assert belevingsvlucht.stop - f_max.stop < pd.Timedelta("10 min")
937+
938+
939+
def test_split_map() -> None:
940+
result = (
941+
belevingsvlucht.aligned_on_ils("EHLE")
942+
.map(lambda f: f.resample("10s"))
943+
.all()
944+
)
945+
assert result is not None
946+
assert 140 <= len(result) <= 160

0 commit comments

Comments
 (0)