summary refs log tree commit diff
path: root/synapse/replication/tcp
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication/tcp')
-rw-r--r--synapse/replication/tcp/client.py231
1 files changed, 224 insertions, 7 deletions
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index ced69ee904..ce5d651cb8 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -14,22 +14,36 @@
 """A replication client for use by synapse workers.
 """
 import logging
-from typing import TYPE_CHECKING, Dict, List, Tuple
+from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
 
 from twisted.internet.defer import Deferred
 from twisted.internet.protocol import ReconnectingClientFactory
 
 from synapse.api.constants import EventTypes
+from synapse.federation import send_queue
+from synapse.federation.sender import FederationSender
 from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
+from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
-from synapse.replication.tcp.streams import TypingStream
+from synapse.replication.tcp.streams import (
+    AccountDataStream,
+    DeviceListsStream,
+    GroupServerStream,
+    PresenceStream,
+    PushersStream,
+    PushRulesStream,
+    ReceiptsStream,
+    TagAccountDataStream,
+    ToDeviceStream,
+    TypingStream,
+)
 from synapse.replication.tcp.streams.events import (
     EventsStream,
     EventsStreamEventRow,
     EventsStreamRow,
 )
-from synapse.types import PersistedEventPosition, UserID
-from synapse.util.async_helpers import timeout_deferred
+from synapse.types import PersistedEventPosition, ReadReceipt, UserID
+from synapse.util.async_helpers import Linearizer, timeout_deferred
 from synapse.util.metrics import Measure
 
 if TYPE_CHECKING:
@@ -105,6 +119,14 @@ class ReplicationDataHandler:
         self._instance_name = hs.get_instance_name()
         self._typing_handler = hs.get_typing_handler()
 
+        self._notify_pushers = hs.config.start_pushers
+        self._pusher_pool = hs.get_pusherpool()
+        self._presence_handler = hs.get_presence_handler()
+
+        self.send_handler = None  # type: Optional[FederationSenderHandler]
+        if hs.should_send_federation():
+            self.send_handler = FederationSenderHandler(hs)
+
         # Map from stream to list of deferreds waiting for the stream to
         # arrive at a particular position. The lists are sorted by stream position.
         self._streams_to_waiters = {}  # type: Dict[str, List[Tuple[int, Deferred]]]
@@ -125,13 +147,53 @@ class ReplicationDataHandler:
         """
         self.store.process_replication_rows(stream_name, instance_name, token, rows)
 
+        if self.send_handler:
+            await self.send_handler.process_replication_rows(stream_name, token, rows)
+
         if stream_name == TypingStream.NAME:
             self._typing_handler.process_replication_rows(token, rows)
             self.notifier.on_new_event(
                 "typing_key", token, rooms=[row.room_id for row in rows]
             )
-
-        if stream_name == EventsStream.NAME:
+        elif stream_name == PushRulesStream.NAME:
+            self.notifier.on_new_event(
+                "push_rules_key", token, users=[row.user_id for row in rows]
+            )
+        elif stream_name in (AccountDataStream.NAME, TagAccountDataStream.NAME):
+            self.notifier.on_new_event(
+                "account_data_key", token, users=[row.user_id for row in rows]
+            )
+        elif stream_name == ReceiptsStream.NAME:
+            self.notifier.on_new_event(
+                "receipt_key", token, rooms=[row.room_id for row in rows]
+            )
+            await self._pusher_pool.on_new_receipts(
+                token, token, {row.room_id for row in rows}
+            )
+        elif stream_name == ToDeviceStream.NAME:
+            entities = [row.entity for row in rows if row.entity.startswith("@")]
+            if entities:
+                self.notifier.on_new_event("to_device_key", token, users=entities)
+        elif stream_name == DeviceListsStream.NAME:
+            all_room_ids = set()  # type: Set[str]
+            for row in rows:
+                if row.entity.startswith("@"):
+                    room_ids = await self.store.get_rooms_for_user(row.entity)
+                    all_room_ids.update(room_ids)
+            self.notifier.on_new_event("device_list_key", token, rooms=all_room_ids)
+        elif stream_name == GroupServerStream.NAME:
+            self.notifier.on_new_event(
+                "groups_key", token, users=[row.user_id for row in rows]
+            )
+        elif stream_name == PushersStream.NAME:
+            for row in rows:
+                if row.deleted:
+                    self.stop_pusher(row.user_id, row.app_id, row.pushkey)
+                else:
+                    await self.start_pusher(row.user_id, row.app_id, row.pushkey)
+        elif stream_name == PresenceStream.NAME:
+            await self._presence_handler.process_replication_rows(token, rows)
+        elif stream_name == EventsStream.NAME:
             # We shouldn't get multiple rows per token for events stream, so
             # we don't need to optimise this for multiple rows.
             for row in rows:
@@ -190,7 +252,7 @@ class ReplicationDataHandler:
         waiting_list[:] = waiting_list[index_of_first_deferred_not_called:]
 
     async def on_position(self, stream_name: str, instance_name: str, token: int):
-        self.store.process_replication_rows(stream_name, instance_name, token, [])
+        await self.on_rdata(stream_name, instance_name, token, [])
 
         # We poke the generic "replication" notifier to wake anything up that
         # may be streaming.
@@ -199,6 +261,11 @@ class ReplicationDataHandler:
     def on_remote_server_up(self, server: str):
         """Called when get a new REMOTE_SERVER_UP command."""
 
+        # Let's wake up the transaction queue for the server in case we have
+        # pending stuff to send to it.
+        if self.send_handler:
+            self.send_handler.wake_destination(server)
+
     async def wait_for_stream_position(
         self, instance_name: str, stream_name: str, position: int
     ):
@@ -235,3 +302,153 @@ class ReplicationDataHandler:
             logger.info(
                 "Finished waiting for repl stream %r to reach %s", stream_name, position
             )
+
+    def stop_pusher(self, user_id, app_id, pushkey):
+        if not self._notify_pushers:
+            return
+
+        key = "%s:%s" % (app_id, pushkey)
+        pushers_for_user = self._pusher_pool.pushers.get(user_id, {})
+        pusher = pushers_for_user.pop(key, None)
+        if pusher is None:
+            return
+        logger.info("Stopping pusher %r / %r", user_id, key)
+        pusher.on_stop()
+
+    async def start_pusher(self, user_id, app_id, pushkey):
+        if not self._notify_pushers:
+            return
+
+        key = "%s:%s" % (app_id, pushkey)
+        logger.info("Starting pusher %r / %r", user_id, key)
+        return await self._pusher_pool.start_pusher_by_id(app_id, pushkey, user_id)
+
+
+class FederationSenderHandler:
+    """Processes the fedration replication stream
+
+    This class is only instantiate on the worker responsible for sending outbound
+    federation transactions. It receives rows from the replication stream and forwards
+    the appropriate entries to the FederationSender class.
+    """
+
+    def __init__(self, hs: "HomeServer"):
+        assert hs.should_send_federation()
+
+        self.store = hs.get_datastore()
+        self._is_mine_id = hs.is_mine_id
+        self._hs = hs
+
+        # We need to make a temporary value to ensure that mypy picks up the
+        # right type. We know we should have a federation sender instance since
+        # `should_send_federation` is True.
+        sender = hs.get_federation_sender()
+        assert isinstance(sender, FederationSender)
+        self.federation_sender = sender
+
+        # Stores the latest position in the federation stream we've gotten up
+        # to. This is always set before we use it.
+        self.federation_position = None  # type: Optional[int]
+
+        self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer")
+
+    def wake_destination(self, server: str):
+        self.federation_sender.wake_destination(server)
+
+    async def process_replication_rows(self, stream_name, token, rows):
+        # The federation stream contains things that we want to send out, e.g.
+        # presence, typing, etc.
+        if stream_name == "federation":
+            send_queue.process_rows_for_federation(self.federation_sender, rows)
+            await self.update_token(token)
+
+        # ... and when new receipts happen
+        elif stream_name == ReceiptsStream.NAME:
+            await self._on_new_receipts(rows)
+
+        # ... as well as device updates and messages
+        elif stream_name == DeviceListsStream.NAME:
+            # The entities are either user IDs (starting with '@') whose devices
+            # have changed, or remote servers that we need to tell about
+            # changes.
+            hosts = {row.entity for row in rows if not row.entity.startswith("@")}
+            for host in hosts:
+                self.federation_sender.send_device_messages(host)
+
+        elif stream_name == ToDeviceStream.NAME:
+            # The to_device stream includes stuff to be pushed to both local
+            # clients and remote servers, so we ignore entities that start with
+            # '@' (since they'll be local users rather than destinations).
+            hosts = {row.entity for row in rows if not row.entity.startswith("@")}
+            for host in hosts:
+                self.federation_sender.send_device_messages(host)
+
+    async def _on_new_receipts(self, rows):
+        """
+        Args:
+            rows (Iterable[synapse.replication.tcp.streams.ReceiptsStream.ReceiptsStreamRow]):
+                new receipts to be processed
+        """
+        for receipt in rows:
+            # we only want to send on receipts for our own users
+            if not self._is_mine_id(receipt.user_id):
+                continue
+            receipt_info = ReadReceipt(
+                receipt.room_id,
+                receipt.receipt_type,
+                receipt.user_id,
+                [receipt.event_id],
+                receipt.data,
+            )
+            await self.federation_sender.send_read_receipt(receipt_info)
+
+    async def update_token(self, token):
+        """Update the record of where we have processed to in the federation stream.
+
+        Called after we have processed a an update received over replication. Sends
+        a FEDERATION_ACK back to the master, and stores the token that we have processed
+         in `federation_stream_position` so that we can restart where we left off.
+        """
+        self.federation_position = token
+
+        # We save and send the ACK to master asynchronously, so we don't block
+        # processing on persistence. We don't need to do this operation for
+        # every single RDATA we receive, we just need to do it periodically.
+
+        if self._fed_position_linearizer.is_queued(None):
+            # There is already a task queued up to save and send the token, so
+            # no need to queue up another task.
+            return
+
+        run_as_background_process("_save_and_send_ack", self._save_and_send_ack)
+
+    async def _save_and_send_ack(self):
+        """Save the current federation position in the database and send an ACK
+        to master with where we're up to.
+        """
+        # We should only be calling this once we've got a token.
+        assert self.federation_position is not None
+
+        try:
+            # We linearize here to ensure we don't have races updating the token
+            #
+            # XXX this appears to be redundant, since the ReplicationCommandHandler
+            # has a linearizer which ensures that we only process one line of
+            # replication data at a time. Should we remove it, or is it doing useful
+            # service for robustness? Or could we replace it with an assertion that
+            # we're not being re-entered?
+
+            with (await self._fed_position_linearizer.queue(None)):
+                # We persist and ack the same position, so we take a copy of it
+                # here as otherwise it can get modified from underneath us.
+                current_position = self.federation_position
+
+                await self.store.update_federation_out_pos(
+                    "federation", current_position
+                )
+
+                # We ACK this token over replication so that the master can drop
+                # its in memory queues
+                self._hs.get_tcp_replication().send_federation_ack(current_position)
+        except Exception:
+            logger.exception("Error updating federation stream position")