summary refs log tree commit diff
path: root/synapse/replication/tcp
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication/tcp')
-rw-r--r--synapse/replication/tcp/__init__.py1
-rw-r--r--synapse/replication/tcp/client.py233
-rw-r--r--synapse/replication/tcp/commands.py1
-rw-r--r--synapse/replication/tcp/external_cache.py1
-rw-r--r--synapse/replication/tcp/handler.py1
-rw-r--r--synapse/replication/tcp/protocol.py1
-rw-r--r--synapse/replication/tcp/redis.py1
-rw-r--r--synapse/replication/tcp/resource.py1
-rw-r--r--synapse/replication/tcp/streams/__init__.py4
-rw-r--r--synapse/replication/tcp/streams/_base.py25
-rw-r--r--synapse/replication/tcp/streams/events.py1
-rw-r--r--synapse/replication/tcp/streams/federation.py1
12 files changed, 252 insertions, 19 deletions
diff --git a/synapse/replication/tcp/__init__.py b/synapse/replication/tcp/__init__.py
index 1b8718b11d..1fa60af8e6 100644
--- a/synapse/replication/tcp/__init__.py
+++ b/synapse/replication/tcp/__init__.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
 # Copyright 2017 Vector Creations Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 3455839d67..4f3c6a18b6 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
 # Copyright 2017 Vector Creations Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
@@ -15,22 +14,35 @@
 """A replication client for use by synapse workers.
 """
 import logging
-from typing import TYPE_CHECKING, Dict, List, Tuple
+from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
 
 from twisted.internet.defer import Deferred
 from twisted.internet.protocol import ReconnectingClientFactory
 
 from synapse.api.constants import EventTypes
+from synapse.federation import send_queue
+from synapse.federation.sender import FederationSender
 from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
+from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
-from synapse.replication.tcp.streams import TypingStream
+from synapse.replication.tcp.streams import (
+    AccountDataStream,
+    DeviceListsStream,
+    GroupServerStream,
+    PushersStream,
+    PushRulesStream,
+    ReceiptsStream,
+    TagAccountDataStream,
+    ToDeviceStream,
+    TypingStream,
+)
 from synapse.replication.tcp.streams.events import (
     EventsStream,
     EventsStreamEventRow,
     EventsStreamRow,
 )
-from synapse.types import PersistedEventPosition, UserID
-from synapse.util.async_helpers import timeout_deferred
+from synapse.types import PersistedEventPosition, ReadReceipt, UserID
+from synapse.util.async_helpers import Linearizer, timeout_deferred
 from synapse.util.metrics import Measure
 
 if TYPE_CHECKING:
@@ -106,6 +118,14 @@ class ReplicationDataHandler:
         self._instance_name = hs.get_instance_name()
         self._typing_handler = hs.get_typing_handler()
 
+        self._notify_pushers = hs.config.start_pushers
+        self._pusher_pool = hs.get_pusherpool()
+        self._presence_handler = hs.get_presence_handler()
+
+        self.send_handler = None  # type: Optional[FederationSenderHandler]
+        if hs.should_send_federation():
+            self.send_handler = FederationSenderHandler(hs)
+
         # Map from stream to list of deferreds waiting for the stream to
         # arrive at a particular position. The lists are sorted by stream position.
         self._streams_to_waiters = {}  # type: Dict[str, List[Tuple[int, Deferred]]]
@@ -126,13 +146,51 @@ class ReplicationDataHandler:
         """
         self.store.process_replication_rows(stream_name, instance_name, token, rows)
 
+        if self.send_handler:
+            await self.send_handler.process_replication_rows(stream_name, token, rows)
+
         if stream_name == TypingStream.NAME:
             self._typing_handler.process_replication_rows(token, rows)
             self.notifier.on_new_event(
                 "typing_key", token, rooms=[row.room_id for row in rows]
             )
-
-        if stream_name == EventsStream.NAME:
+        elif stream_name == PushRulesStream.NAME:
+            self.notifier.on_new_event(
+                "push_rules_key", token, users=[row.user_id for row in rows]
+            )
+        elif stream_name in (AccountDataStream.NAME, TagAccountDataStream.NAME):
+            self.notifier.on_new_event(
+                "account_data_key", token, users=[row.user_id for row in rows]
+            )
+        elif stream_name == ReceiptsStream.NAME:
+            self.notifier.on_new_event(
+                "receipt_key", token, rooms=[row.room_id for row in rows]
+            )
+            await self._pusher_pool.on_new_receipts(
+                token, token, {row.room_id for row in rows}
+            )
+        elif stream_name == ToDeviceStream.NAME:
+            entities = [row.entity for row in rows if row.entity.startswith("@")]
+            if entities:
+                self.notifier.on_new_event("to_device_key", token, users=entities)
+        elif stream_name == DeviceListsStream.NAME:
+            all_room_ids = set()  # type: Set[str]
+            for row in rows:
+                if row.entity.startswith("@"):
+                    room_ids = await self.store.get_rooms_for_user(row.entity)
+                    all_room_ids.update(room_ids)
+            self.notifier.on_new_event("device_list_key", token, rooms=all_room_ids)
+        elif stream_name == GroupServerStream.NAME:
+            self.notifier.on_new_event(
+                "groups_key", token, users=[row.user_id for row in rows]
+            )
+        elif stream_name == PushersStream.NAME:
+            for row in rows:
+                if row.deleted:
+                    self.stop_pusher(row.user_id, row.app_id, row.pushkey)
+                else:
+                    await self.start_pusher(row.user_id, row.app_id, row.pushkey)
+        elif stream_name == EventsStream.NAME:
             # We shouldn't get multiple rows per token for events stream, so
             # we don't need to optimise this for multiple rows.
             for row in rows:
@@ -160,6 +218,10 @@ class ReplicationDataHandler:
                     membership=row.data.membership,
                 )
 
+        await self._presence_handler.process_replication_rows(
+            stream_name, instance_name, token, rows
+        )
+
         # Notify any waiting deferreds. The list is ordered by position so we
         # just iterate through the list until we reach a position that is
         # greater than the received row position.
@@ -191,7 +253,7 @@ class ReplicationDataHandler:
         waiting_list[:] = waiting_list[index_of_first_deferred_not_called:]
 
     async def on_position(self, stream_name: str, instance_name: str, token: int):
-        self.store.process_replication_rows(stream_name, instance_name, token, [])
+        await self.on_rdata(stream_name, instance_name, token, [])
 
         # We poke the generic "replication" notifier to wake anything up that
         # may be streaming.
@@ -200,6 +262,11 @@ class ReplicationDataHandler:
     def on_remote_server_up(self, server: str):
         """Called when get a new REMOTE_SERVER_UP command."""
 
+        # Let's wake up the transaction queue for the server in case we have
+        # pending stuff to send to it.
+        if self.send_handler:
+            self.send_handler.wake_destination(server)
+
     async def wait_for_stream_position(
         self, instance_name: str, stream_name: str, position: int
     ):
@@ -236,3 +303,153 @@ class ReplicationDataHandler:
             logger.info(
                 "Finished waiting for repl stream %r to reach %s", stream_name, position
             )
+
+    def stop_pusher(self, user_id, app_id, pushkey):
+        if not self._notify_pushers:
+            return
+
+        key = "%s:%s" % (app_id, pushkey)
+        pushers_for_user = self._pusher_pool.pushers.get(user_id, {})
+        pusher = pushers_for_user.pop(key, None)
+        if pusher is None:
+            return
+        logger.info("Stopping pusher %r / %r", user_id, key)
+        pusher.on_stop()
+
+    async def start_pusher(self, user_id, app_id, pushkey):
+        if not self._notify_pushers:
+            return
+
+        key = "%s:%s" % (app_id, pushkey)
+        logger.info("Starting pusher %r / %r", user_id, key)
+        return await self._pusher_pool.start_pusher_by_id(app_id, pushkey, user_id)
+
+
+class FederationSenderHandler:
+    """Processes the fedration replication stream
+
+    This class is only instantiate on the worker responsible for sending outbound
+    federation transactions. It receives rows from the replication stream and forwards
+    the appropriate entries to the FederationSender class.
+    """
+
+    def __init__(self, hs: "HomeServer"):
+        assert hs.should_send_federation()
+
+        self.store = hs.get_datastore()
+        self._is_mine_id = hs.is_mine_id
+        self._hs = hs
+
+        # We need to make a temporary value to ensure that mypy picks up the
+        # right type. We know we should have a federation sender instance since
+        # `should_send_federation` is True.
+        sender = hs.get_federation_sender()
+        assert isinstance(sender, FederationSender)
+        self.federation_sender = sender
+
+        # Stores the latest position in the federation stream we've gotten up
+        # to. This is always set before we use it.
+        self.federation_position = None  # type: Optional[int]
+
+        self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer")
+
+    def wake_destination(self, server: str):
+        self.federation_sender.wake_destination(server)
+
+    async def process_replication_rows(self, stream_name, token, rows):
+        # The federation stream contains things that we want to send out, e.g.
+        # presence, typing, etc.
+        if stream_name == "federation":
+            send_queue.process_rows_for_federation(self.federation_sender, rows)
+            await self.update_token(token)
+
+        # ... and when new receipts happen
+        elif stream_name == ReceiptsStream.NAME:
+            await self._on_new_receipts(rows)
+
+        # ... as well as device updates and messages
+        elif stream_name == DeviceListsStream.NAME:
+            # The entities are either user IDs (starting with '@') whose devices
+            # have changed, or remote servers that we need to tell about
+            # changes.
+            hosts = {row.entity for row in rows if not row.entity.startswith("@")}
+            for host in hosts:
+                self.federation_sender.send_device_messages(host)
+
+        elif stream_name == ToDeviceStream.NAME:
+            # The to_device stream includes stuff to be pushed to both local
+            # clients and remote servers, so we ignore entities that start with
+            # '@' (since they'll be local users rather than destinations).
+            hosts = {row.entity for row in rows if not row.entity.startswith("@")}
+            for host in hosts:
+                self.federation_sender.send_device_messages(host)
+
+    async def _on_new_receipts(self, rows):
+        """
+        Args:
+            rows (Iterable[synapse.replication.tcp.streams.ReceiptsStream.ReceiptsStreamRow]):
+                new receipts to be processed
+        """
+        for receipt in rows:
+            # we only want to send on receipts for our own users
+            if not self._is_mine_id(receipt.user_id):
+                continue
+            receipt_info = ReadReceipt(
+                receipt.room_id,
+                receipt.receipt_type,
+                receipt.user_id,
+                [receipt.event_id],
+                receipt.data,
+            )
+            await self.federation_sender.send_read_receipt(receipt_info)
+
+    async def update_token(self, token):
+        """Update the record of where we have processed to in the federation stream.
+
+        Called after we have processed a an update received over replication. Sends
+        a FEDERATION_ACK back to the master, and stores the token that we have processed
+         in `federation_stream_position` so that we can restart where we left off.
+        """
+        self.federation_position = token
+
+        # We save and send the ACK to master asynchronously, so we don't block
+        # processing on persistence. We don't need to do this operation for
+        # every single RDATA we receive, we just need to do it periodically.
+
+        if self._fed_position_linearizer.is_queued(None):
+            # There is already a task queued up to save and send the token, so
+            # no need to queue up another task.
+            return
+
+        run_as_background_process("_save_and_send_ack", self._save_and_send_ack)
+
+    async def _save_and_send_ack(self):
+        """Save the current federation position in the database and send an ACK
+        to master with where we're up to.
+        """
+        # We should only be calling this once we've got a token.
+        assert self.federation_position is not None
+
+        try:
+            # We linearize here to ensure we don't have races updating the token
+            #
+            # XXX this appears to be redundant, since the ReplicationCommandHandler
+            # has a linearizer which ensures that we only process one line of
+            # replication data at a time. Should we remove it, or is it doing useful
+            # service for robustness? Or could we replace it with an assertion that
+            # we're not being re-entered?
+
+            with (await self._fed_position_linearizer.queue(None)):
+                # We persist and ack the same position, so we take a copy of it
+                # here as otherwise it can get modified from underneath us.
+                current_position = self.federation_position
+
+                await self.store.update_federation_out_pos(
+                    "federation", current_position
+                )
+
+                # We ACK this token over replication so that the master can drop
+                # its in memory queues
+                self._hs.get_tcp_replication().send_federation_ack(current_position)
+        except Exception:
+            logger.exception("Error updating federation stream position")
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index 8abed1f52d..505d450e19 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
 # Copyright 2017 Vector Creations Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
diff --git a/synapse/replication/tcp/external_cache.py b/synapse/replication/tcp/external_cache.py
index d89a36f25a..1a3b051e3c 100644
--- a/synapse/replication/tcp/external_cache.py
+++ b/synapse/replication/tcp/external_cache.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
 # Copyright 2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index a8894beadf..2ce1b9f222 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
 # Copyright 2017 Vector Creations Ltd
 # Copyright 2020 The Matrix.org Foundation C.I.C.
 #
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index ba753318bd..2df028f315 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
 # Copyright 2017 Vector Creations Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index 98bdeb0ec6..6a2c2655e4 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
 # Copyright 2020 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 2018f9f29e..bd47d84258 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
 # Copyright 2017 Vector Creations Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py
index d1a61c3314..4c0023c68a 100644
--- a/synapse/replication/tcp/streams/__init__.py
+++ b/synapse/replication/tcp/streams/__init__.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
 # Copyright 2017 Vector Creations Ltd
 # Copyright 2019 New Vector Ltd
 #
@@ -31,6 +30,7 @@ from synapse.replication.tcp.streams._base import (
     CachesStream,
     DeviceListsStream,
     GroupServerStream,
+    PresenceFederationStream,
     PresenceStream,
     PublicRoomsStream,
     PushersStream,
@@ -51,6 +51,7 @@ STREAMS_MAP = {
         EventsStream,
         BackfillStream,
         PresenceStream,
+        PresenceFederationStream,
         TypingStream,
         ReceiptsStream,
         PushRulesStream,
@@ -72,6 +73,7 @@ __all__ = [
     "Stream",
     "BackfillStream",
     "PresenceStream",
+    "PresenceFederationStream",
     "TypingStream",
     "ReceiptsStream",
     "PushRulesStream",
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 3dfee76743..9d75a89f1c 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
 # Copyright 2017 Vector Creations Ltd
 # Copyright 2019 New Vector Ltd
 #
@@ -291,6 +290,30 @@ class PresenceStream(Stream):
         )
 
 
+class PresenceFederationStream(Stream):
+    """A stream used to send ad hoc presence updates over federation.
+
+    Streams the remote destination and the user ID of the presence state to
+    send.
+    """
+
+    @attr.s(slots=True, auto_attribs=True)
+    class PresenceFederationStreamRow:
+        destination: str
+        user_id: str
+
+    NAME = "presence_federation"
+    ROW_TYPE = PresenceFederationStreamRow
+
+    def __init__(self, hs: "HomeServer"):
+        federation_queue = hs.get_presence_handler().get_federation_queue()
+        super().__init__(
+            hs.get_instance_name(),
+            federation_queue.get_current_token,
+            federation_queue.get_replication_rows,
+        )
+
+
 class TypingStream(Stream):
     TypingStreamRow = namedtuple(
         "TypingStreamRow", ("room_id", "user_ids")  # str  # list(str)
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index fa5e37ba7b..e7e87bac92 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
 # Copyright 2017 Vector Creations Ltd
 # Copyright 2019 New Vector Ltd
 #
diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py
index 9bb8e9e177..096a85d363 100644
--- a/synapse/replication/tcp/streams/federation.py
+++ b/synapse/replication/tcp/streams/federation.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
 # Copyright 2017 Vector Creations Ltd
 # Copyright 2019 New Vector Ltd
 #