11import asyncio
2+ import datetime
23import random
34import sys
5+ import time
46import typing
57
68import pytest
1012from zero .protocols .tcp import AsyncTCPClient
1113
1214
15+ def get_async_client ():
16+ from . import tcp_server
17+
18+ return AsyncZeroClient (
19+ tcp_server .HOST ,
20+ tcp_server .PORT ,
21+ protocol = AsyncTCPClient ,
22+ default_timeout = 5000 , # github runners can be slow
23+ pool_size = 5 ,
24+ )
25+
26+
1327@pytest .mark .skipif (
1428 sys .platform == "win32" , reason = "TCP tests not supported on Windows"
1529)
1630@pytest .mark .asyncio
1731async def test_concurrent_divide ():
18- from . import tcp_server
19-
20- async_client = AsyncZeroClient (
21- tcp_server .HOST , tcp_server .PORT , protocol = AsyncTCPClient
22- )
32+ client = get_async_client ()
2333
2434 req_resp = {
2535 (10 , 2 ): 5 ,
@@ -43,9 +53,7 @@ async def test_concurrent_divide():
4353 async def divide (semaphore , req ):
4454 async with semaphore :
4555 try :
46- assert (
47- await async_client .call ("divide" , req , timeout = 500 ) == req_resp [req ]
48- )
56+ assert await client .call ("divide" , req , timeout = 500 ) == req_resp [req ]
4957 nonlocal total_pass
5058 total_pass += 1
5159 except zero .error .TimeoutException :
@@ -64,9 +72,8 @@ async def divide(semaphore, req):
6472)
6573@pytest .mark .asyncio
6674async def test_server_error ():
67- from . import tcp_server
75+ client = get_async_client ()
6876
69- client = AsyncZeroClient (tcp_server .HOST , tcp_server .PORT , protocol = AsyncTCPClient )
7077 try :
7178 await client .call ("error" , "some error" )
7279 raise AssertionError ("Should have thrown an Exception" )
@@ -79,9 +86,7 @@ async def test_server_error():
7986)
8087@pytest .mark .asyncio
8188async def test_timeout_all_async ():
82- from . import tcp_server
83-
84- client = AsyncZeroClient (tcp_server .HOST , tcp_server .PORT , protocol = AsyncTCPClient )
89+ client = get_async_client ()
8590
8691 with pytest .raises (zero .error .TimeoutException ):
8792 await client .call ("sleep" , 1000 , timeout = 10 )
@@ -95,9 +100,7 @@ async def test_timeout_all_async():
95100)
96101@pytest .mark .asyncio
97102async def test_random_timeout_async ():
98- from . import tcp_server
99-
100- client = AsyncZeroClient (tcp_server .HOST , tcp_server .PORT , protocol = AsyncTCPClient )
103+ client = get_async_client ()
101104
102105 fails = 0
103106 should_fail = 0
@@ -122,32 +125,24 @@ async def test_random_timeout_async():
122125)
123126@pytest .mark .asyncio
124127async def test_return_type_parameter ():
125- """Test that return_type parameter is used for proper decoding."""
126- from . import tcp_server
127-
128- client = AsyncZeroClient (tcp_server .HOST , tcp_server .PORT , protocol = AsyncTCPClient )
128+ client = get_async_client ()
129129
130- # Test with int return type
131130 result = await client .call ("echo_int" , 42 , return_type = int )
132131 assert result == 42
133132 assert isinstance (result , int )
134133
135- # Test with str return type
136134 result = await client .call ("echo_str" , "hello" , return_type = str )
137135 assert result == "hello"
138136 assert isinstance (result , str )
139137
140- # Test with float return type
141138 result = await client .call ("echo_float" , 3.14 , return_type = float )
142139 assert result == 3.14
143140 assert isinstance (result , float )
144141
145- # Test with bool return type
146142 result = await client .call ("echo_bool" , True , return_type = bool )
147143 assert result is True
148144 assert isinstance (result , bool )
149145
150- # Test with list return type
151146 result = await client .call ("echo_list" , [1 , 2 , 3 ], return_type = list [int ])
152147 assert result == [1 , 2 , 3 ]
153148 assert isinstance (result , list )
@@ -158,19 +153,14 @@ async def test_return_type_parameter():
158153)
159154@pytest .mark .asyncio
160155async def test_complex_return_types_union ():
161- """Test Union return types with proper decoding."""
162- from . import tcp_server
163-
164- client = AsyncZeroClient (tcp_server .HOST , tcp_server .PORT , protocol = AsyncTCPClient )
156+ client = get_async_client ()
165157
166- # Test Union[int, str] with int value
167158 result = await client .call (
168159 "echo_typing_union" , 42 , return_type = typing .Union [int , str ]
169160 )
170161 assert result == 42
171162 assert isinstance (result , int )
172163
173- # Test Union[int, str] with str value
174164 result = await client .call (
175165 "echo_typing_union" , "hello" , return_type = typing .Union [int , str ]
176166 )
@@ -183,10 +173,7 @@ async def test_complex_return_types_union():
183173)
184174@pytest .mark .asyncio
185175async def test_complex_return_types_tuple ():
186- """Test Tuple return types with proper decoding."""
187- from . import tcp_server
188-
189- client = AsyncZeroClient (tcp_server .HOST , tcp_server .PORT , protocol = AsyncTCPClient )
176+ client = get_async_client ()
190177
191178 # Test Tuple[int, str]
192179 result = await client .call (
@@ -209,10 +196,7 @@ async def test_complex_return_types_tuple():
209196)
210197@pytest .mark .asyncio
211198async def test_complex_return_types_nested_dict ():
212- """Test nested Dict return types with proper decoding."""
213- from . import tcp_server
214-
215- client = AsyncZeroClient (tcp_server .HOST , tcp_server .PORT , protocol = AsyncTCPClient )
199+ client = get_async_client ()
216200
217201 # Test Dict[int, str] - basic dict
218202 result = await client .call (
@@ -229,10 +213,9 @@ async def test_complex_return_types_nested_dict():
229213)
230214@pytest .mark .asyncio
231215async def test_complex_return_types_pydantic ():
232- """Test Pydantic model return types with proper decoding."""
233216 from . import tcp_server
234217
235- client = AsyncZeroClient ( tcp_server . HOST , tcp_server . PORT , protocol = AsyncTCPClient )
218+ client = get_async_client ( )
236219
237220 # Create a pydantic model instance
238221 model = tcp_server .PydanticModel (name = "Alice" , age = 30 )
@@ -253,12 +236,9 @@ async def test_complex_return_types_pydantic():
253236)
254237@pytest .mark .asyncio
255238async def test_complex_return_types_msgspec_struct ():
256- """Test msgspec Struct return types with proper decoding."""
257- import datetime
258-
259239 from . import tcp_server
260240
261- client = AsyncZeroClient ( tcp_server . HOST , tcp_server . PORT , protocol = AsyncTCPClient )
241+ client = get_async_client ( )
262242
263243 # Create a message struct
264244 now = datetime .datetime .now ()
@@ -280,10 +260,7 @@ async def test_complex_return_types_msgspec_struct():
280260)
281261@pytest .mark .asyncio
282262async def test_complex_return_types_optional ():
283- """Test Optional return types with proper decoding."""
284- from . import tcp_server
285-
286- client = AsyncZeroClient (tcp_server .HOST , tcp_server .PORT , protocol = AsyncTCPClient )
263+ client = get_async_client ()
287264
288265 # Test Optional[int] with value
289266 result = await client .call ("echo_typing_optional" , 42 , return_type = int )
@@ -301,10 +278,9 @@ async def test_complex_return_types_optional():
301278)
302279@pytest .mark .asyncio
303280async def test_complex_return_types_dataclass ():
304- """Test dataclass return types with proper decoding."""
305281 from . import tcp_server
306282
307- client = AsyncZeroClient ( tcp_server . HOST , tcp_server . PORT , protocol = AsyncTCPClient )
283+ client = get_async_client ( )
308284
309285 # Test dataclass return type
310286 result = await client .call (
@@ -320,10 +296,9 @@ async def test_complex_return_types_dataclass():
320296)
321297@pytest .mark .asyncio
322298async def test_complex_return_types_enum ():
323- """Test enum return types with proper decoding."""
324299 from . import tcp_server
325300
326- client = AsyncZeroClient ( tcp_server . HOST , tcp_server . PORT , protocol = AsyncTCPClient )
301+ client = get_async_client ( )
327302
328303 # Test enum return type
329304 result = await client .call (
@@ -342,26 +317,21 @@ async def test_complex_return_types_enum():
342317 assert result .value == 2
343318
344319
345- # For some reason this is failing in MacOS
346- # @pytest.mark.skipif(
347- # sys.platform == "win32", reason="TCP tests not supported on Windows"
348- # )
349- # @pytest.mark.asyncio
350- # async def test_async_sleep():
351- # from . import tcp_server
352-
353- # client = AsyncZeroClient(
354- # tcp_server.HOST, tcp_server.PORT, protocol=AsyncTCPClient, pool_size=5
355- # )
320+ @pytest .mark .skipif (
321+ sys .platform == "win32" , reason = "TCP tests not supported on Windows"
322+ )
323+ @pytest .mark .asyncio
324+ async def test_async_sleep ():
325+ client = get_async_client ()
356326
357- # async def task(sleep_time):
358- # res = await client.call("sleep_async", sleep_time)
359- # assert res == f"slept for {sleep_time} msecs"
327+ async def task (sleep_time ):
328+ res = await client .call ("sleep_async" , sleep_time )
329+ assert res == f"slept for { sleep_time } msecs"
360330
361- # tasks = [task(200) for _ in range(5)]
331+ tasks = [task (200 ) for _ in range (5 )]
362332
363- # start = time.perf_counter()
364- # await asyncio.gather(*tasks)
365- # time_taken_ms = (time.perf_counter() - start) * 1000
333+ start = time .perf_counter ()
334+ await asyncio .gather (* tasks )
335+ time_taken_ms = (time .perf_counter () - start ) * 1000
366336
367- # assert time_taken_ms < 1000
337+ assert time_taken_ms < 1000
0 commit comments