1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14-
14+ import base64
1515from concurrent import futures
1616
1717from google .protobuf import empty_pb2
@@ -31,22 +31,32 @@ def __init__(self):
3131 self .results = {}
3232
3333 def add_result (self , sql : str , result : result_set .ResultSet ):
34- self .results [sql .lower ()] = result
34+ self .results [sql .lower ().strip ()] = result
35+
36+ def get_result (self , sql : str ) -> result_set .ResultSet :
37+ result = self .results .get (sql .lower ().strip ())
38+ if result is None :
39+ raise ValueError (f"No result found for { sql } " )
40+ return result
3541
3642 def get_result_as_partial_result_sets (
3743 self , sql : str
3844 ) -> [result_set .PartialResultSet ]:
39- result : result_set .ResultSet = self .results .get (sql .lower ())
40- if result is None :
41- return []
45+ result : result_set .ResultSet = self .get_result (sql )
4246 partials = []
4347 first = True
44- for row in result .rows :
48+ if len ( result .rows ) == 0 :
4549 partial = result_set .PartialResultSet ()
46- if first :
47- partial .metadata = result .metadata
48- partial .values .extend (row )
50+ partial .metadata = result .metadata
4951 partials .append (partial )
52+ else :
53+ for row in result .rows :
54+ partial = result_set .PartialResultSet ()
55+ if first :
56+ partial .metadata = result .metadata
57+ partial .values .extend (row )
58+ partials .append (partial )
59+ partials [len (partials ) - 1 ].stats = result .stats
5060 return partials
5161
5262
@@ -56,6 +66,8 @@ def __init__(self):
5666 self ._requests = []
5767 self .session_counter = 0
5868 self .sessions = {}
69+ self .transaction_counter = 0
70+ self .transactions = {}
5971 self ._mock_spanner = MockSpanner ()
6072
6173 @property
@@ -93,18 +105,20 @@ def __create_session(self, database: str, session_template: spanner.Session):
93105 return session
94106
95107 def GetSession (self , request , context ):
108+ self ._requests .append (request )
96109 return spanner .Session ()
97110
98111 def ListSessions (self , request , context ):
112+ self ._requests .append (request )
99113 return [spanner .Session ()]
100114
101115 def DeleteSession (self , request , context ):
116+ self ._requests .append (request )
102117 return empty_pb2 .Empty ()
103118
104119 def ExecuteSql (self , request , context ):
105120 self ._requests .append (request )
106- result : result_set .ResultSet = self .mock_spanner .results .get (request .sql .lower ())
107- return result
121+ return result_set .ResultSet ()
108122
109123 def ExecuteStreamingSql (self , request , context ):
110124 self ._requests .append (request )
@@ -113,31 +127,74 @@ def ExecuteStreamingSql(self, request, context):
113127 yield result
114128
115129 def ExecuteBatchDml (self , request , context ):
116- return spanner .ExecuteBatchDmlResponse ()
130+ self ._requests .append (request )
131+ response = spanner .ExecuteBatchDmlResponse ()
132+ started_transaction = None
133+ if not request .transaction .begin == transaction .TransactionOptions ():
134+ started_transaction = self .__create_transaction (
135+ request .session , request .transaction .begin
136+ )
137+ first = True
138+ for statement in request .statements :
139+ result = self .mock_spanner .get_result (statement .sql )
140+ if first and started_transaction is not None :
141+ result = result_set .ResultSet (
142+ self .mock_spanner .get_result (statement .sql )
143+ )
144+ result .metadata = result_set .ResultSetMetadata (result .metadata )
145+ result .metadata .transaction = started_transaction
146+ response .result_sets .append (result )
147+ return response
117148
118149 def Read (self , request , context ):
150+ self ._requests .append (request )
119151 return result_set .ResultSet ()
120152
121153 def StreamingRead (self , request , context ):
154+ self ._requests .append (request )
122155 for result in [result_set .PartialResultSet (), result_set .PartialResultSet ()]:
123156 yield result
124157
125158 def BeginTransaction (self , request , context ):
126- return transaction .Transaction ()
159+ self ._requests .append (request )
160+ return self .__create_transaction (request .session , request .options )
161+
162+ def __create_transaction (
163+ self , session : str , options : transaction .TransactionOptions
164+ ) -> transaction .Transaction :
165+ session = self .sessions [session ]
166+ if session is None :
167+ raise ValueError (f"Session not found: { session } " )
168+ self .transaction_counter += 1
169+ id_bytes = bytes (
170+ f"{ session .name } /transactions/{ self .transaction_counter } " , "UTF-8"
171+ )
172+ transaction_id = base64 .urlsafe_b64encode (id_bytes )
173+ self .transactions [transaction_id ] = options
174+ return transaction .Transaction (dict (id = transaction_id ))
127175
128176 def Commit (self , request , context ):
177+ self ._requests .append (request )
178+ tx = self .transactions [request .transaction_id ]
179+ if tx is None :
180+ raise ValueError (f"Transaction not found: { request .transaction_id } " )
181+ del self .transactions [request .transaction_id ]
129182 return commit .CommitResponse ()
130183
131184 def Rollback (self , request , context ):
185+ self ._requests .append (request )
132186 return empty_pb2 .Empty ()
133187
134188 def PartitionQuery (self , request , context ):
189+ self ._requests .append (request )
135190 return spanner .PartitionResponse ()
136191
137192 def PartitionRead (self , request , context ):
193+ self ._requests .append (request )
138194 return spanner .PartitionResponse ()
139195
140196 def BatchWrite (self , request , context ):
197+ self ._requests .append (request )
141198 for result in [spanner .BatchWriteResponse (), spanner .BatchWriteResponse ()]:
142199 yield result
143200
0 commit comments