summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/14442.feature1
-rw-r--r--synapse/federation/federation_server.py23
-rw-r--r--synapse/handlers/sync.py20
-rw-r--r--synapse/storage/databases/main/roommember.py30
-rw-r--r--tests/federation/test_federation_server.py11
5 files changed, 60 insertions, 25 deletions
diff --git a/changelog.d/14442.feature b/changelog.d/14442.feature
new file mode 100644
index 0000000000..917e7edfb3
--- /dev/null
+++ b/changelog.d/14442.feature
@@ -0,0 +1 @@
+Faster joins: include heroes' membership events in the partial join response, for rooms without a name or canonical alias.
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 59e351595b..bb20af6e91 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -74,6 +74,8 @@ from synapse.replication.http.federation import (
 )
 from synapse.storage.databases.main.events import PartialStateConflictError
 from synapse.storage.databases.main.lock import Lock
+from synapse.storage.databases.main.roommember import extract_heroes_from_room_summary
+from synapse.storage.roommember import MemberSummary
 from synapse.types import JsonDict, StateMap, get_domain_from_id
 from synapse.util import json_decoder, unwrapFirstError
 from synapse.util.async_helpers import Linearizer, concurrently_execute, gather_results
@@ -691,8 +693,9 @@ class FederationServer(FederationBase):
         state_event_ids: Collection[str]
         servers_in_room: Optional[Collection[str]]
         if caller_supports_partial_state:
+            summary = await self.store.get_room_summary(room_id)
             state_event_ids = _get_event_ids_for_partial_state_join(
-                event, prev_state_ids
+                event, prev_state_ids, summary
             )
             servers_in_room = await self.state.get_hosts_in_room_at_events(
                 room_id, event_ids=event.prev_event_ids()
@@ -1495,6 +1498,7 @@ class FederationHandlerRegistry:
 def _get_event_ids_for_partial_state_join(
     join_event: EventBase,
     prev_state_ids: StateMap[str],
+    summary: Dict[str, MemberSummary],
 ) -> Collection[str]:
     """Calculate state to be retuned in a partial_state send_join
 
@@ -1521,8 +1525,19 @@ def _get_event_ids_for_partial_state_join(
     if current_membership_event_id is not None:
         state_event_ids.add(current_membership_event_id)
 
-    # TODO: return a few more members:
-    #   - those with invites
-    #   - those that are kicked? / banned
+    name_id = prev_state_ids.get((EventTypes.Name, ""))
+    canonical_alias_id = prev_state_ids.get((EventTypes.CanonicalAlias, ""))
+    if not name_id and not canonical_alias_id:
+        # Also include the hero members of the room (for DM rooms without a title).
+        # To do this properly, we should select the correct subset of membership events
+        # from `prev_state_ids`. Instead, we are lazier and use the (cached)
+        # `get_room_summary` function, which is based on the current state of the room.
+        # This introduces races; we choose to ignore them because a) they should be rare
+        # and b) even if it's wrong, joining servers will get the full state eventually.
+        heroes = extract_heroes_from_room_summary(summary, join_event.state_key)
+        for hero in heroes:
+            membership_event_id = prev_state_ids.get((EventTypes.Member, hero))
+            if membership_event_id:
+                state_event_ids.add(membership_event_id)
 
     return state_event_ids
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 1db5d68021..259456b55d 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -41,6 +41,7 @@ from synapse.logging.context import current_context
 from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span
 from synapse.push.clientformat import format_push_rules_for_user
 from synapse.storage.databases.main.event_push_actions import RoomNotifCounts
+from synapse.storage.databases.main.roommember import extract_heroes_from_room_summary
 from synapse.storage.roommember import MemberSummary
 from synapse.storage.state import StateFilter
 from synapse.types import (
@@ -805,18 +806,6 @@ class SyncHandler:
             if canonical_alias and canonical_alias.content.get("alias"):
                 return summary
 
-        me = sync_config.user.to_string()
-
-        joined_user_ids = [
-            r[0] for r in details.get(Membership.JOIN, empty_ms).members if r[0] != me
-        ]
-        invited_user_ids = [
-            r[0] for r in details.get(Membership.INVITE, empty_ms).members if r[0] != me
-        ]
-        gone_user_ids = [
-            r[0] for r in details.get(Membership.LEAVE, empty_ms).members if r[0] != me
-        ] + [r[0] for r in details.get(Membership.BAN, empty_ms).members if r[0] != me]
-
         # FIXME: only build up a member_ids list for our heroes
         member_ids = {}
         for membership in (
@@ -828,11 +817,8 @@ class SyncHandler:
             for user_id, event_id in details.get(membership, empty_ms).members:
                 member_ids[user_id] = event_id
 
-        # FIXME: order by stream ordering rather than as returned by SQL
-        if joined_user_ids or invited_user_ids:
-            summary["m.heroes"] = sorted(joined_user_ids + invited_user_ids)[0:5]
-        else:
-            summary["m.heroes"] = sorted(gone_user_ids)[0:5]
+        me = sync_config.user.to_string()
+        summary["m.heroes"] = extract_heroes_from_room_summary(details, me)
 
         if not sync_config.filter_collection.lazy_load_members():
             return summary
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index e56a13f21e..f02c1d7ea7 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -1517,6 +1517,36 @@ class RoomMemberStore(
         await self.db_pool.runInteraction("forget_membership", f)
 
 
+def extract_heroes_from_room_summary(
+    details: Mapping[str, MemberSummary], me: str
+) -> List[str]:
+    """Determine the users that represent a room, from the perspective of the `me` user.
+
+    The rules which say which users we select are specified in the "Room Summary"
+    section of
+    https://spec.matrix.org/v1.4/client-server-api/#get_matrixclientv3sync
+
+    Returns a list (possibly empty) of heroes' mxids.
+    """
+    empty_ms = MemberSummary([], 0)
+
+    joined_user_ids = [
+        r[0] for r in details.get(Membership.JOIN, empty_ms).members if r[0] != me
+    ]
+    invited_user_ids = [
+        r[0] for r in details.get(Membership.INVITE, empty_ms).members if r[0] != me
+    ]
+    gone_user_ids = [
+        r[0] for r in details.get(Membership.LEAVE, empty_ms).members if r[0] != me
+    ] + [r[0] for r in details.get(Membership.BAN, empty_ms).members if r[0] != me]
+
+    # FIXME: order by stream ordering rather than as returned by SQL
+    if joined_user_ids or invited_user_ids:
+        return sorted(joined_user_ids + invited_user_ids)[0:5]
+    else:
+        return sorted(gone_user_ids)[0:5]
+
+
 @attr.s(slots=True, auto_attribs=True)
 class _JoinedHostsCache:
     """The cached data used by the `_get_joined_hosts_cache`."""
diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py
index 3a6ef221ae..177e5b5afc 100644
--- a/tests/federation/test_federation_server.py
+++ b/tests/federation/test_federation_server.py
@@ -212,7 +212,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
         self.assertEqual(r[("m.room.member", joining_user)].membership, "join")
 
     @override_config({"experimental_features": {"msc3706_enabled": True}})
-    def test_send_join_partial_state(self):
+    def test_send_join_partial_state(self) -> None:
         """When MSC3706 support is enabled, /send_join should return partial state"""
         joining_user = "@misspiggy:" + self.OTHER_SERVER_NAME
         join_result = self._make_join(joining_user)
@@ -240,6 +240,9 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
                 ("m.room.power_levels", ""),
                 ("m.room.join_rules", ""),
                 ("m.room.history_visibility", ""),
+                # Users included here because they're heroes.
+                ("m.room.member", "@kermit:test"),
+                ("m.room.member", "@fozzie:test"),
             ],
         )
 
@@ -249,9 +252,9 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
         ]
         self.assertCountEqual(
             returned_auth_chain_events,
-            [
-                ("m.room.member", "@kermit:test"),
-            ],
+            # TODO: change the test so that we get at least one event in the auth chain
+            #   here.
+            [],
         )
 
         # the room should show that the new user is a member