Skip to content

Commit e638c54

Browse files
committed
Update types to work with non required params
1 parent 81d2b20 commit e638c54

1 file changed

Lines changed: 19 additions & 6 deletions

File tree

fhirpathpy/__init__.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

3-
from collections.abc import Callable, Mapping
4-
from typing import Any, TypeAlias, TypeVar, cast
3+
from collections.abc import Callable, Mapping, Sequence
4+
from typing import Any, Protocol, TypeAlias, TypeVar, cast
55

66
from fhirpathpy.engine import do_eval
77
from fhirpathpy.engine.invocations.constants import constants
@@ -144,11 +144,24 @@ def compile(
144144

145145
InputType = TypeVar("InputType")
146146
OutputType = TypeVar("OutputType")
147+
# Contravariant: a callable accepting a wider input type is a valid subtype.
148+
_I_contra = TypeVar("_I_contra", contravariant=True)
149+
# Covariant: a callable returning a narrower output type is a valid subtype.
150+
# Sequence (read-only) is required here because list is invariant and rejects covariant TypeVars.
151+
_O_co = TypeVar("_O_co", covariant=True)
152+
153+
154+
class CompiledFirst(Protocol[_I_contra, _O_co]):
155+
def __call__(self, resource: _I_contra, context: ContextType = ...) -> _O_co | None: ...
156+
157+
158+
class CompiledArray(Protocol[_I_contra, _O_co]):
159+
def __call__(self, resource: _I_contra, context: ContextType = ...) -> Sequence[_O_co]: ...
147160

148161

149162
def compile_as_array(
150163
expression: str, input_type: type[InputType], output_type: type[OutputType]
151-
) -> Callable[[InputType, ContextType], list[OutputType]]:
164+
) -> CompiledArray[InputType, OutputType]:
152165
path_fn = compile(expression)
153166

154167
def fn(resource: Any, context: ContextType = None) -> Any:
@@ -158,12 +171,12 @@ def fn(resource: Any, context: ContextType = None) -> Any:
158171
is_array=True,
159172
)
160173

161-
return cast(Callable[[InputType, ContextType], list[OutputType]], fn)
174+
return cast(CompiledArray[InputType, OutputType], fn)
162175

163176

164177
def compile_as_first(
165178
expression: str, input_type: type[InputType], output_type: type[OutputType]
166-
) -> Callable[[InputType, ContextType], OutputType | None]:
179+
) -> CompiledFirst[InputType, OutputType]:
167180
path_fn = compile(expression)
168181

169182
def fn(resource: Any, context: ContextType = None) -> Any:
@@ -173,7 +186,7 @@ def fn(resource: Any, context: ContextType = None) -> Any:
173186
is_array=False,
174187
)
175188

176-
return cast(Callable[[InputType, ContextType], OutputType | None], fn)
189+
return cast(CompiledFirst[InputType, OutputType], fn)
177190

178191

179192
def _prepare_data(resource: Any, input_type: type[InputType]) -> ResourceType:

0 commit comments

Comments
 (0)