Skip to content
This repository was archived by the owner on Nov 29, 2023. It is now read-only.

Commit bf9cc26

Browse files
feat: add context manager support in client (#125)
- [ ] Regenerate this pull request now. chore: fix docstring for first attribute of protos committer: @busunkim96 PiperOrigin-RevId: 401271153 Source-Link: googleapis/googleapis@787f8c9 Source-Link: https://github.com/googleapis/googleapis-gen/commit/81decffe9fc72396a8153e756d1d67a6eecfd620 Copy-Tag: eyJwIjoiLmdpdGh1Yi8uT3dsQm90LnlhbWwiLCJoIjoiODFkZWNmZmU5ZmM3MjM5NmE4MTUzZTc1NmQxZDY3YTZlZWNmZDYyMCJ9
1 parent 397a7f2 commit bf9cc26

File tree

7 files changed

+90
-4
lines changed

7 files changed

+90
-4
lines changed

google/cloud/bigquery_connection_v1/services/connection_service/async_client.py

+6
Original file line numberDiff line numberDiff line change
@@ -943,6 +943,12 @@ async def test_iam_permissions(
943943
# Done; return the response.
944944
return response
945945

946+
async def __aenter__(self):
947+
return self
948+
949+
async def __aexit__(self, exc_type, exc, tb):
950+
await self.transport.close()
951+
946952

947953
try:
948954
DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(

google/cloud/bigquery_connection_v1/services/connection_service/client.py

+14-4
Original file line numberDiff line numberDiff line change
@@ -350,10 +350,7 @@ def __init__(
350350
client_cert_source_for_mtls=client_cert_source_func,
351351
quota_project_id=client_options.quota_project_id,
352352
client_info=client_info,
353-
always_use_jwt_access=(
354-
Transport == type(self).get_transport_class("grpc")
355-
or Transport == type(self).get_transport_class("grpc_asyncio")
356-
),
353+
always_use_jwt_access=True,
357354
)
358355

359356
def create_connection(
@@ -1101,6 +1098,19 @@ def test_iam_permissions(
11011098
# Done; return the response.
11021099
return response
11031100

1101+
def __enter__(self):
1102+
return self
1103+
1104+
def __exit__(self, type, value, traceback):
1105+
"""Releases underlying transport's resources.
1106+
1107+
.. warning::
1108+
ONLY use as a context manager if the transport is NOT shared
1109+
with other clients! Exiting the with block will CLOSE the transport
1110+
and may cause errors in other clients!
1111+
"""
1112+
self.transport.close()
1113+
11041114

11051115
try:
11061116
DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(

google/cloud/bigquery_connection_v1/services/connection_service/transports/base.py

+9
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,15 @@ def _prep_wrapped_messages(self, client_info):
225225
),
226226
}
227227

228+
def close(self):
229+
"""Closes resources associated with the transport.
230+
231+
.. warning::
232+
Only call this method if the transport is NOT shared
233+
with other clients - this may cause errors in other clients!
234+
"""
235+
raise NotImplementedError()
236+
228237
@property
229238
def create_connection(
230239
self,

google/cloud/bigquery_connection_v1/services/connection_service/transports/grpc.py

+3
Original file line numberDiff line numberDiff line change
@@ -461,5 +461,8 @@ def test_iam_permissions(
461461
)
462462
return self._stubs["test_iam_permissions"]
463463

464+
def close(self):
465+
self.grpc_channel.close()
466+
464467

465468
__all__ = ("ConnectionServiceGrpcTransport",)

google/cloud/bigquery_connection_v1/services/connection_service/transports/grpc_asyncio.py

+3
Original file line numberDiff line numberDiff line change
@@ -465,5 +465,8 @@ def test_iam_permissions(
465465
)
466466
return self._stubs["test_iam_permissions"]
467467

468+
def close(self):
469+
return self.grpc_channel.close()
470+
468471

469472
__all__ = ("ConnectionServiceGrpcAsyncIOTransport",)

google/cloud/bigquery_connection_v1/types/connection.py

+5
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ class UpdateConnectionRequest(proto.Message):
134134

135135
class DeleteConnectionRequest(proto.Message):
136136
r"""The request for [ConnectionService.DeleteConnectionRequest][].
137+
137138
Attributes:
138139
name (str):
139140
Required. Name of the deleted connection, for example:
@@ -192,6 +193,7 @@ class Connection(proto.Message):
192193

193194
class CloudSqlProperties(proto.Message):
194195
r"""Connection properties specific to the Cloud SQL.
196+
195197
Attributes:
196198
instance_id (str):
197199
Cloud SQL instance ID in the form
@@ -218,6 +220,7 @@ class DatabaseType(proto.Enum):
218220

219221
class CloudSqlCredential(proto.Message):
220222
r"""Credential info for the Cloud SQL.
223+
221224
Attributes:
222225
username (str):
223226
The username for the credential.
@@ -231,6 +234,7 @@ class CloudSqlCredential(proto.Message):
231234

232235
class CloudSpannerProperties(proto.Message):
233236
r"""Connection properties specific to Cloud Spanner.
237+
234238
Attributes:
235239
database (str):
236240
Cloud Spanner database in the form
@@ -246,6 +250,7 @@ class CloudSpannerProperties(proto.Message):
246250

247251
class AwsProperties(proto.Message):
248252
r"""Connection properties specific to Amazon Web Services (AWS).
253+
249254
Attributes:
250255
cross_account_role (google.cloud.bigquery_connection_v1.types.AwsCrossAccountRole):
251256
Authentication using Google owned AWS IAM

tests/unit/gapic/bigquery_connection_v1/test_connection_service.py

+50
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from google.api_core import gapic_v1
3030
from google.api_core import grpc_helpers
3131
from google.api_core import grpc_helpers_async
32+
from google.api_core import path_template
3233
from google.auth import credentials as ga_credentials
3334
from google.auth.exceptions import MutualTLSChannelError
3435
from google.cloud.bigquery_connection_v1.services.connection_service import (
@@ -2608,6 +2609,9 @@ def test_connection_service_base_transport():
26082609
with pytest.raises(NotImplementedError):
26092610
getattr(transport, method)(request=object())
26102611

2612+
with pytest.raises(NotImplementedError):
2613+
transport.close()
2614+
26112615

26122616
@requires_google_auth_gte_1_25_0
26132617
def test_connection_service_base_transport_with_credentials_file():
@@ -3111,3 +3115,49 @@ def test_client_withDEFAULT_CLIENT_INFO():
31113115
credentials=ga_credentials.AnonymousCredentials(), client_info=client_info,
31123116
)
31133117
prep.assert_called_once_with(client_info)
3118+
3119+
3120+
@pytest.mark.asyncio
3121+
async def test_transport_close_async():
3122+
client = ConnectionServiceAsyncClient(
3123+
credentials=ga_credentials.AnonymousCredentials(), transport="grpc_asyncio",
3124+
)
3125+
with mock.patch.object(
3126+
type(getattr(client.transport, "grpc_channel")), "close"
3127+
) as close:
3128+
async with client:
3129+
close.assert_not_called()
3130+
close.assert_called_once()
3131+
3132+
3133+
def test_transport_close():
3134+
transports = {
3135+
"grpc": "_grpc_channel",
3136+
}
3137+
3138+
for transport, close_name in transports.items():
3139+
client = ConnectionServiceClient(
3140+
credentials=ga_credentials.AnonymousCredentials(), transport=transport
3141+
)
3142+
with mock.patch.object(
3143+
type(getattr(client.transport, close_name)), "close"
3144+
) as close:
3145+
with client:
3146+
close.assert_not_called()
3147+
close.assert_called_once()
3148+
3149+
3150+
def test_client_ctx():
3151+
transports = [
3152+
"grpc",
3153+
]
3154+
for transport in transports:
3155+
client = ConnectionServiceClient(
3156+
credentials=ga_credentials.AnonymousCredentials(), transport=transport
3157+
)
3158+
# Test client calls underlying transport.
3159+
with mock.patch.object(type(client.transport), "close") as close:
3160+
close.assert_not_called()
3161+
with client:
3162+
pass
3163+
close.assert_called()

0 commit comments

Comments
 (0)