summary refs log tree commit diff
path: root/synapse/app/federation_sender.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/app/federation_sender.py')
-rw-r--r--synapse/app/federation_sender.py110
1 files changed, 62 insertions, 48 deletions
diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py
index 76c4cc54d1..8994891aeb 100644
--- a/synapse/app/federation_sender.py
+++ b/synapse/app/federation_sender.py
@@ -31,9 +31,10 @@ from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
 from synapse.replication.slave.storage.registration import SlavedRegistrationStore
 from synapse.replication.slave.storage.transactions import TransactionStore
 from synapse.replication.slave.storage.devices import SlavedDeviceStore
+from synapse.replication.tcp.client import ReplicationClientHandler
 from synapse.storage.engines import create_engine
 from synapse.storage.presence import UserPresenceState
-from synapse.util.async import sleep
+from synapse.util.async import Linearizer
 from synapse.util.httpresourcetree import create_resource_tree
 from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
 from synapse.util.manhole import manhole
@@ -59,7 +60,23 @@ class FederationSenderSlaveStore(
     SlavedDeviceInboxStore, TransactionStore, SlavedReceiptsStore, SlavedEventStore,
     SlavedRegistrationStore, SlavedDeviceStore,
 ):
-    pass
+    def __init__(self, db_conn, hs):
+        super(FederationSenderSlaveStore, self).__init__(db_conn, hs)
+        self.federation_out_pos_startup = self._get_federation_out_pos(db_conn)
+
+    def _get_federation_out_pos(self, db_conn):
+        sql = (
+            "SELECT stream_id FROM federation_stream_position"
+            " WHERE type = ?"
+        )
+        sql = self.database_engine.convert_param_style(sql)
+
+        txn = db_conn.cursor()
+        txn.execute(sql, ("federation",))
+        rows = txn.fetchall()
+        txn.close()
+
+        return rows[0][0] if rows else -1
 
 
 class FederationSenderServer(HomeServer):
@@ -127,26 +144,29 @@ class FederationSenderServer(HomeServer):
             else:
                 logger.warn("Unrecognized listener type: %s", listener["type"])
 
-    @defer.inlineCallbacks
-    def replicate(self):
-        http_client = self.get_simple_http_client()
-        store = self.get_datastore()
-        replication_url = self.config.worker_replication_url
-        send_handler = FederationSenderHandler(self)
-
-        send_handler.on_start()
-
-        while True:
-            try:
-                args = store.stream_positions()
-                args.update((yield send_handler.stream_positions()))
-                args["timeout"] = 30000
-                result = yield http_client.get_json(replication_url, args=args)
-                yield store.process_replication(result)
-                yield send_handler.process_replication(result)
-            except:
-                logger.exception("Error replicating from %r", replication_url)
-                yield sleep(30)
+        self.get_tcp_replication().start_replication(self)
+
+    def build_tcp_replication(self):
+        return FederationSenderReplicationHandler(self)
+
+
+class FederationSenderReplicationHandler(ReplicationClientHandler):
+    def __init__(self, hs):
+        super(FederationSenderReplicationHandler, self).__init__(hs.get_datastore())
+        self.send_handler = FederationSenderHandler(hs)
+
+    def on_rdata(self, stream_name, token, rows):
+        super(FederationSenderReplicationHandler, self).on_rdata(
+            stream_name, token, rows
+        )
+        self.send_handler.process_replication_rows(stream_name, token, rows)
+        if stream_name == "federation":
+            self.send_federation_ack(token)
+
+    def get_streams_to_replicate(self):
+        args = super(FederationSenderReplicationHandler, self).get_streams_to_replicate()
+        args.update(self.send_handler.stream_positions())
+        return args
 
 
 def start(config_options):
@@ -205,7 +225,6 @@ def start(config_options):
             reactor.run()
 
     def start():
-        ps.replicate()
         ps.get_datastore().start_profiling()
         ps.get_state_handler().start_caching()
 
@@ -233,6 +252,9 @@ class FederationSenderHandler(object):
         self.store = hs.get_datastore()
         self.federation_sender = hs.get_federation_sender()
 
+        self.federation_position = self.store.federation_out_pos_startup
+        self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer")
+
         self._room_serials = {}
         self._room_typing = {}
 
@@ -243,25 +265,13 @@ class FederationSenderHandler(object):
             self.store.get_room_max_stream_ordering()
         )
 
-    @defer.inlineCallbacks
     def stream_positions(self):
-        stream_id = yield self.store.get_federation_out_pos("federation")
-        defer.returnValue({
-            "federation": stream_id,
-
-            # Ack stuff we've "processed", this should only be called from
-            # one process.
-            "federation_ack": stream_id,
-        })
+        return {"federation": self.federation_position}
 
-    @defer.inlineCallbacks
-    def process_replication(self, result):
+    def process_replication_rows(self, stream_name, token, rows):
         # The federation stream contains things that we want to send out, e.g.
         # presence, typing, etc.
-        fed_stream = result.get("federation")
-        if fed_stream:
-            latest_id = int(fed_stream["position"])
-
+        if stream_name == "federation":
             # The federation stream containis a bunch of different types of
             # rows that need to be handled differently. We parse the rows, put
             # them into the appropriate collection and then send them off.
@@ -272,8 +282,9 @@ class FederationSenderHandler(object):
             device_destinations = set()
 
             # Parse the rows in the stream
-            for row in fed_stream["rows"]:
-                position, typ, content_js = row
+            for row in rows:
+                typ = row.type
+                content_js = row.data
                 content = json.loads(content_js)
 
                 if typ == send_queue.PRESENCE_TYPE:
@@ -325,16 +336,19 @@ class FederationSenderHandler(object):
             for destination in device_destinations:
                 self.federation_sender.send_device_messages(destination)
 
-            # Record where we are in the stream.
-            yield self.store.update_federation_out_pos(
-                "federation", latest_id
-            )
+            self.update_token(token)
 
         # We also need to poke the federation sender when new events happen
-        event_stream = result.get("events")
-        if event_stream:
-            latest_pos = event_stream["position"]
-            self.federation_sender.notify_new_events(latest_pos)
+        elif stream_name == "events":
+            self.federation_sender.notify_new_events(token)
+
+    @defer.inlineCallbacks
+    def update_token(self, token):
+        self.federation_position = token
+        with (yield self._fed_position_linearizer.queue(None)):
+            yield self.store.update_federation_out_pos(
+                "federation", self.federation_position
+            )
 
 
 if __name__ == '__main__':