diff --git a/runpod/serverless/core.py b/runpod/serverless/core.py index 657dbe64..e951ce27 100644 --- a/runpod/serverless/core.py +++ b/runpod/serverless/core.py @@ -11,6 +11,7 @@ from typing import Any, Callable, Dict, List, Optional from runpod.serverless.modules import rp_job +from runpod.serverless.modules.rp_handler import is_generator from runpod.serverless.modules.rp_logger import RunPodLogger from runpod.version import __version__ as runpod_version @@ -232,7 +233,7 @@ async def _process_job( result = {} try: - if inspect.isgeneratorfunction(handler) or inspect.isasyncgenfunction(handler): + if is_generator(handler): log.debug("SLS Core | Running job as a generator.") generator_output = rp_job.run_job_generator(handler, job) aggregated_output: dict[str, typing.Any] = {"output": []} diff --git a/runpod/serverless/modules/rp_handler.py b/runpod/serverless/modules/rp_handler.py index 2a442f3a..11f04c15 100644 --- a/runpod/serverless/modules/rp_handler.py +++ b/runpod/serverless/modules/rp_handler.py @@ -6,4 +6,15 @@ def is_generator(handler: Callable) -> bool: """Check if handler is a generator function.""" + # handler could be an object that has a __call__ method + if not inspect.isfunction(handler): + handler = getattr(handler, "__call__", lambda: None) return inspect.isgeneratorfunction(handler) or inspect.isasyncgenfunction(handler) + + +def is_async_generator(handler: Callable) -> bool: + """Check if handler is an async generator function.""" + # handler could be an object that has a __call__ method + if not inspect.isfunction(handler): + handler = getattr(handler, "__call__", lambda: None) + return inspect.isasyncgenfunction(handler) diff --git a/runpod/serverless/modules/rp_job.py b/runpod/serverless/modules/rp_job.py index 22f377c2..3f891628 100644 --- a/runpod/serverless/modules/rp_job.py +++ b/runpod/serverless/modules/rp_job.py @@ -2,16 +2,16 @@ Job related helpers. """ +import aiohttp import inspect import json import os import traceback from typing import Any, AsyncGenerator, Callable, Dict, Optional, Union, List -import aiohttp - from runpod.http_client import ClientSession, TooManyRequests from runpod.serverless.modules.rp_logger import RunPodLogger +from runpod.serverless.modules.rp_handler import is_async_generator from ...version import __version__ as runpod_version from ..utils import rp_debugger @@ -230,7 +230,7 @@ async def run_job_generator( Run generator job used to stream output. Yields output partials from the generator. """ - is_async_gen = inspect.isasyncgenfunction(handler) + is_async_gen = is_async_generator(handler) log.debug( "Using Async Generator" if is_async_gen else "Using Standard Generator", job["id"], diff --git a/tests/test_serverless/test_modules/test_handler.py b/tests/test_serverless/test_modules/test_handler.py index 2c68779e..c065d7b4 100644 --- a/tests/test_serverless/test_modules/test_handler.py +++ b/tests/test_serverless/test_modules/test_handler.py @@ -9,6 +9,46 @@ class TestIsGenerator(unittest.TestCase): """Tests for the is_generator function.""" + def test_callable_non_generator_object(self): + """Test that a callable object is not a generator.""" + + class CallableObject: + def __call__(self): + return "I'm a callable object!" + + callable_obj = CallableObject() + self.assertFalse(is_generator(callable_obj)) + + def test_callable_object_generator_object(self): + """Test that a callable object with a generator method is a generator.""" + + class CallableGeneratorObject: + def __call__(self): + yield "I'm a callable object with a generator method!" + + callable_obj = CallableGeneratorObject() + self.assertTrue(is_generator(callable_obj)) + + def test_async_callable_non_generator_object(self): + """Test that an async callable object is not a generator.""" + + class AsyncCallableNonGeneratorObject: + async def __call__(self): + return "I'm an async callable object!" + + async_callable_obj = AsyncCallableNonGeneratorObject() + self.assertFalse(is_generator(async_callable_obj)) + + def test_async_callable_generator_object(self): + """Test that an async callable object with a generator method is a generator.""" + + class AsyncCallableGeneratorObject: + async def __call__(self): + yield "I'm an async callable object with a generator method!" + + async_callable_obj = AsyncCallableGeneratorObject() + self.assertTrue(is_generator(async_callable_obj)) + def test_regular_function(self): """Test that a regular function is not a generator."""