Skip to content

Commit 136be53

Browse files
authored
feat: support type-safe Parameter annotations without mypy plugin (#496)
* feat: support generic type for TaskInstanceParameter and ListTaskInstanceParameter * fix: unsupport non typed Task parameter * fix: support new type hint * docs: add next type hint * chore: import from typing_extensions if Python version is less than 3.11 * fix: use StrParameter instead of Parameter
1 parent 905db38 commit 136be53

29 files changed

+289
-269
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,9 @@ class TwoDumpTask(gokart.TaskOnKart[int]):
107107

108108
class AddTask(gokart.TaskOnKart[int]):
109109
# `a` requires a task to dump `int`
110-
a: gokart.TaskOnKart[int] = gokart.TaskInstanceParameter()
110+
a: gokart.TaskInstanceParameter[gokart.TaskOnKart[int]] = gokart.TaskInstanceParameter()
111111
# `b` requires a task to dump `int`
112-
b: gokart.TaskOnKart[int] = gokart.TaskInstanceParameter()
112+
b: gokart.TaskInstanceParameter[gokart.TaskOnKart[int]] = gokart.TaskInstanceParameter()
113113

114114
def requires(self):
115115
return dict(a=self.a, b=self.b)

docs/intro_to_gokart.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ The :func:`~gokart.build` is inline code.
141141
import luigi
142142
143143
class SampleTask(gokart.TaskOnKart[str]):
144-
param = luigi.Parameter()
144+
param: luigi.Parameter = luigi.Parameter()
145145
146146
def run(self):
147147
self.dump(self.param)

docs/setting_task_parameters.rst

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
Setting Task Parameters
33
============================
44

5-
There are several ways to set task parameters.
5+
There are several ways to set task parameters.
66

77
- Set parameter from command line
88
- Set parameter at config file
@@ -26,7 +26,7 @@ Set parameter at config file
2626
[sample.SomeTask]
2727
param = Hello
2828

29-
Above config file (``config.ini``) must be read before ``gokart.run()`` as the following code:
29+
Above config file (``config.ini``) must be read before ``gokart.run()`` as the following code:
3030

3131
.. code:: python
3232
@@ -68,12 +68,12 @@ Parameter values can be inherited from other task using ``@inherits_config_param
6868
.. code:: python
6969
7070
class MasterConfig(luigi.Config):
71-
param: str = luigi.Parameter()
72-
param2: str = luigi.Parameter()
71+
param: luigi.Parameter = luigi.Parameter()
72+
param2: luigi.Parameter = luigi.Parameter()
7373
7474
@inherits_config_params(MasterConfig)
7575
class SomeTask(gokart.TaskOnKart):
76-
param: str = luigi.Parameter()
76+
param: luigi.Parameter = luigi.Parameter()
7777
7878
7979
This is useful when multiple tasks has the same parameter. In the above example, parameter settings of ``MasterConfig`` will be inherited to all tasks decorated with ``@inherits_config_params(MasterConfig)`` as ``SomeTask``.
@@ -84,12 +84,12 @@ In the above example, ``param2`` will not be available in ``SomeTask``, since ``
8484
.. code:: python
8585
8686
class MasterConfig(luigi.Config):
87-
param: str = luigi.Parameter()
88-
param2: str = luigi.Parameter()
87+
param: luigi.Parameter = luigi.Parameter()
88+
param2: luigi.Parameter = luigi.Parameter()
8989
9090
@inherits_config_params(MasterConfig, parameter_alias={'param2': 'param3'})
9191
class SomeTask(gokart.TaskOnKart):
92-
param3: str = luigi.Parameter()
92+
param3: luigi.Parameter = luigi.Parameter()
9393
9494
9595
You may also set a parameter name alias by setting ``parameter_alias``.

docs/task_on_kart.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ How ``TaskOnKart`` helps to define a task looks like:
1717
1818
1919
class TaskA(gokart.TaskOnKart[str]):
20-
param = luigi.Parameter()
20+
param: luigi.Parameter = luigi.Parameter()
2121
2222
def output(self):
2323
return self.make_target('output_of_task_a.pkl')
@@ -28,7 +28,7 @@ How ``TaskOnKart`` helps to define a task looks like:
2828
2929
3030
class TaskB(gokart.TaskOnKart[str]):
31-
param = luigi.Parameter()
31+
param: luigi.Parameter = luigi.Parameter()
3232
3333
def requires(self):
3434
return TaskA(param='world')

docs/task_parameters.rst

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@ Also please refer to :doc:`task_settings` section.
1111
.. code:: python
1212
1313
class Task(gokart.TaskOnKart):
14-
param_a = luigi.Parameter()
15-
param_c = luigi.ListParameter()
16-
param_d = luigi.IntParameter(default=1)
14+
param_a: luigi.Parameter = luigi.Parameter()
15+
param_c: luigi.ListParameter = luigi.ListParameter()
16+
param_d: luigi.IntParameter = luigi.IntParameter(default=1)
1717
1818
Please refer to `luigi document <https://luigi.readthedocs.io/en/stable/api/luigi.parameter.html>`_ for a list of parameter types.
1919

@@ -42,7 +42,7 @@ The :func:`~gokart.parameter.TaskInstanceParameter` executes a task using the re
4242
4343
4444
class TaskB(gokart.TaskOnKart[str]):
45-
require_task = gokart.TaskInstanceParameter()
45+
require_task: gokart.TaskInstanceParameter = gokart.TaskInstanceParameter()
4646
4747
def requires(self):
4848
return self.require_task
@@ -120,7 +120,7 @@ Example
120120
return cls(**json.loads(s))
121121
122122
class DummyTask(gokart.TaskOnKart):
123-
config: Config = gokart.SerializableParameter(object_type=Config)
123+
config: gokart.SerializableParameter[Config] = gokart.SerializableParameter(object_type=Config)
124124
125125
def run(self):
126126
# Save the `config` object as part of the task result.

docs/task_settings.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ If you want to change the parameter of TaskA and rerun TaskB and TaskC, recommen
8383
.. code:: python
8484
8585
class TaskA(gokart.TaskOnKart):
86-
__version = luigi.IntParameter(default=1)
86+
__version: luigi.IntParameter = luigi.IntParameter(default=1)
8787
8888
If the hash value of TaskA will change, the dependent tasks (in this case, TaskB and TaskC) will rerun.
8989

docs/tutorial.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ Modify ``example/gokart_example/model/sample.py`` as follows:
147147
148148
class StringToSplit(GokartTask):
149149
"""Like the function to divide received data by spaces."""
150-
task = gokart.TaskInstanceParameter()
150+
task: gokart.TaskInstanceParameter = gokart.TaskInstanceParameter()
151151
152152
def run(self):
153153
sample = self.load('task')
@@ -240,7 +240,7 @@ Add new parameter on dependent tasks like following:
240240
.. code:: python
241241
242242
class Sample(GokartTask):
243-
version = luigi.IntParameter(default=1)
243+
version: luigi.IntParameter = luigi.IntParameter(default=1)
244244
245245
def run(self):
246246
self.dump('sample output version {self.version}')

gokart/gcs_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111

1212
class GCSConfig(luigi.Config):
13-
gcs_credential_name: str = luigi.Parameter(default='GCS_CREDENTIAL', description='GCS credential environment variable.')
13+
gcs_credential_name: luigi.StrParameter = luigi.StrParameter(default='GCS_CREDENTIAL', description='GCS credential environment variable.')
1414
_client = None
1515

1616
def get_gcs_client(self) -> luigi.contrib.gcs.GCSClient:

gokart/info.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ def make_tree_info(
4545

4646

4747
class tree_info(TaskOnKart[Any]):
48-
mode: str = luigi.Parameter(default='', description='This must be in ["simple", "all"].')
49-
output_path: str = luigi.Parameter(default='tree.txt', description='Output file path.')
48+
mode: luigi.StrParameter = luigi.StrParameter(default='', description='This must be in ["simple", "all"].')
49+
output_path: luigi.StrParameter = luigi.StrParameter(default='tree.txt', description='Output file path.')
5050

5151
def output(self):
5252
return self.make_target(self.output_path, use_unique_id=False)

gokart/parameter.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,27 +3,48 @@
33
import bz2
44
import datetime
55
import json
6+
import sys
67
from logging import getLogger
78
from typing import Any, Generic, Protocol, TypeVar
9+
10+
if sys.version_info >= (3, 11):
11+
from typing import Unpack
12+
else:
13+
from typing_extensions import Unpack
814
from warnings import warn
915

1016
import luigi
1117
from luigi import task_register
1218

19+
try:
20+
from luigi.parameter import _no_value, _NoValueType, _ParameterKwargs
21+
except ImportError:
22+
_no_value = None # type: ignore[assignment]
23+
_NoValueType = type(None) # type: ignore[assignment,misc]
24+
_ParameterKwargs = dict # type: ignore[assignment,misc]
25+
1326
import gokart
1427

1528
logger = getLogger(__name__)
1629

1730

18-
class TaskInstanceParameter(luigi.Parameter):
19-
def __init__(self, expected_type=None, *args, **kwargs):
31+
TASK_ON_KART_TYPE = TypeVar('TASK_ON_KART_TYPE', bound='gokart.TaskOnKart') # type: ignore
32+
33+
34+
class TaskInstanceParameter(luigi.Parameter[TASK_ON_KART_TYPE], Generic[TASK_ON_KART_TYPE]):
35+
def __init__(
36+
self,
37+
expected_type: type[TASK_ON_KART_TYPE] | None = None,
38+
default: TASK_ON_KART_TYPE | _NoValueType = _no_value,
39+
**kwargs: Unpack[_ParameterKwargs],
40+
):
2041
if expected_type is None:
2142
self.expected_type: type = gokart.TaskOnKart
2243
elif isinstance(expected_type, type):
2344
self.expected_type = expected_type
2445
else:
2546
raise TypeError(f'expected_type must be a type, not {type(expected_type)}')
26-
super().__init__(*args, **kwargs)
47+
super().__init__(default=default, **kwargs)
2748

2849
@staticmethod
2950
def _recursive(param_dict):
@@ -64,15 +85,20 @@ def default(self, obj):
6485
return json.JSONEncoder.default(self, obj)
6586

6687

67-
class ListTaskInstanceParameter(luigi.Parameter):
68-
def __init__(self, expected_elements_type=None, *args, **kwargs):
88+
class ListTaskInstanceParameter(luigi.Parameter[list[TASK_ON_KART_TYPE]], Generic[TASK_ON_KART_TYPE]):
89+
def __init__(
90+
self,
91+
expected_elements_type: type[TASK_ON_KART_TYPE] | None = None,
92+
default: list[TASK_ON_KART_TYPE] | _NoValueType = _no_value,
93+
**kwargs: Unpack[_ParameterKwargs],
94+
):
6995
if expected_elements_type is None:
7096
self.expected_elements_type: type = gokart.TaskOnKart
7197
elif isinstance(expected_elements_type, type):
7298
self.expected_elements_type = expected_elements_type
7399
else:
74100
raise TypeError(f'expected_elements_type must be a type, not {type(expected_elements_type)}')
75-
super().__init__(*args, **kwargs)
101+
super().__init__(default=default, **kwargs)
76102

77103
def parse(self, s):
78104
return [TaskInstanceParameter().parse(x) for x in list(json.loads(s))]
@@ -113,7 +139,7 @@ def gokart_deserialize(cls: type[T], s: str) -> T:
113139
S = TypeVar('S', bound=Serializable)
114140

115141

116-
class SerializableParameter(luigi.Parameter, Generic[S]):
142+
class SerializableParameter(luigi.Parameter[S], Generic[S]):
117143
def __init__(self, object_type: type[S], *args: Any, **kwargs: Any) -> None:
118144
self._object_type = object_type
119145
super().__init__(*args, **kwargs)
@@ -125,7 +151,7 @@ def serialize(self, x: S) -> str:
125151
return x.gokart_serialize()
126152

127153

128-
class ZonedDateSecondParameter(luigi.Parameter):
154+
class ZonedDateSecondParameter(luigi.Parameter[datetime.datetime]):
129155
"""
130156
ZonedDateSecondParameter supports a datetime.datetime object with timezone information.
131157

0 commit comments

Comments
 (0)