Skip to content

Commit ae6a313

Browse files
kitagryclaudeCopilot
authored
fix: replace multiprocessing.set_start_method('fork') with get_context('fork') (#478)
* fix: replace multiprocessing.set_start_method('fork') with get_context('fork') - Use multiprocessing.get_context('fork') instead of set_start_method('fork') to avoid RuntimeError when other libraries set the start method first - Route all multiprocessing calls through _fork_context in worker.py - Default to gokart's WorkerSchedulerFactory in build() and run() Closes #469 Co-Authored-By: Claude Opus 4.6 <[email protected]> * docs: update fork comment Co-authored-by: Copilot <[email protected]> --------- Co-authored-by: Claude Opus 4.6 <[email protected]> Co-authored-by: Copilot <[email protected]>
1 parent a4c33d3 commit ae6a313

File tree

3 files changed

+58
-10
lines changed

3 files changed

+58
-10
lines changed

gokart/build.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ def _build_task():
211211
task_lock_exception_raised.flag = False
212212
result = luigi.build(
213213
[task],
214+
worker_scheduler_factory=WorkerSchedulerFactory(),
214215
local_scheduler=True,
215216
detailed_summary=True,
216217
log_level=logging.getLevelName(log_level),

gokart/run.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,22 @@
11
from __future__ import annotations
22

3+
import logging
34
import os
45
import sys
56
from logging import getLogger
67

78
import luigi
89
import luigi.cmdline
10+
import luigi.cmdline_parser
11+
import luigi.execution_summary
12+
import luigi.interface
913
import luigi.retcodes
14+
import luigi.setup_logging
1015
from luigi.cmdline_parser import CmdlineParser
1116

1217
import gokart
1318
import gokart.slack
19+
from gokart.build import WorkerSchedulerFactory
1420
from gokart.object_storage import ObjectStorage
1521

1622
logger = getLogger(__name__)
@@ -80,6 +86,47 @@ def _try_to_send_event_summary_to_slack(slack_api: gokart.slack.SlackAPI | None,
8086
slack_api.send_snippet(comment=comment, title='event.txt', content=content)
8187

8288

89+
def _run_with_retcodes(argv):
90+
"""run_with_retcodes equivalent that uses gokart's WorkerSchedulerFactory."""
91+
retcode_logger = logging.getLogger('luigi-interface')
92+
with luigi.cmdline_parser.CmdlineParser.global_instance(argv):
93+
retcodes = luigi.retcodes.retcode()
94+
95+
worker = None
96+
try:
97+
worker = luigi.interface._run(argv, worker_scheduler_factory=WorkerSchedulerFactory()).worker
98+
except luigi.interface.PidLockAlreadyTakenExit:
99+
sys.exit(retcodes.already_running)
100+
except Exception:
101+
env_params = luigi.interface.core()
102+
luigi.setup_logging.InterfaceLogging.setup(env_params)
103+
retcode_logger.exception('Uncaught exception in luigi')
104+
sys.exit(retcodes.unhandled_exception)
105+
106+
with luigi.cmdline_parser.CmdlineParser.global_instance(argv):
107+
task_sets = luigi.execution_summary._summary_dict(worker)
108+
root_task = luigi.execution_summary._root_task(worker)
109+
non_empty_categories = {k: v for k, v in task_sets.items() if v}.keys()
110+
111+
def has(status):
112+
assert status in luigi.execution_summary._ORDERED_STATUSES
113+
return status in non_empty_categories
114+
115+
codes_and_conds = (
116+
(retcodes.missing_data, has('still_pending_ext')),
117+
(retcodes.task_failed, has('failed')),
118+
(retcodes.already_running, has('run_by_other_worker')),
119+
(retcodes.scheduling_error, has('scheduling_error')),
120+
(retcodes.not_run, has('not_run')),
121+
)
122+
expected_ret_code = max(code * (1 if cond else 0) for code, cond in codes_and_conds)
123+
124+
if expected_ret_code == 0 and root_task not in task_sets['completed'] and root_task not in task_sets['already_done']:
125+
sys.exit(retcodes.not_run)
126+
else:
127+
sys.exit(expected_ret_code)
128+
129+
83130
def run(cmdline_args=None, set_retcode=True):
84131
cmdline_args = cmdline_args or sys.argv[1:]
85132

@@ -98,7 +145,7 @@ def run(cmdline_args=None, set_retcode=True):
98145
event_aggregator = gokart.slack.EventAggregator()
99146
try:
100147
event_aggregator.set_handlers()
101-
luigi.cmdline.luigi_run(cmdline_args)
148+
_run_with_retcodes(cmdline_args)
102149
except SystemExit as e:
103150
_try_to_send_event_summary_to_slack(slack_api, event_aggregator, cmdline_args)
104151
sys.exit(e.code)

gokart/worker.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,9 @@
6767

6868
logger = logging.getLogger(__name__)
6969

70-
# Set the start method to fork, which is the default on Unix systems.
71-
# This is necessary because the default start method on macOS is spawn, which is not compatible with the multiprocessing
72-
multiprocessing.set_start_method('fork')
70+
# Use fork context instead of the default (spawn on macOS), which ensures compatibility with gokart's multiprocessing requirements.
71+
_fork_context = multiprocessing.get_context('fork')
72+
_ForkProcess = _fork_context.Process
7373

7474
# Prevent fork() from being called during a C-level getaddrinfo() which uses a process-global mutex,
7575
# that may not be unlocked in child process, resulting in the process being locked indefinitely.
@@ -106,7 +106,7 @@ def _get_retry_policy_dict(task: Task) -> dict[str, Any]:
106106
)
107107

108108

109-
class TaskProcess(multiprocessing.Process):
109+
class TaskProcess(_ForkProcess): # type: ignore[valid-type, misc]
110110
"""Wrap all task execution in this class.
111111
112112
Mainly for convenience since this is run in a separate process."""
@@ -447,14 +447,14 @@ def __init__(
447447
pass
448448

449449
# Keep info about what tasks are running (could be in other processes)
450-
self._task_result_queue: multiprocessing.Queue = multiprocessing.Queue()
450+
self._task_result_queue: multiprocessing.Queue = _fork_context.Queue()
451451
self._running_tasks: dict[str, TaskProcess] = {}
452452
self._idle_since: datetime.datetime | None = None
453453

454454
# mp-safe dictionary for caching completation checks across task processes
455455
self._task_completion_cache = None
456456
if self._config.cache_task_completion:
457-
self._task_completion_cache = multiprocessing.Manager().dict()
457+
self._task_completion_cache = _fork_context.Manager().dict()
458458

459459
# Stuff for execution_summary
460460
self._add_task_history: list[Any] = []
@@ -641,8 +641,8 @@ def add(self, task: Task, multiprocess: bool = False, processes: int = 0) -> boo
641641
self._first_task = task.task_id
642642
self.add_succeeded = True
643643
if multiprocess:
644-
queue: Any = multiprocessing.Manager().Queue()
645-
pool: Any = multiprocessing.Pool(processes=processes if processes > 0 else None)
644+
queue: Any = _fork_context.Manager().Queue()
645+
pool: Any = _fork_context.Pool(processes=processes if processes > 0 else None)
646646
else:
647647
queue = luigi.worker.DequeQueue()
648648
pool = luigi.worker.SingleProcessPool()
@@ -905,7 +905,7 @@ def _run_task(self, task_id: str) -> None:
905905
task_process.run()
906906

907907
def _create_task_process(self, task):
908-
message_queue: Any = multiprocessing.Queue() if task.accepts_messages else None
908+
message_queue: Any = _fork_context.Queue() if task.accepts_messages else None
909909
reporter = luigi.worker.TaskStatusReporter(self._scheduler, task.task_id, self._id, message_queue)
910910
use_multiprocessing = self._config.force_multiprocessing or bool(self.worker_processes > 1)
911911
return ContextManagedTaskProcess(

0 commit comments

Comments
 (0)