diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index f44c91a373..b7fc33dc94 100644
--- a/tests/api/test_filtering.py
+++ b/tests/api/test_filtering.py
@@ -15,6 +15,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from unittest.mock import patch
+
import jsonschema
from synapse.api.constants import EventContentFields
@@ -51,9 +53,8 @@ class FilteringTestCase(unittest.HomeserverTestCase):
{"presence": {"senders": ["@bar;pik.test.com"]}},
]
for filter in invalid_filters:
- with self.assertRaises(SynapseError) as check_filter_error:
+ with self.assertRaises(SynapseError):
self.filtering.check_valid_filter(filter)
- self.assertIsInstance(check_filter_error.exception, SynapseError)
def test_valid_filters(self):
valid_filters = [
@@ -119,12 +120,12 @@ class FilteringTestCase(unittest.HomeserverTestCase):
definition = {"types": ["m.room.message", "org.matrix.foo.bar"]}
event = MockEvent(sender="@foo:bar", type="m.room.message", room_id="!foo:bar")
- self.assertTrue(Filter(definition).check(event))
+ self.assertTrue(Filter(self.hs, definition)._check(event))
def test_definition_types_works_with_wildcards(self):
definition = {"types": ["m.*", "org.matrix.foo.bar"]}
event = MockEvent(sender="@foo:bar", type="m.room.message", room_id="!foo:bar")
- self.assertTrue(Filter(definition).check(event))
+ self.assertTrue(Filter(self.hs, definition)._check(event))
def test_definition_types_works_with_unknowns(self):
definition = {"types": ["m.room.message", "org.matrix.foo.bar"]}
@@ -133,24 +134,24 @@ class FilteringTestCase(unittest.HomeserverTestCase):
type="now.for.something.completely.different",
room_id="!foo:bar",
)
- self.assertFalse(Filter(definition).check(event))
+ self.assertFalse(Filter(self.hs, definition)._check(event))
def test_definition_not_types_works_with_literals(self):
definition = {"not_types": ["m.room.message", "org.matrix.foo.bar"]}
event = MockEvent(sender="@foo:bar", type="m.room.message", room_id="!foo:bar")
- self.assertFalse(Filter(definition).check(event))
+ self.assertFalse(Filter(self.hs, definition)._check(event))
def test_definition_not_types_works_with_wildcards(self):
definition = {"not_types": ["m.room.message", "org.matrix.*"]}
event = MockEvent(
sender="@foo:bar", type="org.matrix.custom.event", room_id="!foo:bar"
)
- self.assertFalse(Filter(definition).check(event))
+ self.assertFalse(Filter(self.hs, definition)._check(event))
def test_definition_not_types_works_with_unknowns(self):
definition = {"not_types": ["m.*", "org.*"]}
event = MockEvent(sender="@foo:bar", type="com.nom.nom.nom", room_id="!foo:bar")
- self.assertTrue(Filter(definition).check(event))
+ self.assertTrue(Filter(self.hs, definition)._check(event))
def test_definition_not_types_takes_priority_over_types(self):
definition = {
@@ -158,35 +159,35 @@ class FilteringTestCase(unittest.HomeserverTestCase):
"types": ["m.room.message", "m.room.topic"],
}
event = MockEvent(sender="@foo:bar", type="m.room.topic", room_id="!foo:bar")
- self.assertFalse(Filter(definition).check(event))
+ self.assertFalse(Filter(self.hs, definition)._check(event))
def test_definition_senders_works_with_literals(self):
definition = {"senders": ["@flibble:wibble"]}
event = MockEvent(
sender="@flibble:wibble", type="com.nom.nom.nom", room_id="!foo:bar"
)
- self.assertTrue(Filter(definition).check(event))
+ self.assertTrue(Filter(self.hs, definition)._check(event))
def test_definition_senders_works_with_unknowns(self):
definition = {"senders": ["@flibble:wibble"]}
event = MockEvent(
sender="@challenger:appears", type="com.nom.nom.nom", room_id="!foo:bar"
)
- self.assertFalse(Filter(definition).check(event))
+ self.assertFalse(Filter(self.hs, definition)._check(event))
def test_definition_not_senders_works_with_literals(self):
definition = {"not_senders": ["@flibble:wibble"]}
event = MockEvent(
sender="@flibble:wibble", type="com.nom.nom.nom", room_id="!foo:bar"
)
- self.assertFalse(Filter(definition).check(event))
+ self.assertFalse(Filter(self.hs, definition)._check(event))
def test_definition_not_senders_works_with_unknowns(self):
definition = {"not_senders": ["@flibble:wibble"]}
event = MockEvent(
sender="@challenger:appears", type="com.nom.nom.nom", room_id="!foo:bar"
)
- self.assertTrue(Filter(definition).check(event))
+ self.assertTrue(Filter(self.hs, definition)._check(event))
def test_definition_not_senders_takes_priority_over_senders(self):
definition = {
@@ -196,14 +197,14 @@ class FilteringTestCase(unittest.HomeserverTestCase):
event = MockEvent(
sender="@misspiggy:muppets", type="m.room.topic", room_id="!foo:bar"
)
- self.assertFalse(Filter(definition).check(event))
+ self.assertFalse(Filter(self.hs, definition)._check(event))
def test_definition_rooms_works_with_literals(self):
definition = {"rooms": ["!secretbase:unknown"]}
event = MockEvent(
sender="@foo:bar", type="m.room.message", room_id="!secretbase:unknown"
)
- self.assertTrue(Filter(definition).check(event))
+ self.assertTrue(Filter(self.hs, definition)._check(event))
def test_definition_rooms_works_with_unknowns(self):
definition = {"rooms": ["!secretbase:unknown"]}
@@ -212,7 +213,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
type="m.room.message",
room_id="!anothersecretbase:unknown",
)
- self.assertFalse(Filter(definition).check(event))
+ self.assertFalse(Filter(self.hs, definition)._check(event))
def test_definition_not_rooms_works_with_literals(self):
definition = {"not_rooms": ["!anothersecretbase:unknown"]}
@@ -221,7 +222,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
type="m.room.message",
room_id="!anothersecretbase:unknown",
)
- self.assertFalse(Filter(definition).check(event))
+ self.assertFalse(Filter(self.hs, definition)._check(event))
def test_definition_not_rooms_works_with_unknowns(self):
definition = {"not_rooms": ["!secretbase:unknown"]}
@@ -230,7 +231,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
type="m.room.message",
room_id="!anothersecretbase:unknown",
)
- self.assertTrue(Filter(definition).check(event))
+ self.assertTrue(Filter(self.hs, definition)._check(event))
def test_definition_not_rooms_takes_priority_over_rooms(self):
definition = {
@@ -240,7 +241,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
event = MockEvent(
sender="@foo:bar", type="m.room.message", room_id="!secretbase:unknown"
)
- self.assertFalse(Filter(definition).check(event))
+ self.assertFalse(Filter(self.hs, definition)._check(event))
def test_definition_combined_event(self):
definition = {
@@ -256,7 +257,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
type="m.room.message", # yup
room_id="!stage:unknown", # yup
)
- self.assertTrue(Filter(definition).check(event))
+ self.assertTrue(Filter(self.hs, definition)._check(event))
def test_definition_combined_event_bad_sender(self):
definition = {
@@ -272,7 +273,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
type="m.room.message", # yup
room_id="!stage:unknown", # yup
)
- self.assertFalse(Filter(definition).check(event))
+ self.assertFalse(Filter(self.hs, definition)._check(event))
def test_definition_combined_event_bad_room(self):
definition = {
@@ -288,7 +289,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
type="m.room.message", # yup
room_id="!piggyshouse:muppets", # nope
)
- self.assertFalse(Filter(definition).check(event))
+ self.assertFalse(Filter(self.hs, definition)._check(event))
def test_definition_combined_event_bad_type(self):
definition = {
@@ -304,7 +305,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
type="muppets.misspiggy.kisses", # nope
room_id="!stage:unknown", # yup
)
- self.assertFalse(Filter(definition).check(event))
+ self.assertFalse(Filter(self.hs, definition)._check(event))
def test_filter_labels(self):
definition = {"org.matrix.labels": ["#fun"]}
@@ -315,7 +316,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
content={EventContentFields.LABELS: ["#fun"]},
)
- self.assertTrue(Filter(definition).check(event))
+ self.assertTrue(Filter(self.hs, definition)._check(event))
event = MockEvent(
sender="@foo:bar",
@@ -324,7 +325,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
content={EventContentFields.LABELS: ["#notfun"]},
)
- self.assertFalse(Filter(definition).check(event))
+ self.assertFalse(Filter(self.hs, definition)._check(event))
def test_filter_not_labels(self):
definition = {"org.matrix.not_labels": ["#fun"]}
@@ -335,7 +336,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
content={EventContentFields.LABELS: ["#fun"]},
)
- self.assertFalse(Filter(definition).check(event))
+ self.assertFalse(Filter(self.hs, definition)._check(event))
event = MockEvent(
sender="@foo:bar",
@@ -344,7 +345,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
content={EventContentFields.LABELS: ["#notfun"]},
)
- self.assertTrue(Filter(definition).check(event))
+ self.assertTrue(Filter(self.hs, definition)._check(event))
def test_filter_presence_match(self):
user_filter_json = {"presence": {"types": ["m.*"]}}
@@ -362,7 +363,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
)
)
- results = user_filter.filter_presence(events=events)
+ results = self.get_success(user_filter.filter_presence(events=events))
self.assertEquals(events, results)
def test_filter_presence_no_match(self):
@@ -386,7 +387,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
)
)
- results = user_filter.filter_presence(events=events)
+ results = self.get_success(user_filter.filter_presence(events=events))
self.assertEquals([], results)
def test_filter_room_state_match(self):
@@ -405,7 +406,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
)
)
- results = user_filter.filter_room_state(events=events)
+ results = self.get_success(user_filter.filter_room_state(events=events))
self.assertEquals(events, results)
def test_filter_room_state_no_match(self):
@@ -426,7 +427,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
)
)
- results = user_filter.filter_room_state(events)
+ results = self.get_success(user_filter.filter_room_state(events))
self.assertEquals([], results)
def test_filter_rooms(self):
@@ -441,10 +442,52 @@ class FilteringTestCase(unittest.HomeserverTestCase):
"!not_included:example.com", # Disallowed because not in rooms.
]
- filtered_room_ids = list(Filter(definition).filter_rooms(room_ids))
+ filtered_room_ids = list(Filter(self.hs, definition).filter_rooms(room_ids))
self.assertEquals(filtered_room_ids, ["!allowed:example.com"])
+ @unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
+ def test_filter_relations(self):
+ events = [
+ # An event without a relation.
+ MockEvent(
+ event_id="$no_relation",
+ sender="@foo:bar",
+ type="org.matrix.custom.event",
+ room_id="!foo:bar",
+ ),
+ # An event with a relation.
+ MockEvent(
+ event_id="$with_relation",
+ sender="@foo:bar",
+ type="org.matrix.custom.event",
+ room_id="!foo:bar",
+ ),
+ # Non-EventBase objects get passed through.
+ {},
+ ]
+
+ # For the following tests we patch the datastore method (intead of injecting
+ # events). This is a bit cheeky, but tests the logic of _check_event_relations.
+
+ # Filter for a particular sender.
+ definition = {
+ "io.element.relation_senders": ["@foo:bar"],
+ }
+
+ async def events_have_relations(*args, **kwargs):
+ return ["$with_relation"]
+
+ with patch.object(
+ self.datastore, "events_have_relations", new=events_have_relations
+ ):
+ filtered_events = list(
+ self.get_success(
+ Filter(self.hs, definition)._check_event_relations(events)
+ )
+ )
+ self.assertEquals(filtered_events, events[1:])
+
def test_add_filter(self):
user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
diff --git a/tests/config/test_load.py b/tests/config/test_load.py
index 765258c47a..69a4e9413b 100644
--- a/tests/config/test_load.py
+++ b/tests/config/test_load.py
@@ -46,15 +46,16 @@ class ConfigLoadingFileTestCase(ConfigFileTestCase):
"was: %r" % (config.key.macaroon_secret_key,)
)
- config = HomeServerConfig.load_or_generate_config("", ["-c", self.config_file])
+ config2 = HomeServerConfig.load_or_generate_config("", ["-c", self.config_file])
+ assert config2 is not None
self.assertTrue(
- hasattr(config.key, "macaroon_secret_key"),
+ hasattr(config2.key, "macaroon_secret_key"),
"Want config to have attr macaroon_secret_key",
)
- if len(config.key.macaroon_secret_key) < 5:
+ if len(config2.key.macaroon_secret_key) < 5:
self.fail(
"Want macaroon secret key to be string of at least length 5,"
- "was: %r" % (config.key.macaroon_secret_key,)
+ "was: %r" % (config2.key.macaroon_secret_key,)
)
def test_load_succeeds_if_macaroon_secret_key_missing(self):
@@ -62,6 +63,9 @@ class ConfigLoadingFileTestCase(ConfigFileTestCase):
config1 = HomeServerConfig.load_config("", ["-c", self.config_file])
config2 = HomeServerConfig.load_config("", ["-c", self.config_file])
config3 = HomeServerConfig.load_or_generate_config("", ["-c", self.config_file])
+ assert config1 is not None
+ assert config2 is not None
+ assert config3 is not None
self.assertEqual(
config1.key.macaroon_secret_key, config2.key.macaroon_secret_key
)
@@ -78,14 +82,16 @@ class ConfigLoadingFileTestCase(ConfigFileTestCase):
config = HomeServerConfig.load_config("", ["-c", self.config_file])
self.assertFalse(config.registration.enable_registration)
- config = HomeServerConfig.load_or_generate_config("", ["-c", self.config_file])
- self.assertFalse(config.registration.enable_registration)
+ config2 = HomeServerConfig.load_or_generate_config("", ["-c", self.config_file])
+ assert config2 is not None
+ self.assertFalse(config2.registration.enable_registration)
# Check that either config value is clobbered by the command line.
- config = HomeServerConfig.load_or_generate_config(
+ config3 = HomeServerConfig.load_or_generate_config(
"", ["-c", self.config_file, "--enable-registration"]
)
- self.assertTrue(config.registration.enable_registration)
+ assert config3 is not None
+ self.assertTrue(config3.registration.enable_registration)
def test_stats_enabled(self):
self.generate_config_and_remove_lines_containing("enable_metrics")
@@ -94,3 +100,12 @@ class ConfigLoadingFileTestCase(ConfigFileTestCase):
# The default Metrics Flags are off by default.
config = HomeServerConfig.load_config("", ["-c", self.config_file])
self.assertFalse(config.metrics.metrics_flags.known_servers)
+
+ def test_depreciated_identity_server_flag_throws_error(self):
+ self.generate_config()
+ # Needed to ensure that actual key/value pair added below don't end up on a line with a comment
+ self.add_lines_to_config([" "])
+ # Check that presence of "trust_identity_server_for_password" throws config error
+ self.add_lines_to_config(["trust_identity_server_for_password_resets: true"])
+ with self.assertRaises(ConfigError):
+ HomeServerConfig.load_config("", ["-c", self.config_file])
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index cbecc1c20f..4d1e154578 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -1,4 +1,4 @@
-# Copyright 2017 New Vector Ltd
+# Copyright 2017-2021 The Matrix.org Foundation C.I.C
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -40,7 +40,7 @@ from synapse.storage.keys import FetchKeyResult
from tests import unittest
from tests.test_utils import make_awaitable
-from tests.unittest import logcontext_clean
+from tests.unittest import logcontext_clean, override_config
class MockPerspectiveServer:
@@ -197,7 +197,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
# self.assertFalse(d.called)
self.get_success(d)
- def test_verify_for_server_locally(self):
+ def test_verify_for_local_server(self):
"""Ensure that locally signed JSON can be verified without fetching keys
over federation
"""
@@ -209,6 +209,56 @@ class KeyringTestCase(unittest.HomeserverTestCase):
d = kr.verify_json_for_server(self.hs.hostname, json1, 0)
self.get_success(d)
+ OLD_KEY = signedjson.key.generate_signing_key("old")
+
+ @override_config(
+ {
+ "old_signing_keys": {
+ f"{OLD_KEY.alg}:{OLD_KEY.version}": {
+ "key": encode_verify_key_base64(OLD_KEY.verify_key),
+ "expired_ts": 1000,
+ }
+ }
+ }
+ )
+ def test_verify_for_local_server_old_key(self):
+ """Can also use keys in old_signing_keys for verification"""
+ json1 = {}
+ signedjson.sign.sign_json(json1, self.hs.hostname, self.OLD_KEY)
+
+ kr = keyring.Keyring(self.hs)
+ d = kr.verify_json_for_server(self.hs.hostname, json1, 0)
+ self.get_success(d)
+
+ def test_verify_for_local_server_unknown_key(self):
+ """Local keys that we no longer have should be fetched via the fetcher"""
+
+ # the key we'll sign things with (nb, not known to the Keyring)
+ key2 = signedjson.key.generate_signing_key("2")
+
+ # set up a mock fetcher which will return the key
+ async def get_keys(
+ server_name: str, key_ids: List[str], minimum_valid_until_ts: int
+ ) -> Dict[str, FetchKeyResult]:
+ self.assertEqual(server_name, self.hs.hostname)
+ self.assertEqual(key_ids, [get_key_id(key2)])
+
+ return {get_key_id(key2): FetchKeyResult(get_verify_key(key2), 1200)}
+
+ mock_fetcher = Mock()
+ mock_fetcher.get_keys = Mock(side_effect=get_keys)
+ kr = keyring.Keyring(
+ self.hs, key_fetchers=(StoreKeyFetcher(self.hs), mock_fetcher)
+ )
+
+ # sign the json
+ json1 = {}
+ signedjson.sign.sign_json(json1, self.hs.hostname, key2)
+
+ # ... and check we can verify it.
+ d = kr.verify_json_for_server(self.hs.hostname, json1, 0)
+ self.get_success(d)
+
def test_verify_json_for_server_with_null_valid_until_ms(self):
"""Tests that we correctly handle key requests for keys we've stored
with a null `ts_valid_until_ms`
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index 1f6a924452..d6f14e2dba 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -272,7 +272,9 @@ class AppServiceHandlerTestCase(unittest.TestCase):
make_awaitable(([event], None))
)
- self.handler.notify_interested_services_ephemeral("receipt_key", 580)
+ self.handler.notify_interested_services_ephemeral(
+ "receipt_key", 580, ["@fakerecipient:example.com"]
+ )
self.mock_scheduler.submit_ephemeral_events_for_as.assert_called_once_with(
interested_service, [event]
)
@@ -300,7 +302,9 @@ class AppServiceHandlerTestCase(unittest.TestCase):
make_awaitable(([event], None))
)
- self.handler.notify_interested_services_ephemeral("receipt_key", 579)
+ self.handler.notify_interested_services_ephemeral(
+ "receipt_key", 580, ["@fakerecipient:example.com"]
+ )
self.mock_scheduler.submit_ephemeral_events_for_as.assert_not_called()
def _mkservice(self, is_interested, protocols=None):
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index 12857053e7..72e176da75 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -116,7 +116,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.auth_blocking._limit_usage_by_mau = False
# Ensure does not throw exception
self.get_success(
- self.auth_handler.get_access_token_for_user_id(
+ self.auth_handler.create_access_token_for_user_id(
self.user1, device_id=None, valid_until_ms=None
)
)
@@ -134,7 +134,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
)
self.get_failure(
- self.auth_handler.get_access_token_for_user_id(
+ self.auth_handler.create_access_token_for_user_id(
self.user1, device_id=None, valid_until_ms=None
),
ResourceLimitError,
@@ -162,7 +162,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
# If not in monthly active cohort
self.get_failure(
- self.auth_handler.get_access_token_for_user_id(
+ self.auth_handler.create_access_token_for_user_id(
self.user1, device_id=None, valid_until_ms=None
),
ResourceLimitError,
@@ -179,7 +179,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
return_value=make_awaitable(self.clock.time_msec())
)
self.get_success(
- self.auth_handler.get_access_token_for_user_id(
+ self.auth_handler.create_access_token_for_user_id(
self.user1, device_id=None, valid_until_ms=None
)
)
@@ -197,7 +197,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
)
# Ensure does not raise exception
self.get_success(
- self.auth_handler.get_access_token_for_user_id(
+ self.auth_handler.create_access_token_for_user_id(
self.user1, device_id=None, valid_until_ms=None
)
)
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index be008227df..0ea4e753e2 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -1,4 +1,5 @@
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2021 Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,13 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
from unittest.mock import Mock
import synapse.api.errors
import synapse.rest.admin
from synapse.api.constants import EventTypes
-from synapse.config.room_directory import RoomDirectoryConfig
from synapse.rest.client import directory, login, room
from synapse.types import RoomAlias, create_requester
@@ -394,22 +393,15 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
servlets = [directory.register_servlets, room.register_servlets]
- def prepare(self, reactor, clock, hs):
- # We cheekily override the config to add custom alias creation rules
- config = {}
+ def default_config(self):
+ config = super().default_config()
+
+ # Add custom alias creation rules to the config.
config["alias_creation_rules"] = [
{"user_id": "*", "alias": "#unofficial_*", "action": "allow"}
]
- config["room_list_publication_rules"] = []
- rd_config = RoomDirectoryConfig()
- rd_config.read_config(config)
-
- self.hs.config.roomdirectory.is_alias_creation_allowed = (
- rd_config.is_alias_creation_allowed
- )
-
- return hs
+ return config
def test_denied(self):
room_id = self.helper.create_room_as(self.user_id)
@@ -417,7 +409,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
channel = self.make_request(
"PUT",
b"directory/room/%23test%3Atest",
- ('{"room_id":"%s"}' % (room_id,)).encode("ascii"),
+ {"room_id": room_id},
)
self.assertEquals(403, channel.code, channel.result)
@@ -427,14 +419,35 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
channel = self.make_request(
"PUT",
b"directory/room/%23unofficial_test%3Atest",
- ('{"room_id":"%s"}' % (room_id,)).encode("ascii"),
+ {"room_id": room_id},
)
self.assertEquals(200, channel.code, channel.result)
+ def test_denied_during_creation(self):
+ """A room alias that is not allowed should be rejected during creation."""
+ # Invalid room alias.
+ self.helper.create_room_as(
+ self.user_id,
+ expect_code=403,
+ extra_content={"room_alias_name": "foo"},
+ )
-class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
- data = {"room_alias_name": "unofficial_test"}
+ def test_allowed_during_creation(self):
+ """A valid room alias should be allowed during creation."""
+ room_id = self.helper.create_room_as(
+ self.user_id,
+ extra_content={"room_alias_name": "unofficial_test"},
+ )
+ channel = self.make_request(
+ "GET",
+ b"directory/room/%23unofficial_test%3Atest",
+ )
+ self.assertEquals(200, channel.code, channel.result)
+ self.assertEquals(channel.json_body["room_id"], room_id)
+
+
+class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
login.register_servlets,
@@ -443,27 +456,30 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
]
hijack_auth = False
- def prepare(self, reactor, clock, hs):
- self.allowed_user_id = self.register_user("allowed", "pass")
- self.allowed_access_token = self.login("allowed", "pass")
+ data = {"room_alias_name": "unofficial_test"}
+ allowed_localpart = "allowed"
- self.denied_user_id = self.register_user("denied", "pass")
- self.denied_access_token = self.login("denied", "pass")
+ def default_config(self):
+ config = super().default_config()
- # This time we add custom room list publication rules
- config = {}
- config["alias_creation_rules"] = []
+ # Add custom room list publication rules to the config.
config["room_list_publication_rules"] = [
+ {
+ "user_id": "@" + self.allowed_localpart + "*",
+ "alias": "#unofficial_*",
+ "action": "allow",
+ },
{"user_id": "*", "alias": "*", "action": "deny"},
- {"user_id": self.allowed_user_id, "alias": "*", "action": "allow"},
]
- rd_config = RoomDirectoryConfig()
- rd_config.read_config(config)
+ return config
- self.hs.config.roomdirectory.is_publishing_room_allowed = (
- rd_config.is_publishing_room_allowed
- )
+ def prepare(self, reactor, clock, hs):
+ self.allowed_user_id = self.register_user(self.allowed_localpart, "pass")
+ self.allowed_access_token = self.login(self.allowed_localpart, "pass")
+
+ self.denied_user_id = self.register_user("denied", "pass")
+ self.denied_access_token = self.login("denied", "pass")
return hs
@@ -505,10 +521,23 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
self.allowed_user_id,
tok=self.allowed_access_token,
extra_content=self.data,
- is_public=False,
+ is_public=True,
expect_code=200,
)
+ def test_denied_publication_with_invalid_alias(self):
+ """
+ Try to create a room, register an alias for it, and publish it,
+ as a user WITH permission to publish rooms.
+ """
+ self.helper.create_room_as(
+ self.allowed_user_id,
+ tok=self.allowed_access_token,
+ extra_content={"room_alias_name": "foo"},
+ is_public=True,
+ expect_code=403,
+ )
+
def test_can_create_as_private_room_after_rejection(self):
"""
After failing to publish a room with an alias as a user without publish permission,
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 0c3b86fda9..f0723892e4 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -162,6 +162,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
local_user = "@boris:" + self.hs.hostname
device_id = "xyz"
fallback_key = {"alg1:k1": "key1"}
+ fallback_key2 = {"alg1:k2": "key2"}
otk = {"alg1:k2": "key2"}
# we shouldn't have any unused fallback keys yet
@@ -213,6 +214,35 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
)
+ # re-uploading the same fallback key should still result in no unused fallback
+ # keys
+ self.get_success(
+ self.handler.upload_keys_for_user(
+ local_user,
+ device_id,
+ {"org.matrix.msc2732.fallback_keys": fallback_key},
+ )
+ )
+
+ res = self.get_success(
+ self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
+ )
+ self.assertEqual(res, [])
+
+ # uploading a new fallback key should result in an unused fallback key
+ self.get_success(
+ self.handler.upload_keys_for_user(
+ local_user,
+ device_id,
+ {"org.matrix.msc2732.fallback_keys": fallback_key2},
+ )
+ )
+
+ res = self.get_success(
+ self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
+ )
+ self.assertEqual(res, ["alg1"])
+
# if the user uploads a one-time key, the next claim should fetch the
# one-time key, and then go back to the fallback
self.get_success(
@@ -238,7 +268,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(
res,
- {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
+ {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key2}}},
)
def test_replace_master_key(self):
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index 7dd4a5a367..08e9730d4d 100644
--- a/tests/handlers/test_password_providers.py
+++ b/tests/handlers/test_password_providers.py
@@ -31,7 +31,10 @@ from tests.unittest import override_config
# (possibly experimental) login flows we expect to appear in the list after the normal
# ones
-ADDITIONAL_LOGIN_FLOWS = [{"type": "uk.half-shot.msc2778.login.application_service"}]
+ADDITIONAL_LOGIN_FLOWS = [
+ {"type": "m.login.application_service"},
+ {"type": "uk.half-shot.msc2778.login.application_service"},
+]
# a mock instance which the dummy auth providers delegate to, so we can see what's going
# on
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index db691c4c1c..cd6f2c77ae 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -193,7 +193,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
@override_config({"limit_usage_by_mau": True})
def test_get_or_create_user_mau_not_blocked(self):
- self.store.count_monthly_users = Mock(
+ # Type ignore: mypy doesn't like us assigning to methods.
+ self.store.count_monthly_users = Mock( # type: ignore[assignment]
return_value=make_awaitable(self.hs.config.server.max_mau_value - 1)
)
# Ensure does not throw exception
@@ -201,7 +202,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
@override_config({"limit_usage_by_mau": True})
def test_get_or_create_user_mau_blocked(self):
- self.store.get_monthly_active_count = Mock(
+ # Type ignore: mypy doesn't like us assigning to methods.
+ self.store.get_monthly_active_count = Mock( # type: ignore[assignment]
return_value=make_awaitable(self.lots_of_users)
)
self.get_failure(
@@ -209,7 +211,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
ResourceLimitError,
)
- self.store.get_monthly_active_count = Mock(
+ # Type ignore: mypy doesn't like us assigning to methods.
+ self.store.get_monthly_active_count = Mock( # type: ignore[assignment]
return_value=make_awaitable(self.hs.config.server.max_mau_value)
)
self.get_failure(
diff --git a/tests/handlers/test_room_summary.py b/tests/handlers/test_room_summary.py
index d3d0bf1ac5..7b95844b55 100644
--- a/tests/handlers/test_room_summary.py
+++ b/tests/handlers/test_room_summary.py
@@ -14,6 +14,8 @@
from typing import Any, Iterable, List, Optional, Tuple
from unittest import mock
+from twisted.internet.defer import ensureDeferred
+
from synapse.api.constants import (
EventContentFields,
EventTypes,
@@ -316,6 +318,59 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
AuthError,
)
+ def test_room_hierarchy_cache(self) -> None:
+ """In-flight room hierarchy requests are deduplicated."""
+ # Run two `get_room_hierarchy` calls up until they block.
+ deferred1 = ensureDeferred(
+ self.handler.get_room_hierarchy(self.user, self.space)
+ )
+ deferred2 = ensureDeferred(
+ self.handler.get_room_hierarchy(self.user, self.space)
+ )
+
+ # Complete the two calls.
+ result1 = self.get_success(deferred1)
+ result2 = self.get_success(deferred2)
+
+ # Both `get_room_hierarchy` calls should return the same result.
+ expected = [(self.space, [self.room]), (self.room, ())]
+ self._assert_hierarchy(result1, expected)
+ self._assert_hierarchy(result2, expected)
+ self.assertIs(result1, result2)
+
+ # A subsequent `get_room_hierarchy` call should not reuse the result.
+ result3 = self.get_success(
+ self.handler.get_room_hierarchy(self.user, self.space)
+ )
+ self._assert_hierarchy(result3, expected)
+ self.assertIsNot(result1, result3)
+
+ def test_room_hierarchy_cache_sharing(self) -> None:
+ """Room hierarchy responses for different users are not shared."""
+ user2 = self.register_user("user2", "pass")
+
+ # Make the room within the space invite-only.
+ self.helper.send_state(
+ self.room,
+ event_type=EventTypes.JoinRules,
+ body={"join_rule": JoinRules.INVITE},
+ tok=self.token,
+ )
+
+ # Run two `get_room_hierarchy` calls for different users up until they block.
+ deferred1 = ensureDeferred(
+ self.handler.get_room_hierarchy(self.user, self.space)
+ )
+ deferred2 = ensureDeferred(self.handler.get_room_hierarchy(user2, self.space))
+
+ # Complete the two calls.
+ result1 = self.get_success(deferred1)
+ result2 = self.get_success(deferred2)
+
+ # The `get_room_hierarchy` calls should return different results.
+ self._assert_hierarchy(result1, [(self.space, [self.room]), (self.room, ())])
+ self._assert_hierarchy(result2, [(self.space, [self.room])])
+
def _create_room_with_join_rule(
self, join_rule: str, room_version: Optional[str] = None, **extra_content
) -> str:
diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index 339c039914..638186f173 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -13,10 +13,11 @@
# limitations under the License.
from typing import Optional
+from unittest.mock import Mock
from synapse.api.constants import EventTypes, JoinRules
from synapse.api.errors import Codes, ResourceLimitError
-from synapse.api.filtering import DEFAULT_FILTER_COLLECTION
+from synapse.api.filtering import Filtering
from synapse.api.room_versions import RoomVersions
from synapse.handlers.sync import SyncConfig
from synapse.rest import admin
@@ -197,7 +198,7 @@ def generate_sync_config(
_request_key += 1
return SyncConfig(
user=UserID.from_string(user_id),
- filter_collection=DEFAULT_FILTER_COLLECTION,
+ filter_collection=Filtering(Mock()).DEFAULT_FILTER_COLLECTION,
is_guest=False,
request_key=("request_key", _request_key),
device_id=device_id,
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index eac4664b41..cb02eddf07 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -12,13 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Any, Callable, Dict, List, Optional, Tuple
+from typing import Any, Dict, List, Optional, Tuple
from twisted.internet.protocol import Protocol
from twisted.web.resource import Resource
from synapse.app.generic_worker import GenericWorkerServer
-from synapse.http.server import JsonResource
from synapse.http.site import SynapseRequest, SynapseSite
from synapse.replication.http import ReplicationRestResource
from synapse.replication.tcp.client import ReplicationDataHandler
@@ -220,8 +219,6 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
unlike `BaseStreamTestCase`.
"""
- servlets: List[Callable[[HomeServer, JsonResource], None]] = []
-
def setUp(self):
super().setUp()
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index 192073c520..af849bd471 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -474,3 +474,51 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
% server_and_media_id_2
),
)
+
+
+class PurgeHistoryTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+ self.other_user_tok = self.login("user", "pass")
+
+ self.room_id = self.helper.create_room_as(
+ self.other_user, tok=self.other_user_tok
+ )
+ self.url = f"/_synapse/admin/v1/purge_history/{self.room_id}"
+ self.url_status = "/_synapse/admin/v1/purge_history_status/"
+
+ def test_purge_history(self):
+ """
+ Simple test of purge history API.
+ Test only that is is possible to call, get status 200 and purge_id.
+ """
+
+ channel = self.make_request(
+ "POST",
+ self.url,
+ content={"delete_local_events": True, "purge_up_to_ts": 0},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertIn("purge_id", channel.json_body)
+ purge_id = channel.json_body["purge_id"]
+
+ # get status
+ channel = self.make_request(
+ "GET",
+ self.url_status + purge_id,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual("complete", channel.json_body["status"])
diff --git a/tests/rest/admin/test_background_updates.py b/tests/rest/admin/test_background_updates.py
index 78c48db552..cd5c60b65c 100644
--- a/tests/rest/admin/test_background_updates.py
+++ b/tests/rest/admin/test_background_updates.py
@@ -11,10 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from http import HTTPStatus
+from typing import Collection
+
+from parameterized import parameterized
import synapse.rest.admin
+from synapse.api.errors import Codes
from synapse.rest.client import login
from synapse.server import HomeServer
+from synapse.storage.background_updates import BackgroundUpdater
from tests import unittest
@@ -30,6 +36,60 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
+ @parameterized.expand(
+ [
+ ("GET", "/_synapse/admin/v1/background_updates/enabled"),
+ ("POST", "/_synapse/admin/v1/background_updates/enabled"),
+ ("GET", "/_synapse/admin/v1/background_updates/status"),
+ ("POST", "/_synapse/admin/v1/background_updates/start_job"),
+ ]
+ )
+ def test_requester_is_no_admin(self, method: str, url: str):
+ """
+ If the user is not a server admin, an error 403 is returned.
+ """
+
+ self.register_user("user", "pass", admin=False)
+ other_user_tok = self.login("user", "pass")
+
+ channel = self.make_request(
+ method,
+ url,
+ content={},
+ access_token=other_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_invalid_parameter(self):
+ """
+ If parameters are invalid, an error is returned.
+ """
+ url = "/_synapse/admin/v1/background_updates/start_job"
+
+ # empty content
+ channel = self.make_request(
+ "POST",
+ url,
+ content={},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
+
+ # job_name invalid
+ channel = self.make_request(
+ "POST",
+ url,
+ content={"job_name": "unknown"},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+
def _register_bg_update(self):
"Adds a bg update but doesn't start it"
@@ -60,7 +120,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
"/_synapse/admin/v1/background_updates/status",
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
# Background updates should be enabled, but none should be running.
self.assertDictEqual(
@@ -82,7 +142,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
"/_synapse/admin/v1/background_updates/status",
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
# Background updates should be enabled, and one should be running.
self.assertDictEqual(
@@ -91,9 +151,11 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
"current_updates": {
"master": {
"name": "test_update",
- "average_items_per_ms": 0.1,
+ "average_items_per_ms": 0.001,
"total_duration_ms": 1000.0,
- "total_item_count": 100,
+ "total_item_count": (
+ BackgroundUpdater.MINIMUM_BACKGROUND_BATCH_SIZE
+ ),
}
},
"enabled": True,
@@ -114,7 +176,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
"/_synapse/admin/v1/background_updates/enabled",
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertDictEqual(channel.json_body, {"enabled": True})
# Disable the BG updates
@@ -124,7 +186,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
content={"enabled": False},
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertDictEqual(channel.json_body, {"enabled": False})
# Advance a bit and get the current status, note this will finish the in
@@ -137,16 +199,18 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
"/_synapse/admin/v1/background_updates/status",
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertDictEqual(
channel.json_body,
{
"current_updates": {
"master": {
"name": "test_update",
- "average_items_per_ms": 0.1,
+ "average_items_per_ms": 0.001,
"total_duration_ms": 1000.0,
- "total_item_count": 100,
+ "total_item_count": (
+ BackgroundUpdater.MINIMUM_BACKGROUND_BATCH_SIZE
+ ),
}
},
"enabled": False,
@@ -162,7 +226,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
"/_synapse/admin/v1/background_updates/status",
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
# There should be no change from the previous /status response.
self.assertDictEqual(
@@ -171,9 +235,11 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
"current_updates": {
"master": {
"name": "test_update",
- "average_items_per_ms": 0.1,
+ "average_items_per_ms": 0.001,
"total_duration_ms": 1000.0,
- "total_item_count": 100,
+ "total_item_count": (
+ BackgroundUpdater.MINIMUM_BACKGROUND_BATCH_SIZE
+ ),
}
},
"enabled": False,
@@ -188,7 +254,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
content={"enabled": True},
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertDictEqual(channel.json_body, {"enabled": True})
@@ -199,7 +265,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
"/_synapse/admin/v1/background_updates/status",
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
# Background updates should be enabled and making progress.
self.assertDictEqual(
@@ -208,11 +274,92 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
"current_updates": {
"master": {
"name": "test_update",
- "average_items_per_ms": 0.1,
+ "average_items_per_ms": 0.001,
"total_duration_ms": 2000.0,
- "total_item_count": 200,
+ "total_item_count": (
+ 2 * BackgroundUpdater.MINIMUM_BACKGROUND_BATCH_SIZE
+ ),
}
},
"enabled": True,
},
)
+
+ @parameterized.expand(
+ [
+ ("populate_stats_process_rooms", ["populate_stats_process_rooms"]),
+ (
+ "regenerate_directory",
+ [
+ "populate_user_directory_createtables",
+ "populate_user_directory_process_rooms",
+ "populate_user_directory_process_users",
+ "populate_user_directory_cleanup",
+ ],
+ ),
+ ]
+ )
+ def test_start_backround_job(self, job_name: str, updates: Collection[str]):
+ """
+ Test that background updates add to database and be processed.
+
+ Args:
+ job_name: name of the job to call with API
+ updates: collection of background updates to be started
+ """
+
+ # no background update is waiting
+ self.assertTrue(
+ self.get_success(
+ self.store.db_pool.updates.has_completed_background_updates()
+ )
+ )
+
+ channel = self.make_request(
+ "POST",
+ "/_synapse/admin/v1/background_updates/start_job",
+ content={"job_name": job_name},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+
+ # test that each background update is waiting now
+ for update in updates:
+ self.assertFalse(
+ self.get_success(
+ self.store.db_pool.updates.has_completed_background_update(update)
+ )
+ )
+
+ self.wait_for_background_updates()
+
+ # background updates are done
+ self.assertTrue(
+ self.get_success(
+ self.store.db_pool.updates.has_completed_background_updates()
+ )
+ )
+
+ def test_start_backround_job_twice(self):
+ """Test that add a background update twice return an error."""
+
+ # add job to database
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ table="background_updates",
+ values={
+ "update_name": "populate_stats_process_rooms",
+ "progress_json": "{}",
+ },
+ )
+ )
+
+ channel = self.make_request(
+ "POST",
+ "/_synapse/admin/v1/background_updates/start_job",
+ content={"job_name": "populate_stats_process_rooms"},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index 46116644ce..07077aff78 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -14,12 +14,16 @@
import json
import urllib.parse
+from http import HTTPStatus
from typing import List, Optional
from unittest.mock import Mock
+from parameterized import parameterized
+
import synapse.rest.admin
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import Codes
+from synapse.handlers.pagination import PaginationHandler
from synapse.rest.client import directory, events, login, room
from tests import unittest
@@ -68,11 +72,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"DELETE",
self.url,
- json.dumps({}),
+ {},
access_token=self.other_user_tok,
)
- self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_room_does_not_exist(self):
@@ -84,11 +88,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"DELETE",
url,
- json.dumps({}),
+ {},
access_token=self.admin_user_tok,
)
- self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
def test_room_is_not_valid(self):
@@ -100,11 +104,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"DELETE",
url,
- json.dumps({}),
+ {},
access_token=self.admin_user_tok,
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(
"invalidroom is not a legal room ID",
channel.json_body["error"],
@@ -119,11 +123,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"DELETE",
self.url,
- content=body.encode(encoding="utf_8"),
+ content=body,
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIn("new_room_id", channel.json_body)
self.assertIn("kicked_users", channel.json_body)
self.assertIn("failed_to_kick_users", channel.json_body)
@@ -138,11 +142,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"DELETE",
self.url,
- content=body.encode(encoding="utf_8"),
+ content=body,
access_token=self.admin_user_tok,
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(
"User must be our own: @not:exist.bla",
channel.json_body["error"],
@@ -157,11 +161,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"DELETE",
self.url,
- content=body.encode(encoding="utf_8"),
+ content=body,
access_token=self.admin_user_tok,
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
def test_purge_is_not_bool(self):
@@ -173,11 +177,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"DELETE",
self.url,
- content=body.encode(encoding="utf_8"),
+ content=body,
access_token=self.admin_user_tok,
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
def test_purge_room_and_block(self):
@@ -199,11 +203,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"DELETE",
self.url.encode("ascii"),
- content=body.encode(encoding="utf_8"),
+ content=body,
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(None, channel.json_body["new_room_id"])
self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
self.assertIn("failed_to_kick_users", channel.json_body)
@@ -232,11 +236,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"DELETE",
self.url.encode("ascii"),
- content=body.encode(encoding="utf_8"),
+ content=body,
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(None, channel.json_body["new_room_id"])
self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
self.assertIn("failed_to_kick_users", channel.json_body)
@@ -266,11 +270,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"DELETE",
self.url.encode("ascii"),
- content=body.encode(encoding="utf_8"),
+ content=body,
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(None, channel.json_body["new_room_id"])
self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
self.assertIn("failed_to_kick_users", channel.json_body)
@@ -281,6 +285,31 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
self._is_blocked(self.room_id, expect=True)
self._has_no_members(self.room_id)
+ @parameterized.expand([(True,), (False,)])
+ def test_block_unknown_room(self, purge: bool) -> None:
+ """
+ We can block an unknown room. In this case, the `purge` argument
+ should be ignored.
+ """
+ room_id = "!unknown:test"
+
+ # The room isn't already in the blocked rooms table
+ self._is_blocked(room_id, expect=False)
+
+ # Request the room be blocked.
+ channel = self.make_request(
+ "DELETE",
+ f"/_synapse/admin/v1/rooms/{room_id}",
+ {"block": True, "purge": purge},
+ access_token=self.admin_user_tok,
+ )
+
+ # The room is now blocked.
+ self.assertEqual(
+ HTTPStatus.OK, int(channel.result["code"]), msg=channel.result["body"]
+ )
+ self._is_blocked(room_id)
+
def test_shutdown_room_consent(self):
"""Test that we can shutdown rooms with local users who have not
yet accepted the privacy policy. This used to fail when we tried to
@@ -316,7 +345,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
self.assertIn("new_room_id", channel.json_body)
self.assertIn("failed_to_kick_users", channel.json_body)
@@ -345,7 +374,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
json.dumps({"history_visibility": "world_readable"}),
access_token=self.other_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Test that room is not purged
with self.assertRaises(AssertionError):
@@ -362,7 +391,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
self.assertIn("new_room_id", channel.json_body)
self.assertIn("failed_to_kick_users", channel.json_body)
@@ -418,17 +447,616 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"GET", url.encode("ascii"), access_token=self.admin_user_tok
)
+ self.assertEqual(expect_code, channel.code, msg=channel.json_body)
+
+ url = "events?timeout=0&room_id=" + room_id
+ channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok
+ )
+ self.assertEqual(expect_code, channel.code, msg=channel.json_body)
+
+
+class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ events.register_servlets,
+ room.register_servlets,
+ room.register_deprecated_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.event_creation_handler = hs.get_event_creation_handler()
+ hs.config.consent.user_consent_version = "1"
+
+ consent_uri_builder = Mock()
+ consent_uri_builder.build_user_consent_uri.return_value = "http://example.com"
+ self.event_creation_handler._consent_uri_builder = consent_uri_builder
+
+ self.store = hs.get_datastore()
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+ self.other_user_tok = self.login("user", "pass")
+
+ # Mark the admin user as having consented
+ self.get_success(self.store.user_set_consent_version(self.admin_user, "1"))
+
+ self.room_id = self.helper.create_room_as(
+ self.other_user, tok=self.other_user_tok
+ )
+ self.url = f"/_synapse/admin/v2/rooms/{self.room_id}"
+ self.url_status_by_room_id = (
+ f"/_synapse/admin/v2/rooms/{self.room_id}/delete_status"
+ )
+ self.url_status_by_delete_id = "/_synapse/admin/v2/rooms/delete_status/"
+
+ @parameterized.expand(
+ [
+ ("DELETE", "/_synapse/admin/v2/rooms/%s"),
+ ("GET", "/_synapse/admin/v2/rooms/%s/delete_status"),
+ ("GET", "/_synapse/admin/v2/rooms/delete_status/%s"),
+ ]
+ )
+ def test_requester_is_no_admin(self, method: str, url: str):
+ """
+ If the user is not a server admin, an error 403 is returned.
+ """
+
+ channel = self.make_request(
+ method,
+ url % self.room_id,
+ content={},
+ access_token=self.other_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ @parameterized.expand(
+ [
+ ("DELETE", "/_synapse/admin/v2/rooms/%s"),
+ ("GET", "/_synapse/admin/v2/rooms/%s/delete_status"),
+ ("GET", "/_synapse/admin/v2/rooms/delete_status/%s"),
+ ]
+ )
+ def test_room_does_not_exist(self, method: str, url: str):
+ """
+ Check that unknown rooms/server return error 404.
+ """
+
+ channel = self.make_request(
+ method,
+ url % "!unknown:test",
+ content={},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+ @parameterized.expand(
+ [
+ ("DELETE", "/_synapse/admin/v2/rooms/%s"),
+ ("GET", "/_synapse/admin/v2/rooms/%s/delete_status"),
+ ]
+ )
+ def test_room_is_not_valid(self, method: str, url: str):
+ """
+ Check that invalid room names, return an error 400.
+ """
+
+ channel = self.make_request(
+ method,
+ url % "invalidroom",
+ content={},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(
+ "invalidroom is not a legal room ID",
+ channel.json_body["error"],
+ )
+
+ def test_new_room_user_does_not_exist(self):
+ """
+ Tests that the user ID must be from local server but it does not have to exist.
+ """
+
+ channel = self.make_request(
+ "DELETE",
+ self.url,
+ content={"new_room_user_id": "@unknown:test"},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertIn("delete_id", channel.json_body)
+ delete_id = channel.json_body["delete_id"]
+
+ self._test_result(delete_id, self.other_user, expect_new_room=True)
+
+ def test_new_room_user_is_not_local(self):
+ """
+ Check that only local users can create new room to move members.
+ """
+
+ channel = self.make_request(
+ "DELETE",
+ self.url,
+ content={"new_room_user_id": "@not:exist.bla"},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual(
- expect_code, int(channel.result["code"]), msg=channel.result["body"]
+ "User must be our own: @not:exist.bla",
+ channel.json_body["error"],
)
+ def test_block_is_not_bool(self):
+ """
+ If parameter `block` is not boolean, return an error
+ """
+
+ channel = self.make_request(
+ "DELETE",
+ self.url,
+ content={"block": "NotBool"},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
+
+ def test_purge_is_not_bool(self):
+ """
+ If parameter `purge` is not boolean, return an error
+ """
+
+ channel = self.make_request(
+ "DELETE",
+ self.url,
+ content={"purge": "NotBool"},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
+
+ def test_delete_expired_status(self):
+ """Test that the task status is removed after expiration."""
+
+ # first task, do not purge, that we can create a second task
+ channel = self.make_request(
+ "DELETE",
+ self.url.encode("ascii"),
+ content={"purge": False},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertIn("delete_id", channel.json_body)
+ delete_id1 = channel.json_body["delete_id"]
+
+ # go ahead
+ self.reactor.advance(PaginationHandler.CLEAR_PURGE_AFTER_MS / 1000 / 2)
+
+ # second task
+ channel = self.make_request(
+ "DELETE",
+ self.url.encode("ascii"),
+ content={"purge": True},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertIn("delete_id", channel.json_body)
+ delete_id2 = channel.json_body["delete_id"]
+
+ # get status
+ channel = self.make_request(
+ "GET",
+ self.url_status_by_room_id,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(2, len(channel.json_body["results"]))
+ self.assertEqual("complete", channel.json_body["results"][0]["status"])
+ self.assertEqual("complete", channel.json_body["results"][1]["status"])
+ self.assertEqual(delete_id1, channel.json_body["results"][0]["delete_id"])
+ self.assertEqual(delete_id2, channel.json_body["results"][1]["delete_id"])
+
+ # get status after more than clearing time for first task
+ # second task is not cleared
+ self.reactor.advance(PaginationHandler.CLEAR_PURGE_AFTER_MS / 1000 / 2)
+
+ channel = self.make_request(
+ "GET",
+ self.url_status_by_room_id,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(1, len(channel.json_body["results"]))
+ self.assertEqual("complete", channel.json_body["results"][0]["status"])
+ self.assertEqual(delete_id2, channel.json_body["results"][0]["delete_id"])
+
+ # get status after more than clearing time for all tasks
+ self.reactor.advance(PaginationHandler.CLEAR_PURGE_AFTER_MS / 1000 / 2)
+
+ channel = self.make_request(
+ "GET",
+ self.url_status_by_room_id,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+ def test_delete_same_room_twice(self):
+ """Test that the call for delete a room at second time gives an exception."""
+
+ body = {"new_room_user_id": self.admin_user}
+
+ # first call to delete room
+ # and do not wait for finish the task
+ first_channel = self.make_request(
+ "DELETE",
+ self.url.encode("ascii"),
+ content=body,
+ access_token=self.admin_user_tok,
+ await_result=False,
+ )
+
+ # second call to delete room
+ second_channel = self.make_request(
+ "DELETE",
+ self.url.encode("ascii"),
+ content=body,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, second_channel.code, msg=second_channel.json_body
+ )
+ self.assertEqual(Codes.UNKNOWN, second_channel.json_body["errcode"])
+ self.assertEqual(
+ f"History purge already in progress for {self.room_id}",
+ second_channel.json_body["error"],
+ )
+
+ # get result of first call
+ first_channel.await_result()
+ self.assertEqual(HTTPStatus.OK, first_channel.code, msg=first_channel.json_body)
+ self.assertIn("delete_id", first_channel.json_body)
+
+ # check status after finish the task
+ self._test_result(
+ first_channel.json_body["delete_id"],
+ self.other_user,
+ expect_new_room=True,
+ )
+
+ def test_purge_room_and_block(self):
+ """Test to purge a room and block it.
+ Members will not be moved to a new room and will not receive a message.
+ """
+ # Test that room is not purged
+ with self.assertRaises(AssertionError):
+ self._is_purged(self.room_id)
+
+ # Test that room is not blocked
+ self._is_blocked(self.room_id, expect=False)
+
+ # Assert one user in room
+ self._is_member(room_id=self.room_id, user_id=self.other_user)
+
+ channel = self.make_request(
+ "DELETE",
+ self.url.encode("ascii"),
+ content={"block": True, "purge": True},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertIn("delete_id", channel.json_body)
+ delete_id = channel.json_body["delete_id"]
+
+ self._test_result(delete_id, self.other_user)
+
+ self._is_purged(self.room_id)
+ self._is_blocked(self.room_id, expect=True)
+ self._has_no_members(self.room_id)
+
+ def test_purge_room_and_not_block(self):
+ """Test to purge a room and do not block it.
+ Members will not be moved to a new room and will not receive a message.
+ """
+ # Test that room is not purged
+ with self.assertRaises(AssertionError):
+ self._is_purged(self.room_id)
+
+ # Test that room is not blocked
+ self._is_blocked(self.room_id, expect=False)
+
+ # Assert one user in room
+ self._is_member(room_id=self.room_id, user_id=self.other_user)
+
+ channel = self.make_request(
+ "DELETE",
+ self.url.encode("ascii"),
+ content={"block": False, "purge": True},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertIn("delete_id", channel.json_body)
+ delete_id = channel.json_body["delete_id"]
+
+ self._test_result(delete_id, self.other_user)
+
+ self._is_purged(self.room_id)
+ self._is_blocked(self.room_id, expect=False)
+ self._has_no_members(self.room_id)
+
+ def test_block_room_and_not_purge(self):
+ """Test to block a room without purging it.
+ Members will not be moved to a new room and will not receive a message.
+ The room will not be purged.
+ """
+ # Test that room is not purged
+ with self.assertRaises(AssertionError):
+ self._is_purged(self.room_id)
+
+ # Test that room is not blocked
+ self._is_blocked(self.room_id, expect=False)
+
+ # Assert one user in room
+ self._is_member(room_id=self.room_id, user_id=self.other_user)
+
+ channel = self.make_request(
+ "DELETE",
+ self.url.encode("ascii"),
+ content={"block": True, "purge": False},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertIn("delete_id", channel.json_body)
+ delete_id = channel.json_body["delete_id"]
+
+ self._test_result(delete_id, self.other_user)
+
+ with self.assertRaises(AssertionError):
+ self._is_purged(self.room_id)
+ self._is_blocked(self.room_id, expect=True)
+ self._has_no_members(self.room_id)
+
+ def test_shutdown_room_consent(self):
+ """Test that we can shutdown rooms with local users who have not
+ yet accepted the privacy policy. This used to fail when we tried to
+ force part the user from the old room.
+ Members will be moved to a new room and will receive a message.
+ """
+ self.event_creation_handler._block_events_without_consent_error = None
+
+ # Assert one user in room
+ users_in_room = self.get_success(self.store.get_users_in_room(self.room_id))
+ self.assertEqual([self.other_user], users_in_room)
+
+ # Enable require consent to send events
+ self.event_creation_handler._block_events_without_consent_error = "Error"
+
+ # Assert that the user is getting consent error
+ self.helper.send(
+ self.room_id, body="foo", tok=self.other_user_tok, expect_code=403
+ )
+
+ # Test that room is not purged
+ with self.assertRaises(AssertionError):
+ self._is_purged(self.room_id)
+
+ # Assert one user in room
+ self._is_member(room_id=self.room_id, user_id=self.other_user)
+
+ # Test that the admin can still send shutdown
+ channel = self.make_request(
+ "DELETE",
+ self.url,
+ content={"new_room_user_id": self.admin_user},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertIn("delete_id", channel.json_body)
+ delete_id = channel.json_body["delete_id"]
+
+ self._test_result(delete_id, self.other_user, expect_new_room=True)
+
+ channel = self.make_request(
+ "GET",
+ self.url_status_by_room_id,
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(1, len(channel.json_body["results"]))
+
+ # Test that member has moved to new room
+ self._is_member(
+ room_id=channel.json_body["results"][0]["shutdown_room"]["new_room_id"],
+ user_id=self.other_user,
+ )
+
+ self._is_purged(self.room_id)
+ self._has_no_members(self.room_id)
+
+ def test_shutdown_room_block_peek(self):
+ """Test that a world_readable room can no longer be peeked into after
+ it has been shut down.
+ Members will be moved to a new room and will receive a message.
+ """
+ self.event_creation_handler._block_events_without_consent_error = None
+
+ # Enable world readable
+ url = "rooms/%s/state/m.room.history_visibility" % (self.room_id,)
+ channel = self.make_request(
+ "PUT",
+ url.encode("ascii"),
+ content={"history_visibility": "world_readable"},
+ access_token=self.other_user_tok,
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+
+ # Test that room is not purged
+ with self.assertRaises(AssertionError):
+ self._is_purged(self.room_id)
+
+ # Assert one user in room
+ self._is_member(room_id=self.room_id, user_id=self.other_user)
+
+ # Test that the admin can still send shutdown
+ channel = self.make_request(
+ "DELETE",
+ self.url,
+ content={"new_room_user_id": self.admin_user},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertIn("delete_id", channel.json_body)
+ delete_id = channel.json_body["delete_id"]
+
+ self._test_result(delete_id, self.other_user, expect_new_room=True)
+
+ channel = self.make_request(
+ "GET",
+ self.url_status_by_room_id,
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(1, len(channel.json_body["results"]))
+
+ # Test that member has moved to new room
+ self._is_member(
+ room_id=channel.json_body["results"][0]["shutdown_room"]["new_room_id"],
+ user_id=self.other_user,
+ )
+
+ self._is_purged(self.room_id)
+ self._has_no_members(self.room_id)
+
+ # Assert we can no longer peek into the room
+ self._assert_peek(self.room_id, expect_code=403)
+
+ def _is_blocked(self, room_id: str, expect: bool = True) -> None:
+ """Assert that the room is blocked or not"""
+ d = self.store.is_room_blocked(room_id)
+ if expect:
+ self.assertTrue(self.get_success(d))
+ else:
+ self.assertIsNone(self.get_success(d))
+
+ def _has_no_members(self, room_id: str) -> None:
+ """Assert there is now no longer anyone in the room"""
+ users_in_room = self.get_success(self.store.get_users_in_room(room_id))
+ self.assertEqual([], users_in_room)
+
+ def _is_member(self, room_id: str, user_id: str) -> None:
+ """Test that user is member of the room"""
+ users_in_room = self.get_success(self.store.get_users_in_room(room_id))
+ self.assertIn(user_id, users_in_room)
+
+ def _is_purged(self, room_id: str) -> None:
+ """Test that the following tables have been purged of all rows related to the room."""
+ for table in PURGE_TABLES:
+ count = self.get_success(
+ self.store.db_pool.simple_select_one_onecol(
+ table=table,
+ keyvalues={"room_id": room_id},
+ retcol="COUNT(*)",
+ desc="test_purge_room",
+ )
+ )
+
+ self.assertEqual(count, 0, msg=f"Rows not purged in {table}")
+
+ def _assert_peek(self, room_id: str, expect_code: int) -> None:
+ """Assert that the admin user can (or cannot) peek into the room."""
+
+ url = f"rooms/{room_id}/initialSync"
+ channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok
+ )
+ self.assertEqual(expect_code, channel.code, msg=channel.json_body)
+
url = "events?timeout=0&room_id=" + room_id
channel = self.make_request(
"GET", url.encode("ascii"), access_token=self.admin_user_tok
)
+ self.assertEqual(expect_code, channel.code, msg=channel.json_body)
+
+ def _test_result(
+ self,
+ delete_id: str,
+ kicked_user: str,
+ expect_new_room: bool = False,
+ ) -> None:
+ """
+ Test that the result is the expected.
+ Uses both APIs (status by room_id and delete_id)
+
+ Args:
+ delete_id: id of this purge
+ kicked_user: a user_id which is kicked from the room
+ expect_new_room: if we expect that a new room was created
+ """
+
+ # get information by room_id
+ channel_room_id = self.make_request(
+ "GET",
+ self.url_status_by_room_id,
+ access_token=self.admin_user_tok,
+ )
self.assertEqual(
- expect_code, int(channel.result["code"]), msg=channel.result["body"]
+ HTTPStatus.OK, channel_room_id.code, msg=channel_room_id.json_body
)
+ self.assertEqual(1, len(channel_room_id.json_body["results"]))
+ self.assertEqual(
+ delete_id, channel_room_id.json_body["results"][0]["delete_id"]
+ )
+
+ # get information by delete_id
+ channel_delete_id = self.make_request(
+ "GET",
+ self.url_status_by_delete_id + delete_id,
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(
+ HTTPStatus.OK,
+ channel_delete_id.code,
+ msg=channel_delete_id.json_body,
+ )
+
+ # test values that are the same in both responses
+ for content in [
+ channel_room_id.json_body["results"][0],
+ channel_delete_id.json_body,
+ ]:
+ self.assertEqual("complete", content["status"])
+ self.assertEqual(kicked_user, content["shutdown_room"]["kicked_users"][0])
+ self.assertIn("failed_to_kick_users", content["shutdown_room"])
+ self.assertIn("local_aliases", content["shutdown_room"])
+ self.assertNotIn("error", content)
+
+ if expect_new_room:
+ self.assertIsNotNone(content["shutdown_room"]["new_room_id"])
+ else:
+ self.assertIsNone(content["shutdown_room"]["new_room_id"])
class RoomTestCase(unittest.HomeserverTestCase):
@@ -466,7 +1094,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
)
# Check request completed successfully
- self.assertEqual(200, int(channel.code), msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Check that response json body contains a "rooms" key
self.assertTrue(
@@ -550,9 +1178,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(
- 200, int(channel.result["code"]), msg=channel.result["body"]
- )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertTrue("rooms" in channel.json_body)
for r in channel.json_body["rooms"]:
@@ -592,7 +1218,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
def test_correct_room_attributes(self):
"""Test the correct attributes for a room are returned"""
@@ -615,7 +1241,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
{"room_id": room_id},
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Set this new alias as the canonical alias for this room
self.helper.send_state(
@@ -647,7 +1273,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Check that rooms were returned
self.assertTrue("rooms" in channel.json_body)
@@ -1107,7 +1733,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
{"room_id": room_id},
access_token=admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Set this new alias as the canonical alias for this room
self.helper.send_state(
@@ -1157,11 +1783,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST",
self.url,
- content=body.encode(encoding="utf_8"),
+ content=body,
access_token=self.second_tok,
)
- self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_invalid_parameter(self):
@@ -1173,11 +1799,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST",
self.url,
- content=body.encode(encoding="utf_8"),
+ content=body,
access_token=self.admin_user_tok,
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
def test_local_user_does_not_exist(self):
@@ -1189,11 +1815,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST",
self.url,
- content=body.encode(encoding="utf_8"),
+ content=body,
access_token=self.admin_user_tok,
)
- self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
def test_remote_user(self):
@@ -1205,11 +1831,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST",
self.url,
- content=body.encode(encoding="utf_8"),
+ content=body,
access_token=self.admin_user_tok,
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(
"This endpoint can only be used with local users",
channel.json_body["error"],
@@ -1225,11 +1851,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST",
url,
- content=body.encode(encoding="utf_8"),
+ content=body,
access_token=self.admin_user_tok,
)
- self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual("No known servers", channel.json_body["error"])
def test_room_is_not_valid(self):
@@ -1242,11 +1868,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST",
url,
- content=body.encode(encoding="utf_8"),
+ content=body,
access_token=self.admin_user_tok,
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(
"invalidroom was not legal room ID or room alias",
channel.json_body["error"],
@@ -1261,11 +1887,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST",
self.url,
- content=body.encode(encoding="utf_8"),
+ content=body,
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(self.public_room_id, channel.json_body["room_id"])
# Validate if user is a member of the room
@@ -1275,7 +1901,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
"/_matrix/client/r0/joined_rooms",
access_token=self.second_tok,
)
- self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEquals(200, channel.code, msg=channel.json_body)
self.assertEqual(self.public_room_id, channel.json_body["joined_rooms"][0])
def test_join_private_room_if_not_member(self):
@@ -1292,11 +1918,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST",
url,
- content=body.encode(encoding="utf_8"),
+ content=body,
access_token=self.admin_user_tok,
)
- self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_join_private_room_if_member(self):
@@ -1324,7 +1950,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
"/_matrix/client/r0/joined_rooms",
access_token=self.admin_user_tok,
)
- self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEquals(200, channel.code, msg=channel.json_body)
self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
# Join user to room.
@@ -1335,10 +1961,10 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST",
url,
- content=body.encode(encoding="utf_8"),
+ content=body,
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(private_room_id, channel.json_body["room_id"])
# Validate if user is a member of the room
@@ -1348,7 +1974,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
"/_matrix/client/r0/joined_rooms",
access_token=self.second_tok,
)
- self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEquals(200, channel.code, msg=channel.json_body)
self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
def test_join_private_room_if_owner(self):
@@ -1365,11 +1991,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST",
url,
- content=body.encode(encoding="utf_8"),
+ content=body,
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(private_room_id, channel.json_body["room_id"])
# Validate if user is a member of the room
@@ -1379,7 +2005,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
"/_matrix/client/r0/joined_rooms",
access_token=self.second_tok,
)
- self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEquals(200, channel.code, msg=channel.json_body)
self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
def test_context_as_non_admin(self):
@@ -1413,9 +2039,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
% (room_id, events[midway]["event_id"]),
access_token=tok,
)
- self.assertEquals(
- 403, int(channel.result["code"]), msg=channel.result["body"]
- )
+ self.assertEquals(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_context_as_admin(self):
@@ -1445,7 +2069,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
% (room_id, events[midway]["event_id"]),
access_token=self.admin_user_tok,
)
- self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEquals(200, channel.code, msg=channel.json_body)
self.assertEquals(
channel.json_body["event"]["event_id"], events[midway]["event_id"]
)
@@ -1504,7 +2128,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Now we test that we can join the room and ban a user.
self.helper.join(room_id, self.admin_user, tok=self.admin_user_tok)
@@ -1531,7 +2155,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Now we test that we can join the room (we should have received an
# invite) and can ban a user.
@@ -1557,7 +2181,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Now we test that we can join the room and ban a user.
self.helper.join(room_id, self.second_user_id, tok=self.second_tok)
@@ -1595,13 +2219,241 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
#
# (Note we assert the error message to ensure that it's not denied for
# some other reason)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(
channel.json_body["error"],
"No local admin user in room with power to update power levels.",
)
+class BlockRoomTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self._store = hs.get_datastore()
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+ self.other_user_tok = self.login("user", "pass")
+
+ self.room_id = self.helper.create_room_as(
+ self.other_user, tok=self.other_user_tok
+ )
+ self.url = "/_synapse/admin/v1/rooms/%s/block"
+
+ @parameterized.expand([("PUT",), ("GET",)])
+ def test_requester_is_no_admin(self, method: str):
+ """If the user is not a server admin, an error 403 is returned."""
+
+ channel = self.make_request(
+ method,
+ self.url % self.room_id,
+ content={},
+ access_token=self.other_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ @parameterized.expand([("PUT",), ("GET",)])
+ def test_room_is_not_valid(self, method: str):
+ """Check that invalid room names, return an error 400."""
+
+ channel = self.make_request(
+ method,
+ self.url % "invalidroom",
+ content={},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(
+ "invalidroom is not a legal room ID",
+ channel.json_body["error"],
+ )
+
+ def test_block_is_not_valid(self):
+ """If parameter `block` is not valid, return an error."""
+
+ # `block` is not valid
+ channel = self.make_request(
+ "PUT",
+ self.url % self.room_id,
+ content={"block": "NotBool"},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
+
+ # `block` is not set
+ channel = self.make_request(
+ "PUT",
+ self.url % self.room_id,
+ content={},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
+
+ # no content is send
+ channel = self.make_request(
+ "PUT",
+ self.url % self.room_id,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.NOT_JSON, channel.json_body["errcode"])
+
+ def test_block_room(self):
+ """Test that block a room is successful."""
+
+ def _request_and_test_block_room(room_id: str) -> None:
+ self._is_blocked(room_id, expect=False)
+ channel = self.make_request(
+ "PUT",
+ self.url % room_id,
+ content={"block": True},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertTrue(channel.json_body["block"])
+ self._is_blocked(room_id, expect=True)
+
+ # known internal room
+ _request_and_test_block_room(self.room_id)
+
+ # unknown internal room
+ _request_and_test_block_room("!unknown:test")
+
+ # unknown remote room
+ _request_and_test_block_room("!unknown:remote")
+
+ def test_block_room_twice(self):
+ """Test that block a room that is already blocked is successful."""
+
+ self._is_blocked(self.room_id, expect=False)
+ for _ in range(2):
+ channel = self.make_request(
+ "PUT",
+ self.url % self.room_id,
+ content={"block": True},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertTrue(channel.json_body["block"])
+ self._is_blocked(self.room_id, expect=True)
+
+ def test_unblock_room(self):
+ """Test that unblock a room is successful."""
+
+ def _request_and_test_unblock_room(room_id: str) -> None:
+ self._block_room(room_id)
+
+ channel = self.make_request(
+ "PUT",
+ self.url % room_id,
+ content={"block": False},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertFalse(channel.json_body["block"])
+ self._is_blocked(room_id, expect=False)
+
+ # known internal room
+ _request_and_test_unblock_room(self.room_id)
+
+ # unknown internal room
+ _request_and_test_unblock_room("!unknown:test")
+
+ # unknown remote room
+ _request_and_test_unblock_room("!unknown:remote")
+
+ def test_unblock_room_twice(self):
+ """Test that unblock a room that is not blocked is successful."""
+
+ self._block_room(self.room_id)
+ for _ in range(2):
+ channel = self.make_request(
+ "PUT",
+ self.url % self.room_id,
+ content={"block": False},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertFalse(channel.json_body["block"])
+ self._is_blocked(self.room_id, expect=False)
+
+ def test_get_blocked_room(self):
+ """Test get status of a blocked room"""
+
+ def _request_blocked_room(room_id: str) -> None:
+ self._block_room(room_id)
+
+ channel = self.make_request(
+ "GET",
+ self.url % room_id,
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertTrue(channel.json_body["block"])
+ self.assertEqual(self.other_user, channel.json_body["user_id"])
+
+ # known internal room
+ _request_blocked_room(self.room_id)
+
+ # unknown internal room
+ _request_blocked_room("!unknown:test")
+
+ # unknown remote room
+ _request_blocked_room("!unknown:remote")
+
+ def test_get_unblocked_room(self):
+ """Test get status of a unblocked room"""
+
+ def _request_unblocked_room(room_id: str) -> None:
+ self._is_blocked(room_id, expect=False)
+
+ channel = self.make_request(
+ "GET",
+ self.url % room_id,
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertFalse(channel.json_body["block"])
+ self.assertNotIn("user_id", channel.json_body)
+
+ # known internal room
+ _request_unblocked_room(self.room_id)
+
+ # unknown internal room
+ _request_unblocked_room("!unknown:test")
+
+ # unknown remote room
+ _request_unblocked_room("!unknown:remote")
+
+ def _is_blocked(self, room_id: str, expect: bool = True) -> None:
+ """Assert that the room is blocked or not"""
+ d = self._store.is_room_blocked(room_id)
+ if expect:
+ self.assertTrue(self.get_success(d))
+ else:
+ self.assertIsNone(self.get_success(d))
+
+ def _block_room(self, room_id: str) -> None:
+ """Block a room in database"""
+ self.get_success(self._store.block_room(room_id, self.other_user))
+ self._is_blocked(room_id, expect=True)
+
+
PURGE_TABLES = [
"current_state_events",
"event_backward_extremities",
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 25e8d6cf27..5011e54563 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -1169,14 +1169,14 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# regardless of whether password login or SSO is allowed
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.get_success(
- self.auth_handler.get_access_token_for_user_id(
+ self.auth_handler.create_access_token_for_user_id(
self.admin_user, device_id=None, valid_until_ms=None
)
)
self.other_user = self.register_user("user", "pass", displayname="User")
self.other_user_token = self.get_success(
- self.auth_handler.get_access_token_for_user_id(
+ self.auth_handler.create_access_token_for_user_id(
self.other_user, device_id=None, valid_until_ms=None
)
)
@@ -3592,31 +3592,34 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
self.other_user
)
- def test_no_auth(self):
+ @parameterized.expand(["POST", "DELETE"])
+ def test_no_auth(self, method: str):
"""
Try to get information of an user without authentication.
"""
- channel = self.make_request("POST", self.url)
+ channel = self.make_request(method, self.url)
self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
- def test_requester_is_not_admin(self):
+ @parameterized.expand(["POST", "DELETE"])
+ def test_requester_is_not_admin(self, method: str):
"""
If the user is not a server admin, an error is returned.
"""
other_user_token = self.login("user", "pass")
- channel = self.make_request("POST", self.url, access_token=other_user_token)
+ channel = self.make_request(method, self.url, access_token=other_user_token)
self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
- def test_user_is_not_local(self):
+ @parameterized.expand(["POST", "DELETE"])
+ def test_user_is_not_local(self, method: str):
"""
Tests that shadow-banning for a user that is not a local returns a 400
"""
url = "/_synapse/admin/v1/whois/@unknown_person:unknown_domain"
- channel = self.make_request("POST", url, access_token=self.admin_user_tok)
+ channel = self.make_request(method, url, access_token=self.admin_user_tok)
self.assertEqual(400, channel.code, msg=channel.json_body)
def test_success(self):
@@ -3636,6 +3639,17 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
result = self.get_success(self.store.get_user_by_access_token(other_user_token))
self.assertTrue(result.shadow_banned)
+ # Un-shadow-ban the user.
+ channel = self.make_request(
+ "DELETE", self.url, access_token=self.admin_user_tok
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual({}, channel.json_body)
+
+ # Ensure the user is no longer shadow-banned (and the cache was cleared).
+ result = self.get_success(self.store.get_user_by_access_token(other_user_token))
+ self.assertFalse(result.shadow_banned)
+
class RateLimitTestCase(unittest.HomeserverTestCase):
diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py
index e2fcbdc63a..8552671431 100644
--- a/tests/rest/client/test_auth.py
+++ b/tests/rest/client/test_auth.py
@@ -598,7 +598,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
refresh_response.json_body["refresh_token"],
)
- @override_config({"access_token_lifetime": "1m"})
+ @override_config({"refreshable_access_token_lifetime": "1m"})
def test_refresh_token_expiration(self):
"""
The access token should have some time as specified in the config.
diff --git a/tests/rest/client/test_capabilities.py b/tests/rest/client/test_capabilities.py
index b9e3602552..249808b031 100644
--- a/tests/rest/client/test_capabilities.py
+++ b/tests/rest/client/test_capabilities.py
@@ -71,7 +71,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
@override_config({"password_config": {"localdb_enabled": False}})
def test_get_change_password_capabilities_localdb_disabled(self):
access_token = self.get_success(
- self.auth_handler.get_access_token_for_user_id(
+ self.auth_handler.create_access_token_for_user_id(
self.user, device_id=None, valid_until_ms=None
)
)
@@ -85,7 +85,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
@override_config({"password_config": {"enabled": False}})
def test_get_change_password_capabilities_password_disabled(self):
access_token = self.get_success(
- self.auth_handler.get_access_token_for_user_id(
+ self.auth_handler.create_access_token_for_user_id(
self.user, device_id=None, valid_until_ms=None
)
)
@@ -174,7 +174,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
@override_config({"experimental_features": {"msc3244_enabled": False}})
def test_get_does_not_include_msc3244_fields_when_disabled(self):
access_token = self.get_success(
- self.auth_handler.get_access_token_for_user_id(
+ self.auth_handler.create_access_token_for_user_id(
self.user, device_id=None, valid_until_ms=None
)
)
@@ -189,7 +189,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
def test_get_does_include_msc3244_fields_when_enabled(self):
access_token = self.get_success(
- self.auth_handler.get_access_token_for_user_id(
+ self.auth_handler.create_access_token_for_user_id(
self.user, device_id=None, valid_until_ms=None
)
)
diff --git a/tests/rest/client/test_directory.py b/tests/rest/client/test_directory.py
index d2181ea907..aca03afd0e 100644
--- a/tests/rest/client/test_directory.py
+++ b/tests/rest/client/test_directory.py
@@ -11,12 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import json
+from http import HTTPStatus
+
+from twisted.test.proto_helpers import MemoryReactor
from synapse.rest import admin
from synapse.rest.client import directory, login, room
+from synapse.server import HomeServer
from synapse.types import RoomAlias
+from synapse.util import Clock
from synapse.util.stringutils import random_string
from tests import unittest
@@ -32,7 +36,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
room.register_servlets,
]
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
config["require_membership_for_aliases"] = True
@@ -40,7 +44,11 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
return self.hs
- def prepare(self, reactor, clock, homeserver):
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+ ) -> None:
+ """Create two local users and access tokens for them.
+ One of them creates a room."""
self.room_owner = self.register_user("room_owner", "test")
self.room_owner_tok = self.login("room_owner", "test")
@@ -51,39 +59,39 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
self.user = self.register_user("user", "test")
self.user_tok = self.login("user", "test")
- def test_state_event_not_in_room(self):
+ def test_state_event_not_in_room(self) -> None:
self.ensure_user_left_room()
- self.set_alias_via_state_event(403)
+ self.set_alias_via_state_event(HTTPStatus.FORBIDDEN)
- def test_directory_endpoint_not_in_room(self):
+ def test_directory_endpoint_not_in_room(self) -> None:
self.ensure_user_left_room()
- self.set_alias_via_directory(403)
+ self.set_alias_via_directory(HTTPStatus.FORBIDDEN)
- def test_state_event_in_room_too_long(self):
+ def test_state_event_in_room_too_long(self) -> None:
self.ensure_user_joined_room()
- self.set_alias_via_state_event(400, alias_length=256)
+ self.set_alias_via_state_event(HTTPStatus.BAD_REQUEST, alias_length=256)
- def test_directory_in_room_too_long(self):
+ def test_directory_in_room_too_long(self) -> None:
self.ensure_user_joined_room()
- self.set_alias_via_directory(400, alias_length=256)
+ self.set_alias_via_directory(HTTPStatus.BAD_REQUEST, alias_length=256)
@override_config({"default_room_version": 5})
- def test_state_event_user_in_v5_room(self):
+ def test_state_event_user_in_v5_room(self) -> None:
"""Test that a regular user can add alias events before room v6"""
self.ensure_user_joined_room()
- self.set_alias_via_state_event(200)
+ self.set_alias_via_state_event(HTTPStatus.OK)
@override_config({"default_room_version": 6})
- def test_state_event_v6_room(self):
+ def test_state_event_v6_room(self) -> None:
"""Test that a regular user can *not* add alias events from room v6"""
self.ensure_user_joined_room()
- self.set_alias_via_state_event(403)
+ self.set_alias_via_state_event(HTTPStatus.FORBIDDEN)
- def test_directory_in_room(self):
+ def test_directory_in_room(self) -> None:
self.ensure_user_joined_room()
- self.set_alias_via_directory(200)
+ self.set_alias_via_directory(HTTPStatus.OK)
- def test_room_creation_too_long(self):
+ def test_room_creation_too_long(self) -> None:
url = "/_matrix/client/r0/createRoom"
# We use deliberately a localpart under the length threshold so
@@ -93,9 +101,9 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST", url, request_data, access_token=self.user_tok
)
- self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
- def test_room_creation(self):
+ def test_room_creation(self) -> None:
url = "/_matrix/client/r0/createRoom"
# Check with an alias of allowed length. There should already be
@@ -106,9 +114,46 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST", url, request_data, access_token=self.user_tok
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+
+ def test_deleting_alias_via_directory(self) -> None:
+ # Add an alias for the room. We must be joined to do so.
+ self.ensure_user_joined_room()
+ alias = self.set_alias_via_directory(HTTPStatus.OK)
+
+ # Then try to remove the alias
+ channel = self.make_request(
+ "DELETE",
+ f"/_matrix/client/r0/directory/room/{alias}",
+ access_token=self.user_tok,
+ )
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+
+ def test_deleting_nonexistant_alias(self) -> None:
+ # Check that no alias exists
+ alias = "#potato:test"
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/r0/directory/room/{alias}",
+ access_token=self.user_tok,
+ )
+ self.assertEqual(channel.code, HTTPStatus.NOT_FOUND, channel.result)
+ self.assertIn("error", channel.json_body, channel.json_body)
+ self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND", channel.json_body)
+
+ # Then try to remove the alias
+ channel = self.make_request(
+ "DELETE",
+ f"/_matrix/client/r0/directory/room/{alias}",
+ access_token=self.user_tok,
+ )
+ self.assertEqual(channel.code, HTTPStatus.NOT_FOUND, channel.result)
+ self.assertIn("error", channel.json_body, channel.json_body)
+ self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND", channel.json_body)
- def set_alias_via_state_event(self, expected_code, alias_length=5):
+ def set_alias_via_state_event(
+ self, expected_code: HTTPStatus, alias_length: int = 5
+ ) -> None:
url = "/_matrix/client/r0/rooms/%s/state/m.room.aliases/%s" % (
self.room_id,
self.hs.hostname,
@@ -122,8 +167,11 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(channel.code, expected_code, channel.result)
- def set_alias_via_directory(self, expected_code, alias_length=5):
- url = "/_matrix/client/r0/directory/room/%s" % self.random_alias(alias_length)
+ def set_alias_via_directory(
+ self, expected_code: HTTPStatus, alias_length: int = 5
+ ) -> str:
+ alias = self.random_alias(alias_length)
+ url = "/_matrix/client/r0/directory/room/%s" % alias
data = {"room_id": self.room_id}
request_data = json.dumps(data)
@@ -131,17 +179,18 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
"PUT", url, request_data, access_token=self.user_tok
)
self.assertEqual(channel.code, expected_code, channel.result)
+ return alias
- def random_alias(self, length):
+ def random_alias(self, length: int) -> str:
return RoomAlias(random_string(length), self.hs.hostname).to_string()
- def ensure_user_left_room(self):
+ def ensure_user_left_room(self) -> None:
self.ensure_membership("leave")
- def ensure_user_joined_room(self):
+ def ensure_user_joined_room(self) -> None:
self.ensure_membership("join")
- def ensure_membership(self, membership):
+ def ensure_membership(self, membership: str) -> None:
try:
if membership == "leave":
self.helper.leave(room=self.room_id, user=self.user, tok=self.user_tok)
diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py
index a63f04bd41..19f5e46537 100644
--- a/tests/rest/client/test_login.py
+++ b/tests/rest/client/test_login.py
@@ -79,7 +79,10 @@ EXPECTED_CLIENT_REDIRECT_URL_PARAMS = [("<ab c>", ""), ('q" =+"', '"fö&=o"')]
# (possibly experimental) login flows we expect to appear in the list after the normal
# ones
-ADDITIONAL_LOGIN_FLOWS = [{"type": "uk.half-shot.msc2778.login.application_service"}]
+ADDITIONAL_LOGIN_FLOWS = [
+ {"type": "m.login.application_service"},
+ {"type": "uk.half-shot.msc2778.login.application_service"},
+]
class LoginRestServletTestCase(unittest.HomeserverTestCase):
@@ -812,13 +815,20 @@ class JWTTestCase(unittest.HomeserverTestCase):
jwt_secret = "secret"
jwt_algorithm = "HS256"
+ base_config = {
+ "enabled": True,
+ "secret": jwt_secret,
+ "algorithm": jwt_algorithm,
+ }
- def make_homeserver(self, reactor, clock):
- self.hs = self.setup_test_homeserver()
- self.hs.config.jwt.jwt_enabled = True
- self.hs.config.jwt.jwt_secret = self.jwt_secret
- self.hs.config.jwt.jwt_algorithm = self.jwt_algorithm
- return self.hs
+ def default_config(self):
+ config = super().default_config()
+
+ # If jwt_config has been defined (eg via @override_config), don't replace it.
+ if config.get("jwt_config") is None:
+ config["jwt_config"] = self.base_config
+
+ return config
def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_secret) -> str:
# PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
@@ -876,16 +886,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(channel.json_body["error"], "Invalid JWT")
- @override_config(
- {
- "jwt_config": {
- "jwt_enabled": True,
- "secret": jwt_secret,
- "algorithm": jwt_algorithm,
- "issuer": "test-issuer",
- }
- }
- )
+ @override_config({"jwt_config": {**base_config, "issuer": "test-issuer"}})
def test_login_iss(self):
"""Test validating the issuer claim."""
# A valid issuer.
@@ -916,16 +917,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.result["code"], b"200", channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test")
- @override_config(
- {
- "jwt_config": {
- "jwt_enabled": True,
- "secret": jwt_secret,
- "algorithm": jwt_algorithm,
- "audiences": ["test-audience"],
- }
- }
- )
+ @override_config({"jwt_config": {**base_config, "audiences": ["test-audience"]}})
def test_login_aud(self):
"""Test validating the audience claim."""
# A valid audience.
@@ -959,6 +951,19 @@ class JWTTestCase(unittest.HomeserverTestCase):
channel.json_body["error"], "JWT validation failed: Invalid audience"
)
+ def test_login_default_sub(self):
+ """Test reading user ID from the default subject claim."""
+ channel = self.jwt_login({"sub": "kermit"})
+ self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.json_body["user_id"], "@kermit:test")
+
+ @override_config({"jwt_config": {**base_config, "subject_claim": "username"}})
+ def test_login_custom_sub(self):
+ """Test reading user ID from a custom subject claim."""
+ channel = self.jwt_login({"username": "frog"})
+ self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.json_body["user_id"], "@frog:test")
+
def test_login_no_token(self):
params = {"type": "org.matrix.login.jwt"}
channel = self.make_request(b"POST", LOGIN_URL, params)
@@ -1021,12 +1026,14 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
]
)
- def make_homeserver(self, reactor, clock):
- self.hs = self.setup_test_homeserver()
- self.hs.config.jwt.jwt_enabled = True
- self.hs.config.jwt.jwt_secret = self.jwt_pubkey
- self.hs.config.jwt.jwt_algorithm = "RS256"
- return self.hs
+ def default_config(self):
+ config = super().default_config()
+ config["jwt_config"] = {
+ "enabled": True,
+ "secret": self.jwt_pubkey,
+ "algorithm": "RS256",
+ }
+ return config
def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_privatekey) -> str:
# PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index 78c2fb86b9..eb10d43217 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -1,4 +1,5 @@
# Copyright 2019 New Vector Ltd
+# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -46,6 +47,8 @@ class RelationsTestCase(unittest.HomeserverTestCase):
return config
def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+
self.user_id, self.user_token = self._create_user("alice")
self.user2_id, self.user2_token = self._create_user("bob")
@@ -91,6 +94,49 @@ class RelationsTestCase(unittest.HomeserverTestCase):
channel = self._send_relation(RelationTypes.ANNOTATION, EventTypes.Member)
self.assertEquals(400, channel.code, channel.json_body)
+ def test_deny_invalid_event(self):
+ """Test that we deny relations on non-existant events"""
+ channel = self._send_relation(
+ RelationTypes.ANNOTATION,
+ EventTypes.Message,
+ parent_id="foo",
+ content={"body": "foo", "msgtype": "m.text"},
+ )
+ self.assertEquals(400, channel.code, channel.json_body)
+
+ # Unless that event is referenced from another event!
+ self.get_success(
+ self.hs.get_datastore().db_pool.simple_insert(
+ table="event_relations",
+ values={
+ "event_id": "bar",
+ "relates_to_id": "foo",
+ "relation_type": RelationTypes.THREAD,
+ },
+ desc="test_deny_invalid_event",
+ )
+ )
+ channel = self._send_relation(
+ RelationTypes.THREAD,
+ EventTypes.Message,
+ parent_id="foo",
+ content={"body": "foo", "msgtype": "m.text"},
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ def test_deny_invalid_room(self):
+ """Test that we deny relations on non-existant events"""
+ # Create another room and send a message in it.
+ room2 = self.helper.create_room_as(self.user_id, tok=self.user_token)
+ res = self.helper.send(room2, body="Hi!", tok=self.user_token)
+ parent_id = res["event_id"]
+
+ # Attempt to send an annotation to that event.
+ channel = self._send_relation(
+ RelationTypes.ANNOTATION, "m.reaction", parent_id=parent_id, key="A"
+ )
+ self.assertEquals(400, channel.code, channel.json_body)
+
def test_deny_double_react(self):
"""Test that we deny relations on membership events"""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
@@ -99,6 +145,25 @@ class RelationsTestCase(unittest.HomeserverTestCase):
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
self.assertEquals(400, channel.code, channel.json_body)
+ def test_deny_forked_thread(self):
+ """It is invalid to start a thread off a thread."""
+ channel = self._send_relation(
+ RelationTypes.THREAD,
+ "m.room.message",
+ content={"msgtype": "m.text", "body": "foo"},
+ parent_id=self.parent_id,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ parent_id = channel.json_body["event_id"]
+
+ channel = self._send_relation(
+ RelationTypes.THREAD,
+ "m.room.message",
+ content={"msgtype": "m.text", "body": "foo"},
+ parent_id=parent_id,
+ )
+ self.assertEquals(400, channel.code, channel.json_body)
+
def test_basic_paginate_relations(self):
"""Tests that calling pagination API correctly the latest relations."""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
@@ -703,6 +768,52 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertIn("chunk", channel.json_body)
self.assertEquals(channel.json_body["chunk"], [])
+ def test_unknown_relations(self):
+ """Unknown relations should be accepted."""
+ channel = self._send_relation("m.relation.test", "m.room.test")
+ self.assertEquals(200, channel.code, channel.json_body)
+ event_id = channel.json_body["event_id"]
+
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/rooms/%s/relations/%s?limit=1"
+ % (self.room, self.parent_id),
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ # We expect to get back a single pagination result, which is the full
+ # relation event we sent above.
+ self.assertEquals(len(channel.json_body["chunk"]), 1, channel.json_body)
+ self.assert_dict(
+ {"event_id": event_id, "sender": self.user_id, "type": "m.room.test"},
+ channel.json_body["chunk"][0],
+ )
+
+ # We also expect to get the original event (the id of which is self.parent_id)
+ self.assertEquals(
+ channel.json_body["original_event"]["event_id"], self.parent_id
+ )
+
+ # When bundling the unknown relation is not included.
+ channel = self.make_request(
+ "GET",
+ "/rooms/%s/event/%s" % (self.room, self.parent_id),
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ self.assertNotIn("m.relations", channel.json_body["unsigned"])
+
+ # But unknown relations can be directly queried.
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/rooms/%s/aggregations/%s?limit=1"
+ % (self.room, self.parent_id),
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ self.assertEquals(channel.json_body["chunk"], [])
+
def _send_relation(
self,
relation_type: str,
@@ -749,3 +860,65 @@ class RelationsTestCase(unittest.HomeserverTestCase):
access_token = self.login(localpart, "abc123")
return user_id, access_token
+
+ def test_background_update(self):
+ """Test the event_arbitrary_relations background update."""
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍")
+ self.assertEquals(200, channel.code, channel.json_body)
+ annotation_event_id_good = channel.json_body["event_id"]
+
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="A")
+ self.assertEquals(200, channel.code, channel.json_body)
+ annotation_event_id_bad = channel.json_body["event_id"]
+
+ channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
+ self.assertEquals(200, channel.code, channel.json_body)
+ thread_event_id = channel.json_body["event_id"]
+
+ # Clean-up the table as if the inserts did not happen during event creation.
+ self.get_success(
+ self.store.db_pool.simple_delete_many(
+ table="event_relations",
+ column="event_id",
+ iterable=(annotation_event_id_bad, thread_event_id),
+ keyvalues={},
+ desc="RelationsTestCase.test_background_update",
+ )
+ )
+
+ # Only the "good" annotation should be found.
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=10",
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ self.assertEquals(
+ [ev["event_id"] for ev in channel.json_body["chunk"]],
+ [annotation_event_id_good],
+ )
+
+ # Insert and run the background update.
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "background_updates",
+ {"update_name": "event_arbitrary_relations", "progress_json": "{}"},
+ )
+ )
+
+ # Ugh, have to reset this flag
+ self.store.db_pool.updates._all_done = False
+ self.wait_for_background_updates()
+
+ # The "good" annotation and the thread should be found, but not the "bad"
+ # annotation.
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=10",
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ self.assertCountEqual(
+ [ev["event_id"] for ev in channel.json_body["chunk"]],
+ [annotation_event_id_good, thread_event_id],
+ )
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index 376853fd65..10a4a4dc5e 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -25,7 +25,12 @@ from urllib import parse as urlparse
from twisted.internet import defer
import synapse.rest.admin
-from synapse.api.constants import EventContentFields, EventTypes, Membership
+from synapse.api.constants import (
+ EventContentFields,
+ EventTypes,
+ Membership,
+ RelationTypes,
+)
from synapse.api.errors import Codes, HttpResponseException
from synapse.handlers.pagination import PurgeStatus
from synapse.rest import admin
@@ -2157,6 +2162,153 @@ class LabelsTestCase(unittest.HomeserverTestCase):
return event_id
+class RelationsTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def default_config(self):
+ config = super().default_config()
+ config["experimental_features"] = {"msc3440_enabled": True}
+ return config
+
+ def prepare(self, reactor, clock, homeserver):
+ self.user_id = self.register_user("test", "test")
+ self.tok = self.login("test", "test")
+ self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
+
+ self.second_user_id = self.register_user("second", "test")
+ self.second_tok = self.login("second", "test")
+ self.helper.join(
+ room=self.room_id, user=self.second_user_id, tok=self.second_tok
+ )
+
+ self.third_user_id = self.register_user("third", "test")
+ self.third_tok = self.login("third", "test")
+ self.helper.join(room=self.room_id, user=self.third_user_id, tok=self.third_tok)
+
+ # An initial event with a relation from second user.
+ res = self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={"msgtype": "m.text", "body": "Message 1"},
+ tok=self.tok,
+ )
+ self.event_id_1 = res["event_id"]
+ self.helper.send_event(
+ room_id=self.room_id,
+ type="m.reaction",
+ content={
+ "m.relates_to": {
+ "rel_type": RelationTypes.ANNOTATION,
+ "event_id": self.event_id_1,
+ "key": "👍",
+ }
+ },
+ tok=self.second_tok,
+ )
+
+ # Another event with a relation from third user.
+ res = self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={"msgtype": "m.text", "body": "Message 2"},
+ tok=self.tok,
+ )
+ self.event_id_2 = res["event_id"]
+ self.helper.send_event(
+ room_id=self.room_id,
+ type="m.reaction",
+ content={
+ "m.relates_to": {
+ "rel_type": RelationTypes.REFERENCE,
+ "event_id": self.event_id_2,
+ }
+ },
+ tok=self.third_tok,
+ )
+
+ # An event with no relations.
+ self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={"msgtype": "m.text", "body": "No relations"},
+ tok=self.tok,
+ )
+
+ def _filter_messages(self, filter: JsonDict) -> List[JsonDict]:
+ """Make a request to /messages with a filter, returns the chunk of events."""
+ channel = self.make_request(
+ "GET",
+ "/rooms/%s/messages?filter=%s&dir=b" % (self.room_id, json.dumps(filter)),
+ access_token=self.tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ return channel.json_body["chunk"]
+
+ def test_filter_relation_senders(self):
+ # Messages which second user reacted to.
+ filter = {"io.element.relation_senders": [self.second_user_id]}
+ chunk = self._filter_messages(filter)
+ self.assertEqual(len(chunk), 1, chunk)
+ self.assertEqual(chunk[0]["event_id"], self.event_id_1)
+
+ # Messages which third user reacted to.
+ filter = {"io.element.relation_senders": [self.third_user_id]}
+ chunk = self._filter_messages(filter)
+ self.assertEqual(len(chunk), 1, chunk)
+ self.assertEqual(chunk[0]["event_id"], self.event_id_2)
+
+ # Messages which either user reacted to.
+ filter = {
+ "io.element.relation_senders": [self.second_user_id, self.third_user_id]
+ }
+ chunk = self._filter_messages(filter)
+ self.assertEqual(len(chunk), 2, chunk)
+ self.assertCountEqual(
+ [c["event_id"] for c in chunk], [self.event_id_1, self.event_id_2]
+ )
+
+ def test_filter_relation_type(self):
+ # Messages which have annotations.
+ filter = {"io.element.relation_types": [RelationTypes.ANNOTATION]}
+ chunk = self._filter_messages(filter)
+ self.assertEqual(len(chunk), 1, chunk)
+ self.assertEqual(chunk[0]["event_id"], self.event_id_1)
+
+ # Messages which have references.
+ filter = {"io.element.relation_types": [RelationTypes.REFERENCE]}
+ chunk = self._filter_messages(filter)
+ self.assertEqual(len(chunk), 1, chunk)
+ self.assertEqual(chunk[0]["event_id"], self.event_id_2)
+
+ # Messages which have either annotations or references.
+ filter = {
+ "io.element.relation_types": [
+ RelationTypes.ANNOTATION,
+ RelationTypes.REFERENCE,
+ ]
+ }
+ chunk = self._filter_messages(filter)
+ self.assertEqual(len(chunk), 2, chunk)
+ self.assertCountEqual(
+ [c["event_id"] for c in chunk], [self.event_id_1, self.event_id_2]
+ )
+
+ def test_filter_relation_senders_and_type(self):
+ # Messages which second user reacted to.
+ filter = {
+ "io.element.relation_senders": [self.second_user_id],
+ "io.element.relation_types": [RelationTypes.ANNOTATION],
+ }
+ chunk = self._filter_messages(filter)
+ self.assertEqual(len(chunk), 1, chunk)
+ self.assertEqual(chunk[0]["event_id"], self.event_id_1)
+
+
class ContextTestCase(unittest.HomeserverTestCase):
servlets = [
diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index ec0979850b..1af5e5cee5 100644
--- a/tests/rest/client/utils.py
+++ b/tests/rest/client/utils.py
@@ -19,10 +19,21 @@ import json
import re
import time
import urllib.parse
-from typing import Any, Dict, Iterable, Mapping, MutableMapping, Optional, Tuple, Union
+from typing import (
+ Any,
+ AnyStr,
+ Dict,
+ Iterable,
+ Mapping,
+ MutableMapping,
+ Optional,
+ Tuple,
+ overload,
+)
from unittest.mock import patch
import attr
+from typing_extensions import Literal
from twisted.web.resource import Resource
from twisted.web.server import Site
@@ -45,6 +56,32 @@ class RestHelper:
site = attr.ib(type=Site)
auth_user_id = attr.ib()
+ @overload
+ def create_room_as(
+ self,
+ room_creator: Optional[str] = ...,
+ is_public: Optional[bool] = ...,
+ room_version: Optional[str] = ...,
+ tok: Optional[str] = ...,
+ expect_code: Literal[200] = ...,
+ extra_content: Optional[Dict] = ...,
+ custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = ...,
+ ) -> str:
+ ...
+
+ @overload
+ def create_room_as(
+ self,
+ room_creator: Optional[str] = ...,
+ is_public: Optional[bool] = ...,
+ room_version: Optional[str] = ...,
+ tok: Optional[str] = ...,
+ expect_code: int = ...,
+ extra_content: Optional[Dict] = ...,
+ custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = ...,
+ ) -> Optional[str]:
+ ...
+
def create_room_as(
self,
room_creator: Optional[str] = None,
@@ -53,10 +90,8 @@ class RestHelper:
tok: Optional[str] = None,
expect_code: int = 200,
extra_content: Optional[Dict] = None,
- custom_headers: Optional[
- Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
- ] = None,
- ) -> str:
+ custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
+ ) -> Optional[str]:
"""
Create a room.
@@ -99,6 +134,8 @@ class RestHelper:
if expect_code == 200:
return channel.json_body["room_id"]
+ else:
+ return None
def invite(self, room=None, src=None, targ=None, expect_code=200, tok=None):
self.change_membership(
@@ -168,7 +205,7 @@ class RestHelper:
extra_data: Optional[dict] = None,
tok: Optional[str] = None,
expect_code: int = 200,
- expect_errcode: str = None,
+ expect_errcode: Optional[str] = None,
) -> None:
"""
Send a membership state event into a room.
@@ -227,9 +264,7 @@ class RestHelper:
txn_id=None,
tok=None,
expect_code=200,
- custom_headers: Optional[
- Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
- ] = None,
+ custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
):
if body is None:
body = "body_text_here"
@@ -254,9 +289,7 @@ class RestHelper:
txn_id=None,
tok=None,
expect_code=200,
- custom_headers: Optional[
- Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
- ] = None,
+ custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
):
if txn_id is None:
txn_id = "m%s" % (str(time.time()))
@@ -418,7 +451,7 @@ class RestHelper:
path,
content=image_data,
access_token=tok,
- custom_headers=[(b"Content-Length", str(image_length))],
+ custom_headers=[("Content-Length", str(image_length))],
)
assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
@@ -503,7 +536,7 @@ class RestHelper:
went.
"""
- cookies = {}
+ cookies: Dict[str, str] = {}
# if we're doing a ui auth, hit the ui auth redirect endpoint
if ui_auth_session_id:
@@ -625,7 +658,13 @@ class RestHelper:
# hit the redirect url again with the right Host header, which should now issue
# a cookie and redirect to the SSO provider.
- location = channel.headers.getRawHeaders("Location")[0]
+ def get_location(channel: FakeChannel) -> str:
+ location_values = channel.headers.getRawHeaders("Location")
+ # Keep mypy happy by asserting that location_values is nonempty
+ assert location_values
+ return location_values[0]
+
+ location = get_location(channel)
parts = urllib.parse.urlsplit(location)
channel = make_request(
self.hs.get_reactor(),
@@ -639,7 +678,7 @@ class RestHelper:
assert channel.code == 302
channel.extract_cookies(cookies)
- return channel.headers.getRawHeaders("Location")[0]
+ return get_location(channel)
def initiate_sso_ui_auth(
self, ui_auth_session_id: str, cookies: MutableMapping[str, str]
diff --git a/tests/server.py b/tests/server.py
index 103351b487..40cf5b12c3 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -16,7 +16,17 @@ import json
import logging
from collections import deque
from io import SEEK_END, BytesIO
-from typing import Callable, Dict, Iterable, MutableMapping, Optional, Tuple, Union
+from typing import (
+ AnyStr,
+ Callable,
+ Dict,
+ Iterable,
+ MutableMapping,
+ Optional,
+ Tuple,
+ Type,
+ Union,
+)
import attr
from typing_extensions import Deque
@@ -217,14 +227,12 @@ def make_request(
path: Union[bytes, str],
content: Union[bytes, str, JsonDict] = b"",
access_token: Optional[str] = None,
- request: Request = SynapseRequest,
+ request: Type[Request] = SynapseRequest,
shorthand: bool = True,
federation_auth_origin: Optional[bytes] = None,
content_is_form: bool = False,
await_result: bool = True,
- custom_headers: Optional[
- Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
- ] = None,
+ custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
client_ip: str = "127.0.0.1",
) -> FakeChannel:
"""
diff --git a/tests/storage/databases/main/test_deviceinbox.py b/tests/storage/databases/main/test_deviceinbox.py
index 4b67bd15b7..36c933b9e9 100644
--- a/tests/storage/databases/main/test_deviceinbox.py
+++ b/tests/storage/databases/main/test_deviceinbox.py
@@ -66,7 +66,7 @@ class DeviceInboxBackgroundUpdateStoreTestCase(HomeserverTestCase):
self.store.db_pool.simple_insert(
"background_updates",
{
- "update_name": "remove_deleted_devices_from_device_inbox",
+ "update_name": "remove_dead_devices_from_device_inbox",
"progress_json": "{}",
},
)
@@ -140,7 +140,7 @@ class DeviceInboxBackgroundUpdateStoreTestCase(HomeserverTestCase):
self.store.db_pool.simple_insert(
"background_updates",
{
- "update_name": "remove_hidden_devices_from_device_inbox",
+ "update_name": "remove_dead_devices_from_device_inbox",
"progress_json": "{}",
},
)
diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index 0da42b5ac5..a5f5ebad41 100644
--- a/tests/storage/test_background_update.py
+++ b/tests/storage/test_background_update.py
@@ -19,11 +19,11 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
)
def test_do_background_update(self):
- # the time we claim each update takes
- duration_ms = 42
+ # the time we claim it takes to update one item when running the update
+ duration_ms = 4200
# the target runtime for each bg update
- target_background_update_duration_ms = 50000
+ target_background_update_duration_ms = 5000000
store = self.hs.get_datastore()
self.get_success(
@@ -57,7 +57,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
# on the first call, we should get run with the default background update size
self.update_handler.assert_called_once_with(
- {"my_key": 1}, self.updates.DEFAULT_BACKGROUND_BATCH_SIZE
+ {"my_key": 1}, self.updates.MINIMUM_BACKGROUND_BATCH_SIZE
)
# second step: complete the update
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index a1ba99ff14..d37736edf8 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -11,19 +11,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from twisted.test.proto_helpers import MemoryReactor
+from synapse.server import HomeServer
from synapse.types import UserID
+from synapse.util import Clock
from tests import unittest
class ProfileStoreTestCase(unittest.HomeserverTestCase):
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastore()
self.u_frank = UserID.from_string("@frank:test")
- def test_displayname(self):
+ def test_displayname(self) -> None:
self.get_success(self.store.create_profile(self.u_frank.localpart))
self.get_success(
@@ -48,7 +51,7 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase):
self.get_success(self.store.get_profile_displayname(self.u_frank.localpart))
)
- def test_avatar_url(self):
+ def test_avatar_url(self) -> None:
self.get_success(self.store.create_profile(self.u_frank.localpart))
self.get_success(
diff --git a/tests/storage/test_rollback_worker.py b/tests/storage/test_rollback_worker.py
index 0ce0892165..cfc8098af6 100644
--- a/tests/storage/test_rollback_worker.py
+++ b/tests/storage/test_rollback_worker.py
@@ -33,7 +33,7 @@ def fake_listdir(filepath: str) -> List[str]:
A list of files and folders in the directory.
"""
if filepath.endswith("full_schemas"):
- return [SCHEMA_VERSION]
+ return [str(SCHEMA_VERSION)]
return ["99_add_unicorn_to_database.sql"]
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 2873e22ccf..fccab733c0 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -161,6 +161,54 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(users.keys(), {self.u_alice, self.u_bob})
+ def test__null_byte_in_display_name_properly_handled(self):
+ room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
+
+ res = self.get_success(
+ self.store.db_pool.simple_select_list(
+ "room_memberships",
+ {"user_id": "@alice:test"},
+ ["display_name", "event_id"],
+ )
+ )
+ # Check that we only got one result back
+ self.assertEqual(len(res), 1)
+
+ # Check that alice's display name is "alice"
+ self.assertEqual(res[0]["display_name"], "alice")
+
+ # Grab the event_id to use later
+ event_id = res[0]["event_id"]
+
+ # Create a profile with the offending null byte in the display name
+ new_profile = {"displayname": "ali\u0000ce"}
+
+ # Ensure that the change goes smoothly and does not fail due to the null byte
+ self.helper.change_membership(
+ room,
+ self.u_alice,
+ self.u_alice,
+ "join",
+ extra_data=new_profile,
+ tok=self.t_alice,
+ )
+
+ res2 = self.get_success(
+ self.store.db_pool.simple_select_list(
+ "room_memberships",
+ {"user_id": "@alice:test"},
+ ["display_name", "event_id"],
+ )
+ )
+ # Check that we only have two results
+ self.assertEqual(len(res2), 2)
+
+ # Filter out the previous event using the event_id we grabbed above
+ row = [row for row in res2 if row["event_id"] != event_id]
+
+ # Check that alice's display name is now None
+ self.assertEqual(row[0]["display_name"], None)
+
class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, homeserver):
diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py
new file mode 100644
index 0000000000..ce782c7e1d
--- /dev/null
+++ b/tests/storage/test_stream.py
@@ -0,0 +1,207 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List
+
+from synapse.api.constants import EventTypes, RelationTypes
+from synapse.api.filtering import Filter
+from synapse.events import EventBase
+from synapse.rest import admin
+from synapse.rest.client import login, room
+from synapse.types import JsonDict
+
+from tests.unittest import HomeserverTestCase
+
+
+class PaginationTestCase(HomeserverTestCase):
+ """
+ Test the pre-filtering done in the pagination code.
+
+ This is similar to some of the tests in tests.rest.client.test_rooms but here
+ we ensure that the filtering done in the database is applied successfully.
+ """
+
+ servlets = [
+ admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def default_config(self):
+ config = super().default_config()
+ config["experimental_features"] = {"msc3440_enabled": True}
+ return config
+
+ def prepare(self, reactor, clock, homeserver):
+ self.user_id = self.register_user("test", "test")
+ self.tok = self.login("test", "test")
+ self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
+
+ self.second_user_id = self.register_user("second", "test")
+ self.second_tok = self.login("second", "test")
+ self.helper.join(
+ room=self.room_id, user=self.second_user_id, tok=self.second_tok
+ )
+
+ self.third_user_id = self.register_user("third", "test")
+ self.third_tok = self.login("third", "test")
+ self.helper.join(room=self.room_id, user=self.third_user_id, tok=self.third_tok)
+
+ # An initial event with a relation from second user.
+ res = self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={"msgtype": "m.text", "body": "Message 1"},
+ tok=self.tok,
+ )
+ self.event_id_1 = res["event_id"]
+ self.helper.send_event(
+ room_id=self.room_id,
+ type="m.reaction",
+ content={
+ "m.relates_to": {
+ "rel_type": RelationTypes.ANNOTATION,
+ "event_id": self.event_id_1,
+ "key": "👍",
+ }
+ },
+ tok=self.second_tok,
+ )
+
+ # Another event with a relation from third user.
+ res = self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={"msgtype": "m.text", "body": "Message 2"},
+ tok=self.tok,
+ )
+ self.event_id_2 = res["event_id"]
+ self.helper.send_event(
+ room_id=self.room_id,
+ type="m.reaction",
+ content={
+ "m.relates_to": {
+ "rel_type": RelationTypes.REFERENCE,
+ "event_id": self.event_id_2,
+ }
+ },
+ tok=self.third_tok,
+ )
+
+ # An event with no relations.
+ self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={"msgtype": "m.text", "body": "No relations"},
+ tok=self.tok,
+ )
+
+ def _filter_messages(self, filter: JsonDict) -> List[EventBase]:
+ """Make a request to /messages with a filter, returns the chunk of events."""
+
+ from_token = self.get_success(
+ self.hs.get_event_sources().get_current_token_for_pagination()
+ )
+
+ events, next_key = self.get_success(
+ self.hs.get_datastore().paginate_room_events(
+ room_id=self.room_id,
+ from_key=from_token.room_key,
+ to_key=None,
+ direction="b",
+ limit=10,
+ event_filter=Filter(self.hs, filter),
+ )
+ )
+
+ return events
+
+ def test_filter_relation_senders(self):
+ # Messages which second user reacted to.
+ filter = {"io.element.relation_senders": [self.second_user_id]}
+ chunk = self._filter_messages(filter)
+ self.assertEqual(len(chunk), 1, chunk)
+ self.assertEqual(chunk[0].event_id, self.event_id_1)
+
+ # Messages which third user reacted to.
+ filter = {"io.element.relation_senders": [self.third_user_id]}
+ chunk = self._filter_messages(filter)
+ self.assertEqual(len(chunk), 1, chunk)
+ self.assertEqual(chunk[0].event_id, self.event_id_2)
+
+ # Messages which either user reacted to.
+ filter = {
+ "io.element.relation_senders": [self.second_user_id, self.third_user_id]
+ }
+ chunk = self._filter_messages(filter)
+ self.assertEqual(len(chunk), 2, chunk)
+ self.assertCountEqual(
+ [c.event_id for c in chunk], [self.event_id_1, self.event_id_2]
+ )
+
+ def test_filter_relation_type(self):
+ # Messages which have annotations.
+ filter = {"io.element.relation_types": [RelationTypes.ANNOTATION]}
+ chunk = self._filter_messages(filter)
+ self.assertEqual(len(chunk), 1, chunk)
+ self.assertEqual(chunk[0].event_id, self.event_id_1)
+
+ # Messages which have references.
+ filter = {"io.element.relation_types": [RelationTypes.REFERENCE]}
+ chunk = self._filter_messages(filter)
+ self.assertEqual(len(chunk), 1, chunk)
+ self.assertEqual(chunk[0].event_id, self.event_id_2)
+
+ # Messages which have either annotations or references.
+ filter = {
+ "io.element.relation_types": [
+ RelationTypes.ANNOTATION,
+ RelationTypes.REFERENCE,
+ ]
+ }
+ chunk = self._filter_messages(filter)
+ self.assertEqual(len(chunk), 2, chunk)
+ self.assertCountEqual(
+ [c.event_id for c in chunk], [self.event_id_1, self.event_id_2]
+ )
+
+ def test_filter_relation_senders_and_type(self):
+ # Messages which second user reacted to.
+ filter = {
+ "io.element.relation_senders": [self.second_user_id],
+ "io.element.relation_types": [RelationTypes.ANNOTATION],
+ }
+ chunk = self._filter_messages(filter)
+ self.assertEqual(len(chunk), 1, chunk)
+ self.assertEqual(chunk[0].event_id, self.event_id_1)
+
+ def test_duplicate_relation(self):
+ """An event should only be returned once if there are multiple relations to it."""
+ self.helper.send_event(
+ room_id=self.room_id,
+ type="m.reaction",
+ content={
+ "m.relates_to": {
+ "rel_type": RelationTypes.ANNOTATION,
+ "event_id": self.event_id_1,
+ "key": "A",
+ }
+ },
+ tok=self.second_tok,
+ )
+
+ filter = {"io.element.relation_senders": [self.second_user_id]}
+ chunk = self._filter_messages(filter)
+ self.assertEqual(len(chunk), 1, chunk)
+ self.assertEqual(chunk[0].event_id, self.event_id_1)
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 24fc77d7a7..3eef1c4c05 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -81,8 +81,6 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
origin,
event,
context,
- state=None,
- backfilled=False,
):
return context
diff --git a/tests/unittest.py b/tests/unittest.py
index a9b60b7eeb..165aafc574 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -20,7 +20,20 @@ import inspect
import logging
import secrets
import time
-from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Type, TypeVar, Union
+from typing import (
+ Any,
+ AnyStr,
+ Callable,
+ ClassVar,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Tuple,
+ Type,
+ TypeVar,
+ Union,
+)
from unittest.mock import Mock, patch
from canonicaljson import json
@@ -31,6 +44,7 @@ from twisted.python.threadpool import ThreadPool
from twisted.test.proto_helpers import MemoryReactor
from twisted.trial import unittest
from twisted.web.resource import Resource
+from twisted.web.server import Request
from synapse import events
from synapse.api.constants import EventTypes, Membership
@@ -45,6 +59,7 @@ from synapse.logging.context import (
current_context,
set_current_context,
)
+from synapse.rest import RegisterServletsFunc
from synapse.server import HomeServer
from synapse.types import JsonDict, UserID, create_requester
from synapse.util import Clock
@@ -81,16 +96,13 @@ def around(target):
return _around
-T = TypeVar("T")
-
-
class TestCase(unittest.TestCase):
"""A subclass of twisted.trial's TestCase which looks for 'loglevel'
attributes on both itself and its individual test methods, to override the
root logger's logging level while that test (case|method) runs."""
- def __init__(self, methodName, *args, **kwargs):
- super().__init__(methodName, *args, **kwargs)
+ def __init__(self, methodName: str):
+ super().__init__(methodName)
method = getattr(self, methodName)
@@ -204,18 +216,18 @@ class HomeserverTestCase(TestCase):
config dict.
Attributes:
- servlets (list[function]): List of servlet registration function.
+ servlets: List of servlet registration function.
user_id (str): The user ID to assume if auth is hijacked.
- hijack_auth (bool): Whether to hijack auth to return the user specified
+ hijack_auth: Whether to hijack auth to return the user specified
in user_id.
"""
- servlets = []
- hijack_auth = True
- needs_threadpool = False
+ hijack_auth: ClassVar[bool] = True
+ needs_threadpool: ClassVar[bool] = False
+ servlets: ClassVar[List[RegisterServletsFunc]] = []
- def __init__(self, methodName, *args, **kwargs):
- super().__init__(methodName, *args, **kwargs)
+ def __init__(self, methodName: str):
+ super().__init__(methodName)
# see if we have any additional config for this test
method = getattr(self, methodName)
@@ -287,9 +299,10 @@ class HomeserverTestCase(TestCase):
None,
)
- self.hs.get_auth().get_user_by_req = get_user_by_req
- self.hs.get_auth().get_user_by_access_token = get_user_by_access_token
- self.hs.get_auth().get_access_token_from_request = Mock(
+ # Type ignore: mypy doesn't like us assigning to methods.
+ self.hs.get_auth().get_user_by_req = get_user_by_req # type: ignore[assignment]
+ self.hs.get_auth().get_user_by_access_token = get_user_by_access_token # type: ignore[assignment]
+ self.hs.get_auth().get_access_token_from_request = Mock( # type: ignore[assignment]
return_value="1234"
)
@@ -318,7 +331,12 @@ class HomeserverTestCase(TestCase):
time.sleep(0.01)
def wait_for_background_updates(self) -> None:
- """Block until all background database updates have completed."""
+ """
+ Block until all background database updates have completed.
+
+ Note that callers must ensure that's a store property created on the
+ testcase.
+ """
while not self.get_success(
self.store.db_pool.updates.has_completed_background_updates()
):
@@ -403,14 +421,12 @@ class HomeserverTestCase(TestCase):
path: Union[bytes, str],
content: Union[bytes, str, JsonDict] = b"",
access_token: Optional[str] = None,
- request: Type[T] = SynapseRequest,
+ request: Type[Request] = SynapseRequest,
shorthand: bool = True,
- federation_auth_origin: str = None,
+ federation_auth_origin: Optional[bytes] = None,
content_is_form: bool = False,
await_result: bool = True,
- custom_headers: Optional[
- Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
- ] = None,
+ custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
client_ip: str = "127.0.0.1",
) -> FakeChannel:
"""
@@ -425,7 +441,7 @@ class HomeserverTestCase(TestCase):
a dict.
shorthand: Whether to try and be helpful and prefix the given URL
with the usual REST API path, if it doesn't contain it.
- federation_auth_origin (bytes|None): if set to not-None, we will add a fake
+ federation_auth_origin: if set to not-None, we will add a fake
Authorization header pretenting to be the given server name.
content_is_form: Whether the content is URL encoded form data. Adds the
'Content-Type': 'application/x-www-form-urlencoded' header.
@@ -584,7 +600,7 @@ class HomeserverTestCase(TestCase):
nonce_str += b"\x00notadmin"
want_mac.update(nonce.encode("ascii") + b"\x00" + nonce_str)
- want_mac = want_mac.hexdigest()
+ want_mac_digest = want_mac.hexdigest()
body = json.dumps(
{
@@ -593,7 +609,7 @@ class HomeserverTestCase(TestCase):
"displayname": displayname,
"password": password,
"admin": admin,
- "mac": want_mac,
+ "mac": want_mac_digest,
"inhibit_login": True,
}
)
@@ -639,9 +655,7 @@ class HomeserverTestCase(TestCase):
username,
password,
device_id=None,
- custom_headers: Optional[
- Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
- ] = None,
+ custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
):
"""
Log in a user, and get an access token. Requires the Login API be
diff --git a/tests/utils.py b/tests/utils.py
index cf8ba5c5db..983859120f 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -119,7 +119,6 @@ def default_config(name, parse=False):
"enable_registration": True,
"enable_registration_captcha": False,
"macaroon_secret_key": "not even a little secret",
- "trusted_third_party_id_servers": [],
"password_providers": [],
"worker_replication_url": "",
"worker_app": None,
|