diff --git a/tests/handlers/test_sliding_sync.py b/tests/handlers/test_sliding_sync.py
index eb4b0a05c7..a7aa9bb8af 100644
--- a/tests/handlers/test_sliding_sync.py
+++ b/tests/handlers/test_sliding_sync.py
@@ -19,7 +19,7 @@
#
import logging
from copy import deepcopy
-from typing import Optional
+from typing import Dict, Optional
from unittest.mock import patch
from parameterized import parameterized
@@ -37,12 +37,16 @@ from synapse.api.constants import (
from synapse.api.room_versions import RoomVersions
from synapse.events import make_event_from_dict
from synapse.events.snapshot import EventContext
-from synapse.handlers.sliding_sync import RoomSyncConfig, StateValues
+from synapse.handlers.sliding_sync import (
+ RoomSyncConfig,
+ StateValues,
+ _RoomMembershipForUser,
+)
from synapse.rest import admin
from synapse.rest.client import knock, login, room
from synapse.server import HomeServer
from synapse.storage.util.id_generators import MultiWriterIdGenerator
-from synapse.types import JsonDict, UserID
+from synapse.types import JsonDict, StreamToken, UserID
from synapse.types.handlers import SlidingSyncConfig
from synapse.util import Clock
@@ -581,9 +585,9 @@ class RoomSyncConfigTestCase(TestCase):
self._assert_room_config_equal(room_sync_config_b, expected, "A into B")
-class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
+class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
"""
- Tests Sliding Sync handler `get_sync_room_ids_for_user()` to make sure it returns
+ Tests Sliding Sync handler `get_room_membership_for_user_at_to_token()` to make sure it returns
the correct list of rooms IDs.
"""
@@ -616,7 +620,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
now_token = self.event_sources.get_current_token()
room_id_results = self.get_success(
- self.sliding_sync_handler.get_sync_room_ids_for_user(
+ self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
UserID.from_string(user1_id),
from_token=now_token,
to_token=now_token,
@@ -643,7 +647,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
after_room_token = self.event_sources.get_current_token()
room_id_results = self.get_success(
- self.sliding_sync_handler.get_sync_room_ids_for_user(
+ self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
UserID.from_string(user1_id),
from_token=before_room_token,
to_token=after_room_token,
@@ -657,9 +661,11 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
room_id_results[room_id].event_id,
join_response["event_id"],
)
+ self.assertEqual(room_id_results[room_id].membership, Membership.JOIN)
# We should be considered `newly_joined` because we joined during the token
# range
self.assertEqual(room_id_results[room_id].newly_joined, True)
+ self.assertEqual(room_id_results[room_id].newly_left, False)
def test_get_already_joined_room(self) -> None:
"""
@@ -676,7 +682,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
after_room_token = self.event_sources.get_current_token()
room_id_results = self.get_success(
- self.sliding_sync_handler.get_sync_room_ids_for_user(
+ self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
UserID.from_string(user1_id),
from_token=after_room_token,
to_token=after_room_token,
@@ -690,8 +696,10 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
room_id_results[room_id].event_id,
join_response["event_id"],
)
+ self.assertEqual(room_id_results[room_id].membership, Membership.JOIN)
# We should *NOT* be `newly_joined` because we joined before the token range
self.assertEqual(room_id_results[room_id].newly_joined, False)
+ self.assertEqual(room_id_results[room_id].newly_left, False)
def test_get_invited_banned_knocked_room(self) -> None:
"""
@@ -748,7 +756,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
after_room_token = self.event_sources.get_current_token()
room_id_results = self.get_success(
- self.sliding_sync_handler.get_sync_room_ids_for_user(
+ self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
UserID.from_string(user1_id),
from_token=before_room_token,
to_token=after_room_token,
@@ -770,19 +778,25 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
room_id_results[invited_room_id].event_id,
invite_response["event_id"],
)
+ self.assertEqual(room_id_results[invited_room_id].membership, Membership.INVITE)
+ self.assertEqual(room_id_results[invited_room_id].newly_joined, False)
+ self.assertEqual(room_id_results[invited_room_id].newly_left, False)
+
self.assertEqual(
room_id_results[ban_room_id].event_id,
ban_response["event_id"],
)
+ self.assertEqual(room_id_results[ban_room_id].membership, Membership.BAN)
+ self.assertEqual(room_id_results[ban_room_id].newly_joined, False)
+ self.assertEqual(room_id_results[ban_room_id].newly_left, False)
+
self.assertEqual(
room_id_results[knock_room_id].event_id,
knock_room_membership_state_event.event_id,
)
- # We should *NOT* be `newly_joined` because we were not joined at the the time
- # of the `to_token`.
- self.assertEqual(room_id_results[invited_room_id].newly_joined, False)
- self.assertEqual(room_id_results[ban_room_id].newly_joined, False)
+ self.assertEqual(room_id_results[knock_room_id].membership, Membership.KNOCK)
self.assertEqual(room_id_results[knock_room_id].newly_joined, False)
+ self.assertEqual(room_id_results[knock_room_id].newly_left, False)
def test_get_kicked_room(self) -> None:
"""
@@ -814,7 +828,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
after_kick_token = self.event_sources.get_current_token()
room_id_results = self.get_success(
- self.sliding_sync_handler.get_sync_room_ids_for_user(
+ self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
UserID.from_string(user1_id),
from_token=after_kick_token,
to_token=after_kick_token,
@@ -828,9 +842,12 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
room_id_results[kick_room_id].event_id,
kick_response["event_id"],
)
+ self.assertEqual(room_id_results[kick_room_id].membership, Membership.LEAVE)
+ self.assertNotEqual(room_id_results[kick_room_id].sender, user1_id)
# We should *NOT* be `newly_joined` because we were not joined at the the time
# of the `to_token`.
self.assertEqual(room_id_results[kick_room_id].newly_joined, False)
+ self.assertEqual(room_id_results[kick_room_id].newly_left, False)
def test_forgotten_rooms(self) -> None:
"""
@@ -904,7 +921,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
self.assertEqual(channel.code, 200, channel.result)
room_id_results = self.get_success(
- self.sliding_sync_handler.get_sync_room_ids_for_user(
+ self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
UserID.from_string(user1_id),
from_token=before_room_forgets,
to_token=before_room_forgets,
@@ -914,52 +931,58 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
# We shouldn't see the room because it was forgotten
self.assertEqual(room_id_results.keys(), set())
- def test_only_newly_left_rooms_show_up(self) -> None:
+ def test_newly_left_rooms(self) -> None:
"""
- Test that newly_left rooms still show up in the sync response but rooms that
- were left before the `from_token` don't show up. See condition "2)" comments in
- the `get_sync_room_ids_for_user` method.
+ Test that newly_left are marked properly
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
# Leave before we calculate the `from_token`
room_id1 = self.helper.create_room_as(user1_id, tok=user1_tok)
- self.helper.leave(room_id1, user1_id, tok=user1_tok)
+ leave_response1 = self.helper.leave(room_id1, user1_id, tok=user1_tok)
after_room1_token = self.event_sources.get_current_token()
# Leave during the from_token/to_token range (newly_left)
room_id2 = self.helper.create_room_as(user1_id, tok=user1_tok)
- _leave_response2 = self.helper.leave(room_id2, user1_id, tok=user1_tok)
+ leave_response2 = self.helper.leave(room_id2, user1_id, tok=user1_tok)
after_room2_token = self.event_sources.get_current_token()
room_id_results = self.get_success(
- self.sliding_sync_handler.get_sync_room_ids_for_user(
+ self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
UserID.from_string(user1_id),
from_token=after_room1_token,
to_token=after_room2_token,
)
)
- # Only the newly_left room should show up
- self.assertEqual(room_id_results.keys(), {room_id2})
- # It should be pointing to the latest membership event in the from/to range but
- # the `event_id` is `None` because we left the room causing the server to leave
- # the room because no other local users are in it (quirk of the
- # `current_state_delta_stream` table that we source things from)
+ self.assertEqual(room_id_results.keys(), {room_id1, room_id2})
+
+ self.assertEqual(
+ room_id_results[room_id1].event_id,
+ leave_response1["event_id"],
+ )
+ self.assertEqual(room_id_results[room_id1].membership, Membership.LEAVE)
+ # We should *NOT* be `newly_joined` or `newly_left` because that happened before
+ # the from/to range
+ self.assertEqual(room_id_results[room_id1].newly_joined, False)
+ self.assertEqual(room_id_results[room_id1].newly_left, False)
+
self.assertEqual(
room_id_results[room_id2].event_id,
- None, # _leave_response2["event_id"],
+ leave_response2["event_id"],
)
+ self.assertEqual(room_id_results[room_id2].membership, Membership.LEAVE)
# We should *NOT* be `newly_joined` because we are instead `newly_left`
self.assertEqual(room_id_results[room_id2].newly_joined, False)
+ self.assertEqual(room_id_results[room_id2].newly_left, True)
def test_no_joins_after_to_token(self) -> None:
"""
Rooms we join after the `to_token` should *not* show up. See condition "1b)"
- comments in the `get_sync_room_ids_for_user()` method.
+ comments in the `get_room_membership_for_user_at_to_token()` method.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
@@ -978,7 +1001,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
self.helper.join(room_id2, user1_id, tok=user1_tok)
room_id_results = self.get_success(
- self.sliding_sync_handler.get_sync_room_ids_for_user(
+ self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
UserID.from_string(user1_id),
from_token=before_room1_token,
to_token=after_room1_token,
@@ -991,14 +1014,16 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
room_id_results[room_id1].event_id,
join_response1["event_id"],
)
+ self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN)
# We should be `newly_joined` because we joined during the token range
self.assertEqual(room_id_results[room_id1].newly_joined, True)
+ self.assertEqual(room_id_results[room_id1].newly_left, False)
def test_join_during_range_and_left_room_after_to_token(self) -> None:
"""
Room still shows up if we left the room but were joined during the
from_token/to_token. See condition "1a)" comments in the
- `get_sync_room_ids_for_user()` method.
+ `get_room_membership_for_user_at_to_token()` method.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
@@ -1016,7 +1041,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
leave_response = self.helper.leave(room_id1, user1_id, tok=user1_tok)
room_id_results = self.get_success(
- self.sliding_sync_handler.get_sync_room_ids_for_user(
+ self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
UserID.from_string(user1_id),
from_token=before_room1_token,
to_token=after_room1_token,
@@ -1038,14 +1063,16 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
}
),
)
+ self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN)
# We should be `newly_joined` because we joined during the token range
self.assertEqual(room_id_results[room_id1].newly_joined, True)
+ self.assertEqual(room_id_results[room_id1].newly_left, False)
def test_join_before_range_and_left_room_after_to_token(self) -> None:
"""
Room still shows up if we left the room but were joined before the `from_token`
so it should show up. See condition "1a)" comments in the
- `get_sync_room_ids_for_user()` method.
+ `get_room_membership_for_user_at_to_token()` method.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
@@ -1061,7 +1088,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
leave_response = self.helper.leave(room_id1, user1_id, tok=user1_tok)
room_id_results = self.get_success(
- self.sliding_sync_handler.get_sync_room_ids_for_user(
+ self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
UserID.from_string(user1_id),
from_token=after_room1_token,
to_token=after_room1_token,
@@ -1082,14 +1109,16 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
}
),
)
+ self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN)
# We should *NOT* be `newly_joined` because we joined before the token range
self.assertEqual(room_id_results[room_id1].newly_joined, False)
+ self.assertEqual(room_id_results[room_id1].newly_left, False)
def test_kicked_before_range_and_left_after_to_token(self) -> None:
"""
Room still shows up if we left the room but were kicked before the `from_token`
so it should show up. See condition "1a)" comments in the
- `get_sync_room_ids_for_user()` method.
+ `get_room_membership_for_user_at_to_token()` method.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
@@ -1123,7 +1152,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
leave_response = self.helper.leave(kick_room_id, user1_id, tok=user1_tok)
room_id_results = self.get_success(
- self.sliding_sync_handler.get_sync_room_ids_for_user(
+ self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
UserID.from_string(user1_id),
from_token=after_kick_token,
to_token=after_kick_token,
@@ -1146,14 +1175,17 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
}
),
)
+ self.assertEqual(room_id_results[kick_room_id].membership, Membership.LEAVE)
+ self.assertNotEqual(room_id_results[kick_room_id].sender, user1_id)
# We should *NOT* be `newly_joined` because we were kicked
self.assertEqual(room_id_results[kick_room_id].newly_joined, False)
+ self.assertEqual(room_id_results[kick_room_id].newly_left, False)
def test_newly_left_during_range_and_join_leave_after_to_token(self) -> None:
"""
Newly left room should show up. But we're also testing that joining and leaving
after the `to_token` doesn't mess with the results. See condition "2)" and "1a)"
- comments in the `get_sync_room_ids_for_user()` method.
+ comments in the `get_room_membership_for_user_at_to_token()` method.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
@@ -1176,7 +1208,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
leave_response2 = self.helper.leave(room_id1, user1_id, tok=user1_tok)
room_id_results = self.get_success(
- self.sliding_sync_handler.get_sync_room_ids_for_user(
+ self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
UserID.from_string(user1_id),
from_token=before_room1_token,
to_token=after_room1_token,
@@ -1199,14 +1231,17 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
}
),
)
- # We should *NOT* be `newly_joined` because we left during the token range
+ self.assertEqual(room_id_results[room_id1].membership, Membership.LEAVE)
+ # We should *NOT* be `newly_joined` because we are actually `newly_left` during
+ # the token range
self.assertEqual(room_id_results[room_id1].newly_joined, False)
+ self.assertEqual(room_id_results[room_id1].newly_left, True)
def test_newly_left_during_range_and_join_after_to_token(self) -> None:
"""
Newly left room should show up. But we're also testing that joining after the
`to_token` doesn't mess with the results. See condition "2)" and "1b)" comments
- in the `get_sync_room_ids_for_user()` method.
+ in the `get_room_membership_for_user_at_to_token()` method.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
@@ -1228,7 +1263,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
join_response2 = self.helper.join(room_id1, user1_id, tok=user1_tok)
room_id_results = self.get_success(
- self.sliding_sync_handler.get_sync_room_ids_for_user(
+ self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
UserID.from_string(user1_id),
from_token=before_room1_token,
to_token=after_room1_token,
@@ -1250,16 +1285,19 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
}
),
)
- # We should *NOT* be `newly_joined` because we left during the token range
+ self.assertEqual(room_id_results[room_id1].membership, Membership.LEAVE)
+ # We should *NOT* be `newly_joined` because we are actually `newly_left` during
+ # the token range
self.assertEqual(room_id_results[room_id1].newly_joined, False)
+ self.assertEqual(room_id_results[room_id1].newly_left, True)
def test_no_from_token(self) -> None:
"""
- Test that if we don't provide a `from_token`, we get all the rooms that we we're
- joined up to the `to_token`.
+ Test that if we don't provide a `from_token`, we get all the rooms that we had
+ membership in up to the `to_token`.
- Providing `from_token` only really has the effect that it adds `newly_left`
- rooms to the response.
+ Providing `from_token` only really has the effect that it marks rooms as
+ `newly_left` in the response.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
@@ -1276,7 +1314,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
# Join and leave the room2 before the `to_token`
self.helper.join(room_id2, user1_id, tok=user1_tok)
- self.helper.leave(room_id2, user1_id, tok=user1_tok)
+ leave_response2 = self.helper.leave(room_id2, user1_id, tok=user1_tok)
after_room1_token = self.event_sources.get_current_token()
@@ -1284,7 +1322,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
self.helper.join(room_id2, user1_id, tok=user1_tok)
room_id_results = self.get_success(
- self.sliding_sync_handler.get_sync_room_ids_for_user(
+ self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
UserID.from_string(user1_id),
from_token=None,
to_token=after_room1_token,
@@ -1292,15 +1330,31 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
)
# Only rooms we were joined to before the `to_token` should show up
- self.assertEqual(room_id_results.keys(), {room_id1})
+ self.assertEqual(room_id_results.keys(), {room_id1, room_id2})
+
+ # Room1
# It should be pointing to the latest membership event in the from/to range
self.assertEqual(
room_id_results[room_id1].event_id,
join_response1["event_id"],
)
- # We should *NOT* be `newly_joined` because there is no `from_token` to
- # define a "live" range to compare against
+ self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN)
+ # We should *NOT* be `newly_joined`/`newly_left` because there is no
+ # `from_token` to define a "live" range to compare against
self.assertEqual(room_id_results[room_id1].newly_joined, False)
+ self.assertEqual(room_id_results[room_id1].newly_left, False)
+
+ # Room2
+ # It should be pointing to the latest membership event in the from/to range
+ self.assertEqual(
+ room_id_results[room_id2].event_id,
+ leave_response2["event_id"],
+ )
+ self.assertEqual(room_id_results[room_id2].membership, Membership.LEAVE)
+ # We should *NOT* be `newly_joined`/`newly_left` because there is no
+ # `from_token` to define a "live" range to compare against
+ self.assertEqual(room_id_results[room_id2].newly_joined, False)
+ self.assertEqual(room_id_results[room_id2].newly_left, False)
def test_from_token_ahead_of_to_token(self) -> None:
"""
@@ -1319,28 +1373,28 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
room_id3 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True)
room_id4 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True)
- # Join room1 before `before_room_token`
- join_response1 = self.helper.join(room_id1, user1_id, tok=user1_tok)
+ # Join room1 before `to_token`
+ join_room1_response1 = self.helper.join(room_id1, user1_id, tok=user1_tok)
- # Join and leave the room2 before `before_room_token`
- self.helper.join(room_id2, user1_id, tok=user1_tok)
- self.helper.leave(room_id2, user1_id, tok=user1_tok)
+ # Join and leave the room2 before `to_token`
+ _join_room2_response1 = self.helper.join(room_id2, user1_id, tok=user1_tok)
+ leave_room2_response1 = self.helper.leave(room_id2, user1_id, tok=user1_tok)
# Note: These are purposely swapped. The `from_token` should come after
# the `to_token` in this test
to_token = self.event_sources.get_current_token()
- # Join room2 after `before_room_token`
- self.helper.join(room_id2, user1_id, tok=user1_tok)
+ # Join room2 after `to_token`
+ _join_room2_response2 = self.helper.join(room_id2, user1_id, tok=user1_tok)
# --------
- # Join room3 after `before_room_token`
- self.helper.join(room_id3, user1_id, tok=user1_tok)
+ # Join room3 after `to_token`
+ _join_room3_response1 = self.helper.join(room_id3, user1_id, tok=user1_tok)
- # Join and leave the room4 after `before_room_token`
- self.helper.join(room_id4, user1_id, tok=user1_tok)
- self.helper.leave(room_id4, user1_id, tok=user1_tok)
+ # Join and leave the room4 after `to_token`
+ _join_room4_response1 = self.helper.join(room_id4, user1_id, tok=user1_tok)
+ _leave_room4_response1 = self.helper.leave(room_id4, user1_id, tok=user1_tok)
# Note: These are purposely swapped. The `from_token` should come after the
# `to_token` in this test
@@ -1350,31 +1404,59 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
self.helper.join(room_id4, user1_id, tok=user1_tok)
room_id_results = self.get_success(
- self.sliding_sync_handler.get_sync_room_ids_for_user(
+ self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
UserID.from_string(user1_id),
from_token=from_token,
to_token=to_token,
)
)
- # Only rooms we were joined to before the `to_token` should show up
- #
- # There won't be any newly_left rooms because the `from_token` is ahead of the
- # `to_token` and that range will give no membership changes to check.
- self.assertEqual(room_id_results.keys(), {room_id1})
+ # In the "current" state snapshot, we're joined to all of the rooms but in the
+ # from/to token range...
+ self.assertIncludes(
+ room_id_results.keys(),
+ {
+ # Included because we were joined before both tokens
+ room_id1,
+ # Included because we had membership before the to_token
+ room_id2,
+ # Excluded because we joined after the `to_token`
+ # room_id3,
+ # Excluded because we joined after the `to_token`
+ # room_id4,
+ },
+ exact=True,
+ )
+
+ # Room1
# It should be pointing to the latest membership event in the from/to range
self.assertEqual(
room_id_results[room_id1].event_id,
- join_response1["event_id"],
+ join_room1_response1["event_id"],
)
- # We should *NOT* be `newly_joined` because we joined `room1` before either of the tokens
+ self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN)
+ # We should *NOT* be `newly_joined`/`newly_left` because we joined `room1`
+ # before either of the tokens
self.assertEqual(room_id_results[room_id1].newly_joined, False)
+ self.assertEqual(room_id_results[room_id1].newly_left, False)
+
+ # Room2
+ # It should be pointing to the latest membership event in the from/to range
+ self.assertEqual(
+ room_id_results[room_id2].event_id,
+ leave_room2_response1["event_id"],
+ )
+ self.assertEqual(room_id_results[room_id2].membership, Membership.LEAVE)
+ # We should *NOT* be `newly_joined`/`newly_left` because we joined and left
+ # `room1` before either of the tokens
+ self.assertEqual(room_id_results[room_id2].newly_joined, False)
+ self.assertEqual(room_id_results[room_id2].newly_left, False)
def test_leave_before_range_and_join_leave_after_to_token(self) -> None:
"""
- Old left room shouldn't show up. But we're also testing that joining and leaving
- after the `to_token` doesn't mess with the results. See condition "1a)" comments
- in the `get_sync_room_ids_for_user()` method.
+ Test old left rooms. But we're also testing that joining and leaving after the
+ `to_token` doesn't mess with the results. See condition "1a)" comments in the
+ `get_room_membership_for_user_at_to_token()` method.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
@@ -1386,7 +1468,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True)
# Join and leave the room before the from/to range
self.helper.join(room_id1, user1_id, tok=user1_tok)
- self.helper.leave(room_id1, user1_id, tok=user1_tok)
+ leave_response = self.helper.leave(room_id1, user1_id, tok=user1_tok)
after_room1_token = self.event_sources.get_current_token()
@@ -1395,21 +1477,30 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
self.helper.leave(room_id1, user1_id, tok=user1_tok)
room_id_results = self.get_success(
- self.sliding_sync_handler.get_sync_room_ids_for_user(
+ self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
UserID.from_string(user1_id),
from_token=after_room1_token,
to_token=after_room1_token,
)
)
- # Room shouldn't show up because it was left before the `from_token`
- self.assertEqual(room_id_results.keys(), set())
+ self.assertEqual(room_id_results.keys(), {room_id1})
+ # It should be pointing to the latest membership event in the from/to range
+ self.assertEqual(
+ room_id_results[room_id1].event_id,
+ leave_response["event_id"],
+ )
+ self.assertEqual(room_id_results[room_id1].membership, Membership.LEAVE)
+ # We should *NOT* be `newly_joined`/`newly_left` because we joined and left
+ # `room1` before either of the tokens
+ self.assertEqual(room_id_results[room_id1].newly_joined, False)
+ self.assertEqual(room_id_results[room_id1].newly_left, False)
def test_leave_before_range_and_join_after_to_token(self) -> None:
"""
- Old left room shouldn't show up. But we're also testing that joining after the
- `to_token` doesn't mess with the results. See condition "1b)" comments in the
- `get_sync_room_ids_for_user()` method.
+ Test old left room. But we're also testing that joining after the `to_token`
+ doesn't mess with the results. See condition "1b)" comments in the
+ `get_room_membership_for_user_at_to_token()` method.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
@@ -1421,7 +1512,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True)
# Join and leave the room before the from/to range
self.helper.join(room_id1, user1_id, tok=user1_tok)
- self.helper.leave(room_id1, user1_id, tok=user1_tok)
+ leave_response = self.helper.leave(room_id1, user1_id, tok=user1_tok)
after_room1_token = self.event_sources.get_current_token()
@@ -1429,24 +1520,32 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
self.helper.join(room_id1, user1_id, tok=user1_tok)
room_id_results = self.get_success(
- self.sliding_sync_handler.get_sync_room_ids_for_user(
+ self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
UserID.from_string(user1_id),
from_token=after_room1_token,
to_token=after_room1_token,
)
)
- # Room shouldn't show up because it was left before the `from_token`
- self.assertEqual(room_id_results.keys(), set())
+ self.assertEqual(room_id_results.keys(), {room_id1})
+ # It should be pointing to the latest membership event in the from/to range
+ self.assertEqual(
+ room_id_results[room_id1].event_id,
+ leave_response["event_id"],
+ )
+ self.assertEqual(room_id_results[room_id1].membership, Membership.LEAVE)
+ # We should *NOT* be `newly_joined`/`newly_left` because we joined and left
+ # `room1` before either of the tokens
+ self.assertEqual(room_id_results[room_id1].newly_joined, False)
+ self.assertEqual(room_id_results[room_id1].newly_left, False)
def test_join_leave_multiple_times_during_range_and_after_to_token(
self,
) -> None:
"""
Join and leave multiple times shouldn't affect rooms from showing up. It just
- matters that we were joined or newly_left in the from/to range. But we're also
- testing that joining and leaving after the `to_token` doesn't mess with the
- results.
+ matters that we had membership in the from/to range. But we're also testing that
+ joining and leaving after the `to_token` doesn't mess with the results.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
@@ -1458,7 +1557,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
# We create the room with user2 so the room isn't left with no members when we
# leave and can still re-join.
room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True)
- # Join, leave, join back to the room before the from/to range
+ # Join, leave, join back to the room during the from/to range
join_response1 = self.helper.join(room_id1, user1_id, tok=user1_tok)
leave_response1 = self.helper.leave(room_id1, user1_id, tok=user1_tok)
join_response2 = self.helper.join(room_id1, user1_id, tok=user1_tok)
@@ -1471,7 +1570,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
leave_response3 = self.helper.leave(room_id1, user1_id, tok=user1_tok)
room_id_results = self.get_success(
- self.sliding_sync_handler.get_sync_room_ids_for_user(
+ self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
UserID.from_string(user1_id),
from_token=before_room1_token,
to_token=after_room1_token,
@@ -1496,15 +1595,19 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
}
),
)
+ self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN)
# We should be `newly_joined` because we joined during the token range
self.assertEqual(room_id_results[room_id1].newly_joined, True)
+ # We should *NOT* be `newly_left` because we joined during the token range and
+ # was still joined at the end of the range
+ self.assertEqual(room_id_results[room_id1].newly_left, False)
def test_join_leave_multiple_times_before_range_and_after_to_token(
self,
) -> None:
"""
Join and leave multiple times before the from/to range shouldn't affect rooms
- from showing up. It just matters that we were joined or newly_left in the
+ from showing up. It just matters that we had membership in the
from/to range. But we're also testing that joining and leaving after the
`to_token` doesn't mess with the results.
"""
@@ -1529,7 +1632,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
leave_response3 = self.helper.leave(room_id1, user1_id, tok=user1_tok)
room_id_results = self.get_success(
- self.sliding_sync_handler.get_sync_room_ids_for_user(
+ self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
UserID.from_string(user1_id),
from_token=after_room1_token,
to_token=after_room1_token,
@@ -1554,8 +1657,10 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
}
),
)
+ self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN)
# We should *NOT* be `newly_joined` because we joined before the token range
self.assertEqual(room_id_results[room_id1].newly_joined, False)
+ self.assertEqual(room_id_results[room_id1].newly_left, False)
def test_invite_before_range_and_join_leave_after_to_token(
self,
@@ -1563,7 +1668,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
"""
Make it look like we joined after the token range but we were invited before the
from/to range so the room should still show up. See condition "1a)" comments in
- the `get_sync_room_ids_for_user()` method.
+ the `get_room_membership_for_user_at_to_token()` method.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
@@ -1586,7 +1691,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
leave_response = self.helper.leave(room_id1, user1_id, tok=user1_tok)
room_id_results = self.get_success(
- self.sliding_sync_handler.get_sync_room_ids_for_user(
+ self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
UserID.from_string(user1_id),
from_token=after_room1_token,
to_token=after_room1_token,
@@ -1608,9 +1713,11 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
}
),
)
+ self.assertEqual(room_id_results[room_id1].membership, Membership.INVITE)
# We should *NOT* be `newly_joined` because we were only invited before the
# token range
self.assertEqual(room_id_results[room_id1].newly_joined, False)
+ self.assertEqual(room_id_results[room_id1].newly_left, False)
def test_join_and_display_name_changes_in_token_range(
self,
@@ -1658,7 +1765,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
)
room_id_results = self.get_success(
- self.sliding_sync_handler.get_sync_room_ids_for_user(
+ self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
UserID.from_string(user1_id),
from_token=before_room1_token,
to_token=after_room1_token,
@@ -1684,8 +1791,10 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
}
),
)
+ self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN)
# We should be `newly_joined` because we joined during the token range
self.assertEqual(room_id_results[room_id1].newly_joined, True)
+ self.assertEqual(room_id_results[room_id1].newly_left, False)
def test_display_name_changes_in_token_range(
self,
@@ -1721,7 +1830,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
after_change1_token = self.event_sources.get_current_token()
room_id_results = self.get_success(
- self.sliding_sync_handler.get_sync_room_ids_for_user(
+ self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
UserID.from_string(user1_id),
from_token=after_room1_token,
to_token=after_change1_token,
@@ -1744,8 +1853,10 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
}
),
)
+ self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN)
# We should *NOT* be `newly_joined` because we joined before the token range
self.assertEqual(room_id_results[room_id1].newly_joined, False)
+ self.assertEqual(room_id_results[room_id1].newly_left, False)
def test_display_name_changes_before_and_after_token_range(
self,
@@ -1791,7 +1902,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
)
room_id_results = self.get_success(
- self.sliding_sync_handler.get_sync_room_ids_for_user(
+ self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
UserID.from_string(user1_id),
from_token=after_room1_token,
to_token=after_room1_token,
@@ -1817,8 +1928,10 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
}
),
)
+ self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN)
# We should *NOT* be `newly_joined` because we joined before the token range
self.assertEqual(room_id_results[room_id1].newly_joined, False)
+ self.assertEqual(room_id_results[room_id1].newly_left, False)
def test_display_name_changes_leave_after_token_range(
self,
@@ -1828,7 +1941,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
if there are multiple `join` membership events in a row indicating
`displayname`/`avatar_url` updates and we leave after the `to_token`.
- See condition "1a)" comments in the `get_sync_room_ids_for_user()` method.
+ See condition "1a)" comments in the `get_room_membership_for_user_at_to_token()` method.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
@@ -1871,7 +1984,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
self.helper.leave(room_id1, user1_id, tok=user1_tok)
room_id_results = self.get_success(
- self.sliding_sync_handler.get_sync_room_ids_for_user(
+ self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
UserID.from_string(user1_id),
from_token=before_room1_token,
to_token=after_room1_token,
@@ -1897,8 +2010,10 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
}
),
)
+ self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN)
# We should be `newly_joined` because we joined during the token range
self.assertEqual(room_id_results[room_id1].newly_joined, True)
+ self.assertEqual(room_id_results[room_id1].newly_left, False)
def test_display_name_changes_join_after_token_range(
self,
@@ -1908,7 +2023,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
indicating `displayname`/`avatar_url` updates doesn't affect the results (we
joined after the token range so it shouldn't show up)
- See condition "1b)" comments in the `get_sync_room_ids_for_user()` method.
+ See condition "1b)" comments in the `get_room_membership_for_user_at_to_token()` method.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
@@ -1937,7 +2052,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
)
room_id_results = self.get_success(
- self.sliding_sync_handler.get_sync_room_ids_for_user(
+ self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
UserID.from_string(user1_id),
from_token=before_room1_token,
to_token=after_room1_token,
@@ -1973,7 +2088,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
after_more_changes_token = self.event_sources.get_current_token()
room_id_results = self.get_success(
- self.sliding_sync_handler.get_sync_room_ids_for_user(
+ self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
UserID.from_string(user1_id),
from_token=after_room1_token,
to_token=after_more_changes_token,
@@ -1987,9 +2102,11 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
room_id_results[room_id1].event_id,
join_response2["event_id"],
)
+ self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN)
# We should be considered `newly_joined` because there is some non-join event in
# between our latest join event.
self.assertEqual(room_id_results[room_id1].newly_joined, True)
+ self.assertEqual(room_id_results[room_id1].newly_left, False)
def test_newly_joined_only_joins_during_token_range(
self,
@@ -2036,7 +2153,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
after_room1_token = self.event_sources.get_current_token()
room_id_results = self.get_success(
- self.sliding_sync_handler.get_sync_room_ids_for_user(
+ self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
UserID.from_string(user1_id),
from_token=before_room1_token,
to_token=after_room1_token,
@@ -2062,8 +2179,10 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
}
),
)
+ self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN)
# We should be `newly_joined` because we first joined during the token range
self.assertEqual(room_id_results[room_id1].newly_joined, True)
+ self.assertEqual(room_id_results[room_id1].newly_left, False)
def test_multiple_rooms_are_not_confused(
self,
@@ -2086,16 +2205,18 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
# Invited and left the room before the token
self.helper.invite(room_id1, src=user2_id, targ=user1_id, tok=user2_tok)
- self.helper.leave(room_id1, user1_id, tok=user1_tok)
+ leave_room1_response = self.helper.leave(room_id1, user1_id, tok=user1_tok)
# Invited to room2
- self.helper.invite(room_id2, src=user2_id, targ=user1_id, tok=user2_tok)
+ invite_room2_response = self.helper.invite(
+ room_id2, src=user2_id, targ=user1_id, tok=user2_tok
+ )
before_room3_token = self.event_sources.get_current_token()
# Invited and left room3 during the from/to range
room_id3 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True)
self.helper.invite(room_id3, src=user2_id, targ=user1_id, tok=user2_tok)
- self.helper.leave(room_id3, user1_id, tok=user1_tok)
+ leave_room3_response = self.helper.leave(room_id3, user1_id, tok=user1_tok)
after_room3_token = self.event_sources.get_current_token()
@@ -2108,7 +2229,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
self.helper.leave(room_id3, user1_id, tok=user1_tok)
room_id_results = self.get_success(
- self.sliding_sync_handler.get_sync_room_ids_for_user(
+ self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
UserID.from_string(user1_id),
from_token=before_room3_token,
to_token=after_room3_token,
@@ -2118,19 +2239,158 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
self.assertEqual(
room_id_results.keys(),
{
- # `room_id1` shouldn't show up because we left before the from/to range
- #
- # Room should show up because we were invited before the from/to range
+ # Left before the from/to range
+ room_id1,
+ # Invited before the from/to range
room_id2,
- # Room should show up because it was newly_left during the from/to range
+ # `newly_left` during the from/to range
room_id3,
},
)
+ # Room1
+ # It should be pointing to the latest membership event in the from/to range
+ self.assertEqual(
+ room_id_results[room_id1].event_id,
+ leave_room1_response["event_id"],
+ )
+ self.assertEqual(room_id_results[room_id1].membership, Membership.LEAVE)
+ # We should *NOT* be `newly_joined`/`newly_left` because we were invited and left
+ # before the token range
+ self.assertEqual(room_id_results[room_id1].newly_joined, False)
+ self.assertEqual(room_id_results[room_id1].newly_left, False)
+
+ # Room2
+ # It should be pointing to the latest membership event in the from/to range
+ self.assertEqual(
+ room_id_results[room_id2].event_id,
+ invite_room2_response["event_id"],
+ )
+ self.assertEqual(room_id_results[room_id2].membership, Membership.INVITE)
+ # We should *NOT* be `newly_joined`/`newly_left` because we were invited before
+ # the token range
+ self.assertEqual(room_id_results[room_id2].newly_joined, False)
+ self.assertEqual(room_id_results[room_id2].newly_left, False)
+
+ # Room3
+ # It should be pointing to the latest membership event in the from/to range
+ self.assertEqual(
+ room_id_results[room_id3].event_id,
+ leave_room3_response["event_id"],
+ )
+ self.assertEqual(room_id_results[room_id3].membership, Membership.LEAVE)
+ # We should be `newly_left` because we were invited and left during
+ # the token range
+ self.assertEqual(room_id_results[room_id3].newly_joined, False)
+ self.assertEqual(room_id_results[room_id3].newly_left, True)
+
+ def test_state_reset(self) -> None:
+ """
+ Test a state reset scenario where the user gets removed from the room (when
+ there is no corresponding leave event)
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ # The room where the state reset will happen
+ room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
+ join_response1 = self.helper.join(room_id1, user1_id, tok=user1_tok)
+
+ # Join another room so we don't hit the short-circuit and return early if they
+ # have no room membership
+ room_id2 = self.helper.create_room_as(user2_id, tok=user2_tok)
+ self.helper.join(room_id2, user1_id, tok=user1_tok)
+
+ before_reset_token = self.event_sources.get_current_token()
+
+ # Send another state event to make a position for the state reset to happen at
+ dummy_state_response = self.helper.send_state(
+ room_id1,
+ event_type="foobarbaz",
+ state_key="",
+ body={"foo": "bar"},
+ tok=user2_tok,
+ )
+ dummy_state_pos = self.get_success(
+ self.store.get_position_for_event(dummy_state_response["event_id"])
+ )
+
+ # Mock a state reset removing the membership for user1 in the current state
+ self.get_success(
+ self.store.db_pool.simple_delete(
+ table="current_state_events",
+ keyvalues={
+ "room_id": room_id1,
+ "type": EventTypes.Member,
+ "state_key": user1_id,
+ },
+ desc="state reset user in current_state_events",
+ )
+ )
+ self.get_success(
+ self.store.db_pool.simple_delete(
+ table="local_current_membership",
+ keyvalues={
+ "room_id": room_id1,
+ "user_id": user1_id,
+ },
+ desc="state reset user in local_current_membership",
+ )
+ )
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ table="current_state_delta_stream",
+ values={
+ "stream_id": dummy_state_pos.stream,
+ "room_id": room_id1,
+ "type": EventTypes.Member,
+ "state_key": user1_id,
+ "event_id": None,
+ "prev_event_id": join_response1["event_id"],
+ "instance_name": dummy_state_pos.instance_name,
+ },
+ desc="state reset user in current_state_delta_stream",
+ )
+ )
+
+ # Manually bust the cache since we we're just manually messing with the database
+ # and not causing an actual state reset.
+ self.store._membership_stream_cache.entity_has_changed(
+ user1_id, dummy_state_pos.stream
+ )
+
+ after_reset_token = self.event_sources.get_current_token()
+
+ # The function under test
+ room_id_results = self.get_success(
+ self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
+ UserID.from_string(user1_id),
+ from_token=before_reset_token,
+ to_token=after_reset_token,
+ )
+ )
-class GetSyncRoomIdsForUserEventShardTestCase(BaseMultiWorkerStreamTestCase):
+ # Room1 should show up because it was `newly_left` via state reset during the from/to range
+ self.assertEqual(room_id_results.keys(), {room_id1, room_id2})
+ # It should be pointing to no event because we were removed from the room
+ # without a corresponding leave event
+ self.assertEqual(
+ room_id_results[room_id1].event_id,
+ None,
+ )
+ # State reset caused us to leave the room and there is no corresponding leave event
+ self.assertEqual(room_id_results[room_id1].membership, Membership.LEAVE)
+ # We should *NOT* be `newly_joined` because we joined before the token range
+ self.assertEqual(room_id_results[room_id1].newly_joined, False)
+ # We should be `newly_left` because we were removed via state reset during the from/to range
+ self.assertEqual(room_id_results[room_id1].newly_left, True)
+
+
+class GetRoomMembershipForUserAtToTokenShardTestCase(BaseMultiWorkerStreamTestCase):
"""
- Tests Sliding Sync handler `get_sync_room_ids_for_user()` to make sure it works with
+ Tests Sliding Sync handler `get_room_membership_for_user_at_to_token()` to make sure it works with
sharded event stream_writers enabled
"""
@@ -2189,7 +2449,7 @@ class GetSyncRoomIdsForUserEventShardTestCase(BaseMultiWorkerStreamTestCase):
We then send some events to advance the stream positions of worker1 and worker3
but worker2 is lagging behind because it's stuck. We are specifically testing
- that `get_sync_room_ids_for_user(from_token=xxx, to_token=xxx)` should work
+ that `get_room_membership_for_user_at_to_token(from_token=xxx, to_token=xxx)` should work
correctly in these adverse conditions.
"""
user1_id = self.register_user("user1", "pass")
@@ -2228,7 +2488,7 @@ class GetSyncRoomIdsForUserEventShardTestCase(BaseMultiWorkerStreamTestCase):
join_response1 = self.helper.join(room_id1, user1_id, tok=user1_tok)
join_response2 = self.helper.join(room_id2, user1_id, tok=user1_tok)
# Leave room2
- self.helper.leave(room_id2, user1_id, tok=user1_tok)
+ leave_room2_response = self.helper.leave(room_id2, user1_id, tok=user1_tok)
join_response3 = self.helper.join(room_id3, user1_id, tok=user1_tok)
# Leave room3
self.helper.leave(room_id3, user1_id, tok=user1_tok)
@@ -2265,7 +2525,7 @@ class GetSyncRoomIdsForUserEventShardTestCase(BaseMultiWorkerStreamTestCase):
# For room_id1/worker1: leave and join the room to advance the stream position
# and generate membership changes.
self.helper.leave(room_id1, user1_id, tok=user1_tok)
- self.helper.join(room_id1, user1_id, tok=user1_tok)
+ join_room1_response = self.helper.join(room_id1, user1_id, tok=user1_tok)
# For room_id2/worker2: which is currently stuck, join the room.
join_on_worker2_response = self.helper.join(room_id2, user1_id, tok=user1_tok)
# For room_id3/worker3: leave and join the room to advance the stream position
@@ -2319,7 +2579,7 @@ class GetSyncRoomIdsForUserEventShardTestCase(BaseMultiWorkerStreamTestCase):
# The function under test
room_id_results = self.get_success(
- self.sliding_sync_handler.get_sync_room_ids_for_user(
+ self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
UserID.from_string(user1_id),
from_token=before_stuck_activity_token,
to_token=stuck_activity_token,
@@ -2330,18 +2590,411 @@ class GetSyncRoomIdsForUserEventShardTestCase(BaseMultiWorkerStreamTestCase):
room_id_results.keys(),
{
room_id1,
- # room_id2 shouldn't show up because we left before the from/to range
- # and the join event during the range happened while worker2 was stuck.
- # This means that from the perspective of the master, where the
- # `stuck_activity_token` is generated, the stream position for worker2
- # wasn't advanced to the join yet. Looking at the `instance_map`, the
- # join technically comes after `stuck_activity_token``.
- #
- # room_id2,
+ room_id2,
room_id3,
},
)
+ # Room1
+ # It should be pointing to the latest membership event in the from/to range
+ self.assertEqual(
+ room_id_results[room_id1].event_id,
+ join_room1_response["event_id"],
+ )
+ self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN)
+ # We should be `newly_joined` because we joined during the token range
+ self.assertEqual(room_id_results[room_id1].newly_joined, True)
+ self.assertEqual(room_id_results[room_id1].newly_left, False)
+
+ # Room2
+ # It should be pointing to the latest membership event in the from/to range
+ self.assertEqual(
+ room_id_results[room_id2].event_id,
+ leave_room2_response["event_id"],
+ )
+ self.assertEqual(room_id_results[room_id2].membership, Membership.LEAVE)
+ # room_id2 should *NOT* be considered `newly_left` because we left before the
+ # from/to range and the join event during the range happened while worker2 was
+ # stuck. This means that from the perspective of the master, where the
+ # `stuck_activity_token` is generated, the stream position for worker2 wasn't
+ # advanced to the join yet. Looking at the `instance_map`, the join technically
+ # comes after `stuck_activity_token`.
+ self.assertEqual(room_id_results[room_id2].newly_joined, False)
+ self.assertEqual(room_id_results[room_id2].newly_left, False)
+
+ # Room3
+ # It should be pointing to the latest membership event in the from/to range
+ self.assertEqual(
+ room_id_results[room_id3].event_id,
+ join_on_worker3_response["event_id"],
+ )
+ self.assertEqual(room_id_results[room_id3].membership, Membership.JOIN)
+ # We should be `newly_joined` because we joined during the token range
+ self.assertEqual(room_id_results[room_id3].newly_joined, True)
+ self.assertEqual(room_id_results[room_id3].newly_left, False)
+
+
+class FilterRoomsRelevantForSyncTestCase(HomeserverTestCase):
+ """
+ Tests Sliding Sync handler `filter_rooms_relevant_for_sync()` to make sure it returns
+ the correct list of rooms IDs.
+ """
+
+ servlets = [
+ admin.register_servlets,
+ knock.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def default_config(self) -> JsonDict:
+ config = super().default_config()
+ # Enable sliding sync
+ config["experimental_features"] = {"msc3575_enabled": True}
+ return config
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.sliding_sync_handler = self.hs.get_sliding_sync_handler()
+ self.store = self.hs.get_datastores().main
+ self.event_sources = hs.get_event_sources()
+ self.storage_controllers = hs.get_storage_controllers()
+
+ def _get_sync_room_ids_for_user(
+ self,
+ user: UserID,
+ to_token: StreamToken,
+ from_token: Optional[StreamToken],
+ ) -> Dict[str, _RoomMembershipForUser]:
+ """
+ Get the rooms the user should be syncing with
+ """
+ room_membership_for_user_map = self.get_success(
+ self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
+ user=user,
+ from_token=from_token,
+ to_token=to_token,
+ )
+ )
+ filtered_sync_room_map = self.get_success(
+ self.sliding_sync_handler.filter_rooms_relevant_for_sync(
+ user=user,
+ room_membership_for_user_map=room_membership_for_user_map,
+ )
+ )
+
+ return filtered_sync_room_map
+
+ def test_no_rooms(self) -> None:
+ """
+ Test when the user has never joined any rooms before
+ """
+ user1_id = self.register_user("user1", "pass")
+ # user1_tok = self.login(user1_id, "pass")
+
+ now_token = self.event_sources.get_current_token()
+
+ room_id_results = self._get_sync_room_ids_for_user(
+ UserID.from_string(user1_id),
+ from_token=now_token,
+ to_token=now_token,
+ )
+
+ self.assertEqual(room_id_results.keys(), set())
+
+ def test_basic_rooms(self) -> None:
+ """
+ Test that rooms that the user is joined to, invited to, banned from, and knocked
+ on show up.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ before_room_token = self.event_sources.get_current_token()
+
+ join_room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
+ join_response = self.helper.join(join_room_id, user1_id, tok=user1_tok)
+
+ # Setup the invited room (user2 invites user1 to the room)
+ invited_room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
+ invite_response = self.helper.invite(
+ invited_room_id, targ=user1_id, tok=user2_tok
+ )
+
+ # Setup the ban room (user2 bans user1 from the room)
+ ban_room_id = self.helper.create_room_as(
+ user2_id, tok=user2_tok, is_public=True
+ )
+ self.helper.join(ban_room_id, user1_id, tok=user1_tok)
+ ban_response = self.helper.ban(
+ ban_room_id, src=user2_id, targ=user1_id, tok=user2_tok
+ )
+
+ # Setup the knock room (user1 knocks on the room)
+ knock_room_id = self.helper.create_room_as(
+ user2_id, tok=user2_tok, room_version=RoomVersions.V7.identifier
+ )
+ self.helper.send_state(
+ knock_room_id,
+ EventTypes.JoinRules,
+ {"join_rule": JoinRules.KNOCK},
+ tok=user2_tok,
+ )
+ # User1 knocks on the room
+ knock_channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/knock/%s" % (knock_room_id,),
+ b"{}",
+ user1_tok,
+ )
+ self.assertEqual(knock_channel.code, 200, knock_channel.result)
+ knock_room_membership_state_event = self.get_success(
+ self.storage_controllers.state.get_current_state_event(
+ knock_room_id, EventTypes.Member, user1_id
+ )
+ )
+ assert knock_room_membership_state_event is not None
+
+ after_room_token = self.event_sources.get_current_token()
+
+ room_id_results = self._get_sync_room_ids_for_user(
+ UserID.from_string(user1_id),
+ from_token=before_room_token,
+ to_token=after_room_token,
+ )
+
+ # Ensure that the invited, ban, and knock rooms show up
+ self.assertEqual(
+ room_id_results.keys(),
+ {
+ join_room_id,
+ invited_room_id,
+ ban_room_id,
+ knock_room_id,
+ },
+ )
+ # It should be pointing to the the respective membership event (latest
+ # membership event in the from/to range)
+ self.assertEqual(
+ room_id_results[join_room_id].event_id,
+ join_response["event_id"],
+ )
+ self.assertEqual(room_id_results[join_room_id].membership, Membership.JOIN)
+ self.assertEqual(room_id_results[join_room_id].newly_joined, True)
+ self.assertEqual(room_id_results[join_room_id].newly_left, False)
+
+ self.assertEqual(
+ room_id_results[invited_room_id].event_id,
+ invite_response["event_id"],
+ )
+ self.assertEqual(room_id_results[invited_room_id].membership, Membership.INVITE)
+ self.assertEqual(room_id_results[invited_room_id].newly_joined, False)
+ self.assertEqual(room_id_results[invited_room_id].newly_left, False)
+
+ self.assertEqual(
+ room_id_results[ban_room_id].event_id,
+ ban_response["event_id"],
+ )
+ self.assertEqual(room_id_results[ban_room_id].membership, Membership.BAN)
+ self.assertEqual(room_id_results[ban_room_id].newly_joined, False)
+ self.assertEqual(room_id_results[ban_room_id].newly_left, False)
+
+ self.assertEqual(
+ room_id_results[knock_room_id].event_id,
+ knock_room_membership_state_event.event_id,
+ )
+ self.assertEqual(room_id_results[knock_room_id].membership, Membership.KNOCK)
+ self.assertEqual(room_id_results[knock_room_id].newly_joined, False)
+ self.assertEqual(room_id_results[knock_room_id].newly_left, False)
+
+ def test_only_newly_left_rooms_show_up(self) -> None:
+ """
+ Test that `newly_left` rooms still show up in the sync response but rooms that
+ were left before the `from_token` don't show up. See condition "2)" comments in
+ the `get_room_membership_for_user_at_to_token()` method.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ # Leave before we calculate the `from_token`
+ room_id1 = self.helper.create_room_as(user1_id, tok=user1_tok)
+ self.helper.leave(room_id1, user1_id, tok=user1_tok)
+
+ after_room1_token = self.event_sources.get_current_token()
+
+ # Leave during the from_token/to_token range (newly_left)
+ room_id2 = self.helper.create_room_as(user1_id, tok=user1_tok)
+ _leave_response2 = self.helper.leave(room_id2, user1_id, tok=user1_tok)
+
+ after_room2_token = self.event_sources.get_current_token()
+
+ room_id_results = self._get_sync_room_ids_for_user(
+ UserID.from_string(user1_id),
+ from_token=after_room1_token,
+ to_token=after_room2_token,
+ )
+
+ # Only the `newly_left` room should show up
+ self.assertEqual(room_id_results.keys(), {room_id2})
+ self.assertEqual(
+ room_id_results[room_id2].event_id,
+ _leave_response2["event_id"],
+ )
+ # We should *NOT* be `newly_joined` because we are instead `newly_left`
+ self.assertEqual(room_id_results[room_id2].newly_joined, False)
+ self.assertEqual(room_id_results[room_id2].newly_left, True)
+
+ def test_get_kicked_room(self) -> None:
+ """
+ Test that a room that the user was kicked from still shows up. When the user
+ comes back to their client, they should see that they were kicked.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ # Setup the kick room (user2 kicks user1 from the room)
+ kick_room_id = self.helper.create_room_as(
+ user2_id, tok=user2_tok, is_public=True
+ )
+ self.helper.join(kick_room_id, user1_id, tok=user1_tok)
+ # Kick user1 from the room
+ kick_response = self.helper.change_membership(
+ room=kick_room_id,
+ src=user2_id,
+ targ=user1_id,
+ tok=user2_tok,
+ membership=Membership.LEAVE,
+ extra_data={
+ "reason": "Bad manners",
+ },
+ )
+
+ after_kick_token = self.event_sources.get_current_token()
+
+ room_id_results = self._get_sync_room_ids_for_user(
+ UserID.from_string(user1_id),
+ from_token=after_kick_token,
+ to_token=after_kick_token,
+ )
+
+ # The kicked room should show up
+ self.assertEqual(room_id_results.keys(), {kick_room_id})
+ # It should be pointing to the latest membership event in the from/to range
+ self.assertEqual(
+ room_id_results[kick_room_id].event_id,
+ kick_response["event_id"],
+ )
+ self.assertEqual(room_id_results[kick_room_id].membership, Membership.LEAVE)
+ self.assertNotEqual(room_id_results[kick_room_id].sender, user1_id)
+ # We should *NOT* be `newly_joined` because we were not joined at the the time
+ # of the `to_token`.
+ self.assertEqual(room_id_results[kick_room_id].newly_joined, False)
+ self.assertEqual(room_id_results[kick_room_id].newly_left, False)
+
+ def test_state_reset(self) -> None:
+ """
+ Test a state reset scenario where the user gets removed from the room (when
+ there is no corresponding leave event)
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ # The room where the state reset will happen
+ room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
+ join_response1 = self.helper.join(room_id1, user1_id, tok=user1_tok)
+
+ # Join another room so we don't hit the short-circuit and return early if they
+ # have no room membership
+ room_id2 = self.helper.create_room_as(user2_id, tok=user2_tok)
+ self.helper.join(room_id2, user1_id, tok=user1_tok)
+
+ before_reset_token = self.event_sources.get_current_token()
+
+ # Send another state event to make a position for the state reset to happen at
+ dummy_state_response = self.helper.send_state(
+ room_id1,
+ event_type="foobarbaz",
+ state_key="",
+ body={"foo": "bar"},
+ tok=user2_tok,
+ )
+ dummy_state_pos = self.get_success(
+ self.store.get_position_for_event(dummy_state_response["event_id"])
+ )
+
+ # Mock a state reset removing the membership for user1 in the current state
+ self.get_success(
+ self.store.db_pool.simple_delete(
+ table="current_state_events",
+ keyvalues={
+ "room_id": room_id1,
+ "type": EventTypes.Member,
+ "state_key": user1_id,
+ },
+ desc="state reset user in current_state_events",
+ )
+ )
+ self.get_success(
+ self.store.db_pool.simple_delete(
+ table="local_current_membership",
+ keyvalues={
+ "room_id": room_id1,
+ "user_id": user1_id,
+ },
+ desc="state reset user in local_current_membership",
+ )
+ )
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ table="current_state_delta_stream",
+ values={
+ "stream_id": dummy_state_pos.stream,
+ "room_id": room_id1,
+ "type": EventTypes.Member,
+ "state_key": user1_id,
+ "event_id": None,
+ "prev_event_id": join_response1["event_id"],
+ "instance_name": dummy_state_pos.instance_name,
+ },
+ desc="state reset user in current_state_delta_stream",
+ )
+ )
+
+ # Manually bust the cache since we we're just manually messing with the database
+ # and not causing an actual state reset.
+ self.store._membership_stream_cache.entity_has_changed(
+ user1_id, dummy_state_pos.stream
+ )
+
+ after_reset_token = self.event_sources.get_current_token()
+
+ # The function under test
+ room_id_results = self._get_sync_room_ids_for_user(
+ UserID.from_string(user1_id),
+ from_token=before_reset_token,
+ to_token=after_reset_token,
+ )
+
+ # Room1 should show up because it was `newly_left` via state reset during the from/to range
+ self.assertEqual(room_id_results.keys(), {room_id1, room_id2})
+ # It should be pointing to no event because we were removed from the room
+ # without a corresponding leave event
+ self.assertEqual(
+ room_id_results[room_id1].event_id,
+ None,
+ )
+ # State reset caused us to leave the room and there is no corresponding leave event
+ self.assertEqual(room_id_results[room_id1].membership, Membership.LEAVE)
+ # We should *NOT* be `newly_joined` because we joined before the token range
+ self.assertEqual(room_id_results[room_id1].newly_joined, False)
+ # We should be `newly_left` because we were removed via state reset during the from/to range
+ self.assertEqual(room_id_results[room_id1].newly_left, True)
+
class FilterRoomsTestCase(HomeserverTestCase):
"""
@@ -2367,6 +3020,31 @@ class FilterRoomsTestCase(HomeserverTestCase):
self.store = self.hs.get_datastores().main
self.event_sources = hs.get_event_sources()
+ def _get_sync_room_ids_for_user(
+ self,
+ user: UserID,
+ to_token: StreamToken,
+ from_token: Optional[StreamToken],
+ ) -> Dict[str, _RoomMembershipForUser]:
+ """
+ Get the rooms the user should be syncing with
+ """
+ room_membership_for_user_map = self.get_success(
+ self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
+ user=user,
+ from_token=from_token,
+ to_token=to_token,
+ )
+ )
+ filtered_sync_room_map = self.get_success(
+ self.sliding_sync_handler.filter_rooms_relevant_for_sync(
+ user=user,
+ room_membership_for_user_map=room_membership_for_user_map,
+ )
+ )
+
+ return filtered_sync_room_map
+
def _create_dm_room(
self,
inviter_user_id: str,
@@ -2438,12 +3116,10 @@ class FilterRoomsTestCase(HomeserverTestCase):
after_rooms_token = self.event_sources.get_current_token()
# Get the rooms the user should be syncing with
- sync_room_map = self.get_success(
- self.sliding_sync_handler.get_sync_room_ids_for_user(
- UserID.from_string(user1_id),
- from_token=None,
- to_token=after_rooms_token,
- )
+ sync_room_map = self._get_sync_room_ids_for_user(
+ UserID.from_string(user1_id),
+ from_token=None,
+ to_token=after_rooms_token,
)
# Try with `is_dm=True`
@@ -2496,12 +3172,10 @@ class FilterRoomsTestCase(HomeserverTestCase):
after_rooms_token = self.event_sources.get_current_token()
# Get the rooms the user should be syncing with
- sync_room_map = self.get_success(
- self.sliding_sync_handler.get_sync_room_ids_for_user(
- UserID.from_string(user1_id),
- from_token=None,
- to_token=after_rooms_token,
- )
+ sync_room_map = self._get_sync_room_ids_for_user(
+ UserID.from_string(user1_id),
+ from_token=None,
+ to_token=after_rooms_token,
)
# Try with `is_encrypted=True`
@@ -2552,12 +3226,10 @@ class FilterRoomsTestCase(HomeserverTestCase):
after_rooms_token = self.event_sources.get_current_token()
# Get the rooms the user should be syncing with
- sync_room_map = self.get_success(
- self.sliding_sync_handler.get_sync_room_ids_for_user(
- UserID.from_string(user1_id),
- from_token=None,
- to_token=after_rooms_token,
- )
+ sync_room_map = self._get_sync_room_ids_for_user(
+ UserID.from_string(user1_id),
+ from_token=None,
+ to_token=after_rooms_token,
)
# Try with `is_invite=True`
@@ -2621,12 +3293,10 @@ class FilterRoomsTestCase(HomeserverTestCase):
after_rooms_token = self.event_sources.get_current_token()
# Get the rooms the user should be syncing with
- sync_room_map = self.get_success(
- self.sliding_sync_handler.get_sync_room_ids_for_user(
- UserID.from_string(user1_id),
- from_token=None,
- to_token=after_rooms_token,
- )
+ sync_room_map = self._get_sync_room_ids_for_user(
+ UserID.from_string(user1_id),
+ from_token=None,
+ to_token=after_rooms_token,
)
# Try finding only normal rooms
@@ -2714,12 +3384,10 @@ class FilterRoomsTestCase(HomeserverTestCase):
after_rooms_token = self.event_sources.get_current_token()
# Get the rooms the user should be syncing with
- sync_room_map = self.get_success(
- self.sliding_sync_handler.get_sync_room_ids_for_user(
- UserID.from_string(user1_id),
- from_token=None,
- to_token=after_rooms_token,
- )
+ sync_room_map = self._get_sync_room_ids_for_user(
+ UserID.from_string(user1_id),
+ from_token=None,
+ to_token=after_rooms_token,
)
# Try finding *NOT* normal rooms
@@ -2838,12 +3506,10 @@ class FilterRoomsTestCase(HomeserverTestCase):
after_rooms_token = self.event_sources.get_current_token()
# Get the rooms the user should be syncing with
- sync_room_map = self.get_success(
- self.sliding_sync_handler.get_sync_room_ids_for_user(
- UserID.from_string(user1_id),
- from_token=None,
- to_token=after_rooms_token,
- )
+ sync_room_map = self._get_sync_room_ids_for_user(
+ UserID.from_string(user1_id),
+ from_token=None,
+ to_token=after_rooms_token,
)
filtered_room_map = self.get_success(
@@ -2884,6 +3550,31 @@ class SortRoomsTestCase(HomeserverTestCase):
self.store = self.hs.get_datastores().main
self.event_sources = hs.get_event_sources()
+ def _get_sync_room_ids_for_user(
+ self,
+ user: UserID,
+ to_token: StreamToken,
+ from_token: Optional[StreamToken],
+ ) -> Dict[str, _RoomMembershipForUser]:
+ """
+ Get the rooms the user should be syncing with
+ """
+ room_membership_for_user_map = self.get_success(
+ self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
+ user=user,
+ from_token=from_token,
+ to_token=to_token,
+ )
+ )
+ filtered_sync_room_map = self.get_success(
+ self.sliding_sync_handler.filter_rooms_relevant_for_sync(
+ user=user,
+ room_membership_for_user_map=room_membership_for_user_map,
+ )
+ )
+
+ return filtered_sync_room_map
+
def test_sort_activity_basic(self) -> None:
"""
Rooms with newer activity are sorted first.
@@ -2903,12 +3594,10 @@ class SortRoomsTestCase(HomeserverTestCase):
after_rooms_token = self.event_sources.get_current_token()
# Get the rooms the user should be syncing with
- sync_room_map = self.get_success(
- self.sliding_sync_handler.get_sync_room_ids_for_user(
- UserID.from_string(user1_id),
- from_token=None,
- to_token=after_rooms_token,
- )
+ sync_room_map = self._get_sync_room_ids_for_user(
+ UserID.from_string(user1_id),
+ from_token=None,
+ to_token=after_rooms_token,
)
# Sort the rooms (what we're testing)
@@ -2986,12 +3675,10 @@ class SortRoomsTestCase(HomeserverTestCase):
self.helper.send(room_id3, "activity in room3", tok=user2_tok)
# Get the rooms the user should be syncing with
- sync_room_map = self.get_success(
- self.sliding_sync_handler.get_sync_room_ids_for_user(
- UserID.from_string(user1_id),
- from_token=before_rooms_token,
- to_token=after_rooms_token,
- )
+ sync_room_map = self._get_sync_room_ids_for_user(
+ UserID.from_string(user1_id),
+ from_token=before_rooms_token,
+ to_token=after_rooms_token,
)
# Sort the rooms (what we're testing)
@@ -3052,12 +3739,10 @@ class SortRoomsTestCase(HomeserverTestCase):
after_rooms_token = self.event_sources.get_current_token()
# Get the rooms the user should be syncing with
- sync_room_map = self.get_success(
- self.sliding_sync_handler.get_sync_room_ids_for_user(
- UserID.from_string(user1_id),
- from_token=None,
- to_token=after_rooms_token,
- )
+ sync_room_map = self._get_sync_room_ids_for_user(
+ UserID.from_string(user1_id),
+ from_token=None,
+ to_token=after_rooms_token,
)
# Sort the rooms (what we're testing)
diff --git a/tests/media/test_media_storage.py b/tests/media/test_media_storage.py
index 70912e22f8..e55001fb40 100644
--- a/tests/media/test_media_storage.py
+++ b/tests/media/test_media_storage.py
@@ -1057,13 +1057,15 @@ class RemoteDownloadLimiterTestCase(unittest.HomeserverTestCase):
)
assert channel.code == 200
+ @override_config({"remote_media_download_burst_count": "87M"})
@patch(
"synapse.http.matrixfederationclient.read_body_with_max_size",
read_body_with_max_size_30MiB,
)
- def test_download_ratelimit_max_size_sub(self) -> None:
+ def test_download_ratelimit_unknown_length(self) -> None:
"""
- Test that if no content-length is provided, the default max size is applied instead
+ Test that if no content-length is provided, ratelimit will still be applied after
+ download once length is known
"""
# mock out actually sending the request
@@ -1077,19 +1079,48 @@ class RemoteDownloadLimiterTestCase(unittest.HomeserverTestCase):
self.client._send_request = _send_request # type: ignore
- # ten requests should go through using the max size (500MB/50MB)
- for i in range(10):
- channel2 = self.make_request(
+ # 3 requests should go through (note 3rd one would technically violate ratelimit but
+ # is applied *after* download - the next one will be ratelimited)
+ for i in range(3):
+ channel = self.make_request(
"GET",
f"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxy{i}",
shorthand=False,
)
- assert channel2.code == 200
+ assert channel.code == 200
- # eleventh will hit ratelimit
- channel3 = self.make_request(
+ # 4th will hit ratelimit
+ channel2 = self.make_request(
"GET",
"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyx",
shorthand=False,
)
- assert channel3.code == 429
+ assert channel2.code == 429
+
+ @override_config({"max_upload_size": "29M"})
+ @patch(
+ "synapse.http.matrixfederationclient.read_body_with_max_size",
+ read_body_with_max_size_30MiB,
+ )
+ def test_max_download_respected(self) -> None:
+ """
+ Test that the max download size is enforced - note that max download size is determined
+ by the max_upload_size
+ """
+
+ # mock out actually sending the request
+ async def _send_request(*args: Any, **kwargs: Any) -> IResponse:
+ resp = MagicMock(spec=IResponse)
+ resp.code = 200
+ resp.length = 31457280
+ resp.headers = Headers({"Content-Type": ["application/octet-stream"]})
+ resp.phrase = b"OK"
+ return resp
+
+ self.client._send_request = _send_request # type: ignore
+
+ channel = self.make_request(
+ "GET", "/_matrix/media/v3/download/remote.org/abcd", shorthand=False
+ )
+ assert channel.code == 502
+ assert channel.json_body["errcode"] == "M_TOO_LARGE"
diff --git a/tests/rest/client/test_media.py b/tests/rest/client/test_media.py
index 7f2caed7d5..466c5a0b70 100644
--- a/tests/rest/client/test_media.py
+++ b/tests/rest/client/test_media.py
@@ -1809,13 +1809,19 @@ class RemoteDownloadLimiterTestCase(unittest.HomeserverTestCase):
)
assert channel.code == 200
+ @override_config(
+ {
+ "remote_media_download_burst_count": "87M",
+ }
+ )
@patch(
"synapse.http.matrixfederationclient.read_multipart_response",
read_multipart_response_30MiB,
)
- def test_download_ratelimit_max_size_sub(self) -> None:
+ def test_download_ratelimit_unknown_length(self) -> None:
"""
- Test that if no content-length is provided, the default max size is applied instead
+ Test that if no content-length is provided, ratelimiting is still applied after
+ media is downloaded and length is known
"""
# mock out actually sending the request
@@ -1831,8 +1837,9 @@ class RemoteDownloadLimiterTestCase(unittest.HomeserverTestCase):
self.client._send_request = _send_request # type: ignore
- # ten requests should go through using the max size (500MB/50MB)
- for i in range(10):
+ # first 3 will go through (note that 3rd request technically violates rate limit but
+ # that since the ratelimiting is applied *after* download it goes through, but next one fails)
+ for i in range(3):
channel2 = self.make_request(
"GET",
f"/_matrix/client/v1/media/download/remote.org/abc{i}",
@@ -1841,7 +1848,7 @@ class RemoteDownloadLimiterTestCase(unittest.HomeserverTestCase):
)
assert channel2.code == 200
- # eleventh will hit ratelimit
+ # 4th will hit ratelimit
channel3 = self.make_request(
"GET",
"/_matrix/client/v1/media/download/remote.org/abcd",
@@ -1850,6 +1857,39 @@ class RemoteDownloadLimiterTestCase(unittest.HomeserverTestCase):
)
assert channel3.code == 429
+ @override_config({"max_upload_size": "29M"})
+ @patch(
+ "synapse.http.matrixfederationclient.read_multipart_response",
+ read_multipart_response_30MiB,
+ )
+ def test_max_download_respected(self) -> None:
+ """
+ Test that the max download size is enforced - note that max download size is determined
+ by the max_upload_size
+ """
+
+ # mock out actually sending the request, returns a 30MiB response
+ async def _send_request(*args: Any, **kwargs: Any) -> IResponse:
+ resp = MagicMock(spec=IResponse)
+ resp.code = 200
+ resp.length = 31457280
+ resp.headers = Headers(
+ {"Content-Type": ["multipart/mixed; boundary=gc0p4Jq0M2Yt08jU534c0p"]}
+ )
+ resp.phrase = b"OK"
+ return resp
+
+ self.client._send_request = _send_request # type: ignore
+
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/v1/media/download/remote.org/abcd",
+ shorthand=False,
+ access_token=self.tok,
+ )
+ assert channel.code == 502
+ assert channel.json_body["errcode"] == "M_TOO_LARGE"
+
def test_file_download(self) -> None:
content = io.BytesIO(b"file_to_stream")
content_uri = self.get_success(
diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py
index 304c0d4d3d..a008ee465b 100644
--- a/tests/rest/client/test_sync.py
+++ b/tests/rest/client/test_sync.py
@@ -20,7 +20,8 @@
#
import json
import logging
-from typing import AbstractSet, Any, Dict, Iterable, List, Optional
+from http import HTTPStatus
+from typing import Any, Dict, Iterable, List
from parameterized import parameterized, parameterized_class
@@ -1259,7 +1260,7 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
exact: bool = False,
) -> None:
"""
- Wrapper around `_assertIncludes` to give slightly better looking diff error
+ Wrapper around `assertIncludes` to give slightly better looking diff error
messages that include some context "$event_id (type, state_key)".
Args:
@@ -1275,7 +1276,7 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
for event in actual_required_state:
assert isinstance(event, dict)
- self._assertIncludes(
+ self.assertIncludes(
{
f'{event["event_id"]} ("{event["type"]}", "{event["state_key"]}")'
for event in actual_required_state
@@ -1289,56 +1290,6 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
message=str(actual_required_state),
)
- def _assertIncludes(
- self,
- actual_items: AbstractSet[str],
- expected_items: AbstractSet[str],
- exact: bool = False,
- message: Optional[str] = None,
- ) -> None:
- """
- Assert that all of the `expected_items` are included in the `actual_items`.
-
- This assert could also be called `assertContains`, `assertItemsInSet`
-
- Args:
- actual_items: The container
- expected_items: The items to check for in the container
- exact: Whether the actual state should be exactly equal to the expected
- state (no extras).
- message: Optional message to include in the failure message.
- """
- # Check that each set has the same items
- if exact and actual_items == expected_items:
- return
- # Check for a superset
- elif not exact and actual_items >= expected_items:
- return
-
- expected_lines: List[str] = []
- for expected_item in expected_items:
- is_expected_in_actual = expected_item in actual_items
- expected_lines.append(
- "{} {}".format(" " if is_expected_in_actual else "?", expected_item)
- )
-
- actual_lines: List[str] = []
- for actual_item in actual_items:
- is_actual_in_expected = actual_item in expected_items
- actual_lines.append(
- "{} {}".format("+" if is_actual_in_expected else " ", actual_item)
- )
-
- newline = "\n"
- expected_string = f"Expected items to be in actual ('?' = missing expected items):\n {{\n{newline.join(expected_lines)}\n }}"
- actual_string = f"Actual ('+' = found expected items):\n {{\n{newline.join(actual_lines)}\n }}"
- first_message = (
- "Items must match exactly" if exact else "Some expected items are missing."
- )
- diff_message = f"{first_message}\n{expected_string}\n{actual_string}"
-
- self.fail(f"{diff_message}\n{message}")
-
def _add_new_dm_to_global_account_data(
self, source_user_id: str, target_user_id: str, target_room_id: str
) -> None:
@@ -1662,6 +1613,20 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
list(channel.json_body["lists"]["room-invites"]),
)
+ # Ensure DM's are correctly marked
+ self.assertDictEqual(
+ {
+ room_id: room.get("is_dm")
+ for room_id, room in channel.json_body["rooms"].items()
+ },
+ {
+ invite_room_id: None,
+ room_id: None,
+ invited_dm_room_id: True,
+ joined_dm_room_id: True,
+ },
+ )
+
def test_sort_list(self) -> None:
"""
Test that the `lists` are sorted by `stream_ordering`
@@ -1813,8 +1778,8 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
def test_rooms_meta_when_joined(self) -> None:
"""
- Test that the `rooms` `name` and `avatar` (soon to test `heroes`) are included
- in the response when the user is joined to the room.
+ Test that the `rooms` `name` and `avatar` are included in the response and
+ reflect the current state of the room when the user is joined to the room.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
@@ -1866,11 +1831,22 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
"mxc://DUMMY_MEDIA_ID",
channel.json_body["rooms"][room_id1],
)
+ self.assertEqual(
+ channel.json_body["rooms"][room_id1]["joined_count"],
+ 2,
+ )
+ self.assertEqual(
+ channel.json_body["rooms"][room_id1]["invited_count"],
+ 0,
+ )
+ self.assertIsNone(
+ channel.json_body["rooms"][room_id1].get("is_dm"),
+ )
def test_rooms_meta_when_invited(self) -> None:
"""
- Test that the `rooms` `name` and `avatar` (soon to test `heroes`) are included
- in the response when the user is invited to the room.
+ Test that the `rooms` `name` and `avatar` are included in the response and
+ reflect the current state of the room when the user is invited to the room.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
@@ -1892,7 +1868,8 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
tok=user2_tok,
)
- self.helper.join(room_id1, user1_id, tok=user1_tok)
+ # User1 is invited to the room
+ self.helper.invite(room_id1, src=user2_id, targ=user1_id, tok=user2_tok)
# Update the room name after user1 has left
self.helper.send_state(
@@ -1938,11 +1915,22 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
"mxc://UPDATED_DUMMY_MEDIA_ID",
channel.json_body["rooms"][room_id1],
)
+ self.assertEqual(
+ channel.json_body["rooms"][room_id1]["joined_count"],
+ 1,
+ )
+ self.assertEqual(
+ channel.json_body["rooms"][room_id1]["invited_count"],
+ 1,
+ )
+ self.assertIsNone(
+ channel.json_body["rooms"][room_id1].get("is_dm"),
+ )
def test_rooms_meta_when_banned(self) -> None:
"""
- Test that the `rooms` `name` and `avatar` (soon to test `heroes`) reflect the
- state of the room when the user was banned (do not leak current state).
+ Test that the `rooms` `name` and `avatar` reflect the state of the room when the
+ user was banned (do not leak current state).
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
@@ -2010,6 +1998,273 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
"mxc://DUMMY_MEDIA_ID",
channel.json_body["rooms"][room_id1],
)
+ self.assertEqual(
+ channel.json_body["rooms"][room_id1]["joined_count"],
+ # FIXME: The actual number should be "1" (user2) but we currently don't
+ # support this for rooms where the user has left/been banned.
+ 0,
+ )
+ self.assertEqual(
+ channel.json_body["rooms"][room_id1]["invited_count"],
+ 0,
+ )
+ self.assertIsNone(
+ channel.json_body["rooms"][room_id1].get("is_dm"),
+ )
+
+ def test_rooms_meta_heroes(self) -> None:
+ """
+ Test that the `rooms` `heroes` are included in the response when the room
+ doesn't have a room name set.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+ user3_id = self.register_user("user3", "pass")
+ _user3_tok = self.login(user3_id, "pass")
+
+ room_id1 = self.helper.create_room_as(
+ user2_id,
+ tok=user2_tok,
+ extra_content={
+ "name": "my super room",
+ },
+ )
+ self.helper.join(room_id1, user1_id, tok=user1_tok)
+ # User3 is invited
+ self.helper.invite(room_id1, src=user2_id, targ=user3_id, tok=user2_tok)
+
+ room_id2 = self.helper.create_room_as(
+ user2_id,
+ tok=user2_tok,
+ extra_content={
+ # No room name set so that `heroes` is populated
+ #
+ # "name": "my super room2",
+ },
+ )
+ self.helper.join(room_id2, user1_id, tok=user1_tok)
+ # User3 is invited
+ self.helper.invite(room_id2, src=user2_id, targ=user3_id, tok=user2_tok)
+
+ # Make the Sliding Sync request
+ channel = self.make_request(
+ "POST",
+ self.sync_endpoint,
+ {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ "required_state": [],
+ "timeline_limit": 0,
+ }
+ }
+ },
+ access_token=user1_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # Room1 has a name so we shouldn't see any `heroes` which the client would use
+ # the calculate the room name themselves.
+ self.assertEqual(
+ channel.json_body["rooms"][room_id1]["name"],
+ "my super room",
+ channel.json_body["rooms"][room_id1],
+ )
+ self.assertIsNone(channel.json_body["rooms"][room_id1].get("heroes"))
+ self.assertEqual(
+ channel.json_body["rooms"][room_id1]["joined_count"],
+ 2,
+ )
+ self.assertEqual(
+ channel.json_body["rooms"][room_id1]["invited_count"],
+ 1,
+ )
+
+ # Room2 doesn't have a name so we should see `heroes` populated
+ self.assertIsNone(channel.json_body["rooms"][room_id2].get("name"))
+ self.assertCountEqual(
+ [
+ hero["user_id"]
+ for hero in channel.json_body["rooms"][room_id2].get("heroes", [])
+ ],
+ # Heroes shouldn't include the user themselves (we shouldn't see user1)
+ [user2_id, user3_id],
+ )
+ self.assertEqual(
+ channel.json_body["rooms"][room_id2]["joined_count"],
+ 2,
+ )
+ self.assertEqual(
+ channel.json_body["rooms"][room_id2]["invited_count"],
+ 1,
+ )
+
+ # We didn't request any state so we shouldn't see any `required_state`
+ self.assertIsNone(channel.json_body["rooms"][room_id1].get("required_state"))
+ self.assertIsNone(channel.json_body["rooms"][room_id2].get("required_state"))
+
+ def test_rooms_meta_heroes_max(self) -> None:
+ """
+ Test that the `rooms` `heroes` only includes the first 5 users (not including
+ yourself).
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+ user3_id = self.register_user("user3", "pass")
+ user3_tok = self.login(user3_id, "pass")
+ user4_id = self.register_user("user4", "pass")
+ user4_tok = self.login(user4_id, "pass")
+ user5_id = self.register_user("user5", "pass")
+ user5_tok = self.login(user5_id, "pass")
+ user6_id = self.register_user("user6", "pass")
+ user6_tok = self.login(user6_id, "pass")
+ user7_id = self.register_user("user7", "pass")
+ user7_tok = self.login(user7_id, "pass")
+
+ room_id1 = self.helper.create_room_as(
+ user2_id,
+ tok=user2_tok,
+ extra_content={
+ # No room name set so that `heroes` is populated
+ #
+ # "name": "my super room",
+ },
+ )
+ self.helper.join(room_id1, user1_id, tok=user1_tok)
+ self.helper.join(room_id1, user3_id, tok=user3_tok)
+ self.helper.join(room_id1, user4_id, tok=user4_tok)
+ self.helper.join(room_id1, user5_id, tok=user5_tok)
+ self.helper.join(room_id1, user6_id, tok=user6_tok)
+ self.helper.join(room_id1, user7_id, tok=user7_tok)
+
+ # Make the Sliding Sync request
+ channel = self.make_request(
+ "POST",
+ self.sync_endpoint,
+ {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ "required_state": [],
+ "timeline_limit": 0,
+ }
+ }
+ },
+ access_token=user1_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # Room2 doesn't have a name so we should see `heroes` populated
+ self.assertIsNone(channel.json_body["rooms"][room_id1].get("name"))
+ self.assertCountEqual(
+ [
+ hero["user_id"]
+ for hero in channel.json_body["rooms"][room_id1].get("heroes", [])
+ ],
+ # Heroes should be the first 5 users in the room (excluding the user
+ # themselves, we shouldn't see `user1`)
+ [user2_id, user3_id, user4_id, user5_id, user6_id],
+ )
+ self.assertEqual(
+ channel.json_body["rooms"][room_id1]["joined_count"],
+ 7,
+ )
+ self.assertEqual(
+ channel.json_body["rooms"][room_id1]["invited_count"],
+ 0,
+ )
+
+ # We didn't request any state so we shouldn't see any `required_state`
+ self.assertIsNone(channel.json_body["rooms"][room_id1].get("required_state"))
+
+ def test_rooms_meta_heroes_when_banned(self) -> None:
+ """
+ Test that the `rooms` `heroes` are included in the response when the room
+ doesn't have a room name set but doesn't leak information past their ban.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+ user3_id = self.register_user("user3", "pass")
+ _user3_tok = self.login(user3_id, "pass")
+ user4_id = self.register_user("user4", "pass")
+ user4_tok = self.login(user4_id, "pass")
+ user5_id = self.register_user("user5", "pass")
+ _user5_tok = self.login(user5_id, "pass")
+
+ room_id1 = self.helper.create_room_as(
+ user2_id,
+ tok=user2_tok,
+ extra_content={
+ # No room name set so that `heroes` is populated
+ #
+ # "name": "my super room",
+ },
+ )
+ # User1 joins the room
+ self.helper.join(room_id1, user1_id, tok=user1_tok)
+ # User3 is invited
+ self.helper.invite(room_id1, src=user2_id, targ=user3_id, tok=user2_tok)
+
+ # User1 is banned from the room
+ self.helper.ban(room_id1, src=user2_id, targ=user1_id, tok=user2_tok)
+
+ # User4 joins the room after user1 is banned
+ self.helper.join(room_id1, user4_id, tok=user4_tok)
+ # User5 is invited after user1 is banned
+ self.helper.invite(room_id1, src=user2_id, targ=user5_id, tok=user2_tok)
+
+ # Make the Sliding Sync request
+ channel = self.make_request(
+ "POST",
+ self.sync_endpoint,
+ {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ "required_state": [],
+ "timeline_limit": 0,
+ }
+ }
+ },
+ access_token=user1_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # Room2 doesn't have a name so we should see `heroes` populated
+ self.assertIsNone(channel.json_body["rooms"][room_id1].get("name"))
+ self.assertCountEqual(
+ [
+ hero["user_id"]
+ for hero in channel.json_body["rooms"][room_id1].get("heroes", [])
+ ],
+ # Heroes shouldn't include the user themselves (we shouldn't see user1). We
+ # also shouldn't see user4 since they joined after user1 was banned.
+ #
+ # FIXME: The actual result should be `[user2_id, user3_id]` but we currently
+ # don't support this for rooms where the user has left/been banned.
+ [],
+ )
+
+ self.assertEqual(
+ channel.json_body["rooms"][room_id1]["joined_count"],
+ # FIXME: The actual number should be "1" (user2) but we currently don't
+ # support this for rooms where the user has left/been banned.
+ 0,
+ )
+ self.assertEqual(
+ channel.json_body["rooms"][room_id1]["invited_count"],
+ # We shouldn't see user5 since they were invited after user1 was banned.
+ #
+ # FIXME: The actual number should be "1" (user3) but we currently don't
+ # support this for rooms where the user has left/been banned.
+ 0,
+ )
def test_rooms_limited_initial_sync(self) -> None:
"""
@@ -3081,11 +3336,7 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200, channel.json_body)
# Nothing to see for this banned user in the room in the token range
- self.assertEqual(
- channel.json_body["rooms"][room_id1]["timeline"],
- [],
- channel.json_body["rooms"][room_id1]["timeline"],
- )
+ self.assertIsNone(channel.json_body["rooms"][room_id1].get("timeline"))
# No events returned in the timeline so nothing is "live"
self.assertEqual(
channel.json_body["rooms"][room_id1]["num_live"],
@@ -3565,6 +3816,13 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
body={"foo": "bar"},
tok=user2_tok,
)
+ self.helper.send_state(
+ room_id1,
+ event_type="org.matrix.bar_state",
+ state_key="",
+ body={"bar": "qux"},
+ tok=user2_tok,
+ )
# Make the Sliding Sync request with wildcards for the `state_key`
channel = self.make_request(
@@ -3588,16 +3846,13 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
],
"timeline_limit": 0,
},
- }
- # TODO: Room subscription should also combine with the `required_state`
- # "room_subscriptions": {
- # room_id1: {
- # "required_state": [
- # ["org.matrix.bar_state", ""]
- # ],
- # "timeline_limit": 0,
- # }
- # }
+ },
+ "room_subscriptions": {
+ room_id1: {
+ "required_state": [["org.matrix.bar_state", ""]],
+ "timeline_limit": 0,
+ }
+ },
},
access_token=user1_tok,
)
@@ -3614,6 +3869,7 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
state_map[(EventTypes.Member, user1_id)],
state_map[(EventTypes.Member, user2_id)],
state_map[("org.matrix.foo_state", "")],
+ state_map[("org.matrix.bar_state", "")],
},
exact=True,
)
@@ -3706,6 +3962,271 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
channel.json_body["lists"]["foo-list"],
)
+ def test_room_subscriptions_with_join_membership(self) -> None:
+ """
+ Test `room_subscriptions` with a joined room should give us timeline and current
+ state events.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
+ join_response = self.helper.join(room_id1, user1_id, tok=user1_tok)
+
+ # Make the Sliding Sync request with just the room subscription
+ channel = self.make_request(
+ "POST",
+ self.sync_endpoint,
+ {
+ "room_subscriptions": {
+ room_id1: {
+ "required_state": [
+ [EventTypes.Create, ""],
+ ],
+ "timeline_limit": 1,
+ }
+ },
+ },
+ access_token=user1_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ state_map = self.get_success(
+ self.storage_controllers.state.get_current_state(room_id1)
+ )
+
+ # We should see some state
+ self._assertRequiredStateIncludes(
+ channel.json_body["rooms"][room_id1]["required_state"],
+ {
+ state_map[(EventTypes.Create, "")],
+ },
+ exact=True,
+ )
+ self.assertIsNone(channel.json_body["rooms"][room_id1].get("invite_state"))
+
+ # We should see some events
+ self.assertEqual(
+ [
+ event["event_id"]
+ for event in channel.json_body["rooms"][room_id1]["timeline"]
+ ],
+ [
+ join_response["event_id"],
+ ],
+ channel.json_body["rooms"][room_id1]["timeline"],
+ )
+ # No "live" events in an initial sync (no `from_token` to define the "live"
+ # range)
+ self.assertEqual(
+ channel.json_body["rooms"][room_id1]["num_live"],
+ 0,
+ channel.json_body["rooms"][room_id1],
+ )
+ # There are more events to paginate to
+ self.assertEqual(
+ channel.json_body["rooms"][room_id1]["limited"],
+ True,
+ channel.json_body["rooms"][room_id1],
+ )
+
+ def test_room_subscriptions_with_leave_membership(self) -> None:
+ """
+ Test `room_subscriptions` with a leave room should give us timeline and state
+ events up to the leave event.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
+ self.helper.send_state(
+ room_id1,
+ event_type="org.matrix.foo_state",
+ state_key="",
+ body={"foo": "bar"},
+ tok=user2_tok,
+ )
+
+ join_response = self.helper.join(room_id1, user1_id, tok=user1_tok)
+ leave_response = self.helper.leave(room_id1, user1_id, tok=user1_tok)
+
+ state_map = self.get_success(
+ self.storage_controllers.state.get_current_state(room_id1)
+ )
+
+ # Send some events after user1 leaves
+ self.helper.send(room_id1, "activity after leave", tok=user2_tok)
+ # Update state after user1 leaves
+ self.helper.send_state(
+ room_id1,
+ event_type="org.matrix.foo_state",
+ state_key="",
+ body={"foo": "qux"},
+ tok=user2_tok,
+ )
+
+ # Make the Sliding Sync request with just the room subscription
+ channel = self.make_request(
+ "POST",
+ self.sync_endpoint,
+ {
+ "room_subscriptions": {
+ room_id1: {
+ "required_state": [
+ ["org.matrix.foo_state", ""],
+ ],
+ "timeline_limit": 2,
+ }
+ },
+ },
+ access_token=user1_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # We should see the state at the time of the leave
+ self._assertRequiredStateIncludes(
+ channel.json_body["rooms"][room_id1]["required_state"],
+ {
+ state_map[("org.matrix.foo_state", "")],
+ },
+ exact=True,
+ )
+ self.assertIsNone(channel.json_body["rooms"][room_id1].get("invite_state"))
+
+ # We should see some before we left (nothing after)
+ self.assertEqual(
+ [
+ event["event_id"]
+ for event in channel.json_body["rooms"][room_id1]["timeline"]
+ ],
+ [
+ join_response["event_id"],
+ leave_response["event_id"],
+ ],
+ channel.json_body["rooms"][room_id1]["timeline"],
+ )
+ # No "live" events in an initial sync (no `from_token` to define the "live"
+ # range)
+ self.assertEqual(
+ channel.json_body["rooms"][room_id1]["num_live"],
+ 0,
+ channel.json_body["rooms"][room_id1],
+ )
+ # There are more events to paginate to
+ self.assertEqual(
+ channel.json_body["rooms"][room_id1]["limited"],
+ True,
+ channel.json_body["rooms"][room_id1],
+ )
+
+ def test_room_subscriptions_no_leak_private_room(self) -> None:
+ """
+ Test `room_subscriptions` with a private room we have never been in should not
+ leak any data to the user.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=False)
+
+ # We should not be able to join the private room
+ self.helper.join(
+ room_id1, user1_id, tok=user1_tok, expect_code=HTTPStatus.FORBIDDEN
+ )
+
+ # Make the Sliding Sync request with just the room subscription
+ channel = self.make_request(
+ "POST",
+ self.sync_endpoint,
+ {
+ "room_subscriptions": {
+ room_id1: {
+ "required_state": [
+ [EventTypes.Create, ""],
+ ],
+ "timeline_limit": 1,
+ }
+ },
+ },
+ access_token=user1_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # We should not see the room at all (we're not in it)
+ self.assertIsNone(
+ channel.json_body["rooms"].get(room_id1), channel.json_body["rooms"]
+ )
+
+ def test_room_subscriptions_world_readable(self) -> None:
+ """
+ Test `room_subscriptions` with a room that has `world_readable` history visibility
+
+ FIXME: We should be able to see the room timeline and state
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ # Create a room with `world_readable` history visibility
+ room_id1 = self.helper.create_room_as(
+ user2_id,
+ tok=user2_tok,
+ extra_content={
+ "preset": "public_chat",
+ "initial_state": [
+ {
+ "content": {
+ "history_visibility": HistoryVisibility.WORLD_READABLE
+ },
+ "state_key": "",
+ "type": EventTypes.RoomHistoryVisibility,
+ }
+ ],
+ },
+ )
+ # Ensure we're testing with a room with `world_readable` history visibility
+ # which means events are visible to anyone even without membership.
+ history_visibility_response = self.helper.get_state(
+ room_id1, EventTypes.RoomHistoryVisibility, tok=user2_tok
+ )
+ self.assertEqual(
+ history_visibility_response.get("history_visibility"),
+ HistoryVisibility.WORLD_READABLE,
+ )
+
+ # Note: We never join the room
+
+ # Make the Sliding Sync request with just the room subscription
+ channel = self.make_request(
+ "POST",
+ self.sync_endpoint,
+ {
+ "room_subscriptions": {
+ room_id1: {
+ "required_state": [
+ [EventTypes.Create, ""],
+ ],
+ "timeline_limit": 1,
+ }
+ },
+ },
+ access_token=user1_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # FIXME: In the future, we should be able to see the room because it's
+ # `world_readable` but currently we don't support this.
+ self.assertIsNone(
+ channel.json_body["rooms"].get(room_id1), channel.json_body["rooms"]
+ )
+
class SlidingSyncToDeviceExtensionTestCase(unittest.HomeserverTestCase):
"""Tests for the to-device sliding sync extension"""
diff --git a/tests/server.py b/tests/server.py
index f1cd0f76be..85602e6953 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -289,10 +289,6 @@ class FakeChannel:
self._reactor.run()
while not self.is_finished():
- # If there's a producer, tell it to resume producing so we get content
- if self._producer:
- self._producer.resumeProducing()
-
if self._reactor.seconds() > end_time:
raise TimedOutException("Timed out waiting for request to finish.")
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 882f3bbbdc..418b556108 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -19,20 +19,28 @@
# [This file includes modifications made by New Vector Limited]
#
#
+import logging
from typing import List, Optional, Tuple, cast
from twisted.test.proto_helpers import MemoryReactor
-from synapse.api.constants import Membership
+from synapse.api.constants import EventTypes, JoinRules, Membership
+from synapse.api.room_versions import RoomVersions
+from synapse.rest import admin
from synapse.rest.admin import register_servlets_for_client_rest_resource
-from synapse.rest.client import login, room
+from synapse.rest.client import knock, login, room
from synapse.server import HomeServer
+from synapse.storage.databases.main.roommember import extract_heroes_from_room_summary
+from synapse.storage.roommember import MemberSummary
from synapse.types import UserID, create_requester
from synapse.util import Clock
from tests import unittest
from tests.server import TestHomeServer
from tests.test_utils import event_injection
+from tests.unittest import skip_unless
+
+logger = logging.getLogger(__name__)
class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
@@ -240,6 +248,397 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
)
+class RoomSummaryTestCase(unittest.HomeserverTestCase):
+ """
+ Test `/sync` room summary related logic like `get_room_summary(...)` and
+ `extract_heroes_from_room_summary(...)`
+ """
+
+ servlets = [
+ admin.register_servlets,
+ knock.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.sliding_sync_handler = self.hs.get_sliding_sync_handler()
+ self.store = self.hs.get_datastores().main
+
+ def _assert_member_summary(
+ self,
+ actual_member_summary: MemberSummary,
+ expected_member_list: List[str],
+ *,
+ expected_member_count: Optional[int] = None,
+ ) -> None:
+ """
+ Assert that the `MemberSummary` object has the expected members.
+ """
+ self.assertListEqual(
+ [
+ user_id
+ for user_id, _membership_event_id in actual_member_summary.members
+ ],
+ expected_member_list,
+ )
+ self.assertEqual(
+ actual_member_summary.count,
+ (
+ expected_member_count
+ if expected_member_count is not None
+ else len(expected_member_list)
+ ),
+ )
+
+ def test_get_room_summary_membership(self) -> None:
+ """
+ Test that `get_room_summary(...)` gets every kind of membership when there
+ aren't that many members in the room.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+ user3_id = self.register_user("user3", "pass")
+ _user3_tok = self.login(user3_id, "pass")
+ user4_id = self.register_user("user4", "pass")
+ user4_tok = self.login(user4_id, "pass")
+ user5_id = self.register_user("user5", "pass")
+ user5_tok = self.login(user5_id, "pass")
+
+ # Setup a room (user1 is the creator and is joined to the room)
+ room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+
+ # User2 is banned
+ self.helper.join(room_id, user2_id, tok=user2_tok)
+ self.helper.ban(room_id, src=user1_id, targ=user2_id, tok=user1_tok)
+
+ # User3 is invited by user1
+ self.helper.invite(room_id, targ=user3_id, tok=user1_tok)
+
+ # User4 leaves
+ self.helper.join(room_id, user4_id, tok=user4_tok)
+ self.helper.leave(room_id, user4_id, tok=user4_tok)
+
+ # User5 joins
+ self.helper.join(room_id, user5_id, tok=user5_tok)
+
+ room_membership_summary = self.get_success(self.store.get_room_summary(room_id))
+ empty_ms = MemberSummary([], 0)
+
+ self._assert_member_summary(
+ room_membership_summary.get(Membership.JOIN, empty_ms),
+ [user1_id, user5_id],
+ )
+ self._assert_member_summary(
+ room_membership_summary.get(Membership.INVITE, empty_ms), [user3_id]
+ )
+ self._assert_member_summary(
+ room_membership_summary.get(Membership.LEAVE, empty_ms), [user4_id]
+ )
+ self._assert_member_summary(
+ room_membership_summary.get(Membership.BAN, empty_ms), [user2_id]
+ )
+ self._assert_member_summary(
+ room_membership_summary.get(Membership.KNOCK, empty_ms),
+ [
+ # No one knocked
+ ],
+ )
+
+ def test_get_room_summary_membership_order(self) -> None:
+ """
+ Test that `get_room_summary(...)` stacks our limit of 6 in this order: joins ->
+ invites -> leave -> everything else (bans/knocks)
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+ user3_id = self.register_user("user3", "pass")
+ _user3_tok = self.login(user3_id, "pass")
+ user4_id = self.register_user("user4", "pass")
+ user4_tok = self.login(user4_id, "pass")
+ user5_id = self.register_user("user5", "pass")
+ user5_tok = self.login(user5_id, "pass")
+ user6_id = self.register_user("user6", "pass")
+ user6_tok = self.login(user6_id, "pass")
+ user7_id = self.register_user("user7", "pass")
+ user7_tok = self.login(user7_id, "pass")
+
+ # Setup the room (user1 is the creator and is joined to the room)
+ room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+
+ # We expect the order to be joins -> invites -> leave -> bans so setup the users
+ # *NOT* in that same order to make sure we're actually sorting them.
+
+ # User2 is banned
+ self.helper.join(room_id, user2_id, tok=user2_tok)
+ self.helper.ban(room_id, src=user1_id, targ=user2_id, tok=user1_tok)
+
+ # User3 is invited by user1
+ self.helper.invite(room_id, targ=user3_id, tok=user1_tok)
+
+ # User4 leaves
+ self.helper.join(room_id, user4_id, tok=user4_tok)
+ self.helper.leave(room_id, user4_id, tok=user4_tok)
+
+ # User5, User6, User7 joins
+ self.helper.join(room_id, user5_id, tok=user5_tok)
+ self.helper.join(room_id, user6_id, tok=user6_tok)
+ self.helper.join(room_id, user7_id, tok=user7_tok)
+
+ room_membership_summary = self.get_success(self.store.get_room_summary(room_id))
+ empty_ms = MemberSummary([], 0)
+
+ self._assert_member_summary(
+ room_membership_summary.get(Membership.JOIN, empty_ms),
+ [user1_id, user5_id, user6_id, user7_id],
+ )
+ self._assert_member_summary(
+ room_membership_summary.get(Membership.INVITE, empty_ms), [user3_id]
+ )
+ self._assert_member_summary(
+ room_membership_summary.get(Membership.LEAVE, empty_ms), [user4_id]
+ )
+ self._assert_member_summary(
+ room_membership_summary.get(Membership.BAN, empty_ms),
+ [
+ # The banned user is not in the summary because the summary can only fit
+ # 6 members and prefers everything else before bans
+ #
+ # user2_id
+ ],
+ # But we still see the count of banned users
+ expected_member_count=1,
+ )
+ self._assert_member_summary(
+ room_membership_summary.get(Membership.KNOCK, empty_ms),
+ [
+ # No one knocked
+ ],
+ )
+
+ def test_extract_heroes_from_room_summary_excludes_self(self) -> None:
+ """
+ Test that `extract_heroes_from_room_summary(...)` does not include the user
+ itself.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ # Setup the room (user1 is the creator and is joined to the room)
+ room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+
+ # User2 joins
+ self.helper.join(room_id, user2_id, tok=user2_tok)
+
+ room_membership_summary = self.get_success(self.store.get_room_summary(room_id))
+
+ # We first ask from the perspective of a random fake user
+ hero_user_ids = extract_heroes_from_room_summary(
+ room_membership_summary, me="@fakeuser"
+ )
+
+ # Make sure user1 is in the room (ensure our test setup is correct)
+ self.assertListEqual(hero_user_ids, [user1_id, user2_id])
+
+ # Now, we ask for the room summary from the perspective of user1
+ hero_user_ids = extract_heroes_from_room_summary(
+ room_membership_summary, me=user1_id
+ )
+
+ # User1 should not be included in the list of heroes because they are the one
+ # asking
+ self.assertListEqual(hero_user_ids, [user2_id])
+
+ def test_extract_heroes_from_room_summary_first_five_joins(self) -> None:
+ """
+ Test that `extract_heroes_from_room_summary(...)` returns the first 5 joins.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+ user3_id = self.register_user("user3", "pass")
+ user3_tok = self.login(user3_id, "pass")
+ user4_id = self.register_user("user4", "pass")
+ user4_tok = self.login(user4_id, "pass")
+ user5_id = self.register_user("user5", "pass")
+ user5_tok = self.login(user5_id, "pass")
+ user6_id = self.register_user("user6", "pass")
+ user6_tok = self.login(user6_id, "pass")
+ user7_id = self.register_user("user7", "pass")
+ user7_tok = self.login(user7_id, "pass")
+
+ # Setup the room (user1 is the creator and is joined to the room)
+ room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+
+ # User2 -> User7 joins
+ self.helper.join(room_id, user2_id, tok=user2_tok)
+ self.helper.join(room_id, user3_id, tok=user3_tok)
+ self.helper.join(room_id, user4_id, tok=user4_tok)
+ self.helper.join(room_id, user5_id, tok=user5_tok)
+ self.helper.join(room_id, user6_id, tok=user6_tok)
+ self.helper.join(room_id, user7_id, tok=user7_tok)
+
+ room_membership_summary = self.get_success(self.store.get_room_summary(room_id))
+
+ hero_user_ids = extract_heroes_from_room_summary(
+ room_membership_summary, me="@fakuser"
+ )
+
+ # First 5 users to join the room
+ self.assertListEqual(
+ hero_user_ids, [user1_id, user2_id, user3_id, user4_id, user5_id]
+ )
+
+ def test_extract_heroes_from_room_summary_membership_order(self) -> None:
+ """
+ Test that `extract_heroes_from_room_summary(...)` prefers joins/invites over
+ everything else.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+ user3_id = self.register_user("user3", "pass")
+ _user3_tok = self.login(user3_id, "pass")
+ user4_id = self.register_user("user4", "pass")
+ user4_tok = self.login(user4_id, "pass")
+ user5_id = self.register_user("user5", "pass")
+ user5_tok = self.login(user5_id, "pass")
+
+ # Setup the room (user1 is the creator and is joined to the room)
+ room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+
+ # We expect the order to be joins -> invites -> leave -> bans so setup the users
+ # *NOT* in that same order to make sure we're actually sorting them.
+
+ # User2 is banned
+ self.helper.join(room_id, user2_id, tok=user2_tok)
+ self.helper.ban(room_id, src=user1_id, targ=user2_id, tok=user1_tok)
+
+ # User3 is invited by user1
+ self.helper.invite(room_id, targ=user3_id, tok=user1_tok)
+
+ # User4 leaves
+ self.helper.join(room_id, user4_id, tok=user4_tok)
+ self.helper.leave(room_id, user4_id, tok=user4_tok)
+
+ # User5 joins
+ self.helper.join(room_id, user5_id, tok=user5_tok)
+
+ room_membership_summary = self.get_success(self.store.get_room_summary(room_id))
+
+ hero_user_ids = extract_heroes_from_room_summary(
+ room_membership_summary, me="@fakeuser"
+ )
+
+ # Prefer joins -> invites, over everything else
+ self.assertListEqual(
+ hero_user_ids,
+ [
+ # The joins
+ user1_id,
+ user5_id,
+ # The invites
+ user3_id,
+ ],
+ )
+
+ @skip_unless(
+ False,
+ "Test is not possible because when everyone leaves the room, "
+ + "the server is `no_longer_in_room` and we don't have any `current_state_events` to query",
+ )
+ def test_extract_heroes_from_room_summary_fallback_leave_ban(self) -> None:
+ """
+ Test that `extract_heroes_from_room_summary(...)` falls back to leave/ban if
+ there aren't any joins/invites.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+ user3_id = self.register_user("user3", "pass")
+ user3_tok = self.login(user3_id, "pass")
+
+ # Setup the room (user1 is the creator and is joined to the room)
+ room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+
+ # User2 is banned
+ self.helper.join(room_id, user2_id, tok=user2_tok)
+ self.helper.ban(room_id, src=user1_id, targ=user2_id, tok=user1_tok)
+
+ # User3 leaves
+ self.helper.join(room_id, user3_id, tok=user3_tok)
+ self.helper.leave(room_id, user3_id, tok=user3_tok)
+
+ # User1 leaves (we're doing this last because they're the room creator)
+ self.helper.leave(room_id, user1_id, tok=user1_tok)
+
+ room_membership_summary = self.get_success(self.store.get_room_summary(room_id))
+
+ hero_user_ids = extract_heroes_from_room_summary(
+ room_membership_summary, me="@fakeuser"
+ )
+
+ # Fallback to people who left -> banned
+ self.assertListEqual(
+ hero_user_ids,
+ [user3_id, user1_id, user3_id],
+ )
+
+ def test_extract_heroes_from_room_summary_excludes_knocks(self) -> None:
+ """
+ People who knock on the room have (potentially) never been in the room before
+ and are total outsiders. Plus the spec doesn't mention them at all for heroes.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ # Setup the knock room (user1 is the creator and is joined to the room)
+ knock_room_id = self.helper.create_room_as(
+ user1_id, tok=user1_tok, room_version=RoomVersions.V7.identifier
+ )
+ self.helper.send_state(
+ knock_room_id,
+ EventTypes.JoinRules,
+ {"join_rule": JoinRules.KNOCK},
+ tok=user1_tok,
+ )
+
+ # User2 knocks on the room
+ knock_channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/knock/%s" % (knock_room_id,),
+ b"{}",
+ user2_tok,
+ )
+ self.assertEqual(knock_channel.code, 200, knock_channel.result)
+
+ room_membership_summary = self.get_success(
+ self.store.get_room_summary(knock_room_id)
+ )
+
+ hero_user_ids = extract_heroes_from_room_summary(
+ room_membership_summary, me="@fakeuser"
+ )
+
+ # user1 is the creator and is joined to the room (should show up as a hero)
+ # user2 is knocking on the room (should not show up as a hero)
+ self.assertListEqual(
+ hero_user_ids,
+ [user1_id],
+ )
+
+
class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
diff --git a/tests/test_types.py b/tests/test_types.py
index 944aa784fc..00adc65a5a 100644
--- a/tests/test_types.py
+++ b/tests/test_types.py
@@ -19,9 +19,18 @@
#
#
+from typing import Type
+from unittest import skipUnless
+
+from immutabledict import immutabledict
+from parameterized import parameterized_class
+
from synapse.api.errors import SynapseError
from synapse.types import (
+ AbstractMultiWriterStreamToken,
+ MultiWriterStreamToken,
RoomAlias,
+ RoomStreamToken,
UserID,
get_domain_from_id,
get_localpart_from_id,
@@ -29,6 +38,7 @@ from synapse.types import (
)
from tests import unittest
+from tests.utils import USE_POSTGRES_FOR_TESTS
class IsMineIDTests(unittest.HomeserverTestCase):
@@ -127,3 +137,64 @@ class MapUsernameTestCase(unittest.TestCase):
# this should work with either a unicode or a bytes
self.assertEqual(map_username_to_mxid_localpart("têst"), "t=c3=aast")
self.assertEqual(map_username_to_mxid_localpart("têst".encode()), "t=c3=aast")
+
+
+@parameterized_class(
+ ("token_type",),
+ [
+ (MultiWriterStreamToken,),
+ (RoomStreamToken,),
+ ],
+ class_name_func=lambda cls, num, params_dict: f"{cls.__name__}_{params_dict['token_type'].__name__}",
+)
+class MultiWriterTokenTestCase(unittest.HomeserverTestCase):
+ """Tests for the different types of multi writer tokens."""
+
+ token_type: Type[AbstractMultiWriterStreamToken]
+
+ def test_basic_token(self) -> None:
+ """Test that a simple stream token can be serialized and unserialized"""
+ store = self.hs.get_datastores().main
+
+ token = self.token_type(stream=5)
+
+ string_token = self.get_success(token.to_string(store))
+
+ if isinstance(token, RoomStreamToken):
+ self.assertEqual(string_token, "s5")
+ else:
+ self.assertEqual(string_token, "5")
+
+ parsed_token = self.get_success(self.token_type.parse(store, string_token))
+ self.assertEqual(parsed_token, token)
+
+ @skipUnless(USE_POSTGRES_FOR_TESTS, "Requires Postgres")
+ def test_instance_map(self) -> None:
+ """Test for stream token with instance map"""
+ store = self.hs.get_datastores().main
+
+ token = self.token_type(stream=5, instance_map=immutabledict({"foo": 6}))
+
+ string_token = self.get_success(token.to_string(store))
+ self.assertEqual(string_token, "m5~1.6")
+
+ parsed_token = self.get_success(self.token_type.parse(store, string_token))
+ self.assertEqual(parsed_token, token)
+
+ def test_instance_map_assertion(self) -> None:
+ """Test that we assert values in the instance map are greater than the
+ min stream position"""
+
+ with self.assertRaises(ValueError):
+ self.token_type(stream=5, instance_map=immutabledict({"foo": 4}))
+
+ with self.assertRaises(ValueError):
+ self.token_type(stream=5, instance_map=immutabledict({"foo": 5}))
+
+ def test_parse_bad_token(self) -> None:
+ """Test that we can parse tokens produced by a bug in Synapse of the
+ form `m5~`"""
+ store = self.hs.get_datastores().main
+
+ parsed_token = self.get_success(self.token_type.parse(store, "m5~"))
+ self.assertEqual(parsed_token, self.token_type(stream=5))
diff --git a/tests/unittest.py b/tests/unittest.py
index a7c20556a0..4aa7f56106 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -28,6 +28,7 @@ import logging
import secrets
import time
from typing import (
+ AbstractSet,
Any,
Awaitable,
Callable,
@@ -269,6 +270,56 @@ class TestCase(unittest.TestCase):
required[key], actual[key], msg="%s mismatch. %s" % (key, actual)
)
+ def assertIncludes(
+ self,
+ actual_items: AbstractSet[str],
+ expected_items: AbstractSet[str],
+ exact: bool = False,
+ message: Optional[str] = None,
+ ) -> None:
+ """
+ Assert that all of the `expected_items` are included in the `actual_items`.
+
+ This assert could also be called `assertContains`, `assertItemsInSet`
+
+ Args:
+ actual_items: The container
+ expected_items: The items to check for in the container
+ exact: Whether the actual state should be exactly equal to the expected
+ state (no extras).
+ message: Optional message to include in the failure message.
+ """
+ # Check that each set has the same items
+ if exact and actual_items == expected_items:
+ return
+ # Check for a superset
+ elif not exact and actual_items >= expected_items:
+ return
+
+ expected_lines: List[str] = []
+ for expected_item in expected_items:
+ is_expected_in_actual = expected_item in actual_items
+ expected_lines.append(
+ "{} {}".format(" " if is_expected_in_actual else "?", expected_item)
+ )
+
+ actual_lines: List[str] = []
+ for actual_item in actual_items:
+ is_actual_in_expected = actual_item in expected_items
+ actual_lines.append(
+ "{} {}".format("+" if is_actual_in_expected else " ", actual_item)
+ )
+
+ newline = "\n"
+ expected_string = f"Expected items to be in actual ('?' = missing expected items):\n {{\n{newline.join(expected_lines)}\n }}"
+ actual_string = f"Actual ('+' = found expected items):\n {{\n{newline.join(actual_lines)}\n }}"
+ first_message = (
+ "Items must match exactly" if exact else "Some expected items are missing."
+ )
+ diff_message = f"{first_message}\n{expected_string}\n{actual_string}"
+
+ self.fail(f"{diff_message}\n{message}")
+
def DEBUG(target: TV) -> TV:
"""A decorator to set the .loglevel attribute to logging.DEBUG.
diff --git a/tests/util/test_check_dependencies.py b/tests/util/test_check_dependencies.py
index fb67146c69..13a4e6ddaa 100644
--- a/tests/util/test_check_dependencies.py
+++ b/tests/util/test_check_dependencies.py
@@ -21,6 +21,7 @@
from contextlib import contextmanager
from os import PathLike
+from pathlib import Path
from typing import Generator, Optional, Union
from unittest.mock import patch
@@ -41,7 +42,7 @@ class DummyDistribution(metadata.Distribution):
def version(self) -> str:
return self._version
- def locate_file(self, path: Union[str, PathLike]) -> PathLike:
+ def locate_file(self, path: Union[str, PathLike]) -> Path:
raise NotImplementedError()
def read_text(self, filename: str) -> None:
|