Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ There are several python expressions and idioms that are translated behind your
|List Comprehension | `[j.pt() for j in jets]` | `jets.Select(lambda j: j.pt())` |
|List Comprehension | `[j.pt() for j in jets if abs(j.eta()) < 2.4]` | `jets.Where(lambda j: abs(j.eta()) < 2.4).Select(lambda j: j.pt())` |
|Literal List Comprehension|`[i for i in [1, 2, 3]]`|`[1, 2, 3]`|
|Literal Set Comprehension|`{i for i in [1, 2, 3]}`|`{1, 2, 3}`|
| Data Classes<br>(typed) | `@dataclass`<br>`class my_data:`<br>`x: ObjectStream[Jets]`<br><br>`Select(lambda e: my_data(x=e.Jets()).x)` | `Select(lambda e: {'x': e.Jets()}.x)` |
| Named Tuple<br>(typed) | `class my_data(NamedTuple):`<br>`x: ObjectStream[Jets]`<br><br>`Select(lambda e: my_data(x=e.Jets()).x)` | `Select(lambda e: {'x': e.Jets()}.x)` |
|List Membership|`p.absPdgId() in [35, 51]`|`p.absPdgId() == 35 or p.absPdgId() == 51`|
Expand All @@ -69,6 +70,28 @@ For `any`/`all`, generator/list comprehensions over a literal (or captured liter
are first expanded to a literal list and then reduced as usual. For example,
`any(f(a) for a in [1, 2])` is treated like `any([f(1), f(2)])`.

Set comprehensions are supported only when all iterables in the comprehension are
literal (or captured literal constants). In that case, FuncADL expands the
comprehension at transformation time and emits a literal set expression in the AST.
For an empty result, Python's AST represents this as `set()` instead of `{}` (which
is a dictionary literal), so FuncADL sends `set()` in that case.

If a set comprehension iterates over a non-literal stream (for example
`{j.pt() for j in jets}`), FuncADL raises a `ValueError` rather than guessing a
backend-specific representation.

### What goes over the func_adl wire for sets

When `.value()` is called, FuncADL sends the transformed query AST to the backend.
For sets, the AST seen by the backend is:

- `ast.Set(...)` for non-empty literal set comprehensions
- `ast.Call(func=Name('set'), ...)` for an empty set comprehension

So by the time the query is serialized/sent to a backend, there is no `SetComp`
node left for supported cases; it has already been lowered to ordinary AST nodes
that explicitly represent a set value.

## Extensibility

There are two several extensibility points:
Expand Down
17 changes: 17 additions & 0 deletions docs/source/generic/query_structure.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,23 @@ expressions:
- List comprehensions over literal iterables are expanded directly. For example,
`[i for i in [1, 2, 3]]` becomes `[1, 2, 3]`.
- `any`/`all` over literal lists/tuples are reduced to boolean `or`/`and` expressions.
- Set comprehensions over literal iterables are expanded directly to a set value.

This means patterns like `any(expr(x) for x in LITERAL_LIST)` can be simplified in-query,
as long as the iterable is a literal (or a captured literal constant).

For set comprehensions, only literal iterables are supported. For example,
`{i * 2 for i in [1, 2, 3]}` is lowered before execution to a set literal equivalent.
If the result is empty, the lowered AST is `set()` (Python AST has no empty set literal
syntax node).

If the iterable is not literal (for example `{j.pt() for j in jets}`), FuncADL raises
a `ValueError` because the generic query representation does not define a stream-level
set-construction operator that all backends can execute consistently.

### What is sent to the backend

When `.value()` is called, FuncADL sends the transformed AST query to the backend
executor. For supported set comprehensions, the backend receives regular AST nodes
(`ast.Set` or `set()` call) rather than an `ast.SetComp`. This keeps the wire/query
representation explicit and backend-agnostic.
16 changes: 16 additions & 0 deletions func_adl/ast/syntatic_sugar.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,22 @@ def visit_GeneratorExp(self, node: ast.GeneratorExp) -> Any:

return a

def visit_SetComp(self, node: ast.SetComp) -> Any:
"Translate a set comprehension into a literal set when possible."
a = self.generic_visit(node)

if isinstance(a, ast.SetComp):
if expanded := self._inline_literal_comprehension(a.elt, a.generators, node):
return ast.Set(elts=expanded)
if expanded == []:
return ast.Call(func=ast.Name(id="set", ctx=ast.Load()), args=[], keywords=[])
raise ValueError(
"Set comprehension requires literal iterables so it can be represented"
f" as a set literal - {ast.unparse(node)}"
)

return a

def visit_Compare(self, node: ast.Compare) -> Any:
"""Expand membership tests of an expression against a constant list
or tuple/set into a series of comparisons.
Expand Down
21 changes: 21 additions & 0 deletions tests/ast/test_syntatic_sugar.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,27 @@ def test_resolve_2generator():
) == ast.dump(a_new)


def test_resolve_literal_set_comp():
a = ast.parse("{i * 2 for i in [1, 2, 3]}")
a_new = resolve_syntatic_sugar(a)

assert ast.dump(ast.parse("{1 * 2, 2 * 2, 3 * 2}")) == ast.dump(a_new)


def test_resolve_set_comp_non_literal_iterable_error():
a = ast.parse("{j.pt() for j in jets}")

with pytest.raises(ValueError, match="Set comprehension requires literal iterables"):
resolve_syntatic_sugar(a)


def test_resolve_set_comp_empty_literal_iterable():
a = ast.parse("{j for j in []}")
a_new = resolve_syntatic_sugar(a)

assert ast.unparse(a_new) == ast.unparse(ast.parse("set()"))


def test_resolve_bad_iterator():
a = ast.parse("[j.pt() for idx,j in enumerate(jets)]")
a_new = resolve_syntatic_sugar(a)
Expand Down