Skip to content

Commit 95b6cd6

Browse files
committed
Merge branch 'main' into mock-server-tests
2 parents 21557e0 + 5e8ca94 commit 95b6cd6

File tree

6 files changed

+297
-87
lines changed

6 files changed

+297
-87
lines changed

google/cloud/spanner_dbapi/connection.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,11 @@ class Connection:
8989
committed by other transactions since the start of the read-only transaction. Commit or rolling back
9090
the read-only transaction is semantically the same, and only indicates that the read-only transaction
9191
should end a that a new one should be started when the next statement is executed.
92+
93+
**kwargs: Initial value for connection variables.
9294
"""
9395

94-
def __init__(self, instance, database=None, read_only=False):
96+
def __init__(self, instance, database=None, read_only=False, **kwargs):
9597
self._instance = instance
9698
self._database = database
9799
self._ddl_statements = []
@@ -117,6 +119,7 @@ def __init__(self, instance, database=None, read_only=False):
117119
self._batch_dml_executor: BatchDmlExecutor = None
118120
self._transaction_helper = TransactionRetryHelper(self)
119121
self._autocommit_dml_mode: AutocommitDmlMode = AutocommitDmlMode.TRANSACTIONAL
122+
self._connection_variables = kwargs
120123

121124
@property
122125
def spanner_client(self):
@@ -206,6 +209,10 @@ def _client_transaction_started(self):
206209
"""
207210
return (not self._autocommit) or self._transaction_begin_marked
208211

212+
@property
213+
def _ignore_transaction_warnings(self):
214+
return self._connection_variables.get("ignore_transaction_warnings", False)
215+
209216
@property
210217
def instance(self):
211218
"""Instance to which this connection relates.
@@ -232,7 +239,7 @@ def read_only(self, value):
232239
Args:
233240
value (bool): True for ReadOnly mode, False for ReadWrite.
234241
"""
235-
if self._spanner_transaction_started:
242+
if self._read_only != value and self._spanner_transaction_started:
236243
raise ValueError(
237244
"Connection read/write mode can't be changed while a transaction is in progress. "
238245
"Commit or rollback the current transaction and try again."
@@ -398,9 +405,10 @@ def commit(self):
398405
if self.database is None:
399406
raise ValueError("Database needs to be passed for this operation")
400407
if not self._client_transaction_started:
401-
warnings.warn(
402-
CLIENT_TRANSACTION_NOT_STARTED_WARNING, UserWarning, stacklevel=2
403-
)
408+
if not self._ignore_transaction_warnings:
409+
warnings.warn(
410+
CLIENT_TRANSACTION_NOT_STARTED_WARNING, UserWarning, stacklevel=2
411+
)
404412
return
405413

406414
self.run_prior_DDL_statements()
@@ -418,9 +426,10 @@ def rollback(self):
418426
This is a no-op if there is no active client transaction.
419427
"""
420428
if not self._client_transaction_started:
421-
warnings.warn(
422-
CLIENT_TRANSACTION_NOT_STARTED_WARNING, UserWarning, stacklevel=2
423-
)
429+
if not self._ignore_transaction_warnings:
430+
warnings.warn(
431+
CLIENT_TRANSACTION_NOT_STARTED_WARNING, UserWarning, stacklevel=2
432+
)
424433
return
425434
try:
426435
if self._spanner_transaction_started and not self._read_only:
@@ -654,6 +663,7 @@ def connect(
654663
user_agent=None,
655664
client=None,
656665
route_to_leader_enabled=True,
666+
**kwargs,
657667
):
658668
"""Creates a connection to a Google Cloud Spanner database.
659669
@@ -696,6 +706,8 @@ def connect(
696706
disable leader aware routing. Disabling leader aware routing would
697707
route all requests in RW/PDML transactions to the closest region.
698708
709+
**kwargs: Initial value for connection variables.
710+
699711
700712
:rtype: :class:`google.cloud.spanner_dbapi.connection.Connection`
701713
:returns: Connection object associated with the given Google Cloud Spanner

google/cloud/spanner_v1/_helpers.py

Lines changed: 131 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -266,66 +266,69 @@ def _parse_value_pb(value_pb, field_type, field_name, column_info=None):
266266
:returns: value extracted from value_pb
267267
:raises ValueError: if unknown type is passed
268268
"""
269+
decoder = _get_type_decoder(field_type, field_name, column_info)
270+
return _parse_nullable(value_pb, decoder)
271+
272+
273+
def _get_type_decoder(field_type, field_name, column_info=None):
274+
"""Returns a function that converts a Value protobuf to cell data.
275+
276+
:type field_type: :class:`~google.cloud.spanner_v1.types.Type`
277+
:param field_type: type code for the value
278+
279+
:type field_name: str
280+
:param field_name: column name
281+
282+
:type column_info: dict
283+
:param column_info: (Optional) dict of column name and column information.
284+
An object where column names as keys and custom objects as corresponding
285+
values for deserialization. It's specifically useful for data types like
286+
protobuf where deserialization logic is on user-specific code. When provided,
287+
the custom object enables deserialization of backend-received column data.
288+
If not provided, data remains serialized as bytes for Proto Messages and
289+
integer for Proto Enums.
290+
291+
:rtype: a function that takes a single protobuf value as an input argument
292+
:returns: a function that can be used to extract a value from a protobuf value
293+
:raises ValueError: if unknown type is passed
294+
"""
295+
269296
type_code = field_type.code
270-
if value_pb.HasField("null_value"):
271-
return None
272297
if type_code == TypeCode.STRING:
273-
return value_pb.string_value
298+
return _parse_string
274299
elif type_code == TypeCode.BYTES:
275-
return value_pb.string_value.encode("utf8")
300+
return _parse_bytes
276301
elif type_code == TypeCode.BOOL:
277-
return value_pb.bool_value
302+
return _parse_bool
278303
elif type_code == TypeCode.INT64:
279-
return int(value_pb.string_value)
304+
return _parse_int64
280305
elif type_code == TypeCode.FLOAT64:
281-
if value_pb.HasField("string_value"):
282-
return float(value_pb.string_value)
283-
else:
284-
return value_pb.number_value
306+
return _parse_float
285307
elif type_code == TypeCode.FLOAT32:
286-
if value_pb.HasField("string_value"):
287-
return float(value_pb.string_value)
288-
else:
289-
return value_pb.number_value
308+
return _parse_float
290309
elif type_code == TypeCode.DATE:
291-
return _date_from_iso8601_date(value_pb.string_value)
310+
return _parse_date
292311
elif type_code == TypeCode.TIMESTAMP:
293-
DatetimeWithNanoseconds = datetime_helpers.DatetimeWithNanoseconds
294-
return DatetimeWithNanoseconds.from_rfc3339(value_pb.string_value)
295-
elif type_code == TypeCode.ARRAY:
296-
return [
297-
_parse_value_pb(
298-
item_pb, field_type.array_element_type, field_name, column_info
299-
)
300-
for item_pb in value_pb.list_value.values
301-
]
302-
elif type_code == TypeCode.STRUCT:
303-
return [
304-
_parse_value_pb(
305-
item_pb, field_type.struct_type.fields[i].type_, field_name, column_info
306-
)
307-
for (i, item_pb) in enumerate(value_pb.list_value.values)
308-
]
312+
return _parse_timestamp
309313
elif type_code == TypeCode.NUMERIC:
310-
return decimal.Decimal(value_pb.string_value)
314+
return _parse_numeric
311315
elif type_code == TypeCode.JSON:
312-
return JsonObject.from_str(value_pb.string_value)
316+
return _parse_json
313317
elif type_code == TypeCode.PROTO:
314-
bytes_value = base64.b64decode(value_pb.string_value)
315-
if column_info is not None and column_info.get(field_name) is not None:
316-
default_proto_message = column_info.get(field_name)
317-
if isinstance(default_proto_message, Message):
318-
proto_message = type(default_proto_message)()
319-
proto_message.ParseFromString(bytes_value)
320-
return proto_message
321-
return bytes_value
318+
return lambda value_pb: _parse_proto(value_pb, column_info, field_name)
322319
elif type_code == TypeCode.ENUM:
323-
int_value = int(value_pb.string_value)
324-
if column_info is not None and column_info.get(field_name) is not None:
325-
proto_enum = column_info.get(field_name)
326-
if isinstance(proto_enum, EnumTypeWrapper):
327-
return proto_enum.Name(int_value)
328-
return int_value
320+
return lambda value_pb: _parse_proto_enum(value_pb, column_info, field_name)
321+
elif type_code == TypeCode.ARRAY:
322+
element_decoder = _get_type_decoder(
323+
field_type.array_element_type, field_name, column_info
324+
)
325+
return lambda value_pb: _parse_array(value_pb, element_decoder)
326+
elif type_code == TypeCode.STRUCT:
327+
element_decoders = [
328+
_get_type_decoder(item_field.type_, field_name, column_info)
329+
for item_field in field_type.struct_type.fields
330+
]
331+
return lambda value_pb: _parse_struct(value_pb, element_decoders)
329332
else:
330333
raise ValueError("Unknown type: %s" % (field_type,))
331334

@@ -351,6 +354,87 @@ def _parse_list_value_pbs(rows, row_type):
351354
return result
352355

353356

357+
def _parse_string(value_pb) -> str:
358+
return value_pb.string_value
359+
360+
361+
def _parse_bytes(value_pb):
362+
return value_pb.string_value.encode("utf8")
363+
364+
365+
def _parse_bool(value_pb) -> bool:
366+
return value_pb.bool_value
367+
368+
369+
def _parse_int64(value_pb) -> int:
370+
return int(value_pb.string_value)
371+
372+
373+
def _parse_float(value_pb) -> float:
374+
if value_pb.HasField("string_value"):
375+
return float(value_pb.string_value)
376+
else:
377+
return value_pb.number_value
378+
379+
380+
def _parse_date(value_pb):
381+
return _date_from_iso8601_date(value_pb.string_value)
382+
383+
384+
def _parse_timestamp(value_pb):
385+
DatetimeWithNanoseconds = datetime_helpers.DatetimeWithNanoseconds
386+
return DatetimeWithNanoseconds.from_rfc3339(value_pb.string_value)
387+
388+
389+
def _parse_numeric(value_pb):
390+
return decimal.Decimal(value_pb.string_value)
391+
392+
393+
def _parse_json(value_pb):
394+
return JsonObject.from_str(value_pb.string_value)
395+
396+
397+
def _parse_proto(value_pb, column_info, field_name):
398+
bytes_value = base64.b64decode(value_pb.string_value)
399+
if column_info is not None and column_info.get(field_name) is not None:
400+
default_proto_message = column_info.get(field_name)
401+
if isinstance(default_proto_message, Message):
402+
proto_message = type(default_proto_message)()
403+
proto_message.ParseFromString(bytes_value)
404+
return proto_message
405+
return bytes_value
406+
407+
408+
def _parse_proto_enum(value_pb, column_info, field_name):
409+
int_value = int(value_pb.string_value)
410+
if column_info is not None and column_info.get(field_name) is not None:
411+
proto_enum = column_info.get(field_name)
412+
if isinstance(proto_enum, EnumTypeWrapper):
413+
return proto_enum.Name(int_value)
414+
return int_value
415+
416+
417+
def _parse_array(value_pb, element_decoder) -> []:
418+
return [
419+
_parse_nullable(item_pb, element_decoder)
420+
for item_pb in value_pb.list_value.values
421+
]
422+
423+
424+
def _parse_struct(value_pb, element_decoders):
425+
return [
426+
_parse_nullable(item_pb, element_decoders[i])
427+
for (i, item_pb) in enumerate(value_pb.list_value.values)
428+
]
429+
430+
431+
def _parse_nullable(value_pb, decoder):
432+
if value_pb.HasField("null_value"):
433+
return None
434+
else:
435+
return decoder(value_pb)
436+
437+
354438
class _SessionWrapper(object):
355439
"""Base class for objects wrapping a session.
356440

0 commit comments

Comments
 (0)