Skip to content

Commit 39a11d0

Browse files
committed
chore: remove async + add transaction handling
1 parent 95b6cd6 commit 39a11d0

File tree

4 files changed

+80
-68
lines changed

4 files changed

+80
-68
lines changed

google/cloud/spanner_v1/database.py

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
import google.auth.credentials
2626
from google.api_core.retry import Retry
2727
from google.api_core.retry import if_exception_type
28-
from google.auth.aio.credentials import AnonymousCredentials
2928
from google.cloud.exceptions import NotFound
3029
from google.api_core.exceptions import Aborted
3130
from google.api_core import gapic_v1
@@ -42,7 +41,7 @@
4241
from google.cloud.spanner_admin_database_v1 import UpdateDatabaseDdlRequest
4342
from google.cloud.spanner_admin_database_v1.types import DatabaseDialect
4443
from google.cloud.spanner_v1.transaction import BatchTransactionId
45-
from google.cloud.spanner_v1 import ExecuteSqlRequest, SpannerAsyncClient
44+
from google.cloud.spanner_v1 import ExecuteSqlRequest
4645
from google.cloud.spanner_v1 import Type
4746
from google.cloud.spanner_v1 import TypeCode
4847
from google.cloud.spanner_v1 import TransactionSelector
@@ -144,7 +143,6 @@ class Database(object):
144143
"""
145144

146145
_spanner_api: SpannerClient = None
147-
_spanner_async_api: SpannerAsyncClient = None
148146

149147
def __init__(
150148
self,
@@ -440,28 +438,6 @@ def spanner_api(self):
440438
)
441439
return self._spanner_api
442440

443-
@property
444-
def spanner_async_api(self):
445-
if self._spanner_async_api is None:
446-
client_info = self._instance._client._client_info
447-
client_options = self._instance._client._client_options
448-
if self._instance.emulator_host is not None:
449-
channel=grpc.aio.insecure_channel(target=self._instance.emulator_host)
450-
transport = SpannerGrpcTransport(channel=channel)
451-
self._spanner_async_api = SpannerAsyncClient(
452-
client_info=client_info, transport=transport
453-
)
454-
return self._spanner_async_api
455-
credentials = self._instance._client.credentials
456-
if isinstance(credentials, google.auth.credentials.Scoped):
457-
credentials = credentials.with_scopes((SPANNER_DATA_SCOPE,))
458-
self._spanner_async_api = SpannerAsyncClient(
459-
credentials=credentials,
460-
client_info=client_info,
461-
client_options=client_options,
462-
)
463-
return self._spanner_async_api
464-
465441
def __eq__(self, other):
466442
if not isinstance(other, self.__class__):
467443
return NotImplemented

google/cloud/spanner_v1/services/spanner/transports/grpc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def __init__(
127127
if client_cert_source:
128128
warnings.warn("client_cert_source is deprecated", DeprecationWarning)
129129

130-
if isinstance(channel, grpc.Channel) or isinstance(channel, grpc.aio.Channel):
130+
if isinstance(channel, grpc.Channel):
131131
# Ignore credentials if a channel was passed.
132132
credentials = None
133133
self._ignore_credentials = True

google/cloud/spanner_v1/testing/mock_spanner.py

Lines changed: 70 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
14+
import base64
1515
from concurrent import futures
1616

1717
from google.protobuf import empty_pb2
@@ -31,22 +31,32 @@ def __init__(self):
3131
self.results = {}
3232

3333
def add_result(self, sql: str, result: result_set.ResultSet):
34-
self.results[sql.lower()] = result
34+
self.results[sql.lower().strip()] = result
35+
36+
def get_result(self, sql: str) -> result_set.ResultSet:
37+
result = self.results.get(sql.lower().strip())
38+
if result is None:
39+
raise ValueError(f"No result found for {sql}")
40+
return result
3541

3642
def get_result_as_partial_result_sets(
3743
self, sql: str
3844
) -> [result_set.PartialResultSet]:
39-
result: result_set.ResultSet = self.results.get(sql.lower())
40-
if result is None:
41-
return []
45+
result: result_set.ResultSet = self.get_result(sql)
4246
partials = []
4347
first = True
44-
for row in result.rows:
48+
if len(result.rows) == 0:
4549
partial = result_set.PartialResultSet()
46-
if first:
47-
partial.metadata = result.metadata
48-
partial.values.extend(row)
50+
partial.metadata = result.metadata
4951
partials.append(partial)
52+
else:
53+
for row in result.rows:
54+
partial = result_set.PartialResultSet()
55+
if first:
56+
partial.metadata = result.metadata
57+
partial.values.extend(row)
58+
partials.append(partial)
59+
partials[len(partials) - 1].stats = result.stats
5060
return partials
5161

5262

@@ -56,6 +66,8 @@ def __init__(self):
5666
self._requests = []
5767
self.session_counter = 0
5868
self.sessions = {}
69+
self.transaction_counter = 0
70+
self.transactions = {}
5971
self._mock_spanner = MockSpanner()
6072

6173
@property
@@ -93,18 +105,20 @@ def __create_session(self, database: str, session_template: spanner.Session):
93105
return session
94106

95107
def GetSession(self, request, context):
108+
self._requests.append(request)
96109
return spanner.Session()
97110

98111
def ListSessions(self, request, context):
112+
self._requests.append(request)
99113
return [spanner.Session()]
100114

101115
def DeleteSession(self, request, context):
116+
self._requests.append(request)
102117
return empty_pb2.Empty()
103118

104119
def ExecuteSql(self, request, context):
105120
self._requests.append(request)
106-
result: result_set.ResultSet = self.mock_spanner.results.get(request.sql.lower())
107-
return result
121+
return result_set.ResultSet()
108122

109123
def ExecuteStreamingSql(self, request, context):
110124
self._requests.append(request)
@@ -113,31 +127,74 @@ def ExecuteStreamingSql(self, request, context):
113127
yield result
114128

115129
def ExecuteBatchDml(self, request, context):
116-
return spanner.ExecuteBatchDmlResponse()
130+
self._requests.append(request)
131+
response = spanner.ExecuteBatchDmlResponse()
132+
started_transaction = None
133+
if not request.transaction.begin == transaction.TransactionOptions():
134+
started_transaction = self.__create_transaction(
135+
request.session, request.transaction.begin
136+
)
137+
first = True
138+
for statement in request.statements:
139+
result = self.mock_spanner.get_result(statement.sql)
140+
if first and started_transaction is not None:
141+
result = result_set.ResultSet(
142+
self.mock_spanner.get_result(statement.sql)
143+
)
144+
result.metadata = result_set.ResultSetMetadata(result.metadata)
145+
result.metadata.transaction = started_transaction
146+
response.result_sets.append(result)
147+
return response
117148

118149
def Read(self, request, context):
150+
self._requests.append(request)
119151
return result_set.ResultSet()
120152

121153
def StreamingRead(self, request, context):
154+
self._requests.append(request)
122155
for result in [result_set.PartialResultSet(), result_set.PartialResultSet()]:
123156
yield result
124157

125158
def BeginTransaction(self, request, context):
126-
return transaction.Transaction()
159+
self._requests.append(request)
160+
return self.__create_transaction(request.session, request.options)
161+
162+
def __create_transaction(
163+
self, session: str, options: transaction.TransactionOptions
164+
) -> transaction.Transaction:
165+
session = self.sessions[session]
166+
if session is None:
167+
raise ValueError(f"Session not found: {session}")
168+
self.transaction_counter += 1
169+
id_bytes = bytes(
170+
f"{session.name}/transactions/{self.transaction_counter}", "UTF-8"
171+
)
172+
transaction_id = base64.urlsafe_b64encode(id_bytes)
173+
self.transactions[transaction_id] = options
174+
return transaction.Transaction(dict(id=transaction_id))
127175

128176
def Commit(self, request, context):
177+
self._requests.append(request)
178+
tx = self.transactions[request.transaction_id]
179+
if tx is None:
180+
raise ValueError(f"Transaction not found: {request.transaction_id}")
181+
del self.transactions[request.transaction_id]
129182
return commit.CommitResponse()
130183

131184
def Rollback(self, request, context):
185+
self._requests.append(request)
132186
return empty_pb2.Empty()
133187

134188
def PartitionQuery(self, request, context):
189+
self._requests.append(request)
135190
return spanner.PartitionResponse()
136191

137192
def PartitionRead(self, request, context):
193+
self._requests.append(request)
138194
return spanner.PartitionResponse()
139195

140196
def BatchWrite(self, request, context):
197+
self._requests.append(request)
141198
for result in [spanner.BatchWriteResponse(), spanner.BatchWriteResponse()]:
142199
yield result
143200

tests/mockserver_tests/test_basics.py

Lines changed: 8 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import asyncio
14+
1515
import unittest
1616

1717
from google.cloud.spanner_admin_database_v1.types import spanner_database_admin
@@ -28,7 +28,8 @@
2828
Client,
2929
FixedSizePool,
3030
BatchCreateSessionsRequest,
31-
ExecuteSqlRequest, CreateSessionRequest,
31+
ExecuteSqlRequest,
32+
GetSessionRequest,
3233
)
3334
from google.cloud.spanner_v1.database import Database
3435
from google.cloud.spanner_v1.instance import Instance
@@ -124,9 +125,12 @@ def test_select1(self):
124125
self.assertEqual(1, row[0])
125126
self.assertEqual(1, len(result_list))
126127
requests = self.spanner_service.requests
127-
self.assertEqual(2, len(requests))
128+
self.assertEqual(3, len(requests))
128129
self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest))
129-
self.assertTrue(isinstance(requests[1], ExecuteSqlRequest))
130+
# TODO: Optimize FixedSizePool so this GetSessionRequest is not executed
131+
# every time a session is fetched.
132+
self.assertTrue(isinstance(requests[1], GetSessionRequest))
133+
self.assertTrue(isinstance(requests[2], ExecuteSqlRequest))
130134

131135
def test_create_table(self):
132136
database_admin_api = self.client.database_admin_api
@@ -145,28 +149,3 @@ def test_create_table(self):
145149
)
146150
operation = database_admin_api.update_database_ddl(request)
147151
operation.result(1)
148-
149-
150-
def test_async_select1(self):
151-
self._add_select1_result()
152-
results = asyncio.run(self._async_select1())
153-
result_list = []
154-
for row in results.rows:
155-
result_list.append(row)
156-
self.assertEqual("1", row[0])
157-
self.assertEqual(1, len(result_list))
158-
requests = self.spanner_service.requests
159-
self.assertEqual(3, len(requests))
160-
self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest))
161-
self.assertTrue(isinstance(requests[1], CreateSessionRequest))
162-
self.assertTrue(isinstance(requests[2], ExecuteSqlRequest))
163-
164-
async def _async_select1(self):
165-
client = self.database.spanner_async_api
166-
create_session_request = CreateSessionRequest(database=self._database.name)
167-
session = await client.create_session(create_session_request)
168-
execute_request = ExecuteSqlRequest(dict(
169-
session=session.name,
170-
sql="select 1",
171-
))
172-
return await client.execute_sql(execute_request)

0 commit comments

Comments
 (0)