Skip to content

Commit 4b7c31b

Browse files
committed
Fix race between server task setup and cancel
We used to handle this with a timeout, but we had a test to make sure the timeout never actually had to happen. The timeout could in fact need to happen, and the test was flaky. I tried to get Anthropic Claude to solve this, and it noticed the race, but was unable to come up with a design I liked for fixing it, so I did it myself. I don't *really* like my design either, but at least it's mine now. This moves responsibility for marking a WES workflow as CANCELED from the Celery task to the Celery task if it can and the ToilWorkflow get_state() method otherwise. When somebody asks for the state of a workflow, we ask Celery if it's actually stopped or not. If it has stopped without error and is supposedly CANCELING, we declare it CANCELED. I'm removing the timeout-based way to go from CANCELING to CANCELED, because if the task *is* still there and doing stuff, it can't really be canceled yet. It would still be nicer to have the responsibility in one place, but at least this way I'm reducing and not increasing the number of weird methods. To test this I added a sleep that can make the cancel attempt win the race, which involved adding an ugly argument to the fake-Celery code, because we can't monkey-patch at class scope and expect a Multiprocessing process to see it.
1 parent b5c5080 commit 4b7c31b

5 files changed

Lines changed: 159 additions & 118 deletions

File tree

src/toil/server/cli/wes_cwl_runner.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -419,7 +419,9 @@ def submit_run(
419419

420420

421421
def poll_run(client: WESClientWithWorkflowEngineParameters, run_id: str) -> bool:
422-
"""Return True if the given workflow run is in a finished state."""
422+
"""
423+
Return True if the given workflow run is COMPLETE or never will be.
424+
"""
423425
status_result = client.get_run_status(run_id)
424426
state = status_result.get("state")
425427

src/toil/server/utils.py

Lines changed: 2 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -507,8 +507,6 @@ class WorkflowStateMachine:
507507
clients never see e.g. CANCELED -> COMPLETE or COMPLETE -> SYSTEM_ERROR, we
508508
can implement a real distributed state machine here.
509509
510-
We do handle making sure that tasks don't get stuck in CANCELING.
511-
512510
State can be:
513511
514512
"UNKNOWN"
@@ -569,22 +567,11 @@ def send_cancel(self) -> None:
569567
non-terminal state.
570568
"""
571569

572-
state = self.get_current_state()
573-
if state != "CANCELING" and state not in TERMINAL_STATES:
574-
# If it's not obvious we shouldn't cancel, cancel.
575-
576-
# If we end up in CANCELING but the workflow runner task isn't around,
577-
# or we signal it at the wrong time, we will stay there forever,
578-
# because it's responsible for setting the state to anything else.
579-
# So, we save a timestamp, and if we see a CANCELING status and an old
580-
# timestamp, we move on.
581-
self._store.set("cancel_time", get_iso_time())
582-
# Set state after time, because having the state but no time is an error.
583-
self._store.set("state", "CANCELING")
570+
self._set_state("CANCELING")
584571

585572
def send_canceled(self) -> None:
586573
"""
587-
Send a canceled message that would move to CANCELED from CANCELLING.
574+
Send a canceled message that would move from CANCELING to CANCELED.
588575
"""
589576
self._set_state("CANCELED")
590577

@@ -621,28 +608,6 @@ def get_current_state(self) -> str:
621608
# Otherwise do an actual read from backing storage.
622609
state = self._store.get("state")
623610

624-
if state == "CANCELING":
625-
# Make sure it hasn't been CANCELING for too long.
626-
# We can get stuck in CANCELING if the workflow-running task goes
627-
# away or is stopped while reporting back, because it is
628-
# repsonsible for posting back that it has been successfully
629-
# canceled.
630-
canceled_at = self._store.get("cancel_time")
631-
if canceled_at is None:
632-
# If there's no timestamp but it's supposedly canceling, put it
633-
# into SYSTEM_ERROR, because we didn't move to CANCELING properly.
634-
state = "SYSTEM_ERROR"
635-
self._store.set("state", state)
636-
else:
637-
# See if it has been stuck canceling for too long
638-
canceled_at = datetime.fromisoformat(canceled_at)
639-
canceling_seconds = (datetime.now() - canceled_at).total_seconds()
640-
if canceling_seconds > MAX_CANCELING_SECONDS:
641-
# If it has, go to CANCELED instead, because the task is
642-
# nonresponsive and thus not running.
643-
state = "CANCELED"
644-
self._store.set("state", state)
645-
646611
if state in TERMINAL_STATES:
647612
# We can cache this state forever
648613
self._store.write_cache("state", state)

src/toil/server/wes/tasks.py

Lines changed: 66 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,21 @@
1919
import subprocess
2020
import sys
2121
import tempfile
22+
import time
2223
import zipfile
2324
from typing import Any
2425
from urllib.parse import urldefrag
2526

2627
from celery.exceptions import SoftTimeLimitExceeded # type: ignore
28+
import celery.states # type: ignore
2729

2830
import toil.server.wes.amazon_wes_utils as amazon_wes_utils
2931
from toil.common import Toil
3032
from toil.jobStores.utils import generate_locator
3133
from toil.server.celery_app import celery
3234
from toil.server.utils import (
3335
MAX_CANCELING_SECONDS,
36+
TERMINAL_STATES,
3437
WorkflowStateMachine,
3538
connect_to_workflow_state_store,
3639
download_file_from_internet,
@@ -44,9 +47,8 @@
4447

4548
# How many seconds should we give a Toil workflow to gracefully shut down
4649
# before we kill it?
47-
# Ought to be long enough to let it clean up its job store, but shorter than
48-
# our patience for CANCELING WES workflows to time out to CANCELED.
49-
WAIT_FOR_DEATH_TIMEOUT = MAX_CANCELING_SECONDS - 15
50+
# Ought to be long enough to let it clean up its job store.
51+
WAIT_FOR_DEATH_TIMEOUT = 20
5052

5153

5254
class ToilWorkflowRunner:
@@ -456,9 +458,13 @@ def run_wes_task(
456458
logger.info(f"Fetching output files.")
457459
runner.write_output_files()
458460
except (KeyboardInterrupt, SystemExit, SoftTimeLimitExceeded):
459-
# We canceled the workflow run
460-
logger.info("Canceling the workflow")
461-
runner.state_machine.send_canceled()
461+
# We canceled the workflow run after setup.
462+
# We can confirm cancellation, but only if we haven't already declared
463+
# completion.
464+
state = runner.get_state()
465+
if state not in TERMINAL_STATES:
466+
logger.info("Canceling the workflow")
467+
runner.state_machine.send_canceled()
462468
except Exception:
463469
# The workflow run broke. We still count as the executor here.
464470
logger.exception("Running Toil produced an exception.")
@@ -474,8 +480,9 @@ def run_wes_task(
474480

475481
def cancel_run(task_id: str) -> None:
476482
"""
477-
Send a SIGTERM signal to the process that is running task_id.
483+
Send a signal to the process that is running Celery task task_id.
478484
"""
485+
# Celery uses SIGUSR1 for raising SoftTimeLimitExceeded.
479486
celery.control.terminate(task_id, signal="SIGUSR1")
480487

481488

@@ -484,6 +491,9 @@ class TaskRunner:
484491
Abstraction over the Celery API. Runs our run_wes task and allows canceling it.
485492
486493
We can swap this out in the server to allow testing without Celery.
494+
495+
Note that this is not responsible for acting on or having events for things
496+
like failed Celery tasks.
487497
"""
488498

489499
@staticmethod
@@ -505,12 +515,27 @@ def cancel(task_id: str) -> None:
505515
@staticmethod
506516
def is_ok(task_id: str) -> bool:
507517
"""
508-
Make sure that the task running system is working for the given task.
509-
If the task system has detected an internal failure, return False.
518+
Returns True if the task has not yet failed, and False otherwise.
519+
520+
Returns True if the task was successfully canceled.
521+
522+
If False, the task is also not live.
510523
"""
511-
# Nothing to do for Celery
524+
# Poll Celery about the task. See <https://stackoverflow.com/a/38287835>
525+
result = celery.result.AsyncResult(task_id)
526+
if result.status == celery.states.FAILURE:
527+
return False
512528
return True
513529

530+
@staticmethod
531+
def is_live(task_id: str) -> bool:
532+
"""
533+
Returns True if the task has not yet stopped, and False otherwise.
534+
"""
535+
result = celery.result.AsyncResult(task_id)
536+
# Celery "ready" means the result is as available as it is getting
537+
return result.status not in celery.states.READY_STATES
538+
514539

515540
# If Celery can't be set up, we can just use this fake version instead.
516541

@@ -520,16 +545,23 @@ class MultiprocessingTaskRunner(TaskRunner):
520545
Version of TaskRunner that just runs tasks with Multiprocessing.
521546
522547
Can't use threading because there's no way to send a cancel signal or
523-
exception to a Python thread, if loops in the task (i.e.
524-
ToilWorkflowRunner) don't poll for it.
548+
exception to a Python thread, if loops and the task (i.e.
549+
ToilWorkflowRunner) doesn't poll for it.
525550
"""
526551

527552
_id_to_process: dict[str, multiprocessing.Process] = {}
528553
_id_to_log: dict[str, str] = {}
529554

555+
# For testing, we can delay task setup by this many seconds.
556+
# This needs to be smuggled into the multiprocessing child process because
557+
# it won't inherit any replacements and has its own globals/class scopes
558+
setup_delay = 0
559+
530560
@staticmethod
531561
def set_up_and_run_task(
532-
output_path: str, args: tuple[str, str, str, dict[str, Any], list[str]]
562+
output_path: str,
563+
args: tuple[str, str, str, dict[str, Any], list[str]],
564+
setup_delay: int
533565
) -> None:
534566
"""
535567
Set up logging for the process into the given file and then call
@@ -539,6 +571,8 @@ def set_up_and_run_task(
539571
the process crashes, the caller must clean up the log.
540572
"""
541573

574+
time.sleep(setup_delay)
575+
542576
# Multiprocessing and the server manage to hide actual task output from
543577
# the tests. Logging messages will appear in pytest's "live" log but
544578
# not in the captured log. And replacing sys.stdout and sys.stderr
@@ -604,7 +638,7 @@ def run(
604638
)
605639

606640
cls._id_to_process[task_id] = multiprocessing.Process(
607-
target=cls.set_up_and_run_task, args=(path, args)
641+
target=cls.set_up_and_run_task, args=(path, args, cls.setup_delay)
608642
)
609643
cls._id_to_process[task_id].start()
610644

@@ -622,8 +656,11 @@ def cancel(cls, task_id: str) -> None:
622656
@classmethod
623657
def is_ok(cls, task_id: str) -> bool:
624658
"""
625-
Make sure that the task running system is working for the given task.
626-
If the task system has detected an internal failure, return False.
659+
Return True if the task has not yet failed, and False otherwise.
660+
661+
Returns True if the task was successfully canceled.
662+
663+
If False, the task is also not live.
627664
"""
628665

629666
process = cls._id_to_process.get(task_id)
@@ -639,7 +676,7 @@ def is_ok(cls, task_id: str) -> bool:
639676
process.exitcode is not None
640677
and process.exitcode not in ACCEPTABLE_EXIT_CODES
641678
):
642-
# Something went wring in the task and it couldn't handle it.
679+
# Something went wrong in the task and it couldn't handle it.
643680
logger.error(
644681
"Process for running %s failed with code %s", task_id, process.exitcode
645682
)
@@ -655,3 +692,15 @@ def is_ok(cls, task_id: str) -> bool:
655692
return False
656693

657694
return True
695+
696+
@classmethod
697+
def is_live(cls, task_id: str) -> bool:
698+
"""
699+
Returns True if the task has not yet stopped, and False otherwise.
700+
"""
701+
process = cls._id_to_process.get(task_id)
702+
if process is None:
703+
# Never heard of this task, so it's probably in the process of
704+
# getting made
705+
return True
706+
return process.exitcode is None

0 commit comments

Comments
 (0)