1
1
import asyncio
2
- from typing import Union
3
2
4
3
from asynctest .mock import MagicMock , CoroutineMock
5
4
import pytest
15
14
RetryingConnection ,
16
15
_MIN_BACKOFF_SECS ,
17
16
)
17
+ from google .cloud .pubsublite .testing .test_utils import wire_queues
18
18
19
19
# All test coroutines will be treated as marked.
20
20
pytestmark = pytest .mark .asyncio
@@ -41,7 +41,7 @@ def connection_factory(default_connection):
41
41
42
42
@pytest .fixture ()
43
43
def retrying_connection (connection_factory , reinitializer ):
44
- return RetryingConnection [ int , int ] (connection_factory , reinitializer )
44
+ return RetryingConnection (connection_factory , reinitializer )
45
45
46
46
47
47
@pytest .fixture
@@ -55,40 +55,27 @@ def asyncio_sleep(monkeypatch):
55
55
async def test_permanent_error_on_reinitializer (
56
56
retrying_connection : Connection [int , int ], reinitializer , default_connection
57
57
):
58
- fut = asyncio .Future ()
59
- reinitialize_called = asyncio .Future ()
60
-
61
58
async def reinit_action (conn ):
62
59
assert conn == default_connection
63
- reinitialize_called .set_result (None )
64
- return await fut
60
+ raise InvalidArgument ("abc" )
65
61
66
62
reinitializer .reinitialize .side_effect = reinit_action
67
- async with retrying_connection as _ :
68
- await reinitialize_called
69
- reinitializer .reinitialize .assert_called_once ()
70
- fut .set_exception (InvalidArgument ("abc" ))
71
- with pytest .raises (InvalidArgument ):
72
- await retrying_connection .read ()
63
+ with pytest .raises (InvalidArgument ):
64
+ async with retrying_connection as _ :
65
+ pass
73
66
74
67
75
68
async def test_successful_reinitialize (
76
69
retrying_connection : Connection [int , int ], reinitializer , default_connection
77
70
):
78
- fut = asyncio .Future ()
79
- reinitialize_called = asyncio .Future ()
80
-
81
71
async def reinit_action (conn ):
82
72
assert conn == default_connection
83
- reinitialize_called .set_result (None )
84
- return await fut
73
+ return None
74
+
75
+ default_connection .read .return_value = 1
85
76
86
77
reinitializer .reinitialize .side_effect = reinit_action
87
78
async with retrying_connection as _ :
88
- await reinitialize_called
89
- reinitializer .reinitialize .assert_called_once ()
90
- fut .set_result (None )
91
- default_connection .read .return_value = 1
92
79
assert await retrying_connection .read () == 1
93
80
assert (
94
81
default_connection .read .call_count == 2
@@ -111,26 +98,15 @@ async def test_reinitialize_after_retryable(
111
98
default_connection ,
112
99
asyncio_sleep ,
113
100
):
114
- reinit_called = asyncio .Queue ()
115
- reinit_results : "asyncio.Queue[Union[None, Exception]]" = asyncio .Queue ()
101
+ reinit_queues = wire_queues (reinitializer .reinitialize )
116
102
117
- async def reinit_action (conn ):
118
- assert conn == default_connection
119
- await reinit_called .put (None )
120
- result = await reinit_results .get ()
121
- if isinstance (result , Exception ):
122
- raise result
103
+ default_connection .read .return_value = 1
123
104
124
- reinitializer .reinitialize .side_effect = reinit_action
105
+ await reinit_queues .results .put (InternalServerError ("abc" ))
106
+ await reinit_queues .results .put (None )
125
107
async with retrying_connection as _ :
126
- await reinit_called .get ()
127
- reinitializer .reinitialize .assert_called_once ()
128
- await reinit_results .put (InternalServerError ("abc" ))
129
- await reinit_called .get ()
130
108
asyncio_sleep .assert_called_once_with (_MIN_BACKOFF_SECS )
131
109
assert reinitializer .reinitialize .call_count == 2
132
- await reinit_results .put (None )
133
- default_connection .read .return_value = 1
134
110
assert await retrying_connection .read () == 1
135
111
assert (
136
112
default_connection .read .call_count == 2
0 commit comments