Skip to content

Commit 89f67fd

Browse files
committed
chore: updates thread handling to avoid race condition in tests
1 parent d4bff6e commit 89f67fd

File tree

2 files changed

+51
-11
lines changed

2 files changed

+51
-11
lines changed

packages/google-cloud-spanner/google/cloud/spanner_v1/snapshot.py

Lines changed: 50 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,7 @@ def __init__(self, session, client_context=None):
220220
# respectively.
221221
self._execute_sql_request_count: int = 0
222222
self._read_request_count: int = 0
223+
self._begin_request_sent: bool = False
223224

224225
# Identifier for the transaction.
225226
self._transaction_id: Optional[bytes] = None
@@ -229,10 +230,14 @@ def __init__(self, session, client_context=None):
229230
# highest sequence number is included in the commit request.
230231
self._precommit_token: Optional[MultiplexedSessionPrecommitToken] = None
231232

232-
# Operations within a transaction can be performed using multiple
233+
# Operation within a transaction can be performed using multiple
233234
# threads, so we need to use a lock when updating the transaction.
234235
self._lock: threading.Lock = threading.Lock()
235236

237+
# Event to coordinate concurrent requests beginning the transaction.
238+
# This is used to prevent the "Transaction has not begun" race condition.
239+
self._transaction_begin_event = threading.Event()
240+
236241
def begin(self) -> bytes:
237242
"""Begins a transaction on the database.
238243
@@ -341,11 +346,27 @@ def read(
341346
read request, but is not a multi-use transaction or has not begun.
342347
"""
343348

344-
if self._read_request_count > 0:
345-
if not self._multi_use:
346-
raise ValueError("Cannot re-use single-use snapshot.")
347-
if self._transaction_id is None:
348-
raise ValueError("Transaction has not begun.")
349+
with self._lock:
350+
# Check if this request is beginning the transaction.
351+
# If a request is already in progress, other requests must wait
352+
# until the transaction ID is available.
353+
if self._begin_request_sent or self._read_request_count > 0:
354+
if not self._multi_use:
355+
raise ValueError("Cannot re-use single-use snapshot.")
356+
if self._transaction_id is None:
357+
wait_needed = True
358+
else:
359+
wait_needed = False
360+
else:
361+
wait_needed = False
362+
self._begin_request_sent = True
363+
364+
if wait_needed:
365+
# Wait for the transaction to begin (set by another concurrent request).
366+
# This prevents the race condition where concurrent requests think
367+
# the transaction hasn't begun.
368+
if not self._transaction_begin_event.wait(timeout=30.0):
369+
raise ValueError("Timed out waiting for transaction to begin.")
349370

350371
session = self._session
351372
database = session._database
@@ -527,11 +548,27 @@ def execute_sql(
527548
read request, but is not a multi-use transaction or has not begun.
528549
"""
529550

530-
if self._read_request_count > 0:
531-
if not self._multi_use:
532-
raise ValueError("Cannot re-use single-use snapshot.")
533-
if self._transaction_id is None:
534-
raise ValueError("Transaction has not begun.")
551+
with self._lock:
552+
# Check if this request is beginning the transaction.
553+
# If a request is already in progress, other requests must wait
554+
# until the transaction ID is available.
555+
if self._begin_request_sent or self._read_request_count > 0:
556+
if not self._multi_use:
557+
raise ValueError("Cannot re-use single-use snapshot.")
558+
if self._transaction_id is None:
559+
wait_needed = True
560+
else:
561+
wait_needed = False
562+
else:
563+
wait_needed = False
564+
self._begin_request_sent = True
565+
566+
if wait_needed:
567+
# Wait for the transaction to begin (set by another concurrent request).
568+
# This prevents the race condition where concurrent requests think
569+
# the transaction hasn't begun.
570+
if not self._transaction_begin_event.wait(timeout=30.0):
571+
raise ValueError("Timed out waiting for transaction to begin.")
535572

536573
if params is not None:
537574
params_pb = Struct(
@@ -1058,6 +1095,8 @@ def _update_for_transaction_pb(self, transaction_pb: Transaction) -> None:
10581095
# caller is responsible for locking until the transaction ID is updated.
10591096
if self._transaction_id is None and transaction_pb.id:
10601097
self._transaction_id = transaction_pb.id
1098+
# Notify waiting threads that the transaction has begun.
1099+
self._transaction_begin_event.set()
10611100

10621101
if transaction_pb._pb.HasField("precommit_token"):
10631102
self._update_for_precommit_token_pb_unsafe(transaction_pb.precommit_token)

packages/google-cloud-spanner/tests/unit/test_spanner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1100,6 +1100,7 @@ def test_transaction_for_concurrent_statement_should_begin_one_transaction_with_
11001100
def test_transaction_for_concurrent_statement_should_begin_one_transaction_with_read(
11011101
self,
11021102
):
1103+
self.maxDiff = None
11031104
database = _Database()
11041105
api = database.spanner_api = self._make_spanner_api()
11051106
session = _Session(database)

0 commit comments

Comments
 (0)