Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions src/lib/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@
gi.require_version("Json", "1.0")
from gi.repository import GLib, Json # noqa: E402

T = t.TypeVar("T")

log = logging.getLogger(__name__)


Expand Down Expand Up @@ -588,3 +590,29 @@ def init_logging(level=logging.DEBUG):
logging.basicConfig(level=level, format="%(levelname)-7s %(name)s: %(message)s")
if level == logging.DEBUG:
logging.getLogger("github.Requester").setLevel(logging.INFO)


async def asyncio_gather_failfast(
awaitables: t.Iterable[t.Awaitable[T]],
) -> t.List[T]:
tasks: t.List[asyncio.Task[T]] = [asyncio.ensure_future(a) for a in awaitables]
if not tasks:
return []
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_EXCEPTION)
exc: t.Optional[BaseException] = None
for task in done:
if task.cancelled():
continue
try:
e = task.exception()
except asyncio.CancelledError:
continue
if e is not None:
exc = e
break
if exc:
for task in pending:
task.cancel()
await asyncio.gather(*pending, return_exceptions=True)
raise exc
return [task.result() for task in tasks]
5 changes: 2 additions & 3 deletions src/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import datetime
import dataclasses
import typing as t
import asyncio
from enum import IntEnum
import logging
import os
Expand All @@ -36,7 +35,7 @@
BuilderModule,
ExternalBase,
)
from .lib.utils import read_manifest, dump_manifest
from .lib.utils import asyncio_gather_failfast, read_manifest, dump_manifest
from .lib.errors import (
CheckerError,
AppdataError,
Expand Down Expand Up @@ -414,7 +413,7 @@ async def check(self, filter_type=None) -> t.List[ExternalBase]:
check_tasks.append(self._check_data(counter, http_session, data))

log.info("Checking %s external data items", counter.total)
ext_data_checked = await asyncio.gather(*check_tasks)
ext_data_checked = await asyncio_gather_failfast(check_tasks)

return ext_data_checked

Expand Down
16 changes: 16 additions & 0 deletions tests/test_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,22 @@ async def check(self, external_data):
)


class RaisesAssertionChecker(DummyChecker, register=False):
async def check(self, external_data):
raise AssertionError("test error")


class TestFailFast(unittest.IsolatedAsyncioTestCase):
def setUp(self):
init_logging()

async def test_unexpected_exception_propagates(self):
checker = manifest.ManifestChecker(TEST_MANIFEST)
checker._checkers = [RaisesAssertionChecker]
with self.assertRaises(AssertionError):
await checker.check()


class _TestWithInlineManifest(unittest.IsolatedAsyncioTestCase):
_DUMMY_CHECKER_CLS: t.Type[Checker]
maxDiff = None
Expand Down
78 changes: 77 additions & 1 deletion tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from textwrap import dedent

import aiohttp

import asyncio
from src.lib.errors import CheckerFetchError
from src.lib.utils import (
parse_github_url,
Expand All @@ -40,6 +40,7 @@
get_extra_data_info_from_url,
Command,
dump_manifest,
asyncio_gather_failfast,
)


Expand Down Expand Up @@ -305,3 +306,78 @@ def test_editorconfig(self):

if __name__ == "__main__":
unittest.main()


class TestGatherFailfast(unittest.IsolatedAsyncioTestCase):
async def test_all_succeed(self):
async def succeed(val):
return val

result = await asyncio_gather_failfast([succeed(1), succeed(2), succeed(3)])
self.assertEqual(result, [1, 2, 3])

async def test_empty(self):
result = await asyncio_gather_failfast([])
self.assertEqual(result, [])

async def test_exception_cancels_pending(self):
cancelled = False

async def hang():
nonlocal cancelled
try:
await asyncio.sleep(999)
except asyncio.CancelledError:
cancelled = True
raise

async def fail():
raise AssertionError("test error")

with self.assertRaises(AssertionError):
await asyncio_gather_failfast([hang(), fail()])

self.assertTrue(cancelled)

async def test_exception_propagates(self):
async def fail():
raise ValueError("test error")

with self.assertRaises(ValueError, msg="test error"):
await asyncio_gather_failfast([fail()])

async def test_parent_child_hang(self):
event = asyncio.Event()
child_cancelled = False

async def parent():
raise AssertionError("test error")

async def child():
nonlocal child_cancelled
try:
await event.wait()
except asyncio.CancelledError:
child_cancelled = True
raise

with self.assertRaises(AssertionError):
await asyncio_gather_failfast([parent(), child()])

self.assertTrue(child_cancelled)

async def test_sibling_cancelled_on_exception(self):
async def raises_assert():
raise AssertionError("test error")

async def slow():
await asyncio.sleep(999)

with self.assertRaises((AssertionError, asyncio.TimeoutError)):
await asyncio.wait_for(
asyncio.gather(raises_assert(), slow()),
timeout=2.0,
)

with self.assertRaises(AssertionError):
await asyncio_gather_failfast([raises_assert(), slow()])