summary refs log tree commit diff
path: root/synapse/replication
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication')
-rw-r--r--synapse/replication/http/_base.py44
-rw-r--r--synapse/replication/http/devices.py2
-rw-r--r--synapse/replication/http/federation.py8
-rw-r--r--synapse/replication/http/login.py2
-rw-r--r--synapse/replication/http/membership.py6
-rw-r--r--synapse/replication/http/register.py4
-rw-r--r--synapse/replication/http/send_event.py2
-rw-r--r--synapse/replication/slave/storage/_base.py4
-rw-r--r--synapse/replication/slave/storage/account_data.py2
-rw-r--r--synapse/replication/slave/storage/client_ips.py2
-rw-r--r--synapse/replication/slave/storage/deviceinbox.py2
-rw-r--r--synapse/replication/slave/storage/devices.py2
-rw-r--r--synapse/replication/slave/storage/events.py2
-rw-r--r--synapse/replication/slave/storage/filtering.py2
-rw-r--r--synapse/replication/slave/storage/groups.py2
-rw-r--r--synapse/replication/slave/storage/presence.py2
-rw-r--r--synapse/replication/slave/storage/pushers.py2
-rw-r--r--synapse/replication/slave/storage/receipts.py2
-rw-r--r--synapse/replication/slave/storage/room.py2
-rw-r--r--synapse/replication/tcp/client.py10
-rw-r--r--synapse/replication/tcp/handler.py6
-rw-r--r--synapse/replication/tcp/protocol.py10
-rw-r--r--synapse/replication/tcp/redis.py40
-rw-r--r--synapse/replication/tcp/streams/_base.py2
24 files changed, 111 insertions, 51 deletions
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index ba16f22c91..64edadb624 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -20,20 +20,30 @@ import urllib
 from inspect import signature
 from typing import Dict, List, Tuple
 
-from synapse.api.errors import (
-    CodeMessageException,
-    HttpResponseException,
-    RequestSendFailed,
-    SynapseError,
-)
+from prometheus_client import Counter, Gauge
+
+from synapse.api.errors import HttpResponseException, SynapseError
+from synapse.http import RequestTimedOutError
 from synapse.logging.opentracing import inject_active_span_byte_dict, trace
 from synapse.util.caches.response_cache import ResponseCache
 from synapse.util.stringutils import random_string
 
 logger = logging.getLogger(__name__)
 
+_pending_outgoing_requests = Gauge(
+    "synapse_pending_outgoing_replication_requests",
+    "Number of active outgoing replication requests, by replication method name",
+    ["name"],
+)
+
+_outgoing_request_counter = Counter(
+    "synapse_outgoing_replication_requests",
+    "Number of outgoing replication requests, by replication method name and result",
+    ["name", "code"],
+)
+
 
-class ReplicationEndpoint:
+class ReplicationEndpoint(metaclass=abc.ABCMeta):
     """Helper base class for defining new replication HTTP endpoints.
 
     This creates an endpoint under `/_synapse/replication/:NAME/:PATH_ARGS..`
@@ -72,8 +82,6 @@ class ReplicationEndpoint:
             is received.
     """
 
-    __metaclass__ = abc.ABCMeta
-
     NAME = abc.abstractproperty()  # type: str  # type: ignore
     PATH_ARGS = abc.abstractproperty()  # type: Tuple[str, ...]  # type: ignore
     METHOD = "POST"
@@ -140,7 +148,10 @@ class ReplicationEndpoint:
 
         instance_map = hs.config.worker.instance_map
 
+        outgoing_gauge = _pending_outgoing_requests.labels(cls.NAME)
+
         @trace(opname="outgoing_replication_request")
+        @outgoing_gauge.track_inprogress()
         async def send_request(instance_name="master", **kwargs):
             if instance_name == local_instance_name:
                 raise Exception("Trying to send HTTP request to self")
@@ -195,23 +206,26 @@ class ReplicationEndpoint:
                     try:
                         result = await request_func(uri, data, headers=headers)
                         break
-                    except CodeMessageException as e:
-                        if e.code != 504 or not cls.RETRY_ON_TIMEOUT:
+                    except RequestTimedOutError:
+                        if not cls.RETRY_ON_TIMEOUT:
                             raise
 
-                    logger.warning("%s request timed out", cls.NAME)
+                    logger.warning("%s request timed out; retrying", cls.NAME)
 
                     # If we timed out we probably don't need to worry about backing
                     # off too much, but lets just wait a little anyway.
                     await clock.sleep(1)
             except HttpResponseException as e:
                 # We convert to SynapseError as we know that it was a SynapseError
-                # on the master process that we should send to the client. (And
+                # on the main process that we should send to the client. (And
                 # importantly, not stack traces everywhere)
+                _outgoing_request_counter.labels(cls.NAME, e.code).inc()
                 raise e.to_synapse_error()
-            except RequestSendFailed as e:
-                raise SynapseError(502, "Failed to talk to master") from e
+            except Exception as e:
+                _outgoing_request_counter.labels(cls.NAME, "ERR").inc()
+                raise SynapseError(502, "Failed to talk to main process") from e
 
+            _outgoing_request_counter.labels(cls.NAME, 200).inc()
             return result
 
         return send_request
diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py
index 20f3ba76c0..807b85d2e1 100644
--- a/synapse/replication/http/devices.py
+++ b/synapse/replication/http/devices.py
@@ -53,7 +53,7 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint):
     CACHE = False
 
     def __init__(self, hs):
-        super(ReplicationUserDevicesResyncRestServlet, self).__init__(hs)
+        super().__init__(hs)
 
         self.device_list_updater = hs.get_device_handler().device_list_updater
         self.store = hs.get_datastore()
diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py
index 5c8be747e1..5393b9a9e7 100644
--- a/synapse/replication/http/federation.py
+++ b/synapse/replication/http/federation.py
@@ -57,7 +57,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
     PATH_ARGS = ()
 
     def __init__(self, hs):
-        super(ReplicationFederationSendEventsRestServlet, self).__init__(hs)
+        super().__init__(hs)
 
         self.store = hs.get_datastore()
         self.storage = hs.get_storage()
@@ -150,7 +150,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
     PATH_ARGS = ("edu_type",)
 
     def __init__(self, hs):
-        super(ReplicationFederationSendEduRestServlet, self).__init__(hs)
+        super().__init__(hs)
 
         self.store = hs.get_datastore()
         self.clock = hs.get_clock()
@@ -193,7 +193,7 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint):
     CACHE = False
 
     def __init__(self, hs):
-        super(ReplicationGetQueryRestServlet, self).__init__(hs)
+        super().__init__(hs)
 
         self.store = hs.get_datastore()
         self.clock = hs.get_clock()
@@ -236,7 +236,7 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint):
     PATH_ARGS = ("room_id",)
 
     def __init__(self, hs):
-        super(ReplicationCleanRoomRestServlet, self).__init__(hs)
+        super().__init__(hs)
 
         self.store = hs.get_datastore()
 
diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py
index fb326bb869..4c81e2d784 100644
--- a/synapse/replication/http/login.py
+++ b/synapse/replication/http/login.py
@@ -32,7 +32,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
     PATH_ARGS = ("user_id",)
 
     def __init__(self, hs):
-        super(RegisterDeviceReplicationServlet, self).__init__(hs)
+        super().__init__(hs)
         self.registration_handler = hs.get_registration_handler()
 
     @staticmethod
diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py
index 08095fdf7d..30680baee8 100644
--- a/synapse/replication/http/membership.py
+++ b/synapse/replication/http/membership.py
@@ -45,7 +45,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
     PATH_ARGS = ("room_id", "user_id")
 
     def __init__(self, hs):
-        super(ReplicationRemoteJoinRestServlet, self).__init__(hs)
+        super().__init__(hs)
 
         self.federation_handler = hs.get_handlers().federation_handler
         self.store = hs.get_datastore()
@@ -107,7 +107,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
     PATH_ARGS = ("invite_event_id",)
 
     def __init__(self, hs: "HomeServer"):
-        super(ReplicationRemoteRejectInviteRestServlet, self).__init__(hs)
+        super().__init__(hs)
 
         self.store = hs.get_datastore()
         self.clock = hs.get_clock()
@@ -168,7 +168,7 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
     CACHE = False  # No point caching as should return instantly.
 
     def __init__(self, hs):
-        super(ReplicationUserJoinedLeftRoomRestServlet, self).__init__(hs)
+        super().__init__(hs)
 
         self.registeration_handler = hs.get_registration_handler()
         self.store = hs.get_datastore()
diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py
index a02b27474d..7b12ec9060 100644
--- a/synapse/replication/http/register.py
+++ b/synapse/replication/http/register.py
@@ -29,7 +29,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
     PATH_ARGS = ("user_id",)
 
     def __init__(self, hs):
-        super(ReplicationRegisterServlet, self).__init__(hs)
+        super().__init__(hs)
         self.store = hs.get_datastore()
         self.registration_handler = hs.get_registration_handler()
 
@@ -104,7 +104,7 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint):
     PATH_ARGS = ("user_id",)
 
     def __init__(self, hs):
-        super(ReplicationPostRegisterActionsServlet, self).__init__(hs)
+        super().__init__(hs)
         self.store = hs.get_datastore()
         self.registration_handler = hs.get_registration_handler()
 
diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py
index f13d452426..9a3a694d5d 100644
--- a/synapse/replication/http/send_event.py
+++ b/synapse/replication/http/send_event.py
@@ -52,7 +52,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
     PATH_ARGS = ("event_id",)
 
     def __init__(self, hs):
-        super(ReplicationSendEventRestServlet, self).__init__(hs)
+        super().__init__(hs)
 
         self.event_creation_handler = hs.get_event_creation_handler()
         self.store = hs.get_datastore()
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index 60f2e1245f..d0089fe06c 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -26,16 +26,18 @@ logger = logging.getLogger(__name__)
 
 class BaseSlavedStore(CacheInvalidationWorkerStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(BaseSlavedStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
         if isinstance(self.database_engine, PostgresEngine):
             self._cache_id_gen = MultiWriterIdGenerator(
                 db_conn,
                 database,
+                stream_name="caches",
                 instance_name=hs.get_instance_name(),
                 table="cache_invalidation_stream_by_instance",
                 instance_column="instance_name",
                 id_column="stream_id",
                 sequence_name="cache_invalidation_stream_seq",
+                writers=[],
             )  # type: Optional[MultiWriterIdGenerator]
         else:
             self._cache_id_gen = None
diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py
index bb66ba9b80..4268565fc8 100644
--- a/synapse/replication/slave/storage/account_data.py
+++ b/synapse/replication/slave/storage/account_data.py
@@ -34,7 +34,7 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved
             ],
         )
 
-        super(SlavedAccountDataStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
     def get_max_account_data_stream_id(self):
         return self._account_data_id_gen.get_current_token()
diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py
index a6fdedde63..1f8dafe7ea 100644
--- a/synapse/replication/slave/storage/client_ips.py
+++ b/synapse/replication/slave/storage/client_ips.py
@@ -22,7 +22,7 @@ from ._base import BaseSlavedStore
 
 class SlavedClientIpStore(BaseSlavedStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(SlavedClientIpStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
         self.client_ip_last_seen = Cache(
             name="client_ip_last_seen", keylen=4, max_entries=50000
diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py
index 533d927701..5b045bed02 100644
--- a/synapse/replication/slave/storage/deviceinbox.py
+++ b/synapse/replication/slave/storage/deviceinbox.py
@@ -24,7 +24,7 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(SlavedDeviceInboxStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
         self._device_inbox_id_gen = SlavedIdTracker(
             db_conn, "device_inbox", "stream_id"
         )
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index 3b788c9625..e0d86240dd 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -24,7 +24,7 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(SlavedDeviceStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
         self.hs = hs
 
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index da1cc836cf..fbffe6d85c 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -56,7 +56,7 @@ class SlavedEventStore(
     BaseSlavedStore,
 ):
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(SlavedEventStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
         events_max = self._stream_id_gen.get_current_token()
         curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict(
diff --git a/synapse/replication/slave/storage/filtering.py b/synapse/replication/slave/storage/filtering.py
index 2562b6fc38..6a23252861 100644
--- a/synapse/replication/slave/storage/filtering.py
+++ b/synapse/replication/slave/storage/filtering.py
@@ -21,7 +21,7 @@ from ._base import BaseSlavedStore
 
 class SlavedFilteringStore(BaseSlavedStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(SlavedFilteringStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
     # Filters are immutable so this cache doesn't need to be expired
     get_user_filter = FilteringStore.__dict__["get_user_filter"]
diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py
index 567b4a5cc1..30955bcbfe 100644
--- a/synapse/replication/slave/storage/groups.py
+++ b/synapse/replication/slave/storage/groups.py
@@ -23,7 +23,7 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(SlavedGroupServerStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
         self.hs = hs
 
diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py
index 025f6f6be8..55620c03d8 100644
--- a/synapse/replication/slave/storage/presence.py
+++ b/synapse/replication/slave/storage/presence.py
@@ -25,7 +25,7 @@ from ._slaved_id_tracker import SlavedIdTracker
 
 class SlavedPresenceStore(BaseSlavedStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(SlavedPresenceStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
         self._presence_id_gen = SlavedIdTracker(db_conn, "presence_stream", "stream_id")
 
         self._presence_on_startup = self._get_active_presence(db_conn)  # type: ignore
diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py
index 9da218bfe8..c418730ba8 100644
--- a/synapse/replication/slave/storage/pushers.py
+++ b/synapse/replication/slave/storage/pushers.py
@@ -24,7 +24,7 @@ from ._slaved_id_tracker import SlavedIdTracker
 
 class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(SlavedPusherStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
         self._pushers_id_gen = SlavedIdTracker(
             db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
         )
diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py
index 5c2986e050..6195917376 100644
--- a/synapse/replication/slave/storage/receipts.py
+++ b/synapse/replication/slave/storage/receipts.py
@@ -30,7 +30,7 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
             db_conn, "receipts_linearized", "stream_id"
         )
 
-        super(SlavedReceiptsStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
     def get_max_receipt_stream_id(self):
         return self._receipts_id_gen.get_current_token()
diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py
index 80ae803ad9..109ac6bea1 100644
--- a/synapse/replication/slave/storage/room.py
+++ b/synapse/replication/slave/storage/room.py
@@ -23,7 +23,7 @@ from ._slaved_id_tracker import SlavedIdTracker
 
 class RoomStore(RoomWorkerStore, BaseSlavedStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(RoomStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
         self._public_room_id_gen = SlavedIdTracker(
             db_conn, "public_room_list_stream", "stream_id"
         )
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index e82b9e386f..e165429cad 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -29,7 +29,7 @@ from synapse.replication.tcp.streams.events import (
     EventsStreamEventRow,
     EventsStreamRow,
 )
-from synapse.types import UserID
+from synapse.types import PersistedEventPosition, UserID
 from synapse.util.async_helpers import timeout_deferred
 from synapse.util.metrics import Measure
 
@@ -151,8 +151,12 @@ class ReplicationDataHandler:
                 extra_users = ()  # type: Tuple[UserID, ...]
                 if event.type == EventTypes.Member:
                     extra_users = (UserID.from_string(event.state_key),)
-                max_token = self.store.get_room_max_stream_ordering()
-                self.notifier.on_new_room_event(event, token, max_token, extra_users)
+
+                max_token = self.store.get_room_max_token()
+                event_pos = PersistedEventPosition(instance_name, token)
+                self.notifier.on_new_room_event(
+                    event, event_pos, max_token, extra_users
+                )
 
         # Notify any waiting deferreds. The list is ordered by position so we
         # just iterate through the list until we reach a position that is
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index b323841f73..e92da7b263 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -251,10 +251,9 @@ class ReplicationCommandHandler:
         using TCP.
         """
         if hs.config.redis.redis_enabled:
-            import txredisapi
-
             from synapse.replication.tcp.redis import (
                 RedisDirectTcpReplicationClientFactory,
+                lazyConnection,
             )
 
             logger.info(
@@ -271,7 +270,8 @@ class ReplicationCommandHandler:
             # connection after SUBSCRIBE is called).
 
             # First create the connection for sending commands.
-            outbound_redis_connection = txredisapi.lazyConnection(
+            outbound_redis_connection = lazyConnection(
+                reactor=hs.get_reactor(),
                 host=hs.config.redis_host,
                 port=hs.config.redis_port,
                 password=hs.config.redis.redis_password,
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index 0b0d204e64..a509e599c2 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -51,10 +51,11 @@ import fcntl
 import logging
 import struct
 from inspect import isawaitable
-from typing import TYPE_CHECKING, List
+from typing import TYPE_CHECKING, List, Optional
 
 from prometheus_client import Counter
 
+from twisted.internet import task
 from twisted.protocols.basic import LineOnlyReceiver
 from twisted.python.failure import Failure
 
@@ -152,9 +153,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
 
         self.last_received_command = self.clock.time_msec()
         self.last_sent_command = 0
-        self.time_we_closed = None  # When we requested the connection be closed
+        # When we requested the connection be closed
+        self.time_we_closed = None  # type: Optional[int]
 
-        self.received_ping = False  # Have we reecived a ping from the other side
+        self.received_ping = False  # Have we received a ping from the other side
 
         self.state = ConnectionStates.CONNECTING
 
@@ -165,7 +167,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
         self.pending_commands = []  # type: List[Command]
 
         # The LoopingCall for sending pings.
-        self._send_ping_loop = None
+        self._send_ping_loop = None  # type: Optional[task.LoopingCall]
 
         # a logcontext which we use for processing incoming commands. We declare it as a
         # background process so that the CPU stats get reported to prometheus.
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index f225e533de..de19705c1f 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -15,7 +15,7 @@
 
 import logging
 from inspect import isawaitable
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Optional
 
 import txredisapi
 
@@ -228,3 +228,41 @@ class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
         p.password = self.password
 
         return p
+
+
+def lazyConnection(
+    reactor,
+    host: str = "localhost",
+    port: int = 6379,
+    dbid: Optional[int] = None,
+    reconnect: bool = True,
+    charset: str = "utf-8",
+    password: Optional[str] = None,
+    connectTimeout: Optional[int] = None,
+    replyTimeout: Optional[int] = None,
+    convertNumbers: bool = True,
+) -> txredisapi.RedisProtocol:
+    """Equivalent to `txredisapi.lazyConnection`, except allows specifying a
+    reactor.
+    """
+
+    isLazy = True
+    poolsize = 1
+
+    uuid = "%s:%d" % (host, port)
+    factory = txredisapi.RedisFactory(
+        uuid,
+        dbid,
+        poolsize,
+        isLazy,
+        txredisapi.ConnectionHandler,
+        charset,
+        password,
+        replyTimeout,
+        convertNumbers,
+    )
+    factory.continueTrying = reconnect
+    for x in range(poolsize):
+        reactor.connectTCP(host, port, factory, connectTimeout)
+
+    return factory.handler
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 1f609f158c..54dccd15a6 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -345,7 +345,7 @@ class PushRulesStream(Stream):
     def __init__(self, hs):
         self.store = hs.get_datastore()
 
-        super(PushRulesStream, self).__init__(
+        super().__init__(
             hs.get_instance_name(),
             self._current_token,
             self.store.get_all_push_rule_updates,