diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index 6351326fff..fc2a6c569b 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -20,7 +20,7 @@
#
import urllib.parse
-from typing import Dict
+from typing import Dict, cast
from parameterized import parameterized
@@ -30,8 +30,9 @@ from twisted.web.resource import Resource
import synapse.rest.admin
from synapse.http.server import JsonResource
from synapse.rest.admin import VersionServlet
-from synapse.rest.client import login, room
+from synapse.rest.client import login, media, room
from synapse.server import HomeServer
+from synapse.types import UserID
from synapse.util import Clock
from tests import unittest
@@ -60,6 +61,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
synapse.rest.admin.register_servlets,
synapse.rest.admin.register_servlets_for_media_repo,
login.register_servlets,
+ media.register_servlets,
room.register_servlets,
]
@@ -74,7 +76,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
"""Ensure a piece of media is quarantined when trying to access it."""
channel = self.make_request(
"GET",
- f"/_matrix/media/v3/download/{server_and_media_id}",
+ f"/_matrix/client/v1/media/download/{server_and_media_id}",
shorthand=False,
access_token=admin_user_tok,
)
@@ -131,7 +133,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
# Attempt to access the media
channel = self.make_request(
"GET",
- f"/_matrix/media/v3/download/{server_name_and_media_id}",
+ f"/_matrix/client/v1/media/download/{server_name_and_media_id}",
shorthand=False,
access_token=non_admin_user_tok,
)
@@ -226,10 +228,25 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
# Upload some media
response_1 = self.helper.upload_media(SMALL_PNG, tok=non_admin_user_tok)
response_2 = self.helper.upload_media(SMALL_PNG, tok=non_admin_user_tok)
+ response_3 = self.helper.upload_media(SMALL_PNG, tok=non_admin_user_tok)
# Extract media IDs
server_and_media_id_1 = response_1["content_uri"][6:]
server_and_media_id_2 = response_2["content_uri"][6:]
+ server_and_media_id_3 = response_3["content_uri"][6:]
+
+ # Remove the hash from the media to simulate historic media.
+ self.get_success(
+ self.hs.get_datastores().main.update_local_media(
+ media_id=server_and_media_id_3.split("/")[1],
+ media_type="image/png",
+ upload_name=None,
+ media_length=123,
+ user_id=UserID.from_string(non_admin_user),
+ # Hack to force some media to have no hash.
+ sha256=cast(str, None),
+ )
+ )
# Quarantine all media by this user
url = "/_synapse/admin/v1/user/%s/media/quarantine" % urllib.parse.quote(
@@ -243,12 +260,13 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
self.pump(1.0)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(
- channel.json_body, {"num_quarantined": 2}, "Expected 2 quarantined items"
+ channel.json_body, {"num_quarantined": 3}, "Expected 3 quarantined items"
)
# Attempt to access each piece of media
self._ensure_quarantined(admin_user_tok, server_and_media_id_1)
self._ensure_quarantined(admin_user_tok, server_and_media_id_2)
+ self._ensure_quarantined(admin_user_tok, server_and_media_id_3)
def test_cannot_quarantine_safe_media(self) -> None:
self.register_user("user_admin", "pass", admin=True)
@@ -295,7 +313,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
# Attempt to access each piece of media
channel = self.make_request(
"GET",
- f"/_matrix/media/v3/download/{server_and_media_id_2}",
+ f"/_matrix/client/v1/media/download/{server_and_media_id_2}",
shorthand=False,
access_token=non_admin_user_tok,
)
diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py
index a88c77bd19..531162a6e9 100644
--- a/tests/rest/admin/test_device.py
+++ b/tests/rest/admin/test_device.py
@@ -27,7 +27,7 @@ from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
from synapse.api.errors import Codes
from synapse.handlers.device import DeviceHandler
-from synapse.rest.client import login
+from synapse.rest.client import devices, login
from synapse.server import HomeServer
from synapse.util import Clock
@@ -299,6 +299,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
class DevicesRestTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets,
+ devices.register_servlets,
login.register_servlets,
]
@@ -390,15 +391,63 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(0, channel.json_body["total"])
self.assertEqual(0, len(channel.json_body["devices"]))
+ @unittest.override_config(
+ {"experimental_features": {"msc2697_enabled": False, "msc3814_enabled": True}}
+ )
def test_get_devices(self) -> None:
"""
Tests that a normal lookup for devices is successfully
"""
# Create devices
number_devices = 5
- for _ in range(number_devices):
+ # we create 2 fewer devices in the loop, because we will create another
+ # login after the loop, and we will create a dehydrated device
+ for _ in range(number_devices - 2):
self.login("user", "pass")
+ other_user_token = self.login("user", "pass")
+ dehydrated_device_url = (
+ "/_matrix/client/unstable/org.matrix.msc3814.v1/dehydrated_device"
+ )
+ content = {
+ "device_data": {
+ "algorithm": "m.dehydration.v1.olm",
+ },
+ "device_id": "dehydrated_device",
+ "initial_device_display_name": "foo bar",
+ "device_keys": {
+ "user_id": "@user:test",
+ "device_id": "dehydrated_device",
+ "valid_until_ts": "80",
+ "algorithms": [
+ "m.olm.curve25519-aes-sha2",
+ ],
+ "keys": {
+ "<algorithm>:<device_id>": "<key_base64>",
+ },
+ "signatures": {
+ "@user:test": {"<algorithm>:<device_id>": "<signature_base64>"}
+ },
+ },
+ "fallback_keys": {
+ "alg1:device1": "f4llb4ckk3y",
+ "signed_<algorithm>:<device_id>": {
+ "fallback": "true",
+ "key": "f4llb4ckk3y",
+ "signatures": {
+ "@user:test": {"<algorithm>:<device_id>": "<key_base64>"}
+ },
+ },
+ },
+ "one_time_keys": {"alg1:k1": "0net1m3k3y"},
+ }
+ self.make_request(
+ "PUT",
+ dehydrated_device_url,
+ access_token=other_user_token,
+ content=content,
+ )
+
# Get devices
channel = self.make_request(
"GET",
@@ -410,13 +459,22 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(number_devices, channel.json_body["total"])
self.assertEqual(number_devices, len(channel.json_body["devices"]))
self.assertEqual(self.other_user, channel.json_body["devices"][0]["user_id"])
- # Check that all fields are available
+ # Check that all fields are available, and that the dehydrated device is marked as dehydrated
+ found_dehydrated = False
for d in channel.json_body["devices"]:
self.assertIn("user_id", d)
self.assertIn("device_id", d)
self.assertIn("display_name", d)
self.assertIn("last_seen_ip", d)
self.assertIn("last_seen_ts", d)
+ if d["device_id"] == "dehydrated_device":
+ self.assertTrue(d.get("dehydrated"))
+ found_dehydrated = True
+ else:
+ # Either the field is not present, or set to False
+ self.assertFalse(d.get("dehydrated"))
+
+ self.assertTrue(found_dehydrated)
class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py
index feb410a11d..6047ce1f4a 100644
--- a/tests/rest/admin/test_event_reports.py
+++ b/tests/rest/admin/test_event_reports.py
@@ -378,6 +378,41 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(channel.json_body["event_reports"]), 1)
self.assertNotIn("next_token", channel.json_body)
+ def test_filter_against_event_sender(self) -> None:
+ """
+ Tests filtering by the sender of the reported event
+ """
+ # first grab all the reports
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(channel.code, 200)
+
+ # filter out set of report ids of events sent by one of the users
+ locally_filtered_report_ids = set()
+ for event_report in channel.json_body["event_reports"]:
+ if event_report["sender"] == self.other_user:
+ locally_filtered_report_ids.add(event_report["id"])
+
+ # grab the report ids by sender and compare to filtered report ids
+ channel = self.make_request(
+ "GET",
+ f"{self.url}?event_sender_user_id={self.other_user}",
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code)
+ self.assertEqual(channel.json_body["total"], len(locally_filtered_report_ids))
+
+ event_reports = channel.json_body["event_reports"]
+ server_filtered_report_ids = set()
+ for event_report in event_reports:
+ server_filtered_report_ids.add(event_report["id"])
+ self.assertIncludes(
+ locally_filtered_report_ids, server_filtered_report_ids, exact=True
+ )
+
def _create_event_and_report(self, room_id: str, user_tok: str) -> None:
"""Create and report events"""
resp = self.helper.send(room_id, tok=user_tok)
diff --git a/tests/rest/admin/test_federation.py b/tests/rest/admin/test_federation.py
index c2015774a1..d5ae3345f5 100644
--- a/tests/rest/admin/test_federation.py
+++ b/tests/rest/admin/test_federation.py
@@ -96,7 +96,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
- # unkown order_by
+ # unknown order_by
channel = self.make_request(
"GET",
self.url + "?order_by=bar",
diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py
index f378165513..da0e9749aa 100644
--- a/tests/rest/admin/test_media.py
+++ b/tests/rest/admin/test_media.py
@@ -35,7 +35,8 @@ from synapse.server import HomeServer
from synapse.util import Clock
from tests import unittest
-from tests.test_utils import SMALL_PNG
+from tests.test_utils import SMALL_CMYK_JPEG, SMALL_PNG
+from tests.unittest import override_config
VALID_TIMESTAMP = 1609459200000 # 2021-01-01 in milliseconds
INVALID_TIMESTAMP_IN_S = 1893456000 # 2030-01-01 in seconds
@@ -126,6 +127,7 @@ class DeleteMediaByIDTestCase(_AdminMediaTests):
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only delete local media", channel.json_body["error"])
+ @override_config({"enable_authenticated_media": False})
def test_delete_media(self) -> None:
"""
Tests that delete a media is successfully
@@ -371,6 +373,7 @@ class DeleteMediaByDateSizeTestCase(_AdminMediaTests):
self._access_media(server_and_media_id, False)
+ @override_config({"enable_authenticated_media": False})
def test_keep_media_by_date(self) -> None:
"""
Tests that media is not deleted if it is newer than `before_ts`
@@ -408,6 +411,7 @@ class DeleteMediaByDateSizeTestCase(_AdminMediaTests):
self._access_media(server_and_media_id, False)
+ @override_config({"enable_authenticated_media": False})
def test_keep_media_by_size(self) -> None:
"""
Tests that media is not deleted if its size is smaller than or equal
@@ -443,6 +447,7 @@ class DeleteMediaByDateSizeTestCase(_AdminMediaTests):
self._access_media(server_and_media_id, False)
+ @override_config({"enable_authenticated_media": False})
def test_keep_media_by_user_avatar(self) -> None:
"""
Tests that we do not delete media if is used as a user avatar
@@ -487,6 +492,7 @@ class DeleteMediaByDateSizeTestCase(_AdminMediaTests):
self._access_media(server_and_media_id, False)
+ @override_config({"enable_authenticated_media": False})
def test_keep_media_by_room_avatar(self) -> None:
"""
Tests that we do not delete media if it is used as a room avatar
@@ -592,23 +598,27 @@ class DeleteMediaByDateSizeTestCase(_AdminMediaTests):
class QuarantineMediaByIDTestCase(_AdminMediaTests):
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.store = hs.get_datastores().main
- self.server_name = hs.hostname
-
- self.admin_user = self.register_user("admin", "pass", admin=True)
- self.admin_user_tok = self.login("admin", "pass")
-
+ def upload_media_and_return_media_id(self, data: bytes) -> str:
# Upload some media into the room
response = self.helper.upload_media(
- SMALL_PNG,
+ data,
tok=self.admin_user_tok,
expect_code=200,
)
# Extract media ID from the response
server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://'
- self.media_id = server_and_media_id.split("/")[1]
+ return server_and_media_id.split("/")[1]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.store = hs.get_datastores().main
+ self.server_name = hs.hostname
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+ self.media_id = self.upload_media_and_return_media_id(SMALL_PNG)
+ self.media_id_2 = self.upload_media_and_return_media_id(SMALL_PNG)
+ self.media_id_3 = self.upload_media_and_return_media_id(SMALL_PNG)
+ self.media_id_other = self.upload_media_and_return_media_id(SMALL_CMYK_JPEG)
self.url = "/_synapse/admin/v1/media/%s/%s/%s"
@parameterized.expand(["quarantine", "unquarantine"])
@@ -680,6 +690,52 @@ class QuarantineMediaByIDTestCase(_AdminMediaTests):
assert media_info is not None
self.assertFalse(media_info.quarantined_by)
+ def test_quarantine_media_match_hash(self) -> None:
+ """
+ Tests that quarantining removes all media with the same hash
+ """
+
+ media_info = self.get_success(self.store.get_local_media(self.media_id))
+ assert media_info is not None
+ self.assertFalse(media_info.quarantined_by)
+
+ # quarantining
+ channel = self.make_request(
+ "POST",
+ self.url % ("quarantine", self.server_name, self.media_id),
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertFalse(channel.json_body)
+
+ # Test that ALL similar media was quarantined.
+ for media in [self.media_id, self.media_id_2, self.media_id_3]:
+ media_info = self.get_success(self.store.get_local_media(media))
+ assert media_info is not None
+ self.assertTrue(media_info.quarantined_by)
+
+ # Test that other media was not.
+ media_info = self.get_success(self.store.get_local_media(self.media_id_other))
+ assert media_info is not None
+ self.assertFalse(media_info.quarantined_by)
+
+ # remove from quarantine
+ channel = self.make_request(
+ "POST",
+ self.url % ("unquarantine", self.server_name, self.media_id),
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertFalse(channel.json_body)
+
+ # Test that ALL similar media is now reset.
+ for media in [self.media_id, self.media_id_2, self.media_id_3]:
+ media_info = self.get_success(self.store.get_local_media(media))
+ assert media_info is not None
+ self.assertFalse(media_info.quarantined_by)
+
def test_quarantine_protected_media(self) -> None:
"""
Tests that quarantining from protected media fails
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index 95ed736451..e22dfcba1b 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -369,6 +369,47 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
self._is_blocked(room_id)
+ def test_invited_users_not_joined_to_new_room(self) -> None:
+ """
+ Test that when a new room id is provided, users who are only invited
+ but have not joined original room are not moved to new room.
+ """
+ invitee = self.register_user("invitee", "pass")
+
+ self.helper.invite(
+ self.room_id, self.other_user, invitee, tok=self.other_user_tok
+ )
+
+ # verify that user is invited
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/v3/rooms/{self.room_id}/members?membership=invite",
+ access_token=self.other_user_tok,
+ )
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(len(channel.json_body["chunk"]), 1)
+ invite = channel.json_body["chunk"][0]
+ self.assertEqual(invite["state_key"], invitee)
+
+ # shutdown room
+ channel = self.make_request(
+ "DELETE",
+ self.url,
+ {"new_room_user_id": self.admin_user},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(len(channel.json_body["kicked_users"]), 2)
+
+ # joined member is moved to new room but invited user is not
+ users_in_room = self.get_success(
+ self.store.get_users_in_room(channel.json_body["new_room_id"])
+ )
+ self.assertNotIn(invitee, users_in_room)
+ self.assertIn(self.other_user, users_in_room)
+ self._is_purged(self.room_id)
+ self._has_no_members(self.room_id)
+
def test_shutdown_room_consent(self) -> None:
"""Test that we can shutdown rooms with local users who have not
yet accepted the privacy policy. This used to fail when we tried to
@@ -758,6 +799,8 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
self.assertEqual(2, len(channel.json_body["results"]))
self.assertEqual("complete", channel.json_body["results"][0]["status"])
self.assertEqual("complete", channel.json_body["results"][1]["status"])
+ self.assertEqual(self.room_id, channel.json_body["results"][0]["room_id"])
+ self.assertEqual(self.room_id, channel.json_body["results"][1]["room_id"])
delete_ids = {delete_id1, delete_id2}
self.assertTrue(channel.json_body["results"][0]["delete_id"] in delete_ids)
delete_ids.remove(channel.json_body["results"][0]["delete_id"])
@@ -777,6 +820,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
self.assertEqual(1, len(channel.json_body["results"]))
self.assertEqual("complete", channel.json_body["results"][0]["status"])
self.assertEqual(delete_id2, channel.json_body["results"][0]["delete_id"])
+ self.assertEqual(self.room_id, channel.json_body["results"][0]["room_id"])
# get status after more than clearing time for all tasks
self.reactor.advance(TaskScheduler.KEEP_TASKS_FOR_MS / 1000 / 2)
@@ -1237,6 +1281,9 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
self.assertEqual(
delete_id, channel_room_id.json_body["results"][0]["delete_id"]
)
+ self.assertEqual(
+ self.room_id, channel_room_id.json_body["results"][0]["room_id"]
+ )
# get information by delete_id
channel_delete_id = self.make_request(
@@ -1249,6 +1296,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
channel_delete_id.code,
msg=channel_delete_id.json_body,
)
+ self.assertEqual(self.room_id, channel_delete_id.json_body["room_id"])
# test values that are the same in both responses
for content in [
@@ -1282,6 +1330,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
+ @unittest.override_config({"room_list_publication_rules": [{"action": "allow"}]})
def test_list_rooms(self) -> None:
"""Test that we can list rooms"""
# Create 3 test rooms
@@ -1311,7 +1360,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
# Check that response json body contains a "rooms" key
self.assertTrue(
"rooms" in channel.json_body,
- msg="Response body does not " "contain a 'rooms' key",
+ msg="Response body does not contain a 'rooms' key",
)
# Check that 3 rooms were returned
@@ -1795,6 +1844,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
self.assertEqual(room_id, channel.json_body["rooms"][0].get("room_id"))
self.assertEqual("ж", channel.json_body["rooms"][0].get("name"))
+ @unittest.override_config({"room_list_publication_rules": [{"action": "allow"}]})
def test_filter_public_rooms(self) -> None:
self.helper.create_room_as(
self.admin_user, tok=self.admin_user_tok, is_public=True
@@ -1872,6 +1922,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
self.assertEqual(1, response.json_body["total_rooms"])
self.assertEqual(1, len(response.json_body["rooms"]))
+ @unittest.override_config({"room_list_publication_rules": [{"action": "allow"}]})
def test_single_room(self) -> None:
"""Test that a single room can be requested correctly"""
# Create two test rooms
@@ -2035,6 +2086,52 @@ class RoomTestCase(unittest.HomeserverTestCase):
# the create_room already does the right thing, so no need to verify that we got
# the state events it created.
+ def test_room_state_param(self) -> None:
+ """Test that filtering by state event type works when requesting state"""
+ room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+ channel = self.make_request(
+ "GET",
+ f"/_synapse/admin/v1/rooms/{room_id}/state?type=m.room.member",
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code)
+ state = channel.json_body["state"]
+ # only one member has joined so there should be one membership event
+ self.assertEqual(1, len(state))
+ event = state[0]
+ self.assertEqual(event["type"], "m.room.member")
+ self.assertEqual(event["state_key"], self.admin_user)
+
+ def test_room_state_param_empty(self) -> None:
+ """Test that passing an empty string as state filter param returns no state events"""
+ room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+ channel = self.make_request(
+ "GET",
+ f"/_synapse/admin/v1/rooms/{room_id}/state?type=",
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code)
+ state = channel.json_body["state"]
+ self.assertEqual(5, len(state))
+
+ def test_room_state_param_not_in_room(self) -> None:
+ """
+ Test that passing a state filter param for a state event not in the room
+ returns no state events
+ """
+ room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+ channel = self.make_request(
+ "GET",
+ f"/_synapse/admin/v1/rooms/{room_id}/state?type=m.room.custom",
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code)
+ state = channel.json_body["state"]
+ self.assertEqual(0, len(state))
+
def _set_canonical_alias(
self, room_id: str, test_alias: str, admin_user_tok: str
) -> None:
@@ -3050,7 +3147,7 @@ PURGE_TABLES = [
"pusher_throttle",
"room_account_data",
"room_tags",
- # "state_groups", # Current impl leaves orphaned state groups around.
+ "state_groups",
"state_groups_state",
"federation_inbound_events_staging",
]
diff --git a/tests/rest/admin/test_scheduled_tasks.py b/tests/rest/admin/test_scheduled_tasks.py
new file mode 100644
index 0000000000..9654e9322b
--- /dev/null
+++ b/tests/rest/admin/test_scheduled_tasks.py
@@ -0,0 +1,192 @@
+#
+# This file is licensed under the Affero General Public License (AGPL) version 3.
+#
+# Copyright (C) 2025 New Vector, Ltd
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as
+# published by the Free Software Foundation, either version 3 of the
+# License, or (at your option) any later version.
+#
+# See the GNU Affero General Public License for more details:
+# <https://www.gnu.org/licenses/agpl-3.0.html>.
+#
+#
+#
+from typing import Mapping, Optional, Tuple
+
+from twisted.test.proto_helpers import MemoryReactor
+
+import synapse.rest.admin
+from synapse.api.errors import Codes
+from synapse.rest.client import login
+from synapse.server import HomeServer
+from synapse.types import JsonMapping, ScheduledTask, TaskStatus
+from synapse.util import Clock
+
+from tests import unittest
+
+
+class ScheduledTasksAdminApiTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.store = hs.get_datastores().main
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+ self._task_scheduler = hs.get_task_scheduler()
+
+ # create and schedule a few tasks
+ async def _test_task(
+ task: ScheduledTask,
+ ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]:
+ return TaskStatus.ACTIVE, None, None
+
+ async def _finished_test_task(
+ task: ScheduledTask,
+ ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]:
+ return TaskStatus.COMPLETE, None, None
+
+ async def _failed_test_task(
+ task: ScheduledTask,
+ ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]:
+ return TaskStatus.FAILED, None, "Everything failed"
+
+ self._task_scheduler.register_action(_test_task, "test_task")
+ self.get_success(
+ self._task_scheduler.schedule_task("test_task", resource_id="test")
+ )
+
+ self._task_scheduler.register_action(_finished_test_task, "finished_test_task")
+ self.get_success(
+ self._task_scheduler.schedule_task(
+ "finished_test_task", resource_id="finished_task"
+ )
+ )
+
+ self._task_scheduler.register_action(_failed_test_task, "failed_test_task")
+ self.get_success(
+ self._task_scheduler.schedule_task(
+ "failed_test_task", resource_id="failed_task"
+ )
+ )
+
+ def check_scheduled_tasks_response(self, scheduled_tasks: Mapping) -> list:
+ result = []
+ for task in scheduled_tasks:
+ if task["resource_id"] == "test":
+ self.assertEqual(task["status"], TaskStatus.ACTIVE)
+ self.assertEqual(task["action"], "test_task")
+ result.append(task)
+ if task["resource_id"] == "finished_task":
+ self.assertEqual(task["status"], TaskStatus.COMPLETE)
+ self.assertEqual(task["action"], "finished_test_task")
+ result.append(task)
+ if task["resource_id"] == "failed_task":
+ self.assertEqual(task["status"], TaskStatus.FAILED)
+ self.assertEqual(task["action"], "failed_test_task")
+ result.append(task)
+
+ return result
+
+ def test_requester_is_not_admin(self) -> None:
+ """
+ If the user is not a server admin, an error 403 is returned.
+ """
+
+ self.register_user("user", "pass", admin=False)
+ other_user_tok = self.login("user", "pass")
+
+ channel = self.make_request(
+ "GET",
+ "/_synapse/admin/v1/scheduled_tasks",
+ content={},
+ access_token=other_user_tok,
+ )
+
+ self.assertEqual(403, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_scheduled_tasks(self) -> None:
+ """
+ Test that endpoint returns scheduled tasks.
+ """
+
+ channel = self.make_request(
+ "GET",
+ "/_synapse/admin/v1/scheduled_tasks",
+ content={},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ scheduled_tasks = channel.json_body["scheduled_tasks"]
+
+ # make sure we got back all the scheduled tasks
+ found_tasks = self.check_scheduled_tasks_response(scheduled_tasks)
+ self.assertEqual(len(found_tasks), 3)
+
+ def test_filtering_scheduled_tasks(self) -> None:
+ """
+ Test that filtering the scheduled tasks response via query params works as expected.
+ """
+ # filter via job_status
+ channel = self.make_request(
+ "GET",
+ "/_synapse/admin/v1/scheduled_tasks?job_status=active",
+ content={},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ scheduled_tasks = channel.json_body["scheduled_tasks"]
+ found_tasks = self.check_scheduled_tasks_response(scheduled_tasks)
+
+ # only the active task should have been returned
+ self.assertEqual(len(found_tasks), 1)
+ self.assertEqual(found_tasks[0]["status"], "active")
+
+ # filter via action_name
+ channel = self.make_request(
+ "GET",
+ "/_synapse/admin/v1/scheduled_tasks?action_name=test_task",
+ content={},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ scheduled_tasks = channel.json_body["scheduled_tasks"]
+
+ # only test_task should have been returned
+ found_tasks = self.check_scheduled_tasks_response(scheduled_tasks)
+ self.assertEqual(len(found_tasks), 1)
+ self.assertEqual(found_tasks[0]["action"], "test_task")
+
+ # filter via max_timestamp
+ channel = self.make_request(
+ "GET",
+ "/_synapse/admin/v1/scheduled_tasks?max_timestamp=0",
+ content={},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ scheduled_tasks = channel.json_body["scheduled_tasks"]
+ found_tasks = self.check_scheduled_tasks_response(scheduled_tasks)
+
+ # none should have been returned
+ self.assertEqual(len(found_tasks), 0)
+
+ # filter via resource id
+ channel = self.make_request(
+ "GET",
+ "/_synapse/admin/v1/scheduled_tasks?resource_id=failed_task",
+ content={},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ scheduled_tasks = channel.json_body["scheduled_tasks"]
+ found_tasks = self.check_scheduled_tasks_response(scheduled_tasks)
+
+ # only the task with the matching resource id should have been returned
+ self.assertEqual(len(found_tasks), 1)
+ self.assertEqual(found_tasks[0]["resource_id"], "failed_task")
diff --git a/tests/rest/admin/test_server_notice.py b/tests/rest/admin/test_server_notice.py
index 2a1e42bbc8..150caeeee2 100644
--- a/tests/rest/admin/test_server_notice.py
+++ b/tests/rest/admin/test_server_notice.py
@@ -531,9 +531,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
# simulate a change in server config after a server restart.
new_display_name = "new display name"
- self.server_notices_manager._config.servernotices.server_notices_mxid_display_name = (
- new_display_name
- )
+ self.server_notices_manager._config.servernotices.server_notices_mxid_display_name = new_display_name
self.server_notices_manager.get_or_create_notice_room_for_user.cache.invalidate_all()
self.make_request(
@@ -577,9 +575,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
# simulate a change in server config after a server restart.
new_avatar_url = "test/new-url"
- self.server_notices_manager._config.servernotices.server_notices_mxid_avatar_url = (
- new_avatar_url
- )
+ self.server_notices_manager._config.servernotices.server_notices_mxid_avatar_url = new_avatar_url
self.server_notices_manager.get_or_create_notice_room_for_user.cache.invalidate_all()
self.make_request(
@@ -692,9 +688,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
# simulate a change in server config after a server restart.
new_avatar_url = "test/new-url"
- self.server_notices_manager._config.servernotices.server_notices_room_avatar_url = (
- new_avatar_url
- )
+ self.server_notices_manager._config.servernotices.server_notices_room_avatar_url = new_avatar_url
self.server_notices_manager.get_or_create_notice_room_for_user.cache.invalidate_all()
self.make_request(
diff --git a/tests/rest/admin/test_statistics.py b/tests/rest/admin/test_statistics.py
index 5f60e19e56..07ec49c4e5 100644
--- a/tests/rest/admin/test_statistics.py
+++ b/tests/rest/admin/test_statistics.py
@@ -82,7 +82,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
"""
If parameters are invalid, an error is returned.
"""
- # unkown order_by
+ # unknown order_by
channel = self.make_request(
"GET",
self.url + "?order_by=bar",
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 16bb4349f5..412718f06c 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -21,9 +21,12 @@
import hashlib
import hmac
+import json
import os
+import time
import urllib.parse
from binascii import unhexlify
+from http import HTTPStatus
from typing import Dict, List, Optional
from unittest.mock import AsyncMock, Mock, patch
@@ -33,7 +36,13 @@ from twisted.test.proto_helpers import MemoryReactor
from twisted.web.resource import Resource
import synapse.rest.admin
-from synapse.api.constants import ApprovalNoticeMedium, LoginType, UserTypes
+from synapse.api.constants import (
+ ApprovalNoticeMedium,
+ EventContentFields,
+ EventTypes,
+ LoginType,
+ UserTypes,
+)
from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError
from synapse.api.room_versions import RoomVersions
from synapse.media.filepath import MediaFilePaths
@@ -42,6 +51,7 @@ from synapse.rest.client import (
devices,
login,
logout,
+ media,
profile,
register,
room,
@@ -54,7 +64,9 @@ from synapse.types import JsonDict, UserID, create_requester
from synapse.util import Clock
from tests import unittest
+from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.test_utils import SMALL_PNG
+from tests.test_utils.event_injection import inject_event
from tests.unittest import override_config
@@ -316,6 +328,61 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid user type", channel.json_body["error"])
+ @override_config(
+ {
+ "user_types": {
+ "extra_user_types": ["extra1", "extra2"],
+ }
+ }
+ )
+ def test_extra_user_type(self) -> None:
+ """
+ Check that the extra user type can be used when registering a user.
+ """
+
+ def nonce_mac(user_type: str) -> tuple[str, str]:
+ """
+ Get a nonce and the expected HMAC for that nonce.
+ """
+ channel = self.make_request("GET", self.url)
+ nonce = channel.json_body["nonce"]
+
+ want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
+ want_mac.update(
+ nonce.encode("ascii")
+ + b"\x00alice\x00abc123\x00notadmin\x00"
+ + user_type.encode("ascii")
+ )
+ want_mac_str = want_mac.hexdigest()
+
+ return nonce, want_mac_str
+
+ nonce, mac = nonce_mac("extra1")
+ # Valid user_type
+ body = {
+ "nonce": nonce,
+ "username": "alice",
+ "password": "abc123",
+ "user_type": "extra1",
+ "mac": mac,
+ }
+ channel = self.make_request("POST", self.url, body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ nonce, mac = nonce_mac("extra3")
+ # Invalid user_type
+ body = {
+ "nonce": nonce,
+ "username": "alice",
+ "password": "abc123",
+ "user_type": "extra3",
+ "mac": mac,
+ }
+ channel = self.make_request("POST", self.url, body)
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual("Invalid user type", channel.json_body["error"])
+
def test_displayname(self) -> None:
"""
Test that displayname of new user is set
@@ -715,7 +782,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
- # unkown order_by
+ # unknown order_by
channel = self.make_request(
"GET",
self.url + "?order_by=bar",
@@ -1174,6 +1241,80 @@ class UsersListTestCase(unittest.HomeserverTestCase):
not_user_types=["custom"],
)
+ @override_config(
+ {
+ "user_types": {
+ "extra_user_types": ["extra1", "extra2"],
+ }
+ }
+ )
+ def test_filter_not_user_types_with_extra(self) -> None:
+ """Tests that the endpoint handles the not_user_types param when extra_user_types are configured"""
+
+ regular_user_id = self.register_user("normalo", "secret")
+
+ extra1_user_id = self.register_user("extra1", "secret")
+ self.make_request(
+ "PUT",
+ "/_synapse/admin/v2/users/" + urllib.parse.quote(extra1_user_id),
+ {"user_type": "extra1"},
+ access_token=self.admin_user_tok,
+ )
+
+ def test_user_type(
+ expected_user_ids: List[str], not_user_types: Optional[List[str]] = None
+ ) -> None:
+ """Runs a test for the not_user_types param
+ Args:
+ expected_user_ids: Ids of the users that are expected to be returned
+ not_user_types: List of values for the not_user_types param
+ """
+
+ user_type_query = ""
+
+ if not_user_types is not None:
+ user_type_query = "&".join(
+ [f"not_user_type={u}" for u in not_user_types]
+ )
+
+ test_url = f"{self.url}?{user_type_query}"
+ channel = self.make_request(
+ "GET",
+ test_url,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, channel.code)
+ self.assertEqual(channel.json_body["total"], len(expected_user_ids))
+ self.assertEqual(
+ expected_user_ids,
+ [u["name"] for u in channel.json_body["users"]],
+ )
+
+ # Request without user_types → all users expected
+ test_user_type([self.admin_user, extra1_user_id, regular_user_id])
+
+ # Request and exclude extra1 user type
+ test_user_type(
+ [self.admin_user, regular_user_id],
+ not_user_types=["extra1"],
+ )
+
+ # Request and exclude extra1 and extra2 user types
+ test_user_type(
+ [self.admin_user, regular_user_id],
+ not_user_types=["extra1", "extra2"],
+ )
+
+ # Request and exclude empty user types → only expected the extra1 user
+ test_user_type([extra1_user_id], not_user_types=[""])
+
+ # Request and exclude an unregistered type → expect all users
+ test_user_type(
+ [self.admin_user, extra1_user_id, regular_user_id],
+ not_user_types=["extra3"],
+ )
+
def test_erasure_status(self) -> None:
# Create a new user.
user_id = self.register_user("eraseme", "eraseme")
@@ -1411,9 +1552,6 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
UserID.from_string("@user:test"), "mxc://servername/mediaid"
)
)
- self.get_success(
- self.store.user_add_threepid("@user:test", "email", "foo@bar.com", 0, 0)
- )
def test_no_auth(self) -> None:
"""
@@ -1500,7 +1638,6 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(False, channel.json_body["deactivated"])
- self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
self.assertEqual("User1", channel.json_body["displayname"])
self.assertFalse(channel.json_body["erased"])
@@ -1525,7 +1662,6 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(True, channel.json_body["deactivated"])
- self.assertEqual(0, len(channel.json_body["threepids"]))
self.assertIsNone(channel.json_body["avatar_url"])
self.assertIsNone(channel.json_body["displayname"])
self.assertTrue(channel.json_body["erased"])
@@ -1568,7 +1704,6 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(False, channel.json_body["deactivated"])
- self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
self.assertEqual("User1", channel.json_body["displayname"])
@@ -1592,7 +1727,6 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(True, channel.json_body["deactivated"])
- self.assertEqual(0, len(channel.json_body["threepids"]))
self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
self.assertEqual("User1", channel.json_body["displayname"])
@@ -1622,7 +1756,6 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(False, channel.json_body["deactivated"])
- self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
self.assertIsNone(channel.json_body["avatar_url"])
self.assertIsNone(channel.json_body["displayname"])
@@ -1646,7 +1779,6 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(True, channel.json_body["deactivated"])
- self.assertEqual(0, len(channel.json_body["threepids"]))
self.assertIsNone(channel.json_body["avatar_url"])
self.assertIsNone(channel.json_body["displayname"])
@@ -1817,25 +1949,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
- # threepids not valid
- channel = self.make_request(
- "PUT",
- self.url_other_user,
- access_token=self.admin_user_tok,
- content={"threepids": {"medium": "email", "wrong_address": "id"}},
- )
- self.assertEqual(400, channel.code, msg=channel.json_body)
- self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
-
- channel = self.make_request(
- "PUT",
- self.url_other_user,
- access_token=self.admin_user_tok,
- content={"threepids": {"address": "value"}},
- )
- self.assertEqual(400, channel.code, msg=channel.json_body)
- self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
-
def test_get_user(self) -> None:
"""
Test a simple get of a user.
@@ -1890,8 +2003,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(201, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("Bob's name", channel.json_body["displayname"])
- self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
- self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
self.assertTrue(channel.json_body["admin"])
self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
self._check_fields(channel.json_body)
@@ -1906,8 +2017,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("Bob's name", channel.json_body["displayname"])
- self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
- self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
self.assertTrue(channel.json_body["admin"])
self.assertFalse(channel.json_body["is_guest"])
self.assertFalse(channel.json_body["deactivated"])
@@ -1945,9 +2054,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(201, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("Bob's name", channel.json_body["displayname"])
- self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
- self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
- self.assertEqual(1, len(channel.json_body["threepids"]))
self.assertEqual(
"external_id1", channel.json_body["external_ids"][0]["external_id"]
)
@@ -1969,8 +2075,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("Bob's name", channel.json_body["displayname"])
- self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
- self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
self.assertFalse(channel.json_body["admin"])
self.assertFalse(channel.json_body["is_guest"])
self.assertFalse(channel.json_body["deactivated"])
@@ -2062,123 +2166,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertFalse(channel.json_body["admin"])
- @override_config(
- {
- "email": {
- "enable_notifs": True,
- "notif_for_new_users": True,
- "notif_from": "test@example.com",
- },
- "public_baseurl": "https://example.com",
- }
- )
- def test_create_user_email_notif_for_new_users(self) -> None:
- """
- Check that a new regular user is created successfully and
- got an email pusher.
- """
- url = self.url_prefix % "@bob:test"
-
- # Create user
- body = {
- "password": "abc123",
- # Note that the given email is not in canonical form.
- "threepids": [{"medium": "email", "address": "Bob@bob.bob"}],
- }
-
- channel = self.make_request(
- "PUT",
- url,
- access_token=self.admin_user_tok,
- content=body,
- )
-
- self.assertEqual(201, channel.code, msg=channel.json_body)
- self.assertEqual("@bob:test", channel.json_body["name"])
- self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
- self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
-
- pushers = list(
- self.get_success(self.store.get_pushers_by({"user_name": "@bob:test"}))
- )
- self.assertEqual(len(pushers), 1)
- self.assertEqual("@bob:test", pushers[0].user_name)
-
- @override_config(
- {
- "email": {
- "enable_notifs": False,
- "notif_for_new_users": False,
- "notif_from": "test@example.com",
- },
- "public_baseurl": "https://example.com",
- }
- )
- def test_create_user_email_no_notif_for_new_users(self) -> None:
- """
- Check that a new regular user is created successfully and
- got not an email pusher.
- """
- url = self.url_prefix % "@bob:test"
-
- # Create user
- body = {
- "password": "abc123",
- "threepids": [{"medium": "email", "address": "bob@bob.bob"}],
- }
-
- channel = self.make_request(
- "PUT",
- url,
- access_token=self.admin_user_tok,
- content=body,
- )
-
- self.assertEqual(201, channel.code, msg=channel.json_body)
- self.assertEqual("@bob:test", channel.json_body["name"])
- self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
- self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
-
- pushers = list(
- self.get_success(self.store.get_pushers_by({"user_name": "@bob:test"}))
- )
- self.assertEqual(len(pushers), 0)
-
- @override_config(
- {
- "email": {
- "enable_notifs": True,
- "notif_for_new_users": True,
- "notif_from": "test@example.com",
- },
- "public_baseurl": "https://example.com",
- }
- )
- def test_create_user_email_notif_for_new_users_with_msisdn_threepid(self) -> None:
- """
- Check that a new regular user is created successfully when they have a msisdn
- threepid and email notif_for_new_users is set to True.
- """
- url = self.url_prefix % "@bob:test"
-
- # Create user
- body = {
- "password": "abc123",
- "threepids": [{"medium": "msisdn", "address": "1234567890"}],
- }
-
- channel = self.make_request(
- "PUT",
- url,
- access_token=self.admin_user_tok,
- content=body,
- )
-
- self.assertEqual(201, channel.code, msg=channel.json_body)
- self.assertEqual("@bob:test", channel.json_body["name"])
- self.assertEqual("msisdn", channel.json_body["threepids"][0]["medium"])
- self.assertEqual("1234567890", channel.json_body["threepids"][0]["address"])
-
def test_set_password(self) -> None:
"""
Test setting a new password for another user.
@@ -2222,89 +2209,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual("foobar", channel.json_body["displayname"])
-
- def test_set_threepid(self) -> None:
- """
- Test setting threepid for an other user.
- """
-
- # Add two threepids to user
- channel = self.make_request(
- "PUT",
- self.url_other_user,
- access_token=self.admin_user_tok,
- content={
- "threepids": [
- {"medium": "email", "address": "bob1@bob.bob"},
- {"medium": "email", "address": "bob2@bob.bob"},
- ],
- },
- )
-
- self.assertEqual(200, channel.code, msg=channel.json_body)
- self.assertEqual("@user:test", channel.json_body["name"])
- self.assertEqual(2, len(channel.json_body["threepids"]))
- # result does not always have the same sort order, therefore it becomes sorted
- sorted_result = sorted(
- channel.json_body["threepids"], key=lambda k: k["address"]
- )
- self.assertEqual("email", sorted_result[0]["medium"])
- self.assertEqual("bob1@bob.bob", sorted_result[0]["address"])
- self.assertEqual("email", sorted_result[1]["medium"])
- self.assertEqual("bob2@bob.bob", sorted_result[1]["address"])
- self._check_fields(channel.json_body)
-
- # Set a new and remove a threepid
- channel = self.make_request(
- "PUT",
- self.url_other_user,
- access_token=self.admin_user_tok,
- content={
- "threepids": [
- {"medium": "email", "address": "bob2@bob.bob"},
- {"medium": "email", "address": "bob3@bob.bob"},
- ],
- },
- )
-
- self.assertEqual(200, channel.code, msg=channel.json_body)
- self.assertEqual("@user:test", channel.json_body["name"])
- self.assertEqual(2, len(channel.json_body["threepids"]))
- self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
- self.assertEqual("bob2@bob.bob", channel.json_body["threepids"][0]["address"])
- self.assertEqual("email", channel.json_body["threepids"][1]["medium"])
- self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][1]["address"])
- self._check_fields(channel.json_body)
-
- # Get user
- channel = self.make_request(
- "GET",
- self.url_other_user,
- access_token=self.admin_user_tok,
- )
-
- self.assertEqual(200, channel.code, msg=channel.json_body)
- self.assertEqual("@user:test", channel.json_body["name"])
- self.assertEqual(2, len(channel.json_body["threepids"]))
- self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
- self.assertEqual("bob2@bob.bob", channel.json_body["threepids"][0]["address"])
- self.assertEqual("email", channel.json_body["threepids"][1]["medium"])
- self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][1]["address"])
- self._check_fields(channel.json_body)
-
- # Remove threepids
- channel = self.make_request(
- "PUT",
- self.url_other_user,
- access_token=self.admin_user_tok,
- content={"threepids": []},
- )
- self.assertEqual(200, channel.code, msg=channel.json_body)
- self.assertEqual("@user:test", channel.json_body["name"])
- self.assertEqual(0, len(channel.json_body["threepids"]))
- self._check_fields(channel.json_body)
-
- def test_set_duplicate_threepid(self) -> None:
"""
Test setting the same threepid for a second user.
First user loses and second user gets mapping of this threepid.
@@ -2328,9 +2232,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(first_user, channel.json_body["name"])
- self.assertEqual(1, len(channel.json_body["threepids"]))
- self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
- self.assertEqual("bob1@bob.bob", channel.json_body["threepids"][0]["address"])
self._check_fields(channel.json_body)
# Add threepids to other user
@@ -2347,9 +2248,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
- self.assertEqual(1, len(channel.json_body["threepids"]))
- self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
- self.assertEqual("bob2@bob.bob", channel.json_body["threepids"][0]["address"])
self._check_fields(channel.json_body)
# Add two new threepids to other user
@@ -2369,15 +2267,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# other user has this two threepids
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
- self.assertEqual(2, len(channel.json_body["threepids"]))
- # result does not always have the same sort order, therefore it becomes sorted
- sorted_result = sorted(
- channel.json_body["threepids"], key=lambda k: k["address"]
- )
- self.assertEqual("email", sorted_result[0]["medium"])
- self.assertEqual("bob1@bob.bob", sorted_result[0]["address"])
- self.assertEqual("email", sorted_result[1]["medium"])
- self.assertEqual("bob3@bob.bob", sorted_result[1]["address"])
self._check_fields(channel.json_body)
# first_user has no threepid anymore
@@ -2388,7 +2277,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(first_user, channel.json_body["name"])
- self.assertEqual(0, len(channel.json_body["threepids"]))
self._check_fields(channel.json_body)
def test_set_external_id(self) -> None:
@@ -2623,9 +2511,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
UserID.from_string("@user:test"), "mxc://servername/mediaid"
)
)
- self.get_success(
- self.store.user_add_threepid("@user:test", "email", "foo@bar.com", 0, 0)
- )
# Get user
channel = self.make_request(
@@ -2637,7 +2522,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertFalse(channel.json_body["deactivated"])
- self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
self.assertEqual("User", channel.json_body["displayname"])
@@ -2652,7 +2536,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertTrue(channel.json_body["deactivated"])
- self.assertEqual(0, len(channel.json_body["threepids"]))
self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
self.assertEqual("User", channel.json_body["displayname"])
@@ -2671,7 +2554,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertTrue(channel.json_body["deactivated"])
- self.assertEqual(0, len(channel.json_body["threepids"]))
self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
self.assertEqual("User", channel.json_body["displayname"])
@@ -2965,22 +2847,18 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("@user:test", channel.json_body["name"])
self.assertTrue(channel.json_body["admin"])
- def test_set_user_type(self) -> None:
- """
- Test changing user type.
- """
-
- # Set to support type
+ def set_user_type(self, user_type: Optional[str]) -> None:
+ # Set to user_type
channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
- content={"user_type": UserTypes.SUPPORT},
+ content={"user_type": user_type},
)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
- self.assertEqual(UserTypes.SUPPORT, channel.json_body["user_type"])
+ self.assertEqual(user_type, channel.json_body["user_type"])
# Get user
channel = self.make_request(
@@ -2991,30 +2869,44 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
- self.assertEqual(UserTypes.SUPPORT, channel.json_body["user_type"])
+ self.assertEqual(user_type, channel.json_body["user_type"])
+
+ def test_set_user_type(self) -> None:
+ """
+ Test changing user type.
+ """
+
+ # Set to support type
+ self.set_user_type(UserTypes.SUPPORT)
# Change back to a regular user
- channel = self.make_request(
- "PUT",
- self.url_other_user,
- access_token=self.admin_user_tok,
- content={"user_type": None},
- )
+ self.set_user_type(None)
- self.assertEqual(200, channel.code, msg=channel.json_body)
- self.assertEqual("@user:test", channel.json_body["name"])
- self.assertIsNone(channel.json_body["user_type"])
+ @override_config({"user_types": {"extra_user_types": ["extra1", "extra2"]}})
+ def test_set_user_type_with_extras(self) -> None:
+ """
+ Test changing user type with extra_user_types configured.
+ """
- # Get user
+ # Check that we can still set to support type
+ self.set_user_type(UserTypes.SUPPORT)
+
+ # Check that we can set to an extra user type
+ self.set_user_type("extra2")
+
+ # Change back to a regular user
+ self.set_user_type(None)
+
+ # Try setting to invalid type
channel = self.make_request(
- "GET",
+ "PUT",
self.url_other_user,
access_token=self.admin_user_tok,
+ content={"user_type": "extra3"},
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
- self.assertEqual("@user:test", channel.json_body["name"])
- self.assertIsNone(channel.json_body["user_type"])
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual("Invalid user type", channel.json_body["error"])
def test_accidental_deactivation_prevention(self) -> None:
"""
@@ -3204,7 +3096,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content: Content dictionary to check
"""
self.assertIn("displayname", content)
- self.assertIn("threepids", content)
self.assertIn("avatar_url", content)
self.assertIn("admin", content)
self.assertIn("deactivated", content)
@@ -3217,6 +3108,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertIn("consent_ts", content)
self.assertIn("external_ids", content)
self.assertIn("last_seen_ts", content)
+ self.assertIn("suspended", content)
# This key was removed intentionally. Ensure it is not accidentally re-included.
self.assertNotIn("password_hash", content)
@@ -3513,6 +3405,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
+ media.register_servlets,
]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
@@ -3692,7 +3585,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
@parameterized.expand(["GET", "DELETE"])
def test_invalid_parameter(self, method: str) -> None:
"""If parameters are invalid, an error is returned."""
- # unkown order_by
+ # unknown order_by
channel = self.make_request(
method,
self.url + "?order_by=bar",
@@ -3887,9 +3780,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
image_data1 = SMALL_PNG
# Resolution: 1×1, MIME type: image/gif, Extension: gif, Size: 35 B
image_data2 = unhexlify(
- b"47494638376101000100800100000000"
- b"ffffff2c00000000010001000002024c"
- b"01003b"
+ b"47494638376101000100800100000000ffffff2c00000000010001000002024c01003b"
)
# Resolution: 1×1, MIME type: image/bmp, Extension: bmp, Size: 54 B
image_data3 = unhexlify(
@@ -4019,7 +3910,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
# Try to access a media and to create `last_access_ts`
channel = self.make_request(
"GET",
- f"/_matrix/media/v3/download/{server_and_media_id}",
+ f"/_matrix/client/v1/media/download/{server_and_media_id}",
shorthand=False,
access_token=user_token,
)
@@ -4858,100 +4749,6 @@ class UsersByExternalIdTestCase(unittest.HomeserverTestCase):
)
-class UsersByThreePidTestCase(unittest.HomeserverTestCase):
- servlets = [
- synapse.rest.admin.register_servlets,
- login.register_servlets,
- ]
-
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.store = hs.get_datastores().main
-
- self.admin_user = self.register_user("admin", "pass", admin=True)
- self.admin_user_tok = self.login("admin", "pass")
-
- self.other_user = self.register_user("user", "pass")
- self.get_success(
- self.store.user_add_threepid(
- self.other_user, "email", "user@email.com", 1, 1
- )
- )
- self.get_success(
- self.store.user_add_threepid(self.other_user, "msidn", "+1-12345678", 1, 1)
- )
-
- def test_no_auth(self) -> None:
- """Try to look up a user without authentication."""
- url = "/_synapse/admin/v1/threepid/email/users/user%40email.com"
-
- channel = self.make_request(
- "GET",
- url,
- )
-
- self.assertEqual(401, channel.code, msg=channel.json_body)
- self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
-
- def test_medium_does_not_exist(self) -> None:
- """Tests that both a lookup for a medium that does not exist and a user that
- doesn't exist with that third party ID returns a 404"""
- # test for unknown medium
- url = "/_synapse/admin/v1/threepid/publickey/users/unknown-key"
-
- channel = self.make_request(
- "GET",
- url,
- access_token=self.admin_user_tok,
- )
-
- self.assertEqual(404, channel.code, msg=channel.json_body)
- self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
-
- # test for unknown user with a known medium
- url = "/_synapse/admin/v1/threepid/email/users/unknown"
-
- channel = self.make_request(
- "GET",
- url,
- access_token=self.admin_user_tok,
- )
-
- self.assertEqual(404, channel.code, msg=channel.json_body)
- self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
-
- def test_success(self) -> None:
- """Tests a successful medium + address lookup"""
- # test for email medium with encoded value of user@email.com
- url = "/_synapse/admin/v1/threepid/email/users/user%40email.com"
-
- channel = self.make_request(
- "GET",
- url,
- access_token=self.admin_user_tok,
- )
-
- self.assertEqual(200, channel.code, msg=channel.json_body)
- self.assertEqual(
- {"user_id": self.other_user},
- channel.json_body,
- )
-
- # test for msidn medium with encoded value of +1-12345678
- url = "/_synapse/admin/v1/threepid/msidn/users/%2B1-12345678"
-
- channel = self.make_request(
- "GET",
- url,
- access_token=self.admin_user_tok,
- )
-
- self.assertEqual(200, channel.code, msg=channel.json_body)
- self.assertEqual(
- {"user_id": self.other_user},
- channel.json_body,
- )
-
-
class AllowCrossSigningReplacementTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets,
@@ -5024,7 +4821,6 @@ class UserSuspensionTestCase(unittest.HomeserverTestCase):
self.store = hs.get_datastores().main
- @override_config({"experimental_features": {"msc3823_account_suspension": True}})
def test_suspend_user(self) -> None:
# test that suspending user works
channel = self.make_request(
@@ -5089,3 +4885,766 @@ class UserSuspensionTestCase(unittest.HomeserverTestCase):
res5 = self.get_success(self.store.get_user_suspended_status(self.bad_user))
self.assertEqual(True, res5)
+
+
+class UserRedactionTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ admin.register_servlets,
+ room.register_servlets,
+ sync.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.admin = self.register_user("thomas", "pass", True)
+ self.admin_tok = self.login("thomas", "pass")
+
+ self.bad_user = self.register_user("teresa", "pass")
+ self.bad_user_tok = self.login("teresa", "pass")
+
+ self.store = hs.get_datastores().main
+
+ self.spam_checker = hs.get_module_api_callbacks().spam_checker
+
+ # create rooms - room versions 11+ store the `redacts` key in content while
+ # earlier ones don't so we use a mix of room versions
+ self.rm1 = self.helper.create_room_as(
+ self.admin, tok=self.admin_tok, room_version="7"
+ )
+ self.rm2 = self.helper.create_room_as(self.admin, tok=self.admin_tok)
+ self.rm3 = self.helper.create_room_as(
+ self.admin, tok=self.admin_tok, room_version="11"
+ )
+
+ def test_redact_messages_all_rooms(self) -> None:
+ """
+ Test that request to redact events in all rooms user is member of is successful
+ """
+ # join rooms, send some messages
+ originals = []
+ for rm in [self.rm1, self.rm2, self.rm3]:
+ join = self.helper.join(rm, self.bad_user, tok=self.bad_user_tok)
+ originals.append(join["event_id"])
+ for i in range(15):
+ event = {"body": f"hello{i}", "msgtype": "m.text"}
+ res = self.helper.send_event(
+ rm, "m.room.message", event, tok=self.bad_user_tok, expect_code=200
+ )
+ originals.append(res["event_id"])
+
+ # redact all events in all rooms
+ channel = self.make_request(
+ "POST",
+ f"/_synapse/admin/v1/user/{self.bad_user}/redact",
+ content={"rooms": []},
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel.code, 200)
+
+ matched = []
+ for rm in [self.rm1, self.rm2, self.rm3]:
+ filter = json.dumps({"types": [EventTypes.Redaction]})
+ channel = self.make_request(
+ "GET",
+ f"rooms/{rm}/messages?filter={filter}&limit=50",
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel.code, 200)
+
+ for event in channel.json_body["chunk"]:
+ for event_id in originals:
+ if (
+ event["type"] == "m.room.redaction"
+ and event["redacts"] == event_id
+ ):
+ matched.append(event_id)
+ self.assertEqual(len(matched), len(originals))
+
+ def test_redact_messages_specific_rooms(self) -> None:
+ """
+ Test that request to redact events in specified rooms user is member of is successful
+ """
+
+ originals = []
+ for rm in [self.rm1, self.rm2, self.rm3]:
+ join = self.helper.join(rm, self.bad_user, tok=self.bad_user_tok)
+ originals.append(join["event_id"])
+ for i in range(15):
+ event = {"body": f"hello{i}", "msgtype": "m.text"}
+ res = self.helper.send_event(
+ rm, "m.room.message", event, tok=self.bad_user_tok
+ )
+ originals.append(res["event_id"])
+
+ # redact messages in rooms 1 and 3
+ channel = self.make_request(
+ "POST",
+ f"/_synapse/admin/v1/user/{self.bad_user}/redact",
+ content={"rooms": [self.rm1, self.rm3]},
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel.code, 200)
+
+ # messages in requested rooms are redacted
+ for rm in [self.rm1, self.rm3]:
+ filter = json.dumps({"types": [EventTypes.Redaction]})
+ channel = self.make_request(
+ "GET",
+ f"rooms/{rm}/messages?filter={filter}&limit=50",
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel.code, 200)
+
+ matches = []
+ for event in channel.json_body["chunk"]:
+ for event_id in originals:
+ if (
+ event["type"] == "m.room.redaction"
+ and event["redacts"] == event_id
+ ):
+ matches.append((event_id, event))
+ # we redacted 16 messages
+ self.assertEqual(len(matches), 16)
+
+ channel = self.make_request(
+ "GET", f"rooms/{self.rm2}/messages?limit=50", access_token=self.admin_tok
+ )
+ self.assertEqual(channel.code, 200)
+
+ # messages in remaining room are not
+ for event in channel.json_body["chunk"]:
+ if event["type"] == "m.room.redaction":
+ self.fail("found redaction in room 2")
+
+ def test_redact_status(self) -> None:
+ rm2_originals = []
+ for rm in [self.rm1, self.rm2, self.rm3]:
+ join = self.helper.join(rm, self.bad_user, tok=self.bad_user_tok)
+ if rm == self.rm2:
+ rm2_originals.append(join["event_id"])
+ for i in range(5):
+ event = {"body": f"hello{i}", "msgtype": "m.text"}
+ res = self.helper.send_event(
+ rm, "m.room.message", event, tok=self.bad_user_tok
+ )
+ if rm == self.rm2:
+ rm2_originals.append(res["event_id"])
+
+ # redact messages in rooms 1 and 3
+ channel = self.make_request(
+ "POST",
+ f"/_synapse/admin/v1/user/{self.bad_user}/redact",
+ content={"rooms": [self.rm1, self.rm3]},
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel.code, 200)
+ id = channel.json_body.get("redact_id")
+
+ channel2 = self.make_request(
+ "GET",
+ f"/_synapse/admin/v1/user/redact_status/{id}",
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel2.code, 200)
+ self.assertEqual(channel2.json_body.get("status"), "complete")
+ self.assertEqual(channel2.json_body.get("failed_redactions"), {})
+
+ # mock that will cause persisting the redaction events to fail
+ async def check_event_for_spam(event: str) -> str:
+ return "spam"
+
+ self.spam_checker.check_event_for_spam = check_event_for_spam # type: ignore
+
+ channel3 = self.make_request(
+ "POST",
+ f"/_synapse/admin/v1/user/{self.bad_user}/redact",
+ content={"rooms": [self.rm2]},
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel.code, 200)
+ id = channel3.json_body.get("redact_id")
+
+ channel4 = self.make_request(
+ "GET",
+ f"/_synapse/admin/v1/user/redact_status/{id}",
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel4.code, 200)
+ self.assertEqual(channel4.json_body.get("status"), "complete")
+ failed_redactions = channel4.json_body.get("failed_redactions")
+ assert failed_redactions is not None
+ matched = []
+ for original in rm2_originals:
+ if failed_redactions.get(original) is not None:
+ matched.append(original)
+ self.assertEqual(len(matched), len(rm2_originals))
+
+ def test_admin_redact_works_if_user_kicked_or_banned(self) -> None:
+ originals1 = []
+ originals2 = []
+ for rm in [self.rm1, self.rm2, self.rm3]:
+ join = self.helper.join(rm, self.bad_user, tok=self.bad_user_tok)
+ if rm in [self.rm1, self.rm3]:
+ originals1.append(join["event_id"])
+ else:
+ originals2.append(join["event_id"])
+ for i in range(5):
+ event = {"body": f"hello{i}", "msgtype": "m.text"}
+ res = self.helper.send_event(
+ rm, "m.room.message", event, tok=self.bad_user_tok
+ )
+ if rm in [self.rm1, self.rm3]:
+ originals1.append(res["event_id"])
+ else:
+ originals2.append(res["event_id"])
+
+ # kick user from rooms 1 and 3
+ for r in [self.rm1, self.rm3]:
+ channel = self.make_request(
+ "POST",
+ f"/_matrix/client/r0/rooms/{r}/kick",
+ content={"reason": "being a bummer", "user_id": self.bad_user},
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+
+ # redact messages in room 1 and 3
+ channel1 = self.make_request(
+ "POST",
+ f"/_synapse/admin/v1/user/{self.bad_user}/redact",
+ content={"rooms": [self.rm1, self.rm3]},
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel1.code, 200)
+ id = channel1.json_body.get("redact_id")
+
+ # check that there were no failed redactions in room 1 and 3
+ channel2 = self.make_request(
+ "GET",
+ f"/_synapse/admin/v1/user/redact_status/{id}",
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel2.code, 200)
+ self.assertEqual(channel2.json_body.get("status"), "complete")
+ failed_redactions = channel2.json_body.get("failed_redactions")
+ self.assertEqual(failed_redactions, {})
+
+ # double check
+ for rm in [self.rm1, self.rm3]:
+ filter = json.dumps({"types": [EventTypes.Redaction]})
+ channel3 = self.make_request(
+ "GET",
+ f"rooms/{rm}/messages?filter={filter}&limit=50",
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel3.code, 200)
+
+ matches = []
+ for event in channel3.json_body["chunk"]:
+ for event_id in originals1:
+ if (
+ event["type"] == "m.room.redaction"
+ and event["redacts"] == event_id
+ ):
+ matches.append((event_id, event))
+ # we redacted 6 messages
+ self.assertEqual(len(matches), 6)
+
+ # ban user from room 2
+ channel4 = self.make_request(
+ "POST",
+ f"/_matrix/client/r0/rooms/{self.rm2}/ban",
+ content={"reason": "being a bummer", "user_id": self.bad_user},
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel4.code, HTTPStatus.OK, channel4.result)
+
+ # make a request to ban all user's messages
+ channel5 = self.make_request(
+ "POST",
+ f"/_synapse/admin/v1/user/{self.bad_user}/redact",
+ content={"rooms": []},
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel5.code, 200)
+ id2 = channel5.json_body.get("redact_id")
+
+ # check that there were no failed redactions in room 2
+ channel6 = self.make_request(
+ "GET",
+ f"/_synapse/admin/v1/user/redact_status/{id2}",
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel6.code, 200)
+ self.assertEqual(channel6.json_body.get("status"), "complete")
+ failed_redactions = channel6.json_body.get("failed_redactions")
+ self.assertEqual(failed_redactions, {})
+
+ # double check messages in room 2 were redacted
+ filter = json.dumps({"types": [EventTypes.Redaction]})
+ channel7 = self.make_request(
+ "GET",
+ f"rooms/{self.rm2}/messages?filter={filter}&limit=50",
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel7.code, 200)
+
+ matches = []
+ for event in channel7.json_body["chunk"]:
+ for event_id in originals2:
+ if event["type"] == "m.room.redaction" and event["redacts"] == event_id:
+ matches.append((event_id, event))
+ # we redacted 6 messages
+ self.assertEqual(len(matches), 6)
+
+ def test_redactions_for_remote_user_succeed_with_admin_priv_in_room(self) -> None:
+ """
+ Test that if the admin requester has privileges in a room, redaction requests
+ succeed for a remote user
+ """
+
+ # inject some messages from remote user and collect event ids
+ original_message_ids = []
+ for i in range(5):
+ event = self.get_success(
+ inject_event(
+ self.hs,
+ room_id=self.rm1,
+ type="m.room.message",
+ sender="@remote:remote_server",
+ content={"msgtype": "m.text", "body": f"nefarious_chatter{i}"},
+ )
+ )
+ original_message_ids.append(event.event_id)
+
+ # send a request to redact a remote user's messages in a room.
+ # the server admin created this room and has admin privilege in room
+ channel = self.make_request(
+ "POST",
+ "/_synapse/admin/v1/user/@remote:remote_server/redact",
+ content={"rooms": [self.rm1]},
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel.code, 200)
+ id = channel.json_body.get("redact_id")
+
+ # check that there were no failed redactions
+ channel = self.make_request(
+ "GET",
+ f"/_synapse/admin/v1/user/redact_status/{id}",
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body.get("status"), "complete")
+ failed_redactions = channel.json_body.get("failed_redactions")
+ self.assertEqual(failed_redactions, {})
+
+ filter = json.dumps({"types": [EventTypes.Redaction]})
+ channel = self.make_request(
+ "GET",
+ f"rooms/{self.rm1}/messages?filter={filter}&limit=50",
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel.code, 200)
+
+ for event in channel.json_body["chunk"]:
+ for event_id in original_message_ids:
+ if event["type"] == "m.room.redaction" and event["redacts"] == event_id:
+ original_message_ids.remove(event_id)
+ break
+ # we originally sent 5 messages so 5 should be redacted
+ self.assertEqual(len(original_message_ids), 0)
+
+ def test_redact_redacts_encrypted_messages(self) -> None:
+ """
+ Test that user's encrypted messages are redacted
+ """
+ encrypted_room = self.helper.create_room_as(
+ self.admin, tok=self.admin_tok, room_version="7"
+ )
+ self.helper.send_state(
+ encrypted_room,
+ EventTypes.RoomEncryption,
+ {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"},
+ tok=self.admin_tok,
+ )
+ # join room send some messages
+ originals = []
+ join = self.helper.join(encrypted_room, self.bad_user, tok=self.bad_user_tok)
+ originals.append(join["event_id"])
+ for _ in range(15):
+ res = self.helper.send_event(
+ encrypted_room, "m.room.encrypted", {}, tok=self.bad_user_tok
+ )
+ originals.append(res["event_id"])
+
+ # redact user's events
+ channel = self.make_request(
+ "POST",
+ f"/_synapse/admin/v1/user/{self.bad_user}/redact",
+ content={"rooms": []},
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel.code, 200)
+
+ matched = []
+ filter = json.dumps({"types": [EventTypes.Redaction]})
+ channel = self.make_request(
+ "GET",
+ f"rooms/{encrypted_room}/messages?filter={filter}&limit=50",
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel.code, 200)
+
+ for event in channel.json_body["chunk"]:
+ for event_id in originals:
+ if event["type"] == "m.room.redaction" and event["redacts"] == event_id:
+ matched.append(event_id)
+ self.assertEqual(len(matched), len(originals))
+
+
+class UserRedactionBackgroundTaskTestCase(BaseMultiWorkerStreamTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ admin.register_servlets,
+ room.register_servlets,
+ sync.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.admin = self.register_user("thomas", "pass", True)
+ self.admin_tok = self.login("thomas", "pass")
+
+ self.bad_user = self.register_user("teresa", "pass")
+ self.bad_user_tok = self.login("teresa", "pass")
+
+ # create rooms - room versions 11+ store the `redacts` key in content while
+ # earlier ones don't so we use a mix of room versions
+ self.rm1 = self.helper.create_room_as(
+ self.admin, tok=self.admin_tok, room_version="7"
+ )
+ self.rm2 = self.helper.create_room_as(self.admin, tok=self.admin_tok)
+ self.rm3 = self.helper.create_room_as(
+ self.admin, tok=self.admin_tok, room_version="11"
+ )
+
+ @override_config({"run_background_tasks_on": "worker1"})
+ def test_redact_messages_all_rooms(self) -> None:
+ """
+ Test that redact task successfully runs when `run_background_tasks_on` is specified
+ """
+ self.make_worker_hs(
+ "synapse.app.generic_worker",
+ extra_config={
+ "worker_name": "worker1",
+ "run_background_tasks_on": "worker1",
+ "redis": {"enabled": True},
+ },
+ )
+
+ # join rooms, send some messages
+ original_event_ids = set()
+ for rm in [self.rm1, self.rm2, self.rm3]:
+ join = self.helper.join(rm, self.bad_user, tok=self.bad_user_tok)
+ original_event_ids.add(join["event_id"])
+ for i in range(15):
+ event = {"body": f"hello{i}", "msgtype": "m.text"}
+ res = self.helper.send_event(
+ rm, "m.room.message", event, tok=self.bad_user_tok, expect_code=200
+ )
+ original_event_ids.add(res["event_id"])
+
+ # redact all events in all rooms
+ channel = self.make_request(
+ "POST",
+ f"/_synapse/admin/v1/user/{self.bad_user}/redact",
+ content={"rooms": []},
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel.code, 200)
+ id = channel.json_body.get("redact_id")
+
+ timeout_s = 10
+ start_time = time.time()
+ redact_result = ""
+ while redact_result != "complete":
+ if start_time + timeout_s < time.time():
+ self.fail("Timed out waiting for redactions.")
+
+ channel2 = self.make_request(
+ "GET",
+ f"/_synapse/admin/v1/user/redact_status/{id}",
+ access_token=self.admin_tok,
+ )
+ redact_result = channel2.json_body["status"]
+ if redact_result == "failed":
+ self.fail("Redaction task failed.")
+
+ redaction_ids = set()
+ for rm in [self.rm1, self.rm2, self.rm3]:
+ filter = json.dumps({"types": [EventTypes.Redaction]})
+ channel = self.make_request(
+ "GET",
+ f"rooms/{rm}/messages?filter={filter}&limit=50",
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel.code, 200)
+
+ for event in channel.json_body["chunk"]:
+ if event["type"] == "m.room.redaction":
+ redaction_ids.add(event["redacts"])
+
+ self.assertIncludes(redaction_ids, original_event_ids, exact=True)
+
+
+class GetInvitesFromUserTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ admin.register_servlets,
+ room.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.admin = self.register_user("thomas", "pass", True)
+ self.admin_tok = self.login("thomas", "pass")
+
+ self.bad_user = self.register_user("teresa", "pass")
+ self.bad_user_tok = self.login("teresa", "pass")
+
+ self.random_users = []
+ for i in range(4):
+ self.random_users.append(self.register_user(f"user{i}", f"pass{i}"))
+
+ self.room1 = self.helper.create_room_as(self.bad_user, tok=self.bad_user_tok)
+ self.room2 = self.helper.create_room_as(self.bad_user, tok=self.bad_user_tok)
+ self.room3 = self.helper.create_room_as(self.bad_user, tok=self.bad_user_tok)
+
+ @unittest.override_config(
+ {"rc_invites": {"per_issuer": {"per_second": 1000, "burst_count": 1000}}}
+ )
+ def test_get_user_invite_count_new_invites_test_case(self) -> None:
+ """
+ Test that new invites that arrive after a provided timestamp are counted
+ """
+ # grab a current timestamp
+ before_invites_sent_ts = self.hs.get_clock().time_msec()
+
+ # bad user sends some invites
+ for room_id in [self.room1, self.room2]:
+ for user in self.random_users:
+ self.helper.invite(room_id, self.bad_user, user, tok=self.bad_user_tok)
+
+ # fetch using timestamp, all should be returned
+ channel = self.make_request(
+ "GET",
+ f"/_synapse/admin/v1/users/{self.bad_user}/sent_invite_count?from_ts={before_invites_sent_ts}",
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body["invite_count"], 8)
+
+ # send some more invites, they should show up in addition to original 8 using same timestamp
+ for user in self.random_users:
+ self.helper.invite(
+ self.room3, src=self.bad_user, targ=user, tok=self.bad_user_tok
+ )
+
+ channel = self.make_request(
+ "GET",
+ f"/_synapse/admin/v1/users/{self.bad_user}/sent_invite_count?from_ts={before_invites_sent_ts}",
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body["invite_count"], 12)
+
+ def test_get_user_invite_count_invites_before_ts_test_case(self) -> None:
+ """
+ Test that invites sent before provided ts are not counted
+ """
+ # bad user sends some invites
+ for room_id in [self.room1, self.room2]:
+ for user in self.random_users:
+ self.helper.invite(room_id, self.bad_user, user, tok=self.bad_user_tok)
+
+ # add a msec between last invite and ts
+ after_invites_sent_ts = self.hs.get_clock().time_msec() + 1
+
+ # fetch invites with timestamp, none should be returned
+ channel = self.make_request(
+ "GET",
+ f"/_synapse/admin/v1/users/{self.bad_user}/sent_invite_count?from_ts={after_invites_sent_ts}",
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body["invite_count"], 0)
+
+ def test_user_invite_count_kick_ban_not_counted(self) -> None:
+ """
+ Test that kicks and bans are not counted in invite count
+ """
+ to_kick_user_id = self.register_user("kick_me", "pass")
+ to_kick_tok = self.login("kick_me", "pass")
+
+ self.helper.join(self.room1, to_kick_user_id, tok=to_kick_tok)
+
+ # grab a current timestamp
+ before_invites_sent_ts = self.hs.get_clock().time_msec()
+
+ # bad user sends some invites (8)
+ for room_id in [self.room1, self.room2]:
+ for user in self.random_users:
+ self.helper.invite(
+ room_id, src=self.bad_user, targ=user, tok=self.bad_user_tok
+ )
+
+ # fetch using timestamp, all invites sent should be counted
+ channel = self.make_request(
+ "GET",
+ f"/_synapse/admin/v1/users/{self.bad_user}/sent_invite_count?from_ts={before_invites_sent_ts}",
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body["invite_count"], 8)
+
+ # send a kick and some bans and make sure these aren't counted against invite total
+ for user in self.random_users:
+ self.helper.ban(
+ self.room1, src=self.bad_user, targ=user, tok=self.bad_user_tok
+ )
+
+ channel = self.make_request(
+ "POST",
+ f"/_matrix/client/v3/rooms/{self.room1}/kick",
+ content={"user_id": to_kick_user_id},
+ access_token=self.bad_user_tok,
+ )
+ self.assertEqual(channel.code, 200)
+
+ channel = self.make_request(
+ "GET",
+ f"/_synapse/admin/v1/users/{self.bad_user}/sent_invite_count?from_ts={before_invites_sent_ts}",
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body["invite_count"], 8)
+
+
+class GetCumulativeJoinedRoomCountForUserTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ admin.register_servlets,
+ room.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.admin = self.register_user("thomas", "pass", True)
+ self.admin_tok = self.login("thomas", "pass")
+
+ self.bad_user = self.register_user("teresa", "pass")
+ self.bad_user_tok = self.login("teresa", "pass")
+
+ def test_user_cumulative_joined_room_count(self) -> None:
+ """
+ Tests proper count returned from /cumulative_joined_room_count endpoint
+ """
+ # Create rooms and join, grab timestamp before room creation
+ before_room_creation_timestamp = self.hs.get_clock().time_msec()
+
+ joined_rooms = []
+ for _ in range(3):
+ room = self.helper.create_room_as(self.admin, tok=self.admin_tok)
+ self.helper.join(
+ room, user=self.bad_user, expect_code=200, tok=self.bad_user_tok
+ )
+ joined_rooms.append(room)
+
+ # get a timestamp after room creation and join, add a msec between last join and ts
+ after_room_creation = self.hs.get_clock().time_msec() + 1
+
+ # Get rooms using this timestamp, there should be none since all rooms were created and joined
+ # before provided timestamp
+ channel = self.make_request(
+ "GET",
+ f"/_synapse/admin/v1/users/{self.bad_user}/cumulative_joined_room_count?from_ts={int(after_room_creation)}",
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(0, channel.json_body["cumulative_joined_room_count"])
+
+ # fetch rooms with the older timestamp before they were created and joined, this should
+ # return the rooms
+ channel = self.make_request(
+ "GET",
+ f"/_synapse/admin/v1/users/{self.bad_user}/cumulative_joined_room_count?from_ts={int(before_room_creation_timestamp)}",
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(
+ len(joined_rooms), channel.json_body["cumulative_joined_room_count"]
+ )
+
+ def test_user_joined_room_count_includes_left_and_banned_rooms(self) -> None:
+ """
+ Tests proper count returned from /joined_room_count endpoint when user has left
+ or been banned from joined rooms
+ """
+ # Create rooms and join, grab timestamp before room creation
+ before_room_creation_timestamp = self.hs.get_clock().time_msec()
+
+ joined_rooms = []
+ for _ in range(3):
+ room = self.helper.create_room_as(self.admin, tok=self.admin_tok)
+ self.helper.join(
+ room, user=self.bad_user, expect_code=200, tok=self.bad_user_tok
+ )
+ joined_rooms.append(room)
+
+ # fetch rooms with the older timestamp before they were created and joined
+ channel = self.make_request(
+ "GET",
+ f"/_synapse/admin/v1/users/{self.bad_user}/cumulative_joined_room_count?from_ts={int(before_room_creation_timestamp)}",
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(
+ len(joined_rooms), channel.json_body["cumulative_joined_room_count"]
+ )
+
+ # have the user banned from/leave the joined rooms
+ self.helper.ban(
+ joined_rooms[0],
+ src=self.admin,
+ targ=self.bad_user,
+ expect_code=200,
+ tok=self.admin_tok,
+ )
+ self.helper.change_membership(
+ joined_rooms[1],
+ src=self.bad_user,
+ targ=self.bad_user,
+ membership="leave",
+ expect_code=200,
+ tok=self.bad_user_tok,
+ )
+ self.helper.ban(
+ joined_rooms[2],
+ src=self.admin,
+ targ=self.bad_user,
+ expect_code=200,
+ tok=self.admin_tok,
+ )
+
+ # fetch the joined room count again, the number should remain the same as the collected joined rooms
+ channel = self.make_request(
+ "GET",
+ f"/_synapse/admin/v1/users/{self.bad_user}/cumulative_joined_room_count?from_ts={int(before_room_creation_timestamp)}",
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(
+ len(joined_rooms), channel.json_body["cumulative_joined_room_count"]
+ )
diff --git a/tests/rest/client/sliding_sync/test_connection_tracking.py b/tests/rest/client/sliding_sync/test_connection_tracking.py
index 6863c32f7c..5b819103c2 100644
--- a/tests/rest/client/sliding_sync/test_connection_tracking.py
+++ b/tests/rest/client/sliding_sync/test_connection_tracking.py
@@ -13,7 +13,7 @@
#
import logging
-from parameterized import parameterized
+from parameterized import parameterized, parameterized_class
from twisted.test.proto_helpers import MemoryReactor
@@ -28,6 +28,20 @@ from tests.rest.client.sliding_sync.test_sliding_sync import SlidingSyncBase
logger = logging.getLogger(__name__)
+# FIXME: This can be removed once we bump `SCHEMA_COMPAT_VERSION` and run the
+# foreground update for
+# `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` (tracked by
+# https://github.com/element-hq/synapse/issues/17623)
+@parameterized_class(
+ ("use_new_tables",),
+ [
+ (True,),
+ (False,),
+ ],
+ class_name_func=lambda cls,
+ num,
+ params_dict: f"{cls.__name__}_{'new' if params_dict['use_new_tables'] else 'fallback'}",
+)
class SlidingSyncConnectionTrackingTestCase(SlidingSyncBase):
"""
Test connection tracking in the Sliding Sync API.
@@ -44,6 +58,8 @@ class SlidingSyncConnectionTrackingTestCase(SlidingSyncBase):
self.store = hs.get_datastores().main
self.storage_controllers = hs.get_storage_controllers()
+ super().prepare(reactor, clock, hs)
+
def test_rooms_required_state_incremental_sync_LIVE(self) -> None:
"""Test that we only get state updates in incremental sync for rooms
we've already seen (LIVE).
diff --git a/tests/rest/client/sliding_sync/test_extension_account_data.py b/tests/rest/client/sliding_sync/test_extension_account_data.py
index 3482a5f887..799fbb1856 100644
--- a/tests/rest/client/sliding_sync/test_extension_account_data.py
+++ b/tests/rest/client/sliding_sync/test_extension_account_data.py
@@ -11,8 +11,12 @@
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
+import enum
import logging
+from parameterized import parameterized, parameterized_class
+from typing_extensions import assert_never
+
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
@@ -28,6 +32,25 @@ from tests.server import TimedOutException
logger = logging.getLogger(__name__)
+class TagAction(enum.Enum):
+ ADD = enum.auto()
+ REMOVE = enum.auto()
+
+
+# FIXME: This can be removed once we bump `SCHEMA_COMPAT_VERSION` and run the
+# foreground update for
+# `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` (tracked by
+# https://github.com/element-hq/synapse/issues/17623)
+@parameterized_class(
+ ("use_new_tables",),
+ [
+ (True,),
+ (False,),
+ ],
+ class_name_func=lambda cls,
+ num,
+ params_dict: f"{cls.__name__}_{'new' if params_dict['use_new_tables'] else 'fallback'}",
+)
class SlidingSyncAccountDataExtensionTestCase(SlidingSyncBase):
"""Tests for the account_data sliding sync extension"""
@@ -43,6 +66,8 @@ class SlidingSyncAccountDataExtensionTestCase(SlidingSyncBase):
self.store = hs.get_datastores().main
self.account_data_handler = hs.get_account_data_handler()
+ super().prepare(reactor, clock, hs)
+
def test_no_data_initial_sync(self) -> None:
"""
Test that enabling the account_data extension works during an intitial sync,
@@ -62,18 +87,23 @@ class SlidingSyncAccountDataExtensionTestCase(SlidingSyncBase):
}
response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ global_account_data_map = {
+ global_event["type"]: global_event["content"]
+ for global_event in response_body["extensions"]["account_data"].get(
+ "global"
+ )
+ }
self.assertIncludes(
- {
- global_event["type"]
- for global_event in response_body["extensions"]["account_data"].get(
- "global"
- )
- },
+ global_account_data_map.keys(),
# Even though we don't have any global account data set, Synapse saves some
# default push rules for us.
{AccountDataTypes.PUSH_RULES},
exact=True,
)
+ # Push rules are a giant chunk of JSON data so we will just assume the value is correct if they key is here.
+ # global_account_data_map[AccountDataTypes.PUSH_RULES]
+
+ # No room account data for this test
self.assertIncludes(
response_body["extensions"]["account_data"].get("rooms").keys(),
set(),
@@ -103,16 +133,19 @@ class SlidingSyncAccountDataExtensionTestCase(SlidingSyncBase):
# There has been no account data changes since the `from_token` so we shouldn't
# see any account data here.
+ global_account_data_map = {
+ global_event["type"]: global_event["content"]
+ for global_event in response_body["extensions"]["account_data"].get(
+ "global"
+ )
+ }
self.assertIncludes(
- {
- global_event["type"]
- for global_event in response_body["extensions"]["account_data"].get(
- "global"
- )
- },
+ global_account_data_map.keys(),
set(),
exact=True,
)
+
+ # No room account data for this test
self.assertIncludes(
response_body["extensions"]["account_data"].get("rooms").keys(),
set(),
@@ -147,16 +180,24 @@ class SlidingSyncAccountDataExtensionTestCase(SlidingSyncBase):
response_body, _ = self.do_sync(sync_body, tok=user1_tok)
# It should show us all of the global account data
+ global_account_data_map = {
+ global_event["type"]: global_event["content"]
+ for global_event in response_body["extensions"]["account_data"].get(
+ "global"
+ )
+ }
self.assertIncludes(
- {
- global_event["type"]
- for global_event in response_body["extensions"]["account_data"].get(
- "global"
- )
- },
+ global_account_data_map.keys(),
{AccountDataTypes.PUSH_RULES, "org.matrix.foobarbaz"},
exact=True,
)
+ # Push rules are a giant chunk of JSON data so we will just assume the value is correct if they key is here.
+ # global_account_data_map[AccountDataTypes.PUSH_RULES]
+ self.assertEqual(
+ global_account_data_map["org.matrix.foobarbaz"], {"foo": "bar"}
+ )
+
+ # No room account data for this test
self.assertIncludes(
response_body["extensions"]["account_data"].get("rooms").keys(),
set(),
@@ -202,17 +243,23 @@ class SlidingSyncAccountDataExtensionTestCase(SlidingSyncBase):
# Make an incremental Sliding Sync request with the account_data extension enabled
response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
+ global_account_data_map = {
+ global_event["type"]: global_event["content"]
+ for global_event in response_body["extensions"]["account_data"].get(
+ "global"
+ )
+ }
self.assertIncludes(
- {
- global_event["type"]
- for global_event in response_body["extensions"]["account_data"].get(
- "global"
- )
- },
+ global_account_data_map.keys(),
# We should only see the new global account data that happened after the `from_token`
{"org.matrix.doodardaz"},
exact=True,
)
+ self.assertEqual(
+ global_account_data_map["org.matrix.doodardaz"], {"doo": "dar"}
+ )
+
+ # No room account data for this test
self.assertIncludes(
response_body["extensions"]["account_data"].get("rooms").keys(),
set(),
@@ -237,6 +284,15 @@ class SlidingSyncAccountDataExtensionTestCase(SlidingSyncBase):
content={"roo": "rar"},
)
)
+ # Add a room tag to mark the room as a favourite
+ self.get_success(
+ self.account_data_handler.add_tag_to_room(
+ user_id=user1_id,
+ room_id=room_id1,
+ tag="m.favourite",
+ content={},
+ )
+ )
# Create another room with some room account data
room_id2 = self.helper.create_room_as(user1_id, tok=user1_tok)
@@ -248,6 +304,15 @@ class SlidingSyncAccountDataExtensionTestCase(SlidingSyncBase):
content={"roo": "rar"},
)
)
+ # Add a room tag to mark the room as a favourite
+ self.get_success(
+ self.account_data_handler.add_tag_to_room(
+ user_id=user1_id,
+ room_id=room_id2,
+ tag="m.favourite",
+ content={},
+ )
+ )
# Make an initial Sliding Sync request with the account_data extension enabled
sync_body = {
@@ -276,21 +341,36 @@ class SlidingSyncAccountDataExtensionTestCase(SlidingSyncBase):
{room_id1},
exact=True,
)
+ account_data_map = {
+ event["type"]: event["content"]
+ for event in response_body["extensions"]["account_data"]
+ .get("rooms")
+ .get(room_id1)
+ }
self.assertIncludes(
- {
- event["type"]
- for event in response_body["extensions"]["account_data"]
- .get("rooms")
- .get(room_id1)
- },
- {"org.matrix.roorarraz"},
+ account_data_map.keys(),
+ {"org.matrix.roorarraz", AccountDataTypes.TAG},
exact=True,
)
+ self.assertEqual(account_data_map["org.matrix.roorarraz"], {"roo": "rar"})
+ self.assertEqual(
+ account_data_map[AccountDataTypes.TAG], {"tags": {"m.favourite": {}}}
+ )
- def test_room_account_data_incremental_sync(self) -> None:
+ @parameterized.expand(
+ [
+ ("add tags", TagAction.ADD),
+ ("remove tags", TagAction.REMOVE),
+ ]
+ )
+ def test_room_account_data_incremental_sync(
+ self, test_description: str, tag_action: TagAction
+ ) -> None:
"""
On incremental sync, we return all account data for a given room but only for
rooms that we request and are being returned in the Sliding Sync response.
+
+ (HaveSentRoomFlag.LIVE)
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
@@ -305,6 +385,15 @@ class SlidingSyncAccountDataExtensionTestCase(SlidingSyncBase):
content={"roo": "rar"},
)
)
+ # Add a room tag to mark the room as a favourite
+ self.get_success(
+ self.account_data_handler.add_tag_to_room(
+ user_id=user1_id,
+ room_id=room_id1,
+ tag="m.favourite",
+ content={},
+ )
+ )
# Create another room with some room account data
room_id2 = self.helper.create_room_as(user1_id, tok=user1_tok)
@@ -316,6 +405,15 @@ class SlidingSyncAccountDataExtensionTestCase(SlidingSyncBase):
content={"roo": "rar"},
)
)
+ # Add a room tag to mark the room as a favourite
+ self.get_success(
+ self.account_data_handler.add_tag_to_room(
+ user_id=user1_id,
+ room_id=room_id2,
+ tag="m.favourite",
+ content={},
+ )
+ )
sync_body = {
"lists": {},
@@ -351,6 +449,42 @@ class SlidingSyncAccountDataExtensionTestCase(SlidingSyncBase):
content={"roo": "rar"},
)
)
+ if tag_action == TagAction.ADD:
+ # Add another room tag
+ self.get_success(
+ self.account_data_handler.add_tag_to_room(
+ user_id=user1_id,
+ room_id=room_id1,
+ tag="m.server_notice",
+ content={},
+ )
+ )
+ self.get_success(
+ self.account_data_handler.add_tag_to_room(
+ user_id=user1_id,
+ room_id=room_id2,
+ tag="m.server_notice",
+ content={},
+ )
+ )
+ elif tag_action == TagAction.REMOVE:
+ # Remove the room tag
+ self.get_success(
+ self.account_data_handler.remove_tag_from_room(
+ user_id=user1_id,
+ room_id=room_id1,
+ tag="m.favourite",
+ )
+ )
+ self.get_success(
+ self.account_data_handler.remove_tag_from_room(
+ user_id=user1_id,
+ room_id=room_id2,
+ tag="m.favourite",
+ )
+ )
+ else:
+ assert_never(tag_action)
# Make an incremental Sliding Sync request with the account_data extension enabled
response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
@@ -365,17 +499,444 @@ class SlidingSyncAccountDataExtensionTestCase(SlidingSyncBase):
exact=True,
)
# We should only see the new room account data that happened after the `from_token`
+ account_data_map = {
+ event["type"]: event["content"]
+ for event in response_body["extensions"]["account_data"]
+ .get("rooms")
+ .get(room_id1)
+ }
self.assertIncludes(
+ account_data_map.keys(),
+ {"org.matrix.roorarraz2", AccountDataTypes.TAG},
+ exact=True,
+ )
+ self.assertEqual(account_data_map["org.matrix.roorarraz2"], {"roo": "rar"})
+ if tag_action == TagAction.ADD:
+ self.assertEqual(
+ account_data_map[AccountDataTypes.TAG],
+ {"tags": {"m.favourite": {}, "m.server_notice": {}}},
+ )
+ elif tag_action == TagAction.REMOVE:
+ # If we previously showed the client that the room has tags, when it no
+ # longer has tags, we need to show them an empty map.
+ self.assertEqual(
+ account_data_map[AccountDataTypes.TAG],
+ {"tags": {}},
+ )
+ else:
+ assert_never(tag_action)
+
+ @parameterized.expand(
+ [
+ ("add tags", TagAction.ADD),
+ ("remove tags", TagAction.REMOVE),
+ ]
+ )
+ def test_room_account_data_incremental_sync_out_of_range_never(
+ self, test_description: str, tag_action: TagAction
+ ) -> None:
+ """Tests that we don't return account data for rooms that are out of
+ range, but then do send all account data once they're in range.
+
+ (initial/HaveSentRoomFlag.NEVER)
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ # Create a room and add some room account data
+ room_id1 = self.helper.create_room_as(user1_id, tok=user1_tok)
+ self.get_success(
+ self.account_data_handler.add_account_data_to_room(
+ user_id=user1_id,
+ room_id=room_id1,
+ account_data_type="org.matrix.roorarraz",
+ content={"roo": "rar"},
+ )
+ )
+ # Add a room tag to mark the room as a favourite
+ self.get_success(
+ self.account_data_handler.add_tag_to_room(
+ user_id=user1_id,
+ room_id=room_id1,
+ tag="m.favourite",
+ content={},
+ )
+ )
+
+ # Create another room with some room account data
+ room_id2 = self.helper.create_room_as(user1_id, tok=user1_tok)
+ self.get_success(
+ self.account_data_handler.add_account_data_to_room(
+ user_id=user1_id,
+ room_id=room_id2,
+ account_data_type="org.matrix.roorarraz",
+ content={"roo": "rar"},
+ )
+ )
+ # Add a room tag to mark the room as a favourite
+ self.get_success(
+ self.account_data_handler.add_tag_to_room(
+ user_id=user1_id,
+ room_id=room_id2,
+ tag="m.favourite",
+ content={},
+ )
+ )
+
+ # Now send a message into room1 so that it is at the top of the list
+ self.helper.send(room_id1, body="new event", tok=user1_tok)
+
+ # Make a SS request for only the top room.
+ sync_body = {
+ "lists": {
+ "main": {
+ "ranges": [[0, 0]],
+ "required_state": [],
+ "timeline_limit": 0,
+ }
+ },
+ "extensions": {
+ "account_data": {
+ "enabled": True,
+ "lists": ["main"],
+ }
+ },
+ }
+ response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+
+ # Only room1 should be in the response since it's the latest room with activity
+ # and our range only includes 1 room.
+ self.assertIncludes(
+ response_body["extensions"]["account_data"].get("rooms").keys(),
+ {room_id1},
+ exact=True,
+ )
+
+ # Add some other room account data
+ self.get_success(
+ self.account_data_handler.add_account_data_to_room(
+ user_id=user1_id,
+ room_id=room_id1,
+ account_data_type="org.matrix.roorarraz2",
+ content={"roo": "rar"},
+ )
+ )
+ self.get_success(
+ self.account_data_handler.add_account_data_to_room(
+ user_id=user1_id,
+ room_id=room_id2,
+ account_data_type="org.matrix.roorarraz2",
+ content={"roo": "rar"},
+ )
+ )
+ if tag_action == TagAction.ADD:
+ # Add another room tag
+ self.get_success(
+ self.account_data_handler.add_tag_to_room(
+ user_id=user1_id,
+ room_id=room_id1,
+ tag="m.server_notice",
+ content={},
+ )
+ )
+ self.get_success(
+ self.account_data_handler.add_tag_to_room(
+ user_id=user1_id,
+ room_id=room_id2,
+ tag="m.server_notice",
+ content={},
+ )
+ )
+ elif tag_action == TagAction.REMOVE:
+ # Remove the room tag
+ self.get_success(
+ self.account_data_handler.remove_tag_from_room(
+ user_id=user1_id,
+ room_id=room_id1,
+ tag="m.favourite",
+ )
+ )
+ self.get_success(
+ self.account_data_handler.remove_tag_from_room(
+ user_id=user1_id,
+ room_id=room_id2,
+ tag="m.favourite",
+ )
+ )
+ else:
+ assert_never(tag_action)
+
+ # Move room2 into range.
+ self.helper.send(room_id2, body="new event", tok=user1_tok)
+
+ # Make an incremental Sliding Sync request with the account_data extension enabled
+ response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
+
+ self.assertIsNotNone(response_body["extensions"]["account_data"].get("global"))
+ # We expect to see the account data of room2, as that has the most
+ # recent update.
+ self.assertIncludes(
+ response_body["extensions"]["account_data"].get("rooms").keys(),
+ {room_id2},
+ exact=True,
+ )
+ # Since this is the first time we're seeing room2 down sync, we should see all
+ # room account data for it.
+ account_data_map = {
+ event["type"]: event["content"]
+ for event in response_body["extensions"]["account_data"]
+ .get("rooms")
+ .get(room_id2)
+ }
+ expected_account_data_keys = {
+ "org.matrix.roorarraz",
+ "org.matrix.roorarraz2",
+ }
+ if tag_action == TagAction.ADD:
+ expected_account_data_keys.add(AccountDataTypes.TAG)
+ self.assertIncludes(
+ account_data_map.keys(),
+ expected_account_data_keys,
+ exact=True,
+ )
+ self.assertEqual(account_data_map["org.matrix.roorarraz"], {"roo": "rar"})
+ self.assertEqual(account_data_map["org.matrix.roorarraz2"], {"roo": "rar"})
+ if tag_action == TagAction.ADD:
+ self.assertEqual(
+ account_data_map[AccountDataTypes.TAG],
+ {"tags": {"m.favourite": {}, "m.server_notice": {}}},
+ )
+ elif tag_action == TagAction.REMOVE:
+ # Since we never told the client about the room tags, we don't need to say
+ # anything if there are no tags now (the client doesn't need an update).
+ self.assertIsNone(
+ account_data_map.get(AccountDataTypes.TAG),
+ account_data_map,
+ )
+ else:
+ assert_never(tag_action)
+
+ @parameterized.expand(
+ [
+ ("add tags", TagAction.ADD),
+ ("remove tags", TagAction.REMOVE),
+ ]
+ )
+ def test_room_account_data_incremental_sync_out_of_range_previously(
+ self, test_description: str, tag_action: TagAction
+ ) -> None:
+ """Tests that we don't return account data for rooms that fall out of
+ range, but then do send all account data that has changed they're back in range.
+
+ (HaveSentRoomFlag.PREVIOUSLY)
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ # Create a room and add some room account data
+ room_id1 = self.helper.create_room_as(user1_id, tok=user1_tok)
+ self.get_success(
+ self.account_data_handler.add_account_data_to_room(
+ user_id=user1_id,
+ room_id=room_id1,
+ account_data_type="org.matrix.roorarraz",
+ content={"roo": "rar"},
+ )
+ )
+ # Add a room tag to mark the room as a favourite
+ self.get_success(
+ self.account_data_handler.add_tag_to_room(
+ user_id=user1_id,
+ room_id=room_id1,
+ tag="m.favourite",
+ content={},
+ )
+ )
+
+ # Create another room with some room account data
+ room_id2 = self.helper.create_room_as(user1_id, tok=user1_tok)
+ self.get_success(
+ self.account_data_handler.add_account_data_to_room(
+ user_id=user1_id,
+ room_id=room_id2,
+ account_data_type="org.matrix.roorarraz",
+ content={"roo": "rar"},
+ )
+ )
+ # Add a room tag to mark the room as a favourite
+ self.get_success(
+ self.account_data_handler.add_tag_to_room(
+ user_id=user1_id,
+ room_id=room_id2,
+ tag="m.favourite",
+ content={},
+ )
+ )
+
+ # Make an initial Sliding Sync request for only room1 and room2.
+ sync_body = {
+ "lists": {},
+ "room_subscriptions": {
+ room_id1: {
+ "required_state": [],
+ "timeline_limit": 0,
+ },
+ room_id2: {
+ "required_state": [],
+ "timeline_limit": 0,
+ },
+ },
+ "extensions": {
+ "account_data": {
+ "enabled": True,
+ "rooms": [room_id1, room_id2],
+ }
+ },
+ }
+ response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+
+ # Both rooms show up because we have a room subscription for each and they're
+ # requested in the `account_data` extension.
+ self.assertIncludes(
+ response_body["extensions"]["account_data"].get("rooms").keys(),
+ {room_id1, room_id2},
+ exact=True,
+ )
+
+ # Add some other room account data
+ self.get_success(
+ self.account_data_handler.add_account_data_to_room(
+ user_id=user1_id,
+ room_id=room_id1,
+ account_data_type="org.matrix.roorarraz2",
+ content={"roo": "rar"},
+ )
+ )
+ self.get_success(
+ self.account_data_handler.add_account_data_to_room(
+ user_id=user1_id,
+ room_id=room_id2,
+ account_data_type="org.matrix.roorarraz2",
+ content={"roo": "rar"},
+ )
+ )
+ if tag_action == TagAction.ADD:
+ # Add another room tag
+ self.get_success(
+ self.account_data_handler.add_tag_to_room(
+ user_id=user1_id,
+ room_id=room_id1,
+ tag="m.server_notice",
+ content={},
+ )
+ )
+ self.get_success(
+ self.account_data_handler.add_tag_to_room(
+ user_id=user1_id,
+ room_id=room_id2,
+ tag="m.server_notice",
+ content={},
+ )
+ )
+ elif tag_action == TagAction.REMOVE:
+ # Remove the room tag
+ self.get_success(
+ self.account_data_handler.remove_tag_from_room(
+ user_id=user1_id,
+ room_id=room_id1,
+ tag="m.favourite",
+ )
+ )
+ self.get_success(
+ self.account_data_handler.remove_tag_from_room(
+ user_id=user1_id,
+ room_id=room_id2,
+ tag="m.favourite",
+ )
+ )
+ else:
+ assert_never(tag_action)
+
+ # Make an incremental Sliding Sync request for just room1
+ response_body, from_token = self.do_sync(
{
- event["type"]
- for event in response_body["extensions"]["account_data"]
- .get("rooms")
- .get(room_id1)
+ **sync_body,
+ "room_subscriptions": {
+ room_id1: {
+ "required_state": [],
+ "timeline_limit": 0,
+ },
+ },
},
- {"org.matrix.roorarraz2"},
+ since=from_token,
+ tok=user1_tok,
+ )
+
+ # Only room1 shows up because we only have a room subscription for room1 now.
+ self.assertIncludes(
+ response_body["extensions"]["account_data"].get("rooms").keys(),
+ {room_id1},
exact=True,
)
+ # Make an incremental Sliding Sync request for just room2 now
+ response_body, from_token = self.do_sync(
+ {
+ **sync_body,
+ "room_subscriptions": {
+ room_id2: {
+ "required_state": [],
+ "timeline_limit": 0,
+ },
+ },
+ },
+ since=from_token,
+ tok=user1_tok,
+ )
+
+ # Only room2 shows up because we only have a room subscription for room2 now.
+ self.assertIncludes(
+ response_body["extensions"]["account_data"].get("rooms").keys(),
+ {room_id2},
+ exact=True,
+ )
+
+ self.assertIsNotNone(response_body["extensions"]["account_data"].get("global"))
+ # Check for room account data for room2
+ self.assertIncludes(
+ response_body["extensions"]["account_data"].get("rooms").keys(),
+ {room_id2},
+ exact=True,
+ )
+ # We should see any room account data updates for room2 since the last
+ # time we saw it down sync
+ account_data_map = {
+ event["type"]: event["content"]
+ for event in response_body["extensions"]["account_data"]
+ .get("rooms")
+ .get(room_id2)
+ }
+ self.assertIncludes(
+ account_data_map.keys(),
+ {"org.matrix.roorarraz2", AccountDataTypes.TAG},
+ exact=True,
+ )
+ self.assertEqual(account_data_map["org.matrix.roorarraz2"], {"roo": "rar"})
+ if tag_action == TagAction.ADD:
+ self.assertEqual(
+ account_data_map[AccountDataTypes.TAG],
+ {"tags": {"m.favourite": {}, "m.server_notice": {}}},
+ )
+ elif tag_action == TagAction.REMOVE:
+ # If we previously showed the client that the room has tags, when it no
+ # longer has tags, we need to show them an empty map.
+ self.assertEqual(
+ account_data_map[AccountDataTypes.TAG],
+ {"tags": {}},
+ )
+ else:
+ assert_never(tag_action)
+
def test_wait_for_new_data(self) -> None:
"""
Test to make sure that the Sliding Sync request waits for new data to arrive.
diff --git a/tests/rest/client/sliding_sync/test_extension_e2ee.py b/tests/rest/client/sliding_sync/test_extension_e2ee.py
index 320f8c788f..7ce6592d8f 100644
--- a/tests/rest/client/sliding_sync/test_extension_e2ee.py
+++ b/tests/rest/client/sliding_sync/test_extension_e2ee.py
@@ -13,6 +13,8 @@
#
import logging
+from parameterized import parameterized_class
+
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
@@ -27,6 +29,20 @@ from tests.server import TimedOutException
logger = logging.getLogger(__name__)
+# FIXME: This can be removed once we bump `SCHEMA_COMPAT_VERSION` and run the
+# foreground update for
+# `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` (tracked by
+# https://github.com/element-hq/synapse/issues/17623)
+@parameterized_class(
+ ("use_new_tables",),
+ [
+ (True,),
+ (False,),
+ ],
+ class_name_func=lambda cls,
+ num,
+ params_dict: f"{cls.__name__}_{'new' if params_dict['use_new_tables'] else 'fallback'}",
+)
class SlidingSyncE2eeExtensionTestCase(SlidingSyncBase):
"""Tests for the e2ee sliding sync extension"""
@@ -42,6 +58,8 @@ class SlidingSyncE2eeExtensionTestCase(SlidingSyncBase):
self.store = hs.get_datastores().main
self.e2e_keys_handler = hs.get_e2e_keys_handler()
+ super().prepare(reactor, clock, hs)
+
def test_no_data_initial_sync(self) -> None:
"""
Test that enabling e2ee extension works during an intitial sync, even if there
diff --git a/tests/rest/client/sliding_sync/test_extension_receipts.py b/tests/rest/client/sliding_sync/test_extension_receipts.py
index 65fbac260e..6e7700b533 100644
--- a/tests/rest/client/sliding_sync/test_extension_receipts.py
+++ b/tests/rest/client/sliding_sync/test_extension_receipts.py
@@ -13,6 +13,8 @@
#
import logging
+from parameterized import parameterized_class
+
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
@@ -28,6 +30,20 @@ from tests.server import TimedOutException
logger = logging.getLogger(__name__)
+# FIXME: This can be removed once we bump `SCHEMA_COMPAT_VERSION` and run the
+# foreground update for
+# `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` (tracked by
+# https://github.com/element-hq/synapse/issues/17623)
+@parameterized_class(
+ ("use_new_tables",),
+ [
+ (True,),
+ (False,),
+ ],
+ class_name_func=lambda cls,
+ num,
+ params_dict: f"{cls.__name__}_{'new' if params_dict['use_new_tables'] else 'fallback'}",
+)
class SlidingSyncReceiptsExtensionTestCase(SlidingSyncBase):
"""Tests for the receipts sliding sync extension"""
@@ -42,6 +58,8 @@ class SlidingSyncReceiptsExtensionTestCase(SlidingSyncBase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
+ super().prepare(reactor, clock, hs)
+
def test_no_data_initial_sync(self) -> None:
"""
Test that enabling the receipts extension works during an intitial sync,
@@ -677,3 +695,240 @@ class SlidingSyncReceiptsExtensionTestCase(SlidingSyncBase):
set(),
exact=True,
)
+
+ def test_receipts_incremental_sync_out_of_range(self) -> None:
+ """Tests that we don't return read receipts for rooms that fall out of
+ range, but then do send all read receipts once they're back in range.
+ """
+
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
+ self.helper.join(room_id1, user1_id, tok=user1_tok)
+ room_id2 = self.helper.create_room_as(user2_id, tok=user2_tok)
+ self.helper.join(room_id2, user1_id, tok=user1_tok)
+
+ # Send a message and read receipt into room2
+ event_response = self.helper.send(room_id2, body="new event", tok=user2_tok)
+ room2_event_id = event_response["event_id"]
+
+ self.helper.send_read_receipt(room_id2, room2_event_id, tok=user1_tok)
+
+ # Now send a message into room1 so that it is at the top of the list
+ self.helper.send(room_id1, body="new event", tok=user2_tok)
+
+ # Make a SS request for only the top room.
+ sync_body = {
+ "lists": {
+ "main": {
+ "ranges": [[0, 0]],
+ "required_state": [],
+ "timeline_limit": 5,
+ }
+ },
+ "extensions": {
+ "receipts": {
+ "enabled": True,
+ }
+ },
+ }
+ response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+
+ # The receipt is in room2, but only room1 is returned, so we don't
+ # expect to get the receipt.
+ self.assertIncludes(
+ response_body["extensions"]["receipts"].get("rooms").keys(),
+ set(),
+ exact=True,
+ )
+
+ # Move room2 into range.
+ self.helper.send(room_id2, body="new event", tok=user2_tok)
+
+ response_body, from_token = self.do_sync(
+ sync_body, since=from_token, tok=user1_tok
+ )
+
+ # We expect to see the read receipt of room2, as that has the most
+ # recent update.
+ self.assertIncludes(
+ response_body["extensions"]["receipts"].get("rooms").keys(),
+ {room_id2},
+ exact=True,
+ )
+ receipt = response_body["extensions"]["receipts"]["rooms"][room_id2]
+ self.assertIncludes(
+ receipt["content"][room2_event_id][ReceiptTypes.READ].keys(),
+ {user1_id},
+ exact=True,
+ )
+
+ # Send a message into room1 to bump it to the top, but also send a
+ # receipt in room2
+ self.helper.send(room_id1, body="new event", tok=user2_tok)
+ self.helper.send_read_receipt(room_id2, room2_event_id, tok=user2_tok)
+
+ # We don't expect to see the new read receipt.
+ response_body, from_token = self.do_sync(
+ sync_body, since=from_token, tok=user1_tok
+ )
+ self.assertIncludes(
+ response_body["extensions"]["receipts"].get("rooms").keys(),
+ set(),
+ exact=True,
+ )
+
+ # But if we send a new message into room2, we expect to get the missing receipts
+ self.helper.send(room_id2, body="new event", tok=user2_tok)
+
+ response_body, from_token = self.do_sync(
+ sync_body, since=from_token, tok=user1_tok
+ )
+ self.assertIncludes(
+ response_body["extensions"]["receipts"].get("rooms").keys(),
+ {room_id2},
+ exact=True,
+ )
+
+ # We should only see the new receipt
+ receipt = response_body["extensions"]["receipts"]["rooms"][room_id2]
+ self.assertIncludes(
+ receipt["content"][room2_event_id][ReceiptTypes.READ].keys(),
+ {user2_id},
+ exact=True,
+ )
+
+ def test_return_own_read_receipts(self) -> None:
+ """Test that we always send the user's own read receipts in initial
+ rooms, even if the receipts don't match events in the timeline..
+ """
+
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
+ self.helper.join(room_id1, user1_id, tok=user1_tok)
+
+ # Send a message and read receipts into room1
+ event_response = self.helper.send(room_id1, body="new event", tok=user2_tok)
+ room1_event_id = event_response["event_id"]
+
+ self.helper.send_read_receipt(room_id1, room1_event_id, tok=user1_tok)
+ self.helper.send_read_receipt(room_id1, room1_event_id, tok=user2_tok)
+
+ # Now send a message so the above message is not in the timeline.
+ self.helper.send(room_id1, body="new event", tok=user2_tok)
+
+ # Make a SS request for only the latest message.
+ sync_body = {
+ "lists": {
+ "main": {
+ "ranges": [[0, 0]],
+ "required_state": [],
+ "timeline_limit": 1,
+ }
+ },
+ "extensions": {
+ "receipts": {
+ "enabled": True,
+ }
+ },
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+
+ # We should get our own receipt in room1, even though its not in the
+ # timeline limit.
+ self.assertIncludes(
+ response_body["extensions"]["receipts"].get("rooms").keys(),
+ {room_id1},
+ exact=True,
+ )
+
+ # We should only see our read receipt, not the other user's.
+ receipt = response_body["extensions"]["receipts"]["rooms"][room_id1]
+ self.assertIncludes(
+ receipt["content"][room1_event_id][ReceiptTypes.READ].keys(),
+ {user1_id},
+ exact=True,
+ )
+
+ def test_read_receipts_expanded_timeline(self) -> None:
+ """Test that we get read receipts when we expand the timeline limit (`unstable_expanded_timeline`)."""
+
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
+ self.helper.join(room_id1, user1_id, tok=user1_tok)
+
+ # Send a message and read receipt into room1
+ event_response = self.helper.send(room_id1, body="new event", tok=user2_tok)
+ room1_event_id = event_response["event_id"]
+
+ self.helper.send_read_receipt(room_id1, room1_event_id, tok=user2_tok)
+
+ # Now send a message so the above message is not in the timeline.
+ self.helper.send(room_id1, body="new event", tok=user2_tok)
+
+ # Make a SS request for only the latest message.
+ sync_body = {
+ "lists": {
+ "main": {
+ "ranges": [[0, 0]],
+ "required_state": [],
+ "timeline_limit": 1,
+ }
+ },
+ "extensions": {
+ "receipts": {
+ "enabled": True,
+ }
+ },
+ }
+ response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+
+ # We shouldn't see user2 read receipt, as its not in the timeline
+ self.assertIncludes(
+ response_body["extensions"]["receipts"].get("rooms").keys(),
+ set(),
+ exact=True,
+ )
+
+ # Now do another request with a room subscription with an increased timeline limit
+ sync_body["room_subscriptions"] = {
+ room_id1: {
+ "required_state": [],
+ "timeline_limit": 2,
+ }
+ }
+
+ response_body, from_token = self.do_sync(
+ sync_body, since=from_token, tok=user1_tok
+ )
+
+ # Assert that we did actually get an expanded timeline
+ room_response = response_body["rooms"][room_id1]
+ self.assertNotIn("initial", room_response)
+ self.assertEqual(room_response["unstable_expanded_timeline"], True)
+
+ # We should now see user2 read receipt, as its in the expanded timeline
+ self.assertIncludes(
+ response_body["extensions"]["receipts"].get("rooms").keys(),
+ {room_id1},
+ exact=True,
+ )
+
+ # We should only see our read receipt, not the other user's.
+ receipt = response_body["extensions"]["receipts"]["rooms"][room_id1]
+ self.assertIncludes(
+ receipt["content"][room1_event_id][ReceiptTypes.READ].keys(),
+ {user2_id},
+ exact=True,
+ )
diff --git a/tests/rest/client/sliding_sync/test_extension_to_device.py b/tests/rest/client/sliding_sync/test_extension_to_device.py
index f8500812ea..790abb739d 100644
--- a/tests/rest/client/sliding_sync/test_extension_to_device.py
+++ b/tests/rest/client/sliding_sync/test_extension_to_device.py
@@ -14,6 +14,8 @@
import logging
from typing import List
+from parameterized import parameterized_class
+
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
@@ -28,6 +30,20 @@ from tests.server import TimedOutException
logger = logging.getLogger(__name__)
+# FIXME: This can be removed once we bump `SCHEMA_COMPAT_VERSION` and run the
+# foreground update for
+# `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` (tracked by
+# https://github.com/element-hq/synapse/issues/17623)
+@parameterized_class(
+ ("use_new_tables",),
+ [
+ (True,),
+ (False,),
+ ],
+ class_name_func=lambda cls,
+ num,
+ params_dict: f"{cls.__name__}_{'new' if params_dict['use_new_tables'] else 'fallback'}",
+)
class SlidingSyncToDeviceExtensionTestCase(SlidingSyncBase):
"""Tests for the to-device sliding sync extension"""
@@ -40,6 +56,7 @@ class SlidingSyncToDeviceExtensionTestCase(SlidingSyncBase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
+ super().prepare(reactor, clock, hs)
def _assert_to_device_response(
self, response_body: JsonDict, expected_messages: List[JsonDict]
diff --git a/tests/rest/client/sliding_sync/test_extension_typing.py b/tests/rest/client/sliding_sync/test_extension_typing.py
index 7f523e0f10..f87c3c8b17 100644
--- a/tests/rest/client/sliding_sync/test_extension_typing.py
+++ b/tests/rest/client/sliding_sync/test_extension_typing.py
@@ -13,6 +13,8 @@
#
import logging
+from parameterized import parameterized_class
+
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
@@ -28,6 +30,20 @@ from tests.server import TimedOutException
logger = logging.getLogger(__name__)
+# FIXME: This can be removed once we bump `SCHEMA_COMPAT_VERSION` and run the
+# foreground update for
+# `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` (tracked by
+# https://github.com/element-hq/synapse/issues/17623)
+@parameterized_class(
+ ("use_new_tables",),
+ [
+ (True,),
+ (False,),
+ ],
+ class_name_func=lambda cls,
+ num,
+ params_dict: f"{cls.__name__}_{'new' if params_dict['use_new_tables'] else 'fallback'}",
+)
class SlidingSyncTypingExtensionTestCase(SlidingSyncBase):
"""Tests for the typing notification sliding sync extension"""
@@ -41,6 +57,8 @@ class SlidingSyncTypingExtensionTestCase(SlidingSyncBase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
+ super().prepare(reactor, clock, hs)
+
def test_no_data_initial_sync(self) -> None:
"""
Test that enabling the typing extension works during an intitial sync,
diff --git a/tests/rest/client/sliding_sync/test_extensions.py b/tests/rest/client/sliding_sync/test_extensions.py
index 68f6661334..30230e5c4b 100644
--- a/tests/rest/client/sliding_sync/test_extensions.py
+++ b/tests/rest/client/sliding_sync/test_extensions.py
@@ -14,7 +14,7 @@
import logging
from typing import Literal
-from parameterized import parameterized
+from parameterized import parameterized, parameterized_class
from typing_extensions import assert_never
from twisted.test.proto_helpers import MemoryReactor
@@ -30,6 +30,20 @@ from tests.rest.client.sliding_sync.test_sliding_sync import SlidingSyncBase
logger = logging.getLogger(__name__)
+# FIXME: This can be removed once we bump `SCHEMA_COMPAT_VERSION` and run the
+# foreground update for
+# `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` (tracked by
+# https://github.com/element-hq/synapse/issues/17623)
+@parameterized_class(
+ ("use_new_tables",),
+ [
+ (True,),
+ (False,),
+ ],
+ class_name_func=lambda cls,
+ num,
+ params_dict: f"{cls.__name__}_{'new' if params_dict['use_new_tables'] else 'fallback'}",
+)
class SlidingSyncExtensionsTestCase(SlidingSyncBase):
"""
Test general extensions behavior in the Sliding Sync API. Each extension has their
@@ -49,6 +63,8 @@ class SlidingSyncExtensionsTestCase(SlidingSyncBase):
self.storage_controllers = hs.get_storage_controllers()
self.account_data_handler = hs.get_account_data_handler()
+ super().prepare(reactor, clock, hs)
+
# Any extensions that use `lists`/`rooms` should be tested here
@parameterized.expand([("account_data",), ("receipts",), ("typing",)])
def test_extensions_lists_rooms_relevant_rooms(
@@ -120,19 +136,26 @@ class SlidingSyncExtensionsTestCase(SlidingSyncBase):
"foo-list": {
"ranges": [[0, 1]],
"required_state": [],
- "timeline_limit": 0,
+ # We set this to `1` because we're testing `receipts` which
+ # interact with the `timeline`. With receipts, when a room
+ # hasn't been sent down the connection before or it appears
+ # as `initial: true`, we only include receipts for events in
+ # the timeline to avoid bloating and blowing up the sync
+ # response as the number of users in the room increases.
+ # (this behavior is part of the spec)
+ "timeline_limit": 1,
},
# We expect this list range to include room5, room4, room3
"bar-list": {
"ranges": [[0, 2]],
"required_state": [],
- "timeline_limit": 0,
+ "timeline_limit": 1,
},
},
"room_subscriptions": {
room_id1: {
"required_state": [],
- "timeline_limit": 0,
+ "timeline_limit": 1,
}
},
}
diff --git a/tests/rest/client/sliding_sync/test_lists_filters.py b/tests/rest/client/sliding_sync/test_lists_filters.py
new file mode 100644
index 0000000000..c59f6aedc4
--- /dev/null
+++ b/tests/rest/client/sliding_sync/test_lists_filters.py
@@ -0,0 +1,1975 @@
+#
+# This file is licensed under the Affero General Public License (AGPL) version 3.
+#
+# Copyright (C) 2024 New Vector, Ltd
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as
+# published by the Free Software Foundation, either version 3 of the
+# License, or (at your option) any later version.
+#
+# See the GNU Affero General Public License for more details:
+# <https://www.gnu.org/licenses/agpl-3.0.html>.
+#
+import logging
+
+from parameterized import parameterized_class
+
+from twisted.test.proto_helpers import MemoryReactor
+
+import synapse.rest.admin
+from synapse.api.constants import (
+ EventContentFields,
+ EventTypes,
+ RoomTypes,
+)
+from synapse.api.room_versions import RoomVersions
+from synapse.events import StrippedStateEvent
+from synapse.rest.client import login, room, sync, tags
+from synapse.server import HomeServer
+from synapse.types import JsonDict
+from synapse.util import Clock
+
+from tests.rest.client.sliding_sync.test_sliding_sync import SlidingSyncBase
+
+logger = logging.getLogger(__name__)
+
+
+# FIXME: This can be removed once we bump `SCHEMA_COMPAT_VERSION` and run the
+# foreground update for
+# `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` (tracked by
+# https://github.com/element-hq/synapse/issues/17623)
+@parameterized_class(
+ ("use_new_tables",),
+ [
+ (True,),
+ (False,),
+ ],
+ class_name_func=lambda cls,
+ num,
+ params_dict: f"{cls.__name__}_{'new' if params_dict['use_new_tables'] else 'fallback'}",
+)
+class SlidingSyncFiltersTestCase(SlidingSyncBase):
+ """
+ Test `filters` in the Sliding Sync API to make sure it includes/excludes rooms
+ correctly.
+ """
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ sync.register_servlets,
+ tags.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.store = hs.get_datastores().main
+ self.event_sources = hs.get_event_sources()
+ self.storage_controllers = hs.get_storage_controllers()
+ self.account_data_handler = hs.get_account_data_handler()
+
+ super().prepare(reactor, clock, hs)
+
+ def test_multiple_filters_and_multiple_lists(self) -> None:
+ """
+ Test that filters apply to `lists` in various scenarios.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ # Create a DM room
+ joined_dm_room_id = self._create_dm_room(
+ inviter_user_id=user1_id,
+ inviter_tok=user1_tok,
+ invitee_user_id=user2_id,
+ invitee_tok=user2_tok,
+ should_join_room=True,
+ )
+ invited_dm_room_id = self._create_dm_room(
+ inviter_user_id=user1_id,
+ inviter_tok=user1_tok,
+ invitee_user_id=user2_id,
+ invitee_tok=user2_tok,
+ should_join_room=False,
+ )
+
+ # Create a normal room
+ room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
+ self.helper.join(room_id, user1_id, tok=user1_tok)
+
+ # Create a room that user1 is invited to
+ invite_room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
+ self.helper.invite(invite_room_id, src=user2_id, targ=user1_id, tok=user2_tok)
+
+ # Make the Sliding Sync request
+ sync_body = {
+ "lists": {
+ # Absence of filters does not imply "False" values
+ "all": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 1,
+ "filters": {},
+ },
+ # Test single truthy filter
+ "dms": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 1,
+ "filters": {"is_dm": True},
+ },
+ # Test single falsy filter
+ "non-dms": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 1,
+ "filters": {"is_dm": False},
+ },
+ # Test how multiple filters should stack (AND'd together)
+ "room-invites": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 1,
+ "filters": {"is_dm": False, "is_invite": True},
+ },
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+
+ # Make sure it has the lists we requested
+ self.assertIncludes(
+ response_body["lists"].keys(),
+ {"all", "dms", "non-dms", "room-invites"},
+ exact=True,
+ )
+
+ # Make sure the lists have the correct rooms
+ self.assertIncludes(
+ set(response_body["lists"]["all"]["ops"][0]["room_ids"]),
+ {
+ invite_room_id,
+ room_id,
+ invited_dm_room_id,
+ joined_dm_room_id,
+ },
+ exact=True,
+ )
+ self.assertIncludes(
+ set(response_body["lists"]["dms"]["ops"][0]["room_ids"]),
+ {invited_dm_room_id, joined_dm_room_id},
+ exact=True,
+ )
+ self.assertIncludes(
+ set(response_body["lists"]["non-dms"]["ops"][0]["room_ids"]),
+ {invite_room_id, room_id},
+ exact=True,
+ )
+ self.assertIncludes(
+ set(response_body["lists"]["room-invites"]["ops"][0]["room_ids"]),
+ {invite_room_id},
+ exact=True,
+ )
+
+ def test_filters_regardless_of_membership_server_left_room(self) -> None:
+ """
+ Test that filters apply to rooms regardless of membership. We're also
+ compounding the problem by having all of the local users leave the room causing
+ our server to leave the room.
+
+ We want to make sure that if someone is filtering rooms, and leaves, you still
+ get that final update down sync that you left.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ # Create a normal room
+ room_id = self.helper.create_room_as(user1_id, tok=user2_tok)
+ self.helper.join(room_id, user1_id, tok=user1_tok)
+
+ # Create an encrypted space room
+ space_room_id = self.helper.create_room_as(
+ user2_id,
+ tok=user2_tok,
+ extra_content={
+ "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE}
+ },
+ )
+ self.helper.send_state(
+ space_room_id,
+ EventTypes.RoomEncryption,
+ {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"},
+ tok=user2_tok,
+ )
+ self.helper.join(space_room_id, user1_id, tok=user1_tok)
+
+ # Make an initial Sliding Sync request
+ sync_body = {
+ "lists": {
+ "all-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {},
+ },
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 1,
+ "filters": {
+ "is_encrypted": True,
+ "room_types": [RoomTypes.SPACE],
+ },
+ },
+ }
+ }
+ response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+
+ # Make sure the response has the lists we requested
+ self.assertIncludes(
+ response_body["lists"].keys(),
+ {"all-list", "foo-list"},
+ )
+
+ # Make sure the lists have the correct rooms
+ self.assertIncludes(
+ set(response_body["lists"]["all-list"]["ops"][0]["room_ids"]),
+ {space_room_id, room_id},
+ exact=True,
+ )
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {space_room_id},
+ exact=True,
+ )
+
+ # Everyone leaves the encrypted space room
+ self.helper.leave(space_room_id, user1_id, tok=user1_tok)
+ self.helper.leave(space_room_id, user2_id, tok=user2_tok)
+
+ # Make an incremental Sliding Sync request
+ sync_body = {
+ "lists": {
+ "all-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {},
+ },
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 1,
+ "filters": {
+ "is_encrypted": True,
+ "room_types": [RoomTypes.SPACE],
+ },
+ },
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
+
+ # Make sure the response has the lists we requested
+ self.assertIncludes(
+ response_body["lists"].keys(),
+ {"all-list", "foo-list"},
+ exact=True,
+ )
+
+ # Make sure the lists have the correct rooms even though we `newly_left`
+ self.assertIncludes(
+ set(response_body["lists"]["all-list"]["ops"][0]["room_ids"]),
+ {space_room_id, room_id},
+ exact=True,
+ )
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {space_room_id},
+ exact=True,
+ )
+
+ def test_filters_is_dm(self) -> None:
+ """
+ Test `filter.is_dm` for DM rooms
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ # Create a normal room
+ room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+
+ # Create a DM room
+ dm_room_id = self._create_dm_room(
+ inviter_user_id=user1_id,
+ inviter_tok=user1_tok,
+ invitee_user_id=user2_id,
+ invitee_tok=user2_tok,
+ )
+
+ # Try with `is_dm=True`
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "is_dm": True,
+ },
+ },
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {dm_room_id},
+ exact=True,
+ )
+
+ # Try with `is_dm=False`
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "is_dm": False,
+ },
+ },
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {room_id},
+ exact=True,
+ )
+
+ def test_filters_is_encrypted(self) -> None:
+ """
+ Test `filters.is_encrypted` for encrypted rooms
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ # Create an unencrypted room
+ room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+
+ # Create an encrypted room
+ encrypted_room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+ self.helper.send_state(
+ encrypted_room_id,
+ EventTypes.RoomEncryption,
+ {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"},
+ tok=user1_tok,
+ )
+
+ # Try with `is_encrypted=True`
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "is_encrypted": True,
+ },
+ },
+ }
+ }
+ response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+ # No rooms are encrypted yet
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {encrypted_room_id},
+ exact=True,
+ )
+
+ # Try with `is_encrypted=False`
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "is_encrypted": False,
+ },
+ },
+ }
+ }
+ response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+ # No rooms are encrypted yet
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {room_id},
+ exact=True,
+ )
+
+ def test_filters_is_encrypted_server_left_room(self) -> None:
+ """
+ Test that we can apply a `filters.is_encrypted` against a room that everyone has left.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ # Get a token before we create any rooms
+ sync_body: JsonDict = {
+ "lists": {},
+ }
+ response_body, before_rooms_token = self.do_sync(sync_body, tok=user1_tok)
+
+ # Create an unencrypted room
+ room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+ # Leave the room
+ self.helper.leave(room_id, user1_id, tok=user1_tok)
+
+ # Create an encrypted room
+ encrypted_room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+ self.helper.send_state(
+ encrypted_room_id,
+ EventTypes.RoomEncryption,
+ {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"},
+ tok=user1_tok,
+ )
+ # Leave the room
+ self.helper.leave(encrypted_room_id, user1_id, tok=user1_tok)
+
+ # Try with `is_encrypted=True`
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "is_encrypted": True,
+ },
+ },
+ }
+ }
+ # Use an incremental sync so that the room is considered `newly_left` and shows
+ # up down sync
+ response_body, _ = self.do_sync(
+ sync_body, since=before_rooms_token, tok=user1_tok
+ )
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {encrypted_room_id},
+ exact=True,
+ )
+
+ # Try with `is_encrypted=False`
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "is_encrypted": False,
+ },
+ },
+ }
+ }
+ # Use an incremental sync so that the room is considered `newly_left` and shows
+ # up down sync
+ response_body, _ = self.do_sync(
+ sync_body, since=before_rooms_token, tok=user1_tok
+ )
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {room_id},
+ exact=True,
+ )
+
+ def test_filters_is_encrypted_server_left_room2(self) -> None:
+ """
+ Test that we can apply a `filters.is_encrypted` against a room that everyone has
+ left.
+
+ There is still someone local who is invited to the rooms but that doesn't affect
+ whether the server is participating in the room (users need to be joined).
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ _user2_tok = self.login(user2_id, "pass")
+
+ # Get a token before we create any rooms
+ sync_body: JsonDict = {
+ "lists": {},
+ }
+ response_body, before_rooms_token = self.do_sync(sync_body, tok=user1_tok)
+
+ # Create an unencrypted room
+ room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+ # Invite user2
+ self.helper.invite(room_id, targ=user2_id, tok=user1_tok)
+ # User1 leaves the room
+ self.helper.leave(room_id, user1_id, tok=user1_tok)
+
+ # Create an encrypted room
+ encrypted_room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+ self.helper.send_state(
+ encrypted_room_id,
+ EventTypes.RoomEncryption,
+ {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"},
+ tok=user1_tok,
+ )
+ # Invite user2
+ self.helper.invite(encrypted_room_id, targ=user2_id, tok=user1_tok)
+ # User1 leaves the room
+ self.helper.leave(encrypted_room_id, user1_id, tok=user1_tok)
+
+ # Try with `is_encrypted=True`
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "is_encrypted": True,
+ },
+ },
+ }
+ }
+ # Use an incremental sync so that the room is considered `newly_left` and shows
+ # up down sync
+ response_body, _ = self.do_sync(
+ sync_body, since=before_rooms_token, tok=user1_tok
+ )
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {encrypted_room_id},
+ exact=True,
+ )
+
+ # Try with `is_encrypted=False`
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "is_encrypted": False,
+ },
+ },
+ }
+ }
+ # Use an incremental sync so that the room is considered `newly_left` and shows
+ # up down sync
+ response_body, _ = self.do_sync(
+ sync_body, since=before_rooms_token, tok=user1_tok
+ )
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {room_id},
+ exact=True,
+ )
+
+ def test_filters_is_encrypted_after_we_left(self) -> None:
+ """
+ Test that we can apply a `filters.is_encrypted` against a room that was encrypted
+ after we left the room (make sure we don't just use the current state)
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ # Get a token before we create any rooms
+ sync_body: JsonDict = {
+ "lists": {},
+ }
+ response_body, before_rooms_token = self.do_sync(sync_body, tok=user1_tok)
+
+ # Create an unencrypted room
+ room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
+ # Leave the room
+ self.helper.join(room_id, user1_id, tok=user1_tok)
+ self.helper.leave(room_id, user1_id, tok=user1_tok)
+
+ # Create a room that will be encrypted
+ encrypted_after_we_left_room_id = self.helper.create_room_as(
+ user2_id, tok=user2_tok
+ )
+ # Leave the room
+ self.helper.join(encrypted_after_we_left_room_id, user1_id, tok=user1_tok)
+ self.helper.leave(encrypted_after_we_left_room_id, user1_id, tok=user1_tok)
+
+ # Encrypt the room after we've left
+ self.helper.send_state(
+ encrypted_after_we_left_room_id,
+ EventTypes.RoomEncryption,
+ {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"},
+ tok=user2_tok,
+ )
+
+ # Try with `is_encrypted=True`
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "is_encrypted": True,
+ },
+ },
+ }
+ }
+ response_body, _ = self.do_sync(
+ sync_body, since=before_rooms_token, tok=user1_tok
+ )
+ if self.use_new_tables:
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ set(),
+ exact=True,
+ )
+ else:
+ # Even though we left the room before it was encrypted, we still see it because
+ # someone else on our server is still participating in the room and we "leak"
+ # the current state to the left user. But we consider the room encryption status
+ # to not be a secret given it's often set at the start of the room and it's one
+ # of the stripped state events that is normally handed out.
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {encrypted_after_we_left_room_id},
+ exact=True,
+ )
+
+ # Try with `is_encrypted=False`
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "is_encrypted": False,
+ },
+ },
+ }
+ }
+ response_body, _ = self.do_sync(
+ sync_body, since=before_rooms_token, tok=user1_tok
+ )
+ if self.use_new_tables:
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {room_id, encrypted_after_we_left_room_id},
+ exact=True,
+ )
+ else:
+ # Even though we left the room before it was encrypted... (see comment above)
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {room_id},
+ exact=True,
+ )
+
+ def test_filters_is_encrypted_with_remote_invite_room_no_stripped_state(
+ self,
+ ) -> None:
+ """
+ Test that we can apply a `filters.is_encrypted` filter against a remote invite
+ room without any `unsigned.invite_room_state` (stripped state).
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ # Create a remote invite room without any `unsigned.invite_room_state`
+ _remote_invite_room_id = self._create_remote_invite_room_for_user(
+ user1_id, None
+ )
+
+ # Create an unencrypted room
+ room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+
+ # Create an encrypted room
+ encrypted_room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+ self.helper.send_state(
+ encrypted_room_id,
+ EventTypes.RoomEncryption,
+ {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"},
+ tok=user1_tok,
+ )
+
+ # Try with `is_encrypted=True`
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "is_encrypted": True,
+ },
+ },
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ # `remote_invite_room_id` should not appear because we can't figure out whether
+ # it is encrypted or not (no stripped state, `unsigned.invite_room_state`).
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {encrypted_room_id},
+ exact=True,
+ )
+
+ # Try with `is_encrypted=False`
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "is_encrypted": False,
+ },
+ },
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ # `remote_invite_room_id` should not appear because we can't figure out whether
+ # it is encrypted or not (no stripped state, `unsigned.invite_room_state`).
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {room_id},
+ exact=True,
+ )
+
+ def test_filters_is_encrypted_with_remote_invite_encrypted_room(self) -> None:
+ """
+ Test that we can apply a `filters.is_encrypted` filter against a remote invite
+ encrypted room with some `unsigned.invite_room_state` (stripped state).
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ # Create a remote invite room with some `unsigned.invite_room_state`
+ # indicating that the room is encrypted.
+ remote_invite_room_id = self._create_remote_invite_room_for_user(
+ user1_id,
+ [
+ StrippedStateEvent(
+ type=EventTypes.Create,
+ state_key="",
+ sender="@inviter:remote_server",
+ content={
+ EventContentFields.ROOM_CREATOR: "@inviter:remote_server",
+ EventContentFields.ROOM_VERSION: RoomVersions.V10.identifier,
+ },
+ ),
+ StrippedStateEvent(
+ type=EventTypes.RoomEncryption,
+ state_key="",
+ sender="@inviter:remote_server",
+ content={
+ EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2",
+ },
+ ),
+ ],
+ )
+
+ # Create an unencrypted room
+ room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+
+ # Create an encrypted room
+ encrypted_room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+ self.helper.send_state(
+ encrypted_room_id,
+ EventTypes.RoomEncryption,
+ {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"},
+ tok=user1_tok,
+ )
+
+ # Try with `is_encrypted=True`
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "is_encrypted": True,
+ },
+ },
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ # `remote_invite_room_id` should appear here because it is encrypted
+ # according to the stripped state
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {encrypted_room_id, remote_invite_room_id},
+ exact=True,
+ )
+
+ # Try with `is_encrypted=False`
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "is_encrypted": False,
+ },
+ },
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ # `remote_invite_room_id` should not appear here because it is encrypted
+ # according to the stripped state
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {room_id},
+ exact=True,
+ )
+
+ def test_filters_is_encrypted_with_remote_invite_unencrypted_room(self) -> None:
+ """
+ Test that we can apply a `filters.is_encrypted` filter against a remote invite
+ unencrypted room with some `unsigned.invite_room_state` (stripped state).
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ # Create a remote invite room with some `unsigned.invite_room_state`
+ # but don't set any room encryption event.
+ remote_invite_room_id = self._create_remote_invite_room_for_user(
+ user1_id,
+ [
+ StrippedStateEvent(
+ type=EventTypes.Create,
+ state_key="",
+ sender="@inviter:remote_server",
+ content={
+ EventContentFields.ROOM_CREATOR: "@inviter:remote_server",
+ EventContentFields.ROOM_VERSION: RoomVersions.V10.identifier,
+ },
+ ),
+ # No room encryption event
+ ],
+ )
+
+ # Create an unencrypted room
+ room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+
+ # Create an encrypted room
+ encrypted_room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+ self.helper.send_state(
+ encrypted_room_id,
+ EventTypes.RoomEncryption,
+ {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"},
+ tok=user1_tok,
+ )
+
+ # Try with `is_encrypted=True`
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "is_encrypted": True,
+ },
+ },
+ }
+ }
+ response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+ # `remote_invite_room_id` should not appear here because it is unencrypted
+ # according to the stripped state
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {encrypted_room_id},
+ exact=True,
+ )
+
+ # Try with `is_encrypted=False`
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "is_encrypted": False,
+ },
+ },
+ }
+ }
+ response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+ # `remote_invite_room_id` should appear because it is unencrypted according to
+ # the stripped state
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {room_id, remote_invite_room_id},
+ exact=True,
+ )
+
+ def test_filters_is_encrypted_updated(self) -> None:
+ """
+ Make sure we get rooms if the encrypted room status is updated for a joined room
+ (`filters.is_encrypted`)
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "is_encrypted": True,
+ },
+ },
+ }
+ }
+ response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+
+ # No rooms are encrypted yet
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ set(),
+ exact=True,
+ )
+
+ # Update the encryption status
+ self.helper.send_state(
+ room_id,
+ EventTypes.RoomEncryption,
+ {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"},
+ tok=user1_tok,
+ )
+
+ # We should see the room now because it's encrypted
+ response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {room_id},
+ exact=True,
+ )
+
+ def test_filters_is_invite_rooms(self) -> None:
+ """
+ Test `filters.is_invite` for rooms that the user has been invited to
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ # Create a normal room
+ room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
+ self.helper.join(room_id, user1_id, tok=user1_tok)
+
+ # Create a room that user1 is invited to
+ invite_room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
+ self.helper.invite(invite_room_id, src=user2_id, targ=user1_id, tok=user2_tok)
+
+ # Try with `is_invite=True`
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "is_invite": True,
+ },
+ },
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {invite_room_id},
+ exact=True,
+ )
+
+ # Try with `is_invite=False`
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "is_invite": False,
+ },
+ },
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {room_id},
+ exact=True,
+ )
+
+ def test_filters_room_types(self) -> None:
+ """
+ Test `filters.room_types` for different room types
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ # Create a normal room (no room type)
+ room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+
+ # Create a space room
+ space_room_id = self.helper.create_room_as(
+ user1_id,
+ tok=user1_tok,
+ extra_content={
+ "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE}
+ },
+ )
+
+ # Create an arbitrarily typed room
+ foo_room_id = self.helper.create_room_as(
+ user1_id,
+ tok=user1_tok,
+ extra_content={
+ "creation_content": {
+ EventContentFields.ROOM_TYPE: "org.matrix.foobarbaz"
+ }
+ },
+ )
+
+ # Try finding only normal rooms
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "room_types": [None],
+ },
+ },
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {room_id},
+ exact=True,
+ )
+
+ # Try finding only spaces
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "room_types": [RoomTypes.SPACE],
+ },
+ },
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {space_room_id},
+ exact=True,
+ )
+
+ # Try finding normal rooms and spaces
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "room_types": [None, RoomTypes.SPACE],
+ },
+ },
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {room_id, space_room_id},
+ exact=True,
+ )
+
+ # Try finding an arbitrary room type
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "room_types": ["org.matrix.foobarbaz"],
+ },
+ },
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {foo_room_id},
+ exact=True,
+ )
+
+ # Just make sure we know what happens when you specify an empty list of room_types
+ # (we should find nothing)
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "room_types": [],
+ },
+ },
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ set(),
+ exact=True,
+ )
+
+ def test_filters_not_room_types(self) -> None:
+ """
+ Test `filters.not_room_types` for different room types
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ # Create a normal room (no room type)
+ room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+
+ # Create a space room
+ space_room_id = self.helper.create_room_as(
+ user1_id,
+ tok=user1_tok,
+ extra_content={
+ "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE}
+ },
+ )
+
+ # Create an arbitrarily typed room
+ foo_room_id = self.helper.create_room_as(
+ user1_id,
+ tok=user1_tok,
+ extra_content={
+ "creation_content": {
+ EventContentFields.ROOM_TYPE: "org.matrix.foobarbaz"
+ }
+ },
+ )
+
+ # Try finding *NOT* normal rooms
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "not_room_types": [None],
+ },
+ },
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {space_room_id, foo_room_id},
+ exact=True,
+ )
+
+ # Try finding *NOT* spaces
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "not_room_types": [RoomTypes.SPACE],
+ },
+ },
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {room_id, foo_room_id},
+ exact=True,
+ )
+
+ # Try finding *NOT* normal rooms or spaces
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "not_room_types": [None, RoomTypes.SPACE],
+ },
+ },
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {foo_room_id},
+ exact=True,
+ )
+
+ # Test how it behaves when we have both `room_types` and `not_room_types`.
+ # `not_room_types` should win.
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "room_types": [None],
+ "not_room_types": [None],
+ },
+ },
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ # Nothing matches because nothing is both a normal room and not a normal room
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ set(),
+ exact=True,
+ )
+
+ # Test how it behaves when we have both `room_types` and `not_room_types`.
+ # `not_room_types` should win.
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "room_types": [None, RoomTypes.SPACE],
+ "not_room_types": [None],
+ },
+ },
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {space_room_id},
+ exact=True,
+ )
+
+ # Just make sure we know what happens when you specify an empty list of not_room_types
+ # (we should find all of the rooms)
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "not_room_types": [],
+ },
+ },
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {room_id, foo_room_id, space_room_id},
+ exact=True,
+ )
+
+ def test_filters_room_types_server_left_room(self) -> None:
+ """
+ Test that we can apply a `filters.room_types` against a room that everyone has left.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ # Get a token before we create any rooms
+ sync_body: JsonDict = {
+ "lists": {},
+ }
+ response_body, before_rooms_token = self.do_sync(sync_body, tok=user1_tok)
+
+ # Create a normal room (no room type)
+ room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+ # Leave the room
+ self.helper.leave(room_id, user1_id, tok=user1_tok)
+
+ # Create a space room
+ space_room_id = self.helper.create_room_as(
+ user1_id,
+ tok=user1_tok,
+ extra_content={
+ "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE}
+ },
+ )
+ # Leave the room
+ self.helper.leave(space_room_id, user1_id, tok=user1_tok)
+
+ # Try finding only normal rooms
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "room_types": [None],
+ },
+ },
+ }
+ }
+ # Use an incremental sync so that the room is considered `newly_left` and shows
+ # up down sync
+ response_body, _ = self.do_sync(
+ sync_body, since=before_rooms_token, tok=user1_tok
+ )
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {room_id},
+ exact=True,
+ )
+
+ # Try finding only spaces
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "room_types": [RoomTypes.SPACE],
+ },
+ },
+ }
+ }
+ # Use an incremental sync so that the room is considered `newly_left` and shows
+ # up down sync
+ response_body, _ = self.do_sync(
+ sync_body, since=before_rooms_token, tok=user1_tok
+ )
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {space_room_id},
+ exact=True,
+ )
+
+ def test_filter_room_types_server_left_room2(self) -> None:
+ """
+ Test that we can apply a `filter.room_types` against a room that everyone has left.
+
+ There is still someone local who is invited to the rooms but that doesn't affect
+ whether the server is participating in the room (users need to be joined).
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ _user2_tok = self.login(user2_id, "pass")
+
+ # Get a token before we create any rooms
+ sync_body: JsonDict = {
+ "lists": {},
+ }
+ response_body, before_rooms_token = self.do_sync(sync_body, tok=user1_tok)
+
+ # Create a normal room (no room type)
+ room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+ # Invite user2
+ self.helper.invite(room_id, targ=user2_id, tok=user1_tok)
+ # User1 leaves the room
+ self.helper.leave(room_id, user1_id, tok=user1_tok)
+
+ # Create a space room
+ space_room_id = self.helper.create_room_as(
+ user1_id,
+ tok=user1_tok,
+ extra_content={
+ "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE}
+ },
+ )
+ # Invite user2
+ self.helper.invite(space_room_id, targ=user2_id, tok=user1_tok)
+ # User1 leaves the room
+ self.helper.leave(space_room_id, user1_id, tok=user1_tok)
+
+ # Try finding only normal rooms
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "room_types": [None],
+ },
+ },
+ }
+ }
+ # Use an incremental sync so that the room is considered `newly_left` and shows
+ # up down sync
+ response_body, _ = self.do_sync(
+ sync_body, since=before_rooms_token, tok=user1_tok
+ )
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {room_id},
+ exact=True,
+ )
+
+ # Try finding only spaces
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "room_types": [RoomTypes.SPACE],
+ },
+ },
+ }
+ }
+ # Use an incremental sync so that the room is considered `newly_left` and shows
+ # up down sync
+ response_body, _ = self.do_sync(
+ sync_body, since=before_rooms_token, tok=user1_tok
+ )
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {space_room_id},
+ exact=True,
+ )
+
+ def test_filters_room_types_with_remote_invite_room_no_stripped_state(self) -> None:
+ """
+ Test that we can apply a `filters.room_types` filter against a remote invite
+ room without any `unsigned.invite_room_state` (stripped state).
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ # Create a remote invite room without any `unsigned.invite_room_state`
+ _remote_invite_room_id = self._create_remote_invite_room_for_user(
+ user1_id, None
+ )
+
+ # Create a normal room (no room type)
+ room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+
+ # Create a space room
+ space_room_id = self.helper.create_room_as(
+ user1_id,
+ tok=user1_tok,
+ extra_content={
+ "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE}
+ },
+ )
+
+ # Try finding only normal rooms
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "room_types": [None],
+ },
+ },
+ }
+ }
+ # `remote_invite_room_id` should not appear because we can't figure out what
+ # room type it is (no stripped state, `unsigned.invite_room_state`)
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {room_id},
+ exact=True,
+ )
+
+ # Try finding only spaces
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "room_types": [RoomTypes.SPACE],
+ },
+ },
+ }
+ }
+ # `remote_invite_room_id` should not appear because we can't figure out what
+ # room type it is (no stripped state, `unsigned.invite_room_state`)
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {space_room_id},
+ exact=True,
+ )
+
+ def test_filters_room_types_with_remote_invite_space(self) -> None:
+ """
+ Test that we can apply a `filters.room_types` filter against a remote invite
+ to a space room with some `unsigned.invite_room_state` (stripped state).
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ # Create a remote invite room with some `unsigned.invite_room_state` indicating
+ # that it is a space room
+ remote_invite_room_id = self._create_remote_invite_room_for_user(
+ user1_id,
+ [
+ StrippedStateEvent(
+ type=EventTypes.Create,
+ state_key="",
+ sender="@inviter:remote_server",
+ content={
+ EventContentFields.ROOM_CREATOR: "@inviter:remote_server",
+ EventContentFields.ROOM_VERSION: RoomVersions.V10.identifier,
+ # Specify that it is a space room
+ EventContentFields.ROOM_TYPE: RoomTypes.SPACE,
+ },
+ ),
+ ],
+ )
+
+ # Create a normal room (no room type)
+ room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+
+ # Create a space room
+ space_room_id = self.helper.create_room_as(
+ user1_id,
+ tok=user1_tok,
+ extra_content={
+ "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE}
+ },
+ )
+
+ # Try finding only normal rooms
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "room_types": [None],
+ },
+ },
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ # `remote_invite_room_id` should not appear here because it is a space room
+ # according to the stripped state
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {room_id},
+ exact=True,
+ )
+
+ # Try finding only spaces
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "room_types": [RoomTypes.SPACE],
+ },
+ },
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ # `remote_invite_room_id` should appear here because it is a space room
+ # according to the stripped state
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {space_room_id, remote_invite_room_id},
+ exact=True,
+ )
+
+ def test_filters_room_types_with_remote_invite_normal_room(self) -> None:
+ """
+ Test that we can apply a `filters.room_types` filter against a remote invite
+ to a normal room with some `unsigned.invite_room_state` (stripped state).
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ # Create a remote invite room with some `unsigned.invite_room_state`
+ # but the create event does not specify a room type (normal room)
+ remote_invite_room_id = self._create_remote_invite_room_for_user(
+ user1_id,
+ [
+ StrippedStateEvent(
+ type=EventTypes.Create,
+ state_key="",
+ sender="@inviter:remote_server",
+ content={
+ EventContentFields.ROOM_CREATOR: "@inviter:remote_server",
+ EventContentFields.ROOM_VERSION: RoomVersions.V10.identifier,
+ # No room type means this is a normal room
+ },
+ ),
+ ],
+ )
+
+ # Create a normal room (no room type)
+ room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+
+ # Create a space room
+ space_room_id = self.helper.create_room_as(
+ user1_id,
+ tok=user1_tok,
+ extra_content={
+ "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE}
+ },
+ )
+
+ # Try finding only normal rooms
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "room_types": [None],
+ },
+ },
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ # `remote_invite_room_id` should appear here because it is a normal room
+ # according to the stripped state (no room type)
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {room_id, remote_invite_room_id},
+ exact=True,
+ )
+
+ # Try finding only spaces
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "room_types": [RoomTypes.SPACE],
+ },
+ },
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ # `remote_invite_room_id` should not appear here because it is a normal room
+ # according to the stripped state (no room type)
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {space_room_id},
+ exact=True,
+ )
+
+ def _add_tag_to_room(
+ self, *, room_id: str, user_id: str, access_token: str, tag_name: str
+ ) -> None:
+ channel = self.make_request(
+ method="PUT",
+ path=f"/user/{user_id}/rooms/{room_id}/tags/{tag_name}",
+ content={},
+ access_token=access_token,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ def test_filters_tags(self) -> None:
+ """
+ Test `filters.tags` for rooms with given tags
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ # Create a room with no tags
+ self.helper.create_room_as(user1_id, tok=user1_tok)
+
+ # Create some rooms with tags
+ foo_room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+ bar_room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+ # Create a room without multiple tags
+ foobar_room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+
+ # Add the "foo" tag to the foo room
+ self._add_tag_to_room(
+ room_id=foo_room_id,
+ user_id=user1_id,
+ access_token=user1_tok,
+ tag_name="foo",
+ )
+ # Add the "bar" tag to the bar room
+ self._add_tag_to_room(
+ room_id=bar_room_id,
+ user_id=user1_id,
+ access_token=user1_tok,
+ tag_name="bar",
+ )
+ # Add both "foo" and "bar" tags to the foobar room
+ self._add_tag_to_room(
+ room_id=foobar_room_id,
+ user_id=user1_id,
+ access_token=user1_tok,
+ tag_name="foo",
+ )
+ self._add_tag_to_room(
+ room_id=foobar_room_id,
+ user_id=user1_id,
+ access_token=user1_tok,
+ tag_name="bar",
+ )
+
+ # Try finding rooms with the "foo" tag
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "tags": ["foo"],
+ },
+ },
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {foo_room_id, foobar_room_id},
+ exact=True,
+ )
+
+ # Try finding rooms with either "foo" or "bar" tags
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "tags": ["foo", "bar"],
+ },
+ },
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {foo_room_id, bar_room_id, foobar_room_id},
+ exact=True,
+ )
+
+ # Try with a random tag we didn't add
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "tags": ["flomp"],
+ },
+ },
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ # No rooms should match
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ set(),
+ exact=True,
+ )
+
+ # Just make sure we know what happens when you specify an empty list of tags
+ # (we should find nothing)
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "tags": [],
+ },
+ },
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ set(),
+ exact=True,
+ )
+
+ def test_filters_not_tags(self) -> None:
+ """
+ Test `filters.not_tags` for excluding rooms with given tags
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ # Create a room with no tags
+ untagged_room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+
+ # Create some rooms with tags
+ foo_room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+ bar_room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+ # Create a room without multiple tags
+ foobar_room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+
+ # Add the "foo" tag to the foo room
+ self._add_tag_to_room(
+ room_id=foo_room_id,
+ user_id=user1_id,
+ access_token=user1_tok,
+ tag_name="foo",
+ )
+ # Add the "bar" tag to the bar room
+ self._add_tag_to_room(
+ room_id=bar_room_id,
+ user_id=user1_id,
+ access_token=user1_tok,
+ tag_name="bar",
+ )
+ # Add both "foo" and "bar" tags to the foobar room
+ self._add_tag_to_room(
+ room_id=foobar_room_id,
+ user_id=user1_id,
+ access_token=user1_tok,
+ tag_name="foo",
+ )
+ self._add_tag_to_room(
+ room_id=foobar_room_id,
+ user_id=user1_id,
+ access_token=user1_tok,
+ tag_name="bar",
+ )
+
+ # Try finding rooms without the "foo" tag
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "not_tags": ["foo"],
+ },
+ },
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {untagged_room_id, bar_room_id},
+ exact=True,
+ )
+
+ # Try finding rooms without either "foo" or "bar" tags
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "not_tags": ["foo", "bar"],
+ },
+ },
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {untagged_room_id},
+ exact=True,
+ )
+
+ # Test how it behaves when we have both `tags` and `not_tags`.
+ # `not_tags` should win.
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "tags": ["foo"],
+ "not_tags": ["foo"],
+ },
+ },
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ # Nothing matches because nothing is both tagged with "foo" and not tagged with "foo"
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ set(),
+ exact=True,
+ )
+
+ # Just make sure we know what happens when you specify an empty list of not_tags
+ # (we should find all of the rooms)
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "not_tags": [],
+ },
+ },
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {untagged_room_id, foo_room_id, bar_room_id, foobar_room_id},
+ exact=True,
+ )
diff --git a/tests/rest/client/sliding_sync/test_room_subscriptions.py b/tests/rest/client/sliding_sync/test_room_subscriptions.py
index cc17b0b354..285fdaaf78 100644
--- a/tests/rest/client/sliding_sync/test_room_subscriptions.py
+++ b/tests/rest/client/sliding_sync/test_room_subscriptions.py
@@ -14,6 +14,8 @@
import logging
from http import HTTPStatus
+from parameterized import parameterized_class
+
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
@@ -27,6 +29,20 @@ from tests.rest.client.sliding_sync.test_sliding_sync import SlidingSyncBase
logger = logging.getLogger(__name__)
+# FIXME: This can be removed once we bump `SCHEMA_COMPAT_VERSION` and run the
+# foreground update for
+# `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` (tracked by
+# https://github.com/element-hq/synapse/issues/17623)
+@parameterized_class(
+ ("use_new_tables",),
+ [
+ (True,),
+ (False,),
+ ],
+ class_name_func=lambda cls,
+ num,
+ params_dict: f"{cls.__name__}_{'new' if params_dict['use_new_tables'] else 'fallback'}",
+)
class SlidingSyncRoomSubscriptionsTestCase(SlidingSyncBase):
"""
Test `room_subscriptions` in the Sliding Sync API.
@@ -43,6 +59,8 @@ class SlidingSyncRoomSubscriptionsTestCase(SlidingSyncBase):
self.store = hs.get_datastores().main
self.storage_controllers = hs.get_storage_controllers()
+ super().prepare(reactor, clock, hs)
+
def test_room_subscriptions_with_join_membership(self) -> None:
"""
Test `room_subscriptions` with a joined room should give us timeline and current
diff --git a/tests/rest/client/sliding_sync/test_rooms_invites.py b/tests/rest/client/sliding_sync/test_rooms_invites.py
index f08ffaf674..882762ca29 100644
--- a/tests/rest/client/sliding_sync/test_rooms_invites.py
+++ b/tests/rest/client/sliding_sync/test_rooms_invites.py
@@ -13,6 +13,8 @@
#
import logging
+from parameterized import parameterized_class
+
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
@@ -27,6 +29,20 @@ from tests.rest.client.sliding_sync.test_sliding_sync import SlidingSyncBase
logger = logging.getLogger(__name__)
+# FIXME: This can be removed once we bump `SCHEMA_COMPAT_VERSION` and run the
+# foreground update for
+# `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` (tracked by
+# https://github.com/element-hq/synapse/issues/17623)
+@parameterized_class(
+ ("use_new_tables",),
+ [
+ (True,),
+ (False,),
+ ],
+ class_name_func=lambda cls,
+ num,
+ params_dict: f"{cls.__name__}_{'new' if params_dict['use_new_tables'] else 'fallback'}",
+)
class SlidingSyncRoomsInvitesTestCase(SlidingSyncBase):
"""
Test to make sure the `rooms` response looks good for invites in the Sliding Sync API.
@@ -49,6 +65,8 @@ class SlidingSyncRoomsInvitesTestCase(SlidingSyncBase):
self.store = hs.get_datastores().main
self.storage_controllers = hs.get_storage_controllers()
+ super().prepare(reactor, clock, hs)
+
def test_rooms_invite_shared_history_initial_sync(self) -> None:
"""
Test that `rooms` we are invited to have some stripped `invite_state` during an
diff --git a/tests/rest/client/sliding_sync/test_rooms_meta.py b/tests/rest/client/sliding_sync/test_rooms_meta.py
index 04f11c0524..0a8b2c02c2 100644
--- a/tests/rest/client/sliding_sync/test_rooms_meta.py
+++ b/tests/rest/client/sliding_sync/test_rooms_meta.py
@@ -13,10 +13,12 @@
#
import logging
+from parameterized import parameterized, parameterized_class
+
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EventContentFields, EventTypes, Membership
from synapse.api.room_versions import RoomVersions
from synapse.rest.client import login, room, sync
from synapse.server import HomeServer
@@ -28,6 +30,20 @@ from tests.test_utils.event_injection import create_event
logger = logging.getLogger(__name__)
+# FIXME: This can be removed once we bump `SCHEMA_COMPAT_VERSION` and run the
+# foreground update for
+# `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` (tracked by
+# https://github.com/element-hq/synapse/issues/17623)
+@parameterized_class(
+ ("use_new_tables",),
+ [
+ (True,),
+ (False,),
+ ],
+ class_name_func=lambda cls,
+ num,
+ params_dict: f"{cls.__name__}_{'new' if params_dict['use_new_tables'] else 'fallback'}",
+)
class SlidingSyncRoomsMetaTestCase(SlidingSyncBase):
"""
Test rooms meta info like name, avatar, joined_count, invited_count, is_dm,
@@ -44,11 +60,18 @@ class SlidingSyncRoomsMetaTestCase(SlidingSyncBase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.storage_controllers = hs.get_storage_controllers()
+ self.state_handler = self.hs.get_state_handler()
+ persistence = self.hs.get_storage_controllers().persistence
+ assert persistence is not None
+ self.persistence = persistence
+
+ super().prepare(reactor, clock, hs)
- def test_rooms_meta_when_joined(self) -> None:
+ def test_rooms_meta_when_joined_initial(self) -> None:
"""
- Test that the `rooms` `name` and `avatar` are included in the response and
- reflect the current state of the room when the user is joined to the room.
+ Test that the `rooms` `name` and `avatar` are included in the initial sync
+ response and reflect the current state of the room when the user is joined to
+ the room.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
@@ -85,6 +108,7 @@ class SlidingSyncRoomsMetaTestCase(SlidingSyncBase):
response_body, _ = self.do_sync(sync_body, tok=user1_tok)
# Reflect the current state of the room
+ self.assertEqual(response_body["rooms"][room_id1]["initial"], True)
self.assertEqual(
response_body["rooms"][room_id1]["name"],
"my super room",
@@ -107,6 +131,178 @@ class SlidingSyncRoomsMetaTestCase(SlidingSyncBase):
response_body["rooms"][room_id1].get("is_dm"),
)
+ def test_rooms_meta_when_joined_incremental_no_change(self) -> None:
+ """
+ Test that the `rooms` `name` and `avatar` aren't included in an incremental sync
+ response if they haven't changed.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ room_id1 = self.helper.create_room_as(
+ user2_id,
+ tok=user2_tok,
+ extra_content={
+ "name": "my super room",
+ },
+ )
+ # Set the room avatar URL
+ self.helper.send_state(
+ room_id1,
+ EventTypes.RoomAvatar,
+ {"url": "mxc://DUMMY_MEDIA_ID"},
+ tok=user2_tok,
+ )
+
+ self.helper.join(room_id1, user1_id, tok=user1_tok)
+
+ # Make the Sliding Sync request
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ "required_state": [],
+ # This needs to be set to one so the `RoomResult` isn't empty and
+ # the room comes down incremental sync when we send a new message.
+ "timeline_limit": 1,
+ }
+ }
+ }
+ response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+
+ # Send a message to make the room come down sync
+ self.helper.send(room_id1, "message in room1", tok=user2_tok)
+
+ # Incremental sync
+ response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
+
+ # We should only see changed meta info (nothing changed so we shouldn't see any
+ # of these fields)
+ self.assertNotIn(
+ "initial",
+ response_body["rooms"][room_id1],
+ )
+ self.assertNotIn(
+ "name",
+ response_body["rooms"][room_id1],
+ )
+ self.assertNotIn(
+ "avatar",
+ response_body["rooms"][room_id1],
+ )
+ self.assertNotIn(
+ "joined_count",
+ response_body["rooms"][room_id1],
+ )
+ self.assertNotIn(
+ "invited_count",
+ response_body["rooms"][room_id1],
+ )
+ self.assertIsNone(
+ response_body["rooms"][room_id1].get("is_dm"),
+ )
+
+ @parameterized.expand(
+ [
+ ("in_required_state", True),
+ ("not_in_required_state", False),
+ ]
+ )
+ def test_rooms_meta_when_joined_incremental_with_state_change(
+ self, test_description: str, include_changed_state_in_required_state: bool
+ ) -> None:
+ """
+ Test that the `rooms` `name` and `avatar` are included in an incremental sync
+ response if they changed.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ room_id1 = self.helper.create_room_as(
+ user2_id,
+ tok=user2_tok,
+ extra_content={
+ "name": "my super room",
+ },
+ )
+ # Set the room avatar URL
+ self.helper.send_state(
+ room_id1,
+ EventTypes.RoomAvatar,
+ {"url": "mxc://DUMMY_MEDIA_ID"},
+ tok=user2_tok,
+ )
+
+ self.helper.join(room_id1, user1_id, tok=user1_tok)
+
+ # Make the Sliding Sync request
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ "required_state": (
+ [[EventTypes.Name, ""], [EventTypes.RoomAvatar, ""]]
+ # Conditionally include the changed state in the
+ # `required_state` to make sure whether we request it or not,
+ # the new room name still flows down to the client.
+ if include_changed_state_in_required_state
+ else []
+ ),
+ "timeline_limit": 0,
+ }
+ }
+ }
+ response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+
+ # Update the room name
+ self.helper.send_state(
+ room_id1,
+ EventTypes.Name,
+ {EventContentFields.ROOM_NAME: "my super duper room"},
+ tok=user2_tok,
+ )
+ # Update the room avatar URL
+ self.helper.send_state(
+ room_id1,
+ EventTypes.RoomAvatar,
+ {"url": "mxc://DUMMY_MEDIA_ID_UPDATED"},
+ tok=user2_tok,
+ )
+
+ # Incremental sync
+ response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
+
+ # We should only see changed meta info (the room name and avatar)
+ self.assertNotIn(
+ "initial",
+ response_body["rooms"][room_id1],
+ )
+ self.assertEqual(
+ response_body["rooms"][room_id1]["name"],
+ "my super duper room",
+ response_body["rooms"][room_id1],
+ )
+ self.assertEqual(
+ response_body["rooms"][room_id1]["avatar"],
+ "mxc://DUMMY_MEDIA_ID_UPDATED",
+ response_body["rooms"][room_id1],
+ )
+ self.assertNotIn(
+ "joined_count",
+ response_body["rooms"][room_id1],
+ )
+ self.assertNotIn(
+ "invited_count",
+ response_body["rooms"][room_id1],
+ )
+ self.assertIsNone(
+ response_body["rooms"][room_id1].get("is_dm"),
+ )
+
def test_rooms_meta_when_invited(self) -> None:
"""
Test that the `rooms` `name` and `avatar` are included in the response and
@@ -164,6 +360,7 @@ class SlidingSyncRoomsMetaTestCase(SlidingSyncBase):
# This should still reflect the current state of the room even when the user is
# invited.
+ self.assertEqual(response_body["rooms"][room_id1]["initial"], True)
self.assertEqual(
response_body["rooms"][room_id1]["name"],
"my super duper room",
@@ -174,14 +371,17 @@ class SlidingSyncRoomsMetaTestCase(SlidingSyncBase):
"mxc://UPDATED_DUMMY_MEDIA_ID",
response_body["rooms"][room_id1],
)
- self.assertEqual(
- response_body["rooms"][room_id1]["joined_count"],
- 1,
+
+ # We don't give extra room information to invitees
+ self.assertNotIn(
+ "joined_count",
+ response_body["rooms"][room_id1],
)
- self.assertEqual(
- response_body["rooms"][room_id1]["invited_count"],
- 1,
+ self.assertNotIn(
+ "invited_count",
+ response_body["rooms"][room_id1],
)
+
self.assertIsNone(
response_body["rooms"][room_id1].get("is_dm"),
)
@@ -242,6 +442,7 @@ class SlidingSyncRoomsMetaTestCase(SlidingSyncBase):
response_body, _ = self.do_sync(sync_body, tok=user1_tok)
# Reflect the state of the room at the time of leaving
+ self.assertEqual(response_body["rooms"][room_id1]["initial"], True)
self.assertEqual(
response_body["rooms"][room_id1]["name"],
"my super room",
@@ -252,15 +453,16 @@ class SlidingSyncRoomsMetaTestCase(SlidingSyncBase):
"mxc://DUMMY_MEDIA_ID",
response_body["rooms"][room_id1],
)
- self.assertEqual(
- response_body["rooms"][room_id1]["joined_count"],
- # FIXME: The actual number should be "1" (user2) but we currently don't
- # support this for rooms where the user has left/been banned.
- 0,
+
+ # FIXME: We possibly want to return joined and invited counts for rooms
+ # you're banned form
+ self.assertNotIn(
+ "joined_count",
+ response_body["rooms"][room_id1],
)
- self.assertEqual(
- response_body["rooms"][room_id1]["invited_count"],
- 0,
+ self.assertNotIn(
+ "invited_count",
+ response_body["rooms"][room_id1],
)
self.assertIsNone(
response_body["rooms"][room_id1].get("is_dm"),
@@ -316,6 +518,7 @@ class SlidingSyncRoomsMetaTestCase(SlidingSyncBase):
# Room1 has a name so we shouldn't see any `heroes` which the client would use
# the calculate the room name themselves.
+ self.assertEqual(response_body["rooms"][room_id1]["initial"], True)
self.assertEqual(
response_body["rooms"][room_id1]["name"],
"my super room",
@@ -332,6 +535,7 @@ class SlidingSyncRoomsMetaTestCase(SlidingSyncBase):
)
# Room2 doesn't have a name so we should see `heroes` populated
+ self.assertEqual(response_body["rooms"][room_id2]["initial"], True)
self.assertIsNone(response_body["rooms"][room_id2].get("name"))
self.assertCountEqual(
[
@@ -403,6 +607,7 @@ class SlidingSyncRoomsMetaTestCase(SlidingSyncBase):
response_body, _ = self.do_sync(sync_body, tok=user1_tok)
# Room2 doesn't have a name so we should see `heroes` populated
+ self.assertEqual(response_body["rooms"][room_id1]["initial"], True)
self.assertIsNone(response_body["rooms"][room_id1].get("name"))
self.assertCountEqual(
[
@@ -475,7 +680,8 @@ class SlidingSyncRoomsMetaTestCase(SlidingSyncBase):
}
response_body, _ = self.do_sync(sync_body, tok=user1_tok)
- # Room2 doesn't have a name so we should see `heroes` populated
+ # Room doesn't have a name so we should see `heroes` populated
+ self.assertEqual(response_body["rooms"][room_id1]["initial"], True)
self.assertIsNone(response_body["rooms"][room_id1].get("name"))
self.assertCountEqual(
[
@@ -490,20 +696,175 @@ class SlidingSyncRoomsMetaTestCase(SlidingSyncBase):
[],
)
+ # FIXME: We possibly want to return joined and invited counts for rooms
+ # you're banned form
+ self.assertNotIn(
+ "joined_count",
+ response_body["rooms"][room_id1],
+ )
+ self.assertNotIn(
+ "invited_count",
+ response_body["rooms"][room_id1],
+ )
+
+ def test_rooms_meta_heroes_incremental_sync_no_change(self) -> None:
+ """
+ Test that the `rooms` `heroes` aren't included in an incremental sync
+ response if they haven't changed.
+
+ (when the room doesn't have a room name set)
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+ user3_id = self.register_user("user3", "pass")
+ _user3_tok = self.login(user3_id, "pass")
+
+ room_id = self.helper.create_room_as(
+ user2_id,
+ tok=user2_tok,
+ extra_content={
+ # No room name set so that `heroes` is populated
+ #
+ # "name": "my super room2",
+ },
+ )
+ self.helper.join(room_id, user1_id, tok=user1_tok)
+ # User3 is invited
+ self.helper.invite(room_id, src=user2_id, targ=user3_id, tok=user2_tok)
+
+ # Make the Sliding Sync request
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ "required_state": [],
+ # This needs to be set to one so the `RoomResult` isn't empty and
+ # the room comes down incremental sync when we send a new message.
+ "timeline_limit": 1,
+ }
+ }
+ }
+ response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+
+ # Send a message to make the room come down sync
+ self.helper.send(room_id, "message in room", tok=user2_tok)
+
+ # Incremental sync
+ response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
+
+ # This is an incremental sync and the second time we have seen this room so it
+ # isn't `initial`
+ self.assertNotIn(
+ "initial",
+ response_body["rooms"][room_id],
+ )
+ # Room shouldn't have a room name because we're testing the `heroes` field which
+ # will only has a chance to appear if the room doesn't have a name.
+ self.assertNotIn(
+ "name",
+ response_body["rooms"][room_id],
+ )
+ # No change to heroes
+ self.assertNotIn(
+ "heroes",
+ response_body["rooms"][room_id],
+ )
+ # No change to member counts
+ self.assertNotIn(
+ "joined_count",
+ response_body["rooms"][room_id],
+ )
+ self.assertNotIn(
+ "invited_count",
+ response_body["rooms"][room_id],
+ )
+ # We didn't request any state so we shouldn't see any `required_state`
+ self.assertNotIn(
+ "required_state",
+ response_body["rooms"][room_id],
+ )
+
+ def test_rooms_meta_heroes_incremental_sync_with_membership_change(self) -> None:
+ """
+ Test that the `rooms` `heroes` are included in an incremental sync response if
+ the membership has changed.
+
+ (when the room doesn't have a room name set)
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+ user3_id = self.register_user("user3", "pass")
+ user3_tok = self.login(user3_id, "pass")
+
+ room_id = self.helper.create_room_as(
+ user2_id,
+ tok=user2_tok,
+ extra_content={
+ # No room name set so that `heroes` is populated
+ #
+ # "name": "my super room2",
+ },
+ )
+ self.helper.join(room_id, user1_id, tok=user1_tok)
+ # User3 is invited
+ self.helper.invite(room_id, src=user2_id, targ=user3_id, tok=user2_tok)
+
+ # Make the Sliding Sync request
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ "required_state": [],
+ "timeline_limit": 0,
+ }
+ }
+ }
+ response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+
+ # User3 joins (membership change)
+ self.helper.join(room_id, user3_id, tok=user3_tok)
+
+ # Incremental sync
+ response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
+
+ # This is an incremental sync and the second time we have seen this room so it
+ # isn't `initial`
+ self.assertNotIn(
+ "initial",
+ response_body["rooms"][room_id],
+ )
+ # Room shouldn't have a room name because we're testing the `heroes` field which
+ # will only has a chance to appear if the room doesn't have a name.
+ self.assertNotIn(
+ "name",
+ response_body["rooms"][room_id],
+ )
+ # Membership change so we should see heroes and membership counts
+ self.assertCountEqual(
+ [
+ hero["user_id"]
+ for hero in response_body["rooms"][room_id].get("heroes", [])
+ ],
+ # Heroes shouldn't include the user themselves (we shouldn't see user1)
+ [user2_id, user3_id],
+ )
self.assertEqual(
- response_body["rooms"][room_id1]["joined_count"],
- # FIXME: The actual number should be "1" (user2) but we currently don't
- # support this for rooms where the user has left/been banned.
- 0,
+ response_body["rooms"][room_id]["joined_count"],
+ 3,
)
self.assertEqual(
- response_body["rooms"][room_id1]["invited_count"],
- # We shouldn't see user5 since they were invited after user1 was banned.
- #
- # FIXME: The actual number should be "1" (user3) but we currently don't
- # support this for rooms where the user has left/been banned.
+ response_body["rooms"][room_id]["invited_count"],
0,
)
+ # We didn't request any state so we shouldn't see any `required_state`
+ self.assertNotIn(
+ "required_state",
+ response_body["rooms"][room_id],
+ )
def test_rooms_bump_stamp(self) -> None:
"""
@@ -566,19 +927,17 @@ class SlidingSyncRoomsMetaTestCase(SlidingSyncBase):
)
# Make sure the list includes the rooms in the right order
- self.assertListEqual(
- list(response_body["lists"]["foo-list"]["ops"]),
- [
- {
- "op": "SYNC",
- "range": [0, 1],
- # room1 sorts before room2 because it has the latest event (the
- # reaction)
- "room_ids": [room_id1, room_id2],
- }
- ],
+ self.assertEqual(
+ len(response_body["lists"]["foo-list"]["ops"]),
+ 1,
response_body["lists"]["foo-list"],
)
+ op = response_body["lists"]["foo-list"]["ops"][0]
+ self.assertEqual(op["op"], "SYNC")
+ self.assertEqual(op["range"], [0, 1])
+ # Note that we don't sort the rooms when the range includes all of the rooms, so
+ # we just assert that the rooms are included
+ self.assertIncludes(set(op["room_ids"]), {room_id1, room_id2}, exact=True)
# The `bump_stamp` for room1 should point at the latest message (not the
# reaction since it's not one of the `DEFAULT_BUMP_EVENT_TYPES`)
@@ -600,16 +959,16 @@ class SlidingSyncRoomsMetaTestCase(SlidingSyncBase):
Test that `bump_stamp` ignores backfilled events, i.e. events with a
negative stream ordering.
"""
-
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
# Create a remote room
creator = "@user:other"
room_id = "!foo:other"
+ room_version = RoomVersions.V10
shared_kwargs = {
"room_id": room_id,
- "room_version": "10",
+ "room_version": room_version.identifier,
}
create_tuple = self.get_success(
@@ -618,6 +977,12 @@ class SlidingSyncRoomsMetaTestCase(SlidingSyncBase):
prev_event_ids=[],
type=EventTypes.Create,
state_key="",
+ content={
+ # The `ROOM_CREATOR` field could be removed if we used a room
+ # version > 10 (in favor of relying on `sender`)
+ EventContentFields.ROOM_CREATOR: creator,
+ EventContentFields.ROOM_VERSION: room_version.identifier,
+ },
sender=creator,
**shared_kwargs,
)
@@ -667,22 +1032,29 @@ class SlidingSyncRoomsMetaTestCase(SlidingSyncBase):
]
# Ensure the local HS knows the room version
- self.get_success(
- self.store.store_room(room_id, creator, False, RoomVersions.V10)
- )
+ self.get_success(self.store.store_room(room_id, creator, False, room_version))
# Persist these events as backfilled events.
- persistence = self.hs.get_storage_controllers().persistence
- assert persistence is not None
-
for event, context in remote_events_and_contexts:
- self.get_success(persistence.persist_event(event, context, backfilled=True))
+ self.get_success(
+ self.persistence.persist_event(event, context, backfilled=True)
+ )
- # Now we join the local user to the room
- join_tuple = self.get_success(
+ # Now we join the local user to the room. We want to make this feel as close to
+ # the real `process_remote_join()` as possible but we'd like to avoid some of
+ # the auth checks that would be done in the real code.
+ #
+ # FIXME: The test was originally written using this less-real
+ # `persist_event(...)` shortcut but it would be nice to use the real remote join
+ # process in a `FederatingHomeserverTestCase`.
+ flawed_join_tuple = self.get_success(
create_event(
self.hs,
prev_event_ids=[invite_tuple[0].event_id],
+ # This doesn't work correctly to create an `EventContext` that includes
+ # both of these state events. I assume it's because we're working on our
+ # local homeserver which has the remote state set as `outlier`. We have
+ # to create our own EventContext below to get this right.
auth_event_ids=[create_tuple[0].event_id, invite_tuple[0].event_id],
type=EventTypes.Member,
state_key=user1_id,
@@ -691,7 +1063,22 @@ class SlidingSyncRoomsMetaTestCase(SlidingSyncBase):
**shared_kwargs,
)
)
- self.get_success(persistence.persist_event(*join_tuple))
+ # We have to create our own context to get the state set correctly. If we use
+ # the `EventContext` from the `flawed_join_tuple`, the `current_state_events`
+ # table will only have the join event in it which should never happen in our
+ # real server.
+ join_event = flawed_join_tuple[0]
+ join_context = self.get_success(
+ self.state_handler.compute_event_context(
+ join_event,
+ state_ids_before_event={
+ (e.type, e.state_key): e.event_id
+ for e in [create_tuple[0], invite_tuple[0]]
+ },
+ partial_state=False,
+ )
+ )
+ self.get_success(self.persistence.persist_event(join_event, join_context))
# Doing an SS request should return a positive `bump_stamp`, even though
# the only event that matches the bump types has as negative stream
@@ -708,3 +1095,244 @@ class SlidingSyncRoomsMetaTestCase(SlidingSyncBase):
response_body, _ = self.do_sync(sync_body, tok=user1_tok)
self.assertGreater(response_body["rooms"][room_id]["bump_stamp"], 0)
+
+ def test_rooms_bump_stamp_no_change_incremental(self) -> None:
+ """Test that the bump stamp is omitted if there has been no change"""
+
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ room_id1 = self.helper.create_room_as(
+ user1_id,
+ tok=user1_tok,
+ )
+
+ # Make the Sliding Sync request
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ "required_state": [],
+ "timeline_limit": 100,
+ }
+ }
+ }
+ response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+
+ # Initial sync so we expect to see a bump stamp
+ self.assertIn("bump_stamp", response_body["rooms"][room_id1])
+
+ # Send an event that is not in the bump events list
+ self.helper.send_event(
+ room_id1, type="org.matrix.test", content={}, tok=user1_tok
+ )
+
+ response_body, from_token = self.do_sync(
+ sync_body, since=from_token, tok=user1_tok
+ )
+
+ # There hasn't been a change to the bump stamps, so we ignore it
+ self.assertNotIn("bump_stamp", response_body["rooms"][room_id1])
+
+ def test_rooms_bump_stamp_change_incremental(self) -> None:
+ """Test that the bump stamp is included if there has been a change, even
+ if its not in the timeline"""
+
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ room_id1 = self.helper.create_room_as(
+ user1_id,
+ tok=user1_tok,
+ )
+
+ # Make the Sliding Sync request
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ "required_state": [],
+ "timeline_limit": 2,
+ }
+ }
+ }
+ response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+
+ # Initial sync so we expect to see a bump stamp
+ self.assertIn("bump_stamp", response_body["rooms"][room_id1])
+ first_bump_stamp = response_body["rooms"][room_id1]["bump_stamp"]
+
+ # Send a bump event at the start.
+ self.helper.send(room_id1, "test", tok=user1_tok)
+
+ # Send events that are not in the bump events list to fill the timeline
+ for _ in range(5):
+ self.helper.send_event(
+ room_id1, type="org.matrix.test", content={}, tok=user1_tok
+ )
+
+ response_body, from_token = self.do_sync(
+ sync_body, since=from_token, tok=user1_tok
+ )
+
+ # There was a bump event in the timeline gap, so we should see the bump
+ # stamp be updated.
+ self.assertIn("bump_stamp", response_body["rooms"][room_id1])
+ second_bump_stamp = response_body["rooms"][room_id1]["bump_stamp"]
+
+ self.assertGreater(second_bump_stamp, first_bump_stamp)
+
+ def test_rooms_bump_stamp_invites(self) -> None:
+ """
+ Test that `bump_stamp` is present and points to the membership event,
+ and not later events, for non-joined rooms
+ """
+
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ room_id = self.helper.create_room_as(
+ user2_id,
+ tok=user2_tok,
+ )
+
+ # Invite user1 to the room
+ invite_response = self.helper.invite(room_id, user2_id, user1_id, tok=user2_tok)
+
+ # More messages happen after the invite
+ self.helper.send(room_id, "message in room1", tok=user2_tok)
+
+ # We expect the bump_stamp to match the invite.
+ invite_pos = self.get_success(
+ self.store.get_position_for_event(invite_response["event_id"])
+ )
+
+ # Doing an SS request should return a `bump_stamp` of the invite event,
+ # rather than the message that was sent after.
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ "required_state": [],
+ "timeline_limit": 5,
+ }
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+
+ self.assertEqual(
+ response_body["rooms"][room_id]["bump_stamp"], invite_pos.stream
+ )
+
+ def test_rooms_meta_is_dm(self) -> None:
+ """
+ Test `rooms` `is_dm` is correctly set for DM rooms.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ # Create a DM room
+ joined_dm_room_id = self._create_dm_room(
+ inviter_user_id=user1_id,
+ inviter_tok=user1_tok,
+ invitee_user_id=user2_id,
+ invitee_tok=user2_tok,
+ should_join_room=True,
+ )
+ invited_dm_room_id = self._create_dm_room(
+ inviter_user_id=user1_id,
+ inviter_tok=user1_tok,
+ invitee_user_id=user2_id,
+ invitee_tok=user2_tok,
+ should_join_room=False,
+ )
+
+ # Create a normal room
+ room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
+ self.helper.join(room_id, user1_id, tok=user1_tok)
+
+ # Create a room that user1 is invited to
+ invite_room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
+ self.helper.invite(invite_room_id, src=user2_id, targ=user1_id, tok=user2_tok)
+
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ }
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+
+ # Ensure DM's are correctly marked
+ self.assertDictEqual(
+ {
+ room_id: room.get("is_dm")
+ for room_id, room in response_body["rooms"].items()
+ },
+ {
+ invite_room_id: None,
+ room_id: None,
+ invited_dm_room_id: True,
+ joined_dm_room_id: True,
+ },
+ )
+
+ def test_old_room_with_unknown_room_version(self) -> None:
+ """Test that an old room with unknown room version does not break
+ sync."""
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ # We first create a standard room, then we'll change the room version in
+ # the DB.
+ room_id = self.helper.create_room_as(
+ user1_id,
+ tok=user1_tok,
+ )
+
+ # Poke the database and update the room version to an unknown one.
+ self.get_success(
+ self.hs.get_datastores().main.db_pool.simple_update(
+ "rooms",
+ keyvalues={"room_id": room_id},
+ updatevalues={"room_version": "unknown-room-version"},
+ desc="updated-room-version",
+ )
+ )
+
+ # Invalidate method so that it returns the currently updated version
+ # instead of the cached version.
+ self.hs.get_datastores().main.get_room_version_id.invalidate((room_id,))
+
+ # For old unknown room versions we won't have an entry in this table
+ # (due to us skipping unknown room versions in the background update).
+ self.get_success(
+ self.store.db_pool.simple_delete(
+ table="sliding_sync_joined_rooms",
+ keyvalues={"room_id": room_id},
+ desc="delete_sliding_room",
+ )
+ )
+
+ # Also invalidate some caches to ensure we pull things from the DB.
+ self.store._events_stream_cache._entity_to_key.pop(room_id)
+ self.store._get_max_event_pos.invalidate((room_id,))
+
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ "required_state": [],
+ "timeline_limit": 5,
+ }
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
diff --git a/tests/rest/client/sliding_sync/test_rooms_required_state.py b/tests/rest/client/sliding_sync/test_rooms_required_state.py
index a13cad223f..ba46c5a93c 100644
--- a/tests/rest/client/sliding_sync/test_rooms_required_state.py
+++ b/tests/rest/client/sliding_sync/test_rooms_required_state.py
@@ -11,16 +11,17 @@
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
+import enum
import logging
-from parameterized import parameterized
+from parameterized import parameterized, parameterized_class
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EventContentFields, EventTypes, JoinRules, Membership
from synapse.handlers.sliding_sync import StateValues
-from synapse.rest.client import login, room, sync
+from synapse.rest.client import knock, login, room, sync
from synapse.server import HomeServer
from synapse.util import Clock
@@ -30,6 +31,31 @@ from tests.test_utils.event_injection import mark_event_as_partial_state
logger = logging.getLogger(__name__)
+# Inherit from `str` so that they show up in the test description when we
+# `@parameterized.expand(...)` the first parameter
+class MembershipAction(str, enum.Enum):
+ INVITE = "invite"
+ JOIN = "join"
+ KNOCK = "knock"
+ LEAVE = "leave"
+ BAN = "ban"
+ KICK = "kick"
+
+
+# FIXME: This can be removed once we bump `SCHEMA_COMPAT_VERSION` and run the
+# foreground update for
+# `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` (tracked by
+# https://github.com/element-hq/synapse/issues/17623)
+@parameterized_class(
+ ("use_new_tables",),
+ [
+ (True,),
+ (False,),
+ ],
+ class_name_func=lambda cls,
+ num,
+ params_dict: f"{cls.__name__}_{'new' if params_dict['use_new_tables'] else 'fallback'}",
+)
class SlidingSyncRoomsRequiredStateTestCase(SlidingSyncBase):
"""
Test `rooms.required_state` in the Sliding Sync API.
@@ -38,6 +64,7 @@ class SlidingSyncRoomsRequiredStateTestCase(SlidingSyncBase):
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
+ knock.register_servlets,
room.register_servlets,
sync.register_servlets,
]
@@ -46,6 +73,8 @@ class SlidingSyncRoomsRequiredStateTestCase(SlidingSyncBase):
self.store = hs.get_datastores().main
self.storage_controllers = hs.get_storage_controllers()
+ super().prepare(reactor, clock, hs)
+
def test_rooms_no_required_state(self) -> None:
"""
Empty `rooms.required_state` should not return any state events in the room
@@ -191,8 +220,14 @@ class SlidingSyncRoomsRequiredStateTestCase(SlidingSyncBase):
}
_, from_token = self.do_sync(sync_body, tok=user1_tok)
- # Reset the in-memory cache
- self.hs.get_sliding_sync_handler().connection_store._connections.clear()
+ # Reset the positions
+ self.get_success(
+ self.store.db_pool.simple_delete(
+ table="sliding_sync_connections",
+ keyvalues={"user_id": user1_id},
+ desc="clear_sliding_sync_connections_cache",
+ )
+ )
# Make the Sliding Sync request
channel = self.make_request(
@@ -359,10 +394,10 @@ class SlidingSyncRoomsRequiredStateTestCase(SlidingSyncBase):
)
self.assertIsNone(response_body["rooms"][room_id1].get("invite_state"))
- def test_rooms_required_state_lazy_loading_room_members(self) -> None:
+ def test_rooms_required_state_lazy_loading_room_members_initial_sync(self) -> None:
"""
- Test `rooms.required_state` returns people relevant to the timeline when
- lazy-loading room members, `["m.room.member","$LAZY"]`.
+ On initial sync, test `rooms.required_state` returns people relevant to the
+ timeline when lazy-loading room members, `["m.room.member","$LAZY"]`.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
@@ -410,6 +445,402 @@ class SlidingSyncRoomsRequiredStateTestCase(SlidingSyncBase):
)
self.assertIsNone(response_body["rooms"][room_id1].get("invite_state"))
+ def test_rooms_required_state_lazy_loading_room_members_incremental_sync(
+ self,
+ ) -> None:
+ """
+ On incremental sync, test `rooms.required_state` returns people relevant to the
+ timeline when lazy-loading room members, `["m.room.member","$LAZY"]`.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+ user3_id = self.register_user("user3", "pass")
+ user3_tok = self.login(user3_id, "pass")
+ user4_id = self.register_user("user4", "pass")
+ user4_tok = self.login(user4_id, "pass")
+
+ room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
+ self.helper.join(room_id1, user1_id, tok=user1_tok)
+ self.helper.join(room_id1, user3_id, tok=user3_tok)
+ self.helper.join(room_id1, user4_id, tok=user4_tok)
+
+ self.helper.send(room_id1, "1", tok=user2_tok)
+ self.helper.send(room_id1, "2", tok=user2_tok)
+ self.helper.send(room_id1, "3", tok=user2_tok)
+
+ # Make the Sliding Sync request with lazy loading for the room members
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ "required_state": [
+ [EventTypes.Create, ""],
+ [EventTypes.Member, StateValues.LAZY],
+ ],
+ "timeline_limit": 3,
+ }
+ }
+ }
+ response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+
+ # Send more timeline events into the room
+ self.helper.send(room_id1, "4", tok=user2_tok)
+ self.helper.send(room_id1, "5", tok=user4_tok)
+ self.helper.send(room_id1, "6", tok=user4_tok)
+
+ # Make an incremental Sliding Sync request
+ response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
+
+ state_map = self.get_success(
+ self.storage_controllers.state.get_current_state(room_id1)
+ )
+
+ # Only user2 and user4 sent events in the last 3 events we see in the `timeline`
+ # but since we've seen user2 in the last sync (and their membership hasn't
+ # changed), we should only see user4 here.
+ self._assertRequiredStateIncludes(
+ response_body["rooms"][room_id1]["required_state"],
+ {
+ state_map[(EventTypes.Member, user4_id)],
+ },
+ exact=True,
+ )
+ self.assertIsNone(response_body["rooms"][room_id1].get("invite_state"))
+
+ @parameterized.expand(
+ [
+ (MembershipAction.LEAVE,),
+ (MembershipAction.INVITE,),
+ (MembershipAction.KNOCK,),
+ (MembershipAction.JOIN,),
+ (MembershipAction.BAN,),
+ (MembershipAction.KICK,),
+ ]
+ )
+ def test_rooms_required_state_changed_membership_in_timeline_lazy_loading_room_members_incremental_sync(
+ self,
+ room_membership_action: str,
+ ) -> None:
+ """
+ On incremental sync, test `rooms.required_state` returns people relevant to the
+ timeline when lazy-loading room members, `["m.room.member","$LAZY"]` **including
+ changes to membership**.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+ user3_id = self.register_user("user3", "pass")
+ user3_tok = self.login(user3_id, "pass")
+ user4_id = self.register_user("user4", "pass")
+ user4_tok = self.login(user4_id, "pass")
+ user5_id = self.register_user("user5", "pass")
+ user5_tok = self.login(user5_id, "pass")
+
+ room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True)
+ # If we're testing knocks, set the room to knock
+ if room_membership_action == MembershipAction.KNOCK:
+ self.helper.send_state(
+ room_id1,
+ EventTypes.JoinRules,
+ {"join_rule": JoinRules.KNOCK},
+ tok=user2_tok,
+ )
+
+ # Join the test users to the room
+ self.helper.invite(room_id1, src=user2_id, targ=user1_id, tok=user2_tok)
+ self.helper.join(room_id1, user1_id, tok=user1_tok)
+ self.helper.invite(room_id1, src=user2_id, targ=user3_id, tok=user2_tok)
+ self.helper.join(room_id1, user3_id, tok=user3_tok)
+ self.helper.invite(room_id1, src=user2_id, targ=user4_id, tok=user2_tok)
+ self.helper.join(room_id1, user4_id, tok=user4_tok)
+ if room_membership_action in (
+ MembershipAction.LEAVE,
+ MembershipAction.BAN,
+ MembershipAction.JOIN,
+ ):
+ self.helper.invite(room_id1, src=user2_id, targ=user5_id, tok=user2_tok)
+ self.helper.join(room_id1, user5_id, tok=user5_tok)
+
+ # Send some messages to fill up the space
+ self.helper.send(room_id1, "1", tok=user2_tok)
+ self.helper.send(room_id1, "2", tok=user2_tok)
+ self.helper.send(room_id1, "3", tok=user2_tok)
+
+ # Make the Sliding Sync request with lazy loading for the room members
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ "required_state": [
+ [EventTypes.Create, ""],
+ [EventTypes.Member, StateValues.LAZY],
+ ],
+ "timeline_limit": 3,
+ }
+ }
+ }
+ response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+
+ # Send more timeline events into the room
+ self.helper.send(room_id1, "4", tok=user2_tok)
+ self.helper.send(room_id1, "5", tok=user4_tok)
+ # The third event will be our membership event concerning user5
+ if room_membership_action == MembershipAction.LEAVE:
+ # User 5 leaves
+ self.helper.leave(room_id1, user5_id, tok=user5_tok)
+ elif room_membership_action == MembershipAction.INVITE:
+ # User 5 is invited
+ self.helper.invite(room_id1, src=user2_id, targ=user5_id, tok=user2_tok)
+ elif room_membership_action == MembershipAction.KNOCK:
+ # User 5 knocks
+ self.helper.knock(room_id1, user5_id, tok=user5_tok)
+ # The admin of the room accepts the knock
+ self.helper.invite(room_id1, src=user2_id, targ=user5_id, tok=user2_tok)
+ elif room_membership_action == MembershipAction.JOIN:
+ # Update the display name of user5 (causing a membership change)
+ self.helper.send_state(
+ room_id1,
+ event_type=EventTypes.Member,
+ state_key=user5_id,
+ body={
+ EventContentFields.MEMBERSHIP: Membership.JOIN,
+ EventContentFields.MEMBERSHIP_DISPLAYNAME: "quick changer",
+ },
+ tok=user5_tok,
+ )
+ elif room_membership_action == MembershipAction.BAN:
+ self.helper.ban(room_id1, src=user2_id, targ=user5_id, tok=user2_tok)
+ elif room_membership_action == MembershipAction.KICK:
+ # Kick user5 from the room
+ self.helper.change_membership(
+ room=room_id1,
+ src=user2_id,
+ targ=user5_id,
+ tok=user2_tok,
+ membership=Membership.LEAVE,
+ extra_data={
+ "reason": "Bad manners",
+ },
+ )
+ else:
+ raise AssertionError(
+ f"Unknown room_membership_action: {room_membership_action}"
+ )
+
+ # Make an incremental Sliding Sync request
+ response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
+
+ state_map = self.get_success(
+ self.storage_controllers.state.get_current_state(room_id1)
+ )
+
+ # Only user2, user4, and user5 sent events in the last 3 events we see in the
+ # `timeline`.
+ self._assertRequiredStateIncludes(
+ response_body["rooms"][room_id1]["required_state"],
+ {
+ # This appears because *some* membership in the room changed and the
+ # heroes are recalculated and is thrown in because we have it. But this
+ # is technically optional and not needed because we've already seen user2
+ # in the last sync (and their membership hasn't changed).
+ state_map[(EventTypes.Member, user2_id)],
+ # Appears because there is a message in the timeline from this user
+ state_map[(EventTypes.Member, user4_id)],
+ # Appears because there is a membership event in the timeline from this user
+ state_map[(EventTypes.Member, user5_id)],
+ },
+ exact=True,
+ )
+ self.assertIsNone(response_body["rooms"][room_id1].get("invite_state"))
+
+ def test_rooms_required_state_expand_lazy_loading_room_members_incremental_sync(
+ self,
+ ) -> None:
+ """
+ Test that when we expand the `required_state` to include lazy-loading room
+ members, it returns people relevant to the timeline.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+ user3_id = self.register_user("user3", "pass")
+ user3_tok = self.login(user3_id, "pass")
+ user4_id = self.register_user("user4", "pass")
+ user4_tok = self.login(user4_id, "pass")
+
+ room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
+ self.helper.join(room_id1, user1_id, tok=user1_tok)
+ self.helper.join(room_id1, user3_id, tok=user3_tok)
+ self.helper.join(room_id1, user4_id, tok=user4_tok)
+
+ self.helper.send(room_id1, "1", tok=user2_tok)
+ self.helper.send(room_id1, "2", tok=user2_tok)
+ self.helper.send(room_id1, "3", tok=user2_tok)
+
+ # Make the Sliding Sync request *without* lazy loading for the room members
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ "required_state": [
+ [EventTypes.Create, ""],
+ ],
+ "timeline_limit": 3,
+ }
+ }
+ }
+ response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+
+ # Send more timeline events into the room
+ self.helper.send(room_id1, "4", tok=user2_tok)
+ self.helper.send(room_id1, "5", tok=user4_tok)
+ self.helper.send(room_id1, "6", tok=user4_tok)
+
+ # Expand `required_state` and make an incremental Sliding Sync request *with*
+ # lazy-loading room members
+ sync_body["lists"]["foo-list"]["required_state"] = [
+ [EventTypes.Create, ""],
+ [EventTypes.Member, StateValues.LAZY],
+ ]
+ response_body, from_token = self.do_sync(
+ sync_body, since=from_token, tok=user1_tok
+ )
+
+ state_map = self.get_success(
+ self.storage_controllers.state.get_current_state(room_id1)
+ )
+
+ # Only user2 and user4 sent events in the last 3 events we see in the `timeline`
+ # and we haven't seen any membership before this sync so we should see both
+ # users.
+ self._assertRequiredStateIncludes(
+ response_body["rooms"][room_id1]["required_state"],
+ {
+ state_map[(EventTypes.Member, user2_id)],
+ state_map[(EventTypes.Member, user4_id)],
+ },
+ exact=True,
+ )
+ self.assertIsNone(response_body["rooms"][room_id1].get("invite_state"))
+
+ # Send a message so the room comes down sync.
+ self.helper.send(room_id1, "7", tok=user2_tok)
+ self.helper.send(room_id1, "8", tok=user4_tok)
+ self.helper.send(room_id1, "9", tok=user4_tok)
+
+ # Make another incremental Sliding Sync request
+ response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
+
+ # Only user2 and user4 sent events in the last 3 events we see in the `timeline`
+ # but since we've seen both memberships in the last sync, they shouldn't appear
+ # again.
+ self._assertRequiredStateIncludes(
+ response_body["rooms"][room_id1].get("required_state", []),
+ set(),
+ exact=True,
+ )
+ self.assertIsNone(response_body["rooms"][room_id1].get("invite_state"))
+
+ def test_rooms_required_state_expand_retract_expand_lazy_loading_room_members_incremental_sync(
+ self,
+ ) -> None:
+ """
+ Test that when we expand the `required_state` to include lazy-loading room
+ members, it returns people relevant to the timeline.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+ user3_id = self.register_user("user3", "pass")
+ user3_tok = self.login(user3_id, "pass")
+ user4_id = self.register_user("user4", "pass")
+ user4_tok = self.login(user4_id, "pass")
+
+ room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
+ self.helper.join(room_id1, user1_id, tok=user1_tok)
+ self.helper.join(room_id1, user3_id, tok=user3_tok)
+ self.helper.join(room_id1, user4_id, tok=user4_tok)
+
+ self.helper.send(room_id1, "1", tok=user2_tok)
+ self.helper.send(room_id1, "2", tok=user2_tok)
+ self.helper.send(room_id1, "3", tok=user2_tok)
+
+ # Make the Sliding Sync request *without* lazy loading for the room members
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ "required_state": [
+ [EventTypes.Create, ""],
+ ],
+ "timeline_limit": 3,
+ }
+ }
+ }
+ response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+
+ # Send more timeline events into the room
+ self.helper.send(room_id1, "4", tok=user2_tok)
+ self.helper.send(room_id1, "5", tok=user4_tok)
+ self.helper.send(room_id1, "6", tok=user4_tok)
+
+ # Expand `required_state` and make an incremental Sliding Sync request *with*
+ # lazy-loading room members
+ sync_body["lists"]["foo-list"]["required_state"] = [
+ [EventTypes.Create, ""],
+ [EventTypes.Member, StateValues.LAZY],
+ ]
+ response_body, from_token = self.do_sync(
+ sync_body, since=from_token, tok=user1_tok
+ )
+
+ state_map = self.get_success(
+ self.storage_controllers.state.get_current_state(room_id1)
+ )
+
+ # Only user2 and user4 sent events in the last 3 events we see in the `timeline`
+ # and we haven't seen any membership before this sync so we should see both
+ # users because we're lazy-loading the room members.
+ self._assertRequiredStateIncludes(
+ response_body["rooms"][room_id1]["required_state"],
+ {
+ state_map[(EventTypes.Member, user2_id)],
+ state_map[(EventTypes.Member, user4_id)],
+ },
+ exact=True,
+ )
+
+ # Send a message so the room comes down sync.
+ self.helper.send(room_id1, "msg", tok=user4_tok)
+
+ # Retract `required_state` and make an incremental Sliding Sync request
+ # requesting a few memberships
+ sync_body["lists"]["foo-list"]["required_state"] = [
+ [EventTypes.Create, ""],
+ [EventTypes.Member, StateValues.ME],
+ [EventTypes.Member, user2_id],
+ ]
+ response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
+
+ state_map = self.get_success(
+ self.storage_controllers.state.get_current_state(room_id1)
+ )
+
+ # We've seen user2's membership in the last sync so we shouldn't see it here
+ # even though it's requested. We should only see user1's membership.
+ self._assertRequiredStateIncludes(
+ response_body["rooms"][room_id1]["required_state"],
+ {
+ state_map[(EventTypes.Member, user1_id)],
+ },
+ exact=True,
+ )
+
def test_rooms_required_state_me(self) -> None:
"""
Test `rooms.required_state` correctly handles $ME.
@@ -480,9 +911,10 @@ class SlidingSyncRoomsRequiredStateTestCase(SlidingSyncBase):
self.assertIsNone(response_body["rooms"][room_id1].get("invite_state"))
@parameterized.expand([(Membership.LEAVE,), (Membership.BAN,)])
- def test_rooms_required_state_leave_ban(self, stop_membership: str) -> None:
+ def test_rooms_required_state_leave_ban_initial(self, stop_membership: str) -> None:
"""
- Test `rooms.required_state` should not return state past a leave/ban event.
+ Test `rooms.required_state` should not return state past a leave/ban event when
+ it's the first "initial" time the room is being sent down the connection.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
@@ -517,6 +949,13 @@ class SlidingSyncRoomsRequiredStateTestCase(SlidingSyncBase):
body={"foo": "bar"},
tok=user2_tok,
)
+ self.helper.send_state(
+ room_id1,
+ event_type="org.matrix.bar_state",
+ state_key="",
+ body={"bar": "bar"},
+ tok=user2_tok,
+ )
if stop_membership == Membership.LEAVE:
# User 1 leaves
@@ -525,6 +964,8 @@ class SlidingSyncRoomsRequiredStateTestCase(SlidingSyncBase):
# User 1 is banned
self.helper.ban(room_id1, src=user2_id, targ=user1_id, tok=user2_tok)
+ # Get the state_map before we change the state as this is the final state we
+ # expect User1 to be able to see
state_map = self.get_success(
self.storage_controllers.state.get_current_state(room_id1)
)
@@ -537,12 +978,36 @@ class SlidingSyncRoomsRequiredStateTestCase(SlidingSyncBase):
body={"foo": "qux"},
tok=user2_tok,
)
+ self.helper.send_state(
+ room_id1,
+ event_type="org.matrix.bar_state",
+ state_key="",
+ body={"bar": "qux"},
+ tok=user2_tok,
+ )
self.helper.leave(room_id1, user3_id, tok=user3_tok)
- # Make the Sliding Sync request with lazy loading for the room members
+ # Make an incremental Sliding Sync request
+ #
+ # Also expand the required state to include the `org.matrix.bar_state` event.
+ # This is just an extra complication of the test.
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ "required_state": [
+ [EventTypes.Create, ""],
+ [EventTypes.Member, "*"],
+ ["org.matrix.foo_state", ""],
+ ["org.matrix.bar_state", ""],
+ ],
+ "timeline_limit": 3,
+ }
+ }
+ }
response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
- # Only user2 and user3 sent events in the 3 events we see in the `timeline`
+ # We should only see the state up to the leave/ban event
self._assertRequiredStateIncludes(
response_body["rooms"][room_id1]["required_state"],
{
@@ -551,6 +1016,126 @@ class SlidingSyncRoomsRequiredStateTestCase(SlidingSyncBase):
state_map[(EventTypes.Member, user2_id)],
state_map[(EventTypes.Member, user3_id)],
state_map[("org.matrix.foo_state", "")],
+ state_map[("org.matrix.bar_state", "")],
+ },
+ exact=True,
+ )
+ self.assertIsNone(response_body["rooms"][room_id1].get("invite_state"))
+
+ @parameterized.expand([(Membership.LEAVE,), (Membership.BAN,)])
+ def test_rooms_required_state_leave_ban_incremental(
+ self, stop_membership: str
+ ) -> None:
+ """
+ Test `rooms.required_state` should not return state past a leave/ban event on
+ incremental sync.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+ user3_id = self.register_user("user3", "pass")
+ user3_tok = self.login(user3_id, "pass")
+
+ room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
+ self.helper.join(room_id1, user1_id, tok=user1_tok)
+ self.helper.join(room_id1, user3_id, tok=user3_tok)
+
+ self.helper.send_state(
+ room_id1,
+ event_type="org.matrix.foo_state",
+ state_key="",
+ body={"foo": "bar"},
+ tok=user2_tok,
+ )
+ self.helper.send_state(
+ room_id1,
+ event_type="org.matrix.bar_state",
+ state_key="",
+ body={"bar": "bar"},
+ tok=user2_tok,
+ )
+
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ "required_state": [
+ [EventTypes.Create, ""],
+ [EventTypes.Member, "*"],
+ ["org.matrix.foo_state", ""],
+ ],
+ "timeline_limit": 3,
+ }
+ }
+ }
+ _, from_token = self.do_sync(sync_body, tok=user1_tok)
+
+ if stop_membership == Membership.LEAVE:
+ # User 1 leaves
+ self.helper.leave(room_id1, user1_id, tok=user1_tok)
+ elif stop_membership == Membership.BAN:
+ # User 1 is banned
+ self.helper.ban(room_id1, src=user2_id, targ=user1_id, tok=user2_tok)
+
+ # Get the state_map before we change the state as this is the final state we
+ # expect User1 to be able to see
+ state_map = self.get_success(
+ self.storage_controllers.state.get_current_state(room_id1)
+ )
+
+ # Change the state after user 1 leaves
+ self.helper.send_state(
+ room_id1,
+ event_type="org.matrix.foo_state",
+ state_key="",
+ body={"foo": "qux"},
+ tok=user2_tok,
+ )
+ self.helper.send_state(
+ room_id1,
+ event_type="org.matrix.bar_state",
+ state_key="",
+ body={"bar": "qux"},
+ tok=user2_tok,
+ )
+ self.helper.leave(room_id1, user3_id, tok=user3_tok)
+
+ # Make an incremental Sliding Sync request
+ #
+ # Also expand the required state to include the `org.matrix.bar_state` event.
+ # This is just an extra complication of the test.
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ "required_state": [
+ [EventTypes.Create, ""],
+ [EventTypes.Member, "*"],
+ ["org.matrix.foo_state", ""],
+ ["org.matrix.bar_state", ""],
+ ],
+ "timeline_limit": 3,
+ }
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
+
+ # User1 should only see the state up to the leave/ban event
+ self._assertRequiredStateIncludes(
+ response_body["rooms"][room_id1]["required_state"],
+ {
+ # User1 should see their leave/ban membership
+ state_map[(EventTypes.Member, user1_id)],
+ state_map[("org.matrix.bar_state", "")],
+ # The commented out state events were already returned in the initial
+ # sync so we shouldn't see them again on the incremental sync. And we
+ # shouldn't see the state events that changed after the leave/ban event.
+ #
+ # state_map[(EventTypes.Create, "")],
+ # state_map[(EventTypes.Member, user2_id)],
+ # state_map[(EventTypes.Member, user3_id)],
+ # state_map[("org.matrix.foo_state", "")],
},
exact=True,
)
@@ -631,8 +1216,7 @@ class SlidingSyncRoomsRequiredStateTestCase(SlidingSyncBase):
def test_rooms_required_state_partial_state(self) -> None:
"""
- Test partially-stated room are excluded unless `rooms.required_state` is
- lazy-loading room members.
+ Test partially-stated room are excluded if they require full state.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
@@ -649,13 +1233,63 @@ class SlidingSyncRoomsRequiredStateTestCase(SlidingSyncBase):
mark_event_as_partial_state(self.hs, join_response2["event_id"], room_id2)
)
- # Make the Sliding Sync request (NOT lazy-loading room members)
+ # Make the Sliding Sync request with examples where `must_await_full_state()` is
+ # `False`
sync_body = {
"lists": {
- "foo-list": {
+ "no-state-list": {
+ "ranges": [[0, 1]],
+ "required_state": [],
+ "timeline_limit": 0,
+ },
+ "other-state-list": {
+ "ranges": [[0, 1]],
+ "required_state": [
+ [EventTypes.Create, ""],
+ ],
+ "timeline_limit": 0,
+ },
+ "lazy-load-list": {
"ranges": [[0, 1]],
"required_state": [
[EventTypes.Create, ""],
+ # Lazy-load room members
+ [EventTypes.Member, StateValues.LAZY],
+ # Local member
+ [EventTypes.Member, user2_id],
+ ],
+ "timeline_limit": 0,
+ },
+ "local-members-only-list": {
+ "ranges": [[0, 1]],
+ "required_state": [
+ # Own user ID
+ [EventTypes.Member, user1_id],
+ # Local member
+ [EventTypes.Member, user2_id],
+ ],
+ "timeline_limit": 0,
+ },
+ "me-list": {
+ "ranges": [[0, 1]],
+ "required_state": [
+ # Own user ID
+ [EventTypes.Member, StateValues.ME],
+ # Local member
+ [EventTypes.Member, user2_id],
+ ],
+ "timeline_limit": 0,
+ },
+ "wildcard-type-local-state-key-list": {
+ "ranges": [[0, 1]],
+ "required_state": [
+ ["*", user1_id],
+ # Not a user ID
+ ["*", "foobarbaz"],
+ # Not a user ID
+ ["*", "foo.bar.baz"],
+ # Not a user ID
+ ["*", "@foo"],
],
"timeline_limit": 0,
},
@@ -663,29 +1297,89 @@ class SlidingSyncRoomsRequiredStateTestCase(SlidingSyncBase):
}
response_body, _ = self.do_sync(sync_body, tok=user1_tok)
- # Make sure the list includes room1 but room2 is excluded because it's still
- # partially-stated
- self.assertListEqual(
- list(response_body["lists"]["foo-list"]["ops"]),
- [
- {
- "op": "SYNC",
- "range": [0, 1],
- "room_ids": [room_id1],
+ # The list should include both rooms now because we don't need full state
+ for list_key in response_body["lists"].keys():
+ self.assertIncludes(
+ set(response_body["lists"][list_key]["ops"][0]["room_ids"]),
+ {room_id2, room_id1},
+ exact=True,
+ message=f"Expected all rooms to show up for list_key={list_key}. Response "
+ + str(response_body["lists"][list_key]),
+ )
+
+ # Take each of the list variants and apply them to room subscriptions to make
+ # sure the same rules apply
+ for list_key in sync_body["lists"].keys():
+ sync_body_for_subscriptions = {
+ "room_subscriptions": {
+ room_id1: {
+ "required_state": sync_body["lists"][list_key][
+ "required_state"
+ ],
+ "timeline_limit": 0,
+ },
+ room_id2: {
+ "required_state": sync_body["lists"][list_key][
+ "required_state"
+ ],
+ "timeline_limit": 0,
+ },
}
- ],
- response_body["lists"]["foo-list"],
- )
+ }
+ response_body, _ = self.do_sync(sync_body_for_subscriptions, tok=user1_tok)
+
+ self.assertIncludes(
+ set(response_body["rooms"].keys()),
+ {room_id2, room_id1},
+ exact=True,
+ message=f"Expected all rooms to show up for test_key={list_key}.",
+ )
- # Make the Sliding Sync request (with lazy-loading room members)
+ # =====================================================================
+
+ # Make the Sliding Sync request with examples where `must_await_full_state()` is
+ # `True`
sync_body = {
"lists": {
- "foo-list": {
+ "wildcard-list": {
+ "ranges": [[0, 1]],
+ "required_state": [
+ ["*", "*"],
+ ],
+ "timeline_limit": 0,
+ },
+ "wildcard-type-remote-state-key-list": {
+ "ranges": [[0, 1]],
+ "required_state": [
+ ["*", "@some:remote"],
+ # Not a user ID
+ ["*", "foobarbaz"],
+ # Not a user ID
+ ["*", "foo.bar.baz"],
+ # Not a user ID
+ ["*", "@foo"],
+ ],
+ "timeline_limit": 0,
+ },
+ "remote-member-list": {
+ "ranges": [[0, 1]],
+ "required_state": [
+ # Own user ID
+ [EventTypes.Member, user1_id],
+ # Remote member
+ [EventTypes.Member, "@some:remote"],
+ # Local member
+ [EventTypes.Member, user2_id],
+ ],
+ "timeline_limit": 0,
+ },
+ "lazy-but-remote-member-list": {
"ranges": [[0, 1]],
"required_state": [
- [EventTypes.Create, ""],
# Lazy-load room members
[EventTypes.Member, StateValues.LAZY],
+ # Remote member
+ [EventTypes.Member, "@some:remote"],
],
"timeline_limit": 0,
},
@@ -693,15 +1387,302 @@ class SlidingSyncRoomsRequiredStateTestCase(SlidingSyncBase):
}
response_body, _ = self.do_sync(sync_body, tok=user1_tok)
- # The list should include both rooms now because we're lazy-loading room members
- self.assertListEqual(
- list(response_body["lists"]["foo-list"]["ops"]),
- [
- {
- "op": "SYNC",
- "range": [0, 1],
- "room_ids": [room_id2, room_id1],
+ # Make sure the list includes room1 but room2 is excluded because it's still
+ # partially-stated
+ for list_key in response_body["lists"].keys():
+ self.assertIncludes(
+ set(response_body["lists"][list_key]["ops"][0]["room_ids"]),
+ {room_id1},
+ exact=True,
+ message=f"Expected only fully-stated rooms to show up for list_key={list_key}. Response "
+ + str(response_body["lists"][list_key]),
+ )
+
+ # Take each of the list variants and apply them to room subscriptions to make
+ # sure the same rules apply
+ for list_key in sync_body["lists"].keys():
+ sync_body_for_subscriptions = {
+ "room_subscriptions": {
+ room_id1: {
+ "required_state": sync_body["lists"][list_key][
+ "required_state"
+ ],
+ "timeline_limit": 0,
+ },
+ room_id2: {
+ "required_state": sync_body["lists"][list_key][
+ "required_state"
+ ],
+ "timeline_limit": 0,
+ },
+ }
+ }
+ response_body, _ = self.do_sync(sync_body_for_subscriptions, tok=user1_tok)
+
+ self.assertIncludes(
+ set(response_body["rooms"].keys()),
+ {room_id1},
+ exact=True,
+ message=f"Expected only fully-stated rooms to show up for test_key={list_key}.",
+ )
+
+ def test_rooms_required_state_expand(self) -> None:
+ """Test that when we expand the required state argument we get the
+ expanded state, and not just the changes to the new expanded."""
+
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ # Create a room with a room name.
+ room_id1 = self.helper.create_room_as(
+ user1_id, tok=user1_tok, extra_content={"name": "Foo"}
+ )
+
+ # Only request the state event to begin with
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ "required_state": [
+ [EventTypes.Create, ""],
+ ],
+ "timeline_limit": 1,
}
- ],
- response_body["lists"]["foo-list"],
+ }
+ }
+ response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+
+ state_map = self.get_success(
+ self.storage_controllers.state.get_current_state(room_id1)
)
+
+ self._assertRequiredStateIncludes(
+ response_body["rooms"][room_id1]["required_state"],
+ {
+ state_map[(EventTypes.Create, "")],
+ },
+ exact=True,
+ )
+
+ # Send a message so the room comes down sync.
+ self.helper.send(room_id1, "msg", tok=user1_tok)
+
+ # Update the sliding sync requests to include the room name
+ sync_body["lists"]["foo-list"]["required_state"] = [
+ [EventTypes.Create, ""],
+ [EventTypes.Name, ""],
+ ]
+ response_body, from_token = self.do_sync(
+ sync_body, since=from_token, tok=user1_tok
+ )
+
+ # We should see the room name, even though there haven't been any
+ # changes.
+ self._assertRequiredStateIncludes(
+ response_body["rooms"][room_id1]["required_state"],
+ {
+ state_map[(EventTypes.Name, "")],
+ },
+ exact=True,
+ )
+
+ # Send a message so the room comes down sync.
+ self.helper.send(room_id1, "msg", tok=user1_tok)
+
+ # We should not see any state changes.
+ response_body, from_token = self.do_sync(
+ sync_body, since=from_token, tok=user1_tok
+ )
+ self.assertIsNone(response_body["rooms"][room_id1].get("required_state"))
+
+ def test_rooms_required_state_expand_retract_expand(self) -> None:
+ """Test that when expanding, retracting and then expanding the required
+ state, we get the changes that happened."""
+
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ # Create a room with a room name.
+ room_id1 = self.helper.create_room_as(
+ user1_id, tok=user1_tok, extra_content={"name": "Foo"}
+ )
+
+ # Only request the state event to begin with
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ "required_state": [
+ [EventTypes.Create, ""],
+ ],
+ "timeline_limit": 1,
+ }
+ }
+ }
+ response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+
+ state_map = self.get_success(
+ self.storage_controllers.state.get_current_state(room_id1)
+ )
+
+ self._assertRequiredStateIncludes(
+ response_body["rooms"][room_id1]["required_state"],
+ {
+ state_map[(EventTypes.Create, "")],
+ },
+ exact=True,
+ )
+
+ # Send a message so the room comes down sync.
+ self.helper.send(room_id1, "msg", tok=user1_tok)
+
+ # Update the sliding sync requests to include the room name
+ sync_body["lists"]["foo-list"]["required_state"] = [
+ [EventTypes.Create, ""],
+ [EventTypes.Name, ""],
+ ]
+ response_body, from_token = self.do_sync(
+ sync_body, since=from_token, tok=user1_tok
+ )
+
+ # We should see the room name, even though there haven't been any
+ # changes.
+ self._assertRequiredStateIncludes(
+ response_body["rooms"][room_id1]["required_state"],
+ {
+ state_map[(EventTypes.Name, "")],
+ },
+ exact=True,
+ )
+
+ # Update the room name
+ self.helper.send_state(
+ room_id1, EventTypes.Name, {"name": "Bar"}, state_key="", tok=user1_tok
+ )
+
+ # Update the sliding sync requests to exclude the room name again
+ sync_body["lists"]["foo-list"]["required_state"] = [
+ [EventTypes.Create, ""],
+ ]
+ response_body, from_token = self.do_sync(
+ sync_body, since=from_token, tok=user1_tok
+ )
+
+ # We should not see the updated room name in state (though it will be in
+ # the timeline).
+ self.assertIsNone(response_body["rooms"][room_id1].get("required_state"))
+
+ # Send a message so the room comes down sync.
+ self.helper.send(room_id1, "msg", tok=user1_tok)
+
+ # Update the sliding sync requests to include the room name again
+ sync_body["lists"]["foo-list"]["required_state"] = [
+ [EventTypes.Create, ""],
+ [EventTypes.Name, ""],
+ ]
+ response_body, from_token = self.do_sync(
+ sync_body, since=from_token, tok=user1_tok
+ )
+
+ # We should see the *new* room name, even though there haven't been any
+ # changes.
+ state_map = self.get_success(
+ self.storage_controllers.state.get_current_state(room_id1)
+ )
+ self._assertRequiredStateIncludes(
+ response_body["rooms"][room_id1]["required_state"],
+ {
+ state_map[(EventTypes.Name, "")],
+ },
+ exact=True,
+ )
+
+ def test_rooms_required_state_expand_deduplicate(self) -> None:
+ """Test that when expanding, retracting and then expanding the required
+ state, we don't get the state down again if it hasn't changed"""
+
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ # Create a room with a room name.
+ room_id1 = self.helper.create_room_as(
+ user1_id, tok=user1_tok, extra_content={"name": "Foo"}
+ )
+
+ # Only request the state event to begin with
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ "required_state": [
+ [EventTypes.Create, ""],
+ ],
+ "timeline_limit": 1,
+ }
+ }
+ }
+ response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+
+ state_map = self.get_success(
+ self.storage_controllers.state.get_current_state(room_id1)
+ )
+
+ self._assertRequiredStateIncludes(
+ response_body["rooms"][room_id1]["required_state"],
+ {
+ state_map[(EventTypes.Create, "")],
+ },
+ exact=True,
+ )
+
+ # Send a message so the room comes down sync.
+ self.helper.send(room_id1, "msg", tok=user1_tok)
+
+ # Update the sliding sync requests to include the room name
+ sync_body["lists"]["foo-list"]["required_state"] = [
+ [EventTypes.Create, ""],
+ [EventTypes.Name, ""],
+ ]
+ response_body, from_token = self.do_sync(
+ sync_body, since=from_token, tok=user1_tok
+ )
+
+ # We should see the room name, even though there haven't been any
+ # changes.
+ self._assertRequiredStateIncludes(
+ response_body["rooms"][room_id1]["required_state"],
+ {
+ state_map[(EventTypes.Name, "")],
+ },
+ exact=True,
+ )
+
+ # Send a message so the room comes down sync.
+ self.helper.send(room_id1, "msg", tok=user1_tok)
+
+ # Update the sliding sync requests to exclude the room name again
+ sync_body["lists"]["foo-list"]["required_state"] = [
+ [EventTypes.Create, ""],
+ ]
+ response_body, from_token = self.do_sync(
+ sync_body, since=from_token, tok=user1_tok
+ )
+
+ # We should not see any state updates
+ self.assertIsNone(response_body["rooms"][room_id1].get("required_state"))
+
+ # Send a message so the room comes down sync.
+ self.helper.send(room_id1, "msg", tok=user1_tok)
+
+ # Update the sliding sync requests to include the room name again
+ sync_body["lists"]["foo-list"]["required_state"] = [
+ [EventTypes.Create, ""],
+ [EventTypes.Name, ""],
+ ]
+ response_body, from_token = self.do_sync(
+ sync_body, since=from_token, tok=user1_tok
+ )
+
+ # We should not see the room name again, as we have already sent that
+ # down.
+ self.assertIsNone(response_body["rooms"][room_id1].get("required_state"))
diff --git a/tests/rest/client/sliding_sync/test_rooms_timeline.py b/tests/rest/client/sliding_sync/test_rooms_timeline.py
index 2e9586ca73..535420209b 100644
--- a/tests/rest/client/sliding_sync/test_rooms_timeline.py
+++ b/tests/rest/client/sliding_sync/test_rooms_timeline.py
@@ -14,12 +14,15 @@
import logging
from typing import List, Optional
+from parameterized import parameterized_class
+
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
+from synapse.api.constants import EventTypes
from synapse.rest.client import login, room, sync
from synapse.server import HomeServer
-from synapse.types import StreamToken, StrSequence
+from synapse.types import StrSequence
from synapse.util import Clock
from tests.rest.client.sliding_sync.test_sliding_sync import SlidingSyncBase
@@ -27,6 +30,20 @@ from tests.rest.client.sliding_sync.test_sliding_sync import SlidingSyncBase
logger = logging.getLogger(__name__)
+# FIXME: This can be removed once we bump `SCHEMA_COMPAT_VERSION` and run the
+# foreground update for
+# `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` (tracked by
+# https://github.com/element-hq/synapse/issues/17623)
+@parameterized_class(
+ ("use_new_tables",),
+ [
+ (True,),
+ (False,),
+ ],
+ class_name_func=lambda cls,
+ num,
+ params_dict: f"{cls.__name__}_{'new' if params_dict['use_new_tables'] else 'fallback'}",
+)
class SlidingSyncRoomsTimelineTestCase(SlidingSyncBase):
"""
Test `rooms.timeline` in the Sliding Sync API.
@@ -43,6 +60,8 @@ class SlidingSyncRoomsTimelineTestCase(SlidingSyncBase):
self.store = hs.get_datastores().main
self.storage_controllers = hs.get_storage_controllers()
+ super().prepare(reactor, clock, hs)
+
def _assertListEqual(
self,
actual_items: StrSequence,
@@ -130,16 +149,10 @@ class SlidingSyncRoomsTimelineTestCase(SlidingSyncBase):
user2_tok = self.login(user2_id, "pass")
room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
- self.helper.send(room_id1, "activity1", tok=user2_tok)
- self.helper.send(room_id1, "activity2", tok=user2_tok)
+ event_response1 = self.helper.send(room_id1, "activity1", tok=user2_tok)
+ event_response2 = self.helper.send(room_id1, "activity2", tok=user2_tok)
event_response3 = self.helper.send(room_id1, "activity3", tok=user2_tok)
- event_pos3 = self.get_success(
- self.store.get_position_for_event(event_response3["event_id"])
- )
event_response4 = self.helper.send(room_id1, "activity4", tok=user2_tok)
- event_pos4 = self.get_success(
- self.store.get_position_for_event(event_response4["event_id"])
- )
event_response5 = self.helper.send(room_id1, "activity5", tok=user2_tok)
user1_join_response = self.helper.join(room_id1, user1_id, tok=user1_tok)
@@ -177,27 +190,23 @@ class SlidingSyncRoomsTimelineTestCase(SlidingSyncBase):
)
# Check to make sure the `prev_batch` points at the right place
- prev_batch_token = self.get_success(
- StreamToken.from_string(
- self.store, response_body["rooms"][room_id1]["prev_batch"]
- )
- )
- prev_batch_room_stream_token_serialized = self.get_success(
- prev_batch_token.room_key.to_string(self.store)
+ prev_batch_token = response_body["rooms"][room_id1]["prev_batch"]
+
+ # If we use the `prev_batch` token to look backwards we should see
+ # `event3` and older next.
+ channel = self.make_request(
+ "GET",
+ f"/rooms/{room_id1}/messages?from={prev_batch_token}&dir=b&limit=3",
+ access_token=user1_tok,
)
- # If we use the `prev_batch` token to look backwards, we should see `event3`
- # next so make sure the token encompasses it
- self.assertEqual(
- event_pos3.persisted_after(prev_batch_token.room_key),
- False,
- f"`prev_batch` token {prev_batch_room_stream_token_serialized} should be >= event_pos3={self.get_success(event_pos3.to_room_stream_token().to_string(self.store))}",
- )
- # If we use the `prev_batch` token to look backwards, we shouldn't see `event4`
- # anymore since it was just returned in this response.
- self.assertEqual(
- event_pos4.persisted_after(prev_batch_token.room_key),
- True,
- f"`prev_batch` token {prev_batch_room_stream_token_serialized} should be < event_pos4={self.get_success(event_pos4.to_room_stream_token().to_string(self.store))}",
+ self.assertEqual(channel.code, 200, channel.json_body)
+ self.assertListEqual(
+ [
+ event_response3["event_id"],
+ event_response2["event_id"],
+ event_response1["event_id"],
+ ],
+ [ev["event_id"] for ev in channel.json_body["chunk"]],
)
# With no `from_token` (initial sync), it's all historical since there is no
@@ -300,8 +309,8 @@ class SlidingSyncRoomsTimelineTestCase(SlidingSyncBase):
self.assertEqual(
response_body["rooms"][room_id1]["limited"],
False,
- f'Our `timeline_limit` was {sync_body["lists"]["foo-list"]["timeline_limit"]} '
- + f'and {len(response_body["rooms"][room_id1]["timeline"])} events were returned in the timeline. '
+ f"Our `timeline_limit` was {sync_body['lists']['foo-list']['timeline_limit']} "
+ + f"and {len(response_body['rooms'][room_id1]['timeline'])} events were returned in the timeline. "
+ str(response_body["rooms"][room_id1]),
)
# Check to make sure the latest events are returned
@@ -378,7 +387,7 @@ class SlidingSyncRoomsTimelineTestCase(SlidingSyncBase):
response_body["rooms"][room_id1]["limited"],
True,
f"Our `timeline_limit` was {timeline_limit} "
- + f'and {len(response_body["rooms"][room_id1]["timeline"])} events were returned in the timeline. '
+ + f"and {len(response_body['rooms'][room_id1]['timeline'])} events were returned in the timeline. "
+ str(response_body["rooms"][room_id1]),
)
# Check to make sure that the "live" and historical events are returned
@@ -573,3 +582,138 @@ class SlidingSyncRoomsTimelineTestCase(SlidingSyncBase):
# Nothing to see for this banned user in the room in the token range
self.assertIsNone(response_body["rooms"].get(room_id1))
+
+ def test_increasing_timeline_range_sends_more_messages(self) -> None:
+ """
+ Test that increasing the timeline limit via room subscriptions sends the
+ room down with more messages in a limited sync.
+ """
+
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ room_id1 = self.helper.create_room_as(user1_id, tok=user1_tok)
+
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ "required_state": [[EventTypes.Create, ""]],
+ "timeline_limit": 1,
+ }
+ }
+ }
+
+ message_events = []
+ for _ in range(10):
+ resp = self.helper.send(room_id1, "msg", tok=user1_tok)
+ message_events.append(resp["event_id"])
+
+ # Make the first Sliding Sync request
+ response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+ room_response = response_body["rooms"][room_id1]
+
+ self.assertEqual(room_response["initial"], True)
+ self.assertNotIn("unstable_expanded_timeline", room_response)
+ self.assertEqual(room_response["limited"], True)
+
+ # We only expect the last message at first
+ self._assertTimelineEqual(
+ room_id=room_id1,
+ actual_event_ids=[event["event_id"] for event in room_response["timeline"]],
+ expected_event_ids=message_events[-1:],
+ message=str(room_response["timeline"]),
+ )
+
+ # We also expect to get the create event state.
+ state_map = self.get_success(
+ self.storage_controllers.state.get_current_state(room_id1)
+ )
+ self._assertRequiredStateIncludes(
+ room_response["required_state"],
+ {state_map[(EventTypes.Create, "")]},
+ exact=True,
+ )
+
+ # Now do another request with a room subscription with an increased timeline limit
+ sync_body["room_subscriptions"] = {
+ room_id1: {
+ "required_state": [],
+ "timeline_limit": 10,
+ }
+ }
+
+ response_body, from_token = self.do_sync(
+ sync_body, since=from_token, tok=user1_tok
+ )
+ room_response = response_body["rooms"][room_id1]
+
+ self.assertNotIn("initial", room_response)
+ self.assertEqual(room_response["unstable_expanded_timeline"], True)
+ self.assertEqual(room_response["limited"], True)
+
+ # Now we expect all the messages
+ self._assertTimelineEqual(
+ room_id=room_id1,
+ actual_event_ids=[event["event_id"] for event in room_response["timeline"]],
+ expected_event_ids=message_events,
+ message=str(room_response["timeline"]),
+ )
+
+ # We don't expect to get the room create down, as nothing has changed.
+ self.assertNotIn("required_state", room_response)
+
+ # Decreasing the timeline limit shouldn't resend any events
+ sync_body["room_subscriptions"] = {
+ room_id1: {
+ "required_state": [],
+ "timeline_limit": 5,
+ }
+ }
+
+ event_response = self.helper.send(room_id1, "msg", tok=user1_tok)
+ latest_event_id = event_response["event_id"]
+
+ response_body, from_token = self.do_sync(
+ sync_body, since=from_token, tok=user1_tok
+ )
+ room_response = response_body["rooms"][room_id1]
+
+ self.assertNotIn("initial", room_response)
+ self.assertNotIn("unstable_expanded_timeline", room_response)
+ self.assertEqual(room_response["limited"], False)
+
+ self._assertTimelineEqual(
+ room_id=room_id1,
+ actual_event_ids=[event["event_id"] for event in room_response["timeline"]],
+ expected_event_ids=[latest_event_id],
+ message=str(room_response["timeline"]),
+ )
+
+ # Increasing the limit to what it was before also should not resend any
+ # events
+ sync_body["room_subscriptions"] = {
+ room_id1: {
+ "required_state": [],
+ "timeline_limit": 10,
+ }
+ }
+
+ event_response = self.helper.send(room_id1, "msg", tok=user1_tok)
+ latest_event_id = event_response["event_id"]
+
+ response_body, from_token = self.do_sync(
+ sync_body, since=from_token, tok=user1_tok
+ )
+ room_response = response_body["rooms"][room_id1]
+
+ self.assertNotIn("initial", room_response)
+ self.assertNotIn("unstable_expanded_timeline", room_response)
+ self.assertEqual(room_response["limited"], False)
+
+ self._assertTimelineEqual(
+ room_id=room_id1,
+ actual_event_ids=[event["event_id"] for event in room_response["timeline"]],
+ expected_event_ids=[latest_event_id],
+ message=str(room_response["timeline"]),
+ )
diff --git a/tests/rest/client/sliding_sync/test_sliding_sync.py b/tests/rest/client/sliding_sync/test_sliding_sync.py
index cb7638c5ba..dcec5b4cf0 100644
--- a/tests/rest/client/sliding_sync/test_sliding_sync.py
+++ b/tests/rest/client/sliding_sync/test_sliding_sync.py
@@ -13,7 +13,9 @@
#
import logging
from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple
+from unittest.mock import AsyncMock
+from parameterized import parameterized, parameterized_class
from typing_extensions import assert_never
from twisted.test.proto_helpers import MemoryReactor
@@ -23,10 +25,15 @@ from synapse.api.constants import (
AccountDataTypes,
EventContentFields,
EventTypes,
+ JoinRules,
+ Membership,
RoomTypes,
)
-from synapse.events import EventBase
-from synapse.rest.client import devices, login, receipts, room, sync
+from synapse.api.room_versions import RoomVersions
+from synapse.events import EventBase, StrippedStateEvent, make_event_from_dict
+from synapse.events.snapshot import EventContext
+from synapse.handlers.sliding_sync import StateValues
+from synapse.rest.client import account_data, devices, login, receipts, room, sync
from synapse.server import HomeServer
from synapse.types import (
JsonDict,
@@ -40,6 +47,7 @@ from synapse.util.stringutils import random_string
from tests import unittest
from tests.server import TimedOutException
+from tests.test_utils.event_injection import create_event
logger = logging.getLogger(__name__)
@@ -47,8 +55,25 @@ logger = logging.getLogger(__name__)
class SlidingSyncBase(unittest.HomeserverTestCase):
"""Base class for sliding sync test cases"""
+ # Flag as to whether to use the new sliding sync tables or not
+ #
+ # FIXME: This can be removed once we bump `SCHEMA_COMPAT_VERSION` and run the
+ # foreground update for
+ # `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` (tracked by
+ # https://github.com/element-hq/synapse/issues/17623)
+ use_new_tables: bool = True
+
sync_endpoint = "/_matrix/client/unstable/org.matrix.simplified_msc3575/sync"
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ # FIXME: This can be removed once we bump `SCHEMA_COMPAT_VERSION` and run the
+ # foreground update for
+ # `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` (tracked by
+ # https://github.com/element-hq/synapse/issues/17623)
+ hs.get_datastores().main.have_finished_sliding_sync_background_jobs = AsyncMock( # type: ignore[method-assign]
+ return_value=self.use_new_tables
+ )
+
def default_config(self) -> JsonDict:
config = super().default_config()
# Enable sliding sync
@@ -122,6 +147,172 @@ class SlidingSyncBase(unittest.HomeserverTestCase):
message=str(actual_required_state),
)
+ def _add_new_dm_to_global_account_data(
+ self, source_user_id: str, target_user_id: str, target_room_id: str
+ ) -> None:
+ """
+ Helper to handle inserting a new DM for the source user into global account data
+ (handles all of the list merging).
+
+ Args:
+ source_user_id: The user ID of the DM mapping we're going to update
+ target_user_id: User ID of the person the DM is with
+ target_room_id: Room ID of the DM
+ """
+ store = self.hs.get_datastores().main
+
+ # Get the current DM map
+ existing_dm_map = self.get_success(
+ store.get_global_account_data_by_type_for_user(
+ source_user_id, AccountDataTypes.DIRECT
+ )
+ )
+ # Scrutinize the account data since it has no concrete type. We're just copying
+ # everything into a known type. It should be a mapping from user ID to a list of
+ # room IDs. Ignore anything else.
+ new_dm_map: Dict[str, List[str]] = {}
+ if isinstance(existing_dm_map, dict):
+ for user_id, room_ids in existing_dm_map.items():
+ if isinstance(user_id, str) and isinstance(room_ids, list):
+ for room_id in room_ids:
+ if isinstance(room_id, str):
+ new_dm_map[user_id] = new_dm_map.get(user_id, []) + [
+ room_id
+ ]
+
+ # Add the new DM to the map
+ new_dm_map[target_user_id] = new_dm_map.get(target_user_id, []) + [
+ target_room_id
+ ]
+ # Save the DM map to global account data
+ self.get_success(
+ store.add_account_data_for_user(
+ source_user_id,
+ AccountDataTypes.DIRECT,
+ new_dm_map,
+ )
+ )
+
+ def _create_dm_room(
+ self,
+ inviter_user_id: str,
+ inviter_tok: str,
+ invitee_user_id: str,
+ invitee_tok: str,
+ should_join_room: bool = True,
+ ) -> str:
+ """
+ Helper to create a DM room as the "inviter" and invite the "invitee" user to the
+ room. The "invitee" user also will join the room. The `m.direct` account data
+ will be set for both users.
+ """
+ # Create a room and send an invite the other user
+ room_id = self.helper.create_room_as(
+ inviter_user_id,
+ is_public=False,
+ tok=inviter_tok,
+ )
+ self.helper.invite(
+ room_id,
+ src=inviter_user_id,
+ targ=invitee_user_id,
+ tok=inviter_tok,
+ extra_data={"is_direct": True},
+ )
+ if should_join_room:
+ # Person that was invited joins the room
+ self.helper.join(room_id, invitee_user_id, tok=invitee_tok)
+
+ # Mimic the client setting the room as a direct message in the global account
+ # data for both users.
+ self._add_new_dm_to_global_account_data(
+ invitee_user_id, inviter_user_id, room_id
+ )
+ self._add_new_dm_to_global_account_data(
+ inviter_user_id, invitee_user_id, room_id
+ )
+
+ return room_id
+
+ _remote_invite_count: int = 0
+
+ def _create_remote_invite_room_for_user(
+ self,
+ invitee_user_id: str,
+ unsigned_invite_room_state: Optional[List[StrippedStateEvent]],
+ invite_room_id: Optional[str] = None,
+ ) -> str:
+ """
+ Create a fake invite for a remote room and persist it.
+
+ We don't have any state for these kind of rooms and can only rely on the
+ stripped state included in the unsigned portion of the invite event to identify
+ the room.
+
+ Args:
+ invitee_user_id: The person being invited
+ unsigned_invite_room_state: List of stripped state events to assist the
+ receiver in identifying the room.
+ invite_room_id: Optional remote room ID to be invited to. When unset, we
+ will generate one.
+
+ Returns:
+ The room ID of the remote invite room
+ """
+ store = self.hs.get_datastores().main
+
+ if invite_room_id is None:
+ invite_room_id = f"!test_room{self._remote_invite_count}:remote_server"
+
+ invite_event_dict = {
+ "room_id": invite_room_id,
+ "sender": "@inviter:remote_server",
+ "state_key": invitee_user_id,
+ # Just keep advancing the depth
+ "depth": self._remote_invite_count,
+ "origin_server_ts": 1,
+ "type": EventTypes.Member,
+ "content": {"membership": Membership.INVITE},
+ "auth_events": [],
+ "prev_events": [],
+ }
+ if unsigned_invite_room_state is not None:
+ serialized_stripped_state_events = []
+ for stripped_event in unsigned_invite_room_state:
+ serialized_stripped_state_events.append(
+ {
+ "type": stripped_event.type,
+ "state_key": stripped_event.state_key,
+ "sender": stripped_event.sender,
+ "content": stripped_event.content,
+ }
+ )
+
+ invite_event_dict["unsigned"] = {
+ "invite_room_state": serialized_stripped_state_events
+ }
+
+ invite_event = make_event_from_dict(
+ invite_event_dict,
+ room_version=RoomVersions.V10,
+ )
+ invite_event.internal_metadata.outlier = True
+ invite_event.internal_metadata.out_of_band_membership = True
+
+ self.get_success(
+ store.maybe_store_room_on_outlier_membership(
+ room_id=invite_room_id, room_version=invite_event.room_version
+ )
+ )
+ context = EventContext.for_outlier(self.hs.get_storage_controllers())
+ persist_controller = self.hs.get_storage_controllers().persistence
+ assert persist_controller is not None
+ self.get_success(persist_controller.persist_event(invite_event, context))
+
+ self._remote_invite_count += 1
+
+ return invite_room_id
+
def _bump_notifier_wait_for_events(
self,
user_id: str,
@@ -203,6 +394,20 @@ class SlidingSyncBase(unittest.HomeserverTestCase):
)
+# FIXME: This can be removed once we bump `SCHEMA_COMPAT_VERSION` and run the
+# foreground update for
+# `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` (tracked by
+# https://github.com/element-hq/synapse/issues/17623)
+@parameterized_class(
+ ("use_new_tables",),
+ [
+ (True,),
+ (False,),
+ ],
+ class_name_func=lambda cls,
+ num,
+ params_dict: f"{cls.__name__}_{'new' if params_dict['use_new_tables'] else 'fallback'}",
+)
class SlidingSyncTestCase(SlidingSyncBase):
"""
Tests regarding MSC3575 Sliding Sync `/sync` endpoint.
@@ -218,6 +423,7 @@ class SlidingSyncTestCase(SlidingSyncBase):
sync.register_servlets,
devices.register_servlets,
receipts.register_servlets,
+ account_data.register_servlets,
]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
@@ -225,93 +431,11 @@ class SlidingSyncTestCase(SlidingSyncBase):
self.event_sources = hs.get_event_sources()
self.storage_controllers = hs.get_storage_controllers()
self.account_data_handler = hs.get_account_data_handler()
+ persistence = self.hs.get_storage_controllers().persistence
+ assert persistence is not None
+ self.persistence = persistence
- def _add_new_dm_to_global_account_data(
- self, source_user_id: str, target_user_id: str, target_room_id: str
- ) -> None:
- """
- Helper to handle inserting a new DM for the source user into global account data
- (handles all of the list merging).
-
- Args:
- source_user_id: The user ID of the DM mapping we're going to update
- target_user_id: User ID of the person the DM is with
- target_room_id: Room ID of the DM
- """
-
- # Get the current DM map
- existing_dm_map = self.get_success(
- self.store.get_global_account_data_by_type_for_user(
- source_user_id, AccountDataTypes.DIRECT
- )
- )
- # Scrutinize the account data since it has no concrete type. We're just copying
- # everything into a known type. It should be a mapping from user ID to a list of
- # room IDs. Ignore anything else.
- new_dm_map: Dict[str, List[str]] = {}
- if isinstance(existing_dm_map, dict):
- for user_id, room_ids in existing_dm_map.items():
- if isinstance(user_id, str) and isinstance(room_ids, list):
- for room_id in room_ids:
- if isinstance(room_id, str):
- new_dm_map[user_id] = new_dm_map.get(user_id, []) + [
- room_id
- ]
-
- # Add the new DM to the map
- new_dm_map[target_user_id] = new_dm_map.get(target_user_id, []) + [
- target_room_id
- ]
- # Save the DM map to global account data
- self.get_success(
- self.store.add_account_data_for_user(
- source_user_id,
- AccountDataTypes.DIRECT,
- new_dm_map,
- )
- )
-
- def _create_dm_room(
- self,
- inviter_user_id: str,
- inviter_tok: str,
- invitee_user_id: str,
- invitee_tok: str,
- should_join_room: bool = True,
- ) -> str:
- """
- Helper to create a DM room as the "inviter" and invite the "invitee" user to the
- room. The "invitee" user also will join the room. The `m.direct` account data
- will be set for both users.
- """
-
- # Create a room and send an invite the other user
- room_id = self.helper.create_room_as(
- inviter_user_id,
- is_public=False,
- tok=inviter_tok,
- )
- self.helper.invite(
- room_id,
- src=inviter_user_id,
- targ=invitee_user_id,
- tok=inviter_tok,
- extra_data={"is_direct": True},
- )
- if should_join_room:
- # Person that was invited joins the room
- self.helper.join(room_id, invitee_user_id, tok=invitee_tok)
-
- # Mimic the client setting the room as a direct message in the global account
- # data for both users.
- self._add_new_dm_to_global_account_data(
- invitee_user_id, inviter_user_id, room_id
- )
- self._add_new_dm_to_global_account_data(
- inviter_user_id, invitee_user_id, room_id
- )
-
- return room_id
+ super().prepare(reactor, clock, hs)
def test_sync_list(self) -> None:
"""
@@ -512,288 +636,326 @@ class SlidingSyncTestCase(SlidingSyncBase):
# There should be no room sent down.
self.assertFalse(channel.json_body["rooms"])
- def test_filter_list(self) -> None:
+ def test_forgotten_up_to_date(self) -> None:
"""
- Test that filters apply to `lists`
+ Make sure we get up-to-date `forgotten` status for rooms
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
user2_id = self.register_user("user2", "pass")
user2_tok = self.login(user2_id, "pass")
- # Create a DM room
- joined_dm_room_id = self._create_dm_room(
- inviter_user_id=user1_id,
- inviter_tok=user1_tok,
- invitee_user_id=user2_id,
- invitee_tok=user2_tok,
- should_join_room=True,
- )
- invited_dm_room_id = self._create_dm_room(
- inviter_user_id=user1_id,
- inviter_tok=user1_tok,
- invitee_user_id=user2_id,
- invitee_tok=user2_tok,
- should_join_room=False,
- )
-
- # Create a normal room
room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
- self.helper.join(room_id, user1_id, tok=user1_tok)
- # Create a room that user1 is invited to
- invite_room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
- self.helper.invite(invite_room_id, src=user2_id, targ=user1_id, tok=user2_tok)
+ # User1 is banned from the room (was never in the room)
+ self.helper.ban(room_id, src=user2_id, targ=user1_id, tok=user2_tok)
- # Make the Sliding Sync request
sync_body = {
"lists": {
- # Absense of filters does not imply "False" values
- "all": {
+ "foo-list": {
"ranges": [[0, 99]],
"required_state": [],
- "timeline_limit": 1,
+ "timeline_limit": 0,
"filters": {},
},
- # Test single truthy filter
- "dms": {
- "ranges": [[0, 99]],
- "required_state": [],
- "timeline_limit": 1,
- "filters": {"is_dm": True},
- },
- # Test single falsy filter
- "non-dms": {
- "ranges": [[0, 99]],
- "required_state": [],
- "timeline_limit": 1,
- "filters": {"is_dm": False},
- },
- # Test how multiple filters should stack (AND'd together)
- "room-invites": {
- "ranges": [[0, 99]],
- "required_state": [],
- "timeline_limit": 1,
- "filters": {"is_dm": False, "is_invite": True},
- },
}
}
- response_body, _ = self.do_sync(sync_body, tok=user1_tok)
-
- # Make sure it has the foo-list we requested
- self.assertListEqual(
- list(response_body["lists"].keys()),
- ["all", "dms", "non-dms", "room-invites"],
- response_body["lists"].keys(),
+ response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {room_id},
+ exact=True,
)
- # Make sure the lists have the correct rooms
- self.assertListEqual(
- list(response_body["lists"]["all"]["ops"]),
- [
- {
- "op": "SYNC",
- "range": [0, 99],
- "room_ids": [
- invite_room_id,
- room_id,
- invited_dm_room_id,
- joined_dm_room_id,
- ],
- }
- ],
- list(response_body["lists"]["all"]),
- )
- self.assertListEqual(
- list(response_body["lists"]["dms"]["ops"]),
- [
- {
- "op": "SYNC",
- "range": [0, 99],
- "room_ids": [invited_dm_room_id, joined_dm_room_id],
- }
- ],
- list(response_body["lists"]["dms"]),
- )
- self.assertListEqual(
- list(response_body["lists"]["non-dms"]["ops"]),
- [
- {
- "op": "SYNC",
- "range": [0, 99],
- "room_ids": [invite_room_id, room_id],
- }
- ],
- list(response_body["lists"]["non-dms"]),
- )
- self.assertListEqual(
- list(response_body["lists"]["room-invites"]["ops"]),
- [
- {
- "op": "SYNC",
- "range": [0, 99],
- "room_ids": [invite_room_id],
- }
- ],
- list(response_body["lists"]["room-invites"]),
+ # User1 forgets the room
+ channel = self.make_request(
+ "POST",
+ f"/_matrix/client/r0/rooms/{room_id}/forget",
+ content={},
+ access_token=user1_tok,
)
+ self.assertEqual(channel.code, 200, channel.result)
- # Ensure DM's are correctly marked
- self.assertDictEqual(
- {
- room_id: room.get("is_dm")
- for room_id, room in response_body["rooms"].items()
- },
- {
- invite_room_id: None,
- room_id: None,
- invited_dm_room_id: True,
- joined_dm_room_id: True,
- },
+ # We should no longer see the forgotten room
+ response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ set(),
+ exact=True,
)
- def test_filter_regardless_of_membership_server_left_room(self) -> None:
+ def test_rejoin_forgotten_room(self) -> None:
"""
- Test that filters apply to rooms regardless of membership. We're also
- compounding the problem by having all of the local users leave the room causing
- our server to leave the room.
-
- We want to make sure that if someone is filtering rooms, and leaves, you still
- get that final update down sync that you left.
+ Make sure we can see a forgotten room again if we rejoin (or any new membership
+ like an invite) (no longer forgotten)
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
user2_id = self.register_user("user2", "pass")
user2_tok = self.login(user2_id, "pass")
- # Create a normal room
- room_id = self.helper.create_room_as(user1_id, tok=user2_tok)
+ room_id = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True)
+ # User1 joins the room
self.helper.join(room_id, user1_id, tok=user1_tok)
- # Create an encrypted space room
- space_room_id = self.helper.create_room_as(
- user2_id,
- tok=user2_tok,
- extra_content={
- "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE}
- },
- )
- self.helper.send_state(
- space_room_id,
- EventTypes.RoomEncryption,
- {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"},
- tok=user2_tok,
+ # Make the Sliding Sync request
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ }
+ }
+ }
+ response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+ # We should see the room (like normal)
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {room_id},
+ exact=True,
)
- self.helper.join(space_room_id, user1_id, tok=user1_tok)
- # Make an initial Sliding Sync request
+ # Leave and forget the room
+ self.helper.leave(room_id, user1_id, tok=user1_tok)
+ # User1 forgets the room
channel = self.make_request(
"POST",
- self.sync_endpoint,
- {
- "lists": {
- "all-list": {
- "ranges": [[0, 99]],
- "required_state": [],
- "timeline_limit": 0,
- "filters": {},
- },
- "foo-list": {
- "ranges": [[0, 99]],
- "required_state": [],
- "timeline_limit": 1,
- "filters": {
- "is_encrypted": True,
- "room_types": [RoomTypes.SPACE],
- },
- },
+ f"/_matrix/client/r0/rooms/{room_id}/forget",
+ content={},
+ access_token=user1_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ # Re-join the room
+ self.helper.join(room_id, user1_id, tok=user1_tok)
+
+ # We should see the room again after re-joining
+ response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {room_id},
+ exact=True,
+ )
+
+ def test_invited_to_forgotten_remote_room(self) -> None:
+ """
+ Make sure we can see a forgotten room again if we are invited again
+ (remote/federated out-of-band memberships)
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ # Create a remote room invite (out-of-band membership)
+ room_id = self._create_remote_invite_room_for_user(user1_id, None)
+
+ # Make the Sliding Sync request
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
}
- },
+ }
+ }
+ response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+ # We should see the room (like normal)
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {room_id},
+ exact=True,
+ )
+
+ # Leave and forget the room
+ self.helper.leave(room_id, user1_id, tok=user1_tok)
+ # User1 forgets the room
+ channel = self.make_request(
+ "POST",
+ f"/_matrix/client/r0/rooms/{room_id}/forget",
+ content={},
access_token=user1_tok,
)
- self.assertEqual(channel.code, 200, channel.json_body)
- from_token = channel.json_body["pos"]
+ self.assertEqual(channel.code, 200, channel.result)
- # Make sure the response has the lists we requested
- self.assertListEqual(
- list(channel.json_body["lists"].keys()),
- ["all-list", "foo-list"],
- channel.json_body["lists"].keys(),
+ # Get invited to the room again
+ # self.helper.join(room_id, user1_id, tok=user1_tok)
+ self._create_remote_invite_room_for_user(user1_id, None, invite_room_id=room_id)
+
+ # We should see the room again after re-joining
+ response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {room_id},
+ exact=True,
)
- # Make sure the lists have the correct rooms
- self.assertListEqual(
- list(channel.json_body["lists"]["all-list"]["ops"]),
- [
- {
- "op": "SYNC",
- "range": [0, 99],
- "room_ids": [space_room_id, room_id],
+ def test_reject_remote_invite(self) -> None:
+ """Test that rejecting a remote invite comes down incremental sync"""
+
+ user_id = self.register_user("user1", "pass")
+ user_tok = self.login(user_id, "pass")
+
+ # Create a remote room invite (out-of-band membership)
+ room_id = "!room:remote.server"
+ self._create_remote_invite_room_for_user(user_id, None, room_id)
+
+ # Make the Sliding Sync request
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ "required_state": [(EventTypes.Member, StateValues.ME)],
+ "timeline_limit": 3,
}
- ],
+ }
+ }
+ response_body, from_token = self.do_sync(sync_body, tok=user_tok)
+ # We should see the room (like normal)
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {room_id},
+ exact=True,
)
- self.assertListEqual(
- list(channel.json_body["lists"]["foo-list"]["ops"]),
- [
+
+ # Reject the remote room invite
+ self.helper.leave(room_id, user_id, tok=user_tok)
+
+ # Sync again after rejecting the invite
+ response_body, _ = self.do_sync(sync_body, since=from_token, tok=user_tok)
+
+ # The fix to add the leave event to incremental sync when rejecting a remote
+ # invite relies on the new tables to work.
+ if self.use_new_tables:
+ # We should see the newly_left room
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {room_id},
+ exact=True,
+ )
+ # We should see the leave state for the room so clients don't end up with stuck
+ # invites
+ self.assertIncludes(
{
- "op": "SYNC",
- "range": [0, 99],
- "room_ids": [space_room_id],
- }
- ],
+ (
+ state["type"],
+ state["state_key"],
+ state["content"].get("membership"),
+ )
+ for state in response_body["rooms"][room_id]["required_state"]
+ },
+ {(EventTypes.Member, user_id, Membership.LEAVE)},
+ exact=True,
+ )
+
+ def test_ignored_user_invites_initial_sync(self) -> None:
+ """
+ Make sure we ignore invites if they are from one of the `m.ignored_user_list` on
+ initial sync.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ # Create a room that user1 is already in
+ room_id1 = self.helper.create_room_as(user1_id, tok=user1_tok)
+
+ # Create a room that user2 is already in
+ room_id2 = self.helper.create_room_as(user2_id, tok=user2_tok)
+
+ # User1 is invited to room_id2
+ self.helper.invite(room_id2, src=user2_id, targ=user1_id, tok=user2_tok)
+
+ # Sync once before we ignore to make sure the rooms can show up
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ },
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ # room_id2 shows up because we haven't ignored the user yet
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {room_id1, room_id2},
+ exact=True,
)
- # Everyone leaves the encrypted space room
- self.helper.leave(space_room_id, user1_id, tok=user1_tok)
- self.helper.leave(space_room_id, user2_id, tok=user2_tok)
+ # User1 ignores user2
+ channel = self.make_request(
+ "PUT",
+ f"/_matrix/client/v3/user/{user1_id}/account_data/{AccountDataTypes.IGNORED_USER_LIST}",
+ content={"ignored_users": {user2_id: {}}},
+ access_token=user1_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ # Sync again (initial sync)
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ # The invite for room_id2 should no longer show up because user2 is ignored
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {room_id1},
+ exact=True,
+ )
+
+ def test_ignored_user_invites_incremental_sync(self) -> None:
+ """
+ Make sure we ignore invites if they are from one of the `m.ignored_user_list` on
+ incremental sync.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ # Create a room that user1 is already in
+ room_id1 = self.helper.create_room_as(user1_id, tok=user1_tok)
+
+ # Create a room that user2 is already in
+ room_id2 = self.helper.create_room_as(user2_id, tok=user2_tok)
- # Make an incremental Sliding Sync request
+ # User1 ignores user2
channel = self.make_request(
- "POST",
- self.sync_endpoint + f"?pos={from_token}",
- {
- "lists": {
- "all-list": {
- "ranges": [[0, 99]],
- "required_state": [],
- "timeline_limit": 0,
- "filters": {},
- },
- "foo-list": {
- "ranges": [[0, 99]],
- "required_state": [],
- "timeline_limit": 1,
- "filters": {
- "is_encrypted": True,
- "room_types": [RoomTypes.SPACE],
- },
- },
- }
- },
+ "PUT",
+ f"/_matrix/client/v3/user/{user1_id}/account_data/{AccountDataTypes.IGNORED_USER_LIST}",
+ content={"ignored_users": {user2_id: {}}},
access_token=user1_tok,
)
- self.assertEqual(channel.code, 200, channel.json_body)
+ self.assertEqual(channel.code, 200, channel.result)
- # Make sure the lists have the correct rooms even though we `newly_left`
- self.assertListEqual(
- list(channel.json_body["lists"]["all-list"]["ops"]),
- [
- {
- "op": "SYNC",
- "range": [0, 99],
- "room_ids": [space_room_id, room_id],
- }
- ],
+ # Initial sync
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ },
+ }
+ }
+ response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+ # User1 only has membership in room_id1 at this point
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {room_id1},
+ exact=True,
)
- self.assertListEqual(
- list(channel.json_body["lists"]["foo-list"]["ops"]),
- [
- {
- "op": "SYNC",
- "range": [0, 99],
- "room_ids": [space_room_id],
- }
- ],
+
+ # User1 is invited to room_id2 after the initial sync
+ self.helper.invite(room_id2, src=user2_id, targ=user1_id, tok=user2_tok)
+
+ # Sync again (incremental sync)
+ response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
+ # The invite for room_id2 doesn't show up because user2 is ignored
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {room_id1},
+ exact=True,
)
def test_sort_list(self) -> None:
@@ -812,11 +974,11 @@ class SlidingSyncTestCase(SlidingSyncBase):
self.helper.send(room_id1, "activity in room1", tok=user1_tok)
self.helper.send(room_id2, "activity in room2", tok=user1_tok)
- # Make the Sliding Sync request
+ # Make the Sliding Sync request where the range includes *some* of the rooms
sync_body = {
"lists": {
"foo-list": {
- "ranges": [[0, 99]],
+ "ranges": [[0, 1]],
"required_state": [],
"timeline_limit": 1,
}
@@ -825,25 +987,56 @@ class SlidingSyncTestCase(SlidingSyncBase):
response_body, _ = self.do_sync(sync_body, tok=user1_tok)
# Make sure it has the foo-list we requested
- self.assertListEqual(
- list(response_body["lists"].keys()),
- ["foo-list"],
+ self.assertIncludes(
response_body["lists"].keys(),
+ {"foo-list"},
)
-
- # Make sure the list is sorted in the way we expect
+ # Make sure the list is sorted in the way we expect (we only sort when the range
+ # doesn't include all of the room)
self.assertListEqual(
list(response_body["lists"]["foo-list"]["ops"]),
[
{
"op": "SYNC",
- "range": [0, 99],
- "room_ids": [room_id2, room_id1, room_id3],
+ "range": [0, 1],
+ "room_ids": [room_id2, room_id1],
}
],
response_body["lists"]["foo-list"],
)
+ # Make the Sliding Sync request where the range includes *all* of the rooms
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 1,
+ }
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+
+ # Make sure it has the foo-list we requested
+ self.assertIncludes(
+ response_body["lists"].keys(),
+ {"foo-list"},
+ )
+ # Since the range includes all of the rooms, we don't sort the list
+ self.assertEqual(
+ len(response_body["lists"]["foo-list"]["ops"]),
+ 1,
+ response_body["lists"]["foo-list"],
+ )
+ op = response_body["lists"]["foo-list"]["ops"][0]
+ self.assertEqual(op["op"], "SYNC")
+ self.assertEqual(op["range"], [0, 99])
+ # Note that we don't sort the rooms when the range includes all of the rooms, so
+ # we just assert that the rooms are included
+ self.assertIncludes(
+ set(op["room_ids"]), {room_id1, room_id2, room_id3}, exact=True
+ )
+
def test_sliced_windows(self) -> None:
"""
Test that the `lists` `ranges` are sliced correctly. Both sides of each range
@@ -972,3 +1165,454 @@ class SlidingSyncTestCase(SlidingSyncBase):
# Make the Sliding Sync request
response_body, _ = self.do_sync(sync_body, tok=user1_tok)
self.assertEqual(response_body["rooms"][room_id1]["initial"], True)
+
+ def test_state_reset_room_comes_down_incremental_sync(self) -> None:
+ """Test that a room that we were state reset out of comes down
+ incremental sync"""
+
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ room_id1 = self.helper.create_room_as(
+ user2_id,
+ is_public=True,
+ tok=user2_tok,
+ extra_content={
+ "name": "my super room",
+ },
+ )
+
+ # Create an event for us to point back to for the state reset
+ event_response = self.helper.send(room_id1, "test", tok=user2_tok)
+ event_id = event_response["event_id"]
+
+ self.helper.join(room_id1, user1_id, tok=user1_tok)
+
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ "required_state": [
+ # Request all state just to see what we get back when we are
+ # state reset out of the room
+ [StateValues.WILDCARD, StateValues.WILDCARD]
+ ],
+ "timeline_limit": 1,
+ }
+ }
+ }
+
+ # Make the Sliding Sync request
+ response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+ # Make sure we see room1
+ self.assertIncludes(set(response_body["rooms"].keys()), {room_id1}, exact=True)
+ self.assertEqual(response_body["rooms"][room_id1]["initial"], True)
+
+ # Trigger a state reset
+ join_rule_event, join_rule_context = self.get_success(
+ create_event(
+ self.hs,
+ prev_event_ids=[event_id],
+ type=EventTypes.JoinRules,
+ state_key="",
+ content={"join_rule": JoinRules.INVITE},
+ sender=user2_id,
+ room_id=room_id1,
+ room_version=self.get_success(self.store.get_room_version_id(room_id1)),
+ )
+ )
+ _, join_rule_event_pos, _ = self.get_success(
+ self.persistence.persist_event(join_rule_event, join_rule_context)
+ )
+
+ # Ensure that the state reset worked and only user2 is in the room now
+ users_in_room = self.get_success(self.store.get_users_in_room(room_id1))
+ self.assertIncludes(set(users_in_room), {user2_id}, exact=True)
+
+ state_map_at_reset = self.get_success(
+ self.storage_controllers.state.get_current_state(room_id1)
+ )
+
+ # Update the state after user1 was state reset out of the room
+ self.helper.send_state(
+ room_id1,
+ EventTypes.Name,
+ {EventContentFields.ROOM_NAME: "my super duper room"},
+ tok=user2_tok,
+ )
+
+ # Make another Sliding Sync request (incremental)
+ response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
+
+ # Expect to see room1 because it is `newly_left` thanks to being state reset out
+ # of it since the last time we synced. We need to let the client know that
+ # something happened and that they are no longer in the room.
+ self.assertIncludes(set(response_body["rooms"].keys()), {room_id1}, exact=True)
+ # We set `initial=True` to indicate that the client should reset the state they
+ # have about the room
+ self.assertEqual(response_body["rooms"][room_id1]["initial"], True)
+ # They shouldn't see anything past the state reset
+ self._assertRequiredStateIncludes(
+ response_body["rooms"][room_id1]["required_state"],
+ # We should see all the state events in the room
+ state_map_at_reset.values(),
+ exact=True,
+ )
+ # The position where the state reset happened
+ self.assertEqual(
+ response_body["rooms"][room_id1]["bump_stamp"],
+ join_rule_event_pos.stream,
+ response_body["rooms"][room_id1],
+ )
+
+ # Other non-important things. We just want to check what these are so we know
+ # what happens in a state reset scenario.
+ #
+ # Room name was set at the time of the state reset so we should still be able to
+ # see it.
+ self.assertEqual(response_body["rooms"][room_id1]["name"], "my super room")
+ # Could be set but there is no avatar for this room
+ self.assertIsNone(
+ response_body["rooms"][room_id1].get("avatar"),
+ response_body["rooms"][room_id1],
+ )
+ # Could be set but this room isn't marked as a DM
+ self.assertIsNone(
+ response_body["rooms"][room_id1].get("is_dm"),
+ response_body["rooms"][room_id1],
+ )
+ # Empty timeline because we are not in the room at all (they are all being
+ # filtered out)
+ self.assertIsNone(
+ response_body["rooms"][room_id1].get("timeline"),
+ response_body["rooms"][room_id1],
+ )
+ # `limited` since we're not providing any timeline events but there are some in
+ # the room.
+ self.assertEqual(response_body["rooms"][room_id1]["limited"], True)
+ # User is no longer in the room so they can't see this info
+ self.assertIsNone(
+ response_body["rooms"][room_id1].get("joined_count"),
+ response_body["rooms"][room_id1],
+ )
+ self.assertIsNone(
+ response_body["rooms"][room_id1].get("invited_count"),
+ response_body["rooms"][room_id1],
+ )
+
+ def test_state_reset_previously_room_comes_down_incremental_sync_with_filters(
+ self,
+ ) -> None:
+ """
+ Test that a room that we were state reset out of should always be sent down
+ regardless of the filters if it has been sent down the connection before.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ # Create a space room
+ space_room_id = self.helper.create_room_as(
+ user2_id,
+ tok=user2_tok,
+ extra_content={
+ "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE},
+ "name": "my super space",
+ },
+ )
+
+ # Create an event for us to point back to for the state reset
+ event_response = self.helper.send(space_room_id, "test", tok=user2_tok)
+ event_id = event_response["event_id"]
+
+ self.helper.join(space_room_id, user1_id, tok=user1_tok)
+
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ "required_state": [
+ # Request all state just to see what we get back when we are
+ # state reset out of the room
+ [StateValues.WILDCARD, StateValues.WILDCARD]
+ ],
+ "timeline_limit": 1,
+ "filters": {
+ "room_types": [RoomTypes.SPACE],
+ },
+ }
+ }
+ }
+
+ # Make the Sliding Sync request
+ response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+ # Make sure we see room1
+ self.assertIncludes(
+ set(response_body["rooms"].keys()), {space_room_id}, exact=True
+ )
+ self.assertEqual(response_body["rooms"][space_room_id]["initial"], True)
+
+ # Trigger a state reset
+ join_rule_event, join_rule_context = self.get_success(
+ create_event(
+ self.hs,
+ prev_event_ids=[event_id],
+ type=EventTypes.JoinRules,
+ state_key="",
+ content={"join_rule": JoinRules.INVITE},
+ sender=user2_id,
+ room_id=space_room_id,
+ room_version=self.get_success(
+ self.store.get_room_version_id(space_room_id)
+ ),
+ )
+ )
+ _, join_rule_event_pos, _ = self.get_success(
+ self.persistence.persist_event(join_rule_event, join_rule_context)
+ )
+
+ # Ensure that the state reset worked and only user2 is in the room now
+ users_in_room = self.get_success(self.store.get_users_in_room(space_room_id))
+ self.assertIncludes(set(users_in_room), {user2_id}, exact=True)
+
+ state_map_at_reset = self.get_success(
+ self.storage_controllers.state.get_current_state(space_room_id)
+ )
+
+ # Update the state after user1 was state reset out of the room
+ self.helper.send_state(
+ space_room_id,
+ EventTypes.Name,
+ {EventContentFields.ROOM_NAME: "my super duper space"},
+ tok=user2_tok,
+ )
+
+ # User2 also leaves the room so the server is no longer participating in the room
+ # and we don't have access to current state
+ self.helper.leave(space_room_id, user2_id, tok=user2_tok)
+
+ # Make another Sliding Sync request (incremental)
+ response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
+
+ # Expect to see room1 because it is `newly_left` thanks to being state reset out
+ # of it since the last time we synced. We need to let the client know that
+ # something happened and that they are no longer in the room.
+ self.assertIncludes(
+ set(response_body["rooms"].keys()), {space_room_id}, exact=True
+ )
+ # We set `initial=True` to indicate that the client should reset the state they
+ # have about the room
+ self.assertEqual(response_body["rooms"][space_room_id]["initial"], True)
+ # They shouldn't see anything past the state reset
+ self._assertRequiredStateIncludes(
+ response_body["rooms"][space_room_id]["required_state"],
+ # We should see all the state events in the room
+ state_map_at_reset.values(),
+ exact=True,
+ )
+ # The position where the state reset happened
+ self.assertEqual(
+ response_body["rooms"][space_room_id]["bump_stamp"],
+ join_rule_event_pos.stream,
+ response_body["rooms"][space_room_id],
+ )
+
+ # Other non-important things. We just want to check what these are so we know
+ # what happens in a state reset scenario.
+ #
+ # Room name was set at the time of the state reset so we should still be able to
+ # see it.
+ self.assertEqual(
+ response_body["rooms"][space_room_id]["name"], "my super space"
+ )
+ # Could be set but there is no avatar for this room
+ self.assertIsNone(
+ response_body["rooms"][space_room_id].get("avatar"),
+ response_body["rooms"][space_room_id],
+ )
+ # Could be set but this room isn't marked as a DM
+ self.assertIsNone(
+ response_body["rooms"][space_room_id].get("is_dm"),
+ response_body["rooms"][space_room_id],
+ )
+ # Empty timeline because we are not in the room at all (they are all being
+ # filtered out)
+ self.assertIsNone(
+ response_body["rooms"][space_room_id].get("timeline"),
+ response_body["rooms"][space_room_id],
+ )
+ # `limited` since we're not providing any timeline events but there are some in
+ # the room.
+ self.assertEqual(response_body["rooms"][space_room_id]["limited"], True)
+ # User is no longer in the room so they can't see this info
+ self.assertIsNone(
+ response_body["rooms"][space_room_id].get("joined_count"),
+ response_body["rooms"][space_room_id],
+ )
+ self.assertIsNone(
+ response_body["rooms"][space_room_id].get("invited_count"),
+ response_body["rooms"][space_room_id],
+ )
+
+ @parameterized.expand(
+ [
+ ("server_leaves_room", True),
+ ("server_participating_in_room", False),
+ ]
+ )
+ def test_state_reset_never_room_incremental_sync_with_filters(
+ self, test_description: str, server_leaves_room: bool
+ ) -> None:
+ """
+ Test that a room that we were state reset out of should be sent down if we can
+ figure out the state or if it was sent down the connection before.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ # Create a space room
+ space_room_id = self.helper.create_room_as(
+ user2_id,
+ tok=user2_tok,
+ extra_content={
+ "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE},
+ "name": "my super space",
+ },
+ )
+
+ # Create another space room
+ space_room_id2 = self.helper.create_room_as(
+ user2_id,
+ tok=user2_tok,
+ extra_content={
+ "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE},
+ },
+ )
+
+ # Create an event for us to point back to for the state reset
+ event_response = self.helper.send(space_room_id, "test", tok=user2_tok)
+ event_id = event_response["event_id"]
+
+ # User1 joins the rooms
+ #
+ self.helper.join(space_room_id, user1_id, tok=user1_tok)
+ # Join space_room_id2 so that it is at the top of the list
+ self.helper.join(space_room_id2, user1_id, tok=user1_tok)
+
+ # Make a SS request for only the top room.
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 0]],
+ "required_state": [
+ # Request all state just to see what we get back when we are
+ # state reset out of the room
+ [StateValues.WILDCARD, StateValues.WILDCARD]
+ ],
+ "timeline_limit": 1,
+ "filters": {
+ "room_types": [RoomTypes.SPACE],
+ },
+ }
+ }
+ }
+
+ # Make the Sliding Sync request
+ response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+ # Make sure we only see space_room_id2
+ self.assertIncludes(
+ set(response_body["rooms"].keys()), {space_room_id2}, exact=True
+ )
+ self.assertEqual(response_body["rooms"][space_room_id2]["initial"], True)
+
+ # Just create some activity in space_room_id2 so it appears when we incremental sync again
+ self.helper.send(space_room_id2, "test", tok=user2_tok)
+
+ # Trigger a state reset
+ join_rule_event, join_rule_context = self.get_success(
+ create_event(
+ self.hs,
+ prev_event_ids=[event_id],
+ type=EventTypes.JoinRules,
+ state_key="",
+ content={"join_rule": JoinRules.INVITE},
+ sender=user2_id,
+ room_id=space_room_id,
+ room_version=self.get_success(
+ self.store.get_room_version_id(space_room_id)
+ ),
+ )
+ )
+ _, join_rule_event_pos, _ = self.get_success(
+ self.persistence.persist_event(join_rule_event, join_rule_context)
+ )
+
+ # Ensure that the state reset worked and only user2 is in the room now
+ users_in_room = self.get_success(self.store.get_users_in_room(space_room_id))
+ self.assertIncludes(set(users_in_room), {user2_id}, exact=True)
+
+ # Update the state after user1 was state reset out of the room.
+ # This will also bump it to the top of the list.
+ self.helper.send_state(
+ space_room_id,
+ EventTypes.Name,
+ {EventContentFields.ROOM_NAME: "my super duper space"},
+ tok=user2_tok,
+ )
+
+ if server_leaves_room:
+ # User2 also leaves the room so the server is no longer participating in the room
+ # and we don't have access to current state
+ self.helper.leave(space_room_id, user2_id, tok=user2_tok)
+
+ # Make another Sliding Sync request (incremental)
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ # Expand the range to include all rooms
+ "ranges": [[0, 1]],
+ "required_state": [
+ # Request all state just to see what we get back when we are
+ # state reset out of the room
+ [StateValues.WILDCARD, StateValues.WILDCARD]
+ ],
+ "timeline_limit": 1,
+ "filters": {
+ "room_types": [RoomTypes.SPACE],
+ },
+ }
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
+
+ if self.use_new_tables:
+ if server_leaves_room:
+ # We still only expect to see space_room_id2 because even though we were state
+ # reset out of space_room_id, it was never sent down the connection before so we
+ # don't need to bother the client with it.
+ self.assertIncludes(
+ set(response_body["rooms"].keys()), {space_room_id2}, exact=True
+ )
+ else:
+ # Both rooms show up because we can figure out the state for the
+ # `filters.room_types` if someone is still in the room (we look at the
+ # current state because `room_type` never changes).
+ self.assertIncludes(
+ set(response_body["rooms"].keys()),
+ {space_room_id, space_room_id2},
+ exact=True,
+ )
+ else:
+ # Both rooms show up because we can actually take the time to figure out the
+ # state for the `filters.room_types` in the fallback path (we look at
+ # historical state for `LEAVE` membership).
+ self.assertIncludes(
+ set(response_body["rooms"].keys()),
+ {space_room_id, space_room_id2},
+ exact=True,
+ )
diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py
index a85ea994de..33611e8a8c 100644
--- a/tests/rest/client/test_account.py
+++ b/tests/rest/client/test_account.py
@@ -36,7 +36,6 @@ from synapse.api.errors import Codes, HttpResponseException
from synapse.appservice import ApplicationService
from synapse.rest import admin
from synapse.rest.client import account, login, register, room
-from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource
from synapse.server import HomeServer
from synapse.storage._base import db_to_json
from synapse.types import JsonDict, UserID
@@ -47,430 +46,404 @@ from tests.server import FakeSite, make_request
from tests.unittest import override_config
-class PasswordResetTestCase(unittest.HomeserverTestCase):
- servlets = [
- account.register_servlets,
- synapse.rest.admin.register_servlets_for_client_rest_resource,
- register.register_servlets,
- login.register_servlets,
- ]
-
- def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
- config = self.default_config()
-
- # Email config.
- config["email"] = {
- "enable_notifs": False,
- "template_dir": os.path.abspath(
- pkg_resources.resource_filename("synapse", "res/templates")
- ),
- "smtp_host": "127.0.0.1",
- "smtp_port": 20,
- "require_transport_security": False,
- "smtp_user": None,
- "smtp_pass": None,
- "notif_from": "test@example.com",
- }
- config["public_baseurl"] = "https://example.com"
-
- hs = self.setup_test_homeserver(config=config)
-
- async def sendmail(
- reactor: IReactorTCP,
- smtphost: str,
- smtpport: int,
- from_addr: str,
- to_addr: str,
- msg_bytes: bytes,
- *args: Any,
- **kwargs: Any,
- ) -> None:
- self.email_attempts.append(msg_bytes)
-
- self.email_attempts: List[bytes] = []
- hs.get_send_email_handler()._sendmail = sendmail
-
- return hs
-
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.store = hs.get_datastores().main
- self.submit_token_resource = PasswordResetSubmitTokenResource(hs)
-
- def attempt_wrong_password_login(self, username: str, password: str) -> None:
- """Attempts to login as the user with the given password, asserting
- that the attempt *fails*.
- """
- body = {"type": "m.login.password", "user": username, "password": password}
-
- channel = self.make_request("POST", "/_matrix/client/r0/login", body)
- self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result)
-
- def test_basic_password_reset(self) -> None:
- """Test basic password reset flow"""
- old_password = "monkey"
- new_password = "kangeroo"
-
- user_id = self.register_user("kermit", old_password)
- self.login("kermit", old_password)
-
- email = "test@example.com"
-
- # Add a threepid
- self.get_success(
- self.store.user_add_threepid(
- user_id=user_id,
- medium="email",
- address=email,
- validated_at=0,
- added_at=0,
- )
- )
-
- client_secret = "foobar"
- session_id = self._request_token(email, client_secret)
-
- self.assertEqual(len(self.email_attempts), 1)
- link = self._get_link_from_email()
-
- self._validate_token(link)
-
- self._reset_password(new_password, session_id, client_secret)
-
- # Assert we can log in with the new password
- self.login("kermit", new_password)
-
- # Assert we can't log in with the old password
- self.attempt_wrong_password_login("kermit", old_password)
-
- # Check that the UI Auth information doesn't store the password in the database.
- #
- # Note that we don't have the UI Auth session ID, so just pull out the single
- # row.
- result = self.get_success(
- self.store.db_pool.simple_select_one_onecol(
- "ui_auth_sessions", keyvalues={}, retcol="clientdict"
- )
- )
- client_dict = db_to_json(result)
- self.assertNotIn("new_password", client_dict)
-
- @override_config({"rc_3pid_validation": {"burst_count": 3}})
- def test_ratelimit_by_email(self) -> None:
- """Test that we ratelimit /requestToken for the same email."""
- old_password = "monkey"
- new_password = "kangeroo"
-
- user_id = self.register_user("kermit", old_password)
- self.login("kermit", old_password)
-
- email = "test1@example.com"
-
- # Add a threepid
- self.get_success(
- self.store.user_add_threepid(
- user_id=user_id,
- medium="email",
- address=email,
- validated_at=0,
- added_at=0,
- )
- )
-
- def reset(ip: str) -> None:
- client_secret = "foobar"
- session_id = self._request_token(email, client_secret, ip)
-
- self.assertEqual(len(self.email_attempts), 1)
- link = self._get_link_from_email()
-
- self._validate_token(link)
-
- self._reset_password(new_password, session_id, client_secret)
-
- self.email_attempts.clear()
-
- # We expect to be able to make three requests before getting rate
- # limited.
- #
- # We change IPs to ensure that we're not being ratelimited due to the
- # same IP
- reset("127.0.0.1")
- reset("127.0.0.2")
- reset("127.0.0.3")
-
- with self.assertRaises(HttpResponseException) as cm:
- reset("127.0.0.4")
-
- self.assertEqual(cm.exception.code, 429)
-
- def test_basic_password_reset_canonicalise_email(self) -> None:
- """Test basic password reset flow
- Request password reset with different spelling
- """
- old_password = "monkey"
- new_password = "kangeroo"
-
- user_id = self.register_user("kermit", old_password)
- self.login("kermit", old_password)
-
- email_profile = "test@example.com"
- email_passwort_reset = "TEST@EXAMPLE.COM"
-
- # Add a threepid
- self.get_success(
- self.store.user_add_threepid(
- user_id=user_id,
- medium="email",
- address=email_profile,
- validated_at=0,
- added_at=0,
- )
- )
-
- client_secret = "foobar"
- session_id = self._request_token(email_passwort_reset, client_secret)
-
- self.assertEqual(len(self.email_attempts), 1)
- link = self._get_link_from_email()
-
- self._validate_token(link)
-
- self._reset_password(new_password, session_id, client_secret)
-
- # Assert we can log in with the new password
- self.login("kermit", new_password)
-
- # Assert we can't log in with the old password
- self.attempt_wrong_password_login("kermit", old_password)
-
- def test_cant_reset_password_without_clicking_link(self) -> None:
- """Test that we do actually need to click the link in the email"""
- old_password = "monkey"
- new_password = "kangeroo"
-
- user_id = self.register_user("kermit", old_password)
- self.login("kermit", old_password)
-
- email = "test@example.com"
+# class PasswordResetTestCase(unittest.HomeserverTestCase):
+# servlets = [
+# account.register_servlets,
+# synapse.rest.admin.register_servlets_for_client_rest_resource,
+# register.register_servlets,
+# login.register_servlets,
+# ]
+
+# def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+# config = self.default_config()
+
+# # Email config.
+# config["email"] = {
+# "enable_notifs": False,
+# "template_dir": os.path.abspath(
+# pkg_resources.resource_filename("synapse", "res/templates")
+# ),
+# "smtp_host": "127.0.0.1",
+# "smtp_port": 20,
+# "require_transport_security": False,
+# "smtp_user": None,
+# "smtp_pass": None,
+# "notif_from": "test@example.com",
+# }
+# config["public_baseurl"] = "https://example.com"
+
+# hs = self.setup_test_homeserver(config=config)
- # Add a threepid
- self.get_success(
- self.store.user_add_threepid(
- user_id=user_id,
- medium="email",
- address=email,
- validated_at=0,
- added_at=0,
- )
- )
-
- client_secret = "foobar"
- session_id = self._request_token(email, client_secret)
-
- self.assertEqual(len(self.email_attempts), 1)
-
- # Attempt to reset password without clicking the link
- self._reset_password(new_password, session_id, client_secret, expected_code=401)
-
- # Assert we can log in with the old password
- self.login("kermit", old_password)
-
- # Assert we can't log in with the new password
- self.attempt_wrong_password_login("kermit", new_password)
-
- def test_no_valid_token(self) -> None:
- """Test that we do actually need to request a token and can't just
- make a session up.
- """
- old_password = "monkey"
- new_password = "kangeroo"
-
- user_id = self.register_user("kermit", old_password)
- self.login("kermit", old_password)
-
- email = "test@example.com"
-
- # Add a threepid
- self.get_success(
- self.store.user_add_threepid(
- user_id=user_id,
- medium="email",
- address=email,
- validated_at=0,
- added_at=0,
- )
- )
-
- client_secret = "foobar"
- session_id = "weasle"
+# async def sendmail(
+# reactor: IReactorTCP,
+# smtphost: str,
+# smtpport: int,
+# from_addr: str,
+# to_addr: str,
+# msg_bytes: bytes,
+# *args: Any,
+# **kwargs: Any,
+# ) -> None:
+# self.email_attempts.append(msg_bytes)
- # Attempt to reset password without even requesting an email
- self._reset_password(new_password, session_id, client_secret, expected_code=401)
+# self.email_attempts: List[bytes] = []
+
+# return hs
- # Assert we can log in with the old password
- self.login("kermit", old_password)
+# def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+# self.store = hs.get_datastores().main
+
+# def attempt_wrong_password_login(self, username: str, password: str) -> None:
+# """Attempts to login as the user with the given password, asserting
+# that the attempt *fails*.
+# """
+# body = {"type": "m.login.password", "user": username, "password": password}
+
+# channel = self.make_request("POST", "/_matrix/client/r0/login", body)
+# self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result)
+
+# def test_basic_password_reset(self) -> None:
+# """Test basic password reset flow"""
+# old_password = "monkey"
+# new_password = "kangeroo"
+
+# user_id = self.register_user("kermit", old_password)
+# self.login("kermit", old_password)
- # Assert we can't log in with the new password
- self.attempt_wrong_password_login("kermit", new_password)
- @unittest.override_config({"request_token_inhibit_3pid_errors": True})
- def test_password_reset_bad_email_inhibit_error(self) -> None:
- """Test that triggering a password reset with an email address that isn't bound
- to an account doesn't leak the lack of binding for that address if configured
- that way.
- """
- self.register_user("kermit", "monkey")
- self.login("kermit", "monkey")
+# client_secret = "foobar"
+# session_id = self._request_token(email, client_secret)
- email = "test@example.com"
+# self.assertEqual(len(self.email_attempts), 1)
+# link = self._get_link_from_email()
- client_secret = "foobar"
- session_id = self._request_token(email, client_secret)
+# self._validate_token(link)
- self.assertIsNotNone(session_id)
+# self._reset_password(new_password, session_id, client_secret)
- def test_password_reset_redirection(self) -> None:
- """Test basic password reset flow"""
- old_password = "monkey"
+# # Assert we can log in with the new password
+# self.login("kermit", new_password)
- user_id = self.register_user("kermit", old_password)
- self.login("kermit", old_password)
+# # Assert we can't log in with the old password
+# self.attempt_wrong_password_login("kermit", old_password)
- email = "test@example.com"
+# # Check that the UI Auth information doesn't store the password in the database.
+# #
+# # Note that we don't have the UI Auth session ID, so just pull out the single
+# # row.
+# result = self.get_success(
+# self.store.db_pool.simple_select_one_onecol(
+# "ui_auth_sessions", keyvalues={}, retcol="clientdict"
+# )
+# )
+# client_dict = db_to_json(result)
+# self.assertNotIn("new_password", client_dict)
+
+# @override_config({"rc_3pid_validation": {"burst_count": 3}})
+# def test_ratelimit_by_email(self) -> None:
+# """Test that we ratelimit /requestToken for the same email."""
+# old_password = "monkey"
+# new_password = "kangeroo"
- # Add a threepid
- self.get_success(
- self.store.user_add_threepid(
- user_id=user_id,
- medium="email",
- address=email,
- validated_at=0,
- added_at=0,
- )
- )
+# user_id = self.register_user("kermit", old_password)
+# self.login("kermit", old_password)
+
+
+# def reset(ip: str) -> None:
+# client_secret = "foobar"
+# session_id = self._request_token(email, client_secret, ip)
- client_secret = "foobar"
- next_link = "http://example.com"
- self._request_token(email, client_secret, "127.0.0.1", next_link)
+# self.assertEqual(len(self.email_attempts), 1)
+# link = self._get_link_from_email()
+
+# self._validate_token(link)
+
+# self._reset_password(new_password, session_id, client_secret)
+
+# self.email_attempts.clear()
- self.assertEqual(len(self.email_attempts), 1)
- link = self._get_link_from_email()
+# # We expect to be able to make three requests before getting rate
+# # limited.
+# #
+# # We change IPs to ensure that we're not being ratelimited due to the
+# # same IP
+# reset("127.0.0.1")
+# reset("127.0.0.2")
+# reset("127.0.0.3")
- self._validate_token(link, next_link)
+# with self.assertRaises(HttpResponseException) as cm:
+# reset("127.0.0.4")
- def _request_token(
- self,
- email: str,
- client_secret: str,
- ip: str = "127.0.0.1",
- next_link: Optional[str] = None,
- ) -> str:
- body = {"client_secret": client_secret, "email": email, "send_attempt": 1}
- if next_link is not None:
- body["next_link"] = next_link
- channel = self.make_request(
- "POST",
- b"account/password/email/requestToken",
- body,
- client_ip=ip,
- )
+# self.assertEqual(cm.exception.code, 429)
- if channel.code != 200:
- raise HttpResponseException(
- channel.code,
- channel.result["reason"],
- channel.result["body"],
- )
+# def test_basic_password_reset_canonicalise_email(self) -> None:
+# """Test basic password reset flow
+# Request password reset with different spelling
+# """
+# old_password = "monkey"
+# new_password = "kangeroo"
- return channel.json_body["sid"]
+# user_id = self.register_user("kermit", old_password)
+# self.login("kermit", old_password)
- def _validate_token(self, link: str, next_link: Optional[str] = None) -> None:
- # Remove the host
- path = link.replace("https://example.com", "")
+# email_profile = "test@example.com"
+# email_passwort_reset = "TEST@EXAMPLE.COM"
- # Load the password reset confirmation page
- channel = make_request(
- self.reactor,
- FakeSite(self.submit_token_resource, self.reactor),
- "GET",
- path,
- shorthand=False,
- )
+# # Add a threepid
+# self.get_success(
+# self.store.user_add_threepid(
+# user_id=user_id,
+# medium="email",
+# address=email_profile,
+# validated_at=0,
+# added_at=0,
+# )
+# )
- self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
+# client_secret = "foobar"
+# session_id = self._request_token(email_passwort_reset, client_secret)
- # Now POST to the same endpoint, mimicking the same behaviour as clicking the
- # password reset confirm button
+# self.assertEqual(len(self.email_attempts), 1)
+# link = self._get_link_from_email()
- # Confirm the password reset
- channel = make_request(
- self.reactor,
- FakeSite(self.submit_token_resource, self.reactor),
- "POST",
- path,
- content=b"",
- shorthand=False,
- content_is_form=True,
- )
- self.assertEqual(
- HTTPStatus.OK if next_link is None else HTTPStatus.FOUND,
- channel.code,
- channel.result,
- )
-
- def _get_link_from_email(self) -> str:
- assert self.email_attempts, "No emails have been sent"
+# self._validate_token(link)
- raw_msg = self.email_attempts[-1].decode("UTF-8")
- mail = Parser().parsestr(raw_msg)
+# self._reset_password(new_password, session_id, client_secret)
- text = None
- for part in mail.walk():
- if part.get_content_type() == "text/plain":
- text = part.get_payload(decode=True)
- if text is not None:
- # According to the logic table in `get_payload`, we know that
- # the result of `get_payload` will be `bytes`, but mypy doesn't
- # know this and complains. Thus, we assert the type.
- assert isinstance(text, bytes)
- text = text.decode("UTF-8")
+# # Assert we can log in with the new password
+# self.login("kermit", new_password)
- break
+# # Assert we can't log in with the old password
+# self.attempt_wrong_password_login("kermit", old_password)
- if not text:
- self.fail("Could not find text portion of email to parse")
+# def test_cant_reset_password_without_clicking_link(self) -> None:
+# """Test that we do actually need to click the link in the email"""
+# old_password = "monkey"
+# new_password = "kangeroo"
+
+# user_id = self.register_user("kermit", old_password)
+# self.login("kermit", old_password)
+
+# email = "test@example.com"
+
+# # Add a threepid
+# self.get_success(
+# self.store.user_add_threepid(
+# user_id=user_id,
+# medium="email",
+# address=email,
+# validated_at=0,
+# added_at=0,
+# )
+# )
- # `text` must be a `str`, after being decoded and determined just above
- # to not be `None` or an empty `str`.
- assert isinstance(text, str)
-
- match = re.search(r"https://example.com\S+", text)
- assert match, "Could not find link in email"
-
- return match.group(0)
-
- def _reset_password(
- self,
- new_password: str,
- session_id: str,
- client_secret: str,
- expected_code: int = HTTPStatus.OK,
- ) -> None:
- channel = self.make_request(
- "POST",
- b"account/password",
- {
- "new_password": new_password,
- "auth": {
- "type": LoginType.EMAIL_IDENTITY,
- "threepid_creds": {
- "client_secret": client_secret,
- "sid": session_id,
- },
- },
- },
- )
- self.assertEqual(expected_code, channel.code, channel.result)
+# client_secret = "foobar"
+# session_id = self._request_token(email, client_secret)
+
+# self.assertEqual(len(self.email_attempts), 1)
+
+# # Attempt to reset password without clicking the link
+# self._reset_password(new_password, session_id, client_secret, expected_code=401)
+
+# # Assert we can log in with the old password
+# self.login("kermit", old_password)
+
+# # Assert we can't log in with the new password
+# self.attempt_wrong_password_login("kermit", new_password)
+
+# def test_no_valid_token(self) -> None:
+# """Test that we do actually need to request a token and can't just
+# make a session up.
+# """
+# old_password = "monkey"
+# new_password = "kangeroo"
+
+# user_id = self.register_user("kermit", old_password)
+# self.login("kermit", old_password)
+
+# email = "test@example.com"
+
+# # Add a threepid
+# self.get_success(
+# self.store.user_add_threepid(
+# user_id=user_id,
+# medium="email",
+# address=email,
+# validated_at=0,
+# added_at=0,
+# )
+# )
+
+# client_secret = "foobar"
+# session_id = "weasle"
+
+# # Attempt to reset password without even requesting an email
+# self._reset_password(new_password, session_id, client_secret, expected_code=401)
+
+# # Assert we can log in with the old password
+# self.login("kermit", old_password)
+
+# # Assert we can't log in with the new password
+# self.attempt_wrong_password_login("kermit", new_password)
+
+# @unittest.override_config({"request_token_inhibit_3pid_errors": True})
+# def test_password_reset_bad_email_inhibit_error(self) -> None:
+# """Test that triggering a password reset with an email address that isn't bound
+# to an account doesn't leak the lack of binding for that address if configured
+# that way.
+# """
+# self.register_user("kermit", "monkey")
+# self.login("kermit", "monkey")
+
+# email = "test@example.com"
+
+# client_secret = "foobar"
+# session_id = self._request_token(email, client_secret)
+
+# self.assertIsNotNone(session_id)
+
+# def test_password_reset_redirection(self) -> None:
+# """Test basic password reset flow"""
+# old_password = "monkey"
+
+# user_id = self.register_user("kermit", old_password)
+# self.login("kermit", old_password)
+
+# email = "test@example.com"
+
+# # Add a threepid
+# self.get_success(
+# self.store.user_add_threepid(
+# user_id=user_id,
+# medium="email",
+# address=email,
+# validated_at=0,
+# added_at=0,
+# )
+# )
+
+# client_secret = "foobar"
+# next_link = "http://example.com"
+# self._request_token(email, client_secret, "127.0.0.1", next_link)
+
+# self.assertEqual(len(self.email_attempts), 1)
+# link = self._get_link_from_email()
+
+# self._validate_token(link, next_link)
+
+# def _request_token(
+# self,
+# email: str,
+# client_secret: str,
+# ip: str = "127.0.0.1",
+# next_link: Optional[str] = None,
+# ) -> str:
+# body = {"client_secret": client_secret, "email": email, "send_attempt": 1}
+# if next_link is not None:
+# body["next_link"] = next_link
+# channel = self.make_request(
+# "POST",
+# b"account/password/email/requestToken",
+# body,
+# client_ip=ip,
+# )
+
+# if channel.code != 200:
+# raise HttpResponseException(
+# channel.code,
+# channel.result["reason"],
+# channel.result["body"],
+# )
+
+# return channel.json_body["sid"]
+
+# def _validate_token(self, link: str, next_link: Optional[str] = None) -> None:
+# # Remove the host
+# path = link.replace("https://example.com", "")
+
+# # Load the password reset confirmation page
+# channel = make_request(
+# self.reactor,
+# FakeSite(self.submit_token_resource, self.reactor),
+# "GET",
+# path,
+# shorthand=False,
+# )
+
+# self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
+
+# # Now POST to the same endpoint, mimicking the same behaviour as clicking the
+# # password reset confirm button
+
+# # Confirm the password reset
+# channel = make_request(
+# self.reactor,
+# FakeSite(self.submit_token_resource, self.reactor),
+# "POST",
+# path,
+# content=b"",
+# shorthand=False,
+# content_is_form=True,
+# )
+# self.assertEqual(
+# HTTPStatus.OK if next_link is None else HTTPStatus.FOUND,
+# channel.code,
+# channel.result,
+# )
+
+# def _get_link_from_email(self) -> str:
+# assert self.email_attempts, "No emails have been sent"
+
+# raw_msg = self.email_attempts[-1].decode("UTF-8")
+# mail = Parser().parsestr(raw_msg)
+
+# text = None
+# for part in mail.walk():
+# if part.get_content_type() == "text/plain":
+# text = part.get_payload(decode=True)
+# if text is not None:
+# # According to the logic table in `get_payload`, we know that
+# # the result of `get_payload` will be `bytes`, but mypy doesn't
+# # know this and complains. Thus, we assert the type.
+# assert isinstance(text, bytes)
+# text = text.decode("UTF-8")
+
+# break
+
+# if not text:
+# self.fail("Could not find text portion of email to parse")
+
+# # `text` must be a `str`, after being decoded and determined just above
+# # to not be `None` or an empty `str`.
+# assert isinstance(text, str)
+
+# match = re.search(r"https://example.com\S+", text)
+# assert match, "Could not find link in email"
+
+# return match.group(0)
+
+# def _reset_password(
+# self,
+# new_password: str,
+# session_id: str,
+# client_secret: str,
+# expected_code: int = HTTPStatus.OK,
+# ) -> None:
+# channel = self.make_request(
+# "POST",
+# b"account/password",
+# {
+# "new_password": new_password,
+# "auth": {
+# "type": LoginType.EMAIL_IDENTITY,
+# "threepid_creds": {
+# "client_secret": client_secret,
+# "sid": session_id,
+# },
+# },
+# },
+# )
+# self.assertEqual(expected_code, channel.code, channel.result)
class DeactivateTestCase(unittest.HomeserverTestCase):
@@ -787,503 +760,6 @@ class WhoamiTestCase(unittest.HomeserverTestCase):
return channel.json_body
-class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
- servlets = [
- account.register_servlets,
- login.register_servlets,
- synapse.rest.admin.register_servlets_for_client_rest_resource,
- ]
-
- def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
- config = self.default_config()
-
- # Email config.
- config["email"] = {
- "enable_notifs": False,
- "template_dir": os.path.abspath(
- pkg_resources.resource_filename("synapse", "res/templates")
- ),
- "smtp_host": "127.0.0.1",
- "smtp_port": 20,
- "require_transport_security": False,
- "smtp_user": None,
- "smtp_pass": None,
- "notif_from": "test@example.com",
- }
- config["public_baseurl"] = "https://example.com"
-
- self.hs = self.setup_test_homeserver(config=config)
-
- async def sendmail(
- reactor: IReactorTCP,
- smtphost: str,
- smtpport: int,
- from_addr: str,
- to_addr: str,
- msg_bytes: bytes,
- *args: Any,
- **kwargs: Any,
- ) -> None:
- self.email_attempts.append(msg_bytes)
-
- self.email_attempts: List[bytes] = []
- self.hs.get_send_email_handler()._sendmail = sendmail
-
- return self.hs
-
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.store = hs.get_datastores().main
-
- self.user_id = self.register_user("kermit", "test")
- self.user_id_tok = self.login("kermit", "test")
- self.email = "test@example.com"
- self.url_3pid = b"account/3pid"
-
- def test_add_valid_email(self) -> None:
- self._add_email(self.email, self.email)
-
- def test_add_valid_email_second_time(self) -> None:
- self._add_email(self.email, self.email)
- self._request_token_invalid_email(
- self.email,
- expected_errcode=Codes.THREEPID_IN_USE,
- expected_error="Email is already in use",
- )
-
- def test_add_valid_email_second_time_canonicalise(self) -> None:
- self._add_email(self.email, self.email)
- self._request_token_invalid_email(
- "TEST@EXAMPLE.COM",
- expected_errcode=Codes.THREEPID_IN_USE,
- expected_error="Email is already in use",
- )
-
- def test_add_email_no_at(self) -> None:
- self._request_token_invalid_email(
- "address-without-at.bar",
- expected_errcode=Codes.BAD_JSON,
- expected_error="Unable to parse email address",
- )
-
- def test_add_email_two_at(self) -> None:
- self._request_token_invalid_email(
- "foo@foo@test.bar",
- expected_errcode=Codes.BAD_JSON,
- expected_error="Unable to parse email address",
- )
-
- def test_add_email_bad_format(self) -> None:
- self._request_token_invalid_email(
- "user@bad.example.net@good.example.com",
- expected_errcode=Codes.BAD_JSON,
- expected_error="Unable to parse email address",
- )
-
- def test_add_email_domain_to_lower(self) -> None:
- self._add_email("foo@TEST.BAR", "foo@test.bar")
-
- def test_add_email_domain_with_umlaut(self) -> None:
- self._add_email("foo@Öumlaut.com", "foo@öumlaut.com")
-
- def test_add_email_address_casefold(self) -> None:
- self._add_email("Strauß@Example.com", "strauss@example.com")
-
- def test_address_trim(self) -> None:
- self._add_email(" foo@test.bar ", "foo@test.bar")
-
- @override_config({"rc_3pid_validation": {"burst_count": 3}})
- def test_ratelimit_by_ip(self) -> None:
- """Tests that adding emails is ratelimited by IP"""
-
- # We expect to be able to set three emails before getting ratelimited.
- self._add_email("foo1@test.bar", "foo1@test.bar")
- self._add_email("foo2@test.bar", "foo2@test.bar")
- self._add_email("foo3@test.bar", "foo3@test.bar")
-
- with self.assertRaises(HttpResponseException) as cm:
- self._add_email("foo4@test.bar", "foo4@test.bar")
-
- self.assertEqual(cm.exception.code, 429)
-
- def test_add_email_if_disabled(self) -> None:
- """Test adding email to profile when doing so is disallowed"""
- self.hs.config.registration.enable_3pid_changes = False
-
- client_secret = "foobar"
- channel = self.make_request(
- "POST",
- b"/_matrix/client/unstable/account/3pid/email/requestToken",
- {
- "client_secret": client_secret,
- "email": "test@example.com",
- "send_attempt": 1,
- },
- )
-
- self.assertEqual(
- HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
- )
-
- self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
-
- def test_delete_email(self) -> None:
- """Test deleting an email from profile"""
- # Add a threepid
- self.get_success(
- self.store.user_add_threepid(
- user_id=self.user_id,
- medium="email",
- address=self.email,
- validated_at=0,
- added_at=0,
- )
- )
-
- channel = self.make_request(
- "POST",
- b"account/3pid/delete",
- {"medium": "email", "address": self.email},
- access_token=self.user_id_tok,
- )
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
-
- # Get user
- channel = self.make_request(
- "GET",
- self.url_3pid,
- access_token=self.user_id_tok,
- )
-
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
- self.assertFalse(channel.json_body["threepids"])
-
- def test_delete_email_if_disabled(self) -> None:
- """Test deleting an email from profile when disallowed"""
- self.hs.config.registration.enable_3pid_changes = False
-
- # Add a threepid
- self.get_success(
- self.store.user_add_threepid(
- user_id=self.user_id,
- medium="email",
- address=self.email,
- validated_at=0,
- added_at=0,
- )
- )
-
- channel = self.make_request(
- "POST",
- b"account/3pid/delete",
- {"medium": "email", "address": self.email},
- access_token=self.user_id_tok,
- )
-
- self.assertEqual(
- HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
- )
- self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
-
- # Get user
- channel = self.make_request(
- "GET",
- self.url_3pid,
- access_token=self.user_id_tok,
- )
-
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
- self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
- self.assertEqual(self.email, channel.json_body["threepids"][0]["address"])
-
- def test_cant_add_email_without_clicking_link(self) -> None:
- """Test that we do actually need to click the link in the email"""
- client_secret = "foobar"
- session_id = self._request_token(self.email, client_secret)
-
- self.assertEqual(len(self.email_attempts), 1)
-
- # Attempt to add email without clicking the link
- channel = self.make_request(
- "POST",
- b"/_matrix/client/unstable/account/3pid/add",
- {
- "client_secret": client_secret,
- "sid": session_id,
- "auth": {
- "type": "m.login.password",
- "user": self.user_id,
- "password": "test",
- },
- },
- access_token=self.user_id_tok,
- )
- self.assertEqual(
- HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
- )
- self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"])
-
- # Get user
- channel = self.make_request(
- "GET",
- self.url_3pid,
- access_token=self.user_id_tok,
- )
-
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
- self.assertFalse(channel.json_body["threepids"])
-
- def test_no_valid_token(self) -> None:
- """Test that we do actually need to request a token and can't just
- make a session up.
- """
- client_secret = "foobar"
- session_id = "weasle"
-
- # Attempt to add email without even requesting an email
- channel = self.make_request(
- "POST",
- b"/_matrix/client/unstable/account/3pid/add",
- {
- "client_secret": client_secret,
- "sid": session_id,
- "auth": {
- "type": "m.login.password",
- "user": self.user_id,
- "password": "test",
- },
- },
- access_token=self.user_id_tok,
- )
- self.assertEqual(
- HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
- )
- self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"])
-
- # Get user
- channel = self.make_request(
- "GET",
- self.url_3pid,
- access_token=self.user_id_tok,
- )
-
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
- self.assertFalse(channel.json_body["threepids"])
-
- @override_config({"next_link_domain_whitelist": None})
- def test_next_link(self) -> None:
- """Tests a valid next_link parameter value with no whitelist (good case)"""
- self._request_token(
- "something@example.com",
- "some_secret",
- next_link="https://example.com/a/good/site",
- expect_code=HTTPStatus.OK,
- )
-
- @override_config({"next_link_domain_whitelist": None})
- def test_next_link_exotic_protocol(self) -> None:
- """Tests using a esoteric protocol as a next_link parameter value.
- Someone may be hosting a client on IPFS etc.
- """
- self._request_token(
- "something@example.com",
- "some_secret",
- next_link="some-protocol://abcdefghijklmopqrstuvwxyz",
- expect_code=HTTPStatus.OK,
- )
-
- @override_config({"next_link_domain_whitelist": None})
- def test_next_link_file_uri(self) -> None:
- """Tests next_link parameters cannot be file URI"""
- # Attempt to use a next_link value that points to the local disk
- self._request_token(
- "something@example.com",
- "some_secret",
- next_link="file:///host/path",
- expect_code=HTTPStatus.BAD_REQUEST,
- )
-
- @override_config({"next_link_domain_whitelist": ["example.com", "example.org"]})
- def test_next_link_domain_whitelist(self) -> None:
- """Tests next_link parameters must fit the whitelist if provided"""
-
- # Ensure not providing a next_link parameter still works
- self._request_token(
- "something@example.com",
- "some_secret",
- next_link=None,
- expect_code=HTTPStatus.OK,
- )
-
- self._request_token(
- "something@example.com",
- "some_secret",
- next_link="https://example.com/some/good/page",
- expect_code=HTTPStatus.OK,
- )
-
- self._request_token(
- "something@example.com",
- "some_secret",
- next_link="https://example.org/some/also/good/page",
- expect_code=HTTPStatus.OK,
- )
-
- self._request_token(
- "something@example.com",
- "some_secret",
- next_link="https://bad.example.org/some/bad/page",
- expect_code=HTTPStatus.BAD_REQUEST,
- )
-
- @override_config({"next_link_domain_whitelist": []})
- def test_empty_next_link_domain_whitelist(self) -> None:
- """Tests an empty next_lint_domain_whitelist value, meaning next_link is essentially
- disallowed
- """
- self._request_token(
- "something@example.com",
- "some_secret",
- next_link="https://example.com/a/page",
- expect_code=HTTPStatus.BAD_REQUEST,
- )
-
- def _request_token(
- self,
- email: str,
- client_secret: str,
- next_link: Optional[str] = None,
- expect_code: int = HTTPStatus.OK,
- ) -> Optional[str]:
- """Request a validation token to add an email address to a user's account
-
- Args:
- email: The email address to validate
- client_secret: A secret string
- next_link: A link to redirect the user to after validation
- expect_code: Expected return code of the call
-
- Returns:
- The ID of the new threepid validation session, or None if the response
- did not contain a session ID.
- """
- body = {"client_secret": client_secret, "email": email, "send_attempt": 1}
- if next_link:
- body["next_link"] = next_link
-
- channel = self.make_request(
- "POST",
- b"account/3pid/email/requestToken",
- body,
- )
-
- if channel.code != expect_code:
- raise HttpResponseException(
- channel.code,
- channel.result["reason"],
- channel.result["body"],
- )
-
- return channel.json_body.get("sid")
-
- def _request_token_invalid_email(
- self,
- email: str,
- expected_errcode: str,
- expected_error: str,
- client_secret: str = "foobar",
- ) -> None:
- channel = self.make_request(
- "POST",
- b"account/3pid/email/requestToken",
- {"client_secret": client_secret, "email": email, "send_attempt": 1},
- )
- self.assertEqual(
- HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
- )
- self.assertEqual(expected_errcode, channel.json_body["errcode"])
- self.assertIn(expected_error, channel.json_body["error"])
-
- def _validate_token(self, link: str) -> None:
- # Remove the host
- path = link.replace("https://example.com", "")
-
- channel = self.make_request("GET", path, shorthand=False)
- self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
-
- def _get_link_from_email(self) -> str:
- assert self.email_attempts, "No emails have been sent"
-
- raw_msg = self.email_attempts[-1].decode("UTF-8")
- mail = Parser().parsestr(raw_msg)
-
- text = None
- for part in mail.walk():
- if part.get_content_type() == "text/plain":
- text = part.get_payload(decode=True)
- if text is not None:
- # According to the logic table in `get_payload`, we know that
- # the result of `get_payload` will be `bytes`, but mypy doesn't
- # know this and complains. Thus, we assert the type.
- assert isinstance(text, bytes)
- text = text.decode("UTF-8")
-
- break
-
- if not text:
- self.fail("Could not find text portion of email to parse")
-
- # `text` must be a `str`, after being decoded and determined just above
- # to not be `None` or an empty `str`.
- assert isinstance(text, str)
-
- match = re.search(r"https://example.com\S+", text)
- assert match, "Could not find link in email"
-
- return match.group(0)
-
- def _add_email(self, request_email: str, expected_email: str) -> None:
- """Test adding an email to profile"""
- previous_email_attempts = len(self.email_attempts)
-
- client_secret = "foobar"
- session_id = self._request_token(request_email, client_secret)
-
- self.assertEqual(len(self.email_attempts) - previous_email_attempts, 1)
- link = self._get_link_from_email()
-
- self._validate_token(link)
-
- channel = self.make_request(
- "POST",
- b"/_matrix/client/unstable/account/3pid/add",
- {
- "client_secret": client_secret,
- "sid": session_id,
- "auth": {
- "type": "m.login.password",
- "user": self.user_id,
- "password": "test",
- },
- },
- access_token=self.user_id_tok,
- )
-
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
-
- # Get user
- channel = self.make_request(
- "GET",
- self.url_3pid,
- access_token=self.user_id_tok,
- )
-
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
- self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
-
- threepids = {threepid["address"] for threepid in channel.json_body["threepids"]}
- self.assertIn(expected_email, threepids)
-
-
class AccountStatusTestCase(unittest.HomeserverTestCase):
servlets = [
account.register_servlets,
diff --git a/tests/rest/client/test_auth_issuer.py b/tests/rest/client/test_auth_issuer.py
deleted file mode 100644
index 964baeec32..0000000000
--- a/tests/rest/client/test_auth_issuer.py
+++ /dev/null
@@ -1,59 +0,0 @@
-# Copyright 2023 The Matrix.org Foundation C.I.C.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-from http import HTTPStatus
-
-from synapse.rest.client import auth_issuer
-
-from tests.unittest import HomeserverTestCase, override_config, skip_unless
-from tests.utils import HAS_AUTHLIB
-
-ISSUER = "https://account.example.com/"
-
-
-class AuthIssuerTestCase(HomeserverTestCase):
- servlets = [
- auth_issuer.register_servlets,
- ]
-
- def test_returns_404_when_msc3861_disabled(self) -> None:
- # Make an unauthenticated request for the discovery info.
- channel = self.make_request(
- "GET",
- "/_matrix/client/unstable/org.matrix.msc2965/auth_issuer",
- )
- self.assertEqual(channel.code, HTTPStatus.NOT_FOUND)
-
- @skip_unless(HAS_AUTHLIB, "requires authlib")
- @override_config(
- {
- "disable_registration": True,
- "experimental_features": {
- "msc3861": {
- "enabled": True,
- "issuer": ISSUER,
- "client_id": "David Lister",
- "client_auth_method": "client_secret_post",
- "client_secret": "Who shot Mister Burns?",
- }
- },
- }
- )
- def test_returns_issuer_when_oidc_enabled(self) -> None:
- # Make an unauthenticated request for the discovery info.
- channel = self.make_request(
- "GET",
- "/_matrix/client/unstable/org.matrix.msc2965/auth_issuer",
- )
- self.assertEqual(channel.code, HTTPStatus.OK)
- self.assertEqual(channel.json_body, {"issuer": ISSUER})
diff --git a/tests/rest/client/test_auth_metadata.py b/tests/rest/client/test_auth_metadata.py
new file mode 100644
index 0000000000..a935533b09
--- /dev/null
+++ b/tests/rest/client/test_auth_metadata.py
@@ -0,0 +1,140 @@
+#
+# This file is licensed under the Affero General Public License (AGPL) version 3.
+#
+# Copyright 2023 The Matrix.org Foundation C.I.C
+# Copyright (C) 2023-2025 New Vector, Ltd
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as
+# published by the Free Software Foundation, either version 3 of the
+# License, or (at your option) any later version.
+#
+# See the GNU Affero General Public License for more details:
+# <https://www.gnu.org/licenses/agpl-3.0.html>.
+#
+# Originally licensed under the Apache License, Version 2.0:
+# <http://www.apache.org/licenses/LICENSE-2.0>.
+#
+# [This file includes modifications made by New Vector Limited]
+#
+from http import HTTPStatus
+from unittest.mock import AsyncMock
+
+from synapse.rest.client import auth_metadata
+
+from tests.unittest import HomeserverTestCase, override_config, skip_unless
+from tests.utils import HAS_AUTHLIB
+
+ISSUER = "https://account.example.com/"
+
+
+class AuthIssuerTestCase(HomeserverTestCase):
+ servlets = [
+ auth_metadata.register_servlets,
+ ]
+
+ def test_returns_404_when_msc3861_disabled(self) -> None:
+ # Make an unauthenticated request for the discovery info.
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/org.matrix.msc2965/auth_issuer",
+ )
+ self.assertEqual(channel.code, HTTPStatus.NOT_FOUND)
+
+ @skip_unless(HAS_AUTHLIB, "requires authlib")
+ @override_config(
+ {
+ "disable_registration": True,
+ "experimental_features": {
+ "msc3861": {
+ "enabled": True,
+ "issuer": ISSUER,
+ "client_id": "David Lister",
+ "client_auth_method": "client_secret_post",
+ "client_secret": "Who shot Mister Burns?",
+ }
+ },
+ }
+ )
+ def test_returns_issuer_when_oidc_enabled(self) -> None:
+ # Patch the HTTP client to return the issuer metadata
+ req_mock = AsyncMock(return_value={"issuer": ISSUER})
+ self.hs.get_proxied_http_client().get_json = req_mock # type: ignore[method-assign]
+
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/org.matrix.msc2965/auth_issuer",
+ )
+
+ self.assertEqual(channel.code, HTTPStatus.OK)
+ self.assertEqual(channel.json_body, {"issuer": ISSUER})
+
+ req_mock.assert_called_with(
+ "https://account.example.com/.well-known/openid-configuration"
+ )
+ req_mock.reset_mock()
+
+ # Second call it should use the cached value
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/org.matrix.msc2965/auth_issuer",
+ )
+
+ self.assertEqual(channel.code, HTTPStatus.OK)
+ self.assertEqual(channel.json_body, {"issuer": ISSUER})
+ req_mock.assert_not_called()
+
+
+class AuthMetadataTestCase(HomeserverTestCase):
+ servlets = [
+ auth_metadata.register_servlets,
+ ]
+
+ def test_returns_404_when_msc3861_disabled(self) -> None:
+ # Make an unauthenticated request for the discovery info.
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/org.matrix.msc2965/auth_metadata",
+ )
+ self.assertEqual(channel.code, HTTPStatus.NOT_FOUND)
+
+ @skip_unless(HAS_AUTHLIB, "requires authlib")
+ @override_config(
+ {
+ "disable_registration": True,
+ "experimental_features": {
+ "msc3861": {
+ "enabled": True,
+ "issuer": ISSUER,
+ "client_id": "David Lister",
+ "client_auth_method": "client_secret_post",
+ "client_secret": "Who shot Mister Burns?",
+ }
+ },
+ }
+ )
+ def test_returns_issuer_when_oidc_enabled(self) -> None:
+ # Patch the HTTP client to return the issuer metadata
+ req_mock = AsyncMock(
+ return_value={
+ "issuer": ISSUER,
+ "authorization_endpoint": "https://example.com/auth",
+ "token_endpoint": "https://example.com/token",
+ }
+ )
+ self.hs.get_proxied_http_client().get_json = req_mock # type: ignore[method-assign]
+
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/org.matrix.msc2965/auth_metadata",
+ )
+
+ self.assertEqual(channel.code, HTTPStatus.OK)
+ self.assertEqual(
+ channel.json_body,
+ {
+ "issuer": ISSUER,
+ "authorization_endpoint": "https://example.com/auth",
+ "token_endpoint": "https://example.com/token",
+ },
+ )
diff --git a/tests/rest/client/test_capabilities.py b/tests/rest/client/test_capabilities.py
index bbe8ab1a7c..1cfaf4fbd7 100644
--- a/tests/rest/client/test_capabilities.py
+++ b/tests/rest/client/test_capabilities.py
@@ -118,7 +118,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
self.assertTrue(capabilities["m.change_password"]["enabled"])
self.assertTrue(capabilities["m.set_displayname"]["enabled"])
self.assertTrue(capabilities["m.set_avatar_url"]["enabled"])
- self.assertTrue(capabilities["m.3pid_changes"]["enabled"])
+ self.assertFalse(capabilities["m.3pid_changes"]["enabled"])
@override_config({"enable_set_displayname": False})
def test_get_set_displayname_capabilities_displayname_disabled(self) -> None:
@@ -142,56 +142,49 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, HTTPStatus.OK)
self.assertFalse(capabilities["m.set_avatar_url"]["enabled"])
- @override_config({"enable_3pid_changes": False})
- def test_get_change_3pid_capabilities_3pid_disabled(self) -> None:
- """Test if change 3pid is disabled that the server responds it."""
+ @override_config(
+ {
+ "enable_set_displayname": False,
+ "experimental_features": {"msc4133_enabled": True},
+ }
+ )
+ def test_get_set_displayname_capabilities_displayname_disabled_msc4133(
+ self,
+ ) -> None:
+ """Test if set displayname is disabled that the server responds it."""
access_token = self.login(self.localpart, self.password)
channel = self.make_request("GET", self.url, access_token=access_token)
capabilities = channel.json_body["capabilities"]
self.assertEqual(channel.code, HTTPStatus.OK)
- self.assertFalse(capabilities["m.3pid_changes"]["enabled"])
-
- @override_config({"experimental_features": {"msc3244_enabled": False}})
- def test_get_does_not_include_msc3244_fields_when_disabled(self) -> None:
- access_token = self.get_success(
- self.auth_handler.create_access_token_for_user_id(
- self.user, device_id=None, valid_until_ms=None
- )
- )
-
- channel = self.make_request("GET", self.url, access_token=access_token)
- capabilities = channel.json_body["capabilities"]
-
- self.assertEqual(channel.code, 200)
- self.assertNotIn(
- "org.matrix.msc3244.room_capabilities", capabilities["m.room_versions"]
+ self.assertFalse(capabilities["m.set_displayname"]["enabled"])
+ self.assertTrue(capabilities["uk.tcpip.msc4133.profile_fields"]["enabled"])
+ self.assertEqual(
+ capabilities["uk.tcpip.msc4133.profile_fields"]["disallowed"],
+ ["displayname"],
)
- def test_get_does_include_msc3244_fields_when_enabled(self) -> None:
- access_token = self.get_success(
- self.auth_handler.create_access_token_for_user_id(
- self.user, device_id=None, valid_until_ms=None
- )
- )
+ @override_config(
+ {
+ "enable_set_avatar_url": False,
+ "experimental_features": {"msc4133_enabled": True},
+ }
+ )
+ def test_get_set_avatar_url_capabilities_avatar_url_disabled_msc4133(self) -> None:
+ """Test if set avatar_url is disabled that the server responds it."""
+ access_token = self.login(self.localpart, self.password)
channel = self.make_request("GET", self.url, access_token=access_token)
capabilities = channel.json_body["capabilities"]
- self.assertEqual(channel.code, 200)
- for details in capabilities["m.room_versions"][
- "org.matrix.msc3244.room_capabilities"
- ].values():
- if details["preferred"] is not None:
- self.assertTrue(
- details["preferred"] in KNOWN_ROOM_VERSIONS,
- str(details["preferred"]),
- )
-
- self.assertGreater(len(details["support"]), 0)
- for room_version in details["support"]:
- self.assertTrue(room_version in KNOWN_ROOM_VERSIONS, str(room_version))
+ self.assertEqual(channel.code, HTTPStatus.OK)
+ self.assertFalse(capabilities["m.set_avatar_url"]["enabled"])
+ self.assertTrue(capabilities["uk.tcpip.msc4133.profile_fields"]["enabled"])
+ self.assertEqual(
+ capabilities["uk.tcpip.msc4133.profile_fields"]["disallowed"],
+ ["avatar_url"],
+ )
def test_get_get_token_login_fields_when_disabled(self) -> None:
"""By default login via an existing session is disabled."""
diff --git a/tests/rest/client/test_delayed_events.py b/tests/rest/client/test_delayed_events.py
new file mode 100644
index 0000000000..9f9d241f12
--- /dev/null
+++ b/tests/rest/client/test_delayed_events.py
@@ -0,0 +1,610 @@
+#
+# This file is licensed under the Affero General Public License (AGPL) version 3.
+#
+# Copyright (C) 2024 New Vector, Ltd
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as
+# published by the Free Software Foundation, either version 3 of the
+# License, or (at your option) any later version.
+#
+# See the GNU Affero General Public License for more details:
+# <https://www.gnu.org/licenses/agpl-3.0.html>.
+#
+
+"""Tests REST events for /delayed_events paths."""
+
+from http import HTTPStatus
+from typing import List
+
+from parameterized import parameterized
+
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.api.errors import Codes
+from synapse.rest import admin
+from synapse.rest.client import delayed_events, login, room, versions
+from synapse.server import HomeServer
+from synapse.types import JsonDict
+from synapse.util import Clock
+
+from tests import unittest
+from tests.unittest import HomeserverTestCase
+
+PATH_PREFIX = "/_matrix/client/unstable/org.matrix.msc4140/delayed_events"
+
+_EVENT_TYPE = "com.example.test"
+
+
+class DelayedEventsUnstableSupportTestCase(HomeserverTestCase):
+ servlets = [versions.register_servlets]
+
+ def test_false_by_default(self) -> None:
+ channel = self.make_request("GET", "/_matrix/client/versions")
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertFalse(channel.json_body["unstable_features"]["org.matrix.msc4140"])
+
+ @unittest.override_config({"max_event_delay_duration": "24h"})
+ def test_true_if_enabled(self) -> None:
+ channel = self.make_request("GET", "/_matrix/client/versions")
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertTrue(channel.json_body["unstable_features"]["org.matrix.msc4140"])
+
+
+class DelayedEventsTestCase(HomeserverTestCase):
+ """Tests getting and managing delayed events."""
+
+ servlets = [
+ admin.register_servlets,
+ delayed_events.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def default_config(self) -> JsonDict:
+ config = super().default_config()
+ config["max_event_delay_duration"] = "24h"
+ return config
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.user1_user_id = self.register_user("user1", "pass")
+ self.user1_access_token = self.login("user1", "pass")
+ self.user2_user_id = self.register_user("user2", "pass")
+ self.user2_access_token = self.login("user2", "pass")
+
+ self.room_id = self.helper.create_room_as(
+ self.user1_user_id,
+ tok=self.user1_access_token,
+ extra_content={
+ "preset": "public_chat",
+ "power_level_content_override": {
+ "events": {
+ _EVENT_TYPE: 0,
+ }
+ },
+ },
+ )
+
+ self.helper.join(
+ room=self.room_id, user=self.user2_user_id, tok=self.user2_access_token
+ )
+
+ def test_delayed_events_empty_on_startup(self) -> None:
+ self.assertListEqual([], self._get_delayed_events())
+
+ def test_delayed_state_events_are_sent_on_timeout(self) -> None:
+ state_key = "to_send_on_timeout"
+
+ setter_key = "setter"
+ setter_expected = "on_timeout"
+ channel = self.make_request(
+ "PUT",
+ _get_path_for_delayed_state(self.room_id, _EVENT_TYPE, state_key, 900),
+ {
+ setter_key: setter_expected,
+ },
+ self.user1_access_token,
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
+ events = self._get_delayed_events()
+ self.assertEqual(1, len(events), events)
+ content = self._get_delayed_event_content(events[0])
+ self.assertEqual(setter_expected, content.get(setter_key), content)
+ self.helper.get_state(
+ self.room_id,
+ _EVENT_TYPE,
+ self.user1_access_token,
+ state_key=state_key,
+ expect_code=HTTPStatus.NOT_FOUND,
+ )
+
+ self.reactor.advance(1)
+ self.assertListEqual([], self._get_delayed_events())
+ content = self.helper.get_state(
+ self.room_id,
+ _EVENT_TYPE,
+ self.user1_access_token,
+ state_key=state_key,
+ )
+ self.assertEqual(setter_expected, content.get(setter_key), content)
+
+ @unittest.override_config(
+ {"rc_delayed_event_mgmt": {"per_second": 0.5, "burst_count": 1}}
+ )
+ def test_get_delayed_events_ratelimit(self) -> None:
+ args = ("GET", PATH_PREFIX, b"", self.user1_access_token)
+
+ channel = self.make_request(*args)
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
+
+ channel = self.make_request(*args)
+ self.assertEqual(HTTPStatus.TOO_MANY_REQUESTS, channel.code, channel.result)
+
+ # Add the current user to the ratelimit overrides, allowing them no ratelimiting.
+ self.get_success(
+ self.hs.get_datastores().main.set_ratelimit_for_user(
+ self.user1_user_id, 0, 0
+ )
+ )
+
+ # Test that the request isn't ratelimited anymore.
+ channel = self.make_request(*args)
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
+
+ def test_update_delayed_event_without_id(self) -> None:
+ channel = self.make_request(
+ "POST",
+ f"{PATH_PREFIX}/",
+ access_token=self.user1_access_token,
+ )
+ self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, channel.result)
+
+ def test_update_delayed_event_without_body(self) -> None:
+ channel = self.make_request(
+ "POST",
+ f"{PATH_PREFIX}/abc",
+ access_token=self.user1_access_token,
+ )
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result)
+ self.assertEqual(
+ Codes.NOT_JSON,
+ channel.json_body["errcode"],
+ )
+
+ def test_update_delayed_event_without_action(self) -> None:
+ channel = self.make_request(
+ "POST",
+ f"{PATH_PREFIX}/abc",
+ {},
+ self.user1_access_token,
+ )
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result)
+ self.assertEqual(
+ Codes.MISSING_PARAM,
+ channel.json_body["errcode"],
+ )
+
+ def test_update_delayed_event_with_invalid_action(self) -> None:
+ channel = self.make_request(
+ "POST",
+ f"{PATH_PREFIX}/abc",
+ {"action": "oops"},
+ self.user1_access_token,
+ )
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result)
+ self.assertEqual(
+ Codes.INVALID_PARAM,
+ channel.json_body["errcode"],
+ )
+
+ @parameterized.expand(["cancel", "restart", "send"])
+ def test_update_delayed_event_without_match(self, action: str) -> None:
+ channel = self.make_request(
+ "POST",
+ f"{PATH_PREFIX}/abc",
+ {"action": action},
+ self.user1_access_token,
+ )
+ self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, channel.result)
+
+ def test_cancel_delayed_state_event(self) -> None:
+ state_key = "to_never_send"
+
+ setter_key = "setter"
+ setter_expected = "none"
+ channel = self.make_request(
+ "PUT",
+ _get_path_for_delayed_state(self.room_id, _EVENT_TYPE, state_key, 1500),
+ {
+ setter_key: setter_expected,
+ },
+ self.user1_access_token,
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
+ delay_id = channel.json_body.get("delay_id")
+ self.assertIsNotNone(delay_id)
+
+ self.reactor.advance(1)
+ events = self._get_delayed_events()
+ self.assertEqual(1, len(events), events)
+ content = self._get_delayed_event_content(events[0])
+ self.assertEqual(setter_expected, content.get(setter_key), content)
+ self.helper.get_state(
+ self.room_id,
+ _EVENT_TYPE,
+ self.user1_access_token,
+ state_key=state_key,
+ expect_code=HTTPStatus.NOT_FOUND,
+ )
+
+ channel = self.make_request(
+ "POST",
+ f"{PATH_PREFIX}/{delay_id}",
+ {"action": "cancel"},
+ self.user1_access_token,
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
+ self.assertListEqual([], self._get_delayed_events())
+
+ self.reactor.advance(1)
+ content = self.helper.get_state(
+ self.room_id,
+ _EVENT_TYPE,
+ self.user1_access_token,
+ state_key=state_key,
+ expect_code=HTTPStatus.NOT_FOUND,
+ )
+
+ @unittest.override_config(
+ {"rc_delayed_event_mgmt": {"per_second": 0.5, "burst_count": 1}}
+ )
+ def test_cancel_delayed_event_ratelimit(self) -> None:
+ delay_ids = []
+ for _ in range(2):
+ channel = self.make_request(
+ "POST",
+ _get_path_for_delayed_send(self.room_id, _EVENT_TYPE, 100000),
+ {},
+ self.user1_access_token,
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
+ delay_id = channel.json_body.get("delay_id")
+ self.assertIsNotNone(delay_id)
+ delay_ids.append(delay_id)
+
+ channel = self.make_request(
+ "POST",
+ f"{PATH_PREFIX}/{delay_ids.pop(0)}",
+ {"action": "cancel"},
+ self.user1_access_token,
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
+
+ args = (
+ "POST",
+ f"{PATH_PREFIX}/{delay_ids.pop(0)}",
+ {"action": "cancel"},
+ self.user1_access_token,
+ )
+ channel = self.make_request(*args)
+ self.assertEqual(HTTPStatus.TOO_MANY_REQUESTS, channel.code, channel.result)
+
+ # Add the current user to the ratelimit overrides, allowing them no ratelimiting.
+ self.get_success(
+ self.hs.get_datastores().main.set_ratelimit_for_user(
+ self.user1_user_id, 0, 0
+ )
+ )
+
+ # Test that the request isn't ratelimited anymore.
+ channel = self.make_request(*args)
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
+
+ def test_send_delayed_state_event(self) -> None:
+ state_key = "to_send_on_request"
+
+ setter_key = "setter"
+ setter_expected = "on_send"
+ channel = self.make_request(
+ "PUT",
+ _get_path_for_delayed_state(self.room_id, _EVENT_TYPE, state_key, 100000),
+ {
+ setter_key: setter_expected,
+ },
+ self.user1_access_token,
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
+ delay_id = channel.json_body.get("delay_id")
+ self.assertIsNotNone(delay_id)
+
+ self.reactor.advance(1)
+ events = self._get_delayed_events()
+ self.assertEqual(1, len(events), events)
+ content = self._get_delayed_event_content(events[0])
+ self.assertEqual(setter_expected, content.get(setter_key), content)
+ self.helper.get_state(
+ self.room_id,
+ _EVENT_TYPE,
+ self.user1_access_token,
+ state_key=state_key,
+ expect_code=HTTPStatus.NOT_FOUND,
+ )
+
+ channel = self.make_request(
+ "POST",
+ f"{PATH_PREFIX}/{delay_id}",
+ {"action": "send"},
+ self.user1_access_token,
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
+ self.assertListEqual([], self._get_delayed_events())
+ content = self.helper.get_state(
+ self.room_id,
+ _EVENT_TYPE,
+ self.user1_access_token,
+ state_key=state_key,
+ )
+ self.assertEqual(setter_expected, content.get(setter_key), content)
+
+ @unittest.override_config({"rc_message": {"per_second": 3.5, "burst_count": 4}})
+ def test_send_delayed_event_ratelimit(self) -> None:
+ delay_ids = []
+ for _ in range(2):
+ channel = self.make_request(
+ "POST",
+ _get_path_for_delayed_send(self.room_id, _EVENT_TYPE, 100000),
+ {},
+ self.user1_access_token,
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
+ delay_id = channel.json_body.get("delay_id")
+ self.assertIsNotNone(delay_id)
+ delay_ids.append(delay_id)
+
+ channel = self.make_request(
+ "POST",
+ f"{PATH_PREFIX}/{delay_ids.pop(0)}",
+ {"action": "send"},
+ self.user1_access_token,
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
+
+ args = (
+ "POST",
+ f"{PATH_PREFIX}/{delay_ids.pop(0)}",
+ {"action": "send"},
+ self.user1_access_token,
+ )
+ channel = self.make_request(*args)
+ self.assertEqual(HTTPStatus.TOO_MANY_REQUESTS, channel.code, channel.result)
+
+ # Add the current user to the ratelimit overrides, allowing them no ratelimiting.
+ self.get_success(
+ self.hs.get_datastores().main.set_ratelimit_for_user(
+ self.user1_user_id, 0, 0
+ )
+ )
+
+ # Test that the request isn't ratelimited anymore.
+ channel = self.make_request(*args)
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
+
+ def test_restart_delayed_state_event(self) -> None:
+ state_key = "to_send_on_restarted_timeout"
+
+ setter_key = "setter"
+ setter_expected = "on_timeout"
+ channel = self.make_request(
+ "PUT",
+ _get_path_for_delayed_state(self.room_id, _EVENT_TYPE, state_key, 1500),
+ {
+ setter_key: setter_expected,
+ },
+ self.user1_access_token,
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
+ delay_id = channel.json_body.get("delay_id")
+ self.assertIsNotNone(delay_id)
+
+ self.reactor.advance(1)
+ events = self._get_delayed_events()
+ self.assertEqual(1, len(events), events)
+ content = self._get_delayed_event_content(events[0])
+ self.assertEqual(setter_expected, content.get(setter_key), content)
+ self.helper.get_state(
+ self.room_id,
+ _EVENT_TYPE,
+ self.user1_access_token,
+ state_key=state_key,
+ expect_code=HTTPStatus.NOT_FOUND,
+ )
+
+ channel = self.make_request(
+ "POST",
+ f"{PATH_PREFIX}/{delay_id}",
+ {"action": "restart"},
+ self.user1_access_token,
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
+
+ self.reactor.advance(1)
+ events = self._get_delayed_events()
+ self.assertEqual(1, len(events), events)
+ content = self._get_delayed_event_content(events[0])
+ self.assertEqual(setter_expected, content.get(setter_key), content)
+ self.helper.get_state(
+ self.room_id,
+ _EVENT_TYPE,
+ self.user1_access_token,
+ state_key=state_key,
+ expect_code=HTTPStatus.NOT_FOUND,
+ )
+
+ self.reactor.advance(1)
+ self.assertListEqual([], self._get_delayed_events())
+ content = self.helper.get_state(
+ self.room_id,
+ _EVENT_TYPE,
+ self.user1_access_token,
+ state_key=state_key,
+ )
+ self.assertEqual(setter_expected, content.get(setter_key), content)
+
+ @unittest.override_config(
+ {"rc_delayed_event_mgmt": {"per_second": 0.5, "burst_count": 1}}
+ )
+ def test_restart_delayed_event_ratelimit(self) -> None:
+ delay_ids = []
+ for _ in range(2):
+ channel = self.make_request(
+ "POST",
+ _get_path_for_delayed_send(self.room_id, _EVENT_TYPE, 100000),
+ {},
+ self.user1_access_token,
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
+ delay_id = channel.json_body.get("delay_id")
+ self.assertIsNotNone(delay_id)
+ delay_ids.append(delay_id)
+
+ channel = self.make_request(
+ "POST",
+ f"{PATH_PREFIX}/{delay_ids.pop(0)}",
+ {"action": "restart"},
+ self.user1_access_token,
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
+
+ args = (
+ "POST",
+ f"{PATH_PREFIX}/{delay_ids.pop(0)}",
+ {"action": "restart"},
+ self.user1_access_token,
+ )
+ channel = self.make_request(*args)
+ self.assertEqual(HTTPStatus.TOO_MANY_REQUESTS, channel.code, channel.result)
+
+ # Add the current user to the ratelimit overrides, allowing them no ratelimiting.
+ self.get_success(
+ self.hs.get_datastores().main.set_ratelimit_for_user(
+ self.user1_user_id, 0, 0
+ )
+ )
+
+ # Test that the request isn't ratelimited anymore.
+ channel = self.make_request(*args)
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
+
+ def test_delayed_state_is_not_cancelled_by_new_state_from_same_user(
+ self,
+ ) -> None:
+ state_key = "to_not_be_cancelled_by_same_user"
+
+ setter_key = "setter"
+ setter_expected = "on_timeout"
+ channel = self.make_request(
+ "PUT",
+ _get_path_for_delayed_state(self.room_id, _EVENT_TYPE, state_key, 900),
+ {
+ setter_key: setter_expected,
+ },
+ self.user1_access_token,
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
+ events = self._get_delayed_events()
+ self.assertEqual(1, len(events), events)
+
+ self.helper.send_state(
+ self.room_id,
+ _EVENT_TYPE,
+ {
+ setter_key: "manual",
+ },
+ self.user1_access_token,
+ state_key=state_key,
+ )
+ events = self._get_delayed_events()
+ self.assertEqual(1, len(events), events)
+
+ self.reactor.advance(1)
+ content = self.helper.get_state(
+ self.room_id,
+ _EVENT_TYPE,
+ self.user1_access_token,
+ state_key=state_key,
+ )
+ self.assertEqual(setter_expected, content.get(setter_key), content)
+
+ def test_delayed_state_is_cancelled_by_new_state_from_other_user(
+ self,
+ ) -> None:
+ state_key = "to_be_cancelled_by_other_user"
+
+ setter_key = "setter"
+ channel = self.make_request(
+ "PUT",
+ _get_path_for_delayed_state(self.room_id, _EVENT_TYPE, state_key, 900),
+ {
+ setter_key: "on_timeout",
+ },
+ self.user1_access_token,
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
+ events = self._get_delayed_events()
+ self.assertEqual(1, len(events), events)
+
+ setter_expected = "other_user"
+ self.helper.send_state(
+ self.room_id,
+ _EVENT_TYPE,
+ {
+ setter_key: setter_expected,
+ },
+ self.user2_access_token,
+ state_key=state_key,
+ )
+ self.assertListEqual([], self._get_delayed_events())
+
+ self.reactor.advance(1)
+ content = self.helper.get_state(
+ self.room_id,
+ _EVENT_TYPE,
+ self.user1_access_token,
+ state_key=state_key,
+ )
+ self.assertEqual(setter_expected, content.get(setter_key), content)
+
+ def _get_delayed_events(self) -> List[JsonDict]:
+ channel = self.make_request(
+ "GET",
+ PATH_PREFIX,
+ access_token=self.user1_access_token,
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
+
+ key = "delayed_events"
+ self.assertIn(key, channel.json_body)
+
+ events = channel.json_body[key]
+ self.assertIsInstance(events, list)
+
+ return events
+
+ def _get_delayed_event_content(self, event: JsonDict) -> JsonDict:
+ key = "content"
+ self.assertIn(key, event)
+
+ content = event[key]
+ self.assertIsInstance(content, dict)
+
+ return content
+
+
+def _get_path_for_delayed_state(
+ room_id: str, event_type: str, state_key: str, delay_ms: int
+) -> str:
+ return f"rooms/{room_id}/state/{event_type}/{state_key}?org.matrix.msc4140.delay={delay_ms}"
+
+
+def _get_path_for_delayed_send(room_id: str, event_type: str, delay_ms: int) -> str:
+ return f"rooms/{room_id}/send/{event_type}?org.matrix.msc4140.delay={delay_ms}"
diff --git a/tests/rest/client/test_devices.py b/tests/rest/client/test_devices.py
index a3ed12a38f..dd3abdebac 100644
--- a/tests/rest/client/test_devices.py
+++ b/tests/rest/client/test_devices.py
@@ -24,6 +24,7 @@ from twisted.internet.defer import ensureDeferred
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.errors import NotFoundError
+from synapse.appservice import ApplicationService
from synapse.rest import admin, devices, sync
from synapse.rest.client import keys, login, register
from synapse.server import HomeServer
@@ -455,3 +456,183 @@ class DehydratedDeviceTestCase(unittest.HomeserverTestCase):
token,
)
self.assertEqual(channel.json_body["device_keys"], {"@mikey:test": {}})
+
+
+class MSC4190AppserviceDevicesTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ register.register_servlets,
+ devices.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+ self.hs = self.setup_test_homeserver()
+
+ # This application service uses the new MSC4190 behaviours
+ self.msc4190_service = ApplicationService(
+ id="msc4190",
+ token="some_token",
+ hs_token="some_token",
+ sender="@as:example.com",
+ namespaces={
+ ApplicationService.NS_USERS: [{"regex": "@.*", "exclusive": False}]
+ },
+ msc4190_device_management=True,
+ )
+ # This application service doesn't use the new MSC4190 behaviours
+ self.pre_msc_service = ApplicationService(
+ id="regular",
+ token="other_token",
+ hs_token="other_token",
+ sender="@as2:example.com",
+ namespaces={
+ ApplicationService.NS_USERS: [{"regex": "@.*", "exclusive": False}]
+ },
+ msc4190_device_management=False,
+ )
+ self.hs.get_datastores().main.services_cache.append(self.msc4190_service)
+ self.hs.get_datastores().main.services_cache.append(self.pre_msc_service)
+ return self.hs
+
+ def test_PUT_device(self) -> None:
+ self.register_appservice_user("alice", self.msc4190_service.token)
+ self.register_appservice_user("bob", self.pre_msc_service.token)
+
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/v3/devices?user_id=@alice:test",
+ access_token=self.msc4190_service.token,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+ self.assertEqual(channel.json_body, {"devices": []})
+
+ channel = self.make_request(
+ "PUT",
+ "/_matrix/client/v3/devices/AABBCCDD?user_id=@alice:test",
+ content={"display_name": "Alice's device"},
+ access_token=self.msc4190_service.token,
+ )
+ self.assertEqual(channel.code, 201, channel.json_body)
+
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/v3/devices?user_id=@alice:test",
+ access_token=self.msc4190_service.token,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+ self.assertEqual(len(channel.json_body["devices"]), 1)
+ self.assertEqual(channel.json_body["devices"][0]["device_id"], "AABBCCDD")
+
+ # Doing a second time should return a 200 instead of a 201
+ channel = self.make_request(
+ "PUT",
+ "/_matrix/client/v3/devices/AABBCCDD?user_id=@alice:test",
+ content={"display_name": "Alice's device"},
+ access_token=self.msc4190_service.token,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # On the regular service, that API should not allow for the
+ # creation of new devices.
+ channel = self.make_request(
+ "PUT",
+ "/_matrix/client/v3/devices/AABBCCDD?user_id=@bob:test",
+ content={"display_name": "Bob's device"},
+ access_token=self.pre_msc_service.token,
+ )
+ self.assertEqual(channel.code, 404, channel.json_body)
+
+ def test_DELETE_device(self) -> None:
+ self.register_appservice_user("alice", self.msc4190_service.token)
+
+ # There should be no device
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/v3/devices?user_id=@alice:test",
+ access_token=self.msc4190_service.token,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+ self.assertEqual(channel.json_body, {"devices": []})
+
+ # Create a device
+ channel = self.make_request(
+ "PUT",
+ "/_matrix/client/v3/devices/AABBCCDD?user_id=@alice:test",
+ content={},
+ access_token=self.msc4190_service.token,
+ )
+ self.assertEqual(channel.code, 201, channel.json_body)
+
+ # There should be one device
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/v3/devices?user_id=@alice:test",
+ access_token=self.msc4190_service.token,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+ self.assertEqual(len(channel.json_body["devices"]), 1)
+
+ # Delete the device. UIA should not be required.
+ channel = self.make_request(
+ "DELETE",
+ "/_matrix/client/v3/devices/AABBCCDD?user_id=@alice:test",
+ access_token=self.msc4190_service.token,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # There should be no device again
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/v3/devices?user_id=@alice:test",
+ access_token=self.msc4190_service.token,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+ self.assertEqual(channel.json_body, {"devices": []})
+
+ def test_POST_delete_devices(self) -> None:
+ self.register_appservice_user("alice", self.msc4190_service.token)
+
+ # There should be no device
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/v3/devices?user_id=@alice:test",
+ access_token=self.msc4190_service.token,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+ self.assertEqual(channel.json_body, {"devices": []})
+
+ # Create a device
+ channel = self.make_request(
+ "PUT",
+ "/_matrix/client/v3/devices/AABBCCDD?user_id=@alice:test",
+ content={},
+ access_token=self.msc4190_service.token,
+ )
+ self.assertEqual(channel.code, 201, channel.json_body)
+
+ # There should be one device
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/v3/devices?user_id=@alice:test",
+ access_token=self.msc4190_service.token,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+ self.assertEqual(len(channel.json_body["devices"]), 1)
+
+ # Delete the device with delete_devices
+ # UIA should not be required.
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/v3/delete_devices?user_id=@alice:test",
+ content={"devices": ["AABBCCDD"]},
+ access_token=self.msc4190_service.token,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # There should be no device again
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/v3/devices?user_id=@alice:test",
+ access_token=self.msc4190_service.token,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+ self.assertEqual(channel.json_body, {"devices": []})
diff --git a/tests/rest/client/test_events.py b/tests/rest/client/test_events.py
index 06f1c1b234..039144fdbe 100644
--- a/tests/rest/client/test_events.py
+++ b/tests/rest/client/test_events.py
@@ -19,7 +19,7 @@
#
#
-""" Tests REST events for /events paths."""
+"""Tests REST events for /events paths."""
from unittest.mock import Mock
diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py
deleted file mode 100644
index 63c2c5923e..0000000000
--- a/tests/rest/client/test_identity.py
+++ /dev/null
@@ -1,67 +0,0 @@
-#
-# This file is licensed under the Affero General Public License (AGPL) version 3.
-#
-# Copyright (C) 2023 New Vector, Ltd
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Affero General Public License as
-# published by the Free Software Foundation, either version 3 of the
-# License, or (at your option) any later version.
-#
-# See the GNU Affero General Public License for more details:
-# <https://www.gnu.org/licenses/agpl-3.0.html>.
-#
-# Originally licensed under the Apache License, Version 2.0:
-# <http://www.apache.org/licenses/LICENSE-2.0>.
-#
-# [This file includes modifications made by New Vector Limited]
-#
-#
-
-from http import HTTPStatus
-
-from twisted.test.proto_helpers import MemoryReactor
-
-import synapse.rest.admin
-from synapse.rest.client import login, room
-from synapse.server import HomeServer
-from synapse.util import Clock
-
-from tests import unittest
-
-
-class IdentityTestCase(unittest.HomeserverTestCase):
- servlets = [
- synapse.rest.admin.register_servlets_for_client_rest_resource,
- room.register_servlets,
- login.register_servlets,
- ]
-
- def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
- config = self.default_config()
- config["enable_3pid_lookup"] = False
- self.hs = self.setup_test_homeserver(config=config)
-
- return self.hs
-
- def test_3pid_lookup_disabled(self) -> None:
- self.hs.config.registration.enable_3pid_lookup = False
-
- self.register_user("kermit", "monkey")
- tok = self.login("kermit", "monkey")
-
- channel = self.make_request(b"POST", "/createRoom", b"{}", access_token=tok)
- self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
- room_id = channel.json_body["room_id"]
-
- request_data = {
- "id_server": "testis",
- "medium": "email",
- "address": "test@example.com",
- "id_access_token": tok,
- }
- request_url = ("/rooms/%s/invite" % (room_id)).encode("ascii")
- channel = self.make_request(
- b"POST", request_url, request_data, access_token=tok
- )
- self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result)
diff --git a/tests/rest/client/test_keys.py b/tests/rest/client/test_keys.py
index 8bbd109092..d9a210b616 100644
--- a/tests/rest/client/test_keys.py
+++ b/tests/rest/client/test_keys.py
@@ -315,9 +315,7 @@ class SigningKeyUploadServletTestCase(unittest.HomeserverTestCase):
"master_key": master_key2,
},
)
- self.assertEqual(
- channel.code, HTTPStatus.NOT_IMPLEMENTED, channel.json_body
- )
+ self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.json_body)
# Pretend that MAS did UIA and allowed us to replace the master key.
channel = self.make_request(
@@ -349,9 +347,7 @@ class SigningKeyUploadServletTestCase(unittest.HomeserverTestCase):
"master_key": master_key3,
},
)
- self.assertEqual(
- channel.code, HTTPStatus.NOT_IMPLEMENTED, channel.json_body
- )
+ self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.json_body)
# Pretend that MAS did UIA and allowed us to replace the master key.
channel = self.make_request(
@@ -376,6 +372,4 @@ class SigningKeyUploadServletTestCase(unittest.HomeserverTestCase):
"master_key": master_key3,
},
)
- self.assertEqual(
- channel.code, HTTPStatus.NOT_IMPLEMENTED, channel.json_body
- )
+ self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.json_body)
diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py
index 2b1e44381b..24e2288ee3 100644
--- a/tests/rest/client/test_login.py
+++ b/tests/rest/client/test_login.py
@@ -27,6 +27,7 @@ from typing import (
Collection,
Dict,
List,
+ Literal,
Optional,
Tuple,
Union,
@@ -35,7 +36,6 @@ from unittest.mock import Mock
from urllib.parse import urlencode
import pymacaroons
-from typing_extensions import Literal
from twisted.test.proto_helpers import MemoryReactor
from twisted.web.resource import Resource
@@ -43,6 +43,7 @@ from twisted.web.resource import Resource
import synapse.rest.admin
from synapse.api.constants import ApprovalNoticeMedium, LoginType
from synapse.api.errors import Codes
+from synapse.api.urls import LoginSSORedirectURIBuilder
from synapse.appservice import ApplicationService
from synapse.http.client import RawHeaders
from synapse.module_api import ModuleApi
@@ -55,7 +56,6 @@ from synapse.util import Clock
from tests import unittest
from tests.handlers.test_oidc import HAS_OIDC
-from tests.handlers.test_saml import has_saml2
from tests.rest.client.utils import TEST_OIDC_CONFIG
from tests.server import FakeChannel
from tests.test_utils.html_parsers import TestHtmlParser
@@ -69,6 +69,10 @@ try:
except ImportError:
HAS_JWT = False
+import logging
+
+logger = logging.getLogger(__name__)
+
# synapse server name: used to populate public_baseurl in some tests
SYNAPSE_SERVER_PUBLIC_HOSTNAME = "synapse"
@@ -77,22 +81,7 @@ SYNAPSE_SERVER_PUBLIC_HOSTNAME = "synapse"
# FakeChannel.isSecure() returns False, so synapse will see the requested uri as
# http://..., so using http in the public_baseurl stops Synapse trying to redirect to
# https://....
-BASE_URL = "http://%s/" % (SYNAPSE_SERVER_PUBLIC_HOSTNAME,)
-
-# CAS server used in some tests
-CAS_SERVER = "https://fake.test"
-
-# just enough to tell pysaml2 where to redirect to
-SAML_SERVER = "https://test.saml.server/idp/sso"
-TEST_SAML_METADATA = """
-<md:EntityDescriptor xmlns:md="urn:oasis:names:tc:SAML:2.0:metadata">
- <md:IDPSSODescriptor protocolSupportEnumeration="urn:oasis:names:tc:SAML:2.0:protocol">
- <md:SingleSignOnService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" Location="%(SAML_SERVER)s"/>
- </md:IDPSSODescriptor>
-</md:EntityDescriptor>
-""" % {
- "SAML_SERVER": SAML_SERVER,
-}
+PUBLIC_BASEURL = "http://%s/" % (SYNAPSE_SERVER_PUBLIC_HOSTNAME,)
LOGIN_URL = b"/_matrix/client/r0/login"
TEST_URL = b"/_matrix/client/r0/account/whoami"
@@ -109,6 +98,23 @@ ADDITIONAL_LOGIN_FLOWS = [
]
+def get_relative_uri_from_absolute_uri(absolute_uri: str) -> str:
+ """
+ Peels off the path and query string from an absolute URI. Useful when interacting
+ with `make_request(...)` util function which expects a relative path instead of a
+ full URI.
+ """
+ parsed_uri = urllib.parse.urlparse(absolute_uri)
+ # Sanity check that we're working with an absolute URI
+ assert parsed_uri.scheme == "http" or parsed_uri.scheme == "https"
+
+ relative_uri = parsed_uri.path
+ if parsed_uri.query:
+ relative_uri += "?" + parsed_uri.query
+
+ return relative_uri
+
+
class TestSpamChecker:
def __init__(self, config: None, api: ModuleApi):
api.register_spam_checker_callbacks(
@@ -172,7 +178,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.hs = self.setup_test_homeserver()
self.hs.config.registration.enable_registration = True
- self.hs.config.registration.registrations_require_3pid = []
self.hs.config.registration.auto_join_rooms = []
self.hs.config.captcha.enable_registration_captcha = False
@@ -603,7 +608,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
)
-@skip_unless(has_saml2 and HAS_OIDC, "Requires SAML2 and OIDC")
+@skip_unless(HAS_OIDC, "Requires OIDC")
class MultiSSOTestCase(unittest.HomeserverTestCase):
"""Tests for homeservers with multiple SSO providers enabled"""
@@ -614,21 +619,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
def default_config(self) -> Dict[str, Any]:
config = super().default_config()
- config["public_baseurl"] = BASE_URL
-
- config["cas_config"] = {
- "enabled": True,
- "server_url": CAS_SERVER,
- "service_url": "https://matrix.goodserver.com:8448",
- }
-
- config["saml2_config"] = {
- "sp_config": {
- "metadata": {"inline": [TEST_SAML_METADATA]},
- # use the XMLSecurity backend to avoid relying on xmlsec1
- "crypto_backend": "XMLSecurity",
- },
- }
+ config["public_baseurl"] = PUBLIC_BASEURL
# default OIDC provider
config["oidc_config"] = TEST_OIDC_CONFIG
@@ -653,6 +644,9 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
]
return config
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.login_sso_redirect_url_builder = LoginSSORedirectURIBuilder(hs.config)
+
def create_resource_dict(self) -> Dict[str, Resource]:
d = super().create_resource_dict()
d.update(build_synapse_client_resource_tree(self.hs))
@@ -664,7 +658,6 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200, channel.result)
expected_flow_types = [
- "m.login.cas",
"m.login.sso",
"m.login.token",
"m.login.password",
@@ -678,8 +671,6 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
self.assertCountEqual(
flows["m.login.sso"]["identity_providers"],
[
- {"id": "cas", "name": "CAS"},
- {"id": "saml", "name": "SAML"},
{"id": "oidc-idp1", "name": "IDP1"},
{"id": "oidc", "name": "OIDC"},
],
@@ -713,56 +704,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
self.assertEqual(params["redirectUrl"], [TEST_CLIENT_REDIRECT_URL])
returned_idps.append(params["idp"][0])
- self.assertCountEqual(returned_idps, ["cas", "oidc", "oidc-idp1", "saml"])
-
- def test_multi_sso_redirect_to_cas(self) -> None:
- """If CAS is chosen, should redirect to the CAS server"""
-
- channel = self.make_request(
- "GET",
- "/_synapse/client/pick_idp?redirectUrl="
- + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
- + "&idp=cas",
- shorthand=False,
- )
- self.assertEqual(channel.code, 302, channel.result)
- location_headers = channel.headers.getRawHeaders("Location")
- assert location_headers
- cas_uri = location_headers[0]
- cas_uri_path, cas_uri_query = cas_uri.split("?", 1)
-
- # it should redirect us to the login page of the cas server
- self.assertEqual(cas_uri_path, CAS_SERVER + "/login")
-
- # check that the redirectUrl is correctly encoded in the service param - ie, the
- # place that CAS will redirect to
- cas_uri_params = urllib.parse.parse_qs(cas_uri_query)
- service_uri = cas_uri_params["service"][0]
- _, service_uri_query = service_uri.split("?", 1)
- service_uri_params = urllib.parse.parse_qs(service_uri_query)
- self.assertEqual(service_uri_params["redirectUrl"][0], TEST_CLIENT_REDIRECT_URL)
-
- def test_multi_sso_redirect_to_saml(self) -> None:
- """If SAML is chosen, should redirect to the SAML server"""
- channel = self.make_request(
- "GET",
- "/_synapse/client/pick_idp?redirectUrl="
- + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
- + "&idp=saml",
- )
- self.assertEqual(channel.code, 302, channel.result)
- location_headers = channel.headers.getRawHeaders("Location")
- assert location_headers
- saml_uri = location_headers[0]
- saml_uri_path, saml_uri_query = saml_uri.split("?", 1)
-
- # it should redirect us to the login page of the SAML server
- self.assertEqual(saml_uri_path, SAML_SERVER)
-
- # the RelayState is used to carry the client redirect url
- saml_uri_params = urllib.parse.parse_qs(saml_uri_query)
- relay_state_param = saml_uri_params["RelayState"][0]
- self.assertEqual(relay_state_param, TEST_CLIENT_REDIRECT_URL)
+ self.assertCountEqual(returned_idps, ["oidc", "oidc-idp1"])
def test_login_via_oidc(self) -> None:
"""If OIDC is chosen, should redirect to the OIDC auth endpoint"""
@@ -773,13 +715,38 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
# pick the default OIDC provider
channel = self.make_request(
"GET",
- "/_synapse/client/pick_idp?redirectUrl="
- + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
- + "&idp=oidc",
+ f"/_synapse/client/pick_idp?redirectUrl={urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)}&idp=oidc",
)
self.assertEqual(channel.code, 302, channel.result)
location_headers = channel.headers.getRawHeaders("Location")
assert location_headers
+ sso_login_redirect_uri = location_headers[0]
+
+ # it should redirect us to the standard login SSO redirect flow
+ self.assertEqual(
+ sso_login_redirect_uri,
+ self.login_sso_redirect_url_builder.build_login_sso_redirect_uri(
+ idp_id="oidc", client_redirect_url=TEST_CLIENT_REDIRECT_URL
+ ),
+ )
+
+ with fake_oidc_server.patch_homeserver(hs=self.hs):
+ # follow the redirect
+ channel = self.make_request(
+ "GET",
+ # We have to make this relative to be compatible with `make_request(...)`
+ get_relative_uri_from_absolute_uri(sso_login_redirect_uri),
+ # We have to set the Host header to match the `public_baseurl` to avoid
+ # the extra redirect in the `SsoRedirectServlet` in order for the
+ # cookies to be visible.
+ custom_headers=[
+ ("Host", SYNAPSE_SERVER_PUBLIC_HOSTNAME),
+ ],
+ )
+
+ self.assertEqual(channel.code, 302, channel.result)
+ location_headers = channel.headers.getRawHeaders("Location")
+ assert location_headers
oidc_uri = location_headers[0]
oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)
@@ -838,12 +805,38 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
self.assertEqual(chan.json_body["user_id"], "@user1:test")
def test_multi_sso_redirect_to_unknown(self) -> None:
- """An unknown IdP should cause a 400"""
+ """An unknown IdP should cause a 404"""
channel = self.make_request(
"GET",
"/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz",
)
- self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(channel.code, 302, channel.result)
+ location_headers = channel.headers.getRawHeaders("Location")
+ assert location_headers
+ sso_login_redirect_uri = location_headers[0]
+
+ # it should redirect us to the standard login SSO redirect flow
+ self.assertEqual(
+ sso_login_redirect_uri,
+ self.login_sso_redirect_url_builder.build_login_sso_redirect_uri(
+ idp_id="xyz", client_redirect_url="http://x"
+ ),
+ )
+
+ # follow the redirect
+ channel = self.make_request(
+ "GET",
+ # We have to make this relative to be compatible with `make_request(...)`
+ get_relative_uri_from_absolute_uri(sso_login_redirect_uri),
+ # We have to set the Host header to match the `public_baseurl` to avoid
+ # the extra redirect in the `SsoRedirectServlet` in order for the
+ # cookies to be visible.
+ custom_headers=[
+ ("Host", SYNAPSE_SERVER_PUBLIC_HOSTNAME),
+ ],
+ )
+
+ self.assertEqual(channel.code, 404, channel.result)
def test_client_idp_redirect_to_unknown(self) -> None:
"""If the client tries to pick an unknown IdP, return a 404"""
@@ -891,162 +884,12 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
raise ValueError("No %s caveat in macaroon" % (key,))
-class CASTestCase(unittest.HomeserverTestCase):
- servlets = [
- login.register_servlets,
- ]
-
- def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
- self.base_url = "https://matrix.goodserver.com/"
- self.redirect_path = "_synapse/client/login/sso/redirect/confirm"
-
- config = self.default_config()
- config["public_baseurl"] = (
- config.get("public_baseurl") or "https://matrix.goodserver.com:8448"
- )
- config["cas_config"] = {
- "enabled": True,
- "server_url": CAS_SERVER,
- }
-
- cas_user_id = "username"
- self.user_id = "@%s:test" % cas_user_id
-
- async def get_raw(uri: str, args: Any) -> bytes:
- """Return an example response payload from a call to the `/proxyValidate`
- endpoint of a CAS server, copied from
- https://apereo.github.io/cas/5.0.x/protocol/CAS-Protocol-V2-Specification.html#26-proxyvalidate-cas-20
-
- This needs to be returned by an async function (as opposed to set as the
- mock's return value) because the corresponding Synapse code awaits on it.
- """
- return (
- """
- <cas:serviceResponse xmlns:cas='http://www.yale.edu/tp/cas'>
- <cas:authenticationSuccess>
- <cas:user>%s</cas:user>
- <cas:proxyGrantingTicket>PGTIOU-84678-8a9d...</cas:proxyGrantingTicket>
- <cas:proxies>
- <cas:proxy>https://proxy2/pgtUrl</cas:proxy>
- <cas:proxy>https://proxy1/pgtUrl</cas:proxy>
- </cas:proxies>
- </cas:authenticationSuccess>
- </cas:serviceResponse>
- """
- % cas_user_id
- ).encode("utf-8")
-
- mocked_http_client = Mock(spec=["get_raw"])
- mocked_http_client.get_raw.side_effect = get_raw
-
- self.hs = self.setup_test_homeserver(
- config=config,
- proxied_http_client=mocked_http_client,
- )
-
- return self.hs
-
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.deactivate_account_handler = hs.get_deactivate_account_handler()
-
- def test_cas_redirect_confirm(self) -> None:
- """Tests that the SSO login flow serves a confirmation page before redirecting a
- user to the redirect URL.
- """
- base_url = "/_matrix/client/r0/login/cas/ticket?redirectUrl"
- redirect_url = "https://dodgy-site.com/"
-
- url_parts = list(urllib.parse.urlparse(base_url))
- query = dict(urllib.parse.parse_qsl(url_parts[4]))
- query.update({"redirectUrl": redirect_url})
- query.update({"ticket": "ticket"})
- url_parts[4] = urllib.parse.urlencode(query)
- cas_ticket_url = urllib.parse.urlunparse(url_parts)
-
- # Get Synapse to call the fake CAS and serve the template.
- channel = self.make_request("GET", cas_ticket_url)
-
- # Test that the response is HTML.
- self.assertEqual(channel.code, 200, channel.result)
- content_type_header_value = ""
- for header in channel.headers.getRawHeaders("Content-Type", []):
- content_type_header_value = header
-
- self.assertTrue(content_type_header_value.startswith("text/html"))
-
- # Test that the body isn't empty.
- self.assertTrue(len(channel.result["body"]) > 0)
-
- # And that it contains our redirect link
- self.assertIn(redirect_url, channel.result["body"].decode("UTF-8"))
-
- @override_config(
- {
- "sso": {
- "client_whitelist": [
- "https://legit-site.com/",
- "https://other-site.com/",
- ]
- }
- }
- )
- def test_cas_redirect_whitelisted(self) -> None:
- """Tests that the SSO login flow serves a redirect to a whitelisted url"""
- self._test_redirect("https://legit-site.com/")
-
- @override_config({"public_baseurl": "https://example.com"})
- def test_cas_redirect_login_fallback(self) -> None:
- self._test_redirect("https://example.com/_matrix/static/client/login")
-
- def _test_redirect(self, redirect_url: str) -> None:
- """Tests that the SSO login flow serves a redirect for the given redirect URL."""
- cas_ticket_url = (
- "/_matrix/client/r0/login/cas/ticket?redirectUrl=%s&ticket=ticket"
- % (urllib.parse.quote(redirect_url))
- )
-
- # Get Synapse to call the fake CAS and serve the template.
- channel = self.make_request("GET", cas_ticket_url)
-
- self.assertEqual(channel.code, 302)
- location_headers = channel.headers.getRawHeaders("Location")
- assert location_headers
- self.assertEqual(location_headers[0][: len(redirect_url)], redirect_url)
-
- @override_config({"sso": {"client_whitelist": ["https://legit-site.com/"]}})
- def test_deactivated_user(self) -> None:
- """Logging in as a deactivated account should error."""
- redirect_url = "https://legit-site.com/"
-
- # First login (to create the user).
- self._test_redirect(redirect_url)
-
- # Deactivate the account.
- self.get_success(
- self.deactivate_account_handler.deactivate_account(
- self.user_id, False, create_requester(self.user_id)
- )
- )
-
- # Request the CAS ticket.
- cas_ticket_url = (
- "/_matrix/client/r0/login/cas/ticket?redirectUrl=%s&ticket=ticket"
- % (urllib.parse.quote(redirect_url))
- )
-
- # Get Synapse to call the fake CAS and serve the template.
- channel = self.make_request("GET", cas_ticket_url)
-
- # Because the user is deactivated they are served an error template.
- self.assertEqual(channel.code, 403)
- self.assertIn(b"SSO account deactivated", channel.result["body"])
-
-
@skip_unless(HAS_JWT, "requires authlib")
class JWTTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
login.register_servlets,
+ profile.register_servlets,
]
jwt_secret = "secret"
@@ -1133,18 +976,18 @@ class JWTTestCase(unittest.HomeserverTestCase):
channel = self.jwt_login({"sub": "kermit", "iss": "invalid"})
self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
- self.assertEqual(
+ self.assertRegex(
channel.json_body["error"],
- 'JWT validation failed: invalid_claim: Invalid claim "iss"',
+ r"^JWT validation failed: invalid_claim: Invalid claim [\"']iss[\"']$",
)
# Not providing an issuer.
channel = self.jwt_login({"sub": "kermit"})
self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
- self.assertEqual(
+ self.assertRegex(
channel.json_body["error"],
- 'JWT validation failed: missing_claim: Missing "iss" claim',
+ r"^JWT validation failed: missing_claim: Missing [\"']iss[\"'] claim$",
)
def test_login_iss_no_config(self) -> None:
@@ -1165,18 +1008,18 @@ class JWTTestCase(unittest.HomeserverTestCase):
channel = self.jwt_login({"sub": "kermit", "aud": "invalid"})
self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
- self.assertEqual(
+ self.assertRegex(
channel.json_body["error"],
- 'JWT validation failed: invalid_claim: Invalid claim "aud"',
+ r"^JWT validation failed: invalid_claim: Invalid claim [\"']aud[\"']$",
)
# Not providing an audience.
channel = self.jwt_login({"sub": "kermit"})
self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
- self.assertEqual(
+ self.assertRegex(
channel.json_body["error"],
- 'JWT validation failed: missing_claim: Missing "aud" claim',
+ r"^JWT validation failed: missing_claim: Missing [\"']aud[\"'] claim$",
)
def test_login_aud_no_config(self) -> None:
@@ -1184,9 +1027,9 @@ class JWTTestCase(unittest.HomeserverTestCase):
channel = self.jwt_login({"sub": "kermit", "aud": "invalid"})
self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
- self.assertEqual(
+ self.assertRegex(
channel.json_body["error"],
- 'JWT validation failed: invalid_claim: Invalid claim "aud"',
+ r"^JWT validation failed: invalid_claim: Invalid claim [\"']aud[\"']$",
)
def test_login_default_sub(self) -> None:
@@ -1202,6 +1045,30 @@ class JWTTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@frog:test")
+ @override_config(
+ {"jwt_config": {**base_config, "display_name_claim": "display_name"}}
+ )
+ def test_login_custom_display_name(self) -> None:
+ """Test setting a custom display name."""
+ localpart = "pinkie"
+ user_id = f"@{localpart}:test"
+ display_name = "Pinkie Pie"
+
+ # Perform the login, specifying a custom display name.
+ channel = self.jwt_login({"sub": localpart, "display_name": display_name})
+ self.assertEqual(channel.code, 200, msg=channel.result)
+ self.assertEqual(channel.json_body["user_id"], user_id)
+
+ # Fetch the user's display name and check that it was set correctly.
+ access_token = channel.json_body["access_token"]
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/v3/profile/{user_id}/displayname",
+ access_token=access_token,
+ )
+ self.assertEqual(channel.code, 200, msg=channel.result)
+ self.assertEqual(channel.json_body["displayname"], display_name)
+
def test_login_no_token(self) -> None:
params = {"type": "org.matrix.login.jwt"}
channel = self.make_request(b"POST", LOGIN_URL, params)
@@ -1448,7 +1315,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
def default_config(self) -> Dict[str, Any]:
config = super().default_config()
- config["public_baseurl"] = BASE_URL
+ config["public_baseurl"] = PUBLIC_BASEURL
config["oidc_config"] = {}
config["oidc_config"].update(TEST_OIDC_CONFIG)
@@ -1474,7 +1341,6 @@ class UsernamePickerTestCase(HomeserverTestCase):
self,
fake_oidc_server: FakeOidcServer,
displayname: str,
- email: str,
picture: str,
) -> Tuple[str, str]:
# do the start of the login flow
@@ -1483,8 +1349,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
{
"sub": "tester",
"displayname": displayname,
- "picture": picture,
- "email": email,
+ "picture": picture
},
TEST_CLIENT_REDIRECT_URL,
)
@@ -1513,7 +1378,6 @@ class UsernamePickerTestCase(HomeserverTestCase):
session = username_mapping_sessions[session_id]
self.assertEqual(session.remote_user_id, "tester")
self.assertEqual(session.display_name, displayname)
- self.assertEqual(session.emails, [email])
self.assertEqual(session.avatar_url, picture)
self.assertEqual(session.client_redirect_url, TEST_CLIENT_REDIRECT_URL)
@@ -1530,11 +1394,10 @@ class UsernamePickerTestCase(HomeserverTestCase):
mxid = "@bobby:test"
displayname = "Jonny"
- email = "bobby@test.com"
picture = "mxc://test/avatar_url"
picker_url, session_id = self.proceed_to_username_picker_page(
- fake_oidc_server, displayname, email, picture
+ fake_oidc_server, displayname, picture
)
# Now, submit a username to the username picker, which should serve a redirect
@@ -1544,8 +1407,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
{
b"username": b"bobby",
b"use_display_name": b"true",
- b"use_avatar": b"true",
- b"use_email": email,
+ b"use_avatar": b"true"
}
).encode("utf8")
chan = self.make_request(
@@ -1606,12 +1468,6 @@ class UsernamePickerTestCase(HomeserverTestCase):
self.assertIn("mxc://test", channel.json_body["avatar_url"])
self.assertEqual(displayname, channel.json_body["displayname"])
- # ensure the email from the OIDC response has been configured for the user.
- channel = self.make_request(
- "GET", "/account/3pid", access_token=chan.json_body["access_token"]
- )
- self.assertEqual(channel.code, 200, channel.result)
- self.assertEqual(email, channel.json_body["threepids"][0]["address"])
def test_username_picker_dont_use_displayname_avatar_or_email(self) -> None:
"""Test the happy path of a username picker flow without using displayname, avatar or email."""
@@ -1620,12 +1476,11 @@ class UsernamePickerTestCase(HomeserverTestCase):
mxid = "@bobby:test"
displayname = "Jonny"
- email = "bobby@test.com"
picture = "mxc://test/avatar_url"
username = "bobby"
picker_url, session_id = self.proceed_to_username_picker_page(
- fake_oidc_server, displayname, email, picture
+ fake_oidc_server, displayname, picture
)
# Now, submit a username to the username picker, which should serve a redirect
@@ -1696,13 +1551,6 @@ class UsernamePickerTestCase(HomeserverTestCase):
self.assertNotIn("avatar_url", channel.json_body)
self.assertEqual(username, channel.json_body["displayname"])
- # ensure the email from the OIDC response has not been configured for the user.
- channel = self.make_request(
- "GET", "/account/3pid", access_token=chan.json_body["access_token"]
- )
- self.assertEqual(channel.code, 200, channel.result)
- self.assertListEqual([], channel.json_body["threepids"])
-
async def mock_get_file(
url: str,
diff --git a/tests/rest/client/test_login_token_request.py b/tests/rest/client/test_login_token_request.py
index fbacf9d869..99a0fd4fcd 100644
--- a/tests/rest/client/test_login_token_request.py
+++ b/tests/rest/client/test_login_token_request.py
@@ -43,7 +43,6 @@ class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.hs = self.setup_test_homeserver()
self.hs.config.registration.enable_registration = True
- self.hs.config.registration.registrations_require_3pid = []
self.hs.config.registration.auto_join_rooms = []
self.hs.config.captcha.enable_registration_captcha = False
diff --git a/tests/rest/client/test_media.py b/tests/rest/client/test_media.py
index 30b6d31d0a..6ee761e44b 100644
--- a/tests/rest/client/test_media.py
+++ b/tests/rest/client/test_media.py
@@ -24,14 +24,13 @@ import json
import os
import re
import shutil
-from typing import Any, BinaryIO, Dict, List, Optional, Sequence, Tuple, Type
+from typing import Any, BinaryIO, ClassVar, Dict, List, Optional, Sequence, Tuple, Type
from unittest.mock import MagicMock, Mock, patch
from urllib import parse
from urllib.parse import quote, urlencode
from parameterized import parameterized, parameterized_class
from PIL import Image as Image
-from typing_extensions import ClassVar
from twisted.internet import defer
from twisted.internet._resolver import HostResolution
@@ -66,6 +65,7 @@ from tests.media.test_media_storage import (
SVG,
TestImage,
empty_file,
+ small_cmyk_jpeg,
small_lossless_webp,
small_png,
small_png_with_transparency,
@@ -137,6 +137,7 @@ class MediaDomainBlockingTests(unittest.HomeserverTestCase):
time_now_ms=clock.time_msec(),
upload_name="test.png",
filesystem_id=file_id,
+ sha256=file_id,
)
)
self.register_user("user", "password")
@@ -1005,7 +1006,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
data = base64.b64encode(SMALL_PNG)
end_content = (
- b"<html><head>" b'<img src="data:image/png;base64,%s" />' b"</head></html>"
+ b'<html><head><img src="data:image/png;base64,%s" /></head></html>'
) % (data,)
channel = self.make_request(
@@ -1617,6 +1618,63 @@ class MediaConfigTest(unittest.HomeserverTestCase):
)
+class MediaConfigModuleCallbackTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ media.register_servlets,
+ admin.register_servlets,
+ login.register_servlets,
+ ]
+
+ def make_homeserver(
+ self, reactor: ThreadedMemoryReactorClock, clock: Clock
+ ) -> HomeServer:
+ config = self.default_config()
+
+ self.storage_path = self.mktemp()
+ self.media_store_path = self.mktemp()
+ os.mkdir(self.storage_path)
+ os.mkdir(self.media_store_path)
+ config["media_store_path"] = self.media_store_path
+
+ provider_config = {
+ "module": "synapse.media.storage_provider.FileStorageProviderBackend",
+ "store_local": True,
+ "store_synchronous": False,
+ "store_remote": True,
+ "config": {"directory": self.storage_path},
+ }
+
+ config["media_storage_providers"] = [provider_config]
+
+ return self.setup_test_homeserver(config=config)
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.user = self.register_user("user", "password")
+ self.tok = self.login("user", "password")
+
+ hs.get_module_api().register_media_repository_callbacks(
+ get_media_config_for_user=self.get_media_config_for_user,
+ )
+
+ async def get_media_config_for_user(
+ self,
+ user_id: str,
+ ) -> Optional[JsonDict]:
+ # We echo back the user_id and set a custom upload size.
+ return {"m.upload.size": 1024, "user_id": user_id}
+
+ def test_media_config(self) -> None:
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/v1/media/config",
+ shorthand=False,
+ access_token=self.tok,
+ )
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body["m.upload.size"], 1024)
+ self.assertEqual(channel.json_body["user_id"], self.user)
+
+
class RemoteDownloadLimiterTestCase(unittest.HomeserverTestCase):
servlets = [
media.register_servlets,
@@ -1916,6 +1974,7 @@ class RemoteDownloadLimiterTestCase(unittest.HomeserverTestCase):
test_images = [
small_png,
small_png_with_transparency,
+ small_cmyk_jpeg,
small_lossless_webp,
empty_file,
SVG,
@@ -1957,7 +2016,7 @@ class DownloadAndThumbnailTestCase(unittest.HomeserverTestCase):
"""A mock for MatrixFederationHttpClient.federation_get_file."""
def write_to(
- r: Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]], bytes]]
+ r: Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]], bytes]],
) -> Tuple[int, Dict[bytes, List[bytes]], bytes]:
data, response = r
output_stream.write(data)
@@ -1991,7 +2050,7 @@ class DownloadAndThumbnailTestCase(unittest.HomeserverTestCase):
"""A mock for MatrixFederationHttpClient.get_file."""
def write_to(
- r: Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]
+ r: Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]],
) -> Tuple[int, Dict[bytes, List[bytes]]]:
data, response = r
output_stream.write(data)
@@ -2400,7 +2459,7 @@ class DownloadAndThumbnailTestCase(unittest.HomeserverTestCase):
if expected_body is not None:
self.assertEqual(
- channel.result["body"], expected_body, channel.result["body"]
+ channel.result["body"], expected_body, channel.result["body"].hex()
)
else:
# ensure that the result is at least some valid image
@@ -2592,6 +2651,7 @@ class AuthenticatedMediaTestCase(unittest.HomeserverTestCase):
time_now_ms=self.clock.time_msec(),
upload_name="remote_test.png",
filesystem_id=file_id,
+ sha256=file_id,
)
)
@@ -2675,3 +2735,114 @@ class AuthenticatedMediaTestCase(unittest.HomeserverTestCase):
access_token=self.tok,
)
self.assertEqual(channel10.code, 200)
+
+ def test_authenticated_media_etag(self) -> None:
+ """Test that ETag works correctly with authenticated media over client
+ APIs"""
+
+ # upload some local media with authentication on
+ channel = self.make_request(
+ "POST",
+ "_matrix/media/v3/upload?filename=test_png_upload",
+ SMALL_PNG,
+ self.tok,
+ shorthand=False,
+ content_type=b"image/png",
+ custom_headers=[("Content-Length", str(67))],
+ )
+ self.assertEqual(channel.code, 200)
+ res = channel.json_body.get("content_uri")
+ assert res is not None
+ uri = res.split("mxc://")[1]
+
+ # Check standard media endpoint
+ self._check_caching(f"/download/{uri}")
+
+ # check thumbnails as well
+ params = "?width=32&height=32&method=crop"
+ self._check_caching(f"/thumbnail/{uri}{params}")
+
+ # Inject a piece of remote media.
+ file_id = "abcdefg12345"
+ file_info = FileInfo(server_name="lonelyIsland", file_id=file_id)
+
+ media_storage = self.hs.get_media_repository().media_storage
+
+ ctx = media_storage.store_into_file(file_info)
+ (f, fname) = self.get_success(ctx.__aenter__())
+ f.write(SMALL_PNG)
+ self.get_success(ctx.__aexit__(None, None, None))
+
+ # we write the authenticated status when storing media, so this should pick up
+ # config and authenticate the media
+ self.get_success(
+ self.store.store_cached_remote_media(
+ origin="lonelyIsland",
+ media_id="52",
+ media_type="image/png",
+ media_length=1,
+ time_now_ms=self.clock.time_msec(),
+ upload_name="remote_test.png",
+ filesystem_id=file_id,
+ sha256=file_id,
+ )
+ )
+
+ # ensure we have thumbnails for the non-dynamic code path
+ if self.extra_config == {"dynamic_thumbnails": False}:
+ self.get_success(
+ self.repo._generate_thumbnails(
+ "lonelyIsland", "52", file_id, "image/png"
+ )
+ )
+
+ self._check_caching("/download/lonelyIsland/52")
+
+ params = "?width=32&height=32&method=crop"
+ self._check_caching(f"/thumbnail/lonelyIsland/52{params}")
+
+ def _check_caching(self, path: str) -> None:
+ """
+ Checks that:
+ 1. fetching the path returns an ETag header
+ 2. refetching with the ETag returns a 304 without a body
+ 3. refetching with the ETag but through unauthenticated endpoint
+ returns 404
+ """
+
+ # Request media over authenticated endpoint, should be found
+ channel1 = self.make_request(
+ "GET",
+ f"/_matrix/client/v1/media{path}",
+ access_token=self.tok,
+ shorthand=False,
+ )
+ self.assertEqual(channel1.code, 200)
+
+ # Should have a single ETag field
+ etags = channel1.headers.getRawHeaders("ETag")
+ self.assertIsNotNone(etags)
+ assert etags is not None # For mypy
+ self.assertEqual(len(etags), 1)
+ etag = etags[0]
+
+ # Refetching with the etag should result in 304 and empty body.
+ channel2 = self.make_request(
+ "GET",
+ f"/_matrix/client/v1/media{path}",
+ access_token=self.tok,
+ shorthand=False,
+ custom_headers=[("If-None-Match", etag)],
+ )
+ self.assertEqual(channel2.code, 304)
+ self.assertEqual(channel2.is_finished(), True)
+ self.assertNotIn("body", channel2.result)
+
+ # Refetching with the etag but no access token should result in 404.
+ channel3 = self.make_request(
+ "GET",
+ f"/_matrix/media/r0{path}",
+ shorthand=False,
+ custom_headers=[("If-None-Match", etag)],
+ )
+ self.assertEqual(channel3.code, 404)
diff --git a/tests/rest/client/test_models.py b/tests/rest/client/test_models.py
deleted file mode 100644
index f8a56c80ca..0000000000
--- a/tests/rest/client/test_models.py
+++ /dev/null
@@ -1,89 +0,0 @@
-#
-# This file is licensed under the Affero General Public License (AGPL) version 3.
-#
-# Copyright 2022 The Matrix.org Foundation C.I.C.
-# Copyright (C) 2023 New Vector, Ltd
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Affero General Public License as
-# published by the Free Software Foundation, either version 3 of the
-# License, or (at your option) any later version.
-#
-# See the GNU Affero General Public License for more details:
-# <https://www.gnu.org/licenses/agpl-3.0.html>.
-#
-# Originally licensed under the Apache License, Version 2.0:
-# <http://www.apache.org/licenses/LICENSE-2.0>.
-#
-# [This file includes modifications made by New Vector Limited]
-#
-#
-import unittest as stdlib_unittest
-from typing import TYPE_CHECKING
-
-from typing_extensions import Literal
-
-from synapse._pydantic_compat import HAS_PYDANTIC_V2
-from synapse.types.rest.client import EmailRequestTokenBody
-
-if TYPE_CHECKING or HAS_PYDANTIC_V2:
- from pydantic.v1 import BaseModel, ValidationError
-else:
- from pydantic import BaseModel, ValidationError
-
-
-class ThreepidMediumEnumTestCase(stdlib_unittest.TestCase):
- class Model(BaseModel):
- medium: Literal["email", "msisdn"]
-
- def test_accepts_valid_medium_string(self) -> None:
- """Sanity check that Pydantic behaves sensibly with an enum-of-str
-
- This is arguably more of a test of a class that inherits from str and Enum
- simultaneously.
- """
- model = self.Model.parse_obj({"medium": "email"})
- self.assertEqual(model.medium, "email")
-
- def test_rejects_invalid_medium_value(self) -> None:
- with self.assertRaises(ValidationError):
- self.Model.parse_obj({"medium": "interpretive_dance"})
-
- def test_rejects_invalid_medium_type(self) -> None:
- with self.assertRaises(ValidationError):
- self.Model.parse_obj({"medium": 123})
-
-
-class EmailRequestTokenBodyTestCase(stdlib_unittest.TestCase):
- base_request = {
- "client_secret": "hunter2",
- "email": "alice@wonderland.com",
- "send_attempt": 1,
- }
-
- def test_token_required_if_id_server_provided(self) -> None:
- with self.assertRaises(ValidationError):
- EmailRequestTokenBody.parse_obj(
- {
- **self.base_request,
- "id_server": "identity.wonderland.com",
- }
- )
- with self.assertRaises(ValidationError):
- EmailRequestTokenBody.parse_obj(
- {
- **self.base_request,
- "id_server": "identity.wonderland.com",
- "id_access_token": None,
- }
- )
-
- def test_token_typechecked_when_id_server_provided(self) -> None:
- with self.assertRaises(ValidationError):
- EmailRequestTokenBody.parse_obj(
- {
- **self.base_request,
- "id_server": "identity.wonderland.com",
- "id_access_token": 1337,
- }
- )
diff --git a/tests/rest/client/test_owned_state.py b/tests/rest/client/test_owned_state.py
new file mode 100644
index 0000000000..5fb5767676
--- /dev/null
+++ b/tests/rest/client/test_owned_state.py
@@ -0,0 +1,308 @@
+from http import HTTPStatus
+
+from parameterized import parameterized_class
+
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.api.errors import Codes
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
+from synapse.rest import admin
+from synapse.rest.client import login, room
+from synapse.server import HomeServer
+from synapse.types import JsonDict
+from synapse.util import Clock
+
+from tests.unittest import HomeserverTestCase
+
+_STATE_EVENT_TEST_TYPE = "com.example.test"
+
+# To stress-test parsing, include separator & sigil characters
+_STATE_KEY_SUFFIX = "_state_key_suffix:!@#$123"
+
+
+class OwnedStateBase(HomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.creator_user_id = self.register_user("creator", "pass")
+ self.creator_access_token = self.login("creator", "pass")
+ self.user1_user_id = self.register_user("user1", "pass")
+ self.user1_access_token = self.login("user1", "pass")
+
+ self.room_id = self.helper.create_room_as(
+ self.creator_user_id,
+ tok=self.creator_access_token,
+ is_public=True,
+ extra_content={
+ "power_level_content_override": {
+ "events": {
+ _STATE_EVENT_TEST_TYPE: 0,
+ },
+ },
+ },
+ )
+
+ self.helper.join(
+ room=self.room_id, user=self.user1_user_id, tok=self.user1_access_token
+ )
+
+
+class WithoutOwnedStateTestCase(OwnedStateBase):
+ def default_config(self) -> JsonDict:
+ config = super().default_config()
+ config["default_room_version"] = RoomVersions.V10.identifier
+ return config
+
+ def test_user_can_set_state_with_own_userid_key(self) -> None:
+ self.helper.send_state(
+ self.room_id,
+ _STATE_EVENT_TEST_TYPE,
+ {},
+ state_key=f"{self.user1_user_id}",
+ tok=self.user1_access_token,
+ expect_code=HTTPStatus.OK,
+ )
+
+ def test_room_creator_cannot_set_state_with_own_suffixed_key(self) -> None:
+ body = self.helper.send_state(
+ self.room_id,
+ _STATE_EVENT_TEST_TYPE,
+ {},
+ state_key=f"{self.creator_user_id}{_STATE_KEY_SUFFIX}",
+ tok=self.creator_access_token,
+ expect_code=HTTPStatus.FORBIDDEN,
+ )
+
+ self.assertEqual(
+ body["errcode"],
+ Codes.FORBIDDEN,
+ body,
+ )
+
+ def test_room_creator_cannot_set_state_with_other_userid_key(self) -> None:
+ body = self.helper.send_state(
+ self.room_id,
+ _STATE_EVENT_TEST_TYPE,
+ {},
+ state_key=f"{self.user1_user_id}",
+ tok=self.creator_access_token,
+ expect_code=HTTPStatus.FORBIDDEN,
+ )
+
+ self.assertEqual(
+ body["errcode"],
+ Codes.FORBIDDEN,
+ body,
+ )
+
+ def test_room_creator_cannot_set_state_with_other_suffixed_key(self) -> None:
+ body = self.helper.send_state(
+ self.room_id,
+ _STATE_EVENT_TEST_TYPE,
+ {},
+ state_key=f"{self.user1_user_id}{_STATE_KEY_SUFFIX}",
+ tok=self.creator_access_token,
+ expect_code=HTTPStatus.FORBIDDEN,
+ )
+
+ self.assertEqual(
+ body["errcode"],
+ Codes.FORBIDDEN,
+ body,
+ )
+
+ def test_room_creator_cannot_set_state_with_nonmember_userid_key(self) -> None:
+ body = self.helper.send_state(
+ self.room_id,
+ _STATE_EVENT_TEST_TYPE,
+ {},
+ state_key="@notinroom:hs2",
+ tok=self.creator_access_token,
+ expect_code=HTTPStatus.FORBIDDEN,
+ )
+
+ self.assertEqual(
+ body["errcode"],
+ Codes.FORBIDDEN,
+ body,
+ )
+
+ def test_room_creator_cannot_set_state_with_malformed_userid_key(self) -> None:
+ body = self.helper.send_state(
+ self.room_id,
+ _STATE_EVENT_TEST_TYPE,
+ {},
+ state_key="@oops",
+ tok=self.creator_access_token,
+ expect_code=HTTPStatus.FORBIDDEN,
+ )
+
+ self.assertEqual(
+ body["errcode"],
+ Codes.FORBIDDEN,
+ body,
+ )
+
+
+@parameterized_class(
+ ("room_version",),
+ [(i,) for i, v in KNOWN_ROOM_VERSIONS.items() if v.msc3757_enabled],
+)
+class MSC3757OwnedStateTestCase(OwnedStateBase):
+ room_version: str
+
+ def default_config(self) -> JsonDict:
+ config = super().default_config()
+ config["default_room_version"] = self.room_version
+ return config
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ super().prepare(reactor, clock, hs)
+
+ self.user2_user_id = self.register_user("user2", "pass")
+ self.user2_access_token = self.login("user2", "pass")
+
+ self.helper.join(
+ room=self.room_id, user=self.user2_user_id, tok=self.user2_access_token
+ )
+
+ def test_user_can_set_state_with_own_suffixed_key(self) -> None:
+ self.helper.send_state(
+ self.room_id,
+ _STATE_EVENT_TEST_TYPE,
+ {},
+ state_key=f"{self.user1_user_id}{_STATE_KEY_SUFFIX}",
+ tok=self.user1_access_token,
+ expect_code=HTTPStatus.OK,
+ )
+
+ def test_room_creator_can_set_state_with_other_userid_key(self) -> None:
+ self.helper.send_state(
+ self.room_id,
+ _STATE_EVENT_TEST_TYPE,
+ {},
+ state_key=f"{self.user1_user_id}",
+ tok=self.creator_access_token,
+ expect_code=HTTPStatus.OK,
+ )
+
+ def test_room_creator_can_set_state_with_other_suffixed_key(self) -> None:
+ self.helper.send_state(
+ self.room_id,
+ _STATE_EVENT_TEST_TYPE,
+ {},
+ state_key=f"{self.user1_user_id}{_STATE_KEY_SUFFIX}",
+ tok=self.creator_access_token,
+ expect_code=HTTPStatus.OK,
+ )
+
+ def test_user_cannot_set_state_with_other_userid_key(self) -> None:
+ body = self.helper.send_state(
+ self.room_id,
+ _STATE_EVENT_TEST_TYPE,
+ {},
+ state_key=f"{self.user2_user_id}",
+ tok=self.user1_access_token,
+ expect_code=HTTPStatus.FORBIDDEN,
+ )
+
+ self.assertEqual(
+ body["errcode"],
+ Codes.FORBIDDEN,
+ body,
+ )
+
+ def test_user_cannot_set_state_with_other_suffixed_key(self) -> None:
+ body = self.helper.send_state(
+ self.room_id,
+ _STATE_EVENT_TEST_TYPE,
+ {},
+ state_key=f"{self.user2_user_id}{_STATE_KEY_SUFFIX}",
+ tok=self.user1_access_token,
+ expect_code=HTTPStatus.FORBIDDEN,
+ )
+
+ self.assertEqual(
+ body["errcode"],
+ Codes.FORBIDDEN,
+ body,
+ )
+
+ def test_user_cannot_set_state_with_unseparated_suffixed_key(self) -> None:
+ body = self.helper.send_state(
+ self.room_id,
+ _STATE_EVENT_TEST_TYPE,
+ {},
+ state_key=f"{self.user1_user_id}{_STATE_KEY_SUFFIX[1:]}",
+ tok=self.user1_access_token,
+ expect_code=HTTPStatus.FORBIDDEN,
+ )
+
+ self.assertEqual(
+ body["errcode"],
+ Codes.FORBIDDEN,
+ body,
+ )
+
+ def test_user_cannot_set_state_with_misplaced_userid_in_key(self) -> None:
+ body = self.helper.send_state(
+ self.room_id,
+ _STATE_EVENT_TEST_TYPE,
+ {},
+ # Still put @ at start of state key, because without it, there is no write protection at all
+ state_key=f"@prefix_{self.user1_user_id}{_STATE_KEY_SUFFIX}",
+ tok=self.user1_access_token,
+ expect_code=HTTPStatus.FORBIDDEN,
+ )
+
+ self.assertEqual(
+ body["errcode"],
+ Codes.FORBIDDEN,
+ body,
+ )
+
+ def test_room_creator_can_set_state_with_nonmember_userid_key(self) -> None:
+ self.helper.send_state(
+ self.room_id,
+ _STATE_EVENT_TEST_TYPE,
+ {},
+ state_key="@notinroom:hs2",
+ tok=self.creator_access_token,
+ expect_code=HTTPStatus.OK,
+ )
+
+ def test_room_creator_cannot_set_state_with_malformed_userid_key(self) -> None:
+ body = self.helper.send_state(
+ self.room_id,
+ _STATE_EVENT_TEST_TYPE,
+ {},
+ state_key="@oops",
+ tok=self.creator_access_token,
+ expect_code=HTTPStatus.BAD_REQUEST,
+ )
+
+ self.assertEqual(
+ body["errcode"],
+ Codes.BAD_JSON,
+ body,
+ )
+
+ def test_room_creator_cannot_set_state_with_improperly_suffixed_key(self) -> None:
+ body = self.helper.send_state(
+ self.room_id,
+ _STATE_EVENT_TEST_TYPE,
+ {},
+ state_key=f"{self.creator_user_id}@{_STATE_KEY_SUFFIX[1:]}",
+ tok=self.creator_access_token,
+ expect_code=HTTPStatus.BAD_REQUEST,
+ )
+
+ self.assertEqual(
+ body["errcode"],
+ Codes.BAD_JSON,
+ body,
+ )
diff --git a/tests/rest/client/test_presence.py b/tests/rest/client/test_presence.py
index 5ced8319e1..6b9c70974a 100644
--- a/tests/rest/client/test_presence.py
+++ b/tests/rest/client/test_presence.py
@@ -29,6 +29,7 @@ from synapse.types import UserID
from synapse.util import Clock
from tests import unittest
+from tests.unittest import override_config
class PresenceTestCase(unittest.HomeserverTestCase):
@@ -95,3 +96,54 @@ class PresenceTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, HTTPStatus.OK)
self.assertEqual(self.presence_handler.set_state.call_count, 0)
+
+ @override_config(
+ {"rc_presence": {"per_user": {"per_second": 0.1, "burst_count": 1}}}
+ )
+ def test_put_presence_over_ratelimit(self) -> None:
+ """
+ Multiple PUTs to the status endpoint without sufficient delay will be rate limited.
+ """
+ self.hs.config.server.presence_enabled = True
+
+ body = {"presence": "here", "status_msg": "beep boop"}
+ channel = self.make_request(
+ "PUT", "/presence/%s/status" % (self.user_id,), body
+ )
+
+ self.assertEqual(channel.code, HTTPStatus.OK)
+
+ body = {"presence": "here", "status_msg": "beep boop"}
+ channel = self.make_request(
+ "PUT", "/presence/%s/status" % (self.user_id,), body
+ )
+
+ self.assertEqual(channel.code, HTTPStatus.TOO_MANY_REQUESTS)
+ self.assertEqual(self.presence_handler.set_state.call_count, 1)
+
+ @override_config(
+ {"rc_presence": {"per_user": {"per_second": 0.1, "burst_count": 1}}}
+ )
+ def test_put_presence_within_ratelimit(self) -> None:
+ """
+ Multiple PUTs to the status endpoint with sufficient delay should all call set_state.
+ """
+ self.hs.config.server.presence_enabled = True
+
+ body = {"presence": "here", "status_msg": "beep boop"}
+ channel = self.make_request(
+ "PUT", "/presence/%s/status" % (self.user_id,), body
+ )
+
+ self.assertEqual(channel.code, HTTPStatus.OK)
+
+ # Advance time a sufficient amount to avoid rate limiting.
+ self.reactor.advance(30)
+
+ body = {"presence": "here", "status_msg": "beep boop"}
+ channel = self.make_request(
+ "PUT", "/presence/%s/status" % (self.user_id,), body
+ )
+
+ self.assertEqual(channel.code, HTTPStatus.OK)
+ self.assertEqual(self.presence_handler.set_state.call_count, 2)
diff --git a/tests/rest/client/test_profile.py b/tests/rest/client/test_profile.py
index f98f3f77aa..708402b792 100644
--- a/tests/rest/client/test_profile.py
+++ b/tests/rest/client/test_profile.py
@@ -20,20 +20,25 @@
#
"""Tests REST events for /profile paths."""
+
import urllib.parse
from http import HTTPStatus
from typing import Any, Dict, Optional
+from canonicaljson import encode_canonical_json
+
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.errors import Codes
from synapse.rest import admin
from synapse.rest.client import login, profile, room
from synapse.server import HomeServer
+from synapse.storage.databases.main.profile import MAX_PROFILE_SIZE
from synapse.types import UserID
from synapse.util import Clock
from tests import unittest
+from tests.utils import USE_POSTGRES_FOR_TESTS
class ProfileTestCase(unittest.HomeserverTestCase):
@@ -479,6 +484,298 @@ class ProfileTestCase(unittest.HomeserverTestCase):
# The client requested ?propagate=true, so it should have happened.
self.assertEqual(channel.json_body.get(prop), "http://my.server/pic.gif")
+ @unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
+ def test_get_missing_custom_field(self) -> None:
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field",
+ )
+ self.assertEqual(channel.code, HTTPStatus.NOT_FOUND, channel.result)
+ self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+ @unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
+ def test_get_missing_custom_field_invalid_field_name(self) -> None:
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/[custom_field]",
+ )
+ self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ @unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
+ def test_get_custom_field_rejects_bad_username(self) -> None:
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{urllib.parse.quote('@alice:')}/custom_field",
+ )
+ self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ @unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
+ def test_set_custom_field(self) -> None:
+ channel = self.make_request(
+ "PUT",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field",
+ content={"custom_field": "test"},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field",
+ )
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+ self.assertEqual(channel.json_body, {"custom_field": "test"})
+
+ # Overwriting the field should work.
+ channel = self.make_request(
+ "PUT",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field",
+ content={"custom_field": "new_Value"},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field",
+ )
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+ self.assertEqual(channel.json_body, {"custom_field": "new_Value"})
+
+ # Deleting the field should work.
+ channel = self.make_request(
+ "DELETE",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field",
+ content={},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field",
+ )
+ self.assertEqual(channel.code, HTTPStatus.NOT_FOUND, channel.result)
+ self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+ @unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
+ def test_non_string(self) -> None:
+ """Non-string fields are supported for custom fields."""
+ fields = {
+ "bool_field": True,
+ "array_field": ["test"],
+ "object_field": {"test": "test"},
+ "numeric_field": 1,
+ "null_field": None,
+ }
+
+ for key, value in fields.items():
+ channel = self.make_request(
+ "PUT",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}",
+ content={key: value},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/v3/profile/{self.owner}",
+ )
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+ self.assertEqual(channel.json_body, {"displayname": "owner", **fields})
+
+ # Check getting individual fields works.
+ for key, value in fields.items():
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}",
+ )
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+ self.assertEqual(channel.json_body, {key: value})
+
+ @unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
+ def test_set_custom_field_noauth(self) -> None:
+ channel = self.make_request(
+ "PUT",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field",
+ content={"custom_field": "test"},
+ )
+ self.assertEqual(channel.code, 401, channel.result)
+ self.assertEqual(channel.json_body["errcode"], Codes.MISSING_TOKEN)
+
+ @unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
+ def test_set_custom_field_size(self) -> None:
+ """
+ Attempts to set a custom field name that is too long should get a 400 error.
+ """
+ # Key is missing.
+ channel = self.make_request(
+ "PUT",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/",
+ content={"": "test"},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ # Single key is too large.
+ key = "c" * 500
+ channel = self.make_request(
+ "PUT",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}",
+ content={key: "test"},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(channel.json_body["errcode"], Codes.KEY_TOO_LARGE)
+
+ channel = self.make_request(
+ "DELETE",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}",
+ content={key: "test"},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(channel.json_body["errcode"], Codes.KEY_TOO_LARGE)
+
+ # Key doesn't match body.
+ channel = self.make_request(
+ "PUT",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field",
+ content={"diff_key": "test"},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(channel.json_body["errcode"], Codes.MISSING_PARAM)
+
+ @unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
+ def test_set_custom_field_profile_too_long(self) -> None:
+ """
+ Attempts to set a custom field that would push the overall profile too large.
+ """
+ # Get right to the boundary:
+ # len("displayname") + len("owner") + 5 = 21 for the displayname
+ # 1 + 65498 + 5 for key "a" = 65504
+ # 2 braces, 1 comma
+ # 3 + 21 + 65498 = 65522 < 65536.
+ key = "a"
+ channel = self.make_request(
+ "PUT",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}",
+ content={key: "a" * 65498},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ # Get the entire profile.
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/v3/profile/{self.owner}",
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+ canonical_json = encode_canonical_json(channel.json_body)
+ # 6 is the minimum bytes to store a value: 4 quotes, 1 colon, 1 comma, an empty key.
+ # Be one below that so we can prove we're at the boundary.
+ self.assertEqual(len(canonical_json), MAX_PROFILE_SIZE - 8)
+
+ # Postgres stores JSONB with whitespace, while SQLite doesn't.
+ if USE_POSTGRES_FOR_TESTS:
+ ADDITIONAL_CHARS = 0
+ else:
+ ADDITIONAL_CHARS = 1
+
+ # The next one should fail, note the value has a (JSON) length of 2.
+ key = "b"
+ channel = self.make_request(
+ "PUT",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}",
+ content={key: "1" + "a" * ADDITIONAL_CHARS},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(channel.json_body["errcode"], Codes.PROFILE_TOO_LARGE)
+
+ # Setting an avatar or (longer) display name should not work.
+ channel = self.make_request(
+ "PUT",
+ f"/profile/{self.owner}/displayname",
+ content={"displayname": "owner12345678" + "a" * ADDITIONAL_CHARS},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(channel.json_body["errcode"], Codes.PROFILE_TOO_LARGE)
+
+ channel = self.make_request(
+ "PUT",
+ f"/profile/{self.owner}/avatar_url",
+ content={"avatar_url": "mxc://foo/bar"},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(channel.json_body["errcode"], Codes.PROFILE_TOO_LARGE)
+
+ # Removing a single byte should work.
+ key = "b"
+ channel = self.make_request(
+ "PUT",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}",
+ content={key: "" + "a" * ADDITIONAL_CHARS},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ # Finally, setting a field that already exists to a value that is <= in length should work.
+ key = "a"
+ channel = self.make_request(
+ "PUT",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}",
+ content={key: ""},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ @unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
+ def test_set_custom_field_displayname(self) -> None:
+ channel = self.make_request(
+ "PUT",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/displayname",
+ content={"displayname": "test"},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ displayname = self._get_displayname()
+ self.assertEqual(displayname, "test")
+
+ @unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
+ def test_set_custom_field_avatar_url(self) -> None:
+ channel = self.make_request(
+ "PUT",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/avatar_url",
+ content={"avatar_url": "mxc://test/good"},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ avatar_url = self._get_avatar_url()
+ self.assertEqual(avatar_url, "mxc://test/good")
+
+ @unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
+ def test_set_custom_field_other(self) -> None:
+ """Setting someone else's profile field should fail"""
+ channel = self.make_request(
+ "PUT",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.other}/custom_field",
+ content={"custom_field": "test"},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 403, channel.result)
+ self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN)
+
def _setup_local_files(self, names_and_props: Dict[str, Dict[str, Any]]) -> None:
"""Stores metadata about files in the database.
diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py
index 694f143eff..d40efdfe1d 100644
--- a/tests/rest/client/test_register.py
+++ b/tests/rest/client/test_register.py
@@ -64,7 +64,6 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self, reactor: ThreadedMemoryReactorClock, clock: Clock
) -> HomeServer:
hs = super().make_homeserver(reactor, clock)
- hs.get_send_email_handler()._sendmail = AsyncMock()
return hs
def test_POST_appservice_registration_valid(self) -> None:
@@ -120,6 +119,34 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 401, msg=channel.result)
+ def test_POST_appservice_msc4190_enabled(self) -> None:
+ # With MSC4190 enabled, the registration should *not* return an access token
+ user_id = "@as_user_kermit:test"
+ as_token = "i_am_an_app_service"
+
+ appservice = ApplicationService(
+ as_token,
+ id="1234",
+ namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
+ sender="@as:test",
+ msc4190_device_management=True,
+ )
+
+ self.hs.get_datastores().main.services_cache.append(appservice)
+ request_data = {
+ "username": "as_user_kermit",
+ "type": APP_SERVICE_REGISTRATION_TYPE,
+ }
+
+ channel = self.make_request(
+ b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
+ )
+
+ self.assertEqual(channel.code, 200, msg=channel.result)
+ det_data = {"user_id": user_id, "home_server": self.hs.hostname}
+ self.assertLessEqual(det_data.items(), channel.json_body.items())
+ self.assertNotIn("access_token", channel.json_body)
+
def test_POST_bad_password(self) -> None:
request_data = {"username": "kermit", "password": 666}
channel = self.make_request(b"POST", self.url, request_data)
@@ -593,155 +620,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
# with the stock config, we only expect the dummy flow
self.assertCountEqual([["m.login.dummy"]], (f["stages"] for f in flows))
-
- @unittest.override_config(
- {
- "public_baseurl": "https://test_server",
- "enable_registration_captcha": True,
- "user_consent": {
- "version": "1",
- "template_dir": "/",
- "require_at_registration": True,
- },
- "account_threepid_delegates": {
- "msisdn": "https://id_server",
- },
- "email": {"notif_from": "Synapse <synapse@example.com>"},
- }
- )
- def test_advertised_flows_captcha_and_terms_and_3pids(self) -> None:
- channel = self.make_request(b"POST", self.url, b"{}")
- self.assertEqual(channel.code, 401, msg=channel.result)
- flows = channel.json_body["flows"]
-
- self.assertCountEqual(
- [
- ["m.login.recaptcha", "m.login.terms", "m.login.dummy"],
- ["m.login.recaptcha", "m.login.terms", "m.login.email.identity"],
- ["m.login.recaptcha", "m.login.terms", "m.login.msisdn"],
- [
- "m.login.recaptcha",
- "m.login.terms",
- "m.login.msisdn",
- "m.login.email.identity",
- ],
- ],
- (f["stages"] for f in flows),
- )
-
- @unittest.override_config(
- {
- "public_baseurl": "https://test_server",
- "registrations_require_3pid": ["email"],
- "disable_msisdn_registration": True,
- "email": {
- "smtp_host": "mail_server",
- "smtp_port": 2525,
- "notif_from": "sender@host",
- },
- }
- )
- def test_advertised_flows_no_msisdn_email_required(self) -> None:
- channel = self.make_request(b"POST", self.url, b"{}")
- self.assertEqual(channel.code, 401, msg=channel.result)
- flows = channel.json_body["flows"]
-
- # with the stock config, we expect all four combinations of 3pid
- self.assertCountEqual(
- [["m.login.email.identity"]], (f["stages"] for f in flows)
- )
-
- @unittest.override_config(
- {
- "request_token_inhibit_3pid_errors": True,
- "public_baseurl": "https://test_server",
- "email": {
- "smtp_host": "mail_server",
- "smtp_port": 2525,
- "notif_from": "sender@host",
- },
- }
- )
- def test_request_token_existing_email_inhibit_error(self) -> None:
- """Test that requesting a token via this endpoint doesn't leak existing
- associations if configured that way.
- """
- user_id = self.register_user("kermit", "monkey")
- self.login("kermit", "monkey")
-
- email = "test@example.com"
-
- # Add a threepid
- self.get_success(
- self.hs.get_datastores().main.user_add_threepid(
- user_id=user_id,
- medium="email",
- address=email,
- validated_at=0,
- added_at=0,
- )
- )
-
- channel = self.make_request(
- "POST",
- b"register/email/requestToken",
- {"client_secret": "foobar", "email": email, "send_attempt": 1},
- )
- self.assertEqual(200, channel.code, channel.result)
-
- self.assertIsNotNone(channel.json_body.get("sid"))
-
- @unittest.override_config(
- {
- "public_baseurl": "https://test_server",
- "email": {
- "smtp_host": "mail_server",
- "smtp_port": 2525,
- "notif_from": "sender@host",
- },
- }
- )
- def test_reject_invalid_email(self) -> None:
- """Check that bad emails are rejected"""
-
- # Test for email with multiple @
- channel = self.make_request(
- "POST",
- b"register/email/requestToken",
- {"client_secret": "foobar", "email": "email@@email", "send_attempt": 1},
- )
- self.assertEqual(400, channel.code, channel.result)
- # Check error to ensure that we're not erroring due to a bug in the test.
- self.assertEqual(
- channel.json_body,
- {"errcode": "M_UNKNOWN", "error": "Unable to parse email address"},
- )
-
- # Test for email with no @
- channel = self.make_request(
- "POST",
- b"register/email/requestToken",
- {"client_secret": "foobar", "email": "email", "send_attempt": 1},
- )
- self.assertEqual(400, channel.code, channel.result)
- self.assertEqual(
- channel.json_body,
- {"errcode": "M_UNKNOWN", "error": "Unable to parse email address"},
- )
-
- # Test for super long email
- email = "a@" + "a" * 1000
- channel = self.make_request(
- "POST",
- b"register/email/requestToken",
- {"client_secret": "foobar", "email": email, "send_attempt": 1},
- )
- self.assertEqual(400, channel.code, channel.result)
- self.assertEqual(
- channel.json_body,
- {"errcode": "M_UNKNOWN", "error": "Unable to parse email address"},
- )
-
+
@override_config(
{
"inhibit_user_in_use_error": True,
@@ -925,224 +804,6 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200, msg=channel.result)
-class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
- servlets = [
- register.register_servlets,
- synapse.rest.admin.register_servlets_for_client_rest_resource,
- login.register_servlets,
- sync.register_servlets,
- account_validity.register_servlets,
- account.register_servlets,
- ]
-
- def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
- config = self.default_config()
-
- # Test for account expiring after a week and renewal emails being sent 2
- # days before expiry.
- config["enable_registration"] = True
- config["account_validity"] = {
- "enabled": True,
- "period": 604800000, # Time in ms for 1 week
- "renew_at": 172800000, # Time in ms for 2 days
- "renew_by_email_enabled": True,
- "renew_email_subject": "Renew your account",
- "account_renewed_html_path": "account_renewed.html",
- "invalid_token_html_path": "invalid_token.html",
- }
-
- # Email config.
-
- config["email"] = {
- "enable_notifs": True,
- "template_dir": os.path.abspath(
- pkg_resources.resource_filename("synapse", "res/templates")
- ),
- "expiry_template_html": "notice_expiry.html",
- "expiry_template_text": "notice_expiry.txt",
- "notif_template_html": "notif_mail.html",
- "notif_template_text": "notif_mail.txt",
- "smtp_host": "127.0.0.1",
- "smtp_port": 20,
- "require_transport_security": False,
- "smtp_user": None,
- "smtp_pass": None,
- "notif_from": "test@example.com",
- }
-
- self.hs = self.setup_test_homeserver(config=config)
-
- async def sendmail(*args: Any, **kwargs: Any) -> None:
- self.email_attempts.append((args, kwargs))
-
- self.email_attempts: List[Tuple[Any, Any]] = []
- self.hs.get_send_email_handler()._sendmail = sendmail
-
- self.store = self.hs.get_datastores().main
-
- return self.hs
-
- def test_renewal_email(self) -> None:
- self.email_attempts = []
-
- (user_id, tok) = self.create_user()
-
- # Move 5 days forward. This should trigger a renewal email to be sent.
- self.reactor.advance(datetime.timedelta(days=5).total_seconds())
- self.assertEqual(len(self.email_attempts), 1)
-
- # Retrieving the URL from the email is too much pain for now, so we
- # retrieve the token from the DB.
- renewal_token = self.get_success(self.store.get_renewal_token_for_user(user_id))
- url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token
- channel = self.make_request(b"GET", url)
- self.assertEqual(channel.code, 200, msg=channel.result)
-
- # Check that we're getting HTML back.
- content_type = channel.headers.getRawHeaders(b"Content-Type")
- self.assertEqual(content_type, [b"text/html; charset=utf-8"], channel.result)
-
- # Check that the HTML we're getting is the one we expect on a successful renewal.
- expiration_ts = self.get_success(self.store.get_expiration_ts_for_user(user_id))
- expected_html = self.hs.config.account_validity.account_validity_account_renewed_template.render(
- expiration_ts=expiration_ts
- )
- self.assertEqual(
- channel.result["body"], expected_html.encode("utf8"), channel.result
- )
-
- # Move 1 day forward. Try to renew with the same token again.
- url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token
- channel = self.make_request(b"GET", url)
- self.assertEqual(channel.code, 200, msg=channel.result)
-
- # Check that we're getting HTML back.
- content_type = channel.headers.getRawHeaders(b"Content-Type")
- self.assertEqual(content_type, [b"text/html; charset=utf-8"], channel.result)
-
- # Check that the HTML we're getting is the one we expect when reusing a
- # token. The account expiration date should not have changed.
- expected_html = self.hs.config.account_validity.account_validity_account_previously_renewed_template.render(
- expiration_ts=expiration_ts
- )
- self.assertEqual(
- channel.result["body"], expected_html.encode("utf8"), channel.result
- )
-
- # Move 3 days forward. If the renewal failed, every authed request with
- # our access token should be denied from now, otherwise they should
- # succeed.
- self.reactor.advance(datetime.timedelta(days=3).total_seconds())
- channel = self.make_request(b"GET", "/sync", access_token=tok)
- self.assertEqual(channel.code, 200, msg=channel.result)
-
- def test_renewal_invalid_token(self) -> None:
- # Hit the renewal endpoint with an invalid token and check that it behaves as
- # expected, i.e. that it responds with 404 Not Found and the correct HTML.
- url = "/_matrix/client/unstable/account_validity/renew?token=123"
- channel = self.make_request(b"GET", url)
- self.assertEqual(channel.code, 404, msg=channel.result)
-
- # Check that we're getting HTML back.
- content_type = channel.headers.getRawHeaders(b"Content-Type")
- self.assertEqual(content_type, [b"text/html; charset=utf-8"], channel.result)
-
- # Check that the HTML we're getting is the one we expect when using an
- # invalid/unknown token.
- expected_html = (
- self.hs.config.account_validity.account_validity_invalid_token_template.render()
- )
- self.assertEqual(
- channel.result["body"], expected_html.encode("utf8"), channel.result
- )
-
- def test_manual_email_send(self) -> None:
- self.email_attempts = []
-
- (user_id, tok) = self.create_user()
- channel = self.make_request(
- b"POST",
- "/_matrix/client/unstable/account_validity/send_mail",
- access_token=tok,
- )
- self.assertEqual(channel.code, 200, msg=channel.result)
-
- self.assertEqual(len(self.email_attempts), 1)
-
- def test_deactivated_user(self) -> None:
- self.email_attempts = []
-
- (user_id, tok) = self.create_user()
-
- request_data = {
- "auth": {
- "type": "m.login.password",
- "user": user_id,
- "password": "monkey",
- },
- "erase": False,
- }
- channel = self.make_request(
- "POST", "account/deactivate", request_data, access_token=tok
- )
- self.assertEqual(channel.code, 200)
-
- self.reactor.advance(datetime.timedelta(days=8).total_seconds())
-
- self.assertEqual(len(self.email_attempts), 0)
-
- def create_user(self) -> Tuple[str, str]:
- user_id = self.register_user("kermit", "monkey")
- tok = self.login("kermit", "monkey")
- # We need to manually add an email address otherwise the handler will do
- # nothing.
- now = self.hs.get_clock().time_msec()
- self.get_success(
- self.store.user_add_threepid(
- user_id=user_id,
- medium="email",
- address="kermit@example.com",
- validated_at=now,
- added_at=now,
- )
- )
- return user_id, tok
-
- def test_manual_email_send_expired_account(self) -> None:
- user_id = self.register_user("kermit", "monkey")
- tok = self.login("kermit", "monkey")
-
- # We need to manually add an email address otherwise the handler will do
- # nothing.
- now = self.hs.get_clock().time_msec()
- self.get_success(
- self.store.user_add_threepid(
- user_id=user_id,
- medium="email",
- address="kermit@example.com",
- validated_at=now,
- added_at=now,
- )
- )
-
- # Make the account expire.
- self.reactor.advance(datetime.timedelta(days=8).total_seconds())
-
- # Ignore all emails sent by the automatic background task and only focus on the
- # ones sent manually.
- self.email_attempts = []
-
- # Test that we're still able to manually trigger a mail to be sent.
- channel = self.make_request(
- b"POST",
- "/_matrix/client/unstable/account_validity/send_mail",
- access_token=tok,
- )
- self.assertEqual(channel.code, 200, msg=channel.result)
-
- self.assertEqual(len(self.email_attempts), 1)
-
-
class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase):
servlets = [synapse.rest.admin.register_servlets_for_client_rest_resource]
diff --git a/tests/rest/client/test_rendezvous.py b/tests/rest/client/test_rendezvous.py
index 0ab754a11a..83a5cbdc15 100644
--- a/tests/rest/client/test_rendezvous.py
+++ b/tests/rest/client/test_rendezvous.py
@@ -34,7 +34,6 @@ from tests import unittest
from tests.unittest import override_config
from tests.utils import HAS_AUTHLIB
-msc3886_endpoint = "/_matrix/client/unstable/org.matrix.msc3886/rendezvous"
msc4108_endpoint = "/_matrix/client/unstable/org.matrix.msc4108/rendezvous"
@@ -54,17 +53,9 @@ class RendezvousServletTestCase(unittest.HomeserverTestCase):
}
def test_disabled(self) -> None:
- channel = self.make_request("POST", msc3886_endpoint, {}, access_token=None)
- self.assertEqual(channel.code, 404)
channel = self.make_request("POST", msc4108_endpoint, {}, access_token=None)
self.assertEqual(channel.code, 404)
- @override_config({"experimental_features": {"msc3886_endpoint": "/asd"}})
- def test_msc3886_redirect(self) -> None:
- channel = self.make_request("POST", msc3886_endpoint, {}, access_token=None)
- self.assertEqual(channel.code, 307)
- self.assertEqual(channel.headers.getRawHeaders("Location"), ["/asd"])
-
@unittest.skip_unless(HAS_AUTHLIB, "requires authlib")
@override_config(
{
@@ -126,10 +117,11 @@ class RendezvousServletTestCase(unittest.HomeserverTestCase):
headers = dict(channel.headers.getAllRawHeaders())
self.assertIn(b"ETag", headers)
self.assertIn(b"Expires", headers)
+ self.assertIn(b"Content-Length", headers)
self.assertEqual(headers[b"Content-Type"], [b"application/json"])
self.assertEqual(headers[b"Access-Control-Allow-Origin"], [b"*"])
self.assertEqual(headers[b"Access-Control-Expose-Headers"], [b"etag"])
- self.assertEqual(headers[b"Cache-Control"], [b"no-store"])
+ self.assertEqual(headers[b"Cache-Control"], [b"no-store, no-transform"])
self.assertEqual(headers[b"Pragma"], [b"no-cache"])
self.assertIn("url", channel.json_body)
self.assertTrue(channel.json_body["url"].startswith("https://"))
@@ -150,9 +142,10 @@ class RendezvousServletTestCase(unittest.HomeserverTestCase):
self.assertEqual(headers[b"ETag"], [etag])
self.assertIn(b"Expires", headers)
self.assertEqual(headers[b"Content-Type"], [b"text/plain"])
+ self.assertEqual(headers[b"Content-Length"], [b"7"])
self.assertEqual(headers[b"Access-Control-Allow-Origin"], [b"*"])
self.assertEqual(headers[b"Access-Control-Expose-Headers"], [b"etag"])
- self.assertEqual(headers[b"Cache-Control"], [b"no-store"])
+ self.assertEqual(headers[b"Cache-Control"], [b"no-store, no-transform"])
self.assertEqual(headers[b"Pragma"], [b"no-cache"])
self.assertEqual(channel.text_body, "foo=bar")
diff --git a/tests/rest/client/test_reporting.py b/tests/rest/client/test_reporting.py
index 009deb9cb0..723553979f 100644
--- a/tests/rest/client/test_reporting.py
+++ b/tests/rest/client/test_reporting.py
@@ -156,58 +156,31 @@ class ReportRoomTestCase(unittest.HomeserverTestCase):
self.room_id = self.helper.create_room_as(
self.other_user, tok=self.other_user_tok, is_public=True
)
- self.report_path = (
- f"/_matrix/client/unstable/org.matrix.msc4151/rooms/{self.room_id}/report"
- )
+ self.report_path = f"/_matrix/client/v3/rooms/{self.room_id}/report"
- @unittest.override_config(
- {
- "experimental_features": {"msc4151_enabled": True},
- }
- )
def test_reason_str(self) -> None:
data = {"reason": "this makes me sad"}
self._assert_status(200, data)
- @unittest.override_config(
- {
- "experimental_features": {"msc4151_enabled": True},
- }
- )
def test_no_reason(self) -> None:
data = {"not_reason": "for typechecking"}
self._assert_status(400, data)
- @unittest.override_config(
- {
- "experimental_features": {"msc4151_enabled": True},
- }
- )
def test_reason_nonstring(self) -> None:
data = {"reason": 42}
self._assert_status(400, data)
- @unittest.override_config(
- {
- "experimental_features": {"msc4151_enabled": True},
- }
- )
def test_reason_null(self) -> None:
data = {"reason": None}
self._assert_status(400, data)
- @unittest.override_config(
- {
- "experimental_features": {"msc4151_enabled": True},
- }
- )
def test_cannot_report_nonexistent_room(self) -> None:
"""
Tests that we don't accept event reports for rooms which do not exist.
"""
channel = self.make_request(
"POST",
- "/_matrix/client/unstable/org.matrix.msc4151/rooms/!bloop:example.org/report",
+ "/_matrix/client/v3/rooms/!bloop:example.org/report",
{"reason": "i am very sad"},
access_token=self.other_user_tok,
shorthand=False,
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index c559dfda83..04442febb4 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -4,7 +4,7 @@
# Copyright 2019 The Matrix.org Foundation C.I.C.
# Copyright 2017 Vector Creations Ltd
# Copyright 2014-2016 OpenMarket Ltd
-# Copyright (C) 2023 New Vector, Ltd
+# Copyright (C) 2023-2024 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
@@ -25,12 +25,11 @@
import json
from http import HTTPStatus
-from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
+from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple, Union
from unittest.mock import AsyncMock, Mock, call, patch
from urllib import parse as urlparse
from parameterized import param, parameterized
-from typing_extensions import Literal
from twisted.test.proto_helpers import MemoryReactor
@@ -68,6 +67,7 @@ from tests.http.server._base import make_request_with_cancellation_test
from tests.storage.test_stream import PaginationTestCase
from tests.test_utils.event_injection import create_event
from tests.unittest import override_config
+from tests.utils import default_config
PATH_PREFIX = b"/_matrix/client/api/v1"
@@ -742,7 +742,7 @@ class RoomsCreateTestCase(RoomBase):
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
self.assertTrue("room_id" in channel.json_body)
assert channel.resource_usage is not None
- self.assertEqual(33, channel.resource_usage.db_txn_count)
+ self.assertEqual(35, channel.resource_usage.db_txn_count)
def test_post_room_initial_state(self) -> None:
# POST with initial_state config key, expect new room id
@@ -755,7 +755,7 @@ class RoomsCreateTestCase(RoomBase):
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
self.assertTrue("room_id" in channel.json_body)
assert channel.resource_usage is not None
- self.assertEqual(35, channel.resource_usage.db_txn_count)
+ self.assertEqual(37, channel.resource_usage.db_txn_count)
def test_post_room_visibility_key(self) -> None:
# POST with visibility config key, expect new room id
@@ -1337,17 +1337,13 @@ class RoomJoinTestCase(RoomBase):
"POST", f"/join/{self.room1}", access_token=self.tok2
)
self.assertEqual(channel.code, 403)
- self.assertEqual(
- channel.json_body["errcode"], "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED"
- )
+ self.assertEqual(channel.json_body["errcode"], "M_USER_SUSPENDED")
channel = self.make_request(
"POST", f"/rooms/{self.room1}/join", access_token=self.tok2
)
self.assertEqual(channel.code, 403)
- self.assertEqual(
- channel.json_body["errcode"], "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED"
- )
+ self.assertEqual(channel.json_body["errcode"], "M_USER_SUSPENDED")
def test_suspended_user_cannot_knock_on_room(self) -> None:
# set the user as suspended
@@ -1361,9 +1357,7 @@ class RoomJoinTestCase(RoomBase):
shorthand=False,
)
self.assertEqual(channel.code, 403)
- self.assertEqual(
- channel.json_body["errcode"], "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED"
- )
+ self.assertEqual(channel.json_body["errcode"], "M_USER_SUSPENDED")
def test_suspended_user_cannot_invite_to_room(self) -> None:
# set the user as suspended
@@ -1376,9 +1370,24 @@ class RoomJoinTestCase(RoomBase):
access_token=self.tok1,
content={"user_id": self.user2},
)
- self.assertEqual(
- channel.json_body["errcode"], "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED"
+ self.assertEqual(channel.json_body["errcode"], "M_USER_SUSPENDED")
+
+ def test_suspended_user_can_leave_room(self) -> None:
+ channel = self.make_request(
+ "POST", f"/join/{self.room1}", access_token=self.tok1
)
+ self.assertEqual(channel.code, 200)
+
+ # set the user as suspended
+ self.get_success(self.store.set_user_suspended_status(self.user1, True))
+
+ # leave room
+ channel = self.make_request(
+ "POST",
+ f"/rooms/{self.room1}/leave",
+ access_token=self.tok1,
+ )
+ self.assertEqual(channel.code, 200)
class RoomAppserviceTsParamTestCase(unittest.HomeserverTestCase):
@@ -2291,6 +2300,141 @@ class RoomMessageFilterTestCase(RoomBase):
self.assertEqual(len(chunk), 2, [event["content"] for event in chunk])
+class RoomDelayedEventTestCase(RoomBase):
+ """Tests delayed events."""
+
+ user_id = "@sid1:red"
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.room_id = self.helper.create_room_as(self.user_id)
+
+ @unittest.override_config({"max_event_delay_duration": "24h"})
+ def test_send_delayed_invalid_event(self) -> None:
+ """Test sending a delayed event with invalid content."""
+ channel = self.make_request(
+ "PUT",
+ (
+ "rooms/%s/send/m.room.message/mid1?org.matrix.msc4140.delay=2000"
+ % self.room_id
+ ).encode("ascii"),
+ {},
+ )
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result)
+ self.assertNotIn("org.matrix.msc4140.errcode", channel.json_body)
+
+ def test_delayed_event_unsupported_by_default(self) -> None:
+ """Test that sending a delayed event is unsupported with the default config."""
+ channel = self.make_request(
+ "PUT",
+ (
+ "rooms/%s/send/m.room.message/mid1?org.matrix.msc4140.delay=2000"
+ % self.room_id
+ ).encode("ascii"),
+ {"body": "test", "msgtype": "m.text"},
+ )
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result)
+ self.assertEqual(
+ "M_MAX_DELAY_UNSUPPORTED",
+ channel.json_body.get("org.matrix.msc4140.errcode"),
+ channel.json_body,
+ )
+
+ @unittest.override_config({"max_event_delay_duration": "1000"})
+ def test_delayed_event_exceeds_max_delay(self) -> None:
+ """Test that sending a delayed event fails if its delay is longer than allowed."""
+ channel = self.make_request(
+ "PUT",
+ (
+ "rooms/%s/send/m.room.message/mid1?org.matrix.msc4140.delay=2000"
+ % self.room_id
+ ).encode("ascii"),
+ {"body": "test", "msgtype": "m.text"},
+ )
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result)
+ self.assertEqual(
+ "M_MAX_DELAY_EXCEEDED",
+ channel.json_body.get("org.matrix.msc4140.errcode"),
+ channel.json_body,
+ )
+
+ @unittest.override_config({"max_event_delay_duration": "24h"})
+ def test_delayed_event_with_negative_delay(self) -> None:
+ """Test that sending a delayed event fails if its delay is negative."""
+ channel = self.make_request(
+ "PUT",
+ (
+ "rooms/%s/send/m.room.message/mid1?org.matrix.msc4140.delay=-2000"
+ % self.room_id
+ ).encode("ascii"),
+ {"body": "test", "msgtype": "m.text"},
+ )
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result)
+ self.assertEqual(
+ Codes.INVALID_PARAM, channel.json_body["errcode"], channel.json_body
+ )
+
+ @unittest.override_config({"max_event_delay_duration": "24h"})
+ def test_send_delayed_message_event(self) -> None:
+ """Test sending a valid delayed message event."""
+ channel = self.make_request(
+ "PUT",
+ (
+ "rooms/%s/send/m.room.message/mid1?org.matrix.msc4140.delay=2000"
+ % self.room_id
+ ).encode("ascii"),
+ {"body": "test", "msgtype": "m.text"},
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
+
+ @unittest.override_config({"max_event_delay_duration": "24h"})
+ def test_send_delayed_state_event(self) -> None:
+ """Test sending a valid delayed state event."""
+ channel = self.make_request(
+ "PUT",
+ (
+ "rooms/%s/state/m.room.topic/?org.matrix.msc4140.delay=2000"
+ % self.room_id
+ ).encode("ascii"),
+ {"topic": "This is a topic"},
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
+
+ @unittest.override_config(
+ {
+ "max_event_delay_duration": "24h",
+ "rc_message": {"per_second": 1, "burst_count": 2},
+ }
+ )
+ def test_add_delayed_event_ratelimit(self) -> None:
+ """Test that requests to schedule new delayed events are ratelimited by a RateLimiter,
+ which ratelimits them correctly, including by not limiting when the requester is
+ exempt from ratelimiting.
+ """
+
+ # Test that new delayed events are correctly ratelimited.
+ args = (
+ "POST",
+ (
+ "rooms/%s/send/m.room.message?org.matrix.msc4140.delay=2000"
+ % self.room_id
+ ).encode("ascii"),
+ {"body": "test", "msgtype": "m.text"},
+ )
+ channel = self.make_request(*args)
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
+ channel = self.make_request(*args)
+ self.assertEqual(HTTPStatus.TOO_MANY_REQUESTS, channel.code, channel.result)
+
+ # Add the current user to the ratelimit overrides, allowing them no ratelimiting.
+ self.get_success(
+ self.hs.get_datastores().main.set_ratelimit_for_user(self.user_id, 0, 0)
+ )
+
+ # Test that the new delayed events aren't ratelimited anymore.
+ channel = self.make_request(*args)
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
+
+
class RoomSearchTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
@@ -2457,6 +2601,11 @@ class PublicRoomsRoomTypeFilterTestCase(unittest.HomeserverTestCase):
tok=self.token,
)
+ def default_config(self) -> JsonDict:
+ config = default_config("test")
+ config["room_list_publication_rules"] = [{"action": "allow"}]
+ return config
+
def make_public_rooms_request(
self,
room_types: Optional[List[Union[str, None]]],
@@ -2794,6 +2943,68 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
self.assertEqual(event_content.get("reason"), reason, channel.result)
+class RoomForgottenTestCase(unittest.HomeserverTestCase):
+ """
+ Test forget/forgotten rooms
+ """
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.store = hs.get_datastores().main
+
+ def test_room_not_forgotten_after_unban(self) -> None:
+ """
+ Test what happens when someone is banned from a room, they forget the room, and
+ some time later are unbanned.
+
+ Currently, when they are unbanned, the room isn't forgotten anymore which may or
+ may not be expected.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ room_id = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True)
+ self.helper.join(room_id, user1_id, tok=user1_tok)
+
+ # User1 is banned and forgets the room
+ self.helper.ban(room_id, src=user2_id, targ=user1_id, tok=user2_tok)
+ # User1 forgets the room
+ self.get_success(self.store.forget(user1_id, room_id))
+
+ # The room should show up as forgotten
+ forgotten_room_ids = self.get_success(
+ self.store.get_forgotten_rooms_for_user(user1_id)
+ )
+ self.assertIncludes(forgotten_room_ids, {room_id}, exact=True)
+
+ # Unban user1
+ self.helper.change_membership(
+ room=room_id,
+ src=user2_id,
+ targ=user1_id,
+ membership=Membership.LEAVE,
+ tok=user2_tok,
+ )
+
+ # Room is no longer forgotten because it's a new membership
+ #
+ # XXX: Is this how we actually want it to behave? It seems like ideally, the
+ # room forgotten status should only be reset when the user decides to join again
+ # (or is invited/knocks). This way the room remains forgotten for any ban/leave
+ # transitions.
+ forgotten_room_ids = self.get_success(
+ self.store.get_forgotten_rooms_for_user(user1_id)
+ )
+ self.assertIncludes(forgotten_room_ids, set(), exact=True)
+
+
class LabelsTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
@@ -3577,191 +3788,6 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
self._set_canonical_alias({"alt_aliases": ["@unknown:test"]}, expected_code=400)
-class ThreepidInviteTestCase(unittest.HomeserverTestCase):
- servlets = [
- admin.register_servlets,
- login.register_servlets,
- room.register_servlets,
- ]
-
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.user_id = self.register_user("thomas", "hackme")
- self.tok = self.login("thomas", "hackme")
-
- self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
-
- def test_threepid_invite_spamcheck_deprecated(self) -> None:
- """
- Test allowing/blocking threepid invites with a spam-check module.
-
- In this test, we use the deprecated API in which callbacks return a bool.
- """
- # Mock a few functions to prevent the test from failing due to failing to talk to
- # a remote IS. We keep the mock for make_and_store_3pid_invite around so we
- # can check its call_count later on during the test.
- make_invite_mock = AsyncMock(return_value=(Mock(event_id="abc"), 0))
- self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock # type: ignore[method-assign]
- self.hs.get_identity_handler().lookup_3pid = AsyncMock( # type: ignore[method-assign]
- return_value=None,
- )
-
- # Add a mock to the spamchecker callbacks for user_may_send_3pid_invite. Make it
- # allow everything for now.
- # `spec` argument is needed for this function mock to have `__qualname__`, which
- # is needed for `Measure` metrics buried in SpamChecker.
- mock = AsyncMock(return_value=True, spec=lambda *x: None)
- self.hs.get_module_api_callbacks().spam_checker._user_may_send_3pid_invite_callbacks.append(
- mock
- )
-
- # Send a 3PID invite into the room and check that it succeeded.
- email_to_invite = "teresa@example.com"
- channel = self.make_request(
- method="POST",
- path="/rooms/" + self.room_id + "/invite",
- content={
- "id_server": "example.com",
- "id_access_token": "sometoken",
- "medium": "email",
- "address": email_to_invite,
- },
- access_token=self.tok,
- )
- self.assertEqual(channel.code, 200)
-
- # Check that the callback was called with the right params.
- mock.assert_called_with(self.user_id, "email", email_to_invite, self.room_id)
-
- # Check that the call to send the invite was made.
- make_invite_mock.assert_called_once()
-
- # Now change the return value of the callback to deny any invite and test that
- # we can't send the invite.
- mock.return_value = False
- channel = self.make_request(
- method="POST",
- path="/rooms/" + self.room_id + "/invite",
- content={
- "id_server": "example.com",
- "id_access_token": "sometoken",
- "medium": "email",
- "address": email_to_invite,
- },
- access_token=self.tok,
- )
- self.assertEqual(channel.code, 403)
-
- # Also check that it stopped before calling _make_and_store_3pid_invite.
- make_invite_mock.assert_called_once()
-
- def test_threepid_invite_spamcheck(self) -> None:
- """
- Test allowing/blocking threepid invites with a spam-check module.
-
- In this test, we use the more recent API in which callbacks return a `Union[Codes, Literal["NOT_SPAM"]]`.
- """
- # Mock a few functions to prevent the test from failing due to failing to talk to
- # a remote IS. We keep the mock for make_and_store_3pid_invite around so we
- # can check its call_count later on during the test.
- make_invite_mock = AsyncMock(return_value=(Mock(event_id="abc"), 0))
- self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock # type: ignore[method-assign]
- self.hs.get_identity_handler().lookup_3pid = AsyncMock( # type: ignore[method-assign]
- return_value=None,
- )
-
- # Add a mock to the spamchecker callbacks for user_may_send_3pid_invite. Make it
- # allow everything for now.
- # `spec` argument is needed for this function mock to have `__qualname__`, which
- # is needed for `Measure` metrics buried in SpamChecker.
- mock = AsyncMock(
- return_value=synapse.module_api.NOT_SPAM,
- spec=lambda *x: None,
- )
- self.hs.get_module_api_callbacks().spam_checker._user_may_send_3pid_invite_callbacks.append(
- mock
- )
-
- # Send a 3PID invite into the room and check that it succeeded.
- email_to_invite = "teresa@example.com"
- channel = self.make_request(
- method="POST",
- path="/rooms/" + self.room_id + "/invite",
- content={
- "id_server": "example.com",
- "id_access_token": "sometoken",
- "medium": "email",
- "address": email_to_invite,
- },
- access_token=self.tok,
- )
- self.assertEqual(channel.code, 200)
-
- # Check that the callback was called with the right params.
- mock.assert_called_with(self.user_id, "email", email_to_invite, self.room_id)
-
- # Check that the call to send the invite was made.
- make_invite_mock.assert_called_once()
-
- # Now change the return value of the callback to deny any invite and test that
- # we can't send the invite. We pick an arbitrary error code to be able to check
- # that the same code has been returned
- mock.return_value = Codes.CONSENT_NOT_GIVEN
- channel = self.make_request(
- method="POST",
- path="/rooms/" + self.room_id + "/invite",
- content={
- "id_server": "example.com",
- "id_access_token": "sometoken",
- "medium": "email",
- "address": email_to_invite,
- },
- access_token=self.tok,
- )
- self.assertEqual(channel.code, 403)
- self.assertEqual(channel.json_body["errcode"], Codes.CONSENT_NOT_GIVEN)
-
- # Also check that it stopped before calling _make_and_store_3pid_invite.
- make_invite_mock.assert_called_once()
-
- # Run variant with `Tuple[Codes, dict]`.
- mock.return_value = (Codes.EXPIRED_ACCOUNT, {"field": "value"})
- channel = self.make_request(
- method="POST",
- path="/rooms/" + self.room_id + "/invite",
- content={
- "id_server": "example.com",
- "id_access_token": "sometoken",
- "medium": "email",
- "address": email_to_invite,
- },
- access_token=self.tok,
- )
- self.assertEqual(channel.code, 403)
- self.assertEqual(channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT)
- self.assertEqual(channel.json_body["field"], "value")
-
- # Also check that it stopped before calling _make_and_store_3pid_invite.
- make_invite_mock.assert_called_once()
-
- def test_400_missing_param_without_id_access_token(self) -> None:
- """
- Test that a 3pid invite request returns 400 M_MISSING_PARAM
- if we do not include id_access_token.
- """
- channel = self.make_request(
- method="POST",
- path="/rooms/" + self.room_id + "/invite",
- content={
- "id_server": "example.com",
- "medium": "email",
- "address": "teresa@example.com",
- },
- access_token=self.tok,
- )
- self.assertEqual(channel.code, 400)
- self.assertEqual(channel.json_body["errcode"], "M_MISSING_PARAM")
-
-
class TimestampLookupTestCase(unittest.HomeserverTestCase):
servlets = [
admin.register_servlets,
@@ -3836,10 +3862,25 @@ class UserSuspensionTests(unittest.HomeserverTestCase):
self.user2 = self.register_user("teresa", "hackme")
self.tok2 = self.login("teresa", "hackme")
- self.room1 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1)
+ self.admin = self.register_user("admin", "pass", True)
+ self.admin_tok = self.login("admin", "pass")
+
+ self.room1 = self.helper.create_room_as(
+ room_creator=self.user1, tok=self.tok1, room_version="11"
+ )
self.store = hs.get_datastores().main
- def test_suspended_user_cannot_send_message_to_room(self) -> None:
+ self.room2 = self.helper.create_room_as(
+ room_creator=self.user1, is_public=False, tok=self.tok1
+ )
+ self.helper.send_state(
+ self.room2,
+ EventTypes.RoomEncryption,
+ {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"},
+ tok=self.tok1,
+ )
+
+ def test_suspended_user_cannot_send_message_to_public_room(self) -> None:
# set the user as suspended
self.get_success(self.store.set_user_suspended_status(self.user1, True))
@@ -3849,9 +3890,25 @@ class UserSuspensionTests(unittest.HomeserverTestCase):
access_token=self.tok1,
content={"body": "hello", "msgtype": "m.text"},
)
- self.assertEqual(
- channel.json_body["errcode"], "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED"
+ self.assertEqual(channel.json_body["errcode"], "M_USER_SUSPENDED")
+
+ def test_suspended_user_cannot_send_message_to_encrypted_room(self) -> None:
+ channel = self.make_request(
+ "PUT",
+ f"/_synapse/admin/v1/suspend/{self.user1}",
+ {"suspend": True},
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body, {f"user_{self.user1}_suspended": True})
+
+ channel = self.make_request(
+ "PUT",
+ f"/rooms/{self.room2}/send/m.room.encrypted/1",
+ access_token=self.tok1,
+ content={},
)
+ self.assertEqual(channel.json_body["errcode"], "M_USER_SUSPENDED")
def test_suspended_user_cannot_change_profile_data(self) -> None:
# set the user as suspended
@@ -3864,9 +3921,7 @@ class UserSuspensionTests(unittest.HomeserverTestCase):
content={"avatar_url": "mxc://matrix.org/wefh34uihSDRGhw34"},
shorthand=False,
)
- self.assertEqual(
- channel.json_body["errcode"], "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED"
- )
+ self.assertEqual(channel.json_body["errcode"], "M_USER_SUSPENDED")
channel2 = self.make_request(
"PUT",
@@ -3875,9 +3930,7 @@ class UserSuspensionTests(unittest.HomeserverTestCase):
content={"displayname": "something offensive"},
shorthand=False,
)
- self.assertEqual(
- channel2.json_body["errcode"], "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED"
- )
+ self.assertEqual(channel2.json_body["errcode"], "M_USER_SUSPENDED")
def test_suspended_user_cannot_redact_messages_other_than_their_own(self) -> None:
# first user sends message
@@ -3911,9 +3964,7 @@ class UserSuspensionTests(unittest.HomeserverTestCase):
content={"reason": "bogus"},
shorthand=False,
)
- self.assertEqual(
- channel.json_body["errcode"], "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED"
- )
+ self.assertEqual(channel.json_body["errcode"], "M_USER_SUSPENDED")
# but can redact their own
channel = self.make_request(
@@ -3924,3 +3975,244 @@ class UserSuspensionTests(unittest.HomeserverTestCase):
shorthand=False,
)
self.assertEqual(channel.code, 200)
+
+ channel = self.make_request(
+ "PUT",
+ f"/_matrix/client/v3/rooms/{self.room1}/send/m.room.redaction/3456346",
+ access_token=self.tok1,
+ content={"reason": "bogus", "redacts": event_id},
+ shorthand=False,
+ )
+ self.assertEqual(channel.json_body["errcode"], "M_USER_SUSPENDED")
+
+ channel = self.make_request(
+ "PUT",
+ f"/_matrix/client/v3/rooms/{self.room1}/send/m.room.redaction/3456346",
+ access_token=self.tok1,
+ content={"reason": "bogus", "redacts": event_id2},
+ shorthand=False,
+ )
+ self.assertEqual(channel.code, 200)
+
+ def test_suspended_user_cannot_ban_others(self) -> None:
+ # user to ban joins room user1 created
+ self.make_request("POST", f"/rooms/{self.room1}/join", access_token=self.tok2)
+
+ # suspend user1
+ self.get_success(self.store.set_user_suspended_status(self.user1, True))
+
+ # user1 tries to ban other user while suspended
+ channel = self.make_request(
+ "POST",
+ f"/_matrix/client/v3/rooms/{self.room1}/ban",
+ access_token=self.tok1,
+ content={"reason": "spite", "user_id": self.user2},
+ shorthand=False,
+ )
+ self.assertEqual(channel.json_body["errcode"], "M_USER_SUSPENDED")
+
+ # un-suspend user1
+ self.get_success(self.store.set_user_suspended_status(self.user1, False))
+
+ # ban now goes through
+ channel = self.make_request(
+ "POST",
+ f"/_matrix/client/v3/rooms/{self.room1}/ban",
+ access_token=self.tok1,
+ content={"reason": "spite", "user_id": self.user2},
+ shorthand=False,
+ )
+ self.assertEqual(channel.code, 200)
+
+
+class RoomParticipantTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ login.register_servlets,
+ room.register_servlets,
+ profile.register_servlets,
+ admin.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.user1 = self.register_user("thomas", "hackme")
+ self.tok1 = self.login("thomas", "hackme")
+
+ self.user2 = self.register_user("teresa", "hackme")
+ self.tok2 = self.login("teresa", "hackme")
+
+ self.room1 = self.helper.create_room_as(
+ room_creator=self.user1,
+ tok=self.tok1,
+ # Allow user2 to send state events into the room.
+ extra_content={
+ "power_level_content_override": {
+ "state_default": 0,
+ },
+ },
+ )
+ self.store = hs.get_datastores().main
+
+ @parameterized.expand(
+ [
+ # Should record participation.
+ param(
+ is_state=False,
+ event_type="m.room.message",
+ event_content={
+ "msgtype": "m.text",
+ "body": "I am engaging in this room",
+ },
+ record_participation=True,
+ ),
+ param(
+ is_state=False,
+ event_type="m.room.encrypted",
+ event_content={
+ "algorithm": "m.megolm.v1.aes-sha2",
+ "ciphertext": "AwgAEnACgAkLmt6qF84IK++J7UDH2Za1YVchHyprqTqsg...",
+ "device_id": "RJYKSTBOIE",
+ "sender_key": "IlRMeOPX2e0MurIyfWEucYBRVOEEUMrOHqn/8mLqMjA",
+ "session_id": "X3lUlvLELLYxeTx4yOVu6UDpasGEVO0Jbu+QFnm0cKQ",
+ },
+ record_participation=True,
+ ),
+ # Should not record participation.
+ param(
+ is_state=False,
+ event_type="m.sticker",
+ event_content={
+ "body": "My great sticker",
+ "info": {},
+ "url": "mxc://unused/mxcurl",
+ },
+ record_participation=False,
+ ),
+ # An invalid **state event** with type `m.room.message`
+ param(
+ is_state=True,
+ event_type="m.room.message",
+ event_content={
+ "msgtype": "m.text",
+ "body": "I am engaging in this room",
+ },
+ record_participation=False,
+ ),
+ # An invalid **state event** with type `m.room.encrypted`
+ # Note: this may become valid in the future with encrypted state, though we
+ # still may not want to consider it grounds for marking a user as participating.
+ param(
+ is_state=True,
+ event_type="m.room.encrypted",
+ event_content={
+ "algorithm": "m.megolm.v1.aes-sha2",
+ "ciphertext": "AwgAEnACgAkLmt6qF84IK++J7UDH2Za1YVchHyprqTqsg...",
+ "device_id": "RJYKSTBOIE",
+ "sender_key": "IlRMeOPX2e0MurIyfWEucYBRVOEEUMrOHqn/8mLqMjA",
+ "session_id": "X3lUlvLELLYxeTx4yOVu6UDpasGEVO0Jbu+QFnm0cKQ",
+ },
+ record_participation=False,
+ ),
+ ]
+ )
+ def test_sending_message_records_participation(
+ self,
+ is_state: bool,
+ event_type: str,
+ event_content: JsonDict,
+ record_participation: bool,
+ ) -> None:
+ """
+ Test that sending an various events into a room causes the user to
+ appropriately marked or not marked as a participant in that room.
+ """
+ self.helper.join(self.room1, self.user2, tok=self.tok2)
+
+ # user has not sent any messages, so should not be a participant
+ participant = self.get_success(
+ self.store.get_room_participation(self.user2, self.room1)
+ )
+ self.assertFalse(participant)
+
+ # send an event into the room
+ if is_state:
+ # send a state event
+ self.helper.send_state(
+ self.room1,
+ event_type,
+ body=event_content,
+ tok=self.tok2,
+ )
+ else:
+ # send a non-state event
+ self.helper.send_event(
+ self.room1,
+ event_type,
+ content=event_content,
+ tok=self.tok2,
+ )
+
+ # check whether the user has been marked as a participant
+ participant = self.get_success(
+ self.store.get_room_participation(self.user2, self.room1)
+ )
+ self.assertEqual(participant, record_participation)
+
+ @parameterized.expand(
+ [
+ param(
+ event_type="m.room.message",
+ event_content={
+ "msgtype": "m.text",
+ "body": "I am engaging in this room",
+ },
+ ),
+ param(
+ event_type="m.room.encrypted",
+ event_content={
+ "algorithm": "m.megolm.v1.aes-sha2",
+ "ciphertext": "AwgAEnACgAkLmt6qF84IK++J7UDH2Za1YVchHyprqTqsg...",
+ "device_id": "RJYKSTBOIE",
+ "sender_key": "IlRMeOPX2e0MurIyfWEucYBRVOEEUMrOHqn/8mLqMjA",
+ "session_id": "X3lUlvLELLYxeTx4yOVu6UDpasGEVO0Jbu+QFnm0cKQ",
+ },
+ ),
+ ]
+ )
+ def test_sending_event_and_leaving_does_not_record_participation(
+ self,
+ event_type: str,
+ event_content: JsonDict,
+ ) -> None:
+ """
+ Test that sending an event into a room that should mark a user as a
+ participant, but then leaving the room, results in the user no longer
+ be marked as a participant in that room.
+ """
+ self.helper.join(self.room1, self.user2, tok=self.tok2)
+
+ # user has not sent any messages, so should not be a participant
+ participant = self.get_success(
+ self.store.get_room_participation(self.user2, self.room1)
+ )
+ self.assertFalse(participant)
+
+ # sending a message should now mark user as participant
+ self.helper.send_event(
+ self.room1,
+ event_type,
+ content=event_content,
+ tok=self.tok2,
+ )
+ participant = self.get_success(
+ self.store.get_room_participation(self.user2, self.room1)
+ )
+ self.assertTrue(participant)
+
+ # leave the room
+ self.helper.leave(self.room1, self.user2, tok=self.tok2)
+
+ # user should no longer be considered a participant
+ participant = self.get_success(
+ self.store.get_room_participation(self.user2, self.room1)
+ )
+ self.assertFalse(participant)
diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py
index 2287f233b4..b406a578f0 100644
--- a/tests/rest/client/test_shadow_banned.py
+++ b/tests/rest/client/test_shadow_banned.py
@@ -88,35 +88,6 @@ class RoomTestCase(_ShadowBannedBase):
)
self.assertEqual(invited_rooms, [])
- def test_invite_3pid(self) -> None:
- """Ensure that a 3PID invite does not attempt to contact the identity server."""
- identity_handler = self.hs.get_identity_handler()
- identity_handler.lookup_3pid = Mock( # type: ignore[method-assign]
- side_effect=AssertionError("This should not get called")
- )
-
- # The create works fine.
- room_id = self.helper.create_room_as(
- self.banned_user_id, tok=self.banned_access_token
- )
-
- # Inviting the user completes successfully.
- channel = self.make_request(
- "POST",
- "/rooms/%s/invite" % (room_id,),
- {
- "id_server": "test",
- "medium": "email",
- "address": "test@test.test",
- "id_access_token": "anytoken",
- },
- access_token=self.banned_access_token,
- )
- self.assertEqual(200, channel.code, channel.result)
-
- # This should have raised an error earlier, but double check this wasn't called.
- identity_handler.lookup_3pid.assert_not_called()
-
def test_create_room(self) -> None:
"""Invitations during a room creation should be discarded, but the room still gets created."""
# The room creation is successful.
diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py
index 63df31ec75..c52a5b2e79 100644
--- a/tests/rest/client/test_sync.py
+++ b/tests/rest/client/test_sync.py
@@ -282,22 +282,33 @@ class SyncTypingTests(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code)
next_batch = channel.json_body["next_batch"]
- # This should time out! But it does not, because our stream token is
- # ahead, and therefore it's saying the typing (that we've actually
- # already seen) is new, since it's got a token above our new, now-reset
- # stream token.
- channel = self.make_request("GET", sync_url % (access_token, next_batch))
- self.assertEqual(200, channel.code)
- next_batch = channel.json_body["next_batch"]
-
# Clear the typing information, so that it doesn't think everything is
- # in the future.
+ # in the future. This happens automatically when the typing stream
+ # resets.
typing._reset()
- # Now it SHOULD fail as it never completes!
+ # Nothing new, so we time out.
with self.assertRaises(TimedOutException):
self.make_request("GET", sync_url % (access_token, next_batch))
+ # Sync and start typing again.
+ sync_channel = self.make_request(
+ "GET", sync_url % (access_token, next_batch), await_result=False
+ )
+ self.assertFalse(sync_channel.is_finished())
+
+ channel = self.make_request(
+ "PUT",
+ typing_url % (room, other_user_id, other_access_token),
+ b'{"typing": true, "timeout": 30000}',
+ )
+ self.assertEqual(200, channel.code)
+
+ # Sync should now return.
+ sync_channel.await_result()
+ self.assertEqual(200, sync_channel.code)
+ next_batch = sync_channel.json_body["next_batch"]
+
class SyncKnockTestCase(KnockingStrippedStateEventHelperMixin):
servlets = [
diff --git a/tests/rest/client/test_tags.py b/tests/rest/client/test_tags.py
new file mode 100644
index 0000000000..5d596409e1
--- /dev/null
+++ b/tests/rest/client/test_tags.py
@@ -0,0 +1,95 @@
+#
+# This file is licensed under the Affero General Public License (AGPL) version 3.
+#
+# Copyright (C) 2024 New Vector, Ltd
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as
+# published by the Free Software Foundation, either version 3 of the
+# License, or (at your option) any later version.
+#
+# See the GNU Affero General Public License for more details:
+# <https://www.gnu.org/licenses/agpl-3.0.html>.
+#
+
+"""Tests REST events for /tags paths."""
+
+from http import HTTPStatus
+
+import synapse.rest.admin
+from synapse.rest.client import login, room, tags
+
+from tests import unittest
+
+
+class RoomTaggingTestCase(unittest.HomeserverTestCase):
+ """Tests /user/$user_id/rooms/$room_id/tags/$tag REST API."""
+
+ servlets = [
+ room.register_servlets,
+ tags.register_servlets,
+ login.register_servlets,
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ ]
+
+ def test_put_tag_checks_room_membership(self) -> None:
+ """
+ Test that a user can add a tag to a room if they have membership to the room.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+ tag = "test_tag"
+
+ # Make the request
+ channel = self.make_request(
+ "PUT",
+ f"/user/{user1_id}/rooms/{room_id}/tags/{tag}",
+ content={"order": 0.5},
+ access_token=user1_tok,
+ )
+ # Check that the request was successful
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+
+ def test_put_tag_fails_if_not_in_room(self) -> None:
+ """
+ Test that a user cannot add a tag to a room if they don't have membership to the
+ room.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+ # Create the room with user2 (user1 has no membership in the room)
+ room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
+ tag = "test_tag"
+
+ # Make the request
+ channel = self.make_request(
+ "PUT",
+ f"/user/{user1_id}/rooms/{room_id}/tags/{tag}",
+ content={"order": 0.5},
+ access_token=user1_tok,
+ )
+ # Check that the request failed with the correct error
+ self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result)
+
+ def test_put_tag_fails_if_room_does_not_exist(self) -> None:
+ """
+ Test that a user cannot add a tag to a room if the room doesn't exist (therefore
+ no membership in the room.)
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ room_id = "!nonexistent:test"
+ tag = "test_tag"
+
+ # Make the request
+ channel = self.make_request(
+ "PUT",
+ f"/user/{user1_id}/rooms/{room_id}/tags/{tag}",
+ content={"order": 0.5},
+ access_token=user1_tok,
+ )
+ # Check that the request failed with the correct error
+ self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result)
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index d10df1a90f..f02317533e 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -915,162 +915,3 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
# Check that the mock was called with the right room ID
self.assertEqual(args[1], self.room_id)
-
- def test_on_threepid_bind(self) -> None:
- """Tests that the on_threepid_bind module callback is called correctly after
- associating a 3PID to an account.
- """
- # Register a mocked callback.
- threepid_bind_mock = AsyncMock(return_value=None)
- third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules
- third_party_rules._on_threepid_bind_callbacks.append(threepid_bind_mock)
-
- # Register an admin user.
- self.register_user("admin", "password", admin=True)
- admin_tok = self.login("admin", "password")
-
- # Also register a normal user we can modify.
- user_id = self.register_user("user", "password")
-
- # Add a 3PID to the user.
- channel = self.make_request(
- "PUT",
- "/_synapse/admin/v2/users/%s" % user_id,
- {
- "threepids": [
- {
- "medium": "email",
- "address": "foo@example.com",
- },
- ],
- },
- access_token=admin_tok,
- )
-
- # Check that the shutdown was blocked
- self.assertEqual(channel.code, 200, channel.json_body)
-
- # Check that the mock was called once.
- threepid_bind_mock.assert_called_once()
- args = threepid_bind_mock.call_args[0]
-
- # Check that the mock was called with the right parameters
- self.assertEqual(args, (user_id, "email", "foo@example.com"))
-
- def test_on_add_and_remove_user_third_party_identifier(self) -> None:
- """Tests that the on_add_user_third_party_identifier and
- on_remove_user_third_party_identifier module callbacks are called
- just before associating and removing a 3PID to/from an account.
- """
- # Pretend to be a Synapse module and register both callbacks as mocks.
- on_add_user_third_party_identifier_callback_mock = AsyncMock(return_value=None)
- on_remove_user_third_party_identifier_callback_mock = AsyncMock(
- return_value=None
- )
- self.hs.get_module_api().register_third_party_rules_callbacks(
- on_add_user_third_party_identifier=on_add_user_third_party_identifier_callback_mock,
- on_remove_user_third_party_identifier=on_remove_user_third_party_identifier_callback_mock,
- )
-
- # Register an admin user.
- self.register_user("admin", "password", admin=True)
- admin_tok = self.login("admin", "password")
-
- # Also register a normal user we can modify.
- user_id = self.register_user("user", "password")
-
- # Add a 3PID to the user.
- channel = self.make_request(
- "PUT",
- "/_synapse/admin/v2/users/%s" % user_id,
- {
- "threepids": [
- {
- "medium": "email",
- "address": "foo@example.com",
- },
- ],
- },
- access_token=admin_tok,
- )
-
- # Check that the mocked add callback was called with the appropriate
- # 3PID details.
- self.assertEqual(channel.code, 200, channel.json_body)
- on_add_user_third_party_identifier_callback_mock.assert_called_once()
- args = on_add_user_third_party_identifier_callback_mock.call_args[0]
- self.assertEqual(args, (user_id, "email", "foo@example.com"))
-
- # Now remove the 3PID from the user
- channel = self.make_request(
- "PUT",
- "/_synapse/admin/v2/users/%s" % user_id,
- {
- "threepids": [],
- },
- access_token=admin_tok,
- )
-
- # Check that the mocked remove callback was called with the appropriate
- # 3PID details.
- self.assertEqual(channel.code, 200, channel.json_body)
- on_remove_user_third_party_identifier_callback_mock.assert_called_once()
- args = on_remove_user_third_party_identifier_callback_mock.call_args[0]
- self.assertEqual(args, (user_id, "email", "foo@example.com"))
-
- def test_on_remove_user_third_party_identifier_is_called_on_deactivate(
- self,
- ) -> None:
- """Tests that the on_remove_user_third_party_identifier module callback is called
- when a user is deactivated and their third-party ID associations are deleted.
- """
- # Pretend to be a Synapse module and register both callbacks as mocks.
- on_remove_user_third_party_identifier_callback_mock = AsyncMock(
- return_value=None
- )
- self.hs.get_module_api().register_third_party_rules_callbacks(
- on_remove_user_third_party_identifier=on_remove_user_third_party_identifier_callback_mock,
- )
-
- # Register an admin user.
- self.register_user("admin", "password", admin=True)
- admin_tok = self.login("admin", "password")
-
- # Also register a normal user we can modify.
- user_id = self.register_user("user", "password")
-
- # Add a 3PID to the user.
- channel = self.make_request(
- "PUT",
- "/_synapse/admin/v2/users/%s" % user_id,
- {
- "threepids": [
- {
- "medium": "email",
- "address": "foo@example.com",
- },
- ],
- },
- access_token=admin_tok,
- )
- self.assertEqual(channel.code, 200, channel.json_body)
-
- # Check that the mock was not called on the act of adding a third-party ID.
- on_remove_user_third_party_identifier_callback_mock.assert_not_called()
-
- # Now deactivate the user.
- channel = self.make_request(
- "PUT",
- "/_synapse/admin/v2/users/%s" % user_id,
- {
- "deactivated": True,
- },
- access_token=admin_tok,
- )
-
- # Check that the mocked remove callback was called with the appropriate
- # 3PID details.
- self.assertEqual(channel.code, 200, channel.json_body)
- on_remove_user_third_party_identifier_callback_mock.assert_called_once()
- args = on_remove_user_third_party_identifier_callback_mock.call_args[0]
- self.assertEqual(args, (user_id, "email", "foo@example.com"))
diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index e43140720d..280486da08 100644
--- a/tests/rest/client/utils.py
+++ b/tests/rest/client/utils.py
@@ -31,6 +31,7 @@ from typing import (
AnyStr,
Dict,
Iterable,
+ Literal,
Mapping,
MutableMapping,
Optional,
@@ -40,12 +41,11 @@ from typing import (
from urllib.parse import urlencode
import attr
-from typing_extensions import Literal
from twisted.test.proto_helpers import MemoryReactorClock
from twisted.web.server import Site
-from synapse.api.constants import Membership
+from synapse.api.constants import Membership, ReceiptTypes
from synapse.api.errors import Codes
from synapse.server import HomeServer
from synapse.types import JsonDict
@@ -330,22 +330,24 @@ class RestHelper:
data,
)
- assert (
- channel.code == expect_code
- ), "Expected: %d, got: %d, PUT %s -> resp: %r" % (
- expect_code,
- channel.code,
- path,
- channel.result["body"],
+ assert channel.code == expect_code, (
+ "Expected: %d, got: %d, PUT %s -> resp: %r"
+ % (
+ expect_code,
+ channel.code,
+ path,
+ channel.result["body"],
+ )
)
if expect_errcode:
- assert (
- str(channel.json_body["errcode"]) == expect_errcode
- ), "Expected: %r, got: %r, resp: %r" % (
- expect_errcode,
- channel.json_body["errcode"],
- channel.result["body"],
+ assert str(channel.json_body["errcode"]) == expect_errcode, (
+ "Expected: %r, got: %r, resp: %r"
+ % (
+ expect_errcode,
+ channel.json_body["errcode"],
+ channel.result["body"],
+ )
)
if expect_additional_fields is not None:
@@ -354,13 +356,14 @@ class RestHelper:
expect_key,
channel.json_body,
)
- assert (
- channel.json_body[expect_key] == expect_value
- ), "Expected: %s at %s, got: %s, resp: %s" % (
- expect_value,
- expect_key,
- channel.json_body[expect_key],
- channel.json_body,
+ assert channel.json_body[expect_key] == expect_value, (
+ "Expected: %s at %s, got: %s, resp: %s"
+ % (
+ expect_value,
+ expect_key,
+ channel.json_body[expect_key],
+ channel.json_body,
+ )
)
self.auth_user_id = temp_id
@@ -545,7 +548,7 @@ class RestHelper:
room_id: str,
event_type: str,
body: Dict[str, Any],
- tok: Optional[str],
+ tok: Optional[str] = None,
expect_code: int = HTTPStatus.OK,
state_key: str = "",
) -> JsonDict:
@@ -713,9 +716,9 @@ class RestHelper:
"/login",
content={"type": "m.login.token", "token": login_token},
)
- assert (
- channel.code == expected_status
- ), f"unexpected status in response: {channel.code}"
+ assert channel.code == expected_status, (
+ f"unexpected status in response: {channel.code}"
+ )
return channel.json_body
def auth_via_oidc(
@@ -886,7 +889,7 @@ class RestHelper:
"GET",
uri,
)
- assert channel.code == 302
+ assert channel.code == 302, f"Expected 302 for {uri}, got {channel.code}"
# hit the redirect url again with the right Host header, which should now issue
# a cookie and redirect to the SSO provider.
@@ -898,17 +901,18 @@ class RestHelper:
location = get_location(channel)
parts = urllib.parse.urlsplit(location)
+ next_uri = urllib.parse.urlunsplit(("", "") + parts[2:])
channel = make_request(
self.reactor,
self.site,
"GET",
- urllib.parse.urlunsplit(("", "") + parts[2:]),
+ next_uri,
custom_headers=[
("Host", parts[1]),
],
)
- assert channel.code == 302
+ assert channel.code == 302, f"Expected 302 for {next_uri}, got {channel.code}"
channel.extract_cookies(cookies)
return get_location(channel)
@@ -944,3 +948,15 @@ class RestHelper:
assert len(p.links) == 1, "not exactly one link in confirmation page"
oauth_uri = p.links[0]
return oauth_uri
+
+ def send_read_receipt(self, room_id: str, event_id: str, *, tok: str) -> None:
+ """Send a read receipt into the room at the given event"""
+ channel = make_request(
+ self.reactor,
+ self.site,
+ method="POST",
+ path=f"/rooms/{room_id}/receipt/{ReceiptTypes.READ}/{event_id}",
+ content={},
+ access_token=tok,
+ )
+ assert channel.code == HTTPStatus.OK, channel.text_body
diff --git a/tests/rest/media/test_domain_blocking.py b/tests/rest/media/test_domain_blocking.py
index 72205c6bb3..26453f70dd 100644
--- a/tests/rest/media/test_domain_blocking.py
+++ b/tests/rest/media/test_domain_blocking.py
@@ -61,6 +61,7 @@ class MediaDomainBlockingTests(unittest.HomeserverTestCase):
time_now_ms=clock.time_msec(),
upload_name="test.png",
filesystem_id=file_id,
+ sha256=file_id,
)
)
@@ -91,7 +92,8 @@ class MediaDomainBlockingTests(unittest.HomeserverTestCase):
{
# Disable downloads from a domain we won't be requesting downloads from.
# This proves we haven't broken anything.
- "prevent_media_downloads_from": ["not-listed.com"]
+ "prevent_media_downloads_from": ["not-listed.com"],
+ "enable_authenticated_media": False,
}
)
def test_remote_media_normally_unblocked(self) -> None:
@@ -132,6 +134,7 @@ class MediaDomainBlockingTests(unittest.HomeserverTestCase):
# This proves we haven't broken anything.
"prevent_media_downloads_from": ["not-listed.com"],
"dynamic_thumbnails": True,
+ "enable_authenticated_media": False,
}
)
def test_remote_media_thumbnail_normally_unblocked(self) -> None:
diff --git a/tests/rest/media/test_url_preview.py b/tests/rest/media/test_url_preview.py
index a96f0e7fca..2a7bee19f9 100644
--- a/tests/rest/media/test_url_preview.py
+++ b/tests/rest/media/test_url_preview.py
@@ -42,6 +42,7 @@ from synapse.util.stringutils import parse_and_validate_mxc_uri
from tests import unittest
from tests.server import FakeTransport
from tests.test_utils import SMALL_PNG
+from tests.unittest import override_config
try:
import lxml
@@ -877,7 +878,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
data = base64.b64encode(SMALL_PNG)
end_content = (
- b"<html><head>" b'<img src="data:image/png;base64,%s" />' b"</head></html>"
+ b'<html><head><img src="data:image/png;base64,%s" /></head></html>'
) % (data,)
channel = self.make_request(
@@ -1259,6 +1260,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertIsNone(_port)
return host, media_id
+ @override_config({"enable_authenticated_media": False})
def test_storage_providers_exclude_files(self) -> None:
"""Test that files are not stored in or fetched from storage providers."""
host, media_id = self._download_image()
@@ -1301,6 +1303,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
"URL cache file was unexpectedly retrieved from a storage provider",
)
+ @override_config({"enable_authenticated_media": False})
def test_storage_providers_exclude_thumbnails(self) -> None:
"""Test that thumbnails are not stored in or fetched from storage providers."""
host, media_id = self._download_image()
diff --git a/tests/rest/test_well_known.py b/tests/rest/test_well_known.py
index e166c13bc1..96a4f5598e 100644
--- a/tests/rest/test_well_known.py
+++ b/tests/rest/test_well_known.py
@@ -17,6 +17,8 @@
# [This file includes modifications made by New Vector Limited]
#
#
+from unittest.mock import AsyncMock
+
from twisted.web.resource import Resource
from synapse.rest.well_known import well_known_resource
@@ -35,7 +37,6 @@ class WellKnownTests(unittest.HomeserverTestCase):
@unittest.override_config(
{
"public_baseurl": "https://tesths",
- "default_identity_server": "https://testis",
}
)
def test_client_well_known(self) -> None:
@@ -48,7 +49,6 @@ class WellKnownTests(unittest.HomeserverTestCase):
channel.json_body,
{
"m.homeserver": {"base_url": "https://tesths/"},
- "m.identity_server": {"base_url": "https://testis"},
},
)
@@ -67,7 +67,6 @@ class WellKnownTests(unittest.HomeserverTestCase):
@unittest.override_config(
{
"public_baseurl": "https://tesths",
- "default_identity_server": "https://testis",
"extra_well_known_client_content": {"custom": False},
}
)
@@ -81,7 +80,6 @@ class WellKnownTests(unittest.HomeserverTestCase):
channel.json_body,
{
"m.homeserver": {"base_url": "https://tesths/"},
- "m.identity_server": {"base_url": "https://testis"},
"custom": False,
},
)
@@ -112,7 +110,6 @@ class WellKnownTests(unittest.HomeserverTestCase):
"msc3861": {
"enabled": True,
"issuer": "https://issuer",
- "account_management_url": "https://my-account.issuer",
"client_id": "id",
"client_auth_method": "client_secret_post",
"client_secret": "secret",
@@ -122,18 +119,33 @@ class WellKnownTests(unittest.HomeserverTestCase):
}
)
def test_client_well_known_msc3861_oauth_delegation(self) -> None:
- channel = self.make_request(
- "GET", "/.well-known/matrix/client", shorthand=False
+ # Patch the HTTP client to return the issuer metadata
+ req_mock = AsyncMock(
+ return_value={
+ "issuer": "https://issuer",
+ "account_management_uri": "https://my-account.issuer",
+ }
)
+ self.hs.get_proxied_http_client().get_json = req_mock # type: ignore[method-assign]
- self.assertEqual(channel.code, 200)
- self.assertEqual(
- channel.json_body,
- {
- "m.homeserver": {"base_url": "https://homeserver/"},
- "org.matrix.msc2965.authentication": {
- "issuer": "https://issuer",
- "account": "https://my-account.issuer",
+ for _ in range(2):
+ channel = self.make_request(
+ "GET", "/.well-known/matrix/client", shorthand=False
+ )
+
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(
+ channel.json_body,
+ {
+ "m.homeserver": {"base_url": "https://homeserver/"},
+ "org.matrix.msc2965.authentication": {
+ "issuer": "https://issuer",
+ "account": "https://my-account.issuer",
+ },
},
- },
+ )
+
+ # It should have been called exactly once, because it gets cached
+ req_mock.assert_called_once_with(
+ "https://issuer/.well-known/openid-configuration"
)
|