Skip to content

Commit e3d89e9

Browse files
committed
fix: Secure ClientContext merging and improve type safety
- Fix critical security vulnerability in where in-place modification of the base object could lead to context leakage across requests. - Replace with throughout , , , and for better robustness and subclass support. - Simplify construction logic in for better readability. Based on suggestions from Gemini Code Assist.
1 parent 3611aec commit e3d89e9

File tree

4 files changed

+30
-18
lines changed

4 files changed

+30
-18
lines changed

google/cloud/spanner_v1/_helpers.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -174,15 +174,15 @@ def _merge_query_options(base, merge):
174174
If the resultant object only has empty fields, returns None.
175175
"""
176176
combined = base or ExecuteSqlRequest.QueryOptions()
177-
if type(combined) is dict:
177+
if isinstance(combined, dict):
178178
combined = ExecuteSqlRequest.QueryOptions(
179179
optimizer_version=combined.get("optimizer_version", ""),
180180
optimizer_statistics_package=combined.get(
181181
"optimizer_statistics_package", ""
182182
),
183183
)
184184
merge = merge or ExecuteSqlRequest.QueryOptions()
185-
if type(merge) is dict:
185+
if isinstance(merge, dict):
186186
merge = ExecuteSqlRequest.QueryOptions(
187187
optimizer_version=merge.get("optimizer_version", ""),
188188
optimizer_statistics_package=merge.get("optimizer_statistics_package", ""),
@@ -215,14 +215,28 @@ def _merge_client_context(base, merge):
215215
return None
216216

217217
combined = base or ClientContext()
218-
if type(combined) is dict:
218+
if isinstance(combined, dict):
219219
combined = ClientContext(combined)
220220

221221
merge = merge or ClientContext()
222-
if type(merge) is dict:
222+
if isinstance(merge, dict):
223223
merge = ClientContext(merge)
224224

225-
type(combined).pb(combined).MergeFrom(type(merge).pb(merge))
225+
# Avoid in-place modification of base
226+
combined_pb = ClientContext()._pb
227+
if base:
228+
base_pb = (
229+
ClientContext(base)._pb if isinstance(base, dict) else base._pb
230+
)
231+
combined_pb.MergeFrom(base_pb)
232+
if merge:
233+
merge_pb = (
234+
ClientContext(merge)._pb if isinstance(merge, dict) else merge._pb
235+
)
236+
combined_pb.MergeFrom(merge_pb)
237+
238+
combined = ClientContext(combined_pb)
239+
226240
if not combined.secure_context:
227241
return None
228242
return combined
@@ -250,7 +264,7 @@ def _merge_request_options(request_options, client_context):
250264

251265
if request_options is None:
252266
request_options = RequestOptions()
253-
elif type(request_options) is dict:
267+
elif isinstance(request_options, dict):
254268
request_options = RequestOptions(request_options)
255269

256270
if client_context:

google/cloud/spanner_v1/batch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def __init__(self, session, client_context=None):
6868
self.commit_stats: Optional[CommitResponse.CommitStats] = None
6969

7070
if client_context is not None:
71-
if type(client_context) is dict:
71+
if isinstance(client_context, dict):
7272
client_context = ClientContext(client_context)
7373
elif not isinstance(client_context, ClientContext):
7474
raise TypeError("client_context must be a ClientContext or a dict")
@@ -349,7 +349,7 @@ def __init__(self, session, client_context=None):
349349
self.committed: bool = False
350350

351351
if client_context is not None:
352-
if type(client_context) is dict:
352+
if isinstance(client_context, dict):
353353
client_context = ClientContext(client_context)
354354
elif not isinstance(client_context, ClientContext):
355355
raise TypeError("client_context must be a ClientContext or a dict")

google/cloud/spanner_v1/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ def __init__(
295295
self._query_options = _merge_query_options(query_options, env_query_options)
296296

297297
if client_context is not None:
298-
if type(client_context) is dict:
298+
if isinstance(client_context, dict):
299299
client_context = ClientContext(client_context)
300300
elif not isinstance(client_context, ClientContext):
301301
raise TypeError("client_context must be a ClientContext or a dict")

google/cloud/spanner_v1/snapshot.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def __init__(self, session, client_context=None):
213213
super().__init__(session)
214214

215215
if client_context is not None:
216-
if type(client_context) is dict:
216+
if isinstance(client_context, dict):
217217
client_context = ClientContext(client_context)
218218
elif not isinstance(client_context, ClientContext):
219219
raise TypeError("client_context must be a ClientContext or a dict")
@@ -949,20 +949,18 @@ def _begin_transaction(
949949
"mutation_key": mutation,
950950
}
951951

952+
request_options = begin_request_kwargs.get("request_options")
952953
client_context = _merge_client_context(
953954
database._instance._client._client_context, self._client_context
954955
)
955-
if client_context:
956-
begin_request_kwargs["request_options"] = _merge_request_options(
957-
begin_request_kwargs.get("request_options"), client_context
958-
)
956+
request_options = _merge_request_options(request_options, client_context)
959957

960958
if transaction_tag:
961-
request_options = begin_request_kwargs.get("request_options")
962959
if request_options is None:
963-
request_options = RequestOptions(transaction_tag=transaction_tag)
964-
else:
965-
request_options.transaction_tag = transaction_tag
960+
request_options = RequestOptions()
961+
request_options.transaction_tag = transaction_tag
962+
963+
if request_options:
966964
begin_request_kwargs["request_options"] = request_options
967965

968966
with trace_call(

0 commit comments

Comments
 (0)