diff --git a/tests/__init__.py b/tests/__init__.py
index f7fc502f01..ed805db1c2 100644
--- a/tests/__init__.py
+++ b/tests/__init__.py
@@ -16,9 +16,9 @@
from twisted.trial import util
-import tests.patch_inline_callbacks
+from synapse.util.patch_inline_callbacks import do_patch
# attempt to do the patch before we load any synapse code
-tests.patch_inline_callbacks.do_patch()
+do_patch()
util.DEFAULT_TIMEOUT_DURATION = 20
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index 6121efcfa9..0bfb86bf1f 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -52,6 +52,10 @@ class AuthTestCase(unittest.TestCase):
self.hs.handlers = TestHandlers(self.hs)
self.auth = Auth(self.hs)
+ # AuthBlocking reads from the hs' config on initialization. We need to
+ # modify its config instead of the hs'
+ self.auth_blocking = self.auth._auth_blocking
+
self.test_user = "@foo:bar"
self.test_token = b"_test_token_"
@@ -68,7 +72,7 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={})
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- requester = yield self.auth.get_user_by_req(request)
+ requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request))
self.assertEquals(requester.user.to_string(), self.test_user)
def test_get_user_by_req_user_bad_token(self):
@@ -105,7 +109,7 @@ class AuthTestCase(unittest.TestCase):
request.getClientIP.return_value = "127.0.0.1"
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- requester = yield self.auth.get_user_by_req(request)
+ requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request))
self.assertEquals(requester.user.to_string(), self.test_user)
@defer.inlineCallbacks
@@ -125,7 +129,7 @@ class AuthTestCase(unittest.TestCase):
request.getClientIP.return_value = "192.168.10.10"
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- requester = yield self.auth.get_user_by_req(request)
+ requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request))
self.assertEquals(requester.user.to_string(), self.test_user)
def test_get_user_by_req_appservice_valid_token_bad_ip(self):
@@ -188,7 +192,7 @@ class AuthTestCase(unittest.TestCase):
request.args[b"access_token"] = [self.test_token]
request.args[b"user_id"] = [masquerading_user_id]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- requester = yield self.auth.get_user_by_req(request)
+ requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request))
self.assertEquals(
requester.user.to_string(), masquerading_user_id.decode("utf8")
)
@@ -225,7 +229,9 @@ class AuthTestCase(unittest.TestCase):
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = access")
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
- user_info = yield self.auth.get_user_by_access_token(macaroon.serialize())
+ user_info = yield defer.ensureDeferred(
+ self.auth.get_user_by_access_token(macaroon.serialize())
+ )
user = user_info["user"]
self.assertEqual(UserID.from_string(user_id), user)
@@ -250,7 +256,9 @@ class AuthTestCase(unittest.TestCase):
macaroon.add_first_party_caveat("guest = true")
serialized = macaroon.serialize()
- user_info = yield self.auth.get_user_by_access_token(serialized)
+ user_info = yield defer.ensureDeferred(
+ self.auth.get_user_by_access_token(serialized)
+ )
user = user_info["user"]
is_guest = user_info["is_guest"]
self.assertEqual(UserID.from_string(user_id), user)
@@ -260,10 +268,13 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_cannot_use_regular_token_as_guest(self):
USER_ID = "@percy:matrix.org"
- self.store.add_access_token_to_user = Mock()
+ self.store.add_access_token_to_user = Mock(return_value=defer.succeed(None))
+ self.store.get_device = Mock(return_value=defer.succeed(None))
- token = yield self.hs.handlers.auth_handler.get_access_token_for_user_id(
- USER_ID, "DEVICE", valid_until_ms=None
+ token = yield defer.ensureDeferred(
+ self.hs.handlers.auth_handler.get_access_token_for_user_id(
+ USER_ID, "DEVICE", valid_until_ms=None
+ )
)
self.store.add_access_token_to_user.assert_called_with(
USER_ID, token, "DEVICE", None
@@ -286,7 +297,9 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={})
request.args[b"access_token"] = [token.encode("ascii")]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ requester = yield defer.ensureDeferred(
+ self.auth.get_user_by_req(request, allow_guest=True)
+ )
self.assertEqual(UserID.from_string(USER_ID), requester.user)
self.assertFalse(requester.is_guest)
@@ -301,7 +314,9 @@ class AuthTestCase(unittest.TestCase):
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
with self.assertRaises(InvalidClientCredentialsError) as cm:
- yield self.auth.get_user_by_req(request, allow_guest=True)
+ yield defer.ensureDeferred(
+ self.auth.get_user_by_req(request, allow_guest=True)
+ )
self.assertEqual(401, cm.exception.code)
self.assertEqual("Guest access token used for regular user", cm.exception.msg)
@@ -310,22 +325,22 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_blocking_mau(self):
- self.hs.config.limit_usage_by_mau = False
- self.hs.config.max_mau_value = 50
+ self.auth_blocking._limit_usage_by_mau = False
+ self.auth_blocking._max_mau_value = 50
lots_of_users = 100
small_number_of_users = 1
# Ensure no error thrown
- yield self.auth.check_auth_blocking()
+ yield defer.ensureDeferred(self.auth.check_auth_blocking())
- self.hs.config.limit_usage_by_mau = True
+ self.auth_blocking._limit_usage_by_mau = True
self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(lots_of_users)
)
with self.assertRaises(ResourceLimitError) as e:
- yield self.auth.check_auth_blocking()
+ yield defer.ensureDeferred(self.auth.check_auth_blocking())
self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact)
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.assertEquals(e.exception.code, 403)
@@ -334,49 +349,54 @@ class AuthTestCase(unittest.TestCase):
self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(small_number_of_users)
)
- yield self.auth.check_auth_blocking()
+ yield defer.ensureDeferred(self.auth.check_auth_blocking())
@defer.inlineCallbacks
def test_blocking_mau__depending_on_user_type(self):
- self.hs.config.max_mau_value = 50
- self.hs.config.limit_usage_by_mau = True
+ self.auth_blocking._max_mau_value = 50
+ self.auth_blocking._limit_usage_by_mau = True
self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100))
# Support users allowed
- yield self.auth.check_auth_blocking(user_type=UserTypes.SUPPORT)
+ yield defer.ensureDeferred(
+ self.auth.check_auth_blocking(user_type=UserTypes.SUPPORT)
+ )
self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100))
# Bots not allowed
with self.assertRaises(ResourceLimitError):
- yield self.auth.check_auth_blocking(user_type=UserTypes.BOT)
+ yield defer.ensureDeferred(
+ self.auth.check_auth_blocking(user_type=UserTypes.BOT)
+ )
self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100))
# Real users not allowed
with self.assertRaises(ResourceLimitError):
- yield self.auth.check_auth_blocking()
+ yield defer.ensureDeferred(self.auth.check_auth_blocking())
@defer.inlineCallbacks
def test_reserved_threepid(self):
- self.hs.config.limit_usage_by_mau = True
- self.hs.config.max_mau_value = 1
+ self.auth_blocking._limit_usage_by_mau = True
+ self.auth_blocking._max_mau_value = 1
self.store.get_monthly_active_count = lambda: defer.succeed(2)
threepid = {"medium": "email", "address": "reserved@server.com"}
unknown_threepid = {"medium": "email", "address": "unreserved@server.com"}
- self.hs.config.mau_limits_reserved_threepids = [threepid]
+ self.auth_blocking._mau_limits_reserved_threepids = [threepid]
- yield self.store.register_user(user_id="user1", password_hash=None)
with self.assertRaises(ResourceLimitError):
- yield self.auth.check_auth_blocking()
+ yield defer.ensureDeferred(self.auth.check_auth_blocking())
with self.assertRaises(ResourceLimitError):
- yield self.auth.check_auth_blocking(threepid=unknown_threepid)
+ yield defer.ensureDeferred(
+ self.auth.check_auth_blocking(threepid=unknown_threepid)
+ )
- yield self.auth.check_auth_blocking(threepid=threepid)
+ yield defer.ensureDeferred(self.auth.check_auth_blocking(threepid=threepid))
@defer.inlineCallbacks
def test_hs_disabled(self):
- self.hs.config.hs_disabled = True
- self.hs.config.hs_disabled_message = "Reason for being disabled"
+ self.auth_blocking._hs_disabled = True
+ self.auth_blocking._hs_disabled_message = "Reason for being disabled"
with self.assertRaises(ResourceLimitError) as e:
- yield self.auth.check_auth_blocking()
+ yield defer.ensureDeferred(self.auth.check_auth_blocking())
self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact)
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.assertEquals(e.exception.code, 403)
@@ -388,20 +408,20 @@ class AuthTestCase(unittest.TestCase):
"""
# this should be the default, but we had a bug where the test was doing the wrong
# thing, so let's make it explicit
- self.hs.config.server_notices_mxid = None
+ self.auth_blocking._server_notices_mxid = None
- self.hs.config.hs_disabled = True
- self.hs.config.hs_disabled_message = "Reason for being disabled"
+ self.auth_blocking._hs_disabled = True
+ self.auth_blocking._hs_disabled_message = "Reason for being disabled"
with self.assertRaises(ResourceLimitError) as e:
- yield self.auth.check_auth_blocking()
+ yield defer.ensureDeferred(self.auth.check_auth_blocking())
self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact)
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.assertEquals(e.exception.code, 403)
@defer.inlineCallbacks
def test_server_notices_mxid_special_cased(self):
- self.hs.config.hs_disabled = True
+ self.auth_blocking._hs_disabled = True
user = "@user:server"
- self.hs.config.server_notices_mxid = user
- self.hs.config.hs_disabled_message = "Reason for being disabled"
- yield self.auth.check_auth_blocking(user)
+ self.auth_blocking._server_notices_mxid = user
+ self.auth_blocking._hs_disabled_message = "Reason for being disabled"
+ yield defer.ensureDeferred(self.auth.check_auth_blocking(user))
diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index 6ba623de13..4e67503cf0 100644
--- a/tests/api/test_filtering.py
+++ b/tests/api/test_filtering.py
@@ -1,5 +1,8 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2017 Vector Creations Ltd
+# Copyright 2018-2019 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -19,9 +22,10 @@ import jsonschema
from twisted.internet import defer
+from synapse.api.constants import EventContentFields
from synapse.api.errors import SynapseError
from synapse.api.filtering import Filter
-from synapse.events import FrozenEvent
+from synapse.events import make_event_from_dict
from tests import unittest
from tests.utils import DeferredMockCallable, MockHttpResource, setup_test_homeserver
@@ -34,7 +38,7 @@ def MockEvent(**kwargs):
kwargs["event_id"] = "fake_event_id"
if "type" not in kwargs:
kwargs["type"] = "fake_type"
- return FrozenEvent(kwargs)
+ return make_event_from_dict(kwargs)
class FilteringTestCase(unittest.TestCase):
@@ -95,6 +99,8 @@ class FilteringTestCase(unittest.TestCase):
"types": ["m.room.message"],
"not_rooms": ["!726s6s6q:example.com"],
"not_senders": ["@spam:example.com"],
+ "org.matrix.labels": ["#fun"],
+ "org.matrix.not_labels": ["#work"],
},
"ephemeral": {
"types": ["m.receipt", "m.typing"],
@@ -320,6 +326,46 @@ class FilteringTestCase(unittest.TestCase):
)
self.assertFalse(Filter(definition).check(event))
+ def test_filter_labels(self):
+ definition = {"org.matrix.labels": ["#fun"]}
+ event = MockEvent(
+ sender="@foo:bar",
+ type="m.room.message",
+ room_id="!secretbase:unknown",
+ content={EventContentFields.LABELS: ["#fun"]},
+ )
+
+ self.assertTrue(Filter(definition).check(event))
+
+ event = MockEvent(
+ sender="@foo:bar",
+ type="m.room.message",
+ room_id="!secretbase:unknown",
+ content={EventContentFields.LABELS: ["#notfun"]},
+ )
+
+ self.assertFalse(Filter(definition).check(event))
+
+ def test_filter_not_labels(self):
+ definition = {"org.matrix.not_labels": ["#fun"]}
+ event = MockEvent(
+ sender="@foo:bar",
+ type="m.room.message",
+ room_id="!secretbase:unknown",
+ content={EventContentFields.LABELS: ["#fun"]},
+ )
+
+ self.assertFalse(Filter(definition).check(event))
+
+ event = MockEvent(
+ sender="@foo:bar",
+ type="m.room.message",
+ room_id="!secretbase:unknown",
+ content={EventContentFields.LABELS: ["#notfun"]},
+ )
+
+ self.assertTrue(Filter(definition).check(event))
+
@defer.inlineCallbacks
def test_filter_presence_match(self):
user_filter_json = {"presence": {"types": ["m.*"]}}
diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py
index dbdd427cac..d580e729c5 100644
--- a/tests/api/test_ratelimiting.py
+++ b/tests/api/test_ratelimiting.py
@@ -1,39 +1,97 @@
-from synapse.api.ratelimiting import Ratelimiter
+from synapse.api.ratelimiting import LimitExceededError, Ratelimiter
from tests import unittest
class TestRatelimiter(unittest.TestCase):
- def test_allowed(self):
- limiter = Ratelimiter()
- allowed, time_allowed = limiter.can_do_action(
- key="test_id", time_now_s=0, rate_hz=0.1, burst_count=1
- )
+ def test_allowed_via_can_do_action(self):
+ limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
+ allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=0)
self.assertTrue(allowed)
self.assertEquals(10.0, time_allowed)
- allowed, time_allowed = limiter.can_do_action(
- key="test_id", time_now_s=5, rate_hz=0.1, burst_count=1
- )
+ allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=5)
self.assertFalse(allowed)
self.assertEquals(10.0, time_allowed)
- allowed, time_allowed = limiter.can_do_action(
- key="test_id", time_now_s=10, rate_hz=0.1, burst_count=1
- )
+ allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=10)
self.assertTrue(allowed)
self.assertEquals(20.0, time_allowed)
- def test_pruning(self):
- limiter = Ratelimiter()
+ def test_allowed_via_ratelimit(self):
+ limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
+
+ # Shouldn't raise
+ limiter.ratelimit(key="test_id", _time_now_s=0)
+
+ # Should raise
+ with self.assertRaises(LimitExceededError) as context:
+ limiter.ratelimit(key="test_id", _time_now_s=5)
+ self.assertEqual(context.exception.retry_after_ms, 5000)
+
+ # Shouldn't raise
+ limiter.ratelimit(key="test_id", _time_now_s=10)
+
+ def test_allowed_via_can_do_action_and_overriding_parameters(self):
+ """Test that we can override options of can_do_action that would otherwise fail
+ an action
+ """
+ # Create a Ratelimiter with a very low allowed rate_hz and burst_count
+ limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
+
+ # First attempt should be allowed
+ allowed, time_allowed = limiter.can_do_action(("test_id",), _time_now_s=0,)
+ self.assertTrue(allowed)
+ self.assertEqual(10.0, time_allowed)
+
+ # Second attempt, 1s later, will fail
+ allowed, time_allowed = limiter.can_do_action(("test_id",), _time_now_s=1,)
+ self.assertFalse(allowed)
+ self.assertEqual(10.0, time_allowed)
+
+ # But, if we allow 10 actions/sec for this request, we should be allowed
+ # to continue.
allowed, time_allowed = limiter.can_do_action(
- key="test_id_1", time_now_s=0, rate_hz=0.1, burst_count=1
+ ("test_id",), _time_now_s=1, rate_hz=10.0
)
+ self.assertTrue(allowed)
+ self.assertEqual(1.1, time_allowed)
- self.assertIn("test_id_1", limiter.message_counts)
-
+ # Similarly if we allow a burst of 10 actions
allowed, time_allowed = limiter.can_do_action(
- key="test_id_2", time_now_s=10, rate_hz=0.1, burst_count=1
+ ("test_id",), _time_now_s=1, burst_count=10
)
+ self.assertTrue(allowed)
+ self.assertEqual(1.0, time_allowed)
+
+ def test_allowed_via_ratelimit_and_overriding_parameters(self):
+ """Test that we can override options of the ratelimit method that would otherwise
+ fail an action
+ """
+ # Create a Ratelimiter with a very low allowed rate_hz and burst_count
+ limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
+
+ # First attempt should be allowed
+ limiter.ratelimit(key=("test_id",), _time_now_s=0)
+
+ # Second attempt, 1s later, will fail
+ with self.assertRaises(LimitExceededError) as context:
+ limiter.ratelimit(key=("test_id",), _time_now_s=1)
+ self.assertEqual(context.exception.retry_after_ms, 9000)
+
+ # But, if we allow 10 actions/sec for this request, we should be allowed
+ # to continue.
+ limiter.ratelimit(key=("test_id",), _time_now_s=1, rate_hz=10.0)
+
+ # Similarly if we allow a burst of 10 actions
+ limiter.ratelimit(key=("test_id",), _time_now_s=1, burst_count=10)
+
+ def test_pruning(self):
+ limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
+ limiter.can_do_action(key="test_id_1", _time_now_s=0)
+
+ self.assertIn("test_id_1", limiter.actions)
+
+ limiter.can_do_action(key="test_id_2", _time_now_s=10)
- self.assertNotIn("test_id_1", limiter.message_counts)
+ self.assertNotIn("test_id_1", limiter.actions)
diff --git a/tests/app/test_frontend_proxy.py b/tests/app/test_frontend_proxy.py
index 8bdbc608a9..be20a89682 100644
--- a/tests/app/test_frontend_proxy.py
+++ b/tests/app/test_frontend_proxy.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.app.frontend_proxy import FrontendProxyServer
+from synapse.app.generic_worker import GenericWorkerServer
from tests.unittest import HomeserverTestCase
@@ -22,11 +22,16 @@ class FrontendProxyTests(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver(
- http_client=None, homeserverToUse=FrontendProxyServer
+ http_client=None, homeserverToUse=GenericWorkerServer
)
return hs
+ def default_config(self):
+ c = super().default_config()
+ c["worker_app"] = "synapse.app.frontend_proxy"
+ return c
+
def test_listen_http_with_presence_enabled(self):
"""
When presence is on, the stub servlet will not register.
@@ -46,9 +51,7 @@ class FrontendProxyTests(HomeserverTestCase):
# Grab the resource from the site that was told to listen
self.assertEqual(len(self.reactor.tcpServers), 1)
site = self.reactor.tcpServers[0][1]
- self.resource = (
- site.resource.children[b"_matrix"].children[b"client"].children[b"r0"]
- )
+ self.resource = site.resource.children[b"_matrix"].children[b"client"]
request, channel = self.make_request("PUT", "presence/a/status")
self.render(request)
@@ -76,9 +79,7 @@ class FrontendProxyTests(HomeserverTestCase):
# Grab the resource from the site that was told to listen
self.assertEqual(len(self.reactor.tcpServers), 1)
site = self.reactor.tcpServers[0][1]
- self.resource = (
- site.resource.children[b"_matrix"].children[b"client"].children[b"r0"]
- )
+ self.resource = site.resource.children[b"_matrix"].children[b"client"]
request, channel = self.make_request("PUT", "presence/a/status")
self.render(request)
diff --git a/tests/app/test_openid_listener.py b/tests/app/test_openid_listener.py
index 48792d1480..7364f9f1ec 100644
--- a/tests/app/test_openid_listener.py
+++ b/tests/app/test_openid_listener.py
@@ -16,7 +16,7 @@ from mock import Mock, patch
from parameterized import parameterized
-from synapse.app.federation_reader import FederationReaderServer
+from synapse.app.generic_worker import GenericWorkerServer
from synapse.app.homeserver import SynapseHomeServer
from tests.unittest import HomeserverTestCase
@@ -25,10 +25,18 @@ from tests.unittest import HomeserverTestCase
class FederationReaderOpenIDListenerTests(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver(
- http_client=None, homeserverToUse=FederationReaderServer
+ http_client=None, homeserverToUse=GenericWorkerServer
)
return hs
+ def default_config(self):
+ conf = super().default_config()
+ # we're using FederationReaderServer, which uses a SlavedStore, so we
+ # have to tell the FederationHandler not to try to access stuff that is only
+ # in the primary store.
+ conf["worker_app"] = "yes"
+ return conf
+
@parameterized.expand(
[
(["federation"], "auth_fail"),
diff --git a/tests/config/test_cache.py b/tests/config/test_cache.py
new file mode 100644
index 0000000000..d3ec24c975
--- /dev/null
+++ b/tests/config/test_cache.py
@@ -0,0 +1,171 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.config._base import Config, RootConfig
+from synapse.config.cache import CacheConfig, add_resizable_cache
+from synapse.util.caches.lrucache import LruCache
+
+from tests.unittest import TestCase
+
+
+class FakeServer(Config):
+ section = "server"
+
+
+class TestConfig(RootConfig):
+ config_classes = [FakeServer, CacheConfig]
+
+
+class CacheConfigTests(TestCase):
+ def setUp(self):
+ # Reset caches before each test
+ TestConfig().caches.reset()
+
+ def test_individual_caches_from_environ(self):
+ """
+ Individual cache factors will be loaded from the environment.
+ """
+ config = {}
+ t = TestConfig()
+ t.caches._environ = {
+ "SYNAPSE_CACHE_FACTOR_SOMETHING_OR_OTHER": "2",
+ "SYNAPSE_NOT_CACHE": "BLAH",
+ }
+ t.read_config(config, config_dir_path="", data_dir_path="")
+
+ self.assertEqual(dict(t.caches.cache_factors), {"something_or_other": 2.0})
+
+ def test_config_overrides_environ(self):
+ """
+ Individual cache factors defined in the environment will take precedence
+ over those in the config.
+ """
+ config = {"caches": {"per_cache_factors": {"foo": 2, "bar": 3}}}
+ t = TestConfig()
+ t.caches._environ = {
+ "SYNAPSE_CACHE_FACTOR_SOMETHING_OR_OTHER": "2",
+ "SYNAPSE_CACHE_FACTOR_FOO": 1,
+ }
+ t.read_config(config, config_dir_path="", data_dir_path="")
+
+ self.assertEqual(
+ dict(t.caches.cache_factors),
+ {"foo": 1.0, "bar": 3.0, "something_or_other": 2.0},
+ )
+
+ def test_individual_instantiated_before_config_load(self):
+ """
+ If a cache is instantiated before the config is read, it will be given
+ the default cache size in the interim, and then resized once the config
+ is loaded.
+ """
+ cache = LruCache(100)
+
+ add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor)
+ self.assertEqual(cache.max_size, 50)
+
+ config = {"caches": {"per_cache_factors": {"foo": 3}}}
+ t = TestConfig()
+ t.read_config(config, config_dir_path="", data_dir_path="")
+
+ self.assertEqual(cache.max_size, 300)
+
+ def test_individual_instantiated_after_config_load(self):
+ """
+ If a cache is instantiated after the config is read, it will be
+ immediately resized to the correct size given the per_cache_factor if
+ there is one.
+ """
+ config = {"caches": {"per_cache_factors": {"foo": 2}}}
+ t = TestConfig()
+ t.read_config(config, config_dir_path="", data_dir_path="")
+
+ cache = LruCache(100)
+ add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor)
+ self.assertEqual(cache.max_size, 200)
+
+ def test_global_instantiated_before_config_load(self):
+ """
+ If a cache is instantiated before the config is read, it will be given
+ the default cache size in the interim, and then resized to the new
+ default cache size once the config is loaded.
+ """
+ cache = LruCache(100)
+ add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor)
+ self.assertEqual(cache.max_size, 50)
+
+ config = {"caches": {"global_factor": 4}}
+ t = TestConfig()
+ t.read_config(config, config_dir_path="", data_dir_path="")
+
+ self.assertEqual(cache.max_size, 400)
+
+ def test_global_instantiated_after_config_load(self):
+ """
+ If a cache is instantiated after the config is read, it will be
+ immediately resized to the correct size given the global factor if there
+ is no per-cache factor.
+ """
+ config = {"caches": {"global_factor": 1.5}}
+ t = TestConfig()
+ t.read_config(config, config_dir_path="", data_dir_path="")
+
+ cache = LruCache(100)
+ add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor)
+ self.assertEqual(cache.max_size, 150)
+
+ def test_cache_with_asterisk_in_name(self):
+ """Some caches have asterisks in their name, test that they are set correctly.
+ """
+
+ config = {
+ "caches": {
+ "per_cache_factors": {"*cache_a*": 5, "cache_b": 6, "cache_c": 2}
+ }
+ }
+ t = TestConfig()
+ t.caches._environ = {
+ "SYNAPSE_CACHE_FACTOR_CACHE_A": "2",
+ "SYNAPSE_CACHE_FACTOR_CACHE_B": 3,
+ }
+ t.read_config(config, config_dir_path="", data_dir_path="")
+
+ cache_a = LruCache(100)
+ add_resizable_cache("*cache_a*", cache_resize_callback=cache_a.set_cache_factor)
+ self.assertEqual(cache_a.max_size, 200)
+
+ cache_b = LruCache(100)
+ add_resizable_cache("*Cache_b*", cache_resize_callback=cache_b.set_cache_factor)
+ self.assertEqual(cache_b.max_size, 300)
+
+ cache_c = LruCache(100)
+ add_resizable_cache("*cache_c*", cache_resize_callback=cache_c.set_cache_factor)
+ self.assertEqual(cache_c.max_size, 200)
+
+ def test_apply_cache_factor_from_config(self):
+ """Caches can disable applying cache factor updates, mainly used by
+ event cache size.
+ """
+
+ config = {"caches": {"event_cache_size": "10k"}}
+ t = TestConfig()
+ t.read_config(config, config_dir_path="", data_dir_path="")
+
+ cache = LruCache(
+ max_size=t.caches.event_cache_size, apply_cache_factor_from_config=False,
+ )
+ add_resizable_cache("event_cache", cache_resize_callback=cache.set_cache_factor)
+
+ self.assertEqual(cache.max_size, 10240)
diff --git a/tests/config/test_database.py b/tests/config/test_database.py
index 151d3006ac..f675bde68e 100644
--- a/tests/config/test_database.py
+++ b/tests/config/test_database.py
@@ -21,9 +21,9 @@ from tests import unittest
class DatabaseConfigTestCase(unittest.TestCase):
- def test_database_configured_correctly_no_database_conf_param(self):
+ def test_database_configured_correctly(self):
conf = yaml.safe_load(
- DatabaseConfig().generate_config_section("/data_dir_path", None)
+ DatabaseConfig().generate_config_section(data_dir_path="/data_dir_path")
)
expected_database_conf = {
@@ -32,21 +32,3 @@ class DatabaseConfigTestCase(unittest.TestCase):
}
self.assertEqual(conf["database"], expected_database_conf)
-
- def test_database_configured_correctly_database_conf_param(self):
-
- database_conf = {
- "name": "my super fast datastore",
- "args": {
- "user": "matrix",
- "password": "synapse_database_password",
- "host": "synapse_database_host",
- "database": "matrix",
- },
- }
-
- conf = yaml.safe_load(
- DatabaseConfig().generate_config_section("/data_dir_path", database_conf)
- )
-
- self.assertEqual(conf["database"], database_conf)
diff --git a/tests/config/test_generate.py b/tests/config/test_generate.py
index 2684e662de..463855ecc8 100644
--- a/tests/config/test_generate.py
+++ b/tests/config/test_generate.py
@@ -48,7 +48,7 @@ class ConfigGenerationTestCase(unittest.TestCase):
)
self.assertSetEqual(
- set(["homeserver.yaml", "lemurs.win.log.config", "lemurs.win.signing.key"]),
+ {"homeserver.yaml", "lemurs.win.log.config", "lemurs.win.signing.key"},
set(os.listdir(self.dir)),
)
diff --git a/tests/config/test_load.py b/tests/config/test_load.py
index b3e557bd6a..734a9983e8 100644
--- a/tests/config/test_load.py
+++ b/tests/config/test_load.py
@@ -122,7 +122,7 @@ class ConfigLoadingTestCase(unittest.TestCase):
with open(self.file, "r") as f:
contents = f.readlines()
- contents = [l for l in contents if needle not in l]
+ contents = [line for line in contents if needle not in line]
with open(self.file, "w") as f:
f.write("".join(contents))
diff --git a/tests/config/test_tls.py b/tests/config/test_tls.py
index b02780772a..ec32d4b1ca 100644
--- a/tests/config/test_tls.py
+++ b/tests/config/test_tls.py
@@ -21,17 +21,24 @@ import yaml
from OpenSSL import SSL
+from synapse.config._base import Config, RootConfig
from synapse.config.tls import ConfigError, TlsConfig
-from synapse.crypto.context_factory import ClientTLSOptionsFactory
+from synapse.crypto.context_factory import FederationPolicyForHTTPS
from tests.unittest import TestCase
-class TestConfig(TlsConfig):
+class FakeServer(Config):
+ section = "server"
+
def has_tls_listener(self):
return False
+class TestConfig(RootConfig):
+ config_classes = [FakeServer, TlsConfig]
+
+
class TLSConfigTests(TestCase):
def test_warn_self_signed(self):
"""
@@ -173,12 +180,13 @@ s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg=
t = TestConfig()
t.read_config(config, config_dir_path="", data_dir_path="")
- cf = ClientTLSOptionsFactory(t)
+ cf = FederationPolicyForHTTPS(t)
+ options = _get_ssl_context_options(cf._verify_ssl_context)
# The context has had NO_TLSv1_1 and NO_TLSv1_0 set, but not NO_TLSv1_2
- self.assertNotEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1, 0)
- self.assertNotEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1_1, 0)
- self.assertEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1_2, 0)
+ self.assertNotEqual(options & SSL.OP_NO_TLSv1, 0)
+ self.assertNotEqual(options & SSL.OP_NO_TLSv1_1, 0)
+ self.assertEqual(options & SSL.OP_NO_TLSv1_2, 0)
def test_tls_client_minimum_set_passed_through_1_0(self):
"""
@@ -188,12 +196,13 @@ s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg=
t = TestConfig()
t.read_config(config, config_dir_path="", data_dir_path="")
- cf = ClientTLSOptionsFactory(t)
+ cf = FederationPolicyForHTTPS(t)
+ options = _get_ssl_context_options(cf._verify_ssl_context)
# The context has not had any of the NO_TLS set.
- self.assertEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1, 0)
- self.assertEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1_1, 0)
- self.assertEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1_2, 0)
+ self.assertEqual(options & SSL.OP_NO_TLSv1, 0)
+ self.assertEqual(options & SSL.OP_NO_TLSv1_1, 0)
+ self.assertEqual(options & SSL.OP_NO_TLSv1_2, 0)
def test_acme_disabled_in_generated_config_no_acme_domain_provied(self):
"""
@@ -202,13 +211,13 @@ s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg=
conf = TestConfig()
conf.read_config(
yaml.safe_load(
- TestConfig().generate_config_section(
+ TestConfig().generate_config(
"/config_dir_path",
"my_super_secure_server",
"/data_dir_path",
- "/tls_cert_path",
- "tls_private_key",
- None, # This is the acme_domain
+ tls_certificate_path="/tls_cert_path",
+ tls_private_key_path="tls_private_key",
+ acme_domain=None, # This is the acme_domain
)
),
"/config_dir_path",
@@ -223,13 +232,13 @@ s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg=
conf = TestConfig()
conf.read_config(
yaml.safe_load(
- TestConfig().generate_config_section(
+ TestConfig().generate_config(
"/config_dir_path",
"my_super_secure_server",
"/data_dir_path",
- "/tls_cert_path",
- "tls_private_key",
- "my_supe_secure_server", # This is the acme_domain
+ tls_certificate_path="/tls_cert_path",
+ tls_private_key_path="tls_private_key",
+ acme_domain="my_supe_secure_server", # This is the acme_domain
)
),
"/config_dir_path",
@@ -266,7 +275,7 @@ s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg=
t = TestConfig()
t.read_config(config, config_dir_path="", data_dir_path="")
- cf = ClientTLSOptionsFactory(t)
+ cf = FederationPolicyForHTTPS(t)
# Not in the whitelist
opts = cf.get_options(b"notexample.com")
@@ -275,3 +284,10 @@ s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg=
# Caught by the wildcard
opts = cf.get_options(idna.encode("テスト.ドメイン.テスト"))
self.assertFalse(opts._verifier._verify_certs)
+
+
+def _get_ssl_context_options(ssl_context: SSL.Context) -> int:
+ """get the options bits from an openssl context object"""
+ # the OpenSSL.SSL.Context wrapper doesn't expose get_options, so we have to
+ # use the low-level interface
+ return SSL._lib.SSL_CTX_get_options(ssl_context._context)
diff --git a/tests/crypto/test_event_signing.py b/tests/crypto/test_event_signing.py
index 126e176004..62f639a18d 100644
--- a/tests/crypto/test_event_signing.py
+++ b/tests/crypto/test_event_signing.py
@@ -17,8 +17,9 @@
import nacl.signing
from unpaddedbase64 import decode_base64
+from synapse.api.room_versions import RoomVersions
from synapse.crypto.event_signing import add_hashes_and_signatures
-from synapse.events import FrozenEvent
+from synapse.events import make_event_from_dict
from tests import unittest
@@ -49,9 +50,11 @@ class EventSigningTestCase(unittest.TestCase):
"unsigned": {"age_ts": 1000000},
}
- add_hashes_and_signatures(event_dict, HOSTNAME, self.signing_key)
+ add_hashes_and_signatures(
+ RoomVersions.V1, event_dict, HOSTNAME, self.signing_key
+ )
- event = FrozenEvent(event_dict)
+ event = make_event_from_dict(event_dict)
self.assertTrue(hasattr(event, "hashes"))
self.assertIn("sha256", event.hashes)
@@ -81,9 +84,11 @@ class EventSigningTestCase(unittest.TestCase):
"unsigned": {"age_ts": 1000000},
}
- add_hashes_and_signatures(event_dict, HOSTNAME, self.signing_key)
+ add_hashes_and_signatures(
+ RoomVersions.V1, event_dict, HOSTNAME, self.signing_key
+ )
- event = FrozenEvent(event_dict)
+ event = make_event_from_dict(event_dict)
self.assertTrue(hasattr(event, "hashes"))
self.assertIn("sha256", event.hashes)
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index c4f0bbd3dd..70c8e72303 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -19,6 +19,7 @@ from mock import Mock
import canonicaljson
import signedjson.key
import signedjson.sign
+from nacl.signing import SigningKey
from signedjson.key import encode_verify_key_base64, get_verify_key
from twisted.internet import defer
@@ -33,6 +34,7 @@ from synapse.crypto.keyring import (
from synapse.logging.context import (
LoggingContext,
PreserveLoggingContext,
+ current_context,
make_deferred_yieldable,
)
from synapse.storage.keys import FetchKeyResult
@@ -82,9 +84,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
)
def check_context(self, _, expected):
- self.assertEquals(
- getattr(LoggingContext.current_context(), "request", None), expected
- )
+ self.assertEquals(getattr(current_context(), "request", None), expected)
def test_verify_json_objects_for_server_awaits_previous_requests(self):
key1 = signedjson.key.generate_signing_key(1)
@@ -104,7 +104,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
@defer.inlineCallbacks
def get_perspectives(**kwargs):
- self.assertEquals(LoggingContext.current_context().request, "11")
+ self.assertEquals(current_context().request, "11")
with PreserveLoggingContext():
yield persp_deferred
return persp_resp
@@ -178,7 +178,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
kr = keyring.Keyring(self.hs)
key1 = signedjson.key.generate_signing_key(1)
- r = self.hs.datastore.store_server_verify_keys(
+ r = self.hs.get_datastore().store_server_verify_keys(
"server9",
time.time() * 1000,
[("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), 1000))],
@@ -209,7 +209,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
)
key1 = signedjson.key.generate_signing_key(1)
- r = self.hs.datastore.store_server_verify_keys(
+ r = self.hs.get_datastore().store_server_verify_keys(
"server9",
time.time() * 1000,
[("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), None))],
@@ -412,34 +412,37 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
handlers=None, http_client=self.http_client, config=config
)
- def test_get_keys_from_perspectives(self):
- # arbitrarily advance the clock a bit
- self.reactor.advance(100)
-
- fetcher = PerspectivesKeyFetcher(self.hs)
-
- SERVER_NAME = "server2"
- testkey = signedjson.key.generate_signing_key("ver1")
- testverifykey = signedjson.key.get_verify_key(testkey)
- testverifykey_id = "ed25519:ver1"
- VALID_UNTIL_TS = 200 * 1000
+ def build_perspectives_response(
+ self, server_name: str, signing_key: SigningKey, valid_until_ts: int,
+ ) -> dict:
+ """
+ Build a valid perspectives server response to a request for the given key
+ """
+ verify_key = signedjson.key.get_verify_key(signing_key)
+ verifykey_id = "%s:%s" % (verify_key.alg, verify_key.version)
- # valid response
response = {
- "server_name": SERVER_NAME,
+ "server_name": server_name,
"old_verify_keys": {},
- "valid_until_ts": VALID_UNTIL_TS,
+ "valid_until_ts": valid_until_ts,
"verify_keys": {
- testverifykey_id: {
- "key": signedjson.key.encode_verify_key_base64(testverifykey)
+ verifykey_id: {
+ "key": signedjson.key.encode_verify_key_base64(verify_key)
}
},
}
-
# the response must be signed by both the origin server and the perspectives
# server.
- signedjson.sign.sign_json(response, SERVER_NAME, testkey)
+ signedjson.sign.sign_json(response, server_name, signing_key)
self.mock_perspective_server.sign_response(response)
+ return response
+
+ def expect_outgoing_key_query(
+ self, expected_server_name: str, expected_key_id: str, response: dict
+ ) -> None:
+ """
+ Tell the mock http client to expect a perspectives-server key query
+ """
def post_json(destination, path, data, **kwargs):
self.assertEqual(destination, self.mock_perspective_server.server_name)
@@ -447,11 +450,79 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
# check that the request is for the expected key
q = data["server_keys"]
- self.assertEqual(list(q[SERVER_NAME].keys()), ["key1"])
+ self.assertEqual(list(q[expected_server_name].keys()), [expected_key_id])
return {"server_keys": [response]}
self.http_client.post_json.side_effect = post_json
+ def test_get_keys_from_perspectives(self):
+ # arbitrarily advance the clock a bit
+ self.reactor.advance(100)
+
+ fetcher = PerspectivesKeyFetcher(self.hs)
+
+ SERVER_NAME = "server2"
+ testkey = signedjson.key.generate_signing_key("ver1")
+ testverifykey = signedjson.key.get_verify_key(testkey)
+ testverifykey_id = "ed25519:ver1"
+ VALID_UNTIL_TS = 200 * 1000
+
+ response = self.build_perspectives_response(
+ SERVER_NAME, testkey, VALID_UNTIL_TS,
+ )
+
+ self.expect_outgoing_key_query(SERVER_NAME, "key1", response)
+
+ keys_to_fetch = {SERVER_NAME: {"key1": 0}}
+ keys = self.get_success(fetcher.get_keys(keys_to_fetch))
+ self.assertIn(SERVER_NAME, keys)
+ k = keys[SERVER_NAME][testverifykey_id]
+ self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
+ self.assertEqual(k.verify_key, testverifykey)
+ self.assertEqual(k.verify_key.alg, "ed25519")
+ self.assertEqual(k.verify_key.version, "ver1")
+
+ # check that the perspectives store is correctly updated
+ lookup_triplet = (SERVER_NAME, testverifykey_id, None)
+ key_json = self.get_success(
+ self.hs.get_datastore().get_server_keys_json([lookup_triplet])
+ )
+ res = key_json[lookup_triplet]
+ self.assertEqual(len(res), 1)
+ res = res[0]
+ self.assertEqual(res["key_id"], testverifykey_id)
+ self.assertEqual(res["from_server"], self.mock_perspective_server.server_name)
+ self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
+ self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS)
+
+ self.assertEqual(
+ bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
+ )
+
+ def test_get_perspectives_own_key(self):
+ """Check that we can get the perspectives server's own keys
+
+ This is slightly complicated by the fact that the perspectives server may
+ use different keys for signing notary responses.
+ """
+
+ # arbitrarily advance the clock a bit
+ self.reactor.advance(100)
+
+ fetcher = PerspectivesKeyFetcher(self.hs)
+
+ SERVER_NAME = self.mock_perspective_server.server_name
+ testkey = signedjson.key.generate_signing_key("ver1")
+ testverifykey = signedjson.key.get_verify_key(testkey)
+ testverifykey_id = "ed25519:ver1"
+ VALID_UNTIL_TS = 200 * 1000
+
+ response = self.build_perspectives_response(
+ SERVER_NAME, testkey, VALID_UNTIL_TS
+ )
+
+ self.expect_outgoing_key_query(SERVER_NAME, "key1", response)
+
keys_to_fetch = {SERVER_NAME: {"key1": 0}}
keys = self.get_success(fetcher.get_keys(keys_to_fetch))
self.assertIn(SERVER_NAME, keys)
@@ -490,35 +561,14 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
VALID_UNTIL_TS = 200 * 1000
def build_response():
- # valid response
- response = {
- "server_name": SERVER_NAME,
- "old_verify_keys": {},
- "valid_until_ts": VALID_UNTIL_TS,
- "verify_keys": {
- testverifykey_id: {
- "key": signedjson.key.encode_verify_key_base64(testverifykey)
- }
- },
- }
-
- # the response must be signed by both the origin server and the perspectives
- # server.
- signedjson.sign.sign_json(response, SERVER_NAME, testkey)
- self.mock_perspective_server.sign_response(response)
- return response
+ return self.build_perspectives_response(
+ SERVER_NAME, testkey, VALID_UNTIL_TS
+ )
def get_key_from_perspectives(response):
fetcher = PerspectivesKeyFetcher(self.hs)
keys_to_fetch = {SERVER_NAME: {"key1": 0}}
-
- def post_json(destination, path, data, **kwargs):
- self.assertEqual(destination, self.mock_perspective_server.server_name)
- self.assertEqual(path, "/_matrix/key/v2/query")
- return {"server_keys": [response]}
-
- self.http_client.post_json.side_effect = post_json
-
+ self.expect_outgoing_key_query(SERVER_NAME, "key1", response)
return self.get_success(fetcher.get_keys(keys_to_fetch))
# start with a valid response so we can check we are testing the right thing
diff --git a/tests/events/test_snapshot.py b/tests/events/test_snapshot.py
new file mode 100644
index 0000000000..640f5f3bce
--- /dev/null
+++ b/tests/events/test_snapshot.py
@@ -0,0 +1,100 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.events.snapshot import EventContext
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+
+from tests import unittest
+from tests.test_utils.event_injection import create_event
+
+
+class TestEventContext(unittest.HomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+ self.storage = hs.get_storage()
+
+ self.user_id = self.register_user("u1", "pass")
+ self.user_tok = self.login("u1", "pass")
+ self.room_id = self.helper.create_room_as(tok=self.user_tok)
+
+ def test_serialize_deserialize_msg(self):
+ """Test that an EventContext for a message event is the same after
+ serialize/deserialize.
+ """
+
+ event, context = create_event(
+ self.hs, room_id=self.room_id, type="m.test", sender=self.user_id,
+ )
+
+ self._check_serialize_deserialize(event, context)
+
+ def test_serialize_deserialize_state_no_prev(self):
+ """Test that an EventContext for a state event (with not previous entry)
+ is the same after serialize/deserialize.
+ """
+ event, context = create_event(
+ self.hs,
+ room_id=self.room_id,
+ type="m.test",
+ sender=self.user_id,
+ state_key="",
+ )
+
+ self._check_serialize_deserialize(event, context)
+
+ def test_serialize_deserialize_state_prev(self):
+ """Test that an EventContext for a state event (which replaces a
+ previous entry) is the same after serialize/deserialize.
+ """
+ event, context = create_event(
+ self.hs,
+ room_id=self.room_id,
+ type="m.room.member",
+ sender=self.user_id,
+ state_key=self.user_id,
+ content={"membership": "leave"},
+ )
+
+ self._check_serialize_deserialize(event, context)
+
+ def _check_serialize_deserialize(self, event, context):
+ serialized = self.get_success(context.serialize(event, self.store))
+
+ d_context = EventContext.deserialize(self.storage, serialized)
+
+ self.assertEqual(context.state_group, d_context.state_group)
+ self.assertEqual(context.rejected, d_context.rejected)
+ self.assertEqual(
+ context.state_group_before_event, d_context.state_group_before_event
+ )
+ self.assertEqual(context.prev_group, d_context.prev_group)
+ self.assertEqual(context.delta_ids, d_context.delta_ids)
+ self.assertEqual(context.app_service, d_context.app_service)
+
+ self.assertEqual(
+ self.get_success(context.get_current_state_ids()),
+ self.get_success(d_context.get_current_state_ids()),
+ )
+ self.assertEqual(
+ self.get_success(context.get_prev_state_ids()),
+ self.get_success(d_context.get_prev_state_ids()),
+ )
diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py
index 9e3d4d0f47..c1274c14af 100644
--- a/tests/events/test_utils.py
+++ b/tests/events/test_utils.py
@@ -13,11 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from synapse.api.room_versions import RoomVersions
+from synapse.events import make_event_from_dict
+from synapse.events.utils import (
+ copy_power_levels_contents,
+ prune_event,
+ serialize_event,
+)
+from synapse.util.frozenutils import freeze
-from synapse.events import FrozenEvent
-from synapse.events.utils import prune_event, serialize_event
-
-from .. import unittest
+from tests import unittest
def MockEvent(**kwargs):
@@ -25,15 +30,17 @@ def MockEvent(**kwargs):
kwargs["event_id"] = "fake_event_id"
if "type" not in kwargs:
kwargs["type"] = "fake_type"
- return FrozenEvent(kwargs)
+ return make_event_from_dict(kwargs)
class PruneEventTestCase(unittest.TestCase):
""" Asserts that a new event constructed with `evdict` will look like
`matchdict` when it is redacted. """
- def run_test(self, evdict, matchdict):
- self.assertEquals(prune_event(FrozenEvent(evdict)).get_dict(), matchdict)
+ def run_test(self, evdict, matchdict, **kwargs):
+ self.assertEquals(
+ prune_event(make_event_from_dict(evdict, **kwargs)).get_dict(), matchdict
+ )
def test_minimal(self):
self.run_test(
@@ -122,6 +129,36 @@ class PruneEventTestCase(unittest.TestCase):
},
)
+ def test_alias_event(self):
+ """Alias events have special behavior up through room version 6."""
+ self.run_test(
+ {
+ "type": "m.room.aliases",
+ "event_id": "$test:domain",
+ "content": {"aliases": ["test"]},
+ },
+ {
+ "type": "m.room.aliases",
+ "event_id": "$test:domain",
+ "content": {"aliases": ["test"]},
+ "signatures": {},
+ "unsigned": {},
+ },
+ )
+
+ def test_msc2432_alias_event(self):
+ """After MSC2432, alias events have no special behavior."""
+ self.run_test(
+ {"type": "m.room.aliases", "content": {"aliases": ["test"]}},
+ {
+ "type": "m.room.aliases",
+ "content": {},
+ "signatures": {},
+ "unsigned": {},
+ },
+ room_version=RoomVersions.V6,
+ )
+
class SerializeEventTestCase(unittest.TestCase):
def serialize(self, ev, fields):
@@ -241,3 +278,39 @@ class SerializeEventTestCase(unittest.TestCase):
self.serialize(
MockEvent(room_id="!foo:bar", content={"foo": "bar"}), ["room_id", 4]
)
+
+
+class CopyPowerLevelsContentTestCase(unittest.TestCase):
+ def setUp(self) -> None:
+ self.test_content = {
+ "ban": 50,
+ "events": {"m.room.name": 100, "m.room.power_levels": 100},
+ "events_default": 0,
+ "invite": 50,
+ "kick": 50,
+ "notifications": {"room": 20},
+ "redact": 50,
+ "state_default": 50,
+ "users": {"@example:localhost": 100},
+ "users_default": 0,
+ }
+
+ def _test(self, input):
+ a = copy_power_levels_contents(input)
+
+ self.assertEqual(a["ban"], 50)
+ self.assertEqual(a["events"]["m.room.name"], 100)
+
+ # make sure that changing the copy changes the copy and not the orig
+ a["ban"] = 10
+ a["events"]["m.room.power_levels"] = 20
+
+ self.assertEqual(input["ban"], 50)
+ self.assertEqual(input["events"]["m.room.power_levels"], 100)
+
+ def test_unfrozen(self):
+ self._test(self.test_content)
+
+ def test_frozen(self):
+ input = freeze(self.test_content)
+ self._test(input)
diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py
index 51714a2b06..0c9987be54 100644
--- a/tests/federation/test_complexity.py
+++ b/tests/federation/test_complexity.py
@@ -18,17 +18,14 @@ from mock import Mock
from twisted.internet import defer
from synapse.api.errors import Codes, SynapseError
-from synapse.config.ratelimiting import FederationRateLimitConfig
-from synapse.federation.transport import server
from synapse.rest import admin
from synapse.rest.client.v1 import login, room
from synapse.types import UserID
-from synapse.util.ratelimitutils import FederationRateLimiter
from tests import unittest
-class RoomComplexityTests(unittest.HomeserverTestCase):
+class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
servlets = [
admin.register_servlets,
@@ -36,30 +33,11 @@ class RoomComplexityTests(unittest.HomeserverTestCase):
login.register_servlets,
]
- def default_config(self, name="test"):
- config = super().default_config(name=name)
+ def default_config(self):
+ config = super().default_config()
config["limit_remote_rooms"] = {"enabled": True, "complexity": 0.05}
return config
- def prepare(self, reactor, clock, homeserver):
- class Authenticator(object):
- def authenticate_request(self, request, content):
- return defer.succeed("otherserver.nottld")
-
- ratelimiter = FederationRateLimiter(
- clock,
- FederationRateLimitConfig(
- window_size=1,
- sleep_limit=1,
- sleep_msec=1,
- reject_limit=1000,
- concurrent_requests=1000,
- ),
- )
- server.register_servlets(
- homeserver, self.resource, Authenticator(), ratelimiter
- )
-
def test_complexity_simple(self):
u1 = self.register_user("u1", "pass")
@@ -101,11 +79,13 @@ class RoomComplexityTests(unittest.HomeserverTestCase):
# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=defer.succeed({"v1": 9999}))
- handler.federation_handler.do_invite_join = Mock(return_value=defer.succeed(1))
+ handler.federation_handler.do_invite_join = Mock(
+ return_value=defer.succeed(("", 1))
+ )
d = handler._remote_join(
None,
- ["otherserver.example"],
+ ["other.example.com"],
"roomid",
UserID.from_string(u1),
{"membership": "join"},
@@ -137,7 +117,9 @@ class RoomComplexityTests(unittest.HomeserverTestCase):
# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=defer.succeed(None))
- handler.federation_handler.do_invite_join = Mock(return_value=defer.succeed(1))
+ handler.federation_handler.do_invite_join = Mock(
+ return_value=defer.succeed(("", 1))
+ )
# Artificially raise the complexity
self.hs.get_datastore().get_current_state_event_counts = lambda x: defer.succeed(
@@ -146,7 +128,7 @@ class RoomComplexityTests(unittest.HomeserverTestCase):
d = handler._remote_join(
None,
- ["otherserver.example"],
+ ["other.example.com"],
room_1,
UserID.from_string(u1),
{"membership": "join"},
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index cce8d8c6de..ff12539041 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -12,23 +12,30 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Optional
from mock import Mock
+from signedjson import key, sign
+from signedjson.types import BaseKey, SigningKey
+
from twisted.internet import defer
-from synapse.types import ReadReceipt
+from synapse.rest import admin
+from synapse.rest.client.v1 import login
+from synapse.types import JsonDict, ReadReceipt
-from tests.unittest import HomeserverTestCase
+from tests.unittest import HomeserverTestCase, override_config
-class FederationSenderTestCases(HomeserverTestCase):
+class FederationSenderReceiptsTestCases(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
- return super(FederationSenderTestCases, self).setup_test_homeserver(
+ return self.setup_test_homeserver(
state_handler=Mock(spec=["get_current_hosts_in_room"]),
federation_transport_client=Mock(spec=["send_transaction"]),
)
+ @override_config({"send_federation": True})
def test_send_receipts(self):
mock_state_handler = self.hs.get_state_handler()
mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"]
@@ -69,6 +76,7 @@ class FederationSenderTestCases(HomeserverTestCase):
],
)
+ @override_config({"send_federation": True})
def test_send_receipts_with_backoff(self):
"""Send two receipts in quick succession; the second should be flushed, but
only after 20ms"""
@@ -145,3 +153,392 @@ class FederationSenderTestCases(HomeserverTestCase):
}
],
)
+
+
+class FederationSenderDevicesTestCases(HomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ return self.setup_test_homeserver(
+ state_handler=Mock(spec=["get_current_hosts_in_room"]),
+ federation_transport_client=Mock(spec=["send_transaction"]),
+ )
+
+ def default_config(self):
+ c = super().default_config()
+ c["send_federation"] = True
+ return c
+
+ def prepare(self, reactor, clock, hs):
+ # stub out get_current_hosts_in_room
+ mock_state_handler = hs.get_state_handler()
+ mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"]
+
+ # stub out get_users_who_share_room_with_user so that it claims that
+ # `@user2:host2` is in the room
+ def get_users_who_share_room_with_user(user_id):
+ return defer.succeed({"@user2:host2"})
+
+ hs.get_datastore().get_users_who_share_room_with_user = (
+ get_users_who_share_room_with_user
+ )
+
+ # whenever send_transaction is called, record the edu data
+ self.edus = []
+ self.hs.get_federation_transport_client().send_transaction.side_effect = (
+ self.record_transaction
+ )
+
+ def record_transaction(self, txn, json_cb):
+ data = json_cb()
+ self.edus.extend(data["edus"])
+ return defer.succeed({})
+
+ def test_send_device_updates(self):
+ """Basic case: each device update should result in an EDU"""
+ # create a device
+ u1 = self.register_user("user", "pass")
+ self.login(u1, "pass", device_id="D1")
+
+ # expect one edu
+ self.assertEqual(len(self.edus), 1)
+ stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D1", None)
+
+ # a second call should produce no new device EDUs
+ self.hs.get_federation_sender().send_device_messages("host2")
+ self.pump()
+ self.assertEqual(self.edus, [])
+
+ # a second device
+ self.login("user", "pass", device_id="D2")
+
+ self.assertEqual(len(self.edus), 1)
+ self.check_device_update_edu(self.edus.pop(0), u1, "D2", stream_id)
+
+ def test_upload_signatures(self):
+ """Uploading signatures on some devices should produce updates for that user"""
+
+ e2e_handler = self.hs.get_e2e_keys_handler()
+
+ # register two devices
+ u1 = self.register_user("user", "pass")
+ self.login(u1, "pass", device_id="D1")
+ self.login(u1, "pass", device_id="D2")
+
+ # expect two edus
+ self.assertEqual(len(self.edus), 2)
+ stream_id = None
+ stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D1", stream_id)
+ stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D2", stream_id)
+
+ # upload signing keys for each device
+ device1_signing_key = self.generate_and_upload_device_signing_key(u1, "D1")
+ device2_signing_key = self.generate_and_upload_device_signing_key(u1, "D2")
+
+ # expect two more edus
+ self.assertEqual(len(self.edus), 2)
+ stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D1", stream_id)
+ stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D2", stream_id)
+
+ # upload master key and self-signing key
+ master_signing_key = generate_self_id_key()
+ master_key = {
+ "user_id": u1,
+ "usage": ["master"],
+ "keys": {key_id(master_signing_key): encode_pubkey(master_signing_key)},
+ }
+
+ # private key: HvQBbU+hc2Zr+JP1sE0XwBe1pfZZEYtJNPJLZJtS+F8
+ selfsigning_signing_key = generate_self_id_key()
+ selfsigning_key = {
+ "user_id": u1,
+ "usage": ["self_signing"],
+ "keys": {
+ key_id(selfsigning_signing_key): encode_pubkey(selfsigning_signing_key)
+ },
+ }
+ sign.sign_json(selfsigning_key, u1, master_signing_key)
+
+ cross_signing_keys = {
+ "master_key": master_key,
+ "self_signing_key": selfsigning_key,
+ }
+
+ self.get_success(
+ e2e_handler.upload_signing_keys_for_user(u1, cross_signing_keys)
+ )
+
+ # expect signing key update edu
+ self.assertEqual(len(self.edus), 1)
+ self.assertEqual(self.edus.pop(0)["edu_type"], "org.matrix.signing_key_update")
+
+ # sign the devices
+ d1_json = build_device_dict(u1, "D1", device1_signing_key)
+ sign.sign_json(d1_json, u1, selfsigning_signing_key)
+ d2_json = build_device_dict(u1, "D2", device2_signing_key)
+ sign.sign_json(d2_json, u1, selfsigning_signing_key)
+
+ ret = self.get_success(
+ e2e_handler.upload_signatures_for_device_keys(
+ u1, {u1: {"D1": d1_json, "D2": d2_json}},
+ )
+ )
+ self.assertEqual(ret["failures"], {})
+
+ # expect two edus, in one or two transactions. We don't know what order the
+ # devices will be updated.
+ self.assertEqual(len(self.edus), 2)
+ stream_id = None # FIXME: there is a discontinuity in the stream IDs: see #7142
+ for edu in self.edus:
+ self.assertEqual(edu["edu_type"], "m.device_list_update")
+ c = edu["content"]
+ if stream_id is not None:
+ self.assertEqual(c["prev_id"], [stream_id])
+ self.assertGreaterEqual(c["stream_id"], stream_id)
+ stream_id = c["stream_id"]
+ devices = {edu["content"]["device_id"] for edu in self.edus}
+ self.assertEqual({"D1", "D2"}, devices)
+
+ def test_delete_devices(self):
+ """If devices are deleted, that should result in EDUs too"""
+
+ # create devices
+ u1 = self.register_user("user", "pass")
+ self.login("user", "pass", device_id="D1")
+ self.login("user", "pass", device_id="D2")
+ self.login("user", "pass", device_id="D3")
+
+ # expect three edus
+ self.assertEqual(len(self.edus), 3)
+ stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D1", None)
+ stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D2", stream_id)
+ stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D3", stream_id)
+
+ # delete them again
+ self.get_success(
+ self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"])
+ )
+
+ # expect three edus, in an unknown order
+ self.assertEqual(len(self.edus), 3)
+ for edu in self.edus:
+ self.assertEqual(edu["edu_type"], "m.device_list_update")
+ c = edu["content"]
+ self.assertGreaterEqual(
+ c.items(),
+ {"user_id": u1, "prev_id": [stream_id], "deleted": True}.items(),
+ )
+ self.assertGreaterEqual(c["stream_id"], stream_id)
+ stream_id = c["stream_id"]
+ devices = {edu["content"]["device_id"] for edu in self.edus}
+ self.assertEqual({"D1", "D2", "D3"}, devices)
+
+ def test_unreachable_server(self):
+ """If the destination server is unreachable, all the updates should get sent on
+ recovery
+ """
+ mock_send_txn = self.hs.get_federation_transport_client().send_transaction
+ mock_send_txn.side_effect = lambda t, cb: defer.fail("fail")
+
+ # create devices
+ u1 = self.register_user("user", "pass")
+ self.login("user", "pass", device_id="D1")
+ self.login("user", "pass", device_id="D2")
+ self.login("user", "pass", device_id="D3")
+
+ # delete them again
+ self.get_success(
+ self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"])
+ )
+
+ self.assertGreaterEqual(mock_send_txn.call_count, 4)
+
+ # recover the server
+ mock_send_txn.side_effect = self.record_transaction
+ self.hs.get_federation_sender().send_device_messages("host2")
+ self.pump()
+
+ # for each device, there should be a single update
+ self.assertEqual(len(self.edus), 3)
+ stream_id = None
+ for edu in self.edus:
+ self.assertEqual(edu["edu_type"], "m.device_list_update")
+ c = edu["content"]
+ self.assertEqual(c["prev_id"], [stream_id] if stream_id is not None else [])
+ if stream_id is not None:
+ self.assertGreaterEqual(c["stream_id"], stream_id)
+ stream_id = c["stream_id"]
+ devices = {edu["content"]["device_id"] for edu in self.edus}
+ self.assertEqual({"D1", "D2", "D3"}, devices)
+
+ def test_prune_outbound_device_pokes1(self):
+ """If a destination is unreachable, and the updates are pruned, we should get
+ a single update.
+
+ This case tests the behaviour when the server has never been reachable.
+ """
+ mock_send_txn = self.hs.get_federation_transport_client().send_transaction
+ mock_send_txn.side_effect = lambda t, cb: defer.fail("fail")
+
+ # create devices
+ u1 = self.register_user("user", "pass")
+ self.login("user", "pass", device_id="D1")
+ self.login("user", "pass", device_id="D2")
+ self.login("user", "pass", device_id="D3")
+
+ # delete them again
+ self.get_success(
+ self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"])
+ )
+
+ self.assertGreaterEqual(mock_send_txn.call_count, 4)
+
+ # run the prune job
+ self.reactor.advance(10)
+ self.get_success(
+ self.hs.get_datastore()._prune_old_outbound_device_pokes(prune_age=1)
+ )
+
+ # recover the server
+ mock_send_txn.side_effect = self.record_transaction
+ self.hs.get_federation_sender().send_device_messages("host2")
+ self.pump()
+
+ # there should be a single update for this user.
+ self.assertEqual(len(self.edus), 1)
+ edu = self.edus.pop(0)
+ self.assertEqual(edu["edu_type"], "m.device_list_update")
+ c = edu["content"]
+
+ # synapse uses an empty prev_id list to indicate "needs a full resync".
+ self.assertEqual(c["prev_id"], [])
+
+ def test_prune_outbound_device_pokes2(self):
+ """If a destination is unreachable, and the updates are pruned, we should get
+ a single update.
+
+ This case tests the behaviour when the server was reachable, but then goes
+ offline.
+ """
+
+ # create first device
+ u1 = self.register_user("user", "pass")
+ self.login("user", "pass", device_id="D1")
+
+ # expect the update EDU
+ self.assertEqual(len(self.edus), 1)
+ self.check_device_update_edu(self.edus.pop(0), u1, "D1", None)
+
+ # now the server goes offline
+ mock_send_txn = self.hs.get_federation_transport_client().send_transaction
+ mock_send_txn.side_effect = lambda t, cb: defer.fail("fail")
+
+ self.login("user", "pass", device_id="D2")
+ self.login("user", "pass", device_id="D3")
+
+ # delete them again
+ self.get_success(
+ self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"])
+ )
+
+ self.assertGreaterEqual(mock_send_txn.call_count, 3)
+
+ # run the prune job
+ self.reactor.advance(10)
+ self.get_success(
+ self.hs.get_datastore()._prune_old_outbound_device_pokes(prune_age=1)
+ )
+
+ # recover the server
+ mock_send_txn.side_effect = self.record_transaction
+ self.hs.get_federation_sender().send_device_messages("host2")
+ self.pump()
+
+ # ... and we should get a single update for this user.
+ self.assertEqual(len(self.edus), 1)
+ edu = self.edus.pop(0)
+ self.assertEqual(edu["edu_type"], "m.device_list_update")
+ c = edu["content"]
+
+ # synapse uses an empty prev_id list to indicate "needs a full resync".
+ self.assertEqual(c["prev_id"], [])
+
+ def check_device_update_edu(
+ self,
+ edu: JsonDict,
+ user_id: str,
+ device_id: str,
+ prev_stream_id: Optional[int],
+ ) -> int:
+ """Check that the given EDU is an update for the given device
+ Returns the stream_id.
+ """
+ self.assertEqual(edu["edu_type"], "m.device_list_update")
+ content = edu["content"]
+
+ expected = {
+ "user_id": user_id,
+ "device_id": device_id,
+ "prev_id": [prev_stream_id] if prev_stream_id is not None else [],
+ }
+
+ self.assertLessEqual(expected.items(), content.items())
+ if prev_stream_id is not None:
+ self.assertGreaterEqual(content["stream_id"], prev_stream_id)
+ return content["stream_id"]
+
+ def check_signing_key_update_txn(self, txn: JsonDict,) -> None:
+ """Check that the txn has an EDU with a signing key update.
+ """
+ edus = txn["edus"]
+ self.assertEqual(len(edus), 1)
+
+ def generate_and_upload_device_signing_key(
+ self, user_id: str, device_id: str
+ ) -> SigningKey:
+ """Generate a signing keypair for the given device, and upload it"""
+ sk = key.generate_signing_key(device_id)
+
+ device_dict = build_device_dict(user_id, device_id, sk)
+
+ self.get_success(
+ self.hs.get_e2e_keys_handler().upload_keys_for_user(
+ user_id, device_id, {"device_keys": device_dict},
+ )
+ )
+ return sk
+
+
+def generate_self_id_key() -> SigningKey:
+ """generate a signing key whose version is its public key
+
+ ... as used by the cross-signing-keys.
+ """
+ k = key.generate_signing_key("x")
+ k.version = encode_pubkey(k)
+ return k
+
+
+def key_id(k: BaseKey) -> str:
+ return "%s:%s" % (k.alg, k.version)
+
+
+def encode_pubkey(sk: SigningKey) -> str:
+ """Encode the public key corresponding to the given signing key as base64"""
+ return key.encode_verify_key_base64(key.get_verify_key(sk))
+
+
+def build_device_dict(user_id: str, device_id: str, sk: SigningKey):
+ """Build a dict representing the given device"""
+ return {
+ "user_id": user_id,
+ "device_id": device_id,
+ "algorithms": ["m.olm.curve25519-aes-sha2", "m.megolm.v1.aes-sha2"],
+ "keys": {
+ "curve25519:" + device_id: "curve25519+key",
+ key_id(sk): encode_pubkey(sk),
+ },
+ }
diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py
index b08be451aa..296dc887be 100644
--- a/tests/federation/test_federation_server.py
+++ b/tests/federation/test_federation_server.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
+# Copyright 2019 Matrix.org Federation C.I.C
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,8 +15,10 @@
# limitations under the License.
import logging
-from synapse.events import FrozenEvent
+from synapse.events import make_event_from_dict
from synapse.federation.federation_server import server_matches_acl_event
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
from tests import unittest
@@ -41,8 +44,68 @@ class ServerACLsTestCase(unittest.TestCase):
self.assertTrue(server_matches_acl_event("1:2:3:4", e))
+class StateQueryTests(unittest.FederatingHomeserverTestCase):
+
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def test_without_event_id(self):
+ """
+ Querying v1/state/<room_id> without an event ID will return the current
+ known state.
+ """
+ u1 = self.register_user("u1", "pass")
+ u1_token = self.login("u1", "pass")
+
+ room_1 = self.helper.create_room_as(u1, tok=u1_token)
+ self.inject_room_member(room_1, "@user:other.example.com", "join")
+
+ request, channel = self.make_request(
+ "GET", "/_matrix/federation/v1/state/%s" % (room_1,)
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.result)
+
+ self.assertEqual(
+ channel.json_body["room_version"],
+ self.hs.config.default_room_version.identifier,
+ )
+
+ members = set(
+ map(
+ lambda x: x["state_key"],
+ filter(
+ lambda x: x["type"] == "m.room.member", channel.json_body["pdus"]
+ ),
+ )
+ )
+
+ self.assertEqual(members, {"@user:other.example.com", u1})
+ self.assertEqual(len(channel.json_body["pdus"]), 6)
+
+ def test_needs_to_be_in_room(self):
+ """
+ Querying v1/state/<room_id> requires the server
+ be in the room to provide data.
+ """
+ u1 = self.register_user("u1", "pass")
+ u1_token = self.login("u1", "pass")
+
+ room_1 = self.helper.create_room_as(u1, tok=u1_token)
+
+ request, channel = self.make_request(
+ "GET", "/_matrix/federation/v1/state/%s" % (room_1,)
+ )
+ self.render(request)
+ self.assertEquals(403, channel.code, channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+
+
def _create_acl_event(content):
- return FrozenEvent(
+ return make_event_from_dict(
{
"room_id": "!a:b",
"event_id": "$a:b",
diff --git a/tests/federation/transport/test_server.py b/tests/federation/transport/test_server.py
new file mode 100644
index 0000000000..27d83bb7d9
--- /dev/null
+++ b/tests/federation/transport/test_server.py
@@ -0,0 +1,52 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from twisted.internet import defer
+
+from synapse.config.ratelimiting import FederationRateLimitConfig
+from synapse.federation.transport import server
+from synapse.util.ratelimitutils import FederationRateLimiter
+
+from tests import unittest
+from tests.unittest import override_config
+
+
+class RoomDirectoryFederationTests(unittest.HomeserverTestCase):
+ def prepare(self, reactor, clock, homeserver):
+ class Authenticator(object):
+ def authenticate_request(self, request, content):
+ return defer.succeed("otherserver.nottld")
+
+ ratelimiter = FederationRateLimiter(clock, FederationRateLimitConfig())
+ server.register_servlets(
+ homeserver, self.resource, Authenticator(), ratelimiter
+ )
+
+ @override_config({"allow_public_rooms_over_federation": False})
+ def test_blocked_public_room_list_over_federation(self):
+ request, channel = self.make_request(
+ "GET", "/_matrix/federation/v1/publicRooms"
+ )
+ self.render(request)
+ self.assertEquals(403, channel.code)
+
+ @override_config({"allow_public_rooms_over_federation": True})
+ def test_open_public_room_list_over_federation(self):
+ request, channel = self.make_request(
+ "GET", "/_matrix/federation/v1/publicRooms"
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code)
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index b03103d96f..c01b04e1dc 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -39,8 +39,13 @@ class AuthTestCase(unittest.TestCase):
self.hs.handlers = AuthHandlers(self.hs)
self.auth_handler = self.hs.handlers.auth_handler
self.macaroon_generator = self.hs.get_macaroon_generator()
+
# MAU tests
- self.hs.config.max_mau_value = 50
+ # AuthBlocking reads from the hs' config on initialization. We need to
+ # modify its config instead of the hs'
+ self.auth_blocking = self.hs.get_auth()._auth_blocking
+ self.auth_blocking._max_mau_value = 50
+
self.small_number_of_users = 1
self.large_number_of_users = 100
@@ -82,16 +87,16 @@ class AuthTestCase(unittest.TestCase):
self.hs.clock.now = 1000
token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
- user_id = yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
- token
+ user_id = yield defer.ensureDeferred(
+ self.auth_handler.validate_short_term_login_token_and_get_user_id(token)
)
self.assertEqual("a_user", user_id)
# when we advance the clock, the token should be rejected
self.hs.clock.now = 6000
with self.assertRaises(synapse.api.errors.AuthError):
- yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
- token
+ yield defer.ensureDeferred(
+ self.auth_handler.validate_short_term_login_token_and_get_user_id(token)
)
@defer.inlineCallbacks
@@ -99,8 +104,10 @@ class AuthTestCase(unittest.TestCase):
token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
macaroon = pymacaroons.Macaroon.deserialize(token)
- user_id = yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
- macaroon.serialize()
+ user_id = yield defer.ensureDeferred(
+ self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ macaroon.serialize()
+ )
)
self.assertEqual("a_user", user_id)
@@ -109,99 +116,121 @@ class AuthTestCase(unittest.TestCase):
macaroon.add_first_party_caveat("user_id = b_user")
with self.assertRaises(synapse.api.errors.AuthError):
- yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
- macaroon.serialize()
+ yield defer.ensureDeferred(
+ self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ macaroon.serialize()
+ )
)
@defer.inlineCallbacks
def test_mau_limits_disabled(self):
- self.hs.config.limit_usage_by_mau = False
+ self.auth_blocking._limit_usage_by_mau = False
# Ensure does not throw exception
- yield self.auth_handler.get_access_token_for_user_id(
- "user_a", device_id=None, valid_until_ms=None
+ yield defer.ensureDeferred(
+ self.auth_handler.get_access_token_for_user_id(
+ "user_a", device_id=None, valid_until_ms=None
+ )
)
- yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
- self._get_macaroon().serialize()
+ yield defer.ensureDeferred(
+ self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ self._get_macaroon().serialize()
+ )
)
@defer.inlineCallbacks
def test_mau_limits_exceeded_large(self):
- self.hs.config.limit_usage_by_mau = True
+ self.auth_blocking._limit_usage_by_mau = True
self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.large_number_of_users)
)
with self.assertRaises(ResourceLimitError):
- yield self.auth_handler.get_access_token_for_user_id(
- "user_a", device_id=None, valid_until_ms=None
+ yield defer.ensureDeferred(
+ self.auth_handler.get_access_token_for_user_id(
+ "user_a", device_id=None, valid_until_ms=None
+ )
)
self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.large_number_of_users)
)
with self.assertRaises(ResourceLimitError):
- yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
- self._get_macaroon().serialize()
+ yield defer.ensureDeferred(
+ self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ self._get_macaroon().serialize()
+ )
)
@defer.inlineCallbacks
def test_mau_limits_parity(self):
- self.hs.config.limit_usage_by_mau = True
+ self.auth_blocking._limit_usage_by_mau = True
# If not in monthly active cohort
self.hs.get_datastore().get_monthly_active_count = Mock(
- return_value=defer.succeed(self.hs.config.max_mau_value)
+ return_value=defer.succeed(self.auth_blocking._max_mau_value)
)
with self.assertRaises(ResourceLimitError):
- yield self.auth_handler.get_access_token_for_user_id(
- "user_a", device_id=None, valid_until_ms=None
+ yield defer.ensureDeferred(
+ self.auth_handler.get_access_token_for_user_id(
+ "user_a", device_id=None, valid_until_ms=None
+ )
)
self.hs.get_datastore().get_monthly_active_count = Mock(
- return_value=defer.succeed(self.hs.config.max_mau_value)
+ return_value=defer.succeed(self.auth_blocking._max_mau_value)
)
with self.assertRaises(ResourceLimitError):
- yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
- self._get_macaroon().serialize()
+ yield defer.ensureDeferred(
+ self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ self._get_macaroon().serialize()
+ )
)
# If in monthly active cohort
self.hs.get_datastore().user_last_seen_monthly_active = Mock(
return_value=defer.succeed(self.hs.get_clock().time_msec())
)
self.hs.get_datastore().get_monthly_active_count = Mock(
- return_value=defer.succeed(self.hs.config.max_mau_value)
+ return_value=defer.succeed(self.auth_blocking._max_mau_value)
)
- yield self.auth_handler.get_access_token_for_user_id(
- "user_a", device_id=None, valid_until_ms=None
+ yield defer.ensureDeferred(
+ self.auth_handler.get_access_token_for_user_id(
+ "user_a", device_id=None, valid_until_ms=None
+ )
)
self.hs.get_datastore().user_last_seen_monthly_active = Mock(
return_value=defer.succeed(self.hs.get_clock().time_msec())
)
self.hs.get_datastore().get_monthly_active_count = Mock(
- return_value=defer.succeed(self.hs.config.max_mau_value)
+ return_value=defer.succeed(self.auth_blocking._max_mau_value)
)
- yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
- self._get_macaroon().serialize()
+ yield defer.ensureDeferred(
+ self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ self._get_macaroon().serialize()
+ )
)
@defer.inlineCallbacks
def test_mau_limits_not_exceeded(self):
- self.hs.config.limit_usage_by_mau = True
+ self.auth_blocking._limit_usage_by_mau = True
self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.small_number_of_users)
)
# Ensure does not raise exception
- yield self.auth_handler.get_access_token_for_user_id(
- "user_a", device_id=None, valid_until_ms=None
+ yield defer.ensureDeferred(
+ self.auth_handler.get_access_token_for_user_id(
+ "user_a", device_id=None, valid_until_ms=None
+ )
)
self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.small_number_of_users)
)
- yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
- self._get_macaroon().serialize()
+ yield defer.ensureDeferred(
+ self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ self._get_macaroon().serialize()
+ )
)
def _get_macaroon(self):
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index a3aa0a1cf2..62b47f6574 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -160,6 +160,24 @@ class DeviceTestCase(unittest.HomeserverTestCase):
res = self.get_success(self.handler.get_device(user1, "abc"))
self.assertEqual(res["display_name"], "new display")
+ def test_update_device_too_long_display_name(self):
+ """Update a device with a display name that is invalid (too long)."""
+ self._record_users()
+
+ # Request to update a device display name with a new value that is longer than allowed.
+ update = {
+ "display_name": "a"
+ * (synapse.handlers.device.MAX_DEVICE_DISPLAY_NAME_LEN + 1)
+ }
+ self.get_failure(
+ self.handler.update_device(user1, "abc", update),
+ synapse.api.errors.SynapseError,
+ )
+
+ # Ensure the display name was not updated.
+ res = self.get_success(self.handler.get_device(user1, "abc"))
+ self.assertEqual(res["display_name"], "display 2")
+
def test_update_unknown_device(self):
update = {"display_name": "new_display"}
res = self.handler.update_device("user_id", "unknown_device_id", update)
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index 91c7a17070..00bb776271 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -18,25 +18,20 @@ from mock import Mock
from twisted.internet import defer
+import synapse
+import synapse.api.errors
+from synapse.api.constants import EventTypes
from synapse.config.room_directory import RoomDirectoryConfig
-from synapse.handlers.directory import DirectoryHandler
-from synapse.rest.client.v1 import directory, room
-from synapse.types import RoomAlias
+from synapse.rest.client.v1 import directory, login, room
+from synapse.types import RoomAlias, create_requester
from tests import unittest
-from tests.utils import setup_test_homeserver
-class DirectoryHandlers(object):
- def __init__(self, hs):
- self.directory_handler = DirectoryHandler(hs)
-
-
-class DirectoryTestCase(unittest.TestCase):
+class DirectoryTestCase(unittest.HomeserverTestCase):
""" Tests the directory service. """
- @defer.inlineCallbacks
- def setUp(self):
+ def make_homeserver(self, reactor, clock):
self.mock_federation = Mock()
self.mock_registry = Mock()
@@ -47,14 +42,12 @@ class DirectoryTestCase(unittest.TestCase):
self.mock_registry.register_query_handler = register_query_handler
- hs = yield setup_test_homeserver(
- self.addCleanup,
+ hs = self.setup_test_homeserver(
http_client=None,
resource_for_federation=Mock(),
federation_client=self.mock_federation,
federation_registry=self.mock_registry,
)
- hs.handlers = DirectoryHandlers(hs)
self.handler = hs.get_handlers().directory_handler
@@ -64,23 +57,25 @@ class DirectoryTestCase(unittest.TestCase):
self.your_room = RoomAlias.from_string("#your-room:test")
self.remote_room = RoomAlias.from_string("#another:remote")
- @defer.inlineCallbacks
+ return hs
+
def test_get_local_association(self):
- yield self.store.create_room_alias_association(
- self.my_room, "!8765qwer:test", ["test"]
+ self.get_success(
+ self.store.create_room_alias_association(
+ self.my_room, "!8765qwer:test", ["test"]
+ )
)
- result = yield self.handler.get_association(self.my_room)
+ result = self.get_success(self.handler.get_association(self.my_room))
self.assertEquals({"room_id": "!8765qwer:test", "servers": ["test"]}, result)
- @defer.inlineCallbacks
def test_get_remote_association(self):
self.mock_federation.make_query.return_value = defer.succeed(
{"room_id": "!8765qwer:test", "servers": ["test", "remote"]}
)
- result = yield self.handler.get_association(self.remote_room)
+ result = self.get_success(self.handler.get_association(self.remote_room))
self.assertEquals(
{"room_id": "!8765qwer:test", "servers": ["test", "remote"]}, result
@@ -93,19 +88,303 @@ class DirectoryTestCase(unittest.TestCase):
ignore_backoff=True,
)
- @defer.inlineCallbacks
def test_incoming_fed_query(self):
- yield self.store.create_room_alias_association(
- self.your_room, "!8765asdf:test", ["test"]
+ self.get_success(
+ self.store.create_room_alias_association(
+ self.your_room, "!8765asdf:test", ["test"]
+ )
)
- response = yield self.query_handlers["directory"](
- {"room_alias": "#your-room:test"}
+ response = self.get_success(
+ self.handler.on_directory_query({"room_alias": "#your-room:test"})
)
self.assertEquals({"room_id": "!8765asdf:test", "servers": ["test"]}, response)
+class TestCreateAlias(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ directory.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.handler = hs.get_handlers().directory_handler
+
+ # Create user
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ # Create a test room
+ self.room_id = self.helper.create_room_as(
+ self.admin_user, tok=self.admin_user_tok
+ )
+
+ self.test_alias = "#test:test"
+ self.room_alias = RoomAlias.from_string(self.test_alias)
+
+ # Create a test user.
+ self.test_user = self.register_user("user", "pass", admin=False)
+ self.test_user_tok = self.login("user", "pass")
+ self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok)
+
+ def test_create_alias_joined_room(self):
+ """A user can create an alias for a room they're in."""
+ self.get_success(
+ self.handler.create_association(
+ create_requester(self.test_user), self.room_alias, self.room_id,
+ )
+ )
+
+ def test_create_alias_other_room(self):
+ """A user cannot create an alias for a room they're NOT in."""
+ other_room_id = self.helper.create_room_as(
+ self.admin_user, tok=self.admin_user_tok
+ )
+
+ self.get_failure(
+ self.handler.create_association(
+ create_requester(self.test_user), self.room_alias, other_room_id,
+ ),
+ synapse.api.errors.SynapseError,
+ )
+
+ def test_create_alias_admin(self):
+ """An admin can create an alias for a room they're NOT in."""
+ other_room_id = self.helper.create_room_as(
+ self.test_user, tok=self.test_user_tok
+ )
+
+ self.get_success(
+ self.handler.create_association(
+ create_requester(self.admin_user), self.room_alias, other_room_id,
+ )
+ )
+
+
+class TestDeleteAlias(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ directory.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+ self.handler = hs.get_handlers().directory_handler
+ self.state_handler = hs.get_state_handler()
+
+ # Create user
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ # Create a test room
+ self.room_id = self.helper.create_room_as(
+ self.admin_user, tok=self.admin_user_tok
+ )
+
+ self.test_alias = "#test:test"
+ self.room_alias = RoomAlias.from_string(self.test_alias)
+
+ # Create a test user.
+ self.test_user = self.register_user("user", "pass", admin=False)
+ self.test_user_tok = self.login("user", "pass")
+ self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok)
+
+ def _create_alias(self, user):
+ # Create a new alias to this room.
+ self.get_success(
+ self.store.create_room_alias_association(
+ self.room_alias, self.room_id, ["test"], user
+ )
+ )
+
+ def test_delete_alias_not_allowed(self):
+ """A user that doesn't meet the expected guidelines cannot delete an alias."""
+ self._create_alias(self.admin_user)
+ self.get_failure(
+ self.handler.delete_association(
+ create_requester(self.test_user), self.room_alias
+ ),
+ synapse.api.errors.AuthError,
+ )
+
+ def test_delete_alias_creator(self):
+ """An alias creator can delete their own alias."""
+ # Create an alias from a different user.
+ self._create_alias(self.test_user)
+
+ # Delete the user's alias.
+ result = self.get_success(
+ self.handler.delete_association(
+ create_requester(self.test_user), self.room_alias
+ )
+ )
+ self.assertEquals(self.room_id, result)
+
+ # Confirm the alias is gone.
+ self.get_failure(
+ self.handler.get_association(self.room_alias),
+ synapse.api.errors.SynapseError,
+ )
+
+ def test_delete_alias_admin(self):
+ """A server admin can delete an alias created by another user."""
+ # Create an alias from a different user.
+ self._create_alias(self.test_user)
+
+ # Delete the user's alias as the admin.
+ result = self.get_success(
+ self.handler.delete_association(
+ create_requester(self.admin_user), self.room_alias
+ )
+ )
+ self.assertEquals(self.room_id, result)
+
+ # Confirm the alias is gone.
+ self.get_failure(
+ self.handler.get_association(self.room_alias),
+ synapse.api.errors.SynapseError,
+ )
+
+ def test_delete_alias_sufficient_power(self):
+ """A user with a sufficient power level should be able to delete an alias."""
+ self._create_alias(self.admin_user)
+
+ # Increase the user's power level.
+ self.helper.send_state(
+ self.room_id,
+ "m.room.power_levels",
+ {"users": {self.test_user: 100}},
+ tok=self.admin_user_tok,
+ )
+
+ # They can now delete the alias.
+ result = self.get_success(
+ self.handler.delete_association(
+ create_requester(self.test_user), self.room_alias
+ )
+ )
+ self.assertEquals(self.room_id, result)
+
+ # Confirm the alias is gone.
+ self.get_failure(
+ self.handler.get_association(self.room_alias),
+ synapse.api.errors.SynapseError,
+ )
+
+
+class CanonicalAliasTestCase(unittest.HomeserverTestCase):
+ """Test modifications of the canonical alias when delete aliases.
+ """
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ directory.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+ self.handler = hs.get_handlers().directory_handler
+ self.state_handler = hs.get_state_handler()
+
+ # Create user
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ # Create a test room
+ self.room_id = self.helper.create_room_as(
+ self.admin_user, tok=self.admin_user_tok
+ )
+
+ self.test_alias = "#test:test"
+ self.room_alias = self._add_alias(self.test_alias)
+
+ def _add_alias(self, alias: str) -> RoomAlias:
+ """Add an alias to the test room."""
+ room_alias = RoomAlias.from_string(alias)
+
+ # Create a new alias to this room.
+ self.get_success(
+ self.store.create_room_alias_association(
+ room_alias, self.room_id, ["test"], self.admin_user
+ )
+ )
+ return room_alias
+
+ def _set_canonical_alias(self, content):
+ """Configure the canonical alias state on the room."""
+ self.helper.send_state(
+ self.room_id, "m.room.canonical_alias", content, tok=self.admin_user_tok,
+ )
+
+ def _get_canonical_alias(self):
+ """Get the canonical alias state of the room."""
+ return self.get_success(
+ self.state_handler.get_current_state(
+ self.room_id, EventTypes.CanonicalAlias, ""
+ )
+ )
+
+ def test_remove_alias(self):
+ """Removing an alias that is the canonical alias should remove it there too."""
+ # Set this new alias as the canonical alias for this room
+ self._set_canonical_alias(
+ {"alias": self.test_alias, "alt_aliases": [self.test_alias]}
+ )
+
+ data = self._get_canonical_alias()
+ self.assertEqual(data["content"]["alias"], self.test_alias)
+ self.assertEqual(data["content"]["alt_aliases"], [self.test_alias])
+
+ # Finally, delete the alias.
+ self.get_success(
+ self.handler.delete_association(
+ create_requester(self.admin_user), self.room_alias
+ )
+ )
+
+ data = self._get_canonical_alias()
+ self.assertNotIn("alias", data["content"])
+ self.assertNotIn("alt_aliases", data["content"])
+
+ def test_remove_other_alias(self):
+ """Removing an alias listed as in alt_aliases should remove it there too."""
+ # Create a second alias.
+ other_test_alias = "#test2:test"
+ other_room_alias = self._add_alias(other_test_alias)
+
+ # Set the alias as the canonical alias for this room.
+ self._set_canonical_alias(
+ {
+ "alias": self.test_alias,
+ "alt_aliases": [self.test_alias, other_test_alias],
+ }
+ )
+
+ data = self._get_canonical_alias()
+ self.assertEqual(data["content"]["alias"], self.test_alias)
+ self.assertEqual(
+ data["content"]["alt_aliases"], [self.test_alias, other_test_alias]
+ )
+
+ # Delete the second alias.
+ self.get_success(
+ self.handler.delete_association(
+ create_requester(self.admin_user), other_room_alias
+ )
+ )
+
+ data = self._get_canonical_alias()
+ self.assertEqual(data["content"]["alias"], self.test_alias)
+ self.assertEqual(data["content"]["alt_aliases"], [self.test_alias])
+
+
class TestCreateAliasACL(unittest.HomeserverTestCase):
user_id = "@test:test"
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 8dccc6826e..e1e144b2e7 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -1,5 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
+# Copyright 2019 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,9 +17,11 @@
import mock
+import signedjson.key as key
+import signedjson.sign as sign
+
from twisted.internet import defer
-import synapse.api.errors
import synapse.handlers.e2e_keys
import synapse.storage
from synapse.api import errors
@@ -145,3 +149,357 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"one_time_keys": {local_user: {device_id: {"alg1:k1": "key1"}}},
},
)
+
+ @defer.inlineCallbacks
+ def test_replace_master_key(self):
+ """uploading a new signing key should make the old signing key unavailable"""
+ local_user = "@boris:" + self.hs.hostname
+ keys1 = {
+ "master_key": {
+ # private key: 2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0
+ "user_id": local_user,
+ "usage": ["master"],
+ "keys": {
+ "ed25519:nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk": "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk"
+ },
+ }
+ }
+ yield self.handler.upload_signing_keys_for_user(local_user, keys1)
+
+ keys2 = {
+ "master_key": {
+ # private key: 4TL4AjRYwDVwD3pqQzcor+ez/euOB1/q78aTJ+czDNs
+ "user_id": local_user,
+ "usage": ["master"],
+ "keys": {
+ "ed25519:Hq6gL+utB4ET+UvD5ci0kgAwsX6qP/zvf8v6OInU5iw": "Hq6gL+utB4ET+UvD5ci0kgAwsX6qP/zvf8v6OInU5iw"
+ },
+ }
+ }
+ yield self.handler.upload_signing_keys_for_user(local_user, keys2)
+
+ devices = yield self.handler.query_devices(
+ {"device_keys": {local_user: []}}, 0, local_user
+ )
+ self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]})
+
+ @defer.inlineCallbacks
+ def test_reupload_signatures(self):
+ """re-uploading a signature should not fail"""
+ local_user = "@boris:" + self.hs.hostname
+ keys1 = {
+ "master_key": {
+ # private key: HvQBbU+hc2Zr+JP1sE0XwBe1pfZZEYtJNPJLZJtS+F8
+ "user_id": local_user,
+ "usage": ["master"],
+ "keys": {
+ "ed25519:EmkqvokUn8p+vQAGZitOk4PWjp7Ukp3txV2TbMPEiBQ": "EmkqvokUn8p+vQAGZitOk4PWjp7Ukp3txV2TbMPEiBQ"
+ },
+ },
+ "self_signing_key": {
+ # private key: 2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0
+ "user_id": local_user,
+ "usage": ["self_signing"],
+ "keys": {
+ "ed25519:nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk": "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk"
+ },
+ },
+ }
+ master_signing_key = key.decode_signing_key_base64(
+ "ed25519",
+ "EmkqvokUn8p+vQAGZitOk4PWjp7Ukp3txV2TbMPEiBQ",
+ "HvQBbU+hc2Zr+JP1sE0XwBe1pfZZEYtJNPJLZJtS+F8",
+ )
+ sign.sign_json(keys1["self_signing_key"], local_user, master_signing_key)
+ signing_key = key.decode_signing_key_base64(
+ "ed25519",
+ "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk",
+ "2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0",
+ )
+ yield self.handler.upload_signing_keys_for_user(local_user, keys1)
+
+ # upload two device keys, which will be signed later by the self-signing key
+ device_key_1 = {
+ "user_id": local_user,
+ "device_id": "abc",
+ "algorithms": ["m.olm.curve25519-aes-sha2", "m.megolm.v1.aes-sha2"],
+ "keys": {
+ "ed25519:abc": "base64+ed25519+key",
+ "curve25519:abc": "base64+curve25519+key",
+ },
+ "signatures": {local_user: {"ed25519:abc": "base64+signature"}},
+ }
+ device_key_2 = {
+ "user_id": local_user,
+ "device_id": "def",
+ "algorithms": ["m.olm.curve25519-aes-sha2", "m.megolm.v1.aes-sha2"],
+ "keys": {
+ "ed25519:def": "base64+ed25519+key",
+ "curve25519:def": "base64+curve25519+key",
+ },
+ "signatures": {local_user: {"ed25519:def": "base64+signature"}},
+ }
+
+ yield self.handler.upload_keys_for_user(
+ local_user, "abc", {"device_keys": device_key_1}
+ )
+ yield self.handler.upload_keys_for_user(
+ local_user, "def", {"device_keys": device_key_2}
+ )
+
+ # sign the first device key and upload it
+ del device_key_1["signatures"]
+ sign.sign_json(device_key_1, local_user, signing_key)
+ yield self.handler.upload_signatures_for_device_keys(
+ local_user, {local_user: {"abc": device_key_1}}
+ )
+
+ # sign the second device key and upload both device keys. The server
+ # should ignore the first device key since it already has a valid
+ # signature for it
+ del device_key_2["signatures"]
+ sign.sign_json(device_key_2, local_user, signing_key)
+ yield self.handler.upload_signatures_for_device_keys(
+ local_user, {local_user: {"abc": device_key_1, "def": device_key_2}}
+ )
+
+ device_key_1["signatures"][local_user]["ed25519:abc"] = "base64+signature"
+ device_key_2["signatures"][local_user]["ed25519:def"] = "base64+signature"
+ devices = yield self.handler.query_devices(
+ {"device_keys": {local_user: []}}, 0, local_user
+ )
+ del devices["device_keys"][local_user]["abc"]["unsigned"]
+ del devices["device_keys"][local_user]["def"]["unsigned"]
+ self.assertDictEqual(devices["device_keys"][local_user]["abc"], device_key_1)
+ self.assertDictEqual(devices["device_keys"][local_user]["def"], device_key_2)
+
+ @defer.inlineCallbacks
+ def test_self_signing_key_doesnt_show_up_as_device(self):
+ """signing keys should be hidden when fetching a user's devices"""
+ local_user = "@boris:" + self.hs.hostname
+ keys1 = {
+ "master_key": {
+ # private key: 2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0
+ "user_id": local_user,
+ "usage": ["master"],
+ "keys": {
+ "ed25519:nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk": "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk"
+ },
+ }
+ }
+ yield self.handler.upload_signing_keys_for_user(local_user, keys1)
+
+ res = None
+ try:
+ yield self.hs.get_device_handler().check_device_registered(
+ user_id=local_user,
+ device_id="nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk",
+ initial_device_display_name="new display name",
+ )
+ except errors.SynapseError as e:
+ res = e.code
+ self.assertEqual(res, 400)
+
+ res = yield self.handler.query_local_devices({local_user: None})
+ self.assertDictEqual(res, {local_user: {}})
+
+ @defer.inlineCallbacks
+ def test_upload_signatures(self):
+ """should check signatures that are uploaded"""
+ # set up a user with cross-signing keys and a device. This user will
+ # try uploading signatures
+ local_user = "@boris:" + self.hs.hostname
+ device_id = "xyz"
+ # private key: OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA
+ device_pubkey = "NnHhnqiMFQkq969szYkooLaBAXW244ZOxgukCvm2ZeY"
+ device_key = {
+ "user_id": local_user,
+ "device_id": device_id,
+ "algorithms": ["m.olm.curve25519-aes-sha2", "m.megolm.v1.aes-sha2"],
+ "keys": {"curve25519:xyz": "curve25519+key", "ed25519:xyz": device_pubkey},
+ "signatures": {local_user: {"ed25519:xyz": "something"}},
+ }
+ device_signing_key = key.decode_signing_key_base64(
+ "ed25519", "xyz", "OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA"
+ )
+
+ yield self.handler.upload_keys_for_user(
+ local_user, device_id, {"device_keys": device_key}
+ )
+
+ # private key: 2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0
+ master_pubkey = "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk"
+ master_key = {
+ "user_id": local_user,
+ "usage": ["master"],
+ "keys": {"ed25519:" + master_pubkey: master_pubkey},
+ }
+ master_signing_key = key.decode_signing_key_base64(
+ "ed25519", master_pubkey, "2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0"
+ )
+ usersigning_pubkey = "Hq6gL+utB4ET+UvD5ci0kgAwsX6qP/zvf8v6OInU5iw"
+ usersigning_key = {
+ # private key: 4TL4AjRYwDVwD3pqQzcor+ez/euOB1/q78aTJ+czDNs
+ "user_id": local_user,
+ "usage": ["user_signing"],
+ "keys": {"ed25519:" + usersigning_pubkey: usersigning_pubkey},
+ }
+ usersigning_signing_key = key.decode_signing_key_base64(
+ "ed25519", usersigning_pubkey, "4TL4AjRYwDVwD3pqQzcor+ez/euOB1/q78aTJ+czDNs"
+ )
+ sign.sign_json(usersigning_key, local_user, master_signing_key)
+ # private key: HvQBbU+hc2Zr+JP1sE0XwBe1pfZZEYtJNPJLZJtS+F8
+ selfsigning_pubkey = "EmkqvokUn8p+vQAGZitOk4PWjp7Ukp3txV2TbMPEiBQ"
+ selfsigning_key = {
+ "user_id": local_user,
+ "usage": ["self_signing"],
+ "keys": {"ed25519:" + selfsigning_pubkey: selfsigning_pubkey},
+ }
+ selfsigning_signing_key = key.decode_signing_key_base64(
+ "ed25519", selfsigning_pubkey, "HvQBbU+hc2Zr+JP1sE0XwBe1pfZZEYtJNPJLZJtS+F8"
+ )
+ sign.sign_json(selfsigning_key, local_user, master_signing_key)
+ cross_signing_keys = {
+ "master_key": master_key,
+ "user_signing_key": usersigning_key,
+ "self_signing_key": selfsigning_key,
+ }
+ yield self.handler.upload_signing_keys_for_user(local_user, cross_signing_keys)
+
+ # set up another user with a master key. This user will be signed by
+ # the first user
+ other_user = "@otherboris:" + self.hs.hostname
+ other_master_pubkey = "fHZ3NPiKxoLQm5OoZbKa99SYxprOjNs4TwJUKP+twCM"
+ other_master_key = {
+ # private key: oyw2ZUx0O4GifbfFYM0nQvj9CL0b8B7cyN4FprtK8OI
+ "user_id": other_user,
+ "usage": ["master"],
+ "keys": {"ed25519:" + other_master_pubkey: other_master_pubkey},
+ }
+ yield self.handler.upload_signing_keys_for_user(
+ other_user, {"master_key": other_master_key}
+ )
+
+ # test various signature failures (see below)
+ ret = yield self.handler.upload_signatures_for_device_keys(
+ local_user,
+ {
+ local_user: {
+ # fails because the signature is invalid
+ # should fail with INVALID_SIGNATURE
+ device_id: {
+ "user_id": local_user,
+ "device_id": device_id,
+ "algorithms": [
+ "m.olm.curve25519-aes-sha2",
+ "m.megolm.v1.aes-sha2",
+ ],
+ "keys": {
+ "curve25519:xyz": "curve25519+key",
+ # private key: OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA
+ "ed25519:xyz": device_pubkey,
+ },
+ "signatures": {
+ local_user: {"ed25519:" + selfsigning_pubkey: "something"}
+ },
+ },
+ # fails because device is unknown
+ # should fail with NOT_FOUND
+ "unknown": {
+ "user_id": local_user,
+ "device_id": "unknown",
+ "signatures": {
+ local_user: {"ed25519:" + selfsigning_pubkey: "something"}
+ },
+ },
+ # fails because the signature is invalid
+ # should fail with INVALID_SIGNATURE
+ master_pubkey: {
+ "user_id": local_user,
+ "usage": ["master"],
+ "keys": {"ed25519:" + master_pubkey: master_pubkey},
+ "signatures": {
+ local_user: {"ed25519:" + device_pubkey: "something"}
+ },
+ },
+ },
+ other_user: {
+ # fails because the device is not the user's master-signing key
+ # should fail with NOT_FOUND
+ "unknown": {
+ "user_id": other_user,
+ "device_id": "unknown",
+ "signatures": {
+ local_user: {"ed25519:" + usersigning_pubkey: "something"}
+ },
+ },
+ other_master_pubkey: {
+ # fails because the key doesn't match what the server has
+ # should fail with UNKNOWN
+ "user_id": other_user,
+ "usage": ["master"],
+ "keys": {"ed25519:" + other_master_pubkey: other_master_pubkey},
+ "something": "random",
+ "signatures": {
+ local_user: {"ed25519:" + usersigning_pubkey: "something"}
+ },
+ },
+ },
+ },
+ )
+
+ user_failures = ret["failures"][local_user]
+ self.assertEqual(
+ user_failures[device_id]["errcode"], errors.Codes.INVALID_SIGNATURE
+ )
+ self.assertEqual(
+ user_failures[master_pubkey]["errcode"], errors.Codes.INVALID_SIGNATURE
+ )
+ self.assertEqual(user_failures["unknown"]["errcode"], errors.Codes.NOT_FOUND)
+
+ other_user_failures = ret["failures"][other_user]
+ self.assertEqual(
+ other_user_failures["unknown"]["errcode"], errors.Codes.NOT_FOUND
+ )
+ self.assertEqual(
+ other_user_failures[other_master_pubkey]["errcode"], errors.Codes.UNKNOWN
+ )
+
+ # test successful signatures
+ del device_key["signatures"]
+ sign.sign_json(device_key, local_user, selfsigning_signing_key)
+ sign.sign_json(master_key, local_user, device_signing_key)
+ sign.sign_json(other_master_key, local_user, usersigning_signing_key)
+ ret = yield self.handler.upload_signatures_for_device_keys(
+ local_user,
+ {
+ local_user: {device_id: device_key, master_pubkey: master_key},
+ other_user: {other_master_pubkey: other_master_key},
+ },
+ )
+
+ self.assertEqual(ret["failures"], {})
+
+ # fetch the signed keys/devices and make sure that the signatures are there
+ ret = yield self.handler.query_devices(
+ {"device_keys": {local_user: [], other_user: []}}, 0, local_user
+ )
+
+ self.assertEqual(
+ ret["device_keys"][local_user]["xyz"]["signatures"][local_user][
+ "ed25519:" + selfsigning_pubkey
+ ],
+ device_key["signatures"][local_user]["ed25519:" + selfsigning_pubkey],
+ )
+ self.assertEqual(
+ ret["master_keys"][local_user]["signatures"][local_user][
+ "ed25519:" + device_id
+ ],
+ master_key["signatures"][local_user]["ed25519:" + device_id],
+ )
+ self.assertEqual(
+ ret["master_keys"][other_user]["signatures"][local_user][
+ "ed25519:" + usersigning_pubkey
+ ],
+ other_master_key["signatures"][local_user]["ed25519:" + usersigning_pubkey],
+ )
diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py
index c4503c1611..70f172eb02 100644
--- a/tests/handlers/test_e2e_room_keys.py
+++ b/tests/handlers/test_e2e_room_keys.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
# Copyright 2017 New Vector Ltd
+# Copyright 2019 Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -94,23 +95,29 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
# check we can retrieve it as the current version
res = yield self.handler.get_version_info(self.local_user)
+ version_etag = res["etag"]
+ del res["etag"]
self.assertDictEqual(
res,
{
"version": "1",
"algorithm": "m.megolm_backup.v1",
"auth_data": "first_version_auth_data",
+ "count": 0,
},
)
# check we can retrieve it as a specific version
res = yield self.handler.get_version_info(self.local_user, "1")
+ self.assertEqual(res["etag"], version_etag)
+ del res["etag"]
self.assertDictEqual(
res,
{
"version": "1",
"algorithm": "m.megolm_backup.v1",
"auth_data": "first_version_auth_data",
+ "count": 0,
},
)
@@ -126,12 +133,14 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
# check we can retrieve it as the current version
res = yield self.handler.get_version_info(self.local_user)
+ del res["etag"]
self.assertDictEqual(
res,
{
"version": "2",
"algorithm": "m.megolm_backup.v1",
"auth_data": "second_version_auth_data",
+ "count": 0,
},
)
@@ -158,12 +167,14 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
# check we can retrieve it as the current version
res = yield self.handler.get_version_info(self.local_user)
+ del res["etag"]
self.assertDictEqual(
res,
{
"algorithm": "m.megolm_backup.v1",
"auth_data": "revised_first_version_auth_data",
"version": version,
+ "count": 0,
},
)
@@ -187,9 +198,8 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
self.assertEqual(res, 404)
@defer.inlineCallbacks
- def test_update_bad_version(self):
- """Check that we get a 400 if the version in the body is missing or
- doesn't match
+ def test_update_omitted_version(self):
+ """Check that the update succeeds if the version is missing from the body
"""
version = yield self.handler.create_version(
self.local_user,
@@ -197,19 +207,37 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
)
self.assertEqual(version, "1")
- res = None
- try:
- yield self.handler.update_version(
- self.local_user,
- version,
- {
- "algorithm": "m.megolm_backup.v1",
- "auth_data": "revised_first_version_auth_data",
- },
- )
- except errors.SynapseError as e:
- res = e.code
- self.assertEqual(res, 400)
+ yield self.handler.update_version(
+ self.local_user,
+ version,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "revised_first_version_auth_data",
+ },
+ )
+
+ # check we can retrieve it as the current version
+ res = yield self.handler.get_version_info(self.local_user)
+ del res["etag"] # etag is opaque, so don't test its contents
+ self.assertDictEqual(
+ res,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "revised_first_version_auth_data",
+ "version": version,
+ "count": 0,
+ },
+ )
+
+ @defer.inlineCallbacks
+ def test_update_bad_version(self):
+ """Check that we get a 400 if the version in the body doesn't match
+ """
+ version = yield self.handler.create_version(
+ self.local_user,
+ {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+ )
+ self.assertEqual(version, "1")
res = None
try:
@@ -394,6 +422,11 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
yield self.handler.upload_room_keys(self.local_user, version, room_keys)
+ # get the etag to compare to future versions
+ res = yield self.handler.get_version_info(self.local_user)
+ backup_etag = res["etag"]
+ self.assertEqual(res["count"], 1)
+
new_room_keys = copy.deepcopy(room_keys)
new_room_key = new_room_keys["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]
@@ -408,6 +441,10 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
"SSBBTSBBIEZJU0gK",
)
+ # the etag should be the same since the session did not change
+ res = yield self.handler.get_version_info(self.local_user)
+ self.assertEqual(res["etag"], backup_etag)
+
# test that marking the session as verified however /does/ replace it
new_room_key["is_verified"] = True
yield self.handler.upload_room_keys(self.local_user, version, new_room_keys)
@@ -417,6 +454,11 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new"
)
+ # the etag should NOT be equal now, since the key changed
+ res = yield self.handler.get_version_info(self.local_user)
+ self.assertNotEqual(res["etag"], backup_etag)
+ backup_etag = res["etag"]
+
# test that a session with a higher forwarded_count doesn't replace one
# with a lower forwarding count
new_room_key["forwarded_count"] = 2
@@ -428,6 +470,10 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new"
)
+ # the etag should be the same since the session did not change
+ res = yield self.handler.get_version_info(self.local_user)
+ self.assertEqual(res["etag"], backup_etag)
+
# TODO: check edge cases as well as the common variations here
@defer.inlineCallbacks
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
new file mode 100644
index 0000000000..96fea58673
--- /dev/null
+++ b/tests/handlers/test_federation.py
@@ -0,0 +1,274 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+from unittest import TestCase
+
+from synapse.api.constants import EventTypes
+from synapse.api.errors import AuthError, Codes, SynapseError
+from synapse.api.room_versions import RoomVersions
+from synapse.events import EventBase
+from synapse.federation.federation_base import event_from_pdu_json
+from synapse.logging.context import LoggingContext, run_in_background
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+
+from tests import unittest
+
+logger = logging.getLogger(__name__)
+
+
+class FederationTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ hs = self.setup_test_homeserver(http_client=None)
+ self.handler = hs.get_handlers().federation_handler
+ self.store = hs.get_datastore()
+ return hs
+
+ def test_exchange_revoked_invite(self):
+ user_id = self.register_user("kermit", "test")
+ tok = self.login("kermit", "test")
+
+ room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
+
+ # Send a 3PID invite event with an empty body so it's considered as a revoked one.
+ invite_token = "sometoken"
+ self.helper.send_state(
+ room_id=room_id,
+ event_type=EventTypes.ThirdPartyInvite,
+ state_key=invite_token,
+ body={},
+ tok=tok,
+ )
+
+ d = self.handler.on_exchange_third_party_invite_request(
+ room_id=room_id,
+ event_dict={
+ "type": EventTypes.Member,
+ "room_id": room_id,
+ "sender": user_id,
+ "state_key": "@someone:example.org",
+ "content": {
+ "membership": "invite",
+ "third_party_invite": {
+ "display_name": "alice",
+ "signed": {
+ "mxid": "@alice:localhost",
+ "token": invite_token,
+ "signatures": {
+ "magic.forest": {
+ "ed25519:3": "fQpGIW1Snz+pwLZu6sTy2aHy/DYWWTspTJRPyNp0PKkymfIsNffysMl6ObMMFdIJhk6g6pwlIqZ54rxo8SLmAg"
+ }
+ },
+ },
+ },
+ },
+ },
+ )
+
+ failure = self.get_failure(d, AuthError).value
+
+ self.assertEqual(failure.code, 403, failure)
+ self.assertEqual(failure.errcode, Codes.FORBIDDEN, failure)
+ self.assertEqual(failure.msg, "You are not invited to this room.")
+
+ def test_rejected_message_event_state(self):
+ """
+ Check that we store the state group correctly for rejected non-state events.
+
+ Regression test for #6289.
+ """
+ OTHER_SERVER = "otherserver"
+ OTHER_USER = "@otheruser:" + OTHER_SERVER
+
+ # create the room
+ user_id = self.register_user("kermit", "test")
+ tok = self.login("kermit", "test")
+ room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
+ room_version = self.get_success(self.store.get_room_version(room_id))
+
+ # pretend that another server has joined
+ join_event = self._build_and_send_join_event(OTHER_SERVER, OTHER_USER, room_id)
+
+ # check the state group
+ sg = self.successResultOf(
+ self.store._get_state_group_for_event(join_event.event_id)
+ )
+
+ # build and send an event which will be rejected
+ ev = event_from_pdu_json(
+ {
+ "type": EventTypes.Message,
+ "content": {},
+ "room_id": room_id,
+ "sender": "@yetanotheruser:" + OTHER_SERVER,
+ "depth": join_event["depth"] + 1,
+ "prev_events": [join_event.event_id],
+ "auth_events": [],
+ "origin_server_ts": self.clock.time_msec(),
+ },
+ room_version,
+ )
+
+ with LoggingContext(request="send_rejected"):
+ d = run_in_background(self.handler.on_receive_pdu, OTHER_SERVER, ev)
+ self.get_success(d)
+
+ # that should have been rejected
+ e = self.get_success(self.store.get_event(ev.event_id, allow_rejected=True))
+ self.assertIsNotNone(e.rejected_reason)
+
+ # ... and the state group should be the same as before
+ sg2 = self.successResultOf(self.store._get_state_group_for_event(ev.event_id))
+
+ self.assertEqual(sg, sg2)
+
+ def test_rejected_state_event_state(self):
+ """
+ Check that we store the state group correctly for rejected state events.
+
+ Regression test for #6289.
+ """
+ OTHER_SERVER = "otherserver"
+ OTHER_USER = "@otheruser:" + OTHER_SERVER
+
+ # create the room
+ user_id = self.register_user("kermit", "test")
+ tok = self.login("kermit", "test")
+ room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
+ room_version = self.get_success(self.store.get_room_version(room_id))
+
+ # pretend that another server has joined
+ join_event = self._build_and_send_join_event(OTHER_SERVER, OTHER_USER, room_id)
+
+ # check the state group
+ sg = self.successResultOf(
+ self.store._get_state_group_for_event(join_event.event_id)
+ )
+
+ # build and send an event which will be rejected
+ ev = event_from_pdu_json(
+ {
+ "type": "org.matrix.test",
+ "state_key": "test_key",
+ "content": {},
+ "room_id": room_id,
+ "sender": "@yetanotheruser:" + OTHER_SERVER,
+ "depth": join_event["depth"] + 1,
+ "prev_events": [join_event.event_id],
+ "auth_events": [],
+ "origin_server_ts": self.clock.time_msec(),
+ },
+ room_version,
+ )
+
+ with LoggingContext(request="send_rejected"):
+ d = run_in_background(self.handler.on_receive_pdu, OTHER_SERVER, ev)
+ self.get_success(d)
+
+ # that should have been rejected
+ e = self.get_success(self.store.get_event(ev.event_id, allow_rejected=True))
+ self.assertIsNotNone(e.rejected_reason)
+
+ # ... and the state group should be the same as before
+ sg2 = self.successResultOf(self.store._get_state_group_for_event(ev.event_id))
+
+ self.assertEqual(sg, sg2)
+
+ def _build_and_send_join_event(self, other_server, other_user, room_id):
+ join_event = self.get_success(
+ self.handler.on_make_join_request(other_server, room_id, other_user)
+ )
+ # the auth code requires that a signature exists, but doesn't check that
+ # signature... go figure.
+ join_event.signatures[other_server] = {"x": "y"}
+ with LoggingContext(request="send_join"):
+ d = run_in_background(
+ self.handler.on_send_join_request, other_server, join_event
+ )
+ self.get_success(d)
+
+ # sanity-check: the room should show that the new user is a member
+ r = self.get_success(self.store.get_current_state_ids(room_id))
+ self.assertEqual(r[(EventTypes.Member, other_user)], join_event.event_id)
+
+ return join_event
+
+
+class EventFromPduTestCase(TestCase):
+ def test_valid_json(self):
+ """Valid JSON should be turned into an event."""
+ ev = event_from_pdu_json(
+ {
+ "type": EventTypes.Message,
+ "content": {"bool": True, "null": None, "int": 1, "str": "foobar"},
+ "room_id": "!room:test",
+ "sender": "@user:test",
+ "depth": 1,
+ "prev_events": [],
+ "auth_events": [],
+ "origin_server_ts": 1234,
+ },
+ RoomVersions.V6,
+ )
+
+ self.assertIsInstance(ev, EventBase)
+
+ def test_invalid_numbers(self):
+ """Invalid values for an integer should be rejected, all floats should be rejected."""
+ for value in [
+ -(2 ** 53),
+ 2 ** 53,
+ 1.0,
+ float("inf"),
+ float("-inf"),
+ float("nan"),
+ ]:
+ with self.assertRaises(SynapseError):
+ event_from_pdu_json(
+ {
+ "type": EventTypes.Message,
+ "content": {"foo": value},
+ "room_id": "!room:test",
+ "sender": "@user:test",
+ "depth": 1,
+ "prev_events": [],
+ "auth_events": [],
+ "origin_server_ts": 1234,
+ },
+ RoomVersions.V6,
+ )
+
+ def test_invalid_nested(self):
+ """List and dictionaries are recursively searched."""
+ with self.assertRaises(SynapseError):
+ event_from_pdu_json(
+ {
+ "type": EventTypes.Message,
+ "content": {"foo": [{"bar": 2 ** 56}]},
+ "room_id": "!room:test",
+ "sender": "@user:test",
+ "depth": 1,
+ "prev_events": [],
+ "auth_events": [],
+ "origin_server_ts": 1234,
+ },
+ RoomVersions.V6,
+ )
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
new file mode 100644
index 0000000000..1bb25ab684
--- /dev/null
+++ b/tests/handlers/test_oidc.py
@@ -0,0 +1,570 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Quentin Gliech
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+from urllib.parse import parse_qs, urlparse
+
+from mock import Mock, patch
+
+import attr
+import pymacaroons
+
+from twisted.internet import defer
+from twisted.python.failure import Failure
+from twisted.web._newclient import ResponseDone
+
+from synapse.handlers.oidc_handler import (
+ MappingException,
+ OidcError,
+ OidcHandler,
+ OidcMappingProvider,
+)
+from synapse.types import UserID
+
+from tests.unittest import HomeserverTestCase, override_config
+
+
+@attr.s
+class FakeResponse:
+ code = attr.ib()
+ body = attr.ib()
+ phrase = attr.ib()
+
+ def deliverBody(self, protocol):
+ protocol.dataReceived(self.body)
+ protocol.connectionLost(Failure(ResponseDone()))
+
+
+# These are a few constants that are used as config parameters in the tests.
+ISSUER = "https://issuer/"
+CLIENT_ID = "test-client-id"
+CLIENT_SECRET = "test-client-secret"
+BASE_URL = "https://synapse/"
+CALLBACK_URL = BASE_URL + "_synapse/oidc/callback"
+SCOPES = ["openid"]
+
+AUTHORIZATION_ENDPOINT = ISSUER + "authorize"
+TOKEN_ENDPOINT = ISSUER + "token"
+USERINFO_ENDPOINT = ISSUER + "userinfo"
+WELL_KNOWN = ISSUER + ".well-known/openid-configuration"
+JWKS_URI = ISSUER + ".well-known/jwks.json"
+
+# config for common cases
+COMMON_CONFIG = {
+ "discover": False,
+ "authorization_endpoint": AUTHORIZATION_ENDPOINT,
+ "token_endpoint": TOKEN_ENDPOINT,
+ "jwks_uri": JWKS_URI,
+}
+
+
+# The cookie name and path don't really matter, just that it has to be coherent
+# between the callback & redirect handlers.
+COOKIE_NAME = b"oidc_session"
+COOKIE_PATH = "/_synapse/oidc"
+
+MockedMappingProvider = Mock(OidcMappingProvider)
+
+
+def simple_async_mock(return_value=None, raises=None):
+ # AsyncMock is not available in python3.5, this mimics part of its behaviour
+ async def cb(*args, **kwargs):
+ if raises:
+ raise raises
+ return return_value
+
+ return Mock(side_effect=cb)
+
+
+async def get_json(url):
+ # Mock get_json calls to handle jwks & oidc discovery endpoints
+ if url == WELL_KNOWN:
+ # Minimal discovery document, as defined in OpenID.Discovery
+ # https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata
+ return {
+ "issuer": ISSUER,
+ "authorization_endpoint": AUTHORIZATION_ENDPOINT,
+ "token_endpoint": TOKEN_ENDPOINT,
+ "jwks_uri": JWKS_URI,
+ "userinfo_endpoint": USERINFO_ENDPOINT,
+ "response_types_supported": ["code"],
+ "subject_types_supported": ["public"],
+ "id_token_signing_alg_values_supported": ["RS256"],
+ }
+ elif url == JWKS_URI:
+ return {"keys": []}
+
+
+class OidcHandlerTestCase(HomeserverTestCase):
+ def make_homeserver(self, reactor, clock):
+
+ self.http_client = Mock(spec=["get_json"])
+ self.http_client.get_json.side_effect = get_json
+ self.http_client.user_agent = "Synapse Test"
+
+ config = self.default_config()
+ config["public_baseurl"] = BASE_URL
+ oidc_config = config.get("oidc_config", {})
+ oidc_config["enabled"] = True
+ oidc_config["client_id"] = CLIENT_ID
+ oidc_config["client_secret"] = CLIENT_SECRET
+ oidc_config["issuer"] = ISSUER
+ oidc_config["scopes"] = SCOPES
+ oidc_config["user_mapping_provider"] = {
+ "module": __name__ + ".MockedMappingProvider"
+ }
+ config["oidc_config"] = oidc_config
+
+ hs = self.setup_test_homeserver(
+ http_client=self.http_client,
+ proxied_http_client=self.http_client,
+ config=config,
+ )
+
+ self.handler = OidcHandler(hs)
+
+ return hs
+
+ def metadata_edit(self, values):
+ return patch.dict(self.handler._provider_metadata, values)
+
+ def assertRenderedError(self, error, error_description=None):
+ args = self.handler._render_error.call_args[0]
+ self.assertEqual(args[1], error)
+ if error_description is not None:
+ self.assertEqual(args[2], error_description)
+ # Reset the render_error mock
+ self.handler._render_error.reset_mock()
+
+ def test_config(self):
+ """Basic config correctly sets up the callback URL and client auth correctly."""
+ self.assertEqual(self.handler._callback_url, CALLBACK_URL)
+ self.assertEqual(self.handler._client_auth.client_id, CLIENT_ID)
+ self.assertEqual(self.handler._client_auth.client_secret, CLIENT_SECRET)
+
+ @override_config({"oidc_config": {"discover": True}})
+ @defer.inlineCallbacks
+ def test_discovery(self):
+ """The handler should discover the endpoints from OIDC discovery document."""
+ # This would throw if some metadata were invalid
+ metadata = yield defer.ensureDeferred(self.handler.load_metadata())
+ self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
+
+ self.assertEqual(metadata.issuer, ISSUER)
+ self.assertEqual(metadata.authorization_endpoint, AUTHORIZATION_ENDPOINT)
+ self.assertEqual(metadata.token_endpoint, TOKEN_ENDPOINT)
+ self.assertEqual(metadata.jwks_uri, JWKS_URI)
+ # FIXME: it seems like authlib does not have that defined in its metadata models
+ # self.assertEqual(metadata.userinfo_endpoint, USERINFO_ENDPOINT)
+
+ # subsequent calls should be cached
+ self.http_client.reset_mock()
+ yield defer.ensureDeferred(self.handler.load_metadata())
+ self.http_client.get_json.assert_not_called()
+
+ @override_config({"oidc_config": COMMON_CONFIG})
+ @defer.inlineCallbacks
+ def test_no_discovery(self):
+ """When discovery is disabled, it should not try to load from discovery document."""
+ yield defer.ensureDeferred(self.handler.load_metadata())
+ self.http_client.get_json.assert_not_called()
+
+ @override_config({"oidc_config": COMMON_CONFIG})
+ @defer.inlineCallbacks
+ def test_load_jwks(self):
+ """JWKS loading is done once (then cached) if used."""
+ jwks = yield defer.ensureDeferred(self.handler.load_jwks())
+ self.http_client.get_json.assert_called_once_with(JWKS_URI)
+ self.assertEqual(jwks, {"keys": []})
+
+ # subsequent calls should be cached…
+ self.http_client.reset_mock()
+ yield defer.ensureDeferred(self.handler.load_jwks())
+ self.http_client.get_json.assert_not_called()
+
+ # …unless forced
+ self.http_client.reset_mock()
+ yield defer.ensureDeferred(self.handler.load_jwks(force=True))
+ self.http_client.get_json.assert_called_once_with(JWKS_URI)
+
+ # Throw if the JWKS uri is missing
+ with self.metadata_edit({"jwks_uri": None}):
+ with self.assertRaises(RuntimeError):
+ yield defer.ensureDeferred(self.handler.load_jwks(force=True))
+
+ # Return empty key set if JWKS are not used
+ self.handler._scopes = [] # not asking the openid scope
+ self.http_client.get_json.reset_mock()
+ jwks = yield defer.ensureDeferred(self.handler.load_jwks(force=True))
+ self.http_client.get_json.assert_not_called()
+ self.assertEqual(jwks, {"keys": []})
+
+ @override_config({"oidc_config": COMMON_CONFIG})
+ def test_validate_config(self):
+ """Provider metadatas are extensively validated."""
+ h = self.handler
+
+ # Default test config does not throw
+ h._validate_metadata()
+
+ with self.metadata_edit({"issuer": None}):
+ self.assertRaisesRegex(ValueError, "issuer", h._validate_metadata)
+
+ with self.metadata_edit({"issuer": "http://insecure/"}):
+ self.assertRaisesRegex(ValueError, "issuer", h._validate_metadata)
+
+ with self.metadata_edit({"issuer": "https://invalid/?because=query"}):
+ self.assertRaisesRegex(ValueError, "issuer", h._validate_metadata)
+
+ with self.metadata_edit({"authorization_endpoint": None}):
+ self.assertRaisesRegex(
+ ValueError, "authorization_endpoint", h._validate_metadata
+ )
+
+ with self.metadata_edit({"authorization_endpoint": "http://insecure/auth"}):
+ self.assertRaisesRegex(
+ ValueError, "authorization_endpoint", h._validate_metadata
+ )
+
+ with self.metadata_edit({"token_endpoint": None}):
+ self.assertRaisesRegex(ValueError, "token_endpoint", h._validate_metadata)
+
+ with self.metadata_edit({"token_endpoint": "http://insecure/token"}):
+ self.assertRaisesRegex(ValueError, "token_endpoint", h._validate_metadata)
+
+ with self.metadata_edit({"jwks_uri": None}):
+ self.assertRaisesRegex(ValueError, "jwks_uri", h._validate_metadata)
+
+ with self.metadata_edit({"jwks_uri": "http://insecure/jwks.json"}):
+ self.assertRaisesRegex(ValueError, "jwks_uri", h._validate_metadata)
+
+ with self.metadata_edit({"response_types_supported": ["id_token"]}):
+ self.assertRaisesRegex(
+ ValueError, "response_types_supported", h._validate_metadata
+ )
+
+ with self.metadata_edit(
+ {"token_endpoint_auth_methods_supported": ["client_secret_basic"]}
+ ):
+ # should not throw, as client_secret_basic is the default auth method
+ h._validate_metadata()
+
+ with self.metadata_edit(
+ {"token_endpoint_auth_methods_supported": ["client_secret_post"]}
+ ):
+ self.assertRaisesRegex(
+ ValueError,
+ "token_endpoint_auth_methods_supported",
+ h._validate_metadata,
+ )
+
+ # Tests for configs that the userinfo endpoint
+ self.assertFalse(h._uses_userinfo)
+ h._scopes = [] # do not request the openid scope
+ self.assertTrue(h._uses_userinfo)
+ self.assertRaisesRegex(ValueError, "userinfo_endpoint", h._validate_metadata)
+
+ with self.metadata_edit(
+ {"userinfo_endpoint": USERINFO_ENDPOINT, "jwks_uri": None}
+ ):
+ # Shouldn't raise with a valid userinfo, even without
+ h._validate_metadata()
+
+ @override_config({"oidc_config": {"skip_verification": True}})
+ def test_skip_verification(self):
+ """Provider metadata validation can be disabled by config."""
+ with self.metadata_edit({"issuer": "http://insecure"}):
+ # This should not throw
+ self.handler._validate_metadata()
+
+ @defer.inlineCallbacks
+ def test_redirect_request(self):
+ """The redirect request has the right arguments & generates a valid session cookie."""
+ req = Mock(spec=["addCookie"])
+ url = yield defer.ensureDeferred(
+ self.handler.handle_redirect_request(req, b"http://client/redirect")
+ )
+ url = urlparse(url)
+ auth_endpoint = urlparse(AUTHORIZATION_ENDPOINT)
+
+ self.assertEqual(url.scheme, auth_endpoint.scheme)
+ self.assertEqual(url.netloc, auth_endpoint.netloc)
+ self.assertEqual(url.path, auth_endpoint.path)
+
+ params = parse_qs(url.query)
+ self.assertEqual(params["redirect_uri"], [CALLBACK_URL])
+ self.assertEqual(params["response_type"], ["code"])
+ self.assertEqual(params["scope"], [" ".join(SCOPES)])
+ self.assertEqual(params["client_id"], [CLIENT_ID])
+ self.assertEqual(len(params["state"]), 1)
+ self.assertEqual(len(params["nonce"]), 1)
+
+ # Check what is in the cookie
+ # note: python3.5 mock does not have the .called_once() method
+ calls = req.addCookie.call_args_list
+ self.assertEqual(len(calls), 1) # called once
+ # For some reason, call.args does not work with python3.5
+ args = calls[0][0]
+ kwargs = calls[0][1]
+ self.assertEqual(args[0], COOKIE_NAME)
+ self.assertEqual(kwargs["path"], COOKIE_PATH)
+ cookie = args[1]
+
+ macaroon = pymacaroons.Macaroon.deserialize(cookie)
+ state = self.handler._get_value_from_macaroon(macaroon, "state")
+ nonce = self.handler._get_value_from_macaroon(macaroon, "nonce")
+ redirect = self.handler._get_value_from_macaroon(
+ macaroon, "client_redirect_url"
+ )
+
+ self.assertEqual(params["state"], [state])
+ self.assertEqual(params["nonce"], [nonce])
+ self.assertEqual(redirect, "http://client/redirect")
+
+ @defer.inlineCallbacks
+ def test_callback_error(self):
+ """Errors from the provider returned in the callback are displayed."""
+ self.handler._render_error = Mock()
+ request = Mock(args={})
+ request.args[b"error"] = [b"invalid_client"]
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("invalid_client", "")
+
+ request.args[b"error_description"] = [b"some description"]
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("invalid_client", "some description")
+
+ @defer.inlineCallbacks
+ def test_callback(self):
+ """Code callback works and display errors if something went wrong.
+
+ A lot of scenarios are tested here:
+ - when the callback works, with userinfo from ID token
+ - when the user mapping fails
+ - when ID token verification fails
+ - when the callback works, with userinfo fetched from the userinfo endpoint
+ - when the userinfo fetching fails
+ - when the code exchange fails
+ """
+ token = {
+ "type": "bearer",
+ "id_token": "id_token",
+ "access_token": "access_token",
+ }
+ userinfo = {
+ "sub": "foo",
+ "preferred_username": "bar",
+ }
+ user_id = UserID("foo", "domain.org")
+ self.handler._render_error = Mock(return_value=None)
+ self.handler._exchange_code = simple_async_mock(return_value=token)
+ self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
+ self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
+ self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
+ self.handler._auth_handler.complete_sso_login = simple_async_mock()
+ request = Mock(spec=["args", "getCookie", "addCookie"])
+
+ code = "code"
+ state = "state"
+ nonce = "nonce"
+ client_redirect_url = "http://client/redirect"
+ session = self.handler._generate_oidc_session_token(
+ state=state,
+ nonce=nonce,
+ client_redirect_url=client_redirect_url,
+ ui_auth_session_id=None,
+ )
+ request.getCookie.return_value = session
+
+ request.args = {}
+ request.args[b"code"] = [code.encode("utf-8")]
+ request.args[b"state"] = [state.encode("utf-8")]
+
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+
+ self.handler._auth_handler.complete_sso_login.assert_called_once_with(
+ user_id, request, client_redirect_url,
+ )
+ self.handler._exchange_code.assert_called_once_with(code)
+ self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce)
+ self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token)
+ self.handler._fetch_userinfo.assert_not_called()
+ self.handler._render_error.assert_not_called()
+
+ # Handle mapping errors
+ self.handler._map_userinfo_to_user = simple_async_mock(
+ raises=MappingException()
+ )
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("mapping_error")
+ self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
+
+ # Handle ID token errors
+ self.handler._parse_id_token = simple_async_mock(raises=Exception())
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("invalid_token")
+
+ self.handler._auth_handler.complete_sso_login.reset_mock()
+ self.handler._exchange_code.reset_mock()
+ self.handler._parse_id_token.reset_mock()
+ self.handler._map_userinfo_to_user.reset_mock()
+ self.handler._fetch_userinfo.reset_mock()
+
+ # With userinfo fetching
+ self.handler._scopes = [] # do not ask the "openid" scope
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+
+ self.handler._auth_handler.complete_sso_login.assert_called_once_with(
+ user_id, request, client_redirect_url,
+ )
+ self.handler._exchange_code.assert_called_once_with(code)
+ self.handler._parse_id_token.assert_not_called()
+ self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token)
+ self.handler._fetch_userinfo.assert_called_once_with(token)
+ self.handler._render_error.assert_not_called()
+
+ # Handle userinfo fetching error
+ self.handler._fetch_userinfo = simple_async_mock(raises=Exception())
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("fetch_error")
+
+ # Handle code exchange failure
+ self.handler._exchange_code = simple_async_mock(
+ raises=OidcError("invalid_request")
+ )
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("invalid_request")
+
+ @defer.inlineCallbacks
+ def test_callback_session(self):
+ """The callback verifies the session presence and validity"""
+ self.handler._render_error = Mock(return_value=None)
+ request = Mock(spec=["args", "getCookie", "addCookie"])
+
+ # Missing cookie
+ request.args = {}
+ request.getCookie.return_value = None
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("missing_session", "No session cookie found")
+
+ # Missing session parameter
+ request.args = {}
+ request.getCookie.return_value = "session"
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("invalid_request", "State parameter is missing")
+
+ # Invalid cookie
+ request.args = {}
+ request.args[b"state"] = [b"state"]
+ request.getCookie.return_value = "session"
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("invalid_session")
+
+ # Mismatching session
+ session = self.handler._generate_oidc_session_token(
+ state="state",
+ nonce="nonce",
+ client_redirect_url="http://client/redirect",
+ ui_auth_session_id=None,
+ )
+ request.args = {}
+ request.args[b"state"] = [b"mismatching state"]
+ request.getCookie.return_value = session
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("mismatching_session")
+
+ # Valid session
+ request.args = {}
+ request.args[b"state"] = [b"state"]
+ request.getCookie.return_value = session
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("invalid_request")
+
+ @override_config({"oidc_config": {"client_auth_method": "client_secret_post"}})
+ @defer.inlineCallbacks
+ def test_exchange_code(self):
+ """Code exchange behaves correctly and handles various error scenarios."""
+ token = {"type": "bearer"}
+ token_json = json.dumps(token).encode("utf-8")
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse(code=200, phrase=b"OK", body=token_json)
+ )
+ code = "code"
+ ret = yield defer.ensureDeferred(self.handler._exchange_code(code))
+ kwargs = self.http_client.request.call_args[1]
+
+ self.assertEqual(ret, token)
+ self.assertEqual(kwargs["method"], "POST")
+ self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT)
+
+ args = parse_qs(kwargs["data"].decode("utf-8"))
+ self.assertEqual(args["grant_type"], ["authorization_code"])
+ self.assertEqual(args["code"], [code])
+ self.assertEqual(args["client_id"], [CLIENT_ID])
+ self.assertEqual(args["client_secret"], [CLIENT_SECRET])
+ self.assertEqual(args["redirect_uri"], [CALLBACK_URL])
+
+ # Test error handling
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse(
+ code=400,
+ phrase=b"Bad Request",
+ body=b'{"error": "foo", "error_description": "bar"}',
+ )
+ )
+ with self.assertRaises(OidcError) as exc:
+ yield defer.ensureDeferred(self.handler._exchange_code(code))
+ self.assertEqual(exc.exception.error, "foo")
+ self.assertEqual(exc.exception.error_description, "bar")
+
+ # Internal server error with no JSON body
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse(
+ code=500, phrase=b"Internal Server Error", body=b"Not JSON",
+ )
+ )
+ with self.assertRaises(OidcError) as exc:
+ yield defer.ensureDeferred(self.handler._exchange_code(code))
+ self.assertEqual(exc.exception.error, "server_error")
+
+ # Internal server error with JSON body
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse(
+ code=500,
+ phrase=b"Internal Server Error",
+ body=b'{"error": "internal_server_error"}',
+ )
+ )
+ with self.assertRaises(OidcError) as exc:
+ yield defer.ensureDeferred(self.handler._exchange_code(code))
+ self.assertEqual(exc.exception.error, "internal_server_error")
+
+ # 4xx error without "error" field
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse(code=400, phrase=b"Bad request", body=b"{}",)
+ )
+ with self.assertRaises(OidcError) as exc:
+ yield defer.ensureDeferred(self.handler._exchange_code(code))
+ self.assertEqual(exc.exception.error, "server_error")
+
+ # 2xx error with "error" field
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse(
+ code=200, phrase=b"OK", body=b'{"error": "some_error"}',
+ )
+ )
+ with self.assertRaises(OidcError) as exc:
+ yield defer.ensureDeferred(self.handler._exchange_code(code))
+ self.assertEqual(exc.exception.error, "some_error")
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index f70c6e7d65..05ea40a7de 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -19,9 +19,10 @@ from mock import Mock, call
from signedjson.key import generate_signing_key
from synapse.api.constants import EventTypes, Membership, PresenceState
-from synapse.events import room_version_to_event_format
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events.builder import EventBuilder
from synapse.handlers.presence import (
+ EXTERNAL_PROCESS_EXPIRY,
FEDERATION_PING_INTERVAL,
FEDERATION_TIMEOUT,
IDLE_TIMER,
@@ -337,7 +338,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
)
new_state = handle_timeout(
- state, is_mine=True, syncing_user_ids=set([user_id]), now=now
+ state, is_mine=True, syncing_user_ids={user_id}, now=now
)
self.assertIsNotNone(new_state)
@@ -413,6 +414,44 @@ class PresenceTimeoutTestCase(unittest.TestCase):
self.assertEquals(state, new_state)
+class PresenceHandlerTestCase(unittest.HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
+ self.presence_handler = hs.get_presence_handler()
+ self.clock = hs.get_clock()
+
+ def test_external_process_timeout(self):
+ """Test that if an external process doesn't update the records for a while
+ we time out their syncing users presence.
+ """
+ process_id = 1
+ user_id = "@test:server"
+
+ # Notify handler that a user is now syncing.
+ self.get_success(
+ self.presence_handler.update_external_syncs_row(
+ process_id, user_id, True, self.clock.time_msec()
+ )
+ )
+
+ # Check that if we wait a while without telling the handler the user has
+ # stopped syncing that their presence state doesn't get timed out.
+ self.reactor.advance(EXTERNAL_PROCESS_EXPIRY / 2)
+
+ state = self.get_success(
+ self.presence_handler.get_state(UserID.from_string(user_id))
+ )
+ self.assertEqual(state.state, PresenceState.ONLINE)
+
+ # Check that if the external process timeout fires, then the syncing
+ # user gets timed out
+ self.reactor.advance(EXTERNAL_PROCESS_EXPIRY)
+
+ state = self.get_success(
+ self.presence_handler.get_state(UserID.from_string(user_id))
+ )
+ self.assertEqual(state.state, PresenceState.OFFLINE)
+
+
class PresenceJoinTestCase(unittest.HomeserverTestCase):
"""Tests remote servers get told about presence of users in the room when
they join and when new local users join.
@@ -455,8 +494,10 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
self.helper.join(room_id, "@test2:server")
# Mark test2 as online, test will be offline with a last_active of 0
- self.presence_handler.set_state(
- UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE}
+ self.get_success(
+ self.presence_handler.set_state(
+ UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE}
+ )
)
self.reactor.pump([0]) # Wait for presence updates to be handled
@@ -504,14 +545,18 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
room_id = self.helper.create_room_as(self.user_id)
# Mark test as online
- self.presence_handler.set_state(
- UserID.from_string("@test:server"), {"presence": PresenceState.ONLINE}
+ self.get_success(
+ self.presence_handler.set_state(
+ UserID.from_string("@test:server"), {"presence": PresenceState.ONLINE}
+ )
)
# Mark test2 as online, test will be offline with a last_active of 0.
# Note we don't join them to the room yet
- self.presence_handler.set_state(
- UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE}
+ self.get_success(
+ self.presence_handler.set_state(
+ UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE}
+ )
)
# Add servers to the room
@@ -540,7 +585,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(expected_state.state, PresenceState.ONLINE)
self.federation_sender.send_presence_to_destinations.assert_called_once_with(
- destinations=set(("server2", "server3")), states=[expected_state]
+ destinations={"server2", "server3"}, states=[expected_state]
)
def _add_new_user(self, room_id, user_id):
@@ -549,7 +594,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
hostname = get_domain_from_id(user_id)
- room_version = self.get_success(self.store.get_room_version(room_id))
+ room_version = self.get_success(self.store.get_room_version_id(room_id))
builder = EventBuilder(
state=self.state,
@@ -558,7 +603,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
clock=self.clock,
hostname=hostname,
signing_key=self.random_signing_key,
- format_version=room_version_to_event_format(room_version),
+ room_version=KNOWN_ROOM_VERSIONS[room_version],
room_id=room_id,
type=EventTypes.Member,
sender=user_id,
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index d60c124eec..29dd7d9c6e 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -14,12 +14,12 @@
# limitations under the License.
-from mock import Mock, NonCallableMock
+from mock import Mock
from twisted.internet import defer
import synapse.types
-from synapse.api.errors import AuthError
+from synapse.api.errors import AuthError, SynapseError
from synapse.handlers.profile import MasterProfileHandler
from synapse.types import UserID
@@ -55,12 +55,8 @@ class ProfileTestCase(unittest.TestCase):
federation_client=self.mock_federation,
federation_server=Mock(),
federation_registry=self.mock_registry,
- ratelimiter=NonCallableMock(spec_set=["can_do_action"]),
)
- self.ratelimiter = hs.get_ratelimiter()
- self.ratelimiter.can_do_action.return_value = (True, 0)
-
self.store = hs.get_datastore()
self.frank = UserID.from_string("@1234ABCD:test")
@@ -70,6 +66,7 @@ class ProfileTestCase(unittest.TestCase):
yield self.store.create_profile(self.frank.localpart)
self.handler = hs.get_profile_handler()
+ self.hs = hs
@defer.inlineCallbacks
def test_get_my_name(self):
@@ -81,19 +78,58 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_set_my_name(self):
- yield self.handler.set_displayname(
- self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
+ yield defer.ensureDeferred(
+ self.handler.set_displayname(
+ self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
+ )
)
self.assertEquals(
- (yield self.store.get_profile_displayname(self.frank.localpart)),
+ (
+ yield defer.ensureDeferred(
+ self.store.get_profile_displayname(self.frank.localpart)
+ )
+ ),
"Frank Jr.",
)
+ # Set displayname again
+ yield defer.ensureDeferred(
+ self.handler.set_displayname(
+ self.frank, synapse.types.create_requester(self.frank), "Frank"
+ )
+ )
+
+ self.assertEquals(
+ (yield self.store.get_profile_displayname(self.frank.localpart)), "Frank",
+ )
+
+ @defer.inlineCallbacks
+ def test_set_my_name_if_disabled(self):
+ self.hs.config.enable_set_displayname = False
+
+ # Setting displayname for the first time is allowed
+ yield self.store.set_profile_displayname(self.frank.localpart, "Frank")
+
+ self.assertEquals(
+ (yield self.store.get_profile_displayname(self.frank.localpart)), "Frank",
+ )
+
+ # Setting displayname a second time is forbidden
+ d = defer.ensureDeferred(
+ self.handler.set_displayname(
+ self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
+ )
+ )
+
+ yield self.assertFailure(d, SynapseError)
+
@defer.inlineCallbacks
def test_set_my_name_noauth(self):
- d = self.handler.set_displayname(
- self.frank, synapse.types.create_requester(self.bob), "Frank Jr."
+ d = defer.ensureDeferred(
+ self.handler.set_displayname(
+ self.frank, synapse.types.create_requester(self.bob), "Frank Jr."
+ )
)
yield self.assertFailure(d, AuthError)
@@ -137,13 +173,54 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_set_my_avatar(self):
- yield self.handler.set_avatar_url(
- self.frank,
- synapse.types.create_requester(self.frank),
- "http://my.server/pic.gif",
+ yield defer.ensureDeferred(
+ self.handler.set_avatar_url(
+ self.frank,
+ synapse.types.create_requester(self.frank),
+ "http://my.server/pic.gif",
+ )
)
self.assertEquals(
(yield self.store.get_profile_avatar_url(self.frank.localpart)),
"http://my.server/pic.gif",
)
+
+ # Set avatar again
+ yield defer.ensureDeferred(
+ self.handler.set_avatar_url(
+ self.frank,
+ synapse.types.create_requester(self.frank),
+ "http://my.server/me.png",
+ )
+ )
+
+ self.assertEquals(
+ (yield self.store.get_profile_avatar_url(self.frank.localpart)),
+ "http://my.server/me.png",
+ )
+
+ @defer.inlineCallbacks
+ def test_set_my_avatar_if_disabled(self):
+ self.hs.config.enable_set_avatar_url = False
+
+ # Setting displayname for the first time is allowed
+ yield self.store.set_profile_avatar_url(
+ self.frank.localpart, "http://my.server/me.png"
+ )
+
+ self.assertEquals(
+ (yield self.store.get_profile_avatar_url(self.frank.localpart)),
+ "http://my.server/me.png",
+ )
+
+ # Set avatar a second time is forbidden
+ d = defer.ensureDeferred(
+ self.handler.set_avatar_url(
+ self.frank,
+ synapse.types.create_requester(self.frank),
+ "http://my.server/pic.gif",
+ )
+ )
+
+ yield self.assertFailure(d, SynapseError)
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index 1e9ba3a201..ca32f993a3 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -34,7 +34,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
""" Tests the RegistrationHandler. """
def make_homeserver(self, reactor, clock):
- hs_config = self.default_config("test")
+ hs_config = self.default_config()
# some of the tests rely on us having a user consent version
hs_config["user_consent"] = {
@@ -135,6 +135,16 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.handler.register_user(localpart="local_part"), ResourceLimitError
)
+ def test_auto_join_rooms_for_guests(self):
+ room_alias_str = "#room:test"
+ self.hs.config.auto_join_rooms = [room_alias_str]
+ self.hs.config.auto_join_rooms_for_guests = False
+ user_id = self.get_success(
+ self.handler.register_user(localpart="jeff", make_guest=True),
+ )
+ rooms = self.get_success(self.store.get_rooms_for_user(user_id))
+ self.assertEqual(len(rooms), 0)
+
def test_auto_create_auto_join_rooms(self):
room_alias_str = "#room:test"
self.hs.config.auto_join_rooms = [room_alias_str]
@@ -175,7 +185,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
room_alias_str = "#room:test"
self.hs.config.auto_join_rooms = [room_alias_str]
- self.store.is_real_user = Mock(return_value=False)
+ self.store.is_real_user = Mock(return_value=defer.succeed(False))
user_id = self.get_success(self.handler.register_user(localpart="support"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
self.assertEqual(len(rooms), 0)
@@ -187,8 +197,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
room_alias_str = "#room:test"
self.hs.config.auto_join_rooms = [room_alias_str]
- self.store.count_real_users = Mock(return_value=1)
- self.store.is_real_user = Mock(return_value=True)
+ self.store.count_real_users = Mock(return_value=defer.succeed(1))
+ self.store.is_real_user = Mock(return_value=defer.succeed(True))
user_id = self.get_success(self.handler.register_user(localpart="real"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
directory_handler = self.hs.get_handlers().directory_handler
@@ -202,8 +212,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
room_alias_str = "#room:test"
self.hs.config.auto_join_rooms = [room_alias_str]
- self.store.count_real_users = Mock(return_value=2)
- self.store.is_real_user = Mock(return_value=True)
+ self.store.count_real_users = Mock(return_value=defer.succeed(2))
+ self.store.is_real_user = Mock(return_value=defer.succeed(True))
user_id = self.get_success(self.handler.register_user(localpart="real"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
self.assertEqual(len(rooms), 0)
@@ -256,8 +266,9 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.handler.register_user(localpart=invalid_user_id), SynapseError
)
- @defer.inlineCallbacks
- def get_or_create_user(self, requester, localpart, displayname, password_hash=None):
+ async def get_or_create_user(
+ self, requester, localpart, displayname, password_hash=None
+ ):
"""Creates a new user if the user does not exist,
else revokes all previous access tokens and generates a new one.
@@ -269,16 +280,14 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
one will be randomly generated.
Returns:
A tuple of (user_id, access_token).
- Raises:
- RegistrationError if there was a problem registering.
"""
if localpart is None:
raise SynapseError(400, "Request must include user id")
- yield self.hs.get_auth().check_auth_blocking()
+ await self.hs.get_auth().check_auth_blocking()
need_register = True
try:
- yield self.handler.check_username(localpart)
+ await self.handler.check_username(localpart)
except SynapseError as e:
if e.errcode == Codes.USER_IN_USE:
need_register = False
@@ -290,21 +299,21 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
token = self.macaroon_generator.generate_access_token(user_id)
if need_register:
- yield self.handler.register_with_store(
+ await self.handler.register_with_store(
user_id=user_id,
password_hash=password_hash,
create_profile_with_displayname=user.localpart,
)
else:
- yield self.hs.get_auth_handler().delete_access_tokens_for_user(user_id)
+ await self.hs.get_auth_handler().delete_access_tokens_for_user(user_id)
- yield self.store.add_access_token_to_user(
+ await self.store.add_access_token_to_user(
user_id=user_id, token=token, device_id=None, valid_until_ms=None
)
if displayname is not None:
# logger.info("setting user display name: %s -> %s", user_id, displayname)
- yield self.hs.get_profile_handler().set_displayname(
+ await self.hs.get_profile_handler().set_displayname(
user, requester, displayname, by_admin=True
)
diff --git a/tests/handlers/test_roomlist.py b/tests/handlers/test_roomlist.py
deleted file mode 100644
index 61eebb6985..0000000000
--- a/tests/handlers/test_roomlist.py
+++ /dev/null
@@ -1,39 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2018 New Vector Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from synapse.handlers.room_list import RoomListNextBatch
-
-import tests.unittest
-import tests.utils
-
-
-class RoomListTestCase(tests.unittest.TestCase):
- """ Tests RoomList's RoomListNextBatch. """
-
- def setUp(self):
- pass
-
- def test_check_read_batch_tokens(self):
- batch_token = RoomListNextBatch(
- stream_ordering="abcdef",
- public_room_stream_id="123",
- current_limit=20,
- direction_is_forward=True,
- ).to_token()
- next_batch = RoomListNextBatch.from_token(batch_token)
- self.assertEquals(next_batch.stream_ordering, "abcdef")
- self.assertEquals(next_batch.public_room_stream_id, "123")
- self.assertEquals(next_batch.current_limit, 20)
- self.assertEquals(next_batch.direction_is_forward, True)
diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py
index 7569b6fab5..d9d312f0fb 100644
--- a/tests/handlers/test_stats.py
+++ b/tests/handlers/test_stats.py
@@ -13,9 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse import storage
from synapse.rest import admin
from synapse.rest.client.v1 import login, room
+from synapse.storage.data_stores.main import stats
from tests import unittest
@@ -42,16 +42,16 @@ class StatsRoomTests(unittest.HomeserverTestCase):
Add the background updates we need to run.
"""
# Ugh, have to reset this flag
- self.store._all_done = False
+ self.store.db.updates._all_done = False
self.get_success(
- self.store._simple_insert(
+ self.store.db.simple_insert(
"background_updates",
{"update_name": "populate_stats_prepare", "progress_json": "{}"},
)
)
self.get_success(
- self.store._simple_insert(
+ self.store.db.simple_insert(
"background_updates",
{
"update_name": "populate_stats_process_rooms",
@@ -61,7 +61,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
)
)
self.get_success(
- self.store._simple_insert(
+ self.store.db.simple_insert(
"background_updates",
{
"update_name": "populate_stats_process_users",
@@ -71,7 +71,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
)
)
self.get_success(
- self.store._simple_insert(
+ self.store.db.simple_insert(
"background_updates",
{
"update_name": "populate_stats_cleanup",
@@ -82,21 +82,21 @@ class StatsRoomTests(unittest.HomeserverTestCase):
)
def get_all_room_state(self):
- return self.store._simple_select_list(
+ return self.store.db.simple_select_list(
"room_stats_state", None, retcols=("name", "topic", "canonical_alias")
)
def _get_current_stats(self, stats_type, stat_id):
- table, id_col = storage.stats.TYPE_TO_TABLE[stats_type]
+ table, id_col = stats.TYPE_TO_TABLE[stats_type]
- cols = list(storage.stats.ABSOLUTE_STATS_FIELDS[stats_type]) + list(
- storage.stats.PER_SLICE_FIELDS[stats_type]
+ cols = list(stats.ABSOLUTE_STATS_FIELDS[stats_type]) + list(
+ stats.PER_SLICE_FIELDS[stats_type]
)
end_ts = self.store.quantise_stats_time(self.reactor.seconds() * 1000)
return self.get_success(
- self.store._simple_select_one(
+ self.store.db.simple_select_one(
table + "_historical",
{id_col: stat_id, end_ts: end_ts},
cols,
@@ -108,8 +108,12 @@ class StatsRoomTests(unittest.HomeserverTestCase):
# Do the initial population of the stats via the background update
self._add_background_updates()
- while not self.get_success(self.store.has_completed_background_updates()):
- self.get_success(self.store.do_next_background_update(100), by=0.1)
+ while not self.get_success(
+ self.store.db.updates.has_completed_background_updates()
+ ):
+ self.get_success(
+ self.store.db.updates.do_next_background_update(100), by=0.1
+ )
def test_initial_room(self):
"""
@@ -141,8 +145,12 @@ class StatsRoomTests(unittest.HomeserverTestCase):
# Do the initial population of the user directory via the background update
self._add_background_updates()
- while not self.get_success(self.store.has_completed_background_updates()):
- self.get_success(self.store.do_next_background_update(100), by=0.1)
+ while not self.get_success(
+ self.store.db.updates.has_completed_background_updates()
+ ):
+ self.get_success(
+ self.store.db.updates.do_next_background_update(100), by=0.1
+ )
r = self.get_success(self.get_all_room_state())
@@ -178,9 +186,9 @@ class StatsRoomTests(unittest.HomeserverTestCase):
# the position that the deltas should begin at, once they take over.
self.hs.config.stats_enabled = True
self.handler.stats_enabled = True
- self.store._all_done = False
+ self.store.db.updates._all_done = False
self.get_success(
- self.store._simple_update_one(
+ self.store.db.simple_update_one(
table="stats_incremental_position",
keyvalues={},
updatevalues={"stream_id": 0},
@@ -188,14 +196,18 @@ class StatsRoomTests(unittest.HomeserverTestCase):
)
self.get_success(
- self.store._simple_insert(
+ self.store.db.simple_insert(
"background_updates",
{"update_name": "populate_stats_prepare", "progress_json": "{}"},
)
)
- while not self.get_success(self.store.has_completed_background_updates()):
- self.get_success(self.store.do_next_background_update(100), by=0.1)
+ while not self.get_success(
+ self.store.db.updates.has_completed_background_updates()
+ ):
+ self.get_success(
+ self.store.db.updates.do_next_background_update(100), by=0.1
+ )
# Now, before the table is actually ingested, add some more events.
self.helper.invite(room=room_1, src=u1, targ=u2, tok=u1_token)
@@ -205,13 +217,13 @@ class StatsRoomTests(unittest.HomeserverTestCase):
# Now do the initial ingestion.
self.get_success(
- self.store._simple_insert(
+ self.store.db.simple_insert(
"background_updates",
{"update_name": "populate_stats_process_rooms", "progress_json": "{}"},
)
)
self.get_success(
- self.store._simple_insert(
+ self.store.db.simple_insert(
"background_updates",
{
"update_name": "populate_stats_cleanup",
@@ -221,9 +233,13 @@ class StatsRoomTests(unittest.HomeserverTestCase):
)
)
- self.store._all_done = False
- while not self.get_success(self.store.has_completed_background_updates()):
- self.get_success(self.store.do_next_background_update(100), by=0.1)
+ self.store.db.updates._all_done = False
+ while not self.get_success(
+ self.store.db.updates.has_completed_background_updates()
+ ):
+ self.get_success(
+ self.store.db.updates.do_next_background_update(100), by=0.1
+ )
self.reactor.advance(86401)
@@ -607,6 +623,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
"""
self.hs.config.stats_enabled = False
+ self.handler.stats_enabled = False
u1 = self.register_user("u1", "pass")
u1token = self.login("u1", "pass")
@@ -618,6 +635,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
self.assertIsNone(self._get_current_stats("user", u1))
self.hs.config.stats_enabled = True
+ self.handler.stats_enabled = True
self._perform_background_initial_update()
@@ -651,15 +669,15 @@ class StatsRoomTests(unittest.HomeserverTestCase):
# preparation stage of the initial background update
# Ugh, have to reset this flag
- self.store._all_done = False
+ self.store.db.updates._all_done = False
self.get_success(
- self.store._simple_delete(
+ self.store.db.simple_delete(
"room_stats_current", {"1": 1}, "test_delete_stats"
)
)
self.get_success(
- self.store._simple_delete(
+ self.store.db.simple_delete(
"user_stats_current", {"1": 1}, "test_delete_stats"
)
)
@@ -671,9 +689,9 @@ class StatsRoomTests(unittest.HomeserverTestCase):
# now do the background updates
- self.store._all_done = False
+ self.store.db.updates._all_done = False
self.get_success(
- self.store._simple_insert(
+ self.store.db.simple_insert(
"background_updates",
{
"update_name": "populate_stats_process_rooms",
@@ -683,7 +701,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
)
)
self.get_success(
- self.store._simple_insert(
+ self.store.db.simple_insert(
"background_updates",
{
"update_name": "populate_stats_process_users",
@@ -693,7 +711,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
)
)
self.get_success(
- self.store._simple_insert(
+ self.store.db.simple_insert(
"background_updates",
{
"update_name": "populate_stats_cleanup",
@@ -703,8 +721,12 @@ class StatsRoomTests(unittest.HomeserverTestCase):
)
)
- while not self.get_success(self.store.has_completed_background_updates()):
- self.get_success(self.store.do_next_background_update(100), by=0.1)
+ while not self.get_success(
+ self.store.db.updates.has_completed_background_updates()
+ ):
+ self.get_success(
+ self.store.db.updates.do_next_background_update(100), by=0.1
+ )
r1stats_complete = self._get_current_stats("room", r1)
u1stats_complete = self._get_current_stats("user", u1)
diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index 31f54bbd7d..e178d7765b 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -12,54 +12,56 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
from synapse.api.errors import Codes, ResourceLimitError
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION
-from synapse.handlers.sync import SyncConfig, SyncHandler
+from synapse.handlers.sync import SyncConfig
from synapse.types import UserID
import tests.unittest
import tests.utils
-from tests.utils import setup_test_homeserver
-class SyncTestCase(tests.unittest.TestCase):
+class SyncTestCase(tests.unittest.HomeserverTestCase):
""" Tests Sync Handler. """
- @defer.inlineCallbacks
- def setUp(self):
- self.hs = yield setup_test_homeserver(self.addCleanup)
- self.sync_handler = SyncHandler(self.hs)
+ def prepare(self, reactor, clock, hs):
+ self.hs = hs
+ self.sync_handler = self.hs.get_sync_handler()
self.store = self.hs.get_datastore()
- @defer.inlineCallbacks
- def test_wait_for_sync_for_user_auth_blocking(self):
+ # AuthBlocking reads from the hs' config on initialization. We need to
+ # modify its config instead of the hs'
+ self.auth_blocking = self.hs.get_auth()._auth_blocking
- user_id1 = "@user1:server"
- user_id2 = "@user2:server"
+ def test_wait_for_sync_for_user_auth_blocking(self):
+ user_id1 = "@user1:test"
+ user_id2 = "@user2:test"
sync_config = self._generate_sync_config(user_id1)
- self.hs.config.limit_usage_by_mau = True
- self.hs.config.max_mau_value = 1
+ self.reactor.advance(100) # So we get not 0 time
+ self.auth_blocking._limit_usage_by_mau = True
+ self.auth_blocking._max_mau_value = 1
# Check that the happy case does not throw errors
- yield self.store.upsert_monthly_active_user(user_id1)
- yield self.sync_handler.wait_for_sync_for_user(sync_config)
+ self.get_success(self.store.upsert_monthly_active_user(user_id1))
+ self.get_success(self.sync_handler.wait_for_sync_for_user(sync_config))
# Test that global lock works
- self.hs.config.hs_disabled = True
- with self.assertRaises(ResourceLimitError) as e:
- yield self.sync_handler.wait_for_sync_for_user(sync_config)
- self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+ self.auth_blocking._hs_disabled = True
+ e = self.get_failure(
+ self.sync_handler.wait_for_sync_for_user(sync_config), ResourceLimitError
+ )
+ self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
- self.hs.config.hs_disabled = False
+ self.auth_blocking._hs_disabled = False
sync_config = self._generate_sync_config(user_id2)
- with self.assertRaises(ResourceLimitError) as e:
- yield self.sync_handler.wait_for_sync_for_user(sync_config)
- self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+ e = self.get_failure(
+ self.sync_handler.wait_for_sync_for_user(sync_config), ResourceLimitError
+ )
+ self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
def _generate_sync_config(self, user_id):
return SyncConfig(
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 1f2ef5d01f..2fa8d4739b 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -24,6 +24,7 @@ from synapse.api.errors import AuthError
from synapse.types import UserID
from tests import unittest
+from tests.unittest import override_config
from tests.utils import register_federation_servlets
# Some local users to test with
@@ -63,34 +64,39 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
mock_federation_client = Mock(spec=["put_json"])
mock_federation_client.put_json.return_value = defer.succeed((200, "OK"))
+ datastores = Mock()
+ datastores.main = Mock(
+ spec=[
+ # Bits that Federation needs
+ "prep_send_transaction",
+ "delivered_txn",
+ "get_received_txn_response",
+ "set_received_txn_response",
+ "get_destination_retry_timings",
+ "get_devices_by_remote",
+ "maybe_store_room_on_invite",
+ # Bits that user_directory needs
+ "get_user_directory_stream_pos",
+ "get_current_state_deltas",
+ "get_device_updates_by_remote",
+ ]
+ )
+
+ # the tests assume that we are starting at unix time 1000
+ reactor.pump((1000,))
+
hs = self.setup_test_homeserver(
- datastore=(
- Mock(
- spec=[
- # Bits that Federation needs
- "prep_send_transaction",
- "delivered_txn",
- "get_received_txn_response",
- "set_received_txn_response",
- "get_destination_retry_timings",
- "get_devices_by_remote",
- # Bits that user_directory needs
- "get_user_directory_stream_pos",
- "get_current_state_deltas",
- ]
- )
- ),
notifier=Mock(),
http_client=mock_federation_client,
keyring=mock_keyring,
+ replication_streams={},
)
+ hs.datastores = datastores
+
return hs
def prepare(self, reactor, clock, hs):
- # the tests assume that we are starting at unix time 1000
- reactor.pump((1000,))
-
mock_notifier = hs.get_notifier()
self.on_new_event = mock_notifier.on_new_event
@@ -109,7 +115,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
retry_timings_res
)
- self.datastore.get_devices_by_remote.return_value = (0, [])
+ self.datastore.get_device_updates_by_remote.return_value = defer.succeed(
+ (0, [])
+ )
def get_received_txn_response(*args):
return defer.succeed(None)
@@ -118,19 +126,19 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.room_members = []
- def check_joined_room(room_id, user_id):
+ def check_user_in_room(room_id, user_id):
if user_id not in [u.to_string() for u in self.room_members]:
raise AuthError(401, "User is not in the room")
- hs.get_auth().check_joined_room = check_joined_room
+ hs.get_auth().check_user_in_room = check_user_in_room
def get_joined_hosts_for_room(room_id):
- return set(member.domain for member in self.room_members)
+ return {member.domain for member in self.room_members}
self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room
def get_current_users_in_room(room_id):
- return set(str(u) for u in self.room_members)
+ return {str(u) for u in self.room_members}
hs.get_state_handler().get_current_users_in_room = get_current_users_in_room
@@ -139,11 +147,16 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
defer.succeed(1)
)
- self.datastore.get_current_state_deltas.return_value = None
+ self.datastore.get_current_state_deltas.return_value = (0, None)
self.datastore.get_to_device_stream_token = lambda: 0
- self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: ([], 0)
+ self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: defer.succeed(
+ ([], 0)
+ )
self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: None
+ self.datastore.set_received_txn_response = lambda *args, **kwargs: defer.succeed(
+ None
+ )
def test_started_typing_local(self):
self.room_members = [U_APPLE, U_BANANA]
@@ -159,7 +172,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])])
self.assertEquals(self.event_source.get_current_key(), 1)
- events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+ events = self.get_success(
+ self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+ )
self.assertEquals(
events[0],
[
@@ -171,6 +186,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
],
)
+ @override_config({"send_federation": True})
def test_started_typing_remote_send(self):
self.room_members = [U_APPLE, U_ONION]
@@ -222,7 +238,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])])
self.assertEquals(self.event_source.get_current_key(), 1)
- events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+ events = self.get_success(
+ self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+ )
self.assertEquals(
events[0],
[
@@ -234,6 +252,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
],
)
+ @override_config({"send_federation": True})
def test_stopped_typing(self):
self.room_members = [U_APPLE, U_BANANA, U_ONION]
@@ -242,7 +261,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
member = RoomMember(ROOM_ID, U_APPLE.to_string())
self.handler._member_typing_until[member] = 1002000
- self.handler._room_typing[ROOM_ID] = set([U_APPLE.to_string()])
+ self.handler._room_typing[ROOM_ID] = {U_APPLE.to_string()}
self.assertEquals(self.event_source.get_current_key(), 0)
@@ -273,7 +292,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
)
self.assertEquals(self.event_source.get_current_key(), 1)
- events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+ events = self.get_success(
+ self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+ )
self.assertEquals(
events[0],
[{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}],
@@ -294,7 +315,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.on_new_event.reset_mock()
self.assertEquals(self.event_source.get_current_key(), 1)
- events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+ events = self.get_success(
+ self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+ )
self.assertEquals(
events[0],
[
@@ -311,7 +334,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.on_new_event.assert_has_calls([call("typing_key", 2, rooms=[ROOM_ID])])
self.assertEquals(self.event_source.get_current_key(), 2)
- events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=1)
+ events = self.get_success(
+ self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=1)
+ )
self.assertEquals(
events[0],
[{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}],
@@ -329,7 +354,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.on_new_event.reset_mock()
self.assertEquals(self.event_source.get_current_key(), 3)
- events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+ events = self.get_success(
+ self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+ )
self.assertEquals(
events[0],
[
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index c5e91a8c41..c15bce5bef 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -14,6 +14,8 @@
# limitations under the License.
from mock import Mock
+from twisted.internet import defer
+
import synapse.rest.admin
from synapse.api.constants import UserTypes
from synapse.rest.client.v1 import login, room
@@ -75,18 +77,16 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
)
)
- self.store.remove_from_user_dir = Mock()
- self.store.remove_from_user_in_public_room = Mock()
+ self.store.remove_from_user_dir = Mock(return_value=defer.succeed(None))
self.get_success(self.handler.handle_user_deactivated(s_user_id))
self.store.remove_from_user_dir.not_called()
- self.store.remove_from_user_in_public_room.not_called()
def test_handle_user_deactivated_regular_user(self):
r_user_id = "@regular:test"
self.get_success(
self.store.register_user(user_id=r_user_id, password_hash=None)
)
- self.store.remove_from_user_dir = Mock()
+ self.store.remove_from_user_dir = Mock(return_value=defer.succeed(None))
self.get_success(self.handler.handle_user_deactivated(r_user_id))
self.store.remove_from_user_dir.called_once_with(r_user_id)
@@ -114,7 +114,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
public_users = self.get_users_in_public_rooms()
self.assertEqual(
- self._compress_shared(shares_private), set([(u1, u2, room), (u2, u1, room)])
+ self._compress_shared(shares_private), {(u1, u2, room), (u2, u1, room)}
)
self.assertEqual(public_users, [])
@@ -147,6 +147,98 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
s = self.get_success(self.handler.search_users(u1, "user3", 10))
self.assertEqual(len(s["results"]), 0)
+ def test_spam_checker(self):
+ """
+ A user which fails to the spam checks will not appear in search results.
+ """
+ u1 = self.register_user("user1", "pass")
+ u1_token = self.login(u1, "pass")
+ u2 = self.register_user("user2", "pass")
+ u2_token = self.login(u2, "pass")
+
+ # We do not add users to the directory until they join a room.
+ s = self.get_success(self.handler.search_users(u1, "user2", 10))
+ self.assertEqual(len(s["results"]), 0)
+
+ room = self.helper.create_room_as(u1, is_public=False, tok=u1_token)
+ self.helper.invite(room, src=u1, targ=u2, tok=u1_token)
+ self.helper.join(room, user=u2, tok=u2_token)
+
+ # Check we have populated the database correctly.
+ shares_private = self.get_users_who_share_private_rooms()
+ public_users = self.get_users_in_public_rooms()
+
+ self.assertEqual(
+ self._compress_shared(shares_private), {(u1, u2, room), (u2, u1, room)}
+ )
+ self.assertEqual(public_users, [])
+
+ # We get one search result when searching for user2 by user1.
+ s = self.get_success(self.handler.search_users(u1, "user2", 10))
+ self.assertEqual(len(s["results"]), 1)
+
+ # Configure a spam checker that does not filter any users.
+ spam_checker = self.hs.get_spam_checker()
+
+ class AllowAll(object):
+ def check_username_for_spam(self, user_profile):
+ # Allow all users.
+ return False
+
+ spam_checker.spam_checkers = [AllowAll()]
+
+ # The results do not change:
+ # We get one search result when searching for user2 by user1.
+ s = self.get_success(self.handler.search_users(u1, "user2", 10))
+ self.assertEqual(len(s["results"]), 1)
+
+ # Configure a spam checker that filters all users.
+ class BlockAll(object):
+ def check_username_for_spam(self, user_profile):
+ # All users are spammy.
+ return True
+
+ spam_checker.spam_checkers = [BlockAll()]
+
+ # User1 now gets no search results for any of the other users.
+ s = self.get_success(self.handler.search_users(u1, "user2", 10))
+ self.assertEqual(len(s["results"]), 0)
+
+ def test_legacy_spam_checker(self):
+ """
+ A spam checker without the expected method should be ignored.
+ """
+ u1 = self.register_user("user1", "pass")
+ u1_token = self.login(u1, "pass")
+ u2 = self.register_user("user2", "pass")
+ u2_token = self.login(u2, "pass")
+
+ # We do not add users to the directory until they join a room.
+ s = self.get_success(self.handler.search_users(u1, "user2", 10))
+ self.assertEqual(len(s["results"]), 0)
+
+ room = self.helper.create_room_as(u1, is_public=False, tok=u1_token)
+ self.helper.invite(room, src=u1, targ=u2, tok=u1_token)
+ self.helper.join(room, user=u2, tok=u2_token)
+
+ # Check we have populated the database correctly.
+ shares_private = self.get_users_who_share_private_rooms()
+ public_users = self.get_users_in_public_rooms()
+
+ self.assertEqual(
+ self._compress_shared(shares_private), {(u1, u2, room), (u2, u1, room)}
+ )
+ self.assertEqual(public_users, [])
+
+ # Configure a spam checker.
+ spam_checker = self.hs.get_spam_checker()
+ # The spam checker doesn't need any methods, so create a bare object.
+ spam_checker.spam_checker = object()
+
+ # We get one search result when searching for user2 by user1.
+ s = self.get_success(self.handler.search_users(u1, "user2", 10))
+ self.assertEqual(len(s["results"]), 1)
+
def _compress_shared(self, shared):
"""
Compress a list of users who share rooms dicts to a list of tuples.
@@ -158,7 +250,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
def get_users_in_public_rooms(self):
r = self.get_success(
- self.store._simple_select_list(
+ self.store.db.simple_select_list(
"users_in_public_rooms", None, ("user_id", "room_id")
)
)
@@ -169,7 +261,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
def get_users_who_share_private_rooms(self):
return self.get_success(
- self.store._simple_select_list(
+ self.store.db.simple_select_list(
"users_who_share_private_rooms",
None,
["user_id", "other_user_id", "room_id"],
@@ -181,10 +273,10 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
Add the background updates we need to run.
"""
# Ugh, have to reset this flag
- self.store._all_done = False
+ self.store.db.updates._all_done = False
self.get_success(
- self.store._simple_insert(
+ self.store.db.simple_insert(
"background_updates",
{
"update_name": "populate_user_directory_createtables",
@@ -193,7 +285,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
)
)
self.get_success(
- self.store._simple_insert(
+ self.store.db.simple_insert(
"background_updates",
{
"update_name": "populate_user_directory_process_rooms",
@@ -203,7 +295,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
)
)
self.get_success(
- self.store._simple_insert(
+ self.store.db.simple_insert(
"background_updates",
{
"update_name": "populate_user_directory_process_users",
@@ -213,7 +305,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
)
)
self.get_success(
- self.store._simple_insert(
+ self.store.db.simple_insert(
"background_updates",
{
"update_name": "populate_user_directory_cleanup",
@@ -255,19 +347,23 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
# Do the initial population of the user directory via the background update
self._add_background_updates()
- while not self.get_success(self.store.has_completed_background_updates()):
- self.get_success(self.store.do_next_background_update(100), by=0.1)
+ while not self.get_success(
+ self.store.db.updates.has_completed_background_updates()
+ ):
+ self.get_success(
+ self.store.db.updates.do_next_background_update(100), by=0.1
+ )
shares_private = self.get_users_who_share_private_rooms()
public_users = self.get_users_in_public_rooms()
# User 1 and User 2 are in the same public room
- self.assertEqual(set(public_users), set([(u1, room), (u2, room)]))
+ self.assertEqual(set(public_users), {(u1, room), (u2, room)})
# User 1 and User 3 share private rooms
self.assertEqual(
self._compress_shared(shares_private),
- set([(u1, u3, private_room), (u3, u1, private_room)]),
+ {(u1, u3, private_room), (u3, u1, private_room)},
)
def test_initial_share_all_users(self):
@@ -290,15 +386,19 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
# Do the initial population of the user directory via the background update
self._add_background_updates()
- while not self.get_success(self.store.has_completed_background_updates()):
- self.get_success(self.store.do_next_background_update(100), by=0.1)
+ while not self.get_success(
+ self.store.db.updates.has_completed_background_updates()
+ ):
+ self.get_success(
+ self.store.db.updates.do_next_background_update(100), by=0.1
+ )
shares_private = self.get_users_who_share_private_rooms()
public_users = self.get_users_in_public_rooms()
# No users share rooms
self.assertEqual(public_users, [])
- self.assertEqual(self._compress_shared(shares_private), set([]))
+ self.assertEqual(self._compress_shared(shares_private), set())
# Despite not sharing a room, search_all_users means we get a search
# result.
diff --git a/tests/http/__init__.py b/tests/http/__init__.py
index 2d5dba6464..2096ba3c91 100644
--- a/tests/http/__init__.py
+++ b/tests/http/__init__.py
@@ -20,6 +20,23 @@ from zope.interface import implementer
from OpenSSL import SSL
from OpenSSL.SSL import Connection
from twisted.internet.interfaces import IOpenSSLServerConnectionCreator
+from twisted.internet.ssl import Certificate, trustRootFromCertificates
+from twisted.web.client import BrowserLikePolicyForHTTPS # noqa: F401
+from twisted.web.iweb import IPolicyForHTTPS # noqa: F401
+
+
+def get_test_https_policy():
+ """Get a test IPolicyForHTTPS which trusts the test CA cert
+
+ Returns:
+ IPolicyForHTTPS
+ """
+ ca_file = get_test_ca_cert_file()
+ with open(ca_file) as stream:
+ content = stream.read()
+ cert = Certificate.loadPEM(content)
+ trust_root = trustRootFromCertificates([cert])
+ return BrowserLikePolicyForHTTPS(trustRoot=trust_root)
def get_test_ca_cert_file():
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index 71d7025264..562397cdda 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -31,14 +31,14 @@ from twisted.web.http_headers import Headers
from twisted.web.iweb import IPolicyForHTTPS
from synapse.config.homeserver import HomeServerConfig
-from synapse.crypto.context_factory import ClientTLSOptionsFactory
+from synapse.crypto.context_factory import FederationPolicyForHTTPS
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
from synapse.http.federation.srv_resolver import Server
from synapse.http.federation.well_known_resolver import (
WellKnownResolver,
_cache_period_from_headers,
)
-from synapse.logging.context import LoggingContext
+from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context
from synapse.util.caches.ttlcache import TTLCache
from tests import unittest
@@ -79,7 +79,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
self._config = config = HomeServerConfig()
config.parse_config_dict(config_dict, "", "")
- self.tls_factory = ClientTLSOptionsFactory(config)
+ self.tls_factory = FederationPolicyForHTTPS(config)
self.well_known_cache = TTLCache("test_cache", timer=self.reactor.seconds)
self.had_well_known_cache = TTLCache("test_cache", timer=self.reactor.seconds)
@@ -124,19 +124,24 @@ class MatrixFederationAgentTests(unittest.TestCase):
FakeTransport(client_protocol, self.reactor, server_tls_protocol)
)
+ # grab a hold of the TLS connection, in case it gets torn down
+ server_tls_connection = server_tls_protocol._tlsConnection
+
+ # fish the test server back out of the server-side TLS protocol.
+ http_protocol = server_tls_protocol.wrappedProtocol
+
# give the reactor a pump to get the TLS juices flowing.
self.reactor.pump((0.1,))
# check the SNI
- server_name = server_tls_protocol._tlsConnection.get_servername()
+ server_name = server_tls_connection.get_servername()
self.assertEqual(
server_name,
expected_sni,
"Expected SNI %s but got %s" % (expected_sni, server_name),
)
- # fish the test server back out of the server-side TLS protocol.
- return server_tls_protocol.wrappedProtocol
+ return http_protocol
@defer.inlineCallbacks
def _make_get_request(self, uri):
@@ -150,7 +155,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.assertNoResult(fetch_d)
# should have reset logcontext to the sentinel
- _check_logcontext(LoggingContext.sentinel)
+ _check_logcontext(SENTINEL_CONTEXT)
try:
fetch_res = yield fetch_d
@@ -710,7 +715,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
config = default_config("test", parse=True)
# Build a new agent and WellKnownResolver with a different tls factory
- tls_factory = ClientTLSOptionsFactory(config)
+ tls_factory = FederationPolicyForHTTPS(config)
agent = MatrixFederationAgent(
reactor=self.reactor,
tls_client_options_factory=tls_factory,
@@ -1192,7 +1197,7 @@ class TestCachePeriodFromHeaders(unittest.TestCase):
def _check_logcontext(context):
- current = LoggingContext.current_context()
+ current = current_context()
if current is not context:
raise AssertionError("Expected logcontext %s but was %s" % (context, current))
diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py
index df034ab237..babc201643 100644
--- a/tests/http/federation/test_srv_resolver.py
+++ b/tests/http/federation/test_srv_resolver.py
@@ -22,7 +22,7 @@ from twisted.internet.error import ConnectError
from twisted.names import dns, error
from synapse.http.federation.srv_resolver import SrvResolver
-from synapse.logging.context import LoggingContext
+from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context
from tests import unittest
from tests.utils import MockClock
@@ -54,12 +54,12 @@ class SrvResolverTestCase(unittest.TestCase):
self.assertNoResult(resolve_d)
# should have reset to the sentinel context
- self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel)
+ self.assertIs(current_context(), SENTINEL_CONTEXT)
result = yield resolve_d
# should have restored our context
- self.assertIs(LoggingContext.current_context(), ctx)
+ self.assertIs(current_context(), ctx)
return result
diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py
index 2b01f40a42..fff4f0cbf4 100644
--- a/tests/http/test_fedclient.py
+++ b/tests/http/test_fedclient.py
@@ -29,14 +29,14 @@ from synapse.http.matrixfederationclient import (
MatrixFederationHttpClient,
MatrixFederationRequest,
)
-from synapse.logging.context import LoggingContext
+from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context
from tests.server import FakeTransport
from tests.unittest import HomeserverTestCase
def check_logcontext(context):
- current = LoggingContext.current_context()
+ current = current_context()
if current is not context:
raise AssertionError("Expected logcontext %s but was %s" % (context, current))
@@ -64,7 +64,7 @@ class FederationClientTests(HomeserverTestCase):
self.assertNoResult(fetch_d)
# should have reset logcontext to the sentinel
- check_logcontext(LoggingContext.sentinel)
+ check_logcontext(SENTINEL_CONTEXT)
try:
fetch_res = yield fetch_d
diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py
new file mode 100644
index 0000000000..22abf76515
--- /dev/null
+++ b/tests/http/test_proxyagent.py
@@ -0,0 +1,334 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+import treq
+
+from twisted.internet import interfaces # noqa: F401
+from twisted.internet.protocol import Factory
+from twisted.protocols.tls import TLSMemoryBIOFactory
+from twisted.web.http import HTTPChannel
+
+from synapse.http.proxyagent import ProxyAgent
+
+from tests.http import TestServerTLSConnectionFactory, get_test_https_policy
+from tests.server import FakeTransport, ThreadedMemoryReactorClock
+from tests.unittest import TestCase
+
+logger = logging.getLogger(__name__)
+
+HTTPFactory = Factory.forProtocol(HTTPChannel)
+
+
+class MatrixFederationAgentTests(TestCase):
+ def setUp(self):
+ self.reactor = ThreadedMemoryReactorClock()
+
+ def _make_connection(
+ self, client_factory, server_factory, ssl=False, expected_sni=None
+ ):
+ """Builds a test server, and completes the outgoing client connection
+
+ Args:
+ client_factory (interfaces.IProtocolFactory): the the factory that the
+ application is trying to use to make the outbound connection. We will
+ invoke it to build the client Protocol
+
+ server_factory (interfaces.IProtocolFactory): a factory to build the
+ server-side protocol
+
+ ssl (bool): If true, we will expect an ssl connection and wrap
+ server_factory with a TLSMemoryBIOFactory
+
+ expected_sni (bytes|None): the expected SNI value
+
+ Returns:
+ IProtocol: the server Protocol returned by server_factory
+ """
+ if ssl:
+ server_factory = _wrap_server_factory_for_tls(server_factory)
+
+ server_protocol = server_factory.buildProtocol(None)
+
+ # now, tell the client protocol factory to build the client protocol,
+ # and wire the output of said protocol up to the server via
+ # a FakeTransport.
+ #
+ # Normally this would be done by the TCP socket code in Twisted, but we are
+ # stubbing that out here.
+ client_protocol = client_factory.buildProtocol(None)
+ client_protocol.makeConnection(
+ FakeTransport(server_protocol, self.reactor, client_protocol)
+ )
+
+ # tell the server protocol to send its stuff back to the client, too
+ server_protocol.makeConnection(
+ FakeTransport(client_protocol, self.reactor, server_protocol)
+ )
+
+ if ssl:
+ http_protocol = server_protocol.wrappedProtocol
+ tls_connection = server_protocol._tlsConnection
+ else:
+ http_protocol = server_protocol
+ tls_connection = None
+
+ # give the reactor a pump to get the TLS juices flowing (if needed)
+ self.reactor.advance(0)
+
+ if expected_sni is not None:
+ server_name = tls_connection.get_servername()
+ self.assertEqual(
+ server_name,
+ expected_sni,
+ "Expected SNI %s but got %s" % (expected_sni, server_name),
+ )
+
+ return http_protocol
+
+ def test_http_request(self):
+ agent = ProxyAgent(self.reactor)
+
+ self.reactor.lookups["test.com"] = "1.2.3.4"
+ d = agent.request(b"GET", b"http://test.com")
+
+ # there should be a pending TCP connection
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, "1.2.3.4")
+ self.assertEqual(port, 80)
+
+ # make a test server, and wire up the client
+ http_server = self._make_connection(
+ client_factory, _get_test_protocol_factory()
+ )
+
+ # the FakeTransport is async, so we need to pump the reactor
+ self.reactor.advance(0)
+
+ # now there should be a pending request
+ self.assertEqual(len(http_server.requests), 1)
+
+ request = http_server.requests[0]
+ self.assertEqual(request.method, b"GET")
+ self.assertEqual(request.path, b"/")
+ self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
+ request.write(b"result")
+ request.finish()
+
+ self.reactor.advance(0)
+
+ resp = self.successResultOf(d)
+ body = self.successResultOf(treq.content(resp))
+ self.assertEqual(body, b"result")
+
+ def test_https_request(self):
+ agent = ProxyAgent(self.reactor, contextFactory=get_test_https_policy())
+
+ self.reactor.lookups["test.com"] = "1.2.3.4"
+ d = agent.request(b"GET", b"https://test.com/abc")
+
+ # there should be a pending TCP connection
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, "1.2.3.4")
+ self.assertEqual(port, 443)
+
+ # make a test server, and wire up the client
+ http_server = self._make_connection(
+ client_factory,
+ _get_test_protocol_factory(),
+ ssl=True,
+ expected_sni=b"test.com",
+ )
+
+ # the FakeTransport is async, so we need to pump the reactor
+ self.reactor.advance(0)
+
+ # now there should be a pending request
+ self.assertEqual(len(http_server.requests), 1)
+
+ request = http_server.requests[0]
+ self.assertEqual(request.method, b"GET")
+ self.assertEqual(request.path, b"/abc")
+ self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
+ request.write(b"result")
+ request.finish()
+
+ self.reactor.advance(0)
+
+ resp = self.successResultOf(d)
+ body = self.successResultOf(treq.content(resp))
+ self.assertEqual(body, b"result")
+
+ def test_http_request_via_proxy(self):
+ agent = ProxyAgent(self.reactor, http_proxy=b"proxy.com:8888")
+
+ self.reactor.lookups["proxy.com"] = "1.2.3.5"
+ d = agent.request(b"GET", b"http://test.com")
+
+ # there should be a pending TCP connection
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, "1.2.3.5")
+ self.assertEqual(port, 8888)
+
+ # make a test server, and wire up the client
+ http_server = self._make_connection(
+ client_factory, _get_test_protocol_factory()
+ )
+
+ # the FakeTransport is async, so we need to pump the reactor
+ self.reactor.advance(0)
+
+ # now there should be a pending request
+ self.assertEqual(len(http_server.requests), 1)
+
+ request = http_server.requests[0]
+ self.assertEqual(request.method, b"GET")
+ self.assertEqual(request.path, b"http://test.com")
+ self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
+ request.write(b"result")
+ request.finish()
+
+ self.reactor.advance(0)
+
+ resp = self.successResultOf(d)
+ body = self.successResultOf(treq.content(resp))
+ self.assertEqual(body, b"result")
+
+ def test_https_request_via_proxy(self):
+ agent = ProxyAgent(
+ self.reactor,
+ contextFactory=get_test_https_policy(),
+ https_proxy=b"proxy.com",
+ )
+
+ self.reactor.lookups["proxy.com"] = "1.2.3.5"
+ d = agent.request(b"GET", b"https://test.com/abc")
+
+ # there should be a pending TCP connection
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, "1.2.3.5")
+ self.assertEqual(port, 1080)
+
+ # make a test HTTP server, and wire up the client
+ proxy_server = self._make_connection(
+ client_factory, _get_test_protocol_factory()
+ )
+
+ # fish the transports back out so that we can do the old switcheroo
+ s2c_transport = proxy_server.transport
+ client_protocol = s2c_transport.other
+ c2s_transport = client_protocol.transport
+
+ # the FakeTransport is async, so we need to pump the reactor
+ self.reactor.advance(0)
+
+ # now there should be a pending CONNECT request
+ self.assertEqual(len(proxy_server.requests), 1)
+
+ request = proxy_server.requests[0]
+ self.assertEqual(request.method, b"CONNECT")
+ self.assertEqual(request.path, b"test.com:443")
+
+ # tell the proxy server not to close the connection
+ proxy_server.persistent = True
+
+ # this just stops the http Request trying to do a chunked response
+ # request.setHeader(b"Content-Length", b"0")
+ request.finish()
+
+ # now we can replace the proxy channel with a new, SSL-wrapped HTTP channel
+ ssl_factory = _wrap_server_factory_for_tls(_get_test_protocol_factory())
+ ssl_protocol = ssl_factory.buildProtocol(None)
+ http_server = ssl_protocol.wrappedProtocol
+
+ ssl_protocol.makeConnection(
+ FakeTransport(client_protocol, self.reactor, ssl_protocol)
+ )
+ c2s_transport.other = ssl_protocol
+
+ self.reactor.advance(0)
+
+ server_name = ssl_protocol._tlsConnection.get_servername()
+ expected_sni = b"test.com"
+ self.assertEqual(
+ server_name,
+ expected_sni,
+ "Expected SNI %s but got %s" % (expected_sni, server_name),
+ )
+
+ # now there should be a pending request
+ self.assertEqual(len(http_server.requests), 1)
+
+ request = http_server.requests[0]
+ self.assertEqual(request.method, b"GET")
+ self.assertEqual(request.path, b"/abc")
+ self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
+ request.write(b"result")
+ request.finish()
+
+ self.reactor.advance(0)
+
+ resp = self.successResultOf(d)
+ body = self.successResultOf(treq.content(resp))
+ self.assertEqual(body, b"result")
+
+
+def _wrap_server_factory_for_tls(factory, sanlist=None):
+ """Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory
+
+ The resultant factory will create a TLS server which presents a certificate
+ signed by our test CA, valid for the domains in `sanlist`
+
+ Args:
+ factory (interfaces.IProtocolFactory): protocol factory to wrap
+ sanlist (iterable[bytes]): list of domains the cert should be valid for
+
+ Returns:
+ interfaces.IProtocolFactory
+ """
+ if sanlist is None:
+ sanlist = [b"DNS:test.com"]
+
+ connection_creator = TestServerTLSConnectionFactory(sanlist=sanlist)
+ return TLSMemoryBIOFactory(
+ connection_creator, isClient=False, wrappedFactory=factory
+ )
+
+
+def _get_test_protocol_factory():
+ """Get a protocol Factory which will build an HTTPChannel
+
+ Returns:
+ interfaces.IProtocolFactory
+ """
+ server_factory = Factory.forProtocol(HTTPChannel)
+
+ # Request.finish expects the factory to have a 'log' method.
+ server_factory.log = _log_request
+
+ return server_factory
+
+
+def _log_request(request):
+ """Implements Factory.log, which is expected by Request.finish"""
+ logger.info("Completed request %s", request)
diff --git a/tests/patch_inline_callbacks.py b/tests/patch_inline_callbacks.py
deleted file mode 100644
index 220884311c..0000000000
--- a/tests/patch_inline_callbacks.py
+++ /dev/null
@@ -1,94 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2018 New Vector Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from __future__ import print_function
-
-import functools
-import sys
-
-from twisted.internet import defer
-from twisted.internet.defer import Deferred
-from twisted.python.failure import Failure
-
-
-def do_patch():
- """
- Patch defer.inlineCallbacks so that it checks the state of the logcontext on exit
- """
-
- from synapse.logging.context import LoggingContext
-
- orig_inline_callbacks = defer.inlineCallbacks
-
- def new_inline_callbacks(f):
-
- orig = orig_inline_callbacks(f)
-
- @functools.wraps(f)
- def wrapped(*args, **kwargs):
- start_context = LoggingContext.current_context()
-
- try:
- res = orig(*args, **kwargs)
- except Exception:
- if LoggingContext.current_context() != start_context:
- err = "%s changed context from %s to %s on exception" % (
- f,
- start_context,
- LoggingContext.current_context(),
- )
- print(err, file=sys.stderr)
- raise Exception(err)
- raise
-
- if not isinstance(res, Deferred) or res.called:
- if LoggingContext.current_context() != start_context:
- err = "%s changed context from %s to %s" % (
- f,
- start_context,
- LoggingContext.current_context(),
- )
- # print the error to stderr because otherwise all we
- # see in travis-ci is the 500 error
- print(err, file=sys.stderr)
- raise Exception(err)
- return res
-
- if LoggingContext.current_context() != LoggingContext.sentinel:
- err = (
- "%s returned incomplete deferred in non-sentinel context "
- "%s (start was %s)"
- ) % (f, LoggingContext.current_context(), start_context)
- print(err, file=sys.stderr)
- raise Exception(err)
-
- def check_ctx(r):
- if LoggingContext.current_context() != start_context:
- err = "%s completion of %s changed context from %s to %s" % (
- "Failure" if isinstance(r, Failure) else "Success",
- f,
- start_context,
- LoggingContext.current_context(),
- )
- print(err, file=sys.stderr)
- raise Exception(err)
- return r
-
- res.addBoth(check_ctx)
- return res
-
- return wrapped
-
- defer.inlineCallbacks = new_inline_callbacks
diff --git a/tests/push/test_email.py b/tests/push/test_email.py
index 358b593cd4..83032cc9ea 100644
--- a/tests/push/test_email.py
+++ b/tests/push/test_email.py
@@ -163,8 +163,9 @@ class EmailPusherTests(HomeserverTestCase):
# Get the stream ordering before it gets sent
pushers = self.get_success(
- self.hs.get_datastore().get_pushers_by(dict(user_name=self.user_id))
+ self.hs.get_datastore().get_pushers_by({"user_name": self.user_id})
)
+ pushers = list(pushers)
self.assertEqual(len(pushers), 1)
last_stream_ordering = pushers[0]["last_stream_ordering"]
@@ -173,8 +174,9 @@ class EmailPusherTests(HomeserverTestCase):
# It hasn't succeeded yet, so the stream ordering shouldn't have moved
pushers = self.get_success(
- self.hs.get_datastore().get_pushers_by(dict(user_name=self.user_id))
+ self.hs.get_datastore().get_pushers_by({"user_name": self.user_id})
)
+ pushers = list(pushers)
self.assertEqual(len(pushers), 1)
self.assertEqual(last_stream_ordering, pushers[0]["last_stream_ordering"])
@@ -190,7 +192,8 @@ class EmailPusherTests(HomeserverTestCase):
# The stream ordering has increased
pushers = self.get_success(
- self.hs.get_datastore().get_pushers_by(dict(user_name=self.user_id))
+ self.hs.get_datastore().get_pushers_by({"user_name": self.user_id})
)
+ pushers = list(pushers)
self.assertEqual(len(pushers), 1)
self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering)
diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index 8ce6bb62da..baf9c785f4 100644
--- a/tests/push/test_http.py
+++ b/tests/push/test_http.py
@@ -50,7 +50,7 @@ class HTTPPusherTests(HomeserverTestCase):
config = self.default_config()
config["start_pushers"] = True
- hs = self.setup_test_homeserver(config=config, simple_http_client=m)
+ hs = self.setup_test_homeserver(config=config, proxied_http_client=m)
return hs
@@ -102,8 +102,9 @@ class HTTPPusherTests(HomeserverTestCase):
# Get the stream ordering before it gets sent
pushers = self.get_success(
- self.hs.get_datastore().get_pushers_by(dict(user_name=user_id))
+ self.hs.get_datastore().get_pushers_by({"user_name": user_id})
)
+ pushers = list(pushers)
self.assertEqual(len(pushers), 1)
last_stream_ordering = pushers[0]["last_stream_ordering"]
@@ -112,8 +113,9 @@ class HTTPPusherTests(HomeserverTestCase):
# It hasn't succeeded yet, so the stream ordering shouldn't have moved
pushers = self.get_success(
- self.hs.get_datastore().get_pushers_by(dict(user_name=user_id))
+ self.hs.get_datastore().get_pushers_by({"user_name": user_id})
)
+ pushers = list(pushers)
self.assertEqual(len(pushers), 1)
self.assertEqual(last_stream_ordering, pushers[0]["last_stream_ordering"])
@@ -130,8 +132,9 @@ class HTTPPusherTests(HomeserverTestCase):
# The stream ordering has increased
pushers = self.get_success(
- self.hs.get_datastore().get_pushers_by(dict(user_name=user_id))
+ self.hs.get_datastore().get_pushers_by({"user_name": user_id})
)
+ pushers = list(pushers)
self.assertEqual(len(pushers), 1)
self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering)
last_stream_ordering = pushers[0]["last_stream_ordering"]
@@ -149,7 +152,8 @@ class HTTPPusherTests(HomeserverTestCase):
# The stream ordering has increased, again
pushers = self.get_success(
- self.hs.get_datastore().get_pushers_by(dict(user_name=user_id))
+ self.hs.get_datastore().get_pushers_by({"user_name": user_id})
)
+ pushers = list(pushers)
self.assertEqual(len(pushers), 1)
self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering)
diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py
new file mode 100644
index 0000000000..9ae6a87d7b
--- /dev/null
+++ b/tests/push/test_push_rule_evaluator.py
@@ -0,0 +1,65 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.api.room_versions import RoomVersions
+from synapse.events import FrozenEvent
+from synapse.push.push_rule_evaluator import PushRuleEvaluatorForEvent
+
+from tests import unittest
+
+
+class PushRuleEvaluatorTestCase(unittest.TestCase):
+ def setUp(self):
+ event = FrozenEvent(
+ {
+ "event_id": "$event_id",
+ "type": "m.room.history_visibility",
+ "sender": "@user:test",
+ "state_key": "",
+ "room_id": "@room:test",
+ "content": {"body": "foo bar baz"},
+ },
+ RoomVersions.V1,
+ )
+ room_member_count = 0
+ sender_power_level = 0
+ power_levels = {}
+ self.evaluator = PushRuleEvaluatorForEvent(
+ event, room_member_count, sender_power_level, power_levels
+ )
+
+ def test_display_name(self):
+ """Check for a matching display name in the body of the event."""
+ condition = {
+ "kind": "contains_display_name",
+ }
+
+ # Blank names are skipped.
+ self.assertFalse(self.evaluator.matches(condition, "@user:test", ""))
+
+ # Check a display name that doesn't match.
+ self.assertFalse(self.evaluator.matches(condition, "@user:test", "not found"))
+
+ # Check a display name which matches.
+ self.assertTrue(self.evaluator.matches(condition, "@user:test", "foo"))
+
+ # A display name that matches, but not a full word does not result in a match.
+ self.assertFalse(self.evaluator.matches(condition, "@user:test", "ba"))
+
+ # A display name should not be interpreted as a regular expression.
+ self.assertFalse(self.evaluator.matches(condition, "@user:test", "ba[rz]"))
+
+ # A display name with spaces should work fine.
+ self.assertTrue(self.evaluator.matches(condition, "@user:test", "foo bar"))
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
new file mode 100644
index 0000000000..9d4f0bbe44
--- /dev/null
+++ b/tests/replication/_base.py
@@ -0,0 +1,307 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import Any, List, Optional, Tuple
+
+import attr
+
+from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
+from twisted.internet.task import LoopingCall
+from twisted.web.http import HTTPChannel
+
+from synapse.app.generic_worker import (
+ GenericWorkerReplicationHandler,
+ GenericWorkerServer,
+)
+from synapse.http.site import SynapseRequest
+from synapse.replication.http import streams
+from synapse.replication.tcp.handler import ReplicationCommandHandler
+from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
+from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
+from synapse.server import HomeServer
+from synapse.util import Clock
+
+from tests import unittest
+from tests.server import FakeTransport
+
+logger = logging.getLogger(__name__)
+
+
+class BaseStreamTestCase(unittest.HomeserverTestCase):
+ """Base class for tests of the replication streams"""
+
+ servlets = [
+ streams.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ # build a replication server
+ server_factory = ReplicationStreamProtocolFactory(hs)
+ self.streamer = hs.get_replication_streamer()
+ self.server = server_factory.buildProtocol(None)
+
+ # Make a new HomeServer object for the worker
+ self.reactor.lookups["testserv"] = "1.2.3.4"
+ self.worker_hs = self.setup_test_homeserver(
+ http_client=None,
+ homeserverToUse=GenericWorkerServer,
+ config=self._get_worker_hs_config(),
+ reactor=self.reactor,
+ )
+
+ # Since we use sqlite in memory databases we need to make sure the
+ # databases objects are the same.
+ self.worker_hs.get_datastore().db = hs.get_datastore().db
+
+ self.test_handler = self._build_replication_data_handler()
+ self.worker_hs.replication_data_handler = self.test_handler
+
+ repl_handler = ReplicationCommandHandler(self.worker_hs)
+ self.client = ClientReplicationStreamProtocol(
+ self.worker_hs, "client", "test", clock, repl_handler,
+ )
+
+ self._client_transport = None
+ self._server_transport = None
+
+ def _get_worker_hs_config(self) -> dict:
+ config = self.default_config()
+ config["worker_app"] = "synapse.app.generic_worker"
+ config["worker_replication_host"] = "testserv"
+ config["worker_replication_http_port"] = "8765"
+ return config
+
+ def _build_replication_data_handler(self):
+ return TestReplicationDataHandler(self.worker_hs)
+
+ def reconnect(self):
+ if self._client_transport:
+ self.client.close()
+
+ if self._server_transport:
+ self.server.close()
+
+ self._client_transport = FakeTransport(self.server, self.reactor)
+ self.client.makeConnection(self._client_transport)
+
+ self._server_transport = FakeTransport(self.client, self.reactor)
+ self.server.makeConnection(self._server_transport)
+
+ def disconnect(self):
+ if self._client_transport:
+ self._client_transport = None
+ self.client.close()
+
+ if self._server_transport:
+ self._server_transport = None
+ self.server.close()
+
+ def replicate(self):
+ """Tell the master side of replication that something has happened, and then
+ wait for the replication to occur.
+ """
+ self.streamer.on_notifier_poke()
+ self.pump(0.1)
+
+ def handle_http_replication_attempt(self) -> SynapseRequest:
+ """Asserts that a connection attempt was made to the master HS on the
+ HTTP replication port, then proxies it to the master HS object to be
+ handled.
+
+ Returns:
+ The request object received by master HS.
+ """
+
+ # We should have an outbound connection attempt.
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
+ self.assertEqual(host, "1.2.3.4")
+ self.assertEqual(port, 8765)
+
+ # Set up client side protocol
+ client_protocol = client_factory.buildProtocol(None)
+
+ request_factory = OneShotRequestFactory()
+
+ # Set up the server side protocol
+ channel = _PushHTTPChannel(self.reactor)
+ channel.requestFactory = request_factory
+ channel.site = self.site
+
+ # Connect client to server and vice versa.
+ client_to_server_transport = FakeTransport(
+ channel, self.reactor, client_protocol
+ )
+ client_protocol.makeConnection(client_to_server_transport)
+
+ server_to_client_transport = FakeTransport(
+ client_protocol, self.reactor, channel
+ )
+ channel.makeConnection(server_to_client_transport)
+
+ # The request will now be processed by `self.site` and the response
+ # streamed back.
+ self.reactor.advance(0)
+
+ # We tear down the connection so it doesn't get reused without our
+ # knowledge.
+ server_to_client_transport.loseConnection()
+ client_to_server_transport.loseConnection()
+
+ return request_factory.request
+
+ def assert_request_is_get_repl_stream_updates(
+ self, request: SynapseRequest, stream_name: str
+ ):
+ """Asserts that the given request is a HTTP replication request for
+ fetching updates for given stream.
+ """
+
+ self.assertRegex(
+ request.path,
+ br"^/_synapse/replication/get_repl_stream_updates/%s/[^/]+$"
+ % (stream_name.encode("ascii"),),
+ )
+
+ self.assertEqual(request.method, b"GET")
+
+
+class TestReplicationDataHandler(GenericWorkerReplicationHandler):
+ """Drop-in for ReplicationDataHandler which just collects RDATA rows"""
+
+ def __init__(self, hs: HomeServer):
+ super().__init__(hs)
+
+ # list of received (stream_name, token, row) tuples
+ self.received_rdata_rows = [] # type: List[Tuple[str, int, Any]]
+
+ async def on_rdata(self, stream_name, instance_name, token, rows):
+ await super().on_rdata(stream_name, instance_name, token, rows)
+ for r in rows:
+ self.received_rdata_rows.append((stream_name, token, r))
+
+
+@attr.s()
+class OneShotRequestFactory:
+ """A simple request factory that generates a single `SynapseRequest` and
+ stores it for future use. Can only be used once.
+ """
+
+ request = attr.ib(default=None)
+
+ def __call__(self, *args, **kwargs):
+ assert self.request is None
+
+ self.request = SynapseRequest(*args, **kwargs)
+ return self.request
+
+
+class _PushHTTPChannel(HTTPChannel):
+ """A HTTPChannel that wraps pull producers to push producers.
+
+ This is a hack to get around the fact that HTTPChannel transparently wraps a
+ pull producer (which is what Synapse uses to reply to requests) with
+ `_PullToPush` to convert it to a push producer. Unfortunately `_PullToPush`
+ uses the standard reactor rather than letting us use our test reactor, which
+ makes it very hard to test.
+ """
+
+ def __init__(self, reactor: IReactorTime):
+ super().__init__()
+ self.reactor = reactor
+
+ self._pull_to_push_producer = None # type: Optional[_PullToPushProducer]
+
+ def registerProducer(self, producer, streaming):
+ # Convert pull producers to push producer.
+ if not streaming:
+ self._pull_to_push_producer = _PullToPushProducer(
+ self.reactor, producer, self
+ )
+ producer = self._pull_to_push_producer
+
+ super().registerProducer(producer, True)
+
+ def unregisterProducer(self):
+ if self._pull_to_push_producer:
+ # We need to manually stop the _PullToPushProducer.
+ self._pull_to_push_producer.stop()
+
+
+class _PullToPushProducer:
+ """A push producer that wraps a pull producer.
+ """
+
+ def __init__(
+ self, reactor: IReactorTime, producer: IPullProducer, consumer: IConsumer
+ ):
+ self._clock = Clock(reactor)
+ self._producer = producer
+ self._consumer = consumer
+
+ # While running we use a looping call with a zero delay to call
+ # resumeProducing on given producer.
+ self._looping_call = None # type: Optional[LoopingCall]
+
+ # We start writing next reactor tick.
+ self._start_loop()
+
+ def _start_loop(self):
+ """Start the looping call to
+ """
+
+ if not self._looping_call:
+ # Start a looping call which runs every tick.
+ self._looping_call = self._clock.looping_call(self._run_once, 0)
+
+ def stop(self):
+ """Stops calling resumeProducing.
+ """
+ if self._looping_call:
+ self._looping_call.stop()
+ self._looping_call = None
+
+ def pauseProducing(self):
+ """Implements IPushProducer
+ """
+ self.stop()
+
+ def resumeProducing(self):
+ """Implements IPushProducer
+ """
+ self._start_loop()
+
+ def stopProducing(self):
+ """Implements IPushProducer
+ """
+ self.stop()
+ self._producer.stopProducing()
+
+ def _run_once(self):
+ """Calls resumeProducing on producer once.
+ """
+
+ try:
+ self._producer.resumeProducing()
+ except Exception:
+ logger.exception("Failed to call resumeProducing")
+ try:
+ self._consumer.unregisterProducer()
+ except Exception:
+ pass
+
+ self.stopProducing()
diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index 104349cdbd..56497b8476 100644
--- a/tests/replication/slave/storage/_base.py
+++ b/tests/replication/slave/storage/_base.py
@@ -13,52 +13,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from mock import Mock, NonCallableMock
+from mock import Mock
-from synapse.replication.tcp.client import (
- ReplicationClientFactory,
- ReplicationClientHandler,
-)
-from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
+from tests.replication._base import BaseStreamTestCase
-from tests import unittest
-from tests.server import FakeTransport
-
-class BaseSlavedStoreTestCase(unittest.HomeserverTestCase):
+class BaseSlavedStoreTestCase(BaseStreamTestCase):
def make_homeserver(self, reactor, clock):
- hs = self.setup_test_homeserver(
- "blue",
- federation_client=Mock(),
- ratelimiter=NonCallableMock(spec_set=["can_do_action"]),
- )
-
- hs.get_ratelimiter().can_do_action.return_value = (True, 0)
+ hs = self.setup_test_homeserver(federation_client=Mock())
return hs
def prepare(self, reactor, clock, hs):
+ super().prepare(reactor, clock, hs)
- self.master_store = self.hs.get_datastore()
- self.slaved_store = self.STORE_TYPE(self.hs.get_db_conn(), self.hs)
- self.event_id = 0
-
- server_factory = ReplicationStreamProtocolFactory(self.hs)
- self.streamer = server_factory.streamer
-
- self.replication_handler = ReplicationClientHandler(self.slaved_store)
- client_factory = ReplicationClientFactory(
- self.hs, "client_name", self.replication_handler
- )
-
- server = server_factory.buildProtocol(None)
- client = client_factory.buildProtocol(None)
-
- client.makeConnection(FakeTransport(server, reactor))
+ self.reconnect()
- self.server_to_client_transport = FakeTransport(client, reactor)
- server.makeConnection(self.server_to_client_transport)
+ self.master_store = hs.get_datastore()
+ self.slaved_store = self.worker_hs.get_datastore()
+ self.storage = hs.get_storage()
def replicate(self):
"""Tell the master side of replication that something has happened, and then
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index a368117b43..1a88c7fb80 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -15,18 +15,20 @@ import logging
from canonicaljson import encode_canonical_json
-from synapse.events import FrozenEvent, _EventInternalMetadata
-from synapse.events.snapshot import EventContext
+from synapse.api.room_versions import RoomVersions
+from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_dict
from synapse.handlers.room import RoomEventSource
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.storage.roommember import RoomsForUser
+from tests.server import FakeTransport
+
from ._base import BaseSlavedStoreTestCase
-USER_ID = "@feeling:blue"
-USER_ID_2 = "@bright:blue"
+USER_ID = "@feeling:test"
+USER_ID_2 = "@bright:test"
OUTLIER = {"outlier": True}
-ROOM_ID = "!room:blue"
+ROOM_ID = "!room:test"
logger = logging.getLogger(__name__)
@@ -58,6 +60,15 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.unpatches = [patch__eq__(_EventInternalMetadata), patch__eq__(FrozenEvent)]
return super(SlavedEventStoreTestCase, self).setUp()
+ def prepare(self, *args, **kwargs):
+ super().prepare(*args, **kwargs)
+
+ self.get_success(
+ self.master_store.store_room(
+ ROOM_ID, USER_ID, is_public=False, room_version=RoomVersions.V1,
+ )
+ )
+
def tearDown(self):
[unpatch() for unpatch in self.unpatches]
@@ -90,7 +101,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
msg_dict["content"] = {}
msg_dict["unsigned"]["redacted_by"] = redaction.event_id
msg_dict["unsigned"]["redacted_because"] = redaction
- redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict())
+ redacted = make_event_from_dict(
+ msg_dict, internal_metadata_dict=msg.internal_metadata.get_dict()
+ )
self.check("get_event", [msg.event_id], redacted)
def test_backfilled_redactions(self):
@@ -110,18 +123,20 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
msg_dict["content"] = {}
msg_dict["unsigned"]["redacted_by"] = redaction.event_id
msg_dict["unsigned"]["redacted_because"] = redaction
- redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict())
+ redacted = make_event_from_dict(
+ msg_dict, internal_metadata_dict=msg.internal_metadata.get_dict()
+ )
self.check("get_event", [msg.event_id], redacted)
def test_invites(self):
self.persist(type="m.room.create", key="", creator=USER_ID)
- self.check("get_invited_rooms_for_user", [USER_ID_2], [])
+ self.check("get_invited_rooms_for_local_user", [USER_ID_2], [])
event = self.persist(type="m.room.member", key=USER_ID_2, membership="invite")
self.replicate()
self.check(
- "get_invited_rooms_for_user",
+ "get_invited_rooms_for_local_user",
[USER_ID_2],
[
RoomsForUser(
@@ -225,7 +240,8 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.check("get_rooms_for_user_with_stream_ordering", (USER_ID_2,), set())
# limit the replication rate
- repl_transport = self.server_to_client_transport
+ repl_transport = self._server_transport
+ assert isinstance(repl_transport, FakeTransport)
repl_transport.autoflush = False
# build the join and message events and persist them in the same batch.
@@ -234,7 +250,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join"
)
msg, msgctx = self.build_event()
- self.get_success(self.master_store.persist_events([(j2, j2ctx), (msg, msgctx)]))
+ self.get_success(
+ self.storage.persistence.persist_events([(j2, j2ctx), (msg, msgctx)])
+ )
self.replicate()
event_source = RoomEventSource(self.hs)
@@ -290,10 +308,12 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
if backfill:
self.get_success(
- self.master_store.persist_events([(event, context)], backfilled=True)
+ self.storage.persistence.persist_events(
+ [(event, context)], backfilled=True
+ )
)
else:
- self.get_success(self.master_store.persist_event(event, context))
+ self.get_success(self.storage.persistence.persist_event(event, context))
return event
@@ -304,7 +324,6 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
type="m.room.message",
key=None,
internal={},
- state=None,
depth=None,
prev_events=[],
auth_events=[],
@@ -341,18 +360,11 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
if redacts is not None:
event_dict["redacts"] = redacts
- event = FrozenEvent(event_dict, internal_metadata_dict=internal)
+ event = make_event_from_dict(event_dict, internal_metadata_dict=internal)
self.event_id += 1
-
- if state is not None:
- state_ids = {key: e.event_id for key, e in state.items()}
- context = EventContext.with_state(
- state_group=None, current_state_ids=state_ids, prev_state_ids=state_ids
- )
- else:
- state_handler = self.hs.get_state_handler()
- context = self.get_success(state_handler.compute_event_context(event))
+ state_handler = self.hs.get_state_handler()
+ context = self.get_success(state_handler.compute_event_context(event))
self.master_store.add_push_actions_to_staging(
event.event_id, {user_id: actions for user_id, actions in push_actions}
diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py
deleted file mode 100644
index ce3835ae6a..0000000000
--- a/tests/replication/tcp/streams/_base.py
+++ /dev/null
@@ -1,74 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2019 New Vector Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-from synapse.replication.tcp.commands import ReplicateCommand
-from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
-from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
-
-from tests import unittest
-from tests.server import FakeTransport
-
-
-class BaseStreamTestCase(unittest.HomeserverTestCase):
- """Base class for tests of the replication streams"""
-
- def prepare(self, reactor, clock, hs):
- # build a replication server
- server_factory = ReplicationStreamProtocolFactory(self.hs)
- self.streamer = server_factory.streamer
- server = server_factory.buildProtocol(None)
-
- # build a replication client, with a dummy handler
- self.test_handler = TestReplicationClientHandler()
- self.client = ClientReplicationStreamProtocol(
- "client", "test", clock, self.test_handler
- )
-
- # wire them together
- self.client.makeConnection(FakeTransport(server, reactor))
- server.makeConnection(FakeTransport(self.client, reactor))
-
- def replicate(self):
- """Tell the master side of replication that something has happened, and then
- wait for the replication to occur.
- """
- self.streamer.on_notifier_poke()
- self.pump(0.1)
-
- def replicate_stream(self, stream, token="NOW"):
- """Make the client end a REPLICATE command to set up a subscription to a stream"""
- self.client.send_command(ReplicateCommand(stream, token))
-
-
-class TestReplicationClientHandler(object):
- """Drop-in for ReplicationClientHandler which just collects RDATA rows"""
-
- def __init__(self):
- self.received_rdata_rows = []
-
- def get_streams_to_replicate(self):
- return {}
-
- def get_currently_syncing_users(self):
- return []
-
- def update_connection(self, connection):
- pass
-
- def finished_connecting(self):
- pass
-
- def on_rdata(self, stream_name, token, rows):
- for r in rows:
- self.received_rdata_rows.append((stream_name, token, r))
diff --git a/tests/replication/tcp/streams/test_account_data.py b/tests/replication/tcp/streams/test_account_data.py
new file mode 100644
index 0000000000..6a5116dd2a
--- /dev/null
+++ b/tests/replication/tcp/streams/test_account_data.py
@@ -0,0 +1,117 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.replication.tcp.streams._base import (
+ _STREAM_UPDATE_TARGET_ROW_COUNT,
+ AccountDataStream,
+)
+
+from tests.replication._base import BaseStreamTestCase
+
+
+class AccountDataStreamTestCase(BaseStreamTestCase):
+ def test_update_function_room_account_data_limit(self):
+ """Test replication with many room account data updates
+ """
+ store = self.hs.get_datastore()
+
+ # generate lots of account data updates
+ updates = []
+ for i in range(_STREAM_UPDATE_TARGET_ROW_COUNT + 5):
+ update = "m.test_type.%i" % (i,)
+ self.get_success(
+ store.add_account_data_to_room("test_user", "test_room", update, {})
+ )
+ updates.append(update)
+
+ # also one global update
+ self.get_success(store.add_account_data_for_user("test_user", "m.global", {}))
+
+ # tell the notifier to catch up to avoid duplicate rows.
+ # workaround for https://github.com/matrix-org/synapse/issues/7360
+ # FIXME remove this when the above is fixed
+ self.replicate()
+
+ # check we're testing what we think we are: no rows should yet have been
+ # received
+ self.assertEqual([], self.test_handler.received_rdata_rows)
+
+ # now reconnect to pull the updates
+ self.reconnect()
+ self.replicate()
+
+ # we should have received all the expected rows in the right order
+ received_rows = self.test_handler.received_rdata_rows
+
+ for t in updates:
+ (stream_name, token, row) = received_rows.pop(0)
+ self.assertEqual(stream_name, AccountDataStream.NAME)
+ self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow)
+ self.assertEqual(row.data_type, t)
+ self.assertEqual(row.room_id, "test_room")
+
+ (stream_name, token, row) = received_rows.pop(0)
+ self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow)
+ self.assertEqual(row.data_type, "m.global")
+ self.assertIsNone(row.room_id)
+
+ self.assertEqual([], received_rows)
+
+ def test_update_function_global_account_data_limit(self):
+ """Test replication with many global account data updates
+ """
+ store = self.hs.get_datastore()
+
+ # generate lots of account data updates
+ updates = []
+ for i in range(_STREAM_UPDATE_TARGET_ROW_COUNT + 5):
+ update = "m.test_type.%i" % (i,)
+ self.get_success(store.add_account_data_for_user("test_user", update, {}))
+ updates.append(update)
+
+ # also one per-room update
+ self.get_success(
+ store.add_account_data_to_room("test_user", "test_room", "m.per_room", {})
+ )
+
+ # tell the notifier to catch up to avoid duplicate rows.
+ # workaround for https://github.com/matrix-org/synapse/issues/7360
+ # FIXME remove this when the above is fixed
+ self.replicate()
+
+ # check we're testing what we think we are: no rows should yet have been
+ # received
+ self.assertEqual([], self.test_handler.received_rdata_rows)
+
+ # now reconnect to pull the updates
+ self.reconnect()
+ self.replicate()
+
+ # we should have received all the expected rows in the right order
+ received_rows = self.test_handler.received_rdata_rows
+
+ for t in updates:
+ (stream_name, token, row) = received_rows.pop(0)
+ self.assertEqual(stream_name, AccountDataStream.NAME)
+ self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow)
+ self.assertEqual(row.data_type, t)
+ self.assertIsNone(row.room_id)
+
+ (stream_name, token, row) = received_rows.pop(0)
+ self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow)
+ self.assertEqual(row.data_type, "m.per_room")
+ self.assertEqual(row.room_id, "test_room")
+
+ self.assertEqual([], received_rows)
diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py
new file mode 100644
index 0000000000..51bf0ef4e9
--- /dev/null
+++ b/tests/replication/tcp/streams/test_events.py
@@ -0,0 +1,425 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Optional
+
+from synapse.api.constants import EventTypes, Membership
+from synapse.events import EventBase
+from synapse.replication.tcp.streams._base import _STREAM_UPDATE_TARGET_ROW_COUNT
+from synapse.replication.tcp.streams.events import (
+ EventsStreamCurrentStateRow,
+ EventsStreamEventRow,
+ EventsStreamRow,
+)
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+
+from tests.replication._base import BaseStreamTestCase
+from tests.test_utils.event_injection import inject_event, inject_member_event
+
+
+class EventsStreamTestCase(BaseStreamTestCase):
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ super().prepare(reactor, clock, hs)
+ self.user_id = self.register_user("u1", "pass")
+ self.user_tok = self.login("u1", "pass")
+
+ self.reconnect()
+
+ self.room_id = self.helper.create_room_as(tok=self.user_tok)
+ self.test_handler.received_rdata_rows.clear()
+
+ def test_update_function_event_row_limit(self):
+ """Test replication with many non-state events
+
+ Checks that all events are correctly replicated when there are lots of
+ event rows to be replicated.
+ """
+ # disconnect, so that we can stack up some changes
+ self.disconnect()
+
+ # generate lots of non-state events. We inject them using inject_event
+ # so that they are not send out over replication until we call self.replicate().
+ events = [
+ self._inject_test_event()
+ for _ in range(_STREAM_UPDATE_TARGET_ROW_COUNT + 1)
+ ]
+
+ # also one state event
+ state_event = self._inject_state_event()
+
+ # tell the notifier to catch up to avoid duplicate rows.
+ # workaround for https://github.com/matrix-org/synapse/issues/7360
+ # FIXME remove this when the above is fixed
+ self.replicate()
+
+ # check we're testing what we think we are: no rows should yet have been
+ # received
+ self.assertEqual([], self.test_handler.received_rdata_rows)
+
+ # now reconnect to pull the updates
+ self.reconnect()
+ self.replicate()
+
+ # we should have received all the expected rows in the right order (as
+ # well as various cache invalidation updates which we ignore)
+ received_rows = [
+ row for row in self.test_handler.received_rdata_rows if row[0] == "events"
+ ]
+
+ for event in events:
+ stream_name, token, row = received_rows.pop(0)
+ self.assertEqual("events", stream_name)
+ self.assertIsInstance(row, EventsStreamRow)
+ self.assertEqual(row.type, "ev")
+ self.assertIsInstance(row.data, EventsStreamEventRow)
+ self.assertEqual(row.data.event_id, event.event_id)
+
+ stream_name, token, row = received_rows.pop(0)
+ self.assertIsInstance(row, EventsStreamRow)
+ self.assertIsInstance(row.data, EventsStreamEventRow)
+ self.assertEqual(row.data.event_id, state_event.event_id)
+
+ stream_name, token, row = received_rows.pop(0)
+ self.assertEqual("events", stream_name)
+ self.assertIsInstance(row, EventsStreamRow)
+ self.assertEqual(row.type, "state")
+ self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
+ self.assertEqual(row.data.event_id, state_event.event_id)
+
+ self.assertEqual([], received_rows)
+
+ def test_update_function_huge_state_change(self):
+ """Test replication with many state events
+
+ Ensures that all events are correctly replicated when there are lots of
+ state change rows to be replicated.
+ """
+
+ # we want to generate lots of state changes at a single stream ID.
+ #
+ # We do this by having two branches in the DAG. On one, we have a moderator
+ # which that generates lots of state; on the other, we de-op the moderator,
+ # thus invalidating all the state.
+
+ OTHER_USER = "@other_user:localhost"
+
+ # have the user join
+ inject_member_event(self.hs, self.room_id, OTHER_USER, Membership.JOIN)
+
+ # Update existing power levels with mod at PL50
+ pls = self.helper.get_state(
+ self.room_id, EventTypes.PowerLevels, tok=self.user_tok
+ )
+ pls["users"][OTHER_USER] = 50
+ self.helper.send_state(
+ self.room_id, EventTypes.PowerLevels, pls, tok=self.user_tok,
+ )
+
+ # this is the point in the DAG where we make a fork
+ fork_point = self.get_success(
+ self.hs.get_datastore().get_latest_event_ids_in_room(self.room_id)
+ ) # type: List[str]
+
+ events = [
+ self._inject_state_event(sender=OTHER_USER)
+ for _ in range(_STREAM_UPDATE_TARGET_ROW_COUNT)
+ ]
+
+ self.replicate()
+ # all those events and state changes should have landed
+ self.assertGreaterEqual(
+ len(self.test_handler.received_rdata_rows), 2 * len(events)
+ )
+
+ # disconnect, so that we can stack up the changes
+ self.disconnect()
+ self.test_handler.received_rdata_rows.clear()
+
+ # a state event which doesn't get rolled back, to check that the state
+ # before the huge update comes through ok
+ state1 = self._inject_state_event()
+
+ # roll back all the state by de-modding the user
+ prev_events = fork_point
+ pls["users"][OTHER_USER] = 0
+ pl_event = inject_event(
+ self.hs,
+ prev_event_ids=prev_events,
+ type=EventTypes.PowerLevels,
+ state_key="",
+ sender=self.user_id,
+ room_id=self.room_id,
+ content=pls,
+ )
+
+ # one more bit of state that doesn't get rolled back
+ state2 = self._inject_state_event()
+
+ # tell the notifier to catch up to avoid duplicate rows.
+ # workaround for https://github.com/matrix-org/synapse/issues/7360
+ # FIXME remove this when the above is fixed
+ self.replicate()
+
+ # check we're testing what we think we are: no rows should yet have been
+ # received
+ self.assertEqual([], self.test_handler.received_rdata_rows)
+
+ # now reconnect to pull the updates
+ self.reconnect()
+ self.replicate()
+
+ # we should have received all the expected rows in the right order (as
+ # well as various cache invalidation updates which we ignore)
+ #
+ # we expect:
+ #
+ # - two rows for state1
+ # - the PL event row, plus state rows for the PL event and each
+ # of the states that got reverted.
+ # - two rows for state2
+
+ received_rows = [
+ row for row in self.test_handler.received_rdata_rows if row[0] == "events"
+ ]
+
+ # first check the first two rows, which should be state1
+
+ stream_name, token, row = received_rows.pop(0)
+ self.assertEqual("events", stream_name)
+ self.assertIsInstance(row, EventsStreamRow)
+ self.assertEqual(row.type, "ev")
+ self.assertIsInstance(row.data, EventsStreamEventRow)
+ self.assertEqual(row.data.event_id, state1.event_id)
+
+ stream_name, token, row = received_rows.pop(0)
+ self.assertIsInstance(row, EventsStreamRow)
+ self.assertEqual(row.type, "state")
+ self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
+ self.assertEqual(row.data.event_id, state1.event_id)
+
+ # now the last two rows, which should be state2
+ stream_name, token, row = received_rows.pop(-2)
+ self.assertEqual("events", stream_name)
+ self.assertIsInstance(row, EventsStreamRow)
+ self.assertEqual(row.type, "ev")
+ self.assertIsInstance(row.data, EventsStreamEventRow)
+ self.assertEqual(row.data.event_id, state2.event_id)
+
+ stream_name, token, row = received_rows.pop(-1)
+ self.assertIsInstance(row, EventsStreamRow)
+ self.assertEqual(row.type, "state")
+ self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
+ self.assertEqual(row.data.event_id, state2.event_id)
+
+ # that should leave us with the rows for the PL event
+ self.assertEqual(len(received_rows), len(events) + 2)
+
+ stream_name, token, row = received_rows.pop(0)
+ self.assertEqual("events", stream_name)
+ self.assertIsInstance(row, EventsStreamRow)
+ self.assertEqual(row.type, "ev")
+ self.assertIsInstance(row.data, EventsStreamEventRow)
+ self.assertEqual(row.data.event_id, pl_event.event_id)
+
+ # the state rows are unsorted
+ state_rows = [] # type: List[EventsStreamCurrentStateRow]
+ for stream_name, token, row in received_rows:
+ self.assertEqual("events", stream_name)
+ self.assertIsInstance(row, EventsStreamRow)
+ self.assertEqual(row.type, "state")
+ self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
+ state_rows.append(row.data)
+
+ state_rows.sort(key=lambda r: r.state_key)
+
+ sr = state_rows.pop(0)
+ self.assertEqual(sr.type, EventTypes.PowerLevels)
+ self.assertEqual(sr.event_id, pl_event.event_id)
+ for sr in state_rows:
+ self.assertEqual(sr.type, "test_state_event")
+ # "None" indicates the state has been deleted
+ self.assertIsNone(sr.event_id)
+
+ def test_update_function_state_row_limit(self):
+ """Test replication with many state events over several stream ids.
+ """
+
+ # we want to generate lots of state changes, but for this test, we want to
+ # spread out the state changes over a few stream IDs.
+ #
+ # We do this by having two branches in the DAG. On one, we have four moderators,
+ # each of which that generates lots of state; on the other, we de-op the users,
+ # thus invalidating all the state.
+
+ NUM_USERS = 4
+ STATES_PER_USER = _STREAM_UPDATE_TARGET_ROW_COUNT // 4 + 1
+
+ user_ids = ["@user%i:localhost" % (i,) for i in range(NUM_USERS)]
+
+ # have the users join
+ for u in user_ids:
+ inject_member_event(self.hs, self.room_id, u, Membership.JOIN)
+
+ # Update existing power levels with mod at PL50
+ pls = self.helper.get_state(
+ self.room_id, EventTypes.PowerLevels, tok=self.user_tok
+ )
+ pls["users"].update({u: 50 for u in user_ids})
+ self.helper.send_state(
+ self.room_id, EventTypes.PowerLevels, pls, tok=self.user_tok,
+ )
+
+ # this is the point in the DAG where we make a fork
+ fork_point = self.get_success(
+ self.hs.get_datastore().get_latest_event_ids_in_room(self.room_id)
+ ) # type: List[str]
+
+ events = [] # type: List[EventBase]
+ for user in user_ids:
+ events.extend(
+ self._inject_state_event(sender=user) for _ in range(STATES_PER_USER)
+ )
+
+ self.replicate()
+
+ # all those events and state changes should have landed
+ self.assertGreaterEqual(
+ len(self.test_handler.received_rdata_rows), 2 * len(events)
+ )
+
+ # disconnect, so that we can stack up the changes
+ self.disconnect()
+ self.test_handler.received_rdata_rows.clear()
+
+ # now roll back all that state by de-modding the users
+ prev_events = fork_point
+ pl_events = []
+ for u in user_ids:
+ pls["users"][u] = 0
+ e = inject_event(
+ self.hs,
+ prev_event_ids=prev_events,
+ type=EventTypes.PowerLevels,
+ state_key="",
+ sender=self.user_id,
+ room_id=self.room_id,
+ content=pls,
+ )
+ prev_events = [e.event_id]
+ pl_events.append(e)
+
+ # tell the notifier to catch up to avoid duplicate rows.
+ # workaround for https://github.com/matrix-org/synapse/issues/7360
+ # FIXME remove this when the above is fixed
+ self.replicate()
+
+ # check we're testing what we think we are: no rows should yet have been
+ # received
+ self.assertEqual([], self.test_handler.received_rdata_rows)
+
+ # now reconnect to pull the updates
+ self.reconnect()
+ self.replicate()
+
+ # we should have received all the expected rows in the right order (as
+ # well as various cache invalidation updates which we ignore)
+ received_rows = [
+ row for row in self.test_handler.received_rdata_rows if row[0] == "events"
+ ]
+ self.assertGreaterEqual(len(received_rows), len(events))
+ for i in range(NUM_USERS):
+ # for each user, we expect the PL event row, followed by state rows for
+ # the PL event and each of the states that got reverted.
+ stream_name, token, row = received_rows.pop(0)
+ self.assertEqual("events", stream_name)
+ self.assertIsInstance(row, EventsStreamRow)
+ self.assertEqual(row.type, "ev")
+ self.assertIsInstance(row.data, EventsStreamEventRow)
+ self.assertEqual(row.data.event_id, pl_events[i].event_id)
+
+ # the state rows are unsorted
+ state_rows = [] # type: List[EventsStreamCurrentStateRow]
+ for j in range(STATES_PER_USER + 1):
+ stream_name, token, row = received_rows.pop(0)
+ self.assertEqual("events", stream_name)
+ self.assertIsInstance(row, EventsStreamRow)
+ self.assertEqual(row.type, "state")
+ self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
+ state_rows.append(row.data)
+
+ state_rows.sort(key=lambda r: r.state_key)
+
+ sr = state_rows.pop(0)
+ self.assertEqual(sr.type, EventTypes.PowerLevels)
+ self.assertEqual(sr.event_id, pl_events[i].event_id)
+ for sr in state_rows:
+ self.assertEqual(sr.type, "test_state_event")
+ # "None" indicates the state has been deleted
+ self.assertIsNone(sr.event_id)
+
+ self.assertEqual([], received_rows)
+
+ event_count = 0
+
+ def _inject_test_event(
+ self, body: Optional[str] = None, sender: Optional[str] = None, **kwargs
+ ) -> EventBase:
+ if sender is None:
+ sender = self.user_id
+
+ if body is None:
+ body = "event %i" % (self.event_count,)
+ self.event_count += 1
+
+ return inject_event(
+ self.hs,
+ room_id=self.room_id,
+ sender=sender,
+ type="test_event",
+ content={"body": body},
+ **kwargs
+ )
+
+ def _inject_state_event(
+ self,
+ body: Optional[str] = None,
+ state_key: Optional[str] = None,
+ sender: Optional[str] = None,
+ ) -> EventBase:
+ if sender is None:
+ sender = self.user_id
+
+ if state_key is None:
+ state_key = "state_%i" % (self.event_count,)
+ self.event_count += 1
+
+ if body is None:
+ body = "state event %s" % (state_key,)
+
+ return inject_event(
+ self.hs,
+ room_id=self.room_id,
+ sender=sender,
+ type="test_state_event",
+ state_key=state_key,
+ content={"body": body},
+ )
diff --git a/tests/replication/tcp/streams/test_federation.py b/tests/replication/tcp/streams/test_federation.py
new file mode 100644
index 0000000000..2babea4e3e
--- /dev/null
+++ b/tests/replication/tcp/streams/test_federation.py
@@ -0,0 +1,81 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.federation.send_queue import EduRow
+from synapse.replication.tcp.streams.federation import FederationStream
+
+from tests.replication._base import BaseStreamTestCase
+
+
+class FederationStreamTestCase(BaseStreamTestCase):
+ def _get_worker_hs_config(self) -> dict:
+ # enable federation sending on the worker
+ config = super()._get_worker_hs_config()
+ # TODO: make it so we don't need both of these
+ config["send_federation"] = True
+ config["worker_app"] = "synapse.app.federation_sender"
+ return config
+
+ def test_catchup(self):
+ """Basic test of catchup on reconnect
+
+ Makes sure that updates sent while we are offline are received later.
+ """
+ fed_sender = self.hs.get_federation_sender()
+ received_rows = self.test_handler.received_rdata_rows
+
+ fed_sender.build_and_send_edu("testdest", "m.test_edu", {"a": "b"})
+
+ self.reconnect()
+ self.reactor.advance(0)
+
+ # check we're testing what we think we are: no rows should yet have been
+ # received
+ self.assertEqual(received_rows, [])
+
+ # We should now see an attempt to connect to the master
+ request = self.handle_http_replication_attempt()
+ self.assert_request_is_get_repl_stream_updates(request, "federation")
+
+ # we should have received an update row
+ stream_name, token, row = received_rows.pop()
+ self.assertEqual(stream_name, "federation")
+ self.assertIsInstance(row, FederationStream.FederationStreamRow)
+ self.assertEqual(row.type, EduRow.TypeId)
+ edurow = EduRow.from_data(row.data)
+ self.assertEqual(edurow.edu.edu_type, "m.test_edu")
+ self.assertEqual(edurow.edu.origin, self.hs.hostname)
+ self.assertEqual(edurow.edu.destination, "testdest")
+ self.assertEqual(edurow.edu.content, {"a": "b"})
+
+ self.assertEqual(received_rows, [])
+
+ # additional updates should be transferred without an HTTP hit
+ fed_sender.build_and_send_edu("testdest", "m.test1", {"c": "d"})
+ self.reactor.advance(0)
+ # there should be no http hit
+ self.assertEqual(len(self.reactor.tcpClients), 0)
+ # ... but we should have a row
+ self.assertEqual(len(received_rows), 1)
+
+ stream_name, token, row = received_rows.pop()
+ self.assertEqual(stream_name, "federation")
+ self.assertIsInstance(row, FederationStream.FederationStreamRow)
+ self.assertEqual(row.type, EduRow.TypeId)
+ edurow = EduRow.from_data(row.data)
+ self.assertEqual(edurow.edu.edu_type, "m.test1")
+ self.assertEqual(edurow.edu.origin, self.hs.hostname)
+ self.assertEqual(edurow.edu.destination, "testdest")
+ self.assertEqual(edurow.edu.content, {"c": "d"})
diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py
index d5a99f6caa..56b062ecc1 100644
--- a/tests/replication/tcp/streams/test_receipts.py
+++ b/tests/replication/tcp/streams/test_receipts.py
@@ -12,35 +12,73 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.replication.tcp.streams._base import ReceiptsStreamRow
-from tests.replication.tcp.streams._base import BaseStreamTestCase
+# type: ignore
+
+from mock import Mock
+
+from synapse.replication.tcp.streams._base import ReceiptsStream
+
+from tests.replication._base import BaseStreamTestCase
USER_ID = "@feeling:blue"
-ROOM_ID = "!room:blue"
-EVENT_ID = "$event:blue"
class ReceiptsStreamTestCase(BaseStreamTestCase):
+ def _build_replication_data_handler(self):
+ return Mock(wraps=super()._build_replication_data_handler())
+
def test_receipt(self):
- # make the client subscribe to the receipts stream
- self.replicate_stream("receipts", "NOW")
+ self.reconnect()
# tell the master to send a new receipt
self.get_success(
self.hs.get_datastore().insert_receipt(
- ROOM_ID, "m.read", USER_ID, [EVENT_ID], {"a": 1}
+ "!room:blue", "m.read", USER_ID, ["$event:blue"], {"a": 1}
)
)
self.replicate()
# there should be one RDATA command
- rdata_rows = self.test_handler.received_rdata_rows
+ self.test_handler.on_rdata.assert_called_once()
+ stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+ self.assertEqual(stream_name, "receipts")
self.assertEqual(1, len(rdata_rows))
- self.assertEqual(rdata_rows[0][0], "receipts")
- row = rdata_rows[0][2] # type: ReceiptsStreamRow
- self.assertEqual(ROOM_ID, row.room_id)
+ row = rdata_rows[0] # type: ReceiptsStream.ReceiptsStreamRow
+ self.assertEqual("!room:blue", row.room_id)
self.assertEqual("m.read", row.receipt_type)
self.assertEqual(USER_ID, row.user_id)
- self.assertEqual(EVENT_ID, row.event_id)
+ self.assertEqual("$event:blue", row.event_id)
self.assertEqual({"a": 1}, row.data)
+
+ # Now let's disconnect and insert some data.
+ self.disconnect()
+
+ self.test_handler.on_rdata.reset_mock()
+
+ self.get_success(
+ self.hs.get_datastore().insert_receipt(
+ "!room2:blue", "m.read", USER_ID, ["$event2:foo"], {"a": 2}
+ )
+ )
+ self.replicate()
+
+ # Nothing should have happened as we are disconnected
+ self.test_handler.on_rdata.assert_not_called()
+
+ self.reconnect()
+ self.pump(0.1)
+
+ # We should now have caught up and get the missing data
+ self.test_handler.on_rdata.assert_called_once()
+ stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+ self.assertEqual(stream_name, "receipts")
+ self.assertEqual(token, 3)
+ self.assertEqual(1, len(rdata_rows))
+
+ row = rdata_rows[0] # type: ReceiptsStream.ReceiptsStreamRow
+ self.assertEqual("!room2:blue", row.room_id)
+ self.assertEqual("m.read", row.receipt_type)
+ self.assertEqual(USER_ID, row.user_id)
+ self.assertEqual("$event2:foo", row.event_id)
+ self.assertEqual({"a": 2}, row.data)
diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py
new file mode 100644
index 0000000000..fd62b26356
--- /dev/null
+++ b/tests/replication/tcp/streams/test_typing.py
@@ -0,0 +1,77 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from mock import Mock
+
+from synapse.handlers.typing import RoomMember
+from synapse.replication.tcp.streams import TypingStream
+
+from tests.replication._base import BaseStreamTestCase
+
+USER_ID = "@feeling:blue"
+
+
+class TypingStreamTestCase(BaseStreamTestCase):
+ def _build_replication_data_handler(self):
+ return Mock(wraps=super()._build_replication_data_handler())
+
+ def test_typing(self):
+ typing = self.hs.get_typing_handler()
+
+ room_id = "!bar:blue"
+
+ self.reconnect()
+
+ typing._push_update(member=RoomMember(room_id, USER_ID), typing=True)
+
+ self.reactor.advance(0)
+
+ # We should now see an attempt to connect to the master
+ request = self.handle_http_replication_attempt()
+ self.assert_request_is_get_repl_stream_updates(request, "typing")
+
+ self.test_handler.on_rdata.assert_called_once()
+ stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+ self.assertEqual(stream_name, "typing")
+ self.assertEqual(1, len(rdata_rows))
+ row = rdata_rows[0] # type: TypingStream.TypingStreamRow
+ self.assertEqual(room_id, row.room_id)
+ self.assertEqual([USER_ID], row.user_ids)
+
+ # Now let's disconnect and insert some data.
+ self.disconnect()
+
+ self.test_handler.on_rdata.reset_mock()
+
+ typing._push_update(member=RoomMember(room_id, USER_ID), typing=False)
+
+ self.test_handler.on_rdata.assert_not_called()
+
+ self.reconnect()
+ self.pump(0.1)
+
+ # We should now see an attempt to connect to the master
+ request = self.handle_http_replication_attempt()
+ self.assert_request_is_get_repl_stream_updates(request, "typing")
+
+ # The from token should be the token from the last RDATA we got.
+ self.assertEqual(int(request.args[b"from_token"][0]), token)
+
+ self.test_handler.on_rdata.assert_called_once()
+ stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+ self.assertEqual(stream_name, "typing")
+ self.assertEqual(1, len(rdata_rows))
+ row = rdata_rows[0]
+ self.assertEqual(room_id, row.room_id)
+ self.assertEqual([], row.user_ids)
diff --git a/tests/replication/tcp/test_commands.py b/tests/replication/tcp/test_commands.py
new file mode 100644
index 0000000000..60c10a441a
--- /dev/null
+++ b/tests/replication/tcp/test_commands.py
@@ -0,0 +1,44 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from synapse.replication.tcp.commands import (
+ RdataCommand,
+ ReplicateCommand,
+ parse_command_from_line,
+)
+
+from tests.unittest import TestCase
+
+
+class ParseCommandTestCase(TestCase):
+ def test_parse_one_word_command(self):
+ line = "REPLICATE"
+ cmd = parse_command_from_line(line)
+ self.assertIsInstance(cmd, ReplicateCommand)
+
+ def test_parse_rdata(self):
+ line = 'RDATA events master 6287863 ["ev", ["$eventid", "!roomid", "type", null, null, null]]'
+ cmd = parse_command_from_line(line)
+ assert isinstance(cmd, RdataCommand)
+ self.assertEqual(cmd.stream_name, "events")
+ self.assertEqual(cmd.instance_name, "master")
+ self.assertEqual(cmd.token, 6287863)
+
+ def test_parse_rdata_batch(self):
+ line = 'RDATA presence master batch ["@foo:example.com", "online"]'
+ cmd = parse_command_from_line(line)
+ assert isinstance(cmd, RdataCommand)
+ self.assertEqual(cmd.stream_name, "presence")
+ self.assertEqual(cmd.instance_name, "master")
+ self.assertIsNone(cmd.token)
diff --git a/tests/replication/tcp/test_remote_server_up.py b/tests/replication/tcp/test_remote_server_up.py
new file mode 100644
index 0000000000..d1c15caeb0
--- /dev/null
+++ b/tests/replication/tcp/test_remote_server_up.py
@@ -0,0 +1,62 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Tuple
+
+from twisted.internet.interfaces import IProtocol
+from twisted.test.proto_helpers import StringTransport
+
+from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
+
+from tests.unittest import HomeserverTestCase
+
+
+class RemoteServerUpTestCase(HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
+ self.factory = ReplicationStreamProtocolFactory(hs)
+
+ def _make_client(self) -> Tuple[IProtocol, StringTransport]:
+ """Create a new direct TCP replication connection
+ """
+
+ proto = self.factory.buildProtocol(("127.0.0.1", 0))
+ transport = StringTransport()
+ proto.makeConnection(transport)
+
+ # We can safely ignore the commands received during connection.
+ self.pump()
+ transport.clear()
+
+ return proto, transport
+
+ def test_relay(self):
+ """Test that Synapse will relay REMOTE_SERVER_UP commands to all
+ other connections, but not the one that sent it.
+ """
+
+ proto1, transport1 = self._make_client()
+
+ # We shouldn't receive an echo.
+ proto1.dataReceived(b"REMOTE_SERVER_UP example.com\n")
+ self.pump()
+ self.assertEqual(transport1.value(), b"")
+
+ # But we should see an echo if we connect another client
+ proto2, transport2 = self._make_client()
+ proto1.dataReceived(b"REMOTE_SERVER_UP example.com\n")
+
+ self.pump()
+ self.assertEqual(transport1.value(), b"")
+ self.assertEqual(transport2.value(), b"REMOTE_SERVER_UP example.com\n")
diff --git a/tests/replication/test_federation_ack.py b/tests/replication/test_federation_ack.py
new file mode 100644
index 0000000000..5448d9f0dc
--- /dev/null
+++ b/tests/replication/test_federation_ack.py
@@ -0,0 +1,71 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import mock
+
+from synapse.app.generic_worker import GenericWorkerServer
+from synapse.replication.tcp.commands import FederationAckCommand
+from synapse.replication.tcp.protocol import AbstractConnection
+from synapse.replication.tcp.streams.federation import FederationStream
+
+from tests.unittest import HomeserverTestCase
+
+
+class FederationAckTestCase(HomeserverTestCase):
+ def default_config(self) -> dict:
+ config = super().default_config()
+ config["worker_app"] = "synapse.app.federation_sender"
+ config["send_federation"] = True
+ return config
+
+ def make_homeserver(self, reactor, clock):
+ hs = self.setup_test_homeserver(homeserverToUse=GenericWorkerServer)
+ return hs
+
+ def test_federation_ack_sent(self):
+ """A FEDERATION_ACK should be sent back after each RDATA federation
+
+ This test checks that the federation sender is correctly sending back
+ FEDERATION_ACK messages. The test works by spinning up a federation_sender
+ worker server, and then fishing out its ReplicationCommandHandler. We wire
+ the RCH up to a mock connection (so that we can observe the command being sent)
+ and then poke in an RDATA row.
+
+ XXX: it might be nice to do this by pretending to be a synapse master worker
+ (or a redis server), and having the worker connect to us via a mocked-up TCP
+ transport, rather than assuming that the implementation has a
+ ReplicationCommandHandler.
+ """
+ rch = self.hs.get_tcp_replication()
+
+ # wire up the ReplicationCommandHandler to a mock connection
+ mock_connection = mock.Mock(spec=AbstractConnection)
+ rch.new_connection(mock_connection)
+
+ # tell it it received an RDATA row
+ self.get_success(
+ rch.on_rdata(
+ "federation",
+ "master",
+ token=10,
+ rows=[FederationStream.FederationStreamRow(type="x", data=[1, 2, 3])],
+ )
+ )
+
+ # now check that the FEDERATION_ACK was sent
+ mock_connection.send_command.assert_called_once()
+ cmd = mock_connection.send_command.call_args[0][0]
+ assert isinstance(cmd, FederationAckCommand)
+ self.assertEqual(cmd.token, 10)
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index 5877bb2133..977615ebef 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -13,17 +13,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import hashlib
-import hmac
import json
+import os
+import urllib.parse
+from binascii import unhexlify
from mock import Mock
+from twisted.internet.defer import Deferred
+
import synapse.rest.admin
-from synapse.api.constants import UserTypes
from synapse.http.server import JsonResource
+from synapse.logging.context import make_deferred_yieldable
from synapse.rest.admin import VersionServlet
-from synapse.rest.client.v1 import events, login, room
+from synapse.rest.client.v1 import login, room
from synapse.rest.client.v2_alpha import groups
from tests import unittest
@@ -47,517 +50,440 @@ class VersionTestCase(unittest.HomeserverTestCase):
)
-class UserRegisterTestCase(unittest.HomeserverTestCase):
-
- servlets = [synapse.rest.admin.register_servlets_for_client_rest_resource]
-
- def make_homeserver(self, reactor, clock):
-
- self.url = "/_matrix/client/r0/admin/register"
+class DeleteGroupTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ login.register_servlets,
+ groups.register_servlets,
+ ]
- self.registration_handler = Mock()
- self.identity_handler = Mock()
- self.login_handler = Mock()
- self.device_handler = Mock()
- self.device_handler.check_device_registered = Mock(return_value="FAKE")
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
- self.datastore = Mock(return_value=Mock())
- self.datastore.get_current_state_deltas = Mock(return_value=[])
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
- self.secrets = Mock()
+ self.other_user = self.register_user("user", "pass")
+ self.other_user_token = self.login("user", "pass")
- self.hs = self.setup_test_homeserver()
+ def test_delete_group(self):
+ # Create a new group
+ request, channel = self.make_request(
+ "POST",
+ "/create_group".encode("ascii"),
+ access_token=self.admin_user_tok,
+ content={"localpart": "test"},
+ )
- self.hs.config.registration_shared_secret = "shared"
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
- self.hs.get_media_repository = Mock()
- self.hs.get_deactivate_account_handler = Mock()
+ group_id = channel.json_body["group_id"]
- return self.hs
+ self._check_group(group_id, expect_code=200)
- def test_disabled(self):
- """
- If there is no shared secret, registration through this method will be
- prevented.
- """
- self.hs.config.registration_shared_secret = None
+ # Invite/join another user
- request, channel = self.make_request("POST", self.url, b"{}")
+ url = "/groups/%s/admin/users/invite/%s" % (group_id, self.other_user)
+ request, channel = self.make_request(
+ "PUT", url.encode("ascii"), access_token=self.admin_user_tok, content={}
+ )
self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(
- "Shared secret registration is not enabled", channel.json_body["error"]
+ url = "/groups/%s/self/accept_invite" % (group_id,)
+ request, channel = self.make_request(
+ "PUT", url.encode("ascii"), access_token=self.other_user_token, content={}
)
-
- def test_get_nonce(self):
- """
- Calling GET on the endpoint will return a randomised nonce, using the
- homeserver's secrets provider.
- """
- secrets = Mock()
- secrets.token_hex = Mock(return_value="abcd")
-
- self.hs.get_secrets = Mock(return_value=secrets)
-
- request, channel = self.make_request("GET", self.url)
self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(channel.json_body, {"nonce": "abcd"})
-
- def test_expired_nonce(self):
- """
- Calling GET on the endpoint will return a randomised nonce, which will
- only last for SALT_TIMEOUT (60s).
- """
- request, channel = self.make_request("GET", self.url)
- self.render(request)
- nonce = channel.json_body["nonce"]
+ # Check other user knows they're in the group
+ self.assertIn(group_id, self._get_groups_user_is_in(self.admin_user_tok))
+ self.assertIn(group_id, self._get_groups_user_is_in(self.other_user_token))
- # 59 seconds
- self.reactor.advance(59)
+ # Now delete the group
+ url = "/admin/delete_group/" + group_id
+ request, channel = self.make_request(
+ "POST",
+ url.encode("ascii"),
+ access_token=self.admin_user_tok,
+ content={"localpart": "test"},
+ )
- body = json.dumps({"nonce": nonce})
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual("username must be specified", channel.json_body["error"])
-
- # 61 seconds
- self.reactor.advance(2)
-
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
- self.render(request)
+ # Check group returns 404
+ self._check_group(group_id, expect_code=404)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual("unrecognised nonce", channel.json_body["error"])
+ # Check users don't think they're in the group
+ self.assertNotIn(group_id, self._get_groups_user_is_in(self.admin_user_tok))
+ self.assertNotIn(group_id, self._get_groups_user_is_in(self.other_user_token))
- def test_register_incorrect_nonce(self):
- """
- Only the provided nonce can be used, as it's checked in the MAC.
+ def _check_group(self, group_id, expect_code):
+ """Assert that trying to fetch the given group results in the given
+ HTTP status code
"""
- request, channel = self.make_request("GET", self.url)
- self.render(request)
- nonce = channel.json_body["nonce"]
-
- want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
- want_mac.update(b"notthenonce\x00bob\x00abc123\x00admin")
- want_mac = want_mac.hexdigest()
-
- body = json.dumps(
- {
- "nonce": nonce,
- "username": "bob",
- "password": "abc123",
- "admin": True,
- "mac": want_mac,
- }
- )
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
- self.render(request)
- self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual("HMAC incorrect", channel.json_body["error"])
+ url = "/groups/%s/profile" % (group_id,)
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok
+ )
- def test_register_correct_nonce(self):
- """
- When the correct nonce is provided, and the right key is provided, the
- user is registered.
- """
- request, channel = self.make_request("GET", self.url)
self.render(request)
- nonce = channel.json_body["nonce"]
-
- want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
- want_mac.update(
- nonce.encode("ascii") + b"\x00bob\x00abc123\x00admin\x00support"
+ self.assertEqual(
+ expect_code, int(channel.result["code"]), msg=channel.result["body"]
)
- want_mac = want_mac.hexdigest()
- body = json.dumps(
- {
- "nonce": nonce,
- "username": "bob",
- "password": "abc123",
- "admin": True,
- "user_type": UserTypes.SUPPORT,
- "mac": want_mac,
- }
+ def _get_groups_user_is_in(self, access_token):
+ """Returns the list of groups the user is in (given their access token)
+ """
+ request, channel = self.make_request(
+ "GET", "/joined_groups".encode("ascii"), access_token=access_token
)
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
- self.render(request)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual("@bob:test", channel.json_body["user_id"])
-
- def test_nonce_reuse(self):
- """
- A valid unrecognised nonce.
- """
- request, channel = self.make_request("GET", self.url)
- self.render(request)
- nonce = channel.json_body["nonce"]
-
- want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
- want_mac.update(nonce.encode("ascii") + b"\x00bob\x00abc123\x00admin")
- want_mac = want_mac.hexdigest()
-
- body = json.dumps(
- {
- "nonce": nonce,
- "username": "bob",
- "password": "abc123",
- "admin": True,
- "mac": want_mac,
- }
- )
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
self.render(request)
-
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual("@bob:test", channel.json_body["user_id"])
-
- # Now, try and reuse it
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
- self.render(request)
-
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual("unrecognised nonce", channel.json_body["error"])
- def test_missing_parts(self):
- """
- Synapse will complain if you don't give nonce, username, password, and
- mac. Admin and user_types are optional. Additional checks are done for length
- and type.
- """
-
- def nonce():
- request, channel = self.make_request("GET", self.url)
- self.render(request)
- return channel.json_body["nonce"]
-
- #
- # Nonce check
- #
-
- # Must be present
- body = json.dumps({})
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
- self.render(request)
-
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual("nonce must be specified", channel.json_body["error"])
+ return channel.json_body["groups"]
- #
- # Username checks
- #
- # Must be present
- body = json.dumps({"nonce": nonce()})
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
- self.render(request)
+class QuarantineMediaTestCase(unittest.HomeserverTestCase):
+ """Test /quarantine_media admin API.
+ """
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual("username must be specified", channel.json_body["error"])
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ synapse.rest.admin.register_servlets_for_media_repo,
+ login.register_servlets,
+ room.register_servlets,
+ ]
- # Must be a string
- body = json.dumps({"nonce": nonce(), "username": 1234})
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
- self.render(request)
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+ self.hs = hs
+
+ # Allow for uploading and downloading to/from the media repo
+ self.media_repo = hs.get_media_repository_resource()
+ self.download_resource = self.media_repo.children[b"download"]
+ self.upload_resource = self.media_repo.children[b"upload"]
+ self.image_data = unhexlify(
+ b"89504e470d0a1a0a0000000d4948445200000001000000010806"
+ b"0000001f15c4890000000a49444154789c63000100000500010d"
+ b"0a2db40000000049454e44ae426082"
+ )
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual("Invalid username", channel.json_body["error"])
+ def make_homeserver(self, reactor, clock):
- # Must not have null bytes
- body = json.dumps({"nonce": nonce(), "username": "abcd\u0000"})
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
- self.render(request)
+ self.fetches = []
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual("Invalid username", channel.json_body["error"])
+ def get_file(destination, path, output_stream, args=None, max_size=None):
+ """
+ Returns tuple[int,dict,str,int] of file length, response headers,
+ absolute URI, and response code.
+ """
- # Must not have null bytes
- body = json.dumps({"nonce": nonce(), "username": "a" * 1000})
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
- self.render(request)
+ def write_to(r):
+ data, response = r
+ output_stream.write(data)
+ return response
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual("Invalid username", channel.json_body["error"])
+ d = Deferred()
+ d.addCallback(write_to)
+ self.fetches.append((d, destination, path, args))
+ return make_deferred_yieldable(d)
- #
- # Password checks
- #
+ client = Mock()
+ client.get_file = get_file
- # Must be present
- body = json.dumps({"nonce": nonce(), "username": "a"})
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
- self.render(request)
+ self.storage_path = self.mktemp()
+ self.media_store_path = self.mktemp()
+ os.mkdir(self.storage_path)
+ os.mkdir(self.media_store_path)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual("password must be specified", channel.json_body["error"])
+ config = self.default_config()
+ config["media_store_path"] = self.media_store_path
+ config["thumbnail_requirements"] = {}
+ config["max_image_pixels"] = 2000000
- # Must be a string
- body = json.dumps({"nonce": nonce(), "username": "a", "password": 1234})
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
- self.render(request)
+ provider_config = {
+ "module": "synapse.rest.media.v1.storage_provider.FileStorageProviderBackend",
+ "store_local": True,
+ "store_synchronous": False,
+ "store_remote": True,
+ "config": {"directory": self.storage_path},
+ }
+ config["media_storage_providers"] = [provider_config]
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual("Invalid password", channel.json_body["error"])
+ hs = self.setup_test_homeserver(config=config, http_client=client)
- # Must not have null bytes
- body = json.dumps({"nonce": nonce(), "username": "a", "password": "abcd\u0000"})
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
- self.render(request)
+ return hs
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual("Invalid password", channel.json_body["error"])
+ def test_quarantine_media_requires_admin(self):
+ self.register_user("nonadmin", "pass", admin=False)
+ non_admin_user_tok = self.login("nonadmin", "pass")
- # Super long
- body = json.dumps({"nonce": nonce(), "username": "a", "password": "A" * 1000})
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ # Attempt quarantine media APIs as non-admin
+ url = "/_synapse/admin/v1/media/quarantine/example.org/abcde12345"
+ request, channel = self.make_request(
+ "POST", url.encode("ascii"), access_token=non_admin_user_tok,
+ )
self.render(request)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual("Invalid password", channel.json_body["error"])
-
- #
- # user_type check
- #
+ # Expect a forbidden error
+ self.assertEqual(
+ 403,
+ int(channel.result["code"]),
+ msg="Expected forbidden on quarantining media as a non-admin",
+ )
- # Invalid user_type
- body = json.dumps(
- {
- "nonce": nonce(),
- "username": "a",
- "password": "1234",
- "user_type": "invalid",
- }
+ # And the roomID/userID endpoint
+ url = "/_synapse/admin/v1/room/!room%3Aexample.com/media/quarantine"
+ request, channel = self.make_request(
+ "POST", url.encode("ascii"), access_token=non_admin_user_tok,
)
- request, channel = self.make_request("POST", self.url, body.encode("utf8"))
self.render(request)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual("Invalid user type", channel.json_body["error"])
-
-
-class ShutdownRoomTestCase(unittest.HomeserverTestCase):
- servlets = [
- synapse.rest.admin.register_servlets_for_client_rest_resource,
- login.register_servlets,
- events.register_servlets,
- room.register_servlets,
- room.register_deprecated_servlets,
- ]
-
- def prepare(self, reactor, clock, hs):
- self.event_creation_handler = hs.get_event_creation_handler()
- hs.config.user_consent_version = "1"
-
- consent_uri_builder = Mock()
- consent_uri_builder.build_user_consent_uri.return_value = "http://example.com"
- self.event_creation_handler._consent_uri_builder = consent_uri_builder
-
- self.store = hs.get_datastore()
-
- self.admin_user = self.register_user("admin", "pass", admin=True)
- self.admin_user_tok = self.login("admin", "pass")
+ # Expect a forbidden error
+ self.assertEqual(
+ 403,
+ int(channel.result["code"]),
+ msg="Expected forbidden on quarantining media as a non-admin",
+ )
- self.other_user = self.register_user("user", "pass")
- self.other_user_token = self.login("user", "pass")
+ def test_quarantine_media_by_id(self):
+ self.register_user("id_admin", "pass", admin=True)
+ admin_user_tok = self.login("id_admin", "pass")
- # Mark the admin user as having consented
- self.get_success(self.store.user_set_consent_version(self.admin_user, "1"))
+ self.register_user("id_nonadmin", "pass", admin=False)
+ non_admin_user_tok = self.login("id_nonadmin", "pass")
- def test_shutdown_room_consent(self):
- """Test that we can shutdown rooms with local users who have not
- yet accepted the privacy policy. This used to fail when we tried to
- force part the user from the old room.
- """
- self.event_creation_handler._block_events_without_consent_error = None
+ # Upload some media into the room
+ response = self.helper.upload_media(
+ self.upload_resource, self.image_data, tok=admin_user_tok
+ )
- room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token)
+ # Extract media ID from the response
+ server_name_and_media_id = response["content_uri"][6:] # Cut off 'mxc://'
+ server_name, media_id = server_name_and_media_id.split("/")
- # Assert one user in room
- users_in_room = self.get_success(self.store.get_users_in_room(room_id))
- self.assertEqual([self.other_user], users_in_room)
+ # Attempt to access the media
+ request, channel = self.make_request(
+ "GET",
+ server_name_and_media_id,
+ shorthand=False,
+ access_token=non_admin_user_tok,
+ )
+ request.render(self.download_resource)
+ self.pump(1.0)
- # Enable require consent to send events
- self.event_creation_handler._block_events_without_consent_error = "Error"
+ # Should be successful
+ self.assertEqual(200, int(channel.code), msg=channel.result["body"])
- # Assert that the user is getting consent error
- self.helper.send(
- room_id, body="foo", tok=self.other_user_token, expect_code=403
+ # Quarantine the media
+ url = "/_synapse/admin/v1/media/quarantine/%s/%s" % (
+ urllib.parse.quote(server_name),
+ urllib.parse.quote(media_id),
)
+ request, channel = self.make_request("POST", url, access_token=admin_user_tok,)
+ self.render(request)
+ self.pump(1.0)
+ self.assertEqual(200, int(channel.code), msg=channel.result["body"])
- # Test that the admin can still send shutdown
- url = "admin/shutdown_room/" + room_id
+ # Attempt to access the media
request, channel = self.make_request(
- "POST",
- url.encode("ascii"),
- json.dumps({"new_room_user_id": self.admin_user}),
- access_token=self.admin_user_tok,
+ "GET",
+ server_name_and_media_id,
+ shorthand=False,
+ access_token=admin_user_tok,
)
- self.render(request)
-
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ request.render(self.download_resource)
+ self.pump(1.0)
- # Assert there is now no longer anyone in the room
- users_in_room = self.get_success(self.store.get_users_in_room(room_id))
- self.assertEqual([], users_in_room)
+ # Should be quarantined
+ self.assertEqual(
+ 404,
+ int(channel.code),
+ msg=(
+ "Expected to receive a 404 on accessing quarantined media: %s"
+ % server_name_and_media_id
+ ),
+ )
- def test_shutdown_room_block_peek(self):
- """Test that a world_readable room can no longer be peeked into after
- it has been shut down.
- """
+ def test_quarantine_all_media_in_room(self, override_url_template=None):
+ self.register_user("room_admin", "pass", admin=True)
+ admin_user_tok = self.login("room_admin", "pass")
- self.event_creation_handler._block_events_without_consent_error = None
+ non_admin_user = self.register_user("room_nonadmin", "pass", admin=False)
+ non_admin_user_tok = self.login("room_nonadmin", "pass")
- room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token)
+ room_id = self.helper.create_room_as(non_admin_user, tok=admin_user_tok)
+ self.helper.join(room_id, non_admin_user, tok=non_admin_user_tok)
- # Enable world readable
- url = "rooms/%s/state/m.room.history_visibility" % (room_id,)
- request, channel = self.make_request(
- "PUT",
- url.encode("ascii"),
- json.dumps({"history_visibility": "world_readable"}),
- access_token=self.other_user_token,
+ # Upload some media
+ response_1 = self.helper.upload_media(
+ self.upload_resource, self.image_data, tok=non_admin_user_tok
)
- self.render(request)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-
- # Test that the admin can still send shutdown
- url = "admin/shutdown_room/" + room_id
- request, channel = self.make_request(
- "POST",
- url.encode("ascii"),
- json.dumps({"new_room_user_id": self.admin_user}),
- access_token=self.admin_user_tok,
+ response_2 = self.helper.upload_media(
+ self.upload_resource, self.image_data, tok=non_admin_user_tok
)
- self.render(request)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ # Extract mxcs
+ mxc_1 = response_1["content_uri"]
+ mxc_2 = response_2["content_uri"]
+
+ # Send it into the room
+ self.helper.send_event(
+ room_id,
+ "m.room.message",
+ content={"body": "image-1", "msgtype": "m.image", "url": mxc_1},
+ txn_id="111",
+ tok=non_admin_user_tok,
+ )
+ self.helper.send_event(
+ room_id,
+ "m.room.message",
+ content={"body": "image-2", "msgtype": "m.image", "url": mxc_2},
+ txn_id="222",
+ tok=non_admin_user_tok,
+ )
- # Assert we can no longer peek into the room
- self._assert_peek(room_id, expect_code=403)
+ # Quarantine all media in the room
+ if override_url_template:
+ url = override_url_template % urllib.parse.quote(room_id)
+ else:
+ url = "/_synapse/admin/v1/room/%s/media/quarantine" % urllib.parse.quote(
+ room_id
+ )
+ request, channel = self.make_request("POST", url, access_token=admin_user_tok,)
+ self.render(request)
+ self.pump(1.0)
+ self.assertEqual(200, int(channel.code), msg=channel.result["body"])
+ self.assertEqual(
+ json.loads(channel.result["body"].decode("utf-8")),
+ {"num_quarantined": 2},
+ "Expected 2 quarantined items",
+ )
- def _assert_peek(self, room_id, expect_code):
- """Assert that the admin user can (or cannot) peek into the room.
- """
+ # Convert mxc URLs to server/media_id strings
+ server_and_media_id_1 = mxc_1[6:]
+ server_and_media_id_2 = mxc_2[6:]
- url = "rooms/%s/initialSync" % (room_id,)
+ # Test that we cannot download any of the media anymore
request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok
+ "GET",
+ server_and_media_id_1,
+ shorthand=False,
+ access_token=non_admin_user_tok,
)
- self.render(request)
+ request.render(self.download_resource)
+ self.pump(1.0)
+
+ # Should be quarantined
self.assertEqual(
- expect_code, int(channel.result["code"]), msg=channel.result["body"]
+ 404,
+ int(channel.code),
+ msg=(
+ "Expected to receive a 404 on accessing quarantined media: %s"
+ % server_and_media_id_1
+ ),
)
- url = "events?timeout=0&room_id=" + room_id
request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok
+ "GET",
+ server_and_media_id_2,
+ shorthand=False,
+ access_token=non_admin_user_tok,
)
- self.render(request)
+ request.render(self.download_resource)
+ self.pump(1.0)
+
+ # Should be quarantined
self.assertEqual(
- expect_code, int(channel.result["code"]), msg=channel.result["body"]
+ 404,
+ int(channel.code),
+ msg=(
+ "Expected to receive a 404 on accessing quarantined media: %s"
+ % server_and_media_id_2
+ ),
)
+ def test_quaraantine_all_media_in_room_deprecated_api_path(self):
+ # Perform the above test with the deprecated API path
+ self.test_quarantine_all_media_in_room("/_synapse/admin/v1/quarantine_media/%s")
-class DeleteGroupTestCase(unittest.HomeserverTestCase):
- servlets = [
- synapse.rest.admin.register_servlets_for_client_rest_resource,
- login.register_servlets,
- groups.register_servlets,
- ]
-
- def prepare(self, reactor, clock, hs):
- self.store = hs.get_datastore()
-
- self.admin_user = self.register_user("admin", "pass", admin=True)
- self.admin_user_tok = self.login("admin", "pass")
+ def test_quarantine_all_media_by_user(self):
+ self.register_user("user_admin", "pass", admin=True)
+ admin_user_tok = self.login("user_admin", "pass")
- self.other_user = self.register_user("user", "pass")
- self.other_user_token = self.login("user", "pass")
+ non_admin_user = self.register_user("user_nonadmin", "pass", admin=False)
+ non_admin_user_tok = self.login("user_nonadmin", "pass")
- def test_delete_group(self):
- # Create a new group
- request, channel = self.make_request(
- "POST",
- "/create_group".encode("ascii"),
- access_token=self.admin_user_tok,
- content={"localpart": "test"},
+ # Upload some media
+ response_1 = self.helper.upload_media(
+ self.upload_resource, self.image_data, tok=non_admin_user_tok
+ )
+ response_2 = self.helper.upload_media(
+ self.upload_resource, self.image_data, tok=non_admin_user_tok
)
- self.render(request)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-
- group_id = channel.json_body["group_id"]
-
- self._check_group(group_id, expect_code=200)
-
- # Invite/join another user
+ # Extract media IDs
+ server_and_media_id_1 = response_1["content_uri"][6:]
+ server_and_media_id_2 = response_2["content_uri"][6:]
- url = "/groups/%s/admin/users/invite/%s" % (group_id, self.other_user)
- request, channel = self.make_request(
- "PUT", url.encode("ascii"), access_token=self.admin_user_tok, content={}
+ # Quarantine all media by this user
+ url = "/_synapse/admin/v1/user/%s/media/quarantine" % urllib.parse.quote(
+ non_admin_user
)
- self.render(request)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-
- url = "/groups/%s/self/accept_invite" % (group_id,)
request, channel = self.make_request(
- "PUT", url.encode("ascii"), access_token=self.other_user_token, content={}
+ "POST", url.encode("ascii"), access_token=admin_user_tok,
)
self.render(request)
+ self.pump(1.0)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-
- # Check other user knows they're in the group
- self.assertIn(group_id, self._get_groups_user_is_in(self.admin_user_tok))
- self.assertIn(group_id, self._get_groups_user_is_in(self.other_user_token))
-
- # Now delete the group
- url = "/admin/delete_group/" + group_id
- request, channel = self.make_request(
- "POST",
- url.encode("ascii"),
- access_token=self.admin_user_tok,
- content={"localpart": "test"},
+ self.assertEqual(
+ json.loads(channel.result["body"].decode("utf-8")),
+ {"num_quarantined": 2},
+ "Expected 2 quarantined items",
)
- self.render(request)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-
- # Check group returns 404
- self._check_group(group_id, expect_code=404)
-
- # Check users don't think they're in the group
- self.assertNotIn(group_id, self._get_groups_user_is_in(self.admin_user_tok))
- self.assertNotIn(group_id, self._get_groups_user_is_in(self.other_user_token))
-
- def _check_group(self, group_id, expect_code):
- """Assert that trying to fetch the given group results in the given
- HTTP status code
- """
-
- url = "/groups/%s/profile" % (group_id,)
+ # Attempt to access each piece of media
request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok
+ "GET",
+ server_and_media_id_1,
+ shorthand=False,
+ access_token=non_admin_user_tok,
)
+ request.render(self.download_resource)
+ self.pump(1.0)
- self.render(request)
+ # Should be quarantined
self.assertEqual(
- expect_code, int(channel.result["code"]), msg=channel.result["body"]
+ 404,
+ int(channel.code),
+ msg=(
+ "Expected to receive a 404 on accessing quarantined media: %s"
+ % server_and_media_id_1,
+ ),
)
- def _get_groups_user_is_in(self, access_token):
- """Returns the list of groups the user is in (given their access token)
- """
+ # Attempt to access each piece of media
request, channel = self.make_request(
- "GET", "/joined_groups".encode("ascii"), access_token=access_token
+ "GET",
+ server_and_media_id_2,
+ shorthand=False,
+ access_token=non_admin_user_tok,
)
+ request.render(self.download_resource)
+ self.pump(1.0)
- self.render(request)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-
- return channel.json_body["groups"]
+ # Should be quarantined
+ self.assertEqual(
+ 404,
+ int(channel.code),
+ msg=(
+ "Expected to receive a 404 on accessing quarantined media: %s"
+ % server_and_media_id_2
+ ),
+ )
diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py
new file mode 100644
index 0000000000..faa7f381a9
--- /dev/null
+++ b/tests/rest/admin/test_device.py
@@ -0,0 +1,541 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Dirk Klimpel
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import urllib.parse
+
+import synapse.rest.admin
+from synapse.api.errors import Codes
+from synapse.rest.client.v1 import login
+
+from tests import unittest
+
+
+class DeviceRestTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.handler = hs.get_device_handler()
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+ self.other_user_token = self.login("user", "pass")
+ res = self.get_success(self.handler.get_devices_by_user(self.other_user))
+ self.other_user_device_id = res[0]["device_id"]
+
+ self.url = "/_synapse/admin/v2/users/%s/devices/%s" % (
+ urllib.parse.quote(self.other_user),
+ self.other_user_device_id,
+ )
+
+ def test_no_auth(self):
+ """
+ Try to get a device of an user without authentication.
+ """
+ request, channel = self.make_request("GET", self.url, b"{}")
+ self.render(request)
+
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ request, channel = self.make_request("PUT", self.url, b"{}")
+ self.render(request)
+
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ request, channel = self.make_request("DELETE", self.url, b"{}")
+ self.render(request)
+
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_requester_is_no_admin(self):
+ """
+ If the user is not a server admin, an error is returned.
+ """
+ request, channel = self.make_request(
+ "GET", self.url, access_token=self.other_user_token,
+ )
+ self.render(request)
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ request, channel = self.make_request(
+ "PUT", self.url, access_token=self.other_user_token,
+ )
+ self.render(request)
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ request, channel = self.make_request(
+ "DELETE", self.url, access_token=self.other_user_token,
+ )
+ self.render(request)
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_user_does_not_exist(self):
+ """
+ Tests that a lookup for a user that does not exist returns a 404
+ """
+ url = (
+ "/_synapse/admin/v2/users/@unknown_person:test/devices/%s"
+ % self.other_user_device_id
+ )
+
+ request, channel = self.make_request(
+ "GET", url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+ request, channel = self.make_request(
+ "PUT", url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+ request, channel = self.make_request(
+ "DELETE", url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+ def test_user_is_not_local(self):
+ """
+ Tests that a lookup for a user that is not a local returns a 400
+ """
+ url = (
+ "/_synapse/admin/v2/users/@unknown_person:unknown_domain/devices/%s"
+ % self.other_user_device_id
+ )
+
+ request, channel = self.make_request(
+ "GET", url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual("Can only lookup local users", channel.json_body["error"])
+
+ request, channel = self.make_request(
+ "PUT", url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual("Can only lookup local users", channel.json_body["error"])
+
+ request, channel = self.make_request(
+ "DELETE", url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual("Can only lookup local users", channel.json_body["error"])
+
+ def test_unknown_device(self):
+ """
+ Tests that a lookup for a device that does not exist returns either 404 or 200.
+ """
+ url = "/_synapse/admin/v2/users/%s/devices/unknown_device" % urllib.parse.quote(
+ self.other_user
+ )
+
+ request, channel = self.make_request(
+ "GET", url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+ request, channel = self.make_request(
+ "PUT", url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ request, channel = self.make_request(
+ "DELETE", url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ # Delete unknown device returns status 200
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ def test_update_device_too_long_display_name(self):
+ """
+ Update a device with a display name that is invalid (too long).
+ """
+ # Set iniital display name.
+ update = {"display_name": "new display"}
+ self.get_success(
+ self.handler.update_device(
+ self.other_user, self.other_user_device_id, update
+ )
+ )
+
+ # Request to update a device display name with a new value that is longer than allowed.
+ update = {
+ "display_name": "a"
+ * (synapse.handlers.device.MAX_DEVICE_DISPLAY_NAME_LEN + 1)
+ }
+
+ body = json.dumps(update)
+ request, channel = self.make_request(
+ "PUT",
+ self.url,
+ access_token=self.admin_user_tok,
+ content=body.encode(encoding="utf_8"),
+ )
+ self.render(request)
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+
+ # Ensure the display name was not updated.
+ request, channel = self.make_request(
+ "GET", self.url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual("new display", channel.json_body["display_name"])
+
+ def test_update_no_display_name(self):
+ """
+ Tests that a update for a device without JSON returns a 200
+ """
+ # Set iniital display name.
+ update = {"display_name": "new display"}
+ self.get_success(
+ self.handler.update_device(
+ self.other_user, self.other_user_device_id, update
+ )
+ )
+
+ request, channel = self.make_request(
+ "PUT", self.url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ # Ensure the display name was not updated.
+ request, channel = self.make_request(
+ "GET", self.url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual("new display", channel.json_body["display_name"])
+
+ def test_update_display_name(self):
+ """
+ Tests a normal successful update of display name
+ """
+ # Set new display_name
+ body = json.dumps({"display_name": "new displayname"})
+ request, channel = self.make_request(
+ "PUT",
+ self.url,
+ access_token=self.admin_user_tok,
+ content=body.encode(encoding="utf_8"),
+ )
+ self.render(request)
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ # Check new display_name
+ request, channel = self.make_request(
+ "GET", self.url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual("new displayname", channel.json_body["display_name"])
+
+ def test_get_device(self):
+ """
+ Tests that a normal lookup for a device is successfully
+ """
+ request, channel = self.make_request(
+ "GET", self.url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(self.other_user, channel.json_body["user_id"])
+ # Check that all fields are available
+ self.assertIn("user_id", channel.json_body)
+ self.assertIn("device_id", channel.json_body)
+ self.assertIn("display_name", channel.json_body)
+ self.assertIn("last_seen_ip", channel.json_body)
+ self.assertIn("last_seen_ts", channel.json_body)
+
+ def test_delete_device(self):
+ """
+ Tests that a remove of a device is successfully
+ """
+ # Count number of devies of an user.
+ res = self.get_success(self.handler.get_devices_by_user(self.other_user))
+ number_devices = len(res)
+ self.assertEqual(1, number_devices)
+
+ # Delete device
+ request, channel = self.make_request(
+ "DELETE", self.url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ # Ensure that the number of devices is decreased
+ res = self.get_success(self.handler.get_devices_by_user(self.other_user))
+ self.assertEqual(number_devices - 1, len(res))
+
+
+class DevicesRestTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+
+ self.url = "/_synapse/admin/v2/users/%s/devices" % urllib.parse.quote(
+ self.other_user
+ )
+
+ def test_no_auth(self):
+ """
+ Try to list devices of an user without authentication.
+ """
+ request, channel = self.make_request("GET", self.url, b"{}")
+ self.render(request)
+
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_requester_is_no_admin(self):
+ """
+ If the user is not a server admin, an error is returned.
+ """
+ other_user_token = self.login("user", "pass")
+
+ request, channel = self.make_request(
+ "GET", self.url, access_token=other_user_token,
+ )
+ self.render(request)
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_user_does_not_exist(self):
+ """
+ Tests that a lookup for a user that does not exist returns a 404
+ """
+ url = "/_synapse/admin/v2/users/@unknown_person:test/devices"
+ request, channel = self.make_request(
+ "GET", url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+ def test_user_is_not_local(self):
+ """
+ Tests that a lookup for a user that is not a local returns a 400
+ """
+ url = "/_synapse/admin/v2/users/@unknown_person:unknown_domain/devices"
+
+ request, channel = self.make_request(
+ "GET", url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual("Can only lookup local users", channel.json_body["error"])
+
+ def test_get_devices(self):
+ """
+ Tests that a normal lookup for devices is successfully
+ """
+ # Create devices
+ number_devices = 5
+ for n in range(number_devices):
+ self.login("user", "pass")
+
+ # Get devices
+ request, channel = self.make_request(
+ "GET", self.url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(number_devices, len(channel.json_body["devices"]))
+ self.assertEqual(self.other_user, channel.json_body["devices"][0]["user_id"])
+ # Check that all fields are available
+ for d in channel.json_body["devices"]:
+ self.assertIn("user_id", d)
+ self.assertIn("device_id", d)
+ self.assertIn("display_name", d)
+ self.assertIn("last_seen_ip", d)
+ self.assertIn("last_seen_ts", d)
+
+
+class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.handler = hs.get_device_handler()
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+
+ self.url = "/_synapse/admin/v2/users/%s/delete_devices" % urllib.parse.quote(
+ self.other_user
+ )
+
+ def test_no_auth(self):
+ """
+ Try to delete devices of an user without authentication.
+ """
+ request, channel = self.make_request("POST", self.url, b"{}")
+ self.render(request)
+
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_requester_is_no_admin(self):
+ """
+ If the user is not a server admin, an error is returned.
+ """
+ other_user_token = self.login("user", "pass")
+
+ request, channel = self.make_request(
+ "POST", self.url, access_token=other_user_token,
+ )
+ self.render(request)
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_user_does_not_exist(self):
+ """
+ Tests that a lookup for a user that does not exist returns a 404
+ """
+ url = "/_synapse/admin/v2/users/@unknown_person:test/delete_devices"
+ request, channel = self.make_request(
+ "POST", url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+ def test_user_is_not_local(self):
+ """
+ Tests that a lookup for a user that is not a local returns a 400
+ """
+ url = "/_synapse/admin/v2/users/@unknown_person:unknown_domain/delete_devices"
+
+ request, channel = self.make_request(
+ "POST", url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual("Can only lookup local users", channel.json_body["error"])
+
+ def test_unknown_devices(self):
+ """
+ Tests that a remove of a device that does not exist returns 200.
+ """
+ body = json.dumps({"devices": ["unknown_device1", "unknown_device2"]})
+ request, channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content=body.encode(encoding="utf_8"),
+ )
+ self.render(request)
+
+ # Delete unknown devices returns status 200
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ def test_delete_devices(self):
+ """
+ Tests that a remove of devices is successfully
+ """
+
+ # Create devices
+ number_devices = 5
+ for n in range(number_devices):
+ self.login("user", "pass")
+
+ # Get devices
+ res = self.get_success(self.handler.get_devices_by_user(self.other_user))
+ self.assertEqual(number_devices, len(res))
+
+ # Create list of device IDs
+ device_ids = []
+ for d in res:
+ device_ids.append(str(d["device_id"]))
+
+ # Delete devices
+ body = json.dumps({"devices": device_ids})
+ request, channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content=body.encode(encoding="utf_8"),
+ )
+ self.render(request)
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ res = self.get_success(self.handler.get_devices_by_user(self.other_user))
+ self.assertEqual(0, len(res))
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
new file mode 100644
index 0000000000..54cd24bf64
--- /dev/null
+++ b/tests/rest/admin/test_room.py
@@ -0,0 +1,1007 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Dirk Klimpel
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import urllib.parse
+from typing import List, Optional
+
+from mock import Mock
+
+import synapse.rest.admin
+from synapse.api.errors import Codes
+from synapse.rest.client.v1 import directory, events, login, room
+
+from tests import unittest
+
+"""Tests admin REST events for /rooms paths."""
+
+
+class ShutdownRoomTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ login.register_servlets,
+ events.register_servlets,
+ room.register_servlets,
+ room.register_deprecated_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.event_creation_handler = hs.get_event_creation_handler()
+ hs.config.user_consent_version = "1"
+
+ consent_uri_builder = Mock()
+ consent_uri_builder.build_user_consent_uri.return_value = "http://example.com"
+ self.event_creation_handler._consent_uri_builder = consent_uri_builder
+
+ self.store = hs.get_datastore()
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+ self.other_user_token = self.login("user", "pass")
+
+ # Mark the admin user as having consented
+ self.get_success(self.store.user_set_consent_version(self.admin_user, "1"))
+
+ def test_shutdown_room_consent(self):
+ """Test that we can shutdown rooms with local users who have not
+ yet accepted the privacy policy. This used to fail when we tried to
+ force part the user from the old room.
+ """
+ self.event_creation_handler._block_events_without_consent_error = None
+
+ room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token)
+
+ # Assert one user in room
+ users_in_room = self.get_success(self.store.get_users_in_room(room_id))
+ self.assertEqual([self.other_user], users_in_room)
+
+ # Enable require consent to send events
+ self.event_creation_handler._block_events_without_consent_error = "Error"
+
+ # Assert that the user is getting consent error
+ self.helper.send(
+ room_id, body="foo", tok=self.other_user_token, expect_code=403
+ )
+
+ # Test that the admin can still send shutdown
+ url = "admin/shutdown_room/" + room_id
+ request, channel = self.make_request(
+ "POST",
+ url.encode("ascii"),
+ json.dumps({"new_room_user_id": self.admin_user}),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Assert there is now no longer anyone in the room
+ users_in_room = self.get_success(self.store.get_users_in_room(room_id))
+ self.assertEqual([], users_in_room)
+
+ def test_shutdown_room_block_peek(self):
+ """Test that a world_readable room can no longer be peeked into after
+ it has been shut down.
+ """
+
+ self.event_creation_handler._block_events_without_consent_error = None
+
+ room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token)
+
+ # Enable world readable
+ url = "rooms/%s/state/m.room.history_visibility" % (room_id,)
+ request, channel = self.make_request(
+ "PUT",
+ url.encode("ascii"),
+ json.dumps({"history_visibility": "world_readable"}),
+ access_token=self.other_user_token,
+ )
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Test that the admin can still send shutdown
+ url = "admin/shutdown_room/" + room_id
+ request, channel = self.make_request(
+ "POST",
+ url.encode("ascii"),
+ json.dumps({"new_room_user_id": self.admin_user}),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Assert we can no longer peek into the room
+ self._assert_peek(room_id, expect_code=403)
+
+ def _assert_peek(self, room_id, expect_code):
+ """Assert that the admin user can (or cannot) peek into the room.
+ """
+
+ url = "rooms/%s/initialSync" % (room_id,)
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok
+ )
+ self.render(request)
+ self.assertEqual(
+ expect_code, int(channel.result["code"]), msg=channel.result["body"]
+ )
+
+ url = "events?timeout=0&room_id=" + room_id
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok
+ )
+ self.render(request)
+ self.assertEqual(
+ expect_code, int(channel.result["code"]), msg=channel.result["body"]
+ )
+
+
+class PurgeRoomTestCase(unittest.HomeserverTestCase):
+ """Test /purge_room admin API.
+ """
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ def test_purge_room(self):
+ room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+ # All users have to have left the room.
+ self.helper.leave(room_id, user=self.admin_user, tok=self.admin_user_tok)
+
+ url = "/_synapse/admin/v1/purge_room"
+ request, channel = self.make_request(
+ "POST",
+ url.encode("ascii"),
+ {"room_id": room_id},
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Test that the following tables have been purged of all rows related to the room.
+ for table in (
+ "current_state_events",
+ "event_backward_extremities",
+ "event_forward_extremities",
+ "event_json",
+ "event_push_actions",
+ "event_search",
+ "events",
+ "group_rooms",
+ "public_room_list_stream",
+ "receipts_graph",
+ "receipts_linearized",
+ "room_aliases",
+ "room_depth",
+ "room_memberships",
+ "room_stats_state",
+ "room_stats_current",
+ "room_stats_historical",
+ "room_stats_earliest_token",
+ "rooms",
+ "stream_ordering_to_exterm",
+ "users_in_public_rooms",
+ "users_who_share_private_rooms",
+ "appservice_room_list",
+ "e2e_room_keys",
+ "event_push_summary",
+ "pusher_throttle",
+ "group_summary_rooms",
+ "local_invites",
+ "room_account_data",
+ "room_tags",
+ # "state_groups", # Current impl leaves orphaned state groups around.
+ "state_groups_state",
+ ):
+ count = self.get_success(
+ self.store.db.simple_select_one_onecol(
+ table=table,
+ keyvalues={"room_id": room_id},
+ retcol="COUNT(*)",
+ desc="test_purge_room",
+ )
+ )
+
+ self.assertEqual(count, 0, msg="Rows not purged in {}".format(table))
+
+
+class RoomTestCase(unittest.HomeserverTestCase):
+ """Test /room admin API.
+ """
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ directory.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+
+ # Create user
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ def test_list_rooms(self):
+ """Test that we can list rooms"""
+ # Create 3 test rooms
+ total_rooms = 3
+ room_ids = []
+ for x in range(total_rooms):
+ room_id = self.helper.create_room_as(
+ self.admin_user, tok=self.admin_user_tok
+ )
+ room_ids.append(room_id)
+
+ # Request the list of rooms
+ url = "/_synapse/admin/v1/rooms"
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ # Check request completed successfully
+ self.assertEqual(200, int(channel.code), msg=channel.json_body)
+
+ # Check that response json body contains a "rooms" key
+ self.assertTrue(
+ "rooms" in channel.json_body,
+ msg="Response body does not " "contain a 'rooms' key",
+ )
+
+ # Check that 3 rooms were returned
+ self.assertEqual(3, len(channel.json_body["rooms"]), msg=channel.json_body)
+
+ # Check their room_ids match
+ returned_room_ids = [room["room_id"] for room in channel.json_body["rooms"]]
+ self.assertEqual(room_ids, returned_room_ids)
+
+ # Check that all fields are available
+ for r in channel.json_body["rooms"]:
+ self.assertIn("name", r)
+ self.assertIn("canonical_alias", r)
+ self.assertIn("joined_members", r)
+ self.assertIn("joined_local_members", r)
+ self.assertIn("version", r)
+ self.assertIn("creator", r)
+ self.assertIn("encryption", r)
+ self.assertIn("federatable", r)
+ self.assertIn("public", r)
+ self.assertIn("join_rules", r)
+ self.assertIn("guest_access", r)
+ self.assertIn("history_visibility", r)
+ self.assertIn("state_events", r)
+
+ # Check that the correct number of total rooms was returned
+ self.assertEqual(channel.json_body["total_rooms"], total_rooms)
+
+ # Check that the offset is correct
+ # Should be 0 as we aren't paginating
+ self.assertEqual(channel.json_body["offset"], 0)
+
+ # Check that the prev_batch parameter is not present
+ self.assertNotIn("prev_batch", channel.json_body)
+
+ # We shouldn't receive a next token here as there's no further rooms to show
+ self.assertNotIn("next_batch", channel.json_body)
+
+ def test_list_rooms_pagination(self):
+ """Test that we can get a full list of rooms through pagination"""
+ # Create 5 test rooms
+ total_rooms = 5
+ room_ids = []
+ for x in range(total_rooms):
+ room_id = self.helper.create_room_as(
+ self.admin_user, tok=self.admin_user_tok
+ )
+ room_ids.append(room_id)
+
+ # Set the name of the rooms so we get a consistent returned ordering
+ for idx, room_id in enumerate(room_ids):
+ self.helper.send_state(
+ room_id, "m.room.name", {"name": str(idx)}, tok=self.admin_user_tok,
+ )
+
+ # Request the list of rooms
+ returned_room_ids = []
+ start = 0
+ limit = 2
+
+ run_count = 0
+ should_repeat = True
+ while should_repeat:
+ run_count += 1
+
+ url = "/_synapse/admin/v1/rooms?from=%d&limit=%d&order_by=%s" % (
+ start,
+ limit,
+ "name",
+ )
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(
+ 200, int(channel.result["code"]), msg=channel.result["body"]
+ )
+
+ self.assertTrue("rooms" in channel.json_body)
+ for r in channel.json_body["rooms"]:
+ returned_room_ids.append(r["room_id"])
+
+ # Check that the correct number of total rooms was returned
+ self.assertEqual(channel.json_body["total_rooms"], total_rooms)
+
+ # Check that the offset is correct
+ # We're only getting 2 rooms each page, so should be 2 * last run_count
+ self.assertEqual(channel.json_body["offset"], 2 * (run_count - 1))
+
+ if run_count > 1:
+ # Check the value of prev_batch is correct
+ self.assertEqual(channel.json_body["prev_batch"], 2 * (run_count - 2))
+
+ if "next_batch" not in channel.json_body:
+ # We have reached the end of the list
+ should_repeat = False
+ else:
+ # Make another query with an updated start value
+ start = channel.json_body["next_batch"]
+
+ # We should've queried the endpoint 3 times
+ self.assertEqual(
+ run_count,
+ 3,
+ msg="Should've queried 3 times for 5 rooms with limit 2 per query",
+ )
+
+ # Check that we received all of the room ids
+ self.assertEqual(room_ids, returned_room_ids)
+
+ url = "/_synapse/admin/v1/rooms?from=%d&limit=%d" % (start, limit)
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ def test_correct_room_attributes(self):
+ """Test the correct attributes for a room are returned"""
+ # Create a test room
+ room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+ test_alias = "#test:test"
+ test_room_name = "something"
+
+ # Have another user join the room
+ user_2 = self.register_user("user4", "pass")
+ user_tok_2 = self.login("user4", "pass")
+ self.helper.join(room_id, user_2, tok=user_tok_2)
+
+ # Create a new alias to this room
+ url = "/_matrix/client/r0/directory/room/%s" % (urllib.parse.quote(test_alias),)
+ request, channel = self.make_request(
+ "PUT",
+ url.encode("ascii"),
+ {"room_id": room_id},
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Set this new alias as the canonical alias for this room
+ self.helper.send_state(
+ room_id,
+ "m.room.aliases",
+ {"aliases": [test_alias]},
+ tok=self.admin_user_tok,
+ state_key="test",
+ )
+ self.helper.send_state(
+ room_id,
+ "m.room.canonical_alias",
+ {"alias": test_alias},
+ tok=self.admin_user_tok,
+ )
+
+ # Set a name for the room
+ self.helper.send_state(
+ room_id, "m.room.name", {"name": test_room_name}, tok=self.admin_user_tok,
+ )
+
+ # Request the list of rooms
+ url = "/_synapse/admin/v1/rooms"
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Check that rooms were returned
+ self.assertTrue("rooms" in channel.json_body)
+ rooms = channel.json_body["rooms"]
+
+ # Check that only one room was returned
+ self.assertEqual(len(rooms), 1)
+
+ # And that the value of the total_rooms key was correct
+ self.assertEqual(channel.json_body["total_rooms"], 1)
+
+ # Check that the offset is correct
+ # We're not paginating, so should be 0
+ self.assertEqual(channel.json_body["offset"], 0)
+
+ # Check that there is no `prev_batch`
+ self.assertNotIn("prev_batch", channel.json_body)
+
+ # Check that there is no `next_batch`
+ self.assertNotIn("next_batch", channel.json_body)
+
+ # Check that all provided attributes are set
+ r = rooms[0]
+ self.assertEqual(room_id, r["room_id"])
+ self.assertEqual(test_room_name, r["name"])
+ self.assertEqual(test_alias, r["canonical_alias"])
+
+ def test_room_list_sort_order(self):
+ """Test room list sort ordering. alphabetical name versus number of members,
+ reversing the order, etc.
+ """
+
+ def _set_canonical_alias(room_id: str, test_alias: str, admin_user_tok: str):
+ # Create a new alias to this room
+ url = "/_matrix/client/r0/directory/room/%s" % (
+ urllib.parse.quote(test_alias),
+ )
+ request, channel = self.make_request(
+ "PUT",
+ url.encode("ascii"),
+ {"room_id": room_id},
+ access_token=admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(
+ 200, int(channel.result["code"]), msg=channel.result["body"]
+ )
+
+ # Set this new alias as the canonical alias for this room
+ self.helper.send_state(
+ room_id,
+ "m.room.aliases",
+ {"aliases": [test_alias]},
+ tok=admin_user_tok,
+ state_key="test",
+ )
+ self.helper.send_state(
+ room_id,
+ "m.room.canonical_alias",
+ {"alias": test_alias},
+ tok=admin_user_tok,
+ )
+
+ def _order_test(
+ order_type: str, expected_room_list: List[str], reverse: bool = False,
+ ):
+ """Request the list of rooms in a certain order. Assert that order is what
+ we expect
+
+ Args:
+ order_type: The type of ordering to give the server
+ expected_room_list: The list of room_ids in the order we expect to get
+ back from the server
+ """
+ # Request the list of rooms in the given order
+ url = "/_synapse/admin/v1/rooms?order_by=%s" % (order_type,)
+ if reverse:
+ url += "&dir=b"
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ # Check that rooms were returned
+ self.assertTrue("rooms" in channel.json_body)
+ rooms = channel.json_body["rooms"]
+
+ # Check for the correct total_rooms value
+ self.assertEqual(channel.json_body["total_rooms"], 3)
+
+ # Check that the offset is correct
+ # We're not paginating, so should be 0
+ self.assertEqual(channel.json_body["offset"], 0)
+
+ # Check that there is no `prev_batch`
+ self.assertNotIn("prev_batch", channel.json_body)
+
+ # Check that there is no `next_batch`
+ self.assertNotIn("next_batch", channel.json_body)
+
+ # Check that rooms were returned in alphabetical order
+ returned_order = [r["room_id"] for r in rooms]
+ self.assertListEqual(expected_room_list, returned_order) # order is checked
+
+ # Create 3 test rooms
+ room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+ room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+ room_id_3 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+ # Set room names in alphabetical order. room 1 -> A, 2 -> B, 3 -> C
+ self.helper.send_state(
+ room_id_1, "m.room.name", {"name": "A"}, tok=self.admin_user_tok,
+ )
+ self.helper.send_state(
+ room_id_2, "m.room.name", {"name": "B"}, tok=self.admin_user_tok,
+ )
+ self.helper.send_state(
+ room_id_3, "m.room.name", {"name": "C"}, tok=self.admin_user_tok,
+ )
+
+ # Set room canonical room aliases
+ _set_canonical_alias(room_id_1, "#A_alias:test", self.admin_user_tok)
+ _set_canonical_alias(room_id_2, "#B_alias:test", self.admin_user_tok)
+ _set_canonical_alias(room_id_3, "#C_alias:test", self.admin_user_tok)
+
+ # Set room member size in the reverse order. room 1 -> 1 member, 2 -> 2, 3 -> 3
+ user_1 = self.register_user("bob1", "pass")
+ user_1_tok = self.login("bob1", "pass")
+ self.helper.join(room_id_2, user_1, tok=user_1_tok)
+
+ user_2 = self.register_user("bob2", "pass")
+ user_2_tok = self.login("bob2", "pass")
+ self.helper.join(room_id_3, user_2, tok=user_2_tok)
+
+ user_3 = self.register_user("bob3", "pass")
+ user_3_tok = self.login("bob3", "pass")
+ self.helper.join(room_id_3, user_3, tok=user_3_tok)
+
+ # Test different sort orders, with forward and reverse directions
+ _order_test("name", [room_id_1, room_id_2, room_id_3])
+ _order_test("name", [room_id_3, room_id_2, room_id_1], reverse=True)
+
+ _order_test("canonical_alias", [room_id_1, room_id_2, room_id_3])
+ _order_test("canonical_alias", [room_id_3, room_id_2, room_id_1], reverse=True)
+
+ _order_test("joined_members", [room_id_3, room_id_2, room_id_1])
+ _order_test("joined_members", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+ _order_test("joined_local_members", [room_id_3, room_id_2, room_id_1])
+ _order_test(
+ "joined_local_members", [room_id_1, room_id_2, room_id_3], reverse=True
+ )
+
+ _order_test("version", [room_id_1, room_id_2, room_id_3])
+ _order_test("version", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+ _order_test("creator", [room_id_1, room_id_2, room_id_3])
+ _order_test("creator", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+ _order_test("encryption", [room_id_1, room_id_2, room_id_3])
+ _order_test("encryption", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+ _order_test("federatable", [room_id_1, room_id_2, room_id_3])
+ _order_test("federatable", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+ _order_test("public", [room_id_1, room_id_2, room_id_3])
+ # Different sort order of SQlite and PostreSQL
+ # _order_test("public", [room_id_3, room_id_2, room_id_1], reverse=True)
+
+ _order_test("join_rules", [room_id_1, room_id_2, room_id_3])
+ _order_test("join_rules", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+ _order_test("guest_access", [room_id_1, room_id_2, room_id_3])
+ _order_test("guest_access", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+ _order_test("history_visibility", [room_id_1, room_id_2, room_id_3])
+ _order_test(
+ "history_visibility", [room_id_1, room_id_2, room_id_3], reverse=True
+ )
+
+ _order_test("state_events", [room_id_3, room_id_2, room_id_1])
+ _order_test("state_events", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+ def test_search_term(self):
+ """Test that searching for a room works correctly"""
+ # Create two test rooms
+ room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+ room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+ room_name_1 = "something"
+ room_name_2 = "else"
+
+ # Set the name for each room
+ self.helper.send_state(
+ room_id_1, "m.room.name", {"name": room_name_1}, tok=self.admin_user_tok,
+ )
+ self.helper.send_state(
+ room_id_2, "m.room.name", {"name": room_name_2}, tok=self.admin_user_tok,
+ )
+
+ def _search_test(
+ expected_room_id: Optional[str],
+ search_term: str,
+ expected_http_code: int = 200,
+ ):
+ """Search for a room and check that the returned room's id is a match
+
+ Args:
+ expected_room_id: The room_id expected to be returned by the API. Set
+ to None to expect zero results for the search
+ search_term: The term to search for room names with
+ expected_http_code: The expected http code for the request
+ """
+ url = "/_synapse/admin/v1/rooms?search_term=%s" % (search_term,)
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(expected_http_code, channel.code, msg=channel.json_body)
+
+ if expected_http_code != 200:
+ return
+
+ # Check that rooms were returned
+ self.assertTrue("rooms" in channel.json_body)
+ rooms = channel.json_body["rooms"]
+
+ # Check that the expected number of rooms were returned
+ expected_room_count = 1 if expected_room_id else 0
+ self.assertEqual(len(rooms), expected_room_count)
+ self.assertEqual(channel.json_body["total_rooms"], expected_room_count)
+
+ # Check that the offset is correct
+ # We're not paginating, so should be 0
+ self.assertEqual(channel.json_body["offset"], 0)
+
+ # Check that there is no `prev_batch`
+ self.assertNotIn("prev_batch", channel.json_body)
+
+ # Check that there is no `next_batch`
+ self.assertNotIn("next_batch", channel.json_body)
+
+ if expected_room_id:
+ # Check that the first returned room id is correct
+ r = rooms[0]
+ self.assertEqual(expected_room_id, r["room_id"])
+
+ # Perform search tests
+ _search_test(room_id_1, "something")
+ _search_test(room_id_1, "thing")
+
+ _search_test(room_id_2, "else")
+ _search_test(room_id_2, "se")
+
+ _search_test(None, "foo")
+ _search_test(None, "bar")
+ _search_test(None, "", expected_http_code=400)
+
+ def test_single_room(self):
+ """Test that a single room can be requested correctly"""
+ # Create two test rooms
+ room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+ room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+ room_name_1 = "something"
+ room_name_2 = "else"
+
+ # Set the name for each room
+ self.helper.send_state(
+ room_id_1, "m.room.name", {"name": room_name_1}, tok=self.admin_user_tok,
+ )
+ self.helper.send_state(
+ room_id_2, "m.room.name", {"name": room_name_2}, tok=self.admin_user_tok,
+ )
+
+ url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,)
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ self.assertIn("room_id", channel.json_body)
+ self.assertIn("name", channel.json_body)
+ self.assertIn("canonical_alias", channel.json_body)
+ self.assertIn("joined_members", channel.json_body)
+ self.assertIn("joined_local_members", channel.json_body)
+ self.assertIn("version", channel.json_body)
+ self.assertIn("creator", channel.json_body)
+ self.assertIn("encryption", channel.json_body)
+ self.assertIn("federatable", channel.json_body)
+ self.assertIn("public", channel.json_body)
+ self.assertIn("join_rules", channel.json_body)
+ self.assertIn("guest_access", channel.json_body)
+ self.assertIn("history_visibility", channel.json_body)
+ self.assertIn("state_events", channel.json_body)
+
+ self.assertEqual(room_id_1, channel.json_body["room_id"])
+
+
+class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, homeserver):
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.creator = self.register_user("creator", "test")
+ self.creator_tok = self.login("creator", "test")
+
+ self.second_user_id = self.register_user("second", "test")
+ self.second_tok = self.login("second", "test")
+
+ self.public_room_id = self.helper.create_room_as(
+ self.creator, tok=self.creator_tok, is_public=True
+ )
+ self.url = "/_synapse/admin/v1/join/{}".format(self.public_room_id)
+
+ def test_requester_is_no_admin(self):
+ """
+ If the user is not a server admin, an error 403 is returned.
+ """
+ body = json.dumps({"user_id": self.second_user_id})
+
+ request, channel = self.make_request(
+ "POST",
+ self.url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.second_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_invalid_parameter(self):
+ """
+ If a parameter is missing, return an error
+ """
+ body = json.dumps({"unknown_parameter": "@unknown:test"})
+
+ request, channel = self.make_request(
+ "POST",
+ self.url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
+
+ def test_local_user_does_not_exist(self):
+ """
+ Tests that a lookup for a user that does not exist returns a 404
+ """
+ body = json.dumps({"user_id": "@unknown:test"})
+
+ request, channel = self.make_request(
+ "POST",
+ self.url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+ def test_remote_user(self):
+ """
+ Check that only local user can join rooms.
+ """
+ body = json.dumps({"user_id": "@not:exist.bla"})
+
+ request, channel = self.make_request(
+ "POST",
+ self.url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ "This endpoint can only be used with local users",
+ channel.json_body["error"],
+ )
+
+ def test_room_does_not_exist(self):
+ """
+ Check that unknown rooms/server return error 404.
+ """
+ body = json.dumps({"user_id": self.second_user_id})
+ url = "/_synapse/admin/v1/join/!unknown:test"
+
+ request, channel = self.make_request(
+ "POST",
+ url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("No known servers", channel.json_body["error"])
+
+ def test_room_is_not_valid(self):
+ """
+ Check that invalid room names, return an error 400.
+ """
+ body = json.dumps({"user_id": self.second_user_id})
+ url = "/_synapse/admin/v1/join/invalidroom"
+
+ request, channel = self.make_request(
+ "POST",
+ url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ "invalidroom was not legal room ID or room alias",
+ channel.json_body["error"],
+ )
+
+ def test_join_public_room(self):
+ """
+ Test joining a local user to a public room with "JoinRules.PUBLIC"
+ """
+ body = json.dumps({"user_id": self.second_user_id})
+
+ request, channel = self.make_request(
+ "POST",
+ self.url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(self.public_room_id, channel.json_body["room_id"])
+
+ # Validate if user is a member of the room
+
+ request, channel = self.make_request(
+ "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,
+ )
+ self.render(request)
+ self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(self.public_room_id, channel.json_body["joined_rooms"][0])
+
+ def test_join_private_room_if_not_member(self):
+ """
+ Test joining a local user to a private room with "JoinRules.INVITE"
+ when server admin is not member of this room.
+ """
+ private_room_id = self.helper.create_room_as(
+ self.creator, tok=self.creator_tok, is_public=False
+ )
+ url = "/_synapse/admin/v1/join/{}".format(private_room_id)
+ body = json.dumps({"user_id": self.second_user_id})
+
+ request, channel = self.make_request(
+ "POST",
+ url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_join_private_room_if_member(self):
+ """
+ Test joining a local user to a private room with "JoinRules.INVITE",
+ when server admin is member of this room.
+ """
+ private_room_id = self.helper.create_room_as(
+ self.creator, tok=self.creator_tok, is_public=False
+ )
+ self.helper.invite(
+ room=private_room_id,
+ src=self.creator,
+ targ=self.admin_user,
+ tok=self.creator_tok,
+ )
+ self.helper.join(
+ room=private_room_id, user=self.admin_user, tok=self.admin_user_tok
+ )
+
+ # Validate if server admin is a member of the room
+
+ request, channel = self.make_request(
+ "GET", "/_matrix/client/r0/joined_rooms", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
+
+ # Join user to room.
+
+ url = "/_synapse/admin/v1/join/{}".format(private_room_id)
+ body = json.dumps({"user_id": self.second_user_id})
+
+ request, channel = self.make_request(
+ "POST",
+ url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(private_room_id, channel.json_body["room_id"])
+
+ # Validate if user is a member of the room
+
+ request, channel = self.make_request(
+ "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,
+ )
+ self.render(request)
+ self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
+
+ def test_join_private_room_if_owner(self):
+ """
+ Test joining a local user to a private room with "JoinRules.INVITE",
+ when server admin is owner of this room.
+ """
+ private_room_id = self.helper.create_room_as(
+ self.admin_user, tok=self.admin_user_tok, is_public=False
+ )
+ url = "/_synapse/admin/v1/join/{}".format(private_room_id)
+ body = json.dumps({"user_id": self.second_user_id})
+
+ request, channel = self.make_request(
+ "POST",
+ url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(private_room_id, channel.json_body["room_id"])
+
+ # Validate if user is a member of the room
+
+ request, channel = self.make_request(
+ "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,
+ )
+ self.render(request)
+ self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
new file mode 100644
index 0000000000..cca5f548e6
--- /dev/null
+++ b/tests/rest/admin/test_user.py
@@ -0,0 +1,947 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import hashlib
+import hmac
+import json
+import urllib.parse
+
+from mock import Mock
+
+import synapse.rest.admin
+from synapse.api.constants import UserTypes
+from synapse.api.errors import HttpResponseException, ResourceLimitError
+from synapse.rest.client.v1 import login
+from synapse.rest.client.v2_alpha import sync
+
+from tests import unittest
+from tests.unittest import override_config
+
+
+class UserRegisterTestCase(unittest.HomeserverTestCase):
+
+ servlets = [synapse.rest.admin.register_servlets_for_client_rest_resource]
+
+ def make_homeserver(self, reactor, clock):
+
+ self.url = "/_matrix/client/r0/admin/register"
+
+ self.registration_handler = Mock()
+ self.identity_handler = Mock()
+ self.login_handler = Mock()
+ self.device_handler = Mock()
+ self.device_handler.check_device_registered = Mock(return_value="FAKE")
+
+ self.datastore = Mock(return_value=Mock())
+ self.datastore.get_current_state_deltas = Mock(return_value=(0, []))
+
+ self.secrets = Mock()
+
+ self.hs = self.setup_test_homeserver()
+
+ self.hs.config.registration_shared_secret = "shared"
+
+ self.hs.get_media_repository = Mock()
+ self.hs.get_deactivate_account_handler = Mock()
+
+ return self.hs
+
+ def test_disabled(self):
+ """
+ If there is no shared secret, registration through this method will be
+ prevented.
+ """
+ self.hs.config.registration_shared_secret = None
+
+ request, channel = self.make_request("POST", self.url, b"{}")
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ "Shared secret registration is not enabled", channel.json_body["error"]
+ )
+
+ def test_get_nonce(self):
+ """
+ Calling GET on the endpoint will return a randomised nonce, using the
+ homeserver's secrets provider.
+ """
+ secrets = Mock()
+ secrets.token_hex = Mock(return_value="abcd")
+
+ self.hs.get_secrets = Mock(return_value=secrets)
+
+ request, channel = self.make_request("GET", self.url)
+ self.render(request)
+
+ self.assertEqual(channel.json_body, {"nonce": "abcd"})
+
+ def test_expired_nonce(self):
+ """
+ Calling GET on the endpoint will return a randomised nonce, which will
+ only last for SALT_TIMEOUT (60s).
+ """
+ request, channel = self.make_request("GET", self.url)
+ self.render(request)
+ nonce = channel.json_body["nonce"]
+
+ # 59 seconds
+ self.reactor.advance(59)
+
+ body = json.dumps({"nonce": nonce})
+ request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("username must be specified", channel.json_body["error"])
+
+ # 61 seconds
+ self.reactor.advance(2)
+
+ request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("unrecognised nonce", channel.json_body["error"])
+
+ def test_register_incorrect_nonce(self):
+ """
+ Only the provided nonce can be used, as it's checked in the MAC.
+ """
+ request, channel = self.make_request("GET", self.url)
+ self.render(request)
+ nonce = channel.json_body["nonce"]
+
+ want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
+ want_mac.update(b"notthenonce\x00bob\x00abc123\x00admin")
+ want_mac = want_mac.hexdigest()
+
+ body = json.dumps(
+ {
+ "nonce": nonce,
+ "username": "bob",
+ "password": "abc123",
+ "admin": True,
+ "mac": want_mac,
+ }
+ )
+ request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ self.render(request)
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("HMAC incorrect", channel.json_body["error"])
+
+ def test_register_correct_nonce(self):
+ """
+ When the correct nonce is provided, and the right key is provided, the
+ user is registered.
+ """
+ request, channel = self.make_request("GET", self.url)
+ self.render(request)
+ nonce = channel.json_body["nonce"]
+
+ want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
+ want_mac.update(
+ nonce.encode("ascii") + b"\x00bob\x00abc123\x00admin\x00support"
+ )
+ want_mac = want_mac.hexdigest()
+
+ body = json.dumps(
+ {
+ "nonce": nonce,
+ "username": "bob",
+ "password": "abc123",
+ "admin": True,
+ "user_type": UserTypes.SUPPORT,
+ "mac": want_mac,
+ }
+ )
+ request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@bob:test", channel.json_body["user_id"])
+
+ def test_nonce_reuse(self):
+ """
+ A valid unrecognised nonce.
+ """
+ request, channel = self.make_request("GET", self.url)
+ self.render(request)
+ nonce = channel.json_body["nonce"]
+
+ want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
+ want_mac.update(nonce.encode("ascii") + b"\x00bob\x00abc123\x00admin")
+ want_mac = want_mac.hexdigest()
+
+ body = json.dumps(
+ {
+ "nonce": nonce,
+ "username": "bob",
+ "password": "abc123",
+ "admin": True,
+ "mac": want_mac,
+ }
+ )
+ request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@bob:test", channel.json_body["user_id"])
+
+ # Now, try and reuse it
+ request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("unrecognised nonce", channel.json_body["error"])
+
+ def test_missing_parts(self):
+ """
+ Synapse will complain if you don't give nonce, username, password, and
+ mac. Admin and user_types are optional. Additional checks are done for length
+ and type.
+ """
+
+ def nonce():
+ request, channel = self.make_request("GET", self.url)
+ self.render(request)
+ return channel.json_body["nonce"]
+
+ #
+ # Nonce check
+ #
+
+ # Must be present
+ body = json.dumps({})
+ request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("nonce must be specified", channel.json_body["error"])
+
+ #
+ # Username checks
+ #
+
+ # Must be present
+ body = json.dumps({"nonce": nonce()})
+ request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("username must be specified", channel.json_body["error"])
+
+ # Must be a string
+ body = json.dumps({"nonce": nonce(), "username": 1234})
+ request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("Invalid username", channel.json_body["error"])
+
+ # Must not have null bytes
+ body = json.dumps({"nonce": nonce(), "username": "abcd\u0000"})
+ request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("Invalid username", channel.json_body["error"])
+
+ # Must not have null bytes
+ body = json.dumps({"nonce": nonce(), "username": "a" * 1000})
+ request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("Invalid username", channel.json_body["error"])
+
+ #
+ # Password checks
+ #
+
+ # Must be present
+ body = json.dumps({"nonce": nonce(), "username": "a"})
+ request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("password must be specified", channel.json_body["error"])
+
+ # Must be a string
+ body = json.dumps({"nonce": nonce(), "username": "a", "password": 1234})
+ request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("Invalid password", channel.json_body["error"])
+
+ # Must not have null bytes
+ body = json.dumps({"nonce": nonce(), "username": "a", "password": "abcd\u0000"})
+ request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("Invalid password", channel.json_body["error"])
+
+ # Super long
+ body = json.dumps({"nonce": nonce(), "username": "a", "password": "A" * 1000})
+ request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("Invalid password", channel.json_body["error"])
+
+ #
+ # user_type check
+ #
+
+ # Invalid user_type
+ body = json.dumps(
+ {
+ "nonce": nonce(),
+ "username": "a",
+ "password": "1234",
+ "user_type": "invalid",
+ }
+ )
+ request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("Invalid user type", channel.json_body["error"])
+
+ @override_config(
+ {"limit_usage_by_mau": True, "max_mau_value": 2, "mau_trial_days": 0}
+ )
+ def test_register_mau_limit_reached(self):
+ """
+ Check we can register a user via the shared secret registration API
+ even if the MAU limit is reached.
+ """
+ handler = self.hs.get_registration_handler()
+ store = self.hs.get_datastore()
+
+ # Set monthly active users to the limit
+ store.get_monthly_active_count = Mock(return_value=self.hs.config.max_mau_value)
+ # Check that the blocking of monthly active users is working as expected
+ # The registration of a new user fails due to the limit
+ self.get_failure(
+ handler.register_user(localpart="local_part"), ResourceLimitError
+ )
+
+ # Register new user with admin API
+ request, channel = self.make_request("GET", self.url)
+ self.render(request)
+ nonce = channel.json_body["nonce"]
+
+ want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
+ want_mac.update(
+ nonce.encode("ascii") + b"\x00bob\x00abc123\x00admin\x00support"
+ )
+ want_mac = want_mac.hexdigest()
+
+ body = json.dumps(
+ {
+ "nonce": nonce,
+ "username": "bob",
+ "password": "abc123",
+ "admin": True,
+ "user_type": UserTypes.SUPPORT,
+ "mac": want_mac,
+ }
+ )
+ request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@bob:test", channel.json_body["user_id"])
+
+
+class UsersListTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ ]
+ url = "/_synapse/admin/v2/users"
+
+ def prepare(self, reactor, clock, hs):
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.register_user("user1", "pass1", admin=False)
+ self.register_user("user2", "pass2", admin=False)
+
+ def test_no_auth(self):
+ """
+ Try to list users without authentication.
+ """
+ request, channel = self.make_request("GET", self.url, b"{}")
+ self.render(request)
+
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("M_MISSING_TOKEN", channel.json_body["errcode"])
+
+ def test_all_users(self):
+ """
+ List all users, including deactivated users.
+ """
+ request, channel = self.make_request(
+ "GET",
+ self.url + "?deactivated=true",
+ b"{}",
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(3, len(channel.json_body["users"]))
+ self.assertEqual(3, channel.json_body["total"])
+
+
+class UserRestTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ sync.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+ self.other_user_token = self.login("user", "pass")
+ self.url_other_user = "/_synapse/admin/v2/users/%s" % urllib.parse.quote(
+ self.other_user
+ )
+
+ def test_requester_is_no_admin(self):
+ """
+ If the user is not a server admin, an error is returned.
+ """
+ url = "/_synapse/admin/v2/users/@bob:test"
+
+ request, channel = self.make_request(
+ "GET", url, access_token=self.other_user_token,
+ )
+ self.render(request)
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("You are not a server admin", channel.json_body["error"])
+
+ request, channel = self.make_request(
+ "PUT", url, access_token=self.other_user_token, content=b"{}",
+ )
+ self.render(request)
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("You are not a server admin", channel.json_body["error"])
+
+ def test_user_does_not_exist(self):
+ """
+ Tests that a lookup for a user that does not exist returns a 404
+ """
+
+ request, channel = self.make_request(
+ "GET",
+ "/_synapse/admin/v2/users/@unknown_person:test",
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual("M_NOT_FOUND", channel.json_body["errcode"])
+
+ def test_create_server_admin(self):
+ """
+ Check that a new admin user is created successfully.
+ """
+ url = "/_synapse/admin/v2/users/@bob:test"
+
+ # Create user (server admin)
+ body = json.dumps(
+ {
+ "password": "abc123",
+ "admin": True,
+ "displayname": "Bob's name",
+ "threepids": [{"medium": "email", "address": "bob@bob.bob"}],
+ "avatar_url": None,
+ }
+ )
+
+ request, channel = self.make_request(
+ "PUT",
+ url,
+ access_token=self.admin_user_tok,
+ content=body.encode(encoding="utf_8"),
+ )
+ self.render(request)
+
+ self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@bob:test", channel.json_body["name"])
+ self.assertEqual("Bob's name", channel.json_body["displayname"])
+ self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
+ self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
+ self.assertEqual(True, channel.json_body["admin"])
+
+ # Get user
+ request, channel = self.make_request(
+ "GET", url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@bob:test", channel.json_body["name"])
+ self.assertEqual("Bob's name", channel.json_body["displayname"])
+ self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
+ self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
+ self.assertEqual(True, channel.json_body["admin"])
+ self.assertEqual(False, channel.json_body["is_guest"])
+ self.assertEqual(False, channel.json_body["deactivated"])
+
+ def test_create_user(self):
+ """
+ Check that a new regular user is created successfully.
+ """
+ url = "/_synapse/admin/v2/users/@bob:test"
+
+ # Create user
+ body = json.dumps(
+ {
+ "password": "abc123",
+ "admin": False,
+ "displayname": "Bob's name",
+ "threepids": [{"medium": "email", "address": "bob@bob.bob"}],
+ }
+ )
+
+ request, channel = self.make_request(
+ "PUT",
+ url,
+ access_token=self.admin_user_tok,
+ content=body.encode(encoding="utf_8"),
+ )
+ self.render(request)
+
+ self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@bob:test", channel.json_body["name"])
+ self.assertEqual("Bob's name", channel.json_body["displayname"])
+ self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
+ self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
+ self.assertEqual(False, channel.json_body["admin"])
+
+ # Get user
+ request, channel = self.make_request(
+ "GET", url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@bob:test", channel.json_body["name"])
+ self.assertEqual("Bob's name", channel.json_body["displayname"])
+ self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
+ self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
+ self.assertEqual(False, channel.json_body["admin"])
+ self.assertEqual(False, channel.json_body["is_guest"])
+ self.assertEqual(False, channel.json_body["deactivated"])
+
+ @override_config(
+ {"limit_usage_by_mau": True, "max_mau_value": 2, "mau_trial_days": 0}
+ )
+ def test_create_user_mau_limit_reached_active_admin(self):
+ """
+ Check that an admin can register a new user via the admin API
+ even if the MAU limit is reached.
+ Admin user was active before creating user.
+ """
+
+ handler = self.hs.get_registration_handler()
+
+ # Sync to set admin user to active
+ # before limit of monthly active users is reached
+ request, channel = self.make_request(
+ "GET", "/sync", access_token=self.admin_user_tok
+ )
+ self.render(request)
+
+ if channel.code != 200:
+ raise HttpResponseException(
+ channel.code, channel.result["reason"], channel.result["body"]
+ )
+
+ # Set monthly active users to the limit
+ self.store.get_monthly_active_count = Mock(
+ return_value=self.hs.config.max_mau_value
+ )
+ # Check that the blocking of monthly active users is working as expected
+ # The registration of a new user fails due to the limit
+ self.get_failure(
+ handler.register_user(localpart="local_part"), ResourceLimitError
+ )
+
+ # Register new user with admin API
+ url = "/_synapse/admin/v2/users/@bob:test"
+
+ # Create user
+ body = json.dumps({"password": "abc123", "admin": False})
+
+ request, channel = self.make_request(
+ "PUT",
+ url,
+ access_token=self.admin_user_tok,
+ content=body.encode(encoding="utf_8"),
+ )
+ self.render(request)
+
+ self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@bob:test", channel.json_body["name"])
+ self.assertEqual(False, channel.json_body["admin"])
+
+ @override_config(
+ {"limit_usage_by_mau": True, "max_mau_value": 2, "mau_trial_days": 0}
+ )
+ def test_create_user_mau_limit_reached_passive_admin(self):
+ """
+ Check that an admin can register a new user via the admin API
+ even if the MAU limit is reached.
+ Admin user was not active before creating user.
+ """
+
+ handler = self.hs.get_registration_handler()
+
+ # Set monthly active users to the limit
+ self.store.get_monthly_active_count = Mock(
+ return_value=self.hs.config.max_mau_value
+ )
+ # Check that the blocking of monthly active users is working as expected
+ # The registration of a new user fails due to the limit
+ self.get_failure(
+ handler.register_user(localpart="local_part"), ResourceLimitError
+ )
+
+ # Register new user with admin API
+ url = "/_synapse/admin/v2/users/@bob:test"
+
+ # Create user
+ body = json.dumps({"password": "abc123", "admin": False})
+
+ request, channel = self.make_request(
+ "PUT",
+ url,
+ access_token=self.admin_user_tok,
+ content=body.encode(encoding="utf_8"),
+ )
+ self.render(request)
+
+ # Admin user is not blocked by mau anymore
+ self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@bob:test", channel.json_body["name"])
+ self.assertEqual(False, channel.json_body["admin"])
+
+ @override_config(
+ {
+ "email": {
+ "enable_notifs": True,
+ "notif_for_new_users": True,
+ "notif_from": "test@example.com",
+ },
+ "public_baseurl": "https://example.com",
+ }
+ )
+ def test_create_user_email_notif_for_new_users(self):
+ """
+ Check that a new regular user is created successfully and
+ got an email pusher.
+ """
+ url = "/_synapse/admin/v2/users/@bob:test"
+
+ # Create user
+ body = json.dumps(
+ {
+ "password": "abc123",
+ "threepids": [{"medium": "email", "address": "bob@bob.bob"}],
+ }
+ )
+
+ request, channel = self.make_request(
+ "PUT",
+ url,
+ access_token=self.admin_user_tok,
+ content=body.encode(encoding="utf_8"),
+ )
+ self.render(request)
+
+ self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@bob:test", channel.json_body["name"])
+ self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
+ self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
+
+ pushers = self.get_success(
+ self.store.get_pushers_by({"user_name": "@bob:test"})
+ )
+ pushers = list(pushers)
+ self.assertEqual(len(pushers), 1)
+ self.assertEqual("@bob:test", pushers[0]["user_name"])
+
+ @override_config(
+ {
+ "email": {
+ "enable_notifs": False,
+ "notif_for_new_users": False,
+ "notif_from": "test@example.com",
+ },
+ "public_baseurl": "https://example.com",
+ }
+ )
+ def test_create_user_email_no_notif_for_new_users(self):
+ """
+ Check that a new regular user is created successfully and
+ got not an email pusher.
+ """
+ url = "/_synapse/admin/v2/users/@bob:test"
+
+ # Create user
+ body = json.dumps(
+ {
+ "password": "abc123",
+ "threepids": [{"medium": "email", "address": "bob@bob.bob"}],
+ }
+ )
+
+ request, channel = self.make_request(
+ "PUT",
+ url,
+ access_token=self.admin_user_tok,
+ content=body.encode(encoding="utf_8"),
+ )
+ self.render(request)
+
+ self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@bob:test", channel.json_body["name"])
+ self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
+ self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
+
+ pushers = self.get_success(
+ self.store.get_pushers_by({"user_name": "@bob:test"})
+ )
+ pushers = list(pushers)
+ self.assertEqual(len(pushers), 0)
+
+ def test_set_password(self):
+ """
+ Test setting a new password for another user.
+ """
+
+ # Change password
+ body = json.dumps({"password": "hahaha"})
+
+ request, channel = self.make_request(
+ "PUT",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
+ content=body.encode(encoding="utf_8"),
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ def test_set_displayname(self):
+ """
+ Test setting the displayname of another user.
+ """
+
+ # Modify user
+ body = json.dumps({"displayname": "foobar"})
+
+ request, channel = self.make_request(
+ "PUT",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
+ content=body.encode(encoding="utf_8"),
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertEqual("foobar", channel.json_body["displayname"])
+
+ # Get user
+ request, channel = self.make_request(
+ "GET", self.url_other_user, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertEqual("foobar", channel.json_body["displayname"])
+
+ def test_set_threepid(self):
+ """
+ Test setting threepid for an other user.
+ """
+
+ # Delete old and add new threepid to user
+ body = json.dumps(
+ {"threepids": [{"medium": "email", "address": "bob3@bob.bob"}]}
+ )
+
+ request, channel = self.make_request(
+ "PUT",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
+ content=body.encode(encoding="utf_8"),
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
+ self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][0]["address"])
+
+ # Get user
+ request, channel = self.make_request(
+ "GET", self.url_other_user, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
+ self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][0]["address"])
+
+ def test_deactivate_user(self):
+ """
+ Test deactivating another user.
+ """
+
+ # Deactivate user
+ body = json.dumps({"deactivated": True})
+
+ request, channel = self.make_request(
+ "PUT",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
+ content=body.encode(encoding="utf_8"),
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertEqual(True, channel.json_body["deactivated"])
+ # the user is deactivated, the threepid will be deleted
+
+ # Get user
+ request, channel = self.make_request(
+ "GET", self.url_other_user, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertEqual(True, channel.json_body["deactivated"])
+
+ def test_set_user_as_admin(self):
+ """
+ Test setting the admin flag on a user.
+ """
+
+ # Set a user as an admin
+ body = json.dumps({"admin": True})
+
+ request, channel = self.make_request(
+ "PUT",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
+ content=body.encode(encoding="utf_8"),
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertEqual(True, channel.json_body["admin"])
+
+ # Get user
+ request, channel = self.make_request(
+ "GET", self.url_other_user, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertEqual(True, channel.json_body["admin"])
+
+ def test_accidental_deactivation_prevention(self):
+ """
+ Ensure an account can't accidentally be deactivated by using a str value
+ for the deactivated body parameter
+ """
+ url = "/_synapse/admin/v2/users/@bob:test"
+
+ # Create user
+ body = json.dumps({"password": "abc123"})
+
+ request, channel = self.make_request(
+ "PUT",
+ url,
+ access_token=self.admin_user_tok,
+ content=body.encode(encoding="utf_8"),
+ )
+ self.render(request)
+
+ self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@bob:test", channel.json_body["name"])
+ self.assertEqual("bob", channel.json_body["displayname"])
+
+ # Get user
+ request, channel = self.make_request(
+ "GET", url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@bob:test", channel.json_body["name"])
+ self.assertEqual("bob", channel.json_body["displayname"])
+ self.assertEqual(0, channel.json_body["deactivated"])
+
+ # Change password (and use a str for deactivate instead of a bool)
+ body = json.dumps({"password": "abc123", "deactivated": "false"}) # oops!
+
+ request, channel = self.make_request(
+ "PUT",
+ url,
+ access_token=self.admin_user_tok,
+ content=body.encode(encoding="utf_8"),
+ )
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Check user is not deactivated
+ request, channel = self.make_request(
+ "GET", url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@bob:test", channel.json_body["name"])
+ self.assertEqual("bob", channel.json_body["displayname"])
+
+ # Ensure they're still alive
+ self.assertEqual(0, channel.json_body["deactivated"])
diff --git a/tests/rest/client/test_ephemeral_message.py b/tests/rest/client/test_ephemeral_message.py
new file mode 100644
index 0000000000..5e9c07ebf3
--- /dev/null
+++ b/tests/rest/client/test_ephemeral_message.py
@@ -0,0 +1,101 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from synapse.api.constants import EventContentFields, EventTypes
+from synapse.rest import admin
+from synapse.rest.client.v1 import room
+
+from tests import unittest
+
+
+class EphemeralMessageTestCase(unittest.HomeserverTestCase):
+
+ user_id = "@user:test"
+
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+
+ config["enable_ephemeral_messages"] = True
+
+ self.hs = self.setup_test_homeserver(config=config)
+ return self.hs
+
+ def prepare(self, reactor, clock, homeserver):
+ self.room_id = self.helper.create_room_as(self.user_id)
+
+ def test_message_expiry_no_delay(self):
+ """Tests that sending a message sent with a m.self_destruct_after field set to the
+ past results in that event being deleted right away.
+ """
+ # Send a message in the room that has expired. From here, the reactor clock is
+ # at 200ms, so 0 is in the past, and even if that wasn't the case and the clock
+ # is at 0ms the code path is the same if the event's expiry timestamp is the
+ # current timestamp.
+ res = self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "hello",
+ EventContentFields.SELF_DESTRUCT_AFTER: 0,
+ },
+ )
+ event_id = res["event_id"]
+
+ # Check that we can't retrieve the content of the event.
+ event_content = self.get_event(self.room_id, event_id)["content"]
+ self.assertFalse(bool(event_content), event_content)
+
+ def test_message_expiry_delay(self):
+ """Tests that sending a message with a m.self_destruct_after field set to the
+ future results in that event not being deleted right away, but advancing the
+ clock to after that expiry timestamp causes the event to be deleted.
+ """
+ # Send a message in the room that'll expire in 1s.
+ res = self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "hello",
+ EventContentFields.SELF_DESTRUCT_AFTER: self.clock.time_msec() + 1000,
+ },
+ )
+ event_id = res["event_id"]
+
+ # Check that we can retrieve the content of the event before it has expired.
+ event_content = self.get_event(self.room_id, event_id)["content"]
+ self.assertTrue(bool(event_content), event_content)
+
+ # Advance the clock to after the deletion.
+ self.reactor.advance(1)
+
+ # Check that we can't retrieve the content of the event anymore.
+ event_content = self.get_event(self.room_id, event_id)["content"]
+ self.assertFalse(bool(event_content), event_content)
+
+ def get_event(self, room_id, event_id, expected_code=200):
+ url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id)
+
+ request, channel = self.make_request("GET", url)
+ self.render(request)
+
+ self.assertEqual(channel.code, expected_code, channel.result)
+
+ return channel.json_body
diff --git a/tests/rest/client/test_power_levels.py b/tests/rest/client/test_power_levels.py
new file mode 100644
index 0000000000..913ea3c98e
--- /dev/null
+++ b/tests/rest/client/test_power_levels.py
@@ -0,0 +1,205 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+from synapse.rest.client.v2_alpha import sync
+
+from tests.unittest import HomeserverTestCase
+
+
+class PowerLevelsTestCase(HomeserverTestCase):
+ """Tests that power levels are enforced in various situations"""
+
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ sync.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+
+ return self.setup_test_homeserver(config=config)
+
+ def prepare(self, reactor, clock, hs):
+ # register a room admin, moderator and regular user
+ self.admin_user_id = self.register_user("admin", "pass")
+ self.admin_access_token = self.login("admin", "pass")
+ self.mod_user_id = self.register_user("mod", "pass")
+ self.mod_access_token = self.login("mod", "pass")
+ self.user_user_id = self.register_user("user", "pass")
+ self.user_access_token = self.login("user", "pass")
+
+ # Create a room
+ self.room_id = self.helper.create_room_as(
+ self.admin_user_id, tok=self.admin_access_token
+ )
+
+ # Invite the other users
+ self.helper.invite(
+ room=self.room_id,
+ src=self.admin_user_id,
+ tok=self.admin_access_token,
+ targ=self.mod_user_id,
+ )
+ self.helper.invite(
+ room=self.room_id,
+ src=self.admin_user_id,
+ tok=self.admin_access_token,
+ targ=self.user_user_id,
+ )
+
+ # Make the other users join the room
+ self.helper.join(
+ room=self.room_id, user=self.mod_user_id, tok=self.mod_access_token
+ )
+ self.helper.join(
+ room=self.room_id, user=self.user_user_id, tok=self.user_access_token
+ )
+
+ # Mod the mod
+ room_power_levels = self.helper.get_state(
+ self.room_id, "m.room.power_levels", tok=self.admin_access_token,
+ )
+
+ # Update existing power levels with mod at PL50
+ room_power_levels["users"].update({self.mod_user_id: 50})
+
+ self.helper.send_state(
+ self.room_id,
+ "m.room.power_levels",
+ room_power_levels,
+ tok=self.admin_access_token,
+ )
+
+ def test_non_admins_cannot_enable_room_encryption(self):
+ # have the mod try to enable room encryption
+ self.helper.send_state(
+ self.room_id,
+ "m.room.encryption",
+ {"algorithm": "m.megolm.v1.aes-sha2"},
+ tok=self.mod_access_token,
+ expect_code=403, # expect failure
+ )
+
+ # have the user try to enable room encryption
+ self.helper.send_state(
+ self.room_id,
+ "m.room.encryption",
+ {"algorithm": "m.megolm.v1.aes-sha2"},
+ tok=self.user_access_token,
+ expect_code=403, # expect failure
+ )
+
+ def test_non_admins_cannot_send_server_acl(self):
+ # have the mod try to send a server ACL
+ self.helper.send_state(
+ self.room_id,
+ "m.room.server_acl",
+ {
+ "allow": ["*"],
+ "allow_ip_literals": False,
+ "deny": ["*.evil.com", "evil.com"],
+ },
+ tok=self.mod_access_token,
+ expect_code=403, # expect failure
+ )
+
+ # have the user try to send a server ACL
+ self.helper.send_state(
+ self.room_id,
+ "m.room.server_acl",
+ {
+ "allow": ["*"],
+ "allow_ip_literals": False,
+ "deny": ["*.evil.com", "evil.com"],
+ },
+ tok=self.user_access_token,
+ expect_code=403, # expect failure
+ )
+
+ def test_non_admins_cannot_tombstone_room(self):
+ # Create another room that will serve as our "upgraded room"
+ self.upgraded_room_id = self.helper.create_room_as(
+ self.admin_user_id, tok=self.admin_access_token
+ )
+
+ # have the mod try to send a tombstone event
+ self.helper.send_state(
+ self.room_id,
+ "m.room.tombstone",
+ {
+ "body": "This room has been replaced",
+ "replacement_room": self.upgraded_room_id,
+ },
+ tok=self.mod_access_token,
+ expect_code=403, # expect failure
+ )
+
+ # have the user try to send a tombstone event
+ self.helper.send_state(
+ self.room_id,
+ "m.room.tombstone",
+ {
+ "body": "This room has been replaced",
+ "replacement_room": self.upgraded_room_id,
+ },
+ tok=self.user_access_token,
+ expect_code=403, # expect failure
+ )
+
+ def test_admins_can_enable_room_encryption(self):
+ # have the admin try to enable room encryption
+ self.helper.send_state(
+ self.room_id,
+ "m.room.encryption",
+ {"algorithm": "m.megolm.v1.aes-sha2"},
+ tok=self.admin_access_token,
+ expect_code=200, # expect success
+ )
+
+ def test_admins_can_send_server_acl(self):
+ # have the admin try to send a server ACL
+ self.helper.send_state(
+ self.room_id,
+ "m.room.server_acl",
+ {
+ "allow": ["*"],
+ "allow_ip_literals": False,
+ "deny": ["*.evil.com", "evil.com"],
+ },
+ tok=self.admin_access_token,
+ expect_code=200, # expect success
+ )
+
+ def test_admins_can_tombstone_room(self):
+ # Create another room that will serve as our "upgraded room"
+ self.upgraded_room_id = self.helper.create_room_as(
+ self.admin_user_id, tok=self.admin_access_token
+ )
+
+ # have the admin try to send a tombstone event
+ self.helper.send_state(
+ self.room_id,
+ "m.room.tombstone",
+ {
+ "body": "This room has been replaced",
+ "replacement_room": self.upgraded_room_id,
+ },
+ tok=self.admin_access_token,
+ expect_code=200, # expect success
+ )
diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
new file mode 100644
index 0000000000..95475bb651
--- /dev/null
+++ b/tests/rest/client/test_retention.py
@@ -0,0 +1,293 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from mock import Mock
+
+from synapse.api.constants import EventTypes
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+from synapse.visibility import filter_events_for_client
+
+from tests import unittest
+
+one_hour_ms = 3600000
+one_day_ms = one_hour_ms * 24
+
+
+class RetentionTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+ config["retention"] = {
+ "enabled": True,
+ "default_policy": {
+ "min_lifetime": one_day_ms,
+ "max_lifetime": one_day_ms * 3,
+ },
+ "allowed_lifetime_min": one_day_ms,
+ "allowed_lifetime_max": one_day_ms * 3,
+ }
+
+ self.hs = self.setup_test_homeserver(config=config)
+ return self.hs
+
+ def prepare(self, reactor, clock, homeserver):
+ self.user_id = self.register_user("user", "password")
+ self.token = self.login("user", "password")
+
+ def test_retention_state_event(self):
+ """Tests that the server configuration can limit the values a user can set to the
+ room's retention policy.
+ """
+ room_id = self.helper.create_room_as(self.user_id, tok=self.token)
+
+ self.helper.send_state(
+ room_id=room_id,
+ event_type=EventTypes.Retention,
+ body={"max_lifetime": one_day_ms * 4},
+ tok=self.token,
+ expect_code=400,
+ )
+
+ self.helper.send_state(
+ room_id=room_id,
+ event_type=EventTypes.Retention,
+ body={"max_lifetime": one_hour_ms},
+ tok=self.token,
+ expect_code=400,
+ )
+
+ def test_retention_event_purged_with_state_event(self):
+ """Tests that expired events are correctly purged when the room's retention policy
+ is defined by a state event.
+ """
+ room_id = self.helper.create_room_as(self.user_id, tok=self.token)
+
+ # Set the room's retention period to 2 days.
+ lifetime = one_day_ms * 2
+ self.helper.send_state(
+ room_id=room_id,
+ event_type=EventTypes.Retention,
+ body={"max_lifetime": lifetime},
+ tok=self.token,
+ )
+
+ self._test_retention_event_purged(room_id, one_day_ms * 1.5)
+
+ def test_retention_event_purged_without_state_event(self):
+ """Tests that expired events are correctly purged when the room's retention policy
+ is defined by the server's configuration's default retention policy.
+ """
+ room_id = self.helper.create_room_as(self.user_id, tok=self.token)
+
+ self._test_retention_event_purged(room_id, one_day_ms * 2)
+
+ def test_visibility(self):
+ """Tests that synapse.visibility.filter_events_for_client correctly filters out
+ outdated events
+ """
+ store = self.hs.get_datastore()
+ storage = self.hs.get_storage()
+ room_id = self.helper.create_room_as(self.user_id, tok=self.token)
+ events = []
+
+ # Send a first event, which should be filtered out at the end of the test.
+ resp = self.helper.send(room_id=room_id, body="1", tok=self.token)
+
+ # Get the event from the store so that we end up with a FrozenEvent that we can
+ # give to filter_events_for_client. We need to do this now because the event won't
+ # be in the database anymore after it has expired.
+ events.append(self.get_success(store.get_event(resp.get("event_id"))))
+
+ # Advance the time by 2 days. We're using the default retention policy, therefore
+ # after this the first event will still be valid.
+ self.reactor.advance(one_day_ms * 2 / 1000)
+
+ # Send another event, which shouldn't get filtered out.
+ resp = self.helper.send(room_id=room_id, body="2", tok=self.token)
+
+ valid_event_id = resp.get("event_id")
+
+ events.append(self.get_success(store.get_event(valid_event_id)))
+
+ # Advance the time by anothe 2 days. After this, the first event should be
+ # outdated but not the second one.
+ self.reactor.advance(one_day_ms * 2 / 1000)
+
+ # Run filter_events_for_client with our list of FrozenEvents.
+ filtered_events = self.get_success(
+ filter_events_for_client(storage, self.user_id, events)
+ )
+
+ # We should only get one event back.
+ self.assertEqual(len(filtered_events), 1, filtered_events)
+ # That event should be the second, not outdated event.
+ self.assertEqual(filtered_events[0].event_id, valid_event_id, filtered_events)
+
+ def _test_retention_event_purged(self, room_id, increment):
+ # Get the create event to, later, check that we can still access it.
+ message_handler = self.hs.get_message_handler()
+ create_event = self.get_success(
+ message_handler.get_room_data(self.user_id, room_id, EventTypes.Create)
+ )
+
+ # Send a first event to the room. This is the event we'll want to be purged at the
+ # end of the test.
+ resp = self.helper.send(room_id=room_id, body="1", tok=self.token)
+
+ expired_event_id = resp.get("event_id")
+
+ # Check that we can retrieve the event.
+ expired_event = self.get_event(room_id, expired_event_id)
+ self.assertEqual(
+ expired_event.get("content", {}).get("body"), "1", expired_event
+ )
+
+ # Advance the time.
+ self.reactor.advance(increment / 1000)
+
+ # Send another event. We need this because the purge job won't purge the most
+ # recent event in the room.
+ resp = self.helper.send(room_id=room_id, body="2", tok=self.token)
+
+ valid_event_id = resp.get("event_id")
+
+ # Advance the time again. Now our first event should have expired but our second
+ # one should still be kept.
+ self.reactor.advance(increment / 1000)
+
+ # Check that the event has been purged from the database.
+ self.get_event(room_id, expired_event_id, expected_code=404)
+
+ # Check that the event that hasn't been purged can still be retrieved.
+ valid_event = self.get_event(room_id, valid_event_id)
+ self.assertEqual(valid_event.get("content", {}).get("body"), "2", valid_event)
+
+ # Check that we can still access state events that were sent before the event that
+ # has been purged.
+ self.get_event(room_id, create_event.event_id)
+
+ def get_event(self, room_id, event_id, expected_code=200):
+ url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id)
+
+ request, channel = self.make_request("GET", url, access_token=self.token)
+ self.render(request)
+
+ self.assertEqual(channel.code, expected_code, channel.result)
+
+ return channel.json_body
+
+
+class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+ config["retention"] = {
+ "enabled": True,
+ }
+
+ mock_federation_client = Mock(spec=["backfill"])
+
+ self.hs = self.setup_test_homeserver(
+ config=config, federation_client=mock_federation_client,
+ )
+ return self.hs
+
+ def prepare(self, reactor, clock, homeserver):
+ self.user_id = self.register_user("user", "password")
+ self.token = self.login("user", "password")
+
+ def test_no_default_policy(self):
+ """Tests that an event doesn't get expired if there is neither a default retention
+ policy nor a policy specific to the room.
+ """
+ room_id = self.helper.create_room_as(self.user_id, tok=self.token)
+
+ self._test_retention(room_id)
+
+ def test_state_policy(self):
+ """Tests that an event gets correctly expired if there is no default retention
+ policy but there's a policy specific to the room.
+ """
+ room_id = self.helper.create_room_as(self.user_id, tok=self.token)
+
+ # Set the maximum lifetime to 35 days so that the first event gets expired but not
+ # the second one.
+ self.helper.send_state(
+ room_id=room_id,
+ event_type=EventTypes.Retention,
+ body={"max_lifetime": one_day_ms * 35},
+ tok=self.token,
+ )
+
+ self._test_retention(room_id, expected_code_for_first_event=404)
+
+ def _test_retention(self, room_id, expected_code_for_first_event=200):
+ # Send a first event to the room. This is the event we'll want to be purged at the
+ # end of the test.
+ resp = self.helper.send(room_id=room_id, body="1", tok=self.token)
+
+ first_event_id = resp.get("event_id")
+
+ # Check that we can retrieve the event.
+ expired_event = self.get_event(room_id, first_event_id)
+ self.assertEqual(
+ expired_event.get("content", {}).get("body"), "1", expired_event
+ )
+
+ # Advance the time by a month.
+ self.reactor.advance(one_day_ms * 30 / 1000)
+
+ # Send another event. We need this because the purge job won't purge the most
+ # recent event in the room.
+ resp = self.helper.send(room_id=room_id, body="2", tok=self.token)
+
+ second_event_id = resp.get("event_id")
+
+ # Advance the time by another month.
+ self.reactor.advance(one_day_ms * 30 / 1000)
+
+ # Check if the event has been purged from the database.
+ first_event = self.get_event(
+ room_id, first_event_id, expected_code=expected_code_for_first_event
+ )
+
+ if expected_code_for_first_event == 200:
+ self.assertEqual(
+ first_event.get("content", {}).get("body"), "1", first_event
+ )
+
+ # Check that the event that hasn't been purged can still be retrieved.
+ second_event = self.get_event(room_id, second_event_id)
+ self.assertEqual(second_event.get("content", {}).get("body"), "2", second_event)
+
+ def get_event(self, room_id, event_id, expected_code=200):
+ url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id)
+
+ request, channel = self.make_request("GET", url, access_token=self.token)
+ self.render(request)
+
+ self.assertEqual(channel.code, expected_code, channel.result)
+
+ return channel.json_body
diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py
index a3d7e3c046..171632e195 100644
--- a/tests/rest/client/test_transactions.py
+++ b/tests/rest/client/test_transactions.py
@@ -2,7 +2,7 @@ from mock import Mock, call
from twisted.internet import defer, reactor
-from synapse.logging.context import LoggingContext
+from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context
from synapse.rest.client.transactions import CLEANUP_PERIOD_MS, HttpTransactionCache
from synapse.util import Clock
@@ -52,14 +52,14 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
def test():
with LoggingContext("c") as c1:
res = yield self.cache.fetch_or_execute(self.mock_key, cb)
- self.assertIs(LoggingContext.current_context(), c1)
+ self.assertIs(current_context(), c1)
self.assertEqual(res, "yay")
# run the test twice in parallel
d = defer.gatherResults([test(), test()])
- self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel)
+ self.assertIs(current_context(), SENTINEL_CONTEXT)
yield d
- self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel)
+ self.assertIs(current_context(), SENTINEL_CONTEXT)
@defer.inlineCallbacks
def test_does_not_cache_exceptions(self):
@@ -81,11 +81,11 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
yield self.cache.fetch_or_execute(self.mock_key, cb)
except Exception as e:
self.assertEqual(e.args[0], "boo")
- self.assertIs(LoggingContext.current_context(), test_context)
+ self.assertIs(current_context(), test_context)
res = yield self.cache.fetch_or_execute(self.mock_key, cb)
self.assertEqual(res, self.mock_http_response)
- self.assertIs(LoggingContext.current_context(), test_context)
+ self.assertIs(current_context(), test_context)
@defer.inlineCallbacks
def test_does_not_cache_failures(self):
@@ -107,11 +107,11 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
yield self.cache.fetch_or_execute(self.mock_key, cb)
except Exception as e:
self.assertEqual(e.args[0], "boo")
- self.assertIs(LoggingContext.current_context(), test_context)
+ self.assertIs(current_context(), test_context)
res = yield self.cache.fetch_or_execute(self.mock_key, cb)
self.assertEqual(res, self.mock_http_response)
- self.assertIs(LoggingContext.current_context(), test_context)
+ self.assertIs(current_context(), test_context)
@defer.inlineCallbacks
def test_cleans_up(self):
diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py
index f340b7e851..f75520877f 100644
--- a/tests/rest/client/v1/test_events.py
+++ b/tests/rest/client/v1/test_events.py
@@ -15,7 +15,7 @@
""" Tests REST events for /events paths."""
-from mock import Mock, NonCallableMock
+from mock import Mock
import synapse.rest.admin
from synapse.rest.client.v1 import events, login, room
@@ -40,17 +40,13 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
config["enable_registration"] = True
config["auto_join_rooms"] = []
- hs = self.setup_test_homeserver(
- config=config, ratelimiter=NonCallableMock(spec_set=["can_do_action"])
- )
- self.ratelimiter = hs.get_ratelimiter()
- self.ratelimiter.can_do_action.return_value = (True, 0)
+ hs = self.setup_test_homeserver(config=config)
hs.get_handlers().federation_handler = Mock()
return hs
- def prepare(self, hs, reactor, clock):
+ def prepare(self, reactor, clock, hs):
# register an account
self.user_id = self.register_user("sid1", "pass")
@@ -134,3 +130,30 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
# someone else set topic, expect 6 (join,send,topic,join,send,topic)
pass
+
+
+class GetEventsTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ events.register_servlets,
+ room.register_servlets,
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ login.register_servlets,
+ ]
+
+ def prepare(self, hs, reactor, clock):
+
+ # register an account
+ self.user_id = self.register_user("sid1", "pass")
+ self.token = self.login(self.user_id, "pass")
+
+ self.room_id = self.helper.create_room_as(self.user_id, tok=self.token)
+
+ def test_get_event_via_events(self):
+ resp = self.helper.send(self.room_id, tok=self.token)
+ event_id = resp["event_id"]
+
+ request, channel = self.make_request(
+ "GET", "/events/" + event_id, access_token=self.token,
+ )
+ self.render(request)
+ self.assertEquals(channel.code, 200, msg=channel.result)
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index eae5411325..9033f09fd2 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -1,7 +1,13 @@
import json
+import time
+import urllib.parse
+
+from mock import Mock
+
+import jwt
import synapse.rest.admin
-from synapse.rest.client.v1 import login
+from synapse.rest.client.v1 import login, logout
from synapse.rest.client.v2_alpha import devices
from synapse.rest.client.v2_alpha.account import WhoamiRestServlet
@@ -17,12 +23,12 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
login.register_servlets,
+ logout.register_servlets,
devices.register_servlets,
lambda hs, http_server: WhoamiRestServlet(hs).register(http_server),
]
def make_homeserver(self, reactor, clock):
-
self.hs = self.setup_test_homeserver()
self.hs.config.enable_registration = True
self.hs.config.registrations_require_3pid = []
@@ -31,10 +37,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
return self.hs
+ @override_config(
+ {
+ "rc_login": {
+ "address": {"per_second": 0.17, "burst_count": 5},
+ # Prevent the account login ratelimiter from raising first
+ #
+ # This is normally covered by the default test homeserver config
+ # which sets these values to 10000, but as we're overriding the entire
+ # rc_login dict here, we need to set this manually as well
+ "account": {"per_second": 10000, "burst_count": 10000},
+ }
+ }
+ )
def test_POST_ratelimiting_per_address(self):
- self.hs.config.rc_login_address.burst_count = 5
- self.hs.config.rc_login_address.per_second = 0.17
-
# Create different users so we're sure not to be bothered by the per-user
# ratelimiter.
for i in range(0, 6):
@@ -73,10 +89,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"200", channel.result)
+ @override_config(
+ {
+ "rc_login": {
+ "account": {"per_second": 0.17, "burst_count": 5},
+ # Prevent the address login ratelimiter from raising first
+ #
+ # This is normally covered by the default test homeserver config
+ # which sets these values to 10000, but as we're overriding the entire
+ # rc_login dict here, we need to set this manually as well
+ "address": {"per_second": 10000, "burst_count": 10000},
+ }
+ }
+ )
def test_POST_ratelimiting_per_account(self):
- self.hs.config.rc_login_account.burst_count = 5
- self.hs.config.rc_login_account.per_second = 0.17
-
self.register_user("kermit", "monkey")
for i in range(0, 6):
@@ -112,10 +138,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"200", channel.result)
+ @override_config(
+ {
+ "rc_login": {
+ # Prevent the address login ratelimiter from raising first
+ #
+ # This is normally covered by the default test homeserver config
+ # which sets these values to 10000, but as we're overriding the entire
+ # rc_login dict here, we need to set this manually as well
+ "address": {"per_second": 10000, "burst_count": 10000},
+ "failed_attempts": {"per_second": 0.17, "burst_count": 5},
+ }
+ }
+ )
def test_POST_ratelimiting_per_account_failed_attempts(self):
- self.hs.config.rc_login_failed_attempts.burst_count = 5
- self.hs.config.rc_login_failed_attempts.per_second = 0.17
-
self.register_user("kermit", "monkey")
for i in range(0, 6):
@@ -252,3 +288,370 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
)
self.render(request)
self.assertEquals(channel.code, 200, channel.result)
+
+ @override_config({"session_lifetime": "24h"})
+ def test_session_can_hard_logout_after_being_soft_logged_out(self):
+ self.register_user("kermit", "monkey")
+
+ # log in as normal
+ access_token = self.login("kermit", "monkey")
+
+ # we should now be able to make requests with the access token
+ request, channel = self.make_request(
+ b"GET", TEST_URL, access_token=access_token
+ )
+ self.render(request)
+ self.assertEquals(channel.code, 200, channel.result)
+
+ # time passes
+ self.reactor.advance(24 * 3600)
+
+ # ... and we should be soft-logouted
+ request, channel = self.make_request(
+ b"GET", TEST_URL, access_token=access_token
+ )
+ self.render(request)
+ self.assertEquals(channel.code, 401, channel.result)
+ self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
+ self.assertEquals(channel.json_body["soft_logout"], True)
+
+ # Now try to hard logout this session
+ request, channel = self.make_request(
+ b"POST", "/logout", access_token=access_token
+ )
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ @override_config({"session_lifetime": "24h"})
+ def test_session_can_hard_logout_all_sessions_after_being_soft_logged_out(self):
+ self.register_user("kermit", "monkey")
+
+ # log in as normal
+ access_token = self.login("kermit", "monkey")
+
+ # we should now be able to make requests with the access token
+ request, channel = self.make_request(
+ b"GET", TEST_URL, access_token=access_token
+ )
+ self.render(request)
+ self.assertEquals(channel.code, 200, channel.result)
+
+ # time passes
+ self.reactor.advance(24 * 3600)
+
+ # ... and we should be soft-logouted
+ request, channel = self.make_request(
+ b"GET", TEST_URL, access_token=access_token
+ )
+ self.render(request)
+ self.assertEquals(channel.code, 401, channel.result)
+ self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
+ self.assertEquals(channel.json_body["soft_logout"], True)
+
+ # Now try to hard log out all of the user's sessions
+ request, channel = self.make_request(
+ b"POST", "/logout/all", access_token=access_token
+ )
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+
+class CASTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ login.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ self.base_url = "https://matrix.goodserver.com/"
+ self.redirect_path = "_synapse/client/login/sso/redirect/confirm"
+
+ config = self.default_config()
+ config["cas_config"] = {
+ "enabled": True,
+ "server_url": "https://fake.test",
+ "service_url": "https://matrix.goodserver.com:8448",
+ }
+
+ cas_user_id = "username"
+ self.user_id = "@%s:test" % cas_user_id
+
+ async def get_raw(uri, args):
+ """Return an example response payload from a call to the `/proxyValidate`
+ endpoint of a CAS server, copied from
+ https://apereo.github.io/cas/5.0.x/protocol/CAS-Protocol-V2-Specification.html#26-proxyvalidate-cas-20
+
+ This needs to be returned by an async function (as opposed to set as the
+ mock's return value) because the corresponding Synapse code awaits on it.
+ """
+ return (
+ """
+ <cas:serviceResponse xmlns:cas='http://www.yale.edu/tp/cas'>
+ <cas:authenticationSuccess>
+ <cas:user>%s</cas:user>
+ <cas:proxyGrantingTicket>PGTIOU-84678-8a9d...</cas:proxyGrantingTicket>
+ <cas:proxies>
+ <cas:proxy>https://proxy2/pgtUrl</cas:proxy>
+ <cas:proxy>https://proxy1/pgtUrl</cas:proxy>
+ </cas:proxies>
+ </cas:authenticationSuccess>
+ </cas:serviceResponse>
+ """
+ % cas_user_id
+ )
+
+ mocked_http_client = Mock(spec=["get_raw"])
+ mocked_http_client.get_raw.side_effect = get_raw
+
+ self.hs = self.setup_test_homeserver(
+ config=config, proxied_http_client=mocked_http_client,
+ )
+
+ return self.hs
+
+ def prepare(self, reactor, clock, hs):
+ self.deactivate_account_handler = hs.get_deactivate_account_handler()
+
+ def test_cas_redirect_confirm(self):
+ """Tests that the SSO login flow serves a confirmation page before redirecting a
+ user to the redirect URL.
+ """
+ base_url = "/_matrix/client/r0/login/cas/ticket?redirectUrl"
+ redirect_url = "https://dodgy-site.com/"
+
+ url_parts = list(urllib.parse.urlparse(base_url))
+ query = dict(urllib.parse.parse_qsl(url_parts[4]))
+ query.update({"redirectUrl": redirect_url})
+ query.update({"ticket": "ticket"})
+ url_parts[4] = urllib.parse.urlencode(query)
+ cas_ticket_url = urllib.parse.urlunparse(url_parts)
+
+ # Get Synapse to call the fake CAS and serve the template.
+ request, channel = self.make_request("GET", cas_ticket_url)
+ self.render(request)
+
+ # Test that the response is HTML.
+ self.assertEqual(channel.code, 200)
+ content_type_header_value = ""
+ for header in channel.result.get("headers", []):
+ if header[0] == b"Content-Type":
+ content_type_header_value = header[1].decode("utf8")
+
+ self.assertTrue(content_type_header_value.startswith("text/html"))
+
+ # Test that the body isn't empty.
+ self.assertTrue(len(channel.result["body"]) > 0)
+
+ # And that it contains our redirect link
+ self.assertIn(redirect_url, channel.result["body"].decode("UTF-8"))
+
+ @override_config(
+ {
+ "sso": {
+ "client_whitelist": [
+ "https://legit-site.com/",
+ "https://other-site.com/",
+ ]
+ }
+ }
+ )
+ def test_cas_redirect_whitelisted(self):
+ """Tests that the SSO login flow serves a redirect to a whitelisted url
+ """
+ self._test_redirect("https://legit-site.com/")
+
+ @override_config({"public_baseurl": "https://example.com"})
+ def test_cas_redirect_login_fallback(self):
+ self._test_redirect("https://example.com/_matrix/static/client/login")
+
+ def _test_redirect(self, redirect_url):
+ """Tests that the SSO login flow serves a redirect for the given redirect URL."""
+ cas_ticket_url = (
+ "/_matrix/client/r0/login/cas/ticket?redirectUrl=%s&ticket=ticket"
+ % (urllib.parse.quote(redirect_url))
+ )
+
+ # Get Synapse to call the fake CAS and serve the template.
+ request, channel = self.make_request("GET", cas_ticket_url)
+ self.render(request)
+
+ self.assertEqual(channel.code, 302)
+ location_headers = channel.headers.getRawHeaders("Location")
+ self.assertEqual(location_headers[0][: len(redirect_url)], redirect_url)
+
+ @override_config({"sso": {"client_whitelist": ["https://legit-site.com/"]}})
+ def test_deactivated_user(self):
+ """Logging in as a deactivated account should error."""
+ redirect_url = "https://legit-site.com/"
+
+ # First login (to create the user).
+ self._test_redirect(redirect_url)
+
+ # Deactivate the account.
+ self.get_success(
+ self.deactivate_account_handler.deactivate_account(self.user_id, False)
+ )
+
+ # Request the CAS ticket.
+ cas_ticket_url = (
+ "/_matrix/client/r0/login/cas/ticket?redirectUrl=%s&ticket=ticket"
+ % (urllib.parse.quote(redirect_url))
+ )
+
+ # Get Synapse to call the fake CAS and serve the template.
+ request, channel = self.make_request("GET", cas_ticket_url)
+ self.render(request)
+
+ # Because the user is deactivated they are served an error template.
+ self.assertEqual(channel.code, 403)
+ self.assertIn(b"SSO account deactivated", channel.result["body"])
+
+
+class JWTTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ login.register_servlets,
+ ]
+
+ jwt_secret = "secret"
+
+ def make_homeserver(self, reactor, clock):
+ self.hs = self.setup_test_homeserver()
+ self.hs.config.jwt_enabled = True
+ self.hs.config.jwt_secret = self.jwt_secret
+ self.hs.config.jwt_algorithm = "HS256"
+ return self.hs
+
+ def jwt_encode(self, token, secret=jwt_secret):
+ return jwt.encode(token, secret, "HS256").decode("ascii")
+
+ def jwt_login(self, *args):
+ params = json.dumps({"type": "m.login.jwt", "token": self.jwt_encode(*args)})
+ request, channel = self.make_request(b"POST", LOGIN_URL, params)
+ self.render(request)
+ return channel
+
+ def test_login_jwt_valid_registered(self):
+ self.register_user("kermit", "monkey")
+ channel = self.jwt_login({"sub": "kermit"})
+ self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.json_body["user_id"], "@kermit:test")
+
+ def test_login_jwt_valid_unregistered(self):
+ channel = self.jwt_login({"sub": "frog"})
+ self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.json_body["user_id"], "@frog:test")
+
+ def test_login_jwt_invalid_signature(self):
+ channel = self.jwt_login({"sub": "frog"}, "notsecret")
+ self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
+ self.assertEqual(channel.json_body["error"], "Invalid JWT")
+
+ def test_login_jwt_expired(self):
+ channel = self.jwt_login({"sub": "frog", "exp": 864000})
+ self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
+ self.assertEqual(channel.json_body["error"], "JWT expired")
+
+ def test_login_jwt_not_before(self):
+ now = int(time.time())
+ channel = self.jwt_login({"sub": "frog", "nbf": now + 3600})
+ self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
+ self.assertEqual(channel.json_body["error"], "Invalid JWT")
+
+ def test_login_no_sub(self):
+ channel = self.jwt_login({"username": "root"})
+ self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
+ self.assertEqual(channel.json_body["error"], "Invalid JWT")
+
+ def test_login_no_token(self):
+ params = json.dumps({"type": "m.login.jwt"})
+ request, channel = self.make_request(b"POST", LOGIN_URL, params)
+ self.render(request)
+ self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
+ self.assertEqual(channel.json_body["error"], "Token field for JWT is missing")
+
+
+# The JWTPubKeyTestCase is a complement to JWTTestCase where we instead use
+# RSS256, with a public key configured in synapse as "jwt_secret", and tokens
+# signed by the private key.
+class JWTPubKeyTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ login.register_servlets,
+ ]
+
+ # This key's pubkey is used as the jwt_secret setting of synapse. Valid
+ # tokens are signed by this and validated using the pubkey. It is generated
+ # with `openssl genrsa 512` (not a secure way to generate real keys, but
+ # good enough for tests!)
+ jwt_privatekey = "\n".join(
+ [
+ "-----BEGIN RSA PRIVATE KEY-----",
+ "MIIBPAIBAAJBAM50f1Q5gsdmzifLstzLHb5NhfajiOt7TKO1vSEWdq7u9x8SMFiB",
+ "492RM9W/XFoh8WUfL9uL6Now6tPRDsWv3xsCAwEAAQJAUv7OOSOtiU+wzJq82rnk",
+ "yR4NHqt7XX8BvkZPM7/+EjBRanmZNSp5kYZzKVaZ/gTOM9+9MwlmhidrUOweKfB/",
+ "kQIhAPZwHazbjo7dYlJs7wPQz1vd+aHSEH+3uQKIysebkmm3AiEA1nc6mDdmgiUq",
+ "TpIN8A4MBKmfZMWTLq6z05y/qjKyxb0CIQDYJxCwTEenIaEa4PdoJl+qmXFasVDN",
+ "ZU0+XtNV7yul0wIhAMI9IhiStIjS2EppBa6RSlk+t1oxh2gUWlIh+YVQfZGRAiEA",
+ "tqBR7qLZGJ5CVKxWmNhJZGt1QHoUtOch8t9C4IdOZ2g=",
+ "-----END RSA PRIVATE KEY-----",
+ ]
+ )
+
+ # Generated with `openssl rsa -in foo.key -pubout`, with the the above
+ # private key placed in foo.key (jwt_privatekey).
+ jwt_pubkey = "\n".join(
+ [
+ "-----BEGIN PUBLIC KEY-----",
+ "MFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBAM50f1Q5gsdmzifLstzLHb5NhfajiOt7",
+ "TKO1vSEWdq7u9x8SMFiB492RM9W/XFoh8WUfL9uL6Now6tPRDsWv3xsCAwEAAQ==",
+ "-----END PUBLIC KEY-----",
+ ]
+ )
+
+ # This key is used to sign tokens that shouldn't be accepted by synapse.
+ # Generated just like jwt_privatekey.
+ bad_privatekey = "\n".join(
+ [
+ "-----BEGIN RSA PRIVATE KEY-----",
+ "MIIBOgIBAAJBAL//SQrKpKbjCCnv/FlasJCv+t3k/MPsZfniJe4DVFhsktF2lwQv",
+ "gLjmQD3jBUTz+/FndLSBvr3F4OHtGL9O/osCAwEAAQJAJqH0jZJW7Smzo9ShP02L",
+ "R6HRZcLExZuUrWI+5ZSP7TaZ1uwJzGFspDrunqaVoPobndw/8VsP8HFyKtceC7vY",
+ "uQIhAPdYInDDSJ8rFKGiy3Ajv5KWISBicjevWHF9dbotmNO9AiEAxrdRJVU+EI9I",
+ "eB4qRZpY6n4pnwyP0p8f/A3NBaQPG+cCIFlj08aW/PbxNdqYoBdeBA0xDrXKfmbb",
+ "iwYxBkwL0JCtAiBYmsi94sJn09u2Y4zpuCbJeDPKzWkbuwQh+W1fhIWQJQIhAKR0",
+ "KydN6cRLvphNQ9c/vBTdlzWxzcSxREpguC7F1J1m",
+ "-----END RSA PRIVATE KEY-----",
+ ]
+ )
+
+ def make_homeserver(self, reactor, clock):
+ self.hs = self.setup_test_homeserver()
+ self.hs.config.jwt_enabled = True
+ self.hs.config.jwt_secret = self.jwt_pubkey
+ self.hs.config.jwt_algorithm = "RS256"
+ return self.hs
+
+ def jwt_encode(self, token, secret=jwt_privatekey):
+ return jwt.encode(token, secret, "RS256").decode("ascii")
+
+ def jwt_login(self, *args):
+ params = json.dumps({"type": "m.login.jwt", "token": self.jwt_encode(*args)})
+ request, channel = self.make_request(b"POST", LOGIN_URL, params)
+ self.render(request)
+ return channel
+
+ def test_login_jwt_valid(self):
+ channel = self.jwt_login({"sub": "kermit"})
+ self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.json_body["user_id"], "@kermit:test")
+
+ def test_login_jwt_invalid_signature(self):
+ channel = self.jwt_login({"sub": "frog"}, self.bad_privatekey)
+ self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
+ self.assertEqual(channel.json_body["error"], "Invalid JWT")
diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py
index 66c2b68707..0fdff79aa7 100644
--- a/tests/rest/client/v1/test_presence.py
+++ b/tests/rest/client/v1/test_presence.py
@@ -15,6 +15,8 @@
from mock import Mock
+from twisted.internet import defer
+
from synapse.rest.client.v1 import presence
from synapse.types import UserID
@@ -36,6 +38,7 @@ class PresenceTestCase(unittest.HomeserverTestCase):
)
hs.presence_handler = Mock()
+ hs.presence_handler.set_state.return_value = defer.succeed(None)
return hs
diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py
index 140d8b3772..8df58b4a63 100644
--- a/tests/rest/client/v1/test_profile.py
+++ b/tests/rest/client/v1/test_profile.py
@@ -52,6 +52,14 @@ class MockHandlerProfileTestCase(unittest.TestCase):
]
)
+ self.mock_handler.get_displayname.return_value = defer.succeed(Mock())
+ self.mock_handler.set_displayname.return_value = defer.succeed(Mock())
+ self.mock_handler.get_avatar_url.return_value = defer.succeed(Mock())
+ self.mock_handler.set_avatar_url.return_value = defer.succeed(Mock())
+ self.mock_handler.check_profile_query_allowed.return_value = defer.succeed(
+ Mock()
+ )
+
hs = yield setup_test_homeserver(
self.addCleanup,
"test",
@@ -63,7 +71,7 @@ class MockHandlerProfileTestCase(unittest.TestCase):
)
def _get_user_by_req(request=None, allow_guest=False):
- return synapse.types.create_requester(myid)
+ return defer.succeed(synapse.types.create_requester(myid))
hs.get_auth().get_user_by_req = _get_user_by_req
@@ -229,6 +237,7 @@ class ProfilesRestrictedTestCase(unittest.HomeserverTestCase):
config = self.default_config()
config["require_auth_for_profile_requests"] = True
+ config["limit_profile_requests_to_users_who_share_rooms"] = True
self.hs = self.setup_test_homeserver(config=config)
return self.hs
@@ -301,6 +310,7 @@ class OwnProfileUnrestrictedTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
config = self.default_config()
config["require_auth_for_profile_requests"] = True
+ config["limit_profile_requests_to_users_who_share_rooms"] = True
self.hs = self.setup_test_homeserver(config=config)
return self.hs
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index fe741637f5..4886bbb401 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -1,5 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2017 Vector Creations Ltd
+# Copyright 2018-2019 New Vector Ltd
# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -18,14 +20,18 @@
import json
-from mock import Mock, NonCallableMock
+from mock import Mock
from six.moves.urllib import parse as urlparse
from twisted.internet import defer
import synapse.rest.admin
-from synapse.api.constants import Membership
-from synapse.rest.client.v1 import login, profile, room
+from synapse.api.constants import EventContentFields, EventTypes, Membership
+from synapse.handlers.pagination import PurgeStatus
+from synapse.rest.client.v1 import directory, login, profile, room
+from synapse.rest.client.v2_alpha import account
+from synapse.types import JsonDict, RoomAlias
+from synapse.util.stringutils import random_string
from tests import unittest
@@ -40,13 +46,8 @@ class RoomBase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
self.hs = self.setup_test_homeserver(
- "red",
- http_client=None,
- federation_client=Mock(),
- ratelimiter=NonCallableMock(spec_set=["can_do_action"]),
+ "red", http_client=None, federation_client=Mock(),
)
- self.ratelimiter = self.hs.get_ratelimiter()
- self.ratelimiter.can_do_action.return_value = (True, 0)
self.hs.get_federation_handler = Mock(return_value=Mock())
@@ -484,6 +485,15 @@ class RoomsCreateTestCase(RoomBase):
self.render(request)
self.assertEquals(400, channel.code)
+ def test_post_room_invitees_invalid_mxid(self):
+ # POST with invalid invitee, see https://github.com/matrix-org/synapse/issues/4088
+ # Note the trailing space in the MXID here!
+ request, channel = self.make_request(
+ "POST", "/createRoom", b'{"invite":["@alice:example.com "]}'
+ )
+ self.render(request)
+ self.assertEquals(400, channel.code)
+
class RoomTopicTestCase(RoomBase):
""" Tests /rooms/$room_id/topic REST events. """
@@ -802,6 +812,78 @@ class RoomMessageListTestCase(RoomBase):
self.assertTrue("chunk" in channel.json_body)
self.assertTrue("end" in channel.json_body)
+ def test_room_messages_purge(self):
+ store = self.hs.get_datastore()
+ pagination_handler = self.hs.get_pagination_handler()
+
+ # Send a first message in the room, which will be removed by the purge.
+ first_event_id = self.helper.send(self.room_id, "message 1")["event_id"]
+ first_token = self.get_success(
+ store.get_topological_token_for_event(first_event_id)
+ )
+
+ # Send a second message in the room, which won't be removed, and which we'll
+ # use as the marker to purge events before.
+ second_event_id = self.helper.send(self.room_id, "message 2")["event_id"]
+ second_token = self.get_success(
+ store.get_topological_token_for_event(second_event_id)
+ )
+
+ # Send a third event in the room to ensure we don't fall under any edge case
+ # due to our marker being the latest forward extremity in the room.
+ self.helper.send(self.room_id, "message 3")
+
+ # Check that we get the first and second message when querying /messages.
+ request, channel = self.make_request(
+ "GET",
+ "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s"
+ % (self.room_id, second_token, json.dumps({"types": [EventTypes.Message]})),
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ chunk = channel.json_body["chunk"]
+ self.assertEqual(len(chunk), 2, [event["content"] for event in chunk])
+
+ # Purge every event before the second event.
+ purge_id = random_string(16)
+ pagination_handler._purges_by_id[purge_id] = PurgeStatus()
+ self.get_success(
+ pagination_handler._purge_history(
+ purge_id=purge_id,
+ room_id=self.room_id,
+ token=second_token,
+ delete_local_events=True,
+ )
+ )
+
+ # Check that we only get the second message through /message now that the first
+ # has been purged.
+ request, channel = self.make_request(
+ "GET",
+ "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s"
+ % (self.room_id, second_token, json.dumps({"types": [EventTypes.Message]})),
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ chunk = channel.json_body["chunk"]
+ self.assertEqual(len(chunk), 1, [event["content"] for event in chunk])
+
+ # Check that we get no event, but also no error, when querying /messages with
+ # the token that was pointing at the first event, because we don't have it
+ # anymore.
+ request, channel = self.make_request(
+ "GET",
+ "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s"
+ % (self.room_id, first_token, json.dumps({"types": [EventTypes.Message]})),
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ chunk = channel.json_body["chunk"]
+ self.assertEqual(len(chunk), 0, [event["content"] for event in chunk])
+
class RoomSearchTestCase(unittest.HomeserverTestCase):
servlets = [
@@ -998,3 +1080,899 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase):
res_displayname = channel.json_body["content"]["displayname"]
self.assertEqual(res_displayname, self.displayname, channel.result)
+
+
+class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
+ """Tests that clients can add a "reason" field to membership events and
+ that they get correctly added to the generated events and propagated.
+ """
+
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, homeserver):
+ self.creator = self.register_user("creator", "test")
+ self.creator_tok = self.login("creator", "test")
+
+ self.second_user_id = self.register_user("second", "test")
+ self.second_tok = self.login("second", "test")
+
+ self.room_id = self.helper.create_room_as(self.creator, tok=self.creator_tok)
+
+ def test_join_reason(self):
+ reason = "hello"
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/rooms/{}/join".format(self.room_id),
+ content={"reason": reason},
+ access_token=self.second_tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ self._check_for_reason(reason)
+
+ def test_leave_reason(self):
+ self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok)
+
+ reason = "hello"
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/rooms/{}/leave".format(self.room_id),
+ content={"reason": reason},
+ access_token=self.second_tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ self._check_for_reason(reason)
+
+ def test_kick_reason(self):
+ self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok)
+
+ reason = "hello"
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/rooms/{}/kick".format(self.room_id),
+ content={"reason": reason, "user_id": self.second_user_id},
+ access_token=self.second_tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ self._check_for_reason(reason)
+
+ def test_ban_reason(self):
+ self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok)
+
+ reason = "hello"
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/rooms/{}/ban".format(self.room_id),
+ content={"reason": reason, "user_id": self.second_user_id},
+ access_token=self.creator_tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ self._check_for_reason(reason)
+
+ def test_unban_reason(self):
+ reason = "hello"
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/rooms/{}/unban".format(self.room_id),
+ content={"reason": reason, "user_id": self.second_user_id},
+ access_token=self.creator_tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ self._check_for_reason(reason)
+
+ def test_invite_reason(self):
+ reason = "hello"
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/rooms/{}/invite".format(self.room_id),
+ content={"reason": reason, "user_id": self.second_user_id},
+ access_token=self.creator_tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ self._check_for_reason(reason)
+
+ def test_reject_invite_reason(self):
+ self.helper.invite(
+ self.room_id,
+ src=self.creator,
+ targ=self.second_user_id,
+ tok=self.creator_tok,
+ )
+
+ reason = "hello"
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/rooms/{}/leave".format(self.room_id),
+ content={"reason": reason},
+ access_token=self.second_tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ self._check_for_reason(reason)
+
+ def _check_for_reason(self, reason):
+ request, channel = self.make_request(
+ "GET",
+ "/_matrix/client/r0/rooms/{}/state/m.room.member/{}".format(
+ self.room_id, self.second_user_id
+ ),
+ access_token=self.creator_tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ event_content = channel.json_body
+
+ self.assertEqual(event_content.get("reason"), reason, channel.result)
+
+
+class LabelsTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ profile.register_servlets,
+ ]
+
+ # Filter that should only catch messages with the label "#fun".
+ FILTER_LABELS = {
+ "types": [EventTypes.Message],
+ "org.matrix.labels": ["#fun"],
+ }
+ # Filter that should only catch messages without the label "#fun".
+ FILTER_NOT_LABELS = {
+ "types": [EventTypes.Message],
+ "org.matrix.not_labels": ["#fun"],
+ }
+ # Filter that should only catch messages with the label "#work" but without the label
+ # "#notfun".
+ FILTER_LABELS_NOT_LABELS = {
+ "types": [EventTypes.Message],
+ "org.matrix.labels": ["#work"],
+ "org.matrix.not_labels": ["#notfun"],
+ }
+
+ def prepare(self, reactor, clock, homeserver):
+ self.user_id = self.register_user("test", "test")
+ self.tok = self.login("test", "test")
+ self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
+
+ def test_context_filter_labels(self):
+ """Test that we can filter by a label on a /context request."""
+ event_id = self._send_labelled_messages_in_room()
+
+ request, channel = self.make_request(
+ "GET",
+ "/rooms/%s/context/%s?filter=%s"
+ % (self.room_id, event_id, json.dumps(self.FILTER_LABELS)),
+ access_token=self.tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ events_before = channel.json_body["events_before"]
+
+ self.assertEqual(
+ len(events_before), 1, [event["content"] for event in events_before]
+ )
+ self.assertEqual(
+ events_before[0]["content"]["body"], "with right label", events_before[0]
+ )
+
+ events_after = channel.json_body["events_before"]
+
+ self.assertEqual(
+ len(events_after), 1, [event["content"] for event in events_after]
+ )
+ self.assertEqual(
+ events_after[0]["content"]["body"], "with right label", events_after[0]
+ )
+
+ def test_context_filter_not_labels(self):
+ """Test that we can filter by the absence of a label on a /context request."""
+ event_id = self._send_labelled_messages_in_room()
+
+ request, channel = self.make_request(
+ "GET",
+ "/rooms/%s/context/%s?filter=%s"
+ % (self.room_id, event_id, json.dumps(self.FILTER_NOT_LABELS)),
+ access_token=self.tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ events_before = channel.json_body["events_before"]
+
+ self.assertEqual(
+ len(events_before), 1, [event["content"] for event in events_before]
+ )
+ self.assertEqual(
+ events_before[0]["content"]["body"], "without label", events_before[0]
+ )
+
+ events_after = channel.json_body["events_after"]
+
+ self.assertEqual(
+ len(events_after), 2, [event["content"] for event in events_after]
+ )
+ self.assertEqual(
+ events_after[0]["content"]["body"], "with wrong label", events_after[0]
+ )
+ self.assertEqual(
+ events_after[1]["content"]["body"], "with two wrong labels", events_after[1]
+ )
+
+ def test_context_filter_labels_not_labels(self):
+ """Test that we can filter by both a label and the absence of another label on a
+ /context request.
+ """
+ event_id = self._send_labelled_messages_in_room()
+
+ request, channel = self.make_request(
+ "GET",
+ "/rooms/%s/context/%s?filter=%s"
+ % (self.room_id, event_id, json.dumps(self.FILTER_LABELS_NOT_LABELS)),
+ access_token=self.tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ events_before = channel.json_body["events_before"]
+
+ self.assertEqual(
+ len(events_before), 0, [event["content"] for event in events_before]
+ )
+
+ events_after = channel.json_body["events_after"]
+
+ self.assertEqual(
+ len(events_after), 1, [event["content"] for event in events_after]
+ )
+ self.assertEqual(
+ events_after[0]["content"]["body"], "with wrong label", events_after[0]
+ )
+
+ def test_messages_filter_labels(self):
+ """Test that we can filter by a label on a /messages request."""
+ self._send_labelled_messages_in_room()
+
+ token = "s0_0_0_0_0_0_0_0_0"
+ request, channel = self.make_request(
+ "GET",
+ "/rooms/%s/messages?access_token=%s&from=%s&filter=%s"
+ % (self.room_id, self.tok, token, json.dumps(self.FILTER_LABELS)),
+ )
+ self.render(request)
+
+ events = channel.json_body["chunk"]
+
+ self.assertEqual(len(events), 2, [event["content"] for event in events])
+ self.assertEqual(events[0]["content"]["body"], "with right label", events[0])
+ self.assertEqual(events[1]["content"]["body"], "with right label", events[1])
+
+ def test_messages_filter_not_labels(self):
+ """Test that we can filter by the absence of a label on a /messages request."""
+ self._send_labelled_messages_in_room()
+
+ token = "s0_0_0_0_0_0_0_0_0"
+ request, channel = self.make_request(
+ "GET",
+ "/rooms/%s/messages?access_token=%s&from=%s&filter=%s"
+ % (self.room_id, self.tok, token, json.dumps(self.FILTER_NOT_LABELS)),
+ )
+ self.render(request)
+
+ events = channel.json_body["chunk"]
+
+ self.assertEqual(len(events), 4, [event["content"] for event in events])
+ self.assertEqual(events[0]["content"]["body"], "without label", events[0])
+ self.assertEqual(events[1]["content"]["body"], "without label", events[1])
+ self.assertEqual(events[2]["content"]["body"], "with wrong label", events[2])
+ self.assertEqual(
+ events[3]["content"]["body"], "with two wrong labels", events[3]
+ )
+
+ def test_messages_filter_labels_not_labels(self):
+ """Test that we can filter by both a label and the absence of another label on a
+ /messages request.
+ """
+ self._send_labelled_messages_in_room()
+
+ token = "s0_0_0_0_0_0_0_0_0"
+ request, channel = self.make_request(
+ "GET",
+ "/rooms/%s/messages?access_token=%s&from=%s&filter=%s"
+ % (
+ self.room_id,
+ self.tok,
+ token,
+ json.dumps(self.FILTER_LABELS_NOT_LABELS),
+ ),
+ )
+ self.render(request)
+
+ events = channel.json_body["chunk"]
+
+ self.assertEqual(len(events), 1, [event["content"] for event in events])
+ self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0])
+
+ def test_search_filter_labels(self):
+ """Test that we can filter by a label on a /search request."""
+ request_data = json.dumps(
+ {
+ "search_categories": {
+ "room_events": {
+ "search_term": "label",
+ "filter": self.FILTER_LABELS,
+ }
+ }
+ }
+ )
+
+ self._send_labelled_messages_in_room()
+
+ request, channel = self.make_request(
+ "POST", "/search?access_token=%s" % self.tok, request_data
+ )
+ self.render(request)
+
+ results = channel.json_body["search_categories"]["room_events"]["results"]
+
+ self.assertEqual(
+ len(results), 2, [result["result"]["content"] for result in results],
+ )
+ self.assertEqual(
+ results[0]["result"]["content"]["body"],
+ "with right label",
+ results[0]["result"]["content"]["body"],
+ )
+ self.assertEqual(
+ results[1]["result"]["content"]["body"],
+ "with right label",
+ results[1]["result"]["content"]["body"],
+ )
+
+ def test_search_filter_not_labels(self):
+ """Test that we can filter by the absence of a label on a /search request."""
+ request_data = json.dumps(
+ {
+ "search_categories": {
+ "room_events": {
+ "search_term": "label",
+ "filter": self.FILTER_NOT_LABELS,
+ }
+ }
+ }
+ )
+
+ self._send_labelled_messages_in_room()
+
+ request, channel = self.make_request(
+ "POST", "/search?access_token=%s" % self.tok, request_data
+ )
+ self.render(request)
+
+ results = channel.json_body["search_categories"]["room_events"]["results"]
+
+ self.assertEqual(
+ len(results), 4, [result["result"]["content"] for result in results],
+ )
+ self.assertEqual(
+ results[0]["result"]["content"]["body"],
+ "without label",
+ results[0]["result"]["content"]["body"],
+ )
+ self.assertEqual(
+ results[1]["result"]["content"]["body"],
+ "without label",
+ results[1]["result"]["content"]["body"],
+ )
+ self.assertEqual(
+ results[2]["result"]["content"]["body"],
+ "with wrong label",
+ results[2]["result"]["content"]["body"],
+ )
+ self.assertEqual(
+ results[3]["result"]["content"]["body"],
+ "with two wrong labels",
+ results[3]["result"]["content"]["body"],
+ )
+
+ def test_search_filter_labels_not_labels(self):
+ """Test that we can filter by both a label and the absence of another label on a
+ /search request.
+ """
+ request_data = json.dumps(
+ {
+ "search_categories": {
+ "room_events": {
+ "search_term": "label",
+ "filter": self.FILTER_LABELS_NOT_LABELS,
+ }
+ }
+ }
+ )
+
+ self._send_labelled_messages_in_room()
+
+ request, channel = self.make_request(
+ "POST", "/search?access_token=%s" % self.tok, request_data
+ )
+ self.render(request)
+
+ results = channel.json_body["search_categories"]["room_events"]["results"]
+
+ self.assertEqual(
+ len(results), 1, [result["result"]["content"] for result in results],
+ )
+ self.assertEqual(
+ results[0]["result"]["content"]["body"],
+ "with wrong label",
+ results[0]["result"]["content"]["body"],
+ )
+
+ def _send_labelled_messages_in_room(self):
+ """Sends several messages to a room with different labels (or without any) to test
+ filtering by label.
+ Returns:
+ The ID of the event to use if we're testing filtering on /context.
+ """
+ self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "with right label",
+ EventContentFields.LABELS: ["#fun"],
+ },
+ tok=self.tok,
+ )
+
+ self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={"msgtype": "m.text", "body": "without label"},
+ tok=self.tok,
+ )
+
+ res = self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={"msgtype": "m.text", "body": "without label"},
+ tok=self.tok,
+ )
+ # Return this event's ID when we test filtering in /context requests.
+ event_id = res["event_id"]
+
+ self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "with wrong label",
+ EventContentFields.LABELS: ["#work"],
+ },
+ tok=self.tok,
+ )
+
+ self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "with two wrong labels",
+ EventContentFields.LABELS: ["#work", "#notfun"],
+ },
+ tok=self.tok,
+ )
+
+ self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "with right label",
+ EventContentFields.LABELS: ["#fun"],
+ },
+ tok=self.tok,
+ )
+
+ return event_id
+
+
+class ContextTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ account.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, homeserver):
+ self.user_id = self.register_user("user", "password")
+ self.tok = self.login("user", "password")
+ self.room_id = self.helper.create_room_as(
+ self.user_id, tok=self.tok, is_public=False
+ )
+
+ self.other_user_id = self.register_user("user2", "password")
+ self.other_tok = self.login("user2", "password")
+
+ self.helper.invite(self.room_id, self.user_id, self.other_user_id, tok=self.tok)
+ self.helper.join(self.room_id, self.other_user_id, tok=self.other_tok)
+
+ def test_erased_sender(self):
+ """Test that an erasure request results in the requester's events being hidden
+ from any new member of the room.
+ """
+
+ # Send a bunch of events in the room.
+
+ self.helper.send(self.room_id, "message 1", tok=self.tok)
+ self.helper.send(self.room_id, "message 2", tok=self.tok)
+ event_id = self.helper.send(self.room_id, "message 3", tok=self.tok)["event_id"]
+ self.helper.send(self.room_id, "message 4", tok=self.tok)
+ self.helper.send(self.room_id, "message 5", tok=self.tok)
+
+ # Check that we can still see the messages before the erasure request.
+
+ request, channel = self.make_request(
+ "GET",
+ '/rooms/%s/context/%s?filter={"types":["m.room.message"]}'
+ % (self.room_id, event_id),
+ access_token=self.tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ events_before = channel.json_body["events_before"]
+
+ self.assertEqual(len(events_before), 2, events_before)
+ self.assertEqual(
+ events_before[0].get("content", {}).get("body"),
+ "message 2",
+ events_before[0],
+ )
+ self.assertEqual(
+ events_before[1].get("content", {}).get("body"),
+ "message 1",
+ events_before[1],
+ )
+
+ self.assertEqual(
+ channel.json_body["event"].get("content", {}).get("body"),
+ "message 3",
+ channel.json_body["event"],
+ )
+
+ events_after = channel.json_body["events_after"]
+
+ self.assertEqual(len(events_after), 2, events_after)
+ self.assertEqual(
+ events_after[0].get("content", {}).get("body"),
+ "message 4",
+ events_after[0],
+ )
+ self.assertEqual(
+ events_after[1].get("content", {}).get("body"),
+ "message 5",
+ events_after[1],
+ )
+
+ # Deactivate the first account and erase the user's data.
+
+ deactivate_account_handler = self.hs.get_deactivate_account_handler()
+ self.get_success(
+ deactivate_account_handler.deactivate_account(self.user_id, erase_data=True)
+ )
+
+ # Invite another user in the room. This is needed because messages will be
+ # pruned only if the user wasn't a member of the room when the messages were
+ # sent.
+
+ invited_user_id = self.register_user("user3", "password")
+ invited_tok = self.login("user3", "password")
+
+ self.helper.invite(
+ self.room_id, self.other_user_id, invited_user_id, tok=self.other_tok
+ )
+ self.helper.join(self.room_id, invited_user_id, tok=invited_tok)
+
+ # Check that a user that joined the room after the erasure request can't see
+ # the messages anymore.
+
+ request, channel = self.make_request(
+ "GET",
+ '/rooms/%s/context/%s?filter={"types":["m.room.message"]}'
+ % (self.room_id, event_id),
+ access_token=invited_tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ events_before = channel.json_body["events_before"]
+
+ self.assertEqual(len(events_before), 2, events_before)
+ self.assertDictEqual(events_before[0].get("content"), {}, events_before[0])
+ self.assertDictEqual(events_before[1].get("content"), {}, events_before[1])
+
+ self.assertDictEqual(
+ channel.json_body["event"].get("content"), {}, channel.json_body["event"]
+ )
+
+ events_after = channel.json_body["events_after"]
+
+ self.assertEqual(len(events_after), 2, events_after)
+ self.assertDictEqual(events_after[0].get("content"), {}, events_after[0])
+ self.assertEqual(events_after[1].get("content"), {}, events_after[1])
+
+
+class RoomAliasListTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ directory.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, homeserver):
+ self.room_owner = self.register_user("room_owner", "test")
+ self.room_owner_tok = self.login("room_owner", "test")
+
+ self.room_id = self.helper.create_room_as(
+ self.room_owner, tok=self.room_owner_tok
+ )
+
+ def test_no_aliases(self):
+ res = self._get_aliases(self.room_owner_tok)
+ self.assertEqual(res["aliases"], [])
+
+ def test_not_in_room(self):
+ self.register_user("user", "test")
+ user_tok = self.login("user", "test")
+ res = self._get_aliases(user_tok, expected_code=403)
+ self.assertEqual(res["errcode"], "M_FORBIDDEN")
+
+ def test_admin_user(self):
+ alias1 = self._random_alias()
+ self._set_alias_via_directory(alias1)
+
+ self.register_user("user", "test", admin=True)
+ user_tok = self.login("user", "test")
+
+ res = self._get_aliases(user_tok)
+ self.assertEqual(res["aliases"], [alias1])
+
+ def test_with_aliases(self):
+ alias1 = self._random_alias()
+ alias2 = self._random_alias()
+
+ self._set_alias_via_directory(alias1)
+ self._set_alias_via_directory(alias2)
+
+ res = self._get_aliases(self.room_owner_tok)
+ self.assertEqual(set(res["aliases"]), {alias1, alias2})
+
+ def test_peekable_room(self):
+ alias1 = self._random_alias()
+ self._set_alias_via_directory(alias1)
+
+ self.helper.send_state(
+ self.room_id,
+ EventTypes.RoomHistoryVisibility,
+ body={"history_visibility": "world_readable"},
+ tok=self.room_owner_tok,
+ )
+
+ self.register_user("user", "test")
+ user_tok = self.login("user", "test")
+
+ res = self._get_aliases(user_tok)
+ self.assertEqual(res["aliases"], [alias1])
+
+ def _get_aliases(self, access_token: str, expected_code: int = 200) -> JsonDict:
+ """Calls the endpoint under test. returns the json response object."""
+ request, channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/org.matrix.msc2432/rooms/%s/aliases"
+ % (self.room_id,),
+ access_token=access_token,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, expected_code, channel.result)
+ res = channel.json_body
+ self.assertIsInstance(res, dict)
+ if expected_code == 200:
+ self.assertIsInstance(res["aliases"], list)
+ return res
+
+ def _random_alias(self) -> str:
+ return RoomAlias(random_string(5), self.hs.hostname).to_string()
+
+ def _set_alias_via_directory(self, alias: str, expected_code: int = 200):
+ url = "/_matrix/client/r0/directory/room/" + alias
+ data = {"room_id": self.room_id}
+ request_data = json.dumps(data)
+
+ request, channel = self.make_request(
+ "PUT", url, request_data, access_token=self.room_owner_tok
+ )
+ self.render(request)
+ self.assertEqual(channel.code, expected_code, channel.result)
+
+
+class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ directory.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, homeserver):
+ self.room_owner = self.register_user("room_owner", "test")
+ self.room_owner_tok = self.login("room_owner", "test")
+
+ self.room_id = self.helper.create_room_as(
+ self.room_owner, tok=self.room_owner_tok
+ )
+
+ self.alias = "#alias:test"
+ self._set_alias_via_directory(self.alias)
+
+ def _set_alias_via_directory(self, alias: str, expected_code: int = 200):
+ url = "/_matrix/client/r0/directory/room/" + alias
+ data = {"room_id": self.room_id}
+ request_data = json.dumps(data)
+
+ request, channel = self.make_request(
+ "PUT", url, request_data, access_token=self.room_owner_tok
+ )
+ self.render(request)
+ self.assertEqual(channel.code, expected_code, channel.result)
+
+ def _get_canonical_alias(self, expected_code: int = 200) -> JsonDict:
+ """Calls the endpoint under test. returns the json response object."""
+ request, channel = self.make_request(
+ "GET",
+ "rooms/%s/state/m.room.canonical_alias" % (self.room_id,),
+ access_token=self.room_owner_tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, expected_code, channel.result)
+ res = channel.json_body
+ self.assertIsInstance(res, dict)
+ return res
+
+ def _set_canonical_alias(self, content: str, expected_code: int = 200) -> JsonDict:
+ """Calls the endpoint under test. returns the json response object."""
+ request, channel = self.make_request(
+ "PUT",
+ "rooms/%s/state/m.room.canonical_alias" % (self.room_id,),
+ json.dumps(content),
+ access_token=self.room_owner_tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, expected_code, channel.result)
+ res = channel.json_body
+ self.assertIsInstance(res, dict)
+ return res
+
+ def test_canonical_alias(self):
+ """Test a basic alias message."""
+ # There is no canonical alias to start with.
+ self._get_canonical_alias(expected_code=404)
+
+ # Create an alias.
+ self._set_canonical_alias({"alias": self.alias})
+
+ # Canonical alias now exists!
+ res = self._get_canonical_alias()
+ self.assertEqual(res, {"alias": self.alias})
+
+ # Now remove the alias.
+ self._set_canonical_alias({})
+
+ # There is an alias event, but it is empty.
+ res = self._get_canonical_alias()
+ self.assertEqual(res, {})
+
+ def test_alt_aliases(self):
+ """Test a canonical alias message with alt_aliases."""
+ # Create an alias.
+ self._set_canonical_alias({"alt_aliases": [self.alias]})
+
+ # Canonical alias now exists!
+ res = self._get_canonical_alias()
+ self.assertEqual(res, {"alt_aliases": [self.alias]})
+
+ # Now remove the alt_aliases.
+ self._set_canonical_alias({})
+
+ # There is an alias event, but it is empty.
+ res = self._get_canonical_alias()
+ self.assertEqual(res, {})
+
+ def test_alias_alt_aliases(self):
+ """Test a canonical alias message with an alias and alt_aliases."""
+ # Create an alias.
+ self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]})
+
+ # Canonical alias now exists!
+ res = self._get_canonical_alias()
+ self.assertEqual(res, {"alias": self.alias, "alt_aliases": [self.alias]})
+
+ # Now remove the alias and alt_aliases.
+ self._set_canonical_alias({})
+
+ # There is an alias event, but it is empty.
+ res = self._get_canonical_alias()
+ self.assertEqual(res, {})
+
+ def test_partial_modify(self):
+ """Test removing only the alt_aliases."""
+ # Create an alias.
+ self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]})
+
+ # Canonical alias now exists!
+ res = self._get_canonical_alias()
+ self.assertEqual(res, {"alias": self.alias, "alt_aliases": [self.alias]})
+
+ # Now remove the alt_aliases.
+ self._set_canonical_alias({"alias": self.alias})
+
+ # There is an alias event, but it is empty.
+ res = self._get_canonical_alias()
+ self.assertEqual(res, {"alias": self.alias})
+
+ def test_add_alias(self):
+ """Test removing only the alt_aliases."""
+ # Create an additional alias.
+ second_alias = "#second:test"
+ self._set_alias_via_directory(second_alias)
+
+ # Add the canonical alias.
+ self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]})
+
+ # Then add the second alias.
+ self._set_canonical_alias(
+ {"alias": self.alias, "alt_aliases": [self.alias, second_alias]}
+ )
+
+ # Canonical alias now exists!
+ res = self._get_canonical_alias()
+ self.assertEqual(
+ res, {"alias": self.alias, "alt_aliases": [self.alias, second_alias]}
+ )
+
+ def test_bad_data(self):
+ """Invalid data for alt_aliases should cause errors."""
+ self._set_canonical_alias({"alt_aliases": "@bad:test"}, expected_code=400)
+ self._set_canonical_alias({"alt_aliases": None}, expected_code=400)
+ self._set_canonical_alias({"alt_aliases": 0}, expected_code=400)
+ self._set_canonical_alias({"alt_aliases": 1}, expected_code=400)
+ self._set_canonical_alias({"alt_aliases": False}, expected_code=400)
+ self._set_canonical_alias({"alt_aliases": True}, expected_code=400)
+ self._set_canonical_alias({"alt_aliases": {}}, expected_code=400)
+
+ def test_bad_alias(self):
+ """An alias which does not point to the room raises a SynapseError."""
+ self._set_canonical_alias({"alias": "@unknown:test"}, expected_code=400)
+ self._set_canonical_alias({"alt_aliases": ["@unknown:test"]}, expected_code=400)
diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py
index 30fb77bac8..18260bb90e 100644
--- a/tests/rest/client/v1/test_typing.py
+++ b/tests/rest/client/v1/test_typing.py
@@ -16,7 +16,7 @@
"""Tests REST events for /rooms paths."""
-from mock import Mock, NonCallableMock
+from mock import Mock
from twisted.internet import defer
@@ -39,17 +39,11 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver(
- "red",
- http_client=None,
- federation_client=Mock(),
- ratelimiter=NonCallableMock(spec_set=["can_do_action"]),
+ "red", http_client=None, federation_client=Mock(),
)
self.event_source = hs.get_event_sources().sources["typing"]
- self.ratelimiter = hs.get_ratelimiter()
- self.ratelimiter.can_do_action.return_value = (True, 0)
-
hs.get_handlers().federation_handler = Mock()
def get_user_by_access_token(token=None, allow_guest=False):
@@ -109,7 +103,9 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
self.assertEquals(200, channel.code)
self.assertEquals(self.event_source.get_current_key(), 1)
- events = self.event_source.get_new_events(from_key=0, room_ids=[self.room_id])
+ events = self.get_success(
+ self.event_source.get_new_events(from_key=0, room_ids=[self.room_id])
+ )
self.assertEquals(
events[0],
[
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index cdded88b7f..22d734e763 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -1,5 +1,8 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2017 Vector Creations Ltd
+# Copyright 2018-2019 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,9 +18,12 @@
import json
import time
+from typing import Any, Dict, Optional
import attr
+from twisted.web.resource import Resource
+
from synapse.api.constants import Membership
from tests.server import make_request, render
@@ -33,7 +39,7 @@ class RestHelper(object):
resource = attr.ib()
auth_user_id = attr.ib()
- def create_room_as(self, room_creator, is_public=True, tok=None):
+ def create_room_as(self, room_creator=None, is_public=True, tok=None):
temp_id = self.auth_user_id
self.auth_user_id = room_creator
path = "/_matrix/client/r0/createRoom"
@@ -106,13 +112,22 @@ class RestHelper(object):
self.auth_user_id = temp_id
def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200):
- if txn_id is None:
- txn_id = "m%s" % (str(time.time()))
if body is None:
body = "body_text_here"
- path = "/_matrix/client/r0/rooms/%s/send/m.room.message/%s" % (room_id, txn_id)
content = {"msgtype": "m.text", "body": body}
+
+ return self.send_event(
+ room_id, "m.room.message", content, txn_id, tok, expect_code
+ )
+
+ def send_event(
+ self, room_id, type, content={}, txn_id=None, tok=None, expect_code=200
+ ):
+ if txn_id is None:
+ txn_id = "m%s" % (str(time.time()))
+
+ path = "/_matrix/client/r0/rooms/%s/send/%s/%s" % (room_id, type, txn_id)
if tok:
path = path + "?access_token=%s" % tok
@@ -128,7 +143,34 @@ class RestHelper(object):
return channel.json_body
- def send_state(self, room_id, event_type, body, tok, expect_code=200, state_key=""):
+ def _read_write_state(
+ self,
+ room_id: str,
+ event_type: str,
+ body: Optional[Dict[str, Any]],
+ tok: str,
+ expect_code: int = 200,
+ state_key: str = "",
+ method: str = "GET",
+ ) -> Dict:
+ """Read or write some state from a given room
+
+ Args:
+ room_id:
+ event_type: The type of state event
+ body: Body that is sent when making the request. The content of the state event.
+ If None, the request to the server will have an empty body
+ tok: The access token to use
+ expect_code: The HTTP code to expect in the response
+ state_key:
+ method: "GET" or "PUT" for reading or writing state, respectively
+
+ Returns:
+ The response body from the server
+
+ Raises:
+ AssertionError: if expect_code doesn't match the HTTP code we received
+ """
path = "/_matrix/client/r0/rooms/%s/state/%s/%s" % (
room_id,
event_type,
@@ -137,9 +179,13 @@ class RestHelper(object):
if tok:
path = path + "?access_token=%s" % tok
- request, channel = make_request(
- self.hs.get_reactor(), "PUT", path, json.dumps(body).encode("utf8")
- )
+ # Set request body if provided
+ content = b""
+ if body is not None:
+ content = json.dumps(body).encode("utf8")
+
+ request, channel = make_request(self.hs.get_reactor(), method, path, content)
+
render(request, self.resource, self.hs.get_reactor())
assert int(channel.result["code"]) == expect_code, (
@@ -148,3 +194,94 @@ class RestHelper(object):
)
return channel.json_body
+
+ def get_state(
+ self,
+ room_id: str,
+ event_type: str,
+ tok: str,
+ expect_code: int = 200,
+ state_key: str = "",
+ ):
+ """Gets some state from a room
+
+ Args:
+ room_id:
+ event_type: The type of state event
+ tok: The access token to use
+ expect_code: The HTTP code to expect in the response
+ state_key:
+
+ Returns:
+ The response body from the server
+
+ Raises:
+ AssertionError: if expect_code doesn't match the HTTP code we received
+ """
+ return self._read_write_state(
+ room_id, event_type, None, tok, expect_code, state_key, method="GET"
+ )
+
+ def send_state(
+ self,
+ room_id: str,
+ event_type: str,
+ body: Dict[str, Any],
+ tok: str,
+ expect_code: int = 200,
+ state_key: str = "",
+ ):
+ """Set some state in a room
+
+ Args:
+ room_id:
+ event_type: The type of state event
+ body: Body that is sent when making the request. The content of the state event.
+ tok: The access token to use
+ expect_code: The HTTP code to expect in the response
+ state_key:
+
+ Returns:
+ The response body from the server
+
+ Raises:
+ AssertionError: if expect_code doesn't match the HTTP code we received
+ """
+ return self._read_write_state(
+ room_id, event_type, body, tok, expect_code, state_key, method="PUT"
+ )
+
+ def upload_media(
+ self,
+ resource: Resource,
+ image_data: bytes,
+ tok: str,
+ filename: str = "test.png",
+ expect_code: int = 200,
+ ) -> dict:
+ """Upload a piece of test media to the media repo
+ Args:
+ resource: The resource that will handle the upload request
+ image_data: The image data to upload
+ tok: The user token to use during the upload
+ filename: The filename of the media to be uploaded
+ expect_code: The return code to expect from attempting to upload the media
+ """
+ image_length = len(image_data)
+ path = "/_matrix/media/r0/upload?filename=%s" % (filename,)
+ request, channel = make_request(
+ self.hs.get_reactor(), "POST", path, content=image_data, access_token=tok
+ )
+ request.requestHeaders.addRawHeader(
+ b"Content-Length", str(image_length).encode("UTF-8")
+ )
+ request.render(resource)
+ self.hs.get_reactor().pump([100])
+
+ assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
+ expect_code,
+ int(channel.result["code"]),
+ channel.result["body"],
+ )
+
+ return channel.json_body
diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py
index 920de41de4..3ab611f618 100644
--- a/tests/rest/client/v2_alpha/test_account.py
+++ b/tests/rest/client/v2_alpha/test_account.py
@@ -23,8 +23,9 @@ from email.parser import Parser
import pkg_resources
import synapse.rest.admin
-from synapse.api.constants import LoginType
-from synapse.rest.client.v1 import login
+from synapse.api.constants import LoginType, Membership
+from synapse.api.errors import Codes
+from synapse.rest.client.v1 import login, room
from synapse.rest.client.v2_alpha import account, register
from tests import unittest
@@ -45,7 +46,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
# Email config.
self.email_attempts = []
- def sendmail(smtphost, from_addr, to_addrs, msg, **kwargs):
+ async def sendmail(smtphost, from_addr, to_addrs, msg, **kwargs):
self.email_attempts.append(msg)
return
@@ -178,6 +179,22 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
# Assert we can't log in with the new password
self.attempt_wrong_password_login("kermit", new_password)
+ @unittest.override_config({"request_token_inhibit_3pid_errors": True})
+ def test_password_reset_bad_email_inhibit_error(self):
+ """Test that triggering a password reset with an email address that isn't bound
+ to an account doesn't leak the lack of binding for that address if configured
+ that way.
+ """
+ self.register_user("kermit", "monkey")
+ self.login("kermit", "monkey")
+
+ email = "test@example.com"
+
+ client_secret = "foobar"
+ session_id = self._request_token(email, client_secret)
+
+ self.assertIsNotNone(session_id)
+
def _request_token(self, email, client_secret):
request, channel = self.make_request(
"POST",
@@ -244,16 +261,72 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
synapse.rest.admin.register_servlets_for_client_rest_resource,
login.register_servlets,
account.register_servlets,
+ room.register_servlets,
]
def make_homeserver(self, reactor, clock):
- hs = self.setup_test_homeserver()
- return hs
+ self.hs = self.setup_test_homeserver()
+ return self.hs
def test_deactivate_account(self):
user_id = self.register_user("kermit", "test")
tok = self.login("kermit", "test")
+ self.deactivate(user_id, tok)
+
+ store = self.hs.get_datastore()
+
+ # Check that the user has been marked as deactivated.
+ self.assertTrue(self.get_success(store.get_user_deactivated_status(user_id)))
+
+ # Check that this access token has been invalidated.
+ request, channel = self.make_request("GET", "account/whoami")
+ self.render(request)
+ self.assertEqual(request.code, 401)
+
+ @unittest.INFO
+ def test_pending_invites(self):
+ """Tests that deactivating a user rejects every pending invite for them."""
+ store = self.hs.get_datastore()
+
+ inviter_id = self.register_user("inviter", "test")
+ inviter_tok = self.login("inviter", "test")
+
+ invitee_id = self.register_user("invitee", "test")
+ invitee_tok = self.login("invitee", "test")
+
+ # Make @inviter:test invite @invitee:test in a new room.
+ room_id = self.helper.create_room_as(inviter_id, tok=inviter_tok)
+ self.helper.invite(
+ room=room_id, src=inviter_id, targ=invitee_id, tok=inviter_tok
+ )
+
+ # Make sure the invite is here.
+ pending_invites = self.get_success(
+ store.get_invited_rooms_for_local_user(invitee_id)
+ )
+ self.assertEqual(len(pending_invites), 1, pending_invites)
+ self.assertEqual(pending_invites[0].room_id, room_id, pending_invites)
+
+ # Deactivate @invitee:test.
+ self.deactivate(invitee_id, invitee_tok)
+
+ # Check that the invite isn't there anymore.
+ pending_invites = self.get_success(
+ store.get_invited_rooms_for_local_user(invitee_id)
+ )
+ self.assertEqual(len(pending_invites), 0, pending_invites)
+
+ # Check that the membership of @invitee:test in the room is now "leave".
+ memberships = self.get_success(
+ store.get_rooms_for_local_user_where_membership_is(
+ invitee_id, [Membership.LEAVE]
+ )
+ )
+ self.assertEqual(len(memberships), 1, memberships)
+ self.assertEqual(memberships[0].room_id, room_id, memberships)
+
+ def deactivate(self, user_id, tok):
request_data = json.dumps(
{
"auth": {
@@ -270,12 +343,303 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
self.render(request)
self.assertEqual(request.code, 200)
- store = self.hs.get_datastore()
- # Check that the user has been marked as deactivated.
- self.assertTrue(self.get_success(store.get_user_deactivated_status(user_id)))
+class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ account.register_servlets,
+ login.register_servlets,
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+
+ # Email config.
+ self.email_attempts = []
+
+ async def sendmail(smtphost, from_addr, to_addrs, msg, **kwargs):
+ self.email_attempts.append(msg)
+
+ config["email"] = {
+ "enable_notifs": False,
+ "template_dir": os.path.abspath(
+ pkg_resources.resource_filename("synapse", "res/templates")
+ ),
+ "smtp_host": "127.0.0.1",
+ "smtp_port": 20,
+ "require_transport_security": False,
+ "smtp_user": None,
+ "smtp_pass": None,
+ "notif_from": "test@example.com",
+ }
+ config["public_baseurl"] = "https://example.com"
+
+ self.hs = self.setup_test_homeserver(config=config, sendmail=sendmail)
+ return self.hs
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+
+ self.user_id = self.register_user("kermit", "test")
+ self.user_id_tok = self.login("kermit", "test")
+ self.email = "test@example.com"
+ self.url_3pid = b"account/3pid"
+
+ def test_add_email(self):
+ """Test adding an email to profile
+ """
+ client_secret = "foobar"
+ session_id = self._request_token(self.email, client_secret)
+
+ self.assertEquals(len(self.email_attempts), 1)
+ link = self._get_link_from_email()
+
+ self._validate_token(link)
+
+ request, channel = self.make_request(
+ "POST",
+ b"/_matrix/client/unstable/account/3pid/add",
+ {
+ "client_secret": client_secret,
+ "sid": session_id,
+ "auth": {
+ "type": "m.login.password",
+ "user": self.user_id,
+ "password": "test",
+ },
+ },
+ access_token=self.user_id_tok,
+ )
- # Check that this access token has been invalidated.
- request, channel = self.make_request("GET", "account/whoami")
self.render(request)
- self.assertEqual(request.code, 401)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Get user
+ request, channel = self.make_request(
+ "GET", self.url_3pid, access_token=self.user_id_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
+ self.assertEqual(self.email, channel.json_body["threepids"][0]["address"])
+
+ def test_add_email_if_disabled(self):
+ """Test adding email to profile when doing so is disallowed
+ """
+ self.hs.config.enable_3pid_changes = False
+
+ client_secret = "foobar"
+ session_id = self._request_token(self.email, client_secret)
+
+ self.assertEquals(len(self.email_attempts), 1)
+ link = self._get_link_from_email()
+
+ self._validate_token(link)
+
+ request, channel = self.make_request(
+ "POST",
+ b"/_matrix/client/unstable/account/3pid/add",
+ {
+ "client_secret": client_secret,
+ "sid": session_id,
+ "auth": {
+ "type": "m.login.password",
+ "user": self.user_id,
+ "password": "test",
+ },
+ },
+ access_token=self.user_id_tok,
+ )
+ self.render(request)
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ # Get user
+ request, channel = self.make_request(
+ "GET", self.url_3pid, access_token=self.user_id_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertFalse(channel.json_body["threepids"])
+
+ def test_delete_email(self):
+ """Test deleting an email from profile
+ """
+ # Add a threepid
+ self.get_success(
+ self.store.user_add_threepid(
+ user_id=self.user_id,
+ medium="email",
+ address=self.email,
+ validated_at=0,
+ added_at=0,
+ )
+ )
+
+ request, channel = self.make_request(
+ "POST",
+ b"account/3pid/delete",
+ {"medium": "email", "address": self.email},
+ access_token=self.user_id_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Get user
+ request, channel = self.make_request(
+ "GET", self.url_3pid, access_token=self.user_id_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertFalse(channel.json_body["threepids"])
+
+ def test_delete_email_if_disabled(self):
+ """Test deleting an email from profile when disallowed
+ """
+ self.hs.config.enable_3pid_changes = False
+
+ # Add a threepid
+ self.get_success(
+ self.store.user_add_threepid(
+ user_id=self.user_id,
+ medium="email",
+ address=self.email,
+ validated_at=0,
+ added_at=0,
+ )
+ )
+
+ request, channel = self.make_request(
+ "POST",
+ b"account/3pid/delete",
+ {"medium": "email", "address": self.email},
+ access_token=self.user_id_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ # Get user
+ request, channel = self.make_request(
+ "GET", self.url_3pid, access_token=self.user_id_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
+ self.assertEqual(self.email, channel.json_body["threepids"][0]["address"])
+
+ def test_cant_add_email_without_clicking_link(self):
+ """Test that we do actually need to click the link in the email
+ """
+ client_secret = "foobar"
+ session_id = self._request_token(self.email, client_secret)
+
+ self.assertEquals(len(self.email_attempts), 1)
+
+ # Attempt to add email without clicking the link
+ request, channel = self.make_request(
+ "POST",
+ b"/_matrix/client/unstable/account/3pid/add",
+ {
+ "client_secret": client_secret,
+ "sid": session_id,
+ "auth": {
+ "type": "m.login.password",
+ "user": self.user_id,
+ "password": "test",
+ },
+ },
+ access_token=self.user_id_tok,
+ )
+ self.render(request)
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"])
+
+ # Get user
+ request, channel = self.make_request(
+ "GET", self.url_3pid, access_token=self.user_id_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertFalse(channel.json_body["threepids"])
+
+ def test_no_valid_token(self):
+ """Test that we do actually need to request a token and can't just
+ make a session up.
+ """
+ client_secret = "foobar"
+ session_id = "weasle"
+
+ # Attempt to add email without even requesting an email
+ request, channel = self.make_request(
+ "POST",
+ b"/_matrix/client/unstable/account/3pid/add",
+ {
+ "client_secret": client_secret,
+ "sid": session_id,
+ "auth": {
+ "type": "m.login.password",
+ "user": self.user_id,
+ "password": "test",
+ },
+ },
+ access_token=self.user_id_tok,
+ )
+ self.render(request)
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"])
+
+ # Get user
+ request, channel = self.make_request(
+ "GET", self.url_3pid, access_token=self.user_id_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertFalse(channel.json_body["threepids"])
+
+ def _request_token(self, email, client_secret):
+ request, channel = self.make_request(
+ "POST",
+ b"account/3pid/email/requestToken",
+ {"client_secret": client_secret, "email": email, "send_attempt": 1},
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.result)
+
+ return channel.json_body["sid"]
+
+ def _validate_token(self, link):
+ # Remove the host
+ path = link.replace("https://example.com", "")
+
+ request, channel = self.make_request("GET", path, shorthand=False)
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.result)
+
+ def _get_link_from_email(self):
+ assert self.email_attempts, "No emails have been sent"
+
+ raw_msg = self.email_attempts[-1].decode("UTF-8")
+ mail = Parser().parsestr(raw_msg)
+
+ text = None
+ for part in mail.walk():
+ if part.get_content_type() == "text/plain":
+ text = part.get_payload(decode=True).decode("UTF-8")
+ break
+
+ if not text:
+ self.fail("Could not find text portion of email to parse")
+
+ match = re.search(r"https://example.com\S+", text)
+ assert match, "Could not find link in email"
+
+ return match.group(0)
diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py
index b9ef46e8fb..293ccfba2b 100644
--- a/tests/rest/client/v2_alpha/test_auth.py
+++ b/tests/rest/client/v2_alpha/test_auth.py
@@ -12,22 +12,41 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+from typing import List, Union
from twisted.internet.defer import succeed
import synapse.rest.admin
from synapse.api.constants import LoginType
-from synapse.rest.client.v2_alpha import auth, register
+from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
+from synapse.http.site import SynapseRequest
+from synapse.rest.client.v1 import login
+from synapse.rest.client.v2_alpha import auth, devices, register
+from synapse.types import JsonDict
from tests import unittest
+from tests.server import FakeChannel
+
+
+class DummyRecaptchaChecker(UserInteractiveAuthChecker):
+ def __init__(self, hs):
+ super().__init__(hs)
+ self.recaptcha_attempts = []
+
+ def check_auth(self, authdict, clientip):
+ self.recaptcha_attempts.append((authdict, clientip))
+ return succeed(True)
+
+
+class DummyPasswordChecker(UserInteractiveAuthChecker):
+ def check_auth(self, authdict, clientip):
+ return succeed(authdict["identifier"]["user"])
class FallbackAuthTests(unittest.HomeserverTestCase):
servlets = [
auth.register_servlets,
- synapse.rest.admin.register_servlets_for_client_rest_resource,
register.register_servlets,
]
hijack_auth = False
@@ -44,28 +63,55 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
return hs
def prepare(self, reactor, clock, hs):
+ self.recaptcha_checker = DummyRecaptchaChecker(hs)
auth_handler = hs.get_auth_handler()
+ auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker
- self.recaptcha_attempts = []
+ def register(self, expected_response: int, body: JsonDict) -> FakeChannel:
+ """Make a register request."""
+ request, channel = self.make_request(
+ "POST", "register", body
+ ) # type: SynapseRequest, FakeChannel
+ self.render(request)
- def _recaptcha(authdict, clientip):
- self.recaptcha_attempts.append((authdict, clientip))
- return succeed(True)
+ self.assertEqual(request.code, expected_response)
+ return channel
- auth_handler.checkers[LoginType.RECAPTCHA] = _recaptcha
+ def recaptcha(
+ self, session: str, expected_post_response: int, post_session: str = None
+ ) -> None:
+ """Get and respond to a fallback recaptcha. Returns the second request."""
+ if post_session is None:
+ post_session = session
- @unittest.INFO
- def test_fallback_captcha(self):
+ request, channel = self.make_request(
+ "GET", "auth/m.login.recaptcha/fallback/web?session=" + session
+ ) # type: SynapseRequest, FakeChannel
+ self.render(request)
+ self.assertEqual(request.code, 200)
request, channel = self.make_request(
"POST",
- "register",
- {"username": "user", "type": "m.login.password", "password": "bar"},
+ "auth/m.login.recaptcha/fallback/web?session="
+ + post_session
+ + "&g-recaptcha-response=a",
)
self.render(request)
+ self.assertEqual(request.code, expected_post_response)
+ # The recaptcha handler is called with the response given
+ attempts = self.recaptcha_checker.recaptcha_attempts
+ self.assertEqual(len(attempts), 1)
+ self.assertEqual(attempts[0][0]["response"], "a")
+
+ @unittest.INFO
+ def test_fallback_captcha(self):
+ """Ensure that fallback auth via a captcha works."""
# Returns a 401 as per the spec
- self.assertEqual(request.code, 401)
+ channel = self.register(
+ 401, {"username": "user", "type": "m.login.password", "password": "bar"},
+ )
+
# Grab the session
session = channel.json_body["session"]
# Assert our configured public key is being given
@@ -73,39 +119,198 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
channel.json_body["params"]["m.login.recaptcha"]["public_key"], "brokencake"
)
- request, channel = self.make_request(
- "GET", "auth/m.login.recaptcha/fallback/web?session=" + session
+ # Complete the recaptcha step.
+ self.recaptcha(session, 200)
+
+ # also complete the dummy auth
+ self.register(200, {"auth": {"session": session, "type": "m.login.dummy"}})
+
+ # Now we should have fulfilled a complete auth flow, including
+ # the recaptcha fallback step, we can then send a
+ # request to the register API with the session in the authdict.
+ channel = self.register(200, {"auth": {"session": session}})
+
+ # We're given a registered user.
+ self.assertEqual(channel.json_body["user_id"], "@user:test")
+
+ def test_complete_operation_unknown_session(self):
+ """
+ Attempting to mark an invalid session as complete should error.
+ """
+ # Make the initial request to register. (Later on a different password
+ # will be used.)
+ # Returns a 401 as per the spec
+ channel = self.register(
+ 401, {"username": "user", "type": "m.login.password", "password": "bar"}
)
+
+ # Grab the session
+ session = channel.json_body["session"]
+ # Assert our configured public key is being given
+ self.assertEqual(
+ channel.json_body["params"]["m.login.recaptcha"]["public_key"], "brokencake"
+ )
+
+ # Attempt to complete the recaptcha step with an unknown session.
+ # This results in an error.
+ self.recaptcha(session, 400, session + "unknown")
+
+
+class UIAuthTests(unittest.HomeserverTestCase):
+ servlets = [
+ auth.register_servlets,
+ devices.register_servlets,
+ login.register_servlets,
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ register.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ auth_handler = hs.get_auth_handler()
+ auth_handler.checkers[LoginType.PASSWORD] = DummyPasswordChecker(hs)
+
+ self.user_pass = "pass"
+ self.user = self.register_user("test", self.user_pass)
+ self.user_tok = self.login("test", self.user_pass)
+
+ def get_device_ids(self) -> List[str]:
+ # Get the list of devices so one can be deleted.
+ request, channel = self.make_request(
+ "GET", "devices", access_token=self.user_tok,
+ ) # type: SynapseRequest, FakeChannel
self.render(request)
+
+ # Get the ID of the device.
self.assertEqual(request.code, 200)
+ return [d["device_id"] for d in channel.json_body["devices"]]
+ def delete_device(
+ self, device: str, expected_response: int, body: Union[bytes, JsonDict] = b""
+ ) -> FakeChannel:
+ """Delete an individual device."""
request, channel = self.make_request(
- "POST",
- "auth/m.login.recaptcha/fallback/web?session="
- + session
- + "&g-recaptcha-response=a",
- )
+ "DELETE", "devices/" + device, body, access_token=self.user_tok
+ ) # type: SynapseRequest, FakeChannel
self.render(request)
- self.assertEqual(request.code, 200)
- # The recaptcha handler is called with the response given
- self.assertEqual(len(self.recaptcha_attempts), 1)
- self.assertEqual(self.recaptcha_attempts[0][0]["response"], "a")
+ # Ensure the response is sane.
+ self.assertEqual(request.code, expected_response)
- # also complete the dummy auth
+ return channel
+
+ def delete_devices(self, expected_response: int, body: JsonDict) -> FakeChannel:
+ """Delete 1 or more devices."""
+ # Note that this uses the delete_devices endpoint so that we can modify
+ # the payload half-way through some tests.
request, channel = self.make_request(
- "POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}}
- )
+ "POST", "delete_devices", body, access_token=self.user_tok,
+ ) # type: SynapseRequest, FakeChannel
self.render(request)
- # Now we should have fufilled a complete auth flow, including
- # the recaptcha fallback step, we can then send a
- # request to the register API with the session in the authdict.
- request, channel = self.make_request(
- "POST", "register", {"auth": {"session": session}}
+ # Ensure the response is sane.
+ self.assertEqual(request.code, expected_response)
+
+ return channel
+
+ def test_ui_auth(self):
+ """
+ Test user interactive authentication outside of registration.
+ """
+ device_id = self.get_device_ids()[0]
+
+ # Attempt to delete this device.
+ # Returns a 401 as per the spec
+ channel = self.delete_device(device_id, 401)
+
+ # Grab the session
+ session = channel.json_body["session"]
+ # Ensure that flows are what is expected.
+ self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"])
+
+ # Make another request providing the UI auth flow.
+ self.delete_device(
+ device_id,
+ 200,
+ {
+ "auth": {
+ "type": "m.login.password",
+ "identifier": {"type": "m.id.user", "user": self.user},
+ "password": self.user_pass,
+ "session": session,
+ },
+ },
)
- self.render(request)
- self.assertEqual(channel.code, 200)
- # We're given a registered user.
- self.assertEqual(channel.json_body["user_id"], "@user:test")
+ def test_can_change_body(self):
+ """
+ The client dict can be modified during the user interactive authentication session.
+
+ Note that it is not spec compliant to modify the client dict during a
+ user interactive authentication session, but many clients currently do.
+
+ When Synapse is updated to be spec compliant, the call to re-use the
+ session ID should be rejected.
+ """
+ # Create a second login.
+ self.login("test", self.user_pass)
+
+ device_ids = self.get_device_ids()
+ self.assertEqual(len(device_ids), 2)
+
+ # Attempt to delete the first device.
+ # Returns a 401 as per the spec
+ channel = self.delete_devices(401, {"devices": [device_ids[0]]})
+
+ # Grab the session
+ session = channel.json_body["session"]
+ # Ensure that flows are what is expected.
+ self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"])
+
+ # Make another request providing the UI auth flow, but try to delete the
+ # second device.
+ self.delete_devices(
+ 200,
+ {
+ "devices": [device_ids[1]],
+ "auth": {
+ "type": "m.login.password",
+ "identifier": {"type": "m.id.user", "user": self.user},
+ "password": self.user_pass,
+ "session": session,
+ },
+ },
+ )
+
+ def test_cannot_change_uri(self):
+ """
+ The initial requested URI cannot be modified during the user interactive authentication session.
+ """
+ # Create a second login.
+ self.login("test", self.user_pass)
+
+ device_ids = self.get_device_ids()
+ self.assertEqual(len(device_ids), 2)
+
+ # Attempt to delete the first device.
+ # Returns a 401 as per the spec
+ channel = self.delete_device(device_ids[0], 401)
+
+ # Grab the session
+ session = channel.json_body["session"]
+ # Ensure that flows are what is expected.
+ self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"])
+
+ # Make another request providing the UI auth flow, but try to delete the
+ # second device. This results in an error.
+ self.delete_device(
+ device_ids[1],
+ 403,
+ {
+ "auth": {
+ "type": "m.login.password",
+ "identifier": {"type": "m.id.user", "user": self.user},
+ "password": self.user_pass,
+ "session": session,
+ },
+ },
+ )
diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/v2_alpha/test_filter.py
index f42a8efbf4..e0e9e94fbf 100644
--- a/tests/rest/client/v2_alpha/test_filter.py
+++ b/tests/rest/client/v2_alpha/test_filter.py
@@ -92,7 +92,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
)
self.render(request)
- self.assertEqual(channel.result["code"], b"400")
+ self.assertEqual(channel.result["code"], b"404")
self.assertEquals(channel.json_body["errcode"], Codes.NOT_FOUND)
# Currently invalid params do not have an appropriate errcode
diff --git a/tests/rest/client/v2_alpha/test_password_policy.py b/tests/rest/client/v2_alpha/test_password_policy.py
new file mode 100644
index 0000000000..c57072f50c
--- /dev/null
+++ b/tests/rest/client/v2_alpha/test_password_policy.py
@@ -0,0 +1,179 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+
+from synapse.api.constants import LoginType
+from synapse.api.errors import Codes
+from synapse.rest import admin
+from synapse.rest.client.v1 import login
+from synapse.rest.client.v2_alpha import account, password_policy, register
+
+from tests import unittest
+
+
+class PasswordPolicyTestCase(unittest.HomeserverTestCase):
+ """Tests the password policy feature and its compliance with MSC2000.
+
+ When validating a password, Synapse does the necessary checks in this order:
+
+ 1. Password is long enough
+ 2. Password contains digit(s)
+ 3. Password contains symbol(s)
+ 4. Password contains uppercase letter(s)
+ 5. Password contains lowercase letter(s)
+
+ For each test below that checks whether a password triggers the right error code,
+ that test provides a password good enough to pass the previous tests, but not the
+ one it is currently testing (nor any test that comes afterward).
+ """
+
+ servlets = [
+ admin.register_servlets_for_client_rest_resource,
+ login.register_servlets,
+ register.register_servlets,
+ password_policy.register_servlets,
+ account.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ self.register_url = "/_matrix/client/r0/register"
+ self.policy = {
+ "enabled": True,
+ "minimum_length": 10,
+ "require_digit": True,
+ "require_symbol": True,
+ "require_lowercase": True,
+ "require_uppercase": True,
+ }
+
+ config = self.default_config()
+ config["password_config"] = {
+ "policy": self.policy,
+ }
+
+ hs = self.setup_test_homeserver(config=config)
+ return hs
+
+ def test_get_policy(self):
+ """Tests if the /password_policy endpoint returns the configured policy."""
+
+ request, channel = self.make_request(
+ "GET", "/_matrix/client/r0/password_policy"
+ )
+ self.render(request)
+
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(
+ channel.json_body,
+ {
+ "m.minimum_length": 10,
+ "m.require_digit": True,
+ "m.require_symbol": True,
+ "m.require_lowercase": True,
+ "m.require_uppercase": True,
+ },
+ channel.result,
+ )
+
+ def test_password_too_short(self):
+ request_data = json.dumps({"username": "kermit", "password": "shorty"})
+ request, channel = self.make_request("POST", self.register_url, request_data)
+ self.render(request)
+
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(
+ channel.json_body["errcode"], Codes.PASSWORD_TOO_SHORT, channel.result,
+ )
+
+ def test_password_no_digit(self):
+ request_data = json.dumps({"username": "kermit", "password": "longerpassword"})
+ request, channel = self.make_request("POST", self.register_url, request_data)
+ self.render(request)
+
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(
+ channel.json_body["errcode"], Codes.PASSWORD_NO_DIGIT, channel.result,
+ )
+
+ def test_password_no_symbol(self):
+ request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword"})
+ request, channel = self.make_request("POST", self.register_url, request_data)
+ self.render(request)
+
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(
+ channel.json_body["errcode"], Codes.PASSWORD_NO_SYMBOL, channel.result,
+ )
+
+ def test_password_no_uppercase(self):
+ request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword!"})
+ request, channel = self.make_request("POST", self.register_url, request_data)
+ self.render(request)
+
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(
+ channel.json_body["errcode"], Codes.PASSWORD_NO_UPPERCASE, channel.result,
+ )
+
+ def test_password_no_lowercase(self):
+ request_data = json.dumps({"username": "kermit", "password": "L0NGERPASSWORD!"})
+ request, channel = self.make_request("POST", self.register_url, request_data)
+ self.render(request)
+
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(
+ channel.json_body["errcode"], Codes.PASSWORD_NO_LOWERCASE, channel.result,
+ )
+
+ def test_password_compliant(self):
+ request_data = json.dumps({"username": "kermit", "password": "L0ngerpassword!"})
+ request, channel = self.make_request("POST", self.register_url, request_data)
+ self.render(request)
+
+ # Getting a 401 here means the password has passed validation and the server has
+ # responded with a list of registration flows.
+ self.assertEqual(channel.code, 401, channel.result)
+
+ def test_password_change(self):
+ """This doesn't test every possible use case, only that hitting /account/password
+ triggers the password validation code.
+ """
+ compliant_password = "C0mpl!antpassword"
+ not_compliant_password = "notcompliantpassword"
+
+ user_id = self.register_user("kermit", compliant_password)
+ tok = self.login("kermit", compliant_password)
+
+ request_data = json.dumps(
+ {
+ "new_password": not_compliant_password,
+ "auth": {
+ "password": compliant_password,
+ "type": LoginType.PASSWORD,
+ "user": user_id,
+ },
+ }
+ )
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/account/password",
+ request_data,
+ access_token=tok,
+ )
+ self.render(request)
+
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(channel.json_body["errcode"], Codes.PASSWORD_NO_DIGIT)
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index ab4d7d70d0..7deaf5b24a 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -25,28 +25,26 @@ import synapse.rest.admin
from synapse.api.constants import LoginType
from synapse.api.errors import Codes
from synapse.appservice import ApplicationService
-from synapse.rest.client.v1 import login
+from synapse.rest.client.v1 import login, logout
from synapse.rest.client.v2_alpha import account, account_validity, register, sync
from tests import unittest
+from tests.unittest import override_config
class RegisterRestServletTestCase(unittest.HomeserverTestCase):
- servlets = [register.register_servlets]
-
- def make_homeserver(self, reactor, clock):
-
- self.url = b"/_matrix/client/r0/register"
+ servlets = [
+ login.register_servlets,
+ register.register_servlets,
+ synapse.rest.admin.register_servlets,
+ ]
+ url = b"/_matrix/client/r0/register"
- self.hs = self.setup_test_homeserver()
- self.hs.config.enable_registration = True
- self.hs.config.registrations_require_3pid = []
- self.hs.config.auto_join_rooms = []
- self.hs.config.enable_registration_captcha = False
- self.hs.config.allow_guest_access = True
-
- return self.hs
+ def default_config(self):
+ config = super().default_config()
+ config["allow_guest_access"] = True
+ return config
def test_POST_appservice_registration_valid(self):
user_id = "@as_user_kermit:test"
@@ -149,10 +147,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"403", channel.result)
self.assertEquals(channel.json_body["error"], "Guest access is disabled")
+ @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}})
def test_POST_ratelimiting_guest(self):
- self.hs.config.rc_registration.burst_count = 5
- self.hs.config.rc_registration.per_second = 0.17
-
for i in range(0, 6):
url = self.url + b"?kind=guest"
request, channel = self.make_request(b"POST", url, b"{}")
@@ -171,10 +167,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"200", channel.result)
+ @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}})
def test_POST_ratelimiting(self):
- self.hs.config.rc_registration.burst_count = 5
- self.hs.config.rc_registration.per_second = 0.17
-
for i in range(0, 6):
params = {
"username": "kermit" + str(i),
@@ -199,6 +193,115 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"200", channel.result)
+ def test_advertised_flows(self):
+ request, channel = self.make_request(b"POST", self.url, b"{}")
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"401", channel.result)
+ flows = channel.json_body["flows"]
+
+ # with the stock config, we only expect the dummy flow
+ self.assertCountEqual([["m.login.dummy"]], (f["stages"] for f in flows))
+
+ @unittest.override_config(
+ {
+ "public_baseurl": "https://test_server",
+ "enable_registration_captcha": True,
+ "user_consent": {
+ "version": "1",
+ "template_dir": "/",
+ "require_at_registration": True,
+ },
+ "account_threepid_delegates": {
+ "email": "https://id_server",
+ "msisdn": "https://id_server",
+ },
+ }
+ )
+ def test_advertised_flows_captcha_and_terms_and_3pids(self):
+ request, channel = self.make_request(b"POST", self.url, b"{}")
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"401", channel.result)
+ flows = channel.json_body["flows"]
+
+ self.assertCountEqual(
+ [
+ ["m.login.recaptcha", "m.login.terms", "m.login.dummy"],
+ ["m.login.recaptcha", "m.login.terms", "m.login.email.identity"],
+ ["m.login.recaptcha", "m.login.terms", "m.login.msisdn"],
+ [
+ "m.login.recaptcha",
+ "m.login.terms",
+ "m.login.msisdn",
+ "m.login.email.identity",
+ ],
+ ],
+ (f["stages"] for f in flows),
+ )
+
+ @unittest.override_config(
+ {
+ "public_baseurl": "https://test_server",
+ "registrations_require_3pid": ["email"],
+ "disable_msisdn_registration": True,
+ "email": {
+ "smtp_host": "mail_server",
+ "smtp_port": 2525,
+ "notif_from": "sender@host",
+ },
+ }
+ )
+ def test_advertised_flows_no_msisdn_email_required(self):
+ request, channel = self.make_request(b"POST", self.url, b"{}")
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"401", channel.result)
+ flows = channel.json_body["flows"]
+
+ # with the stock config, we expect all four combinations of 3pid
+ self.assertCountEqual(
+ [["m.login.email.identity"]], (f["stages"] for f in flows)
+ )
+
+ @unittest.override_config(
+ {
+ "request_token_inhibit_3pid_errors": True,
+ "public_baseurl": "https://test_server",
+ "email": {
+ "smtp_host": "mail_server",
+ "smtp_port": 2525,
+ "notif_from": "sender@host",
+ },
+ }
+ )
+ def test_request_token_existing_email_inhibit_error(self):
+ """Test that requesting a token via this endpoint doesn't leak existing
+ associations if configured that way.
+ """
+ user_id = self.register_user("kermit", "monkey")
+ self.login("kermit", "monkey")
+
+ email = "test@example.com"
+
+ # Add a threepid
+ self.get_success(
+ self.hs.get_datastore().user_add_threepid(
+ user_id=user_id,
+ medium="email",
+ address=email,
+ validated_at=0,
+ added_at=0,
+ )
+ )
+
+ request, channel = self.make_request(
+ "POST",
+ b"register/email/requestToken",
+ {"client_secret": "foobar", "email": email, "send_attempt": 1},
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.result)
+
+ self.assertIsNotNone(channel.json_body.get("sid"))
+
class AccountValidityTestCase(unittest.HomeserverTestCase):
@@ -207,6 +310,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
synapse.rest.admin.register_servlets_for_client_rest_resource,
login.register_servlets,
sync.register_servlets,
+ logout.register_servlets,
account_validity.register_servlets,
]
@@ -299,6 +403,39 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result
)
+ def test_logging_out_expired_user(self):
+ user_id = self.register_user("kermit", "monkey")
+ tok = self.login("kermit", "monkey")
+
+ self.register_user("admin", "adminpassword", admin=True)
+ admin_tok = self.login("admin", "adminpassword")
+
+ url = "/_matrix/client/unstable/admin/account_validity/validity"
+ params = {
+ "user_id": user_id,
+ "expiration_ts": 0,
+ "enable_renewal_emails": False,
+ }
+ request_data = json.dumps(params)
+ request, channel = self.make_request(
+ b"POST", url, request_data, access_token=admin_tok
+ )
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ # Try to log the user out
+ request, channel = self.make_request(b"POST", "/logout", access_token=tok)
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ # Log the user in again (allowed for expired accounts)
+ tok = self.login("kermit", "monkey")
+
+ # Try to log out all of the user's sessions
+ request, channel = self.make_request(b"POST", "/logout/all", access_token=tok)
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
@@ -330,9 +467,8 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
# Email config.
self.email_attempts = []
- def sendmail(*args, **kwargs):
+ async def sendmail(*args, **kwargs):
self.email_attempts.append((args, kwargs))
- return
config["email"] = {
"enable_notifs": True,
diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py
index 71895094bd..fa3a3ec1bd 100644
--- a/tests/rest/client/v2_alpha/test_sync.py
+++ b/tests/rest/client/v2_alpha/test_sync.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
-# Copyright 2018 New Vector
+# Copyright 2018-2019 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,10 +13,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-from mock import Mock
+import json
import synapse.rest.admin
+from synapse.api.constants import EventContentFields, EventTypes
from synapse.rest.client.v1 import login, room
from synapse.rest.client.v2_alpha import sync
@@ -26,14 +27,12 @@ from tests.server import TimedOutException
class FilterTestCase(unittest.HomeserverTestCase):
user_id = "@apple:test"
- servlets = [sync.register_servlets]
-
- def make_homeserver(self, reactor, clock):
-
- hs = self.setup_test_homeserver(
- "red", http_client=None, federation_client=Mock()
- )
- return hs
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ sync.register_servlets,
+ ]
def test_sync_argless(self):
request, channel = self.make_request("GET", "/sync")
@@ -41,16 +40,14 @@ class FilterTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200)
self.assertTrue(
- set(
- [
- "next_batch",
- "rooms",
- "presence",
- "account_data",
- "to_device",
- "device_lists",
- ]
- ).issubset(set(channel.json_body.keys()))
+ {
+ "next_batch",
+ "rooms",
+ "presence",
+ "account_data",
+ "to_device",
+ "device_lists",
+ }.issubset(set(channel.json_body.keys()))
)
def test_sync_presence_disabled(self):
@@ -64,11 +61,149 @@ class FilterTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200)
self.assertTrue(
- set(
- ["next_batch", "rooms", "account_data", "to_device", "device_lists"]
- ).issubset(set(channel.json_body.keys()))
+ {
+ "next_batch",
+ "rooms",
+ "account_data",
+ "to_device",
+ "device_lists",
+ }.issubset(set(channel.json_body.keys()))
+ )
+
+
+class SyncFilterTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ sync.register_servlets,
+ ]
+
+ def test_sync_filter_labels(self):
+ """Test that we can filter by a label."""
+ sync_filter = json.dumps(
+ {
+ "room": {
+ "timeline": {
+ "types": [EventTypes.Message],
+ "org.matrix.labels": ["#fun"],
+ }
+ }
+ }
+ )
+
+ events = self._test_sync_filter_labels(sync_filter)
+
+ self.assertEqual(len(events), 2, [event["content"] for event in events])
+ self.assertEqual(events[0]["content"]["body"], "with right label", events[0])
+ self.assertEqual(events[1]["content"]["body"], "with right label", events[1])
+
+ def test_sync_filter_not_labels(self):
+ """Test that we can filter by the absence of a label."""
+ sync_filter = json.dumps(
+ {
+ "room": {
+ "timeline": {
+ "types": [EventTypes.Message],
+ "org.matrix.not_labels": ["#fun"],
+ }
+ }
+ }
+ )
+
+ events = self._test_sync_filter_labels(sync_filter)
+
+ self.assertEqual(len(events), 3, [event["content"] for event in events])
+ self.assertEqual(events[0]["content"]["body"], "without label", events[0])
+ self.assertEqual(events[1]["content"]["body"], "with wrong label", events[1])
+ self.assertEqual(
+ events[2]["content"]["body"], "with two wrong labels", events[2]
+ )
+
+ def test_sync_filter_labels_not_labels(self):
+ """Test that we can filter by both a label and the absence of another label."""
+ sync_filter = json.dumps(
+ {
+ "room": {
+ "timeline": {
+ "types": [EventTypes.Message],
+ "org.matrix.labels": ["#work"],
+ "org.matrix.not_labels": ["#notfun"],
+ }
+ }
+ }
+ )
+
+ events = self._test_sync_filter_labels(sync_filter)
+
+ self.assertEqual(len(events), 1, [event["content"] for event in events])
+ self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0])
+
+ def _test_sync_filter_labels(self, sync_filter):
+ user_id = self.register_user("kermit", "test")
+ tok = self.login("kermit", "test")
+
+ room_id = self.helper.create_room_as(user_id, tok=tok)
+
+ self.helper.send_event(
+ room_id=room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "with right label",
+ EventContentFields.LABELS: ["#fun"],
+ },
+ tok=tok,
+ )
+
+ self.helper.send_event(
+ room_id=room_id,
+ type=EventTypes.Message,
+ content={"msgtype": "m.text", "body": "without label"},
+ tok=tok,
)
+ self.helper.send_event(
+ room_id=room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "with wrong label",
+ EventContentFields.LABELS: ["#work"],
+ },
+ tok=tok,
+ )
+
+ self.helper.send_event(
+ room_id=room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "with two wrong labels",
+ EventContentFields.LABELS: ["#work", "#notfun"],
+ },
+ tok=tok,
+ )
+
+ self.helper.send_event(
+ room_id=room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "with right label",
+ EventContentFields.LABELS: ["#fun"],
+ },
+ tok=tok,
+ )
+
+ request, channel = self.make_request(
+ "GET", "/sync?filter=%s" % sync_filter, access_token=tok
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ return channel.json_body["rooms"]["join"][room_id]["timeline"]["events"]
+
class SyncTypingTests(unittest.HomeserverTestCase):
diff --git a/tests/rest/key/__init__.py b/tests/rest/key/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/tests/rest/key/__init__.py
diff --git a/tests/rest/key/v2/__init__.py b/tests/rest/key/v2/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/tests/rest/key/v2/__init__.py
diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py
new file mode 100644
index 0000000000..99eb477149
--- /dev/null
+++ b/tests/rest/key/v2/test_remote_key_resource.py
@@ -0,0 +1,257 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import urllib.parse
+from io import BytesIO, StringIO
+
+from mock import Mock
+
+import signedjson.key
+from canonicaljson import encode_canonical_json
+from nacl.signing import SigningKey
+from signedjson.sign import sign_json
+
+from twisted.web.resource import NoResource
+
+from synapse.crypto.keyring import PerspectivesKeyFetcher
+from synapse.http.site import SynapseRequest
+from synapse.rest.key.v2 import KeyApiV2Resource
+from synapse.storage.keys import FetchKeyResult
+from synapse.util.httpresourcetree import create_resource_tree
+from synapse.util.stringutils import random_string
+
+from tests import unittest
+from tests.server import FakeChannel, wait_until_result
+from tests.utils import default_config
+
+
+class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase):
+ def make_homeserver(self, reactor, clock):
+ self.http_client = Mock()
+ return self.setup_test_homeserver(http_client=self.http_client)
+
+ def create_test_json_resource(self):
+ return create_resource_tree(
+ {"/_matrix/key/v2": KeyApiV2Resource(self.hs)}, root_resource=NoResource()
+ )
+
+ def expect_outgoing_key_request(
+ self, server_name: str, signing_key: SigningKey
+ ) -> None:
+ """
+ Tell the mock http client to expect an outgoing GET request for the given key
+ """
+
+ def get_json(destination, path, ignore_backoff=False, **kwargs):
+ self.assertTrue(ignore_backoff)
+ self.assertEqual(destination, server_name)
+ key_id = "%s:%s" % (signing_key.alg, signing_key.version)
+ self.assertEqual(
+ path, "/_matrix/key/v2/server/%s" % (urllib.parse.quote(key_id),)
+ )
+
+ response = {
+ "server_name": server_name,
+ "old_verify_keys": {},
+ "valid_until_ts": 200 * 1000,
+ "verify_keys": {
+ key_id: {
+ "key": signedjson.key.encode_verify_key_base64(
+ signing_key.verify_key
+ )
+ }
+ },
+ }
+ sign_json(response, server_name, signing_key)
+ return response
+
+ self.http_client.get_json.side_effect = get_json
+
+
+class RemoteKeyResourceTestCase(BaseRemoteKeyResourceTestCase):
+ def make_notary_request(self, server_name: str, key_id: str) -> dict:
+ """Send a GET request to the test server requesting the given key.
+
+ Checks that the response is a 200 and returns the decoded json body.
+ """
+ channel = FakeChannel(self.site, self.reactor)
+ req = SynapseRequest(channel)
+ req.content = BytesIO(b"")
+ req.requestReceived(
+ b"GET",
+ b"/_matrix/key/v2/query/%s/%s"
+ % (server_name.encode("utf-8"), key_id.encode("utf-8")),
+ b"1.1",
+ )
+ wait_until_result(self.reactor, req)
+ self.assertEqual(channel.code, 200)
+ resp = channel.json_body
+ return resp
+
+ def test_get_key(self):
+ """Fetch a remote key"""
+ SERVER_NAME = "remote.server"
+ testkey = signedjson.key.generate_signing_key("ver1")
+ self.expect_outgoing_key_request(SERVER_NAME, testkey)
+
+ resp = self.make_notary_request(SERVER_NAME, "ed25519:ver1")
+ keys = resp["server_keys"]
+ self.assertEqual(len(keys), 1)
+
+ self.assertIn("ed25519:ver1", keys[0]["verify_keys"])
+ self.assertEqual(len(keys[0]["verify_keys"]), 1)
+
+ # it should be signed by both the origin server and the notary
+ self.assertIn(SERVER_NAME, keys[0]["signatures"])
+ self.assertIn(self.hs.hostname, keys[0]["signatures"])
+
+ def test_get_own_key(self):
+ """Fetch our own key"""
+ testkey = signedjson.key.generate_signing_key("ver1")
+ self.expect_outgoing_key_request(self.hs.hostname, testkey)
+
+ resp = self.make_notary_request(self.hs.hostname, "ed25519:ver1")
+ keys = resp["server_keys"]
+ self.assertEqual(len(keys), 1)
+
+ # it should be signed by both itself, and the notary signing key
+ sigs = keys[0]["signatures"]
+ self.assertEqual(len(sigs), 1)
+ self.assertIn(self.hs.hostname, sigs)
+ oursigs = sigs[self.hs.hostname]
+ self.assertEqual(len(oursigs), 2)
+
+ # the requested key should be present in the verify_keys section
+ self.assertIn("ed25519:ver1", keys[0]["verify_keys"])
+
+
+class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
+ """End-to-end tests of the perspectives fetch case
+
+ The idea here is to actually wire up a PerspectivesKeyFetcher to the notary
+ endpoint, to check that the two implementations are compatible.
+ """
+
+ def default_config(self):
+ config = super().default_config()
+
+ # replace the signing key with our own
+ self.hs_signing_key = signedjson.key.generate_signing_key("kssk")
+ strm = StringIO()
+ signedjson.key.write_signing_keys(strm, [self.hs_signing_key])
+ config["signing_key"] = strm.getvalue()
+
+ return config
+
+ def prepare(self, reactor, clock, homeserver):
+ # make a second homeserver, configured to use the first one as a key notary
+ self.http_client2 = Mock()
+ config = default_config(name="keyclient")
+ config["trusted_key_servers"] = [
+ {
+ "server_name": self.hs.hostname,
+ "verify_keys": {
+ "ed25519:%s"
+ % (
+ self.hs_signing_key.version,
+ ): signedjson.key.encode_verify_key_base64(
+ self.hs_signing_key.verify_key
+ )
+ },
+ }
+ ]
+ self.hs2 = self.setup_test_homeserver(
+ http_client=self.http_client2, config=config
+ )
+
+ # wire up outbound POST /key/v2/query requests from hs2 so that they
+ # will be forwarded to hs1
+ def post_json(destination, path, data):
+ self.assertEqual(destination, self.hs.hostname)
+ self.assertEqual(
+ path, "/_matrix/key/v2/query",
+ )
+
+ channel = FakeChannel(self.site, self.reactor)
+ req = SynapseRequest(channel)
+ req.content = BytesIO(encode_canonical_json(data))
+
+ req.requestReceived(
+ b"POST", path.encode("utf-8"), b"1.1",
+ )
+ wait_until_result(self.reactor, req)
+ self.assertEqual(channel.code, 200)
+ resp = channel.json_body
+ return resp
+
+ self.http_client2.post_json.side_effect = post_json
+
+ def test_get_key(self):
+ """Fetch a key belonging to a random server"""
+ # make up a key to be fetched.
+ testkey = signedjson.key.generate_signing_key("abc")
+
+ # we expect hs1 to make a regular key request to the target server
+ self.expect_outgoing_key_request("targetserver", testkey)
+ keyid = "ed25519:%s" % (testkey.version,)
+
+ fetcher = PerspectivesKeyFetcher(self.hs2)
+ d = fetcher.get_keys({"targetserver": {keyid: 1000}})
+ res = self.get_success(d)
+ self.assertIn("targetserver", res)
+ keyres = res["targetserver"][keyid]
+ assert isinstance(keyres, FetchKeyResult)
+ self.assertEqual(
+ signedjson.key.encode_verify_key_base64(keyres.verify_key),
+ signedjson.key.encode_verify_key_base64(testkey.verify_key),
+ )
+
+ def test_get_notary_key(self):
+ """Fetch a key belonging to the notary server"""
+ # make up a key to be fetched. We randomise the keyid to try to get it to
+ # appear before the key server signing key sometimes (otherwise we bail out
+ # before fetching its signature)
+ testkey = signedjson.key.generate_signing_key(random_string(5))
+
+ # we expect hs1 to make a regular key request to itself
+ self.expect_outgoing_key_request(self.hs.hostname, testkey)
+ keyid = "ed25519:%s" % (testkey.version,)
+
+ fetcher = PerspectivesKeyFetcher(self.hs2)
+ d = fetcher.get_keys({self.hs.hostname: {keyid: 1000}})
+ res = self.get_success(d)
+ self.assertIn(self.hs.hostname, res)
+ keyres = res[self.hs.hostname][keyid]
+ assert isinstance(keyres, FetchKeyResult)
+ self.assertEqual(
+ signedjson.key.encode_verify_key_base64(keyres.verify_key),
+ signedjson.key.encode_verify_key_base64(testkey.verify_key),
+ )
+
+ def test_get_notary_keyserver_key(self):
+ """Fetch the notary's keyserver key"""
+ # we expect hs1 to make a regular key request to itself
+ self.expect_outgoing_key_request(self.hs.hostname, self.hs_signing_key)
+ keyid = "ed25519:%s" % (self.hs_signing_key.version,)
+
+ fetcher = PerspectivesKeyFetcher(self.hs2)
+ d = fetcher.get_keys({self.hs.hostname: {keyid: 1000}})
+ res = self.get_success(d)
+ self.assertIn(self.hs.hostname, res)
+ keyres = res[self.hs.hostname][keyid]
+ assert isinstance(keyres, FetchKeyResult)
+ self.assertEqual(
+ signedjson.key.encode_verify_key_base64(keyres.verify_key),
+ signedjson.key.encode_verify_key_base64(self.hs_signing_key.verify_key),
+ )
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index bc662b61db..1ca648ef2b 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -18,10 +18,16 @@ import os
import shutil
import tempfile
from binascii import unhexlify
+from io import BytesIO
+from typing import Optional
from mock import Mock
from six.moves.urllib import parse
+import attr
+import PIL.Image as Image
+from parameterized import parameterized_class
+
from twisted.internet.defer import Deferred
from synapse.logging.context import make_deferred_yieldable
@@ -94,6 +100,68 @@ class MediaStorageTests(unittest.HomeserverTestCase):
self.assertEqual(test_body, body)
+@attr.s
+class _TestImage:
+ """An image for testing thumbnailing with the expected results
+
+ Attributes:
+ data: The raw image to thumbnail
+ content_type: The type of the image as a content type, e.g. "image/png"
+ extension: The extension associated with the format, e.g. ".png"
+ expected_cropped: The expected bytes from cropped thumbnailing, or None if
+ test should just check for success.
+ expected_scaled: The expected bytes from scaled thumbnailing, or None if
+ test should just check for a valid image returned.
+ """
+
+ data = attr.ib(type=bytes)
+ content_type = attr.ib(type=bytes)
+ extension = attr.ib(type=bytes)
+ expected_cropped = attr.ib(type=Optional[bytes])
+ expected_scaled = attr.ib(type=Optional[bytes])
+
+
+@parameterized_class(
+ ("test_image",),
+ [
+ # smol png
+ (
+ _TestImage(
+ unhexlify(
+ b"89504e470d0a1a0a0000000d4948445200000001000000010806"
+ b"0000001f15c4890000000a49444154789c63000100000500010d"
+ b"0a2db40000000049454e44ae426082"
+ ),
+ b"image/png",
+ b".png",
+ unhexlify(
+ b"89504e470d0a1a0a0000000d4948445200000020000000200806"
+ b"000000737a7af40000001a49444154789cedc101010000008220"
+ b"ffaf6e484001000000ef0610200001194334ee0000000049454e"
+ b"44ae426082"
+ ),
+ unhexlify(
+ b"89504e470d0a1a0a0000000d4948445200000001000000010806"
+ b"0000001f15c4890000000d49444154789c636060606000000005"
+ b"0001a5f645400000000049454e44ae426082"
+ ),
+ ),
+ ),
+ # small lossless webp
+ (
+ _TestImage(
+ unhexlify(
+ b"524946461a000000574542505650384c0d0000002f0000001007"
+ b"1011118888fe0700"
+ ),
+ b"image/webp",
+ b".webp",
+ None,
+ None,
+ ),
+ ),
+ ],
+)
class MediaRepoTests(unittest.HomeserverTestCase):
hijack_auth = True
@@ -149,19 +217,13 @@ class MediaRepoTests(unittest.HomeserverTestCase):
self.media_repo = hs.get_media_repository_resource()
self.download_resource = self.media_repo.children[b"download"]
+ self.thumbnail_resource = self.media_repo.children[b"thumbnail"]
- # smol png
- self.end_content = unhexlify(
- b"89504e470d0a1a0a0000000d4948445200000001000000010806"
- b"0000001f15c4890000000a49444154789c63000100000500010d"
- b"0a2db40000000049454e44ae426082"
- )
+ self.media_id = "example.com/12345"
def _req(self, content_disposition):
- request, channel = self.make_request(
- "GET", "example.com/12345", shorthand=False
- )
+ request, channel = self.make_request("GET", self.media_id, shorthand=False)
request.render(self.download_resource)
self.pump()
@@ -170,19 +232,19 @@ class MediaRepoTests(unittest.HomeserverTestCase):
self.assertEqual(len(self.fetches), 1)
self.assertEqual(self.fetches[0][1], "example.com")
self.assertEqual(
- self.fetches[0][2], "/_matrix/media/v1/download/example.com/12345"
+ self.fetches[0][2], "/_matrix/media/v1/download/" + self.media_id
)
self.assertEqual(self.fetches[0][3], {"allow_remote": "false"})
headers = {
- b"Content-Length": [b"%d" % (len(self.end_content))],
- b"Content-Type": [b"image/png"],
+ b"Content-Length": [b"%d" % (len(self.test_image.data))],
+ b"Content-Type": [self.test_image.content_type],
}
if content_disposition:
headers[b"Content-Disposition"] = [content_disposition]
self.fetches[0][0].callback(
- (self.end_content, (len(self.end_content), headers))
+ (self.test_image.data, (len(self.test_image.data), headers))
)
self.pump()
@@ -195,12 +257,15 @@ class MediaRepoTests(unittest.HomeserverTestCase):
If the filename is filename=<ascii> then Synapse will decode it as an
ASCII string, and use filename= in the response.
"""
- channel = self._req(b"inline; filename=out.png")
+ channel = self._req(b"inline; filename=out" + self.test_image.extension)
headers = channel.headers
- self.assertEqual(headers.getRawHeaders(b"Content-Type"), [b"image/png"])
self.assertEqual(
- headers.getRawHeaders(b"Content-Disposition"), [b"inline; filename=out.png"]
+ headers.getRawHeaders(b"Content-Type"), [self.test_image.content_type]
+ )
+ self.assertEqual(
+ headers.getRawHeaders(b"Content-Disposition"),
+ [b"inline; filename=out" + self.test_image.extension],
)
def test_disposition_filenamestar_utf8escaped(self):
@@ -210,13 +275,17 @@ class MediaRepoTests(unittest.HomeserverTestCase):
response.
"""
filename = parse.quote("\u2603".encode("utf8")).encode("ascii")
- channel = self._req(b"inline; filename*=utf-8''" + filename + b".png")
+ channel = self._req(
+ b"inline; filename*=utf-8''" + filename + self.test_image.extension
+ )
headers = channel.headers
- self.assertEqual(headers.getRawHeaders(b"Content-Type"), [b"image/png"])
+ self.assertEqual(
+ headers.getRawHeaders(b"Content-Type"), [self.test_image.content_type]
+ )
self.assertEqual(
headers.getRawHeaders(b"Content-Disposition"),
- [b"inline; filename*=utf-8''" + filename + b".png"],
+ [b"inline; filename*=utf-8''" + filename + self.test_image.extension],
)
def test_disposition_none(self):
@@ -227,5 +296,39 @@ class MediaRepoTests(unittest.HomeserverTestCase):
channel = self._req(None)
headers = channel.headers
- self.assertEqual(headers.getRawHeaders(b"Content-Type"), [b"image/png"])
+ self.assertEqual(
+ headers.getRawHeaders(b"Content-Type"), [self.test_image.content_type]
+ )
self.assertEqual(headers.getRawHeaders(b"Content-Disposition"), None)
+
+ def test_thumbnail_crop(self):
+ self._test_thumbnail("crop", self.test_image.expected_cropped)
+
+ def test_thumbnail_scale(self):
+ self._test_thumbnail("scale", self.test_image.expected_scaled)
+
+ def _test_thumbnail(self, method, expected_body):
+ params = "?width=32&height=32&method=" + method
+ request, channel = self.make_request(
+ "GET", self.media_id + params, shorthand=False
+ )
+ request.render(self.thumbnail_resource)
+ self.pump()
+
+ headers = {
+ b"Content-Length": [b"%d" % (len(self.test_image.data))],
+ b"Content-Type": [self.test_image.content_type],
+ }
+ self.fetches[0][0].callback(
+ (self.test_image.data, (len(self.test_image.data), headers))
+ )
+ self.pump()
+
+ self.assertEqual(channel.code, 200)
+ if expected_body is not None:
+ self.assertEqual(
+ channel.result["body"], expected_body, channel.result["body"]
+ )
+ else:
+ # ensure that the result is at least some valid image
+ Image.open(BytesIO(channel.result["body"]))
diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py
index 976652aee8..2826211f32 100644
--- a/tests/rest/media/v1/test_url_preview.py
+++ b/tests/rest/media/v1/test_url_preview.py
@@ -74,6 +74,12 @@ class URLPreviewTests(unittest.HomeserverTestCase):
)
config["url_preview_ip_range_whitelist"] = ("1.1.1.1",)
config["url_preview_url_blacklist"] = []
+ config["url_preview_accept_language"] = [
+ "en-UK",
+ "en-US;q=0.9",
+ "fr;q=0.8",
+ "*;q=0.7",
+ ]
self.storage_path = self.mktemp()
self.media_store_path = self.mktemp()
@@ -247,6 +253,41 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430")
+ def test_overlong_title(self):
+ self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")]
+
+ end_content = (
+ b"<html><head>"
+ b"<title>" + b"x" * 2000 + b"</title>"
+ b'<meta property="og:description" content="hi" />'
+ b"</head></html>"
+ )
+
+ request, channel = self.make_request(
+ "GET", "url_preview?url=http://matrix.org", shorthand=False
+ )
+ request.render(self.preview_url)
+ self.pump()
+
+ client = self.reactor.tcpClients[0][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ (
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+ b'Content-Type: text/html; charset="windows-1251"\r\n\r\n'
+ )
+ % (len(end_content),)
+ + end_content
+ )
+
+ self.pump()
+ self.assertEqual(channel.code, 200)
+ res = channel.json_body
+ # We should only see the `og:description` field, as `title` is too long and should be stripped out
+ self.assertCountEqual(["og:description"], res.keys())
+
def test_ipaddr(self):
"""
IP addresses can be previewed directly.
@@ -472,3 +513,52 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.pump()
self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body, {})
+
+ def test_accept_language_config_option(self):
+ """
+ Accept-Language header is sent to the remote server
+ """
+ self.lookups["example.com"] = [(IPv4Address, "8.8.8.8")]
+
+ # Build and make a request to the server
+ request, channel = self.make_request(
+ "GET", "url_preview?url=http://example.com", shorthand=False
+ )
+ request.render(self.preview_url)
+ self.pump()
+
+ # Extract Synapse's tcp client
+ client = self.reactor.tcpClients[0][2].buildProtocol(None)
+
+ # Build a fake remote server to reply with
+ server = AccumulatingProtocol()
+
+ # Connect the two together
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+
+ # Tell Synapse that it has received some data from the remote server
+ client.dataReceived(
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\nContent-Type: text/html\r\n\r\n"
+ % (len(self.end_content),)
+ + self.end_content
+ )
+
+ # Move the reactor along until we get a response on our original channel
+ self.pump()
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(
+ channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
+ )
+
+ # Check that the server received the Accept-Language header as part
+ # of the request from Synapse
+ self.assertIn(
+ (
+ b"Accept-Language: en-UK\r\n"
+ b"Accept-Language: en-US;q=0.9\r\n"
+ b"Accept-Language: fr;q=0.8\r\n"
+ b"Accept-Language: *;q=0.7"
+ ),
+ server.data,
+ )
diff --git a/tests/server.py b/tests/server.py
index e397ebe8fa..1644710aa0 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -20,6 +20,7 @@ from twisted.python.failure import Failure
from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
from twisted.web.http import unquote
from twisted.web.http_headers import Headers
+from twisted.web.server import Site
from synapse.http.site import SynapseRequest
from synapse.util import Clock
@@ -42,6 +43,7 @@ class FakeChannel(object):
wire).
"""
+ site = attr.ib(type=Site)
_reactor = attr.ib()
result = attr.ib(default=attr.Factory(dict))
_producer = None
@@ -161,7 +163,11 @@ def make_request(
path = path.encode("ascii")
# Decorate it to be the full path, if we're using shorthand
- if shorthand and not path.startswith(b"/_matrix"):
+ if (
+ shorthand
+ and not path.startswith(b"/_matrix")
+ and not path.startswith(b"/_synapse")
+ ):
path = b"/_matrix/client/r0/" + path
path = path.replace(b"//", b"/")
@@ -172,9 +178,9 @@ def make_request(
content = content.encode("utf8")
site = FakeSite()
- channel = FakeChannel(reactor)
+ channel = FakeChannel(site, reactor)
- req = request(site, channel)
+ req = request(channel)
req.process = lambda: b""
req.content = BytesIO(content)
req.postpath = list(map(unquote, path[1:].split(b"/")))
@@ -298,41 +304,42 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs):
Set up a synchronous test server, driven by the reactor used by
the homeserver.
"""
- d = _sth(cleanup_func, *args, **kwargs).result
+ server = _sth(cleanup_func, *args, **kwargs)
- if isinstance(d, Failure):
- d.raiseException()
+ database = server.config.database.get_single_database()
# Make the thread pool synchronous.
- clock = d.get_clock()
- pool = d.get_db_pool()
-
- def runWithConnection(func, *args, **kwargs):
- return threads.deferToThreadPool(
- pool._reactor,
- pool.threadpool,
- pool._runWithConnection,
- func,
- *args,
- **kwargs
- )
-
- def runInteraction(interaction, *args, **kwargs):
- return threads.deferToThreadPool(
- pool._reactor,
- pool.threadpool,
- pool._runInteraction,
- interaction,
- *args,
- **kwargs
- )
+ clock = server.get_clock()
+
+ for database in server.get_datastores().databases:
+ pool = database._db_pool
+
+ def runWithConnection(func, *args, **kwargs):
+ return threads.deferToThreadPool(
+ pool._reactor,
+ pool.threadpool,
+ pool._runWithConnection,
+ func,
+ *args,
+ **kwargs
+ )
+
+ def runInteraction(interaction, *args, **kwargs):
+ return threads.deferToThreadPool(
+ pool._reactor,
+ pool.threadpool,
+ pool._runInteraction,
+ interaction,
+ *args,
+ **kwargs
+ )
- if pool:
pool.runWithConnection = runWithConnection
pool.runInteraction = runInteraction
pool.threadpool = ThreadPool(clock._reactor)
pool.running = True
- return d
+
+ return server
def get_clock():
@@ -375,6 +382,7 @@ class FakeTransport(object):
disconnecting = False
disconnected = False
+ connected = True
buffer = attr.ib(default=b"")
producer = attr.ib(default=None)
autoflush = attr.ib(default=True)
@@ -391,11 +399,25 @@ class FakeTransport(object):
self.disconnecting = True
if self._protocol:
self._protocol.connectionLost(reason)
- self.disconnected = True
+
+ # if we still have data to write, delay until that is done
+ if self.buffer:
+ logger.info(
+ "FakeTransport: Delaying disconnect until buffer is flushed"
+ )
+ else:
+ self.connected = False
+ self.disconnected = True
def abortConnection(self):
logger.info("FakeTransport: abortConnection()")
- self.loseConnection()
+
+ if not self.disconnecting:
+ self.disconnecting = True
+ if self._protocol:
+ self._protocol.connectionLost(None)
+
+ self.disconnected = True
def pauseProducing(self):
if not self.producer:
@@ -426,6 +448,9 @@ class FakeTransport(object):
self._reactor.callLater(0.0, _produce)
def write(self, byt):
+ if self.disconnecting:
+ raise Exception("Writing to disconnecting FakeTransport")
+
self.buffer = self.buffer + byt
# always actually do the write asynchronously. Some protocols (notably the
@@ -470,6 +495,10 @@ class FakeTransport(object):
if self.buffer and self.autoflush:
self._reactor.callLater(0.0, self.flush)
+ if not self.buffer and self.disconnecting:
+ logger.info("FakeTransport: Buffer now empty, completing disconnect")
+ self.disconnected = True
+
def connect_client(reactor: IReactorTCP, client_id: int) -> AccumulatingProtocol:
"""
diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index cdf89e3383..99908edba3 100644
--- a/tests/server_notices/test_resource_limits_server_notices.py
+++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -17,27 +17,43 @@ from mock import Mock
from twisted.internet import defer
-from synapse.api.constants import EventTypes, ServerNoticeMsgType
+from synapse.api.constants import EventTypes, LimitBlockingTypes, ServerNoticeMsgType
from synapse.api.errors import ResourceLimitError
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+from synapse.rest.client.v2_alpha import sync
from synapse.server_notices.resource_limits_server_notices import (
ResourceLimitsServerNotices,
)
from tests import unittest
+from tests.unittest import override_config
+from tests.utils import default_config
class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
- def make_homeserver(self, reactor, clock):
- hs_config = self.default_config("test")
- hs_config["server_notices"] = {
- "system_mxid_localpart": "server",
- "system_mxid_display_name": "test display name",
- "system_mxid_avatar_url": None,
- "room_name": "Server Notices",
- }
+ def default_config(self):
+ config = default_config("test")
+
+ config.update(
+ {
+ "admin_contact": "mailto:user@test.com",
+ "limit_usage_by_mau": True,
+ "server_notices": {
+ "system_mxid_localpart": "server",
+ "system_mxid_display_name": "test display name",
+ "system_mxid_avatar_url": None,
+ "room_name": "Server Notices",
+ },
+ }
+ )
+
+ # apply any additional config which was specified via the override_config
+ # decorator.
+ if self._extra_config is not None:
+ config.update(self._extra_config)
- hs = self.setup_test_homeserver(config=hs_config)
- return hs
+ return config
def prepare(self, reactor, clock, hs):
self.server_notices_sender = self.hs.get_server_notices_sender()
@@ -52,54 +68,41 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
self._rlsn._store.user_last_seen_monthly_active = Mock(
return_value=defer.succeed(1000)
)
- self._send_notice = self._rlsn._server_notices_manager.send_notice
- self._rlsn._server_notices_manager.send_notice = Mock()
- self._rlsn._state.get_current_state = Mock(return_value=defer.succeed(None))
- self._rlsn._store.get_events = Mock(return_value=defer.succeed({}))
-
+ self._rlsn._server_notices_manager.send_notice = Mock(
+ return_value=defer.succeed(Mock())
+ )
self._send_notice = self._rlsn._server_notices_manager.send_notice
- self.hs.config.limit_usage_by_mau = True
self.user_id = "@user_id:test"
- # self.server_notices_mxid = "@server:test"
- # self.server_notices_mxid_display_name = None
- # self.server_notices_mxid_avatar_url = None
- # self.server_notices_room_name = "Server Notices"
-
- self._rlsn._server_notices_manager.get_notice_room_for_user = Mock(
- returnValue=""
+ self._rlsn._server_notices_manager.get_or_create_notice_room_for_user = Mock(
+ return_value=defer.succeed("!something:localhost")
)
- self._rlsn._store.add_tag_to_room = Mock()
- self._rlsn._store.get_tags_for_room = Mock(return_value={})
- self.hs.config.admin_contact = "mailto:user@test.com"
-
- def test_maybe_send_server_notice_to_user_flag_off(self):
- """Tests cases where the flags indicate nothing to do"""
- # test hs disabled case
- self.hs.config.hs_disabled = True
+ self._rlsn._store.add_tag_to_room = Mock(return_value=defer.succeed(None))
+ self._rlsn._store.get_tags_for_room = Mock(return_value=defer.succeed({}))
+ @override_config({"hs_disabled": True})
+ def test_maybe_send_server_notice_disabled_hs(self):
+ """If the HS is disabled, we should not send notices"""
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
-
self._send_notice.assert_not_called()
- # Test when mau limiting disabled
- self.hs.config.hs_disabled = False
- self.hs.config.limit_usage_by_mau = False
- self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
+ @override_config({"limit_usage_by_mau": False})
+ def test_maybe_send_server_notice_to_user_flag_off(self):
+ """If mau limiting is disabled, we should not send notices"""
+ self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
self._send_notice.assert_not_called()
def test_maybe_send_server_notice_to_user_remove_blocked_notice(self):
"""Test when user has blocked notice, but should have it removed"""
- self._rlsn._auth.check_auth_blocking = Mock()
+ self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None))
mock_event = Mock(
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
)
self._rlsn._store.get_events = Mock(
return_value=defer.succeed({"123": mock_event})
)
-
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
# Would be better to check the content, but once == remove blocking event
self._send_notice.assert_called_once()
@@ -109,7 +112,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
Test when user has blocked notice, but notice ought to be there (NOOP)
"""
self._rlsn._auth.check_auth_blocking = Mock(
- side_effect=ResourceLimitError(403, "foo")
+ return_value=defer.succeed(None), side_effect=ResourceLimitError(403, "foo")
)
mock_event = Mock(
@@ -118,6 +121,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
self._rlsn._store.get_events = Mock(
return_value=defer.succeed({"123": mock_event})
)
+
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
self._send_notice.assert_not_called()
@@ -126,20 +130,19 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
"""
Test when user does not have blocked notice, but should have one
"""
-
self._rlsn._auth.check_auth_blocking = Mock(
- side_effect=ResourceLimitError(403, "foo")
+ return_value=defer.succeed(None), side_effect=ResourceLimitError(403, "foo")
)
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
# Would be better to check contents, but 2 calls == set blocking event
- self.assertTrue(self._send_notice.call_count == 2)
+ self.assertEqual(self._send_notice.call_count, 2)
def test_maybe_send_server_notice_to_user_add_blocked_notice_noop(self):
"""
Test when user does not have blocked notice, nor should they (NOOP)
"""
- self._rlsn._auth.check_auth_blocking = Mock()
+ self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None))
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@@ -150,7 +153,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
Test when user is not part of the MAU cohort - this should not ever
happen - but ...
"""
- self._rlsn._auth.check_auth_blocking = Mock()
+ self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None))
self._rlsn._store.user_last_seen_monthly_active = Mock(
return_value=defer.succeed(None)
)
@@ -158,8 +161,87 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
self._send_notice.assert_not_called()
+ @override_config({"mau_limit_alerting": False})
+ def test_maybe_send_server_notice_when_alerting_suppressed_room_unblocked(self):
+ """
+ Test that when server is over MAU limit and alerting is suppressed, then
+ an alert message is not sent into the room
+ """
+ self._rlsn._auth.check_auth_blocking = Mock(
+ return_value=defer.succeed(None),
+ side_effect=ResourceLimitError(
+ 403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER
+ ),
+ )
+ self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
+
+ self.assertEqual(self._send_notice.call_count, 0)
+
+ @override_config({"mau_limit_alerting": False})
+ def test_check_hs_disabled_unaffected_by_mau_alert_suppression(self):
+ """
+ Test that when a server is disabled, that MAU limit alerting is ignored.
+ """
+ self._rlsn._auth.check_auth_blocking = Mock(
+ return_value=defer.succeed(None),
+ side_effect=ResourceLimitError(
+ 403, "foo", limit_type=LimitBlockingTypes.HS_DISABLED
+ ),
+ )
+ self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
+
+ # Would be better to check contents, but 2 calls == set blocking event
+ self.assertEqual(self._send_notice.call_count, 2)
+
+ @override_config({"mau_limit_alerting": False})
+ def test_maybe_send_server_notice_when_alerting_suppressed_room_blocked(self):
+ """
+ When the room is already in a blocked state, test that when alerting
+ is suppressed that the room is returned to an unblocked state.
+ """
+ self._rlsn._auth.check_auth_blocking = Mock(
+ return_value=defer.succeed(None),
+ side_effect=ResourceLimitError(
+ 403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER
+ ),
+ )
+
+ self._rlsn._server_notices_manager.__is_room_currently_blocked = Mock(
+ return_value=defer.succeed((True, []))
+ )
+
+ mock_event = Mock(
+ type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
+ )
+ self._rlsn._store.get_events = Mock(
+ return_value=defer.succeed({"123": mock_event})
+ )
+ self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
+
+ self._send_notice.assert_called_once()
+
class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ sync.register_servlets,
+ ]
+
+ def default_config(self):
+ c = super().default_config()
+ c["server_notices"] = {
+ "system_mxid_localpart": "server",
+ "system_mxid_display_name": None,
+ "system_mxid_avatar_url": None,
+ "room_name": "Test Server Notice Room",
+ }
+ c["limit_usage_by_mau"] = True
+ c["max_mau_value"] = 5
+ c["admin_contact"] = "mailto:user@test.com"
+ return c
+
def prepare(self, reactor, clock, hs):
self.store = self.hs.get_datastore()
self.server_notices_sender = self.hs.get_server_notices_sender()
@@ -173,22 +255,14 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
if not isinstance(self._rlsn, ResourceLimitsServerNotices):
raise Exception("Failed to find reference to ResourceLimitsServerNotices")
- self.hs.config.limit_usage_by_mau = True
- self.hs.config.hs_disabled = False
- self.hs.config.max_mau_value = 5
- self.hs.config.server_notices_mxid = "@server:test"
- self.hs.config.server_notices_mxid_display_name = None
- self.hs.config.server_notices_mxid_avatar_url = None
- self.hs.config.server_notices_room_name = "Test Server Notice Room"
-
self.user_id = "@user_id:test"
- self.hs.config.admin_contact = "mailto:user@test.com"
-
def test_server_notice_only_sent_once(self):
self.store.get_monthly_active_count = Mock(return_value=1000)
- self.store.user_last_seen_monthly_active = Mock(return_value=1000)
+ self.store.user_last_seen_monthly_active = Mock(
+ return_value=defer.succeed(1000)
+ )
# Call the function multiple times to ensure we only send the notice once
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@@ -198,7 +272,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
# Now lets get the last load of messages in the service notice room and
# check that there is only one server notice
room_id = self.get_success(
- self.server_notices_manager.get_notice_room_for_user(self.user_id)
+ self.server_notices_manager.get_or_create_notice_room_for_user(self.user_id)
)
token = self.get_success(self.event_source.get_current_token())
@@ -218,3 +292,86 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
count += 1
self.assertEqual(count, 1)
+
+ def test_no_invite_without_notice(self):
+ """Tests that a user doesn't get invited to a server notices room without a
+ server notice being sent.
+
+ The scenario for this test is a single user on a server where the MAU limit
+ hasn't been reached (since it's the only user and the limit is 5), so users
+ shouldn't receive a server notice.
+ """
+ self.register_user("user", "password")
+ tok = self.login("user", "password")
+
+ request, channel = self.make_request("GET", "/sync?timeout=0", access_token=tok)
+ self.render(request)
+
+ invites = channel.json_body["rooms"]["invite"]
+ self.assertEqual(len(invites), 0, invites)
+
+ def test_invite_with_notice(self):
+ """Tests that, if the MAU limit is hit, the server notices user invites each user
+ to a room in which it has sent a notice.
+ """
+ user_id, tok, room_id = self._trigger_notice_and_join()
+
+ # Sync again to retrieve the events in the room, so we can check whether this
+ # room has a notice in it.
+ request, channel = self.make_request("GET", "/sync?timeout=0", access_token=tok)
+ self.render(request)
+
+ # Scan the events in the room to search for a message from the server notices
+ # user.
+ events = channel.json_body["rooms"]["join"][room_id]["timeline"]["events"]
+ notice_in_room = False
+ for event in events:
+ if (
+ event["type"] == EventTypes.Message
+ and event["sender"] == self.hs.config.server_notices_mxid
+ ):
+ notice_in_room = True
+
+ self.assertTrue(notice_in_room, "No server notice in room")
+
+ def _trigger_notice_and_join(self):
+ """Creates enough active users to hit the MAU limit and trigger a system notice
+ about it, then joins the system notices room with one of the users created.
+
+ Returns:
+ user_id (str): The ID of the user that joined the room.
+ tok (str): The access token of the user that joined the room.
+ room_id (str): The ID of the room that's been joined.
+ """
+ user_id = None
+ tok = None
+ invites = []
+
+ # Register as many users as the MAU limit allows.
+ for i in range(self.hs.config.max_mau_value):
+ localpart = "user%d" % i
+ user_id = self.register_user(localpart, "password")
+ tok = self.login(localpart, "password")
+
+ # Sync with the user's token to mark the user as active.
+ request, channel = self.make_request(
+ "GET", "/sync?timeout=0", access_token=tok,
+ )
+ self.render(request)
+
+ # Also retrieves the list of invites for this user. We don't care about that
+ # one except if we're processing the last user, which should have received an
+ # invite to a room with a server notice about the MAU limit being reached.
+ # We could also pick another user and sync with it, which would return an
+ # invite to a system notices room, but it doesn't matter which user we're
+ # using so we use the last one because it saves us an extra sync.
+ invites = channel.json_body["rooms"]["invite"]
+
+ # Make sure we have an invite to process.
+ self.assertEqual(len(invites), 1, invites)
+
+ # Join the room.
+ room_id = list(invites.keys())[0]
+ self.helper.join(room=room_id, user=user_id, tok=tok)
+
+ return user_id, tok, room_id
diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py
index 8d3845c870..a44960203e 100644
--- a/tests/state/test_v2.py
+++ b/tests/state/test_v2.py
@@ -22,7 +22,7 @@ import attr
from synapse.api.constants import EventTypes, JoinRules, Membership
from synapse.api.room_versions import RoomVersions
from synapse.event_auth import auth_types_for_event
-from synapse.events import FrozenEvent
+from synapse.events import make_event_from_dict
from synapse.state.v2 import lexicographical_topological_sort, resolve_events_with_store
from synapse.types import EventID
@@ -58,6 +58,7 @@ class FakeEvent(object):
self.type = type
self.state_key = state_key
self.content = content
+ self.room_id = ROOM_ID
def to_event(self, auth_events, prev_events):
"""Given the auth_events and prev_events, convert to a Frozen Event
@@ -88,7 +89,7 @@ class FakeEvent(object):
if self.state_key is not None:
event_dict["state_key"] = self.state_key
- return FrozenEvent(event_dict)
+ return make_event_from_dict(event_dict)
# All graphs start with this set of events
@@ -418,6 +419,7 @@ class StateTestCase(unittest.TestCase):
state_before = dict(state_at_event[prev_events[0]])
else:
state_d = resolve_events_with_store(
+ ROOM_ID,
RoomVersions.V2.identifier,
[state_at_event[n] for n in prev_events],
event_map=event_map,
@@ -565,6 +567,7 @@ class SimpleParamStateTestCase(unittest.TestCase):
# Test that we correctly handle passing `None` as the event_map
state_d = resolve_events_with_store(
+ ROOM_ID,
RoomVersions.V2.identifier,
[self.state_at_bob, self.state_at_charlie],
event_map=None,
@@ -600,7 +603,7 @@ class TestStateResolutionStore(object):
return {eid: self.event_map[eid] for eid in event_ids if eid in self.event_map}
- def get_auth_chain(self, event_ids):
+ def _get_auth_chain(self, event_ids):
"""Gets the full auth chain for a set of events (including rejected
events).
@@ -614,7 +617,6 @@ class TestStateResolutionStore(object):
Args:
event_ids (list): The event IDs of the events to fetch the auth
chain for. Must be state events.
-
Returns:
Deferred[list[str]]: List of event IDs of the auth chain.
"""
@@ -634,3 +636,9 @@ class TestStateResolutionStore(object):
stack.append(aid)
return list(result)
+
+ def get_auth_chain_difference(self, auth_sets):
+ chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets]
+
+ common = set(chains[0]).intersection(*chains[1:])
+ return set(chains[0]).union(*chains[1:]) - common
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index dd49a14524..5a50e4fdd4 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -25,8 +25,8 @@ from synapse.util.caches.descriptors import Cache, cached
from tests import unittest
-class CacheTestCase(unittest.TestCase):
- def setUp(self):
+class CacheTestCase(unittest.HomeserverTestCase):
+ def prepare(self, reactor, clock, homeserver):
self.cache = Cache("test")
def test_empty(self):
@@ -96,7 +96,7 @@ class CacheTestCase(unittest.TestCase):
cache.get(3)
-class CacheDecoratorTestCase(unittest.TestCase):
+class CacheDecoratorTestCase(unittest.HomeserverTestCase):
@defer.inlineCallbacks
def test_passthrough(self):
class A(object):
@@ -197,7 +197,7 @@ class CacheDecoratorTestCase(unittest.TestCase):
a.func.prefill(("foo",), ObservableDeferred(d))
- self.assertEquals(a.func("foo"), d.result)
+ self.assertEquals(a.func("foo").result, d.result)
self.assertEquals(callcount[0], 0)
@defer.inlineCallbacks
@@ -239,7 +239,7 @@ class CacheDecoratorTestCase(unittest.TestCase):
callcount2 = [0]
class A(object):
- @cached(max_entries=4) # HACK: This makes it 2 due to cache factor
+ @cached(max_entries=2)
def func(self, key):
callcount[0] += 1
return key
@@ -323,7 +323,7 @@ class UpsertManyTests(unittest.HomeserverTestCase):
self.table_name = "table_" + hs.get_secrets().token_hex(6)
self.get_success(
- self.storage.runInteraction(
+ self.storage.db.runInteraction(
"create",
lambda x, *a: x.execute(*a),
"CREATE TABLE %s (id INTEGER, username TEXT, value TEXT)"
@@ -331,7 +331,7 @@ class UpsertManyTests(unittest.HomeserverTestCase):
)
)
self.get_success(
- self.storage.runInteraction(
+ self.storage.db.runInteraction(
"index",
lambda x, *a: x.execute(*a),
"CREATE UNIQUE INDEX %sindex ON %s(id, username)"
@@ -354,9 +354,9 @@ class UpsertManyTests(unittest.HomeserverTestCase):
value_values = [["hello"], ["there"]]
self.get_success(
- self.storage.runInteraction(
+ self.storage.db.runInteraction(
"test",
- self.storage._simple_upsert_many_txn,
+ self.storage.db.simple_upsert_many_txn,
self.table_name,
key_names,
key_values,
@@ -367,13 +367,13 @@ class UpsertManyTests(unittest.HomeserverTestCase):
# Check results are what we expect
res = self.get_success(
- self.storage._simple_select_list(
+ self.storage.db.simple_select_list(
self.table_name, None, ["id, username, value"]
)
)
self.assertEqual(
set(self._dump_to_tuple(res)),
- set([(1, "user1", "hello"), (2, "user2", "there")]),
+ {(1, "user1", "hello"), (2, "user2", "there")},
)
# Update only user2
@@ -381,9 +381,9 @@ class UpsertManyTests(unittest.HomeserverTestCase):
value_values = [["bleb"]]
self.get_success(
- self.storage.runInteraction(
+ self.storage.db.runInteraction(
"test",
- self.storage._simple_upsert_many_txn,
+ self.storage.db.simple_upsert_many_txn,
self.table_name,
key_names,
key_values,
@@ -394,11 +394,11 @@ class UpsertManyTests(unittest.HomeserverTestCase):
# Check results are what we expect
res = self.get_success(
- self.storage._simple_select_list(
+ self.storage.db.simple_select_list(
self.table_name, None, ["id, username, value"]
)
)
self.assertEqual(
set(self._dump_to_tuple(res)),
- set([(1, "user1", "hello"), (2, "user2", "bleb")]),
+ {(1, "user1", "hello"), (2, "user2", "bleb")},
)
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 622b16a071..ef296e7dab 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -24,10 +24,11 @@ from twisted.internet import defer
from synapse.appservice import ApplicationService, ApplicationServiceState
from synapse.config._base import ConfigError
-from synapse.storage.appservice import (
+from synapse.storage.data_stores.main.appservice import (
ApplicationServiceStore,
ApplicationServiceTransactionStore,
)
+from synapse.storage.database import Database, make_conn
from tests import unittest
from tests.utils import setup_test_homeserver
@@ -42,7 +43,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
)
hs.config.app_service_config_files = self.as_yaml_files
- hs.config.event_cache_size = 1
+ hs.config.caches.event_cache_size = 1
hs.config.password_providers = []
self.as_token = "token1"
@@ -54,7 +55,10 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob")
self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob")
# must be done after inserts
- self.store = ApplicationServiceStore(hs.get_db_conn(), hs)
+ database = hs.get_datastores().databases[0]
+ self.store = ApplicationServiceStore(
+ database, make_conn(database._database_config, database.engine), hs
+ )
def tearDown(self):
# TODO: suboptimal that we need to create files for tests!
@@ -65,14 +69,14 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
pass
def _add_appservice(self, as_token, id, url, hs_token, sender):
- as_yaml = dict(
- url=url,
- as_token=as_token,
- hs_token=hs_token,
- id=id,
- sender_localpart=sender,
- namespaces={},
- )
+ as_yaml = {
+ "url": url,
+ "as_token": as_token,
+ "hs_token": hs_token,
+ "id": id,
+ "sender_localpart": sender,
+ "namespaces": {},
+ }
# use the token as the filename
with open(as_token, "w") as outfile:
outfile.write(yaml.dump(as_yaml))
@@ -106,12 +110,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
)
hs.config.app_service_config_files = self.as_yaml_files
- hs.config.event_cache_size = 1
+ hs.config.caches.event_cache_size = 1
hs.config.password_providers = []
- self.db_pool = hs.get_db_pool()
- self.engine = hs.database_engine
-
self.as_list = [
{"token": "token1", "url": "https://matrix-as.org", "id": "id_1"},
{"token": "alpha_tok", "url": "https://alpha.com", "id": "id_alpha"},
@@ -123,17 +124,25 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
self.as_yaml_files = []
- self.store = TestTransactionStore(hs.get_db_conn(), hs)
+ # We assume there is only one database in these tests
+ database = hs.get_datastores().databases[0]
+ self.db_pool = database._db_pool
+ self.engine = database.engine
- def _add_service(self, url, as_token, id):
- as_yaml = dict(
- url=url,
- as_token=as_token,
- hs_token="something",
- id=id,
- sender_localpart="a_sender",
- namespaces={},
+ db_config = hs.config.get_single_database()
+ self.store = TestTransactionStore(
+ database, make_conn(db_config, self.engine), hs
)
+
+ def _add_service(self, url, as_token, id):
+ as_yaml = {
+ "url": url,
+ "as_token": as_token,
+ "hs_token": "something",
+ "id": id,
+ "sender_localpart": "a_sender",
+ "namespaces": {},
+ }
# use the token as the filename
with open(as_token, "w") as outfile:
outfile.write(yaml.dump(as_yaml))
@@ -375,15 +384,15 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
)
self.assertEquals(2, len(services))
self.assertEquals(
- set([self.as_list[2]["id"], self.as_list[0]["id"]]),
- set([services[0].id, services[1].id]),
+ {self.as_list[2]["id"], self.as_list[0]["id"]},
+ {services[0].id, services[1].id},
)
# required for ApplicationServiceTransactionStoreTestCase tests
class TestTransactionStore(ApplicationServiceTransactionStore, ApplicationServiceStore):
- def __init__(self, db_conn, hs):
- super(TestTransactionStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(TestTransactionStore, self).__init__(database, db_conn, hs)
class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
@@ -413,10 +422,13 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
)
hs.config.app_service_config_files = [f1, f2]
- hs.config.event_cache_size = 1
+ hs.config.caches.event_cache_size = 1
hs.config.password_providers = []
- ApplicationServiceStore(hs.get_db_conn(), hs)
+ database = hs.get_datastores().databases[0]
+ ApplicationServiceStore(
+ database, make_conn(database._database_config, database.engine), hs
+ )
@defer.inlineCallbacks
def test_duplicate_ids(self):
@@ -428,11 +440,14 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
)
hs.config.app_service_config_files = [f1, f2]
- hs.config.event_cache_size = 1
+ hs.config.caches.event_cache_size = 1
hs.config.password_providers = []
with self.assertRaises(ConfigError) as cm:
- ApplicationServiceStore(hs.get_db_conn(), hs)
+ database = hs.get_datastores().databases[0]
+ ApplicationServiceStore(
+ database, make_conn(database._database_config, database.engine), hs
+ )
e = cm.exception
self.assertIn(f1, str(e))
@@ -449,11 +464,14 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
)
hs.config.app_service_config_files = [f1, f2]
- hs.config.event_cache_size = 1
+ hs.config.caches.event_cache_size = 1
hs.config.password_providers = []
with self.assertRaises(ConfigError) as cm:
- ApplicationServiceStore(hs.get_db_conn(), hs)
+ database = hs.get_datastores().databases[0]
+ ApplicationServiceStore(
+ database, make_conn(database._database_config, database.engine), hs
+ )
e = cm.exception
self.assertIn(f1, str(e))
diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index 9fabe3fbc0..940b166129 100644
--- a/tests/storage/test_background_update.py
+++ b/tests/storage/test_background_update.py
@@ -2,74 +2,90 @@ from mock import Mock
from twisted.internet import defer
+from synapse.storage.background_updates import BackgroundUpdater
+
from tests import unittest
-from tests.utils import setup_test_homeserver
-class BackgroundUpdateTestCase(unittest.TestCase):
- @defer.inlineCallbacks
- def setUp(self):
- hs = yield setup_test_homeserver(self.addCleanup)
- self.store = hs.get_datastore()
- self.clock = hs.get_clock()
+class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
+ def prepare(self, reactor, clock, homeserver):
+ self.updates = self.hs.get_datastore().db.updates # type: BackgroundUpdater
+ # the base test class should have run the real bg updates for us
+ self.assertTrue(
+ self.get_success(self.updates.has_completed_background_updates())
+ )
self.update_handler = Mock()
-
- yield self.store.register_background_update_handler(
+ self.updates.register_background_update_handler(
"test_update", self.update_handler
)
- # run the real background updates, to get them out the way
- # (perhaps we should run them as part of the test HS setup, since we
- # run all of the other schema setup stuff there?)
- while True:
- res = yield self.store.do_next_background_update(1000)
- if res is None:
- break
-
- @defer.inlineCallbacks
def test_do_background_update(self):
- desired_count = 1000
+ # the time we claim each update takes
duration_ms = 42
+ # the target runtime for each bg update
+ target_background_update_duration_ms = 50000
+
+ store = self.hs.get_datastore()
+ self.get_success(
+ store.db.simple_insert(
+ "background_updates",
+ values={"update_name": "test_update", "progress_json": '{"my_key": 1}'},
+ )
+ )
+
# first step: make a bit of progress
@defer.inlineCallbacks
def update(progress, count):
- self.clock.advance_time_msec(count * duration_ms)
+ yield self.clock.sleep((count * duration_ms) / 1000)
progress = {"my_key": progress["my_key"] + 1}
- yield self.store.runInteraction(
+ yield store.db.runInteraction(
"update_progress",
- self.store._background_update_progress_txn,
+ self.updates._background_update_progress_txn,
"test_update",
progress,
)
return count
self.update_handler.side_effect = update
-
- yield self.store.start_background_update("test_update", {"my_key": 1})
-
self.update_handler.reset_mock()
- result = yield self.store.do_next_background_update(duration_ms * desired_count)
- self.assertIsNotNone(result)
+ res = self.get_success(
+ self.updates.do_next_background_update(
+ target_background_update_duration_ms
+ ),
+ by=0.1,
+ )
+ self.assertFalse(res)
+
+ # on the first call, we should get run with the default background update size
self.update_handler.assert_called_once_with(
- {"my_key": 1}, self.store.DEFAULT_BACKGROUND_BATCH_SIZE
+ {"my_key": 1}, self.updates.DEFAULT_BACKGROUND_BATCH_SIZE
)
# second step: complete the update
+ # we should now get run with a much bigger number of items to update
@defer.inlineCallbacks
def update(progress, count):
- yield self.store._end_background_update("test_update")
+ self.assertEqual(progress, {"my_key": 2})
+ self.assertAlmostEqual(
+ count, target_background_update_duration_ms / duration_ms, places=0,
+ )
+ yield self.updates._end_background_update("test_update")
return count
self.update_handler.side_effect = update
self.update_handler.reset_mock()
- result = yield self.store.do_next_background_update(duration_ms * desired_count)
- self.assertIsNotNone(result)
- self.update_handler.assert_called_once_with({"my_key": 2}, desired_count)
+ result = self.get_success(
+ self.updates.do_next_background_update(target_background_update_duration_ms)
+ )
+ self.assertFalse(result)
+ self.update_handler.assert_called_once()
# third step: we don't expect to be called any more
self.update_handler.reset_mock()
- result = yield self.store.do_next_background_update(duration_ms * desired_count)
- self.assertIsNone(result)
+ result = self.get_success(
+ self.updates.do_next_background_update(target_background_update_duration_ms)
+ )
+ self.assertTrue(result)
self.assertFalse(self.update_handler.called)
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index c778de1f0c..278961c331 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -21,6 +21,7 @@ from mock import Mock
from twisted.internet import defer
from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import Database
from synapse.storage.engines import create_engine
from tests import unittest
@@ -50,22 +51,25 @@ class SQLBaseStoreTestCase(unittest.TestCase):
config = Mock()
config._disable_native_upserts = True
- config.event_cache_size = 1
- config.database_config = {"name": "sqlite3"}
- engine = create_engine(config.database_config)
+ config.caches = Mock()
+ config.caches.event_cache_size = 1
+ hs = TestHomeServer("test", config=config)
+
+ sqlite_config = {"name": "sqlite3"}
+ engine = create_engine(sqlite_config)
fake_engine = Mock(wraps=engine)
fake_engine.can_native_upsert = False
- hs = TestHomeServer(
- "test", db_pool=self.db_pool, config=config, database_engine=fake_engine
- )
- self.datastore = SQLBaseStore(None, hs)
+ db = Database(Mock(), Mock(config=sqlite_config), fake_engine)
+ db._db_pool = self.db_pool
+
+ self.datastore = SQLBaseStore(db, None, hs)
@defer.inlineCallbacks
def test_insert_1col(self):
self.mock_txn.rowcount = 1
- yield self.datastore._simple_insert(
+ yield self.datastore.db.simple_insert(
table="tablename", values={"columname": "Value"}
)
@@ -77,7 +81,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
def test_insert_3cols(self):
self.mock_txn.rowcount = 1
- yield self.datastore._simple_insert(
+ yield self.datastore.db.simple_insert(
table="tablename",
# Use OrderedDict() so we can assert on the SQL generated
values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]),
@@ -92,7 +96,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_txn.rowcount = 1
self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)]))
- value = yield self.datastore._simple_select_one_onecol(
+ value = yield self.datastore.db.simple_select_one_onecol(
table="tablename", keyvalues={"keycol": "TheKey"}, retcol="retcol"
)
@@ -106,7 +110,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_txn.rowcount = 1
self.mock_txn.fetchone.return_value = (1, 2, 3)
- ret = yield self.datastore._simple_select_one(
+ ret = yield self.datastore.db.simple_select_one(
table="tablename",
keyvalues={"keycol": "TheKey"},
retcols=["colA", "colB", "colC"],
@@ -122,7 +126,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_txn.rowcount = 0
self.mock_txn.fetchone.return_value = None
- ret = yield self.datastore._simple_select_one(
+ ret = yield self.datastore.db.simple_select_one(
table="tablename",
keyvalues={"keycol": "Not here"},
retcols=["colA"],
@@ -137,7 +141,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)]))
self.mock_txn.description = (("colA", None, None, None, None, None, None),)
- ret = yield self.datastore._simple_select_list(
+ ret = yield self.datastore.db.simple_select_list(
table="tablename", keyvalues={"keycol": "A set"}, retcols=["colA"]
)
@@ -150,7 +154,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
def test_update_one_1col(self):
self.mock_txn.rowcount = 1
- yield self.datastore._simple_update_one(
+ yield self.datastore.db.simple_update_one(
table="tablename",
keyvalues={"keycol": "TheKey"},
updatevalues={"columnname": "New Value"},
@@ -165,7 +169,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
def test_update_one_4cols(self):
self.mock_txn.rowcount = 1
- yield self.datastore._simple_update_one(
+ yield self.datastore.db.simple_update_one(
table="tablename",
keyvalues=OrderedDict([("colA", 1), ("colB", 2)]),
updatevalues=OrderedDict([("colC", 3), ("colD", 4)]),
@@ -180,7 +184,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
def test_delete_one(self):
self.mock_txn.rowcount = 1
- yield self.datastore._simple_delete_one(
+ yield self.datastore.db.simple_delete_one(
table="tablename", keyvalues={"keycol": "Go away"}
)
diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index e9e2d5337c..43425c969a 100644
--- a/tests/storage/test_cleanup_extrems.py
+++ b/tests/storage/test_cleanup_extrems.py
@@ -14,7 +14,13 @@
# limitations under the License.
import os.path
+from unittest.mock import patch
+from mock import Mock
+
+import synapse.rest.admin
+from synapse.api.constants import EventTypes
+from synapse.rest.client.v1 import login, room
from synapse.storage import prepare_database
from synapse.types import Requester, UserID
@@ -33,17 +39,21 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
# Create a test user and room
self.user = UserID("alice", "test")
self.requester = Requester(self.user, None, False, None, None)
- info = self.get_success(self.room_creator.create_room(self.requester, {}))
+ info, _ = self.get_success(self.room_creator.create_room(self.requester, {}))
self.room_id = info["room_id"]
def run_background_update(self):
"""Re run the background update to clean up the extremities.
"""
# Make sure we don't clash with in progress updates.
- self.assertTrue(self.store._all_done, "Background updates are still ongoing")
+ self.assertTrue(
+ self.store.db.updates._all_done, "Background updates are still ongoing"
+ )
schema_path = os.path.join(
prepare_database.dir_path,
+ "data_stores",
+ "main",
"schema",
"delta",
"54",
@@ -54,14 +64,20 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
prepare_database.executescript(txn, schema_path)
self.get_success(
- self.store.runInteraction("test_delete_forward_extremities", run_delta_file)
+ self.store.db.runInteraction(
+ "test_delete_forward_extremities", run_delta_file
+ )
)
# Ugh, have to reset this flag
- self.store._all_done = False
+ self.store.db.updates._all_done = False
- while not self.get_success(self.store.has_completed_background_updates()):
- self.get_success(self.store.do_next_background_update(100), by=0.1)
+ while not self.get_success(
+ self.store.db.updates.has_completed_background_updates()
+ ):
+ self.get_success(
+ self.store.db.updates.do_next_background_update(100), by=0.1
+ )
def test_soft_failed_extremities_handled_correctly(self):
"""Test that extremities are correctly calculated in the presence of
@@ -118,7 +134,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
latest_event_ids = self.get_success(
self.store.get_latest_event_ids_in_room(self.room_id)
)
- self.assertEqual(set(latest_event_ids), set((event_id_a, event_id_b)))
+ self.assertEqual(set(latest_event_ids), {event_id_a, event_id_b})
# Run the background update and check it did the right thing
self.run_background_update()
@@ -156,7 +172,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
latest_event_ids = self.get_success(
self.store.get_latest_event_ids_in_room(self.room_id)
)
- self.assertEqual(set(latest_event_ids), set((event_id_a, event_id_b)))
+ self.assertEqual(set(latest_event_ids), {event_id_a, event_id_b})
# Run the background update and check it did the right thing
self.run_background_update()
@@ -211,9 +227,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
latest_event_ids = self.get_success(
self.store.get_latest_event_ids_in_room(self.room_id)
)
- self.assertEqual(
- set(latest_event_ids), set((event_id_a, event_id_b, event_id_c))
- )
+ self.assertEqual(set(latest_event_ids), {event_id_a, event_id_b, event_id_c})
# Run the background update and check it did the right thing
self.run_background_update()
@@ -221,10 +235,18 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
latest_event_ids = self.get_success(
self.store.get_latest_event_ids_in_room(self.room_id)
)
- self.assertEqual(set(latest_event_ids), set([event_id_b, event_id_c]))
+ self.assertEqual(set(latest_event_ids), {event_id_b, event_id_c})
class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
+ CONSENT_VERSION = "1"
+ EXTREMITIES_COUNT = 50
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
def make_homeserver(self, reactor, clock):
config = self.default_config()
config["cleanup_extremities_with_dummy_events"] = True
@@ -233,28 +255,39 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
def prepare(self, reactor, clock, homeserver):
self.store = homeserver.get_datastore()
self.room_creator = homeserver.get_room_creation_handler()
+ self.event_creator_handler = homeserver.get_event_creation_handler()
# Create a test user and room
- self.user = UserID("alice", "test")
+ self.user = UserID.from_string(self.register_user("user1", "password"))
+ self.token1 = self.login("user1", "password")
self.requester = Requester(self.user, None, False, None, None)
- info = self.get_success(self.room_creator.create_room(self.requester, {}))
+ info, _ = self.get_success(self.room_creator.create_room(self.requester, {}))
self.room_id = info["room_id"]
+ self.event_creator = homeserver.get_event_creation_handler()
+ homeserver.config.user_consent_version = self.CONSENT_VERSION
def test_send_dummy_event(self):
- # Create a bushy graph with 50 extremities.
+ self._create_extremity_rich_graph()
- event_id_start = self.create_and_send_event(self.room_id, self.user)
-
- for _ in range(50):
- self.create_and_send_event(
- self.room_id, self.user, prev_event_ids=[event_id_start]
- )
+ # Pump the reactor repeatedly so that the background updates have a
+ # chance to run.
+ self.pump(10 * 60)
latest_event_ids = self.get_success(
self.store.get_latest_event_ids_in_room(self.room_id)
)
- self.assertEqual(len(latest_event_ids), 50)
+ self.assertTrue(len(latest_event_ids) < 10, len(latest_event_ids))
+ @patch("synapse.handlers.message._DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY", new=0)
+ def test_send_dummy_events_when_insufficient_power(self):
+ self._create_extremity_rich_graph()
+ # Criple power levels
+ self.helper.send_state(
+ self.room_id,
+ EventTypes.PowerLevels,
+ body={"users": {str(self.user): -1}},
+ tok=self.token1,
+ )
# Pump the reactor repeatedly so that the background updates have a
# chance to run.
self.pump(10 * 60)
@@ -262,4 +295,108 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
latest_event_ids = self.get_success(
self.store.get_latest_event_ids_in_room(self.room_id)
)
+ # Check that the room has not been pruned
+ self.assertTrue(len(latest_event_ids) > 10)
+
+ # New user with regular levels
+ user2 = self.register_user("user2", "password")
+ token2 = self.login("user2", "password")
+ self.helper.join(self.room_id, user2, tok=token2)
+ self.pump(10 * 60)
+
+ latest_event_ids = self.get_success(
+ self.store.get_latest_event_ids_in_room(self.room_id)
+ )
self.assertTrue(len(latest_event_ids) < 10, len(latest_event_ids))
+
+ @patch("synapse.handlers.message._DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY", new=0)
+ def test_send_dummy_event_without_consent(self):
+ self._create_extremity_rich_graph()
+ self._enable_consent_checking()
+
+ # Pump the reactor repeatedly so that the background updates have a
+ # chance to run. Attempt to add dummy event with user that has not consented
+ # Check that dummy event send fails.
+ self.pump(10 * 60)
+ latest_event_ids = self.get_success(
+ self.store.get_latest_event_ids_in_room(self.room_id)
+ )
+ self.assertTrue(len(latest_event_ids) == self.EXTREMITIES_COUNT)
+
+ # Create new user, and add consent
+ user2 = self.register_user("user2", "password")
+ token2 = self.login("user2", "password")
+ self.get_success(
+ self.store.user_set_consent_version(user2, self.CONSENT_VERSION)
+ )
+ self.helper.join(self.room_id, user2, tok=token2)
+
+ # Background updates should now cause a dummy event to be added to the graph
+ self.pump(10 * 60)
+
+ latest_event_ids = self.get_success(
+ self.store.get_latest_event_ids_in_room(self.room_id)
+ )
+ self.assertTrue(len(latest_event_ids) < 10, len(latest_event_ids))
+
+ @patch("synapse.handlers.message._DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY", new=250)
+ def test_expiry_logic(self):
+ """Simple test to ensure that _expire_rooms_to_exclude_from_dummy_event_insertion()
+ expires old entries correctly.
+ """
+ self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion[
+ "1"
+ ] = 100000
+ self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion[
+ "2"
+ ] = 200000
+ self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion[
+ "3"
+ ] = 300000
+ self.event_creator_handler._expire_rooms_to_exclude_from_dummy_event_insertion()
+ # All entries within time frame
+ self.assertEqual(
+ len(
+ self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion
+ ),
+ 3,
+ )
+ # Oldest room to expire
+ self.pump(1)
+ self.event_creator_handler._expire_rooms_to_exclude_from_dummy_event_insertion()
+ self.assertEqual(
+ len(
+ self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion
+ ),
+ 2,
+ )
+ # All rooms to expire
+ self.pump(2)
+ self.assertEqual(
+ len(
+ self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion
+ ),
+ 0,
+ )
+
+ def _create_extremity_rich_graph(self):
+ """Helper method to create bushy graph on demand"""
+
+ event_id_start = self.create_and_send_event(self.room_id, self.user)
+
+ for _ in range(self.EXTREMITIES_COUNT):
+ self.create_and_send_event(
+ self.room_id, self.user, prev_event_ids=[event_id_start]
+ )
+
+ latest_event_ids = self.get_success(
+ self.store.get_latest_event_ids_in_room(self.room_id)
+ )
+ self.assertEqual(len(latest_event_ids), 50)
+
+ def _enable_consent_checking(self):
+ """Helper method to enable consent checking"""
+ self.event_creator._block_events_without_consent_error = "No consent from user"
+ consent_uri_builder = Mock()
+ consent_uri_builder.build_user_consent_uri.return_value = "http://example.com"
+ self.event_creator._consent_uri_builder = consent_uri_builder
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index 09305c3bf1..3b483bc7f0 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -23,6 +23,7 @@ from synapse.http.site import XForwardedForRequest
from synapse.rest.client.v1 import login
from tests import unittest
+from tests.unittest import override_config
class ClientIpStoreTestCase(unittest.HomeserverTestCase):
@@ -37,9 +38,13 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.reactor.advance(12345678)
user_id = "@user:id"
+ device_id = "MY_DEVICE"
+
+ # Insert a user IP
+ self.get_success(self.store.store_device(user_id, device_id, "display name",))
self.get_success(
self.store.insert_client_ip(
- user_id, "access_token", "ip", "user_agent", "device_id"
+ user_id, "access_token", "ip", "user_agent", device_id
)
)
@@ -47,15 +52,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.reactor.advance(10)
result = self.get_success(
- self.store.get_last_client_ip_by_device(user_id, "device_id")
+ self.store.get_last_client_ip_by_device(user_id, device_id)
)
- r = result[(user_id, "device_id")]
+ r = result[(user_id, device_id)]
self.assertDictContainsSubset(
{
"user_id": user_id,
- "device_id": "device_id",
- "access_token": "access_token",
+ "device_id": device_id,
"ip": "ip",
"user_agent": "user_agent",
"last_seen": 12345678000,
@@ -82,7 +86,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.pump(0)
result = self.get_success(
- self.store._simple_select_list(
+ self.store.db.simple_select_list(
table="user_ips",
keyvalues={"user_id": user_id},
retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
@@ -113,7 +117,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.pump(0)
result = self.get_success(
- self.store._simple_select_list(
+ self.store.db.simple_select_list(
table="user_ips",
keyvalues={"user_id": user_id},
retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
@@ -134,9 +138,8 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
],
)
+ @override_config({"limit_usage_by_mau": False, "max_mau_value": 50})
def test_disabled_monthly_active_user(self):
- self.hs.config.limit_usage_by_mau = False
- self.hs.config.max_mau_value = 50
user_id = "@user:server"
self.get_success(
self.store.insert_client_ip(
@@ -146,9 +149,8 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
self.assertFalse(active)
+ @override_config({"limit_usage_by_mau": True, "max_mau_value": 50})
def test_adding_monthly_active_user_when_full(self):
- self.hs.config.limit_usage_by_mau = True
- self.hs.config.max_mau_value = 50
lots_of_users = 100
user_id = "@user:server"
@@ -163,9 +165,8 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
self.assertFalse(active)
+ @override_config({"limit_usage_by_mau": True, "max_mau_value": 50})
def test_adding_monthly_active_user_when_space(self):
- self.hs.config.limit_usage_by_mau = True
- self.hs.config.max_mau_value = 50
user_id = "@user:server"
active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
self.assertFalse(active)
@@ -181,9 +182,8 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
self.assertTrue(active)
+ @override_config({"limit_usage_by_mau": True, "max_mau_value": 50})
def test_updating_monthly_active_user_when_space(self):
- self.hs.config.limit_usage_by_mau = True
- self.hs.config.max_mau_value = 50
user_id = "@user:server"
self.get_success(self.store.register_user(user_id=user_id, password_hash=None))
@@ -201,6 +201,173 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
self.assertTrue(active)
+ def test_devices_last_seen_bg_update(self):
+ # First make sure we have completed all updates.
+ while not self.get_success(
+ self.store.db.updates.has_completed_background_updates()
+ ):
+ self.get_success(
+ self.store.db.updates.do_next_background_update(100), by=0.1
+ )
+
+ user_id = "@user:id"
+ device_id = "MY_DEVICE"
+
+ # Insert a user IP
+ self.get_success(self.store.store_device(user_id, device_id, "display name",))
+ self.get_success(
+ self.store.insert_client_ip(
+ user_id, "access_token", "ip", "user_agent", device_id
+ )
+ )
+ # Force persisting to disk
+ self.reactor.advance(200)
+
+ # But clear the associated entry in devices table
+ self.get_success(
+ self.store.db.simple_update(
+ table="devices",
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ updatevalues={"last_seen": None, "ip": None, "user_agent": None},
+ desc="test_devices_last_seen_bg_update",
+ )
+ )
+
+ # We should now get nulls when querying
+ result = self.get_success(
+ self.store.get_last_client_ip_by_device(user_id, device_id)
+ )
+
+ r = result[(user_id, device_id)]
+ self.assertDictContainsSubset(
+ {
+ "user_id": user_id,
+ "device_id": device_id,
+ "ip": None,
+ "user_agent": None,
+ "last_seen": None,
+ },
+ r,
+ )
+
+ # Register the background update to run again.
+ self.get_success(
+ self.store.db.simple_insert(
+ table="background_updates",
+ values={
+ "update_name": "devices_last_seen",
+ "progress_json": "{}",
+ "depends_on": None,
+ },
+ )
+ )
+
+ # ... and tell the DataStore that it hasn't finished all updates yet
+ self.store.db.updates._all_done = False
+
+ # Now let's actually drive the updates to completion
+ while not self.get_success(
+ self.store.db.updates.has_completed_background_updates()
+ ):
+ self.get_success(
+ self.store.db.updates.do_next_background_update(100), by=0.1
+ )
+
+ # We should now get the correct result again
+ result = self.get_success(
+ self.store.get_last_client_ip_by_device(user_id, device_id)
+ )
+
+ r = result[(user_id, device_id)]
+ self.assertDictContainsSubset(
+ {
+ "user_id": user_id,
+ "device_id": device_id,
+ "ip": "ip",
+ "user_agent": "user_agent",
+ "last_seen": 0,
+ },
+ r,
+ )
+
+ def test_old_user_ips_pruned(self):
+ # First make sure we have completed all updates.
+ while not self.get_success(
+ self.store.db.updates.has_completed_background_updates()
+ ):
+ self.get_success(
+ self.store.db.updates.do_next_background_update(100), by=0.1
+ )
+
+ user_id = "@user:id"
+ device_id = "MY_DEVICE"
+
+ # Insert a user IP
+ self.get_success(self.store.store_device(user_id, device_id, "display name",))
+ self.get_success(
+ self.store.insert_client_ip(
+ user_id, "access_token", "ip", "user_agent", device_id
+ )
+ )
+
+ # Force persisting to disk
+ self.reactor.advance(200)
+
+ # We should see that in the DB
+ result = self.get_success(
+ self.store.db.simple_select_list(
+ table="user_ips",
+ keyvalues={"user_id": user_id},
+ retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
+ desc="get_user_ip_and_agents",
+ )
+ )
+
+ self.assertEqual(
+ result,
+ [
+ {
+ "access_token": "access_token",
+ "ip": "ip",
+ "user_agent": "user_agent",
+ "device_id": device_id,
+ "last_seen": 0,
+ }
+ ],
+ )
+
+ # Now advance by a couple of months
+ self.reactor.advance(60 * 24 * 60 * 60)
+
+ # We should get no results.
+ result = self.get_success(
+ self.store.db.simple_select_list(
+ table="user_ips",
+ keyvalues={"user_id": user_id},
+ retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
+ desc="get_user_ip_and_agents",
+ )
+ )
+
+ self.assertEqual(result, [])
+
+ # But we should still get the correct values for the device
+ result = self.get_success(
+ self.store.get_last_client_ip_by_device(user_id, device_id)
+ )
+
+ r = result[(user_id, device_id)]
+ self.assertDictContainsSubset(
+ {
+ "user_id": user_id,
+ "device_id": device_id,
+ "ip": "ip",
+ "user_agent": "user_agent",
+ "last_seen": 0,
+ },
+ r,
+ )
+
class ClientIpAuthTestCase(unittest.HomeserverTestCase):
diff --git a/tests/storage/test_database.py b/tests/storage/test_database.py
new file mode 100644
index 0000000000..5a77c84962
--- /dev/null
+++ b/tests/storage/test_database.py
@@ -0,0 +1,52 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.storage.database import make_tuple_comparison_clause
+from synapse.storage.engines import BaseDatabaseEngine
+
+from tests import unittest
+
+
+def _stub_db_engine(**kwargs) -> BaseDatabaseEngine:
+ # returns a DatabaseEngine, circumventing the abc mechanism
+ # any kwargs are set as attributes on the class before instantiating it
+ t = type(
+ "TestBaseDatabaseEngine",
+ (BaseDatabaseEngine,),
+ dict(BaseDatabaseEngine.__dict__),
+ )
+ # defeat the abc mechanism
+ t.__abstractmethods__ = set()
+ for k, v in kwargs.items():
+ setattr(t, k, v)
+ return t(None, None)
+
+
+class TupleComparisonClauseTestCase(unittest.TestCase):
+ def test_native_tuple_comparison(self):
+ db_engine = _stub_db_engine(supports_tuple_comparison=True)
+ clause, args = make_tuple_comparison_clause(db_engine, [("a", 1), ("b", 2)])
+ self.assertEqual(clause, "(a,b) > (?,?)")
+ self.assertEqual(args, [1, 2])
+
+ def test_emulated_tuple_comparison(self):
+ db_engine = _stub_db_engine(supports_tuple_comparison=False)
+ clause, args = make_tuple_comparison_clause(
+ db_engine, [("a", 1), ("b", 2), ("c", 3)]
+ )
+ self.assertEqual(
+ clause, "(a >= ? AND (a > ? OR (b >= ? AND (b > ? OR c > ?))))"
+ )
+ self.assertEqual(args, [1, 1, 2, 2, 3])
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index 3cc18f9f1c..c2539b353a 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -72,7 +72,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
)
@defer.inlineCallbacks
- def test_get_devices_by_remote(self):
+ def test_get_device_updates_by_remote(self):
device_ids = ["device_id1", "device_id2"]
# Add two device updates with a single stream_id
@@ -81,63 +81,20 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
)
# Get all device updates ever meant for this remote
- now_stream_id, device_updates = yield self.store.get_devices_by_remote(
+ now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
"somehost", -1, limit=100
)
# Check original device_ids are contained within these updates
self._check_devices_in_updates(device_ids, device_updates)
- @defer.inlineCallbacks
- def test_get_devices_by_remote_limited(self):
- # Test breaking the update limit in 1, 101, and 1 device_id segments
-
- # first add one device
- device_ids1 = ["device_id0"]
- yield self.store.add_device_change_to_streams(
- "user_id", device_ids1, ["someotherhost"]
- )
-
- # then add 101
- device_ids2 = ["device_id" + str(i + 1) for i in range(101)]
- yield self.store.add_device_change_to_streams(
- "user_id", device_ids2, ["someotherhost"]
- )
-
- # then one more
- device_ids3 = ["newdevice"]
- yield self.store.add_device_change_to_streams(
- "user_id", device_ids3, ["someotherhost"]
- )
-
- #
- # now read them back.
- #
-
- # first we should get a single update
- now_stream_id, device_updates = yield self.store.get_devices_by_remote(
- "someotherhost", -1, limit=100
- )
- self._check_devices_in_updates(device_ids1, device_updates)
-
- # Then we should get an empty list back as the 101 devices broke the limit
- now_stream_id, device_updates = yield self.store.get_devices_by_remote(
- "someotherhost", now_stream_id, limit=100
- )
- self.assertEqual(len(device_updates), 0)
-
- # The 101 devices should've been cleared, so we should now just get one device
- # update
- now_stream_id, device_updates = yield self.store.get_devices_by_remote(
- "someotherhost", now_stream_id, limit=100
- )
- self._check_devices_in_updates(device_ids3, device_updates)
-
def _check_devices_in_updates(self, expected_device_ids, device_updates):
"""Check that an specific device ids exist in a list of device update EDUs"""
self.assertEqual(len(device_updates), len(expected_device_ids))
- received_device_ids = {update["device_id"] for update in device_updates}
+ received_device_ids = {
+ update["device_id"] for edu_type, update in device_updates
+ }
self.assertEqual(received_device_ids, set(expected_device_ids))
@defer.inlineCallbacks
diff --git a/tests/storage/test_e2e_room_keys.py b/tests/storage/test_e2e_room_keys.py
new file mode 100644
index 0000000000..35dafbb904
--- /dev/null
+++ b/tests/storage/test_e2e_room_keys.py
@@ -0,0 +1,75 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from tests import unittest
+
+# sample room_key data for use in the tests
+room_key = {
+ "first_message_index": 1,
+ "forwarded_count": 1,
+ "is_verified": False,
+ "session_data": "SSBBTSBBIEZJU0gK",
+}
+
+
+class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
+ def make_homeserver(self, reactor, clock):
+ hs = self.setup_test_homeserver("server", http_client=None)
+ self.store = hs.get_datastore()
+ return hs
+
+ def test_room_keys_version_delete(self):
+ # test that deleting a room key backup deletes the keys
+ version1 = self.get_success(
+ self.store.create_e2e_room_keys_version(
+ "user_id", {"algorithm": "rot13", "auth_data": {}}
+ )
+ )
+
+ self.get_success(
+ self.store.add_e2e_room_keys(
+ "user_id", version1, [("room", "session", room_key)]
+ )
+ )
+
+ version2 = self.get_success(
+ self.store.create_e2e_room_keys_version(
+ "user_id", {"algorithm": "rot13", "auth_data": {}}
+ )
+ )
+
+ self.get_success(
+ self.store.add_e2e_room_keys(
+ "user_id", version2, [("room", "session", room_key)]
+ )
+ )
+
+ # make sure the keys were stored properly
+ keys = self.get_success(self.store.get_e2e_room_keys("user_id", version1))
+ self.assertEqual(len(keys["rooms"]), 1)
+
+ keys = self.get_success(self.store.get_e2e_room_keys("user_id", version2))
+ self.assertEqual(len(keys["rooms"]), 1)
+
+ # delete version1
+ self.get_success(self.store.delete_e2e_room_keys_version("user_id", version1))
+
+ # make sure the key from version1 is gone, and the key from version2 is
+ # still there
+ keys = self.get_success(self.store.get_e2e_room_keys("user_id", version1))
+ self.assertEqual(len(keys["rooms"]), 0)
+
+ keys = self.get_success(self.store.get_e2e_room_keys("user_id", version2))
+ self.assertEqual(len(keys["rooms"]), 1)
diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py
index c8ece15284..398d546280 100644
--- a/tests/storage/test_end_to_end_keys.py
+++ b/tests/storage/test_end_to_end_keys.py
@@ -38,7 +38,7 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
self.assertIn("user", res)
self.assertIn("device", res["user"])
dev = res["user"]["device"]
- self.assertDictContainsSubset({"keys": json, "device_display_name": None}, dev)
+ self.assertDictContainsSubset(json, dev)
@defer.inlineCallbacks
def test_reupload_key(self):
@@ -68,7 +68,7 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
self.assertIn("device", res["user"])
dev = res["user"]["device"]
self.assertDictContainsSubset(
- {"keys": json, "device_display_name": "display_name"}, dev
+ {"key": "value", "unsigned": {"device_display_name": "display_name"}}, dev
)
@defer.inlineCallbacks
@@ -80,10 +80,10 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
yield self.store.store_device("user2", "device1", None)
yield self.store.store_device("user2", "device2", None)
- yield self.store.set_e2e_device_keys("user1", "device1", now, "json11")
- yield self.store.set_e2e_device_keys("user1", "device2", now, "json12")
- yield self.store.set_e2e_device_keys("user2", "device1", now, "json21")
- yield self.store.set_e2e_device_keys("user2", "device2", now, "json22")
+ yield self.store.set_e2e_device_keys("user1", "device1", now, {"key": "json11"})
+ yield self.store.set_e2e_device_keys("user1", "device2", now, {"key": "json12"})
+ yield self.store.set_e2e_device_keys("user2", "device1", now, {"key": "json21"})
+ yield self.store.set_e2e_device_keys("user2", "device2", now, {"key": "json22"})
res = yield self.store.get_e2e_device_keys(
(("user1", "device1"), ("user2", "device2"))
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 86c7ac350d..3aeec0dc0f 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -13,19 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
-
import tests.unittest
import tests.utils
-class EventFederationWorkerStoreTestCase(tests.unittest.TestCase):
- @defer.inlineCallbacks
- def setUp(self):
- hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
+class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
- @defer.inlineCallbacks
def test_get_prev_events_for_room(self):
room_id = "@ROOM:local"
@@ -57,21 +52,182 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase):
"(event_id, algorithm, hash) "
"VALUES (?, 'sha256', ?)"
),
- (event_id, b"ffff"),
+ (event_id, bytearray(b"ffff")),
)
- for i in range(0, 11):
- yield self.store.runInteraction("insert", insert_event, i)
+ for i in range(0, 20):
+ self.get_success(self.store.db.runInteraction("insert", insert_event, i))
- # this should get the last five and five others
- r = yield self.store.get_prev_events_for_room(room_id)
+ # this should get the last ten
+ r = self.get_success(self.store.get_prev_events_for_room(room_id))
self.assertEqual(10, len(r))
- for i in range(0, 5):
- el = r[i]
- depth = el[2]
- self.assertEqual(10 - i, depth)
-
- for i in range(5, 5):
- el = r[i]
- depth = el[2]
- self.assertLessEqual(5, depth)
+ for i in range(0, 10):
+ self.assertEqual("$event_%i:local" % (19 - i), r[i])
+
+ def test_get_rooms_with_many_extremities(self):
+ room1 = "#room1"
+ room2 = "#room2"
+ room3 = "#room3"
+
+ def insert_event(txn, i, room_id):
+ event_id = "$event_%i:local" % i
+ txn.execute(
+ (
+ "INSERT INTO event_forward_extremities (room_id, event_id) "
+ "VALUES (?, ?)"
+ ),
+ (room_id, event_id),
+ )
+
+ for i in range(0, 20):
+ self.get_success(
+ self.store.db.runInteraction("insert", insert_event, i, room1)
+ )
+ self.get_success(
+ self.store.db.runInteraction("insert", insert_event, i, room2)
+ )
+ self.get_success(
+ self.store.db.runInteraction("insert", insert_event, i, room3)
+ )
+
+ # Test simple case
+ r = self.get_success(self.store.get_rooms_with_many_extremities(5, 5, []))
+ self.assertEqual(len(r), 3)
+
+ # Does filter work?
+
+ r = self.get_success(self.store.get_rooms_with_many_extremities(5, 5, [room1]))
+ self.assertTrue(room2 in r)
+ self.assertTrue(room3 in r)
+ self.assertEqual(len(r), 2)
+
+ r = self.get_success(
+ self.store.get_rooms_with_many_extremities(5, 5, [room1, room2])
+ )
+ self.assertEqual(r, [room3])
+
+ # Does filter and limit work?
+
+ r = self.get_success(self.store.get_rooms_with_many_extremities(5, 1, [room1]))
+ self.assertTrue(r == [room2] or r == [room3])
+
+ def test_auth_difference(self):
+ room_id = "@ROOM:local"
+
+ # The silly auth graph we use to test the auth difference algorithm,
+ # where the top are the most recent events.
+ #
+ # A B
+ # \ /
+ # D E
+ # \ |
+ # ` F C
+ # | /|
+ # G ´ |
+ # | \ |
+ # H I
+ # | |
+ # K J
+
+ auth_graph = {
+ "a": ["e"],
+ "b": ["e"],
+ "c": ["g", "i"],
+ "d": ["f"],
+ "e": ["f"],
+ "f": ["g"],
+ "g": ["h", "i"],
+ "h": ["k"],
+ "i": ["j"],
+ "k": [],
+ "j": [],
+ }
+
+ depth_map = {
+ "a": 7,
+ "b": 7,
+ "c": 4,
+ "d": 6,
+ "e": 6,
+ "f": 5,
+ "g": 3,
+ "h": 2,
+ "i": 2,
+ "k": 1,
+ "j": 1,
+ }
+
+ # We rudely fiddle with the appropriate tables directly, as that's much
+ # easier than constructing events properly.
+
+ def insert_event(txn, event_id, stream_ordering):
+
+ depth = depth_map[event_id]
+
+ self.store.db.simple_insert_txn(
+ txn,
+ table="events",
+ values={
+ "event_id": event_id,
+ "room_id": room_id,
+ "depth": depth,
+ "topological_ordering": depth,
+ "type": "m.test",
+ "processed": True,
+ "outlier": False,
+ "stream_ordering": stream_ordering,
+ },
+ )
+
+ self.store.db.simple_insert_many_txn(
+ txn,
+ table="event_auth",
+ values=[
+ {"event_id": event_id, "room_id": room_id, "auth_id": a}
+ for a in auth_graph[event_id]
+ ],
+ )
+
+ next_stream_ordering = 0
+ for event_id in auth_graph:
+ next_stream_ordering += 1
+ self.get_success(
+ self.store.db.runInteraction(
+ "insert", insert_event, event_id, next_stream_ordering
+ )
+ )
+
+ # Now actually test that various combinations give the right result:
+
+ difference = self.get_success(
+ self.store.get_auth_chain_difference([{"a"}, {"b"}])
+ )
+ self.assertSetEqual(difference, {"a", "b"})
+
+ difference = self.get_success(
+ self.store.get_auth_chain_difference([{"a"}, {"b"}, {"c"}])
+ )
+ self.assertSetEqual(difference, {"a", "b", "c", "e", "f"})
+
+ difference = self.get_success(
+ self.store.get_auth_chain_difference([{"a", "c"}, {"b"}])
+ )
+ self.assertSetEqual(difference, {"a", "b", "c"})
+
+ difference = self.get_success(
+ self.store.get_auth_chain_difference([{"a"}, {"b"}, {"d"}])
+ )
+ self.assertSetEqual(difference, {"a", "b", "d", "e"})
+
+ difference = self.get_success(
+ self.store.get_auth_chain_difference([{"a"}, {"b"}, {"c"}, {"d"}])
+ )
+ self.assertSetEqual(difference, {"a", "b", "c", "d", "e", "f"})
+
+ difference = self.get_success(
+ self.store.get_auth_chain_difference([{"a"}, {"b"}, {"e"}])
+ )
+ self.assertSetEqual(difference, {"a", "b"})
+
+ difference = self.get_success(self.store.get_auth_chain_difference([{"a"}]))
+ self.assertSetEqual(difference, set())
diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py
index f26ff57a18..a7b85004e5 100644
--- a/tests/storage/test_event_metrics.py
+++ b/tests/storage/test_event_metrics.py
@@ -33,7 +33,7 @@ class ExtremStatisticsTestCase(HomeserverTestCase):
events = [(3, 2), (6, 2), (4, 6)]
for event_count, extrems in events:
- info = self.get_success(room_creator.create_room(requester, {}))
+ info, _ = self.get_success(room_creator.create_room(requester, {}))
room_id = info["room_id"]
last_event = None
@@ -59,24 +59,22 @@ class ExtremStatisticsTestCase(HomeserverTestCase):
)
)
- expected = set(
- [
- b'synapse_forward_extremities_bucket{le="1.0"} 0.0',
- b'synapse_forward_extremities_bucket{le="2.0"} 2.0',
- b'synapse_forward_extremities_bucket{le="3.0"} 2.0',
- b'synapse_forward_extremities_bucket{le="5.0"} 2.0',
- b'synapse_forward_extremities_bucket{le="7.0"} 3.0',
- b'synapse_forward_extremities_bucket{le="10.0"} 3.0',
- b'synapse_forward_extremities_bucket{le="15.0"} 3.0',
- b'synapse_forward_extremities_bucket{le="20.0"} 3.0',
- b'synapse_forward_extremities_bucket{le="50.0"} 3.0',
- b'synapse_forward_extremities_bucket{le="100.0"} 3.0',
- b'synapse_forward_extremities_bucket{le="200.0"} 3.0',
- b'synapse_forward_extremities_bucket{le="500.0"} 3.0',
- b'synapse_forward_extremities_bucket{le="+Inf"} 3.0',
- b"synapse_forward_extremities_count 3.0",
- b"synapse_forward_extremities_sum 10.0",
- ]
- )
+ expected = {
+ b'synapse_forward_extremities_bucket{le="1.0"} 0.0',
+ b'synapse_forward_extremities_bucket{le="2.0"} 2.0',
+ b'synapse_forward_extremities_bucket{le="3.0"} 2.0',
+ b'synapse_forward_extremities_bucket{le="5.0"} 2.0',
+ b'synapse_forward_extremities_bucket{le="7.0"} 3.0',
+ b'synapse_forward_extremities_bucket{le="10.0"} 3.0',
+ b'synapse_forward_extremities_bucket{le="15.0"} 3.0',
+ b'synapse_forward_extremities_bucket{le="20.0"} 3.0',
+ b'synapse_forward_extremities_bucket{le="50.0"} 3.0',
+ b'synapse_forward_extremities_bucket{le="100.0"} 3.0',
+ b'synapse_forward_extremities_bucket{le="200.0"} 3.0',
+ b'synapse_forward_extremities_bucket{le="500.0"} 3.0',
+ b'synapse_forward_extremities_bucket{le="+Inf"} 3.0',
+ b"synapse_forward_extremities_count 3.0",
+ b"synapse_forward_extremities_sum 10.0",
+ }
self.assertEqual(items, expected)
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index b114c6fb1d..b45bc9c115 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -35,6 +35,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
def setUp(self):
hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
self.store = hs.get_datastore()
+ self.persist_events_store = hs.get_datastores().persist_events
@defer.inlineCallbacks
def test_get_unread_push_actions_for_user_in_range_for_http(self):
@@ -55,7 +56,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks
def _assert_counts(noitf_count, highlight_count):
- counts = yield self.store.runInteraction(
+ counts = yield self.store.db.runInteraction(
"", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0
)
self.assertEquals(
@@ -74,20 +75,20 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
yield self.store.add_push_actions_to_staging(
event.event_id, {user_id: action}
)
- yield self.store.runInteraction(
+ yield self.store.db.runInteraction(
"",
- self.store._set_push_actions_for_event_and_users_txn,
+ self.persist_events_store._set_push_actions_for_event_and_users_txn,
[(event, None)],
[(event, None)],
)
def _rotate(stream):
- return self.store.runInteraction(
+ return self.store.db.runInteraction(
"", self.store._rotate_notifs_before_txn, stream
)
def _mark_read(stream, depth):
- return self.store.runInteraction(
+ return self.store.db.runInteraction(
"",
self.store._remove_old_push_actions_before_txn,
room_id,
@@ -116,7 +117,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
yield _inject_actions(6, PlAIN_NOTIF)
yield _rotate(7)
- yield self.store._simple_delete(
+ yield self.store.db.simple_delete(
table="event_push_actions", keyvalues={"1": 1}, desc=""
)
@@ -135,7 +136,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks
def test_find_first_stream_ordering_after_ts(self):
def add_event(so, ts):
- return self.store._simple_insert(
+ return self.store.db.simple_insert(
"events",
{
"stream_ordering": so,
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
new file mode 100644
index 0000000000..55e9ecf264
--- /dev/null
+++ b/tests/storage/test_id_generators.py
@@ -0,0 +1,184 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from synapse.storage.database import Database
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
+
+from tests.unittest import HomeserverTestCase
+from tests.utils import USE_POSTGRES_FOR_TESTS
+
+
+class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
+ if not USE_POSTGRES_FOR_TESTS:
+ skip = "Requires Postgres"
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+ self.db = self.store.db # type: Database
+
+ self.get_success(self.db.runInteraction("_setup_db", self._setup_db))
+
+ def _setup_db(self, txn):
+ txn.execute("CREATE SEQUENCE foobar_seq")
+ txn.execute(
+ """
+ CREATE TABLE foobar (
+ stream_id BIGINT NOT NULL,
+ instance_name TEXT NOT NULL,
+ data TEXT
+ );
+ """
+ )
+
+ def _create_id_generator(self, instance_name="master") -> MultiWriterIdGenerator:
+ def _create(conn):
+ return MultiWriterIdGenerator(
+ conn,
+ self.db,
+ instance_name=instance_name,
+ table="foobar",
+ instance_column="instance_name",
+ id_column="stream_id",
+ sequence_name="foobar_seq",
+ )
+
+ return self.get_success(self.db.runWithConnection(_create))
+
+ def _insert_rows(self, instance_name: str, number: int):
+ def _insert(txn):
+ for _ in range(number):
+ txn.execute(
+ "INSERT INTO foobar VALUES (nextval('foobar_seq'), ?)",
+ (instance_name,),
+ )
+
+ self.get_success(self.db.runInteraction("test_single_instance", _insert))
+
+ def test_empty(self):
+ """Test an ID generator against an empty database gives sensible
+ current positions.
+ """
+
+ id_gen = self._create_id_generator()
+
+ # The table is empty so we expect an empty map for positions
+ self.assertEqual(id_gen.get_positions(), {})
+
+ def test_single_instance(self):
+ """Test that reads and writes from a single process are handled
+ correctly.
+ """
+
+ # Prefill table with 7 rows written by 'master'
+ self._insert_rows("master", 7)
+
+ id_gen = self._create_id_generator()
+
+ self.assertEqual(id_gen.get_positions(), {"master": 7})
+ self.assertEqual(id_gen.get_current_token("master"), 7)
+
+ # Try allocating a new ID gen and check that we only see position
+ # advanced after we leave the context manager.
+
+ async def _get_next_async():
+ with await id_gen.get_next() as stream_id:
+ self.assertEqual(stream_id, 8)
+
+ self.assertEqual(id_gen.get_positions(), {"master": 7})
+ self.assertEqual(id_gen.get_current_token("master"), 7)
+
+ self.get_success(_get_next_async())
+
+ self.assertEqual(id_gen.get_positions(), {"master": 8})
+ self.assertEqual(id_gen.get_current_token("master"), 8)
+
+ def test_multi_instance(self):
+ """Test that reads and writes from multiple processes are handled
+ correctly.
+ """
+ self._insert_rows("first", 3)
+ self._insert_rows("second", 4)
+
+ first_id_gen = self._create_id_generator("first")
+ second_id_gen = self._create_id_generator("second")
+
+ self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
+ self.assertEqual(first_id_gen.get_current_token("first"), 3)
+ self.assertEqual(first_id_gen.get_current_token("second"), 7)
+
+ # Try allocating a new ID gen and check that we only see position
+ # advanced after we leave the context manager.
+
+ async def _get_next_async():
+ with await first_id_gen.get_next() as stream_id:
+ self.assertEqual(stream_id, 8)
+
+ self.assertEqual(
+ first_id_gen.get_positions(), {"first": 3, "second": 7}
+ )
+
+ self.get_success(_get_next_async())
+
+ self.assertEqual(first_id_gen.get_positions(), {"first": 8, "second": 7})
+
+ # However the ID gen on the second instance won't have seen the update
+ self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7})
+
+ # ... but calling `get_next` on the second instance should give a unique
+ # stream ID
+
+ async def _get_next_async():
+ with await second_id_gen.get_next() as stream_id:
+ self.assertEqual(stream_id, 9)
+
+ self.assertEqual(
+ second_id_gen.get_positions(), {"first": 3, "second": 7}
+ )
+
+ self.get_success(_get_next_async())
+
+ self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 9})
+
+ # If the second ID gen gets told about the first, it correctly updates
+ second_id_gen.advance("first", 8)
+ self.assertEqual(second_id_gen.get_positions(), {"first": 8, "second": 9})
+
+ def test_get_next_txn(self):
+ """Test that the `get_next_txn` function works correctly.
+ """
+
+ # Prefill table with 7 rows written by 'master'
+ self._insert_rows("master", 7)
+
+ id_gen = self._create_id_generator()
+
+ self.assertEqual(id_gen.get_positions(), {"master": 7})
+ self.assertEqual(id_gen.get_current_token("master"), 7)
+
+ # Try allocating a new ID gen and check that we only see position
+ # advanced after we leave the context manager.
+
+ def _get_next_txn(txn):
+ stream_id = id_gen.get_next_txn(txn)
+ self.assertEqual(stream_id, 8)
+
+ self.assertEqual(id_gen.get_positions(), {"master": 7})
+ self.assertEqual(id_gen.get_current_token("master"), 7)
+
+ self.get_success(self.db.runInteraction("test", _get_next_txn))
+
+ self.assertEqual(id_gen.get_positions(), {"master": 8})
+ self.assertEqual(id_gen.get_current_token("master"), 8)
diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py
index e07ff01201..95f309fbbc 100644
--- a/tests/storage/test_keys.py
+++ b/tests/storage/test_keys.py
@@ -14,6 +14,7 @@
# limitations under the License.
import signedjson.key
+import unpaddedbase64
from twisted.internet.defer import Deferred
@@ -21,11 +22,17 @@ from synapse.storage.keys import FetchKeyResult
import tests.unittest
-KEY_1 = signedjson.key.decode_verify_key_base64(
- "ed25519", "key1", "fP5l4JzpZPq/zdbBg5xx6lQGAAOM9/3w94cqiJ5jPrw"
+
+def decode_verify_key_base64(key_id: str, key_base64: str):
+ key_bytes = unpaddedbase64.decode_base64(key_base64)
+ return signedjson.key.decode_verify_key_bytes(key_id, key_bytes)
+
+
+KEY_1 = decode_verify_key_base64(
+ "ed25519:key1", "fP5l4JzpZPq/zdbBg5xx6lQGAAOM9/3w94cqiJ5jPrw"
)
-KEY_2 = signedjson.key.decode_verify_key_base64(
- "ed25519", "key2", "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw"
+KEY_2 = decode_verify_key_base64(
+ "ed25519:key2", "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw"
)
diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py
new file mode 100644
index 0000000000..ab0df5ea93
--- /dev/null
+++ b/tests/storage/test_main.py
@@ -0,0 +1,46 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Awesome Technologies Innovationslabor GmbH
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from twisted.internet import defer
+
+from synapse.types import UserID
+
+from tests import unittest
+from tests.utils import setup_test_homeserver
+
+
+class DataStoreTestCase(unittest.TestCase):
+ @defer.inlineCallbacks
+ def setUp(self):
+ hs = yield setup_test_homeserver(self.addCleanup)
+
+ self.store = hs.get_datastore()
+
+ self.user = UserID.from_string("@abcde:test")
+ self.displayname = "Frank"
+
+ @defer.inlineCallbacks
+ def test_get_users_paginate(self):
+ yield self.store.register_user(self.user.to_string(), "pass")
+ yield self.store.create_profile(self.user.localpart)
+ yield self.store.set_profile_displayname(self.user.localpart, self.displayname)
+
+ users, total = yield self.store.get_users_paginate(
+ 0, 10, name="bc", guests=False
+ )
+
+ self.assertEquals(1, total)
+ self.assertEquals(self.displayname, users.pop()["displayname"])
diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
index 1494650d10..9c04e92577 100644
--- a/tests/storage/test_monthly_active_users.py
+++ b/tests/storage/test_monthly_active_users.py
@@ -19,152 +19,222 @@ from twisted.internet import defer
from synapse.api.constants import UserTypes
from tests import unittest
+from tests.unittest import default_config, override_config
FORTY_DAYS = 40 * 24 * 60 * 60
+def gen_3pids(count):
+ """Generate `count` threepids as a list."""
+ return [
+ {"medium": "email", "address": "user%i@matrix.org" % i} for i in range(count)
+ ]
+
+
class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
- def make_homeserver(self, reactor, clock):
+ def default_config(self):
+ config = default_config("test")
+
+ config.update({"limit_usage_by_mau": True, "max_mau_value": 50})
+
+ # apply any additional config which was specified via the override_config
+ # decorator.
+ if self._extra_config is not None:
+ config.update(self._extra_config)
- hs = self.setup_test_homeserver()
- self.store = hs.get_datastore()
- hs.config.limit_usage_by_mau = True
- hs.config.max_mau_value = 50
+ return config
+ def prepare(self, reactor, clock, homeserver):
+ self.store = homeserver.get_datastore()
# Advance the clock a bit
reactor.advance(FORTY_DAYS)
- return hs
-
+ @override_config({"max_mau_value": 3, "mau_limit_reserved_threepids": gen_3pids(3)})
def test_initialise_reserved_users(self):
- self.hs.config.max_mau_value = 5
+ threepids = self.hs.config.mau_limits_reserved_threepids
+
+ # register three users, of which two have reserved 3pids, and a third
+ # which is a support user.
user1 = "@user1:server"
- user1_email = "user1@matrix.org"
+ user1_email = threepids[0]["address"]
user2 = "@user2:server"
- user2_email = "user2@matrix.org"
+ user2_email = threepids[1]["address"]
user3 = "@user3:server"
- user3_email = "user3@matrix.org"
-
- threepids = [
- {"medium": "email", "address": user1_email},
- {"medium": "email", "address": user2_email},
- {"medium": "email", "address": user3_email},
- ]
- # -1 because user3 is a support user and does not count
- user_num = len(threepids) - 1
- self.store.register_user(user_id=user1, password_hash=None)
- self.store.register_user(user_id=user2, password_hash=None)
- self.store.register_user(
- user_id=user3, password_hash=None, user_type=UserTypes.SUPPORT
+ self.get_success(self.store.register_user(user_id=user1))
+ self.get_success(self.store.register_user(user_id=user2))
+ self.get_success(
+ self.store.register_user(user_id=user3, user_type=UserTypes.SUPPORT)
)
- self.pump()
now = int(self.hs.get_clock().time_msec())
- self.store.user_add_threepid(user1, "email", user1_email, now, now)
- self.store.user_add_threepid(user2, "email", user2_email, now, now)
+ self.get_success(
+ self.store.user_add_threepid(user1, "email", user1_email, now, now)
+ )
+ self.get_success(
+ self.store.user_add_threepid(user2, "email", user2_email, now, now)
+ )
- self.store.runInteraction(
- "initialise", self.store._initialise_reserved_users, threepids
+ # XXX why are we doing this here? this function is only run at startup
+ # so it is odd to re-run it here.
+ self.get_success(
+ self.store.db.runInteraction(
+ "initialise", self.store._initialise_reserved_users, threepids
+ )
)
- self.pump()
- active_count = self.store.get_monthly_active_count()
+ # the number of users we expect will be counted against the mau limit
+ # -1 because user3 is a support user and does not count
+ user_num = len(threepids) - 1
- # Test total counts, ensure user3 (support user) is not counted
- self.assertEquals(self.get_success(active_count), user_num)
+ # Check the number of active users. Ensure user3 (support user) is not counted
+ active_count = self.get_success(self.store.get_monthly_active_count())
+ self.assertEqual(active_count, user_num)
- # Test user is marked as active
- timestamp = self.store.user_last_seen_monthly_active(user1)
- self.assertTrue(self.get_success(timestamp))
- timestamp = self.store.user_last_seen_monthly_active(user2)
- self.assertTrue(self.get_success(timestamp))
+ # Test each of the registered users is marked as active
+ timestamp = self.get_success(self.store.user_last_seen_monthly_active(user1))
+ self.assertGreater(timestamp, 0)
+ timestamp = self.get_success(self.store.user_last_seen_monthly_active(user2))
+ self.assertGreater(timestamp, 0)
- # Test that users are never removed from the db.
+ # Test that users with reserved 3pids are not removed from the MAU table
+ # XXX some of this is redundant. poking things into the config shouldn't
+ # work, and in any case it's not obvious what we expect to happen when
+ # we advance the reactor.
self.hs.config.max_mau_value = 0
-
self.reactor.advance(FORTY_DAYS)
+ self.hs.config.max_mau_value = 5
- self.store.reap_monthly_active_users()
- self.pump()
+ self.get_success(self.store.reap_monthly_active_users())
- active_count = self.store.get_monthly_active_count()
- self.assertEquals(self.get_success(active_count), user_num)
+ active_count = self.get_success(self.store.get_monthly_active_count())
+ self.assertEqual(active_count, user_num)
- # Test that regular users are removed from the db
+ # Add some more users and check they are counted as active
ru_count = 2
- self.store.upsert_monthly_active_user("@ru1:server")
- self.store.upsert_monthly_active_user("@ru2:server")
- self.pump()
- active_count = self.store.get_monthly_active_count()
- self.assertEqual(self.get_success(active_count), user_num + ru_count)
- self.hs.config.max_mau_value = user_num
- self.store.reap_monthly_active_users()
- self.pump()
+ self.get_success(self.store.upsert_monthly_active_user("@ru1:server"))
+ self.get_success(self.store.upsert_monthly_active_user("@ru2:server"))
+
+ active_count = self.get_success(self.store.get_monthly_active_count())
+ self.assertEqual(active_count, user_num + ru_count)
- active_count = self.store.get_monthly_active_count()
- self.assertEquals(self.get_success(active_count), user_num)
+ # now run the reaper and check that the number of active users is reduced
+ # to max_mau_value
+ self.get_success(self.store.reap_monthly_active_users())
+
+ active_count = self.get_success(self.store.get_monthly_active_count())
+ self.assertEqual(active_count, 3)
def test_can_insert_and_count_mau(self):
- count = self.store.get_monthly_active_count()
- self.assertEqual(0, self.get_success(count))
+ count = self.get_success(self.store.get_monthly_active_count())
+ self.assertEqual(count, 0)
- self.store.upsert_monthly_active_user("@user:server")
- self.pump()
+ d = self.store.upsert_monthly_active_user("@user:server")
+ self.get_success(d)
- count = self.store.get_monthly_active_count()
- self.assertEqual(1, self.get_success(count))
+ count = self.get_success(self.store.get_monthly_active_count())
+ self.assertEqual(count, 1)
def test_user_last_seen_monthly_active(self):
user_id1 = "@user1:server"
user_id2 = "@user2:server"
user_id3 = "@user3:server"
- result = self.store.user_last_seen_monthly_active(user_id1)
- self.assertFalse(self.get_success(result) == 0)
+ result = self.get_success(self.store.user_last_seen_monthly_active(user_id1))
+ self.assertNotEqual(result, 0)
- self.store.upsert_monthly_active_user(user_id1)
- self.store.upsert_monthly_active_user(user_id2)
- self.pump()
+ self.get_success(self.store.upsert_monthly_active_user(user_id1))
+ self.get_success(self.store.upsert_monthly_active_user(user_id2))
- result = self.store.user_last_seen_monthly_active(user_id1)
- self.assertGreater(self.get_success(result), 0)
+ result = self.get_success(self.store.user_last_seen_monthly_active(user_id1))
+ self.assertGreater(result, 0)
- result = self.store.user_last_seen_monthly_active(user_id3)
- self.assertNotEqual(self.get_success(result), 0)
+ result = self.get_success(self.store.user_last_seen_monthly_active(user_id3))
+ self.assertNotEqual(result, 0)
+ @override_config({"max_mau_value": 5})
def test_reap_monthly_active_users(self):
- self.hs.config.max_mau_value = 5
initial_users = 10
for i in range(initial_users):
- self.store.upsert_monthly_active_user("@user%d:server" % i)
- self.pump()
+ self.get_success(
+ self.store.upsert_monthly_active_user("@user%d:server" % i)
+ )
- count = self.store.get_monthly_active_count()
- self.assertTrue(self.get_success(count), initial_users)
+ count = self.get_success(self.store.get_monthly_active_count())
+ self.assertEqual(count, initial_users)
- self.store.reap_monthly_active_users()
- self.pump()
- count = self.store.get_monthly_active_count()
- self.assertEquals(
- self.get_success(count), initial_users - self.hs.config.max_mau_value
- )
+ d = self.store.reap_monthly_active_users()
+ self.get_success(d)
+
+ count = self.get_success(self.store.get_monthly_active_count())
+ self.assertEqual(count, self.hs.config.max_mau_value)
self.reactor.advance(FORTY_DAYS)
- self.store.reap_monthly_active_users()
- self.pump()
- count = self.store.get_monthly_active_count()
- self.assertEquals(self.get_success(count), 0)
+ d = self.store.reap_monthly_active_users()
+ self.get_success(d)
+
+ count = self.get_success(self.store.get_monthly_active_count())
+ self.assertEqual(count, 0)
+
+ # Note that below says mau_limit (no s), this is the name of the config
+ # value, although it gets stored on the config object as mau_limits.
+ @override_config({"max_mau_value": 5, "mau_limit_reserved_threepids": gen_3pids(5)})
+ def test_reap_monthly_active_users_reserved_users(self):
+ """ Tests that reaping correctly handles reaping where reserved users are
+ present"""
+ threepids = self.hs.config.mau_limits_reserved_threepids
+ initial_users = len(threepids)
+ reserved_user_number = initial_users - 1
+ for i in range(initial_users):
+ user = "@user%d:server" % i
+ email = "user%d@matrix.org" % i
+
+ self.get_success(self.store.upsert_monthly_active_user(user))
+
+ # Need to ensure that the most recent entries in the
+ # monthly_active_users table are reserved
+ now = int(self.hs.get_clock().time_msec())
+ if i != 0:
+ self.get_success(
+ self.store.register_user(user_id=user, password_hash=None)
+ )
+ self.get_success(
+ self.store.user_add_threepid(user, "email", email, now, now)
+ )
+
+ d = self.store.db.runInteraction(
+ "initialise", self.store._initialise_reserved_users, threepids
+ )
+ self.get_success(d)
+
+ count = self.get_success(self.store.get_monthly_active_count())
+ self.assertEqual(count, initial_users)
+
+ users = self.get_success(self.store.get_registered_reserved_users())
+ self.assertEqual(len(users), reserved_user_number)
+
+ d = self.store.reap_monthly_active_users()
+ self.get_success(d)
+
+ count = self.get_success(self.store.get_monthly_active_count())
+ self.assertEqual(count, self.hs.config.max_mau_value)
def test_populate_monthly_users_is_guest(self):
# Test that guest users are not added to mau list
user_id = "@user_id:host"
- self.store.register_user(user_id=user_id, password_hash=None, make_guest=True)
+
+ d = self.store.register_user(
+ user_id=user_id, password_hash=None, make_guest=True
+ )
+ self.get_success(d)
+
self.store.upsert_monthly_active_user = Mock()
- self.store.populate_monthly_active_users(user_id)
- self.pump()
+
+ d = self.store.populate_monthly_active_users(user_id)
+ self.get_success(d)
+
self.store.upsert_monthly_active_user.assert_not_called()
def test_populate_monthly_users_should_update(self):
@@ -175,8 +245,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.store.user_last_seen_monthly_active = Mock(
return_value=defer.succeed(None)
)
- self.store.populate_monthly_active_users("user_id")
- self.pump()
+ d = self.store.populate_monthly_active_users("user_id")
+ self.get_success(d)
+
self.store.upsert_monthly_active_user.assert_called_once()
def test_populate_monthly_users_should_not_update(self):
@@ -186,80 +257,132 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.store.user_last_seen_monthly_active = Mock(
return_value=defer.succeed(self.hs.get_clock().time_msec())
)
- self.store.populate_monthly_active_users("user_id")
- self.pump()
+
+ d = self.store.populate_monthly_active_users("user_id")
+ self.get_success(d)
+
self.store.upsert_monthly_active_user.assert_not_called()
def test_get_reserved_real_user_account(self):
# Test no reserved users, or reserved threepids
- count = self.store.get_registered_reserved_users_count()
- self.assertEquals(self.get_success(count), 0)
- # Test reserved users but no registered users
+ users = self.get_success(self.store.get_registered_reserved_users())
+ self.assertEqual(len(users), 0)
+ # Test reserved users but no registered users
user1 = "@user1:example.com"
user2 = "@user2:example.com"
+
user1_email = "user1@example.com"
user2_email = "user2@example.com"
threepids = [
{"medium": "email", "address": user1_email},
{"medium": "email", "address": user2_email},
]
+
self.hs.config.mau_limits_reserved_threepids = threepids
- self.store.runInteraction(
+ d = self.store.db.runInteraction(
"initialise", self.store._initialise_reserved_users, threepids
)
+ self.get_success(d)
- self.pump()
- count = self.store.get_registered_reserved_users_count()
- self.assertEquals(self.get_success(count), 0)
+ users = self.get_success(self.store.get_registered_reserved_users())
+ self.assertEqual(len(users), 0)
- # Test reserved registed users
- self.store.register_user(user_id=user1, password_hash=None)
- self.store.register_user(user_id=user2, password_hash=None)
- self.pump()
+ # Test reserved registered users
+ self.get_success(self.store.register_user(user_id=user1, password_hash=None))
+ self.get_success(self.store.register_user(user_id=user2, password_hash=None))
now = int(self.hs.get_clock().time_msec())
self.store.user_add_threepid(user1, "email", user1_email, now, now)
self.store.user_add_threepid(user2, "email", user2_email, now, now)
- count = self.store.get_registered_reserved_users_count()
- self.assertEquals(self.get_success(count), len(threepids))
+
+ users = self.get_success(self.store.get_registered_reserved_users())
+ self.assertEqual(len(users), len(threepids))
def test_support_user_not_add_to_mau_limits(self):
support_user_id = "@support:test"
- count = self.store.get_monthly_active_count()
- self.pump()
- self.assertEqual(self.get_success(count), 0)
- self.store.register_user(
+ count = self.get_success(self.store.get_monthly_active_count())
+ self.assertEqual(count, 0)
+
+ d = self.store.register_user(
user_id=support_user_id, password_hash=None, user_type=UserTypes.SUPPORT
)
+ self.get_success(d)
- self.store.upsert_monthly_active_user(support_user_id)
- count = self.store.get_monthly_active_count()
- self.pump()
- self.assertEqual(self.get_success(count), 0)
+ d = self.store.upsert_monthly_active_user(support_user_id)
+ self.get_success(d)
- def test_track_monthly_users_without_cap(self):
- self.hs.config.limit_usage_by_mau = False
- self.hs.config.mau_stats_only = True
- self.hs.config.max_mau_value = 1 # should not matter
+ d = self.store.get_monthly_active_count()
+ count = self.get_success(d)
+ self.assertEqual(count, 0)
- count = self.store.get_monthly_active_count()
- self.assertEqual(0, self.get_success(count))
+ # Note that the max_mau_value setting should not matter.
+ @override_config(
+ {"limit_usage_by_mau": False, "mau_stats_only": True, "max_mau_value": 1}
+ )
+ def test_track_monthly_users_without_cap(self):
+ count = self.get_success(self.store.get_monthly_active_count())
+ self.assertEqual(0, count)
- self.store.upsert_monthly_active_user("@user1:server")
- self.store.upsert_monthly_active_user("@user2:server")
- self.pump()
+ self.get_success(self.store.upsert_monthly_active_user("@user1:server"))
+ self.get_success(self.store.upsert_monthly_active_user("@user2:server"))
- count = self.store.get_monthly_active_count()
- self.assertEqual(2, self.get_success(count))
+ count = self.get_success(self.store.get_monthly_active_count())
+ self.assertEqual(2, count)
+ @override_config({"limit_usage_by_mau": False, "mau_stats_only": False})
def test_no_users_when_not_tracking(self):
- self.hs.config.limit_usage_by_mau = False
- self.hs.config.mau_stats_only = False
self.store.upsert_monthly_active_user = Mock()
- self.store.populate_monthly_active_users("@user:sever")
- self.pump()
+ self.get_success(self.store.populate_monthly_active_users("@user:sever"))
self.store.upsert_monthly_active_user.assert_not_called()
+
+ def test_get_monthly_active_count_by_service(self):
+ appservice1_user1 = "@appservice1_user1:example.com"
+ appservice1_user2 = "@appservice1_user2:example.com"
+
+ appservice2_user1 = "@appservice2_user1:example.com"
+ native_user1 = "@native_user1:example.com"
+
+ service1 = "service1"
+ service2 = "service2"
+ native = "native"
+
+ self.get_success(
+ self.store.register_user(
+ user_id=appservice1_user1, password_hash=None, appservice_id=service1
+ )
+ )
+ self.get_success(
+ self.store.register_user(
+ user_id=appservice1_user2, password_hash=None, appservice_id=service1
+ )
+ )
+ self.get_success(
+ self.store.register_user(
+ user_id=appservice2_user1, password_hash=None, appservice_id=service2
+ )
+ )
+ self.get_success(
+ self.store.register_user(user_id=native_user1, password_hash=None)
+ )
+
+ count = self.get_success(self.store.get_monthly_active_count_by_service())
+ self.assertEqual(count, {})
+
+ self.get_success(self.store.upsert_monthly_active_user(native_user1))
+ self.get_success(self.store.upsert_monthly_active_user(appservice1_user1))
+ self.get_success(self.store.upsert_monthly_active_user(appservice1_user2))
+ self.get_success(self.store.upsert_monthly_active_user(appservice2_user1))
+
+ count = self.get_success(self.store.get_monthly_active_count())
+ self.assertEqual(count, 4)
+
+ d = self.store.get_monthly_active_count_by_service()
+ result = self.get_success(d)
+
+ self.assertEqual(result[service1], 2)
+ self.assertEqual(result[service2], 1)
+ self.assertEqual(result[native], 1)
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index 45824bd3b2..9b6f7211ae 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -16,7 +16,6 @@
from twisted.internet import defer
-from synapse.storage.profile import ProfileStore
from synapse.types import UserID
from tests import unittest
@@ -28,7 +27,7 @@ class ProfileStoreTestCase(unittest.TestCase):
def setUp(self):
hs = yield setup_test_homeserver(self.addCleanup)
- self.store = ProfileStore(hs.get_db_conn(), hs)
+ self.store = hs.get_datastore()
self.u_frank = UserID.from_string("@frank:test")
diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
index f671599cb8..b9fafaa1a6 100644
--- a/tests/storage/test_purge.py
+++ b/tests/storage/test_purge.py
@@ -40,23 +40,24 @@ class PurgeTests(HomeserverTestCase):
third = self.helper.send(self.room_id, body="test3")
last = self.helper.send(self.room_id, body="test4")
- storage = self.hs.get_datastore()
+ store = self.hs.get_datastore()
+ storage = self.hs.get_storage()
# Get the topological token
- event = storage.get_topological_token_for_event(last["event_id"])
+ event = store.get_topological_token_for_event(last["event_id"])
self.pump()
event = self.successResultOf(event)
# Purge everything before this topological token
- purge = storage.purge_history(self.room_id, event, True)
+ purge = storage.purge_events.purge_history(self.room_id, event, True)
self.pump()
self.assertEqual(self.successResultOf(purge), None)
# Try and get the events
- get_first = storage.get_event(first["event_id"])
- get_second = storage.get_event(second["event_id"])
- get_third = storage.get_event(third["event_id"])
- get_last = storage.get_event(last["event_id"])
+ get_first = store.get_event(first["event_id"])
+ get_second = store.get_event(second["event_id"])
+ get_third = store.get_event(third["event_id"])
+ get_last = store.get_event(last["event_id"])
self.pump()
# 1-3 should fail and last will succeed, meaning that 1-3 are deleted
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index deecfad9fb..db3667dc43 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -39,6 +39,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
+ self.storage = hs.get_storage()
self.event_builder_factory = hs.get_event_builder_factory()
self.event_creation_handler = hs.get_event_creation_handler()
@@ -73,7 +74,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder)
)
- self.get_success(self.store.persist_event(event, context))
+ self.get_success(self.storage.persistence.persist_event(event, context))
return event
@@ -95,7 +96,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder)
)
- self.get_success(self.store.persist_event(event, context))
+ self.get_success(self.storage.persistence.persist_event(event, context))
return event
@@ -116,7 +117,9 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder)
)
- self.get_success(self.store.persist_event(event, context))
+ self.get_success(self.storage.persistence.persist_event(event, context))
+
+ return event
def test_redact(self):
self.get_success(
@@ -235,8 +238,11 @@ class RedactionTestCase(unittest.HomeserverTestCase):
@defer.inlineCallbacks
def build(self, prev_event_ids):
built_event = yield self._base_builder.build(prev_event_ids)
- built_event.event_id = self._event_id
- built_event._event_dict["event_id"] = self._event_id
+
+ built_event._event_id = self._event_id
+ built_event._dict["event_id"] = self._event_id
+ assert built_event.event_id == self._event_id
+
return built_event
@property
@@ -261,7 +267,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
)
)
- self.get_success(self.store.persist_event(event_1, context_1))
+ self.get_success(self.storage.persistence.persist_event(event_1, context_1))
event_2, context_2 = self.get_success(
self.event_creation_handler.create_new_client_event(
@@ -280,7 +286,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
)
)
)
- self.get_success(self.store.persist_event(event_2, context_2))
+ self.get_success(self.storage.persistence.persist_event(event_2, context_2))
# fetch one of the redactions
fetched = self.get_success(self.store.get_event(redaction_event_id1))
@@ -335,7 +341,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
)
event_json = self.get_success(
- self.store._simple_select_one_onecol(
+ self.store.db.simple_select_one_onecol(
table="event_json",
keyvalues={"event_id": msg_event.event_id},
retcol="json",
@@ -353,7 +359,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.reactor.advance(60 * 60 * 2)
event_json = self.get_success(
- self.store._simple_select_one_onecol(
+ self.store.db.simple_select_one_onecol(
table="event_json",
keyvalues={"event_id": msg_event.event_id},
retcol="json",
@@ -361,3 +367,72 @@ class RedactionTestCase(unittest.HomeserverTestCase):
)
self.assert_dict({"content": {}}, json.loads(event_json))
+
+ def test_redact_redaction(self):
+ """Tests that we can redact a redaction and can fetch it again.
+ """
+
+ self.get_success(
+ self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
+ )
+
+ msg_event = self.get_success(self.inject_message(self.room1, self.u_alice, "t"))
+
+ first_redact_event = self.get_success(
+ self.inject_redaction(
+ self.room1, msg_event.event_id, self.u_alice, "Redacting message"
+ )
+ )
+
+ self.get_success(
+ self.inject_redaction(
+ self.room1,
+ first_redact_event.event_id,
+ self.u_alice,
+ "Redacting redaction",
+ )
+ )
+
+ # Now lets jump to the future where we have censored the redaction event
+ # in the DB.
+ self.reactor.advance(60 * 60 * 24 * 31)
+
+ # We just want to check that fetching the event doesn't raise an exception.
+ self.get_success(
+ self.store.get_event(first_redact_event.event_id, allow_none=True)
+ )
+
+ def test_store_redacted_redaction(self):
+ """Tests that we can store a redacted redaction.
+ """
+
+ self.get_success(
+ self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
+ )
+
+ builder = self.event_builder_factory.for_room_version(
+ RoomVersions.V1,
+ {
+ "type": EventTypes.Redaction,
+ "sender": self.u_alice.to_string(),
+ "room_id": self.room1.to_string(),
+ "content": {"reason": "foo"},
+ },
+ )
+
+ redaction_event, context = self.get_success(
+ self.event_creation_handler.create_new_client_event(builder)
+ )
+
+ self.get_success(
+ self.storage.persistence.persist_event(redaction_event, context)
+ )
+
+ # Now lets jump to the future where we have censored the redaction event
+ # in the DB.
+ self.reactor.advance(60 * 60 * 24 * 31)
+
+ # We just want to check that fetching the event doesn't raise an exception.
+ self.get_success(
+ self.store.get_event(redaction_event.event_id, allow_none=True)
+ )
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 4578cc3b60..71a40a0a49 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -26,7 +26,6 @@ class RegistrationStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
hs = yield setup_test_homeserver(self.addCleanup)
- self.db_pool = hs.get_db_pool()
self.store = hs.get_datastore()
@@ -44,12 +43,14 @@ class RegistrationStoreTestCase(unittest.TestCase):
# TODO(paul): Surely this field should be 'user_id', not 'name'
"name": self.user_id,
"password_hash": self.pwhash,
+ "admin": 0,
"is_guest": 0,
"consent_version": None,
"consent_server_notice_sent": None,
"appservice_id": None,
"creation_ts": 1000,
"user_type": None,
+ "deactivated": 0,
},
(yield self.store.get_user_by_id(self.user_id)),
)
diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index 1bee45706f..3b78d48896 100644
--- a/tests/storage/test_room.py
+++ b/tests/storage/test_room.py
@@ -17,6 +17,7 @@
from twisted.internet import defer
from synapse.api.constants import EventTypes
+from synapse.api.room_versions import RoomVersions
from synapse.types import RoomAlias, RoomID, UserID
from tests import unittest
@@ -40,6 +41,7 @@ class RoomStoreTestCase(unittest.TestCase):
self.room.to_string(),
room_creator_user_id=self.u_creator.to_string(),
is_public=True,
+ room_version=RoomVersions.V1,
)
@defer.inlineCallbacks
@@ -53,6 +55,17 @@ class RoomStoreTestCase(unittest.TestCase):
(yield self.store.get_room(self.room.to_string())),
)
+ @defer.inlineCallbacks
+ def test_get_room_with_stats(self):
+ self.assertDictContainsSubset(
+ {
+ "room_id": self.room.to_string(),
+ "creator": self.u_creator.to_string(),
+ "public": True,
+ },
+ (yield self.store.get_room_with_stats(self.room.to_string())),
+ )
+
class RoomEventsStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
@@ -62,17 +75,21 @@ class RoomEventsStoreTestCase(unittest.TestCase):
# Room events need the full datastore, for persist_event() and
# get_room_state()
self.store = hs.get_datastore()
+ self.storage = hs.get_storage()
self.event_factory = hs.get_event_factory()
self.room = RoomID.from_string("!abcde:test")
yield self.store.store_room(
- self.room.to_string(), room_creator_user_id="@creator:text", is_public=True
+ self.room.to_string(),
+ room_creator_user_id="@creator:text",
+ is_public=True,
+ room_version=RoomVersions.V1,
)
@defer.inlineCallbacks
def inject_room_event(self, **kwargs):
- yield self.store.persist_event(
+ yield self.storage.persistence.persist_event(
self.event_factory.create_event(room_id=self.room.to_string(), **kwargs)
)
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 447a3c6ffb..5dd46005e6 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -16,13 +16,14 @@
from unittest.mock import Mock
-from synapse.api.constants import EventTypes, Membership
-from synapse.api.room_versions import RoomVersions
+from synapse.api.constants import Membership
from synapse.rest.admin import register_servlets_for_client_rest_resource
from synapse.rest.client.v1 import login, room
from synapse.types import Requester, UserID
from tests import unittest
+from tests.test_utils import event_injection
+from tests.utils import TestHomeServer
class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
@@ -39,13 +40,11 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
)
return hs
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor, clock, hs: TestHomeServer):
# We can't test the RoomMemberStore on its own without the other event
# storage logic
self.store = hs.get_datastore()
- self.event_builder_factory = hs.get_event_builder_factory()
- self.event_creation_handler = hs.get_event_creation_handler()
self.u_alice = self.register_user("alice", "pass")
self.t_alice = self.login("alice", "pass")
@@ -54,33 +53,13 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
# User elsewhere on another host
self.u_charlie = UserID.from_string("@charlie:elsewhere")
- def inject_room_member(self, room, user, membership, replaces_state=None):
- builder = self.event_builder_factory.for_room_version(
- RoomVersions.V1,
- {
- "type": EventTypes.Member,
- "sender": user,
- "state_key": user,
- "room_id": room,
- "content": {"membership": membership},
- },
- )
-
- event, context = self.get_success(
- self.event_creation_handler.create_new_client_event(builder)
- )
-
- self.get_success(self.store.persist_event(event, context))
-
- return event
-
def test_one_member(self):
# Alice creates the room, and is automatically joined
self.room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
rooms_for_user = self.get_success(
- self.store.get_rooms_for_user_where_membership_is(
+ self.store.get_rooms_for_local_user_where_membership_is(
self.u_alice, [Membership.JOIN]
)
)
@@ -137,6 +116,52 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
# It now knows about Charlie's server.
self.assertEqual(self.store._known_servers_count, 2)
+ def test_get_joined_users_from_context(self):
+ room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
+ bob_event = event_injection.inject_member_event(
+ self.hs, room, self.u_bob, Membership.JOIN
+ )
+
+ # first, create a regular event
+ event, context = event_injection.create_event(
+ self.hs,
+ room_id=room,
+ sender=self.u_alice,
+ prev_event_ids=[bob_event.event_id],
+ type="m.test.1",
+ content={},
+ )
+
+ users = self.get_success(
+ self.store.get_joined_users_from_context(event, context)
+ )
+ self.assertEqual(users.keys(), {self.u_alice, self.u_bob})
+
+ # Regression test for #7376: create a state event whose key matches bob's
+ # user_id, but which is *not* a membership event, and persist that; then check
+ # that `get_joined_users_from_context` returns the correct users for the next event.
+ non_member_event = event_injection.inject_event(
+ self.hs,
+ room_id=room,
+ sender=self.u_bob,
+ prev_event_ids=[bob_event.event_id],
+ type="m.test.2",
+ state_key=self.u_bob,
+ content={},
+ )
+ event, context = event_injection.create_event(
+ self.hs,
+ room_id=room,
+ sender=self.u_alice,
+ prev_event_ids=[non_member_event.event_id],
+ type="m.test.3",
+ content={},
+ )
+ users = self.get_success(
+ self.store.get_joined_users_from_context(event, context)
+ )
+ self.assertEqual(users.keys(), {self.u_alice, self.u_bob})
+
class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, homeserver):
@@ -145,8 +170,12 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
def test_can_rerun_update(self):
# First make sure we have completed all updates.
- while not self.get_success(self.store.has_completed_background_updates()):
- self.get_success(self.store.do_next_background_update(100), by=0.1)
+ while not self.get_success(
+ self.store.db.updates.has_completed_background_updates()
+ ):
+ self.get_success(
+ self.store.db.updates.do_next_background_update(100), by=0.1
+ )
# Now let's create a room, which will insert a membership
user = UserID("alice", "test")
@@ -155,7 +184,7 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
# Register the background update to run again.
self.get_success(
- self.store._simple_insert(
+ self.store.db.simple_insert(
table="background_updates",
values={
"update_name": "current_state_events_membership",
@@ -166,8 +195,12 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
)
# ... and tell the DataStore that it hasn't finished all updates yet
- self.store._all_done = False
+ self.store.db.updates._all_done = False
# Now let's actually drive the updates to completion
- while not self.get_success(self.store.has_completed_background_updates()):
- self.get_success(self.store.do_next_background_update(100), by=0.1)
+ while not self.get_success(
+ self.store.db.updates.has_completed_background_updates()
+ ):
+ self.get_success(
+ self.store.db.updates.do_next_background_update(100), by=0.1
+ )
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 5c2cf3c2db..0b88308ff4 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -34,6 +34,8 @@ class StateStoreTestCase(tests.unittest.TestCase):
hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
self.store = hs.get_datastore()
+ self.storage = hs.get_storage()
+ self.state_datastore = self.storage.state.stores.state
self.event_builder_factory = hs.get_event_builder_factory()
self.event_creation_handler = hs.get_event_creation_handler()
@@ -43,7 +45,10 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.room = RoomID.from_string("!abc123:test")
yield self.store.store_room(
- self.room.to_string(), room_creator_user_id="@creator:text", is_public=True
+ self.room.to_string(),
+ room_creator_user_id="@creator:text",
+ is_public=True,
+ room_version=RoomVersions.V1,
)
@defer.inlineCallbacks
@@ -63,7 +68,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
builder
)
- yield self.store.persist_event(event, context)
+ yield self.storage.persistence.persist_event(event, context)
return event
@@ -82,7 +87,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
)
- state_group_map = yield self.store.get_state_groups_ids(
+ state_group_map = yield self.storage.state.get_state_groups_ids(
self.room, [e2.event_id]
)
self.assertEqual(len(state_group_map), 1)
@@ -101,7 +106,9 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
)
- state_group_map = yield self.store.get_state_groups(self.room, [e2.event_id])
+ state_group_map = yield self.storage.state.get_state_groups(
+ self.room, [e2.event_id]
+ )
self.assertEqual(len(state_group_map), 1)
state_list = list(state_group_map.values())[0]
@@ -141,7 +148,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
)
# check we get the full state as of the final event
- state = yield self.store.get_state_for_event(e5.event_id)
+ state = yield self.storage.state.get_state_for_event(e5.event_id)
self.assertIsNotNone(e4)
@@ -157,21 +164,21 @@ class StateStoreTestCase(tests.unittest.TestCase):
)
# check we can filter to the m.room.name event (with a '' state key)
- state = yield self.store.get_state_for_event(
+ state = yield self.storage.state.get_state_for_event(
e5.event_id, StateFilter.from_types([(EventTypes.Name, "")])
)
self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
# check we can filter to the m.room.name event (with a wildcard None state key)
- state = yield self.store.get_state_for_event(
+ state = yield self.storage.state.get_state_for_event(
e5.event_id, StateFilter.from_types([(EventTypes.Name, None)])
)
self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
# check we can grab the m.room.member events (with a wildcard None state key)
- state = yield self.store.get_state_for_event(
+ state = yield self.storage.state.get_state_for_event(
e5.event_id, StateFilter.from_types([(EventTypes.Member, None)])
)
@@ -181,7 +188,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# check we can grab a specific room member without filtering out the
# other event types
- state = yield self.store.get_state_for_event(
+ state = yield self.storage.state.get_state_for_event(
e5.event_id,
state_filter=StateFilter(
types={EventTypes.Member: {self.u_alice.to_string()}},
@@ -199,7 +206,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
)
# check that we can grab everything except members
- state = yield self.store.get_state_for_event(
+ state = yield self.storage.state.get_state_for_event(
e5.event_id,
state_filter=StateFilter(
types={EventTypes.Member: set()}, include_others=True
@@ -215,13 +222,18 @@ class StateStoreTestCase(tests.unittest.TestCase):
#######################################################
room_id = self.room.to_string()
- group_ids = yield self.store.get_state_groups_ids(room_id, [e5.event_id])
+ group_ids = yield self.storage.state.get_state_groups_ids(
+ room_id, [e5.event_id]
+ )
group = list(group_ids.keys())[0]
# test _get_state_for_group_using_cache correctly filters out members
# with types=[]
- (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
- self.store._state_group_cache,
+ (
+ state_dict,
+ is_all,
+ ) = yield self.state_datastore._get_state_for_group_using_cache(
+ self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: set()}, include_others=True
@@ -237,8 +249,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
state_dict,
)
- (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
- self.store._state_group_members_cache,
+ (
+ state_dict,
+ is_all,
+ ) = yield self.state_datastore._get_state_for_group_using_cache(
+ self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: set()}, include_others=True
@@ -250,8 +265,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members
# with wildcard types
- (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
- self.store._state_group_cache,
+ (
+ state_dict,
+ is_all,
+ ) = yield self.state_datastore._get_state_for_group_using_cache(
+ self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: None}, include_others=True
@@ -267,8 +285,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
state_dict,
)
- (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
- self.store._state_group_members_cache,
+ (
+ state_dict,
+ is_all,
+ ) = yield self.state_datastore._get_state_for_group_using_cache(
+ self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: None}, include_others=True
@@ -287,8 +308,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members
# with specific types
- (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
- self.store._state_group_cache,
+ (
+ state_dict,
+ is_all,
+ ) = yield self.state_datastore._get_state_for_group_using_cache(
+ self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=True
@@ -304,8 +328,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
state_dict,
)
- (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
- self.store._state_group_members_cache,
+ (
+ state_dict,
+ is_all,
+ ) = yield self.state_datastore._get_state_for_group_using_cache(
+ self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=True
@@ -317,8 +344,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members
# with specific types
- (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
- self.store._state_group_members_cache,
+ (
+ state_dict,
+ is_all,
+ ) = yield self.state_datastore._get_state_for_group_using_cache(
+ self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=False
@@ -331,9 +361,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
#######################################################
# deliberately remove e2 (room name) from the _state_group_cache
- (is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get(
- group
- )
+ (
+ is_all,
+ known_absent,
+ state_dict_ids,
+ ) = self.state_datastore._state_group_cache.get(group)
self.assertEqual(is_all, True)
self.assertEqual(known_absent, set())
@@ -346,21 +378,23 @@ class StateStoreTestCase(tests.unittest.TestCase):
)
state_dict_ids.pop((e2.type, e2.state_key))
- self.store._state_group_cache.invalidate(group)
- self.store._state_group_cache.update(
- sequence=self.store._state_group_cache.sequence,
+ self.state_datastore._state_group_cache.invalidate(group)
+ self.state_datastore._state_group_cache.update(
+ sequence=self.state_datastore._state_group_cache.sequence,
key=group,
value=state_dict_ids,
# list fetched keys so it knows it's partial
fetched_keys=((e1.type, e1.state_key),),
)
- (is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get(
- group
- )
+ (
+ is_all,
+ known_absent,
+ state_dict_ids,
+ ) = self.state_datastore._state_group_cache.get(group)
self.assertEqual(is_all, False)
- self.assertEqual(known_absent, set([(e1.type, e1.state_key)]))
+ self.assertEqual(known_absent, {(e1.type, e1.state_key)})
self.assertDictEqual(state_dict_ids, {(e1.type, e1.state_key): e1.event_id})
############################################
@@ -369,8 +403,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters out members
# with types=[]
room_id = self.room.to_string()
- (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
- self.store._state_group_cache,
+ (
+ state_dict,
+ is_all,
+ ) = yield self.state_datastore._get_state_for_group_using_cache(
+ self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: set()}, include_others=True
@@ -381,8 +418,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
room_id = self.room.to_string()
- (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
- self.store._state_group_members_cache,
+ (
+ state_dict,
+ is_all,
+ ) = yield self.state_datastore._get_state_for_group_using_cache(
+ self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: set()}, include_others=True
@@ -394,8 +434,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members
# wildcard types
- (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
- self.store._state_group_cache,
+ (
+ state_dict,
+ is_all,
+ ) = yield self.state_datastore._get_state_for_group_using_cache(
+ self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: None}, include_others=True
@@ -405,8 +448,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertEqual(is_all, False)
self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
- (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
- self.store._state_group_members_cache,
+ (
+ state_dict,
+ is_all,
+ ) = yield self.state_datastore._get_state_for_group_using_cache(
+ self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: None}, include_others=True
@@ -424,8 +470,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members
# with specific types
- (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
- self.store._state_group_cache,
+ (
+ state_dict,
+ is_all,
+ ) = yield self.state_datastore._get_state_for_group_using_cache(
+ self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=True
@@ -435,8 +484,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertEqual(is_all, False)
self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
- (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
- self.store._state_group_members_cache,
+ (
+ state_dict,
+ is_all,
+ ) = yield self.state_datastore._get_state_for_group_using_cache(
+ self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=True
@@ -448,8 +500,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members
# with specific types
- (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
- self.store._state_group_cache,
+ (
+ state_dict,
+ is_all,
+ ) = yield self.state_datastore._get_state_for_group_using_cache(
+ self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=False
@@ -459,8 +514,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertEqual(is_all, False)
self.assertDictEqual({}, state_dict)
- (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
- self.store._state_group_members_cache,
+ (
+ state_dict,
+ is_all,
+ ) = yield self.state_datastore._get_state_for_group_using_cache(
+ self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=False
diff --git a/tests/storage/test_transactions.py b/tests/storage/test_transactions.py
index a771d5af29..8e817e2c7f 100644
--- a/tests/storage/test_transactions.py
+++ b/tests/storage/test_transactions.py
@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from synapse.util.retryutils import MAX_RETRY_INTERVAL
+
from tests.unittest import HomeserverTestCase
@@ -45,3 +47,12 @@ class TransactionStoreTestCase(HomeserverTestCase):
"""
d = self.store.set_destination_retry_timings("example.com", 1000, 50, 100)
self.get_success(d)
+
+ def test_large_destination_retry(self):
+ d = self.store.set_destination_retry_timings(
+ "example.com", MAX_RETRY_INTERVAL, MAX_RETRY_INTERVAL, MAX_RETRY_INTERVAL
+ )
+ self.get_success(d)
+
+ d = self.store.get_destination_retry_timings("example.com")
+ self.get_success(d)
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index d7d244ce97..6a545d2eb0 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -15,8 +15,6 @@
from twisted.internet import defer
-from synapse.storage import UserDirectoryStore
-
from tests import unittest
from tests.utils import setup_test_homeserver
@@ -29,7 +27,7 @@ class UserDirectoryStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
self.hs = yield setup_test_homeserver(self.addCleanup)
- self.store = UserDirectoryStore(self.hs.get_db_conn(), self.hs)
+ self.store = self.hs.get_datastore()
# alice and bob are both in !room_id. bobby is not but shares
# a homeserver with alice.
diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py
index 8b2741d277..69b4c5d6c2 100644
--- a/tests/test_event_auth.py
+++ b/tests/test_event_auth.py
@@ -18,7 +18,8 @@ import unittest
from synapse import event_auth
from synapse.api.errors import AuthError
from synapse.api.room_versions import RoomVersions
-from synapse.events import FrozenEvent
+from synapse.events import make_event_from_dict
+from synapse.types import get_domain_from_id
class EventAuthTestCase(unittest.TestCase):
@@ -37,7 +38,7 @@ class EventAuthTestCase(unittest.TestCase):
# creator should be able to send state
event_auth.check(
- RoomVersions.V1.identifier,
+ RoomVersions.V1,
_random_state_event(creator),
auth_events,
do_sig_check=False,
@@ -47,11 +48,11 @@ class EventAuthTestCase(unittest.TestCase):
self.assertRaises(
AuthError,
event_auth.check,
- RoomVersions.V1.identifier,
+ RoomVersions.V1,
_random_state_event(joiner),
auth_events,
do_sig_check=False,
- ),
+ )
def test_state_default_level(self):
"""
@@ -76,7 +77,7 @@ class EventAuthTestCase(unittest.TestCase):
self.assertRaises(
AuthError,
event_auth.check,
- RoomVersions.V1.identifier,
+ RoomVersions.V1,
_random_state_event(pleb),
auth_events,
do_sig_check=False,
@@ -84,11 +85,112 @@ class EventAuthTestCase(unittest.TestCase):
# king should be able to send state
event_auth.check(
- RoomVersions.V1.identifier,
- _random_state_event(king),
+ RoomVersions.V1, _random_state_event(king), auth_events, do_sig_check=False,
+ )
+
+ def test_alias_event(self):
+ """Alias events have special behavior up through room version 6."""
+ creator = "@creator:example.com"
+ other = "@other:example.com"
+ auth_events = {
+ ("m.room.create", ""): _create_event(creator),
+ ("m.room.member", creator): _join_event(creator),
+ }
+
+ # creator should be able to send aliases
+ event_auth.check(
+ RoomVersions.V1, _alias_event(creator), auth_events, do_sig_check=False,
+ )
+
+ # Reject an event with no state key.
+ with self.assertRaises(AuthError):
+ event_auth.check(
+ RoomVersions.V1,
+ _alias_event(creator, state_key=""),
+ auth_events,
+ do_sig_check=False,
+ )
+
+ # If the domain of the sender does not match the state key, reject.
+ with self.assertRaises(AuthError):
+ event_auth.check(
+ RoomVersions.V1,
+ _alias_event(creator, state_key="test.com"),
+ auth_events,
+ do_sig_check=False,
+ )
+
+ # Note that the member does *not* need to be in the room.
+ event_auth.check(
+ RoomVersions.V1, _alias_event(other), auth_events, do_sig_check=False,
+ )
+
+ def test_msc2432_alias_event(self):
+ """After MSC2432, alias events have no special behavior."""
+ creator = "@creator:example.com"
+ other = "@other:example.com"
+ auth_events = {
+ ("m.room.create", ""): _create_event(creator),
+ ("m.room.member", creator): _join_event(creator),
+ }
+
+ # creator should be able to send aliases
+ event_auth.check(
+ RoomVersions.V6, _alias_event(creator), auth_events, do_sig_check=False,
+ )
+
+ # No particular checks are done on the state key.
+ event_auth.check(
+ RoomVersions.V6,
+ _alias_event(creator, state_key=""),
auth_events,
do_sig_check=False,
)
+ event_auth.check(
+ RoomVersions.V6,
+ _alias_event(creator, state_key="test.com"),
+ auth_events,
+ do_sig_check=False,
+ )
+
+ # Per standard auth rules, the member must be in the room.
+ with self.assertRaises(AuthError):
+ event_auth.check(
+ RoomVersions.V6, _alias_event(other), auth_events, do_sig_check=False,
+ )
+
+ def test_msc2209(self):
+ """
+ Notifications power levels get checked due to MSC2209.
+ """
+ creator = "@creator:example.com"
+ pleb = "@joiner:example.com"
+
+ auth_events = {
+ ("m.room.create", ""): _create_event(creator),
+ ("m.room.member", creator): _join_event(creator),
+ ("m.room.power_levels", ""): _power_levels_event(
+ creator, {"state_default": "30", "users": {pleb: "30"}}
+ ),
+ ("m.room.member", pleb): _join_event(pleb),
+ }
+
+ # pleb should be able to modify the notifications power level.
+ event_auth.check(
+ RoomVersions.V1,
+ _power_levels_event(pleb, {"notifications": {"room": 100}}),
+ auth_events,
+ do_sig_check=False,
+ )
+
+ # But an MSC2209 room rejects this change.
+ with self.assertRaises(AuthError):
+ event_auth.check(
+ RoomVersions.V6,
+ _power_levels_event(pleb, {"notifications": {"room": 100}}),
+ auth_events,
+ do_sig_check=False,
+ )
# helpers for making events
@@ -97,7 +199,7 @@ TEST_ROOM_ID = "!test:room"
def _create_event(user_id):
- return FrozenEvent(
+ return make_event_from_dict(
{
"room_id": TEST_ROOM_ID,
"event_id": _get_event_id(),
@@ -109,7 +211,7 @@ def _create_event(user_id):
def _join_event(user_id):
- return FrozenEvent(
+ return make_event_from_dict(
{
"room_id": TEST_ROOM_ID,
"event_id": _get_event_id(),
@@ -122,7 +224,7 @@ def _join_event(user_id):
def _power_levels_event(sender, content):
- return FrozenEvent(
+ return make_event_from_dict(
{
"room_id": TEST_ROOM_ID,
"event_id": _get_event_id(),
@@ -134,8 +236,21 @@ def _power_levels_event(sender, content):
)
+def _alias_event(sender, **kwargs):
+ data = {
+ "room_id": TEST_ROOM_ID,
+ "event_id": _get_event_id(),
+ "type": "m.room.aliases",
+ "sender": sender,
+ "state_key": get_domain_from_id(sender),
+ "content": {"aliases": []},
+ }
+ data.update(**kwargs)
+ return make_event_from_dict(data)
+
+
def _random_state_event(sender):
- return FrozenEvent(
+ return make_event_from_dict(
{
"room_id": TEST_ROOM_ID,
"event_id": _get_event_id(),
diff --git a/tests/test_federation.py b/tests/test_federation.py
index a73f18f88e..c662195eec 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -1,17 +1,18 @@
from mock import Mock
-from twisted.internet.defer import maybeDeferred, succeed
+from twisted.internet.defer import ensureDeferred, maybeDeferred, succeed
-from synapse.events import FrozenEvent
+from synapse.events import make_event_from_dict
from synapse.logging.context import LoggingContext
from synapse.types import Requester, UserID
from synapse.util import Clock
+from synapse.util.retryutils import NotRetryingDestination
from tests import unittest
from tests.server import ThreadedMemoryReactorClock, setup_test_homeserver
-class MessageAcceptTests(unittest.TestCase):
+class MessageAcceptTests(unittest.HomeserverTestCase):
def setUp(self):
self.http_client = Mock()
@@ -27,20 +28,25 @@ class MessageAcceptTests(unittest.TestCase):
user_id = UserID("us", "test")
our_user = Requester(user_id, None, False, None, None)
room_creator = self.homeserver.get_room_creation_handler()
- room = room_creator.create_room(
- our_user, room_creator.PRESETS_DICT["public_chat"], ratelimit=False
+ room_deferred = ensureDeferred(
+ room_creator.create_room(
+ our_user, room_creator.PRESETS_DICT["public_chat"], ratelimit=False
+ )
)
self.reactor.advance(0.1)
- self.room_id = self.successResultOf(room)["room_id"]
+ self.room_id = self.successResultOf(room_deferred)[0]["room_id"]
+
+ self.store = self.homeserver.get_datastore()
# Figure out what the most recent event is
most_recent = self.successResultOf(
maybeDeferred(
- self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id
+ self.homeserver.get_datastore().get_latest_event_ids_in_room,
+ self.room_id,
)
)[0]
- join_event = FrozenEvent(
+ join_event = make_event_from_dict(
{
"room_id": self.room_id,
"sender": "@baduser:test.serv",
@@ -58,15 +64,19 @@ class MessageAcceptTests(unittest.TestCase):
)
self.handler = self.homeserver.get_handlers().federation_handler
- self.handler.do_auth = lambda *a, **b: succeed(True)
+ self.handler.do_auth = lambda origin, event, context, auth_events: succeed(
+ context
+ )
self.client = self.homeserver.get_federation_client()
self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed(
pdus
)
# Send the join, it should return None (which is not an error)
- d = self.handler.on_receive_pdu(
- "test.serv", join_event, sent_to_us_directly=True
+ d = ensureDeferred(
+ self.handler.on_receive_pdu(
+ "test.serv", join_event, sent_to_us_directly=True
+ )
)
self.reactor.advance(1)
self.assertEqual(self.successResultOf(d), None)
@@ -74,9 +84,7 @@ class MessageAcceptTests(unittest.TestCase):
# Make sure we actually joined the room
self.assertEqual(
self.successResultOf(
- maybeDeferred(
- self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id
- )
+ maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id)
)[0],
"$join:test.serv",
)
@@ -96,13 +104,11 @@ class MessageAcceptTests(unittest.TestCase):
# Figure out what the most recent event is
most_recent = self.successResultOf(
- maybeDeferred(
- self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id
- )
+ maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id)
)[0]
# Now lie about an event
- lying_event = FrozenEvent(
+ lying_event = make_event_from_dict(
{
"room_id": self.room_id,
"sender": "@baduser:test.serv",
@@ -118,8 +124,10 @@ class MessageAcceptTests(unittest.TestCase):
)
with LoggingContext(request="lying_event"):
- d = self.handler.on_receive_pdu(
- "test.serv", lying_event, sent_to_us_directly=True
+ d = ensureDeferred(
+ self.handler.on_receive_pdu(
+ "test.serv", lying_event, sent_to_us_directly=True
+ )
)
# Step the reactor, so the database fetches come back
@@ -136,7 +144,121 @@ class MessageAcceptTests(unittest.TestCase):
)
# Make sure the invalid event isn't there
- extrem = maybeDeferred(
- self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id
- )
+ extrem = maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id)
self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv")
+
+ def test_retry_device_list_resync(self):
+ """Tests that device lists are marked as stale if they couldn't be synced, and
+ that stale device lists are retried periodically.
+ """
+ remote_user_id = "@john:test_remote"
+ remote_origin = "test_remote"
+
+ # Track the number of attempts to resync the user's device list.
+ self.resync_attempts = 0
+
+ # When this function is called, increment the number of resync attempts (only if
+ # we're querying devices for the right user ID), then raise a
+ # NotRetryingDestination error to fail the resync gracefully.
+ def query_user_devices(destination, user_id):
+ if user_id == remote_user_id:
+ self.resync_attempts += 1
+
+ raise NotRetryingDestination(0, 0, destination)
+
+ # Register the mock on the federation client.
+ federation_client = self.homeserver.get_federation_client()
+ federation_client.query_user_devices = Mock(side_effect=query_user_devices)
+
+ # Register a mock on the store so that the incoming update doesn't fail because
+ # we don't share a room with the user.
+ store = self.homeserver.get_datastore()
+ store.get_rooms_for_user = Mock(return_value=["!someroom:test"])
+
+ # Manually inject a fake device list update. We need this update to include at
+ # least one prev_id so that the user's device list will need to be retried.
+ device_list_updater = self.homeserver.get_device_handler().device_list_updater
+ self.get_success(
+ device_list_updater.incoming_device_list_update(
+ origin=remote_origin,
+ edu_content={
+ "deleted": False,
+ "device_display_name": "Mobile",
+ "device_id": "QBUAZIFURK",
+ "prev_id": [5],
+ "stream_id": 6,
+ "user_id": remote_user_id,
+ },
+ )
+ )
+
+ # Check that there was one resync attempt.
+ self.assertEqual(self.resync_attempts, 1)
+
+ # Check that the resync attempt failed and caused the user's device list to be
+ # marked as stale.
+ need_resync = self.get_success(
+ store.get_user_ids_requiring_device_list_resync()
+ )
+ self.assertIn(remote_user_id, need_resync)
+
+ # Check that waiting for 30 seconds caused Synapse to retry resyncing the device
+ # list.
+ self.reactor.advance(30)
+ self.assertEqual(self.resync_attempts, 2)
+
+ def test_cross_signing_keys_retry(self):
+ """Tests that resyncing a device list correctly processes cross-signing keys from
+ the remote server.
+ """
+ remote_user_id = "@john:test_remote"
+ remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY"
+ remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
+
+ # Register mock device list retrieval on the federation client.
+ federation_client = self.homeserver.get_federation_client()
+ federation_client.query_user_devices = Mock(
+ return_value={
+ "user_id": remote_user_id,
+ "stream_id": 1,
+ "devices": [],
+ "master_key": {
+ "user_id": remote_user_id,
+ "usage": ["master"],
+ "keys": {"ed25519:" + remote_master_key: remote_master_key},
+ },
+ "self_signing_key": {
+ "user_id": remote_user_id,
+ "usage": ["self_signing"],
+ "keys": {
+ "ed25519:" + remote_self_signing_key: remote_self_signing_key
+ },
+ },
+ }
+ )
+
+ # Resync the device list.
+ device_handler = self.homeserver.get_device_handler()
+ self.get_success(
+ device_handler.device_list_updater.user_device_resync(remote_user_id),
+ )
+
+ # Retrieve the cross-signing keys for this user.
+ keys = self.get_success(
+ self.store.get_e2e_cross_signing_keys_bulk(user_ids=[remote_user_id]),
+ )
+ self.assertTrue(remote_user_id in keys)
+
+ # Check that the master key is the one returned by the mock.
+ master_key = keys[remote_user_id]["master"]
+ self.assertEqual(len(master_key["keys"]), 1)
+ self.assertTrue("ed25519:" + remote_master_key in master_key["keys"].keys())
+ self.assertTrue(remote_master_key in master_key["keys"].values())
+
+ # Check that the self-signing key is the one returned by the mock.
+ self_signing_key = keys[remote_user_id]["self_signing"]
+ self.assertEqual(len(self_signing_key["keys"]), 1)
+ self.assertTrue(
+ "ed25519:" + remote_self_signing_key in self_signing_key["keys"].keys(),
+ )
+ self.assertTrue(remote_self_signing_key in self_signing_key["keys"].values())
diff --git a/tests/test_mau.py b/tests/test_mau.py
index 1fbe0d51ff..49667ed7f4 100644
--- a/tests/test_mau.py
+++ b/tests/test_mau.py
@@ -17,40 +17,44 @@
import json
-from mock import Mock
-
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.rest.client.v2_alpha import register, sync
from tests import unittest
+from tests.unittest import override_config
+from tests.utils import default_config
class TestMauLimit(unittest.HomeserverTestCase):
servlets = [register.register_servlets, sync.register_servlets]
- def make_homeserver(self, reactor, clock):
+ def default_config(self):
+ config = default_config("test")
- self.hs = self.setup_test_homeserver(
- "red", http_client=None, federation_client=Mock()
+ config.update(
+ {
+ "registrations_require_3pid": [],
+ "limit_usage_by_mau": True,
+ "max_mau_value": 2,
+ "mau_trial_days": 0,
+ "server_notices": {
+ "system_mxid_localpart": "server",
+ "room_name": "Test Server Notice Room",
+ },
+ }
)
- self.store = self.hs.get_datastore()
+ # apply any additional config which was specified via the override_config
+ # decorator.
+ if self._extra_config is not None:
+ config.update(self._extra_config)
- self.hs.config.registrations_require_3pid = []
- self.hs.config.enable_registration_captcha = False
- self.hs.config.recaptcha_public_key = []
+ return config
- self.hs.config.limit_usage_by_mau = True
- self.hs.config.hs_disabled = False
- self.hs.config.max_mau_value = 2
- self.hs.config.mau_trial_days = 0
- self.hs.config.server_notices_mxid = "@server:red"
- self.hs.config.server_notices_mxid_display_name = None
- self.hs.config.server_notices_mxid_avatar_url = None
- self.hs.config.server_notices_room_name = "Test Server Notice Room"
- return self.hs
+ def prepare(self, reactor, clock, homeserver):
+ self.store = homeserver.get_datastore()
def test_simple_deny_mau(self):
# Create and sync so that the MAU counts get updated
@@ -59,6 +63,9 @@ class TestMauLimit(unittest.HomeserverTestCase):
token2 = self.create_user("kermit2")
self.do_sync_for_user(token2)
+ # check we're testing what we think we are: there should be two active users
+ self.assertEqual(self.get_success(self.store.get_monthly_active_count()), 2)
+
# We've created and activated two users, we shouldn't be able to
# register new users
with self.assertRaises(SynapseError) as cm:
@@ -78,7 +85,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
# Advance time by 31 days
self.reactor.advance(31 * 24 * 60 * 60)
- self.store.reap_monthly_active_users()
+ self.get_success(self.store.reap_monthly_active_users())
self.reactor.advance(0)
@@ -86,9 +93,8 @@ class TestMauLimit(unittest.HomeserverTestCase):
token3 = self.create_user("kermit3")
self.do_sync_for_user(token3)
+ @override_config({"mau_trial_days": 1})
def test_trial_delay(self):
- self.hs.config.mau_trial_days = 1
-
# We should be able to register more than the limit initially
token1 = self.create_user("kermit1")
self.do_sync_for_user(token1)
@@ -120,6 +126,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
self.assertEqual(e.code, 403)
self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+ @override_config({"mau_trial_days": 1})
def test_trial_users_cant_come_back(self):
self.hs.config.mau_trial_days = 1
@@ -140,8 +147,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
# Advance by 2 months so everyone falls out of MAU
self.reactor.advance(60 * 24 * 60 * 60)
- self.store.reap_monthly_active_users()
- self.reactor.advance(0)
+ self.get_success(self.store.reap_monthly_active_users())
# We can create as many new users as we want
token4 = self.create_user("kermit4")
@@ -168,11 +174,11 @@ class TestMauLimit(unittest.HomeserverTestCase):
self.assertEqual(e.code, 403)
self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+ @override_config(
+ # max_mau_value should not matter
+ {"max_mau_value": 1, "limit_usage_by_mau": False, "mau_stats_only": True}
+ )
def test_tracked_but_not_limited(self):
- self.hs.config.max_mau_value = 1 # should not matter
- self.hs.config.limit_usage_by_mau = False
- self.hs.config.mau_stats_only = True
-
# Simply being able to create 2 users indicates that the
# limit was not reached.
token1 = self.create_user("kermit1")
diff --git a/tests/test_metrics.py b/tests/test_metrics.py
index 270f853d60..f5f63d8ed6 100644
--- a/tests/test_metrics.py
+++ b/tests/test_metrics.py
@@ -15,6 +15,7 @@
# limitations under the License.
from synapse.metrics import REGISTRY, InFlightGauge, generate_latest
+from synapse.util.caches.descriptors import Cache
from tests import unittest
@@ -129,3 +130,36 @@ class BuildInfoTests(unittest.TestCase):
self.assertTrue(b"osversion=" in items[0])
self.assertTrue(b"pythonversion=" in items[0])
self.assertTrue(b"version=" in items[0])
+
+
+class CacheMetricsTests(unittest.HomeserverTestCase):
+ def test_cache_metric(self):
+ """
+ Caches produce metrics reflecting their state when scraped.
+ """
+ CACHE_NAME = "cache_metrics_test_fgjkbdfg"
+ cache = Cache(CACHE_NAME, max_entries=777)
+
+ items = {
+ x.split(b"{")[0].decode("ascii"): x.split(b" ")[1].decode("ascii")
+ for x in filter(
+ lambda x: b"cache_metrics_test_fgjkbdfg" in x,
+ generate_latest(REGISTRY).split(b"\n"),
+ )
+ }
+
+ self.assertEqual(items["synapse_util_caches_cache_size"], "0.0")
+ self.assertEqual(items["synapse_util_caches_cache_max_size"], "777.0")
+
+ cache.prefill("1", "hi")
+
+ items = {
+ x.split(b"{")[0].decode("ascii"): x.split(b" ")[1].decode("ascii")
+ for x in filter(
+ lambda x: b"cache_metrics_test_fgjkbdfg" in x,
+ generate_latest(REGISTRY).split(b"\n"),
+ )
+ }
+
+ self.assertEqual(items["synapse_util_caches_cache_size"], "1.0")
+ self.assertEqual(items["synapse_util_caches_cache_max_size"], "777.0")
diff --git a/tests/test_phone_home.py b/tests/test_phone_home.py
new file mode 100644
index 0000000000..7657bddea5
--- /dev/null
+++ b/tests/test_phone_home.py
@@ -0,0 +1,51 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import resource
+
+import mock
+
+from synapse.app.homeserver import phone_stats_home
+
+from tests.unittest import HomeserverTestCase
+
+
+class PhoneHomeStatsTestCase(HomeserverTestCase):
+ def test_performance_frozen_clock(self):
+ """
+ If time doesn't move, don't error out.
+ """
+ past_stats = [
+ (self.hs.get_clock().time(), resource.getrusage(resource.RUSAGE_SELF))
+ ]
+ stats = {}
+ self.get_success(phone_stats_home(self.hs, stats, past_stats))
+ self.assertEqual(stats["cpu_average"], 0)
+
+ def test_performance_100(self):
+ """
+ 1 second of usage over 1 second is 100% CPU usage.
+ """
+ real_res = resource.getrusage(resource.RUSAGE_SELF)
+ old_resource = mock.Mock(spec=real_res)
+ old_resource.ru_utime = real_res.ru_utime - 1
+ old_resource.ru_stime = real_res.ru_stime
+ old_resource.ru_maxrss = real_res.ru_maxrss
+
+ past_stats = [(self.hs.get_clock().time(), old_resource)]
+ stats = {}
+ self.reactor.advance(1)
+ self.get_success(phone_stats_home(self.hs, stats, past_stats))
+ self.assertApproximates(stats["cpu_average"], 100, tolerance=2.5)
diff --git a/tests/test_server.py b/tests/test_server.py
index 98fef21d55..e9a43b1e45 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -23,8 +23,13 @@ from twisted.test.proto_helpers import AccumulatingProtocol
from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET
-from synapse.api.errors import Codes, SynapseError
-from synapse.http.server import JsonResource
+from synapse.api.errors import Codes, RedirectException, SynapseError
+from synapse.http.server import (
+ DirectServeResource,
+ JsonResource,
+ OptionsResource,
+ wrap_html_request_handler,
+)
from synapse.http.site import SynapseSite, logger
from synapse.logging.context import make_deferred_yieldable
from synapse.util import Clock
@@ -164,6 +169,157 @@ class JsonResourceTests(unittest.TestCase):
self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
+class OptionsResourceTests(unittest.TestCase):
+ def setUp(self):
+ self.reactor = ThreadedMemoryReactorClock()
+
+ class DummyResource(Resource):
+ isLeaf = True
+
+ def render(self, request):
+ return request.path
+
+ # Setup a resource with some children.
+ self.resource = OptionsResource()
+ self.resource.putChild(b"res", DummyResource())
+
+ def _make_request(self, method, path):
+ """Create a request from the method/path and return a channel with the response."""
+ request, channel = make_request(self.reactor, method, path, shorthand=False)
+ request.prepath = [] # This doesn't get set properly by make_request.
+
+ # Create a site and query for the resource.
+ site = SynapseSite("test", "site_tag", {}, self.resource, "1.0")
+ request.site = site
+ resource = site.getResourceFor(request)
+
+ # Finally, render the resource and return the channel.
+ render(request, resource, self.reactor)
+ return channel
+
+ def test_unknown_options_request(self):
+ """An OPTIONS requests to an unknown URL still returns 200 OK."""
+ channel = self._make_request(b"OPTIONS", b"/foo/")
+ self.assertEqual(channel.result["code"], b"200")
+ self.assertEqual(channel.result["body"], b"{}")
+
+ # Ensure the correct CORS headers have been added
+ self.assertTrue(
+ channel.headers.hasHeader(b"Access-Control-Allow-Origin"),
+ "has CORS Origin header",
+ )
+ self.assertTrue(
+ channel.headers.hasHeader(b"Access-Control-Allow-Methods"),
+ "has CORS Methods header",
+ )
+ self.assertTrue(
+ channel.headers.hasHeader(b"Access-Control-Allow-Headers"),
+ "has CORS Headers header",
+ )
+
+ def test_known_options_request(self):
+ """An OPTIONS requests to an known URL still returns 200 OK."""
+ channel = self._make_request(b"OPTIONS", b"/res/")
+ self.assertEqual(channel.result["code"], b"200")
+ self.assertEqual(channel.result["body"], b"{}")
+
+ # Ensure the correct CORS headers have been added
+ self.assertTrue(
+ channel.headers.hasHeader(b"Access-Control-Allow-Origin"),
+ "has CORS Origin header",
+ )
+ self.assertTrue(
+ channel.headers.hasHeader(b"Access-Control-Allow-Methods"),
+ "has CORS Methods header",
+ )
+ self.assertTrue(
+ channel.headers.hasHeader(b"Access-Control-Allow-Headers"),
+ "has CORS Headers header",
+ )
+
+ def test_unknown_request(self):
+ """A non-OPTIONS request to an unknown URL should 404."""
+ channel = self._make_request(b"GET", b"/foo/")
+ self.assertEqual(channel.result["code"], b"404")
+
+ def test_known_request(self):
+ """A non-OPTIONS request to an known URL should query the proper resource."""
+ channel = self._make_request(b"GET", b"/res/")
+ self.assertEqual(channel.result["code"], b"200")
+ self.assertEqual(channel.result["body"], b"/res/")
+
+
+class WrapHtmlRequestHandlerTests(unittest.TestCase):
+ class TestResource(DirectServeResource):
+ callback = None
+
+ @wrap_html_request_handler
+ async def _async_render_GET(self, request):
+ return await self.callback(request)
+
+ def setUp(self):
+ self.reactor = ThreadedMemoryReactorClock()
+
+ def test_good_response(self):
+ def callback(request):
+ request.write(b"response")
+ request.finish()
+
+ res = WrapHtmlRequestHandlerTests.TestResource()
+ res.callback = callback
+
+ request, channel = make_request(self.reactor, b"GET", b"/path")
+ render(request, res, self.reactor)
+
+ self.assertEqual(channel.result["code"], b"200")
+ body = channel.result["body"]
+ self.assertEqual(body, b"response")
+
+ def test_redirect_exception(self):
+ """
+ If the callback raises a RedirectException, it is turned into a 30x
+ with the right location.
+ """
+
+ def callback(request, **kwargs):
+ raise RedirectException(b"/look/an/eagle", 301)
+
+ res = WrapHtmlRequestHandlerTests.TestResource()
+ res.callback = callback
+
+ request, channel = make_request(self.reactor, b"GET", b"/path")
+ render(request, res, self.reactor)
+
+ self.assertEqual(channel.result["code"], b"301")
+ headers = channel.result["headers"]
+ location_headers = [v for k, v in headers if k == b"Location"]
+ self.assertEqual(location_headers, [b"/look/an/eagle"])
+
+ def test_redirect_exception_with_cookie(self):
+ """
+ If the callback raises a RedirectException which sets a cookie, that is
+ returned too
+ """
+
+ def callback(request, **kwargs):
+ e = RedirectException(b"/no/over/there", 304)
+ e.cookies.append(b"session=yespls")
+ raise e
+
+ res = WrapHtmlRequestHandlerTests.TestResource()
+ res.callback = callback
+
+ request, channel = make_request(self.reactor, b"GET", b"/path")
+ render(request, res, self.reactor)
+
+ self.assertEqual(channel.result["code"], b"304")
+ headers = channel.result["headers"]
+ location_headers = [v for k, v in headers if k == b"Location"]
+ self.assertEqual(location_headers, [b"/no/over/there"])
+ cookies_headers = [v for k, v in headers if k == b"Set-Cookie"]
+ self.assertEqual(cookies_headers, [b"session=yespls"])
+
+
class SiteTestCase(unittest.HomeserverTestCase):
def test_lose_connection(self):
"""
diff --git a/tests/test_state.py b/tests/test_state.py
index 610ec9fb46..66f22f6813 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -20,7 +20,8 @@ from twisted.internet import defer
from synapse.api.auth import Auth
from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions
-from synapse.events import FrozenEvent
+from synapse.events import make_event_from_dict
+from synapse.events.snapshot import EventContext
from synapse.state import StateHandler, StateResolutionHandler
from tests import unittest
@@ -65,7 +66,7 @@ def create_event(
d.update(kwargs)
- event = FrozenEvent(d)
+ event = make_event_from_dict(d)
return event
@@ -118,7 +119,7 @@ class StateGroupStore(object):
def register_event_id_state_group(self, event_id, state_group):
self._event_to_state_group[event_id] = state_group
- def get_room_version(self, room_id):
+ def get_room_version_id(self, room_id):
return RoomVersions.V1.identifier
@@ -158,10 +159,12 @@ class Graph(object):
class StateTestCase(unittest.TestCase):
def setUp(self):
self.store = StateGroupStore()
+ storage = Mock(main=self.store, state=self.store)
hs = Mock(
spec_set=[
"config",
"get_datastore",
+ "get_storage",
"get_auth",
"get_state_handler",
"get_clock",
@@ -174,6 +177,7 @@ class StateTestCase(unittest.TestCase):
hs.get_clock.return_value = MockClock()
hs.get_auth.return_value = Auth(hs)
hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs)
+ hs.get_storage.return_value = storage
self.state = StateHandler(hs)
self.event_id = 0
@@ -195,16 +199,22 @@ class StateTestCase(unittest.TestCase):
self.store.register_events(graph.walk())
- context_store = {}
+ context_store = {} # type: dict[str, EventContext]
for event in graph.walk():
context = yield self.state.compute_event_context(event)
self.store.register_event_context(event, context)
context_store[event.event_id] = context
- prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store)
+ ctx_c = context_store["C"]
+ ctx_d = context_store["D"]
+
+ prev_state_ids = yield ctx_d.get_prev_state_ids()
self.assertEqual(2, len(prev_state_ids))
+ self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event)
+ self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group)
+
@defer.inlineCallbacks
def test_branch_basic_conflict(self):
graph = Graph(
@@ -238,11 +248,16 @@ class StateTestCase(unittest.TestCase):
self.store.register_event_context(event, context)
context_store[event.event_id] = context
- prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store)
+ # C ends up winning the resolution between B and C
- self.assertSetEqual(
- {"START", "A", "C"}, {e_id for e_id in prev_state_ids.values()}
- )
+ ctx_c = context_store["C"]
+ ctx_d = context_store["D"]
+
+ prev_state_ids = yield ctx_d.get_prev_state_ids()
+ self.assertSetEqual({"START", "A", "C"}, set(prev_state_ids.values()))
+
+ self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event)
+ self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group)
@defer.inlineCallbacks
def test_branch_have_banned_conflict(self):
@@ -289,11 +304,16 @@ class StateTestCase(unittest.TestCase):
self.store.register_event_context(event, context)
context_store[event.event_id] = context
- prev_state_ids = yield context_store["E"].get_prev_state_ids(self.store)
+ # C ends up winning the resolution between C and D because bans win over other
+ # changes
- self.assertSetEqual(
- {"START", "A", "B", "C"}, {e for e in prev_state_ids.values()}
- )
+ ctx_c = context_store["C"]
+ ctx_e = context_store["E"]
+
+ prev_state_ids = yield ctx_e.get_prev_state_ids()
+ self.assertSetEqual({"START", "A", "B", "C"}, set(prev_state_ids.values()))
+ self.assertEqual(ctx_c.state_group, ctx_e.state_group_before_event)
+ self.assertEqual(ctx_e.state_group_before_event, ctx_e.state_group)
@defer.inlineCallbacks
def test_branch_have_perms_conflict(self):
@@ -357,11 +377,17 @@ class StateTestCase(unittest.TestCase):
self.store.register_event_context(event, context)
context_store[event.event_id] = context
- prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store)
+ # B ends up winning the resolution between B and C because power levels
+ # win over other changes.
- self.assertSetEqual(
- {"A1", "A2", "A3", "A5", "B"}, {e for e in prev_state_ids.values()}
- )
+ ctx_b = context_store["B"]
+ ctx_d = context_store["D"]
+
+ prev_state_ids = yield ctx_d.get_prev_state_ids()
+ self.assertSetEqual({"A1", "A2", "A3", "A5", "B"}, set(prev_state_ids.values()))
+
+ self.assertEqual(ctx_b.state_group, ctx_d.state_group_before_event)
+ self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group)
def _add_depths(self, nodes, edges):
def _get_depth(ev):
@@ -387,13 +413,16 @@ class StateTestCase(unittest.TestCase):
context = yield self.state.compute_event_context(event, old_state=old_state)
- current_state_ids = yield context.get_current_state_ids(self.store)
+ prev_state_ids = yield context.get_prev_state_ids()
+ self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
- self.assertEqual(
- set(e.event_id for e in old_state), set(current_state_ids.values())
+ current_state_ids = yield context.get_current_state_ids()
+ self.assertCountEqual(
+ (e.event_id for e in old_state), current_state_ids.values()
)
- self.assertIsNotNone(context.state_group)
+ self.assertIsNotNone(context.state_group_before_event)
+ self.assertEqual(context.state_group_before_event, context.state_group)
@defer.inlineCallbacks
def test_annotate_with_old_state(self):
@@ -407,12 +436,19 @@ class StateTestCase(unittest.TestCase):
context = yield self.state.compute_event_context(event, old_state=old_state)
- prev_state_ids = yield context.get_prev_state_ids(self.store)
+ prev_state_ids = yield context.get_prev_state_ids()
+ self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
- self.assertEqual(
- set(e.event_id for e in old_state), set(prev_state_ids.values())
+ current_state_ids = yield context.get_current_state_ids()
+ self.assertCountEqual(
+ (e.event_id for e in old_state + [event]), current_state_ids.values()
)
+ self.assertIsNotNone(context.state_group_before_event)
+ self.assertNotEqual(context.state_group_before_event, context.state_group)
+ self.assertEqual(context.state_group_before_event, context.prev_group)
+ self.assertEqual({("state", ""): event.event_id}, context.delta_ids)
+
@defer.inlineCallbacks
def test_trivial_annotate_message(self):
prev_event_id = "prev_event_id"
@@ -437,10 +473,10 @@ class StateTestCase(unittest.TestCase):
context = yield self.state.compute_event_context(event)
- current_state_ids = yield context.get_current_state_ids(self.store)
+ current_state_ids = yield context.get_current_state_ids()
self.assertEqual(
- set([e.event_id for e in old_state]), set(current_state_ids.values())
+ {e.event_id for e in old_state}, set(current_state_ids.values())
)
self.assertEqual(group_name, context.state_group)
@@ -469,11 +505,9 @@ class StateTestCase(unittest.TestCase):
context = yield self.state.compute_event_context(event)
- prev_state_ids = yield context.get_prev_state_ids(self.store)
+ prev_state_ids = yield context.get_prev_state_ids()
- self.assertEqual(
- set([e.event_id for e in old_state]), set(prev_state_ids.values())
- )
+ self.assertEqual({e.event_id for e in old_state}, set(prev_state_ids.values()))
self.assertIsNotNone(context.state_group)
@@ -510,7 +544,7 @@ class StateTestCase(unittest.TestCase):
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
)
- current_state_ids = yield context.get_current_state_ids(self.store)
+ current_state_ids = yield context.get_current_state_ids()
self.assertEqual(len(current_state_ids), 6)
@@ -552,7 +586,7 @@ class StateTestCase(unittest.TestCase):
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
)
- current_state_ids = yield context.get_current_state_ids(self.store)
+ current_state_ids = yield context.get_current_state_ids()
self.assertEqual(len(current_state_ids), 6)
@@ -607,7 +641,7 @@ class StateTestCase(unittest.TestCase):
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
)
- current_state_ids = yield context.get_current_state_ids(self.store)
+ current_state_ids = yield context.get_current_state_ids()
self.assertEqual(old_state_2[3].event_id, current_state_ids[("test1", "1")])
@@ -635,7 +669,7 @@ class StateTestCase(unittest.TestCase):
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
)
- current_state_ids = yield context.get_current_state_ids(self.store)
+ current_state_ids = yield context.get_current_state_ids()
self.assertEqual(old_state_1[3].event_id, current_state_ids[("test1", "1")])
diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py
index 52739fbabc..5c2817cf28 100644
--- a/tests/test_terms_auth.py
+++ b/tests/test_terms_auth.py
@@ -28,6 +28,21 @@ from tests import unittest
class TermsTestCase(unittest.HomeserverTestCase):
servlets = [register_servlets]
+ def default_config(self):
+ config = super().default_config()
+ config.update(
+ {
+ "public_baseurl": "https://example.org/",
+ "user_consent": {
+ "version": "1.0",
+ "policy_name": "My Cool Privacy Policy",
+ "template_dir": "/",
+ "require_at_registration": True,
+ },
+ }
+ )
+ return config
+
def prepare(self, reactor, clock, hs):
self.clock = MemoryReactorClock()
self.hs_clock = Clock(self.clock)
@@ -35,19 +50,11 @@ class TermsTestCase(unittest.HomeserverTestCase):
self.registration_handler = Mock()
self.auth_handler = Mock()
self.device_handler = Mock()
- hs.config.enable_registration = True
- hs.config.registrations_require_3pid = []
- hs.config.auto_join_rooms = []
- hs.config.enable_registration_captcha = False
def test_ui_auth(self):
- self.hs.config.user_consent_at_registration = True
- self.hs.config.user_consent_policy_name = "My Cool Privacy Policy"
- self.hs.config.public_baseurl = "https://example.org/"
- self.hs.config.user_consent_version = "1.0"
-
# Do a UI auth request
- request, channel = self.make_request(b"POST", self.url, b"{}")
+ request_data = json.dumps({"username": "kermit", "password": "monkey"})
+ request, channel = self.make_request(b"POST", self.url, request_data)
self.render(request)
self.assertEquals(channel.result["code"], b"401", channel.result)
diff --git a/tests/test_types.py b/tests/test_types.py
index 9ab5f829b0..480bea1bdc 100644
--- a/tests/test_types.py
+++ b/tests/test_types.py
@@ -17,18 +17,15 @@ from synapse.api.errors import SynapseError
from synapse.types import GroupID, RoomAlias, UserID, map_username_to_mxid_localpart
from tests import unittest
-from tests.utils import TestHomeServer
-mock_homeserver = TestHomeServer(hostname="my.domain")
-
-class UserIDTestCase(unittest.TestCase):
+class UserIDTestCase(unittest.HomeserverTestCase):
def test_parse(self):
- user = UserID.from_string("@1234abcd:my.domain")
+ user = UserID.from_string("@1234abcd:test")
self.assertEquals("1234abcd", user.localpart)
- self.assertEquals("my.domain", user.domain)
- self.assertEquals(True, mock_homeserver.is_mine(user))
+ self.assertEquals("test", user.domain)
+ self.assertEquals(True, self.hs.is_mine(user))
def test_pase_empty(self):
with self.assertRaises(SynapseError):
@@ -48,13 +45,13 @@ class UserIDTestCase(unittest.TestCase):
self.assertTrue(userA != userB)
-class RoomAliasTestCase(unittest.TestCase):
+class RoomAliasTestCase(unittest.HomeserverTestCase):
def test_parse(self):
- room = RoomAlias.from_string("#channel:my.domain")
+ room = RoomAlias.from_string("#channel:test")
self.assertEquals("channel", room.localpart)
- self.assertEquals("my.domain", room.domain)
- self.assertEquals(True, mock_homeserver.is_mine(room))
+ self.assertEquals("test", room.domain)
+ self.assertEquals(True, self.hs.is_mine(room))
def test_build(self):
room = RoomAlias("channel", "my.domain")
@@ -78,7 +75,7 @@ class GroupIDTestCase(unittest.TestCase):
self.fail("Parsing '%s' should raise exception" % id_string)
except SynapseError as exc:
self.assertEqual(400, exc.code)
- self.assertEqual("M_UNKNOWN", exc.errcode)
+ self.assertEqual("M_INVALID_PARAM", exc.errcode)
class MapUsernameTestCase(unittest.TestCase):
diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index a7310cf12a..7b345b03bb 100644
--- a/tests/test_utils/__init__.py
+++ b/tests/test_utils/__init__.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2019 New Vector Ltd
+# Copyright 2020 The Matrix.org Foundation C.I.C
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,3 +17,22 @@
"""
Utilities for running the unit tests
"""
+from typing import Awaitable, TypeVar
+
+TV = TypeVar("TV")
+
+
+def get_awaitable_result(awaitable: Awaitable[TV]) -> TV:
+ """Get the result from an Awaitable which should have completed
+
+ Asserts that the given awaitable has a result ready, and returns its value
+ """
+ i = awaitable.__await__()
+ try:
+ next(i)
+ except StopIteration as e:
+ # awaitable returned a result
+ return e.value
+
+ # if next didn't raise, the awaitable hasn't completed.
+ raise Exception("awaitable has not yet completed")
diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py
new file mode 100644
index 0000000000..431e9f8e5e
--- /dev/null
+++ b/tests/test_utils/event_injection.py
@@ -0,0 +1,110 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+# Copyright 2020 The Matrix.org Foundation C.I.C
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional, Tuple
+
+import synapse.server
+from synapse.api.constants import EventTypes
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
+from synapse.events import EventBase
+from synapse.events.snapshot import EventContext
+from synapse.types import Collection
+
+from tests.test_utils import get_awaitable_result
+
+
+"""
+Utility functions for poking events into the storage of the server under test.
+"""
+
+
+def inject_member_event(
+ hs: synapse.server.HomeServer,
+ room_id: str,
+ sender: str,
+ membership: str,
+ target: Optional[str] = None,
+ extra_content: Optional[dict] = None,
+ **kwargs
+) -> EventBase:
+ """Inject a membership event into a room."""
+ if target is None:
+ target = sender
+
+ content = {"membership": membership}
+ if extra_content:
+ content.update(extra_content)
+
+ return inject_event(
+ hs,
+ room_id=room_id,
+ type=EventTypes.Member,
+ sender=sender,
+ state_key=target,
+ content=content,
+ **kwargs
+ )
+
+
+def inject_event(
+ hs: synapse.server.HomeServer,
+ room_version: Optional[str] = None,
+ prev_event_ids: Optional[Collection[str]] = None,
+ **kwargs
+) -> EventBase:
+ """Inject a generic event into a room
+
+ Args:
+ hs: the homeserver under test
+ room_version: the version of the room we're inserting into.
+ if not specified, will be looked up
+ prev_event_ids: prev_events for the event. If not specified, will be looked up
+ kwargs: fields for the event to be created
+ """
+ test_reactor = hs.get_reactor()
+
+ event, context = create_event(hs, room_version, prev_event_ids, **kwargs)
+
+ d = hs.get_storage().persistence.persist_event(event, context)
+ test_reactor.advance(0)
+ get_awaitable_result(d)
+
+ return event
+
+
+def create_event(
+ hs: synapse.server.HomeServer,
+ room_version: Optional[str] = None,
+ prev_event_ids: Optional[Collection[str]] = None,
+ **kwargs
+) -> Tuple[EventBase, EventContext]:
+ test_reactor = hs.get_reactor()
+
+ if room_version is None:
+ d = hs.get_datastore().get_room_version_id(kwargs["room_id"])
+ test_reactor.advance(0)
+ room_version = get_awaitable_result(d)
+
+ builder = hs.get_event_builder_factory().for_room_version(
+ KNOWN_ROOM_VERSIONS[room_version], kwargs
+ )
+ d = hs.get_event_creation_handler().create_new_client_event(
+ builder, prev_event_ids=prev_event_ids
+ )
+ test_reactor.advance(0)
+ event, context = get_awaitable_result(d)
+
+ return event, context
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index 18f1a0035d..f7381b2885 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -14,6 +14,8 @@
# limitations under the License.
import logging
+from mock import Mock
+
from twisted.internet import defer
from twisted.internet.defer import succeed
@@ -36,6 +38,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
self.event_creation_handler = self.hs.get_event_creation_handler()
self.event_builder_factory = self.hs.get_event_builder_factory()
self.store = self.hs.get_datastore()
+ self.storage = self.hs.get_storage()
yield create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM")
@@ -62,7 +65,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
events_to_filter.append(evt)
filtered = yield filter_events_for_server(
- self.store, "test_server", events_to_filter
+ self.storage, "test_server", events_to_filter
)
# the result should be 5 redacted events, and 5 unredacted events.
@@ -100,7 +103,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
# ... and the filtering happens.
filtered = yield filter_events_for_server(
- self.store, "test_server", events_to_filter
+ self.storage, "test_server", events_to_filter
)
for i in range(0, len(events_to_filter)):
@@ -137,7 +140,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
event, context = yield self.event_creation_handler.create_new_client_event(
builder
)
- yield self.hs.get_datastore().persist_event(event, context)
+ yield self.storage.persistence.persist_event(event, context)
return event
@defer.inlineCallbacks
@@ -159,7 +162,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
builder
)
- yield self.hs.get_datastore().persist_event(event, context)
+ yield self.storage.persistence.persist_event(event, context)
return event
@defer.inlineCallbacks
@@ -180,7 +183,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
builder
)
- yield self.hs.get_datastore().persist_event(event, context)
+ yield self.storage.persistence.persist_event(event, context)
return event
@defer.inlineCallbacks
@@ -257,6 +260,11 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
logger.info("Starting filtering")
start = time.time()
+
+ storage = Mock()
+ storage.main = test_store
+ storage.state = test_store
+
filtered = yield filter_events_for_server(
test_store, "test_server", events_to_filter
)
diff --git a/tests/unittest.py b/tests/unittest.py
index 561cebc223..6b6f224e9c 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector
+# Copyright 2019 Matrix.org Federation C.I.C
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,29 +14,47 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
import gc
import hashlib
import hmac
+import inspect
import logging
import time
+from typing import Optional, Tuple, Type, TypeVar, Union
from mock import Mock
from canonicaljson import json
-from twisted.internet.defer import Deferred, succeed
+from twisted.internet.defer import Deferred, ensureDeferred, succeed
from twisted.python.threadpool import ThreadPool
from twisted.trial import unittest
-from synapse.api.constants import EventTypes
+from synapse.api.constants import EventTypes, Membership
from synapse.config.homeserver import HomeServerConfig
+from synapse.config.ratelimiting import FederationRateLimitConfig
+from synapse.federation.transport import server as federation_server
from synapse.http.server import JsonResource
-from synapse.http.site import SynapseRequest
-from synapse.logging.context import LoggingContext
+from synapse.http.site import SynapseRequest, SynapseSite
+from synapse.logging.context import (
+ SENTINEL_CONTEXT,
+ LoggingContext,
+ current_context,
+ set_current_context,
+)
from synapse.server import HomeServer
from synapse.types import Requester, UserID, create_requester
-
-from tests.server import get_clock, make_request, render, setup_test_homeserver
+from synapse.util.ratelimitutils import FederationRateLimiter
+
+from tests.server import (
+ FakeChannel,
+ get_clock,
+ make_request,
+ render,
+ setup_test_homeserver,
+)
+from tests.test_utils import event_injection
from tests.test_utils.logging_setup import setup_logging
from tests.utils import default_config, setupdb
@@ -64,6 +83,9 @@ def around(target):
return _around
+T = TypeVar("T")
+
+
class TestCase(unittest.TestCase):
"""A subclass of twisted.trial's TestCase which looks for 'loglevel'
attributes on both itself and its individual test methods, to override the
@@ -80,10 +102,10 @@ class TestCase(unittest.TestCase):
def setUp(orig):
# if we're not starting in the sentinel logcontext, then to be honest
# all future bets are off.
- if LoggingContext.current_context() is not LoggingContext.sentinel:
+ if current_context():
self.fail(
"Test starting with non-sentinel logging context %s"
- % (LoggingContext.current_context(),)
+ % (current_context(),)
)
old_level = logging.getLogger().level
@@ -105,7 +127,7 @@ class TestCase(unittest.TestCase):
# force a GC to workaround problems with deferreds leaking logcontexts when
# they are GCed (see the logcontext docs)
gc.collect()
- LoggingContext.set_current_context(LoggingContext.sentinel)
+ set_current_context(SENTINEL_CONTEXT)
return ret
@@ -203,6 +225,15 @@ class HomeserverTestCase(TestCase):
# Register the resources
self.resource = self.create_test_json_resource()
+ # create a site to wrap the resource.
+ self.site = SynapseSite(
+ logger_name="synapse.access.http.fake",
+ site_tag="test",
+ config={},
+ resource=self.resource,
+ server_version_string="1",
+ )
+
from tests.rest.client.v1.utils import RestHelper
self.helper = RestHelper(self.hs, self.resource, getattr(self, "user_id", None))
@@ -285,14 +316,11 @@ class HomeserverTestCase(TestCase):
return resource
- def default_config(self, name="test"):
+ def default_config(self):
"""
Get a default HomeServer config dict.
-
- Args:
- name (str): The homeserver name/domain.
"""
- config = default_config(name)
+ config = default_config("test")
# apply any additional config which was specified via the override_config
# decorator.
@@ -318,14 +346,14 @@ class HomeserverTestCase(TestCase):
def make_request(
self,
- method,
- path,
- content=b"",
- access_token=None,
- request=SynapseRequest,
- shorthand=True,
- federation_auth_origin=None,
- ):
+ method: Union[bytes, str],
+ path: Union[bytes, str],
+ content: Union[bytes, dict] = b"",
+ access_token: Optional[str] = None,
+ request: Type[T] = SynapseRequest,
+ shorthand: bool = True,
+ federation_auth_origin: str = None,
+ ) -> Tuple[T, FakeChannel]:
"""
Create a SynapseRequest at the path using the method and containing the
given content.
@@ -392,13 +420,17 @@ class HomeserverTestCase(TestCase):
config_obj.parse_config_dict(config, "", "")
kwargs["config"] = config_obj
+ async def run_bg_updates():
+ with LoggingContext("run_bg_updates", request="run_bg_updates-1"):
+ while not await stor.db.updates.has_completed_background_updates():
+ await stor.db.updates.do_next_background_update(1)
+
hs = setup_test_homeserver(self.addCleanup, *args, **kwargs)
stor = hs.get_datastore()
- # Run the database background updates.
- if hasattr(stor, "do_next_background_update"):
- while not self.get_success(stor.has_completed_background_updates()):
- self.get_success(stor.do_next_background_update(1))
+ # Run the database background updates, when running against "master".
+ if hs.__class__.__name__ == "TestHomeServer":
+ self.get_success(run_bg_updates())
return hs
@@ -409,6 +441,8 @@ class HomeserverTestCase(TestCase):
self.reactor.pump([by] * 100)
def get_success(self, d, by=0.0):
+ if inspect.isawaitable(d):
+ d = ensureDeferred(d)
if not isinstance(d, Deferred):
return d
self.pump(by=by)
@@ -418,6 +452,8 @@ class HomeserverTestCase(TestCase):
"""
Run a Deferred and get a Failure from it. The failure must be of the type `exc`.
"""
+ if inspect.isawaitable(d):
+ d = ensureDeferred(d)
if not isinstance(d, Deferred):
return d
self.pump()
@@ -441,7 +477,7 @@ class HomeserverTestCase(TestCase):
# Create the user
request, channel = self.make_request("GET", "/_matrix/client/r0/admin/register")
self.render(request)
- self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.code, 200, msg=channel.result)
nonce = channel.json_body["nonce"]
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
@@ -461,6 +497,7 @@ class HomeserverTestCase(TestCase):
"password": password,
"admin": admin,
"mac": want_mac,
+ "inhibit_login": True,
}
)
request, channel = self.make_request(
@@ -509,10 +546,6 @@ class HomeserverTestCase(TestCase):
secrets = self.hs.get_secrets()
requester = Requester(user, None, False, None, None)
- prev_events_and_hashes = None
- if prev_event_ids:
- prev_events_and_hashes = [[p, {}, 0] for p in prev_event_ids]
-
event, context = self.get_success(
event_creator.create_event(
requester,
@@ -522,7 +555,7 @@ class HomeserverTestCase(TestCase):
"sender": user.to_string(),
"content": {"body": secrets.token_hex(), "msgtype": "m.text"},
},
- prev_events_and_hashes=prev_events_and_hashes,
+ prev_event_ids=prev_event_ids,
)
)
@@ -538,7 +571,7 @@ class HomeserverTestCase(TestCase):
Add the given event as an extremity to the room.
"""
self.get_success(
- self.hs.get_datastore()._simple_insert(
+ self.hs.get_datastore().db.simple_insert(
table="event_forward_extremities",
values={"room_id": room_id, "event_id": event_id},
desc="test_add_extremity",
@@ -559,6 +592,46 @@ class HomeserverTestCase(TestCase):
self.render(request)
self.assertEqual(channel.code, 403, channel.result)
+ def inject_room_member(self, room: str, user: str, membership: Membership) -> None:
+ """
+ Inject a membership event into a room.
+
+ Deprecated: use event_injection.inject_room_member directly
+
+ Args:
+ room: Room ID to inject the event into.
+ user: MXID of the user to inject the membership for.
+ membership: The membership type.
+ """
+ event_injection.inject_member_event(self.hs, room, user, membership)
+
+
+class FederatingHomeserverTestCase(HomeserverTestCase):
+ """
+ A federating homeserver that authenticates incoming requests as `other.example.com`.
+ """
+
+ def prepare(self, reactor, clock, homeserver):
+ class Authenticator(object):
+ def authenticate_request(self, request, content):
+ return succeed("other.example.com")
+
+ ratelimiter = FederationRateLimiter(
+ clock,
+ FederationRateLimitConfig(
+ window_size=1,
+ sleep_limit=1,
+ sleep_msec=1,
+ reject_limit=1000,
+ concurrent_requests=1000,
+ ),
+ )
+ federation_server.register_servlets(
+ homeserver, self.resource, Authenticator(), ratelimiter
+ )
+
+ return super().prepare(reactor, clock, homeserver)
+
def override_config(extra_config):
"""A decorator which can be applied to test functions to give additional HS config
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 5713870f48..4d2b9e0d64 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -22,8 +22,10 @@ from twisted.internet import defer, reactor
from synapse.api.errors import SynapseError
from synapse.logging.context import (
+ SENTINEL_CONTEXT,
LoggingContext,
PreserveLoggingContext,
+ current_context,
make_deferred_yieldable,
)
from synapse.util.caches import descriptors
@@ -194,7 +196,7 @@ class DescriptorTestCase(unittest.TestCase):
with LoggingContext() as c1:
c1.name = "c1"
r = yield obj.fn(1)
- self.assertEqual(LoggingContext.current_context(), c1)
+ self.assertEqual(current_context(), c1)
return r
def check_result(r):
@@ -204,12 +206,12 @@ class DescriptorTestCase(unittest.TestCase):
# set off a deferred which will do a cache lookup
d1 = do_lookup()
- self.assertEqual(LoggingContext.current_context(), LoggingContext.sentinel)
+ self.assertEqual(current_context(), SENTINEL_CONTEXT)
d1.addCallback(check_result)
# and another
d2 = do_lookup()
- self.assertEqual(LoggingContext.current_context(), LoggingContext.sentinel)
+ self.assertEqual(current_context(), SENTINEL_CONTEXT)
d2.addCallback(check_result)
# let the lookup complete
@@ -239,14 +241,14 @@ class DescriptorTestCase(unittest.TestCase):
try:
d = obj.fn(1)
self.assertEqual(
- LoggingContext.current_context(), LoggingContext.sentinel
+ current_context(), SENTINEL_CONTEXT,
)
yield d
self.fail("No exception thrown")
except SynapseError:
pass
- self.assertEqual(LoggingContext.current_context(), c1)
+ self.assertEqual(current_context(), c1)
# the cache should now be empty
self.assertEqual(len(obj.fn.cache.cache), 0)
@@ -255,7 +257,7 @@ class DescriptorTestCase(unittest.TestCase):
# set off a deferred which will do a cache lookup
d1 = do_lookup()
- self.assertEqual(LoggingContext.current_context(), LoggingContext.sentinel)
+ self.assertEqual(current_context(), SENTINEL_CONTEXT)
return d1
@@ -310,14 +312,14 @@ class DescriptorTestCase(unittest.TestCase):
obj.mock.return_value = ["spam", "eggs"]
r = obj.fn(1, 2)
- self.assertEqual(r, ["spam", "eggs"])
+ self.assertEqual(r.result, ["spam", "eggs"])
obj.mock.assert_called_once_with(1, 2)
obj.mock.reset_mock()
# a call with different params should call the mock again
obj.mock.return_value = ["chips"]
r = obj.fn(1, 3)
- self.assertEqual(r, ["chips"])
+ self.assertEqual(r.result, ["chips"])
obj.mock.assert_called_once_with(1, 3)
obj.mock.reset_mock()
@@ -325,9 +327,9 @@ class DescriptorTestCase(unittest.TestCase):
self.assertEqual(len(obj.fn.cache.cache), 3)
r = obj.fn(1, 2)
- self.assertEqual(r, ["spam", "eggs"])
+ self.assertEqual(r.result, ["spam", "eggs"])
r = obj.fn(1, 3)
- self.assertEqual(r, ["chips"])
+ self.assertEqual(r.result, ["chips"])
obj.mock.assert_not_called()
def test_cache_iterable_with_sync_exception(self):
@@ -366,10 +368,10 @@ class CachedListDescriptorTestCase(unittest.TestCase):
@descriptors.cachedList("fn", "args1", inlineCallbacks=True)
def list_fn(self, args1, arg2):
- assert LoggingContext.current_context().request == "c1"
+ assert current_context().request == "c1"
# we want this to behave like an asynchronous function
yield run_on_reactor()
- assert LoggingContext.current_context().request == "c1"
+ assert current_context().request == "c1"
return self.mock(args1, arg2)
with LoggingContext() as c1:
@@ -377,9 +379,9 @@ class CachedListDescriptorTestCase(unittest.TestCase):
obj = Cls()
obj.mock.return_value = {10: "fish", 20: "chips"}
d1 = obj.list_fn([10, 20], 2)
- self.assertEqual(LoggingContext.current_context(), LoggingContext.sentinel)
+ self.assertEqual(current_context(), SENTINEL_CONTEXT)
r = yield d1
- self.assertEqual(LoggingContext.current_context(), c1)
+ self.assertEqual(current_context(), c1)
obj.mock.assert_called_once_with([10, 20], 2)
self.assertEqual(r, {10: "fish", 20: "chips"})
obj.mock.reset_mock()
diff --git a/tests/util/test_async_utils.py b/tests/util/test_async_utils.py
index f60918069a..17fd86d02d 100644
--- a/tests/util/test_async_utils.py
+++ b/tests/util/test_async_utils.py
@@ -16,7 +16,12 @@ from twisted.internet import defer
from twisted.internet.defer import CancelledError, Deferred
from twisted.internet.task import Clock
-from synapse.logging.context import LoggingContext, PreserveLoggingContext
+from synapse.logging.context import (
+ SENTINEL_CONTEXT,
+ LoggingContext,
+ PreserveLoggingContext,
+ current_context,
+)
from synapse.util.async_helpers import timeout_deferred
from tests.unittest import TestCase
@@ -79,10 +84,10 @@ class TimeoutDeferredTest(TestCase):
# the errbacks should be run in the test logcontext
def errback(res, deferred_name):
self.assertIs(
- LoggingContext.current_context(),
+ current_context(),
context_one,
"errback %s run in unexpected logcontext %s"
- % (deferred_name, LoggingContext.current_context()),
+ % (deferred_name, current_context()),
)
return res
@@ -90,7 +95,7 @@ class TimeoutDeferredTest(TestCase):
original_deferred.addErrback(errback, "orig")
timing_out_d = timeout_deferred(original_deferred, 1.0, self.clock)
self.assertNoResult(timing_out_d)
- self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel)
+ self.assertIs(current_context(), SENTINEL_CONTEXT)
timing_out_d.addErrback(errback, "timingout")
self.clock.pump((1.0,))
@@ -99,4 +104,4 @@ class TimeoutDeferredTest(TestCase):
blocking_was_cancelled[0], "non-completing deferred was not cancelled"
)
self.failureResultOf(timing_out_d, defer.TimeoutError)
- self.assertIs(LoggingContext.current_context(), context_one)
+ self.assertIs(current_context(), context_one)
diff --git a/tests/util/test_expiring_cache.py b/tests/util/test_expiring_cache.py
index 50bc7702d2..49ffeebd0e 100644
--- a/tests/util/test_expiring_cache.py
+++ b/tests/util/test_expiring_cache.py
@@ -21,7 +21,7 @@ from tests.utils import MockClock
from .. import unittest
-class ExpiringCacheTestCase(unittest.TestCase):
+class ExpiringCacheTestCase(unittest.HomeserverTestCase):
def test_get_set(self):
clock = MockClock()
cache = ExpiringCache("test", clock, max_len=1)
diff --git a/tests/util/test_itertools.py b/tests/util/test_itertools.py
new file mode 100644
index 0000000000..0ab0a91483
--- /dev/null
+++ b/tests/util/test_itertools.py
@@ -0,0 +1,47 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from synapse.util.iterutils import chunk_seq
+
+from tests.unittest import TestCase
+
+
+class ChunkSeqTests(TestCase):
+ def test_short_seq(self):
+ parts = chunk_seq("123", 8)
+
+ self.assertEqual(
+ list(parts), ["123"],
+ )
+
+ def test_long_seq(self):
+ parts = chunk_seq("abcdefghijklmnop", 8)
+
+ self.assertEqual(
+ list(parts), ["abcdefgh", "ijklmnop"],
+ )
+
+ def test_uneven_parts(self):
+ parts = chunk_seq("abcdefghijklmnop", 5)
+
+ self.assertEqual(
+ list(parts), ["abcde", "fghij", "klmno", "p"],
+ )
+
+ def test_empty_input(self):
+ parts = chunk_seq([], 5)
+
+ self.assertEqual(
+ list(parts), [],
+ )
diff --git a/tests/util/test_linearizer.py b/tests/util/test_linearizer.py
index 0ec8ef90ce..ca3858b184 100644
--- a/tests/util/test_linearizer.py
+++ b/tests/util/test_linearizer.py
@@ -19,7 +19,7 @@ from six.moves import range
from twisted.internet import defer, reactor
from twisted.internet.defer import CancelledError
-from synapse.logging.context import LoggingContext
+from synapse.logging.context import LoggingContext, current_context
from synapse.util import Clock
from synapse.util.async_helpers import Linearizer
@@ -45,6 +45,38 @@ class LinearizerTestCase(unittest.TestCase):
with (yield d2):
pass
+ @defer.inlineCallbacks
+ def test_linearizer_is_queued(self):
+ linearizer = Linearizer()
+
+ key = object()
+
+ d1 = linearizer.queue(key)
+ cm1 = yield d1
+
+ # Since d1 gets called immediately, "is_queued" should return false.
+ self.assertFalse(linearizer.is_queued(key))
+
+ d2 = linearizer.queue(key)
+ self.assertFalse(d2.called)
+
+ # Now d2 is queued up behind successful completion of cm1
+ self.assertTrue(linearizer.is_queued(key))
+
+ with cm1:
+ self.assertFalse(d2.called)
+
+ # cm1 still not done, so d2 still queued.
+ self.assertTrue(linearizer.is_queued(key))
+
+ # And now d2 is called and nothing is in the queue again
+ self.assertFalse(linearizer.is_queued(key))
+
+ with (yield d2):
+ self.assertFalse(linearizer.is_queued(key))
+
+ self.assertFalse(linearizer.is_queued(key))
+
def test_lots_of_queued_things(self):
# we have one slow thing, and lots of fast things queued up behind it.
# it should *not* explode the stack.
@@ -54,11 +86,11 @@ class LinearizerTestCase(unittest.TestCase):
def func(i, sleep=False):
with LoggingContext("func(%s)" % i) as lc:
with (yield linearizer.queue("")):
- self.assertEqual(LoggingContext.current_context(), lc)
+ self.assertEqual(current_context(), lc)
if sleep:
yield Clock(reactor).sleep(0)
- self.assertEqual(LoggingContext.current_context(), lc)
+ self.assertEqual(current_context(), lc)
func(0, sleep=True)
for i in range(1, 100):
diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py
index 8b8455c8b7..95301c013c 100644
--- a/tests/util/test_logcontext.py
+++ b/tests/util/test_logcontext.py
@@ -2,8 +2,10 @@ import twisted.python.failure
from twisted.internet import defer, reactor
from synapse.logging.context import (
+ SENTINEL_CONTEXT,
LoggingContext,
PreserveLoggingContext,
+ current_context,
make_deferred_yieldable,
nested_logging_context,
run_in_background,
@@ -15,7 +17,7 @@ from .. import unittest
class LoggingContextTestCase(unittest.TestCase):
def _check_test_key(self, value):
- self.assertEquals(LoggingContext.current_context().request, value)
+ self.assertEquals(current_context().request, value)
def test_with_context(self):
with LoggingContext() as context_one:
@@ -41,7 +43,7 @@ class LoggingContextTestCase(unittest.TestCase):
self._check_test_key("one")
def _test_run_in_background(self, function):
- sentinel_context = LoggingContext.current_context()
+ sentinel_context = current_context()
callback_completed = [False]
@@ -71,7 +73,7 @@ class LoggingContextTestCase(unittest.TestCase):
# make sure that the context was reset before it got thrown back
# into the reactor
try:
- self.assertIs(LoggingContext.current_context(), sentinel_context)
+ self.assertIs(current_context(), sentinel_context)
d2.callback(None)
except BaseException:
d2.errback(twisted.python.failure.Failure())
@@ -108,7 +110,7 @@ class LoggingContextTestCase(unittest.TestCase):
async def testfunc():
self._check_test_key("one")
d = Clock(reactor).sleep(0)
- self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel)
+ self.assertIs(current_context(), SENTINEL_CONTEXT)
await d
self._check_test_key("one")
@@ -129,14 +131,14 @@ class LoggingContextTestCase(unittest.TestCase):
reactor.callLater(0, d.callback, None)
return d
- sentinel_context = LoggingContext.current_context()
+ sentinel_context = current_context()
with LoggingContext() as context_one:
context_one.request = "one"
d1 = make_deferred_yieldable(blocking_function())
# make sure that the context was reset by make_deferred_yieldable
- self.assertIs(LoggingContext.current_context(), sentinel_context)
+ self.assertIs(current_context(), sentinel_context)
yield d1
@@ -145,14 +147,14 @@ class LoggingContextTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_make_deferred_yieldable_with_chained_deferreds(self):
- sentinel_context = LoggingContext.current_context()
+ sentinel_context = current_context()
with LoggingContext() as context_one:
context_one.request = "one"
d1 = make_deferred_yieldable(_chained_deferred_function())
# make sure that the context was reset by make_deferred_yieldable
- self.assertIs(LoggingContext.current_context(), sentinel_context)
+ self.assertIs(current_context(), sentinel_context)
yield d1
@@ -179,6 +181,30 @@ class LoggingContextTestCase(unittest.TestCase):
nested_context = nested_logging_context(suffix="bar")
self.assertEqual(nested_context.request, "foo-bar")
+ @defer.inlineCallbacks
+ def test_make_deferred_yieldable_with_await(self):
+ # an async function which retuns an incomplete coroutine, but doesn't
+ # follow the synapse rules.
+
+ async def blocking_function():
+ d = defer.Deferred()
+ reactor.callLater(0, d.callback, None)
+ await d
+
+ sentinel_context = current_context()
+
+ with LoggingContext() as context_one:
+ context_one.request = "one"
+
+ d1 = make_deferred_yieldable(blocking_function())
+ # make sure that the context was reset by make_deferred_yieldable
+ self.assertIs(current_context(), sentinel_context)
+
+ yield d1
+
+ # now it should be restored
+ self._check_test_key("one")
+
# a function which returns a deferred which has been "called", but
# which had a function which returned another incomplete deferred on
diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py
index 786947375d..0adb2174af 100644
--- a/tests/util/test_lrucache.py
+++ b/tests/util/test_lrucache.py
@@ -22,7 +22,7 @@ from synapse.util.caches.treecache import TreeCache
from .. import unittest
-class LruCacheTestCase(unittest.TestCase):
+class LruCacheTestCase(unittest.HomeserverTestCase):
def test_get_set(self):
cache = LruCache(1)
cache["key"] = "value"
@@ -84,7 +84,7 @@ class LruCacheTestCase(unittest.TestCase):
self.assertEquals(len(cache), 0)
-class LruCacheCallbacksTestCase(unittest.TestCase):
+class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
def test_get(self):
m = Mock()
cache = LruCache(1)
@@ -233,7 +233,7 @@ class LruCacheCallbacksTestCase(unittest.TestCase):
self.assertEquals(m3.call_count, 1)
-class LruCacheSizedTestCase(unittest.TestCase):
+class LruCacheSizedTestCase(unittest.HomeserverTestCase):
def test_evict(self):
cache = LruCache(5, size_callback=len)
cache["key1"] = [0]
diff --git a/tests/util/test_snapshot_cache.py b/tests/util/test_snapshot_cache.py
deleted file mode 100644
index 1a44f72425..0000000000
--- a/tests/util/test_snapshot_cache.py
+++ /dev/null
@@ -1,63 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2015, 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-from twisted.internet.defer import Deferred
-
-from synapse.util.caches.snapshot_cache import SnapshotCache
-
-from .. import unittest
-
-
-class SnapshotCacheTestCase(unittest.TestCase):
- def setUp(self):
- self.cache = SnapshotCache()
- self.cache.DURATION_MS = 1
-
- def test_get_set(self):
- # Check that getting a missing key returns None
- self.assertEquals(self.cache.get(0, "key"), None)
-
- # Check that setting a key with a deferred returns
- # a deferred that resolves when the initial deferred does
- d = Deferred()
- set_result = self.cache.set(0, "key", d)
- self.assertIsNotNone(set_result)
- self.assertFalse(set_result.called)
-
- # Check that getting the key before the deferred has resolved
- # returns a deferred that resolves when the initial deferred does.
- get_result_at_10 = self.cache.get(10, "key")
- self.assertIsNotNone(get_result_at_10)
- self.assertFalse(get_result_at_10.called)
-
- # Check that the returned deferreds resolve when the initial deferred
- # does.
- d.callback("v")
- self.assertTrue(set_result.called)
- self.assertTrue(get_result_at_10.called)
-
- # Check that getting the key after the deferred has resolved
- # before the cache expires returns a resolved deferred.
- get_result_at_11 = self.cache.get(11, "key")
- self.assertIsNotNone(get_result_at_11)
- if isinstance(get_result_at_11, Deferred):
- # The cache may return the actual result rather than a deferred
- self.assertTrue(get_result_at_11.called)
-
- # Check that getting the key after the deferred has resolved
- # after the cache expires returns None
- get_result_at_12 = self.cache.get(12, "key")
- self.assertIsNone(get_result_at_12)
diff --git a/tests/util/test_stream_change_cache.py b/tests/util/test_stream_change_cache.py
index f2be63706b..13b753e367 100644
--- a/tests/util/test_stream_change_cache.py
+++ b/tests/util/test_stream_change_cache.py
@@ -1,11 +1,9 @@
-from mock import patch
-
from synapse.util.caches.stream_change_cache import StreamChangeCache
from tests import unittest
-class StreamChangeCacheTests(unittest.TestCase):
+class StreamChangeCacheTests(unittest.HomeserverTestCase):
"""
Tests for StreamChangeCache.
"""
@@ -28,26 +26,33 @@ class StreamChangeCacheTests(unittest.TestCase):
cache.entity_has_changed("user@foo.com", 6)
cache.entity_has_changed("bar@baz.net", 7)
+ # also test multiple things changing on the same stream ID
+ cache.entity_has_changed("user2@foo.com", 8)
+ cache.entity_has_changed("bar2@baz.net", 8)
+
# If it's been changed after that stream position, return True
self.assertTrue(cache.has_entity_changed("user@foo.com", 4))
self.assertTrue(cache.has_entity_changed("bar@baz.net", 4))
+ self.assertTrue(cache.has_entity_changed("bar2@baz.net", 4))
+ self.assertTrue(cache.has_entity_changed("user2@foo.com", 4))
# If it's been changed at that stream position, return False
self.assertFalse(cache.has_entity_changed("user@foo.com", 6))
+ self.assertFalse(cache.has_entity_changed("user2@foo.com", 8))
# If there's no changes after that stream position, return False
self.assertFalse(cache.has_entity_changed("user@foo.com", 7))
+ self.assertFalse(cache.has_entity_changed("user2@foo.com", 9))
# If the entity does not exist, return False.
- self.assertFalse(cache.has_entity_changed("not@here.website", 7))
+ self.assertFalse(cache.has_entity_changed("not@here.website", 9))
# If we request before the stream cache's earliest known position,
# return True, whether it's a known entity or not.
self.assertTrue(cache.has_entity_changed("user@foo.com", 0))
self.assertTrue(cache.has_entity_changed("not@here.website", 0))
- @patch("synapse.util.caches.CACHE_SIZE_FACTOR", 1.0)
- def test_has_entity_changed_pops_off_start(self):
+ def test_entity_has_changed_pops_off_start(self):
"""
StreamChangeCache.entity_has_changed will respect the max size and
purge the oldest items upon reaching that max size.
@@ -64,11 +69,20 @@ class StreamChangeCacheTests(unittest.TestCase):
# The oldest item has been popped off
self.assertTrue("user@foo.com" not in cache._entity_to_key)
+ self.assertEqual(
+ cache.get_all_entities_changed(2), ["bar@baz.net", "user@elsewhere.org"],
+ )
+ self.assertIsNone(cache.get_all_entities_changed(1))
+
# If we update an existing entity, it keeps the two existing entities
cache.entity_has_changed("bar@baz.net", 5)
self.assertEqual(
- set(["bar@baz.net", "user@elsewhere.org"]), set(cache._entity_to_key)
+ {"bar@baz.net", "user@elsewhere.org"}, set(cache._entity_to_key)
)
+ self.assertEqual(
+ cache.get_all_entities_changed(2), ["user@elsewhere.org", "bar@baz.net"],
+ )
+ self.assertIsNone(cache.get_all_entities_changed(1))
def test_get_all_entities_changed(self):
"""
@@ -80,18 +94,52 @@ class StreamChangeCacheTests(unittest.TestCase):
cache.entity_has_changed("user@foo.com", 2)
cache.entity_has_changed("bar@baz.net", 3)
+ cache.entity_has_changed("anotheruser@foo.com", 3)
cache.entity_has_changed("user@elsewhere.org", 4)
- self.assertEqual(
- cache.get_all_entities_changed(1),
- ["user@foo.com", "bar@baz.net", "user@elsewhere.org"],
- )
- self.assertEqual(
- cache.get_all_entities_changed(2), ["bar@baz.net", "user@elsewhere.org"]
- )
+ r = cache.get_all_entities_changed(1)
+
+ # either of these are valid
+ ok1 = [
+ "user@foo.com",
+ "bar@baz.net",
+ "anotheruser@foo.com",
+ "user@elsewhere.org",
+ ]
+ ok2 = [
+ "user@foo.com",
+ "anotheruser@foo.com",
+ "bar@baz.net",
+ "user@elsewhere.org",
+ ]
+ self.assertTrue(r == ok1 or r == ok2)
+
+ r = cache.get_all_entities_changed(2)
+ self.assertTrue(r == ok1[1:] or r == ok2[1:])
+
self.assertEqual(cache.get_all_entities_changed(3), ["user@elsewhere.org"])
self.assertEqual(cache.get_all_entities_changed(0), None)
+ # ... later, things gest more updates
+ cache.entity_has_changed("user@foo.com", 5)
+ cache.entity_has_changed("bar@baz.net", 5)
+ cache.entity_has_changed("anotheruser@foo.com", 6)
+
+ ok1 = [
+ "user@elsewhere.org",
+ "user@foo.com",
+ "bar@baz.net",
+ "anotheruser@foo.com",
+ ]
+ ok2 = [
+ "user@elsewhere.org",
+ "bar@baz.net",
+ "user@foo.com",
+ "anotheruser@foo.com",
+ ]
+ r = cache.get_all_entities_changed(3)
+ self.assertTrue(r == ok1 or r == ok2)
+
def test_has_any_entity_changed(self):
"""
StreamChangeCache.has_any_entity_changed will return True if any
@@ -137,7 +185,7 @@ class StreamChangeCacheTests(unittest.TestCase):
cache.get_entities_changed(
["user@foo.com", "bar@baz.net", "user@elsewhere.org"], stream_pos=2
),
- set(["bar@baz.net", "user@elsewhere.org"]),
+ {"bar@baz.net", "user@elsewhere.org"},
)
# Query all the entries mid-way through the stream, but include one
@@ -153,7 +201,7 @@ class StreamChangeCacheTests(unittest.TestCase):
],
stream_pos=2,
),
- set(["bar@baz.net", "user@elsewhere.org"]),
+ {"bar@baz.net", "user@elsewhere.org"},
)
# Query all the entries, but before the first known point. We will get
@@ -168,21 +216,13 @@ class StreamChangeCacheTests(unittest.TestCase):
],
stream_pos=0,
),
- set(
- [
- "user@foo.com",
- "bar@baz.net",
- "user@elsewhere.org",
- "not@here.website",
- ]
- ),
+ {"user@foo.com", "bar@baz.net", "user@elsewhere.org", "not@here.website"},
)
# Query a subset of the entries mid-way through the stream. We should
# only get back the subset.
self.assertEqual(
- cache.get_entities_changed(["bar@baz.net"], stream_pos=2),
- set(["bar@baz.net"]),
+ cache.get_entities_changed(["bar@baz.net"], stream_pos=2), {"bar@baz.net"},
)
def test_max_pos(self):
diff --git a/tests/util/test_stringutils.py b/tests/util/test_stringutils.py
new file mode 100644
index 0000000000..4f4da29a98
--- /dev/null
+++ b/tests/util/test_stringutils.py
@@ -0,0 +1,51 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.api.errors import SynapseError
+from synapse.util.stringutils import assert_valid_client_secret
+
+from .. import unittest
+
+
+class StringUtilsTestCase(unittest.TestCase):
+ def test_client_secret_regex(self):
+ """Ensure that client_secret does not contain illegal characters"""
+ good = [
+ "abcde12345",
+ "ABCabc123",
+ "_--something==_",
+ "...--==-18913",
+ "8Dj2odd-e9asd.cd==_--ddas-secret-",
+ # We temporarily allow : characters: https://github.com/matrix-org/synapse/issues/6766
+ # To be removed in a future release
+ "SECRET:1234567890",
+ ]
+
+ bad = [
+ "--+-/secret",
+ "\\dx--dsa288",
+ "",
+ "AAS//",
+ "asdj**",
+ ">X><Z<!!-)))",
+ "a@b.com",
+ ]
+
+ for client_secret in good:
+ assert_valid_client_secret(client_secret)
+
+ for client_secret in bad:
+ with self.assertRaises(SynapseError):
+ assert_valid_client_secret(client_secret)
diff --git a/tests/utils.py b/tests/utils.py
index 46ef2959f2..59c020a051 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -30,19 +30,16 @@ from twisted.internet import defer, reactor
from synapse.api.constants import EventTypes
from synapse.api.errors import CodeMessageException, cs_error
from synapse.api.room_versions import RoomVersions
+from synapse.config.database import DatabaseConnectionConfig
from synapse.config.homeserver import HomeServerConfig
from synapse.config.server import DEFAULT_ROOM_VERSION
from synapse.federation.transport import server as federation_server
from synapse.http.server import HttpServer
-from synapse.logging.context import LoggingContext
+from synapse.logging.context import current_context, set_current_context
from synapse.server import HomeServer
from synapse.storage import DataStore
from synapse.storage.engines import PostgresEngine, create_engine
-from synapse.storage.prepare_database import (
- _get_or_create_schema_state,
- _setup_new_database,
- prepare_database,
-)
+from synapse.storage.prepare_database import prepare_database
from synapse.util.ratelimitutils import FederationRateLimiter
# set this to True to run the tests against postgres instead of sqlite.
@@ -77,7 +74,10 @@ def setupdb():
db_conn.autocommit = True
cur = db_conn.cursor()
cur.execute("DROP DATABASE IF EXISTS %s;" % (POSTGRES_BASE_DB,))
- cur.execute("CREATE DATABASE %s;" % (POSTGRES_BASE_DB,))
+ cur.execute(
+ "CREATE DATABASE %s ENCODING 'UTF8' LC_COLLATE='C' LC_CTYPE='C' "
+ "template=template0;" % (POSTGRES_BASE_DB,)
+ )
cur.close()
db_conn.close()
@@ -88,11 +88,7 @@ def setupdb():
host=POSTGRES_HOST,
password=POSTGRES_PASSWORD,
)
- cur = db_conn.cursor()
- _get_or_create_schema_state(cur, db_engine)
- _setup_new_database(cur, db_engine)
- db_conn.commit()
- cur.close()
+ prepare_database(db_conn, db_engine, None)
db_conn.close()
def _cleanup():
@@ -117,6 +113,7 @@ def default_config(name, parse=False):
"""
config_dict = {
"server_name": name,
+ "send_federation": False,
"media_store_path": "media",
"uploads_path": "uploads",
# the test signing key is just an arbitrary ed25519 key to keep the config
@@ -145,7 +142,6 @@ def default_config(name, parse=False):
"limit_usage_by_mau": False,
"hs_disabled": False,
"hs_disabled_message": "",
- "hs_disabled_limit_type": "",
"max_mau_value": 50,
"mau_trial_days": 0,
"mau_stats_only": False,
@@ -171,6 +167,7 @@ def default_config(name, parse=False):
# disable user directory updates, because they get done in the
# background, which upsets the test runner.
"update_user_directory": False,
+ "caches": {"global_factor": 1},
}
if parse:
@@ -185,7 +182,6 @@ class TestHomeServer(HomeServer):
DATASTORE_CLASS = DataStore
-@defer.inlineCallbacks
def setup_test_homeserver(
cleanup_func,
name="test",
@@ -222,7 +218,7 @@ def setup_test_homeserver(
if USE_POSTGRES_FOR_TESTS:
test_db = "synapse_test_%s" % uuid.uuid4().hex
- config.database_config = {
+ database_config = {
"name": "psycopg2",
"args": {
"database": test_db,
@@ -234,12 +230,15 @@ def setup_test_homeserver(
},
}
else:
- config.database_config = {
+ database_config = {
"name": "sqlite3",
"args": {"database": ":memory:", "cp_min": 1, "cp_max": 1},
}
- db_engine = create_engine(config.database_config)
+ database = DatabaseConnectionConfig("master", database_config)
+ config.database.databases = [database]
+
+ db_engine = create_engine(database.config)
# Create the database before we actually try and connect to it, based off
# the template database we generate in setupdb()
@@ -259,39 +258,30 @@ def setup_test_homeserver(
cur.close()
db_conn.close()
- # we need to configure the connection pool to run the on_new_connection
- # function, so that we can test code that uses custom sqlite functions
- # (like rank).
- config.database_config["args"]["cp_openfun"] = db_engine.on_new_connection
-
if datastore is None:
hs = homeserverToUse(
name,
config=config,
- db_config=config.database_config,
version_string="Synapse/tests",
- database_engine=db_engine,
tls_server_context_factory=Mock(),
tls_client_options_factory=Mock(),
reactor=reactor,
**kargs
)
- # Prepare the DB on SQLite -- PostgreSQL is a copy of an already up to
- # date db
- if not isinstance(db_engine, PostgresEngine):
- db_conn = hs.get_db_conn()
- yield prepare_database(db_conn, db_engine, config)
- db_conn.commit()
- db_conn.close()
+ hs.setup()
+ if homeserverToUse.__name__ == "TestHomeServer":
+ hs.setup_master()
+
+ if isinstance(db_engine, PostgresEngine):
+ database = hs.get_datastores().databases[0]
- else:
# We need to do cleanup on PostgreSQL
def cleanup():
import psycopg2
# Close all the db pools
- hs.get_db_pool().close()
+ database._db_pool.close()
dropped = False
@@ -330,17 +320,12 @@ def setup_test_homeserver(
# Register the cleanup hook
cleanup_func(cleanup)
- hs.setup()
- if homeserverToUse.__name__ == "TestHomeServer":
- hs.setup_master()
else:
hs = homeserverToUse(
name,
- db_pool=None,
datastore=datastore,
config=config,
version_string="Synapse/tests",
- database_engine=db_engine,
tls_server_context_factory=Mock(),
tls_client_options_factory=Mock(),
reactor=reactor,
@@ -351,10 +336,15 @@ def setup_test_homeserver(
# Need to let the HS build an auth handler and then mess with it
# because AuthHandler's constructor requires the HS, so we can't make one
# beforehand and pass it in to the HS's constructor (chicken / egg)
- hs.get_auth_handler().hash = lambda p: hashlib.md5(p.encode("utf8")).hexdigest()
- hs.get_auth_handler().validate_hash = (
- lambda p, h: hashlib.md5(p.encode("utf8")).hexdigest() == h
- )
+ async def hash(p):
+ return hashlib.md5(p.encode("utf8")).hexdigest()
+
+ hs.get_auth_handler().hash = hash
+
+ async def validate_hash(p, h):
+ return hashlib.md5(p.encode("utf8")).hexdigest() == h
+
+ hs.get_auth_handler().validate_hash = validate_hash
fed = kargs.get("resource_for_federation", None)
if fed:
@@ -463,7 +453,9 @@ class MockHttpResource(HttpServer):
try:
args = [urlparse.unquote(u) for u in matcher.groups()]
- (code, response) = yield func(mock_request, *args)
+ (code, response) = yield defer.ensureDeferred(
+ func(mock_request, *args)
+ )
return code, response
except CodeMessageException as e:
return (e.code, cs_error(e.msg, code=e.errcode))
@@ -510,10 +502,10 @@ class MockClock(object):
return self.time() * 1000
def call_later(self, delay, callback, *args, **kwargs):
- current_context = LoggingContext.current_context()
+ ctx = current_context()
def wrapped_callback():
- LoggingContext.thread_local.current_context = current_context
+ set_current_context(ctx)
callback(*args, **kwargs)
t = [self.now + delay, wrapped_callback, False]
@@ -521,8 +513,8 @@ class MockClock(object):
return t
- def looping_call(self, function, interval):
- self.loopers.append([function, interval / 1000.0, self.now])
+ def looping_call(self, function, interval, *args, **kwargs):
+ self.loopers.append([function, interval / 1000.0, self.now, args, kwargs])
def cancel_call_later(self, timer, ignore_errs=False):
if timer[2]:
@@ -552,9 +544,9 @@ class MockClock(object):
self.timers.append(t)
for looped in self.loopers:
- func, interval, last = looped
+ func, interval, last, args, kwargs = looped
if last + interval < self.now:
- func()
+ func(*args, **kwargs)
looped[2] = self.now
def advance_time_msec(self, ms):
@@ -655,10 +647,18 @@ def create_room(hs, room_id, creator_id):
creator_id (str)
"""
+ persistence_store = hs.get_storage().persistence
store = hs.get_datastore()
event_builder_factory = hs.get_event_builder_factory()
event_creation_handler = hs.get_event_creation_handler()
+ yield store.store_room(
+ room_id=room_id,
+ room_creator_user_id=creator_id,
+ is_public=False,
+ room_version=RoomVersions.V1,
+ )
+
builder = event_builder_factory.for_room_version(
RoomVersions.V1,
{
@@ -672,4 +672,4 @@ def create_room(hs, room_id, creator_id):
event, context = yield event_creation_handler.create_new_client_event(builder)
- yield store.persist_event(event, context)
+ yield persistence_store.persist_event(event, context)
|