summary refs log tree commit diff
path: root/synapse/replication/tcp/client.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication/tcp/client.py')
-rw-r--r--synapse/replication/tcp/client.py21
1 files changed, 12 insertions, 9 deletions
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 6d2513c4e2..cbe9645817 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -15,17 +15,20 @@
 """A replication client for use by synapse workers.
 """
 
-from twisted.internet import reactor, defer
+import logging
+
+from twisted.internet import defer
 from twisted.internet.protocol import ReconnectingClientFactory
 
 from .commands import (
-    FederationAckCommand, UserSyncCommand, RemovePusherCommand, InvalidateCacheCommand,
+    FederationAckCommand,
+    InvalidateCacheCommand,
+    RemovePusherCommand,
     UserIpCommand,
+    UserSyncCommand,
 )
 from .protocol import ClientReplicationStreamProtocol
 
-import logging
-
 logger = logging.getLogger(__name__)
 
 
@@ -44,7 +47,7 @@ class ReplicationClientFactory(ReconnectingClientFactory):
         self.server_name = hs.config.server_name
         self._clock = hs.get_clock()  # As self.clock is defined in super class
 
-        reactor.addSystemEventTrigger("before", "shutdown", self.stopTrying)
+        hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.stopTrying)
 
     def startedConnecting(self, connector):
         logger.info("Connecting to replication: %r", connector.getDestination())
@@ -95,7 +98,7 @@ class ReplicationClientHandler(object):
         factory = ReplicationClientFactory(hs, client_name, self)
         host = hs.config.worker_replication_host
         port = hs.config.worker_replication_port
-        reactor.connectTCP(host, port, factory)
+        hs.get_reactor().connectTCP(host, port, factory)
 
     def on_rdata(self, stream_name, token, rows):
         """Called when we get new replication data. By default this just pokes
@@ -104,7 +107,7 @@ class ReplicationClientHandler(object):
         Can be overriden in subclasses to handle more.
         """
         logger.info("Received rdata %s -> %s", stream_name, token)
-        self.store.process_replication_rows(stream_name, token, rows)
+        return self.store.process_replication_rows(stream_name, token, rows)
 
     def on_position(self, stream_name, token):
         """Called when we get new position data. By default this just pokes
@@ -112,7 +115,7 @@ class ReplicationClientHandler(object):
 
         Can be overriden in subclasses to handle more.
         """
-        self.store.process_replication_rows(stream_name, token, [])
+        return self.store.process_replication_rows(stream_name, token, [])
 
     def on_sync(self, data):
         """When we received a SYNC we wake up any deferreds that were waiting
@@ -189,7 +192,7 @@ class ReplicationClientHandler(object):
         """Returns a deferred that is resolved when we receive a SYNC command
         with given data.
 
-        Used by tests.
+        [Not currently] used by tests.
         """
         return self.awaiting_syncs.setdefault(data, defer.Deferred())