@@ -367,7 +367,7 @@ class BaseConnection:
367367 def __init__ (self , retry_on_timeout = False , stream_timeout = None ,
368368 parser_class = DefaultParser , reader_read_size = 65535 ,
369369 encoding = 'utf-8' , decode_responses = False ,
370- * , loop = None ):
370+ * , client_name = None , loop = None ):
371371 self ._parser = parser_class (reader_read_size )
372372 self ._stream_timeout = stream_timeout
373373 self ._reader = None
@@ -381,6 +381,7 @@ def __init__(self, retry_on_timeout=False, stream_timeout=None,
381381 self .encoding = encoding
382382 self .decode_responses = decode_responses
383383 self .loop = loop
384+ self .client_name = client_name
384385 # flag to show if a connection is waiting for response
385386 self .awaiting_response = False
386387 self .last_active_at = time .time ()
@@ -444,6 +445,11 @@ async def on_connect(self):
444445 await self .send_command ('SELECT' , self .db )
445446 if nativestr (await self .read_response ()) != 'OK' :
446447 raise ConnectionError ('Invalid Database' )
448+
449+ if self .client_name is not None :
450+ await self .send_command ('CLIENT SETNAME' , self .client_name )
451+ if nativestr (await self .read_response ()) != 'OK' :
452+ raise ConnectionError ('Failed to set client name: {}' .format (self .client_name ))
447453 self .last_active_at = time .time ()
448454
449455 async def read_response (self ):
@@ -573,11 +579,11 @@ def __init__(self, host='127.0.0.1', port=6379, password=None,
573579 db = 0 , retry_on_timeout = False , stream_timeout = None , connect_timeout = None ,
574580 ssl_context = None , parser_class = DefaultParser , reader_read_size = 65535 ,
575581 encoding = 'utf-8' , decode_responses = False , socket_keepalive = None ,
576- socket_keepalive_options = None , * , loop = None ):
582+ socket_keepalive_options = None , * , client_name = None , loop = None ):
577583 super (Connection , self ).__init__ (retry_on_timeout , stream_timeout ,
578584 parser_class , reader_read_size ,
579585 encoding , decode_responses ,
580- loop = loop )
586+ client_name = client_name , loop = loop )
581587 self .host = host
582588 self .port = port
583589 self .password = password
@@ -626,11 +632,11 @@ class UnixDomainSocketConnection(BaseConnection):
626632 def __init__ (self , path = '' , password = None ,
627633 db = 0 , retry_on_timeout = False , stream_timeout = None , connect_timeout = None ,
628634 ssl_context = None , parser_class = DefaultParser , reader_read_size = 65535 ,
629- encoding = 'utf-8' , decode_responses = False , * , loop = None ):
635+ encoding = 'utf-8' , decode_responses = False , * , client_name = None , loop = None ):
630636 super (UnixDomainSocketConnection , self ).__init__ (retry_on_timeout , stream_timeout ,
631637 parser_class , reader_read_size ,
632638 encoding , decode_responses ,
633- loop = loop )
639+ client_name = client_name , loop = loop )
634640 self .path = path
635641 self .db = db
636642 self .password = password
0 commit comments