Skip to content

Commit dda5538

Browse files
committed
fix(streaming): raise on thread error events
1 parent 43e324e commit dda5538

File tree

2 files changed

+72
-52
lines changed

2 files changed

+72
-52
lines changed

src/openai/_streaming.py

Lines changed: 20 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,22 @@
1919
_T = TypeVar("_T")
2020

2121

22+
def _raise_if_stream_error(data: object, response: httpx.Response) -> None:
23+
if is_mapping(data) and data.get("error"):
24+
message = None
25+
error = data.get("error")
26+
if is_mapping(error):
27+
message = error.get("message")
28+
if not message or not isinstance(message, str):
29+
message = "An error occurred during streaming"
30+
31+
raise APIError(
32+
message=message,
33+
request=response.request,
34+
body=data["error"],
35+
)
36+
37+
2238
class Stream(Generic[_T]):
2339
"""Provides the core interface to iterate over a synchronous stream response."""
2440

@@ -64,36 +80,12 @@ def __stream__(self) -> Iterator[_T]:
6480
if sse.event and sse.event.startswith("thread."):
6581
data = sse.json()
6682

67-
if sse.event == "error" and is_mapping(data) and data.get("error"):
68-
message = None
69-
error = data.get("error")
70-
if is_mapping(error):
71-
message = error.get("message")
72-
if not message or not isinstance(message, str):
73-
message = "An error occurred during streaming"
74-
75-
raise APIError(
76-
message=message,
77-
request=self.response.request,
78-
body=data["error"],
79-
)
83+
_raise_if_stream_error(data, response)
8084

8185
yield process_data(data={"data": data, "event": sse.event}, cast_to=cast_to, response=response)
8286
else:
8387
data = sse.json()
84-
if is_mapping(data) and data.get("error"):
85-
message = None
86-
error = data.get("error")
87-
if is_mapping(error):
88-
message = error.get("message")
89-
if not message or not isinstance(message, str):
90-
message = "An error occurred during streaming"
91-
92-
raise APIError(
93-
message=message,
94-
request=self.response.request,
95-
body=data["error"],
96-
)
88+
_raise_if_stream_error(data, response)
9789

9890
yield process_data(data=data, cast_to=cast_to, response=response)
9991

@@ -167,36 +159,12 @@ async def __stream__(self) -> AsyncIterator[_T]:
167159
if sse.event and sse.event.startswith("thread."):
168160
data = sse.json()
169161

170-
if sse.event == "error" and is_mapping(data) and data.get("error"):
171-
message = None
172-
error = data.get("error")
173-
if is_mapping(error):
174-
message = error.get("message")
175-
if not message or not isinstance(message, str):
176-
message = "An error occurred during streaming"
177-
178-
raise APIError(
179-
message=message,
180-
request=self.response.request,
181-
body=data["error"],
182-
)
162+
_raise_if_stream_error(data, response)
183163

184164
yield process_data(data={"data": data, "event": sse.event}, cast_to=cast_to, response=response)
185165
else:
186166
data = sse.json()
187-
if is_mapping(data) and data.get("error"):
188-
message = None
189-
error = data.get("error")
190-
if is_mapping(error):
191-
message = error.get("message")
192-
if not message or not isinstance(message, str):
193-
message = "An error occurred during streaming"
194-
195-
raise APIError(
196-
message=message,
197-
request=self.response.request,
198-
body=data["error"],
199-
)
167+
_raise_if_stream_error(data, response)
200168

201169
yield process_data(data=data, cast_to=cast_to, response=response)
202170

tests/test_streaming_errors.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
from __future__ import annotations
2+
3+
from typing import Iterator, AsyncIterator
4+
5+
import httpx
6+
import pytest
7+
8+
from openai import OpenAI, AsyncOpenAI
9+
from openai._exceptions import APIError
10+
from openai._streaming import Stream, AsyncStream
11+
12+
13+
@pytest.mark.asyncio
14+
@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
15+
async def test_thread_event_error_raises(sync: bool, client: OpenAI, async_client: AsyncOpenAI) -> None:
16+
def body() -> Iterator[bytes]:
17+
yield b"event: thread.error\n"
18+
yield b'data: {"error": {"message": "boom"}}\n'
19+
yield b"\n"
20+
21+
iterator = make_stream_iterator(content=body(), sync=sync, client=client, async_client=async_client)
22+
23+
with pytest.raises(APIError, match="boom"):
24+
await iter_next(iterator)
25+
26+
27+
async def to_aiter(iter: Iterator[bytes]) -> AsyncIterator[bytes]:
28+
for chunk in iter:
29+
yield chunk
30+
31+
32+
async def iter_next(iter: Iterator[object] | AsyncIterator[object]) -> object:
33+
if isinstance(iter, AsyncIterator):
34+
return await iter.__anext__()
35+
36+
return next(iter)
37+
38+
39+
def make_stream_iterator(
40+
content: Iterator[bytes],
41+
*,
42+
sync: bool,
43+
client: OpenAI,
44+
async_client: AsyncOpenAI,
45+
) -> Iterator[object] | AsyncIterator[object]:
46+
request = httpx.Request("GET", "http://test")
47+
if sync:
48+
response = httpx.Response(200, request=request, content=content)
49+
return iter(Stream(cast_to=object, client=client, response=response))
50+
51+
response = httpx.Response(200, request=request, content=to_aiter(content))
52+
return AsyncStream(cast_to=object, client=async_client, response=response).__aiter__()

0 commit comments

Comments
 (0)