Skip to content

Commit 4712ec3

Browse files
committed
feat: support Partitioned DML
Adds tests and samples for executing Partitioned DML using SQLAlchemy. Fixes #496
1 parent a633c23 commit 4712ec3

File tree

3 files changed

+66
-7
lines changed

3 files changed

+66
-7
lines changed

test/mockserver_tests/mock_server_test_base.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,19 @@ def add_result(sql: str, result: ResultSet):
3535
MockServerTestBase.spanner_service.mock_spanner.add_result(sql, result)
3636

3737

38+
def add_update_count(sql: str, count: int):
39+
result = result_set.ResultSet(
40+
dict(
41+
stats=result_set.ResultSetStats(
42+
dict(
43+
row_count_exact=count,
44+
)
45+
),
46+
)
47+
)
48+
add_result(sql, result)
49+
50+
3851
def add_select1_result():
3952
result = result_set.ResultSet(
4053
dict(

test/mockserver_tests/mock_spanner.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from google.cloud.spanner_v1 import TransactionOptions, ResultSetMetadata
15+
from google.cloud.spanner_v1 import (
16+
TransactionOptions,
17+
ResultSetMetadata,
18+
ExecuteSqlRequest,
19+
)
1620
from google.protobuf import empty_pb2
1721
import test.mockserver_tests.spanner_pb2_grpc as spanner_grpc
1822
import test.mockserver_tests.spanner_database_admin_pb2_grpc as database_admin_grpc
@@ -40,23 +44,25 @@ def get_result(self, sql: str) -> result_set.ResultSet:
4044
return result
4145

4246
def get_result_as_partial_result_sets(
43-
self, sql: str
47+
self, sql: str, started_transaction: transaction.Transaction
4448
) -> [result_set.PartialResultSet]:
4549
result: result_set.ResultSet = self.get_result(sql)
4650
partials = []
4751
first = True
4852
if len(result.rows) == 0:
4953
partial = result_set.PartialResultSet()
50-
partial.metadata = result.metadata
54+
partial.metadata = ResultSetMetadata(result.metadata)
5155
partials.append(partial)
5256
else:
5357
for row in result.rows:
5458
partial = result_set.PartialResultSet()
5559
if first:
56-
partial.metadata = result.metadata
60+
partial.metadata = ResultSetMetadata(result.metadata)
5761
partial.values.extend(row)
5862
partials.append(partial)
5963
partials[len(partials) - 1].stats = result.stats
64+
if started_transaction:
65+
partials[0].metadata.transaction = started_transaction
6066
return partials
6167

6268

@@ -120,9 +126,16 @@ def ExecuteSql(self, request, context):
120126
self._requests.append(request)
121127
return result_set.ResultSet()
122128

123-
def ExecuteStreamingSql(self, request, context):
129+
def ExecuteStreamingSql(self, request: ExecuteSqlRequest, context):
124130
self._requests.append(request)
125-
partials = self.mock_spanner.get_result_as_partial_result_sets(request.sql)
131+
started_transaction = None
132+
if not request.transaction.begin == TransactionOptions():
133+
started_transaction = self.__create_transaction(
134+
request.session, request.transaction.begin
135+
)
136+
partials = self.mock_spanner.get_result_as_partial_result_sets(
137+
request.sql, started_transaction
138+
)
126139
for result in partials:
127140
yield result
128141

test/mockserver_tests/test_basics.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,19 @@
1313
# limitations under the License.
1414

1515
from google.cloud.spanner_admin_database_v1 import UpdateDatabaseDdlRequest
16-
from sqlalchemy import create_engine, select, MetaData, Table, Column, Integer, String
16+
from google.cloud.spanner_dbapi.parsed_statement import AutocommitDmlMode
17+
from sqlalchemy import (
18+
create_engine,
19+
select,
20+
MetaData,
21+
Table,
22+
Column,
23+
Integer,
24+
String,
25+
update,
26+
table,
27+
column,
28+
)
1729
from sqlalchemy.testing import eq_, is_instance_of
1830
from google.cloud.spanner_v1 import (
1931
FixedSizePool,
@@ -26,6 +38,7 @@
2638
MockServerTestBase,
2739
add_select1_result,
2840
add_result,
41+
add_update_count,
2942
)
3043

3144

@@ -127,3 +140,23 @@ def test_create_multiple_tables(self):
127140
"\n) PRIMARY KEY (id)",
128141
requests[0].statements[i],
129142
)
143+
144+
def test_partitioned_dml(self):
145+
sql = "UPDATE singers SET WHERE active = true"
146+
add_update_count(sql, 100)
147+
engine = create_engine(
148+
"spanner:///projects/p/instances/i/databases/d",
149+
connect_args={"client": self.client, "pool": PingingPool(size=10)},
150+
)
151+
# TODO: Support autocommit_dml_mode as a connection variable in execution
152+
# options.
153+
with engine.connect().execution_options(
154+
isolation_level="AUTOCOMMIT"
155+
) as connection:
156+
connection.connection.set_autocommit_dml_mode(
157+
AutocommitDmlMode.PARTITIONED_NON_ATOMIC
158+
)
159+
results = connection.execute(
160+
update(table("singers")).where(column("active") is True)
161+
).rowcount
162+
eq_(100, results)

0 commit comments

Comments
 (0)