Skip to content

Commit 2eb031b

Browse files
1 parent 280e436 commit 2eb031b

File tree

3 files changed

+31
-23
lines changed

3 files changed

+31
-23
lines changed

‎google/cloud/spanner_v1/testing/mock_spanner.py‎

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,10 @@ def __init__(self):
4141
def add_result(self, sql: str, result: result_set.ResultSet):
4242
self.results[sql.lower().strip()] = result
4343

44-
def add_execute_streaming_sql_results(self, sql: str,
45-
partial_result_sets: list[result_set.PartialResultSet]):
46-
self.execute_streaming_sql_results[
47-
sql.lower().strip()] = partial_result_sets
44+
def add_execute_streaming_sql_results(
45+
self, sql: str, partial_result_sets: list[result_set.PartialResultSet]
46+
):
47+
self.execute_streaming_sql_results[sql.lower().strip()] = partial_result_sets
4848

4949
def get_result(self, sql: str) -> result_set.ResultSet:
5050
result = self.results.get(sql.lower().strip())
@@ -61,9 +61,9 @@ def pop_error(self, context):
6161
if error:
6262
context.abort_with_status(error)
6363

64-
def get_execute_streaming_sql_results(self, sql: str,
65-
started_transaction: transaction.Transaction) -> list[
66-
result_set.PartialResultSet]:
64+
def get_execute_streaming_sql_results(
65+
self, sql: str, started_transaction: transaction.Transaction
66+
) -> list[result_set.PartialResultSet]:
6767
if self.execute_streaming_sql_results[sql.lower().strip()]:
6868
partials = self.execute_streaming_sql_results[sql.lower().strip()]
6969
else:

‎tests/mockserver_tests/mock_server_test_base.py‎

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,7 @@
3939
from google.cloud.spanner_v1 import TypeCode
4040
from google.cloud.spanner_v1.database import Database
4141
from google.cloud.spanner_v1.instance import Instance
42-
from google.cloud.spanner_v1.testing.mock_database_admin import \
43-
DatabaseAdminServicer
42+
from google.cloud.spanner_v1.testing.mock_database_admin import DatabaseAdminServicer
4443
from google.cloud.spanner_v1.testing.mock_spanner import SpannerServicer
4544
from google.cloud.spanner_v1.testing.mock_spanner import start_mock_server
4645

@@ -66,8 +65,9 @@ def aborted_status() -> _Status:
6665
return status
6766

6867

69-
def _make_partial_result_sets(fields: list[tuple[str, TypeCode]],
70-
results: list[dict]) -> list[result_set.PartialResultSet]:
68+
def _make_partial_result_sets(
69+
fields: list[tuple[str, TypeCode]], results: list[dict]
70+
) -> list[result_set.PartialResultSet]:
7171
partial_result_sets = []
7272
for result in results:
7373
partial_result_set = PartialResultSet()
@@ -76,14 +76,16 @@ def _make_partial_result_sets(fields: list[tuple[str, TypeCode]],
7676
metadata = ResultSetMetadata(row_type=StructType(fields=[]))
7777
for field in fields:
7878
metadata.row_type.fields.append(
79-
StructType.Field(name=field[0], type_=Type(code=field[1])))
79+
StructType.Field(name=field[0], type_=Type(code=field[1]))
80+
)
8081
partial_result_set.metadata = metadata
8182
for value in result["values"]:
8283
partial_result_set.values.append(_make_value_pb(value))
83-
partial_result_set.last = result.get('last') or False
84+
partial_result_set.last = result.get("last") or False
8485
partial_result_sets.append(partial_result_set)
8586
return partial_result_sets
8687

88+
8789
# Creates an UNAVAILABLE status with the smallest possible retry delay.
8890
def unavailable_status() -> _Status:
8991
error = status_pb2.Status(
@@ -128,10 +130,13 @@ def add_select1_result():
128130
add_single_result("select 1", "c", TypeCode.INT64, [("1",)])
129131

130132

131-
def add_execute_streaming_sql_results(sql: str,
132-
partial_result_sets: list[result_set.PartialResultSet]):
133+
def add_execute_streaming_sql_results(
134+
sql: str, partial_result_sets: list[result_set.PartialResultSet]
135+
):
133136
MockServerTestBase.spanner_service.mock_spanner.add_execute_streaming_sql_results(
134-
sql, partial_result_sets)
137+
sql, partial_result_sets
138+
)
139+
135140

136141
def add_single_result(
137142
sql: str, column_name: str, type_code: spanner_type.TypeCode, row

‎tests/mockserver_tests/test_basics.py‎

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,19 +24,18 @@
2424
from google.cloud.spanner_v1.testing.mock_spanner import SpannerServicer
2525
from google.cloud.spanner_v1.transaction import Transaction
2626
from tests.mockserver_tests.mock_server_test_base import MockServerTestBase
27-
from tests.mockserver_tests.mock_server_test_base import \
28-
_make_partial_result_sets
27+
from tests.mockserver_tests.mock_server_test_base import _make_partial_result_sets
2928
from tests.mockserver_tests.mock_server_test_base import add_error
30-
from tests.mockserver_tests.mock_server_test_base import \
31-
add_execute_streaming_sql_results
29+
from tests.mockserver_tests.mock_server_test_base import (
30+
add_execute_streaming_sql_results,
31+
)
3232
from tests.mockserver_tests.mock_server_test_base import add_select1_result
3333
from tests.mockserver_tests.mock_server_test_base import add_single_result
3434
from tests.mockserver_tests.mock_server_test_base import add_update_count
3535
from tests.mockserver_tests.mock_server_test_base import unavailable_status
3636

3737

3838
class TestBasics(MockServerTestBase):
39-
4039
def setUp(self):
4140
super().setUp()
4241
super().setup_class()
@@ -183,8 +182,11 @@ def test_last_statement_query(self):
183182
def test_execute_streaming_sql_last_field(self):
184183
partial_result_sets = _make_partial_result_sets(
185184
[("ID", TypeCode.INT64), ("NAME", TypeCode.STRING)],
186-
[{"values": ["1", "ABC", "2", "DEF"]},
187-
{"values": ["3", "GHI"], "last": True}])
185+
[
186+
{"values": ["1", "ABC", "2", "DEF"]},
187+
{"values": ["3", "GHI"], "last": True},
188+
],
189+
)
188190

189191
sql = "select * from my_table"
190192
add_execute_streaming_sql_results(sql, partial_result_sets)
@@ -202,6 +204,7 @@ def test_execute_streaming_sql_last_field(self):
202204
self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest))
203205
self.assertTrue(isinstance(requests[1], ExecuteSqlRequest))
204206

207+
205208
def _execute_query(transaction: Transaction, sql: str):
206209
rows = transaction.execute_sql(sql, last_statement=True)
207210
for _ in rows:

0 commit comments

Comments
 (0)