1- import asyncio
21import sys
32import typing
43from itertools import chain
54
65from aredis .client import (StrictRedisCluster , StrictRedis )
7- from aredis .exceptions import (RedisError ,
8- ConnectionError ,
9- TimeoutError ,
10- ResponseError ,
11- WatchError ,
12- ExecAbortError ,
13- MovedError ,
14- AskError ,
15- TryAgainError ,
16- RedisClusterException )
6+ from aredis .exceptions import (RedisError , ConnectionError ,
7+ TimeoutError , ResponseError ,
8+ WatchError , ExecAbortError ,
9+ MovedError , AskError ,
10+ TryAgainError , RedisClusterException ,
11+ ClusterTransactionError )
1712from aredis .utils import (dict_merge ,
1813 clusterdown_wrapper )
1914
@@ -276,9 +271,9 @@ async def execute(self, raise_on_error=True):
276271 if self .scripts :
277272 await self .load_scripts ()
278273 if self .transaction or self .explicit_transaction :
279- execute = self ._execute_transaction
274+ exec = self ._execute_transaction
280275 else :
281- execute = self ._execute_pipeline
276+ exec = self ._execute_pipeline
282277
283278 conn = self .connection
284279 if not conn :
@@ -288,7 +283,7 @@ async def execute(self, raise_on_error=True):
288283 self .connection = conn
289284
290285 try :
291- return await execute (conn , stack , raise_on_error )
286+ return await exec (conn , stack , raise_on_error )
292287 except (ConnectionError , TimeoutError ) as e :
293288 conn .disconnect ()
294289 if not conn .retry_on_timeout and isinstance (e , TimeoutError ):
@@ -303,7 +298,7 @@ async def execute(self, raise_on_error=True):
303298 "one or more keys" )
304299 # otherwise, it's safe to retry since the transaction isn't
305300 # predicated on any state
306- return await execute (conn , stack , raise_on_error )
301+ return await exec (conn , stack , raise_on_error )
307302 finally :
308303 await self .reset ()
309304
@@ -325,32 +320,25 @@ class StrictPipeline(BasePipeline, StrictRedis):
325320
326321class StrictClusterPipeline (StrictRedisCluster ):
327322 def __init__ (self , connection_pool , result_callbacks = None ,
328- response_callbacks = None , startup_nodes = None ):
329- """
330- """
323+ response_callbacks = None , startup_nodes = None ,
324+ transaction = False , watches = None ):
331325 self .command_stack = []
332326 self .refresh_table_asap = False
333327 self .connection_pool = connection_pool
334328 self .result_callbacks = result_callbacks or self .__class__ .RESULT_CALLBACKS .copy ()
335329 self .startup_nodes = startup_nodes if startup_nodes else []
336330 self .nodes_flags = self .__class__ .NODES_FLAGS .copy ()
337331 self .response_callbacks = dict_merge (response_callbacks or self .__class__ .RESPONSE_CALLBACKS .copy ())
332+ self .transaction = transaction
333+ self .watches = watches or None
334+ self .watching = False
335+ self .explicit_transaction = False
338336
339337 def __repr__ (self ):
340338 """
341339 """
342340 return "{0}" .format (type (self ).__name__ )
343341
344- def __enter__ (self ):
345- """
346- """
347- return self
348-
349- def __exit__ (self , exc_type , exc_value , traceback ):
350- """
351- """
352- self .reset ()
353-
354342 def __del__ (self ):
355343 """
356344 """
@@ -361,6 +349,12 @@ def __len__(self):
361349 """
362350 return len (self .command_stack )
363351
352+ async def __aenter__ (self ):
353+ return self
354+
355+ async def __aexit__ (self , exc_type , exc_val , exc_tb ):
356+ self .reset ()
357+
364358 def _determine_slot (self , * args ):
365359 """
366360 figure out what slot based on command and args
@@ -407,9 +401,12 @@ async def execute(self, raise_on_error=True):
407401
408402 if not stack :
409403 return []
410-
404+ if self .transaction :
405+ execute = self .send_cluster_transaction
406+ else :
407+ execute = self .send_cluster_commands
411408 try :
412- return await self . send_cluster_commands (stack , raise_on_error )
409+ return await execute (stack , raise_on_error )
413410 finally :
414411 self .reset ()
415412
@@ -420,30 +417,61 @@ def reset(self):
420417 self .command_stack = []
421418
422419 self .scripts = set ()
423-
424- # TODO: Implement
425- # make sure to reset the connection state in the event that we were
426- # watching something
427- # if self.watching and self.connection:
428- # try:
429- # # call this manually since our unwatch or
430- # # immediate_execute_command methods can call reset()
431- # self.connection.send_command('UNWATCH')
432- # self.connection.read_response()
433- # except ConnectionError:
434- # # disconnect will also remove any previous WATCHes
435- # self.connection.disconnect()
436-
420+ self .watches = []
437421 # clean up the other instance attributes
438422 self .watching = False
439423 self .explicit_transaction = False
440424
441- # TODO: Implement
442- # we can safely return the connection to the pool here since we're
443- # sure we're no longer WATCHing anything
444- # if self.connection:
445- # self.connection_pool.release(self.connection)
446- # self.connection = None
425+ @clusterdown_wrapper
426+ async def send_cluster_transaction (self , stack , raise_on_error = True ):
427+ # the first time sending the commands we send all of the commands that were queued up.
428+ # if we have to run through it again, we only retry the commands that failed.
429+ attempt = sorted (stack , key = lambda x : x .position )
430+ node = {}
431+
432+ # as we move through each command that still needs to be processed,
433+ # we figure out the slot number that command maps to, then from the slot determine the node.
434+ for c in attempt :
435+ # refer to our internal node -> slot table that tells us where a given
436+ # command should route to.
437+ slot = self ._determine_slot (* c .args )
438+ hashed_node = self .connection_pool .get_node_by_slot (slot )
439+
440+ # now that we know the name of the node ( it's just a string in the form of host:port )
441+ # we can build a list of commands for each node.
442+ if node .get ('name' ) != hashed_node ['name' ]:
443+ # raise error if commands in a transaction can not hash to same node
444+ if len (node ) > 0 :
445+ raise ClusterTransactionError ("Keys in request don't hash to the same node" )
446+ node = hashed_node
447+ conn = self .connection_pool .get_connection_by_node (node )
448+ if self .watches :
449+ await self ._watch (node , conn , self .watches )
450+ node_commands = NodeCommands (self .parse_response , conn , in_transaction = True )
451+ node_commands .append (PipelineCommand (('MULTI' ,)))
452+ node_commands .extend (attempt )
453+ node_commands .append (PipelineCommand (('EXEC' ,)))
454+ self .explicit_transaction = True
455+ await node_commands .write ()
456+ # todo: make this place clear
457+ try :
458+ await node_commands .read ()
459+ except ExecAbortError :
460+ if self .explicit_transaction :
461+ await conn .send_command ('DISCARD' )
462+ await conn .read_response ()
463+
464+ # If at least one watched key is modified before the EXEC command,
465+ # the whole transaction aborts,
466+ # and EXEC returns a Null reply to notify that the transaction failed.
467+ if node_commands .commands [- 1 ].result is None :
468+ raise WatchError
469+ self .connection_pool .release (conn )
470+ if self .watching :
471+ self ._unwatch (conn )
472+ if raise_on_error :
473+ self .raise_first_error (stack )
474+
447475
448476 @clusterdown_wrapper
449477 async def send_cluster_commands (self , stack , raise_on_error = True , allow_redirections = True ):
@@ -550,35 +578,38 @@ def _fail_on_redirect(self, allow_redirections):
550578 if not allow_redirections :
551579 raise RedisClusterException ("ASK & MOVED redirection not allowed in this pipeline" )
552580
553- def multi (self ):
581+ def _multi (self ):
554582 """
555583 """
556584 raise RedisClusterException ("method multi() is not implemented" )
557585
558586 def immediate_execute_command (self , * args , ** options ):
559- """
560- """
561587 raise RedisClusterException ("method immediate_execute_command() is not implemented" )
562588
563- def _execute_transaction (self , * args , ** kwargs ):
564- """
565- """
566- raise RedisClusterException ("method _execute_transaction() is not implemented" )
567-
568589 def load_scripts (self ):
569590 """
570591 """
571592 raise RedisClusterException ("method load_scripts() is not implemented" )
572593
573- def watch (self , * names ):
574- """
575- """
576- raise RedisClusterException ("method watch() is not implemented" )
594+ async def _watch (self , node , conn , names ):
595+ "Watches the values at keys ``names``"
596+ for name in names :
597+ slot = self ._determine_slot ('WATCH' , name )
598+ dist_node = self .connection_pool .get_node_by_slot (slot )
599+ if node .get ('name' ) != dist_node ['name' ]:
600+ # raise error if commands in a transaction can not hash to same node
601+ if len (node ) > 0 :
602+ raise ClusterTransactionError ("Keys in request don't hash to the same node" )
603+ if self .explicit_transaction :
604+ raise RedisError ('Cannot issue a WATCH after a MULTI' )
605+ await conn .send_command ('WATCH' , * names )
606+ return await conn .read_response ()
577607
578- def unwatch (self ):
579- """
580- """
581- raise RedisClusterException ("method unwatch() is not implemented" )
608+ async def _unwatch (self , conn ):
609+ "Unwatches all previously specified keys"
610+ await conn .send_command ('UNWATCH' )
611+ res = await conn .read_response ()
612+ return self .watching and res or True
582613
583614 def script_load_for_pipeline (self , * args , ** kwargs ):
584615 """
@@ -691,12 +722,16 @@ class NodeCommands(object):
691722 """
692723 """
693724
694- def __init__ (self , parse_response , connection ):
725+ def __init__ (self , parse_response , connection , in_transaction = False ):
695726 """
696727 """
697728 self .parse_response = parse_response
698729 self .connection = connection
699730 self .commands = []
731+ self .in_transaction = in_transaction
732+
733+ def extend (self , c ):
734+ self .commands .extend (c )
700735
701736 def append (self , c ):
702737 """
@@ -724,8 +759,6 @@ async def write(self):
724759 c .result = e
725760
726761 async def read (self ):
727- """
728- """
729762 connection = self .connection
730763 for c in self .commands :
731764
@@ -741,7 +774,13 @@ async def read(self):
741774 # explicitly open the connection and all will be well.
742775 if c .result is None :
743776 try :
744- c .result = await self .parse_response (connection , c .args [0 ], ** c .options )
777+ if self .in_transaction :
778+ cmd = '_'
779+ else :
780+ cmd = c .args [0 ]
781+ c .result = await self .parse_response (connection , cmd , ** c .options )
782+ except ExecAbortError :
783+ raise
745784 except (ConnectionError , TimeoutError ) as e :
746785 for c in self .commands :
747786 c .result = e
0 commit comments