summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2021-04-15 10:42:25 +0100
committerErik Johnston <erik@matrix.org>2021-04-15 12:03:23 +0100
commit5c63b653c8d03b56ba56a77c07799b057d5bd901 (patch)
tree48a14d3fd0d382c40808077fc2e08cf00ba1e578
parentAlways use send_presence_to_destinations rather than send_presence (diff)
downloadsynapse-5c63b653c8d03b56ba56a77c07799b057d5bd901.tar.xz
Add a presence federation replication stream
-rw-r--r--synapse/handlers/presence.py248
-rw-r--r--synapse/replication/tcp/client.py8
-rw-r--r--synapse/replication/tcp/streams/__init__.py3
-rw-r--r--synapse/replication/tcp/streams/_base.py24
-rw-r--r--tests/handlers/test_presence.py162
5 files changed, 428 insertions, 17 deletions
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index a0cc779869..1255ceca55 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -24,6 +24,7 @@ The methods that define policy are:
 import abc
 import contextlib
 import logging
+from bisect import bisect
 from contextlib import contextmanager
 from typing import (
     TYPE_CHECKING,
@@ -53,7 +54,9 @@ from synapse.replication.http.presence import (
     ReplicationBumpPresenceActiveTime,
     ReplicationPresenceSetState,
 )
+from synapse.replication.http.streams import ReplicationGetStreamUpdates
 from synapse.replication.tcp.commands import ClearUserSyncsCommand
+from synapse.replication.tcp.streams import PresenceFederationStream, PresenceStream
 from synapse.state import StateHandler
 from synapse.storage.databases.main import DataStore
 from synapse.types import Collection, JsonDict, UserID, get_domain_from_id
@@ -124,6 +127,8 @@ class BasePresenceHandler(abc.ABC):
         self.clock = hs.get_clock()
         self.store = hs.get_datastore()
 
+        self.federation_queue = PresenceFederationQueue(hs, self)
+
         self._busy_presence_enabled = hs.config.experimental.msc3026_enabled
 
         active_presence = self.store.take_presence_startup_info()
@@ -245,9 +250,17 @@ class BasePresenceHandler(abc.ABC):
         """
         pass
 
-    async def process_replication_rows(self, token, rows):
+    async def process_replication_rows(
+        self, stream_name: str, instance_name: str, token: int, rows: list
+    ):
         """Process presence stream rows received over replication."""
-        pass
+        await self.federation_queue.process_replication_rows(
+            stream_name, instance_name, token, rows
+        )
+
+    def get_federation_queue(self) -> "PresenceFederationQueue":
+        """Get the presence federation queue, if any."""
+        return self.federation_queue
 
 
 class _NullContextManager(ContextManager[None]):
@@ -265,6 +278,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
 
         self.presence_router = hs.get_presence_router()
         self._presence_enabled = hs.config.use_presence
+        self.state = hs.get_state_handler()
 
         # The number of ongoing syncs on this process, by user id.
         # Empty if _presence_enabled is false.
@@ -273,6 +287,10 @@ class WorkerPresenceHandler(BasePresenceHandler):
         self.notifier = hs.get_notifier()
         self.instance_id = hs.get_instance_id()
 
+        self._federation = None
+        if hs.should_send_federation():
+            self._federation = hs.get_federation_sender()
+
         # user_id -> last_sync_ms. Lists the users that have stopped syncing
         # but we haven't notified the master of that yet
         self.users_going_offline = {}
@@ -388,7 +406,14 @@ class WorkerPresenceHandler(BasePresenceHandler):
             users=users_to_states.keys(),
         )
 
-    async def process_replication_rows(self, token, rows):
+    async def process_replication_rows(
+        self, stream_name: str, instance_name: str, token: int, rows: list
+    ):
+        await super().process_replication_rows(stream_name, instance_name, token, rows)
+
+        if stream_name != PresenceStream.NAME:
+            return
+
         states = [
             UserPresenceState(
                 row.user_id,
@@ -408,6 +433,20 @@ class WorkerPresenceHandler(BasePresenceHandler):
         stream_id = token
         await self.notify_from_replication(states, stream_id)
 
+        # Handle poking the local federation sender, if there is one.
+        if not self._federation:
+            return
+
+        hosts_and_states = await get_interested_remotes(
+            self.store,
+            self.presence_router,
+            states,
+            self.state,
+        )
+
+        for destinations, states in hosts_and_states:
+            self._federation.send_presence_to_destinations(states, destinations)
+
     def get_currently_syncing_users_for_replication(self) -> Iterable[str]:
         return [
             user_id
@@ -463,11 +502,14 @@ class PresenceHandler(BasePresenceHandler):
         self.server_name = hs.hostname
         self.wheel_timer = WheelTimer()
         self.notifier = hs.get_notifier()
-        self.federation = hs.get_federation_sender()
         self.state = hs.get_state_handler()
         self.presence_router = hs.get_presence_router()
         self._presence_enabled = hs.config.use_presence
 
+        self.federation_sender = None
+        if hs.should_send_federation():
+            self.federation_sender = hs.get_federation_sender()
+
         federation_registry = hs.get_federation_registry()
 
         federation_registry.register_edu_handler("m.presence", self.incoming_presence)
@@ -680,7 +722,17 @@ class PresenceHandler(BasePresenceHandler):
             if to_federation_ping:
                 federation_presence_out_counter.inc(len(to_federation_ping))
 
-                await self._push_to_remotes(to_federation_ping.values())
+                hosts_and_states = await get_interested_remotes(
+                    self.store,
+                    self.presence_router,
+                    list(to_federation_ping.values()),
+                    self.state,
+                )
+
+                for destinations, states in hosts_and_states:
+                    self.federation_queue.send_presence_to_destinations(
+                        states, destinations
+                    )
 
     async def _handle_timeouts(self):
         """Checks the presence of users that have timed out and updates as
@@ -920,14 +972,12 @@ class PresenceHandler(BasePresenceHandler):
             users=[UserID.from_string(u) for u in users_to_states],
         )
 
-        await self._push_to_remotes(states)
-
-    async def _push_to_remotes(self, states):
-        """Sends state updates to remote servers.
+        # We only need to tell the local federation sender, if any, that new
+        # presence has happened. Other federation senders will get notified via
+        # the presence replication stream.
+        if not self.federation_sender:
+            return
 
-        Args:
-            states (list(UserPresenceState))
-        """
         hosts_and_states = await get_interested_remotes(
             self.store,
             self.presence_router,
@@ -936,7 +986,7 @@ class PresenceHandler(BasePresenceHandler):
         )
 
         for destinations, states in hosts_and_states:
-            self.federation.send_presence_to_destinations(states, destinations)
+            self.federation_sender.send_presence_to_destinations(states, destinations)
 
     async def incoming_presence(self, origin, content):
         """Called when we receive a `m.presence` EDU from a remote server."""
@@ -1174,7 +1224,7 @@ class PresenceHandler(BasePresenceHandler):
 
         # Send out user presence updates for each destination
         for destination, user_state_set in presence_destinations.items():
-            self.federation.send_presence_to_destinations(
+            self.federation_queue.send_presence_to_destinations(
                 destinations=[destination], states=user_state_set
             )
 
@@ -1819,3 +1869,173 @@ async def get_interested_remotes(
         hosts_and_states.append(([host], states))
 
     return hosts_and_states
+
+
+class PresenceFederationQueue:
+    """Handles sending ad hoc presence updates over federation, which are *not*
+    due to state updates (that get handled via the presence stream), e.g.
+    federation pings and sending existing present states to newly joined hosts.
+
+    Only the last N minutes will be queued, so if a federation sender instance
+    is down for longer then some updates will be dropped. This is OK as presence
+    is ephemeral, and so it will self correct eventually.
+    """
+
+    # How long to keep entries in the queue for. Workers that are down for
+    # longer than this duration will miss out on older updates.
+    _KEEP_ITEMS_IN_QUEUE_FOR_MS = 5 * 60 * 1000
+
+    # How often to check if we can expire entries from the queue.
+    _CLEAR_ITEMS_EVERY_MS = 60 * 1000
+
+    def __init__(self, hs: "HomeServer", presence_handler: BasePresenceHandler):
+        self._clock = hs.get_clock()
+        self._notifier = hs.get_notifier()
+        self._instance_name = hs.get_instance_name()
+        self._presence_handler = presence_handler
+        self._repl_client = ReplicationGetStreamUpdates.make_client(hs)
+
+        # Should we keep a queue of recent presence updates? We only bother if
+        # another process may be handling federation sending.
+        self._queue_presence_updates = True
+
+        # The federation sender if this instance is a federation sender.
+        self._federation = None
+
+        if hs.should_send_federation():
+            self._federation = hs.get_federation_sender()
+
+            # We don't bother queuing up presence states if only this instance
+            # is sending federation.
+            if hs.config.worker.federation_shard_config.instances == [
+                self._instance_name
+            ]:
+                self._queue_presence_updates = False
+
+        # The queue of recently queued updates as tuples of: `(timestamp,
+        # stream_id, destinations, user_ids)`. We don't store the full states
+        # for efficiency, and remote workers will already have the full states
+        # cached.
+        self._queue = []  # type: List[Tuple[int, int, Collection[str], Set[str]]]
+
+        self._next_id = 1
+
+        # Map from instance name to current token
+        self._current_tokens = {}  # type: Dict[str, int]
+
+        if self._queue_presence_updates:
+            self._clock.looping_call(self._clear_queue, self._CLEAR_ITEMS_EVERY_MS)
+
+    def _clear_queue(self):
+        """Clear out older entries from the queue."""
+        clear_before = self._clock.time_msec() - self._KEEP_ITEMS_IN_QUEUE_FOR_MS
+
+        # The queue is sorted by timestamp, so we can bisect to find the right
+        # place to purge before. Note that we are searching using a 1-tuple with
+        # the time, which does The Right Thing since the queue is a tuple where
+        # the first item is a timestamp.
+        index = bisect(self._queue, (clear_before,))
+        self._queue = self._queue[index:]
+
+    def send_presence_to_destinations(
+        self, states: Collection[UserPresenceState], destinations: Collection[str]
+    ) -> None:
+        """Send the presence states to the given destinations.
+
+        Will forward to the local federation sender (if there is one) and queue
+        to send over replication (if there are other federation sender instances.).
+        """
+
+        if self._federation:
+            self._federation.send_presence_to_destinations(states, destinations)
+
+        if not self._queue_presence_updates:
+            return
+
+        now = self._clock.time_msec()
+
+        stream_id = self._next_id
+        self._next_id += 1
+
+        self._queue.append((now, stream_id, destinations, {s.user_id for s in states}))
+
+        self._notifier.notify_replication()
+
+    def get_current_token(self, instance_name: str) -> int:
+        if instance_name == self._instance_name:
+            return self._next_id - 1
+        else:
+            return self._current_tokens.get(instance_name, 0)
+
+    async def get_replication_rows(
+        self,
+        instance_name: str,
+        from_token: int,
+        upto_token: int,
+        target_row_count: int,
+    ) -> Tuple[List[Tuple[int, Tuple[str, str]]], int, bool]:
+        """Get all the updates between the two tokens.
+
+        We return rows in the form of `(destination, user_id)` to keep the size
+        of each row bounded (rather than returning the sets in a row).
+        """
+        if instance_name != self._instance_name:
+            # If not local we query over replication.
+            result = await self._repl_client(
+                instance_name=instance_name,
+                stream_name=PresenceFederationStream.NAME,
+                from_token=from_token,
+                upto_token=upto_token,
+            )
+            return result["updates"], result["upto_token"], result["limited"]
+
+        # We can find the correct position in the queue by noting that there is
+        # exactly one entry per stream ID, and that the last entry has an ID of
+        # `self._next_id - 1`, so we can count backwards from the end.
+        #
+        # Since the start of the queue is periodically truncated we need to
+        # handle the case where `from_token` stream ID has already been dropped.
+        start_idx = max(from_token - self._next_id, -len(self._queue))
+
+        to_send = []  # type: List[Tuple[int, Tuple[str, str]]]
+        limited = False
+        new_id = upto_token
+        for _, stream_id, destinations, user_ids in self._queue[start_idx:]:
+            if stream_id > upto_token:
+                break
+
+            new_id = stream_id
+
+            to_send.extend(
+                (stream_id, (destination, user_id))
+                for destination in destinations
+                for user_id in user_ids
+            )
+
+            if len(to_send) > target_row_count:
+                limited = True
+                break
+
+        return to_send, new_id, limited
+
+    async def process_replication_rows(
+        self, stream_name: str, instance_name: str, token: int, rows: list
+    ):
+        if stream_name != PresenceFederationStream.NAME:
+            return
+
+        # We keep track of the current tokens
+        self._current_tokens[instance_name] = token
+
+        # If we're a federation sender we pull out the presence states to send
+        # and forward them on.
+        if not self._federation:
+            return
+
+        hosts_to_users = {}  # type: Dict[str, Set[str]]
+        for row in rows:
+            hosts_to_users.setdefault(row.destination, set()).add(row.user_id)
+
+        for host, user_ids in hosts_to_users.items():
+            states = await self._presence_handler.current_state_for_users(user_ids)
+            self._federation.send_presence_to_destinations(states.values(), [host])
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index ce5d651cb8..874cb9c25e 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -29,7 +29,6 @@ from synapse.replication.tcp.streams import (
     AccountDataStream,
     DeviceListsStream,
     GroupServerStream,
-    PresenceStream,
     PushersStream,
     PushRulesStream,
     ReceiptsStream,
@@ -191,8 +190,6 @@ class ReplicationDataHandler:
                     self.stop_pusher(row.user_id, row.app_id, row.pushkey)
                 else:
                     await self.start_pusher(row.user_id, row.app_id, row.pushkey)
-        elif stream_name == PresenceStream.NAME:
-            await self._presence_handler.process_replication_rows(token, rows)
         elif stream_name == EventsStream.NAME:
             # We shouldn't get multiple rows per token for events stream, so
             # we don't need to optimise this for multiple rows.
@@ -221,6 +218,10 @@ class ReplicationDataHandler:
                     membership=row.data.membership,
                 )
 
+        await self._presence_handler.process_replication_rows(
+            stream_name, instance_name, token, rows
+        )
+
         # Notify any waiting deferreds. The list is ordered by position so we
         # just iterate through the list until we reach a position that is
         # greater than the received row position.
@@ -338,6 +339,7 @@ class FederationSenderHandler:
         self.store = hs.get_datastore()
         self._is_mine_id = hs.is_mine_id
         self._hs = hs
+        self._presence_handler = hs.get_presence_handler()
 
         # We need to make a temporary value to ensure that mypy picks up the
         # right type. We know we should have a federation sender instance since
diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py
index fb74ac4e98..4c0023c68a 100644
--- a/synapse/replication/tcp/streams/__init__.py
+++ b/synapse/replication/tcp/streams/__init__.py
@@ -30,6 +30,7 @@ from synapse.replication.tcp.streams._base import (
     CachesStream,
     DeviceListsStream,
     GroupServerStream,
+    PresenceFederationStream,
     PresenceStream,
     PublicRoomsStream,
     PushersStream,
@@ -50,6 +51,7 @@ STREAMS_MAP = {
         EventsStream,
         BackfillStream,
         PresenceStream,
+        PresenceFederationStream,
         TypingStream,
         ReceiptsStream,
         PushRulesStream,
@@ -71,6 +73,7 @@ __all__ = [
     "Stream",
     "BackfillStream",
     "PresenceStream",
+    "PresenceFederationStream",
     "TypingStream",
     "ReceiptsStream",
     "PushRulesStream",
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 520c45f151..9d75a89f1c 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -290,6 +290,30 @@ class PresenceStream(Stream):
         )
 
 
+class PresenceFederationStream(Stream):
+    """A stream used to send ad hoc presence updates over federation.
+
+    Streams the remote destination and the user ID of the presence state to
+    send.
+    """
+
+    @attr.s(slots=True, auto_attribs=True)
+    class PresenceFederationStreamRow:
+        destination: str
+        user_id: str
+
+    NAME = "presence_federation"
+    ROW_TYPE = PresenceFederationStreamRow
+
+    def __init__(self, hs: "HomeServer"):
+        federation_queue = hs.get_presence_handler().get_federation_queue()
+        super().__init__(
+            hs.get_instance_name(),
+            federation_queue.get_current_token,
+            federation_queue.get_replication_rows,
+        )
+
+
 class TypingStream(Stream):
     TypingStreamRow = namedtuple(
         "TypingStreamRow", ("room_id", "user_ids")  # str  # list(str)
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index 2d12e82897..6cda602fce 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -471,6 +471,168 @@ class PresenceHandlerTestCase(unittest.HomeserverTestCase):
         self.assertEqual(state.state, PresenceState.OFFLINE)
 
 
+class PresenceFederationQueueTestCase(unittest.HomeserverTestCase):
+    def prepare(self, reactor, clock, hs):
+        self.presence_handler = hs.get_presence_handler()
+        self.clock = hs.get_clock()
+        self.instance_name = hs.get_instance_name()
+
+        self.queue = self.presence_handler.get_federation_queue()
+
+    def test_send_and_get(self):
+        state1 = UserPresenceState.default("@user1:test")
+        state2 = UserPresenceState.default("@user2:test")
+        state3 = UserPresenceState.default("@user3:test")
+
+        prev_token = self.queue.get_current_token(self.instance_name)
+
+        self.queue.send_presence_to_destinations((state1, state2), ("dest1", "dest2"))
+        self.queue.send_presence_to_destinations((state3,), ("dest3",))
+
+        now_token = self.queue.get_current_token(self.instance_name)
+
+        rows, upto_token, limited = self.get_success(
+            self.queue.get_replication_rows("master", prev_token, now_token, 10)
+        )
+
+        self.assertEqual(upto_token, now_token)
+        self.assertFalse(limited)
+
+        expected_rows = [
+            (1, ("dest1", "@user1:test")),
+            (1, ("dest2", "@user1:test")),
+            (1, ("dest1", "@user2:test")),
+            (1, ("dest2", "@user2:test")),
+            (2, ("dest3", "@user3:test")),
+        ]
+
+        self.assertCountEqual(rows, expected_rows)
+
+    def test_send_and_get_split(self):
+        state1 = UserPresenceState.default("@user1:test")
+        state2 = UserPresenceState.default("@user2:test")
+        state3 = UserPresenceState.default("@user3:test")
+
+        prev_token = self.queue.get_current_token(self.instance_name)
+
+        self.queue.send_presence_to_destinations((state1, state2), ("dest1", "dest2"))
+
+        now_token = self.queue.get_current_token(self.instance_name)
+
+        self.queue.send_presence_to_destinations((state3,), ("dest3",))
+
+        rows, upto_token, limited = self.get_success(
+            self.queue.get_replication_rows("master", prev_token, now_token, 10)
+        )
+
+        self.assertEqual(upto_token, now_token)
+        self.assertFalse(limited)
+
+        expected_rows = [
+            (1, ("dest1", "@user1:test")),
+            (1, ("dest2", "@user1:test")),
+            (1, ("dest1", "@user2:test")),
+            (1, ("dest2", "@user2:test")),
+        ]
+
+        self.assertCountEqual(rows, expected_rows)
+
+    def test_clear_queue_all(self):
+        state1 = UserPresenceState.default("@user1:test")
+        state2 = UserPresenceState.default("@user2:test")
+        state3 = UserPresenceState.default("@user3:test")
+
+        prev_token = self.queue.get_current_token(self.instance_name)
+
+        self.queue.send_presence_to_destinations((state1, state2), ("dest1", "dest2"))
+        self.queue.send_presence_to_destinations((state3,), ("dest3",))
+
+        self.reactor.advance(10 * 60 * 1000)
+
+        now_token = self.queue.get_current_token(self.instance_name)
+
+        rows, upto_token, limited = self.get_success(
+            self.queue.get_replication_rows("master", prev_token, now_token, 10)
+        )
+        self.assertEqual(upto_token, now_token)
+        self.assertFalse(limited)
+        self.assertCountEqual(rows, [])
+
+        prev_token = self.queue.get_current_token(self.instance_name)
+
+        self.queue.send_presence_to_destinations((state1, state2), ("dest1", "dest2"))
+        self.queue.send_presence_to_destinations((state3,), ("dest3",))
+
+        now_token = self.queue.get_current_token(self.instance_name)
+
+        rows, upto_token, limited = self.get_success(
+            self.queue.get_replication_rows("master", prev_token, now_token, 10)
+        )
+        self.assertEqual(upto_token, now_token)
+        self.assertFalse(limited)
+
+        expected_rows = [
+            (3, ("dest1", "@user1:test")),
+            (3, ("dest2", "@user1:test")),
+            (3, ("dest1", "@user2:test")),
+            (3, ("dest2", "@user2:test")),
+            (4, ("dest3", "@user3:test")),
+        ]
+
+        self.assertCountEqual(rows, expected_rows)
+
+    def test_partially_clear_queue(self):
+        state1 = UserPresenceState.default("@user1:test")
+        state2 = UserPresenceState.default("@user2:test")
+        state3 = UserPresenceState.default("@user3:test")
+
+        prev_token = self.queue.get_current_token(self.instance_name)
+
+        self.queue.send_presence_to_destinations((state1, state2), ("dest1", "dest2"))
+
+        self.reactor.advance(2 * 60 * 1000)
+
+        self.queue.send_presence_to_destinations((state3,), ("dest3",))
+
+        self.reactor.advance(4 * 60 * 1000)
+
+        now_token = self.queue.get_current_token(self.instance_name)
+
+        rows, upto_token, limited = self.get_success(
+            self.queue.get_replication_rows("master", prev_token, now_token, 10)
+        )
+        self.assertEqual(upto_token, now_token)
+        self.assertFalse(limited)
+
+        expected_rows = [
+            (2, ("dest3", "@user3:test")),
+        ]
+        self.assertCountEqual(rows, [])
+
+        prev_token = self.queue.get_current_token(self.instance_name)
+
+        self.queue.send_presence_to_destinations((state1, state2), ("dest1", "dest2"))
+        self.queue.send_presence_to_destinations((state3,), ("dest3",))
+
+        now_token = self.queue.get_current_token(self.instance_name)
+
+        rows, upto_token, limited = self.get_success(
+            self.queue.get_replication_rows("master", prev_token, now_token, 10)
+        )
+        self.assertEqual(upto_token, now_token)
+        self.assertFalse(limited)
+
+        expected_rows = [
+            (3, ("dest1", "@user1:test")),
+            (3, ("dest2", "@user1:test")),
+            (3, ("dest1", "@user2:test")),
+            (3, ("dest2", "@user2:test")),
+            (4, ("dest3", "@user3:test")),
+        ]
+
+        self.assertCountEqual(rows, expected_rows)
+
+
 class PresenceJoinTestCase(unittest.HomeserverTestCase):
     """Tests remote servers get told about presence of users in the room when
     they join and when new local users join.