summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2021-04-20 14:11:24 +0100
committerGitHub <noreply@github.com>2021-04-20 14:11:24 +0100
commitde0d088adc0cf3d5bbd80238b88143426cd6eaca (patch)
treebd045d24d1b2bf83b0742e952c4b7400365fc58d
parentFix bug where we sent remote presence states to remote servers (#9850) (diff)
downloadsynapse-de0d088adc0cf3d5bbd80238b88143426cd6eaca.tar.xz
Add presence federation stream (#9819)
-rw-r--r--changelog.d/9819.feature1
-rw-r--r--synapse/handlers/presence.py243
-rw-r--r--synapse/replication/tcp/client.py7
-rw-r--r--synapse/replication/tcp/streams/__init__.py3
-rw-r--r--synapse/replication/tcp/streams/_base.py24
-rw-r--r--tests/handlers/test_presence.py179
6 files changed, 426 insertions, 31 deletions
diff --git a/changelog.d/9819.feature b/changelog.d/9819.feature
new file mode 100644
index 0000000000..f56b0bb3bd
--- /dev/null
+++ b/changelog.d/9819.feature
@@ -0,0 +1 @@
+Add experimental support for handling presence on a worker.
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index bd2382193f..598466c9bd 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -24,6 +24,7 @@ The methods that define policy are:
 import abc
 import contextlib
 import logging
+from bisect import bisect
 from contextlib import contextmanager
 from typing import (
     TYPE_CHECKING,
@@ -53,7 +54,9 @@ from synapse.replication.http.presence import (
     ReplicationBumpPresenceActiveTime,
     ReplicationPresenceSetState,
 )
+from synapse.replication.http.streams import ReplicationGetStreamUpdates
 from synapse.replication.tcp.commands import ClearUserSyncsCommand
+from synapse.replication.tcp.streams import PresenceFederationStream, PresenceStream
 from synapse.state import StateHandler
 from synapse.storage.databases.main import DataStore
 from synapse.types import Collection, JsonDict, UserID, get_domain_from_id
@@ -128,10 +131,10 @@ class BasePresenceHandler(abc.ABC):
         self.is_mine_id = hs.is_mine_id
 
         self._federation = None
-        if hs.should_send_federation() or not hs.config.worker_app:
+        if hs.should_send_federation():
             self._federation = hs.get_federation_sender()
 
-        self._send_federation = hs.should_send_federation()
+        self._federation_queue = PresenceFederationQueue(hs, self)
 
         self._busy_presence_enabled = hs.config.experimental.msc3026_enabled
 
@@ -254,9 +257,17 @@ class BasePresenceHandler(abc.ABC):
         """
         pass
 
-    async def process_replication_rows(self, token, rows):
-        """Process presence stream rows received over replication."""
-        pass
+    async def process_replication_rows(
+        self, stream_name: str, instance_name: str, token: int, rows: list
+    ):
+        """Process streams received over replication."""
+        await self._federation_queue.process_replication_rows(
+            stream_name, instance_name, token, rows
+        )
+
+    def get_federation_queue(self) -> "PresenceFederationQueue":
+        """Get the presence federation queue."""
+        return self._federation_queue
 
     async def maybe_send_presence_to_interested_destinations(
         self, states: List[UserPresenceState]
@@ -266,12 +277,9 @@ class BasePresenceHandler(abc.ABC):
         users.
         """
 
-        if not self._send_federation:
+        if not self._federation:
             return
 
-        # If this worker sends federation we must have a FederationSender.
-        assert self._federation
-
         states = [s for s in states if self.is_mine_id(s.user_id)]
 
         if not states:
@@ -427,7 +435,14 @@ class WorkerPresenceHandler(BasePresenceHandler):
         # If this is a federation sender, notify about presence updates.
         await self.maybe_send_presence_to_interested_destinations(states)
 
-    async def process_replication_rows(self, token, rows):
+    async def process_replication_rows(
+        self, stream_name: str, instance_name: str, token: int, rows: list
+    ):
+        await super().process_replication_rows(stream_name, instance_name, token, rows)
+
+        if stream_name != PresenceStream.NAME:
+            return
+
         states = [
             UserPresenceState(
                 row.user_id,
@@ -729,12 +744,10 @@ class PresenceHandler(BasePresenceHandler):
                     self.state,
                 )
 
-                # Since this is master we know that we have a federation sender or
-                # queue, and so this will be defined.
-                assert self._federation
-
                 for destinations, states in hosts_and_states:
-                    self._federation.send_presence_to_destinations(states, destinations)
+                    self._federation_queue.send_presence_to_destinations(
+                        states, destinations
+                    )
 
     async def _handle_timeouts(self):
         """Checks the presence of users that have timed out and updates as
@@ -1213,13 +1226,9 @@ class PresenceHandler(BasePresenceHandler):
                     user_presence_states
                 )
 
-        # Since this is master we know that we have a federation sender or
-        # queue, and so this will be defined.
-        assert self._federation
-
         # Send out user presence updates for each destination
         for destination, user_state_set in presence_destinations.items():
-            self._federation.send_presence_to_destinations(
+            self._federation_queue.send_presence_to_destinations(
                 destinations=[destination], states=user_state_set
             )
 
@@ -1864,3 +1873,197 @@ async def get_interested_remotes(
         hosts_and_states.append(([host], states))
 
     return hosts_and_states
+
+
+class PresenceFederationQueue:
+    """Handles sending ad hoc presence updates over federation, which are *not*
+    due to state updates (that get handled via the presence stream), e.g.
+    federation pings and sending existing present states to newly joined hosts.
+
+    Only the last N minutes will be queued, so if a federation sender instance
+    is down for longer then some updates will be dropped. This is OK as presence
+    is ephemeral, and so it will self correct eventually.
+
+    On workers the class tracks the last received position of the stream from
+    replication, and handles querying for missed updates over HTTP replication,
+    c.f. `get_current_token` and `get_replication_rows`.
+    """
+
+    # How long to keep entries in the queue for. Workers that are down for
+    # longer than this duration will miss out on older updates.
+    _KEEP_ITEMS_IN_QUEUE_FOR_MS = 5 * 60 * 1000
+
+    # How often to check if we can expire entries from the queue.
+    _CLEAR_ITEMS_EVERY_MS = 60 * 1000
+
+    def __init__(self, hs: "HomeServer", presence_handler: BasePresenceHandler):
+        self._clock = hs.get_clock()
+        self._notifier = hs.get_notifier()
+        self._instance_name = hs.get_instance_name()
+        self._presence_handler = presence_handler
+        self._repl_client = ReplicationGetStreamUpdates.make_client(hs)
+
+        # Should we keep a queue of recent presence updates? We only bother if
+        # another process may be handling federation sending.
+        self._queue_presence_updates = True
+
+        # Whether this instance is a presence writer.
+        self._presence_writer = hs.config.worker.worker_app is None
+
+        # The FederationSender instance, if this process sends federation traffic directly.
+        self._federation = None
+
+        if hs.should_send_federation():
+            self._federation = hs.get_federation_sender()
+
+            # We don't bother queuing up presence states if only this instance
+            # is sending federation.
+            if hs.config.worker.federation_shard_config.instances == [
+                self._instance_name
+            ]:
+                self._queue_presence_updates = False
+
+        # The queue of recently queued updates as tuples of: `(timestamp,
+        # stream_id, destinations, user_ids)`. We don't store the full states
+        # for efficiency, and remote workers will already have the full states
+        # cached.
+        self._queue = []  # type: List[Tuple[int, int, Collection[str], Set[str]]]
+
+        self._next_id = 1
+
+        # Map from instance name to current token
+        self._current_tokens = {}  # type: Dict[str, int]
+
+        if self._queue_presence_updates:
+            self._clock.looping_call(self._clear_queue, self._CLEAR_ITEMS_EVERY_MS)
+
+    def _clear_queue(self):
+        """Clear out older entries from the queue."""
+        clear_before = self._clock.time_msec() - self._KEEP_ITEMS_IN_QUEUE_FOR_MS
+
+        # The queue is sorted by timestamp, so we can bisect to find the right
+        # place to purge before. Note that we are searching using a 1-tuple with
+        # the time, which does The Right Thing since the queue is a tuple where
+        # the first item is a timestamp.
+        index = bisect(self._queue, (clear_before,))
+        self._queue = self._queue[index:]
+
+    def send_presence_to_destinations(
+        self, states: Collection[UserPresenceState], destinations: Collection[str]
+    ) -> None:
+        """Send the presence states to the given destinations.
+
+        Will forward to the local federation sender (if there is one) and queue
+        to send over replication (if there are other federation sender instances.).
+
+        Must only be called on the master process.
+        """
+
+        # This should only be called on a presence writer.
+        assert self._presence_writer
+
+        if self._federation:
+            self._federation.send_presence_to_destinations(
+                states=states,
+                destinations=destinations,
+            )
+
+        if not self._queue_presence_updates:
+            return
+
+        now = self._clock.time_msec()
+
+        stream_id = self._next_id
+        self._next_id += 1
+
+        self._queue.append((now, stream_id, destinations, {s.user_id for s in states}))
+
+        self._notifier.notify_replication()
+
+    def get_current_token(self, instance_name: str) -> int:
+        """Get the current position of the stream.
+
+        On workers this returns the last stream ID received from replication.
+        """
+        if instance_name == self._instance_name:
+            return self._next_id - 1
+        else:
+            return self._current_tokens.get(instance_name, 0)
+
+    async def get_replication_rows(
+        self,
+        instance_name: str,
+        from_token: int,
+        upto_token: int,
+        target_row_count: int,
+    ) -> Tuple[List[Tuple[int, Tuple[str, str]]], int, bool]:
+        """Get all the updates between the two tokens.
+
+        We return rows in the form of `(destination, user_id)` to keep the size
+        of each row bounded (rather than returning the sets in a row).
+
+        On workers this will query the master process via HTTP replication.
+        """
+        if instance_name != self._instance_name:
+            # If not local we query over http replication from the master
+            result = await self._repl_client(
+                instance_name=instance_name,
+                stream_name=PresenceFederationStream.NAME,
+                from_token=from_token,
+                upto_token=upto_token,
+            )
+            return result["updates"], result["upto_token"], result["limited"]
+
+        # We can find the correct position in the queue by noting that there is
+        # exactly one entry per stream ID, and that the last entry has an ID of
+        # `self._next_id - 1`, so we can count backwards from the end.
+        #
+        # Since the start of the queue is periodically truncated we need to
+        # handle the case where `from_token` stream ID has already been dropped.
+        start_idx = max(from_token - self._next_id, -len(self._queue))
+
+        to_send = []  # type: List[Tuple[int, Tuple[str, str]]]
+        limited = False
+        new_id = upto_token
+        for _, stream_id, destinations, user_ids in self._queue[start_idx:]:
+            if stream_id > upto_token:
+                break
+
+            new_id = stream_id
+
+            to_send.extend(
+                (stream_id, (destination, user_id))
+                for destination in destinations
+                for user_id in user_ids
+            )
+
+            if len(to_send) > target_row_count:
+                limited = True
+                break
+
+        return to_send, new_id, limited
+
+    async def process_replication_rows(
+        self, stream_name: str, instance_name: str, token: int, rows: list
+    ):
+        if stream_name != PresenceFederationStream.NAME:
+            return
+
+        # We keep track of the current tokens (so that we can catch up with anything we missed after a disconnect)
+        self._current_tokens[instance_name] = token
+
+        # If we're a federation sender we pull out the presence states to send
+        # and forward them on.
+        if not self._federation:
+            return
+
+        hosts_to_users = {}  # type: Dict[str, Set[str]]
+        for row in rows:
+            hosts_to_users.setdefault(row.destination, set()).add(row.user_id)
+
+        for host, user_ids in hosts_to_users.items():
+            states = await self._presence_handler.current_state_for_users(user_ids)
+            self._federation.send_presence_to_destinations(
+                states=states.values(),
+                destinations=[host],
+            )
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index ce5d651cb8..4f3c6a18b6 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -29,7 +29,6 @@ from synapse.replication.tcp.streams import (
     AccountDataStream,
     DeviceListsStream,
     GroupServerStream,
-    PresenceStream,
     PushersStream,
     PushRulesStream,
     ReceiptsStream,
@@ -191,8 +190,6 @@ class ReplicationDataHandler:
                     self.stop_pusher(row.user_id, row.app_id, row.pushkey)
                 else:
                     await self.start_pusher(row.user_id, row.app_id, row.pushkey)
-        elif stream_name == PresenceStream.NAME:
-            await self._presence_handler.process_replication_rows(token, rows)
         elif stream_name == EventsStream.NAME:
             # We shouldn't get multiple rows per token for events stream, so
             # we don't need to optimise this for multiple rows.
@@ -221,6 +218,10 @@ class ReplicationDataHandler:
                     membership=row.data.membership,
                 )
 
+        await self._presence_handler.process_replication_rows(
+            stream_name, instance_name, token, rows
+        )
+
         # Notify any waiting deferreds. The list is ordered by position so we
         # just iterate through the list until we reach a position that is
         # greater than the received row position.
diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py
index fb74ac4e98..4c0023c68a 100644
--- a/synapse/replication/tcp/streams/__init__.py
+++ b/synapse/replication/tcp/streams/__init__.py
@@ -30,6 +30,7 @@ from synapse.replication.tcp.streams._base import (
     CachesStream,
     DeviceListsStream,
     GroupServerStream,
+    PresenceFederationStream,
     PresenceStream,
     PublicRoomsStream,
     PushersStream,
@@ -50,6 +51,7 @@ STREAMS_MAP = {
         EventsStream,
         BackfillStream,
         PresenceStream,
+        PresenceFederationStream,
         TypingStream,
         ReceiptsStream,
         PushRulesStream,
@@ -71,6 +73,7 @@ __all__ = [
     "Stream",
     "BackfillStream",
     "PresenceStream",
+    "PresenceFederationStream",
     "TypingStream",
     "ReceiptsStream",
     "PushRulesStream",
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 520c45f151..9d75a89f1c 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -290,6 +290,30 @@ class PresenceStream(Stream):
         )
 
 
+class PresenceFederationStream(Stream):
+    """A stream used to send ad hoc presence updates over federation.
+
+    Streams the remote destination and the user ID of the presence state to
+    send.
+    """
+
+    @attr.s(slots=True, auto_attribs=True)
+    class PresenceFederationStreamRow:
+        destination: str
+        user_id: str
+
+    NAME = "presence_federation"
+    ROW_TYPE = PresenceFederationStreamRow
+
+    def __init__(self, hs: "HomeServer"):
+        federation_queue = hs.get_presence_handler().get_federation_queue()
+        super().__init__(
+            hs.get_instance_name(),
+            federation_queue.get_current_token,
+            federation_queue.get_replication_rows,
+        )
+
+
 class TypingStream(Stream):
     TypingStreamRow = namedtuple(
         "TypingStreamRow", ("room_id", "user_ids")  # str  # list(str)
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index 2d12e82897..61271cd084 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -21,6 +21,7 @@ from synapse.api.constants import EventTypes, Membership, PresenceState
 from synapse.api.presence import UserPresenceState
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
 from synapse.events.builder import EventBuilder
+from synapse.federation.sender import FederationSender
 from synapse.handlers.presence import (
     EXTERNAL_PROCESS_EXPIRY,
     FEDERATION_PING_INTERVAL,
@@ -471,6 +472,168 @@ class PresenceHandlerTestCase(unittest.HomeserverTestCase):
         self.assertEqual(state.state, PresenceState.OFFLINE)
 
 
+class PresenceFederationQueueTestCase(unittest.HomeserverTestCase):
+    def prepare(self, reactor, clock, hs):
+        self.presence_handler = hs.get_presence_handler()
+        self.clock = hs.get_clock()
+        self.instance_name = hs.get_instance_name()
+
+        self.queue = self.presence_handler.get_federation_queue()
+
+    def test_send_and_get(self):
+        state1 = UserPresenceState.default("@user1:test")
+        state2 = UserPresenceState.default("@user2:test")
+        state3 = UserPresenceState.default("@user3:test")
+
+        prev_token = self.queue.get_current_token(self.instance_name)
+
+        self.queue.send_presence_to_destinations((state1, state2), ("dest1", "dest2"))
+        self.queue.send_presence_to_destinations((state3,), ("dest3",))
+
+        now_token = self.queue.get_current_token(self.instance_name)
+
+        rows, upto_token, limited = self.get_success(
+            self.queue.get_replication_rows("master", prev_token, now_token, 10)
+        )
+
+        self.assertEqual(upto_token, now_token)
+        self.assertFalse(limited)
+
+        expected_rows = [
+            (1, ("dest1", "@user1:test")),
+            (1, ("dest2", "@user1:test")),
+            (1, ("dest1", "@user2:test")),
+            (1, ("dest2", "@user2:test")),
+            (2, ("dest3", "@user3:test")),
+        ]
+
+        self.assertCountEqual(rows, expected_rows)
+
+    def test_send_and_get_split(self):
+        state1 = UserPresenceState.default("@user1:test")
+        state2 = UserPresenceState.default("@user2:test")
+        state3 = UserPresenceState.default("@user3:test")
+
+        prev_token = self.queue.get_current_token(self.instance_name)
+
+        self.queue.send_presence_to_destinations((state1, state2), ("dest1", "dest2"))
+
+        now_token = self.queue.get_current_token(self.instance_name)
+
+        self.queue.send_presence_to_destinations((state3,), ("dest3",))
+
+        rows, upto_token, limited = self.get_success(
+            self.queue.get_replication_rows("master", prev_token, now_token, 10)
+        )
+
+        self.assertEqual(upto_token, now_token)
+        self.assertFalse(limited)
+
+        expected_rows = [
+            (1, ("dest1", "@user1:test")),
+            (1, ("dest2", "@user1:test")),
+            (1, ("dest1", "@user2:test")),
+            (1, ("dest2", "@user2:test")),
+        ]
+
+        self.assertCountEqual(rows, expected_rows)
+
+    def test_clear_queue_all(self):
+        state1 = UserPresenceState.default("@user1:test")
+        state2 = UserPresenceState.default("@user2:test")
+        state3 = UserPresenceState.default("@user3:test")
+
+        prev_token = self.queue.get_current_token(self.instance_name)
+
+        self.queue.send_presence_to_destinations((state1, state2), ("dest1", "dest2"))
+        self.queue.send_presence_to_destinations((state3,), ("dest3",))
+
+        self.reactor.advance(10 * 60 * 1000)
+
+        now_token = self.queue.get_current_token(self.instance_name)
+
+        rows, upto_token, limited = self.get_success(
+            self.queue.get_replication_rows("master", prev_token, now_token, 10)
+        )
+        self.assertEqual(upto_token, now_token)
+        self.assertFalse(limited)
+        self.assertCountEqual(rows, [])
+
+        prev_token = self.queue.get_current_token(self.instance_name)
+
+        self.queue.send_presence_to_destinations((state1, state2), ("dest1", "dest2"))
+        self.queue.send_presence_to_destinations((state3,), ("dest3",))
+
+        now_token = self.queue.get_current_token(self.instance_name)
+
+        rows, upto_token, limited = self.get_success(
+            self.queue.get_replication_rows("master", prev_token, now_token, 10)
+        )
+        self.assertEqual(upto_token, now_token)
+        self.assertFalse(limited)
+
+        expected_rows = [
+            (3, ("dest1", "@user1:test")),
+            (3, ("dest2", "@user1:test")),
+            (3, ("dest1", "@user2:test")),
+            (3, ("dest2", "@user2:test")),
+            (4, ("dest3", "@user3:test")),
+        ]
+
+        self.assertCountEqual(rows, expected_rows)
+
+    def test_partially_clear_queue(self):
+        state1 = UserPresenceState.default("@user1:test")
+        state2 = UserPresenceState.default("@user2:test")
+        state3 = UserPresenceState.default("@user3:test")
+
+        prev_token = self.queue.get_current_token(self.instance_name)
+
+        self.queue.send_presence_to_destinations((state1, state2), ("dest1", "dest2"))
+
+        self.reactor.advance(2 * 60 * 1000)
+
+        self.queue.send_presence_to_destinations((state3,), ("dest3",))
+
+        self.reactor.advance(4 * 60 * 1000)
+
+        now_token = self.queue.get_current_token(self.instance_name)
+
+        rows, upto_token, limited = self.get_success(
+            self.queue.get_replication_rows("master", prev_token, now_token, 10)
+        )
+        self.assertEqual(upto_token, now_token)
+        self.assertFalse(limited)
+
+        expected_rows = [
+            (2, ("dest3", "@user3:test")),
+        ]
+        self.assertCountEqual(rows, [])
+
+        prev_token = self.queue.get_current_token(self.instance_name)
+
+        self.queue.send_presence_to_destinations((state1, state2), ("dest1", "dest2"))
+        self.queue.send_presence_to_destinations((state3,), ("dest3",))
+
+        now_token = self.queue.get_current_token(self.instance_name)
+
+        rows, upto_token, limited = self.get_success(
+            self.queue.get_replication_rows("master", prev_token, now_token, 10)
+        )
+        self.assertEqual(upto_token, now_token)
+        self.assertFalse(limited)
+
+        expected_rows = [
+            (3, ("dest1", "@user1:test")),
+            (3, ("dest2", "@user1:test")),
+            (3, ("dest1", "@user2:test")),
+            (3, ("dest2", "@user2:test")),
+            (4, ("dest3", "@user3:test")),
+        ]
+
+        self.assertCountEqual(rows, expected_rows)
+
+
 class PresenceJoinTestCase(unittest.HomeserverTestCase):
     """Tests remote servers get told about presence of users in the room when
     they join and when new local users join.
@@ -482,10 +645,17 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
 
     def make_homeserver(self, reactor, clock):
         hs = self.setup_test_homeserver(
-            "server", federation_http_client=None, federation_sender=Mock()
+            "server",
+            federation_http_client=None,
+            federation_sender=Mock(spec=FederationSender),
         )
         return hs
 
+    def default_config(self):
+        config = super().default_config()
+        config["send_federation"] = True
+        return config
+
     def prepare(self, reactor, clock, hs):
         self.federation_sender = hs.get_federation_sender()
         self.event_builder_factory = hs.get_event_builder_factory()
@@ -529,9 +699,6 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
         # Add a new remote server to the room
         self._add_new_user(room_id, "@alice:server2")
 
-        # We shouldn't have sent out any local presence *updates*
-        self.federation_sender.send_presence.assert_not_called()
-
         # When new server is joined we send it the local users presence states.
         # We expect to only see user @test2:server, as @test:server is offline
         # and has a zero last_active_ts
@@ -550,7 +717,6 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
         self.federation_sender.reset_mock()
         self._add_new_user(room_id, "@bob:server3")
 
-        self.federation_sender.send_presence.assert_not_called()
         self.federation_sender.send_presence_to_destinations.assert_called_once_with(
             destinations=["server3"], states={expected_state}
         )
@@ -595,9 +761,6 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
 
         self.reactor.pump([0])  # Wait for presence updates to be handled
 
-        # We shouldn't have sent out any local presence *updates*
-        self.federation_sender.send_presence.assert_not_called()
-
         # We expect to only send test2 presence to server2 and server3
         expected_state = self.get_success(
             self.presence_handler.current_state_for_user("@test2:server")