diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index 6121efcfa9..0bfb86bf1f 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -52,6 +52,10 @@ class AuthTestCase(unittest.TestCase):
self.hs.handlers = TestHandlers(self.hs)
self.auth = Auth(self.hs)
+ # AuthBlocking reads from the hs' config on initialization. We need to
+ # modify its config instead of the hs'
+ self.auth_blocking = self.auth._auth_blocking
+
self.test_user = "@foo:bar"
self.test_token = b"_test_token_"
@@ -68,7 +72,7 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={})
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- requester = yield self.auth.get_user_by_req(request)
+ requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request))
self.assertEquals(requester.user.to_string(), self.test_user)
def test_get_user_by_req_user_bad_token(self):
@@ -105,7 +109,7 @@ class AuthTestCase(unittest.TestCase):
request.getClientIP.return_value = "127.0.0.1"
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- requester = yield self.auth.get_user_by_req(request)
+ requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request))
self.assertEquals(requester.user.to_string(), self.test_user)
@defer.inlineCallbacks
@@ -125,7 +129,7 @@ class AuthTestCase(unittest.TestCase):
request.getClientIP.return_value = "192.168.10.10"
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- requester = yield self.auth.get_user_by_req(request)
+ requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request))
self.assertEquals(requester.user.to_string(), self.test_user)
def test_get_user_by_req_appservice_valid_token_bad_ip(self):
@@ -188,7 +192,7 @@ class AuthTestCase(unittest.TestCase):
request.args[b"access_token"] = [self.test_token]
request.args[b"user_id"] = [masquerading_user_id]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- requester = yield self.auth.get_user_by_req(request)
+ requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request))
self.assertEquals(
requester.user.to_string(), masquerading_user_id.decode("utf8")
)
@@ -225,7 +229,9 @@ class AuthTestCase(unittest.TestCase):
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = access")
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
- user_info = yield self.auth.get_user_by_access_token(macaroon.serialize())
+ user_info = yield defer.ensureDeferred(
+ self.auth.get_user_by_access_token(macaroon.serialize())
+ )
user = user_info["user"]
self.assertEqual(UserID.from_string(user_id), user)
@@ -250,7 +256,9 @@ class AuthTestCase(unittest.TestCase):
macaroon.add_first_party_caveat("guest = true")
serialized = macaroon.serialize()
- user_info = yield self.auth.get_user_by_access_token(serialized)
+ user_info = yield defer.ensureDeferred(
+ self.auth.get_user_by_access_token(serialized)
+ )
user = user_info["user"]
is_guest = user_info["is_guest"]
self.assertEqual(UserID.from_string(user_id), user)
@@ -260,10 +268,13 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_cannot_use_regular_token_as_guest(self):
USER_ID = "@percy:matrix.org"
- self.store.add_access_token_to_user = Mock()
+ self.store.add_access_token_to_user = Mock(return_value=defer.succeed(None))
+ self.store.get_device = Mock(return_value=defer.succeed(None))
- token = yield self.hs.handlers.auth_handler.get_access_token_for_user_id(
- USER_ID, "DEVICE", valid_until_ms=None
+ token = yield defer.ensureDeferred(
+ self.hs.handlers.auth_handler.get_access_token_for_user_id(
+ USER_ID, "DEVICE", valid_until_ms=None
+ )
)
self.store.add_access_token_to_user.assert_called_with(
USER_ID, token, "DEVICE", None
@@ -286,7 +297,9 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={})
request.args[b"access_token"] = [token.encode("ascii")]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ requester = yield defer.ensureDeferred(
+ self.auth.get_user_by_req(request, allow_guest=True)
+ )
self.assertEqual(UserID.from_string(USER_ID), requester.user)
self.assertFalse(requester.is_guest)
@@ -301,7 +314,9 @@ class AuthTestCase(unittest.TestCase):
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
with self.assertRaises(InvalidClientCredentialsError) as cm:
- yield self.auth.get_user_by_req(request, allow_guest=True)
+ yield defer.ensureDeferred(
+ self.auth.get_user_by_req(request, allow_guest=True)
+ )
self.assertEqual(401, cm.exception.code)
self.assertEqual("Guest access token used for regular user", cm.exception.msg)
@@ -310,22 +325,22 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_blocking_mau(self):
- self.hs.config.limit_usage_by_mau = False
- self.hs.config.max_mau_value = 50
+ self.auth_blocking._limit_usage_by_mau = False
+ self.auth_blocking._max_mau_value = 50
lots_of_users = 100
small_number_of_users = 1
# Ensure no error thrown
- yield self.auth.check_auth_blocking()
+ yield defer.ensureDeferred(self.auth.check_auth_blocking())
- self.hs.config.limit_usage_by_mau = True
+ self.auth_blocking._limit_usage_by_mau = True
self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(lots_of_users)
)
with self.assertRaises(ResourceLimitError) as e:
- yield self.auth.check_auth_blocking()
+ yield defer.ensureDeferred(self.auth.check_auth_blocking())
self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact)
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.assertEquals(e.exception.code, 403)
@@ -334,49 +349,54 @@ class AuthTestCase(unittest.TestCase):
self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(small_number_of_users)
)
- yield self.auth.check_auth_blocking()
+ yield defer.ensureDeferred(self.auth.check_auth_blocking())
@defer.inlineCallbacks
def test_blocking_mau__depending_on_user_type(self):
- self.hs.config.max_mau_value = 50
- self.hs.config.limit_usage_by_mau = True
+ self.auth_blocking._max_mau_value = 50
+ self.auth_blocking._limit_usage_by_mau = True
self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100))
# Support users allowed
- yield self.auth.check_auth_blocking(user_type=UserTypes.SUPPORT)
+ yield defer.ensureDeferred(
+ self.auth.check_auth_blocking(user_type=UserTypes.SUPPORT)
+ )
self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100))
# Bots not allowed
with self.assertRaises(ResourceLimitError):
- yield self.auth.check_auth_blocking(user_type=UserTypes.BOT)
+ yield defer.ensureDeferred(
+ self.auth.check_auth_blocking(user_type=UserTypes.BOT)
+ )
self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100))
# Real users not allowed
with self.assertRaises(ResourceLimitError):
- yield self.auth.check_auth_blocking()
+ yield defer.ensureDeferred(self.auth.check_auth_blocking())
@defer.inlineCallbacks
def test_reserved_threepid(self):
- self.hs.config.limit_usage_by_mau = True
- self.hs.config.max_mau_value = 1
+ self.auth_blocking._limit_usage_by_mau = True
+ self.auth_blocking._max_mau_value = 1
self.store.get_monthly_active_count = lambda: defer.succeed(2)
threepid = {"medium": "email", "address": "reserved@server.com"}
unknown_threepid = {"medium": "email", "address": "unreserved@server.com"}
- self.hs.config.mau_limits_reserved_threepids = [threepid]
+ self.auth_blocking._mau_limits_reserved_threepids = [threepid]
- yield self.store.register_user(user_id="user1", password_hash=None)
with self.assertRaises(ResourceLimitError):
- yield self.auth.check_auth_blocking()
+ yield defer.ensureDeferred(self.auth.check_auth_blocking())
with self.assertRaises(ResourceLimitError):
- yield self.auth.check_auth_blocking(threepid=unknown_threepid)
+ yield defer.ensureDeferred(
+ self.auth.check_auth_blocking(threepid=unknown_threepid)
+ )
- yield self.auth.check_auth_blocking(threepid=threepid)
+ yield defer.ensureDeferred(self.auth.check_auth_blocking(threepid=threepid))
@defer.inlineCallbacks
def test_hs_disabled(self):
- self.hs.config.hs_disabled = True
- self.hs.config.hs_disabled_message = "Reason for being disabled"
+ self.auth_blocking._hs_disabled = True
+ self.auth_blocking._hs_disabled_message = "Reason for being disabled"
with self.assertRaises(ResourceLimitError) as e:
- yield self.auth.check_auth_blocking()
+ yield defer.ensureDeferred(self.auth.check_auth_blocking())
self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact)
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.assertEquals(e.exception.code, 403)
@@ -388,20 +408,20 @@ class AuthTestCase(unittest.TestCase):
"""
# this should be the default, but we had a bug where the test was doing the wrong
# thing, so let's make it explicit
- self.hs.config.server_notices_mxid = None
+ self.auth_blocking._server_notices_mxid = None
- self.hs.config.hs_disabled = True
- self.hs.config.hs_disabled_message = "Reason for being disabled"
+ self.auth_blocking._hs_disabled = True
+ self.auth_blocking._hs_disabled_message = "Reason for being disabled"
with self.assertRaises(ResourceLimitError) as e:
- yield self.auth.check_auth_blocking()
+ yield defer.ensureDeferred(self.auth.check_auth_blocking())
self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact)
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.assertEquals(e.exception.code, 403)
@defer.inlineCallbacks
def test_server_notices_mxid_special_cased(self):
- self.hs.config.hs_disabled = True
+ self.auth_blocking._hs_disabled = True
user = "@user:server"
- self.hs.config.server_notices_mxid = user
- self.hs.config.hs_disabled_message = "Reason for being disabled"
- yield self.auth.check_auth_blocking(user)
+ self.auth_blocking._server_notices_mxid = user
+ self.auth_blocking._hs_disabled_message = "Reason for being disabled"
+ yield defer.ensureDeferred(self.auth.check_auth_blocking(user))
diff --git a/tests/app/test_frontend_proxy.py b/tests/app/test_frontend_proxy.py
index d3feafa1b7..be20a89682 100644
--- a/tests/app/test_frontend_proxy.py
+++ b/tests/app/test_frontend_proxy.py
@@ -27,8 +27,8 @@ class FrontendProxyTests(HomeserverTestCase):
return hs
- def default_config(self, name="test"):
- c = super().default_config(name)
+ def default_config(self):
+ c = super().default_config()
c["worker_app"] = "synapse.app.frontend_proxy"
return c
diff --git a/tests/app/test_openid_listener.py b/tests/app/test_openid_listener.py
index 89fcc3889a..7364f9f1ec 100644
--- a/tests/app/test_openid_listener.py
+++ b/tests/app/test_openid_listener.py
@@ -29,8 +29,8 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase):
)
return hs
- def default_config(self, name="test"):
- conf = super().default_config(name)
+ def default_config(self):
+ conf = super().default_config()
# we're using FederationReaderServer, which uses a SlavedStore, so we
# have to tell the FederationHandler not to try to access stuff that is only
# in the primary store.
diff --git a/tests/config/test_cache.py b/tests/config/test_cache.py
new file mode 100644
index 0000000000..2920279125
--- /dev/null
+++ b/tests/config/test_cache.py
@@ -0,0 +1,127 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.config._base import Config, RootConfig
+from synapse.config.cache import CacheConfig, add_resizable_cache
+from synapse.util.caches.lrucache import LruCache
+
+from tests.unittest import TestCase
+
+
+class FakeServer(Config):
+ section = "server"
+
+
+class TestConfig(RootConfig):
+ config_classes = [FakeServer, CacheConfig]
+
+
+class CacheConfigTests(TestCase):
+ def setUp(self):
+ # Reset caches before each test
+ TestConfig().caches.reset()
+
+ def test_individual_caches_from_environ(self):
+ """
+ Individual cache factors will be loaded from the environment.
+ """
+ config = {}
+ t = TestConfig()
+ t.caches._environ = {
+ "SYNAPSE_CACHE_FACTOR_SOMETHING_OR_OTHER": "2",
+ "SYNAPSE_NOT_CACHE": "BLAH",
+ }
+ t.read_config(config, config_dir_path="", data_dir_path="")
+
+ self.assertEqual(dict(t.caches.cache_factors), {"something_or_other": 2.0})
+
+ def test_config_overrides_environ(self):
+ """
+ Individual cache factors defined in the environment will take precedence
+ over those in the config.
+ """
+ config = {"caches": {"per_cache_factors": {"foo": 2, "bar": 3}}}
+ t = TestConfig()
+ t.caches._environ = {
+ "SYNAPSE_CACHE_FACTOR_SOMETHING_OR_OTHER": "2",
+ "SYNAPSE_CACHE_FACTOR_FOO": 1,
+ }
+ t.read_config(config, config_dir_path="", data_dir_path="")
+
+ self.assertEqual(
+ dict(t.caches.cache_factors),
+ {"foo": 1.0, "bar": 3.0, "something_or_other": 2.0},
+ )
+
+ def test_individual_instantiated_before_config_load(self):
+ """
+ If a cache is instantiated before the config is read, it will be given
+ the default cache size in the interim, and then resized once the config
+ is loaded.
+ """
+ cache = LruCache(100)
+
+ add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor)
+ self.assertEqual(cache.max_size, 50)
+
+ config = {"caches": {"per_cache_factors": {"foo": 3}}}
+ t = TestConfig()
+ t.read_config(config, config_dir_path="", data_dir_path="")
+
+ self.assertEqual(cache.max_size, 300)
+
+ def test_individual_instantiated_after_config_load(self):
+ """
+ If a cache is instantiated after the config is read, it will be
+ immediately resized to the correct size given the per_cache_factor if
+ there is one.
+ """
+ config = {"caches": {"per_cache_factors": {"foo": 2}}}
+ t = TestConfig()
+ t.read_config(config, config_dir_path="", data_dir_path="")
+
+ cache = LruCache(100)
+ add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor)
+ self.assertEqual(cache.max_size, 200)
+
+ def test_global_instantiated_before_config_load(self):
+ """
+ If a cache is instantiated before the config is read, it will be given
+ the default cache size in the interim, and then resized to the new
+ default cache size once the config is loaded.
+ """
+ cache = LruCache(100)
+ add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor)
+ self.assertEqual(cache.max_size, 50)
+
+ config = {"caches": {"global_factor": 4}}
+ t = TestConfig()
+ t.read_config(config, config_dir_path="", data_dir_path="")
+
+ self.assertEqual(cache.max_size, 400)
+
+ def test_global_instantiated_after_config_load(self):
+ """
+ If a cache is instantiated after the config is read, it will be
+ immediately resized to the correct size given the global factor if there
+ is no per-cache factor.
+ """
+ config = {"caches": {"global_factor": 1.5}}
+ t = TestConfig()
+ t.read_config(config, config_dir_path="", data_dir_path="")
+
+ cache = LruCache(100)
+ add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor)
+ self.assertEqual(cache.max_size, 150)
diff --git a/tests/config/test_database.py b/tests/config/test_database.py
index 151d3006ac..f675bde68e 100644
--- a/tests/config/test_database.py
+++ b/tests/config/test_database.py
@@ -21,9 +21,9 @@ from tests import unittest
class DatabaseConfigTestCase(unittest.TestCase):
- def test_database_configured_correctly_no_database_conf_param(self):
+ def test_database_configured_correctly(self):
conf = yaml.safe_load(
- DatabaseConfig().generate_config_section("/data_dir_path", None)
+ DatabaseConfig().generate_config_section(data_dir_path="/data_dir_path")
)
expected_database_conf = {
@@ -32,21 +32,3 @@ class DatabaseConfigTestCase(unittest.TestCase):
}
self.assertEqual(conf["database"], expected_database_conf)
-
- def test_database_configured_correctly_database_conf_param(self):
-
- database_conf = {
- "name": "my super fast datastore",
- "args": {
- "user": "matrix",
- "password": "synapse_database_password",
- "host": "synapse_database_host",
- "database": "matrix",
- },
- }
-
- conf = yaml.safe_load(
- DatabaseConfig().generate_config_section("/data_dir_path", database_conf)
- )
-
- self.assertEqual(conf["database"], database_conf)
diff --git a/tests/config/test_load.py b/tests/config/test_load.py
index b3e557bd6a..734a9983e8 100644
--- a/tests/config/test_load.py
+++ b/tests/config/test_load.py
@@ -122,7 +122,7 @@ class ConfigLoadingTestCase(unittest.TestCase):
with open(self.file, "r") as f:
contents = f.readlines()
- contents = [l for l in contents if needle not in l]
+ contents = [line for line in contents if needle not in line]
with open(self.file, "w") as f:
f.write("".join(contents))
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 34d5895f18..70c8e72303 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -34,6 +34,7 @@ from synapse.crypto.keyring import (
from synapse.logging.context import (
LoggingContext,
PreserveLoggingContext,
+ current_context,
make_deferred_yieldable,
)
from synapse.storage.keys import FetchKeyResult
@@ -83,9 +84,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
)
def check_context(self, _, expected):
- self.assertEquals(
- getattr(LoggingContext.current_context(), "request", None), expected
- )
+ self.assertEquals(getattr(current_context(), "request", None), expected)
def test_verify_json_objects_for_server_awaits_previous_requests(self):
key1 = signedjson.key.generate_signing_key(1)
@@ -105,7 +104,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
@defer.inlineCallbacks
def get_perspectives(**kwargs):
- self.assertEquals(LoggingContext.current_context().request, "11")
+ self.assertEquals(current_context().request, "11")
with PreserveLoggingContext():
yield persp_deferred
return persp_resp
diff --git a/tests/events/test_snapshot.py b/tests/events/test_snapshot.py
new file mode 100644
index 0000000000..640f5f3bce
--- /dev/null
+++ b/tests/events/test_snapshot.py
@@ -0,0 +1,100 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.events.snapshot import EventContext
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+
+from tests import unittest
+from tests.test_utils.event_injection import create_event
+
+
+class TestEventContext(unittest.HomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+ self.storage = hs.get_storage()
+
+ self.user_id = self.register_user("u1", "pass")
+ self.user_tok = self.login("u1", "pass")
+ self.room_id = self.helper.create_room_as(tok=self.user_tok)
+
+ def test_serialize_deserialize_msg(self):
+ """Test that an EventContext for a message event is the same after
+ serialize/deserialize.
+ """
+
+ event, context = create_event(
+ self.hs, room_id=self.room_id, type="m.test", sender=self.user_id,
+ )
+
+ self._check_serialize_deserialize(event, context)
+
+ def test_serialize_deserialize_state_no_prev(self):
+ """Test that an EventContext for a state event (with not previous entry)
+ is the same after serialize/deserialize.
+ """
+ event, context = create_event(
+ self.hs,
+ room_id=self.room_id,
+ type="m.test",
+ sender=self.user_id,
+ state_key="",
+ )
+
+ self._check_serialize_deserialize(event, context)
+
+ def test_serialize_deserialize_state_prev(self):
+ """Test that an EventContext for a state event (which replaces a
+ previous entry) is the same after serialize/deserialize.
+ """
+ event, context = create_event(
+ self.hs,
+ room_id=self.room_id,
+ type="m.room.member",
+ sender=self.user_id,
+ state_key=self.user_id,
+ content={"membership": "leave"},
+ )
+
+ self._check_serialize_deserialize(event, context)
+
+ def _check_serialize_deserialize(self, event, context):
+ serialized = self.get_success(context.serialize(event, self.store))
+
+ d_context = EventContext.deserialize(self.storage, serialized)
+
+ self.assertEqual(context.state_group, d_context.state_group)
+ self.assertEqual(context.rejected, d_context.rejected)
+ self.assertEqual(
+ context.state_group_before_event, d_context.state_group_before_event
+ )
+ self.assertEqual(context.prev_group, d_context.prev_group)
+ self.assertEqual(context.delta_ids, d_context.delta_ids)
+ self.assertEqual(context.app_service, d_context.app_service)
+
+ self.assertEqual(
+ self.get_success(context.get_current_state_ids()),
+ self.get_success(d_context.get_current_state_ids()),
+ )
+ self.assertEqual(
+ self.get_success(context.get_prev_state_ids()),
+ self.get_success(d_context.get_prev_state_ids()),
+ )
diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py
index 24fa8dbb45..94980733c4 100644
--- a/tests/federation/test_complexity.py
+++ b/tests/federation/test_complexity.py
@@ -33,8 +33,8 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
login.register_servlets,
]
- def default_config(self, name="test"):
- config = super().default_config(name=name)
+ def default_config(self):
+ config = super().default_config()
config["limit_remote_rooms"] = {"enabled": True, "complexity": 0.05}
return config
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index d456267b87..33105576af 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -12,19 +12,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Optional
from mock import Mock
+from signedjson import key, sign
+from signedjson.types import BaseKey, SigningKey
+
from twisted.internet import defer
-from synapse.types import ReadReceipt
+from synapse.rest import admin
+from synapse.rest.client.v1 import login
+from synapse.types import JsonDict, ReadReceipt
from tests.unittest import HomeserverTestCase, override_config
-class FederationSenderTestCases(HomeserverTestCase):
+class FederationSenderReceiptsTestCases(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
- return super(FederationSenderTestCases, self).setup_test_homeserver(
+ return self.setup_test_homeserver(
state_handler=Mock(spec=["get_current_hosts_in_room"]),
federation_transport_client=Mock(spec=["send_transaction"]),
)
@@ -147,3 +153,392 @@ class FederationSenderTestCases(HomeserverTestCase):
}
],
)
+
+
+class FederationSenderDevicesTestCases(HomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ return self.setup_test_homeserver(
+ state_handler=Mock(spec=["get_current_hosts_in_room"]),
+ federation_transport_client=Mock(spec=["send_transaction"]),
+ )
+
+ def default_config(self):
+ c = super().default_config()
+ c["send_federation"] = True
+ return c
+
+ def prepare(self, reactor, clock, hs):
+ # stub out get_current_hosts_in_room
+ mock_state_handler = hs.get_state_handler()
+ mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"]
+
+ # stub out get_users_who_share_room_with_user so that it claims that
+ # `@user2:host2` is in the room
+ def get_users_who_share_room_with_user(user_id):
+ return defer.succeed({"@user2:host2"})
+
+ hs.get_datastore().get_users_who_share_room_with_user = (
+ get_users_who_share_room_with_user
+ )
+
+ # whenever send_transaction is called, record the edu data
+ self.edus = []
+ self.hs.get_federation_transport_client().send_transaction.side_effect = (
+ self.record_transaction
+ )
+
+ def record_transaction(self, txn, json_cb):
+ data = json_cb()
+ self.edus.extend(data["edus"])
+ return defer.succeed({})
+
+ def test_send_device_updates(self):
+ """Basic case: each device update should result in an EDU"""
+ # create a device
+ u1 = self.register_user("user", "pass")
+ self.login(u1, "pass", device_id="D1")
+
+ # expect one edu
+ self.assertEqual(len(self.edus), 1)
+ stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D1", None)
+
+ # a second call should produce no new device EDUs
+ self.hs.get_federation_sender().send_device_messages("host2")
+ self.pump()
+ self.assertEqual(self.edus, [])
+
+ # a second device
+ self.login("user", "pass", device_id="D2")
+
+ self.assertEqual(len(self.edus), 1)
+ self.check_device_update_edu(self.edus.pop(0), u1, "D2", stream_id)
+
+ def test_upload_signatures(self):
+ """Uploading signatures on some devices should produce updates for that user"""
+
+ e2e_handler = self.hs.get_e2e_keys_handler()
+
+ # register two devices
+ u1 = self.register_user("user", "pass")
+ self.login(u1, "pass", device_id="D1")
+ self.login(u1, "pass", device_id="D2")
+
+ # expect two edus
+ self.assertEqual(len(self.edus), 2)
+ stream_id = None
+ stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D1", stream_id)
+ stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D2", stream_id)
+
+ # upload signing keys for each device
+ device1_signing_key = self.generate_and_upload_device_signing_key(u1, "D1")
+ device2_signing_key = self.generate_and_upload_device_signing_key(u1, "D2")
+
+ # expect two more edus
+ self.assertEqual(len(self.edus), 2)
+ stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D1", stream_id)
+ stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D2", stream_id)
+
+ # upload master key and self-signing key
+ master_signing_key = generate_self_id_key()
+ master_key = {
+ "user_id": u1,
+ "usage": ["master"],
+ "keys": {key_id(master_signing_key): encode_pubkey(master_signing_key)},
+ }
+
+ # private key: HvQBbU+hc2Zr+JP1sE0XwBe1pfZZEYtJNPJLZJtS+F8
+ selfsigning_signing_key = generate_self_id_key()
+ selfsigning_key = {
+ "user_id": u1,
+ "usage": ["self_signing"],
+ "keys": {
+ key_id(selfsigning_signing_key): encode_pubkey(selfsigning_signing_key)
+ },
+ }
+ sign.sign_json(selfsigning_key, u1, master_signing_key)
+
+ cross_signing_keys = {
+ "master_key": master_key,
+ "self_signing_key": selfsigning_key,
+ }
+
+ self.get_success(
+ e2e_handler.upload_signing_keys_for_user(u1, cross_signing_keys)
+ )
+
+ # expect signing key update edu
+ self.assertEqual(len(self.edus), 1)
+ self.assertEqual(self.edus.pop(0)["edu_type"], "org.matrix.signing_key_update")
+
+ # sign the devices
+ d1_json = build_device_dict(u1, "D1", device1_signing_key)
+ sign.sign_json(d1_json, u1, selfsigning_signing_key)
+ d2_json = build_device_dict(u1, "D2", device2_signing_key)
+ sign.sign_json(d2_json, u1, selfsigning_signing_key)
+
+ ret = self.get_success(
+ e2e_handler.upload_signatures_for_device_keys(
+ u1, {u1: {"D1": d1_json, "D2": d2_json}},
+ )
+ )
+ self.assertEqual(ret["failures"], {})
+
+ # expect two edus, in one or two transactions. We don't know what order the
+ # devices will be updated.
+ self.assertEqual(len(self.edus), 2)
+ stream_id = None # FIXME: there is a discontinuity in the stream IDs: see #7142
+ for edu in self.edus:
+ self.assertEqual(edu["edu_type"], "m.device_list_update")
+ c = edu["content"]
+ if stream_id is not None:
+ self.assertEqual(c["prev_id"], [stream_id])
+ self.assertGreaterEqual(c["stream_id"], stream_id)
+ stream_id = c["stream_id"]
+ devices = {edu["content"]["device_id"] for edu in self.edus}
+ self.assertEqual({"D1", "D2"}, devices)
+
+ def test_delete_devices(self):
+ """If devices are deleted, that should result in EDUs too"""
+
+ # create devices
+ u1 = self.register_user("user", "pass")
+ self.login("user", "pass", device_id="D1")
+ self.login("user", "pass", device_id="D2")
+ self.login("user", "pass", device_id="D3")
+
+ # expect three edus
+ self.assertEqual(len(self.edus), 3)
+ stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D1", None)
+ stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D2", stream_id)
+ stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D3", stream_id)
+
+ # delete them again
+ self.get_success(
+ self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"])
+ )
+
+ # expect three edus, in an unknown order
+ self.assertEqual(len(self.edus), 3)
+ for edu in self.edus:
+ self.assertEqual(edu["edu_type"], "m.device_list_update")
+ c = edu["content"]
+ self.assertGreaterEqual(
+ c.items(),
+ {"user_id": u1, "prev_id": [stream_id], "deleted": True}.items(),
+ )
+ self.assertGreaterEqual(c["stream_id"], stream_id)
+ stream_id = c["stream_id"]
+ devices = {edu["content"]["device_id"] for edu in self.edus}
+ self.assertEqual({"D1", "D2", "D3"}, devices)
+
+ def test_unreachable_server(self):
+ """If the destination server is unreachable, all the updates should get sent on
+ recovery
+ """
+ mock_send_txn = self.hs.get_federation_transport_client().send_transaction
+ mock_send_txn.side_effect = lambda t, cb: defer.fail("fail")
+
+ # create devices
+ u1 = self.register_user("user", "pass")
+ self.login("user", "pass", device_id="D1")
+ self.login("user", "pass", device_id="D2")
+ self.login("user", "pass", device_id="D3")
+
+ # delete them again
+ self.get_success(
+ self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"])
+ )
+
+ self.assertGreaterEqual(mock_send_txn.call_count, 4)
+
+ # recover the server
+ mock_send_txn.side_effect = self.record_transaction
+ self.hs.get_federation_sender().send_device_messages("host2")
+ self.pump()
+
+ # for each device, there should be a single update
+ self.assertEqual(len(self.edus), 3)
+ stream_id = None
+ for edu in self.edus:
+ self.assertEqual(edu["edu_type"], "m.device_list_update")
+ c = edu["content"]
+ self.assertEqual(c["prev_id"], [stream_id] if stream_id is not None else [])
+ if stream_id is not None:
+ self.assertGreaterEqual(c["stream_id"], stream_id)
+ stream_id = c["stream_id"]
+ devices = {edu["content"]["device_id"] for edu in self.edus}
+ self.assertEqual({"D1", "D2", "D3"}, devices)
+
+ def test_prune_outbound_device_pokes1(self):
+ """If a destination is unreachable, and the updates are pruned, we should get
+ a single update.
+
+ This case tests the behaviour when the server has never been reachable.
+ """
+ mock_send_txn = self.hs.get_federation_transport_client().send_transaction
+ mock_send_txn.side_effect = lambda t, cb: defer.fail("fail")
+
+ # create devices
+ u1 = self.register_user("user", "pass")
+ self.login("user", "pass", device_id="D1")
+ self.login("user", "pass", device_id="D2")
+ self.login("user", "pass", device_id="D3")
+
+ # delete them again
+ self.get_success(
+ self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"])
+ )
+
+ self.assertGreaterEqual(mock_send_txn.call_count, 4)
+
+ # run the prune job
+ self.reactor.advance(10)
+ self.get_success(
+ self.hs.get_datastore()._prune_old_outbound_device_pokes(prune_age=1)
+ )
+
+ # recover the server
+ mock_send_txn.side_effect = self.record_transaction
+ self.hs.get_federation_sender().send_device_messages("host2")
+ self.pump()
+
+ # there should be a single update for this user.
+ self.assertEqual(len(self.edus), 1)
+ edu = self.edus.pop(0)
+ self.assertEqual(edu["edu_type"], "m.device_list_update")
+ c = edu["content"]
+
+ # synapse uses an empty prev_id list to indicate "needs a full resync".
+ self.assertEqual(c["prev_id"], [])
+
+ def test_prune_outbound_device_pokes2(self):
+ """If a destination is unreachable, and the updates are pruned, we should get
+ a single update.
+
+ This case tests the behaviour when the server was reachable, but then goes
+ offline.
+ """
+
+ # create first device
+ u1 = self.register_user("user", "pass")
+ self.login("user", "pass", device_id="D1")
+
+ # expect the update EDU
+ self.assertEqual(len(self.edus), 1)
+ self.check_device_update_edu(self.edus.pop(0), u1, "D1", None)
+
+ # now the server goes offline
+ mock_send_txn = self.hs.get_federation_transport_client().send_transaction
+ mock_send_txn.side_effect = lambda t, cb: defer.fail("fail")
+
+ self.login("user", "pass", device_id="D2")
+ self.login("user", "pass", device_id="D3")
+
+ # delete them again
+ self.get_success(
+ self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"])
+ )
+
+ self.assertGreaterEqual(mock_send_txn.call_count, 3)
+
+ # run the prune job
+ self.reactor.advance(10)
+ self.get_success(
+ self.hs.get_datastore()._prune_old_outbound_device_pokes(prune_age=1)
+ )
+
+ # recover the server
+ mock_send_txn.side_effect = self.record_transaction
+ self.hs.get_federation_sender().send_device_messages("host2")
+ self.pump()
+
+ # ... and we should get a single update for this user.
+ self.assertEqual(len(self.edus), 1)
+ edu = self.edus.pop(0)
+ self.assertEqual(edu["edu_type"], "m.device_list_update")
+ c = edu["content"]
+
+ # synapse uses an empty prev_id list to indicate "needs a full resync".
+ self.assertEqual(c["prev_id"], [])
+
+ def check_device_update_edu(
+ self,
+ edu: JsonDict,
+ user_id: str,
+ device_id: str,
+ prev_stream_id: Optional[int],
+ ) -> int:
+ """Check that the given EDU is an update for the given device
+ Returns the stream_id.
+ """
+ self.assertEqual(edu["edu_type"], "m.device_list_update")
+ content = edu["content"]
+
+ expected = {
+ "user_id": user_id,
+ "device_id": device_id,
+ "prev_id": [prev_stream_id] if prev_stream_id is not None else [],
+ }
+
+ self.assertLessEqual(expected.items(), content.items())
+ if prev_stream_id is not None:
+ self.assertGreaterEqual(content["stream_id"], prev_stream_id)
+ return content["stream_id"]
+
+ def check_signing_key_update_txn(self, txn: JsonDict,) -> None:
+ """Check that the txn has an EDU with a signing key update.
+ """
+ edus = txn["edus"]
+ self.assertEqual(len(edus), 1)
+
+ def generate_and_upload_device_signing_key(
+ self, user_id: str, device_id: str
+ ) -> SigningKey:
+ """Generate a signing keypair for the given device, and upload it"""
+ sk = key.generate_signing_key(device_id)
+
+ device_dict = build_device_dict(user_id, device_id, sk)
+
+ self.get_success(
+ self.hs.get_e2e_keys_handler().upload_keys_for_user(
+ user_id, device_id, {"device_keys": device_dict},
+ )
+ )
+ return sk
+
+
+def generate_self_id_key() -> SigningKey:
+ """generate a signing key whose version is its public key
+
+ ... as used by the cross-signing-keys.
+ """
+ k = key.generate_signing_key("x")
+ k.version = encode_pubkey(k)
+ return k
+
+
+def key_id(k: BaseKey) -> str:
+ return "%s:%s" % (k.alg, k.version)
+
+
+def encode_pubkey(sk: SigningKey) -> str:
+ """Encode the public key corresponding to the given signing key as base64"""
+ return key.encode_verify_key_base64(key.get_verify_key(sk))
+
+
+def build_device_dict(user_id: str, device_id: str, sk: SigningKey):
+ """Build a dict representing the given device"""
+ return {
+ "user_id": user_id,
+ "device_id": device_id,
+ "algorithms": ["m.olm.curve25519-aes-sha256", "m.megolm.v1.aes-sha"],
+ "keys": {
+ "curve25519:" + device_id: "curve25519+key",
+ key_id(sk): encode_pubkey(sk),
+ },
+ }
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index b03103d96f..c01b04e1dc 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -39,8 +39,13 @@ class AuthTestCase(unittest.TestCase):
self.hs.handlers = AuthHandlers(self.hs)
self.auth_handler = self.hs.handlers.auth_handler
self.macaroon_generator = self.hs.get_macaroon_generator()
+
# MAU tests
- self.hs.config.max_mau_value = 50
+ # AuthBlocking reads from the hs' config on initialization. We need to
+ # modify its config instead of the hs'
+ self.auth_blocking = self.hs.get_auth()._auth_blocking
+ self.auth_blocking._max_mau_value = 50
+
self.small_number_of_users = 1
self.large_number_of_users = 100
@@ -82,16 +87,16 @@ class AuthTestCase(unittest.TestCase):
self.hs.clock.now = 1000
token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
- user_id = yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
- token
+ user_id = yield defer.ensureDeferred(
+ self.auth_handler.validate_short_term_login_token_and_get_user_id(token)
)
self.assertEqual("a_user", user_id)
# when we advance the clock, the token should be rejected
self.hs.clock.now = 6000
with self.assertRaises(synapse.api.errors.AuthError):
- yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
- token
+ yield defer.ensureDeferred(
+ self.auth_handler.validate_short_term_login_token_and_get_user_id(token)
)
@defer.inlineCallbacks
@@ -99,8 +104,10 @@ class AuthTestCase(unittest.TestCase):
token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
macaroon = pymacaroons.Macaroon.deserialize(token)
- user_id = yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
- macaroon.serialize()
+ user_id = yield defer.ensureDeferred(
+ self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ macaroon.serialize()
+ )
)
self.assertEqual("a_user", user_id)
@@ -109,99 +116,121 @@ class AuthTestCase(unittest.TestCase):
macaroon.add_first_party_caveat("user_id = b_user")
with self.assertRaises(synapse.api.errors.AuthError):
- yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
- macaroon.serialize()
+ yield defer.ensureDeferred(
+ self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ macaroon.serialize()
+ )
)
@defer.inlineCallbacks
def test_mau_limits_disabled(self):
- self.hs.config.limit_usage_by_mau = False
+ self.auth_blocking._limit_usage_by_mau = False
# Ensure does not throw exception
- yield self.auth_handler.get_access_token_for_user_id(
- "user_a", device_id=None, valid_until_ms=None
+ yield defer.ensureDeferred(
+ self.auth_handler.get_access_token_for_user_id(
+ "user_a", device_id=None, valid_until_ms=None
+ )
)
- yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
- self._get_macaroon().serialize()
+ yield defer.ensureDeferred(
+ self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ self._get_macaroon().serialize()
+ )
)
@defer.inlineCallbacks
def test_mau_limits_exceeded_large(self):
- self.hs.config.limit_usage_by_mau = True
+ self.auth_blocking._limit_usage_by_mau = True
self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.large_number_of_users)
)
with self.assertRaises(ResourceLimitError):
- yield self.auth_handler.get_access_token_for_user_id(
- "user_a", device_id=None, valid_until_ms=None
+ yield defer.ensureDeferred(
+ self.auth_handler.get_access_token_for_user_id(
+ "user_a", device_id=None, valid_until_ms=None
+ )
)
self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.large_number_of_users)
)
with self.assertRaises(ResourceLimitError):
- yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
- self._get_macaroon().serialize()
+ yield defer.ensureDeferred(
+ self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ self._get_macaroon().serialize()
+ )
)
@defer.inlineCallbacks
def test_mau_limits_parity(self):
- self.hs.config.limit_usage_by_mau = True
+ self.auth_blocking._limit_usage_by_mau = True
# If not in monthly active cohort
self.hs.get_datastore().get_monthly_active_count = Mock(
- return_value=defer.succeed(self.hs.config.max_mau_value)
+ return_value=defer.succeed(self.auth_blocking._max_mau_value)
)
with self.assertRaises(ResourceLimitError):
- yield self.auth_handler.get_access_token_for_user_id(
- "user_a", device_id=None, valid_until_ms=None
+ yield defer.ensureDeferred(
+ self.auth_handler.get_access_token_for_user_id(
+ "user_a", device_id=None, valid_until_ms=None
+ )
)
self.hs.get_datastore().get_monthly_active_count = Mock(
- return_value=defer.succeed(self.hs.config.max_mau_value)
+ return_value=defer.succeed(self.auth_blocking._max_mau_value)
)
with self.assertRaises(ResourceLimitError):
- yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
- self._get_macaroon().serialize()
+ yield defer.ensureDeferred(
+ self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ self._get_macaroon().serialize()
+ )
)
# If in monthly active cohort
self.hs.get_datastore().user_last_seen_monthly_active = Mock(
return_value=defer.succeed(self.hs.get_clock().time_msec())
)
self.hs.get_datastore().get_monthly_active_count = Mock(
- return_value=defer.succeed(self.hs.config.max_mau_value)
+ return_value=defer.succeed(self.auth_blocking._max_mau_value)
)
- yield self.auth_handler.get_access_token_for_user_id(
- "user_a", device_id=None, valid_until_ms=None
+ yield defer.ensureDeferred(
+ self.auth_handler.get_access_token_for_user_id(
+ "user_a", device_id=None, valid_until_ms=None
+ )
)
self.hs.get_datastore().user_last_seen_monthly_active = Mock(
return_value=defer.succeed(self.hs.get_clock().time_msec())
)
self.hs.get_datastore().get_monthly_active_count = Mock(
- return_value=defer.succeed(self.hs.config.max_mau_value)
+ return_value=defer.succeed(self.auth_blocking._max_mau_value)
)
- yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
- self._get_macaroon().serialize()
+ yield defer.ensureDeferred(
+ self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ self._get_macaroon().serialize()
+ )
)
@defer.inlineCallbacks
def test_mau_limits_not_exceeded(self):
- self.hs.config.limit_usage_by_mau = True
+ self.auth_blocking._limit_usage_by_mau = True
self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.small_number_of_users)
)
# Ensure does not raise exception
- yield self.auth_handler.get_access_token_for_user_id(
- "user_a", device_id=None, valid_until_ms=None
+ yield defer.ensureDeferred(
+ self.auth_handler.get_access_token_for_user_id(
+ "user_a", device_id=None, valid_until_ms=None
+ )
)
self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.small_number_of_users)
)
- yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
- self._get_macaroon().serialize()
+ yield defer.ensureDeferred(
+ self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ self._get_macaroon().serialize()
+ )
)
def _get_macaroon(self):
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index 5e40adba52..00bb776271 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -102,6 +102,68 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
self.assertEquals({"room_id": "!8765asdf:test", "servers": ["test"]}, response)
+class TestCreateAlias(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ directory.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.handler = hs.get_handlers().directory_handler
+
+ # Create user
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ # Create a test room
+ self.room_id = self.helper.create_room_as(
+ self.admin_user, tok=self.admin_user_tok
+ )
+
+ self.test_alias = "#test:test"
+ self.room_alias = RoomAlias.from_string(self.test_alias)
+
+ # Create a test user.
+ self.test_user = self.register_user("user", "pass", admin=False)
+ self.test_user_tok = self.login("user", "pass")
+ self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok)
+
+ def test_create_alias_joined_room(self):
+ """A user can create an alias for a room they're in."""
+ self.get_success(
+ self.handler.create_association(
+ create_requester(self.test_user), self.room_alias, self.room_id,
+ )
+ )
+
+ def test_create_alias_other_room(self):
+ """A user cannot create an alias for a room they're NOT in."""
+ other_room_id = self.helper.create_room_as(
+ self.admin_user, tok=self.admin_user_tok
+ )
+
+ self.get_failure(
+ self.handler.create_association(
+ create_requester(self.test_user), self.room_alias, other_room_id,
+ ),
+ synapse.api.errors.SynapseError,
+ )
+
+ def test_create_alias_admin(self):
+ """An admin can create an alias for a room they're NOT in."""
+ other_room_id = self.helper.create_room_as(
+ self.test_user, tok=self.test_user_tok
+ )
+
+ self.get_success(
+ self.handler.create_association(
+ create_requester(self.admin_user), self.room_alias, other_room_id,
+ )
+ )
+
+
class TestDeleteAlias(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets,
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
new file mode 100644
index 0000000000..61963aa90d
--- /dev/null
+++ b/tests/handlers/test_oidc.py
@@ -0,0 +1,565 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Quentin Gliech
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+from urllib.parse import parse_qs, urlparse
+
+from mock import Mock, patch
+
+import attr
+import pymacaroons
+
+from twisted.internet import defer
+from twisted.python.failure import Failure
+from twisted.web._newclient import ResponseDone
+
+from synapse.handlers.oidc_handler import (
+ MappingException,
+ OidcError,
+ OidcHandler,
+ OidcMappingProvider,
+)
+from synapse.types import UserID
+
+from tests.unittest import HomeserverTestCase, override_config
+
+
+@attr.s
+class FakeResponse:
+ code = attr.ib()
+ body = attr.ib()
+ phrase = attr.ib()
+
+ def deliverBody(self, protocol):
+ protocol.dataReceived(self.body)
+ protocol.connectionLost(Failure(ResponseDone()))
+
+
+# These are a few constants that are used as config parameters in the tests.
+ISSUER = "https://issuer/"
+CLIENT_ID = "test-client-id"
+CLIENT_SECRET = "test-client-secret"
+BASE_URL = "https://synapse/"
+CALLBACK_URL = BASE_URL + "_synapse/oidc/callback"
+SCOPES = ["openid"]
+
+AUTHORIZATION_ENDPOINT = ISSUER + "authorize"
+TOKEN_ENDPOINT = ISSUER + "token"
+USERINFO_ENDPOINT = ISSUER + "userinfo"
+WELL_KNOWN = ISSUER + ".well-known/openid-configuration"
+JWKS_URI = ISSUER + ".well-known/jwks.json"
+
+# config for common cases
+COMMON_CONFIG = {
+ "discover": False,
+ "authorization_endpoint": AUTHORIZATION_ENDPOINT,
+ "token_endpoint": TOKEN_ENDPOINT,
+ "jwks_uri": JWKS_URI,
+}
+
+
+# The cookie name and path don't really matter, just that it has to be coherent
+# between the callback & redirect handlers.
+COOKIE_NAME = b"oidc_session"
+COOKIE_PATH = "/_synapse/oidc"
+
+MockedMappingProvider = Mock(OidcMappingProvider)
+
+
+def simple_async_mock(return_value=None, raises=None):
+ # AsyncMock is not available in python3.5, this mimics part of its behaviour
+ async def cb(*args, **kwargs):
+ if raises:
+ raise raises
+ return return_value
+
+ return Mock(side_effect=cb)
+
+
+async def get_json(url):
+ # Mock get_json calls to handle jwks & oidc discovery endpoints
+ if url == WELL_KNOWN:
+ # Minimal discovery document, as defined in OpenID.Discovery
+ # https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata
+ return {
+ "issuer": ISSUER,
+ "authorization_endpoint": AUTHORIZATION_ENDPOINT,
+ "token_endpoint": TOKEN_ENDPOINT,
+ "jwks_uri": JWKS_URI,
+ "userinfo_endpoint": USERINFO_ENDPOINT,
+ "response_types_supported": ["code"],
+ "subject_types_supported": ["public"],
+ "id_token_signing_alg_values_supported": ["RS256"],
+ }
+ elif url == JWKS_URI:
+ return {"keys": []}
+
+
+class OidcHandlerTestCase(HomeserverTestCase):
+ def make_homeserver(self, reactor, clock):
+
+ self.http_client = Mock(spec=["get_json"])
+ self.http_client.get_json.side_effect = get_json
+ self.http_client.user_agent = "Synapse Test"
+
+ config = self.default_config()
+ config["public_baseurl"] = BASE_URL
+ oidc_config = config.get("oidc_config", {})
+ oidc_config["enabled"] = True
+ oidc_config["client_id"] = CLIENT_ID
+ oidc_config["client_secret"] = CLIENT_SECRET
+ oidc_config["issuer"] = ISSUER
+ oidc_config["scopes"] = SCOPES
+ oidc_config["user_mapping_provider"] = {
+ "module": __name__ + ".MockedMappingProvider"
+ }
+ config["oidc_config"] = oidc_config
+
+ hs = self.setup_test_homeserver(
+ http_client=self.http_client,
+ proxied_http_client=self.http_client,
+ config=config,
+ )
+
+ self.handler = OidcHandler(hs)
+
+ return hs
+
+ def metadata_edit(self, values):
+ return patch.dict(self.handler._provider_metadata, values)
+
+ def assertRenderedError(self, error, error_description=None):
+ args = self.handler._render_error.call_args[0]
+ self.assertEqual(args[1], error)
+ if error_description is not None:
+ self.assertEqual(args[2], error_description)
+ # Reset the render_error mock
+ self.handler._render_error.reset_mock()
+
+ def test_config(self):
+ """Basic config correctly sets up the callback URL and client auth correctly."""
+ self.assertEqual(self.handler._callback_url, CALLBACK_URL)
+ self.assertEqual(self.handler._client_auth.client_id, CLIENT_ID)
+ self.assertEqual(self.handler._client_auth.client_secret, CLIENT_SECRET)
+
+ @override_config({"oidc_config": {"discover": True}})
+ @defer.inlineCallbacks
+ def test_discovery(self):
+ """The handler should discover the endpoints from OIDC discovery document."""
+ # This would throw if some metadata were invalid
+ metadata = yield defer.ensureDeferred(self.handler.load_metadata())
+ self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
+
+ self.assertEqual(metadata.issuer, ISSUER)
+ self.assertEqual(metadata.authorization_endpoint, AUTHORIZATION_ENDPOINT)
+ self.assertEqual(metadata.token_endpoint, TOKEN_ENDPOINT)
+ self.assertEqual(metadata.jwks_uri, JWKS_URI)
+ # FIXME: it seems like authlib does not have that defined in its metadata models
+ # self.assertEqual(metadata.userinfo_endpoint, USERINFO_ENDPOINT)
+
+ # subsequent calls should be cached
+ self.http_client.reset_mock()
+ yield defer.ensureDeferred(self.handler.load_metadata())
+ self.http_client.get_json.assert_not_called()
+
+ @override_config({"oidc_config": COMMON_CONFIG})
+ @defer.inlineCallbacks
+ def test_no_discovery(self):
+ """When discovery is disabled, it should not try to load from discovery document."""
+ yield defer.ensureDeferred(self.handler.load_metadata())
+ self.http_client.get_json.assert_not_called()
+
+ @override_config({"oidc_config": COMMON_CONFIG})
+ @defer.inlineCallbacks
+ def test_load_jwks(self):
+ """JWKS loading is done once (then cached) if used."""
+ jwks = yield defer.ensureDeferred(self.handler.load_jwks())
+ self.http_client.get_json.assert_called_once_with(JWKS_URI)
+ self.assertEqual(jwks, {"keys": []})
+
+ # subsequent calls should be cached…
+ self.http_client.reset_mock()
+ yield defer.ensureDeferred(self.handler.load_jwks())
+ self.http_client.get_json.assert_not_called()
+
+ # …unless forced
+ self.http_client.reset_mock()
+ yield defer.ensureDeferred(self.handler.load_jwks(force=True))
+ self.http_client.get_json.assert_called_once_with(JWKS_URI)
+
+ # Throw if the JWKS uri is missing
+ with self.metadata_edit({"jwks_uri": None}):
+ with self.assertRaises(RuntimeError):
+ yield defer.ensureDeferred(self.handler.load_jwks(force=True))
+
+ # Return empty key set if JWKS are not used
+ self.handler._scopes = [] # not asking the openid scope
+ self.http_client.get_json.reset_mock()
+ jwks = yield defer.ensureDeferred(self.handler.load_jwks(force=True))
+ self.http_client.get_json.assert_not_called()
+ self.assertEqual(jwks, {"keys": []})
+
+ @override_config({"oidc_config": COMMON_CONFIG})
+ def test_validate_config(self):
+ """Provider metadatas are extensively validated."""
+ h = self.handler
+
+ # Default test config does not throw
+ h._validate_metadata()
+
+ with self.metadata_edit({"issuer": None}):
+ self.assertRaisesRegex(ValueError, "issuer", h._validate_metadata)
+
+ with self.metadata_edit({"issuer": "http://insecure/"}):
+ self.assertRaisesRegex(ValueError, "issuer", h._validate_metadata)
+
+ with self.metadata_edit({"issuer": "https://invalid/?because=query"}):
+ self.assertRaisesRegex(ValueError, "issuer", h._validate_metadata)
+
+ with self.metadata_edit({"authorization_endpoint": None}):
+ self.assertRaisesRegex(
+ ValueError, "authorization_endpoint", h._validate_metadata
+ )
+
+ with self.metadata_edit({"authorization_endpoint": "http://insecure/auth"}):
+ self.assertRaisesRegex(
+ ValueError, "authorization_endpoint", h._validate_metadata
+ )
+
+ with self.metadata_edit({"token_endpoint": None}):
+ self.assertRaisesRegex(ValueError, "token_endpoint", h._validate_metadata)
+
+ with self.metadata_edit({"token_endpoint": "http://insecure/token"}):
+ self.assertRaisesRegex(ValueError, "token_endpoint", h._validate_metadata)
+
+ with self.metadata_edit({"jwks_uri": None}):
+ self.assertRaisesRegex(ValueError, "jwks_uri", h._validate_metadata)
+
+ with self.metadata_edit({"jwks_uri": "http://insecure/jwks.json"}):
+ self.assertRaisesRegex(ValueError, "jwks_uri", h._validate_metadata)
+
+ with self.metadata_edit({"response_types_supported": ["id_token"]}):
+ self.assertRaisesRegex(
+ ValueError, "response_types_supported", h._validate_metadata
+ )
+
+ with self.metadata_edit(
+ {"token_endpoint_auth_methods_supported": ["client_secret_basic"]}
+ ):
+ # should not throw, as client_secret_basic is the default auth method
+ h._validate_metadata()
+
+ with self.metadata_edit(
+ {"token_endpoint_auth_methods_supported": ["client_secret_post"]}
+ ):
+ self.assertRaisesRegex(
+ ValueError,
+ "token_endpoint_auth_methods_supported",
+ h._validate_metadata,
+ )
+
+ # Tests for configs that the userinfo endpoint
+ self.assertFalse(h._uses_userinfo)
+ h._scopes = [] # do not request the openid scope
+ self.assertTrue(h._uses_userinfo)
+ self.assertRaisesRegex(ValueError, "userinfo_endpoint", h._validate_metadata)
+
+ with self.metadata_edit(
+ {"userinfo_endpoint": USERINFO_ENDPOINT, "jwks_uri": None}
+ ):
+ # Shouldn't raise with a valid userinfo, even without
+ h._validate_metadata()
+
+ @override_config({"oidc_config": {"skip_verification": True}})
+ def test_skip_verification(self):
+ """Provider metadata validation can be disabled by config."""
+ with self.metadata_edit({"issuer": "http://insecure"}):
+ # This should not throw
+ self.handler._validate_metadata()
+
+ @defer.inlineCallbacks
+ def test_redirect_request(self):
+ """The redirect request has the right arguments & generates a valid session cookie."""
+ req = Mock(spec=["addCookie", "redirect", "finish"])
+ yield defer.ensureDeferred(
+ self.handler.handle_redirect_request(req, b"http://client/redirect")
+ )
+ url = req.redirect.call_args[0][0]
+ url = urlparse(url)
+ auth_endpoint = urlparse(AUTHORIZATION_ENDPOINT)
+
+ self.assertEqual(url.scheme, auth_endpoint.scheme)
+ self.assertEqual(url.netloc, auth_endpoint.netloc)
+ self.assertEqual(url.path, auth_endpoint.path)
+
+ params = parse_qs(url.query)
+ self.assertEqual(params["redirect_uri"], [CALLBACK_URL])
+ self.assertEqual(params["response_type"], ["code"])
+ self.assertEqual(params["scope"], [" ".join(SCOPES)])
+ self.assertEqual(params["client_id"], [CLIENT_ID])
+ self.assertEqual(len(params["state"]), 1)
+ self.assertEqual(len(params["nonce"]), 1)
+
+ # Check what is in the cookie
+ # note: python3.5 mock does not have the .called_once() method
+ calls = req.addCookie.call_args_list
+ self.assertEqual(len(calls), 1) # called once
+ # For some reason, call.args does not work with python3.5
+ args = calls[0][0]
+ kwargs = calls[0][1]
+ self.assertEqual(args[0], COOKIE_NAME)
+ self.assertEqual(kwargs["path"], COOKIE_PATH)
+ cookie = args[1]
+
+ macaroon = pymacaroons.Macaroon.deserialize(cookie)
+ state = self.handler._get_value_from_macaroon(macaroon, "state")
+ nonce = self.handler._get_value_from_macaroon(macaroon, "nonce")
+ redirect = self.handler._get_value_from_macaroon(
+ macaroon, "client_redirect_url"
+ )
+
+ self.assertEqual(params["state"], [state])
+ self.assertEqual(params["nonce"], [nonce])
+ self.assertEqual(redirect, "http://client/redirect")
+
+ @defer.inlineCallbacks
+ def test_callback_error(self):
+ """Errors from the provider returned in the callback are displayed."""
+ self.handler._render_error = Mock()
+ request = Mock(args={})
+ request.args[b"error"] = [b"invalid_client"]
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("invalid_client", "")
+
+ request.args[b"error_description"] = [b"some description"]
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("invalid_client", "some description")
+
+ @defer.inlineCallbacks
+ def test_callback(self):
+ """Code callback works and display errors if something went wrong.
+
+ A lot of scenarios are tested here:
+ - when the callback works, with userinfo from ID token
+ - when the user mapping fails
+ - when ID token verification fails
+ - when the callback works, with userinfo fetched from the userinfo endpoint
+ - when the userinfo fetching fails
+ - when the code exchange fails
+ """
+ token = {
+ "type": "bearer",
+ "id_token": "id_token",
+ "access_token": "access_token",
+ }
+ userinfo = {
+ "sub": "foo",
+ "preferred_username": "bar",
+ }
+ user_id = UserID("foo", "domain.org")
+ self.handler._render_error = Mock(return_value=None)
+ self.handler._exchange_code = simple_async_mock(return_value=token)
+ self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
+ self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
+ self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
+ self.handler._auth_handler.complete_sso_login = simple_async_mock()
+ request = Mock(spec=["args", "getCookie", "addCookie"])
+
+ code = "code"
+ state = "state"
+ nonce = "nonce"
+ client_redirect_url = "http://client/redirect"
+ session = self.handler._generate_oidc_session_token(
+ state=state, nonce=nonce, client_redirect_url=client_redirect_url,
+ )
+ request.getCookie.return_value = session
+
+ request.args = {}
+ request.args[b"code"] = [code.encode("utf-8")]
+ request.args[b"state"] = [state.encode("utf-8")]
+
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+
+ self.handler._auth_handler.complete_sso_login.assert_called_once_with(
+ user_id, request, client_redirect_url,
+ )
+ self.handler._exchange_code.assert_called_once_with(code)
+ self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce)
+ self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token)
+ self.handler._fetch_userinfo.assert_not_called()
+ self.handler._render_error.assert_not_called()
+
+ # Handle mapping errors
+ self.handler._map_userinfo_to_user = simple_async_mock(
+ raises=MappingException()
+ )
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("mapping_error")
+ self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
+
+ # Handle ID token errors
+ self.handler._parse_id_token = simple_async_mock(raises=Exception())
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("invalid_token")
+
+ self.handler._auth_handler.complete_sso_login.reset_mock()
+ self.handler._exchange_code.reset_mock()
+ self.handler._parse_id_token.reset_mock()
+ self.handler._map_userinfo_to_user.reset_mock()
+ self.handler._fetch_userinfo.reset_mock()
+
+ # With userinfo fetching
+ self.handler._scopes = [] # do not ask the "openid" scope
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+
+ self.handler._auth_handler.complete_sso_login.assert_called_once_with(
+ user_id, request, client_redirect_url,
+ )
+ self.handler._exchange_code.assert_called_once_with(code)
+ self.handler._parse_id_token.assert_not_called()
+ self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token)
+ self.handler._fetch_userinfo.assert_called_once_with(token)
+ self.handler._render_error.assert_not_called()
+
+ # Handle userinfo fetching error
+ self.handler._fetch_userinfo = simple_async_mock(raises=Exception())
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("fetch_error")
+
+ # Handle code exchange failure
+ self.handler._exchange_code = simple_async_mock(
+ raises=OidcError("invalid_request")
+ )
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("invalid_request")
+
+ @defer.inlineCallbacks
+ def test_callback_session(self):
+ """The callback verifies the session presence and validity"""
+ self.handler._render_error = Mock(return_value=None)
+ request = Mock(spec=["args", "getCookie", "addCookie"])
+
+ # Missing cookie
+ request.args = {}
+ request.getCookie.return_value = None
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("missing_session", "No session cookie found")
+
+ # Missing session parameter
+ request.args = {}
+ request.getCookie.return_value = "session"
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("invalid_request", "State parameter is missing")
+
+ # Invalid cookie
+ request.args = {}
+ request.args[b"state"] = [b"state"]
+ request.getCookie.return_value = "session"
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("invalid_session")
+
+ # Mismatching session
+ session = self.handler._generate_oidc_session_token(
+ state="state", nonce="nonce", client_redirect_url="http://client/redirect",
+ )
+ request.args = {}
+ request.args[b"state"] = [b"mismatching state"]
+ request.getCookie.return_value = session
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("mismatching_session")
+
+ # Valid session
+ request.args = {}
+ request.args[b"state"] = [b"state"]
+ request.getCookie.return_value = session
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("invalid_request")
+
+ @override_config({"oidc_config": {"client_auth_method": "client_secret_post"}})
+ @defer.inlineCallbacks
+ def test_exchange_code(self):
+ """Code exchange behaves correctly and handles various error scenarios."""
+ token = {"type": "bearer"}
+ token_json = json.dumps(token).encode("utf-8")
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse(code=200, phrase=b"OK", body=token_json)
+ )
+ code = "code"
+ ret = yield defer.ensureDeferred(self.handler._exchange_code(code))
+ kwargs = self.http_client.request.call_args[1]
+
+ self.assertEqual(ret, token)
+ self.assertEqual(kwargs["method"], "POST")
+ self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT)
+
+ args = parse_qs(kwargs["data"].decode("utf-8"))
+ self.assertEqual(args["grant_type"], ["authorization_code"])
+ self.assertEqual(args["code"], [code])
+ self.assertEqual(args["client_id"], [CLIENT_ID])
+ self.assertEqual(args["client_secret"], [CLIENT_SECRET])
+ self.assertEqual(args["redirect_uri"], [CALLBACK_URL])
+
+ # Test error handling
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse(
+ code=400,
+ phrase=b"Bad Request",
+ body=b'{"error": "foo", "error_description": "bar"}',
+ )
+ )
+ with self.assertRaises(OidcError) as exc:
+ yield defer.ensureDeferred(self.handler._exchange_code(code))
+ self.assertEqual(exc.exception.error, "foo")
+ self.assertEqual(exc.exception.error_description, "bar")
+
+ # Internal server error with no JSON body
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse(
+ code=500, phrase=b"Internal Server Error", body=b"Not JSON",
+ )
+ )
+ with self.assertRaises(OidcError) as exc:
+ yield defer.ensureDeferred(self.handler._exchange_code(code))
+ self.assertEqual(exc.exception.error, "server_error")
+
+ # Internal server error with JSON body
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse(
+ code=500,
+ phrase=b"Internal Server Error",
+ body=b'{"error": "internal_server_error"}',
+ )
+ )
+ with self.assertRaises(OidcError) as exc:
+ yield defer.ensureDeferred(self.handler._exchange_code(code))
+ self.assertEqual(exc.exception.error, "internal_server_error")
+
+ # 4xx error without "error" field
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse(code=400, phrase=b"Bad request", body=b"{}",)
+ )
+ with self.assertRaises(OidcError) as exc:
+ yield defer.ensureDeferred(self.handler._exchange_code(code))
+ self.assertEqual(exc.exception.error, "server_error")
+
+ # 2xx error with "error" field
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse(
+ code=200, phrase=b"OK", body=b'{"error": "some_error"}',
+ )
+ )
+ with self.assertRaises(OidcError) as exc:
+ yield defer.ensureDeferred(self.handler._exchange_code(code))
+ self.assertEqual(exc.exception.error, "some_error")
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index d60c124eec..8aa56f1496 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -19,7 +19,7 @@ from mock import Mock, NonCallableMock
from twisted.internet import defer
import synapse.types
-from synapse.api.errors import AuthError
+from synapse.api.errors import AuthError, SynapseError
from synapse.handlers.profile import MasterProfileHandler
from synapse.types import UserID
@@ -70,6 +70,7 @@ class ProfileTestCase(unittest.TestCase):
yield self.store.create_profile(self.frank.localpart)
self.handler = hs.get_profile_handler()
+ self.hs = hs
@defer.inlineCallbacks
def test_get_my_name(self):
@@ -81,19 +82,58 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_set_my_name(self):
- yield self.handler.set_displayname(
- self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
+ yield defer.ensureDeferred(
+ self.handler.set_displayname(
+ self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
+ )
)
self.assertEquals(
- (yield self.store.get_profile_displayname(self.frank.localpart)),
+ (
+ yield defer.ensureDeferred(
+ self.store.get_profile_displayname(self.frank.localpart)
+ )
+ ),
"Frank Jr.",
)
+ # Set displayname again
+ yield defer.ensureDeferred(
+ self.handler.set_displayname(
+ self.frank, synapse.types.create_requester(self.frank), "Frank"
+ )
+ )
+
+ self.assertEquals(
+ (yield self.store.get_profile_displayname(self.frank.localpart)), "Frank",
+ )
+
+ @defer.inlineCallbacks
+ def test_set_my_name_if_disabled(self):
+ self.hs.config.enable_set_displayname = False
+
+ # Setting displayname for the first time is allowed
+ yield self.store.set_profile_displayname(self.frank.localpart, "Frank")
+
+ self.assertEquals(
+ (yield self.store.get_profile_displayname(self.frank.localpart)), "Frank",
+ )
+
+ # Setting displayname a second time is forbidden
+ d = defer.ensureDeferred(
+ self.handler.set_displayname(
+ self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
+ )
+ )
+
+ yield self.assertFailure(d, SynapseError)
+
@defer.inlineCallbacks
def test_set_my_name_noauth(self):
- d = self.handler.set_displayname(
- self.frank, synapse.types.create_requester(self.bob), "Frank Jr."
+ d = defer.ensureDeferred(
+ self.handler.set_displayname(
+ self.frank, synapse.types.create_requester(self.bob), "Frank Jr."
+ )
)
yield self.assertFailure(d, AuthError)
@@ -137,13 +177,54 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_set_my_avatar(self):
- yield self.handler.set_avatar_url(
- self.frank,
- synapse.types.create_requester(self.frank),
- "http://my.server/pic.gif",
+ yield defer.ensureDeferred(
+ self.handler.set_avatar_url(
+ self.frank,
+ synapse.types.create_requester(self.frank),
+ "http://my.server/pic.gif",
+ )
)
self.assertEquals(
(yield self.store.get_profile_avatar_url(self.frank.localpart)),
"http://my.server/pic.gif",
)
+
+ # Set avatar again
+ yield defer.ensureDeferred(
+ self.handler.set_avatar_url(
+ self.frank,
+ synapse.types.create_requester(self.frank),
+ "http://my.server/me.png",
+ )
+ )
+
+ self.assertEquals(
+ (yield self.store.get_profile_avatar_url(self.frank.localpart)),
+ "http://my.server/me.png",
+ )
+
+ @defer.inlineCallbacks
+ def test_set_my_avatar_if_disabled(self):
+ self.hs.config.enable_set_avatar_url = False
+
+ # Setting displayname for the first time is allowed
+ yield self.store.set_profile_avatar_url(
+ self.frank.localpart, "http://my.server/me.png"
+ )
+
+ self.assertEquals(
+ (yield self.store.get_profile_avatar_url(self.frank.localpart)),
+ "http://my.server/me.png",
+ )
+
+ # Set avatar a second time is forbidden
+ d = defer.ensureDeferred(
+ self.handler.set_avatar_url(
+ self.frank,
+ synapse.types.create_requester(self.frank),
+ "http://my.server/pic.gif",
+ )
+ )
+
+ yield self.assertFailure(d, SynapseError)
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index e2915eb7b1..1b7935cef2 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -34,7 +34,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
""" Tests the RegistrationHandler. """
def make_homeserver(self, reactor, clock):
- hs_config = self.default_config("test")
+ hs_config = self.default_config()
# some of the tests rely on us having a user consent version
hs_config["user_consent"] = {
@@ -175,7 +175,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
room_alias_str = "#room:test"
self.hs.config.auto_join_rooms = [room_alias_str]
- self.store.is_real_user = Mock(return_value=False)
+ self.store.is_real_user = Mock(return_value=defer.succeed(False))
user_id = self.get_success(self.handler.register_user(localpart="support"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
self.assertEqual(len(rooms), 0)
@@ -187,8 +187,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
room_alias_str = "#room:test"
self.hs.config.auto_join_rooms = [room_alias_str]
- self.store.count_real_users = Mock(return_value=1)
- self.store.is_real_user = Mock(return_value=True)
+ self.store.count_real_users = Mock(return_value=defer.succeed(1))
+ self.store.is_real_user = Mock(return_value=defer.succeed(True))
user_id = self.get_success(self.handler.register_user(localpart="real"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
directory_handler = self.hs.get_handlers().directory_handler
@@ -202,8 +202,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
room_alias_str = "#room:test"
self.hs.config.auto_join_rooms = [room_alias_str]
- self.store.count_real_users = Mock(return_value=2)
- self.store.is_real_user = Mock(return_value=True)
+ self.store.count_real_users = Mock(return_value=defer.succeed(2))
+ self.store.is_real_user = Mock(return_value=defer.succeed(True))
user_id = self.get_success(self.handler.register_user(localpart="real"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
self.assertEqual(len(rooms), 0)
@@ -256,8 +256,9 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.handler.register_user(localpart=invalid_user_id), SynapseError
)
- @defer.inlineCallbacks
- def get_or_create_user(self, requester, localpart, displayname, password_hash=None):
+ async def get_or_create_user(
+ self, requester, localpart, displayname, password_hash=None
+ ):
"""Creates a new user if the user does not exist,
else revokes all previous access tokens and generates a new one.
@@ -272,11 +273,11 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
"""
if localpart is None:
raise SynapseError(400, "Request must include user id")
- yield self.hs.get_auth().check_auth_blocking()
+ await self.hs.get_auth().check_auth_blocking()
need_register = True
try:
- yield self.handler.check_username(localpart)
+ await self.handler.check_username(localpart)
except SynapseError as e:
if e.errcode == Codes.USER_IN_USE:
need_register = False
@@ -288,21 +289,21 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
token = self.macaroon_generator.generate_access_token(user_id)
if need_register:
- yield self.handler.register_with_store(
+ await self.handler.register_with_store(
user_id=user_id,
password_hash=password_hash,
create_profile_with_displayname=user.localpart,
)
else:
- yield self.hs.get_auth_handler().delete_access_tokens_for_user(user_id)
+ await self.hs.get_auth_handler().delete_access_tokens_for_user(user_id)
- yield self.store.add_access_token_to_user(
+ await self.store.add_access_token_to_user(
user_id=user_id, token=token, device_id=None, valid_until_ms=None
)
if displayname is not None:
# logger.info("setting user display name: %s -> %s", user_id, displayname)
- yield self.hs.get_profile_handler().set_displayname(
+ await self.hs.get_profile_handler().set_displayname(
user, requester, displayname, by_admin=True
)
diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index 4cbe9784ed..e178d7765b 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -30,28 +30,31 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
self.sync_handler = self.hs.get_sync_handler()
self.store = self.hs.get_datastore()
- def test_wait_for_sync_for_user_auth_blocking(self):
+ # AuthBlocking reads from the hs' config on initialization. We need to
+ # modify its config instead of the hs'
+ self.auth_blocking = self.hs.get_auth()._auth_blocking
+ def test_wait_for_sync_for_user_auth_blocking(self):
user_id1 = "@user1:test"
user_id2 = "@user2:test"
sync_config = self._generate_sync_config(user_id1)
self.reactor.advance(100) # So we get not 0 time
- self.hs.config.limit_usage_by_mau = True
- self.hs.config.max_mau_value = 1
+ self.auth_blocking._limit_usage_by_mau = True
+ self.auth_blocking._max_mau_value = 1
# Check that the happy case does not throw errors
self.get_success(self.store.upsert_monthly_active_user(user_id1))
self.get_success(self.sync_handler.wait_for_sync_for_user(sync_config))
# Test that global lock works
- self.hs.config.hs_disabled = True
+ self.auth_blocking._hs_disabled = True
e = self.get_failure(
self.sync_handler.wait_for_sync_for_user(sync_config), ResourceLimitError
)
self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
- self.hs.config.hs_disabled = False
+ self.auth_blocking._hs_disabled = False
sync_config = self._generate_sync_config(user_id2)
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 7b92bdbc47..572df8d80b 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -185,7 +185,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
# Allow all users.
return False
- spam_checker.spam_checker = AllowAll()
+ spam_checker.spam_checkers = [AllowAll()]
# The results do not change:
# We get one search result when searching for user2 by user1.
@@ -198,7 +198,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
# All users are spammy.
return True
- spam_checker.spam_checker = BlockAll()
+ spam_checker.spam_checkers = [BlockAll()]
# User1 now gets no search results for any of the other users.
s = self.get_success(self.handler.search_users(u1, "user2", 10))
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index fdc1d918ff..562397cdda 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -38,7 +38,7 @@ from synapse.http.federation.well_known_resolver import (
WellKnownResolver,
_cache_period_from_headers,
)
-from synapse.logging.context import LoggingContext
+from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context
from synapse.util.caches.ttlcache import TTLCache
from tests import unittest
@@ -155,7 +155,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.assertNoResult(fetch_d)
# should have reset logcontext to the sentinel
- _check_logcontext(LoggingContext.sentinel)
+ _check_logcontext(SENTINEL_CONTEXT)
try:
fetch_res = yield fetch_d
@@ -1197,7 +1197,7 @@ class TestCachePeriodFromHeaders(unittest.TestCase):
def _check_logcontext(context):
- current = LoggingContext.current_context()
+ current = current_context()
if current is not context:
raise AssertionError("Expected logcontext %s but was %s" % (context, current))
diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py
index df034ab237..babc201643 100644
--- a/tests/http/federation/test_srv_resolver.py
+++ b/tests/http/federation/test_srv_resolver.py
@@ -22,7 +22,7 @@ from twisted.internet.error import ConnectError
from twisted.names import dns, error
from synapse.http.federation.srv_resolver import SrvResolver
-from synapse.logging.context import LoggingContext
+from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context
from tests import unittest
from tests.utils import MockClock
@@ -54,12 +54,12 @@ class SrvResolverTestCase(unittest.TestCase):
self.assertNoResult(resolve_d)
# should have reset to the sentinel context
- self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel)
+ self.assertIs(current_context(), SENTINEL_CONTEXT)
result = yield resolve_d
# should have restored our context
- self.assertIs(LoggingContext.current_context(), ctx)
+ self.assertIs(current_context(), ctx)
return result
diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py
index 2b01f40a42..fff4f0cbf4 100644
--- a/tests/http/test_fedclient.py
+++ b/tests/http/test_fedclient.py
@@ -29,14 +29,14 @@ from synapse.http.matrixfederationclient import (
MatrixFederationHttpClient,
MatrixFederationRequest,
)
-from synapse.logging.context import LoggingContext
+from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context
from tests.server import FakeTransport
from tests.unittest import HomeserverTestCase
def check_logcontext(context):
- current = LoggingContext.current_context()
+ current = current_context()
if current is not context:
raise AssertionError("Expected logcontext %s but was %s" % (context, current))
@@ -64,7 +64,7 @@ class FederationClientTests(HomeserverTestCase):
self.assertNoResult(fetch_d)
# should have reset logcontext to the sentinel
- check_logcontext(LoggingContext.sentinel)
+ check_logcontext(SENTINEL_CONTEXT)
try:
fetch_res = yield fetch_d
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
new file mode 100644
index 0000000000..9d4f0bbe44
--- /dev/null
+++ b/tests/replication/_base.py
@@ -0,0 +1,307 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import Any, List, Optional, Tuple
+
+import attr
+
+from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
+from twisted.internet.task import LoopingCall
+from twisted.web.http import HTTPChannel
+
+from synapse.app.generic_worker import (
+ GenericWorkerReplicationHandler,
+ GenericWorkerServer,
+)
+from synapse.http.site import SynapseRequest
+from synapse.replication.http import streams
+from synapse.replication.tcp.handler import ReplicationCommandHandler
+from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
+from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
+from synapse.server import HomeServer
+from synapse.util import Clock
+
+from tests import unittest
+from tests.server import FakeTransport
+
+logger = logging.getLogger(__name__)
+
+
+class BaseStreamTestCase(unittest.HomeserverTestCase):
+ """Base class for tests of the replication streams"""
+
+ servlets = [
+ streams.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ # build a replication server
+ server_factory = ReplicationStreamProtocolFactory(hs)
+ self.streamer = hs.get_replication_streamer()
+ self.server = server_factory.buildProtocol(None)
+
+ # Make a new HomeServer object for the worker
+ self.reactor.lookups["testserv"] = "1.2.3.4"
+ self.worker_hs = self.setup_test_homeserver(
+ http_client=None,
+ homeserverToUse=GenericWorkerServer,
+ config=self._get_worker_hs_config(),
+ reactor=self.reactor,
+ )
+
+ # Since we use sqlite in memory databases we need to make sure the
+ # databases objects are the same.
+ self.worker_hs.get_datastore().db = hs.get_datastore().db
+
+ self.test_handler = self._build_replication_data_handler()
+ self.worker_hs.replication_data_handler = self.test_handler
+
+ repl_handler = ReplicationCommandHandler(self.worker_hs)
+ self.client = ClientReplicationStreamProtocol(
+ self.worker_hs, "client", "test", clock, repl_handler,
+ )
+
+ self._client_transport = None
+ self._server_transport = None
+
+ def _get_worker_hs_config(self) -> dict:
+ config = self.default_config()
+ config["worker_app"] = "synapse.app.generic_worker"
+ config["worker_replication_host"] = "testserv"
+ config["worker_replication_http_port"] = "8765"
+ return config
+
+ def _build_replication_data_handler(self):
+ return TestReplicationDataHandler(self.worker_hs)
+
+ def reconnect(self):
+ if self._client_transport:
+ self.client.close()
+
+ if self._server_transport:
+ self.server.close()
+
+ self._client_transport = FakeTransport(self.server, self.reactor)
+ self.client.makeConnection(self._client_transport)
+
+ self._server_transport = FakeTransport(self.client, self.reactor)
+ self.server.makeConnection(self._server_transport)
+
+ def disconnect(self):
+ if self._client_transport:
+ self._client_transport = None
+ self.client.close()
+
+ if self._server_transport:
+ self._server_transport = None
+ self.server.close()
+
+ def replicate(self):
+ """Tell the master side of replication that something has happened, and then
+ wait for the replication to occur.
+ """
+ self.streamer.on_notifier_poke()
+ self.pump(0.1)
+
+ def handle_http_replication_attempt(self) -> SynapseRequest:
+ """Asserts that a connection attempt was made to the master HS on the
+ HTTP replication port, then proxies it to the master HS object to be
+ handled.
+
+ Returns:
+ The request object received by master HS.
+ """
+
+ # We should have an outbound connection attempt.
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
+ self.assertEqual(host, "1.2.3.4")
+ self.assertEqual(port, 8765)
+
+ # Set up client side protocol
+ client_protocol = client_factory.buildProtocol(None)
+
+ request_factory = OneShotRequestFactory()
+
+ # Set up the server side protocol
+ channel = _PushHTTPChannel(self.reactor)
+ channel.requestFactory = request_factory
+ channel.site = self.site
+
+ # Connect client to server and vice versa.
+ client_to_server_transport = FakeTransport(
+ channel, self.reactor, client_protocol
+ )
+ client_protocol.makeConnection(client_to_server_transport)
+
+ server_to_client_transport = FakeTransport(
+ client_protocol, self.reactor, channel
+ )
+ channel.makeConnection(server_to_client_transport)
+
+ # The request will now be processed by `self.site` and the response
+ # streamed back.
+ self.reactor.advance(0)
+
+ # We tear down the connection so it doesn't get reused without our
+ # knowledge.
+ server_to_client_transport.loseConnection()
+ client_to_server_transport.loseConnection()
+
+ return request_factory.request
+
+ def assert_request_is_get_repl_stream_updates(
+ self, request: SynapseRequest, stream_name: str
+ ):
+ """Asserts that the given request is a HTTP replication request for
+ fetching updates for given stream.
+ """
+
+ self.assertRegex(
+ request.path,
+ br"^/_synapse/replication/get_repl_stream_updates/%s/[^/]+$"
+ % (stream_name.encode("ascii"),),
+ )
+
+ self.assertEqual(request.method, b"GET")
+
+
+class TestReplicationDataHandler(GenericWorkerReplicationHandler):
+ """Drop-in for ReplicationDataHandler which just collects RDATA rows"""
+
+ def __init__(self, hs: HomeServer):
+ super().__init__(hs)
+
+ # list of received (stream_name, token, row) tuples
+ self.received_rdata_rows = [] # type: List[Tuple[str, int, Any]]
+
+ async def on_rdata(self, stream_name, instance_name, token, rows):
+ await super().on_rdata(stream_name, instance_name, token, rows)
+ for r in rows:
+ self.received_rdata_rows.append((stream_name, token, r))
+
+
+@attr.s()
+class OneShotRequestFactory:
+ """A simple request factory that generates a single `SynapseRequest` and
+ stores it for future use. Can only be used once.
+ """
+
+ request = attr.ib(default=None)
+
+ def __call__(self, *args, **kwargs):
+ assert self.request is None
+
+ self.request = SynapseRequest(*args, **kwargs)
+ return self.request
+
+
+class _PushHTTPChannel(HTTPChannel):
+ """A HTTPChannel that wraps pull producers to push producers.
+
+ This is a hack to get around the fact that HTTPChannel transparently wraps a
+ pull producer (which is what Synapse uses to reply to requests) with
+ `_PullToPush` to convert it to a push producer. Unfortunately `_PullToPush`
+ uses the standard reactor rather than letting us use our test reactor, which
+ makes it very hard to test.
+ """
+
+ def __init__(self, reactor: IReactorTime):
+ super().__init__()
+ self.reactor = reactor
+
+ self._pull_to_push_producer = None # type: Optional[_PullToPushProducer]
+
+ def registerProducer(self, producer, streaming):
+ # Convert pull producers to push producer.
+ if not streaming:
+ self._pull_to_push_producer = _PullToPushProducer(
+ self.reactor, producer, self
+ )
+ producer = self._pull_to_push_producer
+
+ super().registerProducer(producer, True)
+
+ def unregisterProducer(self):
+ if self._pull_to_push_producer:
+ # We need to manually stop the _PullToPushProducer.
+ self._pull_to_push_producer.stop()
+
+
+class _PullToPushProducer:
+ """A push producer that wraps a pull producer.
+ """
+
+ def __init__(
+ self, reactor: IReactorTime, producer: IPullProducer, consumer: IConsumer
+ ):
+ self._clock = Clock(reactor)
+ self._producer = producer
+ self._consumer = consumer
+
+ # While running we use a looping call with a zero delay to call
+ # resumeProducing on given producer.
+ self._looping_call = None # type: Optional[LoopingCall]
+
+ # We start writing next reactor tick.
+ self._start_loop()
+
+ def _start_loop(self):
+ """Start the looping call to
+ """
+
+ if not self._looping_call:
+ # Start a looping call which runs every tick.
+ self._looping_call = self._clock.looping_call(self._run_once, 0)
+
+ def stop(self):
+ """Stops calling resumeProducing.
+ """
+ if self._looping_call:
+ self._looping_call.stop()
+ self._looping_call = None
+
+ def pauseProducing(self):
+ """Implements IPushProducer
+ """
+ self.stop()
+
+ def resumeProducing(self):
+ """Implements IPushProducer
+ """
+ self._start_loop()
+
+ def stopProducing(self):
+ """Implements IPushProducer
+ """
+ self.stop()
+ self._producer.stopProducing()
+
+ def _run_once(self):
+ """Calls resumeProducing on producer once.
+ """
+
+ try:
+ self._producer.resumeProducing()
+ except Exception:
+ logger.exception("Failed to call resumeProducing")
+ try:
+ self._consumer.unregisterProducer()
+ except Exception:
+ pass
+
+ self.stopProducing()
diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index 2a1e7c7166..32cb04645f 100644
--- a/tests/replication/slave/storage/_base.py
+++ b/tests/replication/slave/storage/_base.py
@@ -15,22 +15,13 @@
from mock import Mock, NonCallableMock
-from synapse.replication.tcp.client import (
- ReplicationClientFactory,
- ReplicationClientHandler,
-)
-from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
-from synapse.storage.database import make_conn
+from tests.replication._base import BaseStreamTestCase
-from tests import unittest
-from tests.server import FakeTransport
-
-class BaseSlavedStoreTestCase(unittest.HomeserverTestCase):
+class BaseSlavedStoreTestCase(BaseStreamTestCase):
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver(
- "blue",
federation_client=Mock(),
ratelimiter=NonCallableMock(spec_set=["can_do_action"]),
)
@@ -40,34 +31,13 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase):
return hs
def prepare(self, reactor, clock, hs):
+ super().prepare(reactor, clock, hs)
- db_config = hs.config.database.get_single_database()
- self.master_store = self.hs.get_datastore()
- self.storage = hs.get_storage()
- database = hs.get_datastores().databases[0]
- self.slaved_store = self.STORE_TYPE(
- database, make_conn(db_config, database.engine), self.hs
- )
- self.event_id = 0
-
- server_factory = ReplicationStreamProtocolFactory(self.hs)
- self.streamer = server_factory.streamer
-
- handler_factory = Mock()
- self.replication_handler = ReplicationClientHandler(self.slaved_store)
- self.replication_handler.factory = handler_factory
+ self.reconnect()
- client_factory = ReplicationClientFactory(
- self.hs, "client_name", self.replication_handler
- )
-
- server = server_factory.buildProtocol(None)
- client = client_factory.buildProtocol(None)
-
- client.makeConnection(FakeTransport(server, reactor))
-
- self.server_to_client_transport = FakeTransport(client, reactor)
- server.makeConnection(self.server_to_client_transport)
+ self.master_store = hs.get_datastore()
+ self.slaved_store = self.worker_hs.get_datastore()
+ self.storage = hs.get_storage()
def replicate(self):
"""Tell the master side of replication that something has happened, and then
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index f0561b30e3..0fee8a71c4 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -24,10 +24,10 @@ from synapse.storage.roommember import RoomsForUser
from ._base import BaseSlavedStoreTestCase
-USER_ID = "@feeling:blue"
-USER_ID_2 = "@bright:blue"
+USER_ID = "@feeling:test"
+USER_ID_2 = "@bright:test"
OUTLIER = {"outlier": True}
-ROOM_ID = "!room:blue"
+ROOM_ID = "!room:test"
logger = logging.getLogger(__name__)
@@ -239,7 +239,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.check("get_rooms_for_user_with_stream_ordering", (USER_ID_2,), set())
# limit the replication rate
- repl_transport = self.server_to_client_transport
+ repl_transport = self._server_transport
repl_transport.autoflush = False
# build the join and message events and persist them in the same batch.
diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py
deleted file mode 100644
index e96ad4ca4e..0000000000
--- a/tests/replication/tcp/streams/_base.py
+++ /dev/null
@@ -1,78 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2019 New Vector Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-from mock import Mock
-
-from synapse.replication.tcp.commands import ReplicateCommand
-from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
-from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
-
-from tests import unittest
-from tests.server import FakeTransport
-
-
-class BaseStreamTestCase(unittest.HomeserverTestCase):
- """Base class for tests of the replication streams"""
-
- def prepare(self, reactor, clock, hs):
- # build a replication server
- server_factory = ReplicationStreamProtocolFactory(self.hs)
- self.streamer = server_factory.streamer
- server = server_factory.buildProtocol(None)
-
- # build a replication client, with a dummy handler
- handler_factory = Mock()
- self.test_handler = TestReplicationClientHandler()
- self.test_handler.factory = handler_factory
- self.client = ClientReplicationStreamProtocol(
- "client", "test", clock, self.test_handler
- )
-
- # wire them together
- self.client.makeConnection(FakeTransport(server, reactor))
- server.makeConnection(FakeTransport(self.client, reactor))
-
- def replicate(self):
- """Tell the master side of replication that something has happened, and then
- wait for the replication to occur.
- """
- self.streamer.on_notifier_poke()
- self.pump(0.1)
-
- def replicate_stream(self, stream, token="NOW"):
- """Make the client end a REPLICATE command to set up a subscription to a stream"""
- self.client.send_command(ReplicateCommand(stream, token))
-
-
-class TestReplicationClientHandler(object):
- """Drop-in for ReplicationClientHandler which just collects RDATA rows"""
-
- def __init__(self):
- self.received_rdata_rows = []
-
- def get_streams_to_replicate(self):
- return {}
-
- def get_currently_syncing_users(self):
- return []
-
- def update_connection(self, connection):
- pass
-
- def finished_connecting(self):
- pass
-
- async def on_rdata(self, stream_name, token, rows):
- for r in rows:
- self.received_rdata_rows.append((stream_name, token, r))
diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py
new file mode 100644
index 0000000000..51bf0ef4e9
--- /dev/null
+++ b/tests/replication/tcp/streams/test_events.py
@@ -0,0 +1,425 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Optional
+
+from synapse.api.constants import EventTypes, Membership
+from synapse.events import EventBase
+from synapse.replication.tcp.streams._base import _STREAM_UPDATE_TARGET_ROW_COUNT
+from synapse.replication.tcp.streams.events import (
+ EventsStreamCurrentStateRow,
+ EventsStreamEventRow,
+ EventsStreamRow,
+)
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+
+from tests.replication._base import BaseStreamTestCase
+from tests.test_utils.event_injection import inject_event, inject_member_event
+
+
+class EventsStreamTestCase(BaseStreamTestCase):
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ super().prepare(reactor, clock, hs)
+ self.user_id = self.register_user("u1", "pass")
+ self.user_tok = self.login("u1", "pass")
+
+ self.reconnect()
+
+ self.room_id = self.helper.create_room_as(tok=self.user_tok)
+ self.test_handler.received_rdata_rows.clear()
+
+ def test_update_function_event_row_limit(self):
+ """Test replication with many non-state events
+
+ Checks that all events are correctly replicated when there are lots of
+ event rows to be replicated.
+ """
+ # disconnect, so that we can stack up some changes
+ self.disconnect()
+
+ # generate lots of non-state events. We inject them using inject_event
+ # so that they are not send out over replication until we call self.replicate().
+ events = [
+ self._inject_test_event()
+ for _ in range(_STREAM_UPDATE_TARGET_ROW_COUNT + 1)
+ ]
+
+ # also one state event
+ state_event = self._inject_state_event()
+
+ # tell the notifier to catch up to avoid duplicate rows.
+ # workaround for https://github.com/matrix-org/synapse/issues/7360
+ # FIXME remove this when the above is fixed
+ self.replicate()
+
+ # check we're testing what we think we are: no rows should yet have been
+ # received
+ self.assertEqual([], self.test_handler.received_rdata_rows)
+
+ # now reconnect to pull the updates
+ self.reconnect()
+ self.replicate()
+
+ # we should have received all the expected rows in the right order (as
+ # well as various cache invalidation updates which we ignore)
+ received_rows = [
+ row for row in self.test_handler.received_rdata_rows if row[0] == "events"
+ ]
+
+ for event in events:
+ stream_name, token, row = received_rows.pop(0)
+ self.assertEqual("events", stream_name)
+ self.assertIsInstance(row, EventsStreamRow)
+ self.assertEqual(row.type, "ev")
+ self.assertIsInstance(row.data, EventsStreamEventRow)
+ self.assertEqual(row.data.event_id, event.event_id)
+
+ stream_name, token, row = received_rows.pop(0)
+ self.assertIsInstance(row, EventsStreamRow)
+ self.assertIsInstance(row.data, EventsStreamEventRow)
+ self.assertEqual(row.data.event_id, state_event.event_id)
+
+ stream_name, token, row = received_rows.pop(0)
+ self.assertEqual("events", stream_name)
+ self.assertIsInstance(row, EventsStreamRow)
+ self.assertEqual(row.type, "state")
+ self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
+ self.assertEqual(row.data.event_id, state_event.event_id)
+
+ self.assertEqual([], received_rows)
+
+ def test_update_function_huge_state_change(self):
+ """Test replication with many state events
+
+ Ensures that all events are correctly replicated when there are lots of
+ state change rows to be replicated.
+ """
+
+ # we want to generate lots of state changes at a single stream ID.
+ #
+ # We do this by having two branches in the DAG. On one, we have a moderator
+ # which that generates lots of state; on the other, we de-op the moderator,
+ # thus invalidating all the state.
+
+ OTHER_USER = "@other_user:localhost"
+
+ # have the user join
+ inject_member_event(self.hs, self.room_id, OTHER_USER, Membership.JOIN)
+
+ # Update existing power levels with mod at PL50
+ pls = self.helper.get_state(
+ self.room_id, EventTypes.PowerLevels, tok=self.user_tok
+ )
+ pls["users"][OTHER_USER] = 50
+ self.helper.send_state(
+ self.room_id, EventTypes.PowerLevels, pls, tok=self.user_tok,
+ )
+
+ # this is the point in the DAG where we make a fork
+ fork_point = self.get_success(
+ self.hs.get_datastore().get_latest_event_ids_in_room(self.room_id)
+ ) # type: List[str]
+
+ events = [
+ self._inject_state_event(sender=OTHER_USER)
+ for _ in range(_STREAM_UPDATE_TARGET_ROW_COUNT)
+ ]
+
+ self.replicate()
+ # all those events and state changes should have landed
+ self.assertGreaterEqual(
+ len(self.test_handler.received_rdata_rows), 2 * len(events)
+ )
+
+ # disconnect, so that we can stack up the changes
+ self.disconnect()
+ self.test_handler.received_rdata_rows.clear()
+
+ # a state event which doesn't get rolled back, to check that the state
+ # before the huge update comes through ok
+ state1 = self._inject_state_event()
+
+ # roll back all the state by de-modding the user
+ prev_events = fork_point
+ pls["users"][OTHER_USER] = 0
+ pl_event = inject_event(
+ self.hs,
+ prev_event_ids=prev_events,
+ type=EventTypes.PowerLevels,
+ state_key="",
+ sender=self.user_id,
+ room_id=self.room_id,
+ content=pls,
+ )
+
+ # one more bit of state that doesn't get rolled back
+ state2 = self._inject_state_event()
+
+ # tell the notifier to catch up to avoid duplicate rows.
+ # workaround for https://github.com/matrix-org/synapse/issues/7360
+ # FIXME remove this when the above is fixed
+ self.replicate()
+
+ # check we're testing what we think we are: no rows should yet have been
+ # received
+ self.assertEqual([], self.test_handler.received_rdata_rows)
+
+ # now reconnect to pull the updates
+ self.reconnect()
+ self.replicate()
+
+ # we should have received all the expected rows in the right order (as
+ # well as various cache invalidation updates which we ignore)
+ #
+ # we expect:
+ #
+ # - two rows for state1
+ # - the PL event row, plus state rows for the PL event and each
+ # of the states that got reverted.
+ # - two rows for state2
+
+ received_rows = [
+ row for row in self.test_handler.received_rdata_rows if row[0] == "events"
+ ]
+
+ # first check the first two rows, which should be state1
+
+ stream_name, token, row = received_rows.pop(0)
+ self.assertEqual("events", stream_name)
+ self.assertIsInstance(row, EventsStreamRow)
+ self.assertEqual(row.type, "ev")
+ self.assertIsInstance(row.data, EventsStreamEventRow)
+ self.assertEqual(row.data.event_id, state1.event_id)
+
+ stream_name, token, row = received_rows.pop(0)
+ self.assertIsInstance(row, EventsStreamRow)
+ self.assertEqual(row.type, "state")
+ self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
+ self.assertEqual(row.data.event_id, state1.event_id)
+
+ # now the last two rows, which should be state2
+ stream_name, token, row = received_rows.pop(-2)
+ self.assertEqual("events", stream_name)
+ self.assertIsInstance(row, EventsStreamRow)
+ self.assertEqual(row.type, "ev")
+ self.assertIsInstance(row.data, EventsStreamEventRow)
+ self.assertEqual(row.data.event_id, state2.event_id)
+
+ stream_name, token, row = received_rows.pop(-1)
+ self.assertIsInstance(row, EventsStreamRow)
+ self.assertEqual(row.type, "state")
+ self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
+ self.assertEqual(row.data.event_id, state2.event_id)
+
+ # that should leave us with the rows for the PL event
+ self.assertEqual(len(received_rows), len(events) + 2)
+
+ stream_name, token, row = received_rows.pop(0)
+ self.assertEqual("events", stream_name)
+ self.assertIsInstance(row, EventsStreamRow)
+ self.assertEqual(row.type, "ev")
+ self.assertIsInstance(row.data, EventsStreamEventRow)
+ self.assertEqual(row.data.event_id, pl_event.event_id)
+
+ # the state rows are unsorted
+ state_rows = [] # type: List[EventsStreamCurrentStateRow]
+ for stream_name, token, row in received_rows:
+ self.assertEqual("events", stream_name)
+ self.assertIsInstance(row, EventsStreamRow)
+ self.assertEqual(row.type, "state")
+ self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
+ state_rows.append(row.data)
+
+ state_rows.sort(key=lambda r: r.state_key)
+
+ sr = state_rows.pop(0)
+ self.assertEqual(sr.type, EventTypes.PowerLevels)
+ self.assertEqual(sr.event_id, pl_event.event_id)
+ for sr in state_rows:
+ self.assertEqual(sr.type, "test_state_event")
+ # "None" indicates the state has been deleted
+ self.assertIsNone(sr.event_id)
+
+ def test_update_function_state_row_limit(self):
+ """Test replication with many state events over several stream ids.
+ """
+
+ # we want to generate lots of state changes, but for this test, we want to
+ # spread out the state changes over a few stream IDs.
+ #
+ # We do this by having two branches in the DAG. On one, we have four moderators,
+ # each of which that generates lots of state; on the other, we de-op the users,
+ # thus invalidating all the state.
+
+ NUM_USERS = 4
+ STATES_PER_USER = _STREAM_UPDATE_TARGET_ROW_COUNT // 4 + 1
+
+ user_ids = ["@user%i:localhost" % (i,) for i in range(NUM_USERS)]
+
+ # have the users join
+ for u in user_ids:
+ inject_member_event(self.hs, self.room_id, u, Membership.JOIN)
+
+ # Update existing power levels with mod at PL50
+ pls = self.helper.get_state(
+ self.room_id, EventTypes.PowerLevels, tok=self.user_tok
+ )
+ pls["users"].update({u: 50 for u in user_ids})
+ self.helper.send_state(
+ self.room_id, EventTypes.PowerLevels, pls, tok=self.user_tok,
+ )
+
+ # this is the point in the DAG where we make a fork
+ fork_point = self.get_success(
+ self.hs.get_datastore().get_latest_event_ids_in_room(self.room_id)
+ ) # type: List[str]
+
+ events = [] # type: List[EventBase]
+ for user in user_ids:
+ events.extend(
+ self._inject_state_event(sender=user) for _ in range(STATES_PER_USER)
+ )
+
+ self.replicate()
+
+ # all those events and state changes should have landed
+ self.assertGreaterEqual(
+ len(self.test_handler.received_rdata_rows), 2 * len(events)
+ )
+
+ # disconnect, so that we can stack up the changes
+ self.disconnect()
+ self.test_handler.received_rdata_rows.clear()
+
+ # now roll back all that state by de-modding the users
+ prev_events = fork_point
+ pl_events = []
+ for u in user_ids:
+ pls["users"][u] = 0
+ e = inject_event(
+ self.hs,
+ prev_event_ids=prev_events,
+ type=EventTypes.PowerLevels,
+ state_key="",
+ sender=self.user_id,
+ room_id=self.room_id,
+ content=pls,
+ )
+ prev_events = [e.event_id]
+ pl_events.append(e)
+
+ # tell the notifier to catch up to avoid duplicate rows.
+ # workaround for https://github.com/matrix-org/synapse/issues/7360
+ # FIXME remove this when the above is fixed
+ self.replicate()
+
+ # check we're testing what we think we are: no rows should yet have been
+ # received
+ self.assertEqual([], self.test_handler.received_rdata_rows)
+
+ # now reconnect to pull the updates
+ self.reconnect()
+ self.replicate()
+
+ # we should have received all the expected rows in the right order (as
+ # well as various cache invalidation updates which we ignore)
+ received_rows = [
+ row for row in self.test_handler.received_rdata_rows if row[0] == "events"
+ ]
+ self.assertGreaterEqual(len(received_rows), len(events))
+ for i in range(NUM_USERS):
+ # for each user, we expect the PL event row, followed by state rows for
+ # the PL event and each of the states that got reverted.
+ stream_name, token, row = received_rows.pop(0)
+ self.assertEqual("events", stream_name)
+ self.assertIsInstance(row, EventsStreamRow)
+ self.assertEqual(row.type, "ev")
+ self.assertIsInstance(row.data, EventsStreamEventRow)
+ self.assertEqual(row.data.event_id, pl_events[i].event_id)
+
+ # the state rows are unsorted
+ state_rows = [] # type: List[EventsStreamCurrentStateRow]
+ for j in range(STATES_PER_USER + 1):
+ stream_name, token, row = received_rows.pop(0)
+ self.assertEqual("events", stream_name)
+ self.assertIsInstance(row, EventsStreamRow)
+ self.assertEqual(row.type, "state")
+ self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
+ state_rows.append(row.data)
+
+ state_rows.sort(key=lambda r: r.state_key)
+
+ sr = state_rows.pop(0)
+ self.assertEqual(sr.type, EventTypes.PowerLevels)
+ self.assertEqual(sr.event_id, pl_events[i].event_id)
+ for sr in state_rows:
+ self.assertEqual(sr.type, "test_state_event")
+ # "None" indicates the state has been deleted
+ self.assertIsNone(sr.event_id)
+
+ self.assertEqual([], received_rows)
+
+ event_count = 0
+
+ def _inject_test_event(
+ self, body: Optional[str] = None, sender: Optional[str] = None, **kwargs
+ ) -> EventBase:
+ if sender is None:
+ sender = self.user_id
+
+ if body is None:
+ body = "event %i" % (self.event_count,)
+ self.event_count += 1
+
+ return inject_event(
+ self.hs,
+ room_id=self.room_id,
+ sender=sender,
+ type="test_event",
+ content={"body": body},
+ **kwargs
+ )
+
+ def _inject_state_event(
+ self,
+ body: Optional[str] = None,
+ state_key: Optional[str] = None,
+ sender: Optional[str] = None,
+ ) -> EventBase:
+ if sender is None:
+ sender = self.user_id
+
+ if state_key is None:
+ state_key = "state_%i" % (self.event_count,)
+ self.event_count += 1
+
+ if body is None:
+ body = "state event %s" % (state_key,)
+
+ return inject_event(
+ self.hs,
+ room_id=self.room_id,
+ sender=sender,
+ type="test_state_event",
+ state_key=state_key,
+ content={"body": body},
+ )
diff --git a/tests/replication/tcp/streams/test_federation.py b/tests/replication/tcp/streams/test_federation.py
new file mode 100644
index 0000000000..2babea4e3e
--- /dev/null
+++ b/tests/replication/tcp/streams/test_federation.py
@@ -0,0 +1,81 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.federation.send_queue import EduRow
+from synapse.replication.tcp.streams.federation import FederationStream
+
+from tests.replication._base import BaseStreamTestCase
+
+
+class FederationStreamTestCase(BaseStreamTestCase):
+ def _get_worker_hs_config(self) -> dict:
+ # enable federation sending on the worker
+ config = super()._get_worker_hs_config()
+ # TODO: make it so we don't need both of these
+ config["send_federation"] = True
+ config["worker_app"] = "synapse.app.federation_sender"
+ return config
+
+ def test_catchup(self):
+ """Basic test of catchup on reconnect
+
+ Makes sure that updates sent while we are offline are received later.
+ """
+ fed_sender = self.hs.get_federation_sender()
+ received_rows = self.test_handler.received_rdata_rows
+
+ fed_sender.build_and_send_edu("testdest", "m.test_edu", {"a": "b"})
+
+ self.reconnect()
+ self.reactor.advance(0)
+
+ # check we're testing what we think we are: no rows should yet have been
+ # received
+ self.assertEqual(received_rows, [])
+
+ # We should now see an attempt to connect to the master
+ request = self.handle_http_replication_attempt()
+ self.assert_request_is_get_repl_stream_updates(request, "federation")
+
+ # we should have received an update row
+ stream_name, token, row = received_rows.pop()
+ self.assertEqual(stream_name, "federation")
+ self.assertIsInstance(row, FederationStream.FederationStreamRow)
+ self.assertEqual(row.type, EduRow.TypeId)
+ edurow = EduRow.from_data(row.data)
+ self.assertEqual(edurow.edu.edu_type, "m.test_edu")
+ self.assertEqual(edurow.edu.origin, self.hs.hostname)
+ self.assertEqual(edurow.edu.destination, "testdest")
+ self.assertEqual(edurow.edu.content, {"a": "b"})
+
+ self.assertEqual(received_rows, [])
+
+ # additional updates should be transferred without an HTTP hit
+ fed_sender.build_and_send_edu("testdest", "m.test1", {"c": "d"})
+ self.reactor.advance(0)
+ # there should be no http hit
+ self.assertEqual(len(self.reactor.tcpClients), 0)
+ # ... but we should have a row
+ self.assertEqual(len(received_rows), 1)
+
+ stream_name, token, row = received_rows.pop()
+ self.assertEqual(stream_name, "federation")
+ self.assertIsInstance(row, FederationStream.FederationStreamRow)
+ self.assertEqual(row.type, EduRow.TypeId)
+ edurow = EduRow.from_data(row.data)
+ self.assertEqual(edurow.edu.edu_type, "m.test1")
+ self.assertEqual(edurow.edu.origin, self.hs.hostname)
+ self.assertEqual(edurow.edu.destination, "testdest")
+ self.assertEqual(edurow.edu.content, {"c": "d"})
diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py
index d5a99f6caa..56b062ecc1 100644
--- a/tests/replication/tcp/streams/test_receipts.py
+++ b/tests/replication/tcp/streams/test_receipts.py
@@ -12,35 +12,73 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.replication.tcp.streams._base import ReceiptsStreamRow
-from tests.replication.tcp.streams._base import BaseStreamTestCase
+# type: ignore
+
+from mock import Mock
+
+from synapse.replication.tcp.streams._base import ReceiptsStream
+
+from tests.replication._base import BaseStreamTestCase
USER_ID = "@feeling:blue"
-ROOM_ID = "!room:blue"
-EVENT_ID = "$event:blue"
class ReceiptsStreamTestCase(BaseStreamTestCase):
+ def _build_replication_data_handler(self):
+ return Mock(wraps=super()._build_replication_data_handler())
+
def test_receipt(self):
- # make the client subscribe to the receipts stream
- self.replicate_stream("receipts", "NOW")
+ self.reconnect()
# tell the master to send a new receipt
self.get_success(
self.hs.get_datastore().insert_receipt(
- ROOM_ID, "m.read", USER_ID, [EVENT_ID], {"a": 1}
+ "!room:blue", "m.read", USER_ID, ["$event:blue"], {"a": 1}
)
)
self.replicate()
# there should be one RDATA command
- rdata_rows = self.test_handler.received_rdata_rows
+ self.test_handler.on_rdata.assert_called_once()
+ stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+ self.assertEqual(stream_name, "receipts")
self.assertEqual(1, len(rdata_rows))
- self.assertEqual(rdata_rows[0][0], "receipts")
- row = rdata_rows[0][2] # type: ReceiptsStreamRow
- self.assertEqual(ROOM_ID, row.room_id)
+ row = rdata_rows[0] # type: ReceiptsStream.ReceiptsStreamRow
+ self.assertEqual("!room:blue", row.room_id)
self.assertEqual("m.read", row.receipt_type)
self.assertEqual(USER_ID, row.user_id)
- self.assertEqual(EVENT_ID, row.event_id)
+ self.assertEqual("$event:blue", row.event_id)
self.assertEqual({"a": 1}, row.data)
+
+ # Now let's disconnect and insert some data.
+ self.disconnect()
+
+ self.test_handler.on_rdata.reset_mock()
+
+ self.get_success(
+ self.hs.get_datastore().insert_receipt(
+ "!room2:blue", "m.read", USER_ID, ["$event2:foo"], {"a": 2}
+ )
+ )
+ self.replicate()
+
+ # Nothing should have happened as we are disconnected
+ self.test_handler.on_rdata.assert_not_called()
+
+ self.reconnect()
+ self.pump(0.1)
+
+ # We should now have caught up and get the missing data
+ self.test_handler.on_rdata.assert_called_once()
+ stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+ self.assertEqual(stream_name, "receipts")
+ self.assertEqual(token, 3)
+ self.assertEqual(1, len(rdata_rows))
+
+ row = rdata_rows[0] # type: ReceiptsStream.ReceiptsStreamRow
+ self.assertEqual("!room2:blue", row.room_id)
+ self.assertEqual("m.read", row.receipt_type)
+ self.assertEqual(USER_ID, row.user_id)
+ self.assertEqual("$event2:foo", row.event_id)
+ self.assertEqual({"a": 2}, row.data)
diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py
new file mode 100644
index 0000000000..fd62b26356
--- /dev/null
+++ b/tests/replication/tcp/streams/test_typing.py
@@ -0,0 +1,77 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from mock import Mock
+
+from synapse.handlers.typing import RoomMember
+from synapse.replication.tcp.streams import TypingStream
+
+from tests.replication._base import BaseStreamTestCase
+
+USER_ID = "@feeling:blue"
+
+
+class TypingStreamTestCase(BaseStreamTestCase):
+ def _build_replication_data_handler(self):
+ return Mock(wraps=super()._build_replication_data_handler())
+
+ def test_typing(self):
+ typing = self.hs.get_typing_handler()
+
+ room_id = "!bar:blue"
+
+ self.reconnect()
+
+ typing._push_update(member=RoomMember(room_id, USER_ID), typing=True)
+
+ self.reactor.advance(0)
+
+ # We should now see an attempt to connect to the master
+ request = self.handle_http_replication_attempt()
+ self.assert_request_is_get_repl_stream_updates(request, "typing")
+
+ self.test_handler.on_rdata.assert_called_once()
+ stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+ self.assertEqual(stream_name, "typing")
+ self.assertEqual(1, len(rdata_rows))
+ row = rdata_rows[0] # type: TypingStream.TypingStreamRow
+ self.assertEqual(room_id, row.room_id)
+ self.assertEqual([USER_ID], row.user_ids)
+
+ # Now let's disconnect and insert some data.
+ self.disconnect()
+
+ self.test_handler.on_rdata.reset_mock()
+
+ typing._push_update(member=RoomMember(room_id, USER_ID), typing=False)
+
+ self.test_handler.on_rdata.assert_not_called()
+
+ self.reconnect()
+ self.pump(0.1)
+
+ # We should now see an attempt to connect to the master
+ request = self.handle_http_replication_attempt()
+ self.assert_request_is_get_repl_stream_updates(request, "typing")
+
+ # The from token should be the token from the last RDATA we got.
+ self.assertEqual(int(request.args[b"from_token"][0]), token)
+
+ self.test_handler.on_rdata.assert_called_once()
+ stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+ self.assertEqual(stream_name, "typing")
+ self.assertEqual(1, len(rdata_rows))
+ row = rdata_rows[0]
+ self.assertEqual(room_id, row.room_id)
+ self.assertEqual([], row.user_ids)
diff --git a/tests/replication/tcp/test_commands.py b/tests/replication/tcp/test_commands.py
new file mode 100644
index 0000000000..7ddfd0a733
--- /dev/null
+++ b/tests/replication/tcp/test_commands.py
@@ -0,0 +1,44 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from synapse.replication.tcp.commands import (
+ RdataCommand,
+ ReplicateCommand,
+ parse_command_from_line,
+)
+
+from tests.unittest import TestCase
+
+
+class ParseCommandTestCase(TestCase):
+ def test_parse_one_word_command(self):
+ line = "REPLICATE"
+ cmd = parse_command_from_line(line)
+ self.assertIsInstance(cmd, ReplicateCommand)
+
+ def test_parse_rdata(self):
+ line = 'RDATA events master 6287863 ["ev", ["$eventid", "!roomid", "type", null, null, null]]'
+ cmd = parse_command_from_line(line)
+ self.assertIsInstance(cmd, RdataCommand)
+ self.assertEqual(cmd.stream_name, "events")
+ self.assertEqual(cmd.instance_name, "master")
+ self.assertEqual(cmd.token, 6287863)
+
+ def test_parse_rdata_batch(self):
+ line = 'RDATA presence master batch ["@foo:example.com", "online"]'
+ cmd = parse_command_from_line(line)
+ self.assertIsInstance(cmd, RdataCommand)
+ self.assertEqual(cmd.stream_name, "presence")
+ self.assertEqual(cmd.instance_name, "master")
+ self.assertIsNone(cmd.token)
diff --git a/tests/replication/tcp/test_remote_server_up.py b/tests/replication/tcp/test_remote_server_up.py
new file mode 100644
index 0000000000..d1c15caeb0
--- /dev/null
+++ b/tests/replication/tcp/test_remote_server_up.py
@@ -0,0 +1,62 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Tuple
+
+from twisted.internet.interfaces import IProtocol
+from twisted.test.proto_helpers import StringTransport
+
+from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
+
+from tests.unittest import HomeserverTestCase
+
+
+class RemoteServerUpTestCase(HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
+ self.factory = ReplicationStreamProtocolFactory(hs)
+
+ def _make_client(self) -> Tuple[IProtocol, StringTransport]:
+ """Create a new direct TCP replication connection
+ """
+
+ proto = self.factory.buildProtocol(("127.0.0.1", 0))
+ transport = StringTransport()
+ proto.makeConnection(transport)
+
+ # We can safely ignore the commands received during connection.
+ self.pump()
+ transport.clear()
+
+ return proto, transport
+
+ def test_relay(self):
+ """Test that Synapse will relay REMOTE_SERVER_UP commands to all
+ other connections, but not the one that sent it.
+ """
+
+ proto1, transport1 = self._make_client()
+
+ # We shouldn't receive an echo.
+ proto1.dataReceived(b"REMOTE_SERVER_UP example.com\n")
+ self.pump()
+ self.assertEqual(transport1.value(), b"")
+
+ # But we should see an echo if we connect another client
+ proto2, transport2 = self._make_client()
+ proto1.dataReceived(b"REMOTE_SERVER_UP example.com\n")
+
+ self.pump()
+ self.assertEqual(transport1.value(), b"")
+ self.assertEqual(transport2.value(), b"REMOTE_SERVER_UP example.com\n")
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index 0342aed416..977615ebef 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -17,7 +17,6 @@ import json
import os
import urllib.parse
from binascii import unhexlify
-from typing import List, Optional
from mock import Mock
@@ -27,7 +26,7 @@ import synapse.rest.admin
from synapse.http.server import JsonResource
from synapse.logging.context import make_deferred_yieldable
from synapse.rest.admin import VersionServlet
-from synapse.rest.client.v1 import directory, events, login, room
+from synapse.rest.client.v1 import login, room
from synapse.rest.client.v2_alpha import groups
from tests import unittest
@@ -51,129 +50,6 @@ class VersionTestCase(unittest.HomeserverTestCase):
)
-class ShutdownRoomTestCase(unittest.HomeserverTestCase):
- servlets = [
- synapse.rest.admin.register_servlets_for_client_rest_resource,
- login.register_servlets,
- events.register_servlets,
- room.register_servlets,
- room.register_deprecated_servlets,
- ]
-
- def prepare(self, reactor, clock, hs):
- self.event_creation_handler = hs.get_event_creation_handler()
- hs.config.user_consent_version = "1"
-
- consent_uri_builder = Mock()
- consent_uri_builder.build_user_consent_uri.return_value = "http://example.com"
- self.event_creation_handler._consent_uri_builder = consent_uri_builder
-
- self.store = hs.get_datastore()
-
- self.admin_user = self.register_user("admin", "pass", admin=True)
- self.admin_user_tok = self.login("admin", "pass")
-
- self.other_user = self.register_user("user", "pass")
- self.other_user_token = self.login("user", "pass")
-
- # Mark the admin user as having consented
- self.get_success(self.store.user_set_consent_version(self.admin_user, "1"))
-
- def test_shutdown_room_consent(self):
- """Test that we can shutdown rooms with local users who have not
- yet accepted the privacy policy. This used to fail when we tried to
- force part the user from the old room.
- """
- self.event_creation_handler._block_events_without_consent_error = None
-
- room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token)
-
- # Assert one user in room
- users_in_room = self.get_success(self.store.get_users_in_room(room_id))
- self.assertEqual([self.other_user], users_in_room)
-
- # Enable require consent to send events
- self.event_creation_handler._block_events_without_consent_error = "Error"
-
- # Assert that the user is getting consent error
- self.helper.send(
- room_id, body="foo", tok=self.other_user_token, expect_code=403
- )
-
- # Test that the admin can still send shutdown
- url = "admin/shutdown_room/" + room_id
- request, channel = self.make_request(
- "POST",
- url.encode("ascii"),
- json.dumps({"new_room_user_id": self.admin_user}),
- access_token=self.admin_user_tok,
- )
- self.render(request)
-
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-
- # Assert there is now no longer anyone in the room
- users_in_room = self.get_success(self.store.get_users_in_room(room_id))
- self.assertEqual([], users_in_room)
-
- def test_shutdown_room_block_peek(self):
- """Test that a world_readable room can no longer be peeked into after
- it has been shut down.
- """
-
- self.event_creation_handler._block_events_without_consent_error = None
-
- room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token)
-
- # Enable world readable
- url = "rooms/%s/state/m.room.history_visibility" % (room_id,)
- request, channel = self.make_request(
- "PUT",
- url.encode("ascii"),
- json.dumps({"history_visibility": "world_readable"}),
- access_token=self.other_user_token,
- )
- self.render(request)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-
- # Test that the admin can still send shutdown
- url = "admin/shutdown_room/" + room_id
- request, channel = self.make_request(
- "POST",
- url.encode("ascii"),
- json.dumps({"new_room_user_id": self.admin_user}),
- access_token=self.admin_user_tok,
- )
- self.render(request)
-
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-
- # Assert we can no longer peek into the room
- self._assert_peek(room_id, expect_code=403)
-
- def _assert_peek(self, room_id, expect_code):
- """Assert that the admin user can (or cannot) peek into the room.
- """
-
- url = "rooms/%s/initialSync" % (room_id,)
- request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok
- )
- self.render(request)
- self.assertEqual(
- expect_code, int(channel.result["code"]), msg=channel.result["body"]
- )
-
- url = "events?timeout=0&room_id=" + room_id
- request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok
- )
- self.render(request)
- self.assertEqual(
- expect_code, int(channel.result["code"]), msg=channel.result["body"]
- )
-
-
class DeleteGroupTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
@@ -273,86 +149,6 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
return channel.json_body["groups"]
-class PurgeRoomTestCase(unittest.HomeserverTestCase):
- """Test /purge_room admin API.
- """
-
- servlets = [
- synapse.rest.admin.register_servlets,
- login.register_servlets,
- room.register_servlets,
- ]
-
- def prepare(self, reactor, clock, hs):
- self.store = hs.get_datastore()
-
- self.admin_user = self.register_user("admin", "pass", admin=True)
- self.admin_user_tok = self.login("admin", "pass")
-
- def test_purge_room(self):
- room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
-
- # All users have to have left the room.
- self.helper.leave(room_id, user=self.admin_user, tok=self.admin_user_tok)
-
- url = "/_synapse/admin/v1/purge_room"
- request, channel = self.make_request(
- "POST",
- url.encode("ascii"),
- {"room_id": room_id},
- access_token=self.admin_user_tok,
- )
- self.render(request)
-
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-
- # Test that the following tables have been purged of all rows related to the room.
- for table in (
- "current_state_events",
- "event_backward_extremities",
- "event_forward_extremities",
- "event_json",
- "event_push_actions",
- "event_search",
- "events",
- "group_rooms",
- "public_room_list_stream",
- "receipts_graph",
- "receipts_linearized",
- "room_aliases",
- "room_depth",
- "room_memberships",
- "room_stats_state",
- "room_stats_current",
- "room_stats_historical",
- "room_stats_earliest_token",
- "rooms",
- "stream_ordering_to_exterm",
- "users_in_public_rooms",
- "users_who_share_private_rooms",
- "appservice_room_list",
- "e2e_room_keys",
- "event_push_summary",
- "pusher_throttle",
- "group_summary_rooms",
- "local_invites",
- "room_account_data",
- "room_tags",
- # "state_groups", # Current impl leaves orphaned state groups around.
- "state_groups_state",
- ):
- count = self.get_success(
- self.store.db.simple_select_one_onecol(
- table=table,
- keyvalues={"room_id": room_id},
- retcol="COUNT(*)",
- desc="test_purge_room",
- )
- )
-
- self.assertEqual(count, 0, msg="Rows not purged in {}".format(table))
-
-
class QuarantineMediaTestCase(unittest.HomeserverTestCase):
"""Test /quarantine_media admin API.
"""
@@ -691,389 +487,3 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
% server_and_media_id_2
),
)
-
-
-class RoomTestCase(unittest.HomeserverTestCase):
- """Test /room admin API.
- """
-
- servlets = [
- synapse.rest.admin.register_servlets,
- login.register_servlets,
- room.register_servlets,
- directory.register_servlets,
- ]
-
- def prepare(self, reactor, clock, hs):
- self.store = hs.get_datastore()
-
- # Create user
- self.admin_user = self.register_user("admin", "pass", admin=True)
- self.admin_user_tok = self.login("admin", "pass")
-
- def test_list_rooms(self):
- """Test that we can list rooms"""
- # Create 3 test rooms
- total_rooms = 3
- room_ids = []
- for x in range(total_rooms):
- room_id = self.helper.create_room_as(
- self.admin_user, tok=self.admin_user_tok
- )
- room_ids.append(room_id)
-
- # Request the list of rooms
- url = "/_synapse/admin/v1/rooms"
- request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok,
- )
- self.render(request)
-
- # Check request completed successfully
- self.assertEqual(200, int(channel.code), msg=channel.json_body)
-
- # Check that response json body contains a "rooms" key
- self.assertTrue(
- "rooms" in channel.json_body,
- msg="Response body does not " "contain a 'rooms' key",
- )
-
- # Check that 3 rooms were returned
- self.assertEqual(3, len(channel.json_body["rooms"]), msg=channel.json_body)
-
- # Check their room_ids match
- returned_room_ids = [room["room_id"] for room in channel.json_body["rooms"]]
- self.assertEqual(room_ids, returned_room_ids)
-
- # Check that all fields are available
- for r in channel.json_body["rooms"]:
- self.assertIn("name", r)
- self.assertIn("canonical_alias", r)
- self.assertIn("joined_members", r)
-
- # Check that the correct number of total rooms was returned
- self.assertEqual(channel.json_body["total_rooms"], total_rooms)
-
- # Check that the offset is correct
- # Should be 0 as we aren't paginating
- self.assertEqual(channel.json_body["offset"], 0)
-
- # Check that the prev_batch parameter is not present
- self.assertNotIn("prev_batch", channel.json_body)
-
- # We shouldn't receive a next token here as there's no further rooms to show
- self.assertNotIn("next_batch", channel.json_body)
-
- def test_list_rooms_pagination(self):
- """Test that we can get a full list of rooms through pagination"""
- # Create 5 test rooms
- total_rooms = 5
- room_ids = []
- for x in range(total_rooms):
- room_id = self.helper.create_room_as(
- self.admin_user, tok=self.admin_user_tok
- )
- room_ids.append(room_id)
-
- # Set the name of the rooms so we get a consistent returned ordering
- for idx, room_id in enumerate(room_ids):
- self.helper.send_state(
- room_id, "m.room.name", {"name": str(idx)}, tok=self.admin_user_tok,
- )
-
- # Request the list of rooms
- returned_room_ids = []
- start = 0
- limit = 2
-
- run_count = 0
- should_repeat = True
- while should_repeat:
- run_count += 1
-
- url = "/_synapse/admin/v1/rooms?from=%d&limit=%d&order_by=%s" % (
- start,
- limit,
- "alphabetical",
- )
- request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok,
- )
- self.render(request)
- self.assertEqual(
- 200, int(channel.result["code"]), msg=channel.result["body"]
- )
-
- self.assertTrue("rooms" in channel.json_body)
- for r in channel.json_body["rooms"]:
- returned_room_ids.append(r["room_id"])
-
- # Check that the correct number of total rooms was returned
- self.assertEqual(channel.json_body["total_rooms"], total_rooms)
-
- # Check that the offset is correct
- # We're only getting 2 rooms each page, so should be 2 * last run_count
- self.assertEqual(channel.json_body["offset"], 2 * (run_count - 1))
-
- if run_count > 1:
- # Check the value of prev_batch is correct
- self.assertEqual(channel.json_body["prev_batch"], 2 * (run_count - 2))
-
- if "next_batch" not in channel.json_body:
- # We have reached the end of the list
- should_repeat = False
- else:
- # Make another query with an updated start value
- start = channel.json_body["next_batch"]
-
- # We should've queried the endpoint 3 times
- self.assertEqual(
- run_count,
- 3,
- msg="Should've queried 3 times for 5 rooms with limit 2 per query",
- )
-
- # Check that we received all of the room ids
- self.assertEqual(room_ids, returned_room_ids)
-
- url = "/_synapse/admin/v1/rooms?from=%d&limit=%d" % (start, limit)
- request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok,
- )
- self.render(request)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-
- def test_correct_room_attributes(self):
- """Test the correct attributes for a room are returned"""
- # Create a test room
- room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
-
- test_alias = "#test:test"
- test_room_name = "something"
-
- # Have another user join the room
- user_2 = self.register_user("user4", "pass")
- user_tok_2 = self.login("user4", "pass")
- self.helper.join(room_id, user_2, tok=user_tok_2)
-
- # Create a new alias to this room
- url = "/_matrix/client/r0/directory/room/%s" % (urllib.parse.quote(test_alias),)
- request, channel = self.make_request(
- "PUT",
- url.encode("ascii"),
- {"room_id": room_id},
- access_token=self.admin_user_tok,
- )
- self.render(request)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-
- # Set this new alias as the canonical alias for this room
- self.helper.send_state(
- room_id,
- "m.room.aliases",
- {"aliases": [test_alias]},
- tok=self.admin_user_tok,
- state_key="test",
- )
- self.helper.send_state(
- room_id,
- "m.room.canonical_alias",
- {"alias": test_alias},
- tok=self.admin_user_tok,
- )
-
- # Set a name for the room
- self.helper.send_state(
- room_id, "m.room.name", {"name": test_room_name}, tok=self.admin_user_tok,
- )
-
- # Request the list of rooms
- url = "/_synapse/admin/v1/rooms"
- request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok,
- )
- self.render(request)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-
- # Check that rooms were returned
- self.assertTrue("rooms" in channel.json_body)
- rooms = channel.json_body["rooms"]
-
- # Check that only one room was returned
- self.assertEqual(len(rooms), 1)
-
- # And that the value of the total_rooms key was correct
- self.assertEqual(channel.json_body["total_rooms"], 1)
-
- # Check that the offset is correct
- # We're not paginating, so should be 0
- self.assertEqual(channel.json_body["offset"], 0)
-
- # Check that there is no `prev_batch`
- self.assertNotIn("prev_batch", channel.json_body)
-
- # Check that there is no `next_batch`
- self.assertNotIn("next_batch", channel.json_body)
-
- # Check that all provided attributes are set
- r = rooms[0]
- self.assertEqual(room_id, r["room_id"])
- self.assertEqual(test_room_name, r["name"])
- self.assertEqual(test_alias, r["canonical_alias"])
-
- def test_room_list_sort_order(self):
- """Test room list sort ordering. alphabetical versus number of members,
- reversing the order, etc.
- """
- # Create 3 test rooms
- room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
- room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
- room_id_3 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
-
- # Set room names in alphabetical order. room 1 -> A, 2 -> B, 3 -> C
- self.helper.send_state(
- room_id_1, "m.room.name", {"name": "A"}, tok=self.admin_user_tok,
- )
- self.helper.send_state(
- room_id_2, "m.room.name", {"name": "B"}, tok=self.admin_user_tok,
- )
- self.helper.send_state(
- room_id_3, "m.room.name", {"name": "C"}, tok=self.admin_user_tok,
- )
-
- # Set room member size in the reverse order. room 1 -> 1 member, 2 -> 2, 3 -> 3
- user_1 = self.register_user("bob1", "pass")
- user_1_tok = self.login("bob1", "pass")
- self.helper.join(room_id_2, user_1, tok=user_1_tok)
-
- user_2 = self.register_user("bob2", "pass")
- user_2_tok = self.login("bob2", "pass")
- self.helper.join(room_id_3, user_2, tok=user_2_tok)
-
- user_3 = self.register_user("bob3", "pass")
- user_3_tok = self.login("bob3", "pass")
- self.helper.join(room_id_3, user_3, tok=user_3_tok)
-
- def _order_test(
- order_type: str, expected_room_list: List[str], reverse: bool = False,
- ):
- """Request the list of rooms in a certain order. Assert that order is what
- we expect
-
- Args:
- order_type: The type of ordering to give the server
- expected_room_list: The list of room_ids in the order we expect to get
- back from the server
- """
- # Request the list of rooms in the given order
- url = "/_synapse/admin/v1/rooms?order_by=%s" % (order_type,)
- if reverse:
- url += "&dir=b"
- request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok,
- )
- self.render(request)
- self.assertEqual(200, channel.code, msg=channel.json_body)
-
- # Check that rooms were returned
- self.assertTrue("rooms" in channel.json_body)
- rooms = channel.json_body["rooms"]
-
- # Check for the correct total_rooms value
- self.assertEqual(channel.json_body["total_rooms"], 3)
-
- # Check that the offset is correct
- # We're not paginating, so should be 0
- self.assertEqual(channel.json_body["offset"], 0)
-
- # Check that there is no `prev_batch`
- self.assertNotIn("prev_batch", channel.json_body)
-
- # Check that there is no `next_batch`
- self.assertNotIn("next_batch", channel.json_body)
-
- # Check that rooms were returned in alphabetical order
- returned_order = [r["room_id"] for r in rooms]
- self.assertListEqual(expected_room_list, returned_order) # order is checked
-
- # Test different sort orders, with forward and reverse directions
- _order_test("alphabetical", [room_id_1, room_id_2, room_id_3])
- _order_test("alphabetical", [room_id_3, room_id_2, room_id_1], reverse=True)
-
- _order_test("size", [room_id_3, room_id_2, room_id_1])
- _order_test("size", [room_id_1, room_id_2, room_id_3], reverse=True)
-
- def test_search_term(self):
- """Test that searching for a room works correctly"""
- # Create two test rooms
- room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
- room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
-
- room_name_1 = "something"
- room_name_2 = "else"
-
- # Set the name for each room
- self.helper.send_state(
- room_id_1, "m.room.name", {"name": room_name_1}, tok=self.admin_user_tok,
- )
- self.helper.send_state(
- room_id_2, "m.room.name", {"name": room_name_2}, tok=self.admin_user_tok,
- )
-
- def _search_test(
- expected_room_id: Optional[str],
- search_term: str,
- expected_http_code: int = 200,
- ):
- """Search for a room and check that the returned room's id is a match
-
- Args:
- expected_room_id: The room_id expected to be returned by the API. Set
- to None to expect zero results for the search
- search_term: The term to search for room names with
- expected_http_code: The expected http code for the request
- """
- url = "/_synapse/admin/v1/rooms?search_term=%s" % (search_term,)
- request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok,
- )
- self.render(request)
- self.assertEqual(expected_http_code, channel.code, msg=channel.json_body)
-
- if expected_http_code != 200:
- return
-
- # Check that rooms were returned
- self.assertTrue("rooms" in channel.json_body)
- rooms = channel.json_body["rooms"]
-
- # Check that the expected number of rooms were returned
- expected_room_count = 1 if expected_room_id else 0
- self.assertEqual(len(rooms), expected_room_count)
- self.assertEqual(channel.json_body["total_rooms"], expected_room_count)
-
- # Check that the offset is correct
- # We're not paginating, so should be 0
- self.assertEqual(channel.json_body["offset"], 0)
-
- # Check that there is no `prev_batch`
- self.assertNotIn("prev_batch", channel.json_body)
-
- # Check that there is no `next_batch`
- self.assertNotIn("next_batch", channel.json_body)
-
- if expected_room_id:
- # Check that the first returned room id is correct
- r = rooms[0]
- self.assertEqual(expected_room_id, r["room_id"])
-
- # Perform search tests
- _search_test(room_id_1, "something")
- _search_test(room_id_1, "thing")
-
- _search_test(room_id_2, "else")
- _search_test(room_id_2, "se")
-
- _search_test(None, "foo")
- _search_test(None, "bar")
- _search_test(None, "", expected_http_code=400)
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
new file mode 100644
index 0000000000..54cd24bf64
--- /dev/null
+++ b/tests/rest/admin/test_room.py
@@ -0,0 +1,1007 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Dirk Klimpel
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import urllib.parse
+from typing import List, Optional
+
+from mock import Mock
+
+import synapse.rest.admin
+from synapse.api.errors import Codes
+from synapse.rest.client.v1 import directory, events, login, room
+
+from tests import unittest
+
+"""Tests admin REST events for /rooms paths."""
+
+
+class ShutdownRoomTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ login.register_servlets,
+ events.register_servlets,
+ room.register_servlets,
+ room.register_deprecated_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.event_creation_handler = hs.get_event_creation_handler()
+ hs.config.user_consent_version = "1"
+
+ consent_uri_builder = Mock()
+ consent_uri_builder.build_user_consent_uri.return_value = "http://example.com"
+ self.event_creation_handler._consent_uri_builder = consent_uri_builder
+
+ self.store = hs.get_datastore()
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+ self.other_user_token = self.login("user", "pass")
+
+ # Mark the admin user as having consented
+ self.get_success(self.store.user_set_consent_version(self.admin_user, "1"))
+
+ def test_shutdown_room_consent(self):
+ """Test that we can shutdown rooms with local users who have not
+ yet accepted the privacy policy. This used to fail when we tried to
+ force part the user from the old room.
+ """
+ self.event_creation_handler._block_events_without_consent_error = None
+
+ room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token)
+
+ # Assert one user in room
+ users_in_room = self.get_success(self.store.get_users_in_room(room_id))
+ self.assertEqual([self.other_user], users_in_room)
+
+ # Enable require consent to send events
+ self.event_creation_handler._block_events_without_consent_error = "Error"
+
+ # Assert that the user is getting consent error
+ self.helper.send(
+ room_id, body="foo", tok=self.other_user_token, expect_code=403
+ )
+
+ # Test that the admin can still send shutdown
+ url = "admin/shutdown_room/" + room_id
+ request, channel = self.make_request(
+ "POST",
+ url.encode("ascii"),
+ json.dumps({"new_room_user_id": self.admin_user}),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Assert there is now no longer anyone in the room
+ users_in_room = self.get_success(self.store.get_users_in_room(room_id))
+ self.assertEqual([], users_in_room)
+
+ def test_shutdown_room_block_peek(self):
+ """Test that a world_readable room can no longer be peeked into after
+ it has been shut down.
+ """
+
+ self.event_creation_handler._block_events_without_consent_error = None
+
+ room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token)
+
+ # Enable world readable
+ url = "rooms/%s/state/m.room.history_visibility" % (room_id,)
+ request, channel = self.make_request(
+ "PUT",
+ url.encode("ascii"),
+ json.dumps({"history_visibility": "world_readable"}),
+ access_token=self.other_user_token,
+ )
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Test that the admin can still send shutdown
+ url = "admin/shutdown_room/" + room_id
+ request, channel = self.make_request(
+ "POST",
+ url.encode("ascii"),
+ json.dumps({"new_room_user_id": self.admin_user}),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Assert we can no longer peek into the room
+ self._assert_peek(room_id, expect_code=403)
+
+ def _assert_peek(self, room_id, expect_code):
+ """Assert that the admin user can (or cannot) peek into the room.
+ """
+
+ url = "rooms/%s/initialSync" % (room_id,)
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok
+ )
+ self.render(request)
+ self.assertEqual(
+ expect_code, int(channel.result["code"]), msg=channel.result["body"]
+ )
+
+ url = "events?timeout=0&room_id=" + room_id
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok
+ )
+ self.render(request)
+ self.assertEqual(
+ expect_code, int(channel.result["code"]), msg=channel.result["body"]
+ )
+
+
+class PurgeRoomTestCase(unittest.HomeserverTestCase):
+ """Test /purge_room admin API.
+ """
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ def test_purge_room(self):
+ room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+ # All users have to have left the room.
+ self.helper.leave(room_id, user=self.admin_user, tok=self.admin_user_tok)
+
+ url = "/_synapse/admin/v1/purge_room"
+ request, channel = self.make_request(
+ "POST",
+ url.encode("ascii"),
+ {"room_id": room_id},
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Test that the following tables have been purged of all rows related to the room.
+ for table in (
+ "current_state_events",
+ "event_backward_extremities",
+ "event_forward_extremities",
+ "event_json",
+ "event_push_actions",
+ "event_search",
+ "events",
+ "group_rooms",
+ "public_room_list_stream",
+ "receipts_graph",
+ "receipts_linearized",
+ "room_aliases",
+ "room_depth",
+ "room_memberships",
+ "room_stats_state",
+ "room_stats_current",
+ "room_stats_historical",
+ "room_stats_earliest_token",
+ "rooms",
+ "stream_ordering_to_exterm",
+ "users_in_public_rooms",
+ "users_who_share_private_rooms",
+ "appservice_room_list",
+ "e2e_room_keys",
+ "event_push_summary",
+ "pusher_throttle",
+ "group_summary_rooms",
+ "local_invites",
+ "room_account_data",
+ "room_tags",
+ # "state_groups", # Current impl leaves orphaned state groups around.
+ "state_groups_state",
+ ):
+ count = self.get_success(
+ self.store.db.simple_select_one_onecol(
+ table=table,
+ keyvalues={"room_id": room_id},
+ retcol="COUNT(*)",
+ desc="test_purge_room",
+ )
+ )
+
+ self.assertEqual(count, 0, msg="Rows not purged in {}".format(table))
+
+
+class RoomTestCase(unittest.HomeserverTestCase):
+ """Test /room admin API.
+ """
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ directory.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+
+ # Create user
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ def test_list_rooms(self):
+ """Test that we can list rooms"""
+ # Create 3 test rooms
+ total_rooms = 3
+ room_ids = []
+ for x in range(total_rooms):
+ room_id = self.helper.create_room_as(
+ self.admin_user, tok=self.admin_user_tok
+ )
+ room_ids.append(room_id)
+
+ # Request the list of rooms
+ url = "/_synapse/admin/v1/rooms"
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ # Check request completed successfully
+ self.assertEqual(200, int(channel.code), msg=channel.json_body)
+
+ # Check that response json body contains a "rooms" key
+ self.assertTrue(
+ "rooms" in channel.json_body,
+ msg="Response body does not " "contain a 'rooms' key",
+ )
+
+ # Check that 3 rooms were returned
+ self.assertEqual(3, len(channel.json_body["rooms"]), msg=channel.json_body)
+
+ # Check their room_ids match
+ returned_room_ids = [room["room_id"] for room in channel.json_body["rooms"]]
+ self.assertEqual(room_ids, returned_room_ids)
+
+ # Check that all fields are available
+ for r in channel.json_body["rooms"]:
+ self.assertIn("name", r)
+ self.assertIn("canonical_alias", r)
+ self.assertIn("joined_members", r)
+ self.assertIn("joined_local_members", r)
+ self.assertIn("version", r)
+ self.assertIn("creator", r)
+ self.assertIn("encryption", r)
+ self.assertIn("federatable", r)
+ self.assertIn("public", r)
+ self.assertIn("join_rules", r)
+ self.assertIn("guest_access", r)
+ self.assertIn("history_visibility", r)
+ self.assertIn("state_events", r)
+
+ # Check that the correct number of total rooms was returned
+ self.assertEqual(channel.json_body["total_rooms"], total_rooms)
+
+ # Check that the offset is correct
+ # Should be 0 as we aren't paginating
+ self.assertEqual(channel.json_body["offset"], 0)
+
+ # Check that the prev_batch parameter is not present
+ self.assertNotIn("prev_batch", channel.json_body)
+
+ # We shouldn't receive a next token here as there's no further rooms to show
+ self.assertNotIn("next_batch", channel.json_body)
+
+ def test_list_rooms_pagination(self):
+ """Test that we can get a full list of rooms through pagination"""
+ # Create 5 test rooms
+ total_rooms = 5
+ room_ids = []
+ for x in range(total_rooms):
+ room_id = self.helper.create_room_as(
+ self.admin_user, tok=self.admin_user_tok
+ )
+ room_ids.append(room_id)
+
+ # Set the name of the rooms so we get a consistent returned ordering
+ for idx, room_id in enumerate(room_ids):
+ self.helper.send_state(
+ room_id, "m.room.name", {"name": str(idx)}, tok=self.admin_user_tok,
+ )
+
+ # Request the list of rooms
+ returned_room_ids = []
+ start = 0
+ limit = 2
+
+ run_count = 0
+ should_repeat = True
+ while should_repeat:
+ run_count += 1
+
+ url = "/_synapse/admin/v1/rooms?from=%d&limit=%d&order_by=%s" % (
+ start,
+ limit,
+ "name",
+ )
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(
+ 200, int(channel.result["code"]), msg=channel.result["body"]
+ )
+
+ self.assertTrue("rooms" in channel.json_body)
+ for r in channel.json_body["rooms"]:
+ returned_room_ids.append(r["room_id"])
+
+ # Check that the correct number of total rooms was returned
+ self.assertEqual(channel.json_body["total_rooms"], total_rooms)
+
+ # Check that the offset is correct
+ # We're only getting 2 rooms each page, so should be 2 * last run_count
+ self.assertEqual(channel.json_body["offset"], 2 * (run_count - 1))
+
+ if run_count > 1:
+ # Check the value of prev_batch is correct
+ self.assertEqual(channel.json_body["prev_batch"], 2 * (run_count - 2))
+
+ if "next_batch" not in channel.json_body:
+ # We have reached the end of the list
+ should_repeat = False
+ else:
+ # Make another query with an updated start value
+ start = channel.json_body["next_batch"]
+
+ # We should've queried the endpoint 3 times
+ self.assertEqual(
+ run_count,
+ 3,
+ msg="Should've queried 3 times for 5 rooms with limit 2 per query",
+ )
+
+ # Check that we received all of the room ids
+ self.assertEqual(room_ids, returned_room_ids)
+
+ url = "/_synapse/admin/v1/rooms?from=%d&limit=%d" % (start, limit)
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ def test_correct_room_attributes(self):
+ """Test the correct attributes for a room are returned"""
+ # Create a test room
+ room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+ test_alias = "#test:test"
+ test_room_name = "something"
+
+ # Have another user join the room
+ user_2 = self.register_user("user4", "pass")
+ user_tok_2 = self.login("user4", "pass")
+ self.helper.join(room_id, user_2, tok=user_tok_2)
+
+ # Create a new alias to this room
+ url = "/_matrix/client/r0/directory/room/%s" % (urllib.parse.quote(test_alias),)
+ request, channel = self.make_request(
+ "PUT",
+ url.encode("ascii"),
+ {"room_id": room_id},
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Set this new alias as the canonical alias for this room
+ self.helper.send_state(
+ room_id,
+ "m.room.aliases",
+ {"aliases": [test_alias]},
+ tok=self.admin_user_tok,
+ state_key="test",
+ )
+ self.helper.send_state(
+ room_id,
+ "m.room.canonical_alias",
+ {"alias": test_alias},
+ tok=self.admin_user_tok,
+ )
+
+ # Set a name for the room
+ self.helper.send_state(
+ room_id, "m.room.name", {"name": test_room_name}, tok=self.admin_user_tok,
+ )
+
+ # Request the list of rooms
+ url = "/_synapse/admin/v1/rooms"
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Check that rooms were returned
+ self.assertTrue("rooms" in channel.json_body)
+ rooms = channel.json_body["rooms"]
+
+ # Check that only one room was returned
+ self.assertEqual(len(rooms), 1)
+
+ # And that the value of the total_rooms key was correct
+ self.assertEqual(channel.json_body["total_rooms"], 1)
+
+ # Check that the offset is correct
+ # We're not paginating, so should be 0
+ self.assertEqual(channel.json_body["offset"], 0)
+
+ # Check that there is no `prev_batch`
+ self.assertNotIn("prev_batch", channel.json_body)
+
+ # Check that there is no `next_batch`
+ self.assertNotIn("next_batch", channel.json_body)
+
+ # Check that all provided attributes are set
+ r = rooms[0]
+ self.assertEqual(room_id, r["room_id"])
+ self.assertEqual(test_room_name, r["name"])
+ self.assertEqual(test_alias, r["canonical_alias"])
+
+ def test_room_list_sort_order(self):
+ """Test room list sort ordering. alphabetical name versus number of members,
+ reversing the order, etc.
+ """
+
+ def _set_canonical_alias(room_id: str, test_alias: str, admin_user_tok: str):
+ # Create a new alias to this room
+ url = "/_matrix/client/r0/directory/room/%s" % (
+ urllib.parse.quote(test_alias),
+ )
+ request, channel = self.make_request(
+ "PUT",
+ url.encode("ascii"),
+ {"room_id": room_id},
+ access_token=admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(
+ 200, int(channel.result["code"]), msg=channel.result["body"]
+ )
+
+ # Set this new alias as the canonical alias for this room
+ self.helper.send_state(
+ room_id,
+ "m.room.aliases",
+ {"aliases": [test_alias]},
+ tok=admin_user_tok,
+ state_key="test",
+ )
+ self.helper.send_state(
+ room_id,
+ "m.room.canonical_alias",
+ {"alias": test_alias},
+ tok=admin_user_tok,
+ )
+
+ def _order_test(
+ order_type: str, expected_room_list: List[str], reverse: bool = False,
+ ):
+ """Request the list of rooms in a certain order. Assert that order is what
+ we expect
+
+ Args:
+ order_type: The type of ordering to give the server
+ expected_room_list: The list of room_ids in the order we expect to get
+ back from the server
+ """
+ # Request the list of rooms in the given order
+ url = "/_synapse/admin/v1/rooms?order_by=%s" % (order_type,)
+ if reverse:
+ url += "&dir=b"
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ # Check that rooms were returned
+ self.assertTrue("rooms" in channel.json_body)
+ rooms = channel.json_body["rooms"]
+
+ # Check for the correct total_rooms value
+ self.assertEqual(channel.json_body["total_rooms"], 3)
+
+ # Check that the offset is correct
+ # We're not paginating, so should be 0
+ self.assertEqual(channel.json_body["offset"], 0)
+
+ # Check that there is no `prev_batch`
+ self.assertNotIn("prev_batch", channel.json_body)
+
+ # Check that there is no `next_batch`
+ self.assertNotIn("next_batch", channel.json_body)
+
+ # Check that rooms were returned in alphabetical order
+ returned_order = [r["room_id"] for r in rooms]
+ self.assertListEqual(expected_room_list, returned_order) # order is checked
+
+ # Create 3 test rooms
+ room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+ room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+ room_id_3 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+ # Set room names in alphabetical order. room 1 -> A, 2 -> B, 3 -> C
+ self.helper.send_state(
+ room_id_1, "m.room.name", {"name": "A"}, tok=self.admin_user_tok,
+ )
+ self.helper.send_state(
+ room_id_2, "m.room.name", {"name": "B"}, tok=self.admin_user_tok,
+ )
+ self.helper.send_state(
+ room_id_3, "m.room.name", {"name": "C"}, tok=self.admin_user_tok,
+ )
+
+ # Set room canonical room aliases
+ _set_canonical_alias(room_id_1, "#A_alias:test", self.admin_user_tok)
+ _set_canonical_alias(room_id_2, "#B_alias:test", self.admin_user_tok)
+ _set_canonical_alias(room_id_3, "#C_alias:test", self.admin_user_tok)
+
+ # Set room member size in the reverse order. room 1 -> 1 member, 2 -> 2, 3 -> 3
+ user_1 = self.register_user("bob1", "pass")
+ user_1_tok = self.login("bob1", "pass")
+ self.helper.join(room_id_2, user_1, tok=user_1_tok)
+
+ user_2 = self.register_user("bob2", "pass")
+ user_2_tok = self.login("bob2", "pass")
+ self.helper.join(room_id_3, user_2, tok=user_2_tok)
+
+ user_3 = self.register_user("bob3", "pass")
+ user_3_tok = self.login("bob3", "pass")
+ self.helper.join(room_id_3, user_3, tok=user_3_tok)
+
+ # Test different sort orders, with forward and reverse directions
+ _order_test("name", [room_id_1, room_id_2, room_id_3])
+ _order_test("name", [room_id_3, room_id_2, room_id_1], reverse=True)
+
+ _order_test("canonical_alias", [room_id_1, room_id_2, room_id_3])
+ _order_test("canonical_alias", [room_id_3, room_id_2, room_id_1], reverse=True)
+
+ _order_test("joined_members", [room_id_3, room_id_2, room_id_1])
+ _order_test("joined_members", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+ _order_test("joined_local_members", [room_id_3, room_id_2, room_id_1])
+ _order_test(
+ "joined_local_members", [room_id_1, room_id_2, room_id_3], reverse=True
+ )
+
+ _order_test("version", [room_id_1, room_id_2, room_id_3])
+ _order_test("version", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+ _order_test("creator", [room_id_1, room_id_2, room_id_3])
+ _order_test("creator", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+ _order_test("encryption", [room_id_1, room_id_2, room_id_3])
+ _order_test("encryption", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+ _order_test("federatable", [room_id_1, room_id_2, room_id_3])
+ _order_test("federatable", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+ _order_test("public", [room_id_1, room_id_2, room_id_3])
+ # Different sort order of SQlite and PostreSQL
+ # _order_test("public", [room_id_3, room_id_2, room_id_1], reverse=True)
+
+ _order_test("join_rules", [room_id_1, room_id_2, room_id_3])
+ _order_test("join_rules", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+ _order_test("guest_access", [room_id_1, room_id_2, room_id_3])
+ _order_test("guest_access", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+ _order_test("history_visibility", [room_id_1, room_id_2, room_id_3])
+ _order_test(
+ "history_visibility", [room_id_1, room_id_2, room_id_3], reverse=True
+ )
+
+ _order_test("state_events", [room_id_3, room_id_2, room_id_1])
+ _order_test("state_events", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+ def test_search_term(self):
+ """Test that searching for a room works correctly"""
+ # Create two test rooms
+ room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+ room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+ room_name_1 = "something"
+ room_name_2 = "else"
+
+ # Set the name for each room
+ self.helper.send_state(
+ room_id_1, "m.room.name", {"name": room_name_1}, tok=self.admin_user_tok,
+ )
+ self.helper.send_state(
+ room_id_2, "m.room.name", {"name": room_name_2}, tok=self.admin_user_tok,
+ )
+
+ def _search_test(
+ expected_room_id: Optional[str],
+ search_term: str,
+ expected_http_code: int = 200,
+ ):
+ """Search for a room and check that the returned room's id is a match
+
+ Args:
+ expected_room_id: The room_id expected to be returned by the API. Set
+ to None to expect zero results for the search
+ search_term: The term to search for room names with
+ expected_http_code: The expected http code for the request
+ """
+ url = "/_synapse/admin/v1/rooms?search_term=%s" % (search_term,)
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(expected_http_code, channel.code, msg=channel.json_body)
+
+ if expected_http_code != 200:
+ return
+
+ # Check that rooms were returned
+ self.assertTrue("rooms" in channel.json_body)
+ rooms = channel.json_body["rooms"]
+
+ # Check that the expected number of rooms were returned
+ expected_room_count = 1 if expected_room_id else 0
+ self.assertEqual(len(rooms), expected_room_count)
+ self.assertEqual(channel.json_body["total_rooms"], expected_room_count)
+
+ # Check that the offset is correct
+ # We're not paginating, so should be 0
+ self.assertEqual(channel.json_body["offset"], 0)
+
+ # Check that there is no `prev_batch`
+ self.assertNotIn("prev_batch", channel.json_body)
+
+ # Check that there is no `next_batch`
+ self.assertNotIn("next_batch", channel.json_body)
+
+ if expected_room_id:
+ # Check that the first returned room id is correct
+ r = rooms[0]
+ self.assertEqual(expected_room_id, r["room_id"])
+
+ # Perform search tests
+ _search_test(room_id_1, "something")
+ _search_test(room_id_1, "thing")
+
+ _search_test(room_id_2, "else")
+ _search_test(room_id_2, "se")
+
+ _search_test(None, "foo")
+ _search_test(None, "bar")
+ _search_test(None, "", expected_http_code=400)
+
+ def test_single_room(self):
+ """Test that a single room can be requested correctly"""
+ # Create two test rooms
+ room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+ room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+ room_name_1 = "something"
+ room_name_2 = "else"
+
+ # Set the name for each room
+ self.helper.send_state(
+ room_id_1, "m.room.name", {"name": room_name_1}, tok=self.admin_user_tok,
+ )
+ self.helper.send_state(
+ room_id_2, "m.room.name", {"name": room_name_2}, tok=self.admin_user_tok,
+ )
+
+ url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,)
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ self.assertIn("room_id", channel.json_body)
+ self.assertIn("name", channel.json_body)
+ self.assertIn("canonical_alias", channel.json_body)
+ self.assertIn("joined_members", channel.json_body)
+ self.assertIn("joined_local_members", channel.json_body)
+ self.assertIn("version", channel.json_body)
+ self.assertIn("creator", channel.json_body)
+ self.assertIn("encryption", channel.json_body)
+ self.assertIn("federatable", channel.json_body)
+ self.assertIn("public", channel.json_body)
+ self.assertIn("join_rules", channel.json_body)
+ self.assertIn("guest_access", channel.json_body)
+ self.assertIn("history_visibility", channel.json_body)
+ self.assertIn("state_events", channel.json_body)
+
+ self.assertEqual(room_id_1, channel.json_body["room_id"])
+
+
+class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, homeserver):
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.creator = self.register_user("creator", "test")
+ self.creator_tok = self.login("creator", "test")
+
+ self.second_user_id = self.register_user("second", "test")
+ self.second_tok = self.login("second", "test")
+
+ self.public_room_id = self.helper.create_room_as(
+ self.creator, tok=self.creator_tok, is_public=True
+ )
+ self.url = "/_synapse/admin/v1/join/{}".format(self.public_room_id)
+
+ def test_requester_is_no_admin(self):
+ """
+ If the user is not a server admin, an error 403 is returned.
+ """
+ body = json.dumps({"user_id": self.second_user_id})
+
+ request, channel = self.make_request(
+ "POST",
+ self.url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.second_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_invalid_parameter(self):
+ """
+ If a parameter is missing, return an error
+ """
+ body = json.dumps({"unknown_parameter": "@unknown:test"})
+
+ request, channel = self.make_request(
+ "POST",
+ self.url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
+
+ def test_local_user_does_not_exist(self):
+ """
+ Tests that a lookup for a user that does not exist returns a 404
+ """
+ body = json.dumps({"user_id": "@unknown:test"})
+
+ request, channel = self.make_request(
+ "POST",
+ self.url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+ def test_remote_user(self):
+ """
+ Check that only local user can join rooms.
+ """
+ body = json.dumps({"user_id": "@not:exist.bla"})
+
+ request, channel = self.make_request(
+ "POST",
+ self.url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ "This endpoint can only be used with local users",
+ channel.json_body["error"],
+ )
+
+ def test_room_does_not_exist(self):
+ """
+ Check that unknown rooms/server return error 404.
+ """
+ body = json.dumps({"user_id": self.second_user_id})
+ url = "/_synapse/admin/v1/join/!unknown:test"
+
+ request, channel = self.make_request(
+ "POST",
+ url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("No known servers", channel.json_body["error"])
+
+ def test_room_is_not_valid(self):
+ """
+ Check that invalid room names, return an error 400.
+ """
+ body = json.dumps({"user_id": self.second_user_id})
+ url = "/_synapse/admin/v1/join/invalidroom"
+
+ request, channel = self.make_request(
+ "POST",
+ url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ "invalidroom was not legal room ID or room alias",
+ channel.json_body["error"],
+ )
+
+ def test_join_public_room(self):
+ """
+ Test joining a local user to a public room with "JoinRules.PUBLIC"
+ """
+ body = json.dumps({"user_id": self.second_user_id})
+
+ request, channel = self.make_request(
+ "POST",
+ self.url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(self.public_room_id, channel.json_body["room_id"])
+
+ # Validate if user is a member of the room
+
+ request, channel = self.make_request(
+ "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,
+ )
+ self.render(request)
+ self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(self.public_room_id, channel.json_body["joined_rooms"][0])
+
+ def test_join_private_room_if_not_member(self):
+ """
+ Test joining a local user to a private room with "JoinRules.INVITE"
+ when server admin is not member of this room.
+ """
+ private_room_id = self.helper.create_room_as(
+ self.creator, tok=self.creator_tok, is_public=False
+ )
+ url = "/_synapse/admin/v1/join/{}".format(private_room_id)
+ body = json.dumps({"user_id": self.second_user_id})
+
+ request, channel = self.make_request(
+ "POST",
+ url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_join_private_room_if_member(self):
+ """
+ Test joining a local user to a private room with "JoinRules.INVITE",
+ when server admin is member of this room.
+ """
+ private_room_id = self.helper.create_room_as(
+ self.creator, tok=self.creator_tok, is_public=False
+ )
+ self.helper.invite(
+ room=private_room_id,
+ src=self.creator,
+ targ=self.admin_user,
+ tok=self.creator_tok,
+ )
+ self.helper.join(
+ room=private_room_id, user=self.admin_user, tok=self.admin_user_tok
+ )
+
+ # Validate if server admin is a member of the room
+
+ request, channel = self.make_request(
+ "GET", "/_matrix/client/r0/joined_rooms", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
+
+ # Join user to room.
+
+ url = "/_synapse/admin/v1/join/{}".format(private_room_id)
+ body = json.dumps({"user_id": self.second_user_id})
+
+ request, channel = self.make_request(
+ "POST",
+ url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(private_room_id, channel.json_body["room_id"])
+
+ # Validate if user is a member of the room
+
+ request, channel = self.make_request(
+ "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,
+ )
+ self.render(request)
+ self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
+
+ def test_join_private_room_if_owner(self):
+ """
+ Test joining a local user to a private room with "JoinRules.INVITE",
+ when server admin is owner of this room.
+ """
+ private_room_id = self.helper.create_room_as(
+ self.admin_user, tok=self.admin_user_tok, is_public=False
+ )
+ url = "/_synapse/admin/v1/join/{}".format(private_room_id)
+ body = json.dumps({"user_id": self.second_user_id})
+
+ request, channel = self.make_request(
+ "POST",
+ url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(private_room_id, channel.json_body["room_id"])
+
+ # Validate if user is a member of the room
+
+ request, channel = self.make_request(
+ "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,
+ )
+ self.render(request)
+ self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 6416fb5d2a..6c88ab06e2 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -360,6 +360,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(3, len(channel.json_body["users"]))
+ self.assertEqual(3, channel.json_body["total"])
class UserRestTestCase(unittest.HomeserverTestCase):
@@ -434,6 +435,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
"admin": True,
"displayname": "Bob's name",
"threepids": [{"medium": "email", "address": "bob@bob.bob"}],
+ "avatar_url": None,
}
)
diff --git a/tests/rest/client/test_power_levels.py b/tests/rest/client/test_power_levels.py
new file mode 100644
index 0000000000..913ea3c98e
--- /dev/null
+++ b/tests/rest/client/test_power_levels.py
@@ -0,0 +1,205 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+from synapse.rest.client.v2_alpha import sync
+
+from tests.unittest import HomeserverTestCase
+
+
+class PowerLevelsTestCase(HomeserverTestCase):
+ """Tests that power levels are enforced in various situations"""
+
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ sync.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+
+ return self.setup_test_homeserver(config=config)
+
+ def prepare(self, reactor, clock, hs):
+ # register a room admin, moderator and regular user
+ self.admin_user_id = self.register_user("admin", "pass")
+ self.admin_access_token = self.login("admin", "pass")
+ self.mod_user_id = self.register_user("mod", "pass")
+ self.mod_access_token = self.login("mod", "pass")
+ self.user_user_id = self.register_user("user", "pass")
+ self.user_access_token = self.login("user", "pass")
+
+ # Create a room
+ self.room_id = self.helper.create_room_as(
+ self.admin_user_id, tok=self.admin_access_token
+ )
+
+ # Invite the other users
+ self.helper.invite(
+ room=self.room_id,
+ src=self.admin_user_id,
+ tok=self.admin_access_token,
+ targ=self.mod_user_id,
+ )
+ self.helper.invite(
+ room=self.room_id,
+ src=self.admin_user_id,
+ tok=self.admin_access_token,
+ targ=self.user_user_id,
+ )
+
+ # Make the other users join the room
+ self.helper.join(
+ room=self.room_id, user=self.mod_user_id, tok=self.mod_access_token
+ )
+ self.helper.join(
+ room=self.room_id, user=self.user_user_id, tok=self.user_access_token
+ )
+
+ # Mod the mod
+ room_power_levels = self.helper.get_state(
+ self.room_id, "m.room.power_levels", tok=self.admin_access_token,
+ )
+
+ # Update existing power levels with mod at PL50
+ room_power_levels["users"].update({self.mod_user_id: 50})
+
+ self.helper.send_state(
+ self.room_id,
+ "m.room.power_levels",
+ room_power_levels,
+ tok=self.admin_access_token,
+ )
+
+ def test_non_admins_cannot_enable_room_encryption(self):
+ # have the mod try to enable room encryption
+ self.helper.send_state(
+ self.room_id,
+ "m.room.encryption",
+ {"algorithm": "m.megolm.v1.aes-sha2"},
+ tok=self.mod_access_token,
+ expect_code=403, # expect failure
+ )
+
+ # have the user try to enable room encryption
+ self.helper.send_state(
+ self.room_id,
+ "m.room.encryption",
+ {"algorithm": "m.megolm.v1.aes-sha2"},
+ tok=self.user_access_token,
+ expect_code=403, # expect failure
+ )
+
+ def test_non_admins_cannot_send_server_acl(self):
+ # have the mod try to send a server ACL
+ self.helper.send_state(
+ self.room_id,
+ "m.room.server_acl",
+ {
+ "allow": ["*"],
+ "allow_ip_literals": False,
+ "deny": ["*.evil.com", "evil.com"],
+ },
+ tok=self.mod_access_token,
+ expect_code=403, # expect failure
+ )
+
+ # have the user try to send a server ACL
+ self.helper.send_state(
+ self.room_id,
+ "m.room.server_acl",
+ {
+ "allow": ["*"],
+ "allow_ip_literals": False,
+ "deny": ["*.evil.com", "evil.com"],
+ },
+ tok=self.user_access_token,
+ expect_code=403, # expect failure
+ )
+
+ def test_non_admins_cannot_tombstone_room(self):
+ # Create another room that will serve as our "upgraded room"
+ self.upgraded_room_id = self.helper.create_room_as(
+ self.admin_user_id, tok=self.admin_access_token
+ )
+
+ # have the mod try to send a tombstone event
+ self.helper.send_state(
+ self.room_id,
+ "m.room.tombstone",
+ {
+ "body": "This room has been replaced",
+ "replacement_room": self.upgraded_room_id,
+ },
+ tok=self.mod_access_token,
+ expect_code=403, # expect failure
+ )
+
+ # have the user try to send a tombstone event
+ self.helper.send_state(
+ self.room_id,
+ "m.room.tombstone",
+ {
+ "body": "This room has been replaced",
+ "replacement_room": self.upgraded_room_id,
+ },
+ tok=self.user_access_token,
+ expect_code=403, # expect failure
+ )
+
+ def test_admins_can_enable_room_encryption(self):
+ # have the admin try to enable room encryption
+ self.helper.send_state(
+ self.room_id,
+ "m.room.encryption",
+ {"algorithm": "m.megolm.v1.aes-sha2"},
+ tok=self.admin_access_token,
+ expect_code=200, # expect success
+ )
+
+ def test_admins_can_send_server_acl(self):
+ # have the admin try to send a server ACL
+ self.helper.send_state(
+ self.room_id,
+ "m.room.server_acl",
+ {
+ "allow": ["*"],
+ "allow_ip_literals": False,
+ "deny": ["*.evil.com", "evil.com"],
+ },
+ tok=self.admin_access_token,
+ expect_code=200, # expect success
+ )
+
+ def test_admins_can_tombstone_room(self):
+ # Create another room that will serve as our "upgraded room"
+ self.upgraded_room_id = self.helper.create_room_as(
+ self.admin_user_id, tok=self.admin_access_token
+ )
+
+ # have the admin try to send a tombstone event
+ self.helper.send_state(
+ self.room_id,
+ "m.room.tombstone",
+ {
+ "body": "This room has been replaced",
+ "replacement_room": self.upgraded_room_id,
+ },
+ tok=self.admin_access_token,
+ expect_code=200, # expect success
+ )
diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py
index a3d7e3c046..171632e195 100644
--- a/tests/rest/client/test_transactions.py
+++ b/tests/rest/client/test_transactions.py
@@ -2,7 +2,7 @@ from mock import Mock, call
from twisted.internet import defer, reactor
-from synapse.logging.context import LoggingContext
+from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context
from synapse.rest.client.transactions import CLEANUP_PERIOD_MS, HttpTransactionCache
from synapse.util import Clock
@@ -52,14 +52,14 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
def test():
with LoggingContext("c") as c1:
res = yield self.cache.fetch_or_execute(self.mock_key, cb)
- self.assertIs(LoggingContext.current_context(), c1)
+ self.assertIs(current_context(), c1)
self.assertEqual(res, "yay")
# run the test twice in parallel
d = defer.gatherResults([test(), test()])
- self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel)
+ self.assertIs(current_context(), SENTINEL_CONTEXT)
yield d
- self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel)
+ self.assertIs(current_context(), SENTINEL_CONTEXT)
@defer.inlineCallbacks
def test_does_not_cache_exceptions(self):
@@ -81,11 +81,11 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
yield self.cache.fetch_or_execute(self.mock_key, cb)
except Exception as e:
self.assertEqual(e.args[0], "boo")
- self.assertIs(LoggingContext.current_context(), test_context)
+ self.assertIs(current_context(), test_context)
res = yield self.cache.fetch_or_execute(self.mock_key, cb)
self.assertEqual(res, self.mock_http_response)
- self.assertIs(LoggingContext.current_context(), test_context)
+ self.assertIs(current_context(), test_context)
@defer.inlineCallbacks
def test_does_not_cache_failures(self):
@@ -107,11 +107,11 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
yield self.cache.fetch_or_execute(self.mock_key, cb)
except Exception as e:
self.assertEqual(e.args[0], "boo")
- self.assertIs(LoggingContext.current_context(), test_context)
+ self.assertIs(current_context(), test_context)
res = yield self.cache.fetch_or_execute(self.mock_key, cb)
self.assertEqual(res, self.mock_http_response)
- self.assertIs(LoggingContext.current_context(), test_context)
+ self.assertIs(current_context(), test_context)
@defer.inlineCallbacks
def test_cleans_up(self):
diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py
index ffb2de1505..b54b06482b 100644
--- a/tests/rest/client/v1/test_events.py
+++ b/tests/rest/client/v1/test_events.py
@@ -50,7 +50,7 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
return hs
- def prepare(self, hs, reactor, clock):
+ def prepare(self, reactor, clock, hs):
# register an account
self.user_id = self.register_user("sid1", "pass")
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index da2c9bfa1e..eb8f6264fd 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -4,7 +4,7 @@ import urllib.parse
from mock import Mock
import synapse.rest.admin
-from synapse.rest.client.v1 import login
+from synapse.rest.client.v1 import login, logout
from synapse.rest.client.v2_alpha import devices
from synapse.rest.client.v2_alpha.account import WhoamiRestServlet
@@ -20,6 +20,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
login.register_servlets,
+ logout.register_servlets,
devices.register_servlets,
lambda hs, http_server: WhoamiRestServlet(hs).register(http_server),
]
@@ -256,8 +257,74 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
self.render(request)
self.assertEquals(channel.code, 200, channel.result)
+ @override_config({"session_lifetime": "24h"})
+ def test_session_can_hard_logout_after_being_soft_logged_out(self):
+ self.register_user("kermit", "monkey")
+
+ # log in as normal
+ access_token = self.login("kermit", "monkey")
+
+ # we should now be able to make requests with the access token
+ request, channel = self.make_request(
+ b"GET", TEST_URL, access_token=access_token
+ )
+ self.render(request)
+ self.assertEquals(channel.code, 200, channel.result)
+
+ # time passes
+ self.reactor.advance(24 * 3600)
+
+ # ... and we should be soft-logouted
+ request, channel = self.make_request(
+ b"GET", TEST_URL, access_token=access_token
+ )
+ self.render(request)
+ self.assertEquals(channel.code, 401, channel.result)
+ self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
+ self.assertEquals(channel.json_body["soft_logout"], True)
+
+ # Now try to hard logout this session
+ request, channel = self.make_request(
+ b"POST", "/logout", access_token=access_token
+ )
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ @override_config({"session_lifetime": "24h"})
+ def test_session_can_hard_logout_all_sessions_after_being_soft_logged_out(self):
+ self.register_user("kermit", "monkey")
+
+ # log in as normal
+ access_token = self.login("kermit", "monkey")
+
+ # we should now be able to make requests with the access token
+ request, channel = self.make_request(
+ b"GET", TEST_URL, access_token=access_token
+ )
+ self.render(request)
+ self.assertEquals(channel.code, 200, channel.result)
+
+ # time passes
+ self.reactor.advance(24 * 3600)
+
+ # ... and we should be soft-logouted
+ request, channel = self.make_request(
+ b"GET", TEST_URL, access_token=access_token
+ )
+ self.render(request)
+ self.assertEquals(channel.code, 401, channel.result)
+ self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
+ self.assertEquals(channel.json_body["soft_logout"], True)
+
+ # Now try to hard log out all of the user's sessions
+ request, channel = self.make_request(
+ b"POST", "/logout/all", access_token=access_token
+ )
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
-class CASRedirectConfirmTestCase(unittest.HomeserverTestCase):
+
+class CASTestCase(unittest.HomeserverTestCase):
servlets = [
login.register_servlets,
@@ -274,6 +341,9 @@ class CASRedirectConfirmTestCase(unittest.HomeserverTestCase):
"service_url": "https://matrix.goodserver.com:8448",
}
+ cas_user_id = "username"
+ self.user_id = "@%s:test" % cas_user_id
+
async def get_raw(uri, args):
"""Return an example response payload from a call to the `/proxyValidate`
endpoint of a CAS server, copied from
@@ -282,10 +352,11 @@ class CASRedirectConfirmTestCase(unittest.HomeserverTestCase):
This needs to be returned by an async function (as opposed to set as the
mock's return value) because the corresponding Synapse code awaits on it.
"""
- return """
+ return (
+ """
<cas:serviceResponse xmlns:cas='http://www.yale.edu/tp/cas'>
<cas:authenticationSuccess>
- <cas:user>username</cas:user>
+ <cas:user>%s</cas:user>
<cas:proxyGrantingTicket>PGTIOU-84678-8a9d...</cas:proxyGrantingTicket>
<cas:proxies>
<cas:proxy>https://proxy2/pgtUrl</cas:proxy>
@@ -294,6 +365,8 @@ class CASRedirectConfirmTestCase(unittest.HomeserverTestCase):
</cas:authenticationSuccess>
</cas:serviceResponse>
"""
+ % cas_user_id
+ )
mocked_http_client = Mock(spec=["get_raw"])
mocked_http_client.get_raw.side_effect = get_raw
@@ -304,6 +377,9 @@ class CASRedirectConfirmTestCase(unittest.HomeserverTestCase):
return self.hs
+ def prepare(self, reactor, clock, hs):
+ self.deactivate_account_handler = hs.get_deactivate_account_handler()
+
def test_cas_redirect_confirm(self):
"""Tests that the SSO login flow serves a confirmation page before redirecting a
user to the redirect URL.
@@ -350,7 +426,14 @@ class CASRedirectConfirmTestCase(unittest.HomeserverTestCase):
def test_cas_redirect_whitelisted(self):
"""Tests that the SSO login flow serves a redirect to a whitelisted url
"""
- redirect_url = "https://legit-site.com/"
+ self._test_redirect("https://legit-site.com/")
+
+ @override_config({"public_baseurl": "https://example.com"})
+ def test_cas_redirect_login_fallback(self):
+ self._test_redirect("https://example.com/_matrix/static/client/login")
+
+ def _test_redirect(self, redirect_url):
+ """Tests that the SSO login flow serves a redirect for the given redirect URL."""
cas_ticket_url = (
"/_matrix/client/r0/login/cas/ticket?redirectUrl=%s&ticket=ticket"
% (urllib.parse.quote(redirect_url))
@@ -363,3 +446,30 @@ class CASRedirectConfirmTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 302)
location_headers = channel.headers.getRawHeaders("Location")
self.assertEqual(location_headers[0][: len(redirect_url)], redirect_url)
+
+ @override_config({"sso": {"client_whitelist": ["https://legit-site.com/"]}})
+ def test_deactivated_user(self):
+ """Logging in as a deactivated account should error."""
+ redirect_url = "https://legit-site.com/"
+
+ # First login (to create the user).
+ self._test_redirect(redirect_url)
+
+ # Deactivate the account.
+ self.get_success(
+ self.deactivate_account_handler.deactivate_account(self.user_id, False)
+ )
+
+ # Request the CAS ticket.
+ cas_ticket_url = (
+ "/_matrix/client/r0/login/cas/ticket?redirectUrl=%s&ticket=ticket"
+ % (urllib.parse.quote(redirect_url))
+ )
+
+ # Get Synapse to call the fake CAS and serve the template.
+ request, channel = self.make_request("GET", cas_ticket_url)
+ self.render(request)
+
+ # Because the user is deactivated they are served an error template.
+ self.assertEqual(channel.code, 403)
+ self.assertIn(b"SSO account deactivated", channel.result["body"])
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index 873d5ef99c..22d734e763 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -18,6 +18,7 @@
import json
import time
+from typing import Any, Dict, Optional
import attr
@@ -38,7 +39,7 @@ class RestHelper(object):
resource = attr.ib()
auth_user_id = attr.ib()
- def create_room_as(self, room_creator, is_public=True, tok=None):
+ def create_room_as(self, room_creator=None, is_public=True, tok=None):
temp_id = self.auth_user_id
self.auth_user_id = room_creator
path = "/_matrix/client/r0/createRoom"
@@ -142,7 +143,34 @@ class RestHelper(object):
return channel.json_body
- def send_state(self, room_id, event_type, body, tok, expect_code=200, state_key=""):
+ def _read_write_state(
+ self,
+ room_id: str,
+ event_type: str,
+ body: Optional[Dict[str, Any]],
+ tok: str,
+ expect_code: int = 200,
+ state_key: str = "",
+ method: str = "GET",
+ ) -> Dict:
+ """Read or write some state from a given room
+
+ Args:
+ room_id:
+ event_type: The type of state event
+ body: Body that is sent when making the request. The content of the state event.
+ If None, the request to the server will have an empty body
+ tok: The access token to use
+ expect_code: The HTTP code to expect in the response
+ state_key:
+ method: "GET" or "PUT" for reading or writing state, respectively
+
+ Returns:
+ The response body from the server
+
+ Raises:
+ AssertionError: if expect_code doesn't match the HTTP code we received
+ """
path = "/_matrix/client/r0/rooms/%s/state/%s/%s" % (
room_id,
event_type,
@@ -151,9 +179,13 @@ class RestHelper(object):
if tok:
path = path + "?access_token=%s" % tok
- request, channel = make_request(
- self.hs.get_reactor(), "PUT", path, json.dumps(body).encode("utf8")
- )
+ # Set request body if provided
+ content = b""
+ if body is not None:
+ content = json.dumps(body).encode("utf8")
+
+ request, channel = make_request(self.hs.get_reactor(), method, path, content)
+
render(request, self.resource, self.hs.get_reactor())
assert int(channel.result["code"]) == expect_code, (
@@ -163,6 +195,62 @@ class RestHelper(object):
return channel.json_body
+ def get_state(
+ self,
+ room_id: str,
+ event_type: str,
+ tok: str,
+ expect_code: int = 200,
+ state_key: str = "",
+ ):
+ """Gets some state from a room
+
+ Args:
+ room_id:
+ event_type: The type of state event
+ tok: The access token to use
+ expect_code: The HTTP code to expect in the response
+ state_key:
+
+ Returns:
+ The response body from the server
+
+ Raises:
+ AssertionError: if expect_code doesn't match the HTTP code we received
+ """
+ return self._read_write_state(
+ room_id, event_type, None, tok, expect_code, state_key, method="GET"
+ )
+
+ def send_state(
+ self,
+ room_id: str,
+ event_type: str,
+ body: Dict[str, Any],
+ tok: str,
+ expect_code: int = 200,
+ state_key: str = "",
+ ):
+ """Set some state in a room
+
+ Args:
+ room_id:
+ event_type: The type of state event
+ body: Body that is sent when making the request. The content of the state event.
+ tok: The access token to use
+ expect_code: The HTTP code to expect in the response
+ state_key:
+
+ Returns:
+ The response body from the server
+
+ Raises:
+ AssertionError: if expect_code doesn't match the HTTP code we received
+ """
+ return self._read_write_state(
+ room_id, event_type, body, tok, expect_code, state_key, method="PUT"
+ )
+
def upload_media(
self,
resource: Resource,
diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py
index c3facc00eb..0d6936fd36 100644
--- a/tests/rest/client/v2_alpha/test_account.py
+++ b/tests/rest/client/v2_alpha/test_account.py
@@ -24,6 +24,7 @@ import pkg_resources
import synapse.rest.admin
from synapse.api.constants import LoginType, Membership
+from synapse.api.errors import Codes
from synapse.rest.client.v1 import login, room
from synapse.rest.client.v2_alpha import account, register
@@ -178,6 +179,22 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
# Assert we can't log in with the new password
self.attempt_wrong_password_login("kermit", new_password)
+ @unittest.override_config({"request_token_inhibit_3pid_errors": True})
+ def test_password_reset_bad_email_inhibit_error(self):
+ """Test that triggering a password reset with an email address that isn't bound
+ to an account doesn't leak the lack of binding for that address if configured
+ that way.
+ """
+ self.register_user("kermit", "monkey")
+ self.login("kermit", "monkey")
+
+ email = "test@example.com"
+
+ client_secret = "foobar"
+ session_id = self._request_token(email, client_secret)
+
+ self.assertIsNotNone(session_id)
+
def _request_token(self, email, client_secret):
request, channel = self.make_request(
"POST",
@@ -325,3 +342,304 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
)
self.render(request)
self.assertEqual(request.code, 200)
+
+
+class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ account.register_servlets,
+ login.register_servlets,
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+
+ # Email config.
+ self.email_attempts = []
+
+ def sendmail(smtphost, from_addr, to_addrs, msg, **kwargs):
+ self.email_attempts.append(msg)
+
+ config["email"] = {
+ "enable_notifs": False,
+ "template_dir": os.path.abspath(
+ pkg_resources.resource_filename("synapse", "res/templates")
+ ),
+ "smtp_host": "127.0.0.1",
+ "smtp_port": 20,
+ "require_transport_security": False,
+ "smtp_user": None,
+ "smtp_pass": None,
+ "notif_from": "test@example.com",
+ }
+ config["public_baseurl"] = "https://example.com"
+
+ self.hs = self.setup_test_homeserver(config=config, sendmail=sendmail)
+ return self.hs
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+
+ self.user_id = self.register_user("kermit", "test")
+ self.user_id_tok = self.login("kermit", "test")
+ self.email = "test@example.com"
+ self.url_3pid = b"account/3pid"
+
+ def test_add_email(self):
+ """Test adding an email to profile
+ """
+ client_secret = "foobar"
+ session_id = self._request_token(self.email, client_secret)
+
+ self.assertEquals(len(self.email_attempts), 1)
+ link = self._get_link_from_email()
+
+ self._validate_token(link)
+
+ request, channel = self.make_request(
+ "POST",
+ b"/_matrix/client/unstable/account/3pid/add",
+ {
+ "client_secret": client_secret,
+ "sid": session_id,
+ "auth": {
+ "type": "m.login.password",
+ "user": self.user_id,
+ "password": "test",
+ },
+ },
+ access_token=self.user_id_tok,
+ )
+
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Get user
+ request, channel = self.make_request(
+ "GET", self.url_3pid, access_token=self.user_id_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
+ self.assertEqual(self.email, channel.json_body["threepids"][0]["address"])
+
+ def test_add_email_if_disabled(self):
+ """Test adding email to profile when doing so is disallowed
+ """
+ self.hs.config.enable_3pid_changes = False
+
+ client_secret = "foobar"
+ session_id = self._request_token(self.email, client_secret)
+
+ self.assertEquals(len(self.email_attempts), 1)
+ link = self._get_link_from_email()
+
+ self._validate_token(link)
+
+ request, channel = self.make_request(
+ "POST",
+ b"/_matrix/client/unstable/account/3pid/add",
+ {
+ "client_secret": client_secret,
+ "sid": session_id,
+ "auth": {
+ "type": "m.login.password",
+ "user": self.user_id,
+ "password": "test",
+ },
+ },
+ access_token=self.user_id_tok,
+ )
+ self.render(request)
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ # Get user
+ request, channel = self.make_request(
+ "GET", self.url_3pid, access_token=self.user_id_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertFalse(channel.json_body["threepids"])
+
+ def test_delete_email(self):
+ """Test deleting an email from profile
+ """
+ # Add a threepid
+ self.get_success(
+ self.store.user_add_threepid(
+ user_id=self.user_id,
+ medium="email",
+ address=self.email,
+ validated_at=0,
+ added_at=0,
+ )
+ )
+
+ request, channel = self.make_request(
+ "POST",
+ b"account/3pid/delete",
+ {"medium": "email", "address": self.email},
+ access_token=self.user_id_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Get user
+ request, channel = self.make_request(
+ "GET", self.url_3pid, access_token=self.user_id_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertFalse(channel.json_body["threepids"])
+
+ def test_delete_email_if_disabled(self):
+ """Test deleting an email from profile when disallowed
+ """
+ self.hs.config.enable_3pid_changes = False
+
+ # Add a threepid
+ self.get_success(
+ self.store.user_add_threepid(
+ user_id=self.user_id,
+ medium="email",
+ address=self.email,
+ validated_at=0,
+ added_at=0,
+ )
+ )
+
+ request, channel = self.make_request(
+ "POST",
+ b"account/3pid/delete",
+ {"medium": "email", "address": self.email},
+ access_token=self.user_id_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ # Get user
+ request, channel = self.make_request(
+ "GET", self.url_3pid, access_token=self.user_id_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
+ self.assertEqual(self.email, channel.json_body["threepids"][0]["address"])
+
+ def test_cant_add_email_without_clicking_link(self):
+ """Test that we do actually need to click the link in the email
+ """
+ client_secret = "foobar"
+ session_id = self._request_token(self.email, client_secret)
+
+ self.assertEquals(len(self.email_attempts), 1)
+
+ # Attempt to add email without clicking the link
+ request, channel = self.make_request(
+ "POST",
+ b"/_matrix/client/unstable/account/3pid/add",
+ {
+ "client_secret": client_secret,
+ "sid": session_id,
+ "auth": {
+ "type": "m.login.password",
+ "user": self.user_id,
+ "password": "test",
+ },
+ },
+ access_token=self.user_id_tok,
+ )
+ self.render(request)
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"])
+
+ # Get user
+ request, channel = self.make_request(
+ "GET", self.url_3pid, access_token=self.user_id_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertFalse(channel.json_body["threepids"])
+
+ def test_no_valid_token(self):
+ """Test that we do actually need to request a token and can't just
+ make a session up.
+ """
+ client_secret = "foobar"
+ session_id = "weasle"
+
+ # Attempt to add email without even requesting an email
+ request, channel = self.make_request(
+ "POST",
+ b"/_matrix/client/unstable/account/3pid/add",
+ {
+ "client_secret": client_secret,
+ "sid": session_id,
+ "auth": {
+ "type": "m.login.password",
+ "user": self.user_id,
+ "password": "test",
+ },
+ },
+ access_token=self.user_id_tok,
+ )
+ self.render(request)
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"])
+
+ # Get user
+ request, channel = self.make_request(
+ "GET", self.url_3pid, access_token=self.user_id_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertFalse(channel.json_body["threepids"])
+
+ def _request_token(self, email, client_secret):
+ request, channel = self.make_request(
+ "POST",
+ b"account/3pid/email/requestToken",
+ {"client_secret": client_secret, "email": email, "send_attempt": 1},
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.result)
+
+ return channel.json_body["sid"]
+
+ def _validate_token(self, link):
+ # Remove the host
+ path = link.replace("https://example.com", "")
+
+ request, channel = self.make_request("GET", path, shorthand=False)
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.result)
+
+ def _get_link_from_email(self):
+ assert self.email_attempts, "No emails have been sent"
+
+ raw_msg = self.email_attempts[-1].decode("UTF-8")
+ mail = Parser().parsestr(raw_msg)
+
+ text = None
+ for part in mail.walk():
+ if part.get_content_type() == "text/plain":
+ text = part.get_payload(decode=True).decode("UTF-8")
+ break
+
+ if not text:
+ self.fail("Could not find text portion of email to parse")
+
+ match = re.search(r"https://example.com\S+", text)
+ assert match, "Could not find link in email"
+
+ return match.group(0)
diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py
index b6df1396ad..293ccfba2b 100644
--- a/tests/rest/client/v2_alpha/test_auth.py
+++ b/tests/rest/client/v2_alpha/test_auth.py
@@ -12,16 +12,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+from typing import List, Union
from twisted.internet.defer import succeed
import synapse.rest.admin
from synapse.api.constants import LoginType
from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
-from synapse.rest.client.v2_alpha import auth, register
+from synapse.http.site import SynapseRequest
+from synapse.rest.client.v1 import login
+from synapse.rest.client.v2_alpha import auth, devices, register
+from synapse.types import JsonDict
from tests import unittest
+from tests.server import FakeChannel
class DummyRecaptchaChecker(UserInteractiveAuthChecker):
@@ -34,11 +38,15 @@ class DummyRecaptchaChecker(UserInteractiveAuthChecker):
return succeed(True)
+class DummyPasswordChecker(UserInteractiveAuthChecker):
+ def check_auth(self, authdict, clientip):
+ return succeed(authdict["identifier"]["user"])
+
+
class FallbackAuthTests(unittest.HomeserverTestCase):
servlets = [
auth.register_servlets,
- synapse.rest.admin.register_servlets_for_client_rest_resource,
register.register_servlets,
]
hijack_auth = False
@@ -59,59 +67,250 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
auth_handler = hs.get_auth_handler()
auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker
- @unittest.INFO
- def test_fallback_captcha(self):
-
+ def register(self, expected_response: int, body: JsonDict) -> FakeChannel:
+ """Make a register request."""
request, channel = self.make_request(
- "POST",
- "register",
- {"username": "user", "type": "m.login.password", "password": "bar"},
- )
+ "POST", "register", body
+ ) # type: SynapseRequest, FakeChannel
self.render(request)
- # Returns a 401 as per the spec
- self.assertEqual(request.code, 401)
- # Grab the session
- session = channel.json_body["session"]
- # Assert our configured public key is being given
- self.assertEqual(
- channel.json_body["params"]["m.login.recaptcha"]["public_key"], "brokencake"
- )
+ self.assertEqual(request.code, expected_response)
+ return channel
+
+ def recaptcha(
+ self, session: str, expected_post_response: int, post_session: str = None
+ ) -> None:
+ """Get and respond to a fallback recaptcha. Returns the second request."""
+ if post_session is None:
+ post_session = session
request, channel = self.make_request(
"GET", "auth/m.login.recaptcha/fallback/web?session=" + session
- )
+ ) # type: SynapseRequest, FakeChannel
self.render(request)
self.assertEqual(request.code, 200)
request, channel = self.make_request(
"POST",
"auth/m.login.recaptcha/fallback/web?session="
- + session
+ + post_session
+ "&g-recaptcha-response=a",
)
self.render(request)
- self.assertEqual(request.code, 200)
+ self.assertEqual(request.code, expected_post_response)
# The recaptcha handler is called with the response given
attempts = self.recaptcha_checker.recaptcha_attempts
self.assertEqual(len(attempts), 1)
self.assertEqual(attempts[0][0]["response"], "a")
- # also complete the dummy auth
- request, channel = self.make_request(
- "POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}}
+ @unittest.INFO
+ def test_fallback_captcha(self):
+ """Ensure that fallback auth via a captcha works."""
+ # Returns a 401 as per the spec
+ channel = self.register(
+ 401, {"username": "user", "type": "m.login.password", "password": "bar"},
)
- self.render(request)
- # Now we should have fufilled a complete auth flow, including
+ # Grab the session
+ session = channel.json_body["session"]
+ # Assert our configured public key is being given
+ self.assertEqual(
+ channel.json_body["params"]["m.login.recaptcha"]["public_key"], "brokencake"
+ )
+
+ # Complete the recaptcha step.
+ self.recaptcha(session, 200)
+
+ # also complete the dummy auth
+ self.register(200, {"auth": {"session": session, "type": "m.login.dummy"}})
+
+ # Now we should have fulfilled a complete auth flow, including
# the recaptcha fallback step, we can then send a
# request to the register API with the session in the authdict.
- request, channel = self.make_request(
- "POST", "register", {"auth": {"session": session}}
- )
- self.render(request)
- self.assertEqual(channel.code, 200)
+ channel = self.register(200, {"auth": {"session": session}})
# We're given a registered user.
self.assertEqual(channel.json_body["user_id"], "@user:test")
+
+ def test_complete_operation_unknown_session(self):
+ """
+ Attempting to mark an invalid session as complete should error.
+ """
+ # Make the initial request to register. (Later on a different password
+ # will be used.)
+ # Returns a 401 as per the spec
+ channel = self.register(
+ 401, {"username": "user", "type": "m.login.password", "password": "bar"}
+ )
+
+ # Grab the session
+ session = channel.json_body["session"]
+ # Assert our configured public key is being given
+ self.assertEqual(
+ channel.json_body["params"]["m.login.recaptcha"]["public_key"], "brokencake"
+ )
+
+ # Attempt to complete the recaptcha step with an unknown session.
+ # This results in an error.
+ self.recaptcha(session, 400, session + "unknown")
+
+
+class UIAuthTests(unittest.HomeserverTestCase):
+ servlets = [
+ auth.register_servlets,
+ devices.register_servlets,
+ login.register_servlets,
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ register.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ auth_handler = hs.get_auth_handler()
+ auth_handler.checkers[LoginType.PASSWORD] = DummyPasswordChecker(hs)
+
+ self.user_pass = "pass"
+ self.user = self.register_user("test", self.user_pass)
+ self.user_tok = self.login("test", self.user_pass)
+
+ def get_device_ids(self) -> List[str]:
+ # Get the list of devices so one can be deleted.
+ request, channel = self.make_request(
+ "GET", "devices", access_token=self.user_tok,
+ ) # type: SynapseRequest, FakeChannel
+ self.render(request)
+
+ # Get the ID of the device.
+ self.assertEqual(request.code, 200)
+ return [d["device_id"] for d in channel.json_body["devices"]]
+
+ def delete_device(
+ self, device: str, expected_response: int, body: Union[bytes, JsonDict] = b""
+ ) -> FakeChannel:
+ """Delete an individual device."""
+ request, channel = self.make_request(
+ "DELETE", "devices/" + device, body, access_token=self.user_tok
+ ) # type: SynapseRequest, FakeChannel
+ self.render(request)
+
+ # Ensure the response is sane.
+ self.assertEqual(request.code, expected_response)
+
+ return channel
+
+ def delete_devices(self, expected_response: int, body: JsonDict) -> FakeChannel:
+ """Delete 1 or more devices."""
+ # Note that this uses the delete_devices endpoint so that we can modify
+ # the payload half-way through some tests.
+ request, channel = self.make_request(
+ "POST", "delete_devices", body, access_token=self.user_tok,
+ ) # type: SynapseRequest, FakeChannel
+ self.render(request)
+
+ # Ensure the response is sane.
+ self.assertEqual(request.code, expected_response)
+
+ return channel
+
+ def test_ui_auth(self):
+ """
+ Test user interactive authentication outside of registration.
+ """
+ device_id = self.get_device_ids()[0]
+
+ # Attempt to delete this device.
+ # Returns a 401 as per the spec
+ channel = self.delete_device(device_id, 401)
+
+ # Grab the session
+ session = channel.json_body["session"]
+ # Ensure that flows are what is expected.
+ self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"])
+
+ # Make another request providing the UI auth flow.
+ self.delete_device(
+ device_id,
+ 200,
+ {
+ "auth": {
+ "type": "m.login.password",
+ "identifier": {"type": "m.id.user", "user": self.user},
+ "password": self.user_pass,
+ "session": session,
+ },
+ },
+ )
+
+ def test_can_change_body(self):
+ """
+ The client dict can be modified during the user interactive authentication session.
+
+ Note that it is not spec compliant to modify the client dict during a
+ user interactive authentication session, but many clients currently do.
+
+ When Synapse is updated to be spec compliant, the call to re-use the
+ session ID should be rejected.
+ """
+ # Create a second login.
+ self.login("test", self.user_pass)
+
+ device_ids = self.get_device_ids()
+ self.assertEqual(len(device_ids), 2)
+
+ # Attempt to delete the first device.
+ # Returns a 401 as per the spec
+ channel = self.delete_devices(401, {"devices": [device_ids[0]]})
+
+ # Grab the session
+ session = channel.json_body["session"]
+ # Ensure that flows are what is expected.
+ self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"])
+
+ # Make another request providing the UI auth flow, but try to delete the
+ # second device.
+ self.delete_devices(
+ 200,
+ {
+ "devices": [device_ids[1]],
+ "auth": {
+ "type": "m.login.password",
+ "identifier": {"type": "m.id.user", "user": self.user},
+ "password": self.user_pass,
+ "session": session,
+ },
+ },
+ )
+
+ def test_cannot_change_uri(self):
+ """
+ The initial requested URI cannot be modified during the user interactive authentication session.
+ """
+ # Create a second login.
+ self.login("test", self.user_pass)
+
+ device_ids = self.get_device_ids()
+ self.assertEqual(len(device_ids), 2)
+
+ # Attempt to delete the first device.
+ # Returns a 401 as per the spec
+ channel = self.delete_device(device_ids[0], 401)
+
+ # Grab the session
+ session = channel.json_body["session"]
+ # Ensure that flows are what is expected.
+ self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"])
+
+ # Make another request providing the UI auth flow, but try to delete the
+ # second device. This results in an error.
+ self.delete_device(
+ device_ids[1],
+ 403,
+ {
+ "auth": {
+ "type": "m.login.password",
+ "identifier": {"type": "m.id.user", "user": self.user},
+ "password": self.user_pass,
+ "session": session,
+ },
+ },
+ )
diff --git a/tests/rest/client/v2_alpha/test_password_policy.py b/tests/rest/client/v2_alpha/test_password_policy.py
new file mode 100644
index 0000000000..c57072f50c
--- /dev/null
+++ b/tests/rest/client/v2_alpha/test_password_policy.py
@@ -0,0 +1,179 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+
+from synapse.api.constants import LoginType
+from synapse.api.errors import Codes
+from synapse.rest import admin
+from synapse.rest.client.v1 import login
+from synapse.rest.client.v2_alpha import account, password_policy, register
+
+from tests import unittest
+
+
+class PasswordPolicyTestCase(unittest.HomeserverTestCase):
+ """Tests the password policy feature and its compliance with MSC2000.
+
+ When validating a password, Synapse does the necessary checks in this order:
+
+ 1. Password is long enough
+ 2. Password contains digit(s)
+ 3. Password contains symbol(s)
+ 4. Password contains uppercase letter(s)
+ 5. Password contains lowercase letter(s)
+
+ For each test below that checks whether a password triggers the right error code,
+ that test provides a password good enough to pass the previous tests, but not the
+ one it is currently testing (nor any test that comes afterward).
+ """
+
+ servlets = [
+ admin.register_servlets_for_client_rest_resource,
+ login.register_servlets,
+ register.register_servlets,
+ password_policy.register_servlets,
+ account.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ self.register_url = "/_matrix/client/r0/register"
+ self.policy = {
+ "enabled": True,
+ "minimum_length": 10,
+ "require_digit": True,
+ "require_symbol": True,
+ "require_lowercase": True,
+ "require_uppercase": True,
+ }
+
+ config = self.default_config()
+ config["password_config"] = {
+ "policy": self.policy,
+ }
+
+ hs = self.setup_test_homeserver(config=config)
+ return hs
+
+ def test_get_policy(self):
+ """Tests if the /password_policy endpoint returns the configured policy."""
+
+ request, channel = self.make_request(
+ "GET", "/_matrix/client/r0/password_policy"
+ )
+ self.render(request)
+
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(
+ channel.json_body,
+ {
+ "m.minimum_length": 10,
+ "m.require_digit": True,
+ "m.require_symbol": True,
+ "m.require_lowercase": True,
+ "m.require_uppercase": True,
+ },
+ channel.result,
+ )
+
+ def test_password_too_short(self):
+ request_data = json.dumps({"username": "kermit", "password": "shorty"})
+ request, channel = self.make_request("POST", self.register_url, request_data)
+ self.render(request)
+
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(
+ channel.json_body["errcode"], Codes.PASSWORD_TOO_SHORT, channel.result,
+ )
+
+ def test_password_no_digit(self):
+ request_data = json.dumps({"username": "kermit", "password": "longerpassword"})
+ request, channel = self.make_request("POST", self.register_url, request_data)
+ self.render(request)
+
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(
+ channel.json_body["errcode"], Codes.PASSWORD_NO_DIGIT, channel.result,
+ )
+
+ def test_password_no_symbol(self):
+ request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword"})
+ request, channel = self.make_request("POST", self.register_url, request_data)
+ self.render(request)
+
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(
+ channel.json_body["errcode"], Codes.PASSWORD_NO_SYMBOL, channel.result,
+ )
+
+ def test_password_no_uppercase(self):
+ request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword!"})
+ request, channel = self.make_request("POST", self.register_url, request_data)
+ self.render(request)
+
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(
+ channel.json_body["errcode"], Codes.PASSWORD_NO_UPPERCASE, channel.result,
+ )
+
+ def test_password_no_lowercase(self):
+ request_data = json.dumps({"username": "kermit", "password": "L0NGERPASSWORD!"})
+ request, channel = self.make_request("POST", self.register_url, request_data)
+ self.render(request)
+
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(
+ channel.json_body["errcode"], Codes.PASSWORD_NO_LOWERCASE, channel.result,
+ )
+
+ def test_password_compliant(self):
+ request_data = json.dumps({"username": "kermit", "password": "L0ngerpassword!"})
+ request, channel = self.make_request("POST", self.register_url, request_data)
+ self.render(request)
+
+ # Getting a 401 here means the password has passed validation and the server has
+ # responded with a list of registration flows.
+ self.assertEqual(channel.code, 401, channel.result)
+
+ def test_password_change(self):
+ """This doesn't test every possible use case, only that hitting /account/password
+ triggers the password validation code.
+ """
+ compliant_password = "C0mpl!antpassword"
+ not_compliant_password = "notcompliantpassword"
+
+ user_id = self.register_user("kermit", compliant_password)
+ tok = self.login("kermit", compliant_password)
+
+ request_data = json.dumps(
+ {
+ "new_password": not_compliant_password,
+ "auth": {
+ "password": compliant_password,
+ "type": LoginType.PASSWORD,
+ "user": user_id,
+ },
+ }
+ )
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/account/password",
+ request_data,
+ access_token=tok,
+ )
+ self.render(request)
+
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(channel.json_body["errcode"], Codes.PASSWORD_NO_DIGIT)
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index d0c997e385..5637ce2090 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -25,7 +25,7 @@ import synapse.rest.admin
from synapse.api.constants import LoginType
from synapse.api.errors import Codes
from synapse.appservice import ApplicationService
-from synapse.rest.client.v1 import login
+from synapse.rest.client.v1 import login, logout
from synapse.rest.client.v2_alpha import account, account_validity, register, sync
from tests import unittest
@@ -33,11 +33,15 @@ from tests import unittest
class RegisterRestServletTestCase(unittest.HomeserverTestCase):
- servlets = [register.register_servlets]
+ servlets = [
+ login.register_servlets,
+ register.register_servlets,
+ synapse.rest.admin.register_servlets,
+ ]
url = b"/_matrix/client/r0/register"
- def default_config(self, name="test"):
- config = super().default_config(name)
+ def default_config(self):
+ config = super().default_config()
config["allow_guest_access"] = True
return config
@@ -260,6 +264,47 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
[["m.login.email.identity"]], (f["stages"] for f in flows)
)
+ @unittest.override_config(
+ {
+ "request_token_inhibit_3pid_errors": True,
+ "public_baseurl": "https://test_server",
+ "email": {
+ "smtp_host": "mail_server",
+ "smtp_port": 2525,
+ "notif_from": "sender@host",
+ },
+ }
+ )
+ def test_request_token_existing_email_inhibit_error(self):
+ """Test that requesting a token via this endpoint doesn't leak existing
+ associations if configured that way.
+ """
+ user_id = self.register_user("kermit", "monkey")
+ self.login("kermit", "monkey")
+
+ email = "test@example.com"
+
+ # Add a threepid
+ self.get_success(
+ self.hs.get_datastore().user_add_threepid(
+ user_id=user_id,
+ medium="email",
+ address=email,
+ validated_at=0,
+ added_at=0,
+ )
+ )
+
+ request, channel = self.make_request(
+ "POST",
+ b"register/email/requestToken",
+ {"client_secret": "foobar", "email": email, "send_attempt": 1},
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.result)
+
+ self.assertIsNotNone(channel.json_body.get("sid"))
+
class AccountValidityTestCase(unittest.HomeserverTestCase):
@@ -268,6 +313,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
synapse.rest.admin.register_servlets_for_client_rest_resource,
login.register_servlets,
sync.register_servlets,
+ logout.register_servlets,
account_validity.register_servlets,
]
@@ -360,6 +406,39 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result
)
+ def test_logging_out_expired_user(self):
+ user_id = self.register_user("kermit", "monkey")
+ tok = self.login("kermit", "monkey")
+
+ self.register_user("admin", "adminpassword", admin=True)
+ admin_tok = self.login("admin", "adminpassword")
+
+ url = "/_matrix/client/unstable/admin/account_validity/validity"
+ params = {
+ "user_id": user_id,
+ "expiration_ts": 0,
+ "enable_renewal_emails": False,
+ }
+ request_data = json.dumps(params)
+ request, channel = self.make_request(
+ b"POST", url, request_data, access_token=admin_tok
+ )
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ # Try to log the user out
+ request, channel = self.make_request(b"POST", "/logout", access_token=tok)
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ # Log the user in again (allowed for expired accounts)
+ tok = self.login("kermit", "monkey")
+
+ # Try to log out all of the user's sessions
+ request, channel = self.make_request(b"POST", "/logout/all", access_token=tok)
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py
index 6776a56cad..99eb477149 100644
--- a/tests/rest/key/v2/test_remote_key_resource.py
+++ b/tests/rest/key/v2/test_remote_key_resource.py
@@ -143,8 +143,8 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
endpoint, to check that the two implementations are compatible.
"""
- def default_config(self, *args, **kwargs):
- config = super().default_config(*args, **kwargs)
+ def default_config(self):
+ config = super().default_config()
# replace the signing key with our own
self.hs_signing_key = signedjson.key.generate_signing_key("kssk")
diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py
index 852b8ab11c..2826211f32 100644
--- a/tests/rest/media/v1/test_url_preview.py
+++ b/tests/rest/media/v1/test_url_preview.py
@@ -74,6 +74,12 @@ class URLPreviewTests(unittest.HomeserverTestCase):
)
config["url_preview_ip_range_whitelist"] = ("1.1.1.1",)
config["url_preview_url_blacklist"] = []
+ config["url_preview_accept_language"] = [
+ "en-UK",
+ "en-US;q=0.9",
+ "fr;q=0.8",
+ "*;q=0.7",
+ ]
self.storage_path = self.mktemp()
self.media_store_path = self.mktemp()
@@ -507,3 +513,52 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.pump()
self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body, {})
+
+ def test_accept_language_config_option(self):
+ """
+ Accept-Language header is sent to the remote server
+ """
+ self.lookups["example.com"] = [(IPv4Address, "8.8.8.8")]
+
+ # Build and make a request to the server
+ request, channel = self.make_request(
+ "GET", "url_preview?url=http://example.com", shorthand=False
+ )
+ request.render(self.preview_url)
+ self.pump()
+
+ # Extract Synapse's tcp client
+ client = self.reactor.tcpClients[0][2].buildProtocol(None)
+
+ # Build a fake remote server to reply with
+ server = AccumulatingProtocol()
+
+ # Connect the two together
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+
+ # Tell Synapse that it has received some data from the remote server
+ client.dataReceived(
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\nContent-Type: text/html\r\n\r\n"
+ % (len(self.end_content),)
+ + self.end_content
+ )
+
+ # Move the reactor along until we get a response on our original channel
+ self.pump()
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(
+ channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
+ )
+
+ # Check that the server received the Accept-Language header as part
+ # of the request from Synapse
+ self.assertIn(
+ (
+ b"Accept-Language: en-UK\r\n"
+ b"Accept-Language: en-US;q=0.9\r\n"
+ b"Accept-Language: fr;q=0.8\r\n"
+ b"Accept-Language: *;q=0.7"
+ ),
+ server.data,
+ )
diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index eb540e34f6..406f29a7c0 100644
--- a/tests/server_notices/test_resource_limits_server_notices.py
+++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -19,6 +19,9 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, LimitBlockingTypes, ServerNoticeMsgType
from synapse.api.errors import ResourceLimitError
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+from synapse.rest.client.v2_alpha import sync
from synapse.server_notices.resource_limits_server_notices import (
ResourceLimitsServerNotices,
)
@@ -28,7 +31,7 @@ from tests import unittest
class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
- hs_config = self.default_config("test")
+ hs_config = self.default_config()
hs_config["server_notices"] = {
"system_mxid_localpart": "server",
"system_mxid_display_name": "test display name",
@@ -52,26 +55,19 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
self._rlsn._store.user_last_seen_monthly_active = Mock(
return_value=defer.succeed(1000)
)
- self._send_notice = self._rlsn._server_notices_manager.send_notice
- self._rlsn._server_notices_manager.send_notice = Mock()
- self._rlsn._state.get_current_state = Mock(return_value=defer.succeed(None))
- self._rlsn._store.get_events = Mock(return_value=defer.succeed({}))
-
+ self._rlsn._server_notices_manager.send_notice = Mock(
+ return_value=defer.succeed(Mock())
+ )
self._send_notice = self._rlsn._server_notices_manager.send_notice
self.hs.config.limit_usage_by_mau = True
self.user_id = "@user_id:test"
- # self.server_notices_mxid = "@server:test"
- # self.server_notices_mxid_display_name = None
- # self.server_notices_mxid_avatar_url = None
- # self.server_notices_room_name = "Server Notices"
-
- self._rlsn._server_notices_manager.get_notice_room_for_user = Mock(
- returnValue=""
+ self._rlsn._server_notices_manager.get_or_create_notice_room_for_user = Mock(
+ return_value=defer.succeed("!something:localhost")
)
- self._rlsn._store.add_tag_to_room = Mock()
- self._rlsn._store.get_tags_for_room = Mock(return_value={})
+ self._rlsn._store.add_tag_to_room = Mock(return_value=defer.succeed(None))
+ self._rlsn._store.get_tags_for_room = Mock(return_value=defer.succeed({}))
self.hs.config.admin_contact = "mailto:user@test.com"
def test_maybe_send_server_notice_to_user_flag_off(self):
@@ -92,14 +88,13 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
def test_maybe_send_server_notice_to_user_remove_blocked_notice(self):
"""Test when user has blocked notice, but should have it removed"""
- self._rlsn._auth.check_auth_blocking = Mock()
+ self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None))
mock_event = Mock(
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
)
self._rlsn._store.get_events = Mock(
return_value=defer.succeed({"123": mock_event})
)
-
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
# Would be better to check the content, but once == remove blocking event
self._send_notice.assert_called_once()
@@ -109,7 +104,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
Test when user has blocked notice, but notice ought to be there (NOOP)
"""
self._rlsn._auth.check_auth_blocking = Mock(
- side_effect=ResourceLimitError(403, "foo")
+ return_value=defer.succeed(None), side_effect=ResourceLimitError(403, "foo")
)
mock_event = Mock(
@@ -118,6 +113,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
self._rlsn._store.get_events = Mock(
return_value=defer.succeed({"123": mock_event})
)
+
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
self._send_notice.assert_not_called()
@@ -126,9 +122,8 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
"""
Test when user does not have blocked notice, but should have one
"""
-
self._rlsn._auth.check_auth_blocking = Mock(
- side_effect=ResourceLimitError(403, "foo")
+ return_value=defer.succeed(None), side_effect=ResourceLimitError(403, "foo")
)
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@@ -139,7 +134,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
"""
Test when user does not have blocked notice, nor should they (NOOP)
"""
- self._rlsn._auth.check_auth_blocking = Mock()
+ self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None))
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@@ -150,7 +145,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
Test when user is not part of the MAU cohort - this should not ever
happen - but ...
"""
- self._rlsn._auth.check_auth_blocking = Mock()
+ self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None))
self._rlsn._store.user_last_seen_monthly_active = Mock(
return_value=defer.succeed(None)
)
@@ -164,24 +159,28 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
an alert message is not sent into the room
"""
self.hs.config.mau_limit_alerting = False
+
self._rlsn._auth.check_auth_blocking = Mock(
+ return_value=defer.succeed(None),
side_effect=ResourceLimitError(
403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER
- )
+ ),
)
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
- self.assertTrue(self._send_notice.call_count == 0)
+ self.assertEqual(self._send_notice.call_count, 0)
def test_check_hs_disabled_unaffected_by_mau_alert_suppression(self):
"""
Test that when a server is disabled, that MAU limit alerting is ignored.
"""
self.hs.config.mau_limit_alerting = False
+
self._rlsn._auth.check_auth_blocking = Mock(
+ return_value=defer.succeed(None),
side_effect=ResourceLimitError(
403, "foo", limit_type=LimitBlockingTypes.HS_DISABLED
- )
+ ),
)
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@@ -195,10 +194,12 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
"""
self.hs.config.mau_limit_alerting = False
self._rlsn._auth.check_auth_blocking = Mock(
+ return_value=defer.succeed(None),
side_effect=ResourceLimitError(
403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER
- )
+ ),
)
+
self._rlsn._server_notices_manager.__is_room_currently_blocked = Mock(
return_value=defer.succeed((True, []))
)
@@ -215,6 +216,26 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ sync.register_servlets,
+ ]
+
+ def default_config(self):
+ c = super().default_config()
+ c["server_notices"] = {
+ "system_mxid_localpart": "server",
+ "system_mxid_display_name": None,
+ "system_mxid_avatar_url": None,
+ "room_name": "Test Server Notice Room",
+ }
+ c["limit_usage_by_mau"] = True
+ c["max_mau_value"] = 5
+ c["admin_contact"] = "mailto:user@test.com"
+ return c
+
def prepare(self, reactor, clock, hs):
self.store = self.hs.get_datastore()
self.server_notices_sender = self.hs.get_server_notices_sender()
@@ -228,22 +249,14 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
if not isinstance(self._rlsn, ResourceLimitsServerNotices):
raise Exception("Failed to find reference to ResourceLimitsServerNotices")
- self.hs.config.limit_usage_by_mau = True
- self.hs.config.hs_disabled = False
- self.hs.config.max_mau_value = 5
- self.hs.config.server_notices_mxid = "@server:test"
- self.hs.config.server_notices_mxid_display_name = None
- self.hs.config.server_notices_mxid_avatar_url = None
- self.hs.config.server_notices_room_name = "Test Server Notice Room"
-
self.user_id = "@user_id:test"
- self.hs.config.admin_contact = "mailto:user@test.com"
-
def test_server_notice_only_sent_once(self):
self.store.get_monthly_active_count = Mock(return_value=1000)
- self.store.user_last_seen_monthly_active = Mock(return_value=1000)
+ self.store.user_last_seen_monthly_active = Mock(
+ return_value=defer.succeed(1000)
+ )
# Call the function multiple times to ensure we only send the notice once
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@@ -253,7 +266,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
# Now lets get the last load of messages in the service notice room and
# check that there is only one server notice
room_id = self.get_success(
- self.server_notices_manager.get_notice_room_for_user(self.user_id)
+ self.server_notices_manager.get_or_create_notice_room_for_user(self.user_id)
)
token = self.get_success(self.event_source.get_current_token())
@@ -273,3 +286,86 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
count += 1
self.assertEqual(count, 1)
+
+ def test_no_invite_without_notice(self):
+ """Tests that a user doesn't get invited to a server notices room without a
+ server notice being sent.
+
+ The scenario for this test is a single user on a server where the MAU limit
+ hasn't been reached (since it's the only user and the limit is 5), so users
+ shouldn't receive a server notice.
+ """
+ self.register_user("user", "password")
+ tok = self.login("user", "password")
+
+ request, channel = self.make_request("GET", "/sync?timeout=0", access_token=tok)
+ self.render(request)
+
+ invites = channel.json_body["rooms"]["invite"]
+ self.assertEqual(len(invites), 0, invites)
+
+ def test_invite_with_notice(self):
+ """Tests that, if the MAU limit is hit, the server notices user invites each user
+ to a room in which it has sent a notice.
+ """
+ user_id, tok, room_id = self._trigger_notice_and_join()
+
+ # Sync again to retrieve the events in the room, so we can check whether this
+ # room has a notice in it.
+ request, channel = self.make_request("GET", "/sync?timeout=0", access_token=tok)
+ self.render(request)
+
+ # Scan the events in the room to search for a message from the server notices
+ # user.
+ events = channel.json_body["rooms"]["join"][room_id]["timeline"]["events"]
+ notice_in_room = False
+ for event in events:
+ if (
+ event["type"] == EventTypes.Message
+ and event["sender"] == self.hs.config.server_notices_mxid
+ ):
+ notice_in_room = True
+
+ self.assertTrue(notice_in_room, "No server notice in room")
+
+ def _trigger_notice_and_join(self):
+ """Creates enough active users to hit the MAU limit and trigger a system notice
+ about it, then joins the system notices room with one of the users created.
+
+ Returns:
+ user_id (str): The ID of the user that joined the room.
+ tok (str): The access token of the user that joined the room.
+ room_id (str): The ID of the room that's been joined.
+ """
+ user_id = None
+ tok = None
+ invites = []
+
+ # Register as many users as the MAU limit allows.
+ for i in range(self.hs.config.max_mau_value):
+ localpart = "user%d" % i
+ user_id = self.register_user(localpart, "password")
+ tok = self.login(localpart, "password")
+
+ # Sync with the user's token to mark the user as active.
+ request, channel = self.make_request(
+ "GET", "/sync?timeout=0", access_token=tok,
+ )
+ self.render(request)
+
+ # Also retrieves the list of invites for this user. We don't care about that
+ # one except if we're processing the last user, which should have received an
+ # invite to a room with a server notice about the MAU limit being reached.
+ # We could also pick another user and sync with it, which would return an
+ # invite to a system notices room, but it doesn't matter which user we're
+ # using so we use the last one because it saves us an extra sync.
+ invites = channel.json_body["rooms"]["invite"]
+
+ # Make sure we have an invite to process.
+ self.assertEqual(len(invites), 1, invites)
+
+ # Join the room.
+ room_id = list(invites.keys())[0]
+ self.helper.join(room=room_id, user=user_id, tok=tok)
+
+ return user_id, tok, room_id
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index e37260a820..5a50e4fdd4 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -25,8 +25,8 @@ from synapse.util.caches.descriptors import Cache, cached
from tests import unittest
-class CacheTestCase(unittest.TestCase):
- def setUp(self):
+class CacheTestCase(unittest.HomeserverTestCase):
+ def prepare(self, reactor, clock, homeserver):
self.cache = Cache("test")
def test_empty(self):
@@ -96,7 +96,7 @@ class CacheTestCase(unittest.TestCase):
cache.get(3)
-class CacheDecoratorTestCase(unittest.TestCase):
+class CacheDecoratorTestCase(unittest.HomeserverTestCase):
@defer.inlineCallbacks
def test_passthrough(self):
class A(object):
@@ -239,7 +239,7 @@ class CacheDecoratorTestCase(unittest.TestCase):
callcount2 = [0]
class A(object):
- @cached(max_entries=4) # HACK: This makes it 2 due to cache factor
+ @cached(max_entries=2)
def func(self, key):
callcount[0] += 1
return key
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 31710949a8..ef296e7dab 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -43,7 +43,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
)
hs.config.app_service_config_files = self.as_yaml_files
- hs.config.event_cache_size = 1
+ hs.config.caches.event_cache_size = 1
hs.config.password_providers = []
self.as_token = "token1"
@@ -110,7 +110,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
)
hs.config.app_service_config_files = self.as_yaml_files
- hs.config.event_cache_size = 1
+ hs.config.caches.event_cache_size = 1
hs.config.password_providers = []
self.as_list = [
@@ -422,7 +422,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
)
hs.config.app_service_config_files = [f1, f2]
- hs.config.event_cache_size = 1
+ hs.config.caches.event_cache_size = 1
hs.config.password_providers = []
database = hs.get_datastores().databases[0]
@@ -440,7 +440,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
)
hs.config.app_service_config_files = [f1, f2]
- hs.config.event_cache_size = 1
+ hs.config.caches.event_cache_size = 1
hs.config.password_providers = []
with self.assertRaises(ConfigError) as cm:
@@ -464,7 +464,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
)
hs.config.app_service_config_files = [f1, f2]
- hs.config.event_cache_size = 1
+ hs.config.caches.event_cache_size = 1
hs.config.password_providers = []
with self.assertRaises(ConfigError) as cm:
diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index ae14fb407d..940b166129 100644
--- a/tests/storage/test_background_update.py
+++ b/tests/storage/test_background_update.py
@@ -11,7 +11,9 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, homeserver):
self.updates = self.hs.get_datastore().db.updates # type: BackgroundUpdater
# the base test class should have run the real bg updates for us
- self.assertTrue(self.updates.has_completed_background_updates())
+ self.assertTrue(
+ self.get_success(self.updates.has_completed_background_updates())
+ )
self.update_handler = Mock()
self.updates.register_background_update_handler(
@@ -25,12 +27,20 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
# the target runtime for each bg update
target_background_update_duration_ms = 50000
+ store = self.hs.get_datastore()
+ self.get_success(
+ store.db.simple_insert(
+ "background_updates",
+ values={"update_name": "test_update", "progress_json": '{"my_key": 1}'},
+ )
+ )
+
# first step: make a bit of progress
@defer.inlineCallbacks
def update(progress, count):
yield self.clock.sleep((count * duration_ms) / 1000)
progress = {"my_key": progress["my_key"] + 1}
- yield self.hs.get_datastore().db.runInteraction(
+ yield store.db.runInteraction(
"update_progress",
self.updates._background_update_progress_txn,
"test_update",
@@ -39,10 +49,6 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
return count
self.update_handler.side_effect = update
-
- self.get_success(
- self.updates.start_background_update("test_update", {"my_key": 1})
- )
self.update_handler.reset_mock()
res = self.get_success(
self.updates.do_next_background_update(
@@ -50,7 +56,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
),
by=0.1,
)
- self.assertIsNotNone(res)
+ self.assertFalse(res)
# on the first call, we should get run with the default background update size
self.update_handler.assert_called_once_with(
@@ -73,7 +79,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
result = self.get_success(
self.updates.do_next_background_update(target_background_update_duration_ms)
)
- self.assertIsNotNone(result)
+ self.assertFalse(result)
self.update_handler.assert_called_once()
# third step: we don't expect to be called any more
@@ -81,5 +87,5 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
result = self.get_success(
self.updates.do_next_background_update(target_background_update_duration_ms)
)
- self.assertIsNone(result)
+ self.assertTrue(result)
self.assertFalse(self.update_handler.called)
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index cdee0a9e60..278961c331 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -51,7 +51,8 @@ class SQLBaseStoreTestCase(unittest.TestCase):
config = Mock()
config._disable_native_upserts = True
- config.event_cache_size = 1
+ config.caches = Mock()
+ config.caches.event_cache_size = 1
hs = TestHomeServer("test", config=config)
sqlite_config = {"name": "sqlite3"}
diff --git a/tests/storage/test_database.py b/tests/storage/test_database.py
new file mode 100644
index 0000000000..5a77c84962
--- /dev/null
+++ b/tests/storage/test_database.py
@@ -0,0 +1,52 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.storage.database import make_tuple_comparison_clause
+from synapse.storage.engines import BaseDatabaseEngine
+
+from tests import unittest
+
+
+def _stub_db_engine(**kwargs) -> BaseDatabaseEngine:
+ # returns a DatabaseEngine, circumventing the abc mechanism
+ # any kwargs are set as attributes on the class before instantiating it
+ t = type(
+ "TestBaseDatabaseEngine",
+ (BaseDatabaseEngine,),
+ dict(BaseDatabaseEngine.__dict__),
+ )
+ # defeat the abc mechanism
+ t.__abstractmethods__ = set()
+ for k, v in kwargs.items():
+ setattr(t, k, v)
+ return t(None, None)
+
+
+class TupleComparisonClauseTestCase(unittest.TestCase):
+ def test_native_tuple_comparison(self):
+ db_engine = _stub_db_engine(supports_tuple_comparison=True)
+ clause, args = make_tuple_comparison_clause(db_engine, [("a", 1), ("b", 2)])
+ self.assertEqual(clause, "(a,b) > (?,?)")
+ self.assertEqual(args, [1, 2])
+
+ def test_emulated_tuple_comparison(self):
+ db_engine = _stub_db_engine(supports_tuple_comparison=False)
+ clause, args = make_tuple_comparison_clause(
+ db_engine, [("a", 1), ("b", 2), ("c", 3)]
+ )
+ self.assertEqual(
+ clause, "(a >= ? AND (a > ? OR (b >= ? AND (b > ? OR c > ?))))"
+ )
+ self.assertEqual(args, [1, 1, 2, 2, 3])
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index 6f8d990959..c2539b353a 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -88,51 +88,6 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
# Check original device_ids are contained within these updates
self._check_devices_in_updates(device_ids, device_updates)
- @defer.inlineCallbacks
- def test_get_device_updates_by_remote_limited(self):
- # Test breaking the update limit in 1, 101, and 1 device_id segments
-
- # first add one device
- device_ids1 = ["device_id0"]
- yield self.store.add_device_change_to_streams(
- "user_id", device_ids1, ["someotherhost"]
- )
-
- # then add 101
- device_ids2 = ["device_id" + str(i + 1) for i in range(101)]
- yield self.store.add_device_change_to_streams(
- "user_id", device_ids2, ["someotherhost"]
- )
-
- # then one more
- device_ids3 = ["newdevice"]
- yield self.store.add_device_change_to_streams(
- "user_id", device_ids3, ["someotherhost"]
- )
-
- #
- # now read them back.
- #
-
- # first we should get a single update
- now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
- "someotherhost", -1, limit=100
- )
- self._check_devices_in_updates(device_ids1, device_updates)
-
- # Then we should get an empty list back as the 101 devices broke the limit
- now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
- "someotherhost", now_stream_id, limit=100
- )
- self.assertEqual(len(device_updates), 0)
-
- # The 101 devices should've been cleared, so we should now just get one device
- # update
- now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
- "someotherhost", now_stream_id, limit=100
- )
- self._check_devices_in_updates(device_ids3, device_updates)
-
def _check_devices_in_updates(self, expected_device_ids, device_updates):
"""Check that an specific device ids exist in a list of device update EDUs"""
self.assertEqual(len(device_updates), len(expected_device_ids))
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index d4bcf1821e..b45bc9c115 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -35,6 +35,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
def setUp(self):
hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
self.store = hs.get_datastore()
+ self.persist_events_store = hs.get_datastores().persist_events
@defer.inlineCallbacks
def test_get_unread_push_actions_for_user_in_range_for_http(self):
@@ -76,7 +77,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
)
yield self.store.db.runInteraction(
"",
- self.store._set_push_actions_for_event_and_users_txn,
+ self.persist_events_store._set_push_actions_for_event_and_users_txn,
[(event, None)],
[(event, None)],
)
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
new file mode 100644
index 0000000000..55e9ecf264
--- /dev/null
+++ b/tests/storage/test_id_generators.py
@@ -0,0 +1,184 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from synapse.storage.database import Database
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
+
+from tests.unittest import HomeserverTestCase
+from tests.utils import USE_POSTGRES_FOR_TESTS
+
+
+class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
+ if not USE_POSTGRES_FOR_TESTS:
+ skip = "Requires Postgres"
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+ self.db = self.store.db # type: Database
+
+ self.get_success(self.db.runInteraction("_setup_db", self._setup_db))
+
+ def _setup_db(self, txn):
+ txn.execute("CREATE SEQUENCE foobar_seq")
+ txn.execute(
+ """
+ CREATE TABLE foobar (
+ stream_id BIGINT NOT NULL,
+ instance_name TEXT NOT NULL,
+ data TEXT
+ );
+ """
+ )
+
+ def _create_id_generator(self, instance_name="master") -> MultiWriterIdGenerator:
+ def _create(conn):
+ return MultiWriterIdGenerator(
+ conn,
+ self.db,
+ instance_name=instance_name,
+ table="foobar",
+ instance_column="instance_name",
+ id_column="stream_id",
+ sequence_name="foobar_seq",
+ )
+
+ return self.get_success(self.db.runWithConnection(_create))
+
+ def _insert_rows(self, instance_name: str, number: int):
+ def _insert(txn):
+ for _ in range(number):
+ txn.execute(
+ "INSERT INTO foobar VALUES (nextval('foobar_seq'), ?)",
+ (instance_name,),
+ )
+
+ self.get_success(self.db.runInteraction("test_single_instance", _insert))
+
+ def test_empty(self):
+ """Test an ID generator against an empty database gives sensible
+ current positions.
+ """
+
+ id_gen = self._create_id_generator()
+
+ # The table is empty so we expect an empty map for positions
+ self.assertEqual(id_gen.get_positions(), {})
+
+ def test_single_instance(self):
+ """Test that reads and writes from a single process are handled
+ correctly.
+ """
+
+ # Prefill table with 7 rows written by 'master'
+ self._insert_rows("master", 7)
+
+ id_gen = self._create_id_generator()
+
+ self.assertEqual(id_gen.get_positions(), {"master": 7})
+ self.assertEqual(id_gen.get_current_token("master"), 7)
+
+ # Try allocating a new ID gen and check that we only see position
+ # advanced after we leave the context manager.
+
+ async def _get_next_async():
+ with await id_gen.get_next() as stream_id:
+ self.assertEqual(stream_id, 8)
+
+ self.assertEqual(id_gen.get_positions(), {"master": 7})
+ self.assertEqual(id_gen.get_current_token("master"), 7)
+
+ self.get_success(_get_next_async())
+
+ self.assertEqual(id_gen.get_positions(), {"master": 8})
+ self.assertEqual(id_gen.get_current_token("master"), 8)
+
+ def test_multi_instance(self):
+ """Test that reads and writes from multiple processes are handled
+ correctly.
+ """
+ self._insert_rows("first", 3)
+ self._insert_rows("second", 4)
+
+ first_id_gen = self._create_id_generator("first")
+ second_id_gen = self._create_id_generator("second")
+
+ self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
+ self.assertEqual(first_id_gen.get_current_token("first"), 3)
+ self.assertEqual(first_id_gen.get_current_token("second"), 7)
+
+ # Try allocating a new ID gen and check that we only see position
+ # advanced after we leave the context manager.
+
+ async def _get_next_async():
+ with await first_id_gen.get_next() as stream_id:
+ self.assertEqual(stream_id, 8)
+
+ self.assertEqual(
+ first_id_gen.get_positions(), {"first": 3, "second": 7}
+ )
+
+ self.get_success(_get_next_async())
+
+ self.assertEqual(first_id_gen.get_positions(), {"first": 8, "second": 7})
+
+ # However the ID gen on the second instance won't have seen the update
+ self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7})
+
+ # ... but calling `get_next` on the second instance should give a unique
+ # stream ID
+
+ async def _get_next_async():
+ with await second_id_gen.get_next() as stream_id:
+ self.assertEqual(stream_id, 9)
+
+ self.assertEqual(
+ second_id_gen.get_positions(), {"first": 3, "second": 7}
+ )
+
+ self.get_success(_get_next_async())
+
+ self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 9})
+
+ # If the second ID gen gets told about the first, it correctly updates
+ second_id_gen.advance("first", 8)
+ self.assertEqual(second_id_gen.get_positions(), {"first": 8, "second": 9})
+
+ def test_get_next_txn(self):
+ """Test that the `get_next_txn` function works correctly.
+ """
+
+ # Prefill table with 7 rows written by 'master'
+ self._insert_rows("master", 7)
+
+ id_gen = self._create_id_generator()
+
+ self.assertEqual(id_gen.get_positions(), {"master": 7})
+ self.assertEqual(id_gen.get_current_token("master"), 7)
+
+ # Try allocating a new ID gen and check that we only see position
+ # advanced after we leave the context manager.
+
+ def _get_next_txn(txn):
+ stream_id = id_gen.get_next_txn(txn)
+ self.assertEqual(stream_id, 8)
+
+ self.assertEqual(id_gen.get_positions(), {"master": 7})
+ self.assertEqual(id_gen.get_current_token("master"), 7)
+
+ self.get_success(self.db.runInteraction("test", _get_next_txn))
+
+ self.assertEqual(id_gen.get_positions(), {"master": 8})
+ self.assertEqual(id_gen.get_current_token("master"), 8)
diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py
new file mode 100644
index 0000000000..ab0df5ea93
--- /dev/null
+++ b/tests/storage/test_main.py
@@ -0,0 +1,46 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Awesome Technologies Innovationslabor GmbH
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from twisted.internet import defer
+
+from synapse.types import UserID
+
+from tests import unittest
+from tests.utils import setup_test_homeserver
+
+
+class DataStoreTestCase(unittest.TestCase):
+ @defer.inlineCallbacks
+ def setUp(self):
+ hs = yield setup_test_homeserver(self.addCleanup)
+
+ self.store = hs.get_datastore()
+
+ self.user = UserID.from_string("@abcde:test")
+ self.displayname = "Frank"
+
+ @defer.inlineCallbacks
+ def test_get_users_paginate(self):
+ yield self.store.register_user(self.user.to_string(), "pass")
+ yield self.store.create_profile(self.user.localpart)
+ yield self.store.set_profile_displayname(self.user.localpart, self.displayname)
+
+ users, total = yield self.store.get_users_paginate(
+ 0, 10, name="bc", guests=False
+ )
+
+ self.assertEquals(1, total)
+ self.assertEquals(self.displayname, users.pop()["displayname"])
diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index 086adeb8fd..3b78d48896 100644
--- a/tests/storage/test_room.py
+++ b/tests/storage/test_room.py
@@ -55,6 +55,17 @@ class RoomStoreTestCase(unittest.TestCase):
(yield self.store.get_room(self.room.to_string())),
)
+ @defer.inlineCallbacks
+ def test_get_room_with_stats(self):
+ self.assertDictContainsSubset(
+ {
+ "room_id": self.room.to_string(),
+ "creator": self.u_creator.to_string(),
+ "public": True,
+ },
+ (yield self.store.get_room_with_stats(self.room.to_string())),
+ )
+
class RoomEventsStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 00df0ea68e..5dd46005e6 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -22,6 +22,8 @@ from synapse.rest.client.v1 import login, room
from synapse.types import Requester, UserID
from tests import unittest
+from tests.test_utils import event_injection
+from tests.utils import TestHomeServer
class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
@@ -38,7 +40,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
)
return hs
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor, clock, hs: TestHomeServer):
# We can't test the RoomMemberStore on its own without the other event
# storage logic
@@ -114,6 +116,52 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
# It now knows about Charlie's server.
self.assertEqual(self.store._known_servers_count, 2)
+ def test_get_joined_users_from_context(self):
+ room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
+ bob_event = event_injection.inject_member_event(
+ self.hs, room, self.u_bob, Membership.JOIN
+ )
+
+ # first, create a regular event
+ event, context = event_injection.create_event(
+ self.hs,
+ room_id=room,
+ sender=self.u_alice,
+ prev_event_ids=[bob_event.event_id],
+ type="m.test.1",
+ content={},
+ )
+
+ users = self.get_success(
+ self.store.get_joined_users_from_context(event, context)
+ )
+ self.assertEqual(users.keys(), {self.u_alice, self.u_bob})
+
+ # Regression test for #7376: create a state event whose key matches bob's
+ # user_id, but which is *not* a membership event, and persist that; then check
+ # that `get_joined_users_from_context` returns the correct users for the next event.
+ non_member_event = event_injection.inject_event(
+ self.hs,
+ room_id=room,
+ sender=self.u_bob,
+ prev_event_ids=[bob_event.event_id],
+ type="m.test.2",
+ state_key=self.u_bob,
+ content={},
+ )
+ event, context = event_injection.create_event(
+ self.hs,
+ room_id=room,
+ sender=self.u_alice,
+ prev_event_ids=[non_member_event.event_id],
+ type="m.test.3",
+ content={},
+ )
+ users = self.get_success(
+ self.store.get_joined_users_from_context(event, context)
+ )
+ self.assertEqual(users.keys(), {self.u_alice, self.u_bob})
+
class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, homeserver):
diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py
index 6c2351cf55..f2def601fb 100644
--- a/tests/test_event_auth.py
+++ b/tests/test_event_auth.py
@@ -165,6 +165,39 @@ class EventAuthTestCase(unittest.TestCase):
do_sig_check=False,
)
+ def test_msc2209(self):
+ """
+ Notifications power levels get checked due to MSC2209.
+ """
+ creator = "@creator:example.com"
+ pleb = "@joiner:example.com"
+
+ auth_events = {
+ ("m.room.create", ""): _create_event(creator),
+ ("m.room.member", creator): _join_event(creator),
+ ("m.room.power_levels", ""): _power_levels_event(
+ creator, {"state_default": "30", "users": {pleb: "30"}}
+ ),
+ ("m.room.member", pleb): _join_event(pleb),
+ }
+
+ # pleb should be able to modify the notifications power level.
+ event_auth.check(
+ RoomVersions.V1,
+ _power_levels_event(pleb, {"notifications": {"room": 100}}),
+ auth_events,
+ do_sig_check=False,
+ )
+
+ # But an MSC2209 room rejects this change.
+ with self.assertRaises(AuthError):
+ event_auth.check(
+ RoomVersions.MSC2209_DEV,
+ _power_levels_event(pleb, {"notifications": {"room": 100}}),
+ auth_events,
+ do_sig_check=False,
+ )
+
# helpers for making events
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 9b5cf562f3..f297de95f1 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -27,8 +27,10 @@ class MessageAcceptTests(unittest.TestCase):
user_id = UserID("us", "test")
our_user = Requester(user_id, None, False, None, None)
room_creator = self.homeserver.get_room_creation_handler()
- room = room_creator.create_room(
- our_user, room_creator.PRESETS_DICT["public_chat"], ratelimit=False
+ room = ensureDeferred(
+ room_creator.create_room(
+ our_user, room_creator.PRESETS_DICT["public_chat"], ratelimit=False
+ )
)
self.reactor.advance(0.1)
self.room_id = self.successResultOf(room)["room_id"]
diff --git a/tests/test_mau.py b/tests/test_mau.py
index 1fbe0d51ff..eb159e3ba5 100644
--- a/tests/test_mau.py
+++ b/tests/test_mau.py
@@ -19,6 +19,7 @@ import json
from mock import Mock
+from synapse.api.auth_blocking import AuthBlocking
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.rest.client.v2_alpha import register, sync
@@ -45,11 +46,17 @@ class TestMauLimit(unittest.HomeserverTestCase):
self.hs.config.limit_usage_by_mau = True
self.hs.config.hs_disabled = False
self.hs.config.max_mau_value = 2
- self.hs.config.mau_trial_days = 0
self.hs.config.server_notices_mxid = "@server:red"
self.hs.config.server_notices_mxid_display_name = None
self.hs.config.server_notices_mxid_avatar_url = None
self.hs.config.server_notices_room_name = "Test Server Notice Room"
+ self.hs.config.mau_trial_days = 0
+
+ # AuthBlocking reads config options during hs creation. Recreate the
+ # hs' copy of AuthBlocking after we've updated config values above
+ self.auth_blocking = AuthBlocking(self.hs)
+ self.hs.get_auth()._auth_blocking = self.auth_blocking
+
return self.hs
def test_simple_deny_mau(self):
@@ -121,6 +128,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
def test_trial_users_cant_come_back(self):
+ self.auth_blocking._mau_trial_days = 1
self.hs.config.mau_trial_days = 1
# We should be able to register more than the limit initially
@@ -169,8 +177,8 @@ class TestMauLimit(unittest.HomeserverTestCase):
self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
def test_tracked_but_not_limited(self):
- self.hs.config.max_mau_value = 1 # should not matter
- self.hs.config.limit_usage_by_mau = False
+ self.auth_blocking._max_mau_value = 1 # should not matter
+ self.auth_blocking._limit_usage_by_mau = False
self.hs.config.mau_stats_only = True
# Simply being able to create 2 users indicates that the
diff --git a/tests/test_metrics.py b/tests/test_metrics.py
index 270f853d60..f5f63d8ed6 100644
--- a/tests/test_metrics.py
+++ b/tests/test_metrics.py
@@ -15,6 +15,7 @@
# limitations under the License.
from synapse.metrics import REGISTRY, InFlightGauge, generate_latest
+from synapse.util.caches.descriptors import Cache
from tests import unittest
@@ -129,3 +130,36 @@ class BuildInfoTests(unittest.TestCase):
self.assertTrue(b"osversion=" in items[0])
self.assertTrue(b"pythonversion=" in items[0])
self.assertTrue(b"version=" in items[0])
+
+
+class CacheMetricsTests(unittest.HomeserverTestCase):
+ def test_cache_metric(self):
+ """
+ Caches produce metrics reflecting their state when scraped.
+ """
+ CACHE_NAME = "cache_metrics_test_fgjkbdfg"
+ cache = Cache(CACHE_NAME, max_entries=777)
+
+ items = {
+ x.split(b"{")[0].decode("ascii"): x.split(b" ")[1].decode("ascii")
+ for x in filter(
+ lambda x: b"cache_metrics_test_fgjkbdfg" in x,
+ generate_latest(REGISTRY).split(b"\n"),
+ )
+ }
+
+ self.assertEqual(items["synapse_util_caches_cache_size"], "0.0")
+ self.assertEqual(items["synapse_util_caches_cache_max_size"], "777.0")
+
+ cache.prefill("1", "hi")
+
+ items = {
+ x.split(b"{")[0].decode("ascii"): x.split(b" ")[1].decode("ascii")
+ for x in filter(
+ lambda x: b"cache_metrics_test_fgjkbdfg" in x,
+ generate_latest(REGISTRY).split(b"\n"),
+ )
+ }
+
+ self.assertEqual(items["synapse_util_caches_cache_size"], "1.0")
+ self.assertEqual(items["synapse_util_caches_cache_max_size"], "777.0")
diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py
index 5ec5d2b358..5c2817cf28 100644
--- a/tests/test_terms_auth.py
+++ b/tests/test_terms_auth.py
@@ -28,8 +28,8 @@ from tests import unittest
class TermsTestCase(unittest.HomeserverTestCase):
servlets = [register_servlets]
- def default_config(self, name="test"):
- config = super().default_config(name)
+ def default_config(self):
+ config = super().default_config()
config.update(
{
"public_baseurl": "https://example.org/",
@@ -53,7 +53,8 @@ class TermsTestCase(unittest.HomeserverTestCase):
def test_ui_auth(self):
# Do a UI auth request
- request, channel = self.make_request(b"POST", self.url, b"{}")
+ request_data = json.dumps({"username": "kermit", "password": "monkey"})
+ request, channel = self.make_request(b"POST", self.url, request_data)
self.render(request)
self.assertEquals(channel.result["code"], b"401", channel.result)
diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index a7310cf12a..7b345b03bb 100644
--- a/tests/test_utils/__init__.py
+++ b/tests/test_utils/__init__.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2019 New Vector Ltd
+# Copyright 2020 The Matrix.org Foundation C.I.C
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,3 +17,22 @@
"""
Utilities for running the unit tests
"""
+from typing import Awaitable, TypeVar
+
+TV = TypeVar("TV")
+
+
+def get_awaitable_result(awaitable: Awaitable[TV]) -> TV:
+ """Get the result from an Awaitable which should have completed
+
+ Asserts that the given awaitable has a result ready, and returns its value
+ """
+ i = awaitable.__await__()
+ try:
+ next(i)
+ except StopIteration as e:
+ # awaitable returned a result
+ return e.value
+
+ # if next didn't raise, the awaitable hasn't completed.
+ raise Exception("awaitable has not yet completed")
diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py
new file mode 100644
index 0000000000..431e9f8e5e
--- /dev/null
+++ b/tests/test_utils/event_injection.py
@@ -0,0 +1,110 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+# Copyright 2020 The Matrix.org Foundation C.I.C
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional, Tuple
+
+import synapse.server
+from synapse.api.constants import EventTypes
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
+from synapse.events import EventBase
+from synapse.events.snapshot import EventContext
+from synapse.types import Collection
+
+from tests.test_utils import get_awaitable_result
+
+
+"""
+Utility functions for poking events into the storage of the server under test.
+"""
+
+
+def inject_member_event(
+ hs: synapse.server.HomeServer,
+ room_id: str,
+ sender: str,
+ membership: str,
+ target: Optional[str] = None,
+ extra_content: Optional[dict] = None,
+ **kwargs
+) -> EventBase:
+ """Inject a membership event into a room."""
+ if target is None:
+ target = sender
+
+ content = {"membership": membership}
+ if extra_content:
+ content.update(extra_content)
+
+ return inject_event(
+ hs,
+ room_id=room_id,
+ type=EventTypes.Member,
+ sender=sender,
+ state_key=target,
+ content=content,
+ **kwargs
+ )
+
+
+def inject_event(
+ hs: synapse.server.HomeServer,
+ room_version: Optional[str] = None,
+ prev_event_ids: Optional[Collection[str]] = None,
+ **kwargs
+) -> EventBase:
+ """Inject a generic event into a room
+
+ Args:
+ hs: the homeserver under test
+ room_version: the version of the room we're inserting into.
+ if not specified, will be looked up
+ prev_event_ids: prev_events for the event. If not specified, will be looked up
+ kwargs: fields for the event to be created
+ """
+ test_reactor = hs.get_reactor()
+
+ event, context = create_event(hs, room_version, prev_event_ids, **kwargs)
+
+ d = hs.get_storage().persistence.persist_event(event, context)
+ test_reactor.advance(0)
+ get_awaitable_result(d)
+
+ return event
+
+
+def create_event(
+ hs: synapse.server.HomeServer,
+ room_version: Optional[str] = None,
+ prev_event_ids: Optional[Collection[str]] = None,
+ **kwargs
+) -> Tuple[EventBase, EventContext]:
+ test_reactor = hs.get_reactor()
+
+ if room_version is None:
+ d = hs.get_datastore().get_room_version_id(kwargs["room_id"])
+ test_reactor.advance(0)
+ room_version = get_awaitable_result(d)
+
+ builder = hs.get_event_builder_factory().for_room_version(
+ KNOWN_ROOM_VERSIONS[room_version], kwargs
+ )
+ d = hs.get_event_creation_handler().create_new_client_event(
+ builder, prev_event_ids=prev_event_ids
+ )
+ test_reactor.advance(0)
+ event, context = get_awaitable_result(d)
+
+ return event, context
diff --git a/tests/unittest.py b/tests/unittest.py
index 8816a4d152..6b6f224e9c 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -32,13 +32,17 @@ from twisted.python.threadpool import ThreadPool
from twisted.trial import unittest
from synapse.api.constants import EventTypes, Membership
-from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.config.homeserver import HomeServerConfig
from synapse.config.ratelimiting import FederationRateLimitConfig
from synapse.federation.transport import server as federation_server
from synapse.http.server import JsonResource
from synapse.http.site import SynapseRequest, SynapseSite
-from synapse.logging.context import LoggingContext
+from synapse.logging.context import (
+ SENTINEL_CONTEXT,
+ LoggingContext,
+ current_context,
+ set_current_context,
+)
from synapse.server import HomeServer
from synapse.types import Requester, UserID, create_requester
from synapse.util.ratelimitutils import FederationRateLimiter
@@ -50,6 +54,7 @@ from tests.server import (
render,
setup_test_homeserver,
)
+from tests.test_utils import event_injection
from tests.test_utils.logging_setup import setup_logging
from tests.utils import default_config, setupdb
@@ -97,10 +102,10 @@ class TestCase(unittest.TestCase):
def setUp(orig):
# if we're not starting in the sentinel logcontext, then to be honest
# all future bets are off.
- if LoggingContext.current_context() is not LoggingContext.sentinel:
+ if current_context():
self.fail(
"Test starting with non-sentinel logging context %s"
- % (LoggingContext.current_context(),)
+ % (current_context(),)
)
old_level = logging.getLogger().level
@@ -122,7 +127,7 @@ class TestCase(unittest.TestCase):
# force a GC to workaround problems with deferreds leaking logcontexts when
# they are GCed (see the logcontext docs)
gc.collect()
- LoggingContext.set_current_context(LoggingContext.sentinel)
+ set_current_context(SENTINEL_CONTEXT)
return ret
@@ -311,14 +316,11 @@ class HomeserverTestCase(TestCase):
return resource
- def default_config(self, name="test"):
+ def default_config(self):
"""
Get a default HomeServer config dict.
-
- Args:
- name (str): The homeserver name/domain.
"""
- config = default_config(name)
+ config = default_config("test")
# apply any additional config which was specified via the override_config
# decorator.
@@ -418,15 +420,17 @@ class HomeserverTestCase(TestCase):
config_obj.parse_config_dict(config, "", "")
kwargs["config"] = config_obj
+ async def run_bg_updates():
+ with LoggingContext("run_bg_updates", request="run_bg_updates-1"):
+ while not await stor.db.updates.has_completed_background_updates():
+ await stor.db.updates.do_next_background_update(1)
+
hs = setup_test_homeserver(self.addCleanup, *args, **kwargs)
stor = hs.get_datastore()
# Run the database background updates, when running against "master".
if hs.__class__.__name__ == "TestHomeServer":
- while not self.get_success(
- stor.db.updates.has_completed_background_updates()
- ):
- self.get_success(stor.db.updates.do_next_background_update(1))
+ self.get_success(run_bg_updates())
return hs
@@ -493,6 +497,7 @@ class HomeserverTestCase(TestCase):
"password": password,
"admin": admin,
"mac": want_mac,
+ "inhibit_login": True,
}
)
request, channel = self.make_request(
@@ -591,36 +596,14 @@ class HomeserverTestCase(TestCase):
"""
Inject a membership event into a room.
+ Deprecated: use event_injection.inject_room_member directly
+
Args:
room: Room ID to inject the event into.
user: MXID of the user to inject the membership for.
membership: The membership type.
"""
- event_builder_factory = self.hs.get_event_builder_factory()
- event_creation_handler = self.hs.get_event_creation_handler()
-
- room_version = self.get_success(
- self.hs.get_datastore().get_room_version_id(room)
- )
-
- builder = event_builder_factory.for_room_version(
- KNOWN_ROOM_VERSIONS[room_version],
- {
- "type": EventTypes.Member,
- "sender": user,
- "state_key": user,
- "room_id": room,
- "content": {"membership": membership},
- },
- )
-
- event, context = self.get_success(
- event_creation_handler.create_new_client_event(builder)
- )
-
- self.get_success(
- self.hs.get_storage().persistence.persist_event(event, context)
- )
+ event_injection.inject_member_event(self.hs, room, user, membership)
class FederatingHomeserverTestCase(HomeserverTestCase):
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 39e360fe24..4d2b9e0d64 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -22,8 +22,10 @@ from twisted.internet import defer, reactor
from synapse.api.errors import SynapseError
from synapse.logging.context import (
+ SENTINEL_CONTEXT,
LoggingContext,
PreserveLoggingContext,
+ current_context,
make_deferred_yieldable,
)
from synapse.util.caches import descriptors
@@ -194,7 +196,7 @@ class DescriptorTestCase(unittest.TestCase):
with LoggingContext() as c1:
c1.name = "c1"
r = yield obj.fn(1)
- self.assertEqual(LoggingContext.current_context(), c1)
+ self.assertEqual(current_context(), c1)
return r
def check_result(r):
@@ -204,12 +206,12 @@ class DescriptorTestCase(unittest.TestCase):
# set off a deferred which will do a cache lookup
d1 = do_lookup()
- self.assertEqual(LoggingContext.current_context(), LoggingContext.sentinel)
+ self.assertEqual(current_context(), SENTINEL_CONTEXT)
d1.addCallback(check_result)
# and another
d2 = do_lookup()
- self.assertEqual(LoggingContext.current_context(), LoggingContext.sentinel)
+ self.assertEqual(current_context(), SENTINEL_CONTEXT)
d2.addCallback(check_result)
# let the lookup complete
@@ -239,14 +241,14 @@ class DescriptorTestCase(unittest.TestCase):
try:
d = obj.fn(1)
self.assertEqual(
- LoggingContext.current_context(), LoggingContext.sentinel
+ current_context(), SENTINEL_CONTEXT,
)
yield d
self.fail("No exception thrown")
except SynapseError:
pass
- self.assertEqual(LoggingContext.current_context(), c1)
+ self.assertEqual(current_context(), c1)
# the cache should now be empty
self.assertEqual(len(obj.fn.cache.cache), 0)
@@ -255,7 +257,7 @@ class DescriptorTestCase(unittest.TestCase):
# set off a deferred which will do a cache lookup
d1 = do_lookup()
- self.assertEqual(LoggingContext.current_context(), LoggingContext.sentinel)
+ self.assertEqual(current_context(), SENTINEL_CONTEXT)
return d1
@@ -366,10 +368,10 @@ class CachedListDescriptorTestCase(unittest.TestCase):
@descriptors.cachedList("fn", "args1", inlineCallbacks=True)
def list_fn(self, args1, arg2):
- assert LoggingContext.current_context().request == "c1"
+ assert current_context().request == "c1"
# we want this to behave like an asynchronous function
yield run_on_reactor()
- assert LoggingContext.current_context().request == "c1"
+ assert current_context().request == "c1"
return self.mock(args1, arg2)
with LoggingContext() as c1:
@@ -377,9 +379,9 @@ class CachedListDescriptorTestCase(unittest.TestCase):
obj = Cls()
obj.mock.return_value = {10: "fish", 20: "chips"}
d1 = obj.list_fn([10, 20], 2)
- self.assertEqual(LoggingContext.current_context(), LoggingContext.sentinel)
+ self.assertEqual(current_context(), SENTINEL_CONTEXT)
r = yield d1
- self.assertEqual(LoggingContext.current_context(), c1)
+ self.assertEqual(current_context(), c1)
obj.mock.assert_called_once_with([10, 20], 2)
self.assertEqual(r, {10: "fish", 20: "chips"})
obj.mock.reset_mock()
diff --git a/tests/util/test_async_utils.py b/tests/util/test_async_utils.py
index f60918069a..17fd86d02d 100644
--- a/tests/util/test_async_utils.py
+++ b/tests/util/test_async_utils.py
@@ -16,7 +16,12 @@ from twisted.internet import defer
from twisted.internet.defer import CancelledError, Deferred
from twisted.internet.task import Clock
-from synapse.logging.context import LoggingContext, PreserveLoggingContext
+from synapse.logging.context import (
+ SENTINEL_CONTEXT,
+ LoggingContext,
+ PreserveLoggingContext,
+ current_context,
+)
from synapse.util.async_helpers import timeout_deferred
from tests.unittest import TestCase
@@ -79,10 +84,10 @@ class TimeoutDeferredTest(TestCase):
# the errbacks should be run in the test logcontext
def errback(res, deferred_name):
self.assertIs(
- LoggingContext.current_context(),
+ current_context(),
context_one,
"errback %s run in unexpected logcontext %s"
- % (deferred_name, LoggingContext.current_context()),
+ % (deferred_name, current_context()),
)
return res
@@ -90,7 +95,7 @@ class TimeoutDeferredTest(TestCase):
original_deferred.addErrback(errback, "orig")
timing_out_d = timeout_deferred(original_deferred, 1.0, self.clock)
self.assertNoResult(timing_out_d)
- self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel)
+ self.assertIs(current_context(), SENTINEL_CONTEXT)
timing_out_d.addErrback(errback, "timingout")
self.clock.pump((1.0,))
@@ -99,4 +104,4 @@ class TimeoutDeferredTest(TestCase):
blocking_was_cancelled[0], "non-completing deferred was not cancelled"
)
self.failureResultOf(timing_out_d, defer.TimeoutError)
- self.assertIs(LoggingContext.current_context(), context_one)
+ self.assertIs(current_context(), context_one)
diff --git a/tests/util/test_expiring_cache.py b/tests/util/test_expiring_cache.py
index 50bc7702d2..49ffeebd0e 100644
--- a/tests/util/test_expiring_cache.py
+++ b/tests/util/test_expiring_cache.py
@@ -21,7 +21,7 @@ from tests.utils import MockClock
from .. import unittest
-class ExpiringCacheTestCase(unittest.TestCase):
+class ExpiringCacheTestCase(unittest.HomeserverTestCase):
def test_get_set(self):
clock = MockClock()
cache = ExpiringCache("test", clock, max_len=1)
diff --git a/tests/util/test_linearizer.py b/tests/util/test_linearizer.py
index 0ec8ef90ce..852ef23185 100644
--- a/tests/util/test_linearizer.py
+++ b/tests/util/test_linearizer.py
@@ -19,7 +19,7 @@ from six.moves import range
from twisted.internet import defer, reactor
from twisted.internet.defer import CancelledError
-from synapse.logging.context import LoggingContext
+from synapse.logging.context import LoggingContext, current_context
from synapse.util import Clock
from synapse.util.async_helpers import Linearizer
@@ -54,11 +54,11 @@ class LinearizerTestCase(unittest.TestCase):
def func(i, sleep=False):
with LoggingContext("func(%s)" % i) as lc:
with (yield linearizer.queue("")):
- self.assertEqual(LoggingContext.current_context(), lc)
+ self.assertEqual(current_context(), lc)
if sleep:
yield Clock(reactor).sleep(0)
- self.assertEqual(LoggingContext.current_context(), lc)
+ self.assertEqual(current_context(), lc)
func(0, sleep=True)
for i in range(1, 100):
diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py
index 281b32c4b8..95301c013c 100644
--- a/tests/util/test_logcontext.py
+++ b/tests/util/test_logcontext.py
@@ -2,8 +2,10 @@ import twisted.python.failure
from twisted.internet import defer, reactor
from synapse.logging.context import (
+ SENTINEL_CONTEXT,
LoggingContext,
PreserveLoggingContext,
+ current_context,
make_deferred_yieldable,
nested_logging_context,
run_in_background,
@@ -15,7 +17,7 @@ from .. import unittest
class LoggingContextTestCase(unittest.TestCase):
def _check_test_key(self, value):
- self.assertEquals(LoggingContext.current_context().request, value)
+ self.assertEquals(current_context().request, value)
def test_with_context(self):
with LoggingContext() as context_one:
@@ -41,7 +43,7 @@ class LoggingContextTestCase(unittest.TestCase):
self._check_test_key("one")
def _test_run_in_background(self, function):
- sentinel_context = LoggingContext.current_context()
+ sentinel_context = current_context()
callback_completed = [False]
@@ -71,7 +73,7 @@ class LoggingContextTestCase(unittest.TestCase):
# make sure that the context was reset before it got thrown back
# into the reactor
try:
- self.assertIs(LoggingContext.current_context(), sentinel_context)
+ self.assertIs(current_context(), sentinel_context)
d2.callback(None)
except BaseException:
d2.errback(twisted.python.failure.Failure())
@@ -108,7 +110,7 @@ class LoggingContextTestCase(unittest.TestCase):
async def testfunc():
self._check_test_key("one")
d = Clock(reactor).sleep(0)
- self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel)
+ self.assertIs(current_context(), SENTINEL_CONTEXT)
await d
self._check_test_key("one")
@@ -129,14 +131,14 @@ class LoggingContextTestCase(unittest.TestCase):
reactor.callLater(0, d.callback, None)
return d
- sentinel_context = LoggingContext.current_context()
+ sentinel_context = current_context()
with LoggingContext() as context_one:
context_one.request = "one"
d1 = make_deferred_yieldable(blocking_function())
# make sure that the context was reset by make_deferred_yieldable
- self.assertIs(LoggingContext.current_context(), sentinel_context)
+ self.assertIs(current_context(), sentinel_context)
yield d1
@@ -145,14 +147,14 @@ class LoggingContextTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_make_deferred_yieldable_with_chained_deferreds(self):
- sentinel_context = LoggingContext.current_context()
+ sentinel_context = current_context()
with LoggingContext() as context_one:
context_one.request = "one"
d1 = make_deferred_yieldable(_chained_deferred_function())
# make sure that the context was reset by make_deferred_yieldable
- self.assertIs(LoggingContext.current_context(), sentinel_context)
+ self.assertIs(current_context(), sentinel_context)
yield d1
@@ -189,14 +191,14 @@ class LoggingContextTestCase(unittest.TestCase):
reactor.callLater(0, d.callback, None)
await d
- sentinel_context = LoggingContext.current_context()
+ sentinel_context = current_context()
with LoggingContext() as context_one:
context_one.request = "one"
d1 = make_deferred_yieldable(blocking_function())
# make sure that the context was reset by make_deferred_yieldable
- self.assertIs(LoggingContext.current_context(), sentinel_context)
+ self.assertIs(current_context(), sentinel_context)
yield d1
diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py
index 786947375d..0adb2174af 100644
--- a/tests/util/test_lrucache.py
+++ b/tests/util/test_lrucache.py
@@ -22,7 +22,7 @@ from synapse.util.caches.treecache import TreeCache
from .. import unittest
-class LruCacheTestCase(unittest.TestCase):
+class LruCacheTestCase(unittest.HomeserverTestCase):
def test_get_set(self):
cache = LruCache(1)
cache["key"] = "value"
@@ -84,7 +84,7 @@ class LruCacheTestCase(unittest.TestCase):
self.assertEquals(len(cache), 0)
-class LruCacheCallbacksTestCase(unittest.TestCase):
+class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
def test_get(self):
m = Mock()
cache = LruCache(1)
@@ -233,7 +233,7 @@ class LruCacheCallbacksTestCase(unittest.TestCase):
self.assertEquals(m3.call_count, 1)
-class LruCacheSizedTestCase(unittest.TestCase):
+class LruCacheSizedTestCase(unittest.HomeserverTestCase):
def test_evict(self):
cache = LruCache(5, size_callback=len)
cache["key1"] = [0]
diff --git a/tests/util/test_stream_change_cache.py b/tests/util/test_stream_change_cache.py
index 72a9de5370..13b753e367 100644
--- a/tests/util/test_stream_change_cache.py
+++ b/tests/util/test_stream_change_cache.py
@@ -1,11 +1,9 @@
-from mock import patch
-
from synapse.util.caches.stream_change_cache import StreamChangeCache
from tests import unittest
-class StreamChangeCacheTests(unittest.TestCase):
+class StreamChangeCacheTests(unittest.HomeserverTestCase):
"""
Tests for StreamChangeCache.
"""
@@ -28,26 +26,33 @@ class StreamChangeCacheTests(unittest.TestCase):
cache.entity_has_changed("user@foo.com", 6)
cache.entity_has_changed("bar@baz.net", 7)
+ # also test multiple things changing on the same stream ID
+ cache.entity_has_changed("user2@foo.com", 8)
+ cache.entity_has_changed("bar2@baz.net", 8)
+
# If it's been changed after that stream position, return True
self.assertTrue(cache.has_entity_changed("user@foo.com", 4))
self.assertTrue(cache.has_entity_changed("bar@baz.net", 4))
+ self.assertTrue(cache.has_entity_changed("bar2@baz.net", 4))
+ self.assertTrue(cache.has_entity_changed("user2@foo.com", 4))
# If it's been changed at that stream position, return False
self.assertFalse(cache.has_entity_changed("user@foo.com", 6))
+ self.assertFalse(cache.has_entity_changed("user2@foo.com", 8))
# If there's no changes after that stream position, return False
self.assertFalse(cache.has_entity_changed("user@foo.com", 7))
+ self.assertFalse(cache.has_entity_changed("user2@foo.com", 9))
# If the entity does not exist, return False.
- self.assertFalse(cache.has_entity_changed("not@here.website", 7))
+ self.assertFalse(cache.has_entity_changed("not@here.website", 9))
# If we request before the stream cache's earliest known position,
# return True, whether it's a known entity or not.
self.assertTrue(cache.has_entity_changed("user@foo.com", 0))
self.assertTrue(cache.has_entity_changed("not@here.website", 0))
- @patch("synapse.util.caches.CACHE_SIZE_FACTOR", 1.0)
- def test_has_entity_changed_pops_off_start(self):
+ def test_entity_has_changed_pops_off_start(self):
"""
StreamChangeCache.entity_has_changed will respect the max size and
purge the oldest items upon reaching that max size.
@@ -64,11 +69,20 @@ class StreamChangeCacheTests(unittest.TestCase):
# The oldest item has been popped off
self.assertTrue("user@foo.com" not in cache._entity_to_key)
+ self.assertEqual(
+ cache.get_all_entities_changed(2), ["bar@baz.net", "user@elsewhere.org"],
+ )
+ self.assertIsNone(cache.get_all_entities_changed(1))
+
# If we update an existing entity, it keeps the two existing entities
cache.entity_has_changed("bar@baz.net", 5)
self.assertEqual(
{"bar@baz.net", "user@elsewhere.org"}, set(cache._entity_to_key)
)
+ self.assertEqual(
+ cache.get_all_entities_changed(2), ["user@elsewhere.org", "bar@baz.net"],
+ )
+ self.assertIsNone(cache.get_all_entities_changed(1))
def test_get_all_entities_changed(self):
"""
@@ -80,18 +94,52 @@ class StreamChangeCacheTests(unittest.TestCase):
cache.entity_has_changed("user@foo.com", 2)
cache.entity_has_changed("bar@baz.net", 3)
+ cache.entity_has_changed("anotheruser@foo.com", 3)
cache.entity_has_changed("user@elsewhere.org", 4)
- self.assertEqual(
- cache.get_all_entities_changed(1),
- ["user@foo.com", "bar@baz.net", "user@elsewhere.org"],
- )
- self.assertEqual(
- cache.get_all_entities_changed(2), ["bar@baz.net", "user@elsewhere.org"]
- )
+ r = cache.get_all_entities_changed(1)
+
+ # either of these are valid
+ ok1 = [
+ "user@foo.com",
+ "bar@baz.net",
+ "anotheruser@foo.com",
+ "user@elsewhere.org",
+ ]
+ ok2 = [
+ "user@foo.com",
+ "anotheruser@foo.com",
+ "bar@baz.net",
+ "user@elsewhere.org",
+ ]
+ self.assertTrue(r == ok1 or r == ok2)
+
+ r = cache.get_all_entities_changed(2)
+ self.assertTrue(r == ok1[1:] or r == ok2[1:])
+
self.assertEqual(cache.get_all_entities_changed(3), ["user@elsewhere.org"])
self.assertEqual(cache.get_all_entities_changed(0), None)
+ # ... later, things gest more updates
+ cache.entity_has_changed("user@foo.com", 5)
+ cache.entity_has_changed("bar@baz.net", 5)
+ cache.entity_has_changed("anotheruser@foo.com", 6)
+
+ ok1 = [
+ "user@elsewhere.org",
+ "user@foo.com",
+ "bar@baz.net",
+ "anotheruser@foo.com",
+ ]
+ ok2 = [
+ "user@elsewhere.org",
+ "bar@baz.net",
+ "user@foo.com",
+ "anotheruser@foo.com",
+ ]
+ r = cache.get_all_entities_changed(3)
+ self.assertTrue(r == ok1 or r == ok2)
+
def test_has_any_entity_changed(self):
"""
StreamChangeCache.has_any_entity_changed will return True if any
diff --git a/tests/utils.py b/tests/utils.py
index 513f358f4f..59c020a051 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -35,7 +35,7 @@ from synapse.config.homeserver import HomeServerConfig
from synapse.config.server import DEFAULT_ROOM_VERSION
from synapse.federation.transport import server as federation_server
from synapse.http.server import HttpServer
-from synapse.logging.context import LoggingContext
+from synapse.logging.context import current_context, set_current_context
from synapse.server import HomeServer
from synapse.storage import DataStore
from synapse.storage.engines import PostgresEngine, create_engine
@@ -74,7 +74,10 @@ def setupdb():
db_conn.autocommit = True
cur = db_conn.cursor()
cur.execute("DROP DATABASE IF EXISTS %s;" % (POSTGRES_BASE_DB,))
- cur.execute("CREATE DATABASE %s;" % (POSTGRES_BASE_DB,))
+ cur.execute(
+ "CREATE DATABASE %s ENCODING 'UTF8' LC_COLLATE='C' LC_CTYPE='C' "
+ "template=template0;" % (POSTGRES_BASE_DB,)
+ )
cur.close()
db_conn.close()
@@ -164,6 +167,7 @@ def default_config(name, parse=False):
# disable user directory updates, because they get done in the
# background, which upsets the test runner.
"update_user_directory": False,
+ "caches": {"global_factor": 1},
}
if parse:
@@ -332,10 +336,15 @@ def setup_test_homeserver(
# Need to let the HS build an auth handler and then mess with it
# because AuthHandler's constructor requires the HS, so we can't make one
# beforehand and pass it in to the HS's constructor (chicken / egg)
- hs.get_auth_handler().hash = lambda p: hashlib.md5(p.encode("utf8")).hexdigest()
- hs.get_auth_handler().validate_hash = (
- lambda p, h: hashlib.md5(p.encode("utf8")).hexdigest() == h
- )
+ async def hash(p):
+ return hashlib.md5(p.encode("utf8")).hexdigest()
+
+ hs.get_auth_handler().hash = hash
+
+ async def validate_hash(p, h):
+ return hashlib.md5(p.encode("utf8")).hexdigest() == h
+
+ hs.get_auth_handler().validate_hash = validate_hash
fed = kargs.get("resource_for_federation", None)
if fed:
@@ -493,10 +502,10 @@ class MockClock(object):
return self.time() * 1000
def call_later(self, delay, callback, *args, **kwargs):
- current_context = LoggingContext.current_context()
+ ctx = current_context()
def wrapped_callback():
- LoggingContext.thread_local.current_context = current_context
+ set_current_context(ctx)
callback(*args, **kwargs)
t = [self.now + delay, wrapped_callback, False]
@@ -504,8 +513,8 @@ class MockClock(object):
return t
- def looping_call(self, function, interval):
- self.loopers.append([function, interval / 1000.0, self.now])
+ def looping_call(self, function, interval, *args, **kwargs):
+ self.loopers.append([function, interval / 1000.0, self.now, args, kwargs])
def cancel_call_later(self, timer, ignore_errs=False):
if timer[2]:
@@ -535,9 +544,9 @@ class MockClock(object):
self.timers.append(t)
for looped in self.loopers:
- func, interval, last = looped
+ func, interval, last, args, kwargs = looped
if last + interval < self.now:
- func()
+ func(*args, **kwargs)
looped[2] = self.now
def advance_time_msec(self, ms):
|