diff options
Diffstat (limited to 'synapse/app/federation_sender.py')
-rw-r--r-- | synapse/app/federation_sender.py | 110 |
1 files changed, 62 insertions, 48 deletions
diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py index 76c4cc54d1..8994891aeb 100644 --- a/synapse/app/federation_sender.py +++ b/synapse/app/federation_sender.py @@ -31,9 +31,10 @@ from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.slave.storage.transactions import TransactionStore from synapse.replication.slave.storage.devices import SlavedDeviceStore +from synapse.replication.tcp.client import ReplicationClientHandler from synapse.storage.engines import create_engine from synapse.storage.presence import UserPresenceState -from synapse.util.async import sleep +from synapse.util.async import Linearizer from synapse.util.httpresourcetree import create_resource_tree from synapse.util.logcontext import LoggingContext, PreserveLoggingContext from synapse.util.manhole import manhole @@ -59,7 +60,23 @@ class FederationSenderSlaveStore( SlavedDeviceInboxStore, TransactionStore, SlavedReceiptsStore, SlavedEventStore, SlavedRegistrationStore, SlavedDeviceStore, ): - pass + def __init__(self, db_conn, hs): + super(FederationSenderSlaveStore, self).__init__(db_conn, hs) + self.federation_out_pos_startup = self._get_federation_out_pos(db_conn) + + def _get_federation_out_pos(self, db_conn): + sql = ( + "SELECT stream_id FROM federation_stream_position" + " WHERE type = ?" + ) + sql = self.database_engine.convert_param_style(sql) + + txn = db_conn.cursor() + txn.execute(sql, ("federation",)) + rows = txn.fetchall() + txn.close() + + return rows[0][0] if rows else -1 class FederationSenderServer(HomeServer): @@ -127,26 +144,29 @@ class FederationSenderServer(HomeServer): else: logger.warn("Unrecognized listener type: %s", listener["type"]) - @defer.inlineCallbacks - def replicate(self): - http_client = self.get_simple_http_client() - store = self.get_datastore() - replication_url = self.config.worker_replication_url - send_handler = FederationSenderHandler(self) - - send_handler.on_start() - - while True: - try: - args = store.stream_positions() - args.update((yield send_handler.stream_positions())) - args["timeout"] = 30000 - result = yield http_client.get_json(replication_url, args=args) - yield store.process_replication(result) - yield send_handler.process_replication(result) - except: - logger.exception("Error replicating from %r", replication_url) - yield sleep(30) + self.get_tcp_replication().start_replication(self) + + def build_tcp_replication(self): + return FederationSenderReplicationHandler(self) + + +class FederationSenderReplicationHandler(ReplicationClientHandler): + def __init__(self, hs): + super(FederationSenderReplicationHandler, self).__init__(hs.get_datastore()) + self.send_handler = FederationSenderHandler(hs) + + def on_rdata(self, stream_name, token, rows): + super(FederationSenderReplicationHandler, self).on_rdata( + stream_name, token, rows + ) + self.send_handler.process_replication_rows(stream_name, token, rows) + if stream_name == "federation": + self.send_federation_ack(token) + + def get_streams_to_replicate(self): + args = super(FederationSenderReplicationHandler, self).get_streams_to_replicate() + args.update(self.send_handler.stream_positions()) + return args def start(config_options): @@ -205,7 +225,6 @@ def start(config_options): reactor.run() def start(): - ps.replicate() ps.get_datastore().start_profiling() ps.get_state_handler().start_caching() @@ -233,6 +252,9 @@ class FederationSenderHandler(object): self.store = hs.get_datastore() self.federation_sender = hs.get_federation_sender() + self.federation_position = self.store.federation_out_pos_startup + self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer") + self._room_serials = {} self._room_typing = {} @@ -243,25 +265,13 @@ class FederationSenderHandler(object): self.store.get_room_max_stream_ordering() ) - @defer.inlineCallbacks def stream_positions(self): - stream_id = yield self.store.get_federation_out_pos("federation") - defer.returnValue({ - "federation": stream_id, - - # Ack stuff we've "processed", this should only be called from - # one process. - "federation_ack": stream_id, - }) + return {"federation": self.federation_position} - @defer.inlineCallbacks - def process_replication(self, result): + def process_replication_rows(self, stream_name, token, rows): # The federation stream contains things that we want to send out, e.g. # presence, typing, etc. - fed_stream = result.get("federation") - if fed_stream: - latest_id = int(fed_stream["position"]) - + if stream_name == "federation": # The federation stream containis a bunch of different types of # rows that need to be handled differently. We parse the rows, put # them into the appropriate collection and then send them off. @@ -272,8 +282,9 @@ class FederationSenderHandler(object): device_destinations = set() # Parse the rows in the stream - for row in fed_stream["rows"]: - position, typ, content_js = row + for row in rows: + typ = row.type + content_js = row.data content = json.loads(content_js) if typ == send_queue.PRESENCE_TYPE: @@ -325,16 +336,19 @@ class FederationSenderHandler(object): for destination in device_destinations: self.federation_sender.send_device_messages(destination) - # Record where we are in the stream. - yield self.store.update_federation_out_pos( - "federation", latest_id - ) + self.update_token(token) # We also need to poke the federation sender when new events happen - event_stream = result.get("events") - if event_stream: - latest_pos = event_stream["position"] - self.federation_sender.notify_new_events(latest_pos) + elif stream_name == "events": + self.federation_sender.notify_new_events(token) + + @defer.inlineCallbacks + def update_token(self, token): + self.federation_position = token + with (yield self._fed_position_linearizer.queue(None)): + yield self.store.update_federation_out_pos( + "federation", self.federation_position + ) if __name__ == '__main__': |