Skip to content

Commit 67c92a0

Browse files
committed
Add the option to set the client name
1 parent b46e671 commit 67c92a0

File tree

5 files changed

+16
-7
lines changed

5 files changed

+16
-7
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,4 @@ vagrant/.vagrant
1515
.vscode/
1616
*.iml
1717
.pytest_cache/
18+
*.so

aredis/client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ def __init__(self, host='localhost', port=6379,
100100
ssl_cert_reqs=None, ssl_ca_certs=None,
101101
max_connections=None, retry_on_timeout=False,
102102
max_idle_time=0, idle_check_interval=1,
103+
client_name=None,
103104
loop=None, **kwargs):
104105
if not connection_pool:
105106
kwargs = {
@@ -113,6 +114,7 @@ def __init__(self, host='localhost', port=6379,
113114
'decode_responses': decode_responses,
114115
'max_idle_time': max_idle_time,
115116
'idle_check_interval': idle_check_interval,
117+
'client_name': client_name,
116118
'loop': loop
117119
}
118120
# based on input, setup appropriate connection args

aredis/connection.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ class BaseConnection:
367367
def __init__(self, retry_on_timeout=False, stream_timeout=None,
368368
parser_class=DefaultParser, reader_read_size=65535,
369369
encoding='utf-8', decode_responses=False,
370-
*, loop=None):
370+
*, client_name=None, loop=None):
371371
self._parser = parser_class(reader_read_size)
372372
self._stream_timeout = stream_timeout
373373
self._reader = None
@@ -381,6 +381,7 @@ def __init__(self, retry_on_timeout=False, stream_timeout=None,
381381
self.encoding = encoding
382382
self.decode_responses = decode_responses
383383
self.loop = loop
384+
self.client_name = client_name
384385
# flag to show if a connection is waiting for response
385386
self.awaiting_response = False
386387
self.last_active_at = time.time()
@@ -444,6 +445,11 @@ async def on_connect(self):
444445
await self.send_command('SELECT', self.db)
445446
if nativestr(await self.read_response()) != 'OK':
446447
raise ConnectionError('Invalid Database')
448+
449+
if self.client_name is not None:
450+
await self.send_command('CLIENT SETNAME', self.client_name)
451+
if nativestr(await self.read_response()) != 'OK':
452+
raise ConnectionError('Failed to set client name: {}'.format(self.client_name))
447453
self.last_active_at = time.time()
448454

449455
async def read_response(self):
@@ -573,11 +579,11 @@ def __init__(self, host='127.0.0.1', port=6379, password=None,
573579
db=0, retry_on_timeout=False, stream_timeout=None, connect_timeout=None,
574580
ssl_context=None, parser_class=DefaultParser, reader_read_size=65535,
575581
encoding='utf-8', decode_responses=False, socket_keepalive=None,
576-
socket_keepalive_options=None, *, loop=None):
582+
socket_keepalive_options=None, *, client_name=None, loop=None):
577583
super(Connection, self).__init__(retry_on_timeout, stream_timeout,
578584
parser_class, reader_read_size,
579585
encoding, decode_responses,
580-
loop=loop)
586+
client_name=client_name, loop=loop)
581587
self.host = host
582588
self.port = port
583589
self.password = password
@@ -626,11 +632,11 @@ class UnixDomainSocketConnection(BaseConnection):
626632
def __init__(self, path='', password=None,
627633
db=0, retry_on_timeout=False, stream_timeout=None, connect_timeout=None,
628634
ssl_context=None, parser_class=DefaultParser, reader_read_size=65535,
629-
encoding='utf-8', decode_responses=False, *, loop=None):
635+
encoding='utf-8', decode_responses=False, *, client_name=None, loop=None):
630636
super(UnixDomainSocketConnection, self).__init__(retry_on_timeout, stream_timeout,
631637
parser_class, reader_read_size,
632638
encoding, decode_responses,
633-
loop=loop)
639+
client_name=client_name, loop=loop)
634640
self.path = path
635641
self.db = db
636642
self.password = password

tests/client/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def skip_python_vsersion_lt(min_version):
3737

3838
@pytest.fixture()
3939
def r(event_loop):
40-
return aredis.StrictRedis(loop=event_loop)
40+
return aredis.StrictRedis(client_name='test', loop=event_loop)
4141

4242

4343
class AsyncMock(Mock):

tests/client/test_commands.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ async def test_client_list_after_client_setname(self, r):
6464
@skip_if_server_version_lt('2.6.9')
6565
@pytest.mark.asyncio(forbid_global_loop=True)
6666
async def test_client_getname(self, r):
67-
assert await r.client_getname() is None
67+
assert await r.client_getname() == 'test'
6868

6969
@skip_if_server_version_lt('2.6.9')
7070
@pytest.mark.asyncio(forbid_global_loop=True)

0 commit comments

Comments
 (0)