diff options
Diffstat (limited to 'tests')
72 files changed, 5736 insertions, 1137 deletions
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 6121efcfa9..0bfb86bf1f 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -52,6 +52,10 @@ class AuthTestCase(unittest.TestCase): self.hs.handlers = TestHandlers(self.hs) self.auth = Auth(self.hs) + # AuthBlocking reads from the hs' config on initialization. We need to + # modify its config instead of the hs' + self.auth_blocking = self.auth._auth_blocking + self.test_user = "@foo:bar" self.test_token = b"_test_token_" @@ -68,7 +72,7 @@ class AuthTestCase(unittest.TestCase): request = Mock(args={}) request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() - requester = yield self.auth.get_user_by_req(request) + requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request)) self.assertEquals(requester.user.to_string(), self.test_user) def test_get_user_by_req_user_bad_token(self): @@ -105,7 +109,7 @@ class AuthTestCase(unittest.TestCase): request.getClientIP.return_value = "127.0.0.1" request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() - requester = yield self.auth.get_user_by_req(request) + requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request)) self.assertEquals(requester.user.to_string(), self.test_user) @defer.inlineCallbacks @@ -125,7 +129,7 @@ class AuthTestCase(unittest.TestCase): request.getClientIP.return_value = "192.168.10.10" request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() - requester = yield self.auth.get_user_by_req(request) + requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request)) self.assertEquals(requester.user.to_string(), self.test_user) def test_get_user_by_req_appservice_valid_token_bad_ip(self): @@ -188,7 +192,7 @@ class AuthTestCase(unittest.TestCase): request.args[b"access_token"] = [self.test_token] request.args[b"user_id"] = [masquerading_user_id] request.requestHeaders.getRawHeaders = mock_getRawHeaders() - requester = yield self.auth.get_user_by_req(request) + requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request)) self.assertEquals( requester.user.to_string(), masquerading_user_id.decode("utf8") ) @@ -225,7 +229,9 @@ class AuthTestCase(unittest.TestCase): macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("type = access") macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) - user_info = yield self.auth.get_user_by_access_token(macaroon.serialize()) + user_info = yield defer.ensureDeferred( + self.auth.get_user_by_access_token(macaroon.serialize()) + ) user = user_info["user"] self.assertEqual(UserID.from_string(user_id), user) @@ -250,7 +256,9 @@ class AuthTestCase(unittest.TestCase): macaroon.add_first_party_caveat("guest = true") serialized = macaroon.serialize() - user_info = yield self.auth.get_user_by_access_token(serialized) + user_info = yield defer.ensureDeferred( + self.auth.get_user_by_access_token(serialized) + ) user = user_info["user"] is_guest = user_info["is_guest"] self.assertEqual(UserID.from_string(user_id), user) @@ -260,10 +268,13 @@ class AuthTestCase(unittest.TestCase): @defer.inlineCallbacks def test_cannot_use_regular_token_as_guest(self): USER_ID = "@percy:matrix.org" - self.store.add_access_token_to_user = Mock() + self.store.add_access_token_to_user = Mock(return_value=defer.succeed(None)) + self.store.get_device = Mock(return_value=defer.succeed(None)) - token = yield self.hs.handlers.auth_handler.get_access_token_for_user_id( - USER_ID, "DEVICE", valid_until_ms=None + token = yield defer.ensureDeferred( + self.hs.handlers.auth_handler.get_access_token_for_user_id( + USER_ID, "DEVICE", valid_until_ms=None + ) ) self.store.add_access_token_to_user.assert_called_with( USER_ID, token, "DEVICE", None @@ -286,7 +297,9 @@ class AuthTestCase(unittest.TestCase): request = Mock(args={}) request.args[b"access_token"] = [token.encode("ascii")] request.requestHeaders.getRawHeaders = mock_getRawHeaders() - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + requester = yield defer.ensureDeferred( + self.auth.get_user_by_req(request, allow_guest=True) + ) self.assertEqual(UserID.from_string(USER_ID), requester.user) self.assertFalse(requester.is_guest) @@ -301,7 +314,9 @@ class AuthTestCase(unittest.TestCase): request.requestHeaders.getRawHeaders = mock_getRawHeaders() with self.assertRaises(InvalidClientCredentialsError) as cm: - yield self.auth.get_user_by_req(request, allow_guest=True) + yield defer.ensureDeferred( + self.auth.get_user_by_req(request, allow_guest=True) + ) self.assertEqual(401, cm.exception.code) self.assertEqual("Guest access token used for regular user", cm.exception.msg) @@ -310,22 +325,22 @@ class AuthTestCase(unittest.TestCase): @defer.inlineCallbacks def test_blocking_mau(self): - self.hs.config.limit_usage_by_mau = False - self.hs.config.max_mau_value = 50 + self.auth_blocking._limit_usage_by_mau = False + self.auth_blocking._max_mau_value = 50 lots_of_users = 100 small_number_of_users = 1 # Ensure no error thrown - yield self.auth.check_auth_blocking() + yield defer.ensureDeferred(self.auth.check_auth_blocking()) - self.hs.config.limit_usage_by_mau = True + self.auth_blocking._limit_usage_by_mau = True self.store.get_monthly_active_count = Mock( return_value=defer.succeed(lots_of_users) ) with self.assertRaises(ResourceLimitError) as e: - yield self.auth.check_auth_blocking() + yield defer.ensureDeferred(self.auth.check_auth_blocking()) self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact) self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEquals(e.exception.code, 403) @@ -334,49 +349,54 @@ class AuthTestCase(unittest.TestCase): self.store.get_monthly_active_count = Mock( return_value=defer.succeed(small_number_of_users) ) - yield self.auth.check_auth_blocking() + yield defer.ensureDeferred(self.auth.check_auth_blocking()) @defer.inlineCallbacks def test_blocking_mau__depending_on_user_type(self): - self.hs.config.max_mau_value = 50 - self.hs.config.limit_usage_by_mau = True + self.auth_blocking._max_mau_value = 50 + self.auth_blocking._limit_usage_by_mau = True self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100)) # Support users allowed - yield self.auth.check_auth_blocking(user_type=UserTypes.SUPPORT) + yield defer.ensureDeferred( + self.auth.check_auth_blocking(user_type=UserTypes.SUPPORT) + ) self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100)) # Bots not allowed with self.assertRaises(ResourceLimitError): - yield self.auth.check_auth_blocking(user_type=UserTypes.BOT) + yield defer.ensureDeferred( + self.auth.check_auth_blocking(user_type=UserTypes.BOT) + ) self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100)) # Real users not allowed with self.assertRaises(ResourceLimitError): - yield self.auth.check_auth_blocking() + yield defer.ensureDeferred(self.auth.check_auth_blocking()) @defer.inlineCallbacks def test_reserved_threepid(self): - self.hs.config.limit_usage_by_mau = True - self.hs.config.max_mau_value = 1 + self.auth_blocking._limit_usage_by_mau = True + self.auth_blocking._max_mau_value = 1 self.store.get_monthly_active_count = lambda: defer.succeed(2) threepid = {"medium": "email", "address": "reserved@server.com"} unknown_threepid = {"medium": "email", "address": "unreserved@server.com"} - self.hs.config.mau_limits_reserved_threepids = [threepid] + self.auth_blocking._mau_limits_reserved_threepids = [threepid] - yield self.store.register_user(user_id="user1", password_hash=None) with self.assertRaises(ResourceLimitError): - yield self.auth.check_auth_blocking() + yield defer.ensureDeferred(self.auth.check_auth_blocking()) with self.assertRaises(ResourceLimitError): - yield self.auth.check_auth_blocking(threepid=unknown_threepid) + yield defer.ensureDeferred( + self.auth.check_auth_blocking(threepid=unknown_threepid) + ) - yield self.auth.check_auth_blocking(threepid=threepid) + yield defer.ensureDeferred(self.auth.check_auth_blocking(threepid=threepid)) @defer.inlineCallbacks def test_hs_disabled(self): - self.hs.config.hs_disabled = True - self.hs.config.hs_disabled_message = "Reason for being disabled" + self.auth_blocking._hs_disabled = True + self.auth_blocking._hs_disabled_message = "Reason for being disabled" with self.assertRaises(ResourceLimitError) as e: - yield self.auth.check_auth_blocking() + yield defer.ensureDeferred(self.auth.check_auth_blocking()) self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact) self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEquals(e.exception.code, 403) @@ -388,20 +408,20 @@ class AuthTestCase(unittest.TestCase): """ # this should be the default, but we had a bug where the test was doing the wrong # thing, so let's make it explicit - self.hs.config.server_notices_mxid = None + self.auth_blocking._server_notices_mxid = None - self.hs.config.hs_disabled = True - self.hs.config.hs_disabled_message = "Reason for being disabled" + self.auth_blocking._hs_disabled = True + self.auth_blocking._hs_disabled_message = "Reason for being disabled" with self.assertRaises(ResourceLimitError) as e: - yield self.auth.check_auth_blocking() + yield defer.ensureDeferred(self.auth.check_auth_blocking()) self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact) self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEquals(e.exception.code, 403) @defer.inlineCallbacks def test_server_notices_mxid_special_cased(self): - self.hs.config.hs_disabled = True + self.auth_blocking._hs_disabled = True user = "@user:server" - self.hs.config.server_notices_mxid = user - self.hs.config.hs_disabled_message = "Reason for being disabled" - yield self.auth.check_auth_blocking(user) + self.auth_blocking._server_notices_mxid = user + self.auth_blocking._hs_disabled_message = "Reason for being disabled" + yield defer.ensureDeferred(self.auth.check_auth_blocking(user)) diff --git a/tests/app/test_frontend_proxy.py b/tests/app/test_frontend_proxy.py index d3feafa1b7..be20a89682 100644 --- a/tests/app/test_frontend_proxy.py +++ b/tests/app/test_frontend_proxy.py @@ -27,8 +27,8 @@ class FrontendProxyTests(HomeserverTestCase): return hs - def default_config(self, name="test"): - c = super().default_config(name) + def default_config(self): + c = super().default_config() c["worker_app"] = "synapse.app.frontend_proxy" return c diff --git a/tests/app/test_openid_listener.py b/tests/app/test_openid_listener.py index 89fcc3889a..7364f9f1ec 100644 --- a/tests/app/test_openid_listener.py +++ b/tests/app/test_openid_listener.py @@ -29,8 +29,8 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase): ) return hs - def default_config(self, name="test"): - conf = super().default_config(name) + def default_config(self): + conf = super().default_config() # we're using FederationReaderServer, which uses a SlavedStore, so we # have to tell the FederationHandler not to try to access stuff that is only # in the primary store. diff --git a/tests/config/test_cache.py b/tests/config/test_cache.py new file mode 100644 index 0000000000..2920279125 --- /dev/null +++ b/tests/config/test_cache.py @@ -0,0 +1,127 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.config._base import Config, RootConfig +from synapse.config.cache import CacheConfig, add_resizable_cache +from synapse.util.caches.lrucache import LruCache + +from tests.unittest import TestCase + + +class FakeServer(Config): + section = "server" + + +class TestConfig(RootConfig): + config_classes = [FakeServer, CacheConfig] + + +class CacheConfigTests(TestCase): + def setUp(self): + # Reset caches before each test + TestConfig().caches.reset() + + def test_individual_caches_from_environ(self): + """ + Individual cache factors will be loaded from the environment. + """ + config = {} + t = TestConfig() + t.caches._environ = { + "SYNAPSE_CACHE_FACTOR_SOMETHING_OR_OTHER": "2", + "SYNAPSE_NOT_CACHE": "BLAH", + } + t.read_config(config, config_dir_path="", data_dir_path="") + + self.assertEqual(dict(t.caches.cache_factors), {"something_or_other": 2.0}) + + def test_config_overrides_environ(self): + """ + Individual cache factors defined in the environment will take precedence + over those in the config. + """ + config = {"caches": {"per_cache_factors": {"foo": 2, "bar": 3}}} + t = TestConfig() + t.caches._environ = { + "SYNAPSE_CACHE_FACTOR_SOMETHING_OR_OTHER": "2", + "SYNAPSE_CACHE_FACTOR_FOO": 1, + } + t.read_config(config, config_dir_path="", data_dir_path="") + + self.assertEqual( + dict(t.caches.cache_factors), + {"foo": 1.0, "bar": 3.0, "something_or_other": 2.0}, + ) + + def test_individual_instantiated_before_config_load(self): + """ + If a cache is instantiated before the config is read, it will be given + the default cache size in the interim, and then resized once the config + is loaded. + """ + cache = LruCache(100) + + add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) + self.assertEqual(cache.max_size, 50) + + config = {"caches": {"per_cache_factors": {"foo": 3}}} + t = TestConfig() + t.read_config(config, config_dir_path="", data_dir_path="") + + self.assertEqual(cache.max_size, 300) + + def test_individual_instantiated_after_config_load(self): + """ + If a cache is instantiated after the config is read, it will be + immediately resized to the correct size given the per_cache_factor if + there is one. + """ + config = {"caches": {"per_cache_factors": {"foo": 2}}} + t = TestConfig() + t.read_config(config, config_dir_path="", data_dir_path="") + + cache = LruCache(100) + add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) + self.assertEqual(cache.max_size, 200) + + def test_global_instantiated_before_config_load(self): + """ + If a cache is instantiated before the config is read, it will be given + the default cache size in the interim, and then resized to the new + default cache size once the config is loaded. + """ + cache = LruCache(100) + add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) + self.assertEqual(cache.max_size, 50) + + config = {"caches": {"global_factor": 4}} + t = TestConfig() + t.read_config(config, config_dir_path="", data_dir_path="") + + self.assertEqual(cache.max_size, 400) + + def test_global_instantiated_after_config_load(self): + """ + If a cache is instantiated after the config is read, it will be + immediately resized to the correct size given the global factor if there + is no per-cache factor. + """ + config = {"caches": {"global_factor": 1.5}} + t = TestConfig() + t.read_config(config, config_dir_path="", data_dir_path="") + + cache = LruCache(100) + add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) + self.assertEqual(cache.max_size, 150) diff --git a/tests/config/test_database.py b/tests/config/test_database.py index 151d3006ac..f675bde68e 100644 --- a/tests/config/test_database.py +++ b/tests/config/test_database.py @@ -21,9 +21,9 @@ from tests import unittest class DatabaseConfigTestCase(unittest.TestCase): - def test_database_configured_correctly_no_database_conf_param(self): + def test_database_configured_correctly(self): conf = yaml.safe_load( - DatabaseConfig().generate_config_section("/data_dir_path", None) + DatabaseConfig().generate_config_section(data_dir_path="/data_dir_path") ) expected_database_conf = { @@ -32,21 +32,3 @@ class DatabaseConfigTestCase(unittest.TestCase): } self.assertEqual(conf["database"], expected_database_conf) - - def test_database_configured_correctly_database_conf_param(self): - - database_conf = { - "name": "my super fast datastore", - "args": { - "user": "matrix", - "password": "synapse_database_password", - "host": "synapse_database_host", - "database": "matrix", - }, - } - - conf = yaml.safe_load( - DatabaseConfig().generate_config_section("/data_dir_path", database_conf) - ) - - self.assertEqual(conf["database"], database_conf) diff --git a/tests/config/test_load.py b/tests/config/test_load.py index b3e557bd6a..734a9983e8 100644 --- a/tests/config/test_load.py +++ b/tests/config/test_load.py @@ -122,7 +122,7 @@ class ConfigLoadingTestCase(unittest.TestCase): with open(self.file, "r") as f: contents = f.readlines() - contents = [l for l in contents if needle not in l] + contents = [line for line in contents if needle not in line] with open(self.file, "w") as f: f.write("".join(contents)) diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 34d5895f18..70c8e72303 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -34,6 +34,7 @@ from synapse.crypto.keyring import ( from synapse.logging.context import ( LoggingContext, PreserveLoggingContext, + current_context, make_deferred_yieldable, ) from synapse.storage.keys import FetchKeyResult @@ -83,9 +84,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): ) def check_context(self, _, expected): - self.assertEquals( - getattr(LoggingContext.current_context(), "request", None), expected - ) + self.assertEquals(getattr(current_context(), "request", None), expected) def test_verify_json_objects_for_server_awaits_previous_requests(self): key1 = signedjson.key.generate_signing_key(1) @@ -105,7 +104,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): @defer.inlineCallbacks def get_perspectives(**kwargs): - self.assertEquals(LoggingContext.current_context().request, "11") + self.assertEquals(current_context().request, "11") with PreserveLoggingContext(): yield persp_deferred return persp_resp diff --git a/tests/events/test_snapshot.py b/tests/events/test_snapshot.py new file mode 100644 index 0000000000..640f5f3bce --- /dev/null +++ b/tests/events/test_snapshot.py @@ -0,0 +1,100 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.events.snapshot import EventContext +from synapse.rest import admin +from synapse.rest.client.v1 import login, room + +from tests import unittest +from tests.test_utils.event_injection import create_event + + +class TestEventContext(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.storage = hs.get_storage() + + self.user_id = self.register_user("u1", "pass") + self.user_tok = self.login("u1", "pass") + self.room_id = self.helper.create_room_as(tok=self.user_tok) + + def test_serialize_deserialize_msg(self): + """Test that an EventContext for a message event is the same after + serialize/deserialize. + """ + + event, context = create_event( + self.hs, room_id=self.room_id, type="m.test", sender=self.user_id, + ) + + self._check_serialize_deserialize(event, context) + + def test_serialize_deserialize_state_no_prev(self): + """Test that an EventContext for a state event (with not previous entry) + is the same after serialize/deserialize. + """ + event, context = create_event( + self.hs, + room_id=self.room_id, + type="m.test", + sender=self.user_id, + state_key="", + ) + + self._check_serialize_deserialize(event, context) + + def test_serialize_deserialize_state_prev(self): + """Test that an EventContext for a state event (which replaces a + previous entry) is the same after serialize/deserialize. + """ + event, context = create_event( + self.hs, + room_id=self.room_id, + type="m.room.member", + sender=self.user_id, + state_key=self.user_id, + content={"membership": "leave"}, + ) + + self._check_serialize_deserialize(event, context) + + def _check_serialize_deserialize(self, event, context): + serialized = self.get_success(context.serialize(event, self.store)) + + d_context = EventContext.deserialize(self.storage, serialized) + + self.assertEqual(context.state_group, d_context.state_group) + self.assertEqual(context.rejected, d_context.rejected) + self.assertEqual( + context.state_group_before_event, d_context.state_group_before_event + ) + self.assertEqual(context.prev_group, d_context.prev_group) + self.assertEqual(context.delta_ids, d_context.delta_ids) + self.assertEqual(context.app_service, d_context.app_service) + + self.assertEqual( + self.get_success(context.get_current_state_ids()), + self.get_success(d_context.get_current_state_ids()), + ) + self.assertEqual( + self.get_success(context.get_prev_state_ids()), + self.get_success(d_context.get_prev_state_ids()), + ) diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py index 24fa8dbb45..94980733c4 100644 --- a/tests/federation/test_complexity.py +++ b/tests/federation/test_complexity.py @@ -33,8 +33,8 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): login.register_servlets, ] - def default_config(self, name="test"): - config = super().default_config(name=name) + def default_config(self): + config = super().default_config() config["limit_remote_rooms"] = {"enabled": True, "complexity": 0.05} return config diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index d456267b87..33105576af 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -12,19 +12,25 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional from mock import Mock +from signedjson import key, sign +from signedjson.types import BaseKey, SigningKey + from twisted.internet import defer -from synapse.types import ReadReceipt +from synapse.rest import admin +from synapse.rest.client.v1 import login +from synapse.types import JsonDict, ReadReceipt from tests.unittest import HomeserverTestCase, override_config -class FederationSenderTestCases(HomeserverTestCase): +class FederationSenderReceiptsTestCases(HomeserverTestCase): def make_homeserver(self, reactor, clock): - return super(FederationSenderTestCases, self).setup_test_homeserver( + return self.setup_test_homeserver( state_handler=Mock(spec=["get_current_hosts_in_room"]), federation_transport_client=Mock(spec=["send_transaction"]), ) @@ -147,3 +153,392 @@ class FederationSenderTestCases(HomeserverTestCase): } ], ) + + +class FederationSenderDevicesTestCases(HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + return self.setup_test_homeserver( + state_handler=Mock(spec=["get_current_hosts_in_room"]), + federation_transport_client=Mock(spec=["send_transaction"]), + ) + + def default_config(self): + c = super().default_config() + c["send_federation"] = True + return c + + def prepare(self, reactor, clock, hs): + # stub out get_current_hosts_in_room + mock_state_handler = hs.get_state_handler() + mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"] + + # stub out get_users_who_share_room_with_user so that it claims that + # `@user2:host2` is in the room + def get_users_who_share_room_with_user(user_id): + return defer.succeed({"@user2:host2"}) + + hs.get_datastore().get_users_who_share_room_with_user = ( + get_users_who_share_room_with_user + ) + + # whenever send_transaction is called, record the edu data + self.edus = [] + self.hs.get_federation_transport_client().send_transaction.side_effect = ( + self.record_transaction + ) + + def record_transaction(self, txn, json_cb): + data = json_cb() + self.edus.extend(data["edus"]) + return defer.succeed({}) + + def test_send_device_updates(self): + """Basic case: each device update should result in an EDU""" + # create a device + u1 = self.register_user("user", "pass") + self.login(u1, "pass", device_id="D1") + + # expect one edu + self.assertEqual(len(self.edus), 1) + stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D1", None) + + # a second call should produce no new device EDUs + self.hs.get_federation_sender().send_device_messages("host2") + self.pump() + self.assertEqual(self.edus, []) + + # a second device + self.login("user", "pass", device_id="D2") + + self.assertEqual(len(self.edus), 1) + self.check_device_update_edu(self.edus.pop(0), u1, "D2", stream_id) + + def test_upload_signatures(self): + """Uploading signatures on some devices should produce updates for that user""" + + e2e_handler = self.hs.get_e2e_keys_handler() + + # register two devices + u1 = self.register_user("user", "pass") + self.login(u1, "pass", device_id="D1") + self.login(u1, "pass", device_id="D2") + + # expect two edus + self.assertEqual(len(self.edus), 2) + stream_id = None + stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D1", stream_id) + stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D2", stream_id) + + # upload signing keys for each device + device1_signing_key = self.generate_and_upload_device_signing_key(u1, "D1") + device2_signing_key = self.generate_and_upload_device_signing_key(u1, "D2") + + # expect two more edus + self.assertEqual(len(self.edus), 2) + stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D1", stream_id) + stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D2", stream_id) + + # upload master key and self-signing key + master_signing_key = generate_self_id_key() + master_key = { + "user_id": u1, + "usage": ["master"], + "keys": {key_id(master_signing_key): encode_pubkey(master_signing_key)}, + } + + # private key: HvQBbU+hc2Zr+JP1sE0XwBe1pfZZEYtJNPJLZJtS+F8 + selfsigning_signing_key = generate_self_id_key() + selfsigning_key = { + "user_id": u1, + "usage": ["self_signing"], + "keys": { + key_id(selfsigning_signing_key): encode_pubkey(selfsigning_signing_key) + }, + } + sign.sign_json(selfsigning_key, u1, master_signing_key) + + cross_signing_keys = { + "master_key": master_key, + "self_signing_key": selfsigning_key, + } + + self.get_success( + e2e_handler.upload_signing_keys_for_user(u1, cross_signing_keys) + ) + + # expect signing key update edu + self.assertEqual(len(self.edus), 1) + self.assertEqual(self.edus.pop(0)["edu_type"], "org.matrix.signing_key_update") + + # sign the devices + d1_json = build_device_dict(u1, "D1", device1_signing_key) + sign.sign_json(d1_json, u1, selfsigning_signing_key) + d2_json = build_device_dict(u1, "D2", device2_signing_key) + sign.sign_json(d2_json, u1, selfsigning_signing_key) + + ret = self.get_success( + e2e_handler.upload_signatures_for_device_keys( + u1, {u1: {"D1": d1_json, "D2": d2_json}}, + ) + ) + self.assertEqual(ret["failures"], {}) + + # expect two edus, in one or two transactions. We don't know what order the + # devices will be updated. + self.assertEqual(len(self.edus), 2) + stream_id = None # FIXME: there is a discontinuity in the stream IDs: see #7142 + for edu in self.edus: + self.assertEqual(edu["edu_type"], "m.device_list_update") + c = edu["content"] + if stream_id is not None: + self.assertEqual(c["prev_id"], [stream_id]) + self.assertGreaterEqual(c["stream_id"], stream_id) + stream_id = c["stream_id"] + devices = {edu["content"]["device_id"] for edu in self.edus} + self.assertEqual({"D1", "D2"}, devices) + + def test_delete_devices(self): + """If devices are deleted, that should result in EDUs too""" + + # create devices + u1 = self.register_user("user", "pass") + self.login("user", "pass", device_id="D1") + self.login("user", "pass", device_id="D2") + self.login("user", "pass", device_id="D3") + + # expect three edus + self.assertEqual(len(self.edus), 3) + stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D1", None) + stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D2", stream_id) + stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D3", stream_id) + + # delete them again + self.get_success( + self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"]) + ) + + # expect three edus, in an unknown order + self.assertEqual(len(self.edus), 3) + for edu in self.edus: + self.assertEqual(edu["edu_type"], "m.device_list_update") + c = edu["content"] + self.assertGreaterEqual( + c.items(), + {"user_id": u1, "prev_id": [stream_id], "deleted": True}.items(), + ) + self.assertGreaterEqual(c["stream_id"], stream_id) + stream_id = c["stream_id"] + devices = {edu["content"]["device_id"] for edu in self.edus} + self.assertEqual({"D1", "D2", "D3"}, devices) + + def test_unreachable_server(self): + """If the destination server is unreachable, all the updates should get sent on + recovery + """ + mock_send_txn = self.hs.get_federation_transport_client().send_transaction + mock_send_txn.side_effect = lambda t, cb: defer.fail("fail") + + # create devices + u1 = self.register_user("user", "pass") + self.login("user", "pass", device_id="D1") + self.login("user", "pass", device_id="D2") + self.login("user", "pass", device_id="D3") + + # delete them again + self.get_success( + self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"]) + ) + + self.assertGreaterEqual(mock_send_txn.call_count, 4) + + # recover the server + mock_send_txn.side_effect = self.record_transaction + self.hs.get_federation_sender().send_device_messages("host2") + self.pump() + + # for each device, there should be a single update + self.assertEqual(len(self.edus), 3) + stream_id = None + for edu in self.edus: + self.assertEqual(edu["edu_type"], "m.device_list_update") + c = edu["content"] + self.assertEqual(c["prev_id"], [stream_id] if stream_id is not None else []) + if stream_id is not None: + self.assertGreaterEqual(c["stream_id"], stream_id) + stream_id = c["stream_id"] + devices = {edu["content"]["device_id"] for edu in self.edus} + self.assertEqual({"D1", "D2", "D3"}, devices) + + def test_prune_outbound_device_pokes1(self): + """If a destination is unreachable, and the updates are pruned, we should get + a single update. + + This case tests the behaviour when the server has never been reachable. + """ + mock_send_txn = self.hs.get_federation_transport_client().send_transaction + mock_send_txn.side_effect = lambda t, cb: defer.fail("fail") + + # create devices + u1 = self.register_user("user", "pass") + self.login("user", "pass", device_id="D1") + self.login("user", "pass", device_id="D2") + self.login("user", "pass", device_id="D3") + + # delete them again + self.get_success( + self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"]) + ) + + self.assertGreaterEqual(mock_send_txn.call_count, 4) + + # run the prune job + self.reactor.advance(10) + self.get_success( + self.hs.get_datastore()._prune_old_outbound_device_pokes(prune_age=1) + ) + + # recover the server + mock_send_txn.side_effect = self.record_transaction + self.hs.get_federation_sender().send_device_messages("host2") + self.pump() + + # there should be a single update for this user. + self.assertEqual(len(self.edus), 1) + edu = self.edus.pop(0) + self.assertEqual(edu["edu_type"], "m.device_list_update") + c = edu["content"] + + # synapse uses an empty prev_id list to indicate "needs a full resync". + self.assertEqual(c["prev_id"], []) + + def test_prune_outbound_device_pokes2(self): + """If a destination is unreachable, and the updates are pruned, we should get + a single update. + + This case tests the behaviour when the server was reachable, but then goes + offline. + """ + + # create first device + u1 = self.register_user("user", "pass") + self.login("user", "pass", device_id="D1") + + # expect the update EDU + self.assertEqual(len(self.edus), 1) + self.check_device_update_edu(self.edus.pop(0), u1, "D1", None) + + # now the server goes offline + mock_send_txn = self.hs.get_federation_transport_client().send_transaction + mock_send_txn.side_effect = lambda t, cb: defer.fail("fail") + + self.login("user", "pass", device_id="D2") + self.login("user", "pass", device_id="D3") + + # delete them again + self.get_success( + self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"]) + ) + + self.assertGreaterEqual(mock_send_txn.call_count, 3) + + # run the prune job + self.reactor.advance(10) + self.get_success( + self.hs.get_datastore()._prune_old_outbound_device_pokes(prune_age=1) + ) + + # recover the server + mock_send_txn.side_effect = self.record_transaction + self.hs.get_federation_sender().send_device_messages("host2") + self.pump() + + # ... and we should get a single update for this user. + self.assertEqual(len(self.edus), 1) + edu = self.edus.pop(0) + self.assertEqual(edu["edu_type"], "m.device_list_update") + c = edu["content"] + + # synapse uses an empty prev_id list to indicate "needs a full resync". + self.assertEqual(c["prev_id"], []) + + def check_device_update_edu( + self, + edu: JsonDict, + user_id: str, + device_id: str, + prev_stream_id: Optional[int], + ) -> int: + """Check that the given EDU is an update for the given device + Returns the stream_id. + """ + self.assertEqual(edu["edu_type"], "m.device_list_update") + content = edu["content"] + + expected = { + "user_id": user_id, + "device_id": device_id, + "prev_id": [prev_stream_id] if prev_stream_id is not None else [], + } + + self.assertLessEqual(expected.items(), content.items()) + if prev_stream_id is not None: + self.assertGreaterEqual(content["stream_id"], prev_stream_id) + return content["stream_id"] + + def check_signing_key_update_txn(self, txn: JsonDict,) -> None: + """Check that the txn has an EDU with a signing key update. + """ + edus = txn["edus"] + self.assertEqual(len(edus), 1) + + def generate_and_upload_device_signing_key( + self, user_id: str, device_id: str + ) -> SigningKey: + """Generate a signing keypair for the given device, and upload it""" + sk = key.generate_signing_key(device_id) + + device_dict = build_device_dict(user_id, device_id, sk) + + self.get_success( + self.hs.get_e2e_keys_handler().upload_keys_for_user( + user_id, device_id, {"device_keys": device_dict}, + ) + ) + return sk + + +def generate_self_id_key() -> SigningKey: + """generate a signing key whose version is its public key + + ... as used by the cross-signing-keys. + """ + k = key.generate_signing_key("x") + k.version = encode_pubkey(k) + return k + + +def key_id(k: BaseKey) -> str: + return "%s:%s" % (k.alg, k.version) + + +def encode_pubkey(sk: SigningKey) -> str: + """Encode the public key corresponding to the given signing key as base64""" + return key.encode_verify_key_base64(key.get_verify_key(sk)) + + +def build_device_dict(user_id: str, device_id: str, sk: SigningKey): + """Build a dict representing the given device""" + return { + "user_id": user_id, + "device_id": device_id, + "algorithms": ["m.olm.curve25519-aes-sha256", "m.megolm.v1.aes-sha"], + "keys": { + "curve25519:" + device_id: "curve25519+key", + key_id(sk): encode_pubkey(sk), + }, + } diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index b03103d96f..c01b04e1dc 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -39,8 +39,13 @@ class AuthTestCase(unittest.TestCase): self.hs.handlers = AuthHandlers(self.hs) self.auth_handler = self.hs.handlers.auth_handler self.macaroon_generator = self.hs.get_macaroon_generator() + # MAU tests - self.hs.config.max_mau_value = 50 + # AuthBlocking reads from the hs' config on initialization. We need to + # modify its config instead of the hs' + self.auth_blocking = self.hs.get_auth()._auth_blocking + self.auth_blocking._max_mau_value = 50 + self.small_number_of_users = 1 self.large_number_of_users = 100 @@ -82,16 +87,16 @@ class AuthTestCase(unittest.TestCase): self.hs.clock.now = 1000 token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000) - user_id = yield self.auth_handler.validate_short_term_login_token_and_get_user_id( - token + user_id = yield defer.ensureDeferred( + self.auth_handler.validate_short_term_login_token_and_get_user_id(token) ) self.assertEqual("a_user", user_id) # when we advance the clock, the token should be rejected self.hs.clock.now = 6000 with self.assertRaises(synapse.api.errors.AuthError): - yield self.auth_handler.validate_short_term_login_token_and_get_user_id( - token + yield defer.ensureDeferred( + self.auth_handler.validate_short_term_login_token_and_get_user_id(token) ) @defer.inlineCallbacks @@ -99,8 +104,10 @@ class AuthTestCase(unittest.TestCase): token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000) macaroon = pymacaroons.Macaroon.deserialize(token) - user_id = yield self.auth_handler.validate_short_term_login_token_and_get_user_id( - macaroon.serialize() + user_id = yield defer.ensureDeferred( + self.auth_handler.validate_short_term_login_token_and_get_user_id( + macaroon.serialize() + ) ) self.assertEqual("a_user", user_id) @@ -109,99 +116,121 @@ class AuthTestCase(unittest.TestCase): macaroon.add_first_party_caveat("user_id = b_user") with self.assertRaises(synapse.api.errors.AuthError): - yield self.auth_handler.validate_short_term_login_token_and_get_user_id( - macaroon.serialize() + yield defer.ensureDeferred( + self.auth_handler.validate_short_term_login_token_and_get_user_id( + macaroon.serialize() + ) ) @defer.inlineCallbacks def test_mau_limits_disabled(self): - self.hs.config.limit_usage_by_mau = False + self.auth_blocking._limit_usage_by_mau = False # Ensure does not throw exception - yield self.auth_handler.get_access_token_for_user_id( - "user_a", device_id=None, valid_until_ms=None + yield defer.ensureDeferred( + self.auth_handler.get_access_token_for_user_id( + "user_a", device_id=None, valid_until_ms=None + ) ) - yield self.auth_handler.validate_short_term_login_token_and_get_user_id( - self._get_macaroon().serialize() + yield defer.ensureDeferred( + self.auth_handler.validate_short_term_login_token_and_get_user_id( + self._get_macaroon().serialize() + ) ) @defer.inlineCallbacks def test_mau_limits_exceeded_large(self): - self.hs.config.limit_usage_by_mau = True + self.auth_blocking._limit_usage_by_mau = True self.hs.get_datastore().get_monthly_active_count = Mock( return_value=defer.succeed(self.large_number_of_users) ) with self.assertRaises(ResourceLimitError): - yield self.auth_handler.get_access_token_for_user_id( - "user_a", device_id=None, valid_until_ms=None + yield defer.ensureDeferred( + self.auth_handler.get_access_token_for_user_id( + "user_a", device_id=None, valid_until_ms=None + ) ) self.hs.get_datastore().get_monthly_active_count = Mock( return_value=defer.succeed(self.large_number_of_users) ) with self.assertRaises(ResourceLimitError): - yield self.auth_handler.validate_short_term_login_token_and_get_user_id( - self._get_macaroon().serialize() + yield defer.ensureDeferred( + self.auth_handler.validate_short_term_login_token_and_get_user_id( + self._get_macaroon().serialize() + ) ) @defer.inlineCallbacks def test_mau_limits_parity(self): - self.hs.config.limit_usage_by_mau = True + self.auth_blocking._limit_usage_by_mau = True # If not in monthly active cohort self.hs.get_datastore().get_monthly_active_count = Mock( - return_value=defer.succeed(self.hs.config.max_mau_value) + return_value=defer.succeed(self.auth_blocking._max_mau_value) ) with self.assertRaises(ResourceLimitError): - yield self.auth_handler.get_access_token_for_user_id( - "user_a", device_id=None, valid_until_ms=None + yield defer.ensureDeferred( + self.auth_handler.get_access_token_for_user_id( + "user_a", device_id=None, valid_until_ms=None + ) ) self.hs.get_datastore().get_monthly_active_count = Mock( - return_value=defer.succeed(self.hs.config.max_mau_value) + return_value=defer.succeed(self.auth_blocking._max_mau_value) ) with self.assertRaises(ResourceLimitError): - yield self.auth_handler.validate_short_term_login_token_and_get_user_id( - self._get_macaroon().serialize() + yield defer.ensureDeferred( + self.auth_handler.validate_short_term_login_token_and_get_user_id( + self._get_macaroon().serialize() + ) ) # If in monthly active cohort self.hs.get_datastore().user_last_seen_monthly_active = Mock( return_value=defer.succeed(self.hs.get_clock().time_msec()) ) self.hs.get_datastore().get_monthly_active_count = Mock( - return_value=defer.succeed(self.hs.config.max_mau_value) + return_value=defer.succeed(self.auth_blocking._max_mau_value) ) - yield self.auth_handler.get_access_token_for_user_id( - "user_a", device_id=None, valid_until_ms=None + yield defer.ensureDeferred( + self.auth_handler.get_access_token_for_user_id( + "user_a", device_id=None, valid_until_ms=None + ) ) self.hs.get_datastore().user_last_seen_monthly_active = Mock( return_value=defer.succeed(self.hs.get_clock().time_msec()) ) self.hs.get_datastore().get_monthly_active_count = Mock( - return_value=defer.succeed(self.hs.config.max_mau_value) + return_value=defer.succeed(self.auth_blocking._max_mau_value) ) - yield self.auth_handler.validate_short_term_login_token_and_get_user_id( - self._get_macaroon().serialize() + yield defer.ensureDeferred( + self.auth_handler.validate_short_term_login_token_and_get_user_id( + self._get_macaroon().serialize() + ) ) @defer.inlineCallbacks def test_mau_limits_not_exceeded(self): - self.hs.config.limit_usage_by_mau = True + self.auth_blocking._limit_usage_by_mau = True self.hs.get_datastore().get_monthly_active_count = Mock( return_value=defer.succeed(self.small_number_of_users) ) # Ensure does not raise exception - yield self.auth_handler.get_access_token_for_user_id( - "user_a", device_id=None, valid_until_ms=None + yield defer.ensureDeferred( + self.auth_handler.get_access_token_for_user_id( + "user_a", device_id=None, valid_until_ms=None + ) ) self.hs.get_datastore().get_monthly_active_count = Mock( return_value=defer.succeed(self.small_number_of_users) ) - yield self.auth_handler.validate_short_term_login_token_and_get_user_id( - self._get_macaroon().serialize() + yield defer.ensureDeferred( + self.auth_handler.validate_short_term_login_token_and_get_user_id( + self._get_macaroon().serialize() + ) ) def _get_macaroon(self): diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 5e40adba52..00bb776271 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -102,6 +102,68 @@ class DirectoryTestCase(unittest.HomeserverTestCase): self.assertEquals({"room_id": "!8765asdf:test", "servers": ["test"]}, response) +class TestCreateAlias(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + directory.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.handler = hs.get_handlers().directory_handler + + # Create user + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + # Create a test room + self.room_id = self.helper.create_room_as( + self.admin_user, tok=self.admin_user_tok + ) + + self.test_alias = "#test:test" + self.room_alias = RoomAlias.from_string(self.test_alias) + + # Create a test user. + self.test_user = self.register_user("user", "pass", admin=False) + self.test_user_tok = self.login("user", "pass") + self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok) + + def test_create_alias_joined_room(self): + """A user can create an alias for a room they're in.""" + self.get_success( + self.handler.create_association( + create_requester(self.test_user), self.room_alias, self.room_id, + ) + ) + + def test_create_alias_other_room(self): + """A user cannot create an alias for a room they're NOT in.""" + other_room_id = self.helper.create_room_as( + self.admin_user, tok=self.admin_user_tok + ) + + self.get_failure( + self.handler.create_association( + create_requester(self.test_user), self.room_alias, other_room_id, + ), + synapse.api.errors.SynapseError, + ) + + def test_create_alias_admin(self): + """An admin can create an alias for a room they're NOT in.""" + other_room_id = self.helper.create_room_as( + self.test_user, tok=self.test_user_tok + ) + + self.get_success( + self.handler.create_association( + create_requester(self.admin_user), self.room_alias, other_room_id, + ) + ) + + class TestDeleteAlias(unittest.HomeserverTestCase): servlets = [ synapse.rest.admin.register_servlets, diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py new file mode 100644 index 0000000000..61963aa90d --- /dev/null +++ b/tests/handlers/test_oidc.py @@ -0,0 +1,565 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Quentin Gliech +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from urllib.parse import parse_qs, urlparse + +from mock import Mock, patch + +import attr +import pymacaroons + +from twisted.internet import defer +from twisted.python.failure import Failure +from twisted.web._newclient import ResponseDone + +from synapse.handlers.oidc_handler import ( + MappingException, + OidcError, + OidcHandler, + OidcMappingProvider, +) +from synapse.types import UserID + +from tests.unittest import HomeserverTestCase, override_config + + +@attr.s +class FakeResponse: + code = attr.ib() + body = attr.ib() + phrase = attr.ib() + + def deliverBody(self, protocol): + protocol.dataReceived(self.body) + protocol.connectionLost(Failure(ResponseDone())) + + +# These are a few constants that are used as config parameters in the tests. +ISSUER = "https://issuer/" +CLIENT_ID = "test-client-id" +CLIENT_SECRET = "test-client-secret" +BASE_URL = "https://synapse/" +CALLBACK_URL = BASE_URL + "_synapse/oidc/callback" +SCOPES = ["openid"] + +AUTHORIZATION_ENDPOINT = ISSUER + "authorize" +TOKEN_ENDPOINT = ISSUER + "token" +USERINFO_ENDPOINT = ISSUER + "userinfo" +WELL_KNOWN = ISSUER + ".well-known/openid-configuration" +JWKS_URI = ISSUER + ".well-known/jwks.json" + +# config for common cases +COMMON_CONFIG = { + "discover": False, + "authorization_endpoint": AUTHORIZATION_ENDPOINT, + "token_endpoint": TOKEN_ENDPOINT, + "jwks_uri": JWKS_URI, +} + + +# The cookie name and path don't really matter, just that it has to be coherent +# between the callback & redirect handlers. +COOKIE_NAME = b"oidc_session" +COOKIE_PATH = "/_synapse/oidc" + +MockedMappingProvider = Mock(OidcMappingProvider) + + +def simple_async_mock(return_value=None, raises=None): + # AsyncMock is not available in python3.5, this mimics part of its behaviour + async def cb(*args, **kwargs): + if raises: + raise raises + return return_value + + return Mock(side_effect=cb) + + +async def get_json(url): + # Mock get_json calls to handle jwks & oidc discovery endpoints + if url == WELL_KNOWN: + # Minimal discovery document, as defined in OpenID.Discovery + # https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata + return { + "issuer": ISSUER, + "authorization_endpoint": AUTHORIZATION_ENDPOINT, + "token_endpoint": TOKEN_ENDPOINT, + "jwks_uri": JWKS_URI, + "userinfo_endpoint": USERINFO_ENDPOINT, + "response_types_supported": ["code"], + "subject_types_supported": ["public"], + "id_token_signing_alg_values_supported": ["RS256"], + } + elif url == JWKS_URI: + return {"keys": []} + + +class OidcHandlerTestCase(HomeserverTestCase): + def make_homeserver(self, reactor, clock): + + self.http_client = Mock(spec=["get_json"]) + self.http_client.get_json.side_effect = get_json + self.http_client.user_agent = "Synapse Test" + + config = self.default_config() + config["public_baseurl"] = BASE_URL + oidc_config = config.get("oidc_config", {}) + oidc_config["enabled"] = True + oidc_config["client_id"] = CLIENT_ID + oidc_config["client_secret"] = CLIENT_SECRET + oidc_config["issuer"] = ISSUER + oidc_config["scopes"] = SCOPES + oidc_config["user_mapping_provider"] = { + "module": __name__ + ".MockedMappingProvider" + } + config["oidc_config"] = oidc_config + + hs = self.setup_test_homeserver( + http_client=self.http_client, + proxied_http_client=self.http_client, + config=config, + ) + + self.handler = OidcHandler(hs) + + return hs + + def metadata_edit(self, values): + return patch.dict(self.handler._provider_metadata, values) + + def assertRenderedError(self, error, error_description=None): + args = self.handler._render_error.call_args[0] + self.assertEqual(args[1], error) + if error_description is not None: + self.assertEqual(args[2], error_description) + # Reset the render_error mock + self.handler._render_error.reset_mock() + + def test_config(self): + """Basic config correctly sets up the callback URL and client auth correctly.""" + self.assertEqual(self.handler._callback_url, CALLBACK_URL) + self.assertEqual(self.handler._client_auth.client_id, CLIENT_ID) + self.assertEqual(self.handler._client_auth.client_secret, CLIENT_SECRET) + + @override_config({"oidc_config": {"discover": True}}) + @defer.inlineCallbacks + def test_discovery(self): + """The handler should discover the endpoints from OIDC discovery document.""" + # This would throw if some metadata were invalid + metadata = yield defer.ensureDeferred(self.handler.load_metadata()) + self.http_client.get_json.assert_called_once_with(WELL_KNOWN) + + self.assertEqual(metadata.issuer, ISSUER) + self.assertEqual(metadata.authorization_endpoint, AUTHORIZATION_ENDPOINT) + self.assertEqual(metadata.token_endpoint, TOKEN_ENDPOINT) + self.assertEqual(metadata.jwks_uri, JWKS_URI) + # FIXME: it seems like authlib does not have that defined in its metadata models + # self.assertEqual(metadata.userinfo_endpoint, USERINFO_ENDPOINT) + + # subsequent calls should be cached + self.http_client.reset_mock() + yield defer.ensureDeferred(self.handler.load_metadata()) + self.http_client.get_json.assert_not_called() + + @override_config({"oidc_config": COMMON_CONFIG}) + @defer.inlineCallbacks + def test_no_discovery(self): + """When discovery is disabled, it should not try to load from discovery document.""" + yield defer.ensureDeferred(self.handler.load_metadata()) + self.http_client.get_json.assert_not_called() + + @override_config({"oidc_config": COMMON_CONFIG}) + @defer.inlineCallbacks + def test_load_jwks(self): + """JWKS loading is done once (then cached) if used.""" + jwks = yield defer.ensureDeferred(self.handler.load_jwks()) + self.http_client.get_json.assert_called_once_with(JWKS_URI) + self.assertEqual(jwks, {"keys": []}) + + # subsequent calls should be cached… + self.http_client.reset_mock() + yield defer.ensureDeferred(self.handler.load_jwks()) + self.http_client.get_json.assert_not_called() + + # …unless forced + self.http_client.reset_mock() + yield defer.ensureDeferred(self.handler.load_jwks(force=True)) + self.http_client.get_json.assert_called_once_with(JWKS_URI) + + # Throw if the JWKS uri is missing + with self.metadata_edit({"jwks_uri": None}): + with self.assertRaises(RuntimeError): + yield defer.ensureDeferred(self.handler.load_jwks(force=True)) + + # Return empty key set if JWKS are not used + self.handler._scopes = [] # not asking the openid scope + self.http_client.get_json.reset_mock() + jwks = yield defer.ensureDeferred(self.handler.load_jwks(force=True)) + self.http_client.get_json.assert_not_called() + self.assertEqual(jwks, {"keys": []}) + + @override_config({"oidc_config": COMMON_CONFIG}) + def test_validate_config(self): + """Provider metadatas are extensively validated.""" + h = self.handler + + # Default test config does not throw + h._validate_metadata() + + with self.metadata_edit({"issuer": None}): + self.assertRaisesRegex(ValueError, "issuer", h._validate_metadata) + + with self.metadata_edit({"issuer": "http://insecure/"}): + self.assertRaisesRegex(ValueError, "issuer", h._validate_metadata) + + with self.metadata_edit({"issuer": "https://invalid/?because=query"}): + self.assertRaisesRegex(ValueError, "issuer", h._validate_metadata) + + with self.metadata_edit({"authorization_endpoint": None}): + self.assertRaisesRegex( + ValueError, "authorization_endpoint", h._validate_metadata + ) + + with self.metadata_edit({"authorization_endpoint": "http://insecure/auth"}): + self.assertRaisesRegex( + ValueError, "authorization_endpoint", h._validate_metadata + ) + + with self.metadata_edit({"token_endpoint": None}): + self.assertRaisesRegex(ValueError, "token_endpoint", h._validate_metadata) + + with self.metadata_edit({"token_endpoint": "http://insecure/token"}): + self.assertRaisesRegex(ValueError, "token_endpoint", h._validate_metadata) + + with self.metadata_edit({"jwks_uri": None}): + self.assertRaisesRegex(ValueError, "jwks_uri", h._validate_metadata) + + with self.metadata_edit({"jwks_uri": "http://insecure/jwks.json"}): + self.assertRaisesRegex(ValueError, "jwks_uri", h._validate_metadata) + + with self.metadata_edit({"response_types_supported": ["id_token"]}): + self.assertRaisesRegex( + ValueError, "response_types_supported", h._validate_metadata + ) + + with self.metadata_edit( + {"token_endpoint_auth_methods_supported": ["client_secret_basic"]} + ): + # should not throw, as client_secret_basic is the default auth method + h._validate_metadata() + + with self.metadata_edit( + {"token_endpoint_auth_methods_supported": ["client_secret_post"]} + ): + self.assertRaisesRegex( + ValueError, + "token_endpoint_auth_methods_supported", + h._validate_metadata, + ) + + # Tests for configs that the userinfo endpoint + self.assertFalse(h._uses_userinfo) + h._scopes = [] # do not request the openid scope + self.assertTrue(h._uses_userinfo) + self.assertRaisesRegex(ValueError, "userinfo_endpoint", h._validate_metadata) + + with self.metadata_edit( + {"userinfo_endpoint": USERINFO_ENDPOINT, "jwks_uri": None} + ): + # Shouldn't raise with a valid userinfo, even without + h._validate_metadata() + + @override_config({"oidc_config": {"skip_verification": True}}) + def test_skip_verification(self): + """Provider metadata validation can be disabled by config.""" + with self.metadata_edit({"issuer": "http://insecure"}): + # This should not throw + self.handler._validate_metadata() + + @defer.inlineCallbacks + def test_redirect_request(self): + """The redirect request has the right arguments & generates a valid session cookie.""" + req = Mock(spec=["addCookie", "redirect", "finish"]) + yield defer.ensureDeferred( + self.handler.handle_redirect_request(req, b"http://client/redirect") + ) + url = req.redirect.call_args[0][0] + url = urlparse(url) + auth_endpoint = urlparse(AUTHORIZATION_ENDPOINT) + + self.assertEqual(url.scheme, auth_endpoint.scheme) + self.assertEqual(url.netloc, auth_endpoint.netloc) + self.assertEqual(url.path, auth_endpoint.path) + + params = parse_qs(url.query) + self.assertEqual(params["redirect_uri"], [CALLBACK_URL]) + self.assertEqual(params["response_type"], ["code"]) + self.assertEqual(params["scope"], [" ".join(SCOPES)]) + self.assertEqual(params["client_id"], [CLIENT_ID]) + self.assertEqual(len(params["state"]), 1) + self.assertEqual(len(params["nonce"]), 1) + + # Check what is in the cookie + # note: python3.5 mock does not have the .called_once() method + calls = req.addCookie.call_args_list + self.assertEqual(len(calls), 1) # called once + # For some reason, call.args does not work with python3.5 + args = calls[0][0] + kwargs = calls[0][1] + self.assertEqual(args[0], COOKIE_NAME) + self.assertEqual(kwargs["path"], COOKIE_PATH) + cookie = args[1] + + macaroon = pymacaroons.Macaroon.deserialize(cookie) + state = self.handler._get_value_from_macaroon(macaroon, "state") + nonce = self.handler._get_value_from_macaroon(macaroon, "nonce") + redirect = self.handler._get_value_from_macaroon( + macaroon, "client_redirect_url" + ) + + self.assertEqual(params["state"], [state]) + self.assertEqual(params["nonce"], [nonce]) + self.assertEqual(redirect, "http://client/redirect") + + @defer.inlineCallbacks + def test_callback_error(self): + """Errors from the provider returned in the callback are displayed.""" + self.handler._render_error = Mock() + request = Mock(args={}) + request.args[b"error"] = [b"invalid_client"] + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.assertRenderedError("invalid_client", "") + + request.args[b"error_description"] = [b"some description"] + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.assertRenderedError("invalid_client", "some description") + + @defer.inlineCallbacks + def test_callback(self): + """Code callback works and display errors if something went wrong. + + A lot of scenarios are tested here: + - when the callback works, with userinfo from ID token + - when the user mapping fails + - when ID token verification fails + - when the callback works, with userinfo fetched from the userinfo endpoint + - when the userinfo fetching fails + - when the code exchange fails + """ + token = { + "type": "bearer", + "id_token": "id_token", + "access_token": "access_token", + } + userinfo = { + "sub": "foo", + "preferred_username": "bar", + } + user_id = UserID("foo", "domain.org") + self.handler._render_error = Mock(return_value=None) + self.handler._exchange_code = simple_async_mock(return_value=token) + self.handler._parse_id_token = simple_async_mock(return_value=userinfo) + self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo) + self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id) + self.handler._auth_handler.complete_sso_login = simple_async_mock() + request = Mock(spec=["args", "getCookie", "addCookie"]) + + code = "code" + state = "state" + nonce = "nonce" + client_redirect_url = "http://client/redirect" + session = self.handler._generate_oidc_session_token( + state=state, nonce=nonce, client_redirect_url=client_redirect_url, + ) + request.getCookie.return_value = session + + request.args = {} + request.args[b"code"] = [code.encode("utf-8")] + request.args[b"state"] = [state.encode("utf-8")] + + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + + self.handler._auth_handler.complete_sso_login.assert_called_once_with( + user_id, request, client_redirect_url, + ) + self.handler._exchange_code.assert_called_once_with(code) + self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce) + self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token) + self.handler._fetch_userinfo.assert_not_called() + self.handler._render_error.assert_not_called() + + # Handle mapping errors + self.handler._map_userinfo_to_user = simple_async_mock( + raises=MappingException() + ) + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.assertRenderedError("mapping_error") + self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id) + + # Handle ID token errors + self.handler._parse_id_token = simple_async_mock(raises=Exception()) + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.assertRenderedError("invalid_token") + + self.handler._auth_handler.complete_sso_login.reset_mock() + self.handler._exchange_code.reset_mock() + self.handler._parse_id_token.reset_mock() + self.handler._map_userinfo_to_user.reset_mock() + self.handler._fetch_userinfo.reset_mock() + + # With userinfo fetching + self.handler._scopes = [] # do not ask the "openid" scope + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + + self.handler._auth_handler.complete_sso_login.assert_called_once_with( + user_id, request, client_redirect_url, + ) + self.handler._exchange_code.assert_called_once_with(code) + self.handler._parse_id_token.assert_not_called() + self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token) + self.handler._fetch_userinfo.assert_called_once_with(token) + self.handler._render_error.assert_not_called() + + # Handle userinfo fetching error + self.handler._fetch_userinfo = simple_async_mock(raises=Exception()) + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.assertRenderedError("fetch_error") + + # Handle code exchange failure + self.handler._exchange_code = simple_async_mock( + raises=OidcError("invalid_request") + ) + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.assertRenderedError("invalid_request") + + @defer.inlineCallbacks + def test_callback_session(self): + """The callback verifies the session presence and validity""" + self.handler._render_error = Mock(return_value=None) + request = Mock(spec=["args", "getCookie", "addCookie"]) + + # Missing cookie + request.args = {} + request.getCookie.return_value = None + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.assertRenderedError("missing_session", "No session cookie found") + + # Missing session parameter + request.args = {} + request.getCookie.return_value = "session" + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.assertRenderedError("invalid_request", "State parameter is missing") + + # Invalid cookie + request.args = {} + request.args[b"state"] = [b"state"] + request.getCookie.return_value = "session" + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.assertRenderedError("invalid_session") + + # Mismatching session + session = self.handler._generate_oidc_session_token( + state="state", nonce="nonce", client_redirect_url="http://client/redirect", + ) + request.args = {} + request.args[b"state"] = [b"mismatching state"] + request.getCookie.return_value = session + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.assertRenderedError("mismatching_session") + + # Valid session + request.args = {} + request.args[b"state"] = [b"state"] + request.getCookie.return_value = session + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.assertRenderedError("invalid_request") + + @override_config({"oidc_config": {"client_auth_method": "client_secret_post"}}) + @defer.inlineCallbacks + def test_exchange_code(self): + """Code exchange behaves correctly and handles various error scenarios.""" + token = {"type": "bearer"} + token_json = json.dumps(token).encode("utf-8") + self.http_client.request = simple_async_mock( + return_value=FakeResponse(code=200, phrase=b"OK", body=token_json) + ) + code = "code" + ret = yield defer.ensureDeferred(self.handler._exchange_code(code)) + kwargs = self.http_client.request.call_args[1] + + self.assertEqual(ret, token) + self.assertEqual(kwargs["method"], "POST") + self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT) + + args = parse_qs(kwargs["data"].decode("utf-8")) + self.assertEqual(args["grant_type"], ["authorization_code"]) + self.assertEqual(args["code"], [code]) + self.assertEqual(args["client_id"], [CLIENT_ID]) + self.assertEqual(args["client_secret"], [CLIENT_SECRET]) + self.assertEqual(args["redirect_uri"], [CALLBACK_URL]) + + # Test error handling + self.http_client.request = simple_async_mock( + return_value=FakeResponse( + code=400, + phrase=b"Bad Request", + body=b'{"error": "foo", "error_description": "bar"}', + ) + ) + with self.assertRaises(OidcError) as exc: + yield defer.ensureDeferred(self.handler._exchange_code(code)) + self.assertEqual(exc.exception.error, "foo") + self.assertEqual(exc.exception.error_description, "bar") + + # Internal server error with no JSON body + self.http_client.request = simple_async_mock( + return_value=FakeResponse( + code=500, phrase=b"Internal Server Error", body=b"Not JSON", + ) + ) + with self.assertRaises(OidcError) as exc: + yield defer.ensureDeferred(self.handler._exchange_code(code)) + self.assertEqual(exc.exception.error, "server_error") + + # Internal server error with JSON body + self.http_client.request = simple_async_mock( + return_value=FakeResponse( + code=500, + phrase=b"Internal Server Error", + body=b'{"error": "internal_server_error"}', + ) + ) + with self.assertRaises(OidcError) as exc: + yield defer.ensureDeferred(self.handler._exchange_code(code)) + self.assertEqual(exc.exception.error, "internal_server_error") + + # 4xx error without "error" field + self.http_client.request = simple_async_mock( + return_value=FakeResponse(code=400, phrase=b"Bad request", body=b"{}",) + ) + with self.assertRaises(OidcError) as exc: + yield defer.ensureDeferred(self.handler._exchange_code(code)) + self.assertEqual(exc.exception.error, "server_error") + + # 2xx error with "error" field + self.http_client.request = simple_async_mock( + return_value=FakeResponse( + code=200, phrase=b"OK", body=b'{"error": "some_error"}', + ) + ) + with self.assertRaises(OidcError) as exc: + yield defer.ensureDeferred(self.handler._exchange_code(code)) + self.assertEqual(exc.exception.error, "some_error") diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index d60c124eec..8aa56f1496 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -19,7 +19,7 @@ from mock import Mock, NonCallableMock from twisted.internet import defer import synapse.types -from synapse.api.errors import AuthError +from synapse.api.errors import AuthError, SynapseError from synapse.handlers.profile import MasterProfileHandler from synapse.types import UserID @@ -70,6 +70,7 @@ class ProfileTestCase(unittest.TestCase): yield self.store.create_profile(self.frank.localpart) self.handler = hs.get_profile_handler() + self.hs = hs @defer.inlineCallbacks def test_get_my_name(self): @@ -81,19 +82,58 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_set_my_name(self): - yield self.handler.set_displayname( - self.frank, synapse.types.create_requester(self.frank), "Frank Jr." + yield defer.ensureDeferred( + self.handler.set_displayname( + self.frank, synapse.types.create_requester(self.frank), "Frank Jr." + ) ) self.assertEquals( - (yield self.store.get_profile_displayname(self.frank.localpart)), + ( + yield defer.ensureDeferred( + self.store.get_profile_displayname(self.frank.localpart) + ) + ), "Frank Jr.", ) + # Set displayname again + yield defer.ensureDeferred( + self.handler.set_displayname( + self.frank, synapse.types.create_requester(self.frank), "Frank" + ) + ) + + self.assertEquals( + (yield self.store.get_profile_displayname(self.frank.localpart)), "Frank", + ) + + @defer.inlineCallbacks + def test_set_my_name_if_disabled(self): + self.hs.config.enable_set_displayname = False + + # Setting displayname for the first time is allowed + yield self.store.set_profile_displayname(self.frank.localpart, "Frank") + + self.assertEquals( + (yield self.store.get_profile_displayname(self.frank.localpart)), "Frank", + ) + + # Setting displayname a second time is forbidden + d = defer.ensureDeferred( + self.handler.set_displayname( + self.frank, synapse.types.create_requester(self.frank), "Frank Jr." + ) + ) + + yield self.assertFailure(d, SynapseError) + @defer.inlineCallbacks def test_set_my_name_noauth(self): - d = self.handler.set_displayname( - self.frank, synapse.types.create_requester(self.bob), "Frank Jr." + d = defer.ensureDeferred( + self.handler.set_displayname( + self.frank, synapse.types.create_requester(self.bob), "Frank Jr." + ) ) yield self.assertFailure(d, AuthError) @@ -137,13 +177,54 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_set_my_avatar(self): - yield self.handler.set_avatar_url( - self.frank, - synapse.types.create_requester(self.frank), - "http://my.server/pic.gif", + yield defer.ensureDeferred( + self.handler.set_avatar_url( + self.frank, + synapse.types.create_requester(self.frank), + "http://my.server/pic.gif", + ) ) self.assertEquals( (yield self.store.get_profile_avatar_url(self.frank.localpart)), "http://my.server/pic.gif", ) + + # Set avatar again + yield defer.ensureDeferred( + self.handler.set_avatar_url( + self.frank, + synapse.types.create_requester(self.frank), + "http://my.server/me.png", + ) + ) + + self.assertEquals( + (yield self.store.get_profile_avatar_url(self.frank.localpart)), + "http://my.server/me.png", + ) + + @defer.inlineCallbacks + def test_set_my_avatar_if_disabled(self): + self.hs.config.enable_set_avatar_url = False + + # Setting displayname for the first time is allowed + yield self.store.set_profile_avatar_url( + self.frank.localpart, "http://my.server/me.png" + ) + + self.assertEquals( + (yield self.store.get_profile_avatar_url(self.frank.localpart)), + "http://my.server/me.png", + ) + + # Set avatar a second time is forbidden + d = defer.ensureDeferred( + self.handler.set_avatar_url( + self.frank, + synapse.types.create_requester(self.frank), + "http://my.server/pic.gif", + ) + ) + + yield self.assertFailure(d, SynapseError) diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index e2915eb7b1..1b7935cef2 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -34,7 +34,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): """ Tests the RegistrationHandler. """ def make_homeserver(self, reactor, clock): - hs_config = self.default_config("test") + hs_config = self.default_config() # some of the tests rely on us having a user consent version hs_config["user_consent"] = { @@ -175,7 +175,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): room_alias_str = "#room:test" self.hs.config.auto_join_rooms = [room_alias_str] - self.store.is_real_user = Mock(return_value=False) + self.store.is_real_user = Mock(return_value=defer.succeed(False)) user_id = self.get_success(self.handler.register_user(localpart="support")) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) self.assertEqual(len(rooms), 0) @@ -187,8 +187,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase): room_alias_str = "#room:test" self.hs.config.auto_join_rooms = [room_alias_str] - self.store.count_real_users = Mock(return_value=1) - self.store.is_real_user = Mock(return_value=True) + self.store.count_real_users = Mock(return_value=defer.succeed(1)) + self.store.is_real_user = Mock(return_value=defer.succeed(True)) user_id = self.get_success(self.handler.register_user(localpart="real")) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) directory_handler = self.hs.get_handlers().directory_handler @@ -202,8 +202,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase): room_alias_str = "#room:test" self.hs.config.auto_join_rooms = [room_alias_str] - self.store.count_real_users = Mock(return_value=2) - self.store.is_real_user = Mock(return_value=True) + self.store.count_real_users = Mock(return_value=defer.succeed(2)) + self.store.is_real_user = Mock(return_value=defer.succeed(True)) user_id = self.get_success(self.handler.register_user(localpart="real")) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) self.assertEqual(len(rooms), 0) @@ -256,8 +256,9 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.handler.register_user(localpart=invalid_user_id), SynapseError ) - @defer.inlineCallbacks - def get_or_create_user(self, requester, localpart, displayname, password_hash=None): + async def get_or_create_user( + self, requester, localpart, displayname, password_hash=None + ): """Creates a new user if the user does not exist, else revokes all previous access tokens and generates a new one. @@ -272,11 +273,11 @@ class RegistrationTestCase(unittest.HomeserverTestCase): """ if localpart is None: raise SynapseError(400, "Request must include user id") - yield self.hs.get_auth().check_auth_blocking() + await self.hs.get_auth().check_auth_blocking() need_register = True try: - yield self.handler.check_username(localpart) + await self.handler.check_username(localpart) except SynapseError as e: if e.errcode == Codes.USER_IN_USE: need_register = False @@ -288,21 +289,21 @@ class RegistrationTestCase(unittest.HomeserverTestCase): token = self.macaroon_generator.generate_access_token(user_id) if need_register: - yield self.handler.register_with_store( + await self.handler.register_with_store( user_id=user_id, password_hash=password_hash, create_profile_with_displayname=user.localpart, ) else: - yield self.hs.get_auth_handler().delete_access_tokens_for_user(user_id) + await self.hs.get_auth_handler().delete_access_tokens_for_user(user_id) - yield self.store.add_access_token_to_user( + await self.store.add_access_token_to_user( user_id=user_id, token=token, device_id=None, valid_until_ms=None ) if displayname is not None: # logger.info("setting user display name: %s -> %s", user_id, displayname) - yield self.hs.get_profile_handler().set_displayname( + await self.hs.get_profile_handler().set_displayname( user, requester, displayname, by_admin=True ) diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index 4cbe9784ed..e178d7765b 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -30,28 +30,31 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): self.sync_handler = self.hs.get_sync_handler() self.store = self.hs.get_datastore() - def test_wait_for_sync_for_user_auth_blocking(self): + # AuthBlocking reads from the hs' config on initialization. We need to + # modify its config instead of the hs' + self.auth_blocking = self.hs.get_auth()._auth_blocking + def test_wait_for_sync_for_user_auth_blocking(self): user_id1 = "@user1:test" user_id2 = "@user2:test" sync_config = self._generate_sync_config(user_id1) self.reactor.advance(100) # So we get not 0 time - self.hs.config.limit_usage_by_mau = True - self.hs.config.max_mau_value = 1 + self.auth_blocking._limit_usage_by_mau = True + self.auth_blocking._max_mau_value = 1 # Check that the happy case does not throw errors self.get_success(self.store.upsert_monthly_active_user(user_id1)) self.get_success(self.sync_handler.wait_for_sync_for_user(sync_config)) # Test that global lock works - self.hs.config.hs_disabled = True + self.auth_blocking._hs_disabled = True e = self.get_failure( self.sync_handler.wait_for_sync_for_user(sync_config), ResourceLimitError ) self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) - self.hs.config.hs_disabled = False + self.auth_blocking._hs_disabled = False sync_config = self._generate_sync_config(user_id2) diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index 7b92bdbc47..572df8d80b 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -185,7 +185,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): # Allow all users. return False - spam_checker.spam_checker = AllowAll() + spam_checker.spam_checkers = [AllowAll()] # The results do not change: # We get one search result when searching for user2 by user1. @@ -198,7 +198,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): # All users are spammy. return True - spam_checker.spam_checker = BlockAll() + spam_checker.spam_checkers = [BlockAll()] # User1 now gets no search results for any of the other users. s = self.get_success(self.handler.search_users(u1, "user2", 10)) diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index fdc1d918ff..562397cdda 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -38,7 +38,7 @@ from synapse.http.federation.well_known_resolver import ( WellKnownResolver, _cache_period_from_headers, ) -from synapse.logging.context import LoggingContext +from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context from synapse.util.caches.ttlcache import TTLCache from tests import unittest @@ -155,7 +155,7 @@ class MatrixFederationAgentTests(unittest.TestCase): self.assertNoResult(fetch_d) # should have reset logcontext to the sentinel - _check_logcontext(LoggingContext.sentinel) + _check_logcontext(SENTINEL_CONTEXT) try: fetch_res = yield fetch_d @@ -1197,7 +1197,7 @@ class TestCachePeriodFromHeaders(unittest.TestCase): def _check_logcontext(context): - current = LoggingContext.current_context() + current = current_context() if current is not context: raise AssertionError("Expected logcontext %s but was %s" % (context, current)) diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py index df034ab237..babc201643 100644 --- a/tests/http/federation/test_srv_resolver.py +++ b/tests/http/federation/test_srv_resolver.py @@ -22,7 +22,7 @@ from twisted.internet.error import ConnectError from twisted.names import dns, error from synapse.http.federation.srv_resolver import SrvResolver -from synapse.logging.context import LoggingContext +from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context from tests import unittest from tests.utils import MockClock @@ -54,12 +54,12 @@ class SrvResolverTestCase(unittest.TestCase): self.assertNoResult(resolve_d) # should have reset to the sentinel context - self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel) + self.assertIs(current_context(), SENTINEL_CONTEXT) result = yield resolve_d # should have restored our context - self.assertIs(LoggingContext.current_context(), ctx) + self.assertIs(current_context(), ctx) return result diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py index 2b01f40a42..fff4f0cbf4 100644 --- a/tests/http/test_fedclient.py +++ b/tests/http/test_fedclient.py @@ -29,14 +29,14 @@ from synapse.http.matrixfederationclient import ( MatrixFederationHttpClient, MatrixFederationRequest, ) -from synapse.logging.context import LoggingContext +from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context from tests.server import FakeTransport from tests.unittest import HomeserverTestCase def check_logcontext(context): - current = LoggingContext.current_context() + current = current_context() if current is not context: raise AssertionError("Expected logcontext %s but was %s" % (context, current)) @@ -64,7 +64,7 @@ class FederationClientTests(HomeserverTestCase): self.assertNoResult(fetch_d) # should have reset logcontext to the sentinel - check_logcontext(LoggingContext.sentinel) + check_logcontext(SENTINEL_CONTEXT) try: fetch_res = yield fetch_d diff --git a/tests/replication/_base.py b/tests/replication/_base.py new file mode 100644 index 0000000000..9d4f0bbe44 --- /dev/null +++ b/tests/replication/_base.py @@ -0,0 +1,307 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Any, List, Optional, Tuple + +import attr + +from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime +from twisted.internet.task import LoopingCall +from twisted.web.http import HTTPChannel + +from synapse.app.generic_worker import ( + GenericWorkerReplicationHandler, + GenericWorkerServer, +) +from synapse.http.site import SynapseRequest +from synapse.replication.http import streams +from synapse.replication.tcp.handler import ReplicationCommandHandler +from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol +from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory +from synapse.server import HomeServer +from synapse.util import Clock + +from tests import unittest +from tests.server import FakeTransport + +logger = logging.getLogger(__name__) + + +class BaseStreamTestCase(unittest.HomeserverTestCase): + """Base class for tests of the replication streams""" + + servlets = [ + streams.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + # build a replication server + server_factory = ReplicationStreamProtocolFactory(hs) + self.streamer = hs.get_replication_streamer() + self.server = server_factory.buildProtocol(None) + + # Make a new HomeServer object for the worker + self.reactor.lookups["testserv"] = "1.2.3.4" + self.worker_hs = self.setup_test_homeserver( + http_client=None, + homeserverToUse=GenericWorkerServer, + config=self._get_worker_hs_config(), + reactor=self.reactor, + ) + + # Since we use sqlite in memory databases we need to make sure the + # databases objects are the same. + self.worker_hs.get_datastore().db = hs.get_datastore().db + + self.test_handler = self._build_replication_data_handler() + self.worker_hs.replication_data_handler = self.test_handler + + repl_handler = ReplicationCommandHandler(self.worker_hs) + self.client = ClientReplicationStreamProtocol( + self.worker_hs, "client", "test", clock, repl_handler, + ) + + self._client_transport = None + self._server_transport = None + + def _get_worker_hs_config(self) -> dict: + config = self.default_config() + config["worker_app"] = "synapse.app.generic_worker" + config["worker_replication_host"] = "testserv" + config["worker_replication_http_port"] = "8765" + return config + + def _build_replication_data_handler(self): + return TestReplicationDataHandler(self.worker_hs) + + def reconnect(self): + if self._client_transport: + self.client.close() + + if self._server_transport: + self.server.close() + + self._client_transport = FakeTransport(self.server, self.reactor) + self.client.makeConnection(self._client_transport) + + self._server_transport = FakeTransport(self.client, self.reactor) + self.server.makeConnection(self._server_transport) + + def disconnect(self): + if self._client_transport: + self._client_transport = None + self.client.close() + + if self._server_transport: + self._server_transport = None + self.server.close() + + def replicate(self): + """Tell the master side of replication that something has happened, and then + wait for the replication to occur. + """ + self.streamer.on_notifier_poke() + self.pump(0.1) + + def handle_http_replication_attempt(self) -> SynapseRequest: + """Asserts that a connection attempt was made to the master HS on the + HTTP replication port, then proxies it to the master HS object to be + handled. + + Returns: + The request object received by master HS. + """ + + # We should have an outbound connection attempt. + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) + self.assertEqual(host, "1.2.3.4") + self.assertEqual(port, 8765) + + # Set up client side protocol + client_protocol = client_factory.buildProtocol(None) + + request_factory = OneShotRequestFactory() + + # Set up the server side protocol + channel = _PushHTTPChannel(self.reactor) + channel.requestFactory = request_factory + channel.site = self.site + + # Connect client to server and vice versa. + client_to_server_transport = FakeTransport( + channel, self.reactor, client_protocol + ) + client_protocol.makeConnection(client_to_server_transport) + + server_to_client_transport = FakeTransport( + client_protocol, self.reactor, channel + ) + channel.makeConnection(server_to_client_transport) + + # The request will now be processed by `self.site` and the response + # streamed back. + self.reactor.advance(0) + + # We tear down the connection so it doesn't get reused without our + # knowledge. + server_to_client_transport.loseConnection() + client_to_server_transport.loseConnection() + + return request_factory.request + + def assert_request_is_get_repl_stream_updates( + self, request: SynapseRequest, stream_name: str + ): + """Asserts that the given request is a HTTP replication request for + fetching updates for given stream. + """ + + self.assertRegex( + request.path, + br"^/_synapse/replication/get_repl_stream_updates/%s/[^/]+$" + % (stream_name.encode("ascii"),), + ) + + self.assertEqual(request.method, b"GET") + + +class TestReplicationDataHandler(GenericWorkerReplicationHandler): + """Drop-in for ReplicationDataHandler which just collects RDATA rows""" + + def __init__(self, hs: HomeServer): + super().__init__(hs) + + # list of received (stream_name, token, row) tuples + self.received_rdata_rows = [] # type: List[Tuple[str, int, Any]] + + async def on_rdata(self, stream_name, instance_name, token, rows): + await super().on_rdata(stream_name, instance_name, token, rows) + for r in rows: + self.received_rdata_rows.append((stream_name, token, r)) + + +@attr.s() +class OneShotRequestFactory: + """A simple request factory that generates a single `SynapseRequest` and + stores it for future use. Can only be used once. + """ + + request = attr.ib(default=None) + + def __call__(self, *args, **kwargs): + assert self.request is None + + self.request = SynapseRequest(*args, **kwargs) + return self.request + + +class _PushHTTPChannel(HTTPChannel): + """A HTTPChannel that wraps pull producers to push producers. + + This is a hack to get around the fact that HTTPChannel transparently wraps a + pull producer (which is what Synapse uses to reply to requests) with + `_PullToPush` to convert it to a push producer. Unfortunately `_PullToPush` + uses the standard reactor rather than letting us use our test reactor, which + makes it very hard to test. + """ + + def __init__(self, reactor: IReactorTime): + super().__init__() + self.reactor = reactor + + self._pull_to_push_producer = None # type: Optional[_PullToPushProducer] + + def registerProducer(self, producer, streaming): + # Convert pull producers to push producer. + if not streaming: + self._pull_to_push_producer = _PullToPushProducer( + self.reactor, producer, self + ) + producer = self._pull_to_push_producer + + super().registerProducer(producer, True) + + def unregisterProducer(self): + if self._pull_to_push_producer: + # We need to manually stop the _PullToPushProducer. + self._pull_to_push_producer.stop() + + +class _PullToPushProducer: + """A push producer that wraps a pull producer. + """ + + def __init__( + self, reactor: IReactorTime, producer: IPullProducer, consumer: IConsumer + ): + self._clock = Clock(reactor) + self._producer = producer + self._consumer = consumer + + # While running we use a looping call with a zero delay to call + # resumeProducing on given producer. + self._looping_call = None # type: Optional[LoopingCall] + + # We start writing next reactor tick. + self._start_loop() + + def _start_loop(self): + """Start the looping call to + """ + + if not self._looping_call: + # Start a looping call which runs every tick. + self._looping_call = self._clock.looping_call(self._run_once, 0) + + def stop(self): + """Stops calling resumeProducing. + """ + if self._looping_call: + self._looping_call.stop() + self._looping_call = None + + def pauseProducing(self): + """Implements IPushProducer + """ + self.stop() + + def resumeProducing(self): + """Implements IPushProducer + """ + self._start_loop() + + def stopProducing(self): + """Implements IPushProducer + """ + self.stop() + self._producer.stopProducing() + + def _run_once(self): + """Calls resumeProducing on producer once. + """ + + try: + self._producer.resumeProducing() + except Exception: + logger.exception("Failed to call resumeProducing") + try: + self._consumer.unregisterProducer() + except Exception: + pass + + self.stopProducing() diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 2a1e7c7166..32cb04645f 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -15,22 +15,13 @@ from mock import Mock, NonCallableMock -from synapse.replication.tcp.client import ( - ReplicationClientFactory, - ReplicationClientHandler, -) -from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory -from synapse.storage.database import make_conn +from tests.replication._base import BaseStreamTestCase -from tests import unittest -from tests.server import FakeTransport - -class BaseSlavedStoreTestCase(unittest.HomeserverTestCase): +class BaseSlavedStoreTestCase(BaseStreamTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver( - "blue", federation_client=Mock(), ratelimiter=NonCallableMock(spec_set=["can_do_action"]), ) @@ -40,34 +31,13 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase): return hs def prepare(self, reactor, clock, hs): + super().prepare(reactor, clock, hs) - db_config = hs.config.database.get_single_database() - self.master_store = self.hs.get_datastore() - self.storage = hs.get_storage() - database = hs.get_datastores().databases[0] - self.slaved_store = self.STORE_TYPE( - database, make_conn(db_config, database.engine), self.hs - ) - self.event_id = 0 - - server_factory = ReplicationStreamProtocolFactory(self.hs) - self.streamer = server_factory.streamer - - handler_factory = Mock() - self.replication_handler = ReplicationClientHandler(self.slaved_store) - self.replication_handler.factory = handler_factory + self.reconnect() - client_factory = ReplicationClientFactory( - self.hs, "client_name", self.replication_handler - ) - - server = server_factory.buildProtocol(None) - client = client_factory.buildProtocol(None) - - client.makeConnection(FakeTransport(server, reactor)) - - self.server_to_client_transport = FakeTransport(client, reactor) - server.makeConnection(self.server_to_client_transport) + self.master_store = hs.get_datastore() + self.slaved_store = self.worker_hs.get_datastore() + self.storage = hs.get_storage() def replicate(self): """Tell the master side of replication that something has happened, and then diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index f0561b30e3..0fee8a71c4 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -24,10 +24,10 @@ from synapse.storage.roommember import RoomsForUser from ._base import BaseSlavedStoreTestCase -USER_ID = "@feeling:blue" -USER_ID_2 = "@bright:blue" +USER_ID = "@feeling:test" +USER_ID_2 = "@bright:test" OUTLIER = {"outlier": True} -ROOM_ID = "!room:blue" +ROOM_ID = "!room:test" logger = logging.getLogger(__name__) @@ -239,7 +239,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): self.check("get_rooms_for_user_with_stream_ordering", (USER_ID_2,), set()) # limit the replication rate - repl_transport = self.server_to_client_transport + repl_transport = self._server_transport repl_transport.autoflush = False # build the join and message events and persist them in the same batch. diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py deleted file mode 100644 index e96ad4ca4e..0000000000 --- a/tests/replication/tcp/streams/_base.py +++ /dev/null @@ -1,78 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2019 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from mock import Mock - -from synapse.replication.tcp.commands import ReplicateCommand -from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol -from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory - -from tests import unittest -from tests.server import FakeTransport - - -class BaseStreamTestCase(unittest.HomeserverTestCase): - """Base class for tests of the replication streams""" - - def prepare(self, reactor, clock, hs): - # build a replication server - server_factory = ReplicationStreamProtocolFactory(self.hs) - self.streamer = server_factory.streamer - server = server_factory.buildProtocol(None) - - # build a replication client, with a dummy handler - handler_factory = Mock() - self.test_handler = TestReplicationClientHandler() - self.test_handler.factory = handler_factory - self.client = ClientReplicationStreamProtocol( - "client", "test", clock, self.test_handler - ) - - # wire them together - self.client.makeConnection(FakeTransport(server, reactor)) - server.makeConnection(FakeTransport(self.client, reactor)) - - def replicate(self): - """Tell the master side of replication that something has happened, and then - wait for the replication to occur. - """ - self.streamer.on_notifier_poke() - self.pump(0.1) - - def replicate_stream(self, stream, token="NOW"): - """Make the client end a REPLICATE command to set up a subscription to a stream""" - self.client.send_command(ReplicateCommand(stream, token)) - - -class TestReplicationClientHandler(object): - """Drop-in for ReplicationClientHandler which just collects RDATA rows""" - - def __init__(self): - self.received_rdata_rows = [] - - def get_streams_to_replicate(self): - return {} - - def get_currently_syncing_users(self): - return [] - - def update_connection(self, connection): - pass - - def finished_connecting(self): - pass - - async def on_rdata(self, stream_name, token, rows): - for r in rows: - self.received_rdata_rows.append((stream_name, token, r)) diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py new file mode 100644 index 0000000000..51bf0ef4e9 --- /dev/null +++ b/tests/replication/tcp/streams/test_events.py @@ -0,0 +1,425 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional + +from synapse.api.constants import EventTypes, Membership +from synapse.events import EventBase +from synapse.replication.tcp.streams._base import _STREAM_UPDATE_TARGET_ROW_COUNT +from synapse.replication.tcp.streams.events import ( + EventsStreamCurrentStateRow, + EventsStreamEventRow, + EventsStreamRow, +) +from synapse.rest import admin +from synapse.rest.client.v1 import login, room + +from tests.replication._base import BaseStreamTestCase +from tests.test_utils.event_injection import inject_event, inject_member_event + + +class EventsStreamTestCase(BaseStreamTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + super().prepare(reactor, clock, hs) + self.user_id = self.register_user("u1", "pass") + self.user_tok = self.login("u1", "pass") + + self.reconnect() + + self.room_id = self.helper.create_room_as(tok=self.user_tok) + self.test_handler.received_rdata_rows.clear() + + def test_update_function_event_row_limit(self): + """Test replication with many non-state events + + Checks that all events are correctly replicated when there are lots of + event rows to be replicated. + """ + # disconnect, so that we can stack up some changes + self.disconnect() + + # generate lots of non-state events. We inject them using inject_event + # so that they are not send out over replication until we call self.replicate(). + events = [ + self._inject_test_event() + for _ in range(_STREAM_UPDATE_TARGET_ROW_COUNT + 1) + ] + + # also one state event + state_event = self._inject_state_event() + + # tell the notifier to catch up to avoid duplicate rows. + # workaround for https://github.com/matrix-org/synapse/issues/7360 + # FIXME remove this when the above is fixed + self.replicate() + + # check we're testing what we think we are: no rows should yet have been + # received + self.assertEqual([], self.test_handler.received_rdata_rows) + + # now reconnect to pull the updates + self.reconnect() + self.replicate() + + # we should have received all the expected rows in the right order (as + # well as various cache invalidation updates which we ignore) + received_rows = [ + row for row in self.test_handler.received_rdata_rows if row[0] == "events" + ] + + for event in events: + stream_name, token, row = received_rows.pop(0) + self.assertEqual("events", stream_name) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "ev") + self.assertIsInstance(row.data, EventsStreamEventRow) + self.assertEqual(row.data.event_id, event.event_id) + + stream_name, token, row = received_rows.pop(0) + self.assertIsInstance(row, EventsStreamRow) + self.assertIsInstance(row.data, EventsStreamEventRow) + self.assertEqual(row.data.event_id, state_event.event_id) + + stream_name, token, row = received_rows.pop(0) + self.assertEqual("events", stream_name) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "state") + self.assertIsInstance(row.data, EventsStreamCurrentStateRow) + self.assertEqual(row.data.event_id, state_event.event_id) + + self.assertEqual([], received_rows) + + def test_update_function_huge_state_change(self): + """Test replication with many state events + + Ensures that all events are correctly replicated when there are lots of + state change rows to be replicated. + """ + + # we want to generate lots of state changes at a single stream ID. + # + # We do this by having two branches in the DAG. On one, we have a moderator + # which that generates lots of state; on the other, we de-op the moderator, + # thus invalidating all the state. + + OTHER_USER = "@other_user:localhost" + + # have the user join + inject_member_event(self.hs, self.room_id, OTHER_USER, Membership.JOIN) + + # Update existing power levels with mod at PL50 + pls = self.helper.get_state( + self.room_id, EventTypes.PowerLevels, tok=self.user_tok + ) + pls["users"][OTHER_USER] = 50 + self.helper.send_state( + self.room_id, EventTypes.PowerLevels, pls, tok=self.user_tok, + ) + + # this is the point in the DAG where we make a fork + fork_point = self.get_success( + self.hs.get_datastore().get_latest_event_ids_in_room(self.room_id) + ) # type: List[str] + + events = [ + self._inject_state_event(sender=OTHER_USER) + for _ in range(_STREAM_UPDATE_TARGET_ROW_COUNT) + ] + + self.replicate() + # all those events and state changes should have landed + self.assertGreaterEqual( + len(self.test_handler.received_rdata_rows), 2 * len(events) + ) + + # disconnect, so that we can stack up the changes + self.disconnect() + self.test_handler.received_rdata_rows.clear() + + # a state event which doesn't get rolled back, to check that the state + # before the huge update comes through ok + state1 = self._inject_state_event() + + # roll back all the state by de-modding the user + prev_events = fork_point + pls["users"][OTHER_USER] = 0 + pl_event = inject_event( + self.hs, + prev_event_ids=prev_events, + type=EventTypes.PowerLevels, + state_key="", + sender=self.user_id, + room_id=self.room_id, + content=pls, + ) + + # one more bit of state that doesn't get rolled back + state2 = self._inject_state_event() + + # tell the notifier to catch up to avoid duplicate rows. + # workaround for https://github.com/matrix-org/synapse/issues/7360 + # FIXME remove this when the above is fixed + self.replicate() + + # check we're testing what we think we are: no rows should yet have been + # received + self.assertEqual([], self.test_handler.received_rdata_rows) + + # now reconnect to pull the updates + self.reconnect() + self.replicate() + + # we should have received all the expected rows in the right order (as + # well as various cache invalidation updates which we ignore) + # + # we expect: + # + # - two rows for state1 + # - the PL event row, plus state rows for the PL event and each + # of the states that got reverted. + # - two rows for state2 + + received_rows = [ + row for row in self.test_handler.received_rdata_rows if row[0] == "events" + ] + + # first check the first two rows, which should be state1 + + stream_name, token, row = received_rows.pop(0) + self.assertEqual("events", stream_name) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "ev") + self.assertIsInstance(row.data, EventsStreamEventRow) + self.assertEqual(row.data.event_id, state1.event_id) + + stream_name, token, row = received_rows.pop(0) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "state") + self.assertIsInstance(row.data, EventsStreamCurrentStateRow) + self.assertEqual(row.data.event_id, state1.event_id) + + # now the last two rows, which should be state2 + stream_name, token, row = received_rows.pop(-2) + self.assertEqual("events", stream_name) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "ev") + self.assertIsInstance(row.data, EventsStreamEventRow) + self.assertEqual(row.data.event_id, state2.event_id) + + stream_name, token, row = received_rows.pop(-1) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "state") + self.assertIsInstance(row.data, EventsStreamCurrentStateRow) + self.assertEqual(row.data.event_id, state2.event_id) + + # that should leave us with the rows for the PL event + self.assertEqual(len(received_rows), len(events) + 2) + + stream_name, token, row = received_rows.pop(0) + self.assertEqual("events", stream_name) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "ev") + self.assertIsInstance(row.data, EventsStreamEventRow) + self.assertEqual(row.data.event_id, pl_event.event_id) + + # the state rows are unsorted + state_rows = [] # type: List[EventsStreamCurrentStateRow] + for stream_name, token, row in received_rows: + self.assertEqual("events", stream_name) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "state") + self.assertIsInstance(row.data, EventsStreamCurrentStateRow) + state_rows.append(row.data) + + state_rows.sort(key=lambda r: r.state_key) + + sr = state_rows.pop(0) + self.assertEqual(sr.type, EventTypes.PowerLevels) + self.assertEqual(sr.event_id, pl_event.event_id) + for sr in state_rows: + self.assertEqual(sr.type, "test_state_event") + # "None" indicates the state has been deleted + self.assertIsNone(sr.event_id) + + def test_update_function_state_row_limit(self): + """Test replication with many state events over several stream ids. + """ + + # we want to generate lots of state changes, but for this test, we want to + # spread out the state changes over a few stream IDs. + # + # We do this by having two branches in the DAG. On one, we have four moderators, + # each of which that generates lots of state; on the other, we de-op the users, + # thus invalidating all the state. + + NUM_USERS = 4 + STATES_PER_USER = _STREAM_UPDATE_TARGET_ROW_COUNT // 4 + 1 + + user_ids = ["@user%i:localhost" % (i,) for i in range(NUM_USERS)] + + # have the users join + for u in user_ids: + inject_member_event(self.hs, self.room_id, u, Membership.JOIN) + + # Update existing power levels with mod at PL50 + pls = self.helper.get_state( + self.room_id, EventTypes.PowerLevels, tok=self.user_tok + ) + pls["users"].update({u: 50 for u in user_ids}) + self.helper.send_state( + self.room_id, EventTypes.PowerLevels, pls, tok=self.user_tok, + ) + + # this is the point in the DAG where we make a fork + fork_point = self.get_success( + self.hs.get_datastore().get_latest_event_ids_in_room(self.room_id) + ) # type: List[str] + + events = [] # type: List[EventBase] + for user in user_ids: + events.extend( + self._inject_state_event(sender=user) for _ in range(STATES_PER_USER) + ) + + self.replicate() + + # all those events and state changes should have landed + self.assertGreaterEqual( + len(self.test_handler.received_rdata_rows), 2 * len(events) + ) + + # disconnect, so that we can stack up the changes + self.disconnect() + self.test_handler.received_rdata_rows.clear() + + # now roll back all that state by de-modding the users + prev_events = fork_point + pl_events = [] + for u in user_ids: + pls["users"][u] = 0 + e = inject_event( + self.hs, + prev_event_ids=prev_events, + type=EventTypes.PowerLevels, + state_key="", + sender=self.user_id, + room_id=self.room_id, + content=pls, + ) + prev_events = [e.event_id] + pl_events.append(e) + + # tell the notifier to catch up to avoid duplicate rows. + # workaround for https://github.com/matrix-org/synapse/issues/7360 + # FIXME remove this when the above is fixed + self.replicate() + + # check we're testing what we think we are: no rows should yet have been + # received + self.assertEqual([], self.test_handler.received_rdata_rows) + + # now reconnect to pull the updates + self.reconnect() + self.replicate() + + # we should have received all the expected rows in the right order (as + # well as various cache invalidation updates which we ignore) + received_rows = [ + row for row in self.test_handler.received_rdata_rows if row[0] == "events" + ] + self.assertGreaterEqual(len(received_rows), len(events)) + for i in range(NUM_USERS): + # for each user, we expect the PL event row, followed by state rows for + # the PL event and each of the states that got reverted. + stream_name, token, row = received_rows.pop(0) + self.assertEqual("events", stream_name) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "ev") + self.assertIsInstance(row.data, EventsStreamEventRow) + self.assertEqual(row.data.event_id, pl_events[i].event_id) + + # the state rows are unsorted + state_rows = [] # type: List[EventsStreamCurrentStateRow] + for j in range(STATES_PER_USER + 1): + stream_name, token, row = received_rows.pop(0) + self.assertEqual("events", stream_name) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "state") + self.assertIsInstance(row.data, EventsStreamCurrentStateRow) + state_rows.append(row.data) + + state_rows.sort(key=lambda r: r.state_key) + + sr = state_rows.pop(0) + self.assertEqual(sr.type, EventTypes.PowerLevels) + self.assertEqual(sr.event_id, pl_events[i].event_id) + for sr in state_rows: + self.assertEqual(sr.type, "test_state_event") + # "None" indicates the state has been deleted + self.assertIsNone(sr.event_id) + + self.assertEqual([], received_rows) + + event_count = 0 + + def _inject_test_event( + self, body: Optional[str] = None, sender: Optional[str] = None, **kwargs + ) -> EventBase: + if sender is None: + sender = self.user_id + + if body is None: + body = "event %i" % (self.event_count,) + self.event_count += 1 + + return inject_event( + self.hs, + room_id=self.room_id, + sender=sender, + type="test_event", + content={"body": body}, + **kwargs + ) + + def _inject_state_event( + self, + body: Optional[str] = None, + state_key: Optional[str] = None, + sender: Optional[str] = None, + ) -> EventBase: + if sender is None: + sender = self.user_id + + if state_key is None: + state_key = "state_%i" % (self.event_count,) + self.event_count += 1 + + if body is None: + body = "state event %s" % (state_key,) + + return inject_event( + self.hs, + room_id=self.room_id, + sender=sender, + type="test_state_event", + state_key=state_key, + content={"body": body}, + ) diff --git a/tests/replication/tcp/streams/test_federation.py b/tests/replication/tcp/streams/test_federation.py new file mode 100644 index 0000000000..2babea4e3e --- /dev/null +++ b/tests/replication/tcp/streams/test_federation.py @@ -0,0 +1,81 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.federation.send_queue import EduRow +from synapse.replication.tcp.streams.federation import FederationStream + +from tests.replication._base import BaseStreamTestCase + + +class FederationStreamTestCase(BaseStreamTestCase): + def _get_worker_hs_config(self) -> dict: + # enable federation sending on the worker + config = super()._get_worker_hs_config() + # TODO: make it so we don't need both of these + config["send_federation"] = True + config["worker_app"] = "synapse.app.federation_sender" + return config + + def test_catchup(self): + """Basic test of catchup on reconnect + + Makes sure that updates sent while we are offline are received later. + """ + fed_sender = self.hs.get_federation_sender() + received_rows = self.test_handler.received_rdata_rows + + fed_sender.build_and_send_edu("testdest", "m.test_edu", {"a": "b"}) + + self.reconnect() + self.reactor.advance(0) + + # check we're testing what we think we are: no rows should yet have been + # received + self.assertEqual(received_rows, []) + + # We should now see an attempt to connect to the master + request = self.handle_http_replication_attempt() + self.assert_request_is_get_repl_stream_updates(request, "federation") + + # we should have received an update row + stream_name, token, row = received_rows.pop() + self.assertEqual(stream_name, "federation") + self.assertIsInstance(row, FederationStream.FederationStreamRow) + self.assertEqual(row.type, EduRow.TypeId) + edurow = EduRow.from_data(row.data) + self.assertEqual(edurow.edu.edu_type, "m.test_edu") + self.assertEqual(edurow.edu.origin, self.hs.hostname) + self.assertEqual(edurow.edu.destination, "testdest") + self.assertEqual(edurow.edu.content, {"a": "b"}) + + self.assertEqual(received_rows, []) + + # additional updates should be transferred without an HTTP hit + fed_sender.build_and_send_edu("testdest", "m.test1", {"c": "d"}) + self.reactor.advance(0) + # there should be no http hit + self.assertEqual(len(self.reactor.tcpClients), 0) + # ... but we should have a row + self.assertEqual(len(received_rows), 1) + + stream_name, token, row = received_rows.pop() + self.assertEqual(stream_name, "federation") + self.assertIsInstance(row, FederationStream.FederationStreamRow) + self.assertEqual(row.type, EduRow.TypeId) + edurow = EduRow.from_data(row.data) + self.assertEqual(edurow.edu.edu_type, "m.test1") + self.assertEqual(edurow.edu.origin, self.hs.hostname) + self.assertEqual(edurow.edu.destination, "testdest") + self.assertEqual(edurow.edu.content, {"c": "d"}) diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py index d5a99f6caa..56b062ecc1 100644 --- a/tests/replication/tcp/streams/test_receipts.py +++ b/tests/replication/tcp/streams/test_receipts.py @@ -12,35 +12,73 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from synapse.replication.tcp.streams._base import ReceiptsStreamRow -from tests.replication.tcp.streams._base import BaseStreamTestCase +# type: ignore + +from mock import Mock + +from synapse.replication.tcp.streams._base import ReceiptsStream + +from tests.replication._base import BaseStreamTestCase USER_ID = "@feeling:blue" -ROOM_ID = "!room:blue" -EVENT_ID = "$event:blue" class ReceiptsStreamTestCase(BaseStreamTestCase): + def _build_replication_data_handler(self): + return Mock(wraps=super()._build_replication_data_handler()) + def test_receipt(self): - # make the client subscribe to the receipts stream - self.replicate_stream("receipts", "NOW") + self.reconnect() # tell the master to send a new receipt self.get_success( self.hs.get_datastore().insert_receipt( - ROOM_ID, "m.read", USER_ID, [EVENT_ID], {"a": 1} + "!room:blue", "m.read", USER_ID, ["$event:blue"], {"a": 1} ) ) self.replicate() # there should be one RDATA command - rdata_rows = self.test_handler.received_rdata_rows + self.test_handler.on_rdata.assert_called_once() + stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0] + self.assertEqual(stream_name, "receipts") self.assertEqual(1, len(rdata_rows)) - self.assertEqual(rdata_rows[0][0], "receipts") - row = rdata_rows[0][2] # type: ReceiptsStreamRow - self.assertEqual(ROOM_ID, row.room_id) + row = rdata_rows[0] # type: ReceiptsStream.ReceiptsStreamRow + self.assertEqual("!room:blue", row.room_id) self.assertEqual("m.read", row.receipt_type) self.assertEqual(USER_ID, row.user_id) - self.assertEqual(EVENT_ID, row.event_id) + self.assertEqual("$event:blue", row.event_id) self.assertEqual({"a": 1}, row.data) + + # Now let's disconnect and insert some data. + self.disconnect() + + self.test_handler.on_rdata.reset_mock() + + self.get_success( + self.hs.get_datastore().insert_receipt( + "!room2:blue", "m.read", USER_ID, ["$event2:foo"], {"a": 2} + ) + ) + self.replicate() + + # Nothing should have happened as we are disconnected + self.test_handler.on_rdata.assert_not_called() + + self.reconnect() + self.pump(0.1) + + # We should now have caught up and get the missing data + self.test_handler.on_rdata.assert_called_once() + stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0] + self.assertEqual(stream_name, "receipts") + self.assertEqual(token, 3) + self.assertEqual(1, len(rdata_rows)) + + row = rdata_rows[0] # type: ReceiptsStream.ReceiptsStreamRow + self.assertEqual("!room2:blue", row.room_id) + self.assertEqual("m.read", row.receipt_type) + self.assertEqual(USER_ID, row.user_id) + self.assertEqual("$event2:foo", row.event_id) + self.assertEqual({"a": 2}, row.data) diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py new file mode 100644 index 0000000000..fd62b26356 --- /dev/null +++ b/tests/replication/tcp/streams/test_typing.py @@ -0,0 +1,77 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from mock import Mock + +from synapse.handlers.typing import RoomMember +from synapse.replication.tcp.streams import TypingStream + +from tests.replication._base import BaseStreamTestCase + +USER_ID = "@feeling:blue" + + +class TypingStreamTestCase(BaseStreamTestCase): + def _build_replication_data_handler(self): + return Mock(wraps=super()._build_replication_data_handler()) + + def test_typing(self): + typing = self.hs.get_typing_handler() + + room_id = "!bar:blue" + + self.reconnect() + + typing._push_update(member=RoomMember(room_id, USER_ID), typing=True) + + self.reactor.advance(0) + + # We should now see an attempt to connect to the master + request = self.handle_http_replication_attempt() + self.assert_request_is_get_repl_stream_updates(request, "typing") + + self.test_handler.on_rdata.assert_called_once() + stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0] + self.assertEqual(stream_name, "typing") + self.assertEqual(1, len(rdata_rows)) + row = rdata_rows[0] # type: TypingStream.TypingStreamRow + self.assertEqual(room_id, row.room_id) + self.assertEqual([USER_ID], row.user_ids) + + # Now let's disconnect and insert some data. + self.disconnect() + + self.test_handler.on_rdata.reset_mock() + + typing._push_update(member=RoomMember(room_id, USER_ID), typing=False) + + self.test_handler.on_rdata.assert_not_called() + + self.reconnect() + self.pump(0.1) + + # We should now see an attempt to connect to the master + request = self.handle_http_replication_attempt() + self.assert_request_is_get_repl_stream_updates(request, "typing") + + # The from token should be the token from the last RDATA we got. + self.assertEqual(int(request.args[b"from_token"][0]), token) + + self.test_handler.on_rdata.assert_called_once() + stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0] + self.assertEqual(stream_name, "typing") + self.assertEqual(1, len(rdata_rows)) + row = rdata_rows[0] + self.assertEqual(room_id, row.room_id) + self.assertEqual([], row.user_ids) diff --git a/tests/replication/tcp/test_commands.py b/tests/replication/tcp/test_commands.py new file mode 100644 index 0000000000..7ddfd0a733 --- /dev/null +++ b/tests/replication/tcp/test_commands.py @@ -0,0 +1,44 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from synapse.replication.tcp.commands import ( + RdataCommand, + ReplicateCommand, + parse_command_from_line, +) + +from tests.unittest import TestCase + + +class ParseCommandTestCase(TestCase): + def test_parse_one_word_command(self): + line = "REPLICATE" + cmd = parse_command_from_line(line) + self.assertIsInstance(cmd, ReplicateCommand) + + def test_parse_rdata(self): + line = 'RDATA events master 6287863 ["ev", ["$eventid", "!roomid", "type", null, null, null]]' + cmd = parse_command_from_line(line) + self.assertIsInstance(cmd, RdataCommand) + self.assertEqual(cmd.stream_name, "events") + self.assertEqual(cmd.instance_name, "master") + self.assertEqual(cmd.token, 6287863) + + def test_parse_rdata_batch(self): + line = 'RDATA presence master batch ["@foo:example.com", "online"]' + cmd = parse_command_from_line(line) + self.assertIsInstance(cmd, RdataCommand) + self.assertEqual(cmd.stream_name, "presence") + self.assertEqual(cmd.instance_name, "master") + self.assertIsNone(cmd.token) diff --git a/tests/replication/tcp/test_remote_server_up.py b/tests/replication/tcp/test_remote_server_up.py new file mode 100644 index 0000000000..d1c15caeb0 --- /dev/null +++ b/tests/replication/tcp/test_remote_server_up.py @@ -0,0 +1,62 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple + +from twisted.internet.interfaces import IProtocol +from twisted.test.proto_helpers import StringTransport + +from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory + +from tests.unittest import HomeserverTestCase + + +class RemoteServerUpTestCase(HomeserverTestCase): + def prepare(self, reactor, clock, hs): + self.factory = ReplicationStreamProtocolFactory(hs) + + def _make_client(self) -> Tuple[IProtocol, StringTransport]: + """Create a new direct TCP replication connection + """ + + proto = self.factory.buildProtocol(("127.0.0.1", 0)) + transport = StringTransport() + proto.makeConnection(transport) + + # We can safely ignore the commands received during connection. + self.pump() + transport.clear() + + return proto, transport + + def test_relay(self): + """Test that Synapse will relay REMOTE_SERVER_UP commands to all + other connections, but not the one that sent it. + """ + + proto1, transport1 = self._make_client() + + # We shouldn't receive an echo. + proto1.dataReceived(b"REMOTE_SERVER_UP example.com\n") + self.pump() + self.assertEqual(transport1.value(), b"") + + # But we should see an echo if we connect another client + proto2, transport2 = self._make_client() + proto1.dataReceived(b"REMOTE_SERVER_UP example.com\n") + + self.pump() + self.assertEqual(transport1.value(), b"") + self.assertEqual(transport2.value(), b"REMOTE_SERVER_UP example.com\n") diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index 0342aed416..977615ebef 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -17,7 +17,6 @@ import json import os import urllib.parse from binascii import unhexlify -from typing import List, Optional from mock import Mock @@ -27,7 +26,7 @@ import synapse.rest.admin from synapse.http.server import JsonResource from synapse.logging.context import make_deferred_yieldable from synapse.rest.admin import VersionServlet -from synapse.rest.client.v1 import directory, events, login, room +from synapse.rest.client.v1 import login, room from synapse.rest.client.v2_alpha import groups from tests import unittest @@ -51,129 +50,6 @@ class VersionTestCase(unittest.HomeserverTestCase): ) -class ShutdownRoomTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - login.register_servlets, - events.register_servlets, - room.register_servlets, - room.register_deprecated_servlets, - ] - - def prepare(self, reactor, clock, hs): - self.event_creation_handler = hs.get_event_creation_handler() - hs.config.user_consent_version = "1" - - consent_uri_builder = Mock() - consent_uri_builder.build_user_consent_uri.return_value = "http://example.com" - self.event_creation_handler._consent_uri_builder = consent_uri_builder - - self.store = hs.get_datastore() - - self.admin_user = self.register_user("admin", "pass", admin=True) - self.admin_user_tok = self.login("admin", "pass") - - self.other_user = self.register_user("user", "pass") - self.other_user_token = self.login("user", "pass") - - # Mark the admin user as having consented - self.get_success(self.store.user_set_consent_version(self.admin_user, "1")) - - def test_shutdown_room_consent(self): - """Test that we can shutdown rooms with local users who have not - yet accepted the privacy policy. This used to fail when we tried to - force part the user from the old room. - """ - self.event_creation_handler._block_events_without_consent_error = None - - room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token) - - # Assert one user in room - users_in_room = self.get_success(self.store.get_users_in_room(room_id)) - self.assertEqual([self.other_user], users_in_room) - - # Enable require consent to send events - self.event_creation_handler._block_events_without_consent_error = "Error" - - # Assert that the user is getting consent error - self.helper.send( - room_id, body="foo", tok=self.other_user_token, expect_code=403 - ) - - # Test that the admin can still send shutdown - url = "admin/shutdown_room/" + room_id - request, channel = self.make_request( - "POST", - url.encode("ascii"), - json.dumps({"new_room_user_id": self.admin_user}), - access_token=self.admin_user_tok, - ) - self.render(request) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - - # Assert there is now no longer anyone in the room - users_in_room = self.get_success(self.store.get_users_in_room(room_id)) - self.assertEqual([], users_in_room) - - def test_shutdown_room_block_peek(self): - """Test that a world_readable room can no longer be peeked into after - it has been shut down. - """ - - self.event_creation_handler._block_events_without_consent_error = None - - room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token) - - # Enable world readable - url = "rooms/%s/state/m.room.history_visibility" % (room_id,) - request, channel = self.make_request( - "PUT", - url.encode("ascii"), - json.dumps({"history_visibility": "world_readable"}), - access_token=self.other_user_token, - ) - self.render(request) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - - # Test that the admin can still send shutdown - url = "admin/shutdown_room/" + room_id - request, channel = self.make_request( - "POST", - url.encode("ascii"), - json.dumps({"new_room_user_id": self.admin_user}), - access_token=self.admin_user_tok, - ) - self.render(request) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - - # Assert we can no longer peek into the room - self._assert_peek(room_id, expect_code=403) - - def _assert_peek(self, room_id, expect_code): - """Assert that the admin user can (or cannot) peek into the room. - """ - - url = "rooms/%s/initialSync" % (room_id,) - request, channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok - ) - self.render(request) - self.assertEqual( - expect_code, int(channel.result["code"]), msg=channel.result["body"] - ) - - url = "events?timeout=0&room_id=" + room_id - request, channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok - ) - self.render(request) - self.assertEqual( - expect_code, int(channel.result["code"]), msg=channel.result["body"] - ) - - class DeleteGroupTestCase(unittest.HomeserverTestCase): servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, @@ -273,86 +149,6 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): return channel.json_body["groups"] -class PurgeRoomTestCase(unittest.HomeserverTestCase): - """Test /purge_room admin API. - """ - - servlets = [ - synapse.rest.admin.register_servlets, - login.register_servlets, - room.register_servlets, - ] - - def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() - - self.admin_user = self.register_user("admin", "pass", admin=True) - self.admin_user_tok = self.login("admin", "pass") - - def test_purge_room(self): - room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) - - # All users have to have left the room. - self.helper.leave(room_id, user=self.admin_user, tok=self.admin_user_tok) - - url = "/_synapse/admin/v1/purge_room" - request, channel = self.make_request( - "POST", - url.encode("ascii"), - {"room_id": room_id}, - access_token=self.admin_user_tok, - ) - self.render(request) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - - # Test that the following tables have been purged of all rows related to the room. - for table in ( - "current_state_events", - "event_backward_extremities", - "event_forward_extremities", - "event_json", - "event_push_actions", - "event_search", - "events", - "group_rooms", - "public_room_list_stream", - "receipts_graph", - "receipts_linearized", - "room_aliases", - "room_depth", - "room_memberships", - "room_stats_state", - "room_stats_current", - "room_stats_historical", - "room_stats_earliest_token", - "rooms", - "stream_ordering_to_exterm", - "users_in_public_rooms", - "users_who_share_private_rooms", - "appservice_room_list", - "e2e_room_keys", - "event_push_summary", - "pusher_throttle", - "group_summary_rooms", - "local_invites", - "room_account_data", - "room_tags", - # "state_groups", # Current impl leaves orphaned state groups around. - "state_groups_state", - ): - count = self.get_success( - self.store.db.simple_select_one_onecol( - table=table, - keyvalues={"room_id": room_id}, - retcol="COUNT(*)", - desc="test_purge_room", - ) - ) - - self.assertEqual(count, 0, msg="Rows not purged in {}".format(table)) - - class QuarantineMediaTestCase(unittest.HomeserverTestCase): """Test /quarantine_media admin API. """ @@ -691,389 +487,3 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): % server_and_media_id_2 ), ) - - -class RoomTestCase(unittest.HomeserverTestCase): - """Test /room admin API. - """ - - servlets = [ - synapse.rest.admin.register_servlets, - login.register_servlets, - room.register_servlets, - directory.register_servlets, - ] - - def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() - - # Create user - self.admin_user = self.register_user("admin", "pass", admin=True) - self.admin_user_tok = self.login("admin", "pass") - - def test_list_rooms(self): - """Test that we can list rooms""" - # Create 3 test rooms - total_rooms = 3 - room_ids = [] - for x in range(total_rooms): - room_id = self.helper.create_room_as( - self.admin_user, tok=self.admin_user_tok - ) - room_ids.append(room_id) - - # Request the list of rooms - url = "/_synapse/admin/v1/rooms" - request, channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok, - ) - self.render(request) - - # Check request completed successfully - self.assertEqual(200, int(channel.code), msg=channel.json_body) - - # Check that response json body contains a "rooms" key - self.assertTrue( - "rooms" in channel.json_body, - msg="Response body does not " "contain a 'rooms' key", - ) - - # Check that 3 rooms were returned - self.assertEqual(3, len(channel.json_body["rooms"]), msg=channel.json_body) - - # Check their room_ids match - returned_room_ids = [room["room_id"] for room in channel.json_body["rooms"]] - self.assertEqual(room_ids, returned_room_ids) - - # Check that all fields are available - for r in channel.json_body["rooms"]: - self.assertIn("name", r) - self.assertIn("canonical_alias", r) - self.assertIn("joined_members", r) - - # Check that the correct number of total rooms was returned - self.assertEqual(channel.json_body["total_rooms"], total_rooms) - - # Check that the offset is correct - # Should be 0 as we aren't paginating - self.assertEqual(channel.json_body["offset"], 0) - - # Check that the prev_batch parameter is not present - self.assertNotIn("prev_batch", channel.json_body) - - # We shouldn't receive a next token here as there's no further rooms to show - self.assertNotIn("next_batch", channel.json_body) - - def test_list_rooms_pagination(self): - """Test that we can get a full list of rooms through pagination""" - # Create 5 test rooms - total_rooms = 5 - room_ids = [] - for x in range(total_rooms): - room_id = self.helper.create_room_as( - self.admin_user, tok=self.admin_user_tok - ) - room_ids.append(room_id) - - # Set the name of the rooms so we get a consistent returned ordering - for idx, room_id in enumerate(room_ids): - self.helper.send_state( - room_id, "m.room.name", {"name": str(idx)}, tok=self.admin_user_tok, - ) - - # Request the list of rooms - returned_room_ids = [] - start = 0 - limit = 2 - - run_count = 0 - should_repeat = True - while should_repeat: - run_count += 1 - - url = "/_synapse/admin/v1/rooms?from=%d&limit=%d&order_by=%s" % ( - start, - limit, - "alphabetical", - ) - request, channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok, - ) - self.render(request) - self.assertEqual( - 200, int(channel.result["code"]), msg=channel.result["body"] - ) - - self.assertTrue("rooms" in channel.json_body) - for r in channel.json_body["rooms"]: - returned_room_ids.append(r["room_id"]) - - # Check that the correct number of total rooms was returned - self.assertEqual(channel.json_body["total_rooms"], total_rooms) - - # Check that the offset is correct - # We're only getting 2 rooms each page, so should be 2 * last run_count - self.assertEqual(channel.json_body["offset"], 2 * (run_count - 1)) - - if run_count > 1: - # Check the value of prev_batch is correct - self.assertEqual(channel.json_body["prev_batch"], 2 * (run_count - 2)) - - if "next_batch" not in channel.json_body: - # We have reached the end of the list - should_repeat = False - else: - # Make another query with an updated start value - start = channel.json_body["next_batch"] - - # We should've queried the endpoint 3 times - self.assertEqual( - run_count, - 3, - msg="Should've queried 3 times for 5 rooms with limit 2 per query", - ) - - # Check that we received all of the room ids - self.assertEqual(room_ids, returned_room_ids) - - url = "/_synapse/admin/v1/rooms?from=%d&limit=%d" % (start, limit) - request, channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok, - ) - self.render(request) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - - def test_correct_room_attributes(self): - """Test the correct attributes for a room are returned""" - # Create a test room - room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) - - test_alias = "#test:test" - test_room_name = "something" - - # Have another user join the room - user_2 = self.register_user("user4", "pass") - user_tok_2 = self.login("user4", "pass") - self.helper.join(room_id, user_2, tok=user_tok_2) - - # Create a new alias to this room - url = "/_matrix/client/r0/directory/room/%s" % (urllib.parse.quote(test_alias),) - request, channel = self.make_request( - "PUT", - url.encode("ascii"), - {"room_id": room_id}, - access_token=self.admin_user_tok, - ) - self.render(request) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - - # Set this new alias as the canonical alias for this room - self.helper.send_state( - room_id, - "m.room.aliases", - {"aliases": [test_alias]}, - tok=self.admin_user_tok, - state_key="test", - ) - self.helper.send_state( - room_id, - "m.room.canonical_alias", - {"alias": test_alias}, - tok=self.admin_user_tok, - ) - - # Set a name for the room - self.helper.send_state( - room_id, "m.room.name", {"name": test_room_name}, tok=self.admin_user_tok, - ) - - # Request the list of rooms - url = "/_synapse/admin/v1/rooms" - request, channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok, - ) - self.render(request) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - - # Check that rooms were returned - self.assertTrue("rooms" in channel.json_body) - rooms = channel.json_body["rooms"] - - # Check that only one room was returned - self.assertEqual(len(rooms), 1) - - # And that the value of the total_rooms key was correct - self.assertEqual(channel.json_body["total_rooms"], 1) - - # Check that the offset is correct - # We're not paginating, so should be 0 - self.assertEqual(channel.json_body["offset"], 0) - - # Check that there is no `prev_batch` - self.assertNotIn("prev_batch", channel.json_body) - - # Check that there is no `next_batch` - self.assertNotIn("next_batch", channel.json_body) - - # Check that all provided attributes are set - r = rooms[0] - self.assertEqual(room_id, r["room_id"]) - self.assertEqual(test_room_name, r["name"]) - self.assertEqual(test_alias, r["canonical_alias"]) - - def test_room_list_sort_order(self): - """Test room list sort ordering. alphabetical versus number of members, - reversing the order, etc. - """ - # Create 3 test rooms - room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) - room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) - room_id_3 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) - - # Set room names in alphabetical order. room 1 -> A, 2 -> B, 3 -> C - self.helper.send_state( - room_id_1, "m.room.name", {"name": "A"}, tok=self.admin_user_tok, - ) - self.helper.send_state( - room_id_2, "m.room.name", {"name": "B"}, tok=self.admin_user_tok, - ) - self.helper.send_state( - room_id_3, "m.room.name", {"name": "C"}, tok=self.admin_user_tok, - ) - - # Set room member size in the reverse order. room 1 -> 1 member, 2 -> 2, 3 -> 3 - user_1 = self.register_user("bob1", "pass") - user_1_tok = self.login("bob1", "pass") - self.helper.join(room_id_2, user_1, tok=user_1_tok) - - user_2 = self.register_user("bob2", "pass") - user_2_tok = self.login("bob2", "pass") - self.helper.join(room_id_3, user_2, tok=user_2_tok) - - user_3 = self.register_user("bob3", "pass") - user_3_tok = self.login("bob3", "pass") - self.helper.join(room_id_3, user_3, tok=user_3_tok) - - def _order_test( - order_type: str, expected_room_list: List[str], reverse: bool = False, - ): - """Request the list of rooms in a certain order. Assert that order is what - we expect - - Args: - order_type: The type of ordering to give the server - expected_room_list: The list of room_ids in the order we expect to get - back from the server - """ - # Request the list of rooms in the given order - url = "/_synapse/admin/v1/rooms?order_by=%s" % (order_type,) - if reverse: - url += "&dir=b" - request, channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok, - ) - self.render(request) - self.assertEqual(200, channel.code, msg=channel.json_body) - - # Check that rooms were returned - self.assertTrue("rooms" in channel.json_body) - rooms = channel.json_body["rooms"] - - # Check for the correct total_rooms value - self.assertEqual(channel.json_body["total_rooms"], 3) - - # Check that the offset is correct - # We're not paginating, so should be 0 - self.assertEqual(channel.json_body["offset"], 0) - - # Check that there is no `prev_batch` - self.assertNotIn("prev_batch", channel.json_body) - - # Check that there is no `next_batch` - self.assertNotIn("next_batch", channel.json_body) - - # Check that rooms were returned in alphabetical order - returned_order = [r["room_id"] for r in rooms] - self.assertListEqual(expected_room_list, returned_order) # order is checked - - # Test different sort orders, with forward and reverse directions - _order_test("alphabetical", [room_id_1, room_id_2, room_id_3]) - _order_test("alphabetical", [room_id_3, room_id_2, room_id_1], reverse=True) - - _order_test("size", [room_id_3, room_id_2, room_id_1]) - _order_test("size", [room_id_1, room_id_2, room_id_3], reverse=True) - - def test_search_term(self): - """Test that searching for a room works correctly""" - # Create two test rooms - room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) - room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) - - room_name_1 = "something" - room_name_2 = "else" - - # Set the name for each room - self.helper.send_state( - room_id_1, "m.room.name", {"name": room_name_1}, tok=self.admin_user_tok, - ) - self.helper.send_state( - room_id_2, "m.room.name", {"name": room_name_2}, tok=self.admin_user_tok, - ) - - def _search_test( - expected_room_id: Optional[str], - search_term: str, - expected_http_code: int = 200, - ): - """Search for a room and check that the returned room's id is a match - - Args: - expected_room_id: The room_id expected to be returned by the API. Set - to None to expect zero results for the search - search_term: The term to search for room names with - expected_http_code: The expected http code for the request - """ - url = "/_synapse/admin/v1/rooms?search_term=%s" % (search_term,) - request, channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok, - ) - self.render(request) - self.assertEqual(expected_http_code, channel.code, msg=channel.json_body) - - if expected_http_code != 200: - return - - # Check that rooms were returned - self.assertTrue("rooms" in channel.json_body) - rooms = channel.json_body["rooms"] - - # Check that the expected number of rooms were returned - expected_room_count = 1 if expected_room_id else 0 - self.assertEqual(len(rooms), expected_room_count) - self.assertEqual(channel.json_body["total_rooms"], expected_room_count) - - # Check that the offset is correct - # We're not paginating, so should be 0 - self.assertEqual(channel.json_body["offset"], 0) - - # Check that there is no `prev_batch` - self.assertNotIn("prev_batch", channel.json_body) - - # Check that there is no `next_batch` - self.assertNotIn("next_batch", channel.json_body) - - if expected_room_id: - # Check that the first returned room id is correct - r = rooms[0] - self.assertEqual(expected_room_id, r["room_id"]) - - # Perform search tests - _search_test(room_id_1, "something") - _search_test(room_id_1, "thing") - - _search_test(room_id_2, "else") - _search_test(room_id_2, "se") - - _search_test(None, "foo") - _search_test(None, "bar") - _search_test(None, "", expected_http_code=400) diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py new file mode 100644 index 0000000000..54cd24bf64 --- /dev/null +++ b/tests/rest/admin/test_room.py @@ -0,0 +1,1007 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Dirk Klimpel +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import urllib.parse +from typing import List, Optional + +from mock import Mock + +import synapse.rest.admin +from synapse.api.errors import Codes +from synapse.rest.client.v1 import directory, events, login, room + +from tests import unittest + +"""Tests admin REST events for /rooms paths.""" + + +class ShutdownRoomTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + events.register_servlets, + room.register_servlets, + room.register_deprecated_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.event_creation_handler = hs.get_event_creation_handler() + hs.config.user_consent_version = "1" + + consent_uri_builder = Mock() + consent_uri_builder.build_user_consent_uri.return_value = "http://example.com" + self.event_creation_handler._consent_uri_builder = consent_uri_builder + + self.store = hs.get_datastore() + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.other_user_token = self.login("user", "pass") + + # Mark the admin user as having consented + self.get_success(self.store.user_set_consent_version(self.admin_user, "1")) + + def test_shutdown_room_consent(self): + """Test that we can shutdown rooms with local users who have not + yet accepted the privacy policy. This used to fail when we tried to + force part the user from the old room. + """ + self.event_creation_handler._block_events_without_consent_error = None + + room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token) + + # Assert one user in room + users_in_room = self.get_success(self.store.get_users_in_room(room_id)) + self.assertEqual([self.other_user], users_in_room) + + # Enable require consent to send events + self.event_creation_handler._block_events_without_consent_error = "Error" + + # Assert that the user is getting consent error + self.helper.send( + room_id, body="foo", tok=self.other_user_token, expect_code=403 + ) + + # Test that the admin can still send shutdown + url = "admin/shutdown_room/" + room_id + request, channel = self.make_request( + "POST", + url.encode("ascii"), + json.dumps({"new_room_user_id": self.admin_user}), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Assert there is now no longer anyone in the room + users_in_room = self.get_success(self.store.get_users_in_room(room_id)) + self.assertEqual([], users_in_room) + + def test_shutdown_room_block_peek(self): + """Test that a world_readable room can no longer be peeked into after + it has been shut down. + """ + + self.event_creation_handler._block_events_without_consent_error = None + + room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token) + + # Enable world readable + url = "rooms/%s/state/m.room.history_visibility" % (room_id,) + request, channel = self.make_request( + "PUT", + url.encode("ascii"), + json.dumps({"history_visibility": "world_readable"}), + access_token=self.other_user_token, + ) + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Test that the admin can still send shutdown + url = "admin/shutdown_room/" + room_id + request, channel = self.make_request( + "POST", + url.encode("ascii"), + json.dumps({"new_room_user_id": self.admin_user}), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Assert we can no longer peek into the room + self._assert_peek(room_id, expect_code=403) + + def _assert_peek(self, room_id, expect_code): + """Assert that the admin user can (or cannot) peek into the room. + """ + + url = "rooms/%s/initialSync" % (room_id,) + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok + ) + self.render(request) + self.assertEqual( + expect_code, int(channel.result["code"]), msg=channel.result["body"] + ) + + url = "events?timeout=0&room_id=" + room_id + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok + ) + self.render(request) + self.assertEqual( + expect_code, int(channel.result["code"]), msg=channel.result["body"] + ) + + +class PurgeRoomTestCase(unittest.HomeserverTestCase): + """Test /purge_room admin API. + """ + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + def test_purge_room(self): + room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + + # All users have to have left the room. + self.helper.leave(room_id, user=self.admin_user, tok=self.admin_user_tok) + + url = "/_synapse/admin/v1/purge_room" + request, channel = self.make_request( + "POST", + url.encode("ascii"), + {"room_id": room_id}, + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Test that the following tables have been purged of all rows related to the room. + for table in ( + "current_state_events", + "event_backward_extremities", + "event_forward_extremities", + "event_json", + "event_push_actions", + "event_search", + "events", + "group_rooms", + "public_room_list_stream", + "receipts_graph", + "receipts_linearized", + "room_aliases", + "room_depth", + "room_memberships", + "room_stats_state", + "room_stats_current", + "room_stats_historical", + "room_stats_earliest_token", + "rooms", + "stream_ordering_to_exterm", + "users_in_public_rooms", + "users_who_share_private_rooms", + "appservice_room_list", + "e2e_room_keys", + "event_push_summary", + "pusher_throttle", + "group_summary_rooms", + "local_invites", + "room_account_data", + "room_tags", + # "state_groups", # Current impl leaves orphaned state groups around. + "state_groups_state", + ): + count = self.get_success( + self.store.db.simple_select_one_onecol( + table=table, + keyvalues={"room_id": room_id}, + retcol="COUNT(*)", + desc="test_purge_room", + ) + ) + + self.assertEqual(count, 0, msg="Rows not purged in {}".format(table)) + + +class RoomTestCase(unittest.HomeserverTestCase): + """Test /room admin API. + """ + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + directory.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + + # Create user + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + def test_list_rooms(self): + """Test that we can list rooms""" + # Create 3 test rooms + total_rooms = 3 + room_ids = [] + for x in range(total_rooms): + room_id = self.helper.create_room_as( + self.admin_user, tok=self.admin_user_tok + ) + room_ids.append(room_id) + + # Request the list of rooms + url = "/_synapse/admin/v1/rooms" + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.render(request) + + # Check request completed successfully + self.assertEqual(200, int(channel.code), msg=channel.json_body) + + # Check that response json body contains a "rooms" key + self.assertTrue( + "rooms" in channel.json_body, + msg="Response body does not " "contain a 'rooms' key", + ) + + # Check that 3 rooms were returned + self.assertEqual(3, len(channel.json_body["rooms"]), msg=channel.json_body) + + # Check their room_ids match + returned_room_ids = [room["room_id"] for room in channel.json_body["rooms"]] + self.assertEqual(room_ids, returned_room_ids) + + # Check that all fields are available + for r in channel.json_body["rooms"]: + self.assertIn("name", r) + self.assertIn("canonical_alias", r) + self.assertIn("joined_members", r) + self.assertIn("joined_local_members", r) + self.assertIn("version", r) + self.assertIn("creator", r) + self.assertIn("encryption", r) + self.assertIn("federatable", r) + self.assertIn("public", r) + self.assertIn("join_rules", r) + self.assertIn("guest_access", r) + self.assertIn("history_visibility", r) + self.assertIn("state_events", r) + + # Check that the correct number of total rooms was returned + self.assertEqual(channel.json_body["total_rooms"], total_rooms) + + # Check that the offset is correct + # Should be 0 as we aren't paginating + self.assertEqual(channel.json_body["offset"], 0) + + # Check that the prev_batch parameter is not present + self.assertNotIn("prev_batch", channel.json_body) + + # We shouldn't receive a next token here as there's no further rooms to show + self.assertNotIn("next_batch", channel.json_body) + + def test_list_rooms_pagination(self): + """Test that we can get a full list of rooms through pagination""" + # Create 5 test rooms + total_rooms = 5 + room_ids = [] + for x in range(total_rooms): + room_id = self.helper.create_room_as( + self.admin_user, tok=self.admin_user_tok + ) + room_ids.append(room_id) + + # Set the name of the rooms so we get a consistent returned ordering + for idx, room_id in enumerate(room_ids): + self.helper.send_state( + room_id, "m.room.name", {"name": str(idx)}, tok=self.admin_user_tok, + ) + + # Request the list of rooms + returned_room_ids = [] + start = 0 + limit = 2 + + run_count = 0 + should_repeat = True + while should_repeat: + run_count += 1 + + url = "/_synapse/admin/v1/rooms?from=%d&limit=%d&order_by=%s" % ( + start, + limit, + "name", + ) + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual( + 200, int(channel.result["code"]), msg=channel.result["body"] + ) + + self.assertTrue("rooms" in channel.json_body) + for r in channel.json_body["rooms"]: + returned_room_ids.append(r["room_id"]) + + # Check that the correct number of total rooms was returned + self.assertEqual(channel.json_body["total_rooms"], total_rooms) + + # Check that the offset is correct + # We're only getting 2 rooms each page, so should be 2 * last run_count + self.assertEqual(channel.json_body["offset"], 2 * (run_count - 1)) + + if run_count > 1: + # Check the value of prev_batch is correct + self.assertEqual(channel.json_body["prev_batch"], 2 * (run_count - 2)) + + if "next_batch" not in channel.json_body: + # We have reached the end of the list + should_repeat = False + else: + # Make another query with an updated start value + start = channel.json_body["next_batch"] + + # We should've queried the endpoint 3 times + self.assertEqual( + run_count, + 3, + msg="Should've queried 3 times for 5 rooms with limit 2 per query", + ) + + # Check that we received all of the room ids + self.assertEqual(room_ids, returned_room_ids) + + url = "/_synapse/admin/v1/rooms?from=%d&limit=%d" % (start, limit) + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + def test_correct_room_attributes(self): + """Test the correct attributes for a room are returned""" + # Create a test room + room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + + test_alias = "#test:test" + test_room_name = "something" + + # Have another user join the room + user_2 = self.register_user("user4", "pass") + user_tok_2 = self.login("user4", "pass") + self.helper.join(room_id, user_2, tok=user_tok_2) + + # Create a new alias to this room + url = "/_matrix/client/r0/directory/room/%s" % (urllib.parse.quote(test_alias),) + request, channel = self.make_request( + "PUT", + url.encode("ascii"), + {"room_id": room_id}, + access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Set this new alias as the canonical alias for this room + self.helper.send_state( + room_id, + "m.room.aliases", + {"aliases": [test_alias]}, + tok=self.admin_user_tok, + state_key="test", + ) + self.helper.send_state( + room_id, + "m.room.canonical_alias", + {"alias": test_alias}, + tok=self.admin_user_tok, + ) + + # Set a name for the room + self.helper.send_state( + room_id, "m.room.name", {"name": test_room_name}, tok=self.admin_user_tok, + ) + + # Request the list of rooms + url = "/_synapse/admin/v1/rooms" + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Check that rooms were returned + self.assertTrue("rooms" in channel.json_body) + rooms = channel.json_body["rooms"] + + # Check that only one room was returned + self.assertEqual(len(rooms), 1) + + # And that the value of the total_rooms key was correct + self.assertEqual(channel.json_body["total_rooms"], 1) + + # Check that the offset is correct + # We're not paginating, so should be 0 + self.assertEqual(channel.json_body["offset"], 0) + + # Check that there is no `prev_batch` + self.assertNotIn("prev_batch", channel.json_body) + + # Check that there is no `next_batch` + self.assertNotIn("next_batch", channel.json_body) + + # Check that all provided attributes are set + r = rooms[0] + self.assertEqual(room_id, r["room_id"]) + self.assertEqual(test_room_name, r["name"]) + self.assertEqual(test_alias, r["canonical_alias"]) + + def test_room_list_sort_order(self): + """Test room list sort ordering. alphabetical name versus number of members, + reversing the order, etc. + """ + + def _set_canonical_alias(room_id: str, test_alias: str, admin_user_tok: str): + # Create a new alias to this room + url = "/_matrix/client/r0/directory/room/%s" % ( + urllib.parse.quote(test_alias), + ) + request, channel = self.make_request( + "PUT", + url.encode("ascii"), + {"room_id": room_id}, + access_token=admin_user_tok, + ) + self.render(request) + self.assertEqual( + 200, int(channel.result["code"]), msg=channel.result["body"] + ) + + # Set this new alias as the canonical alias for this room + self.helper.send_state( + room_id, + "m.room.aliases", + {"aliases": [test_alias]}, + tok=admin_user_tok, + state_key="test", + ) + self.helper.send_state( + room_id, + "m.room.canonical_alias", + {"alias": test_alias}, + tok=admin_user_tok, + ) + + def _order_test( + order_type: str, expected_room_list: List[str], reverse: bool = False, + ): + """Request the list of rooms in a certain order. Assert that order is what + we expect + + Args: + order_type: The type of ordering to give the server + expected_room_list: The list of room_ids in the order we expect to get + back from the server + """ + # Request the list of rooms in the given order + url = "/_synapse/admin/v1/rooms?order_by=%s" % (order_type,) + if reverse: + url += "&dir=b" + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, channel.code, msg=channel.json_body) + + # Check that rooms were returned + self.assertTrue("rooms" in channel.json_body) + rooms = channel.json_body["rooms"] + + # Check for the correct total_rooms value + self.assertEqual(channel.json_body["total_rooms"], 3) + + # Check that the offset is correct + # We're not paginating, so should be 0 + self.assertEqual(channel.json_body["offset"], 0) + + # Check that there is no `prev_batch` + self.assertNotIn("prev_batch", channel.json_body) + + # Check that there is no `next_batch` + self.assertNotIn("next_batch", channel.json_body) + + # Check that rooms were returned in alphabetical order + returned_order = [r["room_id"] for r in rooms] + self.assertListEqual(expected_room_list, returned_order) # order is checked + + # Create 3 test rooms + room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + room_id_3 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + + # Set room names in alphabetical order. room 1 -> A, 2 -> B, 3 -> C + self.helper.send_state( + room_id_1, "m.room.name", {"name": "A"}, tok=self.admin_user_tok, + ) + self.helper.send_state( + room_id_2, "m.room.name", {"name": "B"}, tok=self.admin_user_tok, + ) + self.helper.send_state( + room_id_3, "m.room.name", {"name": "C"}, tok=self.admin_user_tok, + ) + + # Set room canonical room aliases + _set_canonical_alias(room_id_1, "#A_alias:test", self.admin_user_tok) + _set_canonical_alias(room_id_2, "#B_alias:test", self.admin_user_tok) + _set_canonical_alias(room_id_3, "#C_alias:test", self.admin_user_tok) + + # Set room member size in the reverse order. room 1 -> 1 member, 2 -> 2, 3 -> 3 + user_1 = self.register_user("bob1", "pass") + user_1_tok = self.login("bob1", "pass") + self.helper.join(room_id_2, user_1, tok=user_1_tok) + + user_2 = self.register_user("bob2", "pass") + user_2_tok = self.login("bob2", "pass") + self.helper.join(room_id_3, user_2, tok=user_2_tok) + + user_3 = self.register_user("bob3", "pass") + user_3_tok = self.login("bob3", "pass") + self.helper.join(room_id_3, user_3, tok=user_3_tok) + + # Test different sort orders, with forward and reverse directions + _order_test("name", [room_id_1, room_id_2, room_id_3]) + _order_test("name", [room_id_3, room_id_2, room_id_1], reverse=True) + + _order_test("canonical_alias", [room_id_1, room_id_2, room_id_3]) + _order_test("canonical_alias", [room_id_3, room_id_2, room_id_1], reverse=True) + + _order_test("joined_members", [room_id_3, room_id_2, room_id_1]) + _order_test("joined_members", [room_id_1, room_id_2, room_id_3], reverse=True) + + _order_test("joined_local_members", [room_id_3, room_id_2, room_id_1]) + _order_test( + "joined_local_members", [room_id_1, room_id_2, room_id_3], reverse=True + ) + + _order_test("version", [room_id_1, room_id_2, room_id_3]) + _order_test("version", [room_id_1, room_id_2, room_id_3], reverse=True) + + _order_test("creator", [room_id_1, room_id_2, room_id_3]) + _order_test("creator", [room_id_1, room_id_2, room_id_3], reverse=True) + + _order_test("encryption", [room_id_1, room_id_2, room_id_3]) + _order_test("encryption", [room_id_1, room_id_2, room_id_3], reverse=True) + + _order_test("federatable", [room_id_1, room_id_2, room_id_3]) + _order_test("federatable", [room_id_1, room_id_2, room_id_3], reverse=True) + + _order_test("public", [room_id_1, room_id_2, room_id_3]) + # Different sort order of SQlite and PostreSQL + # _order_test("public", [room_id_3, room_id_2, room_id_1], reverse=True) + + _order_test("join_rules", [room_id_1, room_id_2, room_id_3]) + _order_test("join_rules", [room_id_1, room_id_2, room_id_3], reverse=True) + + _order_test("guest_access", [room_id_1, room_id_2, room_id_3]) + _order_test("guest_access", [room_id_1, room_id_2, room_id_3], reverse=True) + + _order_test("history_visibility", [room_id_1, room_id_2, room_id_3]) + _order_test( + "history_visibility", [room_id_1, room_id_2, room_id_3], reverse=True + ) + + _order_test("state_events", [room_id_3, room_id_2, room_id_1]) + _order_test("state_events", [room_id_1, room_id_2, room_id_3], reverse=True) + + def test_search_term(self): + """Test that searching for a room works correctly""" + # Create two test rooms + room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + + room_name_1 = "something" + room_name_2 = "else" + + # Set the name for each room + self.helper.send_state( + room_id_1, "m.room.name", {"name": room_name_1}, tok=self.admin_user_tok, + ) + self.helper.send_state( + room_id_2, "m.room.name", {"name": room_name_2}, tok=self.admin_user_tok, + ) + + def _search_test( + expected_room_id: Optional[str], + search_term: str, + expected_http_code: int = 200, + ): + """Search for a room and check that the returned room's id is a match + + Args: + expected_room_id: The room_id expected to be returned by the API. Set + to None to expect zero results for the search + search_term: The term to search for room names with + expected_http_code: The expected http code for the request + """ + url = "/_synapse/admin/v1/rooms?search_term=%s" % (search_term,) + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(expected_http_code, channel.code, msg=channel.json_body) + + if expected_http_code != 200: + return + + # Check that rooms were returned + self.assertTrue("rooms" in channel.json_body) + rooms = channel.json_body["rooms"] + + # Check that the expected number of rooms were returned + expected_room_count = 1 if expected_room_id else 0 + self.assertEqual(len(rooms), expected_room_count) + self.assertEqual(channel.json_body["total_rooms"], expected_room_count) + + # Check that the offset is correct + # We're not paginating, so should be 0 + self.assertEqual(channel.json_body["offset"], 0) + + # Check that there is no `prev_batch` + self.assertNotIn("prev_batch", channel.json_body) + + # Check that there is no `next_batch` + self.assertNotIn("next_batch", channel.json_body) + + if expected_room_id: + # Check that the first returned room id is correct + r = rooms[0] + self.assertEqual(expected_room_id, r["room_id"]) + + # Perform search tests + _search_test(room_id_1, "something") + _search_test(room_id_1, "thing") + + _search_test(room_id_2, "else") + _search_test(room_id_2, "se") + + _search_test(None, "foo") + _search_test(None, "bar") + _search_test(None, "", expected_http_code=400) + + def test_single_room(self): + """Test that a single room can be requested correctly""" + # Create two test rooms + room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + + room_name_1 = "something" + room_name_2 = "else" + + # Set the name for each room + self.helper.send_state( + room_id_1, "m.room.name", {"name": room_name_1}, tok=self.admin_user_tok, + ) + self.helper.send_state( + room_id_2, "m.room.name", {"name": room_name_2}, tok=self.admin_user_tok, + ) + + url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,) + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, channel.code, msg=channel.json_body) + + self.assertIn("room_id", channel.json_body) + self.assertIn("name", channel.json_body) + self.assertIn("canonical_alias", channel.json_body) + self.assertIn("joined_members", channel.json_body) + self.assertIn("joined_local_members", channel.json_body) + self.assertIn("version", channel.json_body) + self.assertIn("creator", channel.json_body) + self.assertIn("encryption", channel.json_body) + self.assertIn("federatable", channel.json_body) + self.assertIn("public", channel.json_body) + self.assertIn("join_rules", channel.json_body) + self.assertIn("guest_access", channel.json_body) + self.assertIn("history_visibility", channel.json_body) + self.assertIn("state_events", channel.json_body) + + self.assertEqual(room_id_1, channel.json_body["room_id"]) + + +class JoinAliasRoomTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, homeserver): + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.creator = self.register_user("creator", "test") + self.creator_tok = self.login("creator", "test") + + self.second_user_id = self.register_user("second", "test") + self.second_tok = self.login("second", "test") + + self.public_room_id = self.helper.create_room_as( + self.creator, tok=self.creator_tok, is_public=True + ) + self.url = "/_synapse/admin/v1/join/{}".format(self.public_room_id) + + def test_requester_is_no_admin(self): + """ + If the user is not a server admin, an error 403 is returned. + """ + body = json.dumps({"user_id": self.second_user_id}) + + request, channel = self.make_request( + "POST", + self.url, + content=body.encode(encoding="utf_8"), + access_token=self.second_tok, + ) + self.render(request) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_invalid_parameter(self): + """ + If a parameter is missing, return an error + """ + body = json.dumps({"unknown_parameter": "@unknown:test"}) + + request, channel = self.make_request( + "POST", + self.url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) + + def test_local_user_does_not_exist(self): + """ + Tests that a lookup for a user that does not exist returns a 404 + """ + body = json.dumps({"user_id": "@unknown:test"}) + + request, channel = self.make_request( + "POST", + self.url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + def test_remote_user(self): + """ + Check that only local user can join rooms. + """ + body = json.dumps({"user_id": "@not:exist.bla"}) + + request, channel = self.make_request( + "POST", + self.url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + "This endpoint can only be used with local users", + channel.json_body["error"], + ) + + def test_room_does_not_exist(self): + """ + Check that unknown rooms/server return error 404. + """ + body = json.dumps({"user_id": self.second_user_id}) + url = "/_synapse/admin/v1/join/!unknown:test" + + request, channel = self.make_request( + "POST", + url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("No known servers", channel.json_body["error"]) + + def test_room_is_not_valid(self): + """ + Check that invalid room names, return an error 400. + """ + body = json.dumps({"user_id": self.second_user_id}) + url = "/_synapse/admin/v1/join/invalidroom" + + request, channel = self.make_request( + "POST", + url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + "invalidroom was not legal room ID or room alias", + channel.json_body["error"], + ) + + def test_join_public_room(self): + """ + Test joining a local user to a public room with "JoinRules.PUBLIC" + """ + body = json.dumps({"user_id": self.second_user_id}) + + request, channel = self.make_request( + "POST", + self.url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(self.public_room_id, channel.json_body["room_id"]) + + # Validate if user is a member of the room + + request, channel = self.make_request( + "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, + ) + self.render(request) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(self.public_room_id, channel.json_body["joined_rooms"][0]) + + def test_join_private_room_if_not_member(self): + """ + Test joining a local user to a private room with "JoinRules.INVITE" + when server admin is not member of this room. + """ + private_room_id = self.helper.create_room_as( + self.creator, tok=self.creator_tok, is_public=False + ) + url = "/_synapse/admin/v1/join/{}".format(private_room_id) + body = json.dumps({"user_id": self.second_user_id}) + + request, channel = self.make_request( + "POST", + url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_join_private_room_if_member(self): + """ + Test joining a local user to a private room with "JoinRules.INVITE", + when server admin is member of this room. + """ + private_room_id = self.helper.create_room_as( + self.creator, tok=self.creator_tok, is_public=False + ) + self.helper.invite( + room=private_room_id, + src=self.creator, + targ=self.admin_user, + tok=self.creator_tok, + ) + self.helper.join( + room=private_room_id, user=self.admin_user, tok=self.admin_user_tok + ) + + # Validate if server admin is a member of the room + + request, channel = self.make_request( + "GET", "/_matrix/client/r0/joined_rooms", access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) + + # Join user to room. + + url = "/_synapse/admin/v1/join/{}".format(private_room_id) + body = json.dumps({"user_id": self.second_user_id}) + + request, channel = self.make_request( + "POST", + url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(private_room_id, channel.json_body["room_id"]) + + # Validate if user is a member of the room + + request, channel = self.make_request( + "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, + ) + self.render(request) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) + + def test_join_private_room_if_owner(self): + """ + Test joining a local user to a private room with "JoinRules.INVITE", + when server admin is owner of this room. + """ + private_room_id = self.helper.create_room_as( + self.admin_user, tok=self.admin_user_tok, is_public=False + ) + url = "/_synapse/admin/v1/join/{}".format(private_room_id) + body = json.dumps({"user_id": self.second_user_id}) + + request, channel = self.make_request( + "POST", + url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(private_room_id, channel.json_body["room_id"]) + + # Validate if user is a member of the room + + request, channel = self.make_request( + "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, + ) + self.render(request) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 6416fb5d2a..6c88ab06e2 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -360,6 +360,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(3, len(channel.json_body["users"])) + self.assertEqual(3, channel.json_body["total"]) class UserRestTestCase(unittest.HomeserverTestCase): @@ -434,6 +435,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): "admin": True, "displayname": "Bob's name", "threepids": [{"medium": "email", "address": "bob@bob.bob"}], + "avatar_url": None, } ) diff --git a/tests/rest/client/test_power_levels.py b/tests/rest/client/test_power_levels.py new file mode 100644 index 0000000000..913ea3c98e --- /dev/null +++ b/tests/rest/client/test_power_levels.py @@ -0,0 +1,205 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.rest import admin +from synapse.rest.client.v1 import login, room +from synapse.rest.client.v2_alpha import sync + +from tests.unittest import HomeserverTestCase + + +class PowerLevelsTestCase(HomeserverTestCase): + """Tests that power levels are enforced in various situations""" + + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + sync.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + + return self.setup_test_homeserver(config=config) + + def prepare(self, reactor, clock, hs): + # register a room admin, moderator and regular user + self.admin_user_id = self.register_user("admin", "pass") + self.admin_access_token = self.login("admin", "pass") + self.mod_user_id = self.register_user("mod", "pass") + self.mod_access_token = self.login("mod", "pass") + self.user_user_id = self.register_user("user", "pass") + self.user_access_token = self.login("user", "pass") + + # Create a room + self.room_id = self.helper.create_room_as( + self.admin_user_id, tok=self.admin_access_token + ) + + # Invite the other users + self.helper.invite( + room=self.room_id, + src=self.admin_user_id, + tok=self.admin_access_token, + targ=self.mod_user_id, + ) + self.helper.invite( + room=self.room_id, + src=self.admin_user_id, + tok=self.admin_access_token, + targ=self.user_user_id, + ) + + # Make the other users join the room + self.helper.join( + room=self.room_id, user=self.mod_user_id, tok=self.mod_access_token + ) + self.helper.join( + room=self.room_id, user=self.user_user_id, tok=self.user_access_token + ) + + # Mod the mod + room_power_levels = self.helper.get_state( + self.room_id, "m.room.power_levels", tok=self.admin_access_token, + ) + + # Update existing power levels with mod at PL50 + room_power_levels["users"].update({self.mod_user_id: 50}) + + self.helper.send_state( + self.room_id, + "m.room.power_levels", + room_power_levels, + tok=self.admin_access_token, + ) + + def test_non_admins_cannot_enable_room_encryption(self): + # have the mod try to enable room encryption + self.helper.send_state( + self.room_id, + "m.room.encryption", + {"algorithm": "m.megolm.v1.aes-sha2"}, + tok=self.mod_access_token, + expect_code=403, # expect failure + ) + + # have the user try to enable room encryption + self.helper.send_state( + self.room_id, + "m.room.encryption", + {"algorithm": "m.megolm.v1.aes-sha2"}, + tok=self.user_access_token, + expect_code=403, # expect failure + ) + + def test_non_admins_cannot_send_server_acl(self): + # have the mod try to send a server ACL + self.helper.send_state( + self.room_id, + "m.room.server_acl", + { + "allow": ["*"], + "allow_ip_literals": False, + "deny": ["*.evil.com", "evil.com"], + }, + tok=self.mod_access_token, + expect_code=403, # expect failure + ) + + # have the user try to send a server ACL + self.helper.send_state( + self.room_id, + "m.room.server_acl", + { + "allow": ["*"], + "allow_ip_literals": False, + "deny": ["*.evil.com", "evil.com"], + }, + tok=self.user_access_token, + expect_code=403, # expect failure + ) + + def test_non_admins_cannot_tombstone_room(self): + # Create another room that will serve as our "upgraded room" + self.upgraded_room_id = self.helper.create_room_as( + self.admin_user_id, tok=self.admin_access_token + ) + + # have the mod try to send a tombstone event + self.helper.send_state( + self.room_id, + "m.room.tombstone", + { + "body": "This room has been replaced", + "replacement_room": self.upgraded_room_id, + }, + tok=self.mod_access_token, + expect_code=403, # expect failure + ) + + # have the user try to send a tombstone event + self.helper.send_state( + self.room_id, + "m.room.tombstone", + { + "body": "This room has been replaced", + "replacement_room": self.upgraded_room_id, + }, + tok=self.user_access_token, + expect_code=403, # expect failure + ) + + def test_admins_can_enable_room_encryption(self): + # have the admin try to enable room encryption + self.helper.send_state( + self.room_id, + "m.room.encryption", + {"algorithm": "m.megolm.v1.aes-sha2"}, + tok=self.admin_access_token, + expect_code=200, # expect success + ) + + def test_admins_can_send_server_acl(self): + # have the admin try to send a server ACL + self.helper.send_state( + self.room_id, + "m.room.server_acl", + { + "allow": ["*"], + "allow_ip_literals": False, + "deny": ["*.evil.com", "evil.com"], + }, + tok=self.admin_access_token, + expect_code=200, # expect success + ) + + def test_admins_can_tombstone_room(self): + # Create another room that will serve as our "upgraded room" + self.upgraded_room_id = self.helper.create_room_as( + self.admin_user_id, tok=self.admin_access_token + ) + + # have the admin try to send a tombstone event + self.helper.send_state( + self.room_id, + "m.room.tombstone", + { + "body": "This room has been replaced", + "replacement_room": self.upgraded_room_id, + }, + tok=self.admin_access_token, + expect_code=200, # expect success + ) diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py index a3d7e3c046..171632e195 100644 --- a/tests/rest/client/test_transactions.py +++ b/tests/rest/client/test_transactions.py @@ -2,7 +2,7 @@ from mock import Mock, call from twisted.internet import defer, reactor -from synapse.logging.context import LoggingContext +from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context from synapse.rest.client.transactions import CLEANUP_PERIOD_MS, HttpTransactionCache from synapse.util import Clock @@ -52,14 +52,14 @@ class HttpTransactionCacheTestCase(unittest.TestCase): def test(): with LoggingContext("c") as c1: res = yield self.cache.fetch_or_execute(self.mock_key, cb) - self.assertIs(LoggingContext.current_context(), c1) + self.assertIs(current_context(), c1) self.assertEqual(res, "yay") # run the test twice in parallel d = defer.gatherResults([test(), test()]) - self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel) + self.assertIs(current_context(), SENTINEL_CONTEXT) yield d - self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel) + self.assertIs(current_context(), SENTINEL_CONTEXT) @defer.inlineCallbacks def test_does_not_cache_exceptions(self): @@ -81,11 +81,11 @@ class HttpTransactionCacheTestCase(unittest.TestCase): yield self.cache.fetch_or_execute(self.mock_key, cb) except Exception as e: self.assertEqual(e.args[0], "boo") - self.assertIs(LoggingContext.current_context(), test_context) + self.assertIs(current_context(), test_context) res = yield self.cache.fetch_or_execute(self.mock_key, cb) self.assertEqual(res, self.mock_http_response) - self.assertIs(LoggingContext.current_context(), test_context) + self.assertIs(current_context(), test_context) @defer.inlineCallbacks def test_does_not_cache_failures(self): @@ -107,11 +107,11 @@ class HttpTransactionCacheTestCase(unittest.TestCase): yield self.cache.fetch_or_execute(self.mock_key, cb) except Exception as e: self.assertEqual(e.args[0], "boo") - self.assertIs(LoggingContext.current_context(), test_context) + self.assertIs(current_context(), test_context) res = yield self.cache.fetch_or_execute(self.mock_key, cb) self.assertEqual(res, self.mock_http_response) - self.assertIs(LoggingContext.current_context(), test_context) + self.assertIs(current_context(), test_context) @defer.inlineCallbacks def test_cleans_up(self): diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py index ffb2de1505..b54b06482b 100644 --- a/tests/rest/client/v1/test_events.py +++ b/tests/rest/client/v1/test_events.py @@ -50,7 +50,7 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase): return hs - def prepare(self, hs, reactor, clock): + def prepare(self, reactor, clock, hs): # register an account self.user_id = self.register_user("sid1", "pass") diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index da2c9bfa1e..eb8f6264fd 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -4,7 +4,7 @@ import urllib.parse from mock import Mock import synapse.rest.admin -from synapse.rest.client.v1 import login +from synapse.rest.client.v1 import login, logout from synapse.rest.client.v2_alpha import devices from synapse.rest.client.v2_alpha.account import WhoamiRestServlet @@ -20,6 +20,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, login.register_servlets, + logout.register_servlets, devices.register_servlets, lambda hs, http_server: WhoamiRestServlet(hs).register(http_server), ] @@ -256,8 +257,74 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): self.render(request) self.assertEquals(channel.code, 200, channel.result) + @override_config({"session_lifetime": "24h"}) + def test_session_can_hard_logout_after_being_soft_logged_out(self): + self.register_user("kermit", "monkey") + + # log in as normal + access_token = self.login("kermit", "monkey") + + # we should now be able to make requests with the access token + request, channel = self.make_request( + b"GET", TEST_URL, access_token=access_token + ) + self.render(request) + self.assertEquals(channel.code, 200, channel.result) + + # time passes + self.reactor.advance(24 * 3600) + + # ... and we should be soft-logouted + request, channel = self.make_request( + b"GET", TEST_URL, access_token=access_token + ) + self.render(request) + self.assertEquals(channel.code, 401, channel.result) + self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") + self.assertEquals(channel.json_body["soft_logout"], True) + + # Now try to hard logout this session + request, channel = self.make_request( + b"POST", "/logout", access_token=access_token + ) + self.render(request) + self.assertEquals(channel.result["code"], b"200", channel.result) + + @override_config({"session_lifetime": "24h"}) + def test_session_can_hard_logout_all_sessions_after_being_soft_logged_out(self): + self.register_user("kermit", "monkey") + + # log in as normal + access_token = self.login("kermit", "monkey") + + # we should now be able to make requests with the access token + request, channel = self.make_request( + b"GET", TEST_URL, access_token=access_token + ) + self.render(request) + self.assertEquals(channel.code, 200, channel.result) + + # time passes + self.reactor.advance(24 * 3600) + + # ... and we should be soft-logouted + request, channel = self.make_request( + b"GET", TEST_URL, access_token=access_token + ) + self.render(request) + self.assertEquals(channel.code, 401, channel.result) + self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") + self.assertEquals(channel.json_body["soft_logout"], True) + + # Now try to hard log out all of the user's sessions + request, channel = self.make_request( + b"POST", "/logout/all", access_token=access_token + ) + self.render(request) + self.assertEquals(channel.result["code"], b"200", channel.result) -class CASRedirectConfirmTestCase(unittest.HomeserverTestCase): + +class CASTestCase(unittest.HomeserverTestCase): servlets = [ login.register_servlets, @@ -274,6 +341,9 @@ class CASRedirectConfirmTestCase(unittest.HomeserverTestCase): "service_url": "https://matrix.goodserver.com:8448", } + cas_user_id = "username" + self.user_id = "@%s:test" % cas_user_id + async def get_raw(uri, args): """Return an example response payload from a call to the `/proxyValidate` endpoint of a CAS server, copied from @@ -282,10 +352,11 @@ class CASRedirectConfirmTestCase(unittest.HomeserverTestCase): This needs to be returned by an async function (as opposed to set as the mock's return value) because the corresponding Synapse code awaits on it. """ - return """ + return ( + """ <cas:serviceResponse xmlns:cas='http://www.yale.edu/tp/cas'> <cas:authenticationSuccess> - <cas:user>username</cas:user> + <cas:user>%s</cas:user> <cas:proxyGrantingTicket>PGTIOU-84678-8a9d...</cas:proxyGrantingTicket> <cas:proxies> <cas:proxy>https://proxy2/pgtUrl</cas:proxy> @@ -294,6 +365,8 @@ class CASRedirectConfirmTestCase(unittest.HomeserverTestCase): </cas:authenticationSuccess> </cas:serviceResponse> """ + % cas_user_id + ) mocked_http_client = Mock(spec=["get_raw"]) mocked_http_client.get_raw.side_effect = get_raw @@ -304,6 +377,9 @@ class CASRedirectConfirmTestCase(unittest.HomeserverTestCase): return self.hs + def prepare(self, reactor, clock, hs): + self.deactivate_account_handler = hs.get_deactivate_account_handler() + def test_cas_redirect_confirm(self): """Tests that the SSO login flow serves a confirmation page before redirecting a user to the redirect URL. @@ -350,7 +426,14 @@ class CASRedirectConfirmTestCase(unittest.HomeserverTestCase): def test_cas_redirect_whitelisted(self): """Tests that the SSO login flow serves a redirect to a whitelisted url """ - redirect_url = "https://legit-site.com/" + self._test_redirect("https://legit-site.com/") + + @override_config({"public_baseurl": "https://example.com"}) + def test_cas_redirect_login_fallback(self): + self._test_redirect("https://example.com/_matrix/static/client/login") + + def _test_redirect(self, redirect_url): + """Tests that the SSO login flow serves a redirect for the given redirect URL.""" cas_ticket_url = ( "/_matrix/client/r0/login/cas/ticket?redirectUrl=%s&ticket=ticket" % (urllib.parse.quote(redirect_url)) @@ -363,3 +446,30 @@ class CASRedirectConfirmTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 302) location_headers = channel.headers.getRawHeaders("Location") self.assertEqual(location_headers[0][: len(redirect_url)], redirect_url) + + @override_config({"sso": {"client_whitelist": ["https://legit-site.com/"]}}) + def test_deactivated_user(self): + """Logging in as a deactivated account should error.""" + redirect_url = "https://legit-site.com/" + + # First login (to create the user). + self._test_redirect(redirect_url) + + # Deactivate the account. + self.get_success( + self.deactivate_account_handler.deactivate_account(self.user_id, False) + ) + + # Request the CAS ticket. + cas_ticket_url = ( + "/_matrix/client/r0/login/cas/ticket?redirectUrl=%s&ticket=ticket" + % (urllib.parse.quote(redirect_url)) + ) + + # Get Synapse to call the fake CAS and serve the template. + request, channel = self.make_request("GET", cas_ticket_url) + self.render(request) + + # Because the user is deactivated they are served an error template. + self.assertEqual(channel.code, 403) + self.assertIn(b"SSO account deactivated", channel.result["body"]) diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 873d5ef99c..22d734e763 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -18,6 +18,7 @@ import json import time +from typing import Any, Dict, Optional import attr @@ -38,7 +39,7 @@ class RestHelper(object): resource = attr.ib() auth_user_id = attr.ib() - def create_room_as(self, room_creator, is_public=True, tok=None): + def create_room_as(self, room_creator=None, is_public=True, tok=None): temp_id = self.auth_user_id self.auth_user_id = room_creator path = "/_matrix/client/r0/createRoom" @@ -142,7 +143,34 @@ class RestHelper(object): return channel.json_body - def send_state(self, room_id, event_type, body, tok, expect_code=200, state_key=""): + def _read_write_state( + self, + room_id: str, + event_type: str, + body: Optional[Dict[str, Any]], + tok: str, + expect_code: int = 200, + state_key: str = "", + method: str = "GET", + ) -> Dict: + """Read or write some state from a given room + + Args: + room_id: + event_type: The type of state event + body: Body that is sent when making the request. The content of the state event. + If None, the request to the server will have an empty body + tok: The access token to use + expect_code: The HTTP code to expect in the response + state_key: + method: "GET" or "PUT" for reading or writing state, respectively + + Returns: + The response body from the server + + Raises: + AssertionError: if expect_code doesn't match the HTTP code we received + """ path = "/_matrix/client/r0/rooms/%s/state/%s/%s" % ( room_id, event_type, @@ -151,9 +179,13 @@ class RestHelper(object): if tok: path = path + "?access_token=%s" % tok - request, channel = make_request( - self.hs.get_reactor(), "PUT", path, json.dumps(body).encode("utf8") - ) + # Set request body if provided + content = b"" + if body is not None: + content = json.dumps(body).encode("utf8") + + request, channel = make_request(self.hs.get_reactor(), method, path, content) + render(request, self.resource, self.hs.get_reactor()) assert int(channel.result["code"]) == expect_code, ( @@ -163,6 +195,62 @@ class RestHelper(object): return channel.json_body + def get_state( + self, + room_id: str, + event_type: str, + tok: str, + expect_code: int = 200, + state_key: str = "", + ): + """Gets some state from a room + + Args: + room_id: + event_type: The type of state event + tok: The access token to use + expect_code: The HTTP code to expect in the response + state_key: + + Returns: + The response body from the server + + Raises: + AssertionError: if expect_code doesn't match the HTTP code we received + """ + return self._read_write_state( + room_id, event_type, None, tok, expect_code, state_key, method="GET" + ) + + def send_state( + self, + room_id: str, + event_type: str, + body: Dict[str, Any], + tok: str, + expect_code: int = 200, + state_key: str = "", + ): + """Set some state in a room + + Args: + room_id: + event_type: The type of state event + body: Body that is sent when making the request. The content of the state event. + tok: The access token to use + expect_code: The HTTP code to expect in the response + state_key: + + Returns: + The response body from the server + + Raises: + AssertionError: if expect_code doesn't match the HTTP code we received + """ + return self._read_write_state( + room_id, event_type, body, tok, expect_code, state_key, method="PUT" + ) + def upload_media( self, resource: Resource, diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py index c3facc00eb..0d6936fd36 100644 --- a/tests/rest/client/v2_alpha/test_account.py +++ b/tests/rest/client/v2_alpha/test_account.py @@ -24,6 +24,7 @@ import pkg_resources import synapse.rest.admin from synapse.api.constants import LoginType, Membership +from synapse.api.errors import Codes from synapse.rest.client.v1 import login, room from synapse.rest.client.v2_alpha import account, register @@ -178,6 +179,22 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): # Assert we can't log in with the new password self.attempt_wrong_password_login("kermit", new_password) + @unittest.override_config({"request_token_inhibit_3pid_errors": True}) + def test_password_reset_bad_email_inhibit_error(self): + """Test that triggering a password reset with an email address that isn't bound + to an account doesn't leak the lack of binding for that address if configured + that way. + """ + self.register_user("kermit", "monkey") + self.login("kermit", "monkey") + + email = "test@example.com" + + client_secret = "foobar" + session_id = self._request_token(email, client_secret) + + self.assertIsNotNone(session_id) + def _request_token(self, email, client_secret): request, channel = self.make_request( "POST", @@ -325,3 +342,304 @@ class DeactivateTestCase(unittest.HomeserverTestCase): ) self.render(request) self.assertEqual(request.code, 200) + + +class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): + + servlets = [ + account.register_servlets, + login.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + + # Email config. + self.email_attempts = [] + + def sendmail(smtphost, from_addr, to_addrs, msg, **kwargs): + self.email_attempts.append(msg) + + config["email"] = { + "enable_notifs": False, + "template_dir": os.path.abspath( + pkg_resources.resource_filename("synapse", "res/templates") + ), + "smtp_host": "127.0.0.1", + "smtp_port": 20, + "require_transport_security": False, + "smtp_user": None, + "smtp_pass": None, + "notif_from": "test@example.com", + } + config["public_baseurl"] = "https://example.com" + + self.hs = self.setup_test_homeserver(config=config, sendmail=sendmail) + return self.hs + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + + self.user_id = self.register_user("kermit", "test") + self.user_id_tok = self.login("kermit", "test") + self.email = "test@example.com" + self.url_3pid = b"account/3pid" + + def test_add_email(self): + """Test adding an email to profile + """ + client_secret = "foobar" + session_id = self._request_token(self.email, client_secret) + + self.assertEquals(len(self.email_attempts), 1) + link = self._get_link_from_email() + + self._validate_token(link) + + request, channel = self.make_request( + "POST", + b"/_matrix/client/unstable/account/3pid/add", + { + "client_secret": client_secret, + "sid": session_id, + "auth": { + "type": "m.login.password", + "user": self.user_id, + "password": "test", + }, + }, + access_token=self.user_id_tok, + ) + + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_3pid, access_token=self.user_id_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) + self.assertEqual(self.email, channel.json_body["threepids"][0]["address"]) + + def test_add_email_if_disabled(self): + """Test adding email to profile when doing so is disallowed + """ + self.hs.config.enable_3pid_changes = False + + client_secret = "foobar" + session_id = self._request_token(self.email, client_secret) + + self.assertEquals(len(self.email_attempts), 1) + link = self._get_link_from_email() + + self._validate_token(link) + + request, channel = self.make_request( + "POST", + b"/_matrix/client/unstable/account/3pid/add", + { + "client_secret": client_secret, + "sid": session_id, + "auth": { + "type": "m.login.password", + "user": self.user_id, + "password": "test", + }, + }, + access_token=self.user_id_tok, + ) + self.render(request) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_3pid, access_token=self.user_id_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertFalse(channel.json_body["threepids"]) + + def test_delete_email(self): + """Test deleting an email from profile + """ + # Add a threepid + self.get_success( + self.store.user_add_threepid( + user_id=self.user_id, + medium="email", + address=self.email, + validated_at=0, + added_at=0, + ) + ) + + request, channel = self.make_request( + "POST", + b"account/3pid/delete", + {"medium": "email", "address": self.email}, + access_token=self.user_id_tok, + ) + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_3pid, access_token=self.user_id_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertFalse(channel.json_body["threepids"]) + + def test_delete_email_if_disabled(self): + """Test deleting an email from profile when disallowed + """ + self.hs.config.enable_3pid_changes = False + + # Add a threepid + self.get_success( + self.store.user_add_threepid( + user_id=self.user_id, + medium="email", + address=self.email, + validated_at=0, + added_at=0, + ) + ) + + request, channel = self.make_request( + "POST", + b"account/3pid/delete", + {"medium": "email", "address": self.email}, + access_token=self.user_id_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_3pid, access_token=self.user_id_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) + self.assertEqual(self.email, channel.json_body["threepids"][0]["address"]) + + def test_cant_add_email_without_clicking_link(self): + """Test that we do actually need to click the link in the email + """ + client_secret = "foobar" + session_id = self._request_token(self.email, client_secret) + + self.assertEquals(len(self.email_attempts), 1) + + # Attempt to add email without clicking the link + request, channel = self.make_request( + "POST", + b"/_matrix/client/unstable/account/3pid/add", + { + "client_secret": client_secret, + "sid": session_id, + "auth": { + "type": "m.login.password", + "user": self.user_id, + "password": "test", + }, + }, + access_token=self.user_id_tok, + ) + self.render(request) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_3pid, access_token=self.user_id_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertFalse(channel.json_body["threepids"]) + + def test_no_valid_token(self): + """Test that we do actually need to request a token and can't just + make a session up. + """ + client_secret = "foobar" + session_id = "weasle" + + # Attempt to add email without even requesting an email + request, channel = self.make_request( + "POST", + b"/_matrix/client/unstable/account/3pid/add", + { + "client_secret": client_secret, + "sid": session_id, + "auth": { + "type": "m.login.password", + "user": self.user_id, + "password": "test", + }, + }, + access_token=self.user_id_tok, + ) + self.render(request) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_3pid, access_token=self.user_id_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertFalse(channel.json_body["threepids"]) + + def _request_token(self, email, client_secret): + request, channel = self.make_request( + "POST", + b"account/3pid/email/requestToken", + {"client_secret": client_secret, "email": email, "send_attempt": 1}, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.result) + + return channel.json_body["sid"] + + def _validate_token(self, link): + # Remove the host + path = link.replace("https://example.com", "") + + request, channel = self.make_request("GET", path, shorthand=False) + self.render(request) + self.assertEquals(200, channel.code, channel.result) + + def _get_link_from_email(self): + assert self.email_attempts, "No emails have been sent" + + raw_msg = self.email_attempts[-1].decode("UTF-8") + mail = Parser().parsestr(raw_msg) + + text = None + for part in mail.walk(): + if part.get_content_type() == "text/plain": + text = part.get_payload(decode=True).decode("UTF-8") + break + + if not text: + self.fail("Could not find text portion of email to parse") + + match = re.search(r"https://example.com\S+", text) + assert match, "Could not find link in email" + + return match.group(0) diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py index b6df1396ad..293ccfba2b 100644 --- a/tests/rest/client/v2_alpha/test_auth.py +++ b/tests/rest/client/v2_alpha/test_auth.py @@ -12,16 +12,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from typing import List, Union from twisted.internet.defer import succeed import synapse.rest.admin from synapse.api.constants import LoginType from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker -from synapse.rest.client.v2_alpha import auth, register +from synapse.http.site import SynapseRequest +from synapse.rest.client.v1 import login +from synapse.rest.client.v2_alpha import auth, devices, register +from synapse.types import JsonDict from tests import unittest +from tests.server import FakeChannel class DummyRecaptchaChecker(UserInteractiveAuthChecker): @@ -34,11 +38,15 @@ class DummyRecaptchaChecker(UserInteractiveAuthChecker): return succeed(True) +class DummyPasswordChecker(UserInteractiveAuthChecker): + def check_auth(self, authdict, clientip): + return succeed(authdict["identifier"]["user"]) + + class FallbackAuthTests(unittest.HomeserverTestCase): servlets = [ auth.register_servlets, - synapse.rest.admin.register_servlets_for_client_rest_resource, register.register_servlets, ] hijack_auth = False @@ -59,59 +67,250 @@ class FallbackAuthTests(unittest.HomeserverTestCase): auth_handler = hs.get_auth_handler() auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker - @unittest.INFO - def test_fallback_captcha(self): - + def register(self, expected_response: int, body: JsonDict) -> FakeChannel: + """Make a register request.""" request, channel = self.make_request( - "POST", - "register", - {"username": "user", "type": "m.login.password", "password": "bar"}, - ) + "POST", "register", body + ) # type: SynapseRequest, FakeChannel self.render(request) - # Returns a 401 as per the spec - self.assertEqual(request.code, 401) - # Grab the session - session = channel.json_body["session"] - # Assert our configured public key is being given - self.assertEqual( - channel.json_body["params"]["m.login.recaptcha"]["public_key"], "brokencake" - ) + self.assertEqual(request.code, expected_response) + return channel + + def recaptcha( + self, session: str, expected_post_response: int, post_session: str = None + ) -> None: + """Get and respond to a fallback recaptcha. Returns the second request.""" + if post_session is None: + post_session = session request, channel = self.make_request( "GET", "auth/m.login.recaptcha/fallback/web?session=" + session - ) + ) # type: SynapseRequest, FakeChannel self.render(request) self.assertEqual(request.code, 200) request, channel = self.make_request( "POST", "auth/m.login.recaptcha/fallback/web?session=" - + session + + post_session + "&g-recaptcha-response=a", ) self.render(request) - self.assertEqual(request.code, 200) + self.assertEqual(request.code, expected_post_response) # The recaptcha handler is called with the response given attempts = self.recaptcha_checker.recaptcha_attempts self.assertEqual(len(attempts), 1) self.assertEqual(attempts[0][0]["response"], "a") - # also complete the dummy auth - request, channel = self.make_request( - "POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}} + @unittest.INFO + def test_fallback_captcha(self): + """Ensure that fallback auth via a captcha works.""" + # Returns a 401 as per the spec + channel = self.register( + 401, {"username": "user", "type": "m.login.password", "password": "bar"}, ) - self.render(request) - # Now we should have fufilled a complete auth flow, including + # Grab the session + session = channel.json_body["session"] + # Assert our configured public key is being given + self.assertEqual( + channel.json_body["params"]["m.login.recaptcha"]["public_key"], "brokencake" + ) + + # Complete the recaptcha step. + self.recaptcha(session, 200) + + # also complete the dummy auth + self.register(200, {"auth": {"session": session, "type": "m.login.dummy"}}) + + # Now we should have fulfilled a complete auth flow, including # the recaptcha fallback step, we can then send a # request to the register API with the session in the authdict. - request, channel = self.make_request( - "POST", "register", {"auth": {"session": session}} - ) - self.render(request) - self.assertEqual(channel.code, 200) + channel = self.register(200, {"auth": {"session": session}}) # We're given a registered user. self.assertEqual(channel.json_body["user_id"], "@user:test") + + def test_complete_operation_unknown_session(self): + """ + Attempting to mark an invalid session as complete should error. + """ + # Make the initial request to register. (Later on a different password + # will be used.) + # Returns a 401 as per the spec + channel = self.register( + 401, {"username": "user", "type": "m.login.password", "password": "bar"} + ) + + # Grab the session + session = channel.json_body["session"] + # Assert our configured public key is being given + self.assertEqual( + channel.json_body["params"]["m.login.recaptcha"]["public_key"], "brokencake" + ) + + # Attempt to complete the recaptcha step with an unknown session. + # This results in an error. + self.recaptcha(session, 400, session + "unknown") + + +class UIAuthTests(unittest.HomeserverTestCase): + servlets = [ + auth.register_servlets, + devices.register_servlets, + login.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, + register.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + auth_handler = hs.get_auth_handler() + auth_handler.checkers[LoginType.PASSWORD] = DummyPasswordChecker(hs) + + self.user_pass = "pass" + self.user = self.register_user("test", self.user_pass) + self.user_tok = self.login("test", self.user_pass) + + def get_device_ids(self) -> List[str]: + # Get the list of devices so one can be deleted. + request, channel = self.make_request( + "GET", "devices", access_token=self.user_tok, + ) # type: SynapseRequest, FakeChannel + self.render(request) + + # Get the ID of the device. + self.assertEqual(request.code, 200) + return [d["device_id"] for d in channel.json_body["devices"]] + + def delete_device( + self, device: str, expected_response: int, body: Union[bytes, JsonDict] = b"" + ) -> FakeChannel: + """Delete an individual device.""" + request, channel = self.make_request( + "DELETE", "devices/" + device, body, access_token=self.user_tok + ) # type: SynapseRequest, FakeChannel + self.render(request) + + # Ensure the response is sane. + self.assertEqual(request.code, expected_response) + + return channel + + def delete_devices(self, expected_response: int, body: JsonDict) -> FakeChannel: + """Delete 1 or more devices.""" + # Note that this uses the delete_devices endpoint so that we can modify + # the payload half-way through some tests. + request, channel = self.make_request( + "POST", "delete_devices", body, access_token=self.user_tok, + ) # type: SynapseRequest, FakeChannel + self.render(request) + + # Ensure the response is sane. + self.assertEqual(request.code, expected_response) + + return channel + + def test_ui_auth(self): + """ + Test user interactive authentication outside of registration. + """ + device_id = self.get_device_ids()[0] + + # Attempt to delete this device. + # Returns a 401 as per the spec + channel = self.delete_device(device_id, 401) + + # Grab the session + session = channel.json_body["session"] + # Ensure that flows are what is expected. + self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"]) + + # Make another request providing the UI auth flow. + self.delete_device( + device_id, + 200, + { + "auth": { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": self.user}, + "password": self.user_pass, + "session": session, + }, + }, + ) + + def test_can_change_body(self): + """ + The client dict can be modified during the user interactive authentication session. + + Note that it is not spec compliant to modify the client dict during a + user interactive authentication session, but many clients currently do. + + When Synapse is updated to be spec compliant, the call to re-use the + session ID should be rejected. + """ + # Create a second login. + self.login("test", self.user_pass) + + device_ids = self.get_device_ids() + self.assertEqual(len(device_ids), 2) + + # Attempt to delete the first device. + # Returns a 401 as per the spec + channel = self.delete_devices(401, {"devices": [device_ids[0]]}) + + # Grab the session + session = channel.json_body["session"] + # Ensure that flows are what is expected. + self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"]) + + # Make another request providing the UI auth flow, but try to delete the + # second device. + self.delete_devices( + 200, + { + "devices": [device_ids[1]], + "auth": { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": self.user}, + "password": self.user_pass, + "session": session, + }, + }, + ) + + def test_cannot_change_uri(self): + """ + The initial requested URI cannot be modified during the user interactive authentication session. + """ + # Create a second login. + self.login("test", self.user_pass) + + device_ids = self.get_device_ids() + self.assertEqual(len(device_ids), 2) + + # Attempt to delete the first device. + # Returns a 401 as per the spec + channel = self.delete_device(device_ids[0], 401) + + # Grab the session + session = channel.json_body["session"] + # Ensure that flows are what is expected. + self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"]) + + # Make another request providing the UI auth flow, but try to delete the + # second device. This results in an error. + self.delete_device( + device_ids[1], + 403, + { + "auth": { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": self.user}, + "password": self.user_pass, + "session": session, + }, + }, + ) diff --git a/tests/rest/client/v2_alpha/test_password_policy.py b/tests/rest/client/v2_alpha/test_password_policy.py new file mode 100644 index 0000000000..c57072f50c --- /dev/null +++ b/tests/rest/client/v2_alpha/test_password_policy.py @@ -0,0 +1,179 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +from synapse.api.constants import LoginType +from synapse.api.errors import Codes +from synapse.rest import admin +from synapse.rest.client.v1 import login +from synapse.rest.client.v2_alpha import account, password_policy, register + +from tests import unittest + + +class PasswordPolicyTestCase(unittest.HomeserverTestCase): + """Tests the password policy feature and its compliance with MSC2000. + + When validating a password, Synapse does the necessary checks in this order: + + 1. Password is long enough + 2. Password contains digit(s) + 3. Password contains symbol(s) + 4. Password contains uppercase letter(s) + 5. Password contains lowercase letter(s) + + For each test below that checks whether a password triggers the right error code, + that test provides a password good enough to pass the previous tests, but not the + one it is currently testing (nor any test that comes afterward). + """ + + servlets = [ + admin.register_servlets_for_client_rest_resource, + login.register_servlets, + register.register_servlets, + password_policy.register_servlets, + account.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + self.register_url = "/_matrix/client/r0/register" + self.policy = { + "enabled": True, + "minimum_length": 10, + "require_digit": True, + "require_symbol": True, + "require_lowercase": True, + "require_uppercase": True, + } + + config = self.default_config() + config["password_config"] = { + "policy": self.policy, + } + + hs = self.setup_test_homeserver(config=config) + return hs + + def test_get_policy(self): + """Tests if the /password_policy endpoint returns the configured policy.""" + + request, channel = self.make_request( + "GET", "/_matrix/client/r0/password_policy" + ) + self.render(request) + + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual( + channel.json_body, + { + "m.minimum_length": 10, + "m.require_digit": True, + "m.require_symbol": True, + "m.require_lowercase": True, + "m.require_uppercase": True, + }, + channel.result, + ) + + def test_password_too_short(self): + request_data = json.dumps({"username": "kermit", "password": "shorty"}) + request, channel = self.make_request("POST", self.register_url, request_data) + self.render(request) + + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual( + channel.json_body["errcode"], Codes.PASSWORD_TOO_SHORT, channel.result, + ) + + def test_password_no_digit(self): + request_data = json.dumps({"username": "kermit", "password": "longerpassword"}) + request, channel = self.make_request("POST", self.register_url, request_data) + self.render(request) + + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual( + channel.json_body["errcode"], Codes.PASSWORD_NO_DIGIT, channel.result, + ) + + def test_password_no_symbol(self): + request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword"}) + request, channel = self.make_request("POST", self.register_url, request_data) + self.render(request) + + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual( + channel.json_body["errcode"], Codes.PASSWORD_NO_SYMBOL, channel.result, + ) + + def test_password_no_uppercase(self): + request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword!"}) + request, channel = self.make_request("POST", self.register_url, request_data) + self.render(request) + + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual( + channel.json_body["errcode"], Codes.PASSWORD_NO_UPPERCASE, channel.result, + ) + + def test_password_no_lowercase(self): + request_data = json.dumps({"username": "kermit", "password": "L0NGERPASSWORD!"}) + request, channel = self.make_request("POST", self.register_url, request_data) + self.render(request) + + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual( + channel.json_body["errcode"], Codes.PASSWORD_NO_LOWERCASE, channel.result, + ) + + def test_password_compliant(self): + request_data = json.dumps({"username": "kermit", "password": "L0ngerpassword!"}) + request, channel = self.make_request("POST", self.register_url, request_data) + self.render(request) + + # Getting a 401 here means the password has passed validation and the server has + # responded with a list of registration flows. + self.assertEqual(channel.code, 401, channel.result) + + def test_password_change(self): + """This doesn't test every possible use case, only that hitting /account/password + triggers the password validation code. + """ + compliant_password = "C0mpl!antpassword" + not_compliant_password = "notcompliantpassword" + + user_id = self.register_user("kermit", compliant_password) + tok = self.login("kermit", compliant_password) + + request_data = json.dumps( + { + "new_password": not_compliant_password, + "auth": { + "password": compliant_password, + "type": LoginType.PASSWORD, + "user": user_id, + }, + } + ) + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/account/password", + request_data, + access_token=tok, + ) + self.render(request) + + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.json_body["errcode"], Codes.PASSWORD_NO_DIGIT) diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index d0c997e385..5637ce2090 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -25,7 +25,7 @@ import synapse.rest.admin from synapse.api.constants import LoginType from synapse.api.errors import Codes from synapse.appservice import ApplicationService -from synapse.rest.client.v1 import login +from synapse.rest.client.v1 import login, logout from synapse.rest.client.v2_alpha import account, account_validity, register, sync from tests import unittest @@ -33,11 +33,15 @@ from tests import unittest class RegisterRestServletTestCase(unittest.HomeserverTestCase): - servlets = [register.register_servlets] + servlets = [ + login.register_servlets, + register.register_servlets, + synapse.rest.admin.register_servlets, + ] url = b"/_matrix/client/r0/register" - def default_config(self, name="test"): - config = super().default_config(name) + def default_config(self): + config = super().default_config() config["allow_guest_access"] = True return config @@ -260,6 +264,47 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): [["m.login.email.identity"]], (f["stages"] for f in flows) ) + @unittest.override_config( + { + "request_token_inhibit_3pid_errors": True, + "public_baseurl": "https://test_server", + "email": { + "smtp_host": "mail_server", + "smtp_port": 2525, + "notif_from": "sender@host", + }, + } + ) + def test_request_token_existing_email_inhibit_error(self): + """Test that requesting a token via this endpoint doesn't leak existing + associations if configured that way. + """ + user_id = self.register_user("kermit", "monkey") + self.login("kermit", "monkey") + + email = "test@example.com" + + # Add a threepid + self.get_success( + self.hs.get_datastore().user_add_threepid( + user_id=user_id, + medium="email", + address=email, + validated_at=0, + added_at=0, + ) + ) + + request, channel = self.make_request( + "POST", + b"register/email/requestToken", + {"client_secret": "foobar", "email": email, "send_attempt": 1}, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.result) + + self.assertIsNotNone(channel.json_body.get("sid")) + class AccountValidityTestCase(unittest.HomeserverTestCase): @@ -268,6 +313,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): synapse.rest.admin.register_servlets_for_client_rest_resource, login.register_servlets, sync.register_servlets, + logout.register_servlets, account_validity.register_servlets, ] @@ -360,6 +406,39 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result ) + def test_logging_out_expired_user(self): + user_id = self.register_user("kermit", "monkey") + tok = self.login("kermit", "monkey") + + self.register_user("admin", "adminpassword", admin=True) + admin_tok = self.login("admin", "adminpassword") + + url = "/_matrix/client/unstable/admin/account_validity/validity" + params = { + "user_id": user_id, + "expiration_ts": 0, + "enable_renewal_emails": False, + } + request_data = json.dumps(params) + request, channel = self.make_request( + b"POST", url, request_data, access_token=admin_tok + ) + self.render(request) + self.assertEquals(channel.result["code"], b"200", channel.result) + + # Try to log the user out + request, channel = self.make_request(b"POST", "/logout", access_token=tok) + self.render(request) + self.assertEquals(channel.result["code"], b"200", channel.result) + + # Log the user in again (allowed for expired accounts) + tok = self.login("kermit", "monkey") + + # Try to log out all of the user's sessions + request, channel = self.make_request(b"POST", "/logout/all", access_token=tok) + self.render(request) + self.assertEquals(channel.result["code"], b"200", channel.result) + class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py index 6776a56cad..99eb477149 100644 --- a/tests/rest/key/v2/test_remote_key_resource.py +++ b/tests/rest/key/v2/test_remote_key_resource.py @@ -143,8 +143,8 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase): endpoint, to check that the two implementations are compatible. """ - def default_config(self, *args, **kwargs): - config = super().default_config(*args, **kwargs) + def default_config(self): + config = super().default_config() # replace the signing key with our own self.hs_signing_key = signedjson.key.generate_signing_key("kssk") diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py index 852b8ab11c..2826211f32 100644 --- a/tests/rest/media/v1/test_url_preview.py +++ b/tests/rest/media/v1/test_url_preview.py @@ -74,6 +74,12 @@ class URLPreviewTests(unittest.HomeserverTestCase): ) config["url_preview_ip_range_whitelist"] = ("1.1.1.1",) config["url_preview_url_blacklist"] = [] + config["url_preview_accept_language"] = [ + "en-UK", + "en-US;q=0.9", + "fr;q=0.8", + "*;q=0.7", + ] self.storage_path = self.mktemp() self.media_store_path = self.mktemp() @@ -507,3 +513,52 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.pump() self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body, {}) + + def test_accept_language_config_option(self): + """ + Accept-Language header is sent to the remote server + """ + self.lookups["example.com"] = [(IPv4Address, "8.8.8.8")] + + # Build and make a request to the server + request, channel = self.make_request( + "GET", "url_preview?url=http://example.com", shorthand=False + ) + request.render(self.preview_url) + self.pump() + + # Extract Synapse's tcp client + client = self.reactor.tcpClients[0][2].buildProtocol(None) + + # Build a fake remote server to reply with + server = AccumulatingProtocol() + + # Connect the two together + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + + # Tell Synapse that it has received some data from the remote server + client.dataReceived( + b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\nContent-Type: text/html\r\n\r\n" + % (len(self.end_content),) + + self.end_content + ) + + # Move the reactor along until we get a response on our original channel + self.pump() + self.assertEqual(channel.code, 200) + self.assertEqual( + channel.json_body, {"og:title": "~matrix~", "og:description": "hi"} + ) + + # Check that the server received the Accept-Language header as part + # of the request from Synapse + self.assertIn( + ( + b"Accept-Language: en-UK\r\n" + b"Accept-Language: en-US;q=0.9\r\n" + b"Accept-Language: fr;q=0.8\r\n" + b"Accept-Language: *;q=0.7" + ), + server.data, + ) diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index eb540e34f6..406f29a7c0 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -19,6 +19,9 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, LimitBlockingTypes, ServerNoticeMsgType from synapse.api.errors import ResourceLimitError +from synapse.rest import admin +from synapse.rest.client.v1 import login, room +from synapse.rest.client.v2_alpha import sync from synapse.server_notices.resource_limits_server_notices import ( ResourceLimitsServerNotices, ) @@ -28,7 +31,7 @@ from tests import unittest class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): - hs_config = self.default_config("test") + hs_config = self.default_config() hs_config["server_notices"] = { "system_mxid_localpart": "server", "system_mxid_display_name": "test display name", @@ -52,26 +55,19 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): self._rlsn._store.user_last_seen_monthly_active = Mock( return_value=defer.succeed(1000) ) - self._send_notice = self._rlsn._server_notices_manager.send_notice - self._rlsn._server_notices_manager.send_notice = Mock() - self._rlsn._state.get_current_state = Mock(return_value=defer.succeed(None)) - self._rlsn._store.get_events = Mock(return_value=defer.succeed({})) - + self._rlsn._server_notices_manager.send_notice = Mock( + return_value=defer.succeed(Mock()) + ) self._send_notice = self._rlsn._server_notices_manager.send_notice self.hs.config.limit_usage_by_mau = True self.user_id = "@user_id:test" - # self.server_notices_mxid = "@server:test" - # self.server_notices_mxid_display_name = None - # self.server_notices_mxid_avatar_url = None - # self.server_notices_room_name = "Server Notices" - - self._rlsn._server_notices_manager.get_notice_room_for_user = Mock( - returnValue="" + self._rlsn._server_notices_manager.get_or_create_notice_room_for_user = Mock( + return_value=defer.succeed("!something:localhost") ) - self._rlsn._store.add_tag_to_room = Mock() - self._rlsn._store.get_tags_for_room = Mock(return_value={}) + self._rlsn._store.add_tag_to_room = Mock(return_value=defer.succeed(None)) + self._rlsn._store.get_tags_for_room = Mock(return_value=defer.succeed({})) self.hs.config.admin_contact = "mailto:user@test.com" def test_maybe_send_server_notice_to_user_flag_off(self): @@ -92,14 +88,13 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): def test_maybe_send_server_notice_to_user_remove_blocked_notice(self): """Test when user has blocked notice, but should have it removed""" - self._rlsn._auth.check_auth_blocking = Mock() + self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None)) mock_event = Mock( type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} ) self._rlsn._store.get_events = Mock( return_value=defer.succeed({"123": mock_event}) ) - self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) # Would be better to check the content, but once == remove blocking event self._send_notice.assert_called_once() @@ -109,7 +104,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): Test when user has blocked notice, but notice ought to be there (NOOP) """ self._rlsn._auth.check_auth_blocking = Mock( - side_effect=ResourceLimitError(403, "foo") + return_value=defer.succeed(None), side_effect=ResourceLimitError(403, "foo") ) mock_event = Mock( @@ -118,6 +113,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): self._rlsn._store.get_events = Mock( return_value=defer.succeed({"123": mock_event}) ) + self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) self._send_notice.assert_not_called() @@ -126,9 +122,8 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): """ Test when user does not have blocked notice, but should have one """ - self._rlsn._auth.check_auth_blocking = Mock( - side_effect=ResourceLimitError(403, "foo") + return_value=defer.succeed(None), side_effect=ResourceLimitError(403, "foo") ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) @@ -139,7 +134,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): """ Test when user does not have blocked notice, nor should they (NOOP) """ - self._rlsn._auth.check_auth_blocking = Mock() + self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None)) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) @@ -150,7 +145,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): Test when user is not part of the MAU cohort - this should not ever happen - but ... """ - self._rlsn._auth.check_auth_blocking = Mock() + self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None)) self._rlsn._store.user_last_seen_monthly_active = Mock( return_value=defer.succeed(None) ) @@ -164,24 +159,28 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): an alert message is not sent into the room """ self.hs.config.mau_limit_alerting = False + self._rlsn._auth.check_auth_blocking = Mock( + return_value=defer.succeed(None), side_effect=ResourceLimitError( 403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER - ) + ), ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) - self.assertTrue(self._send_notice.call_count == 0) + self.assertEqual(self._send_notice.call_count, 0) def test_check_hs_disabled_unaffected_by_mau_alert_suppression(self): """ Test that when a server is disabled, that MAU limit alerting is ignored. """ self.hs.config.mau_limit_alerting = False + self._rlsn._auth.check_auth_blocking = Mock( + return_value=defer.succeed(None), side_effect=ResourceLimitError( 403, "foo", limit_type=LimitBlockingTypes.HS_DISABLED - ) + ), ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) @@ -195,10 +194,12 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): """ self.hs.config.mau_limit_alerting = False self._rlsn._auth.check_auth_blocking = Mock( + return_value=defer.succeed(None), side_effect=ResourceLimitError( 403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER - ) + ), ) + self._rlsn._server_notices_manager.__is_room_currently_blocked = Mock( return_value=defer.succeed((True, [])) ) @@ -215,6 +216,26 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + sync.register_servlets, + ] + + def default_config(self): + c = super().default_config() + c["server_notices"] = { + "system_mxid_localpart": "server", + "system_mxid_display_name": None, + "system_mxid_avatar_url": None, + "room_name": "Test Server Notice Room", + } + c["limit_usage_by_mau"] = True + c["max_mau_value"] = 5 + c["admin_contact"] = "mailto:user@test.com" + return c + def prepare(self, reactor, clock, hs): self.store = self.hs.get_datastore() self.server_notices_sender = self.hs.get_server_notices_sender() @@ -228,22 +249,14 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): if not isinstance(self._rlsn, ResourceLimitsServerNotices): raise Exception("Failed to find reference to ResourceLimitsServerNotices") - self.hs.config.limit_usage_by_mau = True - self.hs.config.hs_disabled = False - self.hs.config.max_mau_value = 5 - self.hs.config.server_notices_mxid = "@server:test" - self.hs.config.server_notices_mxid_display_name = None - self.hs.config.server_notices_mxid_avatar_url = None - self.hs.config.server_notices_room_name = "Test Server Notice Room" - self.user_id = "@user_id:test" - self.hs.config.admin_contact = "mailto:user@test.com" - def test_server_notice_only_sent_once(self): self.store.get_monthly_active_count = Mock(return_value=1000) - self.store.user_last_seen_monthly_active = Mock(return_value=1000) + self.store.user_last_seen_monthly_active = Mock( + return_value=defer.succeed(1000) + ) # Call the function multiple times to ensure we only send the notice once self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) @@ -253,7 +266,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): # Now lets get the last load of messages in the service notice room and # check that there is only one server notice room_id = self.get_success( - self.server_notices_manager.get_notice_room_for_user(self.user_id) + self.server_notices_manager.get_or_create_notice_room_for_user(self.user_id) ) token = self.get_success(self.event_source.get_current_token()) @@ -273,3 +286,86 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): count += 1 self.assertEqual(count, 1) + + def test_no_invite_without_notice(self): + """Tests that a user doesn't get invited to a server notices room without a + server notice being sent. + + The scenario for this test is a single user on a server where the MAU limit + hasn't been reached (since it's the only user and the limit is 5), so users + shouldn't receive a server notice. + """ + self.register_user("user", "password") + tok = self.login("user", "password") + + request, channel = self.make_request("GET", "/sync?timeout=0", access_token=tok) + self.render(request) + + invites = channel.json_body["rooms"]["invite"] + self.assertEqual(len(invites), 0, invites) + + def test_invite_with_notice(self): + """Tests that, if the MAU limit is hit, the server notices user invites each user + to a room in which it has sent a notice. + """ + user_id, tok, room_id = self._trigger_notice_and_join() + + # Sync again to retrieve the events in the room, so we can check whether this + # room has a notice in it. + request, channel = self.make_request("GET", "/sync?timeout=0", access_token=tok) + self.render(request) + + # Scan the events in the room to search for a message from the server notices + # user. + events = channel.json_body["rooms"]["join"][room_id]["timeline"]["events"] + notice_in_room = False + for event in events: + if ( + event["type"] == EventTypes.Message + and event["sender"] == self.hs.config.server_notices_mxid + ): + notice_in_room = True + + self.assertTrue(notice_in_room, "No server notice in room") + + def _trigger_notice_and_join(self): + """Creates enough active users to hit the MAU limit and trigger a system notice + about it, then joins the system notices room with one of the users created. + + Returns: + user_id (str): The ID of the user that joined the room. + tok (str): The access token of the user that joined the room. + room_id (str): The ID of the room that's been joined. + """ + user_id = None + tok = None + invites = [] + + # Register as many users as the MAU limit allows. + for i in range(self.hs.config.max_mau_value): + localpart = "user%d" % i + user_id = self.register_user(localpart, "password") + tok = self.login(localpart, "password") + + # Sync with the user's token to mark the user as active. + request, channel = self.make_request( + "GET", "/sync?timeout=0", access_token=tok, + ) + self.render(request) + + # Also retrieves the list of invites for this user. We don't care about that + # one except if we're processing the last user, which should have received an + # invite to a room with a server notice about the MAU limit being reached. + # We could also pick another user and sync with it, which would return an + # invite to a system notices room, but it doesn't matter which user we're + # using so we use the last one because it saves us an extra sync. + invites = channel.json_body["rooms"]["invite"] + + # Make sure we have an invite to process. + self.assertEqual(len(invites), 1, invites) + + # Join the room. + room_id = list(invites.keys())[0] + self.helper.join(room=room_id, user=user_id, tok=tok) + + return user_id, tok, room_id diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index e37260a820..5a50e4fdd4 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -25,8 +25,8 @@ from synapse.util.caches.descriptors import Cache, cached from tests import unittest -class CacheTestCase(unittest.TestCase): - def setUp(self): +class CacheTestCase(unittest.HomeserverTestCase): + def prepare(self, reactor, clock, homeserver): self.cache = Cache("test") def test_empty(self): @@ -96,7 +96,7 @@ class CacheTestCase(unittest.TestCase): cache.get(3) -class CacheDecoratorTestCase(unittest.TestCase): +class CacheDecoratorTestCase(unittest.HomeserverTestCase): @defer.inlineCallbacks def test_passthrough(self): class A(object): @@ -239,7 +239,7 @@ class CacheDecoratorTestCase(unittest.TestCase): callcount2 = [0] class A(object): - @cached(max_entries=4) # HACK: This makes it 2 due to cache factor + @cached(max_entries=2) def func(self, key): callcount[0] += 1 return key diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 31710949a8..ef296e7dab 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -43,7 +43,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): ) hs.config.app_service_config_files = self.as_yaml_files - hs.config.event_cache_size = 1 + hs.config.caches.event_cache_size = 1 hs.config.password_providers = [] self.as_token = "token1" @@ -110,7 +110,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): ) hs.config.app_service_config_files = self.as_yaml_files - hs.config.event_cache_size = 1 + hs.config.caches.event_cache_size = 1 hs.config.password_providers = [] self.as_list = [ @@ -422,7 +422,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): ) hs.config.app_service_config_files = [f1, f2] - hs.config.event_cache_size = 1 + hs.config.caches.event_cache_size = 1 hs.config.password_providers = [] database = hs.get_datastores().databases[0] @@ -440,7 +440,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): ) hs.config.app_service_config_files = [f1, f2] - hs.config.event_cache_size = 1 + hs.config.caches.event_cache_size = 1 hs.config.password_providers = [] with self.assertRaises(ConfigError) as cm: @@ -464,7 +464,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): ) hs.config.app_service_config_files = [f1, f2] - hs.config.event_cache_size = 1 + hs.config.caches.event_cache_size = 1 hs.config.password_providers = [] with self.assertRaises(ConfigError) as cm: diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index ae14fb407d..940b166129 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -11,7 +11,9 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, homeserver): self.updates = self.hs.get_datastore().db.updates # type: BackgroundUpdater # the base test class should have run the real bg updates for us - self.assertTrue(self.updates.has_completed_background_updates()) + self.assertTrue( + self.get_success(self.updates.has_completed_background_updates()) + ) self.update_handler = Mock() self.updates.register_background_update_handler( @@ -25,12 +27,20 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): # the target runtime for each bg update target_background_update_duration_ms = 50000 + store = self.hs.get_datastore() + self.get_success( + store.db.simple_insert( + "background_updates", + values={"update_name": "test_update", "progress_json": '{"my_key": 1}'}, + ) + ) + # first step: make a bit of progress @defer.inlineCallbacks def update(progress, count): yield self.clock.sleep((count * duration_ms) / 1000) progress = {"my_key": progress["my_key"] + 1} - yield self.hs.get_datastore().db.runInteraction( + yield store.db.runInteraction( "update_progress", self.updates._background_update_progress_txn, "test_update", @@ -39,10 +49,6 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): return count self.update_handler.side_effect = update - - self.get_success( - self.updates.start_background_update("test_update", {"my_key": 1}) - ) self.update_handler.reset_mock() res = self.get_success( self.updates.do_next_background_update( @@ -50,7 +56,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): ), by=0.1, ) - self.assertIsNotNone(res) + self.assertFalse(res) # on the first call, we should get run with the default background update size self.update_handler.assert_called_once_with( @@ -73,7 +79,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): result = self.get_success( self.updates.do_next_background_update(target_background_update_duration_ms) ) - self.assertIsNotNone(result) + self.assertFalse(result) self.update_handler.assert_called_once() # third step: we don't expect to be called any more @@ -81,5 +87,5 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): result = self.get_success( self.updates.do_next_background_update(target_background_update_duration_ms) ) - self.assertIsNone(result) + self.assertTrue(result) self.assertFalse(self.update_handler.called) diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index cdee0a9e60..278961c331 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -51,7 +51,8 @@ class SQLBaseStoreTestCase(unittest.TestCase): config = Mock() config._disable_native_upserts = True - config.event_cache_size = 1 + config.caches = Mock() + config.caches.event_cache_size = 1 hs = TestHomeServer("test", config=config) sqlite_config = {"name": "sqlite3"} diff --git a/tests/storage/test_database.py b/tests/storage/test_database.py new file mode 100644 index 0000000000..5a77c84962 --- /dev/null +++ b/tests/storage/test_database.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.storage.database import make_tuple_comparison_clause +from synapse.storage.engines import BaseDatabaseEngine + +from tests import unittest + + +def _stub_db_engine(**kwargs) -> BaseDatabaseEngine: + # returns a DatabaseEngine, circumventing the abc mechanism + # any kwargs are set as attributes on the class before instantiating it + t = type( + "TestBaseDatabaseEngine", + (BaseDatabaseEngine,), + dict(BaseDatabaseEngine.__dict__), + ) + # defeat the abc mechanism + t.__abstractmethods__ = set() + for k, v in kwargs.items(): + setattr(t, k, v) + return t(None, None) + + +class TupleComparisonClauseTestCase(unittest.TestCase): + def test_native_tuple_comparison(self): + db_engine = _stub_db_engine(supports_tuple_comparison=True) + clause, args = make_tuple_comparison_clause(db_engine, [("a", 1), ("b", 2)]) + self.assertEqual(clause, "(a,b) > (?,?)") + self.assertEqual(args, [1, 2]) + + def test_emulated_tuple_comparison(self): + db_engine = _stub_db_engine(supports_tuple_comparison=False) + clause, args = make_tuple_comparison_clause( + db_engine, [("a", 1), ("b", 2), ("c", 3)] + ) + self.assertEqual( + clause, "(a >= ? AND (a > ? OR (b >= ? AND (b > ? OR c > ?))))" + ) + self.assertEqual(args, [1, 1, 2, 2, 3]) diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index 6f8d990959..c2539b353a 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -88,51 +88,6 @@ class DeviceStoreTestCase(tests.unittest.TestCase): # Check original device_ids are contained within these updates self._check_devices_in_updates(device_ids, device_updates) - @defer.inlineCallbacks - def test_get_device_updates_by_remote_limited(self): - # Test breaking the update limit in 1, 101, and 1 device_id segments - - # first add one device - device_ids1 = ["device_id0"] - yield self.store.add_device_change_to_streams( - "user_id", device_ids1, ["someotherhost"] - ) - - # then add 101 - device_ids2 = ["device_id" + str(i + 1) for i in range(101)] - yield self.store.add_device_change_to_streams( - "user_id", device_ids2, ["someotherhost"] - ) - - # then one more - device_ids3 = ["newdevice"] - yield self.store.add_device_change_to_streams( - "user_id", device_ids3, ["someotherhost"] - ) - - # - # now read them back. - # - - # first we should get a single update - now_stream_id, device_updates = yield self.store.get_device_updates_by_remote( - "someotherhost", -1, limit=100 - ) - self._check_devices_in_updates(device_ids1, device_updates) - - # Then we should get an empty list back as the 101 devices broke the limit - now_stream_id, device_updates = yield self.store.get_device_updates_by_remote( - "someotherhost", now_stream_id, limit=100 - ) - self.assertEqual(len(device_updates), 0) - - # The 101 devices should've been cleared, so we should now just get one device - # update - now_stream_id, device_updates = yield self.store.get_device_updates_by_remote( - "someotherhost", now_stream_id, limit=100 - ) - self._check_devices_in_updates(device_ids3, device_updates) - def _check_devices_in_updates(self, expected_device_ids, device_updates): """Check that an specific device ids exist in a list of device update EDUs""" self.assertEqual(len(device_updates), len(expected_device_ids)) diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index d4bcf1821e..b45bc9c115 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -35,6 +35,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): def setUp(self): hs = yield tests.utils.setup_test_homeserver(self.addCleanup) self.store = hs.get_datastore() + self.persist_events_store = hs.get_datastores().persist_events @defer.inlineCallbacks def test_get_unread_push_actions_for_user_in_range_for_http(self): @@ -76,7 +77,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): ) yield self.store.db.runInteraction( "", - self.store._set_push_actions_for_event_and_users_txn, + self.persist_events_store._set_push_actions_for_event_and_users_txn, [(event, None)], [(event, None)], ) diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py new file mode 100644 index 0000000000..55e9ecf264 --- /dev/null +++ b/tests/storage/test_id_generators.py @@ -0,0 +1,184 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from synapse.storage.database import Database +from synapse.storage.util.id_generators import MultiWriterIdGenerator + +from tests.unittest import HomeserverTestCase +from tests.utils import USE_POSTGRES_FOR_TESTS + + +class MultiWriterIdGeneratorTestCase(HomeserverTestCase): + if not USE_POSTGRES_FOR_TESTS: + skip = "Requires Postgres" + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.db = self.store.db # type: Database + + self.get_success(self.db.runInteraction("_setup_db", self._setup_db)) + + def _setup_db(self, txn): + txn.execute("CREATE SEQUENCE foobar_seq") + txn.execute( + """ + CREATE TABLE foobar ( + stream_id BIGINT NOT NULL, + instance_name TEXT NOT NULL, + data TEXT + ); + """ + ) + + def _create_id_generator(self, instance_name="master") -> MultiWriterIdGenerator: + def _create(conn): + return MultiWriterIdGenerator( + conn, + self.db, + instance_name=instance_name, + table="foobar", + instance_column="instance_name", + id_column="stream_id", + sequence_name="foobar_seq", + ) + + return self.get_success(self.db.runWithConnection(_create)) + + def _insert_rows(self, instance_name: str, number: int): + def _insert(txn): + for _ in range(number): + txn.execute( + "INSERT INTO foobar VALUES (nextval('foobar_seq'), ?)", + (instance_name,), + ) + + self.get_success(self.db.runInteraction("test_single_instance", _insert)) + + def test_empty(self): + """Test an ID generator against an empty database gives sensible + current positions. + """ + + id_gen = self._create_id_generator() + + # The table is empty so we expect an empty map for positions + self.assertEqual(id_gen.get_positions(), {}) + + def test_single_instance(self): + """Test that reads and writes from a single process are handled + correctly. + """ + + # Prefill table with 7 rows written by 'master' + self._insert_rows("master", 7) + + id_gen = self._create_id_generator() + + self.assertEqual(id_gen.get_positions(), {"master": 7}) + self.assertEqual(id_gen.get_current_token("master"), 7) + + # Try allocating a new ID gen and check that we only see position + # advanced after we leave the context manager. + + async def _get_next_async(): + with await id_gen.get_next() as stream_id: + self.assertEqual(stream_id, 8) + + self.assertEqual(id_gen.get_positions(), {"master": 7}) + self.assertEqual(id_gen.get_current_token("master"), 7) + + self.get_success(_get_next_async()) + + self.assertEqual(id_gen.get_positions(), {"master": 8}) + self.assertEqual(id_gen.get_current_token("master"), 8) + + def test_multi_instance(self): + """Test that reads and writes from multiple processes are handled + correctly. + """ + self._insert_rows("first", 3) + self._insert_rows("second", 4) + + first_id_gen = self._create_id_generator("first") + second_id_gen = self._create_id_generator("second") + + self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7}) + self.assertEqual(first_id_gen.get_current_token("first"), 3) + self.assertEqual(first_id_gen.get_current_token("second"), 7) + + # Try allocating a new ID gen and check that we only see position + # advanced after we leave the context manager. + + async def _get_next_async(): + with await first_id_gen.get_next() as stream_id: + self.assertEqual(stream_id, 8) + + self.assertEqual( + first_id_gen.get_positions(), {"first": 3, "second": 7} + ) + + self.get_success(_get_next_async()) + + self.assertEqual(first_id_gen.get_positions(), {"first": 8, "second": 7}) + + # However the ID gen on the second instance won't have seen the update + self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7}) + + # ... but calling `get_next` on the second instance should give a unique + # stream ID + + async def _get_next_async(): + with await second_id_gen.get_next() as stream_id: + self.assertEqual(stream_id, 9) + + self.assertEqual( + second_id_gen.get_positions(), {"first": 3, "second": 7} + ) + + self.get_success(_get_next_async()) + + self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 9}) + + # If the second ID gen gets told about the first, it correctly updates + second_id_gen.advance("first", 8) + self.assertEqual(second_id_gen.get_positions(), {"first": 8, "second": 9}) + + def test_get_next_txn(self): + """Test that the `get_next_txn` function works correctly. + """ + + # Prefill table with 7 rows written by 'master' + self._insert_rows("master", 7) + + id_gen = self._create_id_generator() + + self.assertEqual(id_gen.get_positions(), {"master": 7}) + self.assertEqual(id_gen.get_current_token("master"), 7) + + # Try allocating a new ID gen and check that we only see position + # advanced after we leave the context manager. + + def _get_next_txn(txn): + stream_id = id_gen.get_next_txn(txn) + self.assertEqual(stream_id, 8) + + self.assertEqual(id_gen.get_positions(), {"master": 7}) + self.assertEqual(id_gen.get_current_token("master"), 7) + + self.get_success(self.db.runInteraction("test", _get_next_txn)) + + self.assertEqual(id_gen.get_positions(), {"master": 8}) + self.assertEqual(id_gen.get_current_token("master"), 8) diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py new file mode 100644 index 0000000000..ab0df5ea93 --- /dev/null +++ b/tests/storage/test_main.py @@ -0,0 +1,46 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Awesome Technologies Innovationslabor GmbH +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from twisted.internet import defer + +from synapse.types import UserID + +from tests import unittest +from tests.utils import setup_test_homeserver + + +class DataStoreTestCase(unittest.TestCase): + @defer.inlineCallbacks + def setUp(self): + hs = yield setup_test_homeserver(self.addCleanup) + + self.store = hs.get_datastore() + + self.user = UserID.from_string("@abcde:test") + self.displayname = "Frank" + + @defer.inlineCallbacks + def test_get_users_paginate(self): + yield self.store.register_user(self.user.to_string(), "pass") + yield self.store.create_profile(self.user.localpart) + yield self.store.set_profile_displayname(self.user.localpart, self.displayname) + + users, total = yield self.store.get_users_paginate( + 0, 10, name="bc", guests=False + ) + + self.assertEquals(1, total) + self.assertEquals(self.displayname, users.pop()["displayname"]) diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index 086adeb8fd..3b78d48896 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -55,6 +55,17 @@ class RoomStoreTestCase(unittest.TestCase): (yield self.store.get_room(self.room.to_string())), ) + @defer.inlineCallbacks + def test_get_room_with_stats(self): + self.assertDictContainsSubset( + { + "room_id": self.room.to_string(), + "creator": self.u_creator.to_string(), + "public": True, + }, + (yield self.store.get_room_with_stats(self.room.to_string())), + ) + class RoomEventsStoreTestCase(unittest.TestCase): @defer.inlineCallbacks diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 00df0ea68e..5dd46005e6 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -22,6 +22,8 @@ from synapse.rest.client.v1 import login, room from synapse.types import Requester, UserID from tests import unittest +from tests.test_utils import event_injection +from tests.utils import TestHomeServer class RoomMemberStoreTestCase(unittest.HomeserverTestCase): @@ -38,7 +40,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): ) return hs - def prepare(self, reactor, clock, hs): + def prepare(self, reactor, clock, hs: TestHomeServer): # We can't test the RoomMemberStore on its own without the other event # storage logic @@ -114,6 +116,52 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): # It now knows about Charlie's server. self.assertEqual(self.store._known_servers_count, 2) + def test_get_joined_users_from_context(self): + room = self.helper.create_room_as(self.u_alice, tok=self.t_alice) + bob_event = event_injection.inject_member_event( + self.hs, room, self.u_bob, Membership.JOIN + ) + + # first, create a regular event + event, context = event_injection.create_event( + self.hs, + room_id=room, + sender=self.u_alice, + prev_event_ids=[bob_event.event_id], + type="m.test.1", + content={}, + ) + + users = self.get_success( + self.store.get_joined_users_from_context(event, context) + ) + self.assertEqual(users.keys(), {self.u_alice, self.u_bob}) + + # Regression test for #7376: create a state event whose key matches bob's + # user_id, but which is *not* a membership event, and persist that; then check + # that `get_joined_users_from_context` returns the correct users for the next event. + non_member_event = event_injection.inject_event( + self.hs, + room_id=room, + sender=self.u_bob, + prev_event_ids=[bob_event.event_id], + type="m.test.2", + state_key=self.u_bob, + content={}, + ) + event, context = event_injection.create_event( + self.hs, + room_id=room, + sender=self.u_alice, + prev_event_ids=[non_member_event.event_id], + type="m.test.3", + content={}, + ) + users = self.get_success( + self.store.get_joined_users_from_context(event, context) + ) + self.assertEqual(users.keys(), {self.u_alice, self.u_bob}) + class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, homeserver): diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py index 6c2351cf55..f2def601fb 100644 --- a/tests/test_event_auth.py +++ b/tests/test_event_auth.py @@ -165,6 +165,39 @@ class EventAuthTestCase(unittest.TestCase): do_sig_check=False, ) + def test_msc2209(self): + """ + Notifications power levels get checked due to MSC2209. + """ + creator = "@creator:example.com" + pleb = "@joiner:example.com" + + auth_events = { + ("m.room.create", ""): _create_event(creator), + ("m.room.member", creator): _join_event(creator), + ("m.room.power_levels", ""): _power_levels_event( + creator, {"state_default": "30", "users": {pleb: "30"}} + ), + ("m.room.member", pleb): _join_event(pleb), + } + + # pleb should be able to modify the notifications power level. + event_auth.check( + RoomVersions.V1, + _power_levels_event(pleb, {"notifications": {"room": 100}}), + auth_events, + do_sig_check=False, + ) + + # But an MSC2209 room rejects this change. + with self.assertRaises(AuthError): + event_auth.check( + RoomVersions.MSC2209_DEV, + _power_levels_event(pleb, {"notifications": {"room": 100}}), + auth_events, + do_sig_check=False, + ) + # helpers for making events diff --git a/tests/test_federation.py b/tests/test_federation.py index 9b5cf562f3..f297de95f1 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -27,8 +27,10 @@ class MessageAcceptTests(unittest.TestCase): user_id = UserID("us", "test") our_user = Requester(user_id, None, False, None, None) room_creator = self.homeserver.get_room_creation_handler() - room = room_creator.create_room( - our_user, room_creator.PRESETS_DICT["public_chat"], ratelimit=False + room = ensureDeferred( + room_creator.create_room( + our_user, room_creator.PRESETS_DICT["public_chat"], ratelimit=False + ) ) self.reactor.advance(0.1) self.room_id = self.successResultOf(room)["room_id"] diff --git a/tests/test_mau.py b/tests/test_mau.py index 1fbe0d51ff..eb159e3ba5 100644 --- a/tests/test_mau.py +++ b/tests/test_mau.py @@ -19,6 +19,7 @@ import json from mock import Mock +from synapse.api.auth_blocking import AuthBlocking from synapse.api.constants import LoginType from synapse.api.errors import Codes, HttpResponseException, SynapseError from synapse.rest.client.v2_alpha import register, sync @@ -45,11 +46,17 @@ class TestMauLimit(unittest.HomeserverTestCase): self.hs.config.limit_usage_by_mau = True self.hs.config.hs_disabled = False self.hs.config.max_mau_value = 2 - self.hs.config.mau_trial_days = 0 self.hs.config.server_notices_mxid = "@server:red" self.hs.config.server_notices_mxid_display_name = None self.hs.config.server_notices_mxid_avatar_url = None self.hs.config.server_notices_room_name = "Test Server Notice Room" + self.hs.config.mau_trial_days = 0 + + # AuthBlocking reads config options during hs creation. Recreate the + # hs' copy of AuthBlocking after we've updated config values above + self.auth_blocking = AuthBlocking(self.hs) + self.hs.get_auth()._auth_blocking = self.auth_blocking + return self.hs def test_simple_deny_mau(self): @@ -121,6 +128,7 @@ class TestMauLimit(unittest.HomeserverTestCase): self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) def test_trial_users_cant_come_back(self): + self.auth_blocking._mau_trial_days = 1 self.hs.config.mau_trial_days = 1 # We should be able to register more than the limit initially @@ -169,8 +177,8 @@ class TestMauLimit(unittest.HomeserverTestCase): self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) def test_tracked_but_not_limited(self): - self.hs.config.max_mau_value = 1 # should not matter - self.hs.config.limit_usage_by_mau = False + self.auth_blocking._max_mau_value = 1 # should not matter + self.auth_blocking._limit_usage_by_mau = False self.hs.config.mau_stats_only = True # Simply being able to create 2 users indicates that the diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 270f853d60..f5f63d8ed6 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -15,6 +15,7 @@ # limitations under the License. from synapse.metrics import REGISTRY, InFlightGauge, generate_latest +from synapse.util.caches.descriptors import Cache from tests import unittest @@ -129,3 +130,36 @@ class BuildInfoTests(unittest.TestCase): self.assertTrue(b"osversion=" in items[0]) self.assertTrue(b"pythonversion=" in items[0]) self.assertTrue(b"version=" in items[0]) + + +class CacheMetricsTests(unittest.HomeserverTestCase): + def test_cache_metric(self): + """ + Caches produce metrics reflecting their state when scraped. + """ + CACHE_NAME = "cache_metrics_test_fgjkbdfg" + cache = Cache(CACHE_NAME, max_entries=777) + + items = { + x.split(b"{")[0].decode("ascii"): x.split(b" ")[1].decode("ascii") + for x in filter( + lambda x: b"cache_metrics_test_fgjkbdfg" in x, + generate_latest(REGISTRY).split(b"\n"), + ) + } + + self.assertEqual(items["synapse_util_caches_cache_size"], "0.0") + self.assertEqual(items["synapse_util_caches_cache_max_size"], "777.0") + + cache.prefill("1", "hi") + + items = { + x.split(b"{")[0].decode("ascii"): x.split(b" ")[1].decode("ascii") + for x in filter( + lambda x: b"cache_metrics_test_fgjkbdfg" in x, + generate_latest(REGISTRY).split(b"\n"), + ) + } + + self.assertEqual(items["synapse_util_caches_cache_size"], "1.0") + self.assertEqual(items["synapse_util_caches_cache_max_size"], "777.0") diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py index 5ec5d2b358..5c2817cf28 100644 --- a/tests/test_terms_auth.py +++ b/tests/test_terms_auth.py @@ -28,8 +28,8 @@ from tests import unittest class TermsTestCase(unittest.HomeserverTestCase): servlets = [register_servlets] - def default_config(self, name="test"): - config = super().default_config(name) + def default_config(self): + config = super().default_config() config.update( { "public_baseurl": "https://example.org/", @@ -53,7 +53,8 @@ class TermsTestCase(unittest.HomeserverTestCase): def test_ui_auth(self): # Do a UI auth request - request, channel = self.make_request(b"POST", self.url, b"{}") + request_data = json.dumps({"username": "kermit", "password": "monkey"}) + request, channel = self.make_request(b"POST", self.url, request_data) self.render(request) self.assertEquals(channel.result["code"], b"401", channel.result) diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py index a7310cf12a..7b345b03bb 100644 --- a/tests/test_utils/__init__.py +++ b/tests/test_utils/__init__.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2019 New Vector Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,3 +17,22 @@ """ Utilities for running the unit tests """ +from typing import Awaitable, TypeVar + +TV = TypeVar("TV") + + +def get_awaitable_result(awaitable: Awaitable[TV]) -> TV: + """Get the result from an Awaitable which should have completed + + Asserts that the given awaitable has a result ready, and returns its value + """ + i = awaitable.__await__() + try: + next(i) + except StopIteration as e: + # awaitable returned a result + return e.value + + # if next didn't raise, the awaitable hasn't completed. + raise Exception("awaitable has not yet completed") diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py new file mode 100644 index 0000000000..431e9f8e5e --- /dev/null +++ b/tests/test_utils/event_injection.py @@ -0,0 +1,110 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple + +import synapse.server +from synapse.api.constants import EventTypes +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS +from synapse.events import EventBase +from synapse.events.snapshot import EventContext +from synapse.types import Collection + +from tests.test_utils import get_awaitable_result + + +""" +Utility functions for poking events into the storage of the server under test. +""" + + +def inject_member_event( + hs: synapse.server.HomeServer, + room_id: str, + sender: str, + membership: str, + target: Optional[str] = None, + extra_content: Optional[dict] = None, + **kwargs +) -> EventBase: + """Inject a membership event into a room.""" + if target is None: + target = sender + + content = {"membership": membership} + if extra_content: + content.update(extra_content) + + return inject_event( + hs, + room_id=room_id, + type=EventTypes.Member, + sender=sender, + state_key=target, + content=content, + **kwargs + ) + + +def inject_event( + hs: synapse.server.HomeServer, + room_version: Optional[str] = None, + prev_event_ids: Optional[Collection[str]] = None, + **kwargs +) -> EventBase: + """Inject a generic event into a room + + Args: + hs: the homeserver under test + room_version: the version of the room we're inserting into. + if not specified, will be looked up + prev_event_ids: prev_events for the event. If not specified, will be looked up + kwargs: fields for the event to be created + """ + test_reactor = hs.get_reactor() + + event, context = create_event(hs, room_version, prev_event_ids, **kwargs) + + d = hs.get_storage().persistence.persist_event(event, context) + test_reactor.advance(0) + get_awaitable_result(d) + + return event + + +def create_event( + hs: synapse.server.HomeServer, + room_version: Optional[str] = None, + prev_event_ids: Optional[Collection[str]] = None, + **kwargs +) -> Tuple[EventBase, EventContext]: + test_reactor = hs.get_reactor() + + if room_version is None: + d = hs.get_datastore().get_room_version_id(kwargs["room_id"]) + test_reactor.advance(0) + room_version = get_awaitable_result(d) + + builder = hs.get_event_builder_factory().for_room_version( + KNOWN_ROOM_VERSIONS[room_version], kwargs + ) + d = hs.get_event_creation_handler().create_new_client_event( + builder, prev_event_ids=prev_event_ids + ) + test_reactor.advance(0) + event, context = get_awaitable_result(d) + + return event, context diff --git a/tests/unittest.py b/tests/unittest.py index 8816a4d152..6b6f224e9c 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -32,13 +32,17 @@ from twisted.python.threadpool import ThreadPool from twisted.trial import unittest from synapse.api.constants import EventTypes, Membership -from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.config.homeserver import HomeServerConfig from synapse.config.ratelimiting import FederationRateLimitConfig from synapse.federation.transport import server as federation_server from synapse.http.server import JsonResource from synapse.http.site import SynapseRequest, SynapseSite -from synapse.logging.context import LoggingContext +from synapse.logging.context import ( + SENTINEL_CONTEXT, + LoggingContext, + current_context, + set_current_context, +) from synapse.server import HomeServer from synapse.types import Requester, UserID, create_requester from synapse.util.ratelimitutils import FederationRateLimiter @@ -50,6 +54,7 @@ from tests.server import ( render, setup_test_homeserver, ) +from tests.test_utils import event_injection from tests.test_utils.logging_setup import setup_logging from tests.utils import default_config, setupdb @@ -97,10 +102,10 @@ class TestCase(unittest.TestCase): def setUp(orig): # if we're not starting in the sentinel logcontext, then to be honest # all future bets are off. - if LoggingContext.current_context() is not LoggingContext.sentinel: + if current_context(): self.fail( "Test starting with non-sentinel logging context %s" - % (LoggingContext.current_context(),) + % (current_context(),) ) old_level = logging.getLogger().level @@ -122,7 +127,7 @@ class TestCase(unittest.TestCase): # force a GC to workaround problems with deferreds leaking logcontexts when # they are GCed (see the logcontext docs) gc.collect() - LoggingContext.set_current_context(LoggingContext.sentinel) + set_current_context(SENTINEL_CONTEXT) return ret @@ -311,14 +316,11 @@ class HomeserverTestCase(TestCase): return resource - def default_config(self, name="test"): + def default_config(self): """ Get a default HomeServer config dict. - - Args: - name (str): The homeserver name/domain. """ - config = default_config(name) + config = default_config("test") # apply any additional config which was specified via the override_config # decorator. @@ -418,15 +420,17 @@ class HomeserverTestCase(TestCase): config_obj.parse_config_dict(config, "", "") kwargs["config"] = config_obj + async def run_bg_updates(): + with LoggingContext("run_bg_updates", request="run_bg_updates-1"): + while not await stor.db.updates.has_completed_background_updates(): + await stor.db.updates.do_next_background_update(1) + hs = setup_test_homeserver(self.addCleanup, *args, **kwargs) stor = hs.get_datastore() # Run the database background updates, when running against "master". if hs.__class__.__name__ == "TestHomeServer": - while not self.get_success( - stor.db.updates.has_completed_background_updates() - ): - self.get_success(stor.db.updates.do_next_background_update(1)) + self.get_success(run_bg_updates()) return hs @@ -493,6 +497,7 @@ class HomeserverTestCase(TestCase): "password": password, "admin": admin, "mac": want_mac, + "inhibit_login": True, } ) request, channel = self.make_request( @@ -591,36 +596,14 @@ class HomeserverTestCase(TestCase): """ Inject a membership event into a room. + Deprecated: use event_injection.inject_room_member directly + Args: room: Room ID to inject the event into. user: MXID of the user to inject the membership for. membership: The membership type. """ - event_builder_factory = self.hs.get_event_builder_factory() - event_creation_handler = self.hs.get_event_creation_handler() - - room_version = self.get_success( - self.hs.get_datastore().get_room_version_id(room) - ) - - builder = event_builder_factory.for_room_version( - KNOWN_ROOM_VERSIONS[room_version], - { - "type": EventTypes.Member, - "sender": user, - "state_key": user, - "room_id": room, - "content": {"membership": membership}, - }, - ) - - event, context = self.get_success( - event_creation_handler.create_new_client_event(builder) - ) - - self.get_success( - self.hs.get_storage().persistence.persist_event(event, context) - ) + event_injection.inject_member_event(self.hs, room, user, membership) class FederatingHomeserverTestCase(HomeserverTestCase): diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 39e360fe24..4d2b9e0d64 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -22,8 +22,10 @@ from twisted.internet import defer, reactor from synapse.api.errors import SynapseError from synapse.logging.context import ( + SENTINEL_CONTEXT, LoggingContext, PreserveLoggingContext, + current_context, make_deferred_yieldable, ) from synapse.util.caches import descriptors @@ -194,7 +196,7 @@ class DescriptorTestCase(unittest.TestCase): with LoggingContext() as c1: c1.name = "c1" r = yield obj.fn(1) - self.assertEqual(LoggingContext.current_context(), c1) + self.assertEqual(current_context(), c1) return r def check_result(r): @@ -204,12 +206,12 @@ class DescriptorTestCase(unittest.TestCase): # set off a deferred which will do a cache lookup d1 = do_lookup() - self.assertEqual(LoggingContext.current_context(), LoggingContext.sentinel) + self.assertEqual(current_context(), SENTINEL_CONTEXT) d1.addCallback(check_result) # and another d2 = do_lookup() - self.assertEqual(LoggingContext.current_context(), LoggingContext.sentinel) + self.assertEqual(current_context(), SENTINEL_CONTEXT) d2.addCallback(check_result) # let the lookup complete @@ -239,14 +241,14 @@ class DescriptorTestCase(unittest.TestCase): try: d = obj.fn(1) self.assertEqual( - LoggingContext.current_context(), LoggingContext.sentinel + current_context(), SENTINEL_CONTEXT, ) yield d self.fail("No exception thrown") except SynapseError: pass - self.assertEqual(LoggingContext.current_context(), c1) + self.assertEqual(current_context(), c1) # the cache should now be empty self.assertEqual(len(obj.fn.cache.cache), 0) @@ -255,7 +257,7 @@ class DescriptorTestCase(unittest.TestCase): # set off a deferred which will do a cache lookup d1 = do_lookup() - self.assertEqual(LoggingContext.current_context(), LoggingContext.sentinel) + self.assertEqual(current_context(), SENTINEL_CONTEXT) return d1 @@ -366,10 +368,10 @@ class CachedListDescriptorTestCase(unittest.TestCase): @descriptors.cachedList("fn", "args1", inlineCallbacks=True) def list_fn(self, args1, arg2): - assert LoggingContext.current_context().request == "c1" + assert current_context().request == "c1" # we want this to behave like an asynchronous function yield run_on_reactor() - assert LoggingContext.current_context().request == "c1" + assert current_context().request == "c1" return self.mock(args1, arg2) with LoggingContext() as c1: @@ -377,9 +379,9 @@ class CachedListDescriptorTestCase(unittest.TestCase): obj = Cls() obj.mock.return_value = {10: "fish", 20: "chips"} d1 = obj.list_fn([10, 20], 2) - self.assertEqual(LoggingContext.current_context(), LoggingContext.sentinel) + self.assertEqual(current_context(), SENTINEL_CONTEXT) r = yield d1 - self.assertEqual(LoggingContext.current_context(), c1) + self.assertEqual(current_context(), c1) obj.mock.assert_called_once_with([10, 20], 2) self.assertEqual(r, {10: "fish", 20: "chips"}) obj.mock.reset_mock() diff --git a/tests/util/test_async_utils.py b/tests/util/test_async_utils.py index f60918069a..17fd86d02d 100644 --- a/tests/util/test_async_utils.py +++ b/tests/util/test_async_utils.py @@ -16,7 +16,12 @@ from twisted.internet import defer from twisted.internet.defer import CancelledError, Deferred from twisted.internet.task import Clock -from synapse.logging.context import LoggingContext, PreserveLoggingContext +from synapse.logging.context import ( + SENTINEL_CONTEXT, + LoggingContext, + PreserveLoggingContext, + current_context, +) from synapse.util.async_helpers import timeout_deferred from tests.unittest import TestCase @@ -79,10 +84,10 @@ class TimeoutDeferredTest(TestCase): # the errbacks should be run in the test logcontext def errback(res, deferred_name): self.assertIs( - LoggingContext.current_context(), + current_context(), context_one, "errback %s run in unexpected logcontext %s" - % (deferred_name, LoggingContext.current_context()), + % (deferred_name, current_context()), ) return res @@ -90,7 +95,7 @@ class TimeoutDeferredTest(TestCase): original_deferred.addErrback(errback, "orig") timing_out_d = timeout_deferred(original_deferred, 1.0, self.clock) self.assertNoResult(timing_out_d) - self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel) + self.assertIs(current_context(), SENTINEL_CONTEXT) timing_out_d.addErrback(errback, "timingout") self.clock.pump((1.0,)) @@ -99,4 +104,4 @@ class TimeoutDeferredTest(TestCase): blocking_was_cancelled[0], "non-completing deferred was not cancelled" ) self.failureResultOf(timing_out_d, defer.TimeoutError) - self.assertIs(LoggingContext.current_context(), context_one) + self.assertIs(current_context(), context_one) diff --git a/tests/util/test_expiring_cache.py b/tests/util/test_expiring_cache.py index 50bc7702d2..49ffeebd0e 100644 --- a/tests/util/test_expiring_cache.py +++ b/tests/util/test_expiring_cache.py @@ -21,7 +21,7 @@ from tests.utils import MockClock from .. import unittest -class ExpiringCacheTestCase(unittest.TestCase): +class ExpiringCacheTestCase(unittest.HomeserverTestCase): def test_get_set(self): clock = MockClock() cache = ExpiringCache("test", clock, max_len=1) diff --git a/tests/util/test_linearizer.py b/tests/util/test_linearizer.py index 0ec8ef90ce..852ef23185 100644 --- a/tests/util/test_linearizer.py +++ b/tests/util/test_linearizer.py @@ -19,7 +19,7 @@ from six.moves import range from twisted.internet import defer, reactor from twisted.internet.defer import CancelledError -from synapse.logging.context import LoggingContext +from synapse.logging.context import LoggingContext, current_context from synapse.util import Clock from synapse.util.async_helpers import Linearizer @@ -54,11 +54,11 @@ class LinearizerTestCase(unittest.TestCase): def func(i, sleep=False): with LoggingContext("func(%s)" % i) as lc: with (yield linearizer.queue("")): - self.assertEqual(LoggingContext.current_context(), lc) + self.assertEqual(current_context(), lc) if sleep: yield Clock(reactor).sleep(0) - self.assertEqual(LoggingContext.current_context(), lc) + self.assertEqual(current_context(), lc) func(0, sleep=True) for i in range(1, 100): diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py index 281b32c4b8..95301c013c 100644 --- a/tests/util/test_logcontext.py +++ b/tests/util/test_logcontext.py @@ -2,8 +2,10 @@ import twisted.python.failure from twisted.internet import defer, reactor from synapse.logging.context import ( + SENTINEL_CONTEXT, LoggingContext, PreserveLoggingContext, + current_context, make_deferred_yieldable, nested_logging_context, run_in_background, @@ -15,7 +17,7 @@ from .. import unittest class LoggingContextTestCase(unittest.TestCase): def _check_test_key(self, value): - self.assertEquals(LoggingContext.current_context().request, value) + self.assertEquals(current_context().request, value) def test_with_context(self): with LoggingContext() as context_one: @@ -41,7 +43,7 @@ class LoggingContextTestCase(unittest.TestCase): self._check_test_key("one") def _test_run_in_background(self, function): - sentinel_context = LoggingContext.current_context() + sentinel_context = current_context() callback_completed = [False] @@ -71,7 +73,7 @@ class LoggingContextTestCase(unittest.TestCase): # make sure that the context was reset before it got thrown back # into the reactor try: - self.assertIs(LoggingContext.current_context(), sentinel_context) + self.assertIs(current_context(), sentinel_context) d2.callback(None) except BaseException: d2.errback(twisted.python.failure.Failure()) @@ -108,7 +110,7 @@ class LoggingContextTestCase(unittest.TestCase): async def testfunc(): self._check_test_key("one") d = Clock(reactor).sleep(0) - self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel) + self.assertIs(current_context(), SENTINEL_CONTEXT) await d self._check_test_key("one") @@ -129,14 +131,14 @@ class LoggingContextTestCase(unittest.TestCase): reactor.callLater(0, d.callback, None) return d - sentinel_context = LoggingContext.current_context() + sentinel_context = current_context() with LoggingContext() as context_one: context_one.request = "one" d1 = make_deferred_yieldable(blocking_function()) # make sure that the context was reset by make_deferred_yieldable - self.assertIs(LoggingContext.current_context(), sentinel_context) + self.assertIs(current_context(), sentinel_context) yield d1 @@ -145,14 +147,14 @@ class LoggingContextTestCase(unittest.TestCase): @defer.inlineCallbacks def test_make_deferred_yieldable_with_chained_deferreds(self): - sentinel_context = LoggingContext.current_context() + sentinel_context = current_context() with LoggingContext() as context_one: context_one.request = "one" d1 = make_deferred_yieldable(_chained_deferred_function()) # make sure that the context was reset by make_deferred_yieldable - self.assertIs(LoggingContext.current_context(), sentinel_context) + self.assertIs(current_context(), sentinel_context) yield d1 @@ -189,14 +191,14 @@ class LoggingContextTestCase(unittest.TestCase): reactor.callLater(0, d.callback, None) await d - sentinel_context = LoggingContext.current_context() + sentinel_context = current_context() with LoggingContext() as context_one: context_one.request = "one" d1 = make_deferred_yieldable(blocking_function()) # make sure that the context was reset by make_deferred_yieldable - self.assertIs(LoggingContext.current_context(), sentinel_context) + self.assertIs(current_context(), sentinel_context) yield d1 diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py index 786947375d..0adb2174af 100644 --- a/tests/util/test_lrucache.py +++ b/tests/util/test_lrucache.py @@ -22,7 +22,7 @@ from synapse.util.caches.treecache import TreeCache from .. import unittest -class LruCacheTestCase(unittest.TestCase): +class LruCacheTestCase(unittest.HomeserverTestCase): def test_get_set(self): cache = LruCache(1) cache["key"] = "value" @@ -84,7 +84,7 @@ class LruCacheTestCase(unittest.TestCase): self.assertEquals(len(cache), 0) -class LruCacheCallbacksTestCase(unittest.TestCase): +class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): def test_get(self): m = Mock() cache = LruCache(1) @@ -233,7 +233,7 @@ class LruCacheCallbacksTestCase(unittest.TestCase): self.assertEquals(m3.call_count, 1) -class LruCacheSizedTestCase(unittest.TestCase): +class LruCacheSizedTestCase(unittest.HomeserverTestCase): def test_evict(self): cache = LruCache(5, size_callback=len) cache["key1"] = [0] diff --git a/tests/util/test_stream_change_cache.py b/tests/util/test_stream_change_cache.py index 72a9de5370..13b753e367 100644 --- a/tests/util/test_stream_change_cache.py +++ b/tests/util/test_stream_change_cache.py @@ -1,11 +1,9 @@ -from mock import patch - from synapse.util.caches.stream_change_cache import StreamChangeCache from tests import unittest -class StreamChangeCacheTests(unittest.TestCase): +class StreamChangeCacheTests(unittest.HomeserverTestCase): """ Tests for StreamChangeCache. """ @@ -28,26 +26,33 @@ class StreamChangeCacheTests(unittest.TestCase): cache.entity_has_changed("user@foo.com", 6) cache.entity_has_changed("bar@baz.net", 7) + # also test multiple things changing on the same stream ID + cache.entity_has_changed("user2@foo.com", 8) + cache.entity_has_changed("bar2@baz.net", 8) + # If it's been changed after that stream position, return True self.assertTrue(cache.has_entity_changed("user@foo.com", 4)) self.assertTrue(cache.has_entity_changed("bar@baz.net", 4)) + self.assertTrue(cache.has_entity_changed("bar2@baz.net", 4)) + self.assertTrue(cache.has_entity_changed("user2@foo.com", 4)) # If it's been changed at that stream position, return False self.assertFalse(cache.has_entity_changed("user@foo.com", 6)) + self.assertFalse(cache.has_entity_changed("user2@foo.com", 8)) # If there's no changes after that stream position, return False self.assertFalse(cache.has_entity_changed("user@foo.com", 7)) + self.assertFalse(cache.has_entity_changed("user2@foo.com", 9)) # If the entity does not exist, return False. - self.assertFalse(cache.has_entity_changed("not@here.website", 7)) + self.assertFalse(cache.has_entity_changed("not@here.website", 9)) # If we request before the stream cache's earliest known position, # return True, whether it's a known entity or not. self.assertTrue(cache.has_entity_changed("user@foo.com", 0)) self.assertTrue(cache.has_entity_changed("not@here.website", 0)) - @patch("synapse.util.caches.CACHE_SIZE_FACTOR", 1.0) - def test_has_entity_changed_pops_off_start(self): + def test_entity_has_changed_pops_off_start(self): """ StreamChangeCache.entity_has_changed will respect the max size and purge the oldest items upon reaching that max size. @@ -64,11 +69,20 @@ class StreamChangeCacheTests(unittest.TestCase): # The oldest item has been popped off self.assertTrue("user@foo.com" not in cache._entity_to_key) + self.assertEqual( + cache.get_all_entities_changed(2), ["bar@baz.net", "user@elsewhere.org"], + ) + self.assertIsNone(cache.get_all_entities_changed(1)) + # If we update an existing entity, it keeps the two existing entities cache.entity_has_changed("bar@baz.net", 5) self.assertEqual( {"bar@baz.net", "user@elsewhere.org"}, set(cache._entity_to_key) ) + self.assertEqual( + cache.get_all_entities_changed(2), ["user@elsewhere.org", "bar@baz.net"], + ) + self.assertIsNone(cache.get_all_entities_changed(1)) def test_get_all_entities_changed(self): """ @@ -80,18 +94,52 @@ class StreamChangeCacheTests(unittest.TestCase): cache.entity_has_changed("user@foo.com", 2) cache.entity_has_changed("bar@baz.net", 3) + cache.entity_has_changed("anotheruser@foo.com", 3) cache.entity_has_changed("user@elsewhere.org", 4) - self.assertEqual( - cache.get_all_entities_changed(1), - ["user@foo.com", "bar@baz.net", "user@elsewhere.org"], - ) - self.assertEqual( - cache.get_all_entities_changed(2), ["bar@baz.net", "user@elsewhere.org"] - ) + r = cache.get_all_entities_changed(1) + + # either of these are valid + ok1 = [ + "user@foo.com", + "bar@baz.net", + "anotheruser@foo.com", + "user@elsewhere.org", + ] + ok2 = [ + "user@foo.com", + "anotheruser@foo.com", + "bar@baz.net", + "user@elsewhere.org", + ] + self.assertTrue(r == ok1 or r == ok2) + + r = cache.get_all_entities_changed(2) + self.assertTrue(r == ok1[1:] or r == ok2[1:]) + self.assertEqual(cache.get_all_entities_changed(3), ["user@elsewhere.org"]) self.assertEqual(cache.get_all_entities_changed(0), None) + # ... later, things gest more updates + cache.entity_has_changed("user@foo.com", 5) + cache.entity_has_changed("bar@baz.net", 5) + cache.entity_has_changed("anotheruser@foo.com", 6) + + ok1 = [ + "user@elsewhere.org", + "user@foo.com", + "bar@baz.net", + "anotheruser@foo.com", + ] + ok2 = [ + "user@elsewhere.org", + "bar@baz.net", + "user@foo.com", + "anotheruser@foo.com", + ] + r = cache.get_all_entities_changed(3) + self.assertTrue(r == ok1 or r == ok2) + def test_has_any_entity_changed(self): """ StreamChangeCache.has_any_entity_changed will return True if any diff --git a/tests/utils.py b/tests/utils.py index 513f358f4f..59c020a051 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -35,7 +35,7 @@ from synapse.config.homeserver import HomeServerConfig from synapse.config.server import DEFAULT_ROOM_VERSION from synapse.federation.transport import server as federation_server from synapse.http.server import HttpServer -from synapse.logging.context import LoggingContext +from synapse.logging.context import current_context, set_current_context from synapse.server import HomeServer from synapse.storage import DataStore from synapse.storage.engines import PostgresEngine, create_engine @@ -74,7 +74,10 @@ def setupdb(): db_conn.autocommit = True cur = db_conn.cursor() cur.execute("DROP DATABASE IF EXISTS %s;" % (POSTGRES_BASE_DB,)) - cur.execute("CREATE DATABASE %s;" % (POSTGRES_BASE_DB,)) + cur.execute( + "CREATE DATABASE %s ENCODING 'UTF8' LC_COLLATE='C' LC_CTYPE='C' " + "template=template0;" % (POSTGRES_BASE_DB,) + ) cur.close() db_conn.close() @@ -164,6 +167,7 @@ def default_config(name, parse=False): # disable user directory updates, because they get done in the # background, which upsets the test runner. "update_user_directory": False, + "caches": {"global_factor": 1}, } if parse: @@ -332,10 +336,15 @@ def setup_test_homeserver( # Need to let the HS build an auth handler and then mess with it # because AuthHandler's constructor requires the HS, so we can't make one # beforehand and pass it in to the HS's constructor (chicken / egg) - hs.get_auth_handler().hash = lambda p: hashlib.md5(p.encode("utf8")).hexdigest() - hs.get_auth_handler().validate_hash = ( - lambda p, h: hashlib.md5(p.encode("utf8")).hexdigest() == h - ) + async def hash(p): + return hashlib.md5(p.encode("utf8")).hexdigest() + + hs.get_auth_handler().hash = hash + + async def validate_hash(p, h): + return hashlib.md5(p.encode("utf8")).hexdigest() == h + + hs.get_auth_handler().validate_hash = validate_hash fed = kargs.get("resource_for_federation", None) if fed: @@ -493,10 +502,10 @@ class MockClock(object): return self.time() * 1000 def call_later(self, delay, callback, *args, **kwargs): - current_context = LoggingContext.current_context() + ctx = current_context() def wrapped_callback(): - LoggingContext.thread_local.current_context = current_context + set_current_context(ctx) callback(*args, **kwargs) t = [self.now + delay, wrapped_callback, False] @@ -504,8 +513,8 @@ class MockClock(object): return t - def looping_call(self, function, interval): - self.loopers.append([function, interval / 1000.0, self.now]) + def looping_call(self, function, interval, *args, **kwargs): + self.loopers.append([function, interval / 1000.0, self.now, args, kwargs]) def cancel_call_later(self, timer, ignore_errs=False): if timer[2]: @@ -535,9 +544,9 @@ class MockClock(object): self.timers.append(t) for looped in self.loopers: - func, interval, last = looped + func, interval, last, args, kwargs = looped if last + interval < self.now: - func() + func(*args, **kwargs) looped[2] = self.now def advance_time_msec(self, ms): |