summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/9882.misc1
-rw-r--r--changelog.d/9902.feature1
-rw-r--r--changelog.d/9910.misc1
-rw-r--r--changelog.d/9916.misc1
-rw-r--r--docs/sample_config.yaml8
-rw-r--r--synapse/app/generic_worker.py3
-rw-r--r--synapse/app/homeserver.py3
-rw-r--r--synapse/config/server.py9
-rw-r--r--synapse/federation/federation_client.py10
-rw-r--r--synapse/handlers/federation.py21
-rw-r--r--synapse/handlers/presence.py134
-rw-r--r--synapse/http/matrixfederationclient.py3
-rw-r--r--synapse/metrics/__init__.py179
-rw-r--r--tests/handlers/test_presence.py14
14 files changed, 305 insertions, 83 deletions
diff --git a/changelog.d/9882.misc b/changelog.d/9882.misc
new file mode 100644
index 0000000000..facfa31f38
--- /dev/null
+++ b/changelog.d/9882.misc
@@ -0,0 +1 @@
+Export jemalloc stats to Prometheus if it is being used.
diff --git a/changelog.d/9902.feature b/changelog.d/9902.feature
new file mode 100644
index 0000000000..4d9f324d4e
--- /dev/null
+++ b/changelog.d/9902.feature
@@ -0,0 +1 @@
+Add limits to how often Synapse will GC, ensuring that large servers do not end up GC thrashing if `gc_thresholds` has not been correctly set.
diff --git a/changelog.d/9910.misc b/changelog.d/9910.misc
new file mode 100644
index 0000000000..54165cce18
--- /dev/null
+++ b/changelog.d/9910.misc
@@ -0,0 +1 @@
+Improve performance after joining a large room when presence is enabled.
diff --git a/changelog.d/9916.misc b/changelog.d/9916.misc
new file mode 100644
index 0000000000..401298fa3d
--- /dev/null
+++ b/changelog.d/9916.misc
@@ -0,0 +1 @@
+Improve performance of handling presence when joining large rooms.
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index e0350279ad..ebf364cf40 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -152,6 +152,14 @@ presence:
 #
 #gc_thresholds: [700, 10, 10]
 
+# The minimum time in seconds between each GC for a generation, regardless of
+# the GC thresholds. This ensures that we don't do GC too frequently.
+#
+# A value of `[1, 10, 30]` indicates that a second must pass between consecutive
+# generation 0 GCs, etc.
+#
+# gc_min_seconds_between: [1, 10, 30]
+
 # Set the limit on the returned events in the timeline in the get
 # and sync operations. The default value is 100. -1 means no upper limit.
 #
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 1a15ceee81..a3fe9a3f38 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -455,6 +455,9 @@ def start(config_options):
 
     synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts
 
+    if config.server.gc_seconds:
+        synapse.metrics.MIN_TIME_BETWEEN_GCS = config.server.gc_seconds
+
     hs = GenericWorkerServer(
         config.server_name,
         config=config,
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 8e78134bbe..6a823da10d 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -342,6 +342,9 @@ def setup(config_options):
 
     events.USE_FROZEN_DICTS = config.use_frozen_dicts
 
+    if config.server.gc_seconds:
+        synapse.metrics.MIN_TIME_BETWEEN_GCS = config.server.gc_seconds
+
     hs = SynapseHomeServer(
         config.server_name,
         config=config,
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 21ca7b33e3..ca1c9711f8 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -572,6 +572,7 @@ class ServerConfig(Config):
             _warn_if_webclient_configured(self.listeners)
 
         self.gc_thresholds = read_gc_thresholds(config.get("gc_thresholds", None))
+        self.gc_seconds = read_gc_thresholds(config.get("gc_min_seconds_between", None))
 
         @attr.s
         class LimitRemoteRoomsConfig:
@@ -917,6 +918,14 @@ class ServerConfig(Config):
         #
         #gc_thresholds: [700, 10, 10]
 
+        # The minimum time in seconds between each GC for a generation, regardless of
+        # the GC thresholds. This ensures that we don't do GC too frequently.
+        #
+        # A value of `[1, 10, 30]` indicates that a second must pass between consecutive
+        # generation 0 GCs, etc.
+        #
+        # gc_min_seconds_between: [1, 10, 30]
+
         # Set the limit on the returned events in the timeline in the get
         # and sync operations. The default value is 100. -1 means no upper limit.
         #
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 481f3f6438..40225abf81 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -670,13 +670,15 @@ class FederationClient(FederationBase):
 
             logger.debug("Got content: %s", content)
 
+            # logger.info("send_join content: %d", len(content))
+
             content.seek(0)
             state = [
                 event_from_pdu_json(p, room_version, outlier=True)
                 for p in ijson.items(content, "state.item")
             ]
 
-            logger.debug("state: %s", state)
+            logger.info("Parsed auth chain: %d", len(state))
             content.seek(0)
 
             auth_chain = [
@@ -684,7 +686,7 @@ class FederationClient(FederationBase):
                 for p in ijson.items(content, "auth_chain.item")
             ]
 
-            logger.debug("auth_chain: %s", auth_chain)
+            logger.info("Parsed auth chain: %d", len(auth_chain))
 
             pdus = {p.event_id: p for p in itertools.chain(state, auth_chain)}
 
@@ -718,6 +720,8 @@ class FederationClient(FederationBase):
                 room_version=room_version,
             )
 
+            logger.info("_check_sigs_and_hash_and_fetch done")
+
             valid_pdus_map = {p.event_id: p for p in valid_pdus}
 
             # NB: We *need* to copy to ensure that we don't have multiple
@@ -751,6 +755,8 @@ class FederationClient(FederationBase):
                     % (auth_chain_create_events,)
                 )
 
+            logger.info("Returning from send_join")
+
             return {
                 "state": signed_state,
                 "auth_chain": signed_auth,
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 9d867aaf4d..69055a14b3 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -1452,7 +1452,7 @@ class FederationHandler(BaseHandler):
         # room stuff after join currently doesn't work on workers.
         assert self.config.worker.worker_app is None
 
-        logger.debug("Joining %s to %s", joinee, room_id)
+        logger.info("Joining %s to %s", joinee, room_id)
 
         origin, event, room_version_obj = await self._make_and_verify_event(
             target_hosts,
@@ -1463,6 +1463,8 @@ class FederationHandler(BaseHandler):
             params={"ver": KNOWN_ROOM_VERSIONS},
         )
 
+        logger.info("make_join done from %s", origin)
+
         # This shouldn't happen, because the RoomMemberHandler has a
         # linearizer lock which only allows one operation per user per room
         # at a time - so this is just paranoia.
@@ -1482,10 +1484,13 @@ class FederationHandler(BaseHandler):
             except ValueError:
                 pass
 
+            logger.info("Sending join")
             ret = await self.federation_client.send_join(
                 host_list, event, room_version_obj
             )
 
+            logger.info("send join done")
+
             origin = ret["origin"]
             state = ret["state"]
             auth_chain = ret["auth_chain"]
@@ -1510,10 +1515,14 @@ class FederationHandler(BaseHandler):
                 room_version=room_version_obj,
             )
 
+            logger.info("Persisting auth true")
+
             max_stream_id = await self._persist_auth_tree(
                 origin, room_id, auth_chain, state, event, room_version_obj
             )
 
+            logger.info("Persisted auth true")
+
             # We wait here until this instance has seen the events come down
             # replication (if we're using replication) as the below uses caches.
             await self._replication.wait_for_stream_position(
@@ -2166,6 +2175,8 @@ class FederationHandler(BaseHandler):
             ctx = await self.state_handler.compute_event_context(e)
             events_to_context[e.event_id] = ctx
 
+        logger.info("Computed contexts")
+
         event_map = {
             e.event_id: e for e in itertools.chain(auth_events, state, [event])
         }
@@ -2207,6 +2218,8 @@ class FederationHandler(BaseHandler):
             else:
                 logger.info("Failed to find auth event %r", e_id)
 
+        logger.info("Got missing events")
+
         for e in itertools.chain(auth_events, state, [event]):
             auth_for_e = {
                 (event_map[e_id].type, event_map[e_id].state_key): event_map[e_id]
@@ -2231,6 +2244,8 @@ class FederationHandler(BaseHandler):
                     raise
                 events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR
 
+        logger.info("Authed events")
+
         await self.persist_events_and_notify(
             room_id,
             [
@@ -2239,10 +2254,14 @@ class FederationHandler(BaseHandler):
             ],
         )
 
+        logger.info("Persisted events")
+
         new_event_context = await self.state_handler.compute_event_context(
             event, old_state=state
         )
 
+        logger.info("Computed context")
+
         return await self.persist_events_and_notify(
             room_id, [(event, new_event_context)]
         )
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index ebbc234334..e9e0f1338f 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -1196,8 +1196,11 @@ class PresenceHandler(BasePresenceHandler):
         """Process current state deltas to find new joins that need to be
         handled.
         """
-        # A map of destination to a set of user state that they should receive
-        presence_destinations = {}  # type: Dict[str, Set[UserPresenceState]]
+
+        # Sets of newly joined users. Note that if the local server is
+        # joining a remote room for the first time we'll see both the joining
+        # user and all remote users as newly joined.
+        newly_joined_users = set()
 
         for delta in deltas:
             typ = delta["type"]
@@ -1231,72 +1234,55 @@ class PresenceHandler(BasePresenceHandler):
                     # Ignore changes to join events.
                     continue
 
-            # Retrieve any user presence state updates that need to be sent as a result,
-            # and the destinations that need to receive it
-            destinations, user_presence_states = await self._on_user_joined_room(
-                room_id, state_key
-            )
-
-            # Insert the destinations and respective updates into our destinations dict
-            for destination in destinations:
-                presence_destinations.setdefault(destination, set()).update(
-                    user_presence_states
-                )
-
-        # Send out user presence updates for each destination
-        for destination, user_state_set in presence_destinations.items():
-            self._federation_queue.send_presence_to_destinations(
-                destinations=[destination], states=user_state_set
-            )
-
-    async def _on_user_joined_room(
-        self, room_id: str, user_id: str
-    ) -> Tuple[List[str], List[UserPresenceState]]:
-        """Called when we detect a user joining the room via the current state
-        delta stream. Returns the destinations that need to be updated and the
-        presence updates to send to them.
-
-        Args:
-            room_id: The ID of the room that the user has joined.
-            user_id: The ID of the user that has joined the room.
-
-        Returns:
-            A tuple of destinations and presence updates to send to them.
-        """
-        if self.is_mine_id(user_id):
-            # If this is a local user then we need to send their presence
-            # out to hosts in the room (who don't already have it)
-
-            # TODO: We should be able to filter the hosts down to those that
-            # haven't previously seen the user
-
-            remote_hosts = await self.state.get_current_hosts_in_room(room_id)
-
-            # Filter out ourselves.
-            filtered_remote_hosts = [
-                host for host in remote_hosts if host != self.server_name
-            ]
-
-            state = await self.current_state_for_user(user_id)
-            return filtered_remote_hosts, [state]
-        else:
-            # A remote user has joined the room, so we need to:
-            #   1. Check if this is a new server in the room
-            #   2. If so send any presence they don't already have for
-            #      local users in the room.
-
-            # TODO: We should be able to filter the users down to those that
-            # the server hasn't previously seen
+            newly_joined_users.add(state_key)
 
-            # TODO: Check that this is actually a new server joining the
-            # room.
-
-            remote_host = get_domain_from_id(user_id)
+        if not newly_joined_users:
+            # If nobody has joined then there's nothing to do.
+            return
 
-            users = await self.state.get_current_users_in_room(room_id)
-            user_ids = list(filter(self.is_mine_id, users))
+        # We want to send:
+        #   1. presence states of all local users in the room to newly joined
+        #      remote servers
+        #   2. presence states of newly joined users to all remote servers in
+        #      the room.
+        #
+        # TODO: Only send presence states to remote hosts that don't already
+        # have them (because they already share rooms).
+
+        # Get all the users who were already in the room, by fetching the
+        # current users in the room and removing the newly joined users.
+        users = await self.store.get_users_in_room(room_id)
+        prev_users = set(users) - newly_joined_users
+
+        # Construct sets for all the local users and remote hosts that were
+        # already in the room
+        prev_local_users = set()
+        prev_remote_hosts = set()
+        for user_id in prev_users:
+            if self.is_mine_id(user_id):
+                prev_local_users.add(user_id)
+            else:
+                prev_remote_hosts.add(get_domain_from_id(user_id))
+
+        # Similarly, construct sets for all the local users and remote hosts
+        # that were *not* already in the room. Care needs to be taken with the
+        # calculating the remote hosts, as a host may have already been in the
+        # room even if there is a newly joined user from that host.
+        newly_joined_local_users = set()
+        newly_joined_remote_hosts = set()
+        for user_id in newly_joined_users:
+            if self.is_mine_id(user_id):
+                newly_joined_local_users.add(user_id)
+            else:
+                host = get_domain_from_id(user_id)
+                if host not in prev_remote_hosts:
+                    newly_joined_remote_hosts.add(host)
 
-            states_d = await self.current_state_for_users(user_ids)
+        # Send presence states of all local users in the room to newly joined
+        # remote servers. (We actually only send states for local users already
+        # in the room, as we'll send states for newly joined local users below.)
+        if prev_local_users and newly_joined_remote_hosts:
+            local_states = await self.current_state_for_users(prev_local_users)
 
             # Filter out old presence, i.e. offline presence states where
             # the user hasn't been active for a week. We can change this
@@ -1306,13 +1292,27 @@ class PresenceHandler(BasePresenceHandler):
             now = self.clock.time_msec()
             states = [
                 state
-                for state in states_d.values()
+                for state in local_states.values()
                 if state.state != PresenceState.OFFLINE
                 or now - state.last_active_ts < 7 * 24 * 60 * 60 * 1000
                 or state.status_msg is not None
             ]
 
-            return [remote_host], states
+            self._federation_queue.send_presence_to_destinations(
+                destinations=newly_joined_remote_hosts,
+                states=states,
+            )
+
+        # Send presence states of newly joined users to all remote servers in
+        # the room
+        if newly_joined_local_users and (
+            prev_remote_hosts or newly_joined_remote_hosts
+        ):
+            local_states = await self.current_state_for_users(newly_joined_local_users)
+            self._federation_queue.send_presence_to_destinations(
+                destinations=prev_remote_hosts | newly_joined_remote_hosts,
+                states=list(local_states.values()),
+            )
 
 
 def should_notify(old_state: UserPresenceState, new_state: UserPresenceState) -> bool:
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index d44fa0cb9b..6db1aece35 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -226,12 +226,13 @@ async def _handle_json_response(
     time_taken_secs = reactor.seconds() - start_ms / 1000
 
     logger.info(
-        "{%s} [%s] Completed request: %d %s in %.2f secs - %s %s",
+        "{%s} [%s] Completed request: %d %s in %.2f secs got %dB - %s %s",
         request.txn_id,
         request.destination,
         response.code,
         response.phrase.decode("ascii", errors="replace"),
         time_taken_secs,
+        len(buf.getvalue()),
         request.method,
         request.uri.decode("ascii"),
     )
diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py
index 31b7b3c256..c841363b1e 100644
--- a/synapse/metrics/__init__.py
+++ b/synapse/metrics/__init__.py
@@ -12,12 +12,15 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import ctypes
+import ctypes.util
 import functools
 import gc
 import itertools
 import logging
 import os
 import platform
+import re
 import threading
 import time
 from typing import Callable, Dict, Iterable, Optional, Tuple, Union
@@ -535,6 +538,13 @@ class ReactorLastSeenMetric:
 
 REGISTRY.register(ReactorLastSeenMetric())
 
+# The minimum time in seconds between GCs for each generation, regardless of the current GC
+# thresholds and counts.
+MIN_TIME_BETWEEN_GCS = [1, 10, 30]
+
+# The time in seconds of the last time we did a GC for each generation.
+_last_gc = [0, 0, 0]
+
 
 def runUntilCurrentTimer(reactor, func):
     @functools.wraps(func)
@@ -575,11 +585,16 @@ def runUntilCurrentTimer(reactor, func):
             return ret
 
         # Check if we need to do a manual GC (since its been disabled), and do
-        # one if necessary.
+        # one if necessary. Note we go in reverse order as e.g. a gen 1 GC may
+        # promote an object into gen 2, and we don't want to handle the same
+        # object multiple times.
         threshold = gc.get_threshold()
         counts = gc.get_count()
         for i in (2, 1, 0):
-            if threshold[i] < counts[i]:
+            # We check if we need to do one based on a straightforward
+            # comparison between the threshold and count. We also do an extra
+            # check to make sure that we don't a GC too often.
+            if threshold[i] < counts[i] and MIN_TIME_BETWEEN_GCS[i] < end - _last_gc[i]:
                 if i == 0:
                     logger.debug("Collecting gc %d", i)
                 else:
@@ -589,6 +604,8 @@ def runUntilCurrentTimer(reactor, func):
                 unreachable = gc.collect(i)
                 end = time.time()
 
+                _last_gc[i] = int(end)
+
                 gc_time.labels(i).observe(end - start)
                 gc_unreachable.labels(i).set(unreachable)
 
@@ -597,6 +614,163 @@ def runUntilCurrentTimer(reactor, func):
     return f
 
 
+def _setup_jemalloc_stats():
+    """Checks to see if jemalloc is loaded, and hooks up a collector to record
+    statistics exposed by jemalloc.
+    """
+
+    # Try to find the loaded jemalloc shared library, if any. We need to
+    # introspect into what is loaded, rather than loading whatever is on the
+    # path, as if we load a *different* jemalloc version things will seg fault.
+    pid = os.getpid()
+
+    # We're looking for a path at the end of the line that includes
+    # "libjemalloc".
+    regex = re.compile(r"/\S+/libjemalloc.*$")
+
+    jemalloc_path = None
+    with open(f"/proc/{pid}/maps") as f:
+        for line in f.readlines():
+            match = regex.search(line.strip())
+            if match:
+                jemalloc_path = match.group()
+
+    if not jemalloc_path:
+        # No loaded jemalloc was found.
+        return
+
+    jemalloc = ctypes.CDLL(jemalloc_path)
+
+    def _mallctl(
+        name: str, read: bool = True, write: Optional[int] = None
+    ) -> Optional[int]:
+        """Wrapper around `mallctl` for reading and writing integers to
+        jemalloc.
+
+        Args:
+            name: The name of the option to read from/write to.
+            read: Whether to try and read the value.
+            write: The value to write, if given.
+
+        Returns:
+            The value read if `read` is True, otherwise None.
+
+        Raises:
+            An exception if `mallctl` returns a non-zero error code.
+        """
+
+        input_var = None
+        input_var_ref = None
+        input_len_ref = None
+        if read:
+            input_var = ctypes.c_size_t(0)
+            input_len = ctypes.c_size_t(ctypes.sizeof(input_var))
+
+            input_var_ref = ctypes.byref(input_var)
+            input_len_ref = ctypes.byref(input_len)
+
+        write_var_ref = None
+        write_len = ctypes.c_size_t(0)
+        if write is not None:
+            write_var = ctypes.c_size_t(write)
+            write_len = ctypes.c_size_t(ctypes.sizeof(write_var))
+
+            write_var_ref = ctypes.byref(write_var)
+
+        # The interface is:
+        #
+        #   int mallctl(
+        #       const char *name,
+        #       void *oldp,
+        #       size_t *oldlenp,
+        #       void *newp,
+        #       size_t newlen
+        #   )
+        #
+        # Where oldp/oldlenp is a buffer where the old value will be written to
+        # (if not null), and newp/newlen is the buffer with the new value to set
+        # (if not null). Note that they're all references *except* newlen.
+        result = jemalloc.mallctl(
+            name.encode("ascii"),
+            input_var_ref,
+            input_len_ref,
+            write_var_ref,
+            write_len,
+        )
+
+        if result != 0:
+            raise Exception("Failed to call mallctl")
+
+        if input_var is None:
+            return None
+
+        return input_var.value
+
+    def _jemalloc_refresh_stats() -> None:
+        """Request that jemalloc updates its internal statistics. This needs to
+        be called before querying for stats, otherwise it will return stale
+        values.
+        """
+        try:
+            _mallctl("epoch", read=False, write=1)
+        except Exception:
+            pass
+
+    class JemallocCollector:
+        """Metrics for internal jemalloc stats."""
+
+        def collect(self):
+            _jemalloc_refresh_stats()
+
+            g = GaugeMetricFamily(
+                "jemalloc_stats_app_memory",
+                "The stats reported by jemalloc",
+                labels=["type"],
+            )
+
+            # Read the relevant global stats from jemalloc. Note that these may
+            # not be accurate if python is configured to use its internal small
+            # object allocator (which is on by default, disable by setting the
+            # env `PYTHONMALLOC=malloc`).
+            #
+            # See the jemalloc manpage for details about what each value means,
+            # roughly:
+            #   - allocated ─ Total number of bytes allocated by the app
+            #   - active ─ Total number of bytes in active pages allocated by
+            #     the application, this is bigger than `allocated`.
+            #   - resident ─ Maximum number of bytes in physically resident data
+            #     pages mapped by the allocator, comprising all pages dedicated
+            #     to allocator metadata, pages backing active allocations, and
+            #     unused dirty pages. This is bigger than `active`.
+            #   - mapped ─ Total number of bytes in active extents mapped by the
+            #     allocator.
+            #   - metadata ─ Total number of bytes dedicated to jemalloc
+            #     metadata.
+            for t in (
+                "allocated",
+                "active",
+                "resident",
+                "mapped",
+                "metadata",
+            ):
+                try:
+                    value = _mallctl(f"stats.{t}")
+                except Exception:
+                    # There was an error fetching the value, skip.
+                    continue
+
+                g.add_metric([t], value=value)
+
+            yield g
+
+    REGISTRY.register(JemallocCollector())
+
+
+try:
+    _setup_jemalloc_stats()
+except Exception:
+    logger.info("Failed to setup collector to record jemalloc stats.")
+
 try:
     # Ensure the reactor has all the attributes we expect
     reactor.seconds  # type: ignore
@@ -615,6 +789,7 @@ try:
 except AttributeError:
     pass
 
+
 __all__ = [
     "MetricsResource",
     "generate_latest",
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index ce330e79cc..1ffab709fc 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -729,7 +729,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
         )
         self.assertEqual(expected_state.state, PresenceState.ONLINE)
         self.federation_sender.send_presence_to_destinations.assert_called_once_with(
-            destinations=["server2"], states={expected_state}
+            destinations={"server2"}, states=[expected_state]
         )
 
         #
@@ -740,7 +740,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
         self._add_new_user(room_id, "@bob:server3")
 
         self.federation_sender.send_presence_to_destinations.assert_called_once_with(
-            destinations=["server3"], states={expected_state}
+            destinations={"server3"}, states=[expected_state]
         )
 
     def test_remote_gets_presence_when_local_user_joins(self):
@@ -788,14 +788,8 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
             self.presence_handler.current_state_for_user("@test2:server")
         )
         self.assertEqual(expected_state.state, PresenceState.ONLINE)
-        self.assertEqual(
-            self.federation_sender.send_presence_to_destinations.call_count, 2
-        )
-        self.federation_sender.send_presence_to_destinations.assert_any_call(
-            destinations=["server3"], states={expected_state}
-        )
-        self.federation_sender.send_presence_to_destinations.assert_any_call(
-            destinations=["server2"], states={expected_state}
+        self.federation_sender.send_presence_to_destinations.assert_called_once_with(
+            destinations={"server2", "server3"}, states=[expected_state]
         )
 
     def _add_new_user(self, room_id, user_id):