Skip to content

Commit d84a5ad

Browse files
authored
dev (#49)
* change wrong order of lontitude & latitude in comment of command GEOADD & GEOPOS * sync optimization from redis-py made by bgreenberg * update release log * [fix bug] pipeline callback executed * sync bug fixed of `geopos` from redis-py made by categulario * fix error caused by byte decode issues in sentinel * fix comment of connection pool * fix sentinel document bug * add basic transaction support for single node in cluster * [fix bug] fix bug of get_random_connection reported by myrfy001 * optimize parse_response of cluster transaction * update cluster scripting test case
1 parent d1e6cab commit d84a5ad

9 files changed

Lines changed: 215 additions & 88 deletions

File tree

aredis/client.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -465,23 +465,24 @@ async def _execute_command_on_nodes(self, nodes, *args, **kwargs):
465465
self.connection_pool.release(connection)
466466
return self._merge_result(command, res, **kwargs)
467467

468-
async def pipeline(self, transaction=None, shard_hint=None):
468+
async def pipeline(self, transaction=None, shard_hint=None, watches=None):
469469
"""
470470
Cluster impl:
471471
Pipelines do not work in cluster mode the same way they do in normal mode.
472472
Create a clone of this object so that simulating pipelines will work correctly.
473473
Each command will be called directly when used and when calling execute() will only return the result stack.
474+
cluster transaction can only be run with commands in the same node, otherwise error will be raised.
474475
"""
475476
await self.connection_pool.initialize()
476477
if shard_hint:
477478
raise RedisClusterException("shard_hint is deprecated in cluster mode")
478479

479-
if transaction:
480-
raise RedisClusterException("transaction is deprecated in cluster mode")
481480
from aredis.pipeline import StrictClusterPipeline
482481
return StrictClusterPipeline(
483482
connection_pool=self.connection_pool,
484483
startup_nodes=self.connection_pool.nodes.startup_nodes,
485484
result_callbacks=self.result_callbacks,
486485
response_callbacks=self.response_callbacks,
486+
transaction=transaction,
487+
watches=watches
487488
)

aredis/commands/transaction.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,28 @@ async def unwatch(self):
5454

5555
class ClusterTransactionCommandMixin(TransactionCommandMixin):
5656

57-
async def transaction(self, *args, **kwargs):
57+
async def transaction(self, func, *watches, **kwargs):
5858
"""
59-
Transaction is not implemented in cluster mode yet.
59+
Convenience method for executing the callable `func` as a transaction
60+
while watching all keys specified in `watches`. The 'func' callable
61+
should expect a single argument which is a Pipeline object.
62+
63+
cluster transaction can only be run with commands in the same node,
64+
otherwise error will be raised.
6065
"""
61-
raise RedisClusterException("method StrictRedisCluster.transaction() is not implemented")
66+
shard_hint = kwargs.pop('shard_hint', None)
67+
value_from_callable = kwargs.pop('value_from_callable', False)
68+
watch_delay = kwargs.pop('watch_delay', None)
69+
async with await self.pipeline(True, shard_hint, watches=watches) as pipe:
70+
while True:
71+
try:
72+
func_value = await func(pipe)
73+
exec_value = await pipe.execute()
74+
return func_value if value_from_callable else exec_value
75+
except WatchError:
76+
if watch_delay is not None and watch_delay > 0:
77+
await asyncio.sleep(
78+
watch_delay,
79+
loop=self.connection_pool.loop
80+
)
81+
continue

aredis/exceptions.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,12 @@ def __init__(self, resp):
109109
self.message = resp
110110

111111

112+
class ClusterTransactionError(ClusterError):
113+
114+
def __init__(self, msg):
115+
self.msg = msg
116+
117+
112118
class AskError(ResponseError):
113119
"""
114120
src node: MIGRATING to dst node

aredis/nodemanager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def get_redis_link(self, host, port):
100100
'loop'
101101
)
102102
connection_kwargs = {k: v for k, v in self.connection_kwargs.items() if k in allowed_keys}
103-
return StrictRedis(host=host, port=port, **connection_kwargs)
103+
return StrictRedis(host=host, port=port, decode_responses=True, **connection_kwargs)
104104

105105
async def initialize(self):
106106
"""

aredis/pipeline.py

Lines changed: 109 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,14 @@
1-
import asyncio
21
import sys
32
import typing
43
from itertools import chain
54

65
from aredis.client import (StrictRedisCluster, StrictRedis)
7-
from aredis.exceptions import (RedisError,
8-
ConnectionError,
9-
TimeoutError,
10-
ResponseError,
11-
WatchError,
12-
ExecAbortError,
13-
MovedError,
14-
AskError,
15-
TryAgainError,
16-
RedisClusterException)
6+
from aredis.exceptions import (RedisError, ConnectionError,
7+
TimeoutError, ResponseError,
8+
WatchError, ExecAbortError,
9+
MovedError, AskError,
10+
TryAgainError, RedisClusterException,
11+
ClusterTransactionError)
1712
from aredis.utils import (dict_merge,
1813
clusterdown_wrapper)
1914

@@ -276,9 +271,9 @@ async def execute(self, raise_on_error=True):
276271
if self.scripts:
277272
await self.load_scripts()
278273
if self.transaction or self.explicit_transaction:
279-
execute = self._execute_transaction
274+
exec = self._execute_transaction
280275
else:
281-
execute = self._execute_pipeline
276+
exec = self._execute_pipeline
282277

283278
conn = self.connection
284279
if not conn:
@@ -288,7 +283,7 @@ async def execute(self, raise_on_error=True):
288283
self.connection = conn
289284

290285
try:
291-
return await execute(conn, stack, raise_on_error)
286+
return await exec(conn, stack, raise_on_error)
292287
except (ConnectionError, TimeoutError) as e:
293288
conn.disconnect()
294289
if not conn.retry_on_timeout and isinstance(e, TimeoutError):
@@ -303,7 +298,7 @@ async def execute(self, raise_on_error=True):
303298
"one or more keys")
304299
# otherwise, it's safe to retry since the transaction isn't
305300
# predicated on any state
306-
return await execute(conn, stack, raise_on_error)
301+
return await exec(conn, stack, raise_on_error)
307302
finally:
308303
await self.reset()
309304

@@ -325,32 +320,25 @@ class StrictPipeline(BasePipeline, StrictRedis):
325320

326321
class StrictClusterPipeline(StrictRedisCluster):
327322
def __init__(self, connection_pool, result_callbacks=None,
328-
response_callbacks=None, startup_nodes=None):
329-
"""
330-
"""
323+
response_callbacks=None, startup_nodes=None,
324+
transaction=False, watches=None):
331325
self.command_stack = []
332326
self.refresh_table_asap = False
333327
self.connection_pool = connection_pool
334328
self.result_callbacks = result_callbacks or self.__class__.RESULT_CALLBACKS.copy()
335329
self.startup_nodes = startup_nodes if startup_nodes else []
336330
self.nodes_flags = self.__class__.NODES_FLAGS.copy()
337331
self.response_callbacks = dict_merge(response_callbacks or self.__class__.RESPONSE_CALLBACKS.copy())
332+
self.transaction = transaction
333+
self.watches = watches or None
334+
self.watching = False
335+
self.explicit_transaction = False
338336

339337
def __repr__(self):
340338
"""
341339
"""
342340
return "{0}".format(type(self).__name__)
343341

344-
def __enter__(self):
345-
"""
346-
"""
347-
return self
348-
349-
def __exit__(self, exc_type, exc_value, traceback):
350-
"""
351-
"""
352-
self.reset()
353-
354342
def __del__(self):
355343
"""
356344
"""
@@ -361,6 +349,12 @@ def __len__(self):
361349
"""
362350
return len(self.command_stack)
363351

352+
async def __aenter__(self):
353+
return self
354+
355+
async def __aexit__(self, exc_type, exc_val, exc_tb):
356+
self.reset()
357+
364358
def _determine_slot(self, *args):
365359
"""
366360
figure out what slot based on command and args
@@ -407,9 +401,12 @@ async def execute(self, raise_on_error=True):
407401

408402
if not stack:
409403
return []
410-
404+
if self.transaction:
405+
execute = self.send_cluster_transaction
406+
else:
407+
execute = self.send_cluster_commands
411408
try:
412-
return await self.send_cluster_commands(stack, raise_on_error)
409+
return await execute(stack, raise_on_error)
413410
finally:
414411
self.reset()
415412

@@ -420,30 +417,61 @@ def reset(self):
420417
self.command_stack = []
421418

422419
self.scripts = set()
423-
424-
# TODO: Implement
425-
# make sure to reset the connection state in the event that we were
426-
# watching something
427-
# if self.watching and self.connection:
428-
# try:
429-
# # call this manually since our unwatch or
430-
# # immediate_execute_command methods can call reset()
431-
# self.connection.send_command('UNWATCH')
432-
# self.connection.read_response()
433-
# except ConnectionError:
434-
# # disconnect will also remove any previous WATCHes
435-
# self.connection.disconnect()
436-
420+
self.watches = []
437421
# clean up the other instance attributes
438422
self.watching = False
439423
self.explicit_transaction = False
440424

441-
# TODO: Implement
442-
# we can safely return the connection to the pool here since we're
443-
# sure we're no longer WATCHing anything
444-
# if self.connection:
445-
# self.connection_pool.release(self.connection)
446-
# self.connection = None
425+
@clusterdown_wrapper
426+
async def send_cluster_transaction(self, stack, raise_on_error=True):
427+
# the first time sending the commands we send all of the commands that were queued up.
428+
# if we have to run through it again, we only retry the commands that failed.
429+
attempt = sorted(stack, key=lambda x: x.position)
430+
node = {}
431+
432+
# as we move through each command that still needs to be processed,
433+
# we figure out the slot number that command maps to, then from the slot determine the node.
434+
for c in attempt:
435+
# refer to our internal node -> slot table that tells us where a given
436+
# command should route to.
437+
slot = self._determine_slot(*c.args)
438+
hashed_node = self.connection_pool.get_node_by_slot(slot)
439+
440+
# now that we know the name of the node ( it's just a string in the form of host:port )
441+
# we can build a list of commands for each node.
442+
if node.get('name') != hashed_node['name']:
443+
# raise error if commands in a transaction can not hash to same node
444+
if len(node) > 0:
445+
raise ClusterTransactionError("Keys in request don't hash to the same node")
446+
node = hashed_node
447+
conn = self.connection_pool.get_connection_by_node(node)
448+
if self.watches:
449+
await self._watch(node, conn, self.watches)
450+
node_commands = NodeCommands(self.parse_response, conn, in_transaction=True)
451+
node_commands.append(PipelineCommand(('MULTI',)))
452+
node_commands.extend(attempt)
453+
node_commands.append(PipelineCommand(('EXEC',)))
454+
self.explicit_transaction = True
455+
await node_commands.write()
456+
# todo: make this place clear
457+
try:
458+
await node_commands.read()
459+
except ExecAbortError:
460+
if self.explicit_transaction:
461+
await conn.send_command('DISCARD')
462+
await conn.read_response()
463+
464+
# If at least one watched key is modified before the EXEC command,
465+
# the whole transaction aborts,
466+
# and EXEC returns a Null reply to notify that the transaction failed.
467+
if node_commands.commands[-1].result is None:
468+
raise WatchError
469+
self.connection_pool.release(conn)
470+
if self.watching:
471+
self._unwatch(conn)
472+
if raise_on_error:
473+
self.raise_first_error(stack)
474+
447475

448476
@clusterdown_wrapper
449477
async def send_cluster_commands(self, stack, raise_on_error=True, allow_redirections=True):
@@ -550,35 +578,38 @@ def _fail_on_redirect(self, allow_redirections):
550578
if not allow_redirections:
551579
raise RedisClusterException("ASK & MOVED redirection not allowed in this pipeline")
552580

553-
def multi(self):
581+
def _multi(self):
554582
"""
555583
"""
556584
raise RedisClusterException("method multi() is not implemented")
557585

558586
def immediate_execute_command(self, *args, **options):
559-
"""
560-
"""
561587
raise RedisClusterException("method immediate_execute_command() is not implemented")
562588

563-
def _execute_transaction(self, *args, **kwargs):
564-
"""
565-
"""
566-
raise RedisClusterException("method _execute_transaction() is not implemented")
567-
568589
def load_scripts(self):
569590
"""
570591
"""
571592
raise RedisClusterException("method load_scripts() is not implemented")
572593

573-
def watch(self, *names):
574-
"""
575-
"""
576-
raise RedisClusterException("method watch() is not implemented")
594+
async def _watch(self, node, conn, names):
595+
"Watches the values at keys ``names``"
596+
for name in names:
597+
slot = self._determine_slot('WATCH', name)
598+
dist_node = self.connection_pool.get_node_by_slot(slot)
599+
if node.get('name') != dist_node['name']:
600+
# raise error if commands in a transaction can not hash to same node
601+
if len(node) > 0:
602+
raise ClusterTransactionError("Keys in request don't hash to the same node")
603+
if self.explicit_transaction:
604+
raise RedisError('Cannot issue a WATCH after a MULTI')
605+
await conn.send_command('WATCH', *names)
606+
return await conn.read_response()
577607

578-
def unwatch(self):
579-
"""
580-
"""
581-
raise RedisClusterException("method unwatch() is not implemented")
608+
async def _unwatch(self, conn):
609+
"Unwatches all previously specified keys"
610+
await conn.send_command('UNWATCH')
611+
res = await conn.read_response()
612+
return self.watching and res or True
582613

583614
def script_load_for_pipeline(self, *args, **kwargs):
584615
"""
@@ -691,12 +722,16 @@ class NodeCommands(object):
691722
"""
692723
"""
693724

694-
def __init__(self, parse_response, connection):
725+
def __init__(self, parse_response, connection, in_transaction=False):
695726
"""
696727
"""
697728
self.parse_response = parse_response
698729
self.connection = connection
699730
self.commands = []
731+
self.in_transaction = in_transaction
732+
733+
def extend(self, c):
734+
self.commands.extend(c)
700735

701736
def append(self, c):
702737
"""
@@ -724,8 +759,6 @@ async def write(self):
724759
c.result = e
725760

726761
async def read(self):
727-
"""
728-
"""
729762
connection = self.connection
730763
for c in self.commands:
731764

@@ -741,7 +774,13 @@ async def read(self):
741774
# explicitly open the connection and all will be well.
742775
if c.result is None:
743776
try:
744-
c.result = await self.parse_response(connection, c.args[0], **c.options)
777+
if self.in_transaction:
778+
cmd = '_'
779+
else:
780+
cmd = c.args[0]
781+
c.result = await self.parse_response(connection, cmd, **c.options)
782+
except ExecAbortError:
783+
raise
745784
except (ConnectionError, TimeoutError) as e:
746785
for c in self.commands:
747786
c.result = e

0 commit comments

Comments
 (0)