diff --git a/tests/federation/test_federation_media.py b/tests/federation/test_federation_media.py
index 2c396adbe3..142f73cfdb 100644
--- a/tests/federation/test_federation_media.py
+++ b/tests/federation/test_federation_media.py
@@ -36,10 +36,9 @@ from synapse.util import Clock
from tests import unittest
from tests.test_utils import SMALL_PNG
-from tests.unittest import override_config
-class FederationUnstableMediaDownloadsTest(unittest.FederatingHomeserverTestCase):
+class FederationMediaDownloadsTest(unittest.FederatingHomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
super().prepare(reactor, clock, hs)
@@ -65,9 +64,6 @@ class FederationUnstableMediaDownloadsTest(unittest.FederatingHomeserverTestCase
)
self.media_repo = hs.get_media_repository()
- @override_config(
- {"experimental_features": {"msc3916_authenticated_media_enabled": True}}
- )
def test_file_download(self) -> None:
content = io.BytesIO(b"file_to_stream")
content_uri = self.get_success(
@@ -82,7 +78,7 @@ class FederationUnstableMediaDownloadsTest(unittest.FederatingHomeserverTestCase
# test with a text file
channel = self.make_signed_federation_request(
"GET",
- f"/_matrix/federation/unstable/org.matrix.msc3916/media/download/{content_uri.media_id}",
+ f"/_matrix/federation/v1/media/download/{content_uri.media_id}",
)
self.pump()
self.assertEqual(200, channel.code)
@@ -106,7 +102,8 @@ class FederationUnstableMediaDownloadsTest(unittest.FederatingHomeserverTestCase
# check that the text file and expected value exist
found_file = any(
- "\r\nContent-Type: text/plain\r\n\r\nfile_to_stream" in field
+ "\r\nContent-Type: text/plain\r\nContent-Disposition: inline; filename=test_upload\r\n\r\nfile_to_stream"
+ in field
for field in stripped
)
self.assertTrue(found_file)
@@ -124,7 +121,7 @@ class FederationUnstableMediaDownloadsTest(unittest.FederatingHomeserverTestCase
# test with an image file
channel = self.make_signed_federation_request(
"GET",
- f"/_matrix/federation/unstable/org.matrix.msc3916/media/download/{content_uri.media_id}",
+ f"/_matrix/federation/v1/media/download/{content_uri.media_id}",
)
self.pump()
self.assertEqual(200, channel.code)
@@ -149,25 +146,3 @@ class FederationUnstableMediaDownloadsTest(unittest.FederatingHomeserverTestCase
# check that the png file exists and matches what was uploaded
found_file = any(SMALL_PNG in field for field in stripped_bytes)
self.assertTrue(found_file)
-
- @override_config(
- {"experimental_features": {"msc3916_authenticated_media_enabled": False}}
- )
- def test_disable_config(self) -> None:
- content = io.BytesIO(b"file_to_stream")
- content_uri = self.get_success(
- self.media_repo.create_content(
- "text/plain",
- "test_upload",
- content,
- 46,
- UserID.from_string("@user_id:whatever.org"),
- )
- )
- channel = self.make_signed_federation_request(
- "GET",
- f"/_matrix/federation/unstable/org.matrix.msc3916/media/download/{content_uri.media_id}",
- )
- self.pump()
- self.assertEqual(404, channel.code)
- self.assertEqual(channel.json_body.get("errcode"), "M_UNRECOGNIZED")
diff --git a/tests/handlers/test_sliding_sync.py b/tests/handlers/test_sliding_sync.py
index 8dd4521b18..713a798703 100644
--- a/tests/handlers/test_sliding_sync.py
+++ b/tests/handlers/test_sliding_sync.py
@@ -24,7 +24,14 @@ from parameterized import parameterized
from twisted.test.proto_helpers import MemoryReactor
-from synapse.api.constants import AccountDataTypes, EventTypes, JoinRules, Membership
+from synapse.api.constants import (
+ AccountDataTypes,
+ EventContentFields,
+ EventTypes,
+ JoinRules,
+ Membership,
+ RoomTypes,
+)
from synapse.api.room_versions import RoomVersions
from synapse.handlers.sliding_sync import SlidingSyncConfig
from synapse.rest import admin
@@ -63,6 +70,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
self.sliding_sync_handler = self.hs.get_sliding_sync_handler()
self.store = self.hs.get_datastores().main
self.event_sources = hs.get_event_sources()
+ self.storage_controllers = hs.get_storage_controllers()
def test_no_rooms(self) -> None:
"""
@@ -90,10 +98,13 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
before_room_token = self.event_sources.get_current_token()
- room_id = self.helper.create_room_as(user1_id, tok=user1_tok, is_public=True)
+ room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
+ join_response = self.helper.join(room_id, user1_id, tok=user1_tok)
after_room_token = self.event_sources.get_current_token()
@@ -106,6 +117,15 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
)
self.assertEqual(room_id_results.keys(), {room_id})
+ # It should be pointing to the join event (latest membership event in the
+ # from/to range)
+ self.assertEqual(
+ room_id_results[room_id].event_id,
+ join_response["event_id"],
+ )
+ # We should be considered `newly_joined` because we joined during the token
+ # range
+ self.assertEqual(room_id_results[room_id].newly_joined, True)
def test_get_already_joined_room(self) -> None:
"""
@@ -113,8 +133,11 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
- room_id = self.helper.create_room_as(user1_id, tok=user1_tok, is_public=True)
+ room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
+ join_response = self.helper.join(room_id, user1_id, tok=user1_tok)
after_room_token = self.event_sources.get_current_token()
@@ -127,6 +150,14 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
)
self.assertEqual(room_id_results.keys(), {room_id})
+ # It should be pointing to the join event (latest membership event in the
+ # from/to range)
+ self.assertEqual(
+ room_id_results[room_id].event_id,
+ join_response["event_id"],
+ )
+ # We should *NOT* be `newly_joined` because we joined before the token range
+ self.assertEqual(room_id_results[room_id].newly_joined, False)
def test_get_invited_banned_knocked_room(self) -> None:
"""
@@ -142,14 +173,18 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
# Setup the invited room (user2 invites user1 to the room)
invited_room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
- self.helper.invite(invited_room_id, targ=user1_id, tok=user2_tok)
+ invite_response = self.helper.invite(
+ invited_room_id, targ=user1_id, tok=user2_tok
+ )
# Setup the ban room (user2 bans user1 from the room)
ban_room_id = self.helper.create_room_as(
user2_id, tok=user2_tok, is_public=True
)
self.helper.join(ban_room_id, user1_id, tok=user1_tok)
- self.helper.ban(ban_room_id, src=user2_id, targ=user1_id, tok=user2_tok)
+ ban_response = self.helper.ban(
+ ban_room_id, src=user2_id, targ=user1_id, tok=user2_tok
+ )
# Setup the knock room (user1 knocks on the room)
knock_room_id = self.helper.create_room_as(
@@ -162,13 +197,19 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
tok=user2_tok,
)
# User1 knocks on the room
- channel = self.make_request(
+ knock_channel = self.make_request(
"POST",
"/_matrix/client/r0/knock/%s" % (knock_room_id,),
b"{}",
user1_tok,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(knock_channel.code, 200, knock_channel.result)
+ knock_room_membership_state_event = self.get_success(
+ self.storage_controllers.state.get_current_state_event(
+ knock_room_id, EventTypes.Member, user1_id
+ )
+ )
+ assert knock_room_membership_state_event is not None
after_room_token = self.event_sources.get_current_token()
@@ -189,6 +230,25 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
knock_room_id,
},
)
+ # It should be pointing to the the respective membership event (latest
+ # membership event in the from/to range)
+ self.assertEqual(
+ room_id_results[invited_room_id].event_id,
+ invite_response["event_id"],
+ )
+ self.assertEqual(
+ room_id_results[ban_room_id].event_id,
+ ban_response["event_id"],
+ )
+ self.assertEqual(
+ room_id_results[knock_room_id].event_id,
+ knock_room_membership_state_event.event_id,
+ )
+ # We should *NOT* be `newly_joined` because we were not joined at the the time
+ # of the `to_token`.
+ self.assertEqual(room_id_results[invited_room_id].newly_joined, False)
+ self.assertEqual(room_id_results[ban_room_id].newly_joined, False)
+ self.assertEqual(room_id_results[knock_room_id].newly_joined, False)
def test_get_kicked_room(self) -> None:
"""
@@ -206,7 +266,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
)
self.helper.join(kick_room_id, user1_id, tok=user1_tok)
# Kick user1 from the room
- self.helper.change_membership(
+ kick_response = self.helper.change_membership(
room=kick_room_id,
src=user2_id,
targ=user1_id,
@@ -229,6 +289,14 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
# The kicked room should show up
self.assertEqual(room_id_results.keys(), {kick_room_id})
+ # It should be pointing to the latest membership event in the from/to range
+ self.assertEqual(
+ room_id_results[kick_room_id].event_id,
+ kick_response["event_id"],
+ )
+ # We should *NOT* be `newly_joined` because we were not joined at the the time
+ # of the `to_token`.
+ self.assertEqual(room_id_results[kick_room_id].newly_joined, False)
def test_forgotten_rooms(self) -> None:
"""
@@ -329,7 +397,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
# Leave during the from_token/to_token range (newly_left)
room_id2 = self.helper.create_room_as(user1_id, tok=user1_tok)
- self.helper.leave(room_id2, user1_id, tok=user1_tok)
+ _leave_response2 = self.helper.leave(room_id2, user1_id, tok=user1_tok)
after_room2_token = self.event_sources.get_current_token()
@@ -343,6 +411,16 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
# Only the newly_left room should show up
self.assertEqual(room_id_results.keys(), {room_id2})
+ # It should be pointing to the latest membership event in the from/to range but
+ # the `event_id` is `None` because we left the room causing the server to leave
+ # the room because no other local users are in it (quirk of the
+ # `current_state_delta_stream` table that we source things from)
+ self.assertEqual(
+ room_id_results[room_id2].event_id,
+ None, # _leave_response2["event_id"],
+ )
+ # We should *NOT* be `newly_joined` because we are instead `newly_left`
+ self.assertEqual(room_id_results[room_id2].newly_joined, False)
def test_no_joins_after_to_token(self) -> None:
"""
@@ -351,16 +429,19 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
before_room1_token = self.event_sources.get_current_token()
- room_id1 = self.helper.create_room_as(user1_id, tok=user1_tok)
+ room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
+ join_response1 = self.helper.join(room_id1, user1_id, tok=user1_tok)
after_room1_token = self.event_sources.get_current_token()
- # Room join after after our `to_token` shouldn't show up
- room_id2 = self.helper.create_room_as(user1_id, tok=user1_tok)
- _ = room_id2
+ # Room join after our `to_token` shouldn't show up
+ room_id2 = self.helper.create_room_as(user2_id, tok=user2_tok)
+ self.helper.join(room_id2, user1_id, tok=user1_tok)
room_id_results = self.get_success(
self.sliding_sync_handler.get_sync_room_ids_for_user(
@@ -371,6 +452,13 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
)
self.assertEqual(room_id_results.keys(), {room_id1})
+ # It should be pointing to the latest membership event in the from/to range
+ self.assertEqual(
+ room_id_results[room_id1].event_id,
+ join_response1["event_id"],
+ )
+ # We should be `newly_joined` because we joined during the token range
+ self.assertEqual(room_id_results[room_id1].newly_joined, True)
def test_join_during_range_and_left_room_after_to_token(self) -> None:
"""
@@ -380,15 +468,18 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
before_room1_token = self.event_sources.get_current_token()
- room_id1 = self.helper.create_room_as(user1_id, tok=user1_tok)
+ room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
+ join_response = self.helper.join(room_id1, user1_id, tok=user1_tok)
after_room1_token = self.event_sources.get_current_token()
# Leave the room after we already have our tokens
- self.helper.leave(room_id1, user1_id, tok=user1_tok)
+ leave_response = self.helper.leave(room_id1, user1_id, tok=user1_tok)
room_id_results = self.get_success(
self.sliding_sync_handler.get_sync_room_ids_for_user(
@@ -401,6 +492,20 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
# We should still see the room because we were joined during the
# from_token/to_token time period.
self.assertEqual(room_id_results.keys(), {room_id1})
+ # It should be pointing to the latest membership event in the from/to range
+ self.assertEqual(
+ room_id_results[room_id1].event_id,
+ join_response["event_id"],
+ "Corresponding map to disambiguate the opaque event IDs: "
+ + str(
+ {
+ "join_response": join_response["event_id"],
+ "leave_response": leave_response["event_id"],
+ }
+ ),
+ )
+ # We should be `newly_joined` because we joined during the token range
+ self.assertEqual(room_id_results[room_id1].newly_joined, True)
def test_join_before_range_and_left_room_after_to_token(self) -> None:
"""
@@ -410,13 +515,16 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
- room_id1 = self.helper.create_room_as(user1_id, tok=user1_tok)
+ room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
+ join_response = self.helper.join(room_id1, user1_id, tok=user1_tok)
after_room1_token = self.event_sources.get_current_token()
# Leave the room after we already have our tokens
- self.helper.leave(room_id1, user1_id, tok=user1_tok)
+ leave_response = self.helper.leave(room_id1, user1_id, tok=user1_tok)
room_id_results = self.get_success(
self.sliding_sync_handler.get_sync_room_ids_for_user(
@@ -428,6 +536,20 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
# We should still see the room because we were joined before the `from_token`
self.assertEqual(room_id_results.keys(), {room_id1})
+ # It should be pointing to the latest membership event in the from/to range
+ self.assertEqual(
+ room_id_results[room_id1].event_id,
+ join_response["event_id"],
+ "Corresponding map to disambiguate the opaque event IDs: "
+ + str(
+ {
+ "join_response": join_response["event_id"],
+ "leave_response": leave_response["event_id"],
+ }
+ ),
+ )
+ # We should *NOT* be `newly_joined` because we joined before the token range
+ self.assertEqual(room_id_results[room_id1].newly_joined, False)
def test_kicked_before_range_and_left_after_to_token(self) -> None:
"""
@@ -444,9 +566,9 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
kick_room_id = self.helper.create_room_as(
user2_id, tok=user2_tok, is_public=True
)
- self.helper.join(kick_room_id, user1_id, tok=user1_tok)
+ join_response1 = self.helper.join(kick_room_id, user1_id, tok=user1_tok)
# Kick user1 from the room
- self.helper.change_membership(
+ kick_response = self.helper.change_membership(
room=kick_room_id,
src=user2_id,
targ=user1_id,
@@ -463,8 +585,8 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
#
# We have to join before we can leave (leave -> leave isn't a valid transition
# or at least it doesn't work in Synapse, 403 forbidden)
- self.helper.join(kick_room_id, user1_id, tok=user1_tok)
- self.helper.leave(kick_room_id, user1_id, tok=user1_tok)
+ join_response2 = self.helper.join(kick_room_id, user1_id, tok=user1_tok)
+ leave_response = self.helper.leave(kick_room_id, user1_id, tok=user1_tok)
room_id_results = self.get_success(
self.sliding_sync_handler.get_sync_room_ids_for_user(
@@ -476,6 +598,22 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
# We shouldn't see the room because it was forgotten
self.assertEqual(room_id_results.keys(), {kick_room_id})
+ # It should be pointing to the latest membership event in the from/to range
+ self.assertEqual(
+ room_id_results[kick_room_id].event_id,
+ kick_response["event_id"],
+ "Corresponding map to disambiguate the opaque event IDs: "
+ + str(
+ {
+ "join_response1": join_response1["event_id"],
+ "kick_response": kick_response["event_id"],
+ "join_response2": join_response2["event_id"],
+ "leave_response": leave_response["event_id"],
+ }
+ ),
+ )
+ # We should *NOT* be `newly_joined` because we were kicked
+ self.assertEqual(room_id_results[kick_room_id].newly_joined, False)
def test_newly_left_during_range_and_join_leave_after_to_token(self) -> None:
"""
@@ -494,14 +632,14 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
# leave and can still re-join.
room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True)
# Join and leave the room during the from/to range
- self.helper.join(room_id1, user1_id, tok=user1_tok)
- self.helper.leave(room_id1, user1_id, tok=user1_tok)
+ join_response1 = self.helper.join(room_id1, user1_id, tok=user1_tok)
+ leave_response1 = self.helper.leave(room_id1, user1_id, tok=user1_tok)
after_room1_token = self.event_sources.get_current_token()
# Join and leave the room after we already have our tokens
- self.helper.join(room_id1, user1_id, tok=user1_tok)
- self.helper.leave(room_id1, user1_id, tok=user1_tok)
+ join_response2 = self.helper.join(room_id1, user1_id, tok=user1_tok)
+ leave_response2 = self.helper.leave(room_id1, user1_id, tok=user1_tok)
room_id_results = self.get_success(
self.sliding_sync_handler.get_sync_room_ids_for_user(
@@ -513,6 +651,22 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
# Room should still show up because it's newly_left during the from/to range
self.assertEqual(room_id_results.keys(), {room_id1})
+ # It should be pointing to the latest membership event in the from/to range
+ self.assertEqual(
+ room_id_results[room_id1].event_id,
+ leave_response1["event_id"],
+ "Corresponding map to disambiguate the opaque event IDs: "
+ + str(
+ {
+ "join_response1": join_response1["event_id"],
+ "leave_response1": leave_response1["event_id"],
+ "join_response2": join_response2["event_id"],
+ "leave_response2": leave_response2["event_id"],
+ }
+ ),
+ )
+ # We should *NOT* be `newly_joined` because we left during the token range
+ self.assertEqual(room_id_results[room_id1].newly_joined, False)
def test_newly_left_during_range_and_join_after_to_token(self) -> None:
"""
@@ -531,13 +685,13 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
# leave and can still re-join.
room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True)
# Join and leave the room during the from/to range
- self.helper.join(room_id1, user1_id, tok=user1_tok)
- self.helper.leave(room_id1, user1_id, tok=user1_tok)
+ join_response1 = self.helper.join(room_id1, user1_id, tok=user1_tok)
+ leave_response1 = self.helper.leave(room_id1, user1_id, tok=user1_tok)
after_room1_token = self.event_sources.get_current_token()
# Join the room after we already have our tokens
- self.helper.join(room_id1, user1_id, tok=user1_tok)
+ join_response2 = self.helper.join(room_id1, user1_id, tok=user1_tok)
room_id_results = self.get_success(
self.sliding_sync_handler.get_sync_room_ids_for_user(
@@ -549,11 +703,26 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
# Room should still show up because it's newly_left during the from/to range
self.assertEqual(room_id_results.keys(), {room_id1})
+ # It should be pointing to the latest membership event in the from/to range
+ self.assertEqual(
+ room_id_results[room_id1].event_id,
+ leave_response1["event_id"],
+ "Corresponding map to disambiguate the opaque event IDs: "
+ + str(
+ {
+ "join_response1": join_response1["event_id"],
+ "leave_response1": leave_response1["event_id"],
+ "join_response2": join_response2["event_id"],
+ }
+ ),
+ )
+ # We should *NOT* be `newly_joined` because we left during the token range
+ self.assertEqual(room_id_results[room_id1].newly_joined, False)
def test_no_from_token(self) -> None:
"""
Test that if we don't provide a `from_token`, we get all the rooms that we we're
- joined to up to the `to_token`.
+ joined up to the `to_token`.
Providing `from_token` only really has the effect that it adds `newly_left`
rooms to the response.
@@ -569,7 +738,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
room_id2 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True)
# Join room1
- self.helper.join(room_id1, user1_id, tok=user1_tok)
+ join_response1 = self.helper.join(room_id1, user1_id, tok=user1_tok)
# Join and leave the room2 before the `to_token`
self.helper.join(room_id2, user1_id, tok=user1_tok)
@@ -590,6 +759,14 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
# Only rooms we were joined to before the `to_token` should show up
self.assertEqual(room_id_results.keys(), {room_id1})
+ # It should be pointing to the latest membership event in the from/to range
+ self.assertEqual(
+ room_id_results[room_id1].event_id,
+ join_response1["event_id"],
+ )
+ # We should *NOT* be `newly_joined` because there is no `from_token` to
+ # define a "live" range to compare against
+ self.assertEqual(room_id_results[room_id1].newly_joined, False)
def test_from_token_ahead_of_to_token(self) -> None:
"""
@@ -609,7 +786,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
room_id4 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True)
# Join room1 before `before_room_token`
- self.helper.join(room_id1, user1_id, tok=user1_tok)
+ join_response1 = self.helper.join(room_id1, user1_id, tok=user1_tok)
# Join and leave the room2 before `before_room_token`
self.helper.join(room_id2, user1_id, tok=user1_tok)
@@ -651,6 +828,13 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
# There won't be any newly_left rooms because the `from_token` is ahead of the
# `to_token` and that range will give no membership changes to check.
self.assertEqual(room_id_results.keys(), {room_id1})
+ # It should be pointing to the latest membership event in the from/to range
+ self.assertEqual(
+ room_id_results[room_id1].event_id,
+ join_response1["event_id"],
+ )
+ # We should *NOT* be `newly_joined` because we joined `room1` before either of the tokens
+ self.assertEqual(room_id_results[room_id1].newly_joined, False)
def test_leave_before_range_and_join_leave_after_to_token(self) -> None:
"""
@@ -741,16 +925,16 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
# leave and can still re-join.
room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True)
# Join, leave, join back to the room before the from/to range
- self.helper.join(room_id1, user1_id, tok=user1_tok)
- self.helper.leave(room_id1, user1_id, tok=user1_tok)
- self.helper.join(room_id1, user1_id, tok=user1_tok)
+ join_response1 = self.helper.join(room_id1, user1_id, tok=user1_tok)
+ leave_response1 = self.helper.leave(room_id1, user1_id, tok=user1_tok)
+ join_response2 = self.helper.join(room_id1, user1_id, tok=user1_tok)
after_room1_token = self.event_sources.get_current_token()
# Leave and Join the room multiple times after we already have our tokens
- self.helper.leave(room_id1, user1_id, tok=user1_tok)
- self.helper.join(room_id1, user1_id, tok=user1_tok)
- self.helper.leave(room_id1, user1_id, tok=user1_tok)
+ leave_response2 = self.helper.leave(room_id1, user1_id, tok=user1_tok)
+ join_response3 = self.helper.join(room_id1, user1_id, tok=user1_tok)
+ leave_response3 = self.helper.leave(room_id1, user1_id, tok=user1_tok)
room_id_results = self.get_success(
self.sliding_sync_handler.get_sync_room_ids_for_user(
@@ -762,6 +946,24 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
# Room should show up because it was newly_left and joined during the from/to range
self.assertEqual(room_id_results.keys(), {room_id1})
+ # It should be pointing to the latest membership event in the from/to range
+ self.assertEqual(
+ room_id_results[room_id1].event_id,
+ join_response2["event_id"],
+ "Corresponding map to disambiguate the opaque event IDs: "
+ + str(
+ {
+ "join_response1": join_response1["event_id"],
+ "leave_response1": leave_response1["event_id"],
+ "join_response2": join_response2["event_id"],
+ "leave_response2": leave_response2["event_id"],
+ "join_response3": join_response3["event_id"],
+ "leave_response3": leave_response3["event_id"],
+ }
+ ),
+ )
+ # We should be `newly_joined` because we joined during the token range
+ self.assertEqual(room_id_results[room_id1].newly_joined, True)
def test_join_leave_multiple_times_before_range_and_after_to_token(
self,
@@ -781,16 +983,16 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
# leave and can still re-join.
room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True)
# Join, leave, join back to the room before the from/to range
- self.helper.join(room_id1, user1_id, tok=user1_tok)
- self.helper.leave(room_id1, user1_id, tok=user1_tok)
- self.helper.join(room_id1, user1_id, tok=user1_tok)
+ join_response1 = self.helper.join(room_id1, user1_id, tok=user1_tok)
+ leave_response1 = self.helper.leave(room_id1, user1_id, tok=user1_tok)
+ join_response2 = self.helper.join(room_id1, user1_id, tok=user1_tok)
after_room1_token = self.event_sources.get_current_token()
# Leave and Join the room multiple times after we already have our tokens
- self.helper.leave(room_id1, user1_id, tok=user1_tok)
- self.helper.join(room_id1, user1_id, tok=user1_tok)
- self.helper.leave(room_id1, user1_id, tok=user1_tok)
+ leave_response2 = self.helper.leave(room_id1, user1_id, tok=user1_tok)
+ join_response3 = self.helper.join(room_id1, user1_id, tok=user1_tok)
+ leave_response3 = self.helper.leave(room_id1, user1_id, tok=user1_tok)
room_id_results = self.get_success(
self.sliding_sync_handler.get_sync_room_ids_for_user(
@@ -802,6 +1004,24 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
# Room should show up because we were joined before the from/to range
self.assertEqual(room_id_results.keys(), {room_id1})
+ # It should be pointing to the latest membership event in the from/to range
+ self.assertEqual(
+ room_id_results[room_id1].event_id,
+ join_response2["event_id"],
+ "Corresponding map to disambiguate the opaque event IDs: "
+ + str(
+ {
+ "join_response1": join_response1["event_id"],
+ "leave_response1": leave_response1["event_id"],
+ "join_response2": join_response2["event_id"],
+ "leave_response2": leave_response2["event_id"],
+ "join_response3": join_response3["event_id"],
+ "leave_response3": leave_response3["event_id"],
+ }
+ ),
+ )
+ # We should *NOT* be `newly_joined` because we joined before the token range
+ self.assertEqual(room_id_results[room_id1].newly_joined, False)
def test_invite_before_range_and_join_leave_after_to_token(
self,
@@ -821,24 +1041,495 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True)
# Invited to the room before the token
- self.helper.invite(room_id1, src=user2_id, targ=user1_id, tok=user2_tok)
+ invite_response = self.helper.invite(
+ room_id1, src=user2_id, targ=user1_id, tok=user2_tok
+ )
after_room1_token = self.event_sources.get_current_token()
# Join and leave the room after we already have our tokens
+ join_respsonse = self.helper.join(room_id1, user1_id, tok=user1_tok)
+ leave_response = self.helper.leave(room_id1, user1_id, tok=user1_tok)
+
+ room_id_results = self.get_success(
+ self.sliding_sync_handler.get_sync_room_ids_for_user(
+ UserID.from_string(user1_id),
+ from_token=after_room1_token,
+ to_token=after_room1_token,
+ )
+ )
+
+ # Room should show up because we were invited before the from/to range
+ self.assertEqual(room_id_results.keys(), {room_id1})
+ # It should be pointing to the latest membership event in the from/to range
+ self.assertEqual(
+ room_id_results[room_id1].event_id,
+ invite_response["event_id"],
+ "Corresponding map to disambiguate the opaque event IDs: "
+ + str(
+ {
+ "invite_response": invite_response["event_id"],
+ "join_respsonse": join_respsonse["event_id"],
+ "leave_response": leave_response["event_id"],
+ }
+ ),
+ )
+ # We should *NOT* be `newly_joined` because we were only invited before the
+ # token range
+ self.assertEqual(room_id_results[room_id1].newly_joined, False)
+
+ def test_join_and_display_name_changes_in_token_range(
+ self,
+ ) -> None:
+ """
+ Test that we point to the correct membership event within the from/to range even
+ if there are multiple `join` membership events in a row indicating
+ `displayname`/`avatar_url` updates.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ before_room1_token = self.event_sources.get_current_token()
+
+ # We create the room with user2 so the room isn't left with no members when we
+ # leave and can still re-join.
+ room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True)
+ join_response = self.helper.join(room_id1, user1_id, tok=user1_tok)
+ # Update the displayname during the token range
+ displayname_change_during_token_range_response = self.helper.send_state(
+ room_id1,
+ event_type=EventTypes.Member,
+ state_key=user1_id,
+ body={
+ "membership": Membership.JOIN,
+ "displayname": "displayname during token range",
+ },
+ tok=user1_tok,
+ )
+
+ after_room1_token = self.event_sources.get_current_token()
+
+ # Update the displayname after the token range
+ displayname_change_after_token_range_response = self.helper.send_state(
+ room_id1,
+ event_type=EventTypes.Member,
+ state_key=user1_id,
+ body={
+ "membership": Membership.JOIN,
+ "displayname": "displayname after token range",
+ },
+ tok=user1_tok,
+ )
+
+ room_id_results = self.get_success(
+ self.sliding_sync_handler.get_sync_room_ids_for_user(
+ UserID.from_string(user1_id),
+ from_token=before_room1_token,
+ to_token=after_room1_token,
+ )
+ )
+
+ # Room should show up because we were joined during the from/to range
+ self.assertEqual(room_id_results.keys(), {room_id1})
+ # It should be pointing to the latest membership event in the from/to range
+ self.assertEqual(
+ room_id_results[room_id1].event_id,
+ displayname_change_during_token_range_response["event_id"],
+ "Corresponding map to disambiguate the opaque event IDs: "
+ + str(
+ {
+ "join_response": join_response["event_id"],
+ "displayname_change_during_token_range_response": displayname_change_during_token_range_response[
+ "event_id"
+ ],
+ "displayname_change_after_token_range_response": displayname_change_after_token_range_response[
+ "event_id"
+ ],
+ }
+ ),
+ )
+ # We should be `newly_joined` because we joined during the token range
+ self.assertEqual(room_id_results[room_id1].newly_joined, True)
+
+ def test_display_name_changes_in_token_range(
+ self,
+ ) -> None:
+ """
+ Test that we point to the correct membership event within the from/to range even
+ if there is `displayname`/`avatar_url` updates.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ # We create the room with user2 so the room isn't left with no members when we
+ # leave and can still re-join.
+ room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True)
+ join_response = self.helper.join(room_id1, user1_id, tok=user1_tok)
+
+ after_room1_token = self.event_sources.get_current_token()
+
+ # Update the displayname during the token range
+ displayname_change_during_token_range_response = self.helper.send_state(
+ room_id1,
+ event_type=EventTypes.Member,
+ state_key=user1_id,
+ body={
+ "membership": Membership.JOIN,
+ "displayname": "displayname during token range",
+ },
+ tok=user1_tok,
+ )
+
+ after_change1_token = self.event_sources.get_current_token()
+
+ room_id_results = self.get_success(
+ self.sliding_sync_handler.get_sync_room_ids_for_user(
+ UserID.from_string(user1_id),
+ from_token=after_room1_token,
+ to_token=after_change1_token,
+ )
+ )
+
+ # Room should show up because we were joined during the from/to range
+ self.assertEqual(room_id_results.keys(), {room_id1})
+ # It should be pointing to the latest membership event in the from/to range
+ self.assertEqual(
+ room_id_results[room_id1].event_id,
+ displayname_change_during_token_range_response["event_id"],
+ "Corresponding map to disambiguate the opaque event IDs: "
+ + str(
+ {
+ "join_response": join_response["event_id"],
+ "displayname_change_during_token_range_response": displayname_change_during_token_range_response[
+ "event_id"
+ ],
+ }
+ ),
+ )
+ # We should *NOT* be `newly_joined` because we joined before the token range
+ self.assertEqual(room_id_results[room_id1].newly_joined, False)
+
+ def test_display_name_changes_before_and_after_token_range(
+ self,
+ ) -> None:
+ """
+ Test that we point to the correct membership event even though there are no
+ membership events in the from/range but there are `displayname`/`avatar_url`
+ changes before/after the token range.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ # We create the room with user2 so the room isn't left with no members when we
+ # leave and can still re-join.
+ room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True)
+ join_response = self.helper.join(room_id1, user1_id, tok=user1_tok)
+ # Update the displayname before the token range
+ displayname_change_before_token_range_response = self.helper.send_state(
+ room_id1,
+ event_type=EventTypes.Member,
+ state_key=user1_id,
+ body={
+ "membership": Membership.JOIN,
+ "displayname": "displayname during token range",
+ },
+ tok=user1_tok,
+ )
+
+ after_room1_token = self.event_sources.get_current_token()
+
+ # Update the displayname after the token range
+ displayname_change_after_token_range_response = self.helper.send_state(
+ room_id1,
+ event_type=EventTypes.Member,
+ state_key=user1_id,
+ body={
+ "membership": Membership.JOIN,
+ "displayname": "displayname after token range",
+ },
+ tok=user1_tok,
+ )
+
+ room_id_results = self.get_success(
+ self.sliding_sync_handler.get_sync_room_ids_for_user(
+ UserID.from_string(user1_id),
+ from_token=after_room1_token,
+ to_token=after_room1_token,
+ )
+ )
+
+ # Room should show up because we were joined before the from/to range
+ self.assertEqual(room_id_results.keys(), {room_id1})
+ # It should be pointing to the latest membership event in the from/to range
+ self.assertEqual(
+ room_id_results[room_id1].event_id,
+ displayname_change_before_token_range_response["event_id"],
+ "Corresponding map to disambiguate the opaque event IDs: "
+ + str(
+ {
+ "join_response": join_response["event_id"],
+ "displayname_change_before_token_range_response": displayname_change_before_token_range_response[
+ "event_id"
+ ],
+ "displayname_change_after_token_range_response": displayname_change_after_token_range_response[
+ "event_id"
+ ],
+ }
+ ),
+ )
+ # We should *NOT* be `newly_joined` because we joined before the token range
+ self.assertEqual(room_id_results[room_id1].newly_joined, False)
+
+ def test_display_name_changes_leave_after_token_range(
+ self,
+ ) -> None:
+ """
+ Test that we point to the correct membership event within the from/to range even
+ if there are multiple `join` membership events in a row indicating
+ `displayname`/`avatar_url` updates and we leave after the `to_token`.
+
+ See condition "1a)" comments in the `get_sync_room_ids_for_user()` method.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ before_room1_token = self.event_sources.get_current_token()
+
+ # We create the room with user2 so the room isn't left with no members when we
+ # leave and can still re-join.
+ room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True)
+ join_response = self.helper.join(room_id1, user1_id, tok=user1_tok)
+ # Update the displayname during the token range
+ displayname_change_during_token_range_response = self.helper.send_state(
+ room_id1,
+ event_type=EventTypes.Member,
+ state_key=user1_id,
+ body={
+ "membership": Membership.JOIN,
+ "displayname": "displayname during token range",
+ },
+ tok=user1_tok,
+ )
+
+ after_room1_token = self.event_sources.get_current_token()
+
+ # Update the displayname after the token range
+ displayname_change_after_token_range_response = self.helper.send_state(
+ room_id1,
+ event_type=EventTypes.Member,
+ state_key=user1_id,
+ body={
+ "membership": Membership.JOIN,
+ "displayname": "displayname after token range",
+ },
+ tok=user1_tok,
+ )
+
+ # Leave after the token
+ self.helper.leave(room_id1, user1_id, tok=user1_tok)
+
+ room_id_results = self.get_success(
+ self.sliding_sync_handler.get_sync_room_ids_for_user(
+ UserID.from_string(user1_id),
+ from_token=before_room1_token,
+ to_token=after_room1_token,
+ )
+ )
+
+ # Room should show up because we were joined during the from/to range
+ self.assertEqual(room_id_results.keys(), {room_id1})
+ # It should be pointing to the latest membership event in the from/to range
+ self.assertEqual(
+ room_id_results[room_id1].event_id,
+ displayname_change_during_token_range_response["event_id"],
+ "Corresponding map to disambiguate the opaque event IDs: "
+ + str(
+ {
+ "join_response": join_response["event_id"],
+ "displayname_change_during_token_range_response": displayname_change_during_token_range_response[
+ "event_id"
+ ],
+ "displayname_change_after_token_range_response": displayname_change_after_token_range_response[
+ "event_id"
+ ],
+ }
+ ),
+ )
+ # We should be `newly_joined` because we joined during the token range
+ self.assertEqual(room_id_results[room_id1].newly_joined, True)
+
+ def test_display_name_changes_join_after_token_range(
+ self,
+ ) -> None:
+ """
+ Test that multiple `join` membership events (after the `to_token`) in a row
+ indicating `displayname`/`avatar_url` updates doesn't affect the results (we
+ joined after the token range so it shouldn't show up)
+
+ See condition "1b)" comments in the `get_sync_room_ids_for_user()` method.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ before_room1_token = self.event_sources.get_current_token()
+
+ # We create the room with user2 so the room isn't left with no members when we
+ # leave and can still re-join.
+ room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True)
+
+ after_room1_token = self.event_sources.get_current_token()
+
+ self.helper.join(room_id1, user1_id, tok=user1_tok)
+ # Update the displayname after the token range
+ self.helper.send_state(
+ room_id1,
+ event_type=EventTypes.Member,
+ state_key=user1_id,
+ body={
+ "membership": Membership.JOIN,
+ "displayname": "displayname after token range",
+ },
+ tok=user1_tok,
+ )
+
+ room_id_results = self.get_success(
+ self.sliding_sync_handler.get_sync_room_ids_for_user(
+ UserID.from_string(user1_id),
+ from_token=before_room1_token,
+ to_token=after_room1_token,
+ )
+ )
+
+ # Room shouldn't show up because we joined after the from/to range
+ self.assertEqual(room_id_results.keys(), set())
+
+ def test_newly_joined_with_leave_join_in_token_range(
+ self,
+ ) -> None:
+ """
+ Test that even though we're joined before the token range, if we leave and join
+ within the token range, it's still counted as `newly_joined`.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ # We create the room with user2 so the room isn't left with no members when we
+ # leave and can still re-join.
+ room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True)
self.helper.join(room_id1, user1_id, tok=user1_tok)
+
+ after_room1_token = self.event_sources.get_current_token()
+
+ # Leave and join back during the token range
self.helper.leave(room_id1, user1_id, tok=user1_tok)
+ join_response2 = self.helper.join(room_id1, user1_id, tok=user1_tok)
+
+ after_more_changes_token = self.event_sources.get_current_token()
room_id_results = self.get_success(
self.sliding_sync_handler.get_sync_room_ids_for_user(
UserID.from_string(user1_id),
from_token=after_room1_token,
+ to_token=after_more_changes_token,
+ )
+ )
+
+ # Room should show up because we were joined during the from/to range
+ self.assertEqual(room_id_results.keys(), {room_id1})
+ # It should be pointing to the latest membership event in the from/to range
+ self.assertEqual(
+ room_id_results[room_id1].event_id,
+ join_response2["event_id"],
+ )
+ # We should be considered `newly_joined` because there is some non-join event in
+ # between our latest join event.
+ self.assertEqual(room_id_results[room_id1].newly_joined, True)
+
+ def test_newly_joined_only_joins_during_token_range(
+ self,
+ ) -> None:
+ """
+ Test that a join and more joins caused by display name changes, all during the
+ token range, still count as `newly_joined`.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ before_room1_token = self.event_sources.get_current_token()
+
+ # We create the room with user2 so the room isn't left with no members when we
+ # leave and can still re-join.
+ room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True)
+ # Join, leave, join back to the room before the from/to range
+ join_response1 = self.helper.join(room_id1, user1_id, tok=user1_tok)
+ # Update the displayname during the token range (looks like another join)
+ displayname_change_during_token_range_response1 = self.helper.send_state(
+ room_id1,
+ event_type=EventTypes.Member,
+ state_key=user1_id,
+ body={
+ "membership": Membership.JOIN,
+ "displayname": "displayname during token range",
+ },
+ tok=user1_tok,
+ )
+ # Update the displayname during the token range (looks like another join)
+ displayname_change_during_token_range_response2 = self.helper.send_state(
+ room_id1,
+ event_type=EventTypes.Member,
+ state_key=user1_id,
+ body={
+ "membership": Membership.JOIN,
+ "displayname": "displayname during token range",
+ },
+ tok=user1_tok,
+ )
+
+ after_room1_token = self.event_sources.get_current_token()
+
+ room_id_results = self.get_success(
+ self.sliding_sync_handler.get_sync_room_ids_for_user(
+ UserID.from_string(user1_id),
+ from_token=before_room1_token,
to_token=after_room1_token,
)
)
- # Room should show up because we were invited before the from/to range
+ # Room should show up because it was newly_left and joined during the from/to range
self.assertEqual(room_id_results.keys(), {room_id1})
+ # It should be pointing to the latest membership event in the from/to range
+ self.assertEqual(
+ room_id_results[room_id1].event_id,
+ displayname_change_during_token_range_response2["event_id"],
+ "Corresponding map to disambiguate the opaque event IDs: "
+ + str(
+ {
+ "join_response1": join_response1["event_id"],
+ "displayname_change_during_token_range_response1": displayname_change_during_token_range_response1[
+ "event_id"
+ ],
+ "displayname_change_during_token_range_response2": displayname_change_during_token_range_response2[
+ "event_id"
+ ],
+ }
+ ),
+ )
+ # We should be `newly_joined` because we first joined during the token range
+ self.assertEqual(room_id_results[room_id1].newly_joined, True)
def test_multiple_rooms_are_not_confused(
self,
@@ -1363,6 +2054,211 @@ class FilterRoomsTestCase(HomeserverTestCase):
self.assertEqual(falsy_filtered_room_map.keys(), {room_id})
+ def test_filter_room_types(self) -> None:
+ """
+ Test `filter.room_types` for different room types
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ # Create a normal room (no room type)
+ room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+
+ # Create a space room
+ space_room_id = self.helper.create_room_as(
+ user1_id,
+ tok=user1_tok,
+ extra_content={
+ "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE}
+ },
+ )
+
+ # Create an arbitrarily typed room
+ foo_room_id = self.helper.create_room_as(
+ user1_id,
+ tok=user1_tok,
+ extra_content={
+ "creation_content": {
+ EventContentFields.ROOM_TYPE: "org.matrix.foobarbaz"
+ }
+ },
+ )
+
+ after_rooms_token = self.event_sources.get_current_token()
+
+ # Get the rooms the user should be syncing with
+ sync_room_map = self.get_success(
+ self.sliding_sync_handler.get_sync_room_ids_for_user(
+ UserID.from_string(user1_id),
+ from_token=None,
+ to_token=after_rooms_token,
+ )
+ )
+
+ # Try finding only normal rooms
+ filtered_room_map = self.get_success(
+ self.sliding_sync_handler.filter_rooms(
+ UserID.from_string(user1_id),
+ sync_room_map,
+ SlidingSyncConfig.SlidingSyncList.Filters(room_types=[None]),
+ after_rooms_token,
+ )
+ )
+
+ self.assertEqual(filtered_room_map.keys(), {room_id})
+
+ # Try finding only spaces
+ filtered_room_map = self.get_success(
+ self.sliding_sync_handler.filter_rooms(
+ UserID.from_string(user1_id),
+ sync_room_map,
+ SlidingSyncConfig.SlidingSyncList.Filters(room_types=[RoomTypes.SPACE]),
+ after_rooms_token,
+ )
+ )
+
+ self.assertEqual(filtered_room_map.keys(), {space_room_id})
+
+ # Try finding normal rooms and spaces
+ filtered_room_map = self.get_success(
+ self.sliding_sync_handler.filter_rooms(
+ UserID.from_string(user1_id),
+ sync_room_map,
+ SlidingSyncConfig.SlidingSyncList.Filters(
+ room_types=[None, RoomTypes.SPACE]
+ ),
+ after_rooms_token,
+ )
+ )
+
+ self.assertEqual(filtered_room_map.keys(), {room_id, space_room_id})
+
+ # Try finding an arbitrary room type
+ filtered_room_map = self.get_success(
+ self.sliding_sync_handler.filter_rooms(
+ UserID.from_string(user1_id),
+ sync_room_map,
+ SlidingSyncConfig.SlidingSyncList.Filters(
+ room_types=["org.matrix.foobarbaz"]
+ ),
+ after_rooms_token,
+ )
+ )
+
+ self.assertEqual(filtered_room_map.keys(), {foo_room_id})
+
+ def test_filter_not_room_types(self) -> None:
+ """
+ Test `filter.not_room_types` for different room types
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ # Create a normal room (no room type)
+ room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+
+ # Create a space room
+ space_room_id = self.helper.create_room_as(
+ user1_id,
+ tok=user1_tok,
+ extra_content={
+ "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE}
+ },
+ )
+
+ # Create an arbitrarily typed room
+ foo_room_id = self.helper.create_room_as(
+ user1_id,
+ tok=user1_tok,
+ extra_content={
+ "creation_content": {
+ EventContentFields.ROOM_TYPE: "org.matrix.foobarbaz"
+ }
+ },
+ )
+
+ after_rooms_token = self.event_sources.get_current_token()
+
+ # Get the rooms the user should be syncing with
+ sync_room_map = self.get_success(
+ self.sliding_sync_handler.get_sync_room_ids_for_user(
+ UserID.from_string(user1_id),
+ from_token=None,
+ to_token=after_rooms_token,
+ )
+ )
+
+ # Try finding *NOT* normal rooms
+ filtered_room_map = self.get_success(
+ self.sliding_sync_handler.filter_rooms(
+ UserID.from_string(user1_id),
+ sync_room_map,
+ SlidingSyncConfig.SlidingSyncList.Filters(not_room_types=[None]),
+ after_rooms_token,
+ )
+ )
+
+ self.assertEqual(filtered_room_map.keys(), {space_room_id, foo_room_id})
+
+ # Try finding *NOT* spaces
+ filtered_room_map = self.get_success(
+ self.sliding_sync_handler.filter_rooms(
+ UserID.from_string(user1_id),
+ sync_room_map,
+ SlidingSyncConfig.SlidingSyncList.Filters(
+ not_room_types=[RoomTypes.SPACE]
+ ),
+ after_rooms_token,
+ )
+ )
+
+ self.assertEqual(filtered_room_map.keys(), {room_id, foo_room_id})
+
+ # Try finding *NOT* normal rooms or spaces
+ filtered_room_map = self.get_success(
+ self.sliding_sync_handler.filter_rooms(
+ UserID.from_string(user1_id),
+ sync_room_map,
+ SlidingSyncConfig.SlidingSyncList.Filters(
+ not_room_types=[None, RoomTypes.SPACE]
+ ),
+ after_rooms_token,
+ )
+ )
+
+ self.assertEqual(filtered_room_map.keys(), {foo_room_id})
+
+ # Test how it behaves when we have both `room_types` and `not_room_types`.
+ # `not_room_types` should win.
+ filtered_room_map = self.get_success(
+ self.sliding_sync_handler.filter_rooms(
+ UserID.from_string(user1_id),
+ sync_room_map,
+ SlidingSyncConfig.SlidingSyncList.Filters(
+ room_types=[None], not_room_types=[None]
+ ),
+ after_rooms_token,
+ )
+ )
+
+ # Nothing matches because nothing is both a normal room and not a normal room
+ self.assertEqual(filtered_room_map.keys(), set())
+
+ # Test how it behaves when we have both `room_types` and `not_room_types`.
+ # `not_room_types` should win.
+ filtered_room_map = self.get_success(
+ self.sliding_sync_handler.filter_rooms(
+ UserID.from_string(user1_id),
+ sync_room_map,
+ SlidingSyncConfig.SlidingSyncList.Filters(
+ room_types=[None, RoomTypes.SPACE], not_room_types=[None]
+ ),
+ after_rooms_token,
+ )
+ )
+
+ self.assertEqual(filtered_room_map.keys(), {space_room_id})
+
class SortRoomsTestCase(HomeserverTestCase):
"""
diff --git a/tests/http/test_client.py b/tests/http/test_client.py
index a98091d711..721917f957 100644
--- a/tests/http/test_client.py
+++ b/tests/http/test_client.py
@@ -37,18 +37,155 @@ from synapse.http.client import (
BlocklistingAgentWrapper,
BlocklistingReactorWrapper,
BodyExceededMaxSize,
+ MultipartResponse,
_DiscardBodyWithMaxSizeProtocol,
+ _MultipartParserProtocol,
read_body_with_max_size,
+ read_multipart_response,
)
from tests.server import FakeTransport, get_clock
from tests.unittest import TestCase
+class ReadMultipartResponseTests(TestCase):
+ data1 = b"\r\n\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: application/json\r\n\r\n{}\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: text/plain\r\nContent-Disposition: inline; filename=test_upload\r\n\r\nfile_"
+ data2 = b"to_stream\r\n--6067d4698f8d40a0a794ea7d7379d53a--\r\n\r\n"
+
+ redirect_data = b"\r\n\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: application/json\r\n\r\n{}\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nLocation: https://cdn.example.org/ab/c1/2345.txt\r\n\r\n--6067d4698f8d40a0a794ea7d7379d53a--\r\n\r\n"
+
+ def _build_multipart_response(
+ self, response_length: Union[int, str], max_length: int
+ ) -> Tuple[
+ BytesIO,
+ "Deferred[MultipartResponse]",
+ _MultipartParserProtocol,
+ ]:
+ """Start reading the body, returns the response, result and proto"""
+ response = Mock(length=response_length)
+ result = BytesIO()
+ boundary = "6067d4698f8d40a0a794ea7d7379d53a"
+ deferred = read_multipart_response(response, result, boundary, max_length)
+
+ # Fish the protocol out of the response.
+ protocol = response.deliverBody.call_args[0][0]
+ protocol.transport = Mock()
+
+ return result, deferred, protocol
+
+ def _assert_error(
+ self,
+ deferred: "Deferred[MultipartResponse]",
+ protocol: _MultipartParserProtocol,
+ ) -> None:
+ """Ensure that the expected error is received."""
+ assert isinstance(deferred.result, Failure)
+ self.assertIsInstance(deferred.result.value, BodyExceededMaxSize)
+ assert protocol.transport is not None
+ # type-ignore: presumably abortConnection has been replaced with a Mock.
+ protocol.transport.abortConnection.assert_called_once() # type: ignore[attr-defined]
+
+ def _cleanup_error(self, deferred: "Deferred[MultipartResponse]") -> None:
+ """Ensure that the error in the Deferred is handled gracefully."""
+ called = [False]
+
+ def errback(f: Failure) -> None:
+ called[0] = True
+
+ deferred.addErrback(errback)
+ self.assertTrue(called[0])
+
+ def test_parse_file(self) -> None:
+ """
+ Check that a multipart response containing a file is properly parsed
+ into the json/file parts, and the json and file are properly captured
+ """
+ result, deferred, protocol = self._build_multipart_response(249, 250)
+
+ # Start sending data.
+ protocol.dataReceived(self.data1)
+ protocol.dataReceived(self.data2)
+ # Close the connection.
+ protocol.connectionLost(Failure(ResponseDone()))
+
+ multipart_response: MultipartResponse = deferred.result # type: ignore[assignment]
+
+ self.assertEqual(multipart_response.json, b"{}")
+ self.assertEqual(result.getvalue(), b"file_to_stream")
+ self.assertEqual(multipart_response.length, len(b"file_to_stream"))
+ self.assertEqual(multipart_response.content_type, b"text/plain")
+ self.assertEqual(
+ multipart_response.disposition, b"inline; filename=test_upload"
+ )
+
+ def test_parse_redirect(self) -> None:
+ """
+ check that a multipart response containing a redirect is properly parsed and redirect url is
+ returned
+ """
+ result, deferred, protocol = self._build_multipart_response(249, 250)
+
+ # Start sending data.
+ protocol.dataReceived(self.redirect_data)
+ # Close the connection.
+ protocol.connectionLost(Failure(ResponseDone()))
+
+ multipart_response: MultipartResponse = deferred.result # type: ignore[assignment]
+
+ self.assertEqual(multipart_response.json, b"{}")
+ self.assertEqual(result.getvalue(), b"")
+ self.assertEqual(
+ multipart_response.url, b"https://cdn.example.org/ab/c1/2345.txt"
+ )
+
+ def test_too_large(self) -> None:
+ """A response which is too large raises an exception."""
+ result, deferred, protocol = self._build_multipart_response(UNKNOWN_LENGTH, 180)
+
+ # Start sending data.
+ protocol.dataReceived(self.data1)
+
+ self.assertEqual(result.getvalue(), b"file_")
+ self._assert_error(deferred, protocol)
+ self._cleanup_error(deferred)
+
+ def test_additional_data(self) -> None:
+ """A connection can receive data after being closed."""
+ result, deferred, protocol = self._build_multipart_response(UNKNOWN_LENGTH, 180)
+
+ # Start sending data.
+ protocol.dataReceived(self.data1)
+ self._assert_error(deferred, protocol)
+
+ # More data might have come in.
+ protocol.dataReceived(self.data2)
+
+ self.assertEqual(result.getvalue(), b"file_")
+ self._assert_error(deferred, protocol)
+ self._cleanup_error(deferred)
+
+ def test_content_length(self) -> None:
+ """The body shouldn't be read (at all) if the Content-Length header is too large."""
+ result, deferred, protocol = self._build_multipart_response(250, 1)
+
+ # Deferred shouldn't be called yet.
+ self.assertFalse(deferred.called)
+
+ # Start sending data.
+ protocol.dataReceived(self.data1)
+ self._assert_error(deferred, protocol)
+ self._cleanup_error(deferred)
+
+ # The data is never consumed.
+ self.assertEqual(result.getvalue(), b"")
+
+
class ReadBodyWithMaxSizeTests(TestCase):
- def _build_response(
- self, length: Union[int, str] = UNKNOWN_LENGTH
- ) -> Tuple[BytesIO, "Deferred[int]", _DiscardBodyWithMaxSizeProtocol]:
+ def _build_response(self, length: Union[int, str] = UNKNOWN_LENGTH) -> Tuple[
+ BytesIO,
+ "Deferred[int]",
+ _DiscardBodyWithMaxSizeProtocol,
+ ]:
"""Start reading the body, returns the response, result and proto"""
response = Mock(length=length)
result = BytesIO()
diff --git a/tests/media/test_media_storage.py b/tests/media/test_media_storage.py
index 46d20ce775..024086b775 100644
--- a/tests/media/test_media_storage.py
+++ b/tests/media/test_media_storage.py
@@ -129,7 +129,7 @@ class MediaStorageTests(unittest.HomeserverTestCase):
@attr.s(auto_attribs=True, slots=True, frozen=True)
-class _TestImage:
+class TestImage:
"""An image for testing thumbnailing with the expected results
Attributes:
@@ -158,7 +158,7 @@ class _TestImage:
is_inline: bool = True
-small_png = _TestImage(
+small_png = TestImage(
SMALL_PNG,
b"image/png",
b".png",
@@ -175,7 +175,7 @@ small_png = _TestImage(
),
)
-small_png_with_transparency = _TestImage(
+small_png_with_transparency = TestImage(
unhexlify(
b"89504e470d0a1a0a0000000d49484452000000010000000101000"
b"00000376ef9240000000274524e5300010194fdae0000000a4944"
@@ -188,7 +188,7 @@ small_png_with_transparency = _TestImage(
# different versions of Pillow.
)
-small_lossless_webp = _TestImage(
+small_lossless_webp = TestImage(
unhexlify(
b"524946461a000000574542505650384c0d0000002f0000001007" b"1011118888fe0700"
),
@@ -196,7 +196,7 @@ small_lossless_webp = _TestImage(
b".webp",
)
-empty_file = _TestImage(
+empty_file = TestImage(
b"",
b"image/gif",
b".gif",
@@ -204,7 +204,7 @@ empty_file = _TestImage(
unable_to_thumbnail=True,
)
-SVG = _TestImage(
+SVG = TestImage(
b"""<?xml version="1.0"?>
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN"
"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
@@ -236,7 +236,7 @@ urls = [
@parameterized_class(("test_image", "url"), itertools.product(test_images, urls))
class MediaRepoTests(unittest.HomeserverTestCase):
servlets = [media.register_servlets]
- test_image: ClassVar[_TestImage]
+ test_image: ClassVar[TestImage]
hijack_auth = True
user_id = "@test:user"
url: ClassVar[str]
diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py
index 4927e45446..6fc4600c41 100644
--- a/tests/replication/test_multi_media_repo.py
+++ b/tests/replication/test_multi_media_repo.py
@@ -28,7 +28,7 @@ from twisted.web.http import HTTPChannel
from twisted.web.server import Request
from synapse.rest import admin
-from synapse.rest.client import login
+from synapse.rest.client import login, media
from synapse.server import HomeServer
from synapse.util import Clock
@@ -255,6 +255,238 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
return sum(len(files) for _, _, files in os.walk(path))
+class AuthenticatedMediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
+ """Checks running multiple media repos work correctly using autheticated media paths"""
+
+ servlets = [
+ admin.register_servlets_for_client_rest_resource,
+ login.register_servlets,
+ media.register_servlets,
+ ]
+
+ file_data = b"\r\n\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: application/json\r\n\r\n{}\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: text/plain\r\nContent-Disposition: inline; filename=test_upload\r\n\r\nfile_to_stream\r\n--6067d4698f8d40a0a794ea7d7379d53a--\r\n\r\n"
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.user_id = self.register_user("user", "pass")
+ self.access_token = self.login("user", "pass")
+
+ self.reactor.lookups["example.com"] = "1.2.3.4"
+
+ def default_config(self) -> dict:
+ conf = super().default_config()
+ conf["federation_custom_ca_list"] = [get_test_ca_cert_file()]
+ return conf
+
+ def make_worker_hs(
+ self, worker_app: str, extra_config: Optional[dict] = None, **kwargs: Any
+ ) -> HomeServer:
+ worker_hs = super().make_worker_hs(worker_app, extra_config, **kwargs)
+ # Force the media paths onto the replication resource.
+ worker_hs.get_media_repository_resource().register_servlets(
+ self._hs_to_site[worker_hs].resource, worker_hs
+ )
+ return worker_hs
+
+ def _get_media_req(
+ self, hs: HomeServer, target: str, media_id: str
+ ) -> Tuple[FakeChannel, Request]:
+ """Request some remote media from the given HS by calling the download
+ API.
+
+ This then triggers an outbound request from the HS to the target.
+
+ Returns:
+ The channel for the *client* request and the *outbound* request for
+ the media which the caller should respond to.
+ """
+ channel = make_request(
+ self.reactor,
+ self._hs_to_site[hs],
+ "GET",
+ f"/_matrix/client/v1/media/download/{target}/{media_id}",
+ shorthand=False,
+ access_token=self.access_token,
+ await_result=False,
+ )
+ self.pump()
+
+ clients = self.reactor.tcpClients
+ self.assertGreaterEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients.pop()
+
+ # build the test server
+ server_factory = Factory.forProtocol(HTTPChannel)
+ # Request.finish expects the factory to have a 'log' method.
+ server_factory.log = _log_request
+
+ server_tls_protocol = wrap_server_factory_for_tls(
+ server_factory, self.reactor, sanlist=[b"DNS:example.com"]
+ ).buildProtocol(None)
+
+ # now, tell the client protocol factory to build the client protocol (it will be a
+ # _WrappingProtocol, around a TLSMemoryBIOProtocol, around an
+ # HTTP11ClientProtocol) and wire the output of said protocol up to the server via
+ # a FakeTransport.
+ #
+ # Normally this would be done by the TCP socket code in Twisted, but we are
+ # stubbing that out here.
+ client_protocol = client_factory.buildProtocol(None)
+ client_protocol.makeConnection(
+ FakeTransport(server_tls_protocol, self.reactor, client_protocol)
+ )
+
+ # tell the server tls protocol to send its stuff back to the client, too
+ server_tls_protocol.makeConnection(
+ FakeTransport(client_protocol, self.reactor, server_tls_protocol)
+ )
+
+ # fish the test server back out of the server-side TLS protocol.
+ http_server: HTTPChannel = server_tls_protocol.wrappedProtocol
+
+ # give the reactor a pump to get the TLS juices flowing.
+ self.reactor.pump((0.1,))
+
+ self.assertEqual(len(http_server.requests), 1)
+ request = http_server.requests[0]
+
+ self.assertEqual(request.method, b"GET")
+ self.assertEqual(
+ request.path,
+ f"/_matrix/federation/v1/media/download/{media_id}".encode(),
+ )
+ self.assertEqual(
+ request.requestHeaders.getRawHeaders(b"host"), [target.encode("utf-8")]
+ )
+
+ return channel, request
+
+ def test_basic(self) -> None:
+ """Test basic fetching of remote media from a single worker."""
+ hs1 = self.make_worker_hs("synapse.app.generic_worker")
+
+ channel, request = self._get_media_req(hs1, "example.com:443", "ABC123")
+
+ request.setResponseCode(200)
+ request.responseHeaders.setRawHeaders(
+ b"Content-Type",
+ ["multipart/mixed; boundary=6067d4698f8d40a0a794ea7d7379d53a"],
+ )
+ request.write(self.file_data)
+ request.finish()
+
+ self.pump(0.1)
+
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.result["body"], b"file_to_stream")
+
+ def test_download_simple_file_race(self) -> None:
+ """Test that fetching remote media from two different processes at the
+ same time works.
+ """
+ hs1 = self.make_worker_hs("synapse.app.generic_worker")
+ hs2 = self.make_worker_hs("synapse.app.generic_worker")
+
+ start_count = self._count_remote_media()
+
+ # Make two requests without responding to the outbound media requests.
+ channel1, request1 = self._get_media_req(hs1, "example.com:443", "ABC123")
+ channel2, request2 = self._get_media_req(hs2, "example.com:443", "ABC123")
+
+ # Respond to the first outbound media request and check that the client
+ # request is successful
+ request1.setResponseCode(200)
+ request1.responseHeaders.setRawHeaders(
+ b"Content-Type",
+ ["multipart/mixed; boundary=6067d4698f8d40a0a794ea7d7379d53a"],
+ )
+ request1.write(self.file_data)
+ request1.finish()
+
+ self.pump(0.1)
+
+ self.assertEqual(channel1.code, 200, channel1.result["body"])
+ self.assertEqual(channel1.result["body"], b"file_to_stream")
+
+ # Now respond to the second with the same content.
+ request2.setResponseCode(200)
+ request2.responseHeaders.setRawHeaders(
+ b"Content-Type",
+ ["multipart/mixed; boundary=6067d4698f8d40a0a794ea7d7379d53a"],
+ )
+ request2.write(self.file_data)
+ request2.finish()
+
+ self.pump(0.1)
+
+ self.assertEqual(channel2.code, 200, channel2.result["body"])
+ self.assertEqual(channel2.result["body"], b"file_to_stream")
+
+ # We expect only one new file to have been persisted.
+ self.assertEqual(start_count + 1, self._count_remote_media())
+
+ def test_download_image_race(self) -> None:
+ """Test that fetching remote *images* from two different processes at
+ the same time works.
+
+ This checks that races generating thumbnails are handled correctly.
+ """
+ hs1 = self.make_worker_hs("synapse.app.generic_worker")
+ hs2 = self.make_worker_hs("synapse.app.generic_worker")
+
+ start_count = self._count_remote_thumbnails()
+
+ channel1, request1 = self._get_media_req(hs1, "example.com:443", "PIC1")
+ channel2, request2 = self._get_media_req(hs2, "example.com:443", "PIC1")
+
+ request1.setResponseCode(200)
+ request1.responseHeaders.setRawHeaders(
+ b"Content-Type",
+ ["multipart/mixed; boundary=6067d4698f8d40a0a794ea7d7379d53a"],
+ )
+ img_data = b"\r\n\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: application/json\r\n\r\n{}\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: image/png\r\nContent-Disposition: inline; filename=test_img\r\n\r\n"
+ request1.write(img_data)
+ request1.write(SMALL_PNG)
+ request1.write(b"\r\n--6067d4698f8d40a0a794ea7d7379d53a--\r\n\r\n")
+ request1.finish()
+
+ self.pump(0.1)
+
+ self.assertEqual(channel1.code, 200, channel1.result["body"])
+ self.assertEqual(channel1.result["body"], SMALL_PNG)
+
+ request2.setResponseCode(200)
+ request2.responseHeaders.setRawHeaders(
+ b"Content-Type",
+ ["multipart/mixed; boundary=6067d4698f8d40a0a794ea7d7379d53a"],
+ )
+ request2.write(img_data)
+ request2.write(SMALL_PNG)
+ request2.write(b"\r\n--6067d4698f8d40a0a794ea7d7379d53a--\r\n\r\n")
+ request2.finish()
+
+ self.pump(0.1)
+
+ self.assertEqual(channel2.code, 200, channel2.result["body"])
+ self.assertEqual(channel2.result["body"], SMALL_PNG)
+
+ # We expect only three new thumbnails to have been persisted.
+ self.assertEqual(start_count + 3, self._count_remote_thumbnails())
+
+ def _count_remote_media(self) -> int:
+ """Count the number of files in our remote media directory."""
+ path = os.path.join(
+ self.hs.get_media_repository().primary_base_path, "remote_content"
+ )
+ return sum(len(files) for _, _, files in os.walk(path))
+
+ def _count_remote_thumbnails(self) -> int:
+ """Count the number of files in our remote thumbnails directory."""
+ path = os.path.join(
+ self.hs.get_media_repository().primary_base_path, "remote_thumbnail"
+ )
+ return sum(len(files) for _, _, files in os.walk(path))
+
+
def _log_request(request: Request) -> None:
"""Implements Factory.log, which is expected by Request.finish"""
logger.info("Completed request %s", request)
diff --git a/tests/rest/client/test_media.py b/tests/rest/client/test_media.py
index be4a289ec1..6b5af2dbb6 100644
--- a/tests/rest/client/test_media.py
+++ b/tests/rest/client/test_media.py
@@ -19,31 +19,54 @@
#
#
import base64
+import io
import json
import os
import re
-from typing import Any, Dict, Optional, Sequence, Tuple, Type
+from typing import Any, BinaryIO, ClassVar, Dict, List, Optional, Sequence, Tuple, Type
+from unittest.mock import MagicMock, Mock, patch
+from urllib import parse
from urllib.parse import quote, urlencode
+from parameterized import parameterized_class
+
+from twisted.internet import defer
from twisted.internet._resolver import HostResolution
from twisted.internet.address import IPv4Address, IPv6Address
+from twisted.internet.defer import Deferred
from twisted.internet.error import DNSLookupError
from twisted.internet.interfaces import IAddress, IResolutionReceiver
+from twisted.python.failure import Failure
from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactor
+from twisted.web.http_headers import Headers
+from twisted.web.iweb import UNKNOWN_LENGTH, IResponse
from twisted.web.resource import Resource
+from synapse.api.errors import HttpResponseException
+from synapse.api.ratelimiting import Ratelimiter
from synapse.config.oembed import OEmbedEndpointConfig
+from synapse.http.client import MultipartResponse
+from synapse.http.types import QueryParams
+from synapse.logging.context import make_deferred_yieldable
from synapse.media._base import FileInfo
from synapse.media.url_previewer import IMAGE_CACHE_EXPIRY_MS
from synapse.rest import admin
from synapse.rest.client import login, media
from synapse.server import HomeServer
-from synapse.types import JsonDict
+from synapse.types import JsonDict, UserID
from synapse.util import Clock
from synapse.util.stringutils import parse_and_validate_mxc_uri
from tests import unittest
-from tests.server import FakeTransport, ThreadedMemoryReactorClock
+from tests.media.test_media_storage import (
+ SVG,
+ TestImage,
+ empty_file,
+ small_lossless_webp,
+ small_png,
+ small_png_with_transparency,
+)
+from tests.server import FakeChannel, FakeTransport, ThreadedMemoryReactorClock
from tests.test_utils import SMALL_PNG
from tests.unittest import override_config
@@ -1607,3 +1630,583 @@ class UnstableMediaConfigTest(unittest.HomeserverTestCase):
self.assertEqual(
channel.json_body["m.upload.size"], self.hs.config.media.max_upload_size
)
+
+
+class RemoteDownloadLimiterTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ media.register_servlets,
+ login.register_servlets,
+ admin.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+ config = self.default_config()
+
+ self.storage_path = self.mktemp()
+ self.media_store_path = self.mktemp()
+ os.mkdir(self.storage_path)
+ os.mkdir(self.media_store_path)
+ config["media_store_path"] = self.media_store_path
+
+ provider_config = {
+ "module": "synapse.media.storage_provider.FileStorageProviderBackend",
+ "store_local": True,
+ "store_synchronous": False,
+ "store_remote": True,
+ "config": {"directory": self.storage_path},
+ }
+
+ config["media_storage_providers"] = [provider_config]
+
+ return self.setup_test_homeserver(config=config)
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.repo = hs.get_media_repository()
+ self.client = hs.get_federation_http_client()
+ self.store = hs.get_datastores().main
+ self.user = self.register_user("user", "pass")
+ self.tok = self.login("user", "pass")
+
+ # mock actually reading file body
+ def read_multipart_response_30MiB(*args: Any, **kwargs: Any) -> Deferred:
+ d: Deferred = defer.Deferred()
+ d.callback(MultipartResponse(b"{}", 31457280, b"img/png", None))
+ return d
+
+ def read_multipart_response_50MiB(*args: Any, **kwargs: Any) -> Deferred:
+ d: Deferred = defer.Deferred()
+ d.callback(MultipartResponse(b"{}", 31457280, b"img/png", None))
+ return d
+
+ @patch(
+ "synapse.http.matrixfederationclient.read_multipart_response",
+ read_multipart_response_30MiB,
+ )
+ def test_download_ratelimit_default(self) -> None:
+ """
+ Test remote media download ratelimiting against default configuration - 500MB bucket
+ and 87kb/second drain rate
+ """
+
+ # mock out actually sending the request, returns a 30MiB response
+ async def _send_request(*args: Any, **kwargs: Any) -> IResponse:
+ resp = MagicMock(spec=IResponse)
+ resp.code = 200
+ resp.length = 31457280
+ resp.headers = Headers(
+ {"Content-Type": ["multipart/mixed; boundary=gc0p4Jq0M2Yt08jU534c0p"]}
+ )
+ resp.phrase = b"OK"
+ return resp
+
+ self.client._send_request = _send_request # type: ignore
+
+ # first request should go through
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/v1/media/download/remote.org/abc",
+ shorthand=False,
+ access_token=self.tok,
+ )
+ assert channel.code == 200
+
+ # next 15 should go through
+ for i in range(15):
+ channel2 = self.make_request(
+ "GET",
+ f"/_matrix/client/v1/media/download/remote.org/abc{i}",
+ shorthand=False,
+ access_token=self.tok,
+ )
+ assert channel2.code == 200
+
+ # 17th will hit ratelimit
+ channel3 = self.make_request(
+ "GET",
+ "/_matrix/client/v1/media/download/remote.org/abcd",
+ shorthand=False,
+ access_token=self.tok,
+ )
+ assert channel3.code == 429
+
+ # however, a request from a different IP will go through
+ channel4 = self.make_request(
+ "GET",
+ "/_matrix/client/v1/media/download/remote.org/abcde",
+ shorthand=False,
+ client_ip="187.233.230.159",
+ access_token=self.tok,
+ )
+ assert channel4.code == 200
+
+ # at 87Kib/s it should take about 2 minutes for enough to drain from bucket that another
+ # 30MiB download is authorized - The last download was blocked at 503,316,480.
+ # The next download will be authorized when bucket hits 492,830,720
+ # (524,288,000 total capacity - 31,457,280 download size) so 503,316,480 - 492,830,720 ~= 10,485,760
+ # needs to drain before another download will be authorized, that will take ~=
+ # 2 minutes (10,485,760/89,088/60)
+ self.reactor.pump([2.0 * 60.0])
+
+ # enough has drained and next request goes through
+ channel5 = self.make_request(
+ "GET",
+ "/_matrix/client/v1/media/download/remote.org/abcdef",
+ shorthand=False,
+ access_token=self.tok,
+ )
+ assert channel5.code == 200
+
+ @override_config(
+ {
+ "remote_media_download_per_second": "50M",
+ "remote_media_download_burst_count": "50M",
+ }
+ )
+ @patch(
+ "synapse.http.matrixfederationclient.read_multipart_response",
+ read_multipart_response_50MiB,
+ )
+ def test_download_rate_limit_config(self) -> None:
+ """
+ Test that download rate limit config options are correctly picked up and applied
+ """
+
+ async def _send_request(*args: Any, **kwargs: Any) -> IResponse:
+ resp = MagicMock(spec=IResponse)
+ resp.code = 200
+ resp.length = 52428800
+ resp.headers = Headers(
+ {"Content-Type": ["multipart/mixed; boundary=gc0p4Jq0M2Yt08jU534c0p"]}
+ )
+ resp.phrase = b"OK"
+ return resp
+
+ self.client._send_request = _send_request # type: ignore
+
+ # first request should go through
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/v1/media/download/remote.org/abc",
+ shorthand=False,
+ access_token=self.tok,
+ )
+ assert channel.code == 200
+
+ # immediate second request should fail
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/v1/media/download/remote.org/abcd",
+ shorthand=False,
+ access_token=self.tok,
+ )
+ assert channel.code == 429
+
+ # advance half a second
+ self.reactor.pump([0.5])
+
+ # request still fails
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/v1/media/download/remote.org/abcde",
+ shorthand=False,
+ access_token=self.tok,
+ )
+ assert channel.code == 429
+
+ # advance another half second
+ self.reactor.pump([0.5])
+
+ # enough has drained from bucket and request is successful
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/v1/media/download/remote.org/abcdef",
+ shorthand=False,
+ access_token=self.tok,
+ )
+ assert channel.code == 200
+
+ @patch(
+ "synapse.http.matrixfederationclient.read_multipart_response",
+ read_multipart_response_30MiB,
+ )
+ def test_download_ratelimit_max_size_sub(self) -> None:
+ """
+ Test that if no content-length is provided, the default max size is applied instead
+ """
+
+ # mock out actually sending the request
+ async def _send_request(*args: Any, **kwargs: Any) -> IResponse:
+ resp = MagicMock(spec=IResponse)
+ resp.code = 200
+ resp.length = UNKNOWN_LENGTH
+ resp.headers = Headers(
+ {"Content-Type": ["multipart/mixed; boundary=gc0p4Jq0M2Yt08jU534c0p"]}
+ )
+ resp.phrase = b"OK"
+ return resp
+
+ self.client._send_request = _send_request # type: ignore
+
+ # ten requests should go through using the max size (500MB/50MB)
+ for i in range(10):
+ channel2 = self.make_request(
+ "GET",
+ f"/_matrix/client/v1/media/download/remote.org/abc{i}",
+ shorthand=False,
+ access_token=self.tok,
+ )
+ assert channel2.code == 200
+
+ # eleventh will hit ratelimit
+ channel3 = self.make_request(
+ "GET",
+ "/_matrix/client/v1/media/download/remote.org/abcd",
+ shorthand=False,
+ access_token=self.tok,
+ )
+ assert channel3.code == 429
+
+ def test_file_download(self) -> None:
+ content = io.BytesIO(b"file_to_stream")
+ content_uri = self.get_success(
+ self.repo.create_content(
+ "text/plain",
+ "test_upload",
+ content,
+ 46,
+ UserID.from_string("@user_id:whatever.org"),
+ )
+ )
+ # test with a text file
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/v1/media/download/test/{content_uri.media_id}",
+ shorthand=False,
+ access_token=self.tok,
+ )
+ self.pump()
+ self.assertEqual(200, channel.code)
+
+
+test_images = [
+ small_png,
+ small_png_with_transparency,
+ small_lossless_webp,
+ empty_file,
+ SVG,
+]
+input_values = [(x,) for x in test_images]
+
+
+@parameterized_class(("test_image",), input_values)
+class DownloadTestCase(unittest.HomeserverTestCase):
+ test_image: ClassVar[TestImage]
+ servlets = [
+ media.register_servlets,
+ login.register_servlets,
+ admin.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+ self.fetches: List[
+ Tuple[
+ "Deferred[Any]",
+ str,
+ str,
+ Optional[QueryParams],
+ ]
+ ] = []
+
+ def federation_get_file(
+ destination: str,
+ path: str,
+ output_stream: BinaryIO,
+ download_ratelimiter: Ratelimiter,
+ ip_address: Any,
+ max_size: int,
+ args: Optional[QueryParams] = None,
+ retry_on_dns_fail: bool = True,
+ ignore_backoff: bool = False,
+ follow_redirects: bool = False,
+ ) -> "Deferred[Tuple[int, Dict[bytes, List[bytes]], bytes]]":
+ """A mock for MatrixFederationHttpClient.federation_get_file."""
+
+ def write_to(
+ r: Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]], bytes]]
+ ) -> Tuple[int, Dict[bytes, List[bytes]], bytes]:
+ data, response = r
+ output_stream.write(data)
+ return response
+
+ def write_err(f: Failure) -> Failure:
+ f.trap(HttpResponseException)
+ output_stream.write(f.value.response)
+ return f
+
+ d: Deferred[Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]], bytes]]] = (
+ Deferred()
+ )
+ self.fetches.append((d, destination, path, args))
+ # Note that this callback changes the value held by d.
+ d_after_callback = d.addCallbacks(write_to, write_err)
+ return make_deferred_yieldable(d_after_callback)
+
+ def get_file(
+ destination: str,
+ path: str,
+ output_stream: BinaryIO,
+ download_ratelimiter: Ratelimiter,
+ ip_address: Any,
+ max_size: int,
+ args: Optional[QueryParams] = None,
+ retry_on_dns_fail: bool = True,
+ ignore_backoff: bool = False,
+ follow_redirects: bool = False,
+ ) -> "Deferred[Tuple[int, Dict[bytes, List[bytes]]]]":
+ """A mock for MatrixFederationHttpClient.get_file."""
+
+ def write_to(
+ r: Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]
+ ) -> Tuple[int, Dict[bytes, List[bytes]]]:
+ data, response = r
+ output_stream.write(data)
+ return response
+
+ def write_err(f: Failure) -> Failure:
+ f.trap(HttpResponseException)
+ output_stream.write(f.value.response)
+ return f
+
+ d: Deferred[Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]] = Deferred()
+ self.fetches.append((d, destination, path, args))
+ # Note that this callback changes the value held by d.
+ d_after_callback = d.addCallbacks(write_to, write_err)
+ return make_deferred_yieldable(d_after_callback)
+
+ # Mock out the homeserver's MatrixFederationHttpClient
+ client = Mock()
+ client.federation_get_file = federation_get_file
+ client.get_file = get_file
+
+ self.storage_path = self.mktemp()
+ self.media_store_path = self.mktemp()
+ os.mkdir(self.storage_path)
+ os.mkdir(self.media_store_path)
+
+ config = self.default_config()
+ config["media_store_path"] = self.media_store_path
+ config["max_image_pixels"] = 2000000
+
+ provider_config = {
+ "module": "synapse.media.storage_provider.FileStorageProviderBackend",
+ "store_local": True,
+ "store_synchronous": False,
+ "store_remote": True,
+ "config": {"directory": self.storage_path},
+ }
+ config["media_storage_providers"] = [provider_config]
+ config["experimental_features"] = {"msc3916_authenticated_media_enabled": True}
+
+ hs = self.setup_test_homeserver(config=config, federation_http_client=client)
+
+ return hs
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.store = hs.get_datastores().main
+ self.media_repo = hs.get_media_repository()
+
+ self.remote = "example.com"
+ self.media_id = "12345"
+
+ self.user = self.register_user("user", "pass")
+ self.tok = self.login("user", "pass")
+
+ def _req(
+ self, content_disposition: Optional[bytes], include_content_type: bool = True
+ ) -> FakeChannel:
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/v1/media/download/{self.remote}/{self.media_id}",
+ shorthand=False,
+ await_result=False,
+ access_token=self.tok,
+ )
+ self.pump()
+
+ # We've made one fetch, to example.com, using the federation media URL
+ self.assertEqual(len(self.fetches), 1)
+ self.assertEqual(self.fetches[0][1], "example.com")
+ self.assertEqual(
+ self.fetches[0][2], "/_matrix/federation/v1/media/download/" + self.media_id
+ )
+ self.assertEqual(
+ self.fetches[0][3],
+ {"timeout_ms": "20000"},
+ )
+
+ headers = {
+ b"Content-Length": [b"%d" % (len(self.test_image.data))],
+ }
+
+ if include_content_type:
+ headers[b"Content-Type"] = [self.test_image.content_type]
+
+ if content_disposition:
+ headers[b"Content-Disposition"] = [content_disposition]
+
+ self.fetches[0][0].callback(
+ (self.test_image.data, (len(self.test_image.data), headers, b"{}"))
+ )
+
+ self.pump()
+ self.assertEqual(channel.code, 200)
+
+ return channel
+
+ def test_handle_missing_content_type(self) -> None:
+ channel = self._req(
+ b"attachment; filename=out" + self.test_image.extension,
+ include_content_type=False,
+ )
+ headers = channel.headers
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(
+ headers.getRawHeaders(b"Content-Type"), [b"application/octet-stream"]
+ )
+
+ def test_disposition_filename_ascii(self) -> None:
+ """
+ If the filename is filename=<ascii> then Synapse will decode it as an
+ ASCII string, and use filename= in the response.
+ """
+ channel = self._req(b"attachment; filename=out" + self.test_image.extension)
+
+ headers = channel.headers
+ self.assertEqual(
+ headers.getRawHeaders(b"Content-Type"), [self.test_image.content_type]
+ )
+ self.assertEqual(
+ headers.getRawHeaders(b"Content-Disposition"),
+ [
+ (b"inline" if self.test_image.is_inline else b"attachment")
+ + b"; filename=out"
+ + self.test_image.extension
+ ],
+ )
+
+ def test_disposition_filenamestar_utf8escaped(self) -> None:
+ """
+ If the filename is filename=*utf8''<utf8 escaped> then Synapse will
+ correctly decode it as the UTF-8 string, and use filename* in the
+ response.
+ """
+ filename = parse.quote("\u2603".encode()).encode("ascii")
+ channel = self._req(
+ b"attachment; filename*=utf-8''" + filename + self.test_image.extension
+ )
+
+ headers = channel.headers
+ self.assertEqual(
+ headers.getRawHeaders(b"Content-Type"), [self.test_image.content_type]
+ )
+ self.assertEqual(
+ headers.getRawHeaders(b"Content-Disposition"),
+ [
+ (b"inline" if self.test_image.is_inline else b"attachment")
+ + b"; filename*=utf-8''"
+ + filename
+ + self.test_image.extension
+ ],
+ )
+
+ def test_disposition_none(self) -> None:
+ """
+ If there is no filename, Content-Disposition should only
+ be a disposition type.
+ """
+ channel = self._req(None)
+
+ headers = channel.headers
+ self.assertEqual(
+ headers.getRawHeaders(b"Content-Type"), [self.test_image.content_type]
+ )
+ self.assertEqual(
+ headers.getRawHeaders(b"Content-Disposition"),
+ [b"inline" if self.test_image.is_inline else b"attachment"],
+ )
+
+ def test_x_robots_tag_header(self) -> None:
+ """
+ Tests that the `X-Robots-Tag` header is present, which informs web crawlers
+ to not index, archive, or follow links in media.
+ """
+ channel = self._req(b"attachment; filename=out" + self.test_image.extension)
+
+ headers = channel.headers
+ self.assertEqual(
+ headers.getRawHeaders(b"X-Robots-Tag"),
+ [b"noindex, nofollow, noarchive, noimageindex"],
+ )
+
+ def test_cross_origin_resource_policy_header(self) -> None:
+ """
+ Test that the Cross-Origin-Resource-Policy header is set to "cross-origin"
+ allowing web clients to embed media from the downloads API.
+ """
+ channel = self._req(b"attachment; filename=out" + self.test_image.extension)
+
+ headers = channel.headers
+
+ self.assertEqual(
+ headers.getRawHeaders(b"Cross-Origin-Resource-Policy"),
+ [b"cross-origin"],
+ )
+
+ def test_unknown_federation_endpoint(self) -> None:
+ """
+ Test that if the downloadd request to remote federation endpoint returns a 404
+ we fall back to the _matrix/media endpoint
+ """
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/v1/media/download/{self.remote}/{self.media_id}",
+ shorthand=False,
+ await_result=False,
+ access_token=self.tok,
+ )
+ self.pump()
+
+ # We've made one fetch, to example.com, using the media URL, and asking
+ # the other server not to do a remote fetch
+ self.assertEqual(len(self.fetches), 1)
+ self.assertEqual(self.fetches[0][1], "example.com")
+ self.assertEqual(
+ self.fetches[0][2], f"/_matrix/federation/v1/media/download/{self.media_id}"
+ )
+
+ # The result which says the endpoint is unknown.
+ unknown_endpoint = b'{"errcode":"M_UNRECOGNIZED","error":"Unknown request"}'
+ self.fetches[0][0].errback(
+ HttpResponseException(404, "NOT FOUND", unknown_endpoint)
+ )
+
+ self.pump()
+
+ # There should now be another request to the _matrix/media/v3/download URL.
+ self.assertEqual(len(self.fetches), 2)
+ self.assertEqual(self.fetches[1][1], "example.com")
+ self.assertEqual(
+ self.fetches[1][2],
+ f"/_matrix/media/v3/download/example.com/{self.media_id}",
+ )
+
+ headers = {
+ b"Content-Length": [b"%d" % (len(self.test_image.data))],
+ }
+
+ self.fetches[1][0].callback(
+ (self.test_image.data, (len(self.test_image.data), headers))
+ )
+
+ self.pump()
+ self.assertEqual(channel.code, 200)
diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py
index 12c11f342c..966c622e14 100644
--- a/tests/rest/client/test_sync.py
+++ b/tests/rest/client/test_sync.py
@@ -31,12 +31,13 @@ from synapse.api.constants import (
AccountDataTypes,
EventContentFields,
EventTypes,
+ HistoryVisibility,
ReceiptTypes,
RelationTypes,
)
from synapse.rest.client import devices, knock, login, read_marker, receipts, room, sync
from synapse.server import HomeServer
-from synapse.types import JsonDict, RoomStreamToken, StreamKeyType
+from synapse.types import JsonDict, RoomStreamToken, StreamKeyType, StreamToken, UserID
from synapse.util import Clock
from tests import unittest
@@ -1326,7 +1327,7 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
def test_sync_list(self) -> None:
"""
- Test that room IDs show up in the Sliding Sync lists
+ Test that room IDs show up in the Sliding Sync `lists`
"""
alice_user_id = self.register_user("alice", "correcthorse")
alice_access_token = self.login(alice_user_id, "correcthorse")
@@ -1425,15 +1426,13 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
channel.await_result(timeout_ms=200)
self.assertEqual(channel.code, 200, channel.json_body)
- # We expect the `next_pos` in the result to be the same as what we requested
+ # We expect the next `pos` in the result to be the same as what we requested
# with because we weren't able to find anything new yet.
- self.assertEqual(
- channel.json_body["next_pos"], future_position_token_serialized
- )
+ self.assertEqual(channel.json_body["pos"], future_position_token_serialized)
def test_filter_list(self) -> None:
"""
- Test that filters apply to lists
+ Test that filters apply to `lists`
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
@@ -1564,7 +1563,7 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
def test_sort_list(self) -> None:
"""
- Test that the lists are sorted by `stream_ordering`
+ Test that the `lists` are sorted by `stream_ordering`
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
@@ -1618,3 +1617,1067 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
],
channel.json_body["lists"]["foo-list"],
)
+
+ def test_sliced_windows(self) -> None:
+ """
+ Test that the `lists` `ranges` are sliced correctly. Both sides of each range
+ are inclusive.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ _room_id1 = self.helper.create_room_as(user1_id, tok=user1_tok, is_public=True)
+ room_id2 = self.helper.create_room_as(user1_id, tok=user1_tok, is_public=True)
+ room_id3 = self.helper.create_room_as(user1_id, tok=user1_tok, is_public=True)
+
+ # Make the Sliding Sync request for a single room
+ channel = self.make_request(
+ "POST",
+ self.sync_endpoint,
+ {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 0]],
+ "required_state": [
+ ["m.room.join_rules", ""],
+ ["m.room.history_visibility", ""],
+ ["m.space.child", "*"],
+ ],
+ "timeline_limit": 1,
+ }
+ }
+ },
+ access_token=user1_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # Make sure it has the foo-list we requested
+ self.assertListEqual(
+ list(channel.json_body["lists"].keys()),
+ ["foo-list"],
+ channel.json_body["lists"].keys(),
+ )
+ # Make sure the list is sorted in the way we expect
+ self.assertListEqual(
+ list(channel.json_body["lists"]["foo-list"]["ops"]),
+ [
+ {
+ "op": "SYNC",
+ "range": [0, 0],
+ "room_ids": [room_id3],
+ }
+ ],
+ channel.json_body["lists"]["foo-list"],
+ )
+
+ # Make the Sliding Sync request for the first two rooms
+ channel = self.make_request(
+ "POST",
+ self.sync_endpoint,
+ {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ "required_state": [
+ ["m.room.join_rules", ""],
+ ["m.room.history_visibility", ""],
+ ["m.space.child", "*"],
+ ],
+ "timeline_limit": 1,
+ }
+ }
+ },
+ access_token=user1_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # Make sure it has the foo-list we requested
+ self.assertListEqual(
+ list(channel.json_body["lists"].keys()),
+ ["foo-list"],
+ channel.json_body["lists"].keys(),
+ )
+ # Make sure the list is sorted in the way we expect
+ self.assertListEqual(
+ list(channel.json_body["lists"]["foo-list"]["ops"]),
+ [
+ {
+ "op": "SYNC",
+ "range": [0, 1],
+ "room_ids": [room_id3, room_id2],
+ }
+ ],
+ channel.json_body["lists"]["foo-list"],
+ )
+
+ def test_rooms_limited_initial_sync(self) -> None:
+ """
+ Test that we mark `rooms` as `limited=True` when we saturate the `timeline_limit`
+ on initial sync.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
+ self.helper.send(room_id1, "activity1", tok=user2_tok)
+ self.helper.send(room_id1, "activity2", tok=user2_tok)
+ event_response3 = self.helper.send(room_id1, "activity3", tok=user2_tok)
+ event_pos3 = self.get_success(
+ self.store.get_position_for_event(event_response3["event_id"])
+ )
+ event_response4 = self.helper.send(room_id1, "activity4", tok=user2_tok)
+ event_pos4 = self.get_success(
+ self.store.get_position_for_event(event_response4["event_id"])
+ )
+ event_response5 = self.helper.send(room_id1, "activity5", tok=user2_tok)
+ user1_join_response = self.helper.join(room_id1, user1_id, tok=user1_tok)
+
+ # Make the Sliding Sync request
+ channel = self.make_request(
+ "POST",
+ self.sync_endpoint,
+ {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ "required_state": [],
+ "timeline_limit": 3,
+ }
+ }
+ },
+ access_token=user1_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # We expect to saturate the `timeline_limit` (there are more than 3 messages in the room)
+ self.assertEqual(
+ channel.json_body["rooms"][room_id1]["limited"],
+ True,
+ channel.json_body["rooms"][room_id1],
+ )
+ # Check to make sure the latest events are returned
+ self.assertEqual(
+ [
+ event["event_id"]
+ for event in channel.json_body["rooms"][room_id1]["timeline"]
+ ],
+ [
+ event_response4["event_id"],
+ event_response5["event_id"],
+ user1_join_response["event_id"],
+ ],
+ channel.json_body["rooms"][room_id1]["timeline"],
+ )
+
+ # Check to make sure the `prev_batch` points at the right place
+ prev_batch_token = self.get_success(
+ StreamToken.from_string(
+ self.store, channel.json_body["rooms"][room_id1]["prev_batch"]
+ )
+ )
+ prev_batch_room_stream_token_serialized = self.get_success(
+ prev_batch_token.room_key.to_string(self.store)
+ )
+ # If we use the `prev_batch` token to look backwards, we should see `event3`
+ # next so make sure the token encompasses it
+ self.assertEqual(
+ event_pos3.persisted_after(prev_batch_token.room_key),
+ False,
+ f"`prev_batch` token {prev_batch_room_stream_token_serialized} should be >= event_pos3={self.get_success(event_pos3.to_room_stream_token().to_string(self.store))}",
+ )
+ # If we use the `prev_batch` token to look backwards, we shouldn't see `event4`
+ # anymore since it was just returned in this response.
+ self.assertEqual(
+ event_pos4.persisted_after(prev_batch_token.room_key),
+ True,
+ f"`prev_batch` token {prev_batch_room_stream_token_serialized} should be < event_pos4={self.get_success(event_pos4.to_room_stream_token().to_string(self.store))}",
+ )
+
+ # With no `from_token` (initial sync), it's all historical since there is no
+ # "live" range
+ self.assertEqual(
+ channel.json_body["rooms"][room_id1]["num_live"],
+ 0,
+ channel.json_body["rooms"][room_id1],
+ )
+
+ def test_rooms_not_limited_initial_sync(self) -> None:
+ """
+ Test that we mark `rooms` as `limited=False` when there are no more events to
+ paginate to.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
+ self.helper.send(room_id1, "activity1", tok=user2_tok)
+ self.helper.send(room_id1, "activity2", tok=user2_tok)
+ self.helper.send(room_id1, "activity3", tok=user2_tok)
+ self.helper.join(room_id1, user1_id, tok=user1_tok)
+
+ # Make the Sliding Sync request
+ timeline_limit = 100
+ channel = self.make_request(
+ "POST",
+ self.sync_endpoint,
+ {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ "required_state": [],
+ "timeline_limit": timeline_limit,
+ }
+ }
+ },
+ access_token=user1_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # The timeline should be `limited=False` because we have all of the events (no
+ # more to paginate to)
+ self.assertEqual(
+ channel.json_body["rooms"][room_id1]["limited"],
+ False,
+ channel.json_body["rooms"][room_id1],
+ )
+ expected_number_of_events = 9
+ # We're just looking to make sure we got all of the events before hitting the `timeline_limit`
+ self.assertEqual(
+ len(channel.json_body["rooms"][room_id1]["timeline"]),
+ expected_number_of_events,
+ channel.json_body["rooms"][room_id1]["timeline"],
+ )
+ self.assertLessEqual(expected_number_of_events, timeline_limit)
+
+ # With no `from_token` (initial sync), it's all historical since there is no
+ # "live" token range.
+ self.assertEqual(
+ channel.json_body["rooms"][room_id1]["num_live"],
+ 0,
+ channel.json_body["rooms"][room_id1],
+ )
+
+ def test_rooms_incremental_sync(self) -> None:
+ """
+ Test `rooms` data during an incremental sync after an initial sync.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
+ self.helper.join(room_id1, user1_id, tok=user1_tok)
+ self.helper.send(room_id1, "activity before initial sync1", tok=user2_tok)
+
+ # Make an initial Sliding Sync request to grab a token. This is also a sanity
+ # check that we can go from initial to incremental sync.
+ sync_params = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ "required_state": [],
+ "timeline_limit": 3,
+ }
+ }
+ }
+ channel = self.make_request(
+ "POST",
+ self.sync_endpoint,
+ sync_params,
+ access_token=user1_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+ next_pos = channel.json_body["pos"]
+
+ # Send some events but don't send enough to saturate the `timeline_limit`.
+ # We want to later test that we only get the new events since the `next_pos`
+ event_response2 = self.helper.send(room_id1, "activity after2", tok=user2_tok)
+ event_response3 = self.helper.send(room_id1, "activity after3", tok=user2_tok)
+
+ # Make an incremental Sliding Sync request (what we're trying to test)
+ channel = self.make_request(
+ "POST",
+ self.sync_endpoint + f"?pos={next_pos}",
+ sync_params,
+ access_token=user1_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # We only expect to see the new events since the last sync which isn't enough to
+ # fill up the `timeline_limit`.
+ self.assertEqual(
+ channel.json_body["rooms"][room_id1]["limited"],
+ False,
+ f'Our `timeline_limit` was {sync_params["lists"]["foo-list"]["timeline_limit"]} '
+ + f'and {len(channel.json_body["rooms"][room_id1]["timeline"])} events were returned in the timeline. '
+ + str(channel.json_body["rooms"][room_id1]),
+ )
+ # Check to make sure the latest events are returned
+ self.assertEqual(
+ [
+ event["event_id"]
+ for event in channel.json_body["rooms"][room_id1]["timeline"]
+ ],
+ [
+ event_response2["event_id"],
+ event_response3["event_id"],
+ ],
+ channel.json_body["rooms"][room_id1]["timeline"],
+ )
+
+ # All events are "live"
+ self.assertEqual(
+ channel.json_body["rooms"][room_id1]["num_live"],
+ 2,
+ channel.json_body["rooms"][room_id1],
+ )
+
+ def test_rooms_newly_joined_incremental_sync(self) -> None:
+ """
+ Test that when we make an incremental sync with a `newly_joined` `rooms`, we are
+ able to see some historical events before the `from_token`.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
+ self.helper.send(room_id1, "activity before token1", tok=user2_tok)
+ event_response2 = self.helper.send(
+ room_id1, "activity before token2", tok=user2_tok
+ )
+
+ from_token = self.event_sources.get_current_token()
+
+ # Join the room after the `from_token` which will make us consider this room as
+ # `newly_joined`.
+ user1_join_response = self.helper.join(room_id1, user1_id, tok=user1_tok)
+
+ # Send some events but don't send enough to saturate the `timeline_limit`.
+ # We want to later test that we only get the new events since the `next_pos`
+ event_response3 = self.helper.send(
+ room_id1, "activity after token3", tok=user2_tok
+ )
+ event_response4 = self.helper.send(
+ room_id1, "activity after token4", tok=user2_tok
+ )
+
+ # The `timeline_limit` is set to 4 so we can at least see one historical event
+ # before the `from_token`. We should see historical events because this is a
+ # `newly_joined` room.
+ timeline_limit = 4
+ # Make an incremental Sliding Sync request (what we're trying to test)
+ channel = self.make_request(
+ "POST",
+ self.sync_endpoint
+ + f"?pos={self.get_success(from_token.to_string(self.store))}",
+ {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ "required_state": [],
+ "timeline_limit": timeline_limit,
+ }
+ }
+ },
+ access_token=user1_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # We should see the new events and the rest should be filled with historical
+ # events which will make us `limited=True` since there are more to paginate to.
+ self.assertEqual(
+ channel.json_body["rooms"][room_id1]["limited"],
+ True,
+ f"Our `timeline_limit` was {timeline_limit} "
+ + f'and {len(channel.json_body["rooms"][room_id1]["timeline"])} events were returned in the timeline. '
+ + str(channel.json_body["rooms"][room_id1]),
+ )
+ # Check to make sure that the "live" and historical events are returned
+ self.assertEqual(
+ [
+ event["event_id"]
+ for event in channel.json_body["rooms"][room_id1]["timeline"]
+ ],
+ [
+ event_response2["event_id"],
+ user1_join_response["event_id"],
+ event_response3["event_id"],
+ event_response4["event_id"],
+ ],
+ channel.json_body["rooms"][room_id1]["timeline"],
+ )
+
+ # Only events after the `from_token` are "live" (join, event3, event4)
+ self.assertEqual(
+ channel.json_body["rooms"][room_id1]["num_live"],
+ 3,
+ channel.json_body["rooms"][room_id1],
+ )
+
+ def test_rooms_invite_shared_history_initial_sync(self) -> None:
+ """
+ Test that `rooms` we are invited to have some stripped `invite_state` during an
+ initial sync.
+
+ This is an `invite` room so we should only have `stripped_state` (no `timeline`)
+ but we also shouldn't see any timeline events because the history visiblity is
+ `shared` and we haven't joined the room yet.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user1 = UserID.from_string(user1_id)
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+ user2 = UserID.from_string(user2_id)
+
+ room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
+ # Ensure we're testing with a room with `shared` history visibility which means
+ # history visible until you actually join the room.
+ history_visibility_response = self.helper.get_state(
+ room_id1, EventTypes.RoomHistoryVisibility, tok=user2_tok
+ )
+ self.assertEqual(
+ history_visibility_response.get("history_visibility"),
+ HistoryVisibility.SHARED,
+ )
+
+ self.helper.send(room_id1, "activity before1", tok=user2_tok)
+ self.helper.send(room_id1, "activity before2", tok=user2_tok)
+ self.helper.invite(room_id1, src=user2_id, targ=user1_id, tok=user2_tok)
+ self.helper.send(room_id1, "activity after3", tok=user2_tok)
+ self.helper.send(room_id1, "activity after4", tok=user2_tok)
+
+ # Make the Sliding Sync request
+ channel = self.make_request(
+ "POST",
+ self.sync_endpoint,
+ {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ "required_state": [],
+ "timeline_limit": 3,
+ }
+ }
+ },
+ access_token=user1_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # `timeline` is omitted for `invite` rooms with `stripped_state`
+ self.assertIsNone(
+ channel.json_body["rooms"][room_id1].get("timeline"),
+ channel.json_body["rooms"][room_id1],
+ )
+ # `num_live` is omitted for `invite` rooms with `stripped_state` (no timeline anyway)
+ self.assertIsNone(
+ channel.json_body["rooms"][room_id1].get("num_live"),
+ channel.json_body["rooms"][room_id1],
+ )
+ # `limited` is omitted for `invite` rooms with `stripped_state` (no timeline anyway)
+ self.assertIsNone(
+ channel.json_body["rooms"][room_id1].get("limited"),
+ channel.json_body["rooms"][room_id1],
+ )
+ # `prev_batch` is omitted for `invite` rooms with `stripped_state` (no timeline anyway)
+ self.assertIsNone(
+ channel.json_body["rooms"][room_id1].get("prev_batch"),
+ channel.json_body["rooms"][room_id1],
+ )
+ # We should have some `stripped_state` so the potential joiner can identify the
+ # room (we don't care about the order).
+ self.assertCountEqual(
+ channel.json_body["rooms"][room_id1]["invite_state"],
+ [
+ {
+ "content": {"creator": user2_id, "room_version": "10"},
+ "sender": user2_id,
+ "state_key": "",
+ "type": "m.room.create",
+ },
+ {
+ "content": {"join_rule": "public"},
+ "sender": user2_id,
+ "state_key": "",
+ "type": "m.room.join_rules",
+ },
+ {
+ "content": {"displayname": user2.localpart, "membership": "join"},
+ "sender": user2_id,
+ "state_key": user2_id,
+ "type": "m.room.member",
+ },
+ {
+ "content": {"displayname": user1.localpart, "membership": "invite"},
+ "sender": user2_id,
+ "state_key": user1_id,
+ "type": "m.room.member",
+ },
+ ],
+ channel.json_body["rooms"][room_id1]["invite_state"],
+ )
+
+ def test_rooms_invite_shared_history_incremental_sync(self) -> None:
+ """
+ Test that `rooms` we are invited to have some stripped `invite_state` during an
+ incremental sync.
+
+ This is an `invite` room so we should only have `stripped_state` (no `timeline`)
+ but we also shouldn't see any timeline events because the history visiblity is
+ `shared` and we haven't joined the room yet.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user1 = UserID.from_string(user1_id)
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+ user2 = UserID.from_string(user2_id)
+
+ room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
+ # Ensure we're testing with a room with `shared` history visibility which means
+ # history visible until you actually join the room.
+ history_visibility_response = self.helper.get_state(
+ room_id1, EventTypes.RoomHistoryVisibility, tok=user2_tok
+ )
+ self.assertEqual(
+ history_visibility_response.get("history_visibility"),
+ HistoryVisibility.SHARED,
+ )
+
+ self.helper.send(room_id1, "activity before invite1", tok=user2_tok)
+ self.helper.send(room_id1, "activity before invite2", tok=user2_tok)
+ self.helper.invite(room_id1, src=user2_id, targ=user1_id, tok=user2_tok)
+ self.helper.send(room_id1, "activity after invite3", tok=user2_tok)
+ self.helper.send(room_id1, "activity after invite4", tok=user2_tok)
+
+ from_token = self.event_sources.get_current_token()
+
+ self.helper.send(room_id1, "activity after token5", tok=user2_tok)
+ self.helper.send(room_id1, "activity after toekn6", tok=user2_tok)
+
+ # Make the Sliding Sync request
+ channel = self.make_request(
+ "POST",
+ self.sync_endpoint
+ + f"?pos={self.get_success(from_token.to_string(self.store))}",
+ {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ "required_state": [],
+ "timeline_limit": 3,
+ }
+ }
+ },
+ access_token=user1_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # `timeline` is omitted for `invite` rooms with `stripped_state`
+ self.assertIsNone(
+ channel.json_body["rooms"][room_id1].get("timeline"),
+ channel.json_body["rooms"][room_id1],
+ )
+ # `num_live` is omitted for `invite` rooms with `stripped_state` (no timeline anyway)
+ self.assertIsNone(
+ channel.json_body["rooms"][room_id1].get("num_live"),
+ channel.json_body["rooms"][room_id1],
+ )
+ # `limited` is omitted for `invite` rooms with `stripped_state` (no timeline anyway)
+ self.assertIsNone(
+ channel.json_body["rooms"][room_id1].get("limited"),
+ channel.json_body["rooms"][room_id1],
+ )
+ # `prev_batch` is omitted for `invite` rooms with `stripped_state` (no timeline anyway)
+ self.assertIsNone(
+ channel.json_body["rooms"][room_id1].get("prev_batch"),
+ channel.json_body["rooms"][room_id1],
+ )
+ # We should have some `stripped_state` so the potential joiner can identify the
+ # room (we don't care about the order).
+ self.assertCountEqual(
+ channel.json_body["rooms"][room_id1]["invite_state"],
+ [
+ {
+ "content": {"creator": user2_id, "room_version": "10"},
+ "sender": user2_id,
+ "state_key": "",
+ "type": "m.room.create",
+ },
+ {
+ "content": {"join_rule": "public"},
+ "sender": user2_id,
+ "state_key": "",
+ "type": "m.room.join_rules",
+ },
+ {
+ "content": {"displayname": user2.localpart, "membership": "join"},
+ "sender": user2_id,
+ "state_key": user2_id,
+ "type": "m.room.member",
+ },
+ {
+ "content": {"displayname": user1.localpart, "membership": "invite"},
+ "sender": user2_id,
+ "state_key": user1_id,
+ "type": "m.room.member",
+ },
+ ],
+ channel.json_body["rooms"][room_id1]["invite_state"],
+ )
+
+ def test_rooms_invite_world_readable_history_initial_sync(self) -> None:
+ """
+ Test that `rooms` we are invited to have some stripped `invite_state` during an
+ initial sync.
+
+ This is an `invite` room so we should only have `stripped_state` (no `timeline`)
+ but depending on the semantics we decide, we could potentially see some
+ historical events before/after the `from_token` because the history is
+ `world_readable`. Same situation for events after the `from_token` if the
+ history visibility was set to `invited`.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user1 = UserID.from_string(user1_id)
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+ user2 = UserID.from_string(user2_id)
+
+ room_id1 = self.helper.create_room_as(
+ user2_id,
+ tok=user2_tok,
+ extra_content={
+ "preset": "public_chat",
+ "initial_state": [
+ {
+ "content": {
+ "history_visibility": HistoryVisibility.WORLD_READABLE
+ },
+ "state_key": "",
+ "type": EventTypes.RoomHistoryVisibility,
+ }
+ ],
+ },
+ )
+ # Ensure we're testing with a room with `world_readable` history visibility
+ # which means events are visible to anyone even without membership.
+ history_visibility_response = self.helper.get_state(
+ room_id1, EventTypes.RoomHistoryVisibility, tok=user2_tok
+ )
+ self.assertEqual(
+ history_visibility_response.get("history_visibility"),
+ HistoryVisibility.WORLD_READABLE,
+ )
+
+ self.helper.send(room_id1, "activity before1", tok=user2_tok)
+ self.helper.send(room_id1, "activity before2", tok=user2_tok)
+ self.helper.invite(room_id1, src=user2_id, targ=user1_id, tok=user2_tok)
+ self.helper.send(room_id1, "activity after3", tok=user2_tok)
+ self.helper.send(room_id1, "activity after4", tok=user2_tok)
+
+ # Make the Sliding Sync request
+ channel = self.make_request(
+ "POST",
+ self.sync_endpoint,
+ {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ "required_state": [],
+ # Large enough to see the latest events and before the invite
+ "timeline_limit": 4,
+ }
+ }
+ },
+ access_token=user1_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # `timeline` is omitted for `invite` rooms with `stripped_state`
+ self.assertIsNone(
+ channel.json_body["rooms"][room_id1].get("timeline"),
+ channel.json_body["rooms"][room_id1],
+ )
+ # `num_live` is omitted for `invite` rooms with `stripped_state` (no timeline anyway)
+ self.assertIsNone(
+ channel.json_body["rooms"][room_id1].get("num_live"),
+ channel.json_body["rooms"][room_id1],
+ )
+ # `limited` is omitted for `invite` rooms with `stripped_state` (no timeline anyway)
+ self.assertIsNone(
+ channel.json_body["rooms"][room_id1].get("limited"),
+ channel.json_body["rooms"][room_id1],
+ )
+ # `prev_batch` is omitted for `invite` rooms with `stripped_state` (no timeline anyway)
+ self.assertIsNone(
+ channel.json_body["rooms"][room_id1].get("prev_batch"),
+ channel.json_body["rooms"][room_id1],
+ )
+ # We should have some `stripped_state` so the potential joiner can identify the
+ # room (we don't care about the order).
+ self.assertCountEqual(
+ channel.json_body["rooms"][room_id1]["invite_state"],
+ [
+ {
+ "content": {"creator": user2_id, "room_version": "10"},
+ "sender": user2_id,
+ "state_key": "",
+ "type": "m.room.create",
+ },
+ {
+ "content": {"join_rule": "public"},
+ "sender": user2_id,
+ "state_key": "",
+ "type": "m.room.join_rules",
+ },
+ {
+ "content": {"displayname": user2.localpart, "membership": "join"},
+ "sender": user2_id,
+ "state_key": user2_id,
+ "type": "m.room.member",
+ },
+ {
+ "content": {"displayname": user1.localpart, "membership": "invite"},
+ "sender": user2_id,
+ "state_key": user1_id,
+ "type": "m.room.member",
+ },
+ ],
+ channel.json_body["rooms"][room_id1]["invite_state"],
+ )
+
+ def test_rooms_invite_world_readable_history_incremental_sync(self) -> None:
+ """
+ Test that `rooms` we are invited to have some stripped `invite_state` during an
+ incremental sync.
+
+ This is an `invite` room so we should only have `stripped_state` (no `timeline`)
+ but depending on the semantics we decide, we could potentially see some
+ historical events before/after the `from_token` because the history is
+ `world_readable`. Same situation for events after the `from_token` if the
+ history visibility was set to `invited`.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user1 = UserID.from_string(user1_id)
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+ user2 = UserID.from_string(user2_id)
+
+ room_id1 = self.helper.create_room_as(
+ user2_id,
+ tok=user2_tok,
+ extra_content={
+ "preset": "public_chat",
+ "initial_state": [
+ {
+ "content": {
+ "history_visibility": HistoryVisibility.WORLD_READABLE
+ },
+ "state_key": "",
+ "type": EventTypes.RoomHistoryVisibility,
+ }
+ ],
+ },
+ )
+ # Ensure we're testing with a room with `world_readable` history visibility
+ # which means events are visible to anyone even without membership.
+ history_visibility_response = self.helper.get_state(
+ room_id1, EventTypes.RoomHistoryVisibility, tok=user2_tok
+ )
+ self.assertEqual(
+ history_visibility_response.get("history_visibility"),
+ HistoryVisibility.WORLD_READABLE,
+ )
+
+ self.helper.send(room_id1, "activity before invite1", tok=user2_tok)
+ self.helper.send(room_id1, "activity before invite2", tok=user2_tok)
+ self.helper.invite(room_id1, src=user2_id, targ=user1_id, tok=user2_tok)
+ self.helper.send(room_id1, "activity after invite3", tok=user2_tok)
+ self.helper.send(room_id1, "activity after invite4", tok=user2_tok)
+
+ from_token = self.event_sources.get_current_token()
+
+ self.helper.send(room_id1, "activity after token5", tok=user2_tok)
+ self.helper.send(room_id1, "activity after toekn6", tok=user2_tok)
+
+ # Make the Sliding Sync request
+ channel = self.make_request(
+ "POST",
+ self.sync_endpoint
+ + f"?pos={self.get_success(from_token.to_string(self.store))}",
+ {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ "required_state": [],
+ # Large enough to see the latest events and before the invite
+ "timeline_limit": 4,
+ }
+ }
+ },
+ access_token=user1_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # `timeline` is omitted for `invite` rooms with `stripped_state`
+ self.assertIsNone(
+ channel.json_body["rooms"][room_id1].get("timeline"),
+ channel.json_body["rooms"][room_id1],
+ )
+ # `num_live` is omitted for `invite` rooms with `stripped_state` (no timeline anyway)
+ self.assertIsNone(
+ channel.json_body["rooms"][room_id1].get("num_live"),
+ channel.json_body["rooms"][room_id1],
+ )
+ # `limited` is omitted for `invite` rooms with `stripped_state` (no timeline anyway)
+ self.assertIsNone(
+ channel.json_body["rooms"][room_id1].get("limited"),
+ channel.json_body["rooms"][room_id1],
+ )
+ # `prev_batch` is omitted for `invite` rooms with `stripped_state` (no timeline anyway)
+ self.assertIsNone(
+ channel.json_body["rooms"][room_id1].get("prev_batch"),
+ channel.json_body["rooms"][room_id1],
+ )
+ # We should have some `stripped_state` so the potential joiner can identify the
+ # room (we don't care about the order).
+ self.assertCountEqual(
+ channel.json_body["rooms"][room_id1]["invite_state"],
+ [
+ {
+ "content": {"creator": user2_id, "room_version": "10"},
+ "sender": user2_id,
+ "state_key": "",
+ "type": "m.room.create",
+ },
+ {
+ "content": {"join_rule": "public"},
+ "sender": user2_id,
+ "state_key": "",
+ "type": "m.room.join_rules",
+ },
+ {
+ "content": {"displayname": user2.localpart, "membership": "join"},
+ "sender": user2_id,
+ "state_key": user2_id,
+ "type": "m.room.member",
+ },
+ {
+ "content": {"displayname": user1.localpart, "membership": "invite"},
+ "sender": user2_id,
+ "state_key": user1_id,
+ "type": "m.room.member",
+ },
+ ],
+ channel.json_body["rooms"][room_id1]["invite_state"],
+ )
+
+ def test_rooms_ban_initial_sync(self) -> None:
+ """
+ Test that `rooms` we are banned from in an intial sync only allows us to see
+ timeline events up to the ban event.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
+ self.helper.send(room_id1, "activity before1", tok=user2_tok)
+ self.helper.send(room_id1, "activity before2", tok=user2_tok)
+ self.helper.join(room_id1, user1_id, tok=user1_tok)
+
+ event_response3 = self.helper.send(room_id1, "activity after3", tok=user2_tok)
+ event_response4 = self.helper.send(room_id1, "activity after4", tok=user2_tok)
+ user1_ban_response = self.helper.ban(
+ room_id1, src=user2_id, targ=user1_id, tok=user2_tok
+ )
+
+ self.helper.send(room_id1, "activity after5", tok=user2_tok)
+ self.helper.send(room_id1, "activity after6", tok=user2_tok)
+
+ # Make the Sliding Sync request
+ channel = self.make_request(
+ "POST",
+ self.sync_endpoint,
+ {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ "required_state": [],
+ "timeline_limit": 3,
+ }
+ }
+ },
+ access_token=user1_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # We should see events before the ban but not after
+ self.assertEqual(
+ [
+ event["event_id"]
+ for event in channel.json_body["rooms"][room_id1]["timeline"]
+ ],
+ [
+ event_response3["event_id"],
+ event_response4["event_id"],
+ user1_ban_response["event_id"],
+ ],
+ channel.json_body["rooms"][room_id1]["timeline"],
+ )
+ # No "live" events in an initial sync (no `from_token` to define the "live"
+ # range)
+ self.assertEqual(
+ channel.json_body["rooms"][room_id1]["num_live"],
+ 0,
+ channel.json_body["rooms"][room_id1],
+ )
+ # There are more events to paginate to
+ self.assertEqual(
+ channel.json_body["rooms"][room_id1]["limited"],
+ True,
+ channel.json_body["rooms"][room_id1],
+ )
+
+ def test_rooms_ban_incremental_sync1(self) -> None:
+ """
+ Test that `rooms` we are banned from during the next incremental sync only
+ allows us to see timeline events up to the ban event.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
+ self.helper.send(room_id1, "activity before1", tok=user2_tok)
+ self.helper.send(room_id1, "activity before2", tok=user2_tok)
+ self.helper.join(room_id1, user1_id, tok=user1_tok)
+
+ from_token = self.event_sources.get_current_token()
+
+ event_response3 = self.helper.send(room_id1, "activity after3", tok=user2_tok)
+ event_response4 = self.helper.send(room_id1, "activity after4", tok=user2_tok)
+ # The ban is within the token range (between the `from_token` and the sliding
+ # sync request)
+ user1_ban_response = self.helper.ban(
+ room_id1, src=user2_id, targ=user1_id, tok=user2_tok
+ )
+
+ self.helper.send(room_id1, "activity after5", tok=user2_tok)
+ self.helper.send(room_id1, "activity after6", tok=user2_tok)
+
+ # Make the Sliding Sync request
+ channel = self.make_request(
+ "POST",
+ self.sync_endpoint
+ + f"?pos={self.get_success(from_token.to_string(self.store))}",
+ {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ "required_state": [],
+ "timeline_limit": 4,
+ }
+ }
+ },
+ access_token=user1_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # We should see events before the ban but not after
+ self.assertEqual(
+ [
+ event["event_id"]
+ for event in channel.json_body["rooms"][room_id1]["timeline"]
+ ],
+ [
+ event_response3["event_id"],
+ event_response4["event_id"],
+ user1_ban_response["event_id"],
+ ],
+ channel.json_body["rooms"][room_id1]["timeline"],
+ )
+ # All live events in the incremental sync
+ self.assertEqual(
+ channel.json_body["rooms"][room_id1]["num_live"],
+ 3,
+ channel.json_body["rooms"][room_id1],
+ )
+ # There aren't anymore events to paginate to in this range
+ self.assertEqual(
+ channel.json_body["rooms"][room_id1]["limited"],
+ False,
+ channel.json_body["rooms"][room_id1],
+ )
+
+ def test_rooms_ban_incremental_sync2(self) -> None:
+ """
+ Test that `rooms` we are banned from before the incremental sync don't return
+ any events in the timeline.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
+ self.helper.send(room_id1, "activity before1", tok=user2_tok)
+ self.helper.join(room_id1, user1_id, tok=user1_tok)
+
+ self.helper.send(room_id1, "activity after2", tok=user2_tok)
+ # The ban is before we get our `from_token`
+ self.helper.ban(room_id1, src=user2_id, targ=user1_id, tok=user2_tok)
+
+ self.helper.send(room_id1, "activity after3", tok=user2_tok)
+
+ from_token = self.event_sources.get_current_token()
+
+ self.helper.send(room_id1, "activity after4", tok=user2_tok)
+
+ # Make the Sliding Sync request
+ channel = self.make_request(
+ "POST",
+ self.sync_endpoint
+ + f"?pos={self.get_success(from_token.to_string(self.store))}",
+ {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ "required_state": [],
+ "timeline_limit": 4,
+ }
+ }
+ },
+ access_token=user1_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # Nothing to see for this banned user in the room in the token range
+ self.assertEqual(
+ channel.json_body["rooms"][room_id1]["timeline"],
+ [],
+ channel.json_body["rooms"][room_id1]["timeline"],
+ )
+ # No events returned in the timeline so nothing is "live"
+ self.assertEqual(
+ channel.json_body["rooms"][room_id1]["num_live"],
+ 0,
+ channel.json_body["rooms"][room_id1],
+ )
+ # There aren't anymore events to paginate to in this range
+ self.assertEqual(
+ channel.json_body["rooms"][room_id1]["limited"],
+ False,
+ channel.json_body["rooms"][room_id1],
+ )
diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index f0ba40a1f1..e43140720d 100644
--- a/tests/rest/client/utils.py
+++ b/tests/rest/client/utils.py
@@ -261,9 +261,9 @@ class RestHelper:
targ: str,
expect_code: int = HTTPStatus.OK,
tok: Optional[str] = None,
- ) -> None:
+ ) -> JsonDict:
"""A convenience helper: `change_membership` with `membership` preset to "ban"."""
- self.change_membership(
+ return self.change_membership(
room=room,
src=src,
targ=targ,
diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py
index fe1e873e15..aad46b1b44 100644
--- a/tests/storage/test_stream.py
+++ b/tests/storage/test_stream.py
@@ -21,20 +21,32 @@
import logging
from typing import List, Tuple
+from unittest.mock import AsyncMock, patch
from immutabledict import immutabledict
from twisted.test.proto_helpers import MemoryReactor
-from synapse.api.constants import Direction, EventTypes, RelationTypes
+from synapse.api.constants import Direction, EventTypes, Membership, RelationTypes
from synapse.api.filtering import Filter
+from synapse.crypto.event_signing import add_hashes_and_signatures
+from synapse.events import FrozenEventV3
+from synapse.federation.federation_client import SendJoinResult
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
-from synapse.types import JsonDict, PersistedEventPosition, RoomStreamToken
+from synapse.storage.databases.main.stream import CurrentStateDeltaMembership
+from synapse.types import (
+ JsonDict,
+ PersistedEventPosition,
+ RoomStreamToken,
+ UserID,
+ create_requester,
+)
from synapse.util import Clock
-from tests.unittest import HomeserverTestCase
+from tests.test_utils.event_injection import create_event
+from tests.unittest import FederatingHomeserverTestCase, HomeserverTestCase
logger = logging.getLogger(__name__)
@@ -543,3 +555,859 @@ class GetLastEventInRoomBeforeStreamOrderingTestCase(HomeserverTestCase):
}
),
)
+
+
+class GetCurrentStateDeltaMembershipChangesForUserTestCase(HomeserverTestCase):
+ """
+ Test `get_current_state_delta_membership_changes_for_user(...)`
+ """
+
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.store = hs.get_datastores().main
+ self.event_sources = hs.get_event_sources()
+ self.state_handler = self.hs.get_state_handler()
+ persistence = hs.get_storage_controllers().persistence
+ assert persistence is not None
+ self.persistence = persistence
+
+ def test_returns_membership_events(self) -> None:
+ """
+ A basic test that a membership event in the token range is returned for the user.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ before_room1_token = self.event_sources.get_current_token()
+
+ room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
+ join_response = self.helper.join(room_id1, user1_id, tok=user1_tok)
+ join_pos = self.get_success(
+ self.store.get_position_for_event(join_response["event_id"])
+ )
+
+ after_room1_token = self.event_sources.get_current_token()
+
+ membership_changes = self.get_success(
+ self.store.get_current_state_delta_membership_changes_for_user(
+ user1_id,
+ from_key=before_room1_token.room_key,
+ to_key=after_room1_token.room_key,
+ )
+ )
+
+ # Let the whole diff show on failure
+ self.maxDiff = None
+ self.assertEqual(
+ membership_changes,
+ [
+ CurrentStateDeltaMembership(
+ room_id=room_id1,
+ event_id=join_response["event_id"],
+ event_pos=join_pos,
+ membership="join",
+ sender=user1_id,
+ prev_event_id=None,
+ prev_event_pos=None,
+ prev_membership=None,
+ prev_sender=None,
+ )
+ ],
+ )
+
+ def test_server_left_room_after_us(self) -> None:
+ """
+ Test that when probing over part of the DAG where the server left the room *after
+ us*, we still see the join and leave changes.
+
+ This is to make sure we play nicely with this behavior: When the server leaves a
+ room, it will insert new rows with `event_id = null` into the
+ `current_state_delta_stream` table for all current state.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ before_room1_token = self.event_sources.get_current_token()
+
+ room_id1 = self.helper.create_room_as(
+ user2_id,
+ tok=user2_tok,
+ extra_content={
+ "power_level_content_override": {
+ "users": {
+ user2_id: 100,
+ # Allow user1 to send state in the room
+ user1_id: 100,
+ }
+ }
+ },
+ )
+ join_response1 = self.helper.join(room_id1, user1_id, tok=user1_tok)
+ join_pos1 = self.get_success(
+ self.store.get_position_for_event(join_response1["event_id"])
+ )
+ # Make sure that random other non-member state that happens to have a `state_key`
+ # matching the user ID doesn't mess with things.
+ self.helper.send_state(
+ room_id1,
+ event_type="foobarbazdummy",
+ state_key=user1_id,
+ body={"foo": "bar"},
+ tok=user1_tok,
+ )
+ # User1 should leave the room first
+ leave_response1 = self.helper.leave(room_id1, user1_id, tok=user1_tok)
+ leave_pos1 = self.get_success(
+ self.store.get_position_for_event(leave_response1["event_id"])
+ )
+
+ # User2 should also leave the room (everyone has left the room which means the
+ # server is no longer in the room).
+ self.helper.leave(room_id1, user2_id, tok=user2_tok)
+
+ after_room1_token = self.event_sources.get_current_token()
+
+ # Get the membership changes for the user.
+ #
+ # At this point, the `current_state_delta_stream` table should look like the
+ # following. When the server leaves a room, it will insert new rows with
+ # `event_id = null` for all current state.
+ #
+ # | stream_id | room_id | type | state_key | event_id | prev_event_id |
+ # |-----------|----------|-----------------------------|----------------|----------|---------------|
+ # | 2 | !x:test | 'm.room.create' | '' | $xxx | None |
+ # | 3 | !x:test | 'm.room.member' | '@user2:test' | $aaa | None |
+ # | 4 | !x:test | 'm.room.history_visibility' | '' | $xxx | None |
+ # | 4 | !x:test | 'm.room.join_rules' | '' | $xxx | None |
+ # | 4 | !x:test | 'm.room.power_levels' | '' | $xxx | None |
+ # | 7 | !x:test | 'm.room.member' | '@user1:test' | $ooo | None |
+ # | 8 | !x:test | 'foobarbazdummy' | '@user1:test' | $xxx | None |
+ # | 9 | !x:test | 'm.room.member' | '@user1:test' | $ppp | $ooo |
+ # | 10 | !x:test | 'foobarbazdummy' | '@user1:test' | None | $xxx |
+ # | 10 | !x:test | 'm.room.create' | '' | None | $xxx |
+ # | 10 | !x:test | 'm.room.history_visibility' | '' | None | $xxx |
+ # | 10 | !x:test | 'm.room.join_rules' | '' | None | $xxx |
+ # | 10 | !x:test | 'm.room.member' | '@user1:test' | None | $ppp |
+ # | 10 | !x:test | 'm.room.member' | '@user2:test' | None | $aaa |
+ # | 10 | !x:test | 'm.room.power_levels' | | None | $xxx |
+ membership_changes = self.get_success(
+ self.store.get_current_state_delta_membership_changes_for_user(
+ user1_id,
+ from_key=before_room1_token.room_key,
+ to_key=after_room1_token.room_key,
+ )
+ )
+
+ # Let the whole diff show on failure
+ self.maxDiff = None
+ self.assertEqual(
+ membership_changes,
+ [
+ CurrentStateDeltaMembership(
+ room_id=room_id1,
+ event_id=join_response1["event_id"],
+ event_pos=join_pos1,
+ membership="join",
+ sender=user1_id,
+ prev_event_id=None,
+ prev_event_pos=None,
+ prev_membership=None,
+ prev_sender=None,
+ ),
+ CurrentStateDeltaMembership(
+ room_id=room_id1,
+ event_id=leave_response1["event_id"],
+ event_pos=leave_pos1,
+ membership="leave",
+ sender=user1_id,
+ prev_event_id=join_response1["event_id"],
+ prev_event_pos=join_pos1,
+ prev_membership="join",
+ prev_sender=user1_id,
+ ),
+ ],
+ )
+
+ def test_server_left_room_after_us_later(self) -> None:
+ """
+ Test when the user leaves the room, then sometime later, everyone else leaves
+ the room, causing the server to leave the room, we shouldn't see any membership
+ changes.
+
+ This is to make sure we play nicely with this behavior: When the server leaves a
+ room, it will insert new rows with `event_id = null` into the
+ `current_state_delta_stream` table for all current state.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
+ self.helper.join(room_id1, user1_id, tok=user1_tok)
+ # User1 should leave the room first
+ self.helper.leave(room_id1, user1_id, tok=user1_tok)
+
+ after_user1_leave_token = self.event_sources.get_current_token()
+
+ # User2 should also leave the room (everyone has left the room which means the
+ # server is no longer in the room).
+ self.helper.leave(room_id1, user2_id, tok=user2_tok)
+
+ after_server_leave_token = self.event_sources.get_current_token()
+
+ # Join another room as user1 just to advance the stream_ordering and bust
+ # `_membership_stream_cache`
+ room_id2 = self.helper.create_room_as(user2_id, tok=user2_tok)
+ self.helper.join(room_id2, user1_id, tok=user1_tok)
+
+ # Get the membership changes for the user.
+ #
+ # At this point, the `current_state_delta_stream` table should look like the
+ # following. When the server leaves a room, it will insert new rows with
+ # `event_id = null` for all current state.
+ #
+ # TODO: Add DB rows to better see what's going on.
+ membership_changes = self.get_success(
+ self.store.get_current_state_delta_membership_changes_for_user(
+ user1_id,
+ from_key=after_user1_leave_token.room_key,
+ to_key=after_server_leave_token.room_key,
+ )
+ )
+
+ # Let the whole diff show on failure
+ self.maxDiff = None
+ self.assertEqual(
+ membership_changes,
+ [],
+ )
+
+ def test_we_cause_server_left_room(self) -> None:
+ """
+ Test that when probing over part of the DAG where the user leaves the room
+ causing the server to leave the room (because we were the last local user in the
+ room), we still see the join and leave changes.
+
+ This is to make sure we play nicely with this behavior: When the server leaves a
+ room, it will insert new rows with `event_id = null` into the
+ `current_state_delta_stream` table for all current state.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ before_room1_token = self.event_sources.get_current_token()
+
+ room_id1 = self.helper.create_room_as(
+ user2_id,
+ tok=user2_tok,
+ extra_content={
+ "power_level_content_override": {
+ "users": {
+ user2_id: 100,
+ # Allow user1 to send state in the room
+ user1_id: 100,
+ }
+ }
+ },
+ )
+ join_response1 = self.helper.join(room_id1, user1_id, tok=user1_tok)
+ join_pos1 = self.get_success(
+ self.store.get_position_for_event(join_response1["event_id"])
+ )
+ # Make sure that random other non-member state that happens to have a `state_key`
+ # matching the user ID doesn't mess with things.
+ self.helper.send_state(
+ room_id1,
+ event_type="foobarbazdummy",
+ state_key=user1_id,
+ body={"foo": "bar"},
+ tok=user1_tok,
+ )
+
+ # User2 should leave the room first.
+ self.helper.leave(room_id1, user2_id, tok=user2_tok)
+
+ # User1 (the person we're testing with) should also leave the room (everyone has
+ # left the room which means the server is no longer in the room).
+ leave_response1 = self.helper.leave(room_id1, user1_id, tok=user1_tok)
+ leave_pos1 = self.get_success(
+ self.store.get_position_for_event(leave_response1["event_id"])
+ )
+
+ after_room1_token = self.event_sources.get_current_token()
+
+ # Get the membership changes for the user.
+ #
+ # At this point, the `current_state_delta_stream` table should look like the
+ # following. When the server leaves a room, it will insert new rows with
+ # `event_id = null` for all current state.
+ #
+ # | stream_id | room_id | type | state_key | event_id | prev_event_id |
+ # |-----------|-----------|-----------------------------|---------------|----------|---------------|
+ # | 2 | '!x:test' | 'm.room.create' | '' | '$xxx' | None |
+ # | 3 | '!x:test' | 'm.room.member' | '@user2:test' | '$aaa' | None |
+ # | 4 | '!x:test' | 'm.room.history_visibility' | '' | '$xxx' | None |
+ # | 4 | '!x:test' | 'm.room.join_rules' | '' | '$xxx' | None |
+ # | 4 | '!x:test' | 'm.room.power_levels' | '' | '$xxx' | None |
+ # | 7 | '!x:test' | 'm.room.member' | '@user1:test' | '$ooo' | None |
+ # | 8 | '!x:test' | 'foobarbazdummy' | '@user1:test' | '$xxx' | None |
+ # | 9 | '!x:test' | 'm.room.member' | '@user2:test' | '$bbb' | '$aaa' |
+ # | 10 | '!x:test' | 'foobarbazdummy' | '@user1:test' | None | '$xxx' |
+ # | 10 | '!x:test' | 'm.room.create' | '' | None | '$xxx' |
+ # | 10 | '!x:test' | 'm.room.history_visibility' | '' | None | '$xxx' |
+ # | 10 | '!x:test' | 'm.room.join_rules' | '' | None | '$xxx' |
+ # | 10 | '!x:test' | 'm.room.member' | '@user1:test' | None | '$ooo' |
+ # | 10 | '!x:test' | 'm.room.member' | '@user2:test' | None | '$bbb' |
+ # | 10 | '!x:test' | 'm.room.power_levels' | '' | None | '$xxx' |
+ membership_changes = self.get_success(
+ self.store.get_current_state_delta_membership_changes_for_user(
+ user1_id,
+ from_key=before_room1_token.room_key,
+ to_key=after_room1_token.room_key,
+ )
+ )
+
+ # Let the whole diff show on failure
+ self.maxDiff = None
+ self.assertEqual(
+ membership_changes,
+ [
+ CurrentStateDeltaMembership(
+ room_id=room_id1,
+ event_id=join_response1["event_id"],
+ event_pos=join_pos1,
+ membership="join",
+ sender=user1_id,
+ prev_event_id=None,
+ prev_event_pos=None,
+ prev_membership=None,
+ prev_sender=None,
+ ),
+ CurrentStateDeltaMembership(
+ room_id=room_id1,
+ event_id=None, # leave_response1["event_id"],
+ event_pos=leave_pos1,
+ membership="leave",
+ sender=None, # user1_id,
+ prev_event_id=join_response1["event_id"],
+ prev_event_pos=join_pos1,
+ prev_membership="join",
+ prev_sender=user1_id,
+ ),
+ ],
+ )
+
+ def test_different_user_membership_persisted_in_same_batch(self) -> None:
+ """
+ Test batch of membership events from different users being processed at once.
+ This will result in all of the memberships being stored in the
+ `current_state_delta_stream` table with the same `stream_ordering` even though
+ the individual events have different `stream_ordering`s.
+ """
+ user1_id = self.register_user("user1", "pass")
+ _user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+ user3_id = self.register_user("user3", "pass")
+ _user3_tok = self.login(user3_id, "pass")
+ user4_id = self.register_user("user4", "pass")
+ _user4_tok = self.login(user4_id, "pass")
+
+ before_room1_token = self.event_sources.get_current_token()
+
+ # User2 is just the designated person to create the room (we do this across the
+ # tests to be consistent)
+ room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
+
+ # Persist the user1, user3, and user4 join events in the same batch so they all
+ # end up in the `current_state_delta_stream` table with the same
+ # stream_ordering.
+ join_event3, join_event_context3 = self.get_success(
+ create_event(
+ self.hs,
+ sender=user3_id,
+ type=EventTypes.Member,
+ state_key=user3_id,
+ content={"membership": "join"},
+ room_id=room_id1,
+ )
+ )
+ # We want to put user1 in the middle of the batch. This way, regardless of the
+ # implementation that inserts rows into current_state_delta_stream` (whether it
+ # be minimum/maximum of stream position of the batch), we will still catch bugs.
+ join_event1, join_event_context1 = self.get_success(
+ create_event(
+ self.hs,
+ sender=user1_id,
+ type=EventTypes.Member,
+ state_key=user1_id,
+ content={"membership": "join"},
+ room_id=room_id1,
+ )
+ )
+ join_event4, join_event_context4 = self.get_success(
+ create_event(
+ self.hs,
+ sender=user4_id,
+ type=EventTypes.Member,
+ state_key=user4_id,
+ content={"membership": "join"},
+ room_id=room_id1,
+ )
+ )
+ self.get_success(
+ self.persistence.persist_events(
+ [
+ (join_event3, join_event_context3),
+ (join_event1, join_event_context1),
+ (join_event4, join_event_context4),
+ ]
+ )
+ )
+
+ after_room1_token = self.event_sources.get_current_token()
+
+ # Get the membership changes for the user.
+ #
+ # At this point, the `current_state_delta_stream` table should look like (notice
+ # those three memberships at the end with `stream_id=7` because we persisted
+ # them in the same batch):
+ #
+ # | stream_id | room_id | type | state_key | event_id | prev_event_id |
+ # |-----------|-----------|----------------------------|------------------|----------|---------------|
+ # | 2 | '!x:test' | 'm.room.create' | '' | '$xxx' | None |
+ # | 3 | '!x:test' | 'm.room.member' | '@user2:test' | '$xxx' | None |
+ # | 4 | '!x:test' | 'm.room.history_visibility'| '' | '$xxx' | None |
+ # | 4 | '!x:test' | 'm.room.join_rules' | '' | '$xxx' | None |
+ # | 4 | '!x:test' | 'm.room.power_levels' | '' | '$xxx' | None |
+ # | 7 | '!x:test' | 'm.room.member' | '@user3:test' | '$xxx' | None |
+ # | 7 | '!x:test' | 'm.room.member' | '@user1:test' | '$xxx' | None |
+ # | 7 | '!x:test' | 'm.room.member' | '@user4:test' | '$xxx' | None |
+ membership_changes = self.get_success(
+ self.store.get_current_state_delta_membership_changes_for_user(
+ user1_id,
+ from_key=before_room1_token.room_key,
+ to_key=after_room1_token.room_key,
+ )
+ )
+
+ join_pos3 = self.get_success(
+ self.store.get_position_for_event(join_event3.event_id)
+ )
+
+ # Let the whole diff show on failure
+ self.maxDiff = None
+ self.assertEqual(
+ membership_changes,
+ [
+ CurrentStateDeltaMembership(
+ room_id=room_id1,
+ event_id=join_event1.event_id,
+ # Ideally, this would be `join_pos1` (to match the `event_id`) but
+ # when events are persisted in a batch, they are all stored in the
+ # `current_state_delta_stream` table with the minimum
+ # `stream_ordering` from the batch.
+ event_pos=join_pos3,
+ membership="join",
+ sender=user1_id,
+ prev_event_id=None,
+ prev_event_pos=None,
+ prev_membership=None,
+ prev_sender=None,
+ ),
+ ],
+ )
+
+ def test_state_reset(self) -> None:
+ """
+ Test a state reset scenario where the user gets removed from the room (when
+ there is no corresponding leave event)
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
+ join_response1 = self.helper.join(room_id1, user1_id, tok=user1_tok)
+ join_pos1 = self.get_success(
+ self.store.get_position_for_event(join_response1["event_id"])
+ )
+
+ before_reset_token = self.event_sources.get_current_token()
+
+ # Send another state event to make a position for the state reset to happen at
+ dummy_state_response = self.helper.send_state(
+ room_id1,
+ event_type="foobarbaz",
+ state_key="",
+ body={"foo": "bar"},
+ tok=user2_tok,
+ )
+ dummy_state_pos = self.get_success(
+ self.store.get_position_for_event(dummy_state_response["event_id"])
+ )
+
+ # Mock a state reset removing the membership for user1 in the current state
+ self.get_success(
+ self.store.db_pool.simple_delete(
+ table="current_state_events",
+ keyvalues={
+ "room_id": room_id1,
+ "type": EventTypes.Member,
+ "state_key": user1_id,
+ },
+ desc="state reset user in current_state_delta_stream",
+ )
+ )
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ table="current_state_delta_stream",
+ values={
+ "stream_id": dummy_state_pos.stream,
+ "room_id": room_id1,
+ "type": EventTypes.Member,
+ "state_key": user1_id,
+ "event_id": None,
+ "prev_event_id": join_response1["event_id"],
+ "instance_name": dummy_state_pos.instance_name,
+ },
+ desc="state reset user in current_state_delta_stream",
+ )
+ )
+
+ # Manually bust the cache since we we're just manually messing with the database
+ # and not causing an actual state reset.
+ self.store._membership_stream_cache.entity_has_changed(
+ user1_id, dummy_state_pos.stream
+ )
+
+ after_reset_token = self.event_sources.get_current_token()
+
+ membership_changes = self.get_success(
+ self.store.get_current_state_delta_membership_changes_for_user(
+ user1_id,
+ from_key=before_reset_token.room_key,
+ to_key=after_reset_token.room_key,
+ )
+ )
+
+ # Let the whole diff show on failure
+ self.maxDiff = None
+ self.assertEqual(
+ membership_changes,
+ [
+ CurrentStateDeltaMembership(
+ room_id=room_id1,
+ event_id=None,
+ event_pos=dummy_state_pos,
+ membership="leave",
+ sender=None, # user1_id,
+ prev_event_id=join_response1["event_id"],
+ prev_event_pos=join_pos1,
+ prev_membership="join",
+ prev_sender=user1_id,
+ ),
+ ],
+ )
+
+ def test_excluded_room_ids(self) -> None:
+ """
+ Test that the `excluded_room_ids` option excludes changes from the specified rooms.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ before_room1_token = self.event_sources.get_current_token()
+
+ room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
+ join_response1 = self.helper.join(room_id1, user1_id, tok=user1_tok)
+ join_pos1 = self.get_success(
+ self.store.get_position_for_event(join_response1["event_id"])
+ )
+
+ room_id2 = self.helper.create_room_as(user2_id, tok=user2_tok)
+ join_response2 = self.helper.join(room_id2, user1_id, tok=user1_tok)
+ join_pos2 = self.get_success(
+ self.store.get_position_for_event(join_response2["event_id"])
+ )
+
+ after_room1_token = self.event_sources.get_current_token()
+
+ # First test the the room is returned without the `excluded_room_ids` option
+ membership_changes = self.get_success(
+ self.store.get_current_state_delta_membership_changes_for_user(
+ user1_id,
+ from_key=before_room1_token.room_key,
+ to_key=after_room1_token.room_key,
+ )
+ )
+
+ # Let the whole diff show on failure
+ self.maxDiff = None
+ self.assertEqual(
+ membership_changes,
+ [
+ CurrentStateDeltaMembership(
+ room_id=room_id1,
+ event_id=join_response1["event_id"],
+ event_pos=join_pos1,
+ membership="join",
+ sender=user1_id,
+ prev_event_id=None,
+ prev_event_pos=None,
+ prev_membership=None,
+ prev_sender=None,
+ ),
+ CurrentStateDeltaMembership(
+ room_id=room_id2,
+ event_id=join_response2["event_id"],
+ event_pos=join_pos2,
+ membership="join",
+ sender=user1_id,
+ prev_event_id=None,
+ prev_event_pos=None,
+ prev_membership=None,
+ prev_sender=None,
+ ),
+ ],
+ )
+
+ # The test that `excluded_room_ids` excludes room2 as expected
+ membership_changes = self.get_success(
+ self.store.get_current_state_delta_membership_changes_for_user(
+ user1_id,
+ from_key=before_room1_token.room_key,
+ to_key=after_room1_token.room_key,
+ excluded_room_ids=[room_id2],
+ )
+ )
+
+ # Let the whole diff show on failure
+ self.maxDiff = None
+ self.assertEqual(
+ membership_changes,
+ [
+ CurrentStateDeltaMembership(
+ room_id=room_id1,
+ event_id=join_response1["event_id"],
+ event_pos=join_pos1,
+ membership="join",
+ sender=user1_id,
+ prev_event_id=None,
+ prev_event_pos=None,
+ prev_membership=None,
+ prev_sender=None,
+ )
+ ],
+ )
+
+
+class GetCurrentStateDeltaMembershipChangesForUserFederationTestCase(
+ FederatingHomeserverTestCase
+):
+ """
+ Test `get_current_state_delta_membership_changes_for_user(...)` when joining remote federated rooms.
+ """
+
+ servlets = [
+ admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.sliding_sync_handler = self.hs.get_sliding_sync_handler()
+ self.store = self.hs.get_datastores().main
+ self.event_sources = hs.get_event_sources()
+ self.room_member_handler = hs.get_room_member_handler()
+
+ def test_remote_join(self) -> None:
+ """
+ Test remote join where the first rows in `current_state_delta_stream` will just
+ be the state when you joined the remote room.
+ """
+ user1_id = self.register_user("user1", "pass")
+ _user1_tok = self.login(user1_id, "pass")
+
+ before_join_token = self.event_sources.get_current_token()
+
+ intially_unjoined_room_id = f"!example:{self.OTHER_SERVER_NAME}"
+
+ # Remotely join a room on another homeserver.
+ #
+ # To do this we have to mock the responses from the remote homeserver. We also
+ # patch out a bunch of event checks on our end.
+ create_event_source = {
+ "auth_events": [],
+ "content": {
+ "creator": f"@creator:{self.OTHER_SERVER_NAME}",
+ "room_version": self.hs.config.server.default_room_version.identifier,
+ },
+ "depth": 0,
+ "origin_server_ts": 0,
+ "prev_events": [],
+ "room_id": intially_unjoined_room_id,
+ "sender": f"@creator:{self.OTHER_SERVER_NAME}",
+ "state_key": "",
+ "type": EventTypes.Create,
+ }
+ self.add_hashes_and_signatures_from_other_server(
+ create_event_source,
+ self.hs.config.server.default_room_version,
+ )
+ create_event = FrozenEventV3(
+ create_event_source,
+ self.hs.config.server.default_room_version,
+ {},
+ None,
+ )
+ creator_join_event_source = {
+ "auth_events": [create_event.event_id],
+ "content": {
+ "membership": "join",
+ },
+ "depth": 1,
+ "origin_server_ts": 1,
+ "prev_events": [],
+ "room_id": intially_unjoined_room_id,
+ "sender": f"@creator:{self.OTHER_SERVER_NAME}",
+ "state_key": f"@creator:{self.OTHER_SERVER_NAME}",
+ "type": EventTypes.Member,
+ }
+ self.add_hashes_and_signatures_from_other_server(
+ creator_join_event_source,
+ self.hs.config.server.default_room_version,
+ )
+ creator_join_event = FrozenEventV3(
+ creator_join_event_source,
+ self.hs.config.server.default_room_version,
+ {},
+ None,
+ )
+
+ # Our local user is going to remote join the room
+ join_event_source = {
+ "auth_events": [create_event.event_id],
+ "content": {"membership": "join"},
+ "depth": 1,
+ "origin_server_ts": 100,
+ "prev_events": [creator_join_event.event_id],
+ "sender": user1_id,
+ "state_key": user1_id,
+ "room_id": intially_unjoined_room_id,
+ "type": EventTypes.Member,
+ }
+ add_hashes_and_signatures(
+ self.hs.config.server.default_room_version,
+ join_event_source,
+ self.hs.hostname,
+ self.hs.signing_key,
+ )
+ join_event = FrozenEventV3(
+ join_event_source,
+ self.hs.config.server.default_room_version,
+ {},
+ None,
+ )
+
+ mock_make_membership_event = AsyncMock(
+ return_value=(
+ self.OTHER_SERVER_NAME,
+ join_event,
+ self.hs.config.server.default_room_version,
+ )
+ )
+ mock_send_join = AsyncMock(
+ return_value=SendJoinResult(
+ join_event,
+ self.OTHER_SERVER_NAME,
+ state=[create_event, creator_join_event],
+ auth_chain=[create_event, creator_join_event],
+ partial_state=False,
+ servers_in_room=frozenset(),
+ )
+ )
+
+ with patch.object(
+ self.room_member_handler.federation_handler.federation_client,
+ "make_membership_event",
+ mock_make_membership_event,
+ ), patch.object(
+ self.room_member_handler.federation_handler.federation_client,
+ "send_join",
+ mock_send_join,
+ ), patch(
+ "synapse.event_auth._is_membership_change_allowed",
+ return_value=None,
+ ), patch(
+ "synapse.handlers.federation_event.check_state_dependent_auth_rules",
+ return_value=None,
+ ):
+ self.get_success(
+ self.room_member_handler.update_membership(
+ requester=create_requester(user1_id),
+ target=UserID.from_string(user1_id),
+ room_id=intially_unjoined_room_id,
+ action=Membership.JOIN,
+ remote_room_hosts=[self.OTHER_SERVER_NAME],
+ )
+ )
+
+ after_join_token = self.event_sources.get_current_token()
+
+ # Get the membership changes for the user.
+ #
+ # At this point, the `current_state_delta_stream` table should look like the
+ # following. Notice that all of the events are at the same `stream_id` because
+ # the current state starts out where we remotely joined:
+ #
+ # | stream_id | room_id | type | state_key | event_id | prev_event_id |
+ # |-----------|------------------------------|-----------------|------------------------------|----------|---------------|
+ # | 2 | '!example:other.example.com' | 'm.room.member' | '@user1:test' | '$xxx' | None |
+ # | 2 | '!example:other.example.com' | 'm.room.create' | '' | '$xxx' | None |
+ # | 2 | '!example:other.example.com' | 'm.room.member' | '@creator:other.example.com' | '$xxx' | None |
+ membership_changes = self.get_success(
+ self.store.get_current_state_delta_membership_changes_for_user(
+ user1_id,
+ from_key=before_join_token.room_key,
+ to_key=after_join_token.room_key,
+ )
+ )
+
+ join_pos = self.get_success(
+ self.store.get_position_for_event(join_event.event_id)
+ )
+
+ # Let the whole diff show on failure
+ self.maxDiff = None
+ self.assertEqual(
+ membership_changes,
+ [
+ CurrentStateDeltaMembership(
+ room_id=intially_unjoined_room_id,
+ event_id=join_event.event_id,
+ event_pos=join_pos,
+ membership="join",
+ sender=user1_id,
+ prev_event_id=None,
+ prev_event_pos=None,
+ prev_membership=None,
+ prev_sender=None,
+ ),
+ ],
+ )
|