summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/handlers/presence.py243
-rw-r--r--synapse/replication/tcp/client.py7
-rw-r--r--synapse/replication/tcp/streams/__init__.py3
-rw-r--r--synapse/replication/tcp/streams/_base.py24
4 files changed, 254 insertions, 23 deletions
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index bd2382193f..598466c9bd 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -24,6 +24,7 @@ The methods that define policy are:
 import abc
 import contextlib
 import logging
+from bisect import bisect
 from contextlib import contextmanager
 from typing import (
     TYPE_CHECKING,
@@ -53,7 +54,9 @@ from synapse.replication.http.presence import (
     ReplicationBumpPresenceActiveTime,
     ReplicationPresenceSetState,
 )
+from synapse.replication.http.streams import ReplicationGetStreamUpdates
 from synapse.replication.tcp.commands import ClearUserSyncsCommand
+from synapse.replication.tcp.streams import PresenceFederationStream, PresenceStream
 from synapse.state import StateHandler
 from synapse.storage.databases.main import DataStore
 from synapse.types import Collection, JsonDict, UserID, get_domain_from_id
@@ -128,10 +131,10 @@ class BasePresenceHandler(abc.ABC):
         self.is_mine_id = hs.is_mine_id
 
         self._federation = None
-        if hs.should_send_federation() or not hs.config.worker_app:
+        if hs.should_send_federation():
             self._federation = hs.get_federation_sender()
 
-        self._send_federation = hs.should_send_federation()
+        self._federation_queue = PresenceFederationQueue(hs, self)
 
         self._busy_presence_enabled = hs.config.experimental.msc3026_enabled
 
@@ -254,9 +257,17 @@ class BasePresenceHandler(abc.ABC):
         """
         pass
 
-    async def process_replication_rows(self, token, rows):
-        """Process presence stream rows received over replication."""
-        pass
+    async def process_replication_rows(
+        self, stream_name: str, instance_name: str, token: int, rows: list
+    ):
+        """Process streams received over replication."""
+        await self._federation_queue.process_replication_rows(
+            stream_name, instance_name, token, rows
+        )
+
+    def get_federation_queue(self) -> "PresenceFederationQueue":
+        """Get the presence federation queue."""
+        return self._federation_queue
 
     async def maybe_send_presence_to_interested_destinations(
         self, states: List[UserPresenceState]
@@ -266,12 +277,9 @@ class BasePresenceHandler(abc.ABC):
         users.
         """
 
-        if not self._send_federation:
+        if not self._federation:
             return
 
-        # If this worker sends federation we must have a FederationSender.
-        assert self._federation
-
         states = [s for s in states if self.is_mine_id(s.user_id)]
 
         if not states:
@@ -427,7 +435,14 @@ class WorkerPresenceHandler(BasePresenceHandler):
         # If this is a federation sender, notify about presence updates.
         await self.maybe_send_presence_to_interested_destinations(states)
 
-    async def process_replication_rows(self, token, rows):
+    async def process_replication_rows(
+        self, stream_name: str, instance_name: str, token: int, rows: list
+    ):
+        await super().process_replication_rows(stream_name, instance_name, token, rows)
+
+        if stream_name != PresenceStream.NAME:
+            return
+
         states = [
             UserPresenceState(
                 row.user_id,
@@ -729,12 +744,10 @@ class PresenceHandler(BasePresenceHandler):
                     self.state,
                 )
 
-                # Since this is master we know that we have a federation sender or
-                # queue, and so this will be defined.
-                assert self._federation
-
                 for destinations, states in hosts_and_states:
-                    self._federation.send_presence_to_destinations(states, destinations)
+                    self._federation_queue.send_presence_to_destinations(
+                        states, destinations
+                    )
 
     async def _handle_timeouts(self):
         """Checks the presence of users that have timed out and updates as
@@ -1213,13 +1226,9 @@ class PresenceHandler(BasePresenceHandler):
                     user_presence_states
                 )
 
-        # Since this is master we know that we have a federation sender or
-        # queue, and so this will be defined.
-        assert self._federation
-
         # Send out user presence updates for each destination
         for destination, user_state_set in presence_destinations.items():
-            self._federation.send_presence_to_destinations(
+            self._federation_queue.send_presence_to_destinations(
                 destinations=[destination], states=user_state_set
             )
 
@@ -1864,3 +1873,197 @@ async def get_interested_remotes(
         hosts_and_states.append(([host], states))
 
     return hosts_and_states
+
+
+class PresenceFederationQueue:
+    """Handles sending ad hoc presence updates over federation, which are *not*
+    due to state updates (that get handled via the presence stream), e.g.
+    federation pings and sending existing present states to newly joined hosts.
+
+    Only the last N minutes will be queued, so if a federation sender instance
+    is down for longer then some updates will be dropped. This is OK as presence
+    is ephemeral, and so it will self correct eventually.
+
+    On workers the class tracks the last received position of the stream from
+    replication, and handles querying for missed updates over HTTP replication,
+    c.f. `get_current_token` and `get_replication_rows`.
+    """
+
+    # How long to keep entries in the queue for. Workers that are down for
+    # longer than this duration will miss out on older updates.
+    _KEEP_ITEMS_IN_QUEUE_FOR_MS = 5 * 60 * 1000
+
+    # How often to check if we can expire entries from the queue.
+    _CLEAR_ITEMS_EVERY_MS = 60 * 1000
+
+    def __init__(self, hs: "HomeServer", presence_handler: BasePresenceHandler):
+        self._clock = hs.get_clock()
+        self._notifier = hs.get_notifier()
+        self._instance_name = hs.get_instance_name()
+        self._presence_handler = presence_handler
+        self._repl_client = ReplicationGetStreamUpdates.make_client(hs)
+
+        # Should we keep a queue of recent presence updates? We only bother if
+        # another process may be handling federation sending.
+        self._queue_presence_updates = True
+
+        # Whether this instance is a presence writer.
+        self._presence_writer = hs.config.worker.worker_app is None
+
+        # The FederationSender instance, if this process sends federation traffic directly.
+        self._federation = None
+
+        if hs.should_send_federation():
+            self._federation = hs.get_federation_sender()
+
+            # We don't bother queuing up presence states if only this instance
+            # is sending federation.
+            if hs.config.worker.federation_shard_config.instances == [
+                self._instance_name
+            ]:
+                self._queue_presence_updates = False
+
+        # The queue of recently queued updates as tuples of: `(timestamp,
+        # stream_id, destinations, user_ids)`. We don't store the full states
+        # for efficiency, and remote workers will already have the full states
+        # cached.
+        self._queue = []  # type: List[Tuple[int, int, Collection[str], Set[str]]]
+
+        self._next_id = 1
+
+        # Map from instance name to current token
+        self._current_tokens = {}  # type: Dict[str, int]
+
+        if self._queue_presence_updates:
+            self._clock.looping_call(self._clear_queue, self._CLEAR_ITEMS_EVERY_MS)
+
+    def _clear_queue(self):
+        """Clear out older entries from the queue."""
+        clear_before = self._clock.time_msec() - self._KEEP_ITEMS_IN_QUEUE_FOR_MS
+
+        # The queue is sorted by timestamp, so we can bisect to find the right
+        # place to purge before. Note that we are searching using a 1-tuple with
+        # the time, which does The Right Thing since the queue is a tuple where
+        # the first item is a timestamp.
+        index = bisect(self._queue, (clear_before,))
+        self._queue = self._queue[index:]
+
+    def send_presence_to_destinations(
+        self, states: Collection[UserPresenceState], destinations: Collection[str]
+    ) -> None:
+        """Send the presence states to the given destinations.
+
+        Will forward to the local federation sender (if there is one) and queue
+        to send over replication (if there are other federation sender instances.).
+
+        Must only be called on the master process.
+        """
+
+        # This should only be called on a presence writer.
+        assert self._presence_writer
+
+        if self._federation:
+            self._federation.send_presence_to_destinations(
+                states=states,
+                destinations=destinations,
+            )
+
+        if not self._queue_presence_updates:
+            return
+
+        now = self._clock.time_msec()
+
+        stream_id = self._next_id
+        self._next_id += 1
+
+        self._queue.append((now, stream_id, destinations, {s.user_id for s in states}))
+
+        self._notifier.notify_replication()
+
+    def get_current_token(self, instance_name: str) -> int:
+        """Get the current position of the stream.
+
+        On workers this returns the last stream ID received from replication.
+        """
+        if instance_name == self._instance_name:
+            return self._next_id - 1
+        else:
+            return self._current_tokens.get(instance_name, 0)
+
+    async def get_replication_rows(
+        self,
+        instance_name: str,
+        from_token: int,
+        upto_token: int,
+        target_row_count: int,
+    ) -> Tuple[List[Tuple[int, Tuple[str, str]]], int, bool]:
+        """Get all the updates between the two tokens.
+
+        We return rows in the form of `(destination, user_id)` to keep the size
+        of each row bounded (rather than returning the sets in a row).
+
+        On workers this will query the master process via HTTP replication.
+        """
+        if instance_name != self._instance_name:
+            # If not local we query over http replication from the master
+            result = await self._repl_client(
+                instance_name=instance_name,
+                stream_name=PresenceFederationStream.NAME,
+                from_token=from_token,
+                upto_token=upto_token,
+            )
+            return result["updates"], result["upto_token"], result["limited"]
+
+        # We can find the correct position in the queue by noting that there is
+        # exactly one entry per stream ID, and that the last entry has an ID of
+        # `self._next_id - 1`, so we can count backwards from the end.
+        #
+        # Since the start of the queue is periodically truncated we need to
+        # handle the case where `from_token` stream ID has already been dropped.
+        start_idx = max(from_token - self._next_id, -len(self._queue))
+
+        to_send = []  # type: List[Tuple[int, Tuple[str, str]]]
+        limited = False
+        new_id = upto_token
+        for _, stream_id, destinations, user_ids in self._queue[start_idx:]:
+            if stream_id > upto_token:
+                break
+
+            new_id = stream_id
+
+            to_send.extend(
+                (stream_id, (destination, user_id))
+                for destination in destinations
+                for user_id in user_ids
+            )
+
+            if len(to_send) > target_row_count:
+                limited = True
+                break
+
+        return to_send, new_id, limited
+
+    async def process_replication_rows(
+        self, stream_name: str, instance_name: str, token: int, rows: list
+    ):
+        if stream_name != PresenceFederationStream.NAME:
+            return
+
+        # We keep track of the current tokens (so that we can catch up with anything we missed after a disconnect)
+        self._current_tokens[instance_name] = token
+
+        # If we're a federation sender we pull out the presence states to send
+        # and forward them on.
+        if not self._federation:
+            return
+
+        hosts_to_users = {}  # type: Dict[str, Set[str]]
+        for row in rows:
+            hosts_to_users.setdefault(row.destination, set()).add(row.user_id)
+
+        for host, user_ids in hosts_to_users.items():
+            states = await self._presence_handler.current_state_for_users(user_ids)
+            self._federation.send_presence_to_destinations(
+                states=states.values(),
+                destinations=[host],
+            )
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index ce5d651cb8..4f3c6a18b6 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -29,7 +29,6 @@ from synapse.replication.tcp.streams import (
     AccountDataStream,
     DeviceListsStream,
     GroupServerStream,
-    PresenceStream,
     PushersStream,
     PushRulesStream,
     ReceiptsStream,
@@ -191,8 +190,6 @@ class ReplicationDataHandler:
                     self.stop_pusher(row.user_id, row.app_id, row.pushkey)
                 else:
                     await self.start_pusher(row.user_id, row.app_id, row.pushkey)
-        elif stream_name == PresenceStream.NAME:
-            await self._presence_handler.process_replication_rows(token, rows)
         elif stream_name == EventsStream.NAME:
             # We shouldn't get multiple rows per token for events stream, so
             # we don't need to optimise this for multiple rows.
@@ -221,6 +218,10 @@ class ReplicationDataHandler:
                     membership=row.data.membership,
                 )
 
+        await self._presence_handler.process_replication_rows(
+            stream_name, instance_name, token, rows
+        )
+
         # Notify any waiting deferreds. The list is ordered by position so we
         # just iterate through the list until we reach a position that is
         # greater than the received row position.
diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py
index fb74ac4e98..4c0023c68a 100644
--- a/synapse/replication/tcp/streams/__init__.py
+++ b/synapse/replication/tcp/streams/__init__.py
@@ -30,6 +30,7 @@ from synapse.replication.tcp.streams._base import (
     CachesStream,
     DeviceListsStream,
     GroupServerStream,
+    PresenceFederationStream,
     PresenceStream,
     PublicRoomsStream,
     PushersStream,
@@ -50,6 +51,7 @@ STREAMS_MAP = {
         EventsStream,
         BackfillStream,
         PresenceStream,
+        PresenceFederationStream,
         TypingStream,
         ReceiptsStream,
         PushRulesStream,
@@ -71,6 +73,7 @@ __all__ = [
     "Stream",
     "BackfillStream",
     "PresenceStream",
+    "PresenceFederationStream",
     "TypingStream",
     "ReceiptsStream",
     "PushRulesStream",
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 520c45f151..9d75a89f1c 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -290,6 +290,30 @@ class PresenceStream(Stream):
         )
 
 
+class PresenceFederationStream(Stream):
+    """A stream used to send ad hoc presence updates over federation.
+
+    Streams the remote destination and the user ID of the presence state to
+    send.
+    """
+
+    @attr.s(slots=True, auto_attribs=True)
+    class PresenceFederationStreamRow:
+        destination: str
+        user_id: str
+
+    NAME = "presence_federation"
+    ROW_TYPE = PresenceFederationStreamRow
+
+    def __init__(self, hs: "HomeServer"):
+        federation_queue = hs.get_presence_handler().get_federation_queue()
+        super().__init__(
+            hs.get_instance_name(),
+            federation_queue.get_current_token,
+            federation_queue.get_replication_rows,
+        )
+
+
 class TypingStream(Stream):
     TypingStreamRow = namedtuple(
         "TypingStreamRow", ("room_id", "user_ids")  # str  # list(str)