diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py
index 76c4cc54d1..8994891aeb 100644
--- a/synapse/app/federation_sender.py
+++ b/synapse/app/federation_sender.py
@@ -31,9 +31,10 @@ from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.slave.storage.transactions import TransactionStore
from synapse.replication.slave.storage.devices import SlavedDeviceStore
+from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.storage.engines import create_engine
from synapse.storage.presence import UserPresenceState
-from synapse.util.async import sleep
+from synapse.util.async import Linearizer
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
from synapse.util.manhole import manhole
@@ -59,7 +60,23 @@ class FederationSenderSlaveStore(
SlavedDeviceInboxStore, TransactionStore, SlavedReceiptsStore, SlavedEventStore,
SlavedRegistrationStore, SlavedDeviceStore,
):
- pass
+ def __init__(self, db_conn, hs):
+ super(FederationSenderSlaveStore, self).__init__(db_conn, hs)
+ self.federation_out_pos_startup = self._get_federation_out_pos(db_conn)
+
+ def _get_federation_out_pos(self, db_conn):
+ sql = (
+ "SELECT stream_id FROM federation_stream_position"
+ " WHERE type = ?"
+ )
+ sql = self.database_engine.convert_param_style(sql)
+
+ txn = db_conn.cursor()
+ txn.execute(sql, ("federation",))
+ rows = txn.fetchall()
+ txn.close()
+
+ return rows[0][0] if rows else -1
class FederationSenderServer(HomeServer):
@@ -127,26 +144,29 @@ class FederationSenderServer(HomeServer):
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
- @defer.inlineCallbacks
- def replicate(self):
- http_client = self.get_simple_http_client()
- store = self.get_datastore()
- replication_url = self.config.worker_replication_url
- send_handler = FederationSenderHandler(self)
-
- send_handler.on_start()
-
- while True:
- try:
- args = store.stream_positions()
- args.update((yield send_handler.stream_positions()))
- args["timeout"] = 30000
- result = yield http_client.get_json(replication_url, args=args)
- yield store.process_replication(result)
- yield send_handler.process_replication(result)
- except:
- logger.exception("Error replicating from %r", replication_url)
- yield sleep(30)
+ self.get_tcp_replication().start_replication(self)
+
+ def build_tcp_replication(self):
+ return FederationSenderReplicationHandler(self)
+
+
+class FederationSenderReplicationHandler(ReplicationClientHandler):
+ def __init__(self, hs):
+ super(FederationSenderReplicationHandler, self).__init__(hs.get_datastore())
+ self.send_handler = FederationSenderHandler(hs)
+
+ def on_rdata(self, stream_name, token, rows):
+ super(FederationSenderReplicationHandler, self).on_rdata(
+ stream_name, token, rows
+ )
+ self.send_handler.process_replication_rows(stream_name, token, rows)
+ if stream_name == "federation":
+ self.send_federation_ack(token)
+
+ def get_streams_to_replicate(self):
+ args = super(FederationSenderReplicationHandler, self).get_streams_to_replicate()
+ args.update(self.send_handler.stream_positions())
+ return args
def start(config_options):
@@ -205,7 +225,6 @@ def start(config_options):
reactor.run()
def start():
- ps.replicate()
ps.get_datastore().start_profiling()
ps.get_state_handler().start_caching()
@@ -233,6 +252,9 @@ class FederationSenderHandler(object):
self.store = hs.get_datastore()
self.federation_sender = hs.get_federation_sender()
+ self.federation_position = self.store.federation_out_pos_startup
+ self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer")
+
self._room_serials = {}
self._room_typing = {}
@@ -243,25 +265,13 @@ class FederationSenderHandler(object):
self.store.get_room_max_stream_ordering()
)
- @defer.inlineCallbacks
def stream_positions(self):
- stream_id = yield self.store.get_federation_out_pos("federation")
- defer.returnValue({
- "federation": stream_id,
-
- # Ack stuff we've "processed", this should only be called from
- # one process.
- "federation_ack": stream_id,
- })
+ return {"federation": self.federation_position}
- @defer.inlineCallbacks
- def process_replication(self, result):
+ def process_replication_rows(self, stream_name, token, rows):
# The federation stream contains things that we want to send out, e.g.
# presence, typing, etc.
- fed_stream = result.get("federation")
- if fed_stream:
- latest_id = int(fed_stream["position"])
-
+ if stream_name == "federation":
# The federation stream containis a bunch of different types of
# rows that need to be handled differently. We parse the rows, put
# them into the appropriate collection and then send them off.
@@ -272,8 +282,9 @@ class FederationSenderHandler(object):
device_destinations = set()
# Parse the rows in the stream
- for row in fed_stream["rows"]:
- position, typ, content_js = row
+ for row in rows:
+ typ = row.type
+ content_js = row.data
content = json.loads(content_js)
if typ == send_queue.PRESENCE_TYPE:
@@ -325,16 +336,19 @@ class FederationSenderHandler(object):
for destination in device_destinations:
self.federation_sender.send_device_messages(destination)
- # Record where we are in the stream.
- yield self.store.update_federation_out_pos(
- "federation", latest_id
- )
+ self.update_token(token)
# We also need to poke the federation sender when new events happen
- event_stream = result.get("events")
- if event_stream:
- latest_pos = event_stream["position"]
- self.federation_sender.notify_new_events(latest_pos)
+ elif stream_name == "events":
+ self.federation_sender.notify_new_events(token)
+
+ @defer.inlineCallbacks
+ def update_token(self, token):
+ self.federation_position = token
+ with (yield self._fed_position_linearizer.queue(None)):
+ yield self.store.update_federation_out_pos(
+ "federation", self.federation_position
+ )
if __name__ == '__main__':
|