diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 882f3bbbdc..418b556108 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -19,20 +19,28 @@
# [This file includes modifications made by New Vector Limited]
#
#
+import logging
from typing import List, Optional, Tuple, cast
from twisted.test.proto_helpers import MemoryReactor
-from synapse.api.constants import Membership
+from synapse.api.constants import EventTypes, JoinRules, Membership
+from synapse.api.room_versions import RoomVersions
+from synapse.rest import admin
from synapse.rest.admin import register_servlets_for_client_rest_resource
-from synapse.rest.client import login, room
+from synapse.rest.client import knock, login, room
from synapse.server import HomeServer
+from synapse.storage.databases.main.roommember import extract_heroes_from_room_summary
+from synapse.storage.roommember import MemberSummary
from synapse.types import UserID, create_requester
from synapse.util import Clock
from tests import unittest
from tests.server import TestHomeServer
from tests.test_utils import event_injection
+from tests.unittest import skip_unless
+
+logger = logging.getLogger(__name__)
class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
@@ -240,6 +248,397 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
)
+class RoomSummaryTestCase(unittest.HomeserverTestCase):
+ """
+ Test `/sync` room summary related logic like `get_room_summary(...)` and
+ `extract_heroes_from_room_summary(...)`
+ """
+
+ servlets = [
+ admin.register_servlets,
+ knock.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.sliding_sync_handler = self.hs.get_sliding_sync_handler()
+ self.store = self.hs.get_datastores().main
+
+ def _assert_member_summary(
+ self,
+ actual_member_summary: MemberSummary,
+ expected_member_list: List[str],
+ *,
+ expected_member_count: Optional[int] = None,
+ ) -> None:
+ """
+ Assert that the `MemberSummary` object has the expected members.
+ """
+ self.assertListEqual(
+ [
+ user_id
+ for user_id, _membership_event_id in actual_member_summary.members
+ ],
+ expected_member_list,
+ )
+ self.assertEqual(
+ actual_member_summary.count,
+ (
+ expected_member_count
+ if expected_member_count is not None
+ else len(expected_member_list)
+ ),
+ )
+
+ def test_get_room_summary_membership(self) -> None:
+ """
+ Test that `get_room_summary(...)` gets every kind of membership when there
+ aren't that many members in the room.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+ user3_id = self.register_user("user3", "pass")
+ _user3_tok = self.login(user3_id, "pass")
+ user4_id = self.register_user("user4", "pass")
+ user4_tok = self.login(user4_id, "pass")
+ user5_id = self.register_user("user5", "pass")
+ user5_tok = self.login(user5_id, "pass")
+
+ # Setup a room (user1 is the creator and is joined to the room)
+ room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+
+ # User2 is banned
+ self.helper.join(room_id, user2_id, tok=user2_tok)
+ self.helper.ban(room_id, src=user1_id, targ=user2_id, tok=user1_tok)
+
+ # User3 is invited by user1
+ self.helper.invite(room_id, targ=user3_id, tok=user1_tok)
+
+ # User4 leaves
+ self.helper.join(room_id, user4_id, tok=user4_tok)
+ self.helper.leave(room_id, user4_id, tok=user4_tok)
+
+ # User5 joins
+ self.helper.join(room_id, user5_id, tok=user5_tok)
+
+ room_membership_summary = self.get_success(self.store.get_room_summary(room_id))
+ empty_ms = MemberSummary([], 0)
+
+ self._assert_member_summary(
+ room_membership_summary.get(Membership.JOIN, empty_ms),
+ [user1_id, user5_id],
+ )
+ self._assert_member_summary(
+ room_membership_summary.get(Membership.INVITE, empty_ms), [user3_id]
+ )
+ self._assert_member_summary(
+ room_membership_summary.get(Membership.LEAVE, empty_ms), [user4_id]
+ )
+ self._assert_member_summary(
+ room_membership_summary.get(Membership.BAN, empty_ms), [user2_id]
+ )
+ self._assert_member_summary(
+ room_membership_summary.get(Membership.KNOCK, empty_ms),
+ [
+ # No one knocked
+ ],
+ )
+
+ def test_get_room_summary_membership_order(self) -> None:
+ """
+ Test that `get_room_summary(...)` stacks our limit of 6 in this order: joins ->
+ invites -> leave -> everything else (bans/knocks)
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+ user3_id = self.register_user("user3", "pass")
+ _user3_tok = self.login(user3_id, "pass")
+ user4_id = self.register_user("user4", "pass")
+ user4_tok = self.login(user4_id, "pass")
+ user5_id = self.register_user("user5", "pass")
+ user5_tok = self.login(user5_id, "pass")
+ user6_id = self.register_user("user6", "pass")
+ user6_tok = self.login(user6_id, "pass")
+ user7_id = self.register_user("user7", "pass")
+ user7_tok = self.login(user7_id, "pass")
+
+ # Setup the room (user1 is the creator and is joined to the room)
+ room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+
+ # We expect the order to be joins -> invites -> leave -> bans so setup the users
+ # *NOT* in that same order to make sure we're actually sorting them.
+
+ # User2 is banned
+ self.helper.join(room_id, user2_id, tok=user2_tok)
+ self.helper.ban(room_id, src=user1_id, targ=user2_id, tok=user1_tok)
+
+ # User3 is invited by user1
+ self.helper.invite(room_id, targ=user3_id, tok=user1_tok)
+
+ # User4 leaves
+ self.helper.join(room_id, user4_id, tok=user4_tok)
+ self.helper.leave(room_id, user4_id, tok=user4_tok)
+
+ # User5, User6, User7 joins
+ self.helper.join(room_id, user5_id, tok=user5_tok)
+ self.helper.join(room_id, user6_id, tok=user6_tok)
+ self.helper.join(room_id, user7_id, tok=user7_tok)
+
+ room_membership_summary = self.get_success(self.store.get_room_summary(room_id))
+ empty_ms = MemberSummary([], 0)
+
+ self._assert_member_summary(
+ room_membership_summary.get(Membership.JOIN, empty_ms),
+ [user1_id, user5_id, user6_id, user7_id],
+ )
+ self._assert_member_summary(
+ room_membership_summary.get(Membership.INVITE, empty_ms), [user3_id]
+ )
+ self._assert_member_summary(
+ room_membership_summary.get(Membership.LEAVE, empty_ms), [user4_id]
+ )
+ self._assert_member_summary(
+ room_membership_summary.get(Membership.BAN, empty_ms),
+ [
+ # The banned user is not in the summary because the summary can only fit
+ # 6 members and prefers everything else before bans
+ #
+ # user2_id
+ ],
+ # But we still see the count of banned users
+ expected_member_count=1,
+ )
+ self._assert_member_summary(
+ room_membership_summary.get(Membership.KNOCK, empty_ms),
+ [
+ # No one knocked
+ ],
+ )
+
+ def test_extract_heroes_from_room_summary_excludes_self(self) -> None:
+ """
+ Test that `extract_heroes_from_room_summary(...)` does not include the user
+ itself.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ # Setup the room (user1 is the creator and is joined to the room)
+ room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+
+ # User2 joins
+ self.helper.join(room_id, user2_id, tok=user2_tok)
+
+ room_membership_summary = self.get_success(self.store.get_room_summary(room_id))
+
+ # We first ask from the perspective of a random fake user
+ hero_user_ids = extract_heroes_from_room_summary(
+ room_membership_summary, me="@fakeuser"
+ )
+
+ # Make sure user1 is in the room (ensure our test setup is correct)
+ self.assertListEqual(hero_user_ids, [user1_id, user2_id])
+
+ # Now, we ask for the room summary from the perspective of user1
+ hero_user_ids = extract_heroes_from_room_summary(
+ room_membership_summary, me=user1_id
+ )
+
+ # User1 should not be included in the list of heroes because they are the one
+ # asking
+ self.assertListEqual(hero_user_ids, [user2_id])
+
+ def test_extract_heroes_from_room_summary_first_five_joins(self) -> None:
+ """
+ Test that `extract_heroes_from_room_summary(...)` returns the first 5 joins.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+ user3_id = self.register_user("user3", "pass")
+ user3_tok = self.login(user3_id, "pass")
+ user4_id = self.register_user("user4", "pass")
+ user4_tok = self.login(user4_id, "pass")
+ user5_id = self.register_user("user5", "pass")
+ user5_tok = self.login(user5_id, "pass")
+ user6_id = self.register_user("user6", "pass")
+ user6_tok = self.login(user6_id, "pass")
+ user7_id = self.register_user("user7", "pass")
+ user7_tok = self.login(user7_id, "pass")
+
+ # Setup the room (user1 is the creator and is joined to the room)
+ room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+
+ # User2 -> User7 joins
+ self.helper.join(room_id, user2_id, tok=user2_tok)
+ self.helper.join(room_id, user3_id, tok=user3_tok)
+ self.helper.join(room_id, user4_id, tok=user4_tok)
+ self.helper.join(room_id, user5_id, tok=user5_tok)
+ self.helper.join(room_id, user6_id, tok=user6_tok)
+ self.helper.join(room_id, user7_id, tok=user7_tok)
+
+ room_membership_summary = self.get_success(self.store.get_room_summary(room_id))
+
+ hero_user_ids = extract_heroes_from_room_summary(
+ room_membership_summary, me="@fakuser"
+ )
+
+ # First 5 users to join the room
+ self.assertListEqual(
+ hero_user_ids, [user1_id, user2_id, user3_id, user4_id, user5_id]
+ )
+
+ def test_extract_heroes_from_room_summary_membership_order(self) -> None:
+ """
+ Test that `extract_heroes_from_room_summary(...)` prefers joins/invites over
+ everything else.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+ user3_id = self.register_user("user3", "pass")
+ _user3_tok = self.login(user3_id, "pass")
+ user4_id = self.register_user("user4", "pass")
+ user4_tok = self.login(user4_id, "pass")
+ user5_id = self.register_user("user5", "pass")
+ user5_tok = self.login(user5_id, "pass")
+
+ # Setup the room (user1 is the creator and is joined to the room)
+ room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+
+ # We expect the order to be joins -> invites -> leave -> bans so setup the users
+ # *NOT* in that same order to make sure we're actually sorting them.
+
+ # User2 is banned
+ self.helper.join(room_id, user2_id, tok=user2_tok)
+ self.helper.ban(room_id, src=user1_id, targ=user2_id, tok=user1_tok)
+
+ # User3 is invited by user1
+ self.helper.invite(room_id, targ=user3_id, tok=user1_tok)
+
+ # User4 leaves
+ self.helper.join(room_id, user4_id, tok=user4_tok)
+ self.helper.leave(room_id, user4_id, tok=user4_tok)
+
+ # User5 joins
+ self.helper.join(room_id, user5_id, tok=user5_tok)
+
+ room_membership_summary = self.get_success(self.store.get_room_summary(room_id))
+
+ hero_user_ids = extract_heroes_from_room_summary(
+ room_membership_summary, me="@fakeuser"
+ )
+
+ # Prefer joins -> invites, over everything else
+ self.assertListEqual(
+ hero_user_ids,
+ [
+ # The joins
+ user1_id,
+ user5_id,
+ # The invites
+ user3_id,
+ ],
+ )
+
+ @skip_unless(
+ False,
+ "Test is not possible because when everyone leaves the room, "
+ + "the server is `no_longer_in_room` and we don't have any `current_state_events` to query",
+ )
+ def test_extract_heroes_from_room_summary_fallback_leave_ban(self) -> None:
+ """
+ Test that `extract_heroes_from_room_summary(...)` falls back to leave/ban if
+ there aren't any joins/invites.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+ user3_id = self.register_user("user3", "pass")
+ user3_tok = self.login(user3_id, "pass")
+
+ # Setup the room (user1 is the creator and is joined to the room)
+ room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+
+ # User2 is banned
+ self.helper.join(room_id, user2_id, tok=user2_tok)
+ self.helper.ban(room_id, src=user1_id, targ=user2_id, tok=user1_tok)
+
+ # User3 leaves
+ self.helper.join(room_id, user3_id, tok=user3_tok)
+ self.helper.leave(room_id, user3_id, tok=user3_tok)
+
+ # User1 leaves (we're doing this last because they're the room creator)
+ self.helper.leave(room_id, user1_id, tok=user1_tok)
+
+ room_membership_summary = self.get_success(self.store.get_room_summary(room_id))
+
+ hero_user_ids = extract_heroes_from_room_summary(
+ room_membership_summary, me="@fakeuser"
+ )
+
+ # Fallback to people who left -> banned
+ self.assertListEqual(
+ hero_user_ids,
+ [user3_id, user1_id, user3_id],
+ )
+
+ def test_extract_heroes_from_room_summary_excludes_knocks(self) -> None:
+ """
+ People who knock on the room have (potentially) never been in the room before
+ and are total outsiders. Plus the spec doesn't mention them at all for heroes.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ # Setup the knock room (user1 is the creator and is joined to the room)
+ knock_room_id = self.helper.create_room_as(
+ user1_id, tok=user1_tok, room_version=RoomVersions.V7.identifier
+ )
+ self.helper.send_state(
+ knock_room_id,
+ EventTypes.JoinRules,
+ {"join_rule": JoinRules.KNOCK},
+ tok=user1_tok,
+ )
+
+ # User2 knocks on the room
+ knock_channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/knock/%s" % (knock_room_id,),
+ b"{}",
+ user2_tok,
+ )
+ self.assertEqual(knock_channel.code, 200, knock_channel.result)
+
+ room_membership_summary = self.get_success(
+ self.store.get_room_summary(knock_room_id)
+ )
+
+ hero_user_ids = extract_heroes_from_room_summary(
+ room_membership_summary, me="@fakeuser"
+ )
+
+ # user1 is the creator and is joined to the room (should show up as a hero)
+ # user2 is knocking on the room (should not show up as a hero)
+ self.assertListEqual(
+ hero_user_ids,
+ [user1_id],
+ )
+
+
class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
|