summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--changelog.d/12042.misc1
-rw-r--r--stubs/txredisapi.pyi9
-rw-r--r--synapse/replication/tcp/external_cache.py4
-rw-r--r--synapse/replication/tcp/redis.py6
-rw-r--r--synapse/server.py4
5 files changed, 14 insertions, 10 deletions
diff --git a/changelog.d/12042.misc b/changelog.d/12042.misc
new file mode 100644
index 0000000000..6ecdc96021
--- /dev/null
+++ b/changelog.d/12042.misc
@@ -0,0 +1 @@
+Correct type hints for txredis.
diff --git a/stubs/txredisapi.pyi b/stubs/txredisapi.pyi
index 429234d7ae..2d8ca018fb 100644
--- a/stubs/txredisapi.pyi
+++ b/stubs/txredisapi.pyi
@@ -20,7 +20,7 @@ from twisted.internet import protocol
 from twisted.internet.defer import Deferred
 
 class RedisProtocol(protocol.Protocol):
-    def publish(self, channel: str, message: bytes): ...
+    def publish(self, channel: str, message: bytes) -> "Deferred[None]": ...
     def ping(self) -> "Deferred[None]": ...
     def set(
         self,
@@ -52,11 +52,14 @@ def lazyConnection(
     convertNumbers: bool = ...,
 ) -> RedisProtocol: ...
 
-class ConnectionHandler: ...
+# ConnectionHandler doesn't actually inherit from RedisProtocol, but it proxies
+# most methods to it via ConnectionHandler.__getattr__.
+class ConnectionHandler(RedisProtocol):
+    def disconnect(self) -> "Deferred[None]": ...
 
 class RedisFactory(protocol.ReconnectingClientFactory):
     continueTrying: bool
-    handler: RedisProtocol
+    handler: ConnectionHandler
     pool: List[RedisProtocol]
     replyTimeout: Optional[int]
     def __init__(
diff --git a/synapse/replication/tcp/external_cache.py b/synapse/replication/tcp/external_cache.py
index aaf91e5e02..bf7d017968 100644
--- a/synapse/replication/tcp/external_cache.py
+++ b/synapse/replication/tcp/external_cache.py
@@ -21,7 +21,7 @@ from synapse.logging.context import make_deferred_yieldable
 from synapse.util import json_decoder, json_encoder
 
 if TYPE_CHECKING:
-    from txredisapi import RedisProtocol
+    from txredisapi import ConnectionHandler
 
     from synapse.server import HomeServer
 
@@ -63,7 +63,7 @@ class ExternalCache:
     def __init__(self, hs: "HomeServer"):
         if hs.config.redis.redis_enabled:
             self._redis_connection: Optional[
-                "RedisProtocol"
+                "ConnectionHandler"
             ] = hs.get_outbound_redis_connection()
         else:
             self._redis_connection = None
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index 3170f7c59b..b84e572da1 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -93,7 +93,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol):
 
     synapse_handler: "ReplicationCommandHandler"
     synapse_stream_name: str
-    synapse_outbound_redis_connection: txredisapi.RedisProtocol
+    synapse_outbound_redis_connection: txredisapi.ConnectionHandler
 
     def __init__(self, *args: Any, **kwargs: Any):
         super().__init__(*args, **kwargs)
@@ -313,7 +313,7 @@ class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory):
     protocol = RedisSubscriber
 
     def __init__(
-        self, hs: "HomeServer", outbound_redis_connection: txredisapi.RedisProtocol
+        self, hs: "HomeServer", outbound_redis_connection: txredisapi.ConnectionHandler
     ):
 
         super().__init__(
@@ -353,7 +353,7 @@ def lazyConnection(
     reconnect: bool = True,
     password: Optional[str] = None,
     replyTimeout: int = 30,
-) -> txredisapi.RedisProtocol:
+) -> txredisapi.ConnectionHandler:
     """Creates a connection to Redis that is lazily set up and reconnects if the
     connections is lost.
     """
diff --git a/synapse/server.py b/synapse/server.py
index b5e2a319bc..46a64418ea 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -145,7 +145,7 @@ from synapse.util.stringutils import random_string
 logger = logging.getLogger(__name__)
 
 if TYPE_CHECKING:
-    from txredisapi import RedisProtocol
+    from txredisapi import ConnectionHandler
 
     from synapse.handlers.oidc import OidcHandler
     from synapse.handlers.saml import SamlHandler
@@ -807,7 +807,7 @@ class HomeServer(metaclass=abc.ABCMeta):
         return AccountHandler(self)
 
     @cache_in_self
-    def get_outbound_redis_connection(self) -> "RedisProtocol":
+    def get_outbound_redis_connection(self) -> "ConnectionHandler":
         """
         The Redis connection used for replication.