Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion runpod/serverless/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing import Any, Callable, Dict, List, Optional

from runpod.serverless.modules import rp_job
from runpod.serverless.modules.rp_handler import is_generator
from runpod.serverless.modules.rp_logger import RunPodLogger
from runpod.version import __version__ as runpod_version

Expand Down Expand Up @@ -232,7 +233,7 @@ async def _process_job(

result = {}
try:
if inspect.isgeneratorfunction(handler) or inspect.isasyncgenfunction(handler):
if is_generator(handler):
log.debug("SLS Core | Running job as a generator.")
generator_output = rp_job.run_job_generator(handler, job)
aggregated_output: dict[str, typing.Any] = {"output": []}
Expand Down
11 changes: 11 additions & 0 deletions runpod/serverless/modules/rp_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,15 @@

def is_generator(handler: Callable) -> bool:
"""Check if handler is a generator function."""
# handler could be an object that has a __call__ method
if not inspect.isfunction(handler):
handler = getattr(handler, "__call__", lambda: None)
return inspect.isgeneratorfunction(handler) or inspect.isasyncgenfunction(handler)


def is_async_generator(handler: Callable) -> bool:
"""Check if handler is an async generator function."""
# handler could be an object that has a __call__ method
if not inspect.isfunction(handler):
handler = getattr(handler, "__call__", lambda: None)
return inspect.isasyncgenfunction(handler)
6 changes: 3 additions & 3 deletions runpod/serverless/modules/rp_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@
Job related helpers.
"""

import aiohttp
import inspect
import json
import os
import traceback
from typing import Any, AsyncGenerator, Callable, Dict, Optional, Union, List

import aiohttp

from runpod.http_client import ClientSession, TooManyRequests
from runpod.serverless.modules.rp_logger import RunPodLogger
from runpod.serverless.modules.rp_handler import is_async_generator

from ...version import __version__ as runpod_version
from ..utils import rp_debugger
Expand Down Expand Up @@ -230,7 +230,7 @@ async def run_job_generator(
Run generator job used to stream output.
Yields output partials from the generator.
"""
is_async_gen = inspect.isasyncgenfunction(handler)
is_async_gen = is_async_generator(handler)
log.debug(
"Using Async Generator" if is_async_gen else "Using Standard Generator",
job["id"],
Expand Down
40 changes: 40 additions & 0 deletions tests/test_serverless/test_modules/test_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,46 @@
class TestIsGenerator(unittest.TestCase):
"""Tests for the is_generator function."""

def test_callable_non_generator_object(self):
"""Test that a callable object is not a generator."""

class CallableObject:
def __call__(self):
return "I'm a callable object!"

callable_obj = CallableObject()
self.assertFalse(is_generator(callable_obj))

def test_callable_object_generator_object(self):
"""Test that a callable object with a generator method is a generator."""

class CallableGeneratorObject:
def __call__(self):
yield "I'm a callable object with a generator method!"

callable_obj = CallableGeneratorObject()
self.assertTrue(is_generator(callable_obj))

def test_async_callable_non_generator_object(self):
"""Test that an async callable object is not a generator."""

class AsyncCallableNonGeneratorObject:
async def __call__(self):
return "I'm an async callable object!"

async_callable_obj = AsyncCallableNonGeneratorObject()
self.assertFalse(is_generator(async_callable_obj))

def test_async_callable_generator_object(self):
"""Test that an async callable object with a generator method is a generator."""

class AsyncCallableGeneratorObject:
async def __call__(self):
yield "I'm an async callable object with a generator method!"

async_callable_obj = AsyncCallableGeneratorObject()
self.assertTrue(is_generator(async_callable_obj))

def test_regular_function(self):
"""Test that a regular function is not a generator."""

Expand Down
Loading