diff --git a/tests/__init__.py b/tests/__init__.py
index d3181f9403..f7fc502f01 100644
--- a/tests/__init__.py
+++ b/tests/__init__.py
@@ -21,4 +21,4 @@ import tests.patch_inline_callbacks
# attempt to do the patch before we load any synapse code
tests.patch_inline_callbacks.do_patch()
-util.DEFAULT_TIMEOUT_DURATION = 10
+util.DEFAULT_TIMEOUT_DURATION = 20
diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index 2a7044801a..6ba623de13 100644
--- a/tests/api/test_filtering.py
+++ b/tests/api/test_filtering.py
@@ -109,7 +109,6 @@ class FilteringTestCase(unittest.TestCase):
"event_format": "client",
"event_fields": ["type", "content", "sender"],
},
-
# a single backslash should be permitted (though it is debatable whether
# it should be permitted before anything other than `.`, and what that
# actually means)
diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py
index 30a255d441..dbdd427cac 100644
--- a/tests/api/test_ratelimiting.py
+++ b/tests/api/test_ratelimiting.py
@@ -10,19 +10,19 @@ class TestRatelimiter(unittest.TestCase):
key="test_id", time_now_s=0, rate_hz=0.1, burst_count=1
)
self.assertTrue(allowed)
- self.assertEquals(10., time_allowed)
+ self.assertEquals(10.0, time_allowed)
allowed, time_allowed = limiter.can_do_action(
key="test_id", time_now_s=5, rate_hz=0.1, burst_count=1
)
self.assertFalse(allowed)
- self.assertEquals(10., time_allowed)
+ self.assertEquals(10.0, time_allowed)
allowed, time_allowed = limiter.can_do_action(
key="test_id", time_now_s=10, rate_hz=0.1, burst_count=1
)
self.assertTrue(allowed)
- self.assertEquals(20., time_allowed)
+ self.assertEquals(20.0, time_allowed)
def test_pruning(self):
limiter = Ratelimiter()
diff --git a/tests/app/test_openid_listener.py b/tests/app/test_openid_listener.py
index 590abc1e92..48792d1480 100644
--- a/tests/app/test_openid_listener.py
+++ b/tests/app/test_openid_listener.py
@@ -25,16 +25,18 @@ from tests.unittest import HomeserverTestCase
class FederationReaderOpenIDListenerTests(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver(
- http_client=None, homeserverToUse=FederationReaderServer,
+ http_client=None, homeserverToUse=FederationReaderServer
)
return hs
- @parameterized.expand([
- (["federation"], "auth_fail"),
- ([], "no_resource"),
- (["openid", "federation"], "auth_fail"),
- (["openid"], "auth_fail"),
- ])
+ @parameterized.expand(
+ [
+ (["federation"], "auth_fail"),
+ ([], "no_resource"),
+ (["openid", "federation"], "auth_fail"),
+ (["openid"], "auth_fail"),
+ ]
+ )
def test_openid_listener(self, names, expectation):
"""
Test different openid listener configurations.
@@ -53,17 +55,14 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase):
# Grab the resource from the site that was told to listen
site = self.reactor.tcpServers[0][1]
try:
- self.resource = (
- site.resource.children[b"_matrix"].children[b"federation"]
- )
+ self.resource = site.resource.children[b"_matrix"].children[b"federation"]
except KeyError:
if expectation == "no_resource":
return
raise
request, channel = self.make_request(
- "GET",
- "/_matrix/federation/v1/openid/userinfo",
+ "GET", "/_matrix/federation/v1/openid/userinfo"
)
self.render(request)
@@ -74,16 +73,18 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase):
class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver(
- http_client=None, homeserverToUse=SynapseHomeServer,
+ http_client=None, homeserverToUse=SynapseHomeServer
)
return hs
- @parameterized.expand([
- (["federation"], "auth_fail"),
- ([], "no_resource"),
- (["openid", "federation"], "auth_fail"),
- (["openid"], "auth_fail"),
- ])
+ @parameterized.expand(
+ [
+ (["federation"], "auth_fail"),
+ ([], "no_resource"),
+ (["openid", "federation"], "auth_fail"),
+ (["openid"], "auth_fail"),
+ ]
+ )
def test_openid_listener(self, names, expectation):
"""
Test different openid listener configurations.
@@ -102,17 +103,14 @@ class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase):
# Grab the resource from the site that was told to listen
site = self.reactor.tcpServers[0][1]
try:
- self.resource = (
- site.resource.children[b"_matrix"].children[b"federation"]
- )
+ self.resource = site.resource.children[b"_matrix"].children[b"federation"]
except KeyError:
if expectation == "no_resource":
return
raise
request, channel = self.make_request(
- "GET",
- "/_matrix/federation/v1/openid/userinfo",
+ "GET", "/_matrix/federation/v1/openid/userinfo"
)
self.render(request)
diff --git a/tests/config/test_generate.py b/tests/config/test_generate.py
index 795b4c298d..5017cbce85 100644
--- a/tests/config/test_generate.py
+++ b/tests/config/test_generate.py
@@ -45,13 +45,7 @@ class ConfigGenerationTestCase(unittest.TestCase):
)
self.assertSetEqual(
- set(
- [
- "homeserver.yaml",
- "lemurs.win.log.config",
- "lemurs.win.signing.key",
- ]
- ),
+ set(["homeserver.yaml", "lemurs.win.log.config", "lemurs.win.signing.key"]),
set(os.listdir(self.dir)),
)
diff --git a/tests/config/test_room_directory.py b/tests/config/test_room_directory.py
index 47fffcfeb2..0ec10019b3 100644
--- a/tests/config/test_room_directory.py
+++ b/tests/config/test_room_directory.py
@@ -22,7 +22,8 @@ from tests import unittest
class RoomDirectoryConfigTestCase(unittest.TestCase):
def test_alias_creation_acl(self):
- config = yaml.safe_load("""
+ config = yaml.safe_load(
+ """
alias_creation_rules:
- user_id: "*bob*"
alias: "*"
@@ -38,43 +39,49 @@ class RoomDirectoryConfigTestCase(unittest.TestCase):
action: "allow"
room_list_publication_rules: []
- """)
+ """
+ )
rd_config = RoomDirectoryConfig()
rd_config.read_config(config)
- self.assertFalse(rd_config.is_alias_creation_allowed(
- user_id="@bob:example.com",
- room_id="!test",
- alias="#test:example.com",
- ))
-
- self.assertTrue(rd_config.is_alias_creation_allowed(
- user_id="@test:example.com",
- room_id="!test",
- alias="#unofficial_st:example.com",
- ))
-
- self.assertTrue(rd_config.is_alias_creation_allowed(
- user_id="@foobar:example.com",
- room_id="!test",
- alias="#test:example.com",
- ))
-
- self.assertTrue(rd_config.is_alias_creation_allowed(
- user_id="@gah:example.com",
- room_id="!test",
- alias="#goo:example.com",
- ))
-
- self.assertFalse(rd_config.is_alias_creation_allowed(
- user_id="@test:example.com",
- room_id="!test",
- alias="#test:example.com",
- ))
+ self.assertFalse(
+ rd_config.is_alias_creation_allowed(
+ user_id="@bob:example.com", room_id="!test", alias="#test:example.com"
+ )
+ )
+
+ self.assertTrue(
+ rd_config.is_alias_creation_allowed(
+ user_id="@test:example.com",
+ room_id="!test",
+ alias="#unofficial_st:example.com",
+ )
+ )
+
+ self.assertTrue(
+ rd_config.is_alias_creation_allowed(
+ user_id="@foobar:example.com",
+ room_id="!test",
+ alias="#test:example.com",
+ )
+ )
+
+ self.assertTrue(
+ rd_config.is_alias_creation_allowed(
+ user_id="@gah:example.com", room_id="!test", alias="#goo:example.com"
+ )
+ )
+
+ self.assertFalse(
+ rd_config.is_alias_creation_allowed(
+ user_id="@test:example.com", room_id="!test", alias="#test:example.com"
+ )
+ )
def test_room_publish_acl(self):
- config = yaml.safe_load("""
+ config = yaml.safe_load(
+ """
alias_creation_rules: []
room_list_publication_rules:
@@ -92,55 +99,66 @@ class RoomDirectoryConfigTestCase(unittest.TestCase):
action: "allow"
- room_id: "!test-deny"
action: "deny"
- """)
+ """
+ )
rd_config = RoomDirectoryConfig()
rd_config.read_config(config)
- self.assertFalse(rd_config.is_publishing_room_allowed(
- user_id="@bob:example.com",
- room_id="!test",
- aliases=["#test:example.com"],
- ))
-
- self.assertTrue(rd_config.is_publishing_room_allowed(
- user_id="@test:example.com",
- room_id="!test",
- aliases=["#unofficial_st:example.com"],
- ))
-
- self.assertTrue(rd_config.is_publishing_room_allowed(
- user_id="@foobar:example.com",
- room_id="!test",
- aliases=[],
- ))
-
- self.assertTrue(rd_config.is_publishing_room_allowed(
- user_id="@gah:example.com",
- room_id="!test",
- aliases=["#goo:example.com"],
- ))
-
- self.assertFalse(rd_config.is_publishing_room_allowed(
- user_id="@test:example.com",
- room_id="!test",
- aliases=["#test:example.com"],
- ))
-
- self.assertTrue(rd_config.is_publishing_room_allowed(
- user_id="@foobar:example.com",
- room_id="!test-deny",
- aliases=[],
- ))
-
- self.assertFalse(rd_config.is_publishing_room_allowed(
- user_id="@gah:example.com",
- room_id="!test-deny",
- aliases=[],
- ))
-
- self.assertTrue(rd_config.is_publishing_room_allowed(
- user_id="@test:example.com",
- room_id="!test",
- aliases=["#unofficial_st:example.com", "#blah:example.com"],
- ))
+ self.assertFalse(
+ rd_config.is_publishing_room_allowed(
+ user_id="@bob:example.com",
+ room_id="!test",
+ aliases=["#test:example.com"],
+ )
+ )
+
+ self.assertTrue(
+ rd_config.is_publishing_room_allowed(
+ user_id="@test:example.com",
+ room_id="!test",
+ aliases=["#unofficial_st:example.com"],
+ )
+ )
+
+ self.assertTrue(
+ rd_config.is_publishing_room_allowed(
+ user_id="@foobar:example.com", room_id="!test", aliases=[]
+ )
+ )
+
+ self.assertTrue(
+ rd_config.is_publishing_room_allowed(
+ user_id="@gah:example.com",
+ room_id="!test",
+ aliases=["#goo:example.com"],
+ )
+ )
+
+ self.assertFalse(
+ rd_config.is_publishing_room_allowed(
+ user_id="@test:example.com",
+ room_id="!test",
+ aliases=["#test:example.com"],
+ )
+ )
+
+ self.assertTrue(
+ rd_config.is_publishing_room_allowed(
+ user_id="@foobar:example.com", room_id="!test-deny", aliases=[]
+ )
+ )
+
+ self.assertFalse(
+ rd_config.is_publishing_room_allowed(
+ user_id="@gah:example.com", room_id="!test-deny", aliases=[]
+ )
+ )
+
+ self.assertTrue(
+ rd_config.is_publishing_room_allowed(
+ user_id="@test:example.com",
+ room_id="!test",
+ aliases=["#unofficial_st:example.com", "#blah:example.com"],
+ )
+ )
diff --git a/tests/config/test_server.py b/tests/config/test_server.py
index f5836d73ac..de64965a60 100644
--- a/tests/config/test_server.py
+++ b/tests/config/test_server.py
@@ -19,7 +19,6 @@ from tests import unittest
class ServerConfigTestCase(unittest.TestCase):
-
def test_is_threepid_reserved(self):
user1 = {'medium': 'email', 'address': 'user1@example.com'}
user2 = {'medium': 'email', 'address': 'user2@example.com'}
diff --git a/tests/config/test_tls.py b/tests/config/test_tls.py
index c260d3359f..40ca428778 100644
--- a/tests/config/test_tls.py
+++ b/tests/config/test_tls.py
@@ -26,7 +26,6 @@ class TestConfig(TlsConfig):
class TLSConfigTests(TestCase):
-
def test_warn_self_signed(self):
"""
Synapse will give a warning when it loads a self-signed certificate.
@@ -34,7 +33,8 @@ class TLSConfigTests(TestCase):
config_dir = self.mktemp()
os.mkdir(config_dir)
with open(os.path.join(config_dir, "cert.pem"), 'w') as f:
- f.write("""-----BEGIN CERTIFICATE-----
+ f.write(
+ """-----BEGIN CERTIFICATE-----
MIID6DCCAtACAws9CjANBgkqhkiG9w0BAQUFADCBtzELMAkGA1UEBhMCVFIxDzAN
BgNVBAgMBsOHb3J1bTEUMBIGA1UEBwwLQmHFn21ha8OnxLExEjAQBgNVBAMMCWxv
Y2FsaG9zdDEcMBoGA1UECgwTVHdpc3RlZCBNYXRyaXggTGFiczEkMCIGA1UECwwb
@@ -56,11 +56,12 @@ I8OtG1xGwcok53lyDuuUUDexnK4O5BkjKiVlNPg4HPim5Kuj2hRNFfNt/F2BVIlj
iZupikC5MT1LQaRwidkSNxCku1TfAyueiBwhLnFwTmIGNnhuDCutEVAD9kFmcJN2
SznugAcPk4doX2+rL+ila+ThqgPzIkwTUHtnmjI0TI6xsDUlXz5S3UyudrE2Qsfz
s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg=
------END CERTIFICATE-----""")
+-----END CERTIFICATE-----"""
+ )
config = {
"tls_certificate_path": os.path.join(config_dir, "cert.pem"),
- "tls_fingerprints": []
+ "tls_fingerprints": [],
}
t = TestConfig()
@@ -75,5 +76,5 @@ s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg=
"Self-signed TLS certificates will not be accepted by "
"Synapse 1.0. Please either provide a valid certificate, "
"or use Synapse's ACME support to provision one."
- )
+ ),
)
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 9af0656a83..5a355f00cc 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2017 New Vector Ltd.
+# Copyright 2017 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -19,12 +19,18 @@ from mock import Mock
import canonicaljson
import signedjson.key
import signedjson.sign
+from signedjson.key import encode_verify_key_base64, get_verify_key
from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.crypto import keyring
-from synapse.crypto.keyring import KeyLookupError
+from synapse.crypto.keyring import (
+ PerspectivesKeyFetcher,
+ ServerKeyFetcher,
+ StoreKeyFetcher,
+)
+from synapse.storage.keys import FetchKeyResult
from synapse.util import logcontext
from synapse.util.logcontext import LoggingContext
@@ -38,7 +44,7 @@ class MockPerspectiveServer(object):
def get_verify_keys(self):
vk = signedjson.key.get_verify_key(self.key)
- return {"%s:%s" % (vk.alg, vk.version): vk}
+ return {"%s:%s" % (vk.alg, vk.version): encode_verify_key_base64(vk)}
def get_signed_key(self, server_name, verify_key):
key_id = "%s:%s" % (verify_key.alg, verify_key.version)
@@ -46,25 +52,31 @@ class MockPerspectiveServer(object):
"server_name": server_name,
"old_verify_keys": {},
"valid_until_ts": time.time() * 1000 + 3600,
- "verify_keys": {
- key_id: {"key": signedjson.key.encode_verify_key_base64(verify_key)}
- },
+ "verify_keys": {key_id: {"key": encode_verify_key_base64(verify_key)}},
}
- return self.get_signed_response(res)
+ self.sign_response(res)
+ return res
- def get_signed_response(self, res):
+ def sign_response(self, res):
signedjson.sign.sign_json(res, self.server_name, self.key)
- return res
class KeyringTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
self.mock_perspective_server = MockPerspectiveServer()
self.http_client = Mock()
- hs = self.setup_test_homeserver(handlers=None, http_client=self.http_client)
- keys = self.mock_perspective_server.get_verify_keys()
- hs.config.perspectives = {self.mock_perspective_server.server_name: keys}
- return hs
+
+ config = self.default_config()
+ config["trusted_key_servers"] = [
+ {
+ "server_name": self.mock_perspective_server.server_name,
+ "verify_keys": self.mock_perspective_server.get_verify_keys(),
+ }
+ ]
+
+ return self.setup_test_homeserver(
+ handlers=None, http_client=self.http_client, config=config
+ )
def check_context(self, _, expected):
self.assertEquals(
@@ -80,7 +92,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
# we run the lookup in a logcontext so that the patched inlineCallbacks can check
# it is doing the right thing with logcontexts.
wait_1_deferred = run_in_context(
- kr.wait_for_previous_lookups, ["server1"], {"server1": lookup_1_deferred}
+ kr.wait_for_previous_lookups, {"server1": lookup_1_deferred}
)
# there were no previous lookups, so the deferred should be ready
@@ -89,7 +101,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
# set off another wait. It should block because the first lookup
# hasn't yet completed.
wait_2_deferred = run_in_context(
- kr.wait_for_previous_lookups, ["server1"], {"server1": lookup_2_deferred}
+ kr.wait_for_previous_lookups, {"server1": lookup_2_deferred}
)
self.assertFalse(wait_2_deferred.called)
@@ -132,7 +144,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
context_11.request = "11"
res_deferreds = kr.verify_json_objects_for_server(
- [("server10", json1), ("server11", {})]
+ [("server10", json1, 0, "test10"), ("server11", {}, 0, "test11")]
)
# the unsigned json should be rejected pretty quickly
@@ -169,7 +181,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
self.http_client.post_json.return_value = defer.Deferred()
res_deferreds_2 = kr.verify_json_objects_for_server(
- [("server10", json1, )]
+ [("server10", json1, 0, "test")]
)
res_deferreds_2[0].addBoth(self.check_context, None)
yield logcontext.make_deferred_yieldable(res_deferreds_2[0])
@@ -192,31 +204,169 @@ class KeyringTestCase(unittest.HomeserverTestCase):
kr = keyring.Keyring(self.hs)
key1 = signedjson.key.generate_signing_key(1)
- r = self.hs.datastore.store_server_verify_key(
- "server9", "", time.time() * 1000, signedjson.key.get_verify_key(key1)
+ r = self.hs.datastore.store_server_verify_keys(
+ "server9",
+ time.time() * 1000,
+ [("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), 1000))],
+ )
+ self.get_success(r)
+
+ json1 = {}
+ signedjson.sign.sign_json(json1, "server9", key1)
+
+ # should fail immediately on an unsigned object
+ d = _verify_json_for_server(kr, "server9", {}, 0, "test unsigned")
+ self.failureResultOf(d, SynapseError)
+
+ # should suceed on a signed object
+ d = _verify_json_for_server(kr, "server9", json1, 500, "test signed")
+ # self.assertFalse(d.called)
+ self.get_success(d)
+
+ def test_verify_json_for_server_with_null_valid_until_ms(self):
+ """Tests that we correctly handle key requests for keys we've stored
+ with a null `ts_valid_until_ms`
+ """
+ mock_fetcher = keyring.KeyFetcher()
+ mock_fetcher.get_keys = Mock(return_value=defer.succeed({}))
+
+ kr = keyring.Keyring(
+ self.hs, key_fetchers=(StoreKeyFetcher(self.hs), mock_fetcher)
+ )
+
+ key1 = signedjson.key.generate_signing_key(1)
+ r = self.hs.datastore.store_server_verify_keys(
+ "server9",
+ time.time() * 1000,
+ [("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), None))],
)
self.get_success(r)
+
json1 = {}
signedjson.sign.sign_json(json1, "server9", key1)
# should fail immediately on an unsigned object
- d = _verify_json_for_server(kr, "server9", {})
+ d = _verify_json_for_server(kr, "server9", {}, 0, "test unsigned")
self.failureResultOf(d, SynapseError)
- d = _verify_json_for_server(kr, "server9", json1)
- self.assertFalse(d.called)
+ # should fail on a signed object with a non-zero minimum_valid_until_ms,
+ # as it tries to refetch the keys and fails.
+ d = _verify_json_for_server(
+ kr, "server9", json1, 500, "test signed non-zero min"
+ )
+ self.get_failure(d, SynapseError)
+
+ # We expect the keyring tried to refetch the key once.
+ mock_fetcher.get_keys.assert_called_once_with(
+ {"server9": {get_key_id(key1): 500}}
+ )
+
+ # should succeed on a signed object with a 0 minimum_valid_until_ms
+ d = _verify_json_for_server(
+ kr, "server9", json1, 0, "test signed with zero min"
+ )
self.get_success(d)
+ def test_verify_json_dedupes_key_requests(self):
+ """Two requests for the same key should be deduped."""
+ key1 = signedjson.key.generate_signing_key(1)
+
+ def get_keys(keys_to_fetch):
+ # there should only be one request object (with the max validity)
+ self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
+
+ return defer.succeed(
+ {
+ "server1": {
+ get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
+ }
+ }
+ )
+
+ mock_fetcher = keyring.KeyFetcher()
+ mock_fetcher.get_keys = Mock(side_effect=get_keys)
+ kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher,))
+
+ json1 = {}
+ signedjson.sign.sign_json(json1, "server1", key1)
+
+ # the first request should succeed; the second should fail because the key
+ # has expired
+ results = kr.verify_json_objects_for_server(
+ [("server1", json1, 500, "test1"), ("server1", json1, 1500, "test2")]
+ )
+ self.assertEqual(len(results), 2)
+ self.get_success(results[0])
+ e = self.get_failure(results[1], SynapseError).value
+ self.assertEqual(e.errcode, "M_UNAUTHORIZED")
+ self.assertEqual(e.code, 401)
+
+ # there should have been a single call to the fetcher
+ mock_fetcher.get_keys.assert_called_once()
+
+ def test_verify_json_falls_back_to_other_fetchers(self):
+ """If the first fetcher cannot provide a recent enough key, we fall back"""
+ key1 = signedjson.key.generate_signing_key(1)
+
+ def get_keys1(keys_to_fetch):
+ self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
+ return defer.succeed(
+ {
+ "server1": {
+ get_key_id(key1): FetchKeyResult(get_verify_key(key1), 800)
+ }
+ }
+ )
+
+ def get_keys2(keys_to_fetch):
+ self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
+ return defer.succeed(
+ {
+ "server1": {
+ get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
+ }
+ }
+ )
+
+ mock_fetcher1 = keyring.KeyFetcher()
+ mock_fetcher1.get_keys = Mock(side_effect=get_keys1)
+ mock_fetcher2 = keyring.KeyFetcher()
+ mock_fetcher2.get_keys = Mock(side_effect=get_keys2)
+ kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher1, mock_fetcher2))
+
+ json1 = {}
+ signedjson.sign.sign_json(json1, "server1", key1)
+
+ results = kr.verify_json_objects_for_server(
+ [("server1", json1, 1200, "test1"), ("server1", json1, 1500, "test2")]
+ )
+ self.assertEqual(len(results), 2)
+ self.get_success(results[0])
+ e = self.get_failure(results[1], SynapseError).value
+ self.assertEqual(e.errcode, "M_UNAUTHORIZED")
+ self.assertEqual(e.code, 401)
+
+ # there should have been a single call to each fetcher
+ mock_fetcher1.get_keys.assert_called_once()
+ mock_fetcher2.get_keys.assert_called_once()
+
+
+class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
+ def make_homeserver(self, reactor, clock):
+ self.http_client = Mock()
+ hs = self.setup_test_homeserver(handlers=None, http_client=self.http_client)
+ return hs
+
def test_get_keys_from_server(self):
# arbitrarily advance the clock a bit
self.reactor.advance(100)
SERVER_NAME = "server2"
- kr = keyring.Keyring(self.hs)
+ fetcher = ServerKeyFetcher(self.hs)
testkey = signedjson.key.generate_signing_key("ver1")
testverifykey = signedjson.key.get_verify_key(testkey)
testverifykey_id = "ed25519:ver1"
- VALID_UNTIL_TS = 1000
+ VALID_UNTIL_TS = 200 * 1000
# valid response
response = {
@@ -238,12 +388,13 @@ class KeyringTestCase(unittest.HomeserverTestCase):
self.http_client.get_json.side_effect = get_json
- server_name_and_key_ids = [(SERVER_NAME, ("key1",))]
- keys = self.get_success(kr.get_keys_from_server(server_name_and_key_ids))
+ keys_to_fetch = {SERVER_NAME: {"key1": 0}}
+ keys = self.get_success(fetcher.get_keys(keys_to_fetch))
k = keys[SERVER_NAME][testverifykey_id]
- self.assertEqual(k, testverifykey)
- self.assertEqual(k.alg, "ed25519")
- self.assertEqual(k.version, "ver1")
+ self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
+ self.assertEqual(k.verify_key, testverifykey)
+ self.assertEqual(k.verify_key.alg, "ed25519")
+ self.assertEqual(k.verify_key.version, "ver1")
# check that the perspectives store is correctly updated
lookup_triplet = (SERVER_NAME, testverifykey_id, None)
@@ -263,18 +414,37 @@ class KeyringTestCase(unittest.HomeserverTestCase):
bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
)
- # change the server name: it should cause a rejection
+ # change the server name: the result should be ignored
response["server_name"] = "OTHER_SERVER"
- self.get_failure(
- kr.get_keys_from_server(server_name_and_key_ids), KeyLookupError
+
+ keys = self.get_success(fetcher.get_keys(keys_to_fetch))
+ self.assertEqual(keys, {})
+
+
+class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
+ def make_homeserver(self, reactor, clock):
+ self.mock_perspective_server = MockPerspectiveServer()
+ self.http_client = Mock()
+
+ config = self.default_config()
+ config["trusted_key_servers"] = [
+ {
+ "server_name": self.mock_perspective_server.server_name,
+ "verify_keys": self.mock_perspective_server.get_verify_keys(),
+ }
+ ]
+
+ return self.setup_test_homeserver(
+ handlers=None, http_client=self.http_client, config=config
)
def test_get_keys_from_perspectives(self):
# arbitrarily advance the clock a bit
self.reactor.advance(100)
+ fetcher = PerspectivesKeyFetcher(self.hs)
+
SERVER_NAME = "server2"
- kr = keyring.Keyring(self.hs)
testkey = signedjson.key.generate_signing_key("ver1")
testverifykey = signedjson.key.get_verify_key(testkey)
testverifykey_id = "ed25519:ver1"
@@ -292,9 +462,10 @@ class KeyringTestCase(unittest.HomeserverTestCase):
},
}
- persp_resp = {
- "server_keys": [self.mock_perspective_server.get_signed_response(response)]
- }
+ # the response must be signed by both the origin server and the perspectives
+ # server.
+ signedjson.sign.sign_json(response, SERVER_NAME, testkey)
+ self.mock_perspective_server.sign_response(response)
def post_json(destination, path, data, **kwargs):
self.assertEqual(destination, self.mock_perspective_server.server_name)
@@ -303,17 +474,18 @@ class KeyringTestCase(unittest.HomeserverTestCase):
# check that the request is for the expected key
q = data["server_keys"]
self.assertEqual(list(q[SERVER_NAME].keys()), ["key1"])
- return persp_resp
+ return {"server_keys": [response]}
self.http_client.post_json.side_effect = post_json
- server_name_and_key_ids = [(SERVER_NAME, ("key1",))]
- keys = self.get_success(kr.get_keys_from_perspectives(server_name_and_key_ids))
+ keys_to_fetch = {SERVER_NAME: {"key1": 0}}
+ keys = self.get_success(fetcher.get_keys(keys_to_fetch))
self.assertIn(SERVER_NAME, keys)
k = keys[SERVER_NAME][testverifykey_id]
- self.assertEqual(k, testverifykey)
- self.assertEqual(k.alg, "ed25519")
- self.assertEqual(k.version, "ver1")
+ self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
+ self.assertEqual(k.verify_key, testverifykey)
+ self.assertEqual(k.verify_key.alg, "ed25519")
+ self.assertEqual(k.verify_key.version, "ver1")
# check that the perspectives store is correctly updated
lookup_triplet = (SERVER_NAME, testverifykey_id, None)
@@ -329,25 +501,96 @@ class KeyringTestCase(unittest.HomeserverTestCase):
self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS)
self.assertEqual(
- bytes(res["key_json"]),
- canonicaljson.encode_canonical_json(persp_resp["server_keys"][0]),
+ bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
)
+ def test_invalid_perspectives_responses(self):
+ """Check that invalid responses from the perspectives server are rejected"""
+ # arbitrarily advance the clock a bit
+ self.reactor.advance(100)
+
+ SERVER_NAME = "server2"
+ testkey = signedjson.key.generate_signing_key("ver1")
+ testverifykey = signedjson.key.get_verify_key(testkey)
+ testverifykey_id = "ed25519:ver1"
+ VALID_UNTIL_TS = 200 * 1000
+
+ def build_response():
+ # valid response
+ response = {
+ "server_name": SERVER_NAME,
+ "old_verify_keys": {},
+ "valid_until_ts": VALID_UNTIL_TS,
+ "verify_keys": {
+ testverifykey_id: {
+ "key": signedjson.key.encode_verify_key_base64(testverifykey)
+ }
+ },
+ }
+
+ # the response must be signed by both the origin server and the perspectives
+ # server.
+ signedjson.sign.sign_json(response, SERVER_NAME, testkey)
+ self.mock_perspective_server.sign_response(response)
+ return response
+
+ def get_key_from_perspectives(response):
+ fetcher = PerspectivesKeyFetcher(self.hs)
+ keys_to_fetch = {SERVER_NAME: {"key1": 0}}
+
+ def post_json(destination, path, data, **kwargs):
+ self.assertEqual(destination, self.mock_perspective_server.server_name)
+ self.assertEqual(path, "/_matrix/key/v2/query")
+ return {"server_keys": [response]}
+
+ self.http_client.post_json.side_effect = post_json
+
+ return self.get_success(fetcher.get_keys(keys_to_fetch))
+
+ # start with a valid response so we can check we are testing the right thing
+ response = build_response()
+ keys = get_key_from_perspectives(response)
+ k = keys[SERVER_NAME][testverifykey_id]
+ self.assertEqual(k.verify_key, testverifykey)
+
+ # remove the perspectives server's signature
+ response = build_response()
+ del response["signatures"][self.mock_perspective_server.server_name]
+ self.http_client.post_json.return_value = {"server_keys": [response]}
+ keys = get_key_from_perspectives(response)
+ self.assertEqual(keys, {}, "Expected empty dict with missing persp server sig")
+
+ # remove the origin server's signature
+ response = build_response()
+ del response["signatures"][SERVER_NAME]
+ self.http_client.post_json.return_value = {"server_keys": [response]}
+ keys = get_key_from_perspectives(response)
+ self.assertEqual(keys, {}, "Expected empty dict with missing origin server sig")
+
+
+def get_key_id(key):
+ """Get the matrix ID tag for a given SigningKey or VerifyKey"""
+ return "%s:%s" % (key.alg, key.version)
+
@defer.inlineCallbacks
def run_in_context(f, *args, **kwargs):
- with LoggingContext("testctx"):
+ with LoggingContext("testctx") as ctx:
+ # we set the "request" prop to make it easier to follow what's going on in the
+ # logs.
+ ctx.request = "testctx"
rv = yield f(*args, **kwargs)
defer.returnValue(rv)
-def _verify_json_for_server(keyring, server_name, json_object):
+def _verify_json_for_server(kr, *args):
"""thin wrapper around verify_json_for_server which makes sure it is wrapped
with the patched defer.inlineCallbacks.
"""
+
@defer.inlineCallbacks
def v():
- rv1 = yield keyring.verify_json_for_server(server_name, json_object)
+ rv1 = yield kr.verify_json_for_server(*args)
defer.returnValue(rv1)
return run_in_context(v)
diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py
new file mode 100644
index 0000000000..1e3e5aec66
--- /dev/null
+++ b/tests/federation/test_complexity.py
@@ -0,0 +1,90 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 Matrix.org Foundation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from synapse.config.ratelimiting import FederationRateLimitConfig
+from synapse.federation.transport import server
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+from synapse.util.ratelimitutils import FederationRateLimiter
+
+from tests import unittest
+
+
+class RoomComplexityTests(unittest.HomeserverTestCase):
+
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def default_config(self, name='test'):
+ config = super(RoomComplexityTests, self).default_config(name=name)
+ config["limit_large_remote_room_joins"] = True
+ config["limit_large_remote_room_complexity"] = 0.05
+ return config
+
+ def prepare(self, reactor, clock, homeserver):
+ class Authenticator(object):
+ def authenticate_request(self, request, content):
+ return defer.succeed("otherserver.nottld")
+
+ ratelimiter = FederationRateLimiter(
+ clock,
+ FederationRateLimitConfig(
+ window_size=1,
+ sleep_limit=1,
+ sleep_msec=1,
+ reject_limit=1000,
+ concurrent_requests=1000,
+ ),
+ )
+ server.register_servlets(
+ homeserver, self.resource, Authenticator(), ratelimiter
+ )
+
+ def test_complexity_simple(self):
+
+ u1 = self.register_user("u1", "pass")
+ u1_token = self.login("u1", "pass")
+
+ room_1 = self.helper.create_room_as(u1, tok=u1_token)
+ self.helper.send_state(
+ room_1, event_type="m.room.topic", body={"topic": "foo"}, tok=u1_token
+ )
+
+ # Get the room complexity
+ request, channel = self.make_request(
+ "GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,)
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code)
+ complexity = channel.json_body["v1"]
+ self.assertTrue(complexity > 0, complexity)
+
+ # Artificially raise the complexity
+ store = self.hs.get_datastore()
+ store.get_current_state_event_counts = lambda x: defer.succeed(500 * 1.23)
+
+ # Get the room complexity again -- make sure it's our artificial value
+ request, channel = self.make_request(
+ "GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,)
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code)
+ complexity = channel.json_body["v1"]
+ self.assertEqual(complexity, 1.23)
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index 28e7e27416..7bb106b5f7 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -33,11 +33,15 @@ class FederationSenderTestCases(HomeserverTestCase):
mock_state_handler = self.hs.get_state_handler()
mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"]
- mock_send_transaction = self.hs.get_federation_transport_client().send_transaction
+ mock_send_transaction = (
+ self.hs.get_federation_transport_client().send_transaction
+ )
mock_send_transaction.return_value = defer.succeed({})
sender = self.hs.get_federation_sender()
- receipt = ReadReceipt("room_id", "m.read", "user_id", ["event_id"], {"ts": 1234})
+ receipt = ReadReceipt(
+ "room_id", "m.read", "user_id", ["event_id"], {"ts": 1234}
+ )
self.successResultOf(sender.send_read_receipt(receipt))
self.pump()
@@ -46,21 +50,24 @@ class FederationSenderTestCases(HomeserverTestCase):
mock_send_transaction.assert_called_once()
json_cb = mock_send_transaction.call_args[0][1]
data = json_cb()
- self.assertEqual(data['edus'], [
- {
- 'edu_type': 'm.receipt',
- 'content': {
- 'room_id': {
- 'm.read': {
- 'user_id': {
- 'event_ids': ['event_id'],
- 'data': {'ts': 1234},
- },
- },
+ self.assertEqual(
+ data['edus'],
+ [
+ {
+ 'edu_type': 'm.receipt',
+ 'content': {
+ 'room_id': {
+ 'm.read': {
+ 'user_id': {
+ 'event_ids': ['event_id'],
+ 'data': {'ts': 1234},
+ }
+ }
+ }
},
- },
- },
- ])
+ }
+ ],
+ )
def test_send_receipts_with_backoff(self):
"""Send two receipts in quick succession; the second should be flushed, but
@@ -68,11 +75,15 @@ class FederationSenderTestCases(HomeserverTestCase):
mock_state_handler = self.hs.get_state_handler()
mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"]
- mock_send_transaction = self.hs.get_federation_transport_client().send_transaction
+ mock_send_transaction = (
+ self.hs.get_federation_transport_client().send_transaction
+ )
mock_send_transaction.return_value = defer.succeed({})
sender = self.hs.get_federation_sender()
- receipt = ReadReceipt("room_id", "m.read", "user_id", ["event_id"], {"ts": 1234})
+ receipt = ReadReceipt(
+ "room_id", "m.read", "user_id", ["event_id"], {"ts": 1234}
+ )
self.successResultOf(sender.send_read_receipt(receipt))
self.pump()
@@ -81,25 +92,30 @@ class FederationSenderTestCases(HomeserverTestCase):
mock_send_transaction.assert_called_once()
json_cb = mock_send_transaction.call_args[0][1]
data = json_cb()
- self.assertEqual(data['edus'], [
- {
- 'edu_type': 'm.receipt',
- 'content': {
- 'room_id': {
- 'm.read': {
- 'user_id': {
- 'event_ids': ['event_id'],
- 'data': {'ts': 1234},
- },
- },
+ self.assertEqual(
+ data['edus'],
+ [
+ {
+ 'edu_type': 'm.receipt',
+ 'content': {
+ 'room_id': {
+ 'm.read': {
+ 'user_id': {
+ 'event_ids': ['event_id'],
+ 'data': {'ts': 1234},
+ }
+ }
+ }
},
- },
- },
- ])
+ }
+ ],
+ )
mock_send_transaction.reset_mock()
# send the second RR
- receipt = ReadReceipt("room_id", "m.read", "user_id", ["other_id"], {"ts": 1234})
+ receipt = ReadReceipt(
+ "room_id", "m.read", "user_id", ["other_id"], {"ts": 1234}
+ )
self.successResultOf(sender.send_read_receipt(receipt))
self.pump()
mock_send_transaction.assert_not_called()
@@ -111,18 +127,21 @@ class FederationSenderTestCases(HomeserverTestCase):
mock_send_transaction.assert_called_once()
json_cb = mock_send_transaction.call_args[0][1]
data = json_cb()
- self.assertEqual(data['edus'], [
- {
- 'edu_type': 'm.receipt',
- 'content': {
- 'room_id': {
- 'm.read': {
- 'user_id': {
- 'event_ids': ['other_id'],
- 'data': {'ts': 1234},
- },
- },
+ self.assertEqual(
+ data['edus'],
+ [
+ {
+ 'edu_type': 'm.receipt',
+ 'content': {
+ 'room_id': {
+ 'm.read': {
+ 'user_id': {
+ 'event_ids': ['other_id'],
+ 'data': {'ts': 1234},
+ }
+ }
+ }
},
- },
- },
- ])
+ }
+ ],
+ )
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index 5b2105bc76..917548bb31 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -115,11 +115,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
# We cheekily override the config to add custom alias creation rules
config = {}
config["alias_creation_rules"] = [
- {
- "user_id": "*",
- "alias": "#unofficial_*",
- "action": "allow",
- }
+ {"user_id": "*", "alias": "#unofficial_*", "action": "allow"}
]
config["room_list_publication_rules"] = []
@@ -162,9 +158,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
room_id = self.helper.create_room_as(self.user_id)
request, channel = self.make_request(
- "PUT",
- b"directory/list/room/%s" % (room_id.encode('ascii'),),
- b'{}',
+ "PUT", b"directory/list/room/%s" % (room_id.encode('ascii'),), b'{}'
)
self.render(request)
self.assertEquals(200, channel.code, channel.result)
@@ -179,10 +173,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
self.directory_handler.enable_room_list_search = True
# Room list is enabled so we should get some results
- request, channel = self.make_request(
- "GET",
- b"publicRooms",
- )
+ request, channel = self.make_request("GET", b"publicRooms")
self.render(request)
self.assertEquals(200, channel.code, channel.result)
self.assertTrue(len(channel.json_body["chunk"]) > 0)
@@ -191,10 +182,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
self.directory_handler.enable_room_list_search = False
# Room list disabled so we should get no results
- request, channel = self.make_request(
- "GET",
- b"publicRooms",
- )
+ request, channel = self.make_request("GET", b"publicRooms")
self.render(request)
self.assertEquals(200, channel.code, channel.result)
self.assertTrue(len(channel.json_body["chunk"]) == 0)
@@ -202,9 +190,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
# Room list disabled so we shouldn't be allowed to publish rooms
room_id = self.helper.create_room_as(self.user_id)
request, channel = self.make_request(
- "PUT",
- b"directory/list/room/%s" % (room_id.encode('ascii'),),
- b'{}',
+ "PUT", b"directory/list/room/%s" % (room_id.encode('ascii'),), b'{}'
)
self.render(request)
self.assertEquals(403, channel.code, channel.result)
diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py
index 1c49bbbc3c..2e72a1dd23 100644
--- a/tests/handlers/test_e2e_room_keys.py
+++ b/tests/handlers/test_e2e_room_keys.py
@@ -36,7 +36,7 @@ room_keys = {
"first_message_index": 1,
"forwarded_count": 1,
"is_verified": False,
- "session_data": "SSBBTSBBIEZJU0gK"
+ "session_data": "SSBBTSBBIEZJU0gK",
}
}
}
@@ -47,15 +47,13 @@ room_keys = {
class E2eRoomKeysHandlerTestCase(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(E2eRoomKeysHandlerTestCase, self).__init__(*args, **kwargs)
- self.hs = None # type: synapse.server.HomeServer
+ self.hs = None # type: synapse.server.HomeServer
self.handler = None # type: synapse.handlers.e2e_keys.E2eRoomKeysHandler
@defer.inlineCallbacks
def setUp(self):
self.hs = yield utils.setup_test_homeserver(
- self.addCleanup,
- handlers=None,
- replication_layer=mock.Mock(),
+ self.addCleanup, handlers=None, replication_layer=mock.Mock()
)
self.handler = synapse.handlers.e2e_room_keys.E2eRoomKeysHandler(self.hs)
self.local_user = "@boris:" + self.hs.hostname
@@ -88,67 +86,86 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
def test_create_version(self):
"""Check that we can create and then retrieve versions.
"""
- res = yield self.handler.create_version(self.local_user, {
- "algorithm": "m.megolm_backup.v1",
- "auth_data": "first_version_auth_data",
- })
+ res = yield self.handler.create_version(
+ self.local_user,
+ {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+ )
self.assertEqual(res, "1")
# check we can retrieve it as the current version
res = yield self.handler.get_version_info(self.local_user)
- self.assertDictEqual(res, {
- "version": "1",
- "algorithm": "m.megolm_backup.v1",
- "auth_data": "first_version_auth_data",
- })
+ self.assertDictEqual(
+ res,
+ {
+ "version": "1",
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "first_version_auth_data",
+ },
+ )
# check we can retrieve it as a specific version
res = yield self.handler.get_version_info(self.local_user, "1")
- self.assertDictEqual(res, {
- "version": "1",
- "algorithm": "m.megolm_backup.v1",
- "auth_data": "first_version_auth_data",
- })
+ self.assertDictEqual(
+ res,
+ {
+ "version": "1",
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "first_version_auth_data",
+ },
+ )
# upload a new one...
- res = yield self.handler.create_version(self.local_user, {
- "algorithm": "m.megolm_backup.v1",
- "auth_data": "second_version_auth_data",
- })
+ res = yield self.handler.create_version(
+ self.local_user,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "second_version_auth_data",
+ },
+ )
self.assertEqual(res, "2")
# check we can retrieve it as the current version
res = yield self.handler.get_version_info(self.local_user)
- self.assertDictEqual(res, {
- "version": "2",
- "algorithm": "m.megolm_backup.v1",
- "auth_data": "second_version_auth_data",
- })
+ self.assertDictEqual(
+ res,
+ {
+ "version": "2",
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "second_version_auth_data",
+ },
+ )
@defer.inlineCallbacks
def test_update_version(self):
"""Check that we can update versions.
"""
- version = yield self.handler.create_version(self.local_user, {
- "algorithm": "m.megolm_backup.v1",
- "auth_data": "first_version_auth_data",
- })
+ version = yield self.handler.create_version(
+ self.local_user,
+ {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+ )
self.assertEqual(version, "1")
- res = yield self.handler.update_version(self.local_user, version, {
- "algorithm": "m.megolm_backup.v1",
- "auth_data": "revised_first_version_auth_data",
- "version": version
- })
+ res = yield self.handler.update_version(
+ self.local_user,
+ version,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "revised_first_version_auth_data",
+ "version": version,
+ },
+ )
self.assertDictEqual(res, {})
# check we can retrieve it as the current version
res = yield self.handler.get_version_info(self.local_user)
- self.assertDictEqual(res, {
- "algorithm": "m.megolm_backup.v1",
- "auth_data": "revised_first_version_auth_data",
- "version": version
- })
+ self.assertDictEqual(
+ res,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "revised_first_version_auth_data",
+ "version": version,
+ },
+ )
@defer.inlineCallbacks
def test_update_missing_version(self):
@@ -156,11 +173,15 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
"""
res = None
try:
- yield self.handler.update_version(self.local_user, "1", {
- "algorithm": "m.megolm_backup.v1",
- "auth_data": "revised_first_version_auth_data",
- "version": "1"
- })
+ yield self.handler.update_version(
+ self.local_user,
+ "1",
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "revised_first_version_auth_data",
+ "version": "1",
+ },
+ )
except errors.SynapseError as e:
res = e.code
self.assertEqual(res, 404)
@@ -170,29 +191,37 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
"""Check that we get a 400 if the version in the body is missing or
doesn't match
"""
- version = yield self.handler.create_version(self.local_user, {
- "algorithm": "m.megolm_backup.v1",
- "auth_data": "first_version_auth_data",
- })
+ version = yield self.handler.create_version(
+ self.local_user,
+ {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+ )
self.assertEqual(version, "1")
res = None
try:
- yield self.handler.update_version(self.local_user, version, {
- "algorithm": "m.megolm_backup.v1",
- "auth_data": "revised_first_version_auth_data"
- })
+ yield self.handler.update_version(
+ self.local_user,
+ version,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "revised_first_version_auth_data",
+ },
+ )
except errors.SynapseError as e:
res = e.code
self.assertEqual(res, 400)
res = None
try:
- yield self.handler.update_version(self.local_user, version, {
- "algorithm": "m.megolm_backup.v1",
- "auth_data": "revised_first_version_auth_data",
- "version": "incorrect"
- })
+ yield self.handler.update_version(
+ self.local_user,
+ version,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "revised_first_version_auth_data",
+ "version": "incorrect",
+ },
+ )
except errors.SynapseError as e:
res = e.code
self.assertEqual(res, 400)
@@ -223,10 +252,10 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
def test_delete_version(self):
"""Check that we can create and then delete versions.
"""
- res = yield self.handler.create_version(self.local_user, {
- "algorithm": "m.megolm_backup.v1",
- "auth_data": "first_version_auth_data",
- })
+ res = yield self.handler.create_version(
+ self.local_user,
+ {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+ )
self.assertEqual(res, "1")
# check we can delete it
@@ -255,16 +284,14 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
def test_get_missing_room_keys(self):
"""Check we get an empty response from an empty backup
"""
- version = yield self.handler.create_version(self.local_user, {
- "algorithm": "m.megolm_backup.v1",
- "auth_data": "first_version_auth_data",
- })
+ version = yield self.handler.create_version(
+ self.local_user,
+ {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+ )
self.assertEqual(version, "1")
res = yield self.handler.get_room_keys(self.local_user, version)
- self.assertDictEqual(res, {
- "rooms": {}
- })
+ self.assertDictEqual(res, {"rooms": {}})
# TODO: test the locking semantics when uploading room_keys,
# although this is probably best done in sytest
@@ -275,7 +302,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
"""
res = None
try:
- yield self.handler.upload_room_keys(self.local_user, "no_version", room_keys)
+ yield self.handler.upload_room_keys(
+ self.local_user, "no_version", room_keys
+ )
except errors.SynapseError as e:
res = e.code
self.assertEqual(res, 404)
@@ -285,10 +314,10 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
"""Check that we get a 404 on uploading keys when an nonexistent version
is specified
"""
- version = yield self.handler.create_version(self.local_user, {
- "algorithm": "m.megolm_backup.v1",
- "auth_data": "first_version_auth_data",
- })
+ version = yield self.handler.create_version(
+ self.local_user,
+ {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+ )
self.assertEqual(version, "1")
res = None
@@ -304,16 +333,19 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
def test_upload_room_keys_wrong_version(self):
"""Check that we get a 403 on uploading keys for an old version
"""
- version = yield self.handler.create_version(self.local_user, {
- "algorithm": "m.megolm_backup.v1",
- "auth_data": "first_version_auth_data",
- })
+ version = yield self.handler.create_version(
+ self.local_user,
+ {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+ )
self.assertEqual(version, "1")
- version = yield self.handler.create_version(self.local_user, {
- "algorithm": "m.megolm_backup.v1",
- "auth_data": "second_version_auth_data",
- })
+ version = yield self.handler.create_version(
+ self.local_user,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "second_version_auth_data",
+ },
+ )
self.assertEqual(version, "2")
res = None
@@ -327,10 +359,10 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
def test_upload_room_keys_insert(self):
"""Check that we can insert and retrieve keys for a session
"""
- version = yield self.handler.create_version(self.local_user, {
- "algorithm": "m.megolm_backup.v1",
- "auth_data": "first_version_auth_data",
- })
+ version = yield self.handler.create_version(
+ self.local_user,
+ {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+ )
self.assertEqual(version, "1")
yield self.handler.upload_room_keys(self.local_user, version, room_keys)
@@ -340,18 +372,13 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
# check getting room_keys for a given room
res = yield self.handler.get_room_keys(
- self.local_user,
- version,
- room_id="!abc:matrix.org"
+ self.local_user, version, room_id="!abc:matrix.org"
)
self.assertDictEqual(res, room_keys)
# check getting room_keys for a given session_id
res = yield self.handler.get_room_keys(
- self.local_user,
- version,
- room_id="!abc:matrix.org",
- session_id="c0ff33",
+ self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
)
self.assertDictEqual(res, room_keys)
@@ -359,10 +386,10 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
def test_upload_room_keys_merge(self):
"""Check that we can upload a new room_key for an existing session and
have it correctly merged"""
- version = yield self.handler.create_version(self.local_user, {
- "algorithm": "m.megolm_backup.v1",
- "auth_data": "first_version_auth_data",
- })
+ version = yield self.handler.create_version(
+ self.local_user,
+ {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+ )
self.assertEqual(version, "1")
yield self.handler.upload_room_keys(self.local_user, version, room_keys)
@@ -378,7 +405,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
res = yield self.handler.get_room_keys(self.local_user, version)
self.assertEqual(
res['rooms']['!abc:matrix.org']['sessions']['c0ff33']['session_data'],
- "SSBBTSBBIEZJU0gK"
+ "SSBBTSBBIEZJU0gK",
)
# test that marking the session as verified however /does/ replace it
@@ -387,8 +414,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
res = yield self.handler.get_room_keys(self.local_user, version)
self.assertEqual(
- res['rooms']['!abc:matrix.org']['sessions']['c0ff33']['session_data'],
- "new"
+ res['rooms']['!abc:matrix.org']['sessions']['c0ff33']['session_data'], "new"
)
# test that a session with a higher forwarded_count doesn't replace one
@@ -399,8 +425,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
res = yield self.handler.get_room_keys(self.local_user, version)
self.assertEqual(
- res['rooms']['!abc:matrix.org']['sessions']['c0ff33']['session_data'],
- "new"
+ res['rooms']['!abc:matrix.org']['sessions']['c0ff33']['session_data'], "new"
)
# TODO: check edge cases as well as the common variations here
@@ -409,56 +434,36 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
def test_delete_room_keys(self):
"""Check that we can insert and delete keys for a session
"""
- version = yield self.handler.create_version(self.local_user, {
- "algorithm": "m.megolm_backup.v1",
- "auth_data": "first_version_auth_data",
- })
+ version = yield self.handler.create_version(
+ self.local_user,
+ {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+ )
self.assertEqual(version, "1")
# check for bulk-delete
yield self.handler.upload_room_keys(self.local_user, version, room_keys)
yield self.handler.delete_room_keys(self.local_user, version)
res = yield self.handler.get_room_keys(
- self.local_user,
- version,
- room_id="!abc:matrix.org",
- session_id="c0ff33",
+ self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
)
- self.assertDictEqual(res, {
- "rooms": {}
- })
+ self.assertDictEqual(res, {"rooms": {}})
# check for bulk-delete per room
yield self.handler.upload_room_keys(self.local_user, version, room_keys)
yield self.handler.delete_room_keys(
- self.local_user,
- version,
- room_id="!abc:matrix.org",
+ self.local_user, version, room_id="!abc:matrix.org"
)
res = yield self.handler.get_room_keys(
- self.local_user,
- version,
- room_id="!abc:matrix.org",
- session_id="c0ff33",
+ self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
)
- self.assertDictEqual(res, {
- "rooms": {}
- })
+ self.assertDictEqual(res, {"rooms": {}})
# check for bulk-delete per session
yield self.handler.upload_room_keys(self.local_user, version, room_keys)
yield self.handler.delete_room_keys(
- self.local_user,
- version,
- room_id="!abc:matrix.org",
- session_id="c0ff33",
+ self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
)
res = yield self.handler.get_room_keys(
- self.local_user,
- version,
- room_id="!abc:matrix.org",
- session_id="c0ff33",
+ self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
)
- self.assertDictEqual(res, {
- "rooms": {}
- })
+ self.assertDictEqual(res, {"rooms": {}})
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index 94c6080e34..f70c6e7d65 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -424,8 +424,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver(
- "server", http_client=None,
- federation_sender=Mock(),
+ "server", http_client=None, federation_sender=Mock()
)
return hs
@@ -457,7 +456,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
# Mark test2 as online, test will be offline with a last_active of 0
self.presence_handler.set_state(
- UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE},
+ UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE}
)
self.reactor.pump([0]) # Wait for presence updates to be handled
@@ -506,13 +505,13 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
# Mark test as online
self.presence_handler.set_state(
- UserID.from_string("@test:server"), {"presence": PresenceState.ONLINE},
+ UserID.from_string("@test:server"), {"presence": PresenceState.ONLINE}
)
# Mark test2 as online, test will be offline with a last_active of 0.
# Note we don't join them to the room yet
self.presence_handler.set_state(
- UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE},
+ UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE}
)
# Add servers to the room
@@ -541,8 +540,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(expected_state.state, PresenceState.ONLINE)
self.federation_sender.send_presence_to_destinations.assert_called_once_with(
- destinations=set(("server2", "server3")),
- states=[expected_state]
+ destinations=set(("server2", "server3")), states=[expected_state]
)
def _add_new_user(self, room_id, user_id):
@@ -565,7 +563,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
type=EventTypes.Member,
sender=user_id,
state_key=user_id,
- content={"membership": Membership.JOIN}
+ content={"membership": Membership.JOIN},
)
prev_event_ids = self.get_success(
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index 017ea0385e..5ffba2ca7a 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -37,8 +37,12 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
hs_config = self.default_config("test")
# some of the tests rely on us having a user consent version
- hs_config.user_consent_version = "test_consent_version"
- hs_config.max_mau_value = 50
+ hs_config["user_consent"] = {
+ "version": "test_consent_version",
+ "template_dir": ".",
+ }
+ hs_config["max_mau_value"] = 50
+ hs_config["limit_usage_by_mau"] = True
hs = self.setup_test_homeserver(config=hs_config, expire_access_token=True)
return hs
@@ -224,3 +228,10 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
def test_register_not_support_user(self):
res = self.get_success(self.handler.register(localpart='user'))
self.assertFalse(self.store.is_support_user(res[0]))
+
+ def test_invalid_user_id_length(self):
+ invalid_user_id = "x" * 256
+ self.get_failure(
+ self.handler.register(localpart=invalid_user_id),
+ SynapseError
+ )
diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py
new file mode 100644
index 0000000000..2710c991cf
--- /dev/null
+++ b/tests/handlers/test_stats.py
@@ -0,0 +1,307 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from mock import Mock
+
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes, Membership
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+
+from tests import unittest
+
+
+class StatsRoomTests(unittest.HomeserverTestCase):
+
+ servlets = [
+ admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+
+ self.store = hs.get_datastore()
+ self.handler = self.hs.get_stats_handler()
+
+ def _add_background_updates(self):
+ """
+ Add the background updates we need to run.
+ """
+ # Ugh, have to reset this flag
+ self.store._all_done = False
+
+ self.get_success(
+ self.store._simple_insert(
+ "background_updates",
+ {"update_name": "populate_stats_createtables", "progress_json": "{}"},
+ )
+ )
+ self.get_success(
+ self.store._simple_insert(
+ "background_updates",
+ {
+ "update_name": "populate_stats_process_rooms",
+ "progress_json": "{}",
+ "depends_on": "populate_stats_createtables",
+ },
+ )
+ )
+ self.get_success(
+ self.store._simple_insert(
+ "background_updates",
+ {
+ "update_name": "populate_stats_cleanup",
+ "progress_json": "{}",
+ "depends_on": "populate_stats_process_rooms",
+ },
+ )
+ )
+
+ def test_initial_room(self):
+ """
+ The background updates will build the table from scratch.
+ """
+ r = self.get_success(self.store.get_all_room_state())
+ self.assertEqual(len(r), 0)
+
+ # Disable stats
+ self.hs.config.stats_enabled = False
+ self.handler.stats_enabled = False
+
+ u1 = self.register_user("u1", "pass")
+ u1_token = self.login("u1", "pass")
+
+ room_1 = self.helper.create_room_as(u1, tok=u1_token)
+ self.helper.send_state(
+ room_1, event_type="m.room.topic", body={"topic": "foo"}, tok=u1_token
+ )
+
+ # Stats disabled, shouldn't have done anything
+ r = self.get_success(self.store.get_all_room_state())
+ self.assertEqual(len(r), 0)
+
+ # Enable stats
+ self.hs.config.stats_enabled = True
+ self.handler.stats_enabled = True
+
+ # Do the initial population of the user directory via the background update
+ self._add_background_updates()
+
+ while not self.get_success(self.store.has_completed_background_updates()):
+ self.get_success(self.store.do_next_background_update(100), by=0.1)
+
+ r = self.get_success(self.store.get_all_room_state())
+
+ self.assertEqual(len(r), 1)
+ self.assertEqual(r[0]["topic"], "foo")
+
+ def test_initial_earliest_token(self):
+ """
+ Ingestion via notify_new_event will ignore tokens that the background
+ update have already processed.
+ """
+ self.reactor.advance(86401)
+
+ self.hs.config.stats_enabled = False
+ self.handler.stats_enabled = False
+
+ u1 = self.register_user("u1", "pass")
+ u1_token = self.login("u1", "pass")
+
+ u2 = self.register_user("u2", "pass")
+ u2_token = self.login("u2", "pass")
+
+ u3 = self.register_user("u3", "pass")
+ u3_token = self.login("u3", "pass")
+
+ room_1 = self.helper.create_room_as(u1, tok=u1_token)
+ self.helper.send_state(
+ room_1, event_type="m.room.topic", body={"topic": "foo"}, tok=u1_token
+ )
+
+ # Begin the ingestion by creating the temp tables. This will also store
+ # the position that the deltas should begin at, once they take over.
+ self.hs.config.stats_enabled = True
+ self.handler.stats_enabled = True
+ self.store._all_done = False
+ self.get_success(self.store.update_stats_stream_pos(None))
+
+ self.get_success(
+ self.store._simple_insert(
+ "background_updates",
+ {"update_name": "populate_stats_createtables", "progress_json": "{}"},
+ )
+ )
+
+ while not self.get_success(self.store.has_completed_background_updates()):
+ self.get_success(self.store.do_next_background_update(100), by=0.1)
+
+ # Now, before the table is actually ingested, add some more events.
+ self.helper.invite(room=room_1, src=u1, targ=u2, tok=u1_token)
+ self.helper.join(room=room_1, user=u2, tok=u2_token)
+
+ # Now do the initial ingestion.
+ self.get_success(
+ self.store._simple_insert(
+ "background_updates",
+ {"update_name": "populate_stats_process_rooms", "progress_json": "{}"},
+ )
+ )
+ self.get_success(
+ self.store._simple_insert(
+ "background_updates",
+ {
+ "update_name": "populate_stats_cleanup",
+ "progress_json": "{}",
+ "depends_on": "populate_stats_process_rooms",
+ },
+ )
+ )
+
+ self.store._all_done = False
+ while not self.get_success(self.store.has_completed_background_updates()):
+ self.get_success(self.store.do_next_background_update(100), by=0.1)
+
+ self.reactor.advance(86401)
+
+ # Now add some more events, triggering ingestion. Because of the stream
+ # position being set to before the events sent in the middle, a simpler
+ # implementation would reprocess those events, and say there were four
+ # users, not three.
+ self.helper.invite(room=room_1, src=u1, targ=u3, tok=u1_token)
+ self.helper.join(room=room_1, user=u3, tok=u3_token)
+
+ # Get the deltas! There should be two -- day 1, and day 2.
+ r = self.get_success(self.store.get_deltas_for_room(room_1, 0))
+
+ # The oldest has 2 joined members
+ self.assertEqual(r[-1]["joined_members"], 2)
+
+ # The newest has 3
+ self.assertEqual(r[0]["joined_members"], 3)
+
+ def test_incorrect_state_transition(self):
+ """
+ If the state transition is not one of (JOIN, INVITE, LEAVE, BAN) to
+ (JOIN, INVITE, LEAVE, BAN), an error is raised.
+ """
+ events = {
+ "a1": {"membership": Membership.LEAVE},
+ "a2": {"membership": "not a real thing"},
+ }
+
+ def get_event(event_id, allow_none=True):
+ m = Mock()
+ m.content = events[event_id]
+ d = defer.Deferred()
+ self.reactor.callLater(0.0, d.callback, m)
+ return d
+
+ def get_received_ts(event_id):
+ return defer.succeed(1)
+
+ self.store.get_received_ts = get_received_ts
+ self.store.get_event = get_event
+
+ deltas = [
+ {
+ "type": EventTypes.Member,
+ "state_key": "some_user",
+ "room_id": "room",
+ "event_id": "a1",
+ "prev_event_id": "a2",
+ "stream_id": 60,
+ }
+ ]
+
+ f = self.get_failure(self.handler._handle_deltas(deltas), ValueError)
+ self.assertEqual(
+ f.value.args[0], "'not a real thing' is not a valid prev_membership"
+ )
+
+ # And the other way...
+ deltas = [
+ {
+ "type": EventTypes.Member,
+ "state_key": "some_user",
+ "room_id": "room",
+ "event_id": "a2",
+ "prev_event_id": "a1",
+ "stream_id": 100,
+ }
+ ]
+
+ f = self.get_failure(self.handler._handle_deltas(deltas), ValueError)
+ self.assertEqual(
+ f.value.args[0], "'not a real thing' is not a valid membership"
+ )
+
+ def test_redacted_prev_event(self):
+ """
+ If the prev_event does not exist, then it is assumed to be a LEAVE.
+ """
+ u1 = self.register_user("u1", "pass")
+ u1_token = self.login("u1", "pass")
+
+ room_1 = self.helper.create_room_as(u1, tok=u1_token)
+
+ # Do the initial population of the user directory via the background update
+ self._add_background_updates()
+
+ while not self.get_success(self.store.has_completed_background_updates()):
+ self.get_success(self.store.do_next_background_update(100), by=0.1)
+
+ events = {
+ "a1": None,
+ "a2": {"membership": Membership.JOIN},
+ }
+
+ def get_event(event_id, allow_none=True):
+ if events.get(event_id):
+ m = Mock()
+ m.content = events[event_id]
+ else:
+ m = None
+ d = defer.Deferred()
+ self.reactor.callLater(0.0, d.callback, m)
+ return d
+
+ def get_received_ts(event_id):
+ return defer.succeed(1)
+
+ self.store.get_received_ts = get_received_ts
+ self.store.get_event = get_event
+
+ deltas = [
+ {
+ "type": EventTypes.Member,
+ "state_key": "some_user:test",
+ "room_id": room_1,
+ "event_id": "a2",
+ "prev_event_id": "a1",
+ "stream_id": 100,
+ }
+ ]
+
+ # Handle our fake deltas, which has a user going from LEAVE -> JOIN.
+ self.get_success(self.handler._handle_deltas(deltas))
+
+ # One delta, with two joined members -- the room creator, and our fake
+ # user.
+ r = self.get_success(self.store.get_deltas_for_room(room_1, 0))
+ self.assertEqual(len(r), 1)
+ self.assertEqual(r[0]["joined_members"], 2)
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 5a0b6c201c..cb8b4d2913 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -64,20 +64,22 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
mock_federation_client.put_json.return_value = defer.succeed((200, "OK"))
hs = self.setup_test_homeserver(
- datastore=(Mock(
- spec=[
- # Bits that Federation needs
- "prep_send_transaction",
- "delivered_txn",
- "get_received_txn_response",
- "set_received_txn_response",
- "get_destination_retry_timings",
- "get_devices_by_remote",
- # Bits that user_directory needs
- "get_user_directory_stream_pos",
- "get_current_state_deltas",
- ]
- )),
+ datastore=(
+ Mock(
+ spec=[
+ # Bits that Federation needs
+ "prep_send_transaction",
+ "delivered_txn",
+ "get_received_txn_response",
+ "set_received_txn_response",
+ "get_destination_retry_timings",
+ "get_devices_by_remote",
+ # Bits that user_directory needs
+ "get_user_directory_stream_pos",
+ "get_current_state_deltas",
+ ]
+ )
+ ),
notifier=Mock(),
http_client=mock_federation_client,
keyring=mock_keyring,
@@ -87,7 +89,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
# the tests assume that we are starting at unix time 1000
- reactor.pump((1000, ))
+ reactor.pump((1000,))
mock_notifier = hs.get_notifier()
self.on_new_event = mock_notifier.on_new_event
@@ -114,6 +116,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
def check_joined_room(room_id, user_id):
if user_id not in [u.to_string() for u in self.room_members]:
raise AuthError(401, "User is not in the room")
+
hs.get_auth().check_joined_room = check_joined_room
def get_joined_hosts_for_room(room_id):
@@ -123,6 +126,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
def get_current_users_in_room(room_id):
return set(str(u) for u in self.room_members)
+
hs.get_state_handler().get_current_users_in_room = get_current_users_in_room
self.datastore.get_user_directory_stream_pos.return_value = (
@@ -141,21 +145,16 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(self.event_source.get_current_key(), 0)
- self.successResultOf(self.handler.started_typing(
- target_user=U_APPLE,
- auth_user=U_APPLE,
- room_id=ROOM_ID,
- timeout=20000,
- ))
-
- self.on_new_event.assert_has_calls(
- [call('typing_key', 1, rooms=[ROOM_ID])]
+ self.successResultOf(
+ self.handler.started_typing(
+ target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=20000
+ )
)
+ self.on_new_event.assert_has_calls([call('typing_key', 1, rooms=[ROOM_ID])])
+
self.assertEquals(self.event_source.get_current_key(), 1)
- events = self.event_source.get_new_events(
- room_ids=[ROOM_ID], from_key=0
- )
+ events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
self.assertEquals(
events[0],
[
@@ -170,12 +169,11 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
def test_started_typing_remote_send(self):
self.room_members = [U_APPLE, U_ONION]
- self.successResultOf(self.handler.started_typing(
- target_user=U_APPLE,
- auth_user=U_APPLE,
- room_id=ROOM_ID,
- timeout=20000,
- ))
+ self.successResultOf(
+ self.handler.started_typing(
+ target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=20000
+ )
+ )
put_json = self.hs.get_http_client().put_json
put_json.assert_called_once_with(
@@ -216,14 +214,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.render(request)
self.assertEqual(channel.code, 200)
- self.on_new_event.assert_has_calls(
- [call('typing_key', 1, rooms=[ROOM_ID])]
- )
+ self.on_new_event.assert_has_calls([call('typing_key', 1, rooms=[ROOM_ID])])
self.assertEquals(self.event_source.get_current_key(), 1)
- events = self.event_source.get_new_events(
- room_ids=[ROOM_ID], from_key=0
- )
+ events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
self.assertEquals(
events[0],
[
@@ -247,14 +241,14 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(self.event_source.get_current_key(), 0)
- self.successResultOf(self.handler.stopped_typing(
- target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID
- ))
-
- self.on_new_event.assert_has_calls(
- [call('typing_key', 1, rooms=[ROOM_ID])]
+ self.successResultOf(
+ self.handler.stopped_typing(
+ target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID
+ )
)
+ self.on_new_event.assert_has_calls([call('typing_key', 1, rooms=[ROOM_ID])])
+
put_json = self.hs.get_http_client().put_json
put_json.assert_called_once_with(
"farm",
@@ -274,18 +268,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
)
self.assertEquals(self.event_source.get_current_key(), 1)
- events = self.event_source.get_new_events(
- room_ids=[ROOM_ID], from_key=0
- )
+ events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
self.assertEquals(
events[0],
- [
- {
- "type": "m.typing",
- "room_id": ROOM_ID,
- "content": {"user_ids": []},
- }
- ],
+ [{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}],
)
def test_typing_timeout(self):
@@ -293,22 +279,17 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(self.event_source.get_current_key(), 0)
- self.successResultOf(self.handler.started_typing(
- target_user=U_APPLE,
- auth_user=U_APPLE,
- room_id=ROOM_ID,
- timeout=10000,
- ))
-
- self.on_new_event.assert_has_calls(
- [call('typing_key', 1, rooms=[ROOM_ID])]
+ self.successResultOf(
+ self.handler.started_typing(
+ target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=10000
+ )
)
+
+ self.on_new_event.assert_has_calls([call('typing_key', 1, rooms=[ROOM_ID])])
self.on_new_event.reset_mock()
self.assertEquals(self.event_source.get_current_key(), 1)
- events = self.event_source.get_new_events(
- room_ids=[ROOM_ID], from_key=0
- )
+ events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
self.assertEquals(
events[0],
[
@@ -320,45 +301,30 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
],
)
- self.reactor.pump([16, ])
+ self.reactor.pump([16])
- self.on_new_event.assert_has_calls(
- [call('typing_key', 2, rooms=[ROOM_ID])]
- )
+ self.on_new_event.assert_has_calls([call('typing_key', 2, rooms=[ROOM_ID])])
self.assertEquals(self.event_source.get_current_key(), 2)
- events = self.event_source.get_new_events(
- room_ids=[ROOM_ID], from_key=1
- )
+ events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=1)
self.assertEquals(
events[0],
- [
- {
- "type": "m.typing",
- "room_id": ROOM_ID,
- "content": {"user_ids": []},
- }
- ],
+ [{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}],
)
# SYN-230 - see if we can still set after timeout
- self.successResultOf(self.handler.started_typing(
- target_user=U_APPLE,
- auth_user=U_APPLE,
- room_id=ROOM_ID,
- timeout=10000,
- ))
-
- self.on_new_event.assert_has_calls(
- [call('typing_key', 3, rooms=[ROOM_ID])]
+ self.successResultOf(
+ self.handler.started_typing(
+ target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=10000
+ )
)
+
+ self.on_new_event.assert_has_calls([call('typing_key', 3, rooms=[ROOM_ID])])
self.on_new_event.reset_mock()
self.assertEquals(self.event_source.get_current_key(), 3)
- events = self.event_source.get_new_events(
- room_ids=[ROOM_ID], from_key=0
- )
+ events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
self.assertEquals(
events[0],
[
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index f1d0aa42b6..9021e647fe 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -14,8 +14,9 @@
# limitations under the License.
from mock import Mock
+import synapse.rest.admin
from synapse.api.constants import UserTypes
-from synapse.rest.client.v1 import admin, login, room
+from synapse.rest.client.v1 import login, room
from synapse.rest.client.v2_alpha import user_directory
from synapse.storage.roommember import ProfileInfo
@@ -29,14 +30,14 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
servlets = [
login.register_servlets,
- admin.register_servlets,
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
room.register_servlets,
]
def make_homeserver(self, reactor, clock):
config = self.default_config()
- config.update_user_directory = True
+ config["update_user_directory"] = True
return self.setup_test_homeserver(config=config)
def prepare(self, reactor, clock, hs):
@@ -327,12 +328,12 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase):
user_directory.register_servlets,
room.register_servlets,
login.register_servlets,
- admin.register_servlets,
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
]
def make_homeserver(self, reactor, clock):
config = self.default_config()
- config.update_user_directory = True
+ config["update_user_directory"] = True
hs = self.setup_test_homeserver(config=config)
self.config = hs.config
@@ -351,9 +352,7 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase):
# Assert user directory is not empty
request, channel = self.make_request(
- "POST",
- b"user_directory/search",
- b'{"search_term":"user2"}',
+ "POST", b"user_directory/search", b'{"search_term":"user2"}'
)
self.render(request)
self.assertEquals(200, channel.code, channel.result)
@@ -362,9 +361,7 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase):
# Disable user directory and check search returns nothing
self.config.user_directory_search_enabled = False
request, channel = self.make_request(
- "POST",
- b"user_directory/search",
- b'{"search_term":"user2"}',
+ "POST", b"user_directory/search", b'{"search_term":"user2"}'
)
self.render(request)
self.assertEquals(200, channel.code, channel.result)
diff --git a/tests/http/__init__.py b/tests/http/__init__.py
index ee8010f598..2d5dba6464 100644
--- a/tests/http/__init__.py
+++ b/tests/http/__init__.py
@@ -13,30 +13,122 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os.path
+import subprocess
+
+from zope.interface import implementer
from OpenSSL import SSL
+from OpenSSL.SSL import Connection
+from twisted.internet.interfaces import IOpenSSLServerConnectionCreator
+
+
+def get_test_ca_cert_file():
+ """Get the path to the test CA cert
+
+ The keypair is generated with:
+
+ openssl genrsa -out ca.key 2048
+ openssl req -new -x509 -key ca.key -days 3650 -out ca.crt \
+ -subj '/CN=synapse test CA'
+ """
+ return os.path.join(os.path.dirname(__file__), "ca.crt")
+
+
+def get_test_key_file():
+ """get the path to the test key
+
+ The key file is made with:
+
+ openssl genrsa -out server.key 2048
+ """
+ return os.path.join(os.path.dirname(__file__), "server.key")
+
+cert_file_count = 0
-def get_test_cert_file():
- """get the path to the test cert"""
+CONFIG_TEMPLATE = b"""\
+[default]
+basicConstraints = CA:FALSE
+keyUsage=nonRepudiation, digitalSignature, keyEncipherment
+subjectAltName = %(sanentries)s
+"""
- # the cert file itself is made with:
- #
- # openssl req -x509 -newkey rsa:4096 -keyout server.pem -out server.pem -days 36500 \
- # -nodes -subj '/CN=testserv'
- return os.path.join(
- os.path.dirname(__file__),
- 'server.pem',
+
+def create_test_cert_file(sanlist):
+ """build an x509 certificate file
+
+ Args:
+ sanlist: list[bytes]: a list of subjectAltName values for the cert
+
+ Returns:
+ str: the path to the file
+ """
+ global cert_file_count
+ csr_filename = "server.csr"
+ cnf_filename = "server.%i.cnf" % (cert_file_count,)
+ cert_filename = "server.%i.crt" % (cert_file_count,)
+ cert_file_count += 1
+
+ # first build a CSR
+ subprocess.check_call(
+ [
+ "openssl",
+ "req",
+ "-new",
+ "-key",
+ get_test_key_file(),
+ "-subj",
+ "/",
+ "-out",
+ csr_filename,
+ ]
)
+ # now a config file describing the right SAN entries
+ sanentries = b",".join(sanlist)
+ with open(cnf_filename, "wb") as f:
+ f.write(CONFIG_TEMPLATE % {b"sanentries": sanentries})
+
+ # finally the cert
+ ca_key_filename = os.path.join(os.path.dirname(__file__), "ca.key")
+ ca_cert_filename = get_test_ca_cert_file()
+ subprocess.check_call(
+ [
+ "openssl",
+ "x509",
+ "-req",
+ "-in",
+ csr_filename,
+ "-CA",
+ ca_cert_filename,
+ "-CAkey",
+ ca_key_filename,
+ "-set_serial",
+ "1",
+ "-extfile",
+ cnf_filename,
+ "-out",
+ cert_filename,
+ ]
+ )
+
+ return cert_filename
+
+
+@implementer(IOpenSSLServerConnectionCreator)
+class TestServerTLSConnectionFactory(object):
+ """An SSL connection creator which returns connections which present a certificate
+ signed by our test CA."""
-class ServerTLSContext(object):
- """A TLS Context which presents our test cert."""
- def __init__(self):
- self.filename = get_test_cert_file()
+ def __init__(self, sanlist):
+ """
+ Args:
+ sanlist: list[bytes]: a list of subjectAltName values for the cert
+ """
+ self._cert_file = create_test_cert_file(sanlist)
- def getContext(self):
+ def serverConnectionForTLS(self, tlsProtocol):
ctx = SSL.Context(SSL.TLSv1_METHOD)
- ctx.use_certificate_file(self.filename)
- ctx.use_privatekey_file(self.filename)
- return ctx
+ ctx.use_certificate_file(self._cert_file)
+ ctx.use_privatekey_file(get_test_key_file())
+ return Connection(ctx, None)
diff --git a/tests/http/ca.crt b/tests/http/ca.crt
new file mode 100644
index 0000000000..730f81e99c
--- /dev/null
+++ b/tests/http/ca.crt
@@ -0,0 +1,19 @@
+-----BEGIN CERTIFICATE-----
+MIIDCjCCAfKgAwIBAgIJAPwHIHgH/jtjMA0GCSqGSIb3DQEBCwUAMBoxGDAWBgNV
+BAMMD3N5bmFwc2UgdGVzdCBDQTAeFw0xOTA2MTAxMTI2NDdaFw0yOTA2MDcxMTI2
+NDdaMBoxGDAWBgNVBAMMD3N5bmFwc2UgdGVzdCBDQTCCASIwDQYJKoZIhvcNAQEB
+BQADggEPADCCAQoCggEBAOZOXCKuylf9jHzJXpU2nS+XEKrnGPgs2SAhQKrzBxg3
+/d8KT2Zsfsj1i3G7oGu7B0ZKO6qG5AxOPCmSMf9/aiSHFilfSh+r8rCpJyWMev2c
+/w/xmhoFHgn+H90NnqlXvWb5y1YZCE3gWaituQSaa93GPKacRqXCgIrzjPUuhfeT
+uwFQt4iyUhMNBYEy3aw4IuIHdyBqi4noUhR2ZeuflLJ6PswdJ8mEiAvxCbBGPerq
+idhWcZwlo0fKu4u1uu5B8TnTsMg2fJgL6c5olBG90Urt22gA6anfP5W/U1ZdVhmB
+T3Rv5SJMkGyMGE6sEUetLFyb2GJpgGD7ePkUCZr+IMMCAwEAAaNTMFEwHQYDVR0O
+BBYEFLg7nTCYsvQXWTyS6upLc0YTlIwRMB8GA1UdIwQYMBaAFLg7nTCYsvQXWTyS
+6upLc0YTlIwRMA8GA1UdEwEB/wQFMAMBAf8wDQYJKoZIhvcNAQELBQADggEBADqx
+GX4Ul5OGQlcG+xTt4u3vMCeqGo8mh1AnJ7zQbyRmwjJiNxJVX+/EcqFSTsmkBNoe
+xdYITI7Z6dyoiKw99yCZDE7gALcyACEU7r0XY7VY/hebAaX6uLaw1sZKKAIC04lD
+KgCu82tG85n60Qyud5SiZZF0q1XVq7lbvOYVdzVZ7k8Vssy5p9XnaLJLMggYeOiX
+psHIQjvYGnTTEBZZHzWOrc0WGThd69wxTOOkAbCsoTPEwZL8BGUsdtLWtvhp452O
+npvaUBzKg39R5X3KTdhB68XptiQfzbQkd3FtrwNuYPUywlsg55Bxkv85n57+xDO3
+D9YkgUqEp0RGUXQgCsQ=
+-----END CERTIFICATE-----
diff --git a/tests/http/ca.key b/tests/http/ca.key
new file mode 100644
index 0000000000..5c99cae186
--- /dev/null
+++ b/tests/http/ca.key
@@ -0,0 +1,27 @@
+-----BEGIN RSA PRIVATE KEY-----
+MIIEpgIBAAKCAQEA5k5cIq7KV/2MfMlelTadL5cQqucY+CzZICFAqvMHGDf93wpP
+Zmx+yPWLcbuga7sHRko7qobkDE48KZIx/39qJIcWKV9KH6vysKknJYx6/Zz/D/Ga
+GgUeCf4f3Q2eqVe9ZvnLVhkITeBZqK25BJpr3cY8ppxGpcKAivOM9S6F95O7AVC3
+iLJSEw0FgTLdrDgi4gd3IGqLiehSFHZl65+Usno+zB0nyYSIC/EJsEY96uqJ2FZx
+nCWjR8q7i7W67kHxOdOwyDZ8mAvpzmiUEb3RSu3baADpqd8/lb9TVl1WGYFPdG/l
+IkyQbIwYTqwRR60sXJvYYmmAYPt4+RQJmv4gwwIDAQABAoIBAQCFuFG+wYYy+MCt
+Y65LLN6vVyMSWAQjdMbM5QHLQDiKU1hQPIhFjBFBVXCVpL9MTde3dDqYlKGsk3BT
+ItNs6eoTM2wmsXE0Wn4bHNvh7WMsBhACjeFP4lDCtI6DpvjMkmkidT8eyoIL1Yu5
+aMTYa2Dd79AfXPWYIQrJowfhBBY83KuW5fmYnKKDVLqkT9nf2dgmmQz85RgtNiZC
+zFkIsNmPqH1zRbcw0wORfOBrLFvsMc4Tt8EY5Wz3NnH8Zfgf8Q3MgARH1yspz3Vp
+B+EYHbsK17xZ+P59KPiX3yefvyYWEUjFF7ymVsVnDxLugYl4pXwWUpm19GxeDvFk
+cgBUD5OBAoGBAP7lBdCp6lx6fYtxdxUm3n4MMQmYcac4qZdeBIrvpFMnvOBBuixl
+eavcfFmFdwgAr8HyVYiu9ynac504IYvmtYlcpUmiRBbmMHbvLQEYHl7FYFKNz9ej
+2ue4oJE3RsPdLsD3xIlc+xN8oT1j0knyorwsHdj0Sv77eZzZS9XZZfJzAoGBAOdO
+CibYmoNqK/mqDHkp6PgsnbQGD5/CvPF/BLUWV1QpHxLzUQQeoBOQW5FatHe1H5zi
+mbq3emBefVmsCLrRIJ4GQu4vsTMfjcpGLwviWmaK6pHbGPt8IYeEQ2MNyv59EtA2
+pQy4dX7/Oe6NLAR1UEQjXmCuXf+rxnxF3VJd1nRxAoGBANb9eusl9fusgSnVOTjJ
+AQ7V36KVRv9hZoG6liBNwo80zDVmms4JhRd1MBkd3mkMkzIF4SkZUnWlwLBSANGM
+dX/3eZ5i1AVwgF5Am/f5TNxopDbdT/o1RVT/P8dcFT7s1xuBn+6wU0F7dFBgWqVu
+lt4aY85zNrJcj5XBHhqwdDGLAoGBAIksPNUAy9F3m5C6ih8o/aKAQx5KIeXrBUZq
+v43tK+kbYfRJHBjHWMOBbuxq0G/VmGPf9q9GtGqGXuxZG+w+rYtJx1OeMQZShjIZ
+ITl5CYeahrXtK4mo+fF2PMh3m5UE861LWuKKWhPwpJiWXC5grDNcjlHj1pcTdeip
+PjHkuJPhAoGBAIh35DptqqdicOd3dr/+/m2YQywY8aSpMrR0bC06aAkscD7oq4tt
+s/jwl0UlHIrEm/aMN7OnGIbpfkVdExfGKYaa5NRlgOwQpShwLufIo/c8fErd2zb8
+K3ptlwBxMrayMXpS3DP78r83Z0B8/FSK2guelzdRJ3ftipZ9io1Gss1C
+-----END RSA PRIVATE KEY-----
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index dcf184d3cf..ecce473b01 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -17,16 +17,19 @@ import logging
from mock import Mock
import treq
+from service_identity import VerificationError
from zope.interface import implementer
from twisted.internet import defer
from twisted.internet._sslverify import ClientTLSOptions, OpenSSLCertificateOptions
from twisted.internet.protocol import Factory
from twisted.protocols.tls import TLSMemoryBIOFactory
+from twisted.web._newclient import ResponseNeverReceived
from twisted.web.http import HTTPChannel
from twisted.web.http_headers import Headers
from twisted.web.iweb import IPolicyForHTTPS
+from synapse.config.homeserver import HomeServerConfig
from synapse.crypto.context_factory import ClientTLSOptionsFactory
from synapse.http.federation.matrix_federation_agent import (
MatrixFederationAgent,
@@ -36,12 +39,29 @@ from synapse.http.federation.srv_resolver import Server
from synapse.util.caches.ttlcache import TTLCache
from synapse.util.logcontext import LoggingContext
-from tests.http import ServerTLSContext
+from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file
from tests.server import FakeTransport, ThreadedMemoryReactorClock
from tests.unittest import TestCase
+from tests.utils import default_config
logger = logging.getLogger(__name__)
+test_server_connection_factory = None
+
+
+def get_connection_factory():
+ # this needs to happen once, but not until we are ready to run the first test
+ global test_server_connection_factory
+ if test_server_connection_factory is None:
+ test_server_connection_factory = TestServerTLSConnectionFactory(sanlist=[
+ b'DNS:testserv',
+ b'DNS:target-server',
+ b'DNS:xn--bcher-kva.com',
+ b'IP:1.2.3.4',
+ b'IP:::1',
+ ])
+ return test_server_connection_factory
+
class MatrixFederationAgentTests(TestCase):
def setUp(self):
@@ -51,9 +71,16 @@ class MatrixFederationAgentTests(TestCase):
self.well_known_cache = TTLCache("test_cache", timer=self.reactor.seconds)
+ config_dict = default_config("test", parse=False)
+ config_dict["federation_custom_ca_list"] = [get_test_ca_cert_file()]
+ # config_dict["trusted_key_servers"] = []
+
+ self._config = config = HomeServerConfig()
+ config.parse_config_dict(config_dict)
+
self.agent = MatrixFederationAgent(
reactor=self.reactor,
- tls_client_options_factory=ClientTLSOptionsFactory(None),
+ tls_client_options_factory=ClientTLSOptionsFactory(config),
_well_known_tls_policy=TrustingTLSPolicyForHTTPS(),
_srv_resolver=self.mock_resolver,
_well_known_cache=self.well_known_cache,
@@ -67,7 +94,7 @@ class MatrixFederationAgentTests(TestCase):
"""
# build the test server
- server_tls_protocol = _build_test_server()
+ server_tls_protocol = _build_test_server(get_connection_factory())
# now, tell the client protocol factory to build the client protocol (it will be a
# _WrappingProtocol, around a TLSMemoryBIOProtocol, around an
@@ -78,12 +105,12 @@ class MatrixFederationAgentTests(TestCase):
# stubbing that out here.
client_protocol = client_factory.buildProtocol(None)
client_protocol.makeConnection(
- FakeTransport(server_tls_protocol, self.reactor, client_protocol),
+ FakeTransport(server_tls_protocol, self.reactor, client_protocol)
)
# tell the server tls protocol to send its stuff back to the client, too
server_tls_protocol.makeConnection(
- FakeTransport(client_protocol, self.reactor, server_tls_protocol),
+ FakeTransport(client_protocol, self.reactor, server_tls_protocol)
)
# give the reactor a pump to get the TLS juices flowing.
@@ -124,7 +151,7 @@ class MatrixFederationAgentTests(TestCase):
_check_logcontext(context)
def _handle_well_known_connection(
- self, client_factory, expected_sni, content, response_headers={},
+ self, client_factory, expected_sni, content, response_headers={}
):
"""Handle an outgoing HTTPs connection: wire it up to a server, check that the
request is for a .well-known, and send the response.
@@ -138,8 +165,7 @@ class MatrixFederationAgentTests(TestCase):
"""
# make the connection for .well-known
well_known_server = self._make_connection(
- client_factory,
- expected_sni=expected_sni,
+ client_factory, expected_sni=expected_sni
)
# check the .well-known request and send a response
self.assertEqual(len(well_known_server.requests), 1)
@@ -153,17 +179,14 @@ class MatrixFederationAgentTests(TestCase):
"""
self.assertEqual(request.method, b'GET')
self.assertEqual(request.path, b'/.well-known/matrix/server')
- self.assertEqual(
- request.requestHeaders.getRawHeaders(b'host'),
- [b'testserv'],
- )
+ self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'testserv'])
# send back a response
for k, v in headers.items():
request.setHeader(k, v)
request.write(content)
request.finish()
- self.reactor.pump((0.1, ))
+ self.reactor.pump((0.1,))
def test_get(self):
"""
@@ -183,18 +206,14 @@ class MatrixFederationAgentTests(TestCase):
self.assertEqual(port, 8448)
# make a test server, and wire up the client
- http_server = self._make_connection(
- client_factory,
- expected_sni=b"testserv",
- )
+ http_server = self._make_connection(client_factory, expected_sni=b"testserv")
self.assertEqual(len(http_server.requests), 1)
request = http_server.requests[0]
self.assertEqual(request.method, b'GET')
self.assertEqual(request.path, b'/foo/bar')
self.assertEqual(
- request.requestHeaders.getRawHeaders(b'host'),
- [b'testserv:8448']
+ request.requestHeaders.getRawHeaders(b'host'), [b'testserv:8448']
)
content = request.content.read()
self.assertEqual(content, b'')
@@ -243,19 +262,13 @@ class MatrixFederationAgentTests(TestCase):
self.assertEqual(port, 8448)
# make a test server, and wire up the client
- http_server = self._make_connection(
- client_factory,
- expected_sni=None,
- )
+ http_server = self._make_connection(client_factory, expected_sni=None)
self.assertEqual(len(http_server.requests), 1)
request = http_server.requests[0]
self.assertEqual(request.method, b'GET')
self.assertEqual(request.path, b'/foo/bar')
- self.assertEqual(
- request.requestHeaders.getRawHeaders(b'host'),
- [b'1.2.3.4'],
- )
+ self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'1.2.3.4'])
# finish the request
request.finish()
@@ -284,19 +297,13 @@ class MatrixFederationAgentTests(TestCase):
self.assertEqual(port, 8448)
# make a test server, and wire up the client
- http_server = self._make_connection(
- client_factory,
- expected_sni=None,
- )
+ http_server = self._make_connection(client_factory, expected_sni=None)
self.assertEqual(len(http_server.requests), 1)
request = http_server.requests[0]
self.assertEqual(request.method, b'GET')
self.assertEqual(request.path, b'/foo/bar')
- self.assertEqual(
- request.requestHeaders.getRawHeaders(b'host'),
- [b'[::1]'],
- )
+ self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'[::1]'])
# finish the request
request.finish()
@@ -325,25 +332,101 @@ class MatrixFederationAgentTests(TestCase):
self.assertEqual(port, 80)
# make a test server, and wire up the client
- http_server = self._make_connection(
- client_factory,
- expected_sni=None,
- )
+ http_server = self._make_connection(client_factory, expected_sni=None)
self.assertEqual(len(http_server.requests), 1)
request = http_server.requests[0]
self.assertEqual(request.method, b'GET')
self.assertEqual(request.path, b'/foo/bar')
- self.assertEqual(
- request.requestHeaders.getRawHeaders(b'host'),
- [b'[::1]:80'],
- )
+ self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'[::1]:80'])
# finish the request
request.finish()
self.reactor.pump((0.1,))
self.successResultOf(test_d)
+ def test_get_hostname_bad_cert(self):
+ """
+ Test the behaviour when the certificate on the server doesn't match the hostname
+ """
+ self.mock_resolver.resolve_service.side_effect = lambda _: []
+ self.reactor.lookups["testserv1"] = "1.2.3.4"
+
+ test_d = self._make_get_request(b"matrix://testserv1/foo/bar")
+
+ # Nothing happened yet
+ self.assertNoResult(test_d)
+
+ # No SRV record lookup yet
+ self.mock_resolver.resolve_service.assert_not_called()
+
+ # there should be an attempt to connect on port 443 for the .well-known
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, '1.2.3.4')
+ self.assertEqual(port, 443)
+
+ # fonx the connection
+ client_factory.clientConnectionFailed(None, Exception("nope"))
+
+ # attemptdelay on the hostnameendpoint is 0.3, so takes that long before the
+ # .well-known request fails.
+ self.reactor.pump((0.4,))
+
+ # now there should be a SRV lookup
+ self.mock_resolver.resolve_service.assert_called_once_with(
+ b"_matrix._tcp.testserv1"
+ )
+
+ # we should fall back to a direct connection
+ self.assertEqual(len(clients), 2)
+ (host, port, client_factory, _timeout, _bindAddress) = clients[1]
+ self.assertEqual(host, '1.2.3.4')
+ self.assertEqual(port, 8448)
+
+ # make a test server, and wire up the client
+ http_server = self._make_connection(client_factory, expected_sni=b'testserv1')
+
+ # there should be no requests
+ self.assertEqual(len(http_server.requests), 0)
+
+ # ... and the request should have failed
+ e = self.failureResultOf(test_d, ResponseNeverReceived)
+ failure_reason = e.value.reasons[0]
+ self.assertIsInstance(failure_reason.value, VerificationError)
+
+ def test_get_ip_address_bad_cert(self):
+ """
+ Test the behaviour when the server name contains an explicit IP, but
+ the server cert doesn't cover it
+ """
+ # there will be a getaddrinfo on the IP
+ self.reactor.lookups["1.2.3.5"] = "1.2.3.5"
+
+ test_d = self._make_get_request(b"matrix://1.2.3.5/foo/bar")
+
+ # Nothing happened yet
+ self.assertNoResult(test_d)
+
+ # Make sure treq is trying to connect
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, '1.2.3.5')
+ self.assertEqual(port, 8448)
+
+ # make a test server, and wire up the client
+ http_server = self._make_connection(client_factory, expected_sni=None)
+
+ # there should be no requests
+ self.assertEqual(len(http_server.requests), 0)
+
+ # ... and the request should have failed
+ e = self.failureResultOf(test_d, ResponseNeverReceived)
+ failure_reason = e.value.reasons[0]
+ self.assertIsInstance(failure_reason.value, VerificationError)
+
def test_get_no_srv_no_well_known(self):
"""
Test the behaviour when the server name has no port, no SRV, and no well-known
@@ -376,7 +459,7 @@ class MatrixFederationAgentTests(TestCase):
# now there should be a SRV lookup
self.mock_resolver.resolve_service.assert_called_once_with(
- b"_matrix._tcp.testserv",
+ b"_matrix._tcp.testserv"
)
# we should fall back to a direct connection
@@ -386,19 +469,13 @@ class MatrixFederationAgentTests(TestCase):
self.assertEqual(port, 8448)
# make a test server, and wire up the client
- http_server = self._make_connection(
- client_factory,
- expected_sni=b'testserv',
- )
+ http_server = self._make_connection(client_factory, expected_sni=b'testserv')
self.assertEqual(len(http_server.requests), 1)
request = http_server.requests[0]
self.assertEqual(request.method, b'GET')
self.assertEqual(request.path, b'/foo/bar')
- self.assertEqual(
- request.requestHeaders.getRawHeaders(b'host'),
- [b'testserv'],
- )
+ self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'testserv'])
# finish the request
request.finish()
@@ -426,13 +503,14 @@ class MatrixFederationAgentTests(TestCase):
self.assertEqual(port, 443)
self._handle_well_known_connection(
- client_factory, expected_sni=b"testserv",
+ client_factory,
+ expected_sni=b"testserv",
content=b'{ "m.server": "target-server" }',
)
# there should be a SRV lookup
self.mock_resolver.resolve_service.assert_called_once_with(
- b"_matrix._tcp.target-server",
+ b"_matrix._tcp.target-server"
)
# now we should get a connection to the target server
@@ -443,8 +521,7 @@ class MatrixFederationAgentTests(TestCase):
# make a test server, and wire up the client
http_server = self._make_connection(
- client_factory,
- expected_sni=b'target-server',
+ client_factory, expected_sni=b'target-server'
)
self.assertEqual(len(http_server.requests), 1)
@@ -452,8 +529,7 @@ class MatrixFederationAgentTests(TestCase):
self.assertEqual(request.method, b'GET')
self.assertEqual(request.path, b'/foo/bar')
self.assertEqual(
- request.requestHeaders.getRawHeaders(b'host'),
- [b'target-server'],
+ request.requestHeaders.getRawHeaders(b'host'), [b'target-server']
)
# finish the request
@@ -489,8 +565,7 @@ class MatrixFederationAgentTests(TestCase):
self.assertEqual(port, 443)
redirect_server = self._make_connection(
- client_factory,
- expected_sni=b"testserv",
+ client_factory, expected_sni=b"testserv"
)
# send a 302 redirect
@@ -499,7 +574,7 @@ class MatrixFederationAgentTests(TestCase):
request.redirect(b'https://testserv/even_better_known')
request.finish()
- self.reactor.pump((0.1, ))
+ self.reactor.pump((0.1,))
# now there should be another connection
clients = self.reactor.tcpClients
@@ -509,8 +584,7 @@ class MatrixFederationAgentTests(TestCase):
self.assertEqual(port, 443)
well_known_server = self._make_connection(
- client_factory,
- expected_sni=b"testserv",
+ client_factory, expected_sni=b"testserv"
)
self.assertEqual(len(well_known_server.requests), 1, "No request after 302")
@@ -520,11 +594,11 @@ class MatrixFederationAgentTests(TestCase):
request.write(b'{ "m.server": "target-server" }')
request.finish()
- self.reactor.pump((0.1, ))
+ self.reactor.pump((0.1,))
# there should be a SRV lookup
self.mock_resolver.resolve_service.assert_called_once_with(
- b"_matrix._tcp.target-server",
+ b"_matrix._tcp.target-server"
)
# now we should get a connection to the target server
@@ -535,8 +609,7 @@ class MatrixFederationAgentTests(TestCase):
# make a test server, and wire up the client
http_server = self._make_connection(
- client_factory,
- expected_sni=b'target-server',
+ client_factory, expected_sni=b'target-server'
)
self.assertEqual(len(http_server.requests), 1)
@@ -544,8 +617,7 @@ class MatrixFederationAgentTests(TestCase):
self.assertEqual(request.method, b'GET')
self.assertEqual(request.path, b'/foo/bar')
self.assertEqual(
- request.requestHeaders.getRawHeaders(b'host'),
- [b'target-server'],
+ request.requestHeaders.getRawHeaders(b'host'), [b'target-server']
)
# finish the request
@@ -584,12 +656,12 @@ class MatrixFederationAgentTests(TestCase):
self.assertEqual(port, 443)
self._handle_well_known_connection(
- client_factory, expected_sni=b"testserv", content=b'NOT JSON',
+ client_factory, expected_sni=b"testserv", content=b'NOT JSON'
)
# now there should be a SRV lookup
self.mock_resolver.resolve_service.assert_called_once_with(
- b"_matrix._tcp.testserv",
+ b"_matrix._tcp.testserv"
)
# we should fall back to a direct connection
@@ -599,25 +671,62 @@ class MatrixFederationAgentTests(TestCase):
self.assertEqual(port, 8448)
# make a test server, and wire up the client
- http_server = self._make_connection(
- client_factory,
- expected_sni=b'testserv',
- )
+ http_server = self._make_connection(client_factory, expected_sni=b'testserv')
self.assertEqual(len(http_server.requests), 1)
request = http_server.requests[0]
self.assertEqual(request.method, b'GET')
self.assertEqual(request.path, b'/foo/bar')
- self.assertEqual(
- request.requestHeaders.getRawHeaders(b'host'),
- [b'testserv'],
- )
+ self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'testserv'])
# finish the request
request.finish()
self.reactor.pump((0.1,))
self.successResultOf(test_d)
+ def test_get_well_known_unsigned_cert(self):
+ """Test the behaviour when the .well-known server presents a cert
+ not signed by a CA
+ """
+
+ # we use the same test server as the other tests, but use an agent
+ # with _well_known_tls_policy left to the default, which will not
+ # trust it (since the presented cert is signed by a test CA)
+
+ self.mock_resolver.resolve_service.side_effect = lambda _: []
+ self.reactor.lookups["testserv"] = "1.2.3.4"
+
+ agent = MatrixFederationAgent(
+ reactor=self.reactor,
+ tls_client_options_factory=ClientTLSOptionsFactory(self._config),
+ _srv_resolver=self.mock_resolver,
+ _well_known_cache=self.well_known_cache,
+ )
+
+ test_d = agent.request(b"GET", b"matrix://testserv/foo/bar")
+
+ # Nothing happened yet
+ self.assertNoResult(test_d)
+
+ # there should be an attempt to connect on port 443 for the .well-known
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, '1.2.3.4')
+ self.assertEqual(port, 443)
+
+ http_proto = self._make_connection(
+ client_factory, expected_sni=b"testserv",
+ )
+
+ # there should be no requests
+ self.assertEqual(len(http_proto.requests), 0)
+
+ # and there should be a SRV lookup instead
+ self.mock_resolver.resolve_service.assert_called_once_with(
+ b"_matrix._tcp.testserv"
+ )
+
def test_get_hostname_srv(self):
"""
Test the behaviour when there is a single SRV record
@@ -634,7 +743,7 @@ class MatrixFederationAgentTests(TestCase):
# the request for a .well-known will have failed with a DNS lookup error.
self.mock_resolver.resolve_service.assert_called_once_with(
- b"_matrix._tcp.testserv",
+ b"_matrix._tcp.testserv"
)
# Make sure treq is trying to connect
@@ -645,19 +754,13 @@ class MatrixFederationAgentTests(TestCase):
self.assertEqual(port, 8443)
# make a test server, and wire up the client
- http_server = self._make_connection(
- client_factory,
- expected_sni=b'testserv',
- )
+ http_server = self._make_connection(client_factory, expected_sni=b'testserv')
self.assertEqual(len(http_server.requests), 1)
request = http_server.requests[0]
self.assertEqual(request.method, b'GET')
self.assertEqual(request.path, b'/foo/bar')
- self.assertEqual(
- request.requestHeaders.getRawHeaders(b'host'),
- [b'testserv'],
- )
+ self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'testserv'])
# finish the request
request.finish()
@@ -684,17 +787,18 @@ class MatrixFederationAgentTests(TestCase):
self.assertEqual(port, 443)
self.mock_resolver.resolve_service.side_effect = lambda _: [
- Server(host=b"srvtarget", port=8443),
+ Server(host=b"srvtarget", port=8443)
]
self._handle_well_known_connection(
- client_factory, expected_sni=b"testserv",
+ client_factory,
+ expected_sni=b"testserv",
content=b'{ "m.server": "target-server" }',
)
# there should be a SRV lookup
self.mock_resolver.resolve_service.assert_called_once_with(
- b"_matrix._tcp.target-server",
+ b"_matrix._tcp.target-server"
)
# now we should get a connection to the target of the SRV record
@@ -705,8 +809,7 @@ class MatrixFederationAgentTests(TestCase):
# make a test server, and wire up the client
http_server = self._make_connection(
- client_factory,
- expected_sni=b'target-server',
+ client_factory, expected_sni=b'target-server'
)
self.assertEqual(len(http_server.requests), 1)
@@ -714,8 +817,7 @@ class MatrixFederationAgentTests(TestCase):
self.assertEqual(request.method, b'GET')
self.assertEqual(request.path, b'/foo/bar')
self.assertEqual(
- request.requestHeaders.getRawHeaders(b'host'),
- [b'target-server'],
+ request.requestHeaders.getRawHeaders(b'host'), [b'target-server']
)
# finish the request
@@ -756,7 +858,7 @@ class MatrixFederationAgentTests(TestCase):
# now there should have been a SRV lookup
self.mock_resolver.resolve_service.assert_called_once_with(
- b"_matrix._tcp.xn--bcher-kva.com",
+ b"_matrix._tcp.xn--bcher-kva.com"
)
# We should fall back to port 8448
@@ -768,8 +870,7 @@ class MatrixFederationAgentTests(TestCase):
# make a test server, and wire up the client
http_server = self._make_connection(
- client_factory,
- expected_sni=b'xn--bcher-kva.com',
+ client_factory, expected_sni=b'xn--bcher-kva.com'
)
self.assertEqual(len(http_server.requests), 1)
@@ -777,8 +878,7 @@ class MatrixFederationAgentTests(TestCase):
self.assertEqual(request.method, b'GET')
self.assertEqual(request.path, b'/foo/bar')
self.assertEqual(
- request.requestHeaders.getRawHeaders(b'host'),
- [b'xn--bcher-kva.com'],
+ request.requestHeaders.getRawHeaders(b'host'), [b'xn--bcher-kva.com']
)
# finish the request
@@ -800,7 +900,7 @@ class MatrixFederationAgentTests(TestCase):
self.assertNoResult(test_d)
self.mock_resolver.resolve_service.assert_called_once_with(
- b"_matrix._tcp.xn--bcher-kva.com",
+ b"_matrix._tcp.xn--bcher-kva.com"
)
# Make sure treq is trying to connect
@@ -812,8 +912,7 @@ class MatrixFederationAgentTests(TestCase):
# make a test server, and wire up the client
http_server = self._make_connection(
- client_factory,
- expected_sni=b'xn--bcher-kva.com',
+ client_factory, expected_sni=b'xn--bcher-kva.com'
)
self.assertEqual(len(http_server.requests), 1)
@@ -821,8 +920,7 @@ class MatrixFederationAgentTests(TestCase):
self.assertEqual(request.method, b'GET')
self.assertEqual(request.path, b'/foo/bar')
self.assertEqual(
- request.requestHeaders.getRawHeaders(b'host'),
- [b'xn--bcher-kva.com'],
+ request.requestHeaders.getRawHeaders(b'host'), [b'xn--bcher-kva.com']
)
# finish the request
@@ -896,74 +994,83 @@ class TestCachePeriodFromHeaders(TestCase):
# uppercase
self.assertEqual(
_cache_period_from_headers(
- Headers({b'Cache-Control': [b'foo, Max-Age = 100, bar']}),
- ), 100,
+ Headers({b'Cache-Control': [b'foo, Max-Age = 100, bar']})
+ ),
+ 100,
)
# missing value
- self.assertIsNone(_cache_period_from_headers(
- Headers({b'Cache-Control': [b'max-age=, bar']}),
- ))
+ self.assertIsNone(
+ _cache_period_from_headers(Headers({b'Cache-Control': [b'max-age=, bar']}))
+ )
# hackernews: bogus due to semicolon
- self.assertIsNone(_cache_period_from_headers(
- Headers({b'Cache-Control': [b'private; max-age=0']}),
- ))
+ self.assertIsNone(
+ _cache_period_from_headers(
+ Headers({b'Cache-Control': [b'private; max-age=0']})
+ )
+ )
# github
self.assertEqual(
_cache_period_from_headers(
- Headers({b'Cache-Control': [b'max-age=0, private, must-revalidate']}),
- ), 0,
+ Headers({b'Cache-Control': [b'max-age=0, private, must-revalidate']})
+ ),
+ 0,
)
# google
self.assertEqual(
_cache_period_from_headers(
- Headers({b'cache-control': [b'private, max-age=0']}),
- ), 0,
+ Headers({b'cache-control': [b'private, max-age=0']})
+ ),
+ 0,
)
def test_expires(self):
self.assertEqual(
_cache_period_from_headers(
Headers({b'Expires': [b'Wed, 30 Jan 2019 07:35:33 GMT']}),
- time_now=lambda: 1548833700
- ), 33,
+ time_now=lambda: 1548833700,
+ ),
+ 33,
)
# cache-control overrides expires
self.assertEqual(
_cache_period_from_headers(
- Headers({
- b'cache-control': [b'max-age=10'],
- b'Expires': [b'Wed, 30 Jan 2019 07:35:33 GMT']
- }),
- time_now=lambda: 1548833700
- ), 10,
+ Headers(
+ {
+ b'cache-control': [b'max-age=10'],
+ b'Expires': [b'Wed, 30 Jan 2019 07:35:33 GMT'],
+ }
+ ),
+ time_now=lambda: 1548833700,
+ ),
+ 10,
)
# invalid expires means immediate expiry
- self.assertEqual(
- _cache_period_from_headers(
- Headers({b'Expires': [b'0']}),
- ), 0,
- )
+ self.assertEqual(_cache_period_from_headers(Headers({b'Expires': [b'0']})), 0)
def _check_logcontext(context):
current = LoggingContext.current_context()
if current is not context:
- raise AssertionError(
- "Expected logcontext %s but was %s" % (context, current),
- )
+ raise AssertionError("Expected logcontext %s but was %s" % (context, current))
-def _build_test_server():
+def _build_test_server(connection_creator):
"""Construct a test server
This builds an HTTP channel, wrapped with a TLSMemoryBIOProtocol
+ Args:
+ connection_creator (IOpenSSLServerConnectionCreator): thing to build
+ SSL connections
+ sanlist (list[bytes]): list of the SAN entries for the cert returned
+ by the server
+
Returns:
TLSMemoryBIOProtocol
"""
@@ -972,7 +1079,7 @@ def _build_test_server():
server_factory.log = _log_request
server_tls_factory = TLSMemoryBIOFactory(
- ServerTLSContext(), isClient=False, wrappedFactory=server_factory,
+ connection_creator, isClient=False, wrappedFactory=server_factory
)
return server_tls_factory.buildProtocol(None)
@@ -985,7 +1092,9 @@ def _log_request(request):
@implementer(IPolicyForHTTPS)
class TrustingTLSPolicyForHTTPS(object):
- """An IPolicyForHTTPS which doesn't do any certificate verification"""
+ """An IPolicyForHTTPS which checks that the certificate belongs to the
+ right server, but doesn't check the certificate chain."""
+
def creatorForNetloc(self, hostname, port):
certificateOptions = OpenSSLCertificateOptions()
return ClientTLSOptions(hostname, certificateOptions.getContext())
diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py
index a872e2441e..034c0db8d2 100644
--- a/tests/http/federation/test_srv_resolver.py
+++ b/tests/http/federation/test_srv_resolver.py
@@ -68,9 +68,7 @@ class SrvResolverTestCase(unittest.TestCase):
dns_client_mock.lookupService.assert_called_once_with(service_name)
- result_deferred.callback(
- ([answer_srv], None, None)
- )
+ result_deferred.callback(([answer_srv], None, None))
servers = self.successResultOf(test_d)
@@ -112,7 +110,7 @@ class SrvResolverTestCase(unittest.TestCase):
cache = {service_name: [entry]}
resolver = SrvResolver(
- dns_client=dns_client_mock, cache=cache, get_time=clock.time,
+ dns_client=dns_client_mock, cache=cache, get_time=clock.time
)
servers = yield resolver.resolve_service(service_name)
@@ -168,11 +166,13 @@ class SrvResolverTestCase(unittest.TestCase):
self.assertNoResult(resolve_d)
# returning a single "." should make the lookup fail with a ConenctError
- lookup_deferred.callback((
- [dns.RRHeader(type=dns.SRV, payload=dns.Record_SRV(target=b"."))],
- None,
- None,
- ))
+ lookup_deferred.callback(
+ (
+ [dns.RRHeader(type=dns.SRV, payload=dns.Record_SRV(target=b"."))],
+ None,
+ None,
+ )
+ )
self.failureResultOf(resolve_d, ConnectError)
@@ -191,14 +191,16 @@ class SrvResolverTestCase(unittest.TestCase):
resolve_d = resolver.resolve_service(service_name)
self.assertNoResult(resolve_d)
- lookup_deferred.callback((
- [
- dns.RRHeader(type=dns.A, payload=dns.Record_A()),
- dns.RRHeader(type=dns.SRV, payload=dns.Record_SRV(target=b"host")),
- ],
- None,
- None,
- ))
+ lookup_deferred.callback(
+ (
+ [
+ dns.RRHeader(type=dns.A, payload=dns.Record_A()),
+ dns.RRHeader(type=dns.SRV, payload=dns.Record_SRV(target=b"host")),
+ ],
+ None,
+ None,
+ )
+ )
servers = self.successResultOf(resolve_d)
diff --git a/tests/http/server.key b/tests/http/server.key
new file mode 100644
index 0000000000..c53ee02b21
--- /dev/null
+++ b/tests/http/server.key
@@ -0,0 +1,27 @@
+-----BEGIN RSA PRIVATE KEY-----
+MIIEpAIBAAKCAQEAvUAWLOE6TEp3FYSfEnJMwYtJg3KIW5BjiAOOvFVOVQfJ5eEa
+vzyJ1Z+8DUgLznFnUkAeD9GjPvP7awl3NPJKLQSMkV5Tp+ea4YyV+Aa4R7flROEa
+zCGvmleydZw0VqN1atVZ0ikEoglM/APJQd70ec7KSR3QoxaV2/VNCHmyAPdP+0WI
+llV54VXX1CZrWSHaCSn1gzo3WjnGbxTOCQE5Z4k5hqJAwLWWhxDv+FX/jD38Sq3H
+gMFNpXJv6FYwwaKU8awghHdSY/qlBPE/1rU83vIBFJ3jW6I1WnQDfCQ69of5vshK
+N4v4hok56ScwdUnk8lw6xvJx1Uav/XQB9qGh4QIDAQABAoIBAQCHLO5p8hotAgdb
+JFZm26N9nxrMPBOvq0ucjEX4ucnwrFaGzynGrNwa7TRqHCrqs0/EjS2ryOacgbL0
+eldeRy26SASLlN+WD7UuI7e+6DXabDzj3RHB+tGuIbPDk+ZCeBDXVTsKBOhdQN1v
+KNkpJrJjCtSsMxKiWvCBow353srJKqCDZcF5NIBYBeDBPMoMbfYn5dJ9JhEf+2h4
+0iwpnWDX1Vqf46pCRa0hwEyMXycGeV2CnfJSyV7z52ZHQrvkz8QspSnPpnlCnbOE
+UAvc8kZ5e8oZE7W+JfkK38vHbEGM1FCrBmrC/46uUGMRpZfDferGs91RwQVq/F0n
+JN9hLzsBAoGBAPh2pm9Xt7a4fWSkX0cDgjI7PT2BvLUjbRwKLV+459uDa7+qRoGE
+sSwb2QBqmQ1kbr9JyTS+Ld8dyUTsGHZK+YbTieAxI3FBdKsuFtcYJO/REN0vik+6
+fMaBHPvDHSU2ioq7spZ4JBFskzqs38FvZ0lX7aa3fguMk8GMLnofQ8QxAoGBAML9
+o5sJLN9Tk9bv2aFgnERgfRfNjjV4Wd99TsktnCD04D1GrP2eDSLfpwFlCnguck6b
+jxikqcolsNhZH4dgYHqRNj+IljSdl+sYZiygO6Ld0XU+dEFO86N3E9NzZhKcQ1at
+85VdwNPCS7JM2fIxEvS9xfbVnsmK6/37ZZ5iI7yxAoGBALw2vRtJGmy60pojfd1A
+hibhAyINnlKlFGkSOI7zdgeuRTf6l9BTIRclvTt4hJpFgzM6hMWEbyE94hJoupsZ
+bm443o/LCWsox2VI05p6urhD6f9znNWKkiyY78izY+elqksvpjgfqEresaTYAeP5
+LQe9KNSK2VuMUP1j4G04M9BxAoGAWe8ITZJuytZOgrz/YIohqPvj1l2tcIYA1a6C
+7xEFSMIIxtpZIWSLZIFJEsCakpHBkPX4iwIveZfmt/JrM1JFTWK6ZZVGyh/BmOIZ
+Bg4lU1oBqJTUo+aZQtTCJS29b2n5OPpkNYkXTdP4e9UsVKNDvfPlYZJneUeEzxDr
+bqCPIRECgYA544KMwrWxDQZg1dsKWgdVVKx80wEFZAiQr9+0KF6ch6Iu7lwGJHFY
+iI6O85paX41qeC/Fo+feIWJVJU2GvG6eBsbO4bmq+KSg4NkABJSYxodgBp9ftNeD
+jo1tfw+gudlNe5jXHu7oSX93tqGjR4Cnlgan/KtfkB96yHOumGmOhQ==
+-----END RSA PRIVATE KEY-----
diff --git a/tests/http/server.pem b/tests/http/server.pem
deleted file mode 100644
index 0584cf1a80..0000000000
--- a/tests/http/server.pem
+++ /dev/null
@@ -1,81 +0,0 @@
------BEGIN PRIVATE KEY-----
-MIIJQgIBADANBgkqhkiG9w0BAQEFAASCCSwwggkoAgEAAoICAQCgF43/3lAgJ+p0
-x7Rn8UcL8a4fctvdkikvZrCngw96LkB34Evfq8YGWlOVjU+f9naUJLAKMatmAfEN
-r+rMX4VOXmpTwuu6iLtqwreUrRFMESyrmvQxa15p+y85gkY0CFmXMblv6ORbxHTG
-ncBGwST4WK4Poewcgt6jcISFCESTUKu1zc3cw1ANIDRyDLB5K44KwIe36dcKckyN
-Kdtv4BJ+3fcIZIkPJH62zqCypgFF1oiFt40uJzClxgHdJZlKYpgkfnDTckw4Y/Mx
-9k8BbE310KAzUNMV9H7I1eEolzrNr66FQj1eN64X/dqO8lTbwCqAd4diCT4sIUk0
-0SVsAUjNd3g8j651hx+Qb1t8fuOjrny8dmeMxtUgIBHoQcpcj76R55Fs7KZ9uar0
-8OFTyGIze51W1jG2K/7/5M1zxIqrA+7lsXu5OR81s7I+Ng/UUAhiHA/z+42/aiNa
-qEuk6tqj3rHfLctnCbtZ+JrRNqSSwEi8F0lMA021ivEd2eJV+284OyJjhXOmKHrX
-QADHrmS7Sh4syTZvRNm9n+qWID0KdDr2Sji/KnS3Enp44HDQ4xriT6/xhwEGsyuX
-oH5aAkdLznulbWkHBbyx1SUQSTLpOqzaioF9m1vRrLsFvrkrY3D253mPJ5eU9HM/
-dilduFcUgj4rz+6cdXUAh+KK/v95zwIDAQABAoICAFG5tJPaOa0ws0/KYx5s3YgL
-aIhFalhCNSQtmCDrlwsYcXDA3/rfBchYdDL0YKGYgBBAal3J3WXFt/j0xThvyu2m
-5UC9UPl4s7RckrsjXqEmY1d3UxGnbhtMT19cUdpeKN42VCP9EBaIw9Rg07dLAkSF
-gNYaIx6q8F0fI4eGIPvTQtUcqur4CfWpaxyNvckdovV6M85/YXfDwbCOnacPDGIX
-jfSK3i0MxGMuOHr6o8uzKR6aBUh6WStHWcw7VXXTvzdiFNbckmx3Gb93rf1b/LBw
-QFfx+tBKcC62gKroCOzXso/0sL9YTVeSD/DJZOiJwSiz3Dj/3u1IUMbVvfTU8wSi
-CYS7Z+jHxwSOCSSNTXm1wO/MtDsNKbI1+R0cohr/J9pOMQvrVh1+2zSDOFvXAQ1S
-yvjn+uqdmijRoV2VEGVHd+34C+ci7eJGAhL/f92PohuuFR2shUETgGWzpACZSJwg
-j1d90Hs81hj07vWRb+xCeDh00vimQngz9AD8vYvv/S4mqRGQ6TZdfjLoUwSTg0JD
-6sQgRXX026gQhLhn687vLKZfHwzQPZkpQdxOR0dTZ/ho/RyGGRJXH4kN4cA2tPr+
-AKYQ29YXGlEzGG7OqikaZcprNWG6UFgEpuXyBxCgp9r4ladZo3J+1Rhgus8ZYatd
-uO98q3WEBmP6CZ2n32mBAoIBAQDS/c/ybFTos0YpGHakwdmSfj5OOQJto2y8ywfG
-qDHwO0ebcpNnS1+MA+7XbKUQb/3Iq7iJljkkzJG2DIJ6rpKynYts1ViYpM7M/t0T
-W3V1gvUcUL62iqkgws4pnpWmubFkqV31cPSHcfIIclnzeQ1aOEGsGHNAvhty0ciC
-DnkJACbqApvopFLOR5f6UFTtKExE+hDH0WqgpsCAKJ1L4g6pBzZatI32/CN9JEVU
-tDbxLV75hHlFFjUrG7nT1rPyr/gI8Ceh9/2xeXPfjJUR0PrG3U1nwLqUCZkvFzO6
-XpN2+A+/v4v5xqMjKDKDFy1oq6SCMomwv/viw6wl/84TMbolAoIBAQDCPiMecnR8
-REik6tqVzQO/uSe9ZHjz6J15t5xdwaI6HpSwLlIkQPkLTjyXtFpemK5DOYRxrJvQ
-remfrZrN2qtLlb/DKpuGPWRsPOvWCrSuNEp48ivUehtclljrzxAFfy0sM+fWeJ48
-nTnR+td9KNhjNtZixzWdAy/mE+jdaMsXVnk66L73Uz+2WsnvVMW2R6cpCR0F2eP/
-B4zDWRqlT2w47sePAB81mFYSQLvPC6Xcgg1OqMubfiizJI49c8DO6Jt+FFYdsxhd
-kG52Eqa/Net6rN3ueiS6yXL5TU3Y6g96bPA2KyNCypucGcddcBfqaiVx/o4AH6yT
-NrdsrYtyvk/jAoIBAQDHUwKVeeRJJbvdbQAArCV4MI155n+1xhMe1AuXkCQFWGtQ
-nlBE4D72jmyf1UKnIbW2Uwv15xY6/ouVWYIWlj9+QDmMaozVP7Uiko+WDuwLRNl8
-k4dn+dzHV2HejbPBG2JLv3lFOx23q1zEwArcaXrExaq9Ayg2fKJ/uVHcFAIiD6Oz
-pR1XDY4w1A/uaN+iYFSVQUyDCQLbnEz1hej73CaPZoHh9Pq83vxD5/UbjVjuRTeZ
-L55FNzKpc/r89rNvTPBcuUwnxplDhYKDKVNWzn9rSXwrzTY2Tk8J3rh+k4RqevSd
-6D47jH1n5Dy7/TRn0ueKHGZZtTUnyEUkbOJo3ayFAoIBAHKDyZaQqaX9Z8p6fwWj
-yVsFoK0ih8BcWkLBAdmwZ6DWGJjJpjmjaG/G3ygc9s4gO1R8m12dAnuDnGE8KzDD
-gwtbrKM2Alyg4wyA2hTlWOH/CAzH0RlCJ9Fs/d1/xJVJBeuyajLiB3/6vXTS6qnq
-I7BSSxAPG8eGcn21LSsjNeB7ZZtaTgNnu/8ZBUYo9yrgkWc67TZe3/ChldYxOOlO
-qqHh/BqNWtjxB4VZTp/g4RbgQVInZ2ozdXEv0v/dt0UEk29ANAjsZif7F3RayJ2f
-/0TilzCaJ/9K9pKNhaClVRy7Dt8QjYg6BIWCGSw4ApF7pLnQ9gySn95mersCkVzD
-YDsCggEAb0E/TORjQhKfNQvahyLfQFm151e+HIoqBqa4WFyfFxe/IJUaLH/JSSFw
-VohbQqPdCmaAeuQ8ERL564DdkcY5BgKcax79fLLCOYP5bT11aQx6uFpfl2Dcm6Z9
-QdCRI4jzPftsd5fxLNH1XtGyC4t6vTic4Pji2O71WgWzx0j5v4aeDY4sZQeFxqCV
-/q7Ee8hem1Rn5RFHu14FV45RS4LAWl6wvf5pQtneSKzx8YL0GZIRRytOzdEfnGKr
-FeUlAj5uL+5/p0ZEgM7gPsEBwdm8scF79qSUn8UWSoXNeIauF9D4BDg8RZcFFxka
-KILVFsq3cQC+bEnoM4eVbjEQkGs1RQ==
------END PRIVATE KEY-----
------BEGIN CERTIFICATE-----
-MIIE/jCCAuagAwIBAgIJANFtVaGvJWZlMA0GCSqGSIb3DQEBCwUAMBMxETAPBgNV
-BAMMCHRlc3RzZXJ2MCAXDTE5MDEyNzIyMDIzNloYDzIxMTkwMTAzMjIwMjM2WjAT
-MREwDwYDVQQDDAh0ZXN0c2VydjCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoC
-ggIBAKAXjf/eUCAn6nTHtGfxRwvxrh9y292SKS9msKeDD3ouQHfgS9+rxgZaU5WN
-T5/2dpQksAoxq2YB8Q2v6sxfhU5ealPC67qIu2rCt5StEUwRLKua9DFrXmn7LzmC
-RjQIWZcxuW/o5FvEdMadwEbBJPhYrg+h7ByC3qNwhIUIRJNQq7XNzdzDUA0gNHIM
-sHkrjgrAh7fp1wpyTI0p22/gEn7d9whkiQ8kfrbOoLKmAUXWiIW3jS4nMKXGAd0l
-mUpimCR+cNNyTDhj8zH2TwFsTfXQoDNQ0xX0fsjV4SiXOs2vroVCPV43rhf92o7y
-VNvAKoB3h2IJPiwhSTTRJWwBSM13eDyPrnWHH5BvW3x+46OufLx2Z4zG1SAgEehB
-ylyPvpHnkWzspn25qvTw4VPIYjN7nVbWMbYr/v/kzXPEiqsD7uWxe7k5HzWzsj42
-D9RQCGIcD/P7jb9qI1qoS6Tq2qPesd8ty2cJu1n4mtE2pJLASLwXSUwDTbWK8R3Z
-4lX7bzg7ImOFc6YoetdAAMeuZLtKHizJNm9E2b2f6pYgPQp0OvZKOL8qdLcSenjg
-cNDjGuJPr/GHAQazK5egfloCR0vOe6VtaQcFvLHVJRBJMuk6rNqKgX2bW9GsuwW+
-uStjcPbneY8nl5T0cz92KV24VxSCPivP7px1dQCH4or+/3nPAgMBAAGjUzBRMB0G
-A1UdDgQWBBQcQZpzLzTk5KdS/Iz7sGCV7gTd/zAfBgNVHSMEGDAWgBQcQZpzLzTk
-5KdS/Iz7sGCV7gTd/zAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUAA4IC
-AQAr/Pgha57jqYsDDX1LyRrVdqoVBpLBeB7x/p9dKYm7S6tBTDFNMZ0SZyQP8VEG
-7UoC9/OQ9nCdEMoR7ZKpQsmipwcIqpXHS6l4YOkf5EEq5jpMgvlEesHmBJJeJew/
-FEPDl1bl8d0tSrmWaL3qepmwzA+2lwAAouWk2n+rLiP8CZ3jZeoTXFqYYrUlEqO9
-fHMvuWqTV4KCSyNY+GWCrnHetulgKHlg+W2J1mZnrCKcBhWf9C2DesTJO+JldIeM
-ornTFquSt21hZi+k3aySuMn2N3MWiNL8XsZVsAnPSs0zA+2fxjJkShls8Gc7cCvd
-a6XrNC+PY6pONguo7rEU4HiwbvnawSTngFFglmH/ImdA/HkaAekW6o82aI8/UxFx
-V9fFMO3iKDQdOrg77hI1bx9RlzKNZZinE2/Pu26fWd5d2zqDWCjl8ykGQRAfXgYN
-H3BjgyXLl+ao5/pOUYYtzm3ruTXTgRcy5hhL6hVTYhSrf9vYh4LNIeXNKnZ78tyG
-TX77/kU2qXhBGCFEUUMqUNV/+ITir2lmoxVjknt19M07aGr8C7SgYt6Rs+qDpMiy
-JurgvRh8LpVq4pHx1efxzxCFmo58DMrG40I0+CF3y/niNpOb1gp2wAqByRiORkds
-f0ytW6qZ0TpHbD6gOtQLYDnhx3ISuX+QYSekVwQUpffeWQ==
------END CERTIFICATE-----
diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py
index cd8e086f86..ee767f3a5a 100644
--- a/tests/http/test_fedclient.py
+++ b/tests/http/test_fedclient.py
@@ -15,6 +15,8 @@
from mock import Mock
+from netaddr import IPSet
+
from twisted.internet import defer
from twisted.internet.defer import TimeoutError
from twisted.internet.error import ConnectingCancelledError, DNSLookupError
@@ -36,9 +38,7 @@ from tests.unittest import HomeserverTestCase
def check_logcontext(context):
current = LoggingContext.current_context()
if current is not context:
- raise AssertionError(
- "Expected logcontext %s but was %s" % (context, current),
- )
+ raise AssertionError("Expected logcontext %s but was %s" % (context, current))
class FederationClientTests(HomeserverTestCase):
@@ -54,6 +54,7 @@ class FederationClientTests(HomeserverTestCase):
"""
happy-path test of a GET request
"""
+
@defer.inlineCallbacks
def do_request():
with LoggingContext("one") as context:
@@ -175,8 +176,7 @@ class FederationClientTests(HomeserverTestCase):
self.assertIsInstance(f.value, RequestSendFailed)
self.assertIsInstance(
- f.value.inner_exception,
- (ConnectingCancelledError, TimeoutError),
+ f.value.inner_exception, (ConnectingCancelledError, TimeoutError)
)
def test_client_connect_no_response(self):
@@ -211,14 +211,81 @@ class FederationClientTests(HomeserverTestCase):
self.assertIsInstance(f.value, RequestSendFailed)
self.assertIsInstance(f.value.inner_exception, ResponseNeverReceived)
+ def test_client_ip_range_blacklist(self):
+ """Ensure that Synapse does not try to connect to blacklisted IPs"""
+
+ # Set up the ip_range blacklist
+ self.hs.config.federation_ip_range_blacklist = IPSet([
+ "127.0.0.0/8",
+ "fe80::/64",
+ ])
+ self.reactor.lookups["internal"] = "127.0.0.1"
+ self.reactor.lookups["internalv6"] = "fe80:0:0:0:0:8a2e:370:7337"
+ self.reactor.lookups["fine"] = "10.20.30.40"
+ cl = MatrixFederationHttpClient(self.hs, None)
+
+ # Try making a GET request to a blacklisted IPv4 address
+ # ------------------------------------------------------
+ # Make the request
+ d = cl.get_json("internal:8008", "foo/bar", timeout=10000)
+
+ # Nothing happened yet
+ self.assertNoResult(d)
+
+ self.pump(1)
+
+ # Check that it was unable to resolve the address
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 0)
+
+ f = self.failureResultOf(d)
+ self.assertIsInstance(f.value, RequestSendFailed)
+ self.assertIsInstance(f.value.inner_exception, DNSLookupError)
+
+ # Try making a POST request to a blacklisted IPv6 address
+ # -------------------------------------------------------
+ # Make the request
+ d = cl.post_json("internalv6:8008", "foo/bar", timeout=10000)
+
+ # Nothing has happened yet
+ self.assertNoResult(d)
+
+ # Move the reactor forwards
+ self.pump(1)
+
+ # Check that it was unable to resolve the address
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 0)
+
+ # Check that it was due to a blacklisted DNS lookup
+ f = self.failureResultOf(d, RequestSendFailed)
+ self.assertIsInstance(f.value.inner_exception, DNSLookupError)
+
+ # Try making a GET request to a non-blacklisted IPv4 address
+ # ----------------------------------------------------------
+ # Make the request
+ d = cl.post_json("fine:8008", "foo/bar", timeout=10000)
+
+ # Nothing has happened yet
+ self.assertNoResult(d)
+
+ # Move the reactor forwards
+ self.pump(1)
+
+ # Check that it was able to resolve the address
+ clients = self.reactor.tcpClients
+ self.assertNotEqual(len(clients), 0)
+
+ # Connection will still fail as this IP address does not resolve to anything
+ f = self.failureResultOf(d, RequestSendFailed)
+ self.assertIsInstance(f.value.inner_exception, ConnectingCancelledError)
+
def test_client_gets_headers(self):
"""
Once the client gets the headers, _request returns successfully.
"""
request = MatrixFederationRequest(
- method="GET",
- destination="testserv:8008",
- path="foo/bar",
+ method="GET", destination="testserv:8008", path="foo/bar"
)
d = self.cl._send_request(request, timeout=10000)
@@ -258,8 +325,10 @@ class FederationClientTests(HomeserverTestCase):
# Send it the HTTP response
client.dataReceived(
- (b"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\n"
- b"Server: Fake\r\n\r\n")
+ (
+ b"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\n"
+ b"Server: Fake\r\n\r\n"
+ )
)
# Push by enough to time it out
@@ -274,9 +343,7 @@ class FederationClientTests(HomeserverTestCase):
requiring a trailing slash. We need to retry the request with a
trailing slash. Workaround for Synapse <= v0.99.3, explained in #3622.
"""
- d = self.cl.get_json(
- "testserv:8008", "foo/bar", try_trailing_slash_on_400=True,
- )
+ d = self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True)
# Send the request
self.pump()
@@ -329,9 +396,7 @@ class FederationClientTests(HomeserverTestCase):
See test_client_requires_trailing_slashes() for context.
"""
- d = self.cl.get_json(
- "testserv:8008", "foo/bar", try_trailing_slash_on_400=True,
- )
+ d = self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True)
# Send the request
self.pump()
@@ -368,10 +433,7 @@ class FederationClientTests(HomeserverTestCase):
self.failureResultOf(d)
def test_client_sends_body(self):
- self.cl.post_json(
- "testserv:8008", "foo/bar", timeout=10000,
- data={"a": "b"}
- )
+ self.cl.post_json("testserv:8008", "foo/bar", timeout=10000, data={"a": "b"})
self.pump()
diff --git a/tests/patch_inline_callbacks.py b/tests/patch_inline_callbacks.py
index 0f613945c8..ee0add3455 100644
--- a/tests/patch_inline_callbacks.py
+++ b/tests/patch_inline_callbacks.py
@@ -45,7 +45,9 @@ def do_patch():
except Exception:
if LoggingContext.current_context() != start_context:
err = "%s changed context from %s to %s on exception" % (
- f, start_context, LoggingContext.current_context()
+ f,
+ start_context,
+ LoggingContext.current_context(),
)
print(err, file=sys.stderr)
raise Exception(err)
@@ -54,7 +56,9 @@ def do_patch():
if not isinstance(res, Deferred) or res.called:
if LoggingContext.current_context() != start_context:
err = "%s changed context from %s to %s" % (
- f, start_context, LoggingContext.current_context()
+ f,
+ start_context,
+ LoggingContext.current_context(),
)
# print the error to stderr because otherwise all we
# see in travis-ci is the 500 error
@@ -66,9 +70,7 @@ def do_patch():
err = (
"%s returned incomplete deferred in non-sentinel context "
"%s (start was %s)"
- ) % (
- f, LoggingContext.current_context(), start_context,
- )
+ ) % (f, LoggingContext.current_context(), start_context)
print(err, file=sys.stderr)
raise Exception(err)
@@ -76,7 +78,9 @@ def do_patch():
if LoggingContext.current_context() != start_context:
err = "%s completion of %s changed context from %s to %s" % (
"Failure" if isinstance(r, Failure) else "Success",
- f, start_context, LoggingContext.current_context(),
+ f,
+ start_context,
+ LoggingContext.current_context(),
)
print(err, file=sys.stderr)
raise Exception(err)
diff --git a/tests/push/test_email.py b/tests/push/test_email.py
index be3fed8de3..72760a0733 100644
--- a/tests/push/test_email.py
+++ b/tests/push/test_email.py
@@ -15,25 +15,28 @@
import os
+import attr
import pkg_resources
from twisted.internet.defer import Deferred
-from synapse.rest.client.v1 import admin, login, room
+import synapse.rest.admin
+from synapse.rest.client.v1 import login, room
from tests.unittest import HomeserverTestCase
-try:
- from synapse.push.mailer import load_jinja2_templates
-except Exception:
- load_jinja2_templates = None
+
+@attr.s
+class _User(object):
+ "Helper wrapper for user ID and access token"
+ id = attr.ib()
+ token = attr.ib()
class EmailPusherTests(HomeserverTestCase):
- skip = "No Jinja installed" if not load_jinja2_templates else None
servlets = [
- admin.register_servlets,
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
room.register_servlets,
login.register_servlets,
]
@@ -51,46 +54,57 @@ class EmailPusherTests(HomeserverTestCase):
return d
config = self.default_config()
- config.email_enable_notifs = True
- config.start_pushers = True
-
- config.email_template_dir = os.path.abspath(
- pkg_resources.resource_filename('synapse', 'res/templates')
- )
- config.email_notif_template_html = "notif_mail.html"
- config.email_notif_template_text = "notif_mail.txt"
- config.email_smtp_host = "127.0.0.1"
- config.email_smtp_port = 20
- config.require_transport_security = False
- config.email_smtp_user = None
- config.email_smtp_pass = None
- config.email_app_name = "Matrix"
- config.email_notif_from = "test@example.com"
- config.email_riot_base_url = None
+ config["email"] = {
+ "enable_notifs": True,
+ "template_dir": os.path.abspath(
+ pkg_resources.resource_filename('synapse', 'res/templates')
+ ),
+ "expiry_template_html": "notice_expiry.html",
+ "expiry_template_text": "notice_expiry.txt",
+ "notif_template_html": "notif_mail.html",
+ "notif_template_text": "notif_mail.txt",
+ "smtp_host": "127.0.0.1",
+ "smtp_port": 20,
+ "require_transport_security": False,
+ "smtp_user": None,
+ "smtp_pass": None,
+ "app_name": "Matrix",
+ "notif_from": "test@example.com",
+ "riot_base_url": None,
+ }
+ config["public_baseurl"] = "aaa"
+ config["start_pushers"] = True
hs = self.setup_test_homeserver(config=config, sendmail=sendmail)
return hs
- def test_sends_email(self):
-
+ def prepare(self, reactor, clock, hs):
# Register the user who gets notified
- user_id = self.register_user("user", "pass")
- access_token = self.login("user", "pass")
-
- # Register the user who sends the message
- other_user_id = self.register_user("otheruser", "pass")
- other_access_token = self.login("otheruser", "pass")
+ self.user_id = self.register_user("user", "pass")
+ self.access_token = self.login("user", "pass")
+
+ # Register other users
+ self.others = [
+ _User(
+ id=self.register_user("otheruser1", "pass"),
+ token=self.login("otheruser1", "pass"),
+ ),
+ _User(
+ id=self.register_user("otheruser2", "pass"),
+ token=self.login("otheruser2", "pass"),
+ ),
+ ]
# Register the pusher
user_tuple = self.get_success(
- self.hs.get_datastore().get_user_by_access_token(access_token)
+ self.hs.get_datastore().get_user_by_access_token(self.access_token)
)
token_id = user_tuple["token_id"]
- self.get_success(
+ self.pusher = self.get_success(
self.hs.get_pusherpool().add_pusher(
- user_id=user_id,
+ user_id=self.user_id,
access_token=token_id,
kind="email",
app_id="m.email",
@@ -102,22 +116,54 @@ class EmailPusherTests(HomeserverTestCase):
)
)
- # Create a room
- room = self.helper.create_room_as(user_id, tok=access_token)
+ def test_simple_sends_email(self):
+ # Create a simple room with two users
+ room = self.helper.create_room_as(self.user_id, tok=self.access_token)
+ self.helper.invite(
+ room=room, src=self.user_id, tok=self.access_token, targ=self.others[0].id,
+ )
+ self.helper.join(room=room, user=self.others[0].id, tok=self.others[0].token)
- # Invite the other person
- self.helper.invite(room=room, src=user_id, tok=access_token, targ=other_user_id)
+ # The other user sends some messages
+ self.helper.send(room, body="Hi!", tok=self.others[0].token)
+ self.helper.send(room, body="There!", tok=self.others[0].token)
- # The other user joins
- self.helper.join(room=room, user=other_user_id, tok=other_access_token)
+ # We should get emailed about that message
+ self._check_for_mail()
- # The other user sends some messages
- self.helper.send(room, body="Hi!", tok=other_access_token)
- self.helper.send(room, body="There!", tok=other_access_token)
+ def test_multiple_members_email(self):
+ # We want to test multiple notifications, so we pause processing of push
+ # while we send messages.
+ self.pusher._pause_processing()
+
+ # Create a simple room with multiple other users
+ room = self.helper.create_room_as(self.user_id, tok=self.access_token)
+
+ for other in self.others:
+ self.helper.invite(
+ room=room, src=self.user_id, tok=self.access_token, targ=other.id,
+ )
+ self.helper.join(room=room, user=other.id, tok=other.token)
+
+ # The other users send some messages
+ self.helper.send(room, body="Hi!", tok=self.others[0].token)
+ self.helper.send(room, body="There!", tok=self.others[1].token)
+ self.helper.send(room, body="There!", tok=self.others[1].token)
+
+ # Nothing should have happened yet, as we're paused.
+ assert not self.email_attempts
+
+ self.pusher._resume_processing()
+
+ # We should get emailed about those messages
+ self._check_for_mail()
+
+ def _check_for_mail(self):
+ "Check that the user receives an email notification"
# Get the stream ordering before it gets sent
pushers = self.get_success(
- self.hs.get_datastore().get_pushers_by(dict(user_name=user_id))
+ self.hs.get_datastore().get_pushers_by(dict(user_name=self.user_id))
)
self.assertEqual(len(pushers), 1)
last_stream_ordering = pushers[0]["last_stream_ordering"]
@@ -127,7 +173,7 @@ class EmailPusherTests(HomeserverTestCase):
# It hasn't succeeded yet, so the stream ordering shouldn't have moved
pushers = self.get_success(
- self.hs.get_datastore().get_pushers_by(dict(user_name=user_id))
+ self.hs.get_datastore().get_pushers_by(dict(user_name=self.user_id))
)
self.assertEqual(len(pushers), 1)
self.assertEqual(last_stream_ordering, pushers[0]["last_stream_ordering"])
@@ -144,7 +190,7 @@ class EmailPusherTests(HomeserverTestCase):
# The stream ordering has increased
pushers = self.get_success(
- self.hs.get_datastore().get_pushers_by(dict(user_name=user_id))
+ self.hs.get_datastore().get_pushers_by(dict(user_name=self.user_id))
)
self.assertEqual(len(pushers), 1)
self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering)
diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index 6dc45e8506..22c3f73ef3 100644
--- a/tests/push/test_http.py
+++ b/tests/push/test_http.py
@@ -17,22 +17,17 @@ from mock import Mock
from twisted.internet.defer import Deferred
-from synapse.rest.client.v1 import admin, login, room
+import synapse.rest.admin
+from synapse.rest.client.v1 import login, room
from synapse.util.logcontext import make_deferred_yieldable
from tests.unittest import HomeserverTestCase
-try:
- from synapse.push.mailer import load_jinja2_templates
-except Exception:
- load_jinja2_templates = None
-
class HTTPPusherTests(HomeserverTestCase):
- skip = "No Jinja installed" if not load_jinja2_templates else None
servlets = [
- admin.register_servlets,
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
room.register_servlets,
login.register_servlets,
]
@@ -53,7 +48,7 @@ class HTTPPusherTests(HomeserverTestCase):
m.post_json_get_json = post_json_get_json
config = self.default_config()
- config.start_pushers = True
+ config["start_pushers"] = True
hs = self.setup_test_homeserver(config=config, simple_http_client=m)
diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index 1f72a2a04f..104349cdbd 100644
--- a/tests/replication/slave/storage/_base.py
+++ b/tests/replication/slave/storage/_base.py
@@ -74,21 +74,18 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase):
self.assertEqual(
master_result,
expected_result,
- "Expected master result to be %r but was %r" % (
- expected_result, master_result
- ),
+ "Expected master result to be %r but was %r"
+ % (expected_result, master_result),
)
self.assertEqual(
slaved_result,
expected_result,
- "Expected slave result to be %r but was %r" % (
- expected_result, slaved_result
- ),
+ "Expected slave result to be %r but was %r"
+ % (expected_result, slaved_result),
)
self.assertEqual(
master_result,
slaved_result,
- "Slave result %r does not match master result %r" % (
- slaved_result, master_result
- ),
+ "Slave result %r does not match master result %r"
+ % (slaved_result, master_result),
)
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index 65ecff3bd6..a368117b43 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -234,10 +234,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join"
)
msg, msgctx = self.build_event()
- self.get_success(self.master_store.persist_events([
- (j2, j2ctx),
- (msg, msgctx),
- ]))
+ self.get_success(self.master_store.persist_events([(j2, j2ctx), (msg, msgctx)]))
self.replicate()
event_source = RoomEventSource(self.hs)
@@ -257,15 +254,13 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
#
# First, we get a list of the rooms we are joined to
joined_rooms = self.get_success(
- self.slaved_store.get_rooms_for_user_with_stream_ordering(
- USER_ID_2,
- ),
+ self.slaved_store.get_rooms_for_user_with_stream_ordering(USER_ID_2)
)
# Then, we get a list of the events since the last sync
membership_changes = self.get_success(
self.slaved_store.get_membership_changes_for_user(
- USER_ID_2, prev_token, current_token,
+ USER_ID_2, prev_token, current_token
)
)
@@ -298,9 +293,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.master_store.persist_events([(event, context)], backfilled=True)
)
else:
- self.get_success(
- self.master_store.persist_event(event, context)
- )
+ self.get_success(self.master_store.persist_event(event, context))
return event
@@ -359,9 +352,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
)
else:
state_handler = self.hs.get_state_handler()
- context = self.get_success(state_handler.compute_event_context(
- event
- ))
+ context = self.get_success(state_handler.compute_event_context(event))
self.master_store.add_push_actions_to_staging(
event.event_id, {user_id: actions for user_id, actions in push_actions}
diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py
index 38b368a972..ce3835ae6a 100644
--- a/tests/replication/tcp/streams/_base.py
+++ b/tests/replication/tcp/streams/_base.py
@@ -22,6 +22,7 @@ from tests.server import FakeTransport
class BaseStreamTestCase(unittest.HomeserverTestCase):
"""Base class for tests of the replication streams"""
+
def prepare(self, reactor, clock, hs):
# build a replication server
server_factory = ReplicationStreamProtocolFactory(self.hs)
@@ -52,6 +53,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
class TestReplicationClientHandler(object):
"""Drop-in for ReplicationClientHandler which just collects RDATA rows"""
+
def __init__(self):
self.received_rdata_rows = []
@@ -69,6 +71,4 @@ class TestReplicationClientHandler(object):
def on_rdata(self, stream_name, token, rows):
for r in rows:
- self.received_rdata_rows.append(
- (stream_name, token, r)
- )
+ self.received_rdata_rows.append((stream_name, token, r))
diff --git a/tests/rest/admin/__init__.py b/tests/rest/admin/__init__.py
new file mode 100644
index 0000000000..1453d04571
--- /dev/null
+++ b/tests/rest/admin/__init__.py
@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/rest/client/v1/test_admin.py b/tests/rest/admin/test_admin.py
index c00ef21d75..e5fc2fcd15 100644
--- a/tests/rest/client/v1/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -19,50 +19,37 @@ import json
from mock import Mock
+import synapse.rest.admin
from synapse.api.constants import UserTypes
-from synapse.rest.client.v1 import admin, events, login, room
+from synapse.http.server import JsonResource
+from synapse.rest.admin import VersionServlet
+from synapse.rest.client.v1 import events, login, room
from synapse.rest.client.v2_alpha import groups
from tests import unittest
class VersionTestCase(unittest.HomeserverTestCase):
+ url = '/_synapse/admin/v1/server_version'
- servlets = [
- admin.register_servlets,
- login.register_servlets,
- ]
-
- url = '/_matrix/client/r0/admin/server_version'
+ def create_test_json_resource(self):
+ resource = JsonResource(self.hs)
+ VersionServlet(self.hs).register(resource)
+ return resource
def test_version_string(self):
- self.register_user("admin", "pass", admin=True)
- self.admin_token = self.login("admin", "pass")
-
- request, channel = self.make_request("GET", self.url,
- access_token=self.admin_token)
+ request, channel = self.make_request("GET", self.url, shorthand=False)
self.render(request)
- self.assertEqual(200, int(channel.result["code"]),
- msg=channel.result["body"])
- self.assertEqual({'server_version', 'python_version'},
- set(channel.json_body.keys()))
-
- def test_inaccessible_to_non_admins(self):
- self.register_user("unprivileged-user", "pass", admin=False)
- user_token = self.login("unprivileged-user", "pass")
-
- request, channel = self.make_request("GET", self.url,
- access_token=user_token)
- self.render(request)
-
- self.assertEqual(403, int(channel.result['code']),
- msg=channel.result['body'])
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ {'server_version', 'python_version'}, set(channel.json_body.keys())
+ )
class UserRegisterTestCase(unittest.HomeserverTestCase):
- servlets = [admin.register_servlets]
+ servlets = [synapse.rest.admin.register_servlets_for_client_rest_resource]
def make_homeserver(self, reactor, clock):
@@ -213,9 +200,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
nonce = channel.json_body["nonce"]
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
- want_mac.update(
- nonce.encode('ascii') + b"\x00bob\x00abc123\x00admin"
- )
+ want_mac.update(nonce.encode('ascii') + b"\x00bob\x00abc123\x00admin")
want_mac = want_mac.hexdigest()
body = json.dumps(
@@ -343,11 +328,13 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
#
# Invalid user_type
- body = json.dumps({
- "nonce": nonce(),
- "username": "a",
- "password": "1234",
- "user_type": "invalid"}
+ body = json.dumps(
+ {
+ "nonce": nonce(),
+ "username": "a",
+ "password": "1234",
+ "user_type": "invalid",
+ }
)
request, channel = self.make_request("POST", self.url, body.encode('utf8'))
self.render(request)
@@ -358,7 +345,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
class ShutdownRoomTestCase(unittest.HomeserverTestCase):
servlets = [
- admin.register_servlets,
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
login.register_servlets,
events.register_servlets,
room.register_servlets,
@@ -370,9 +357,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase):
hs.config.user_consent_version = "1"
consent_uri_builder = Mock()
- consent_uri_builder.build_user_consent_uri.return_value = (
- "http://example.com"
- )
+ consent_uri_builder.build_user_consent_uri.return_value = "http://example.com"
self.event_creation_handler._consent_uri_builder = consent_uri_builder
self.store = hs.get_datastore()
@@ -384,9 +369,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase):
self.other_user_token = self.login("user", "pass")
# Mark the admin user as having consented
- self.get_success(
- self.store.user_set_consent_version(self.admin_user, "1"),
- )
+ self.get_success(self.store.user_set_consent_version(self.admin_user, "1"))
def test_shutdown_room_consent(self):
"""Test that we can shutdown rooms with local users who have not
@@ -398,9 +381,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase):
room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token)
# Assert one user in room
- users_in_room = self.get_success(
- self.store.get_users_in_room(room_id),
- )
+ users_in_room = self.get_success(self.store.get_users_in_room(room_id))
self.assertEqual([self.other_user], users_in_room)
# Enable require consent to send events
@@ -408,8 +389,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase):
# Assert that the user is getting consent error
self.helper.send(
- room_id,
- body="foo", tok=self.other_user_token, expect_code=403,
+ room_id, body="foo", tok=self.other_user_token, expect_code=403
)
# Test that the admin can still send shutdown
@@ -425,12 +405,9 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
# Assert there is now no longer anyone in the room
- users_in_room = self.get_success(
- self.store.get_users_in_room(room_id),
- )
+ users_in_room = self.get_success(self.store.get_users_in_room(room_id))
self.assertEqual([], users_in_room)
- @unittest.DEBUG
def test_shutdown_room_block_peek(self):
"""Test that a world_readable room can no longer be peeked into after
it has been shut down.
@@ -472,30 +449,26 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase):
url = "rooms/%s/initialSync" % (room_id,)
request, channel = self.make_request(
- "GET",
- url.encode('ascii'),
- access_token=self.admin_user_tok,
+ "GET", url.encode('ascii'), access_token=self.admin_user_tok
)
self.render(request)
self.assertEqual(
- expect_code, int(channel.result["code"]), msg=channel.result["body"],
+ expect_code, int(channel.result["code"]), msg=channel.result["body"]
)
url = "events?timeout=0&room_id=" + room_id
request, channel = self.make_request(
- "GET",
- url.encode('ascii'),
- access_token=self.admin_user_tok,
+ "GET", url.encode('ascii'), access_token=self.admin_user_tok
)
self.render(request)
self.assertEqual(
- expect_code, int(channel.result["code"]), msg=channel.result["body"],
+ expect_code, int(channel.result["code"]), msg=channel.result["body"]
)
class DeleteGroupTestCase(unittest.HomeserverTestCase):
servlets = [
- admin.register_servlets,
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
login.register_servlets,
groups.register_servlets,
]
@@ -515,15 +488,11 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
"POST",
"/create_group".encode('ascii'),
access_token=self.admin_user_tok,
- content={
- "localpart": "test",
- }
+ content={"localpart": "test"},
)
self.render(request)
- self.assertEqual(
- 200, int(channel.result["code"]), msg=channel.result["body"],
- )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
group_id = channel.json_body["group_id"]
@@ -533,27 +502,17 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
url = "/groups/%s/admin/users/invite/%s" % (group_id, self.other_user)
request, channel = self.make_request(
- "PUT",
- url.encode('ascii'),
- access_token=self.admin_user_tok,
- content={}
+ "PUT", url.encode('ascii'), access_token=self.admin_user_tok, content={}
)
self.render(request)
- self.assertEqual(
- 200, int(channel.result["code"]), msg=channel.result["body"],
- )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
url = "/groups/%s/self/accept_invite" % (group_id,)
request, channel = self.make_request(
- "PUT",
- url.encode('ascii'),
- access_token=self.other_user_token,
- content={}
+ "PUT", url.encode('ascii'), access_token=self.other_user_token, content={}
)
self.render(request)
- self.assertEqual(
- 200, int(channel.result["code"]), msg=channel.result["body"],
- )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
# Check other user knows they're in the group
self.assertIn(group_id, self._get_groups_user_is_in(self.admin_user_tok))
@@ -565,15 +524,11 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
"POST",
url.encode('ascii'),
access_token=self.admin_user_tok,
- content={
- "localpart": "test",
- }
+ content={"localpart": "test"},
)
self.render(request)
- self.assertEqual(
- 200, int(channel.result["code"]), msg=channel.result["body"],
- )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
# Check group returns 404
self._check_group(group_id, expect_code=404)
@@ -589,28 +544,22 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
url = "/groups/%s/profile" % (group_id,)
request, channel = self.make_request(
- "GET",
- url.encode('ascii'),
- access_token=self.admin_user_tok,
+ "GET", url.encode('ascii'), access_token=self.admin_user_tok
)
self.render(request)
self.assertEqual(
- expect_code, int(channel.result["code"]), msg=channel.result["body"],
+ expect_code, int(channel.result["code"]), msg=channel.result["body"]
)
def _get_groups_user_is_in(self, access_token):
"""Returns the list of groups the user is in (given their access token)
"""
request, channel = self.make_request(
- "GET",
- "/joined_groups".encode('ascii'),
- access_token=access_token,
+ "GET", "/joined_groups".encode('ascii'), access_token=access_token
)
self.render(request)
- self.assertEqual(
- 200, int(channel.result["code"]), msg=channel.result["body"],
- )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
return channel.json_body["groups"]
diff --git a/tests/rest/client/test_consent.py b/tests/rest/client/test_consent.py
index 4294bbec2a..efc5a99db3 100644
--- a/tests/rest/client/test_consent.py
+++ b/tests/rest/client/test_consent.py
@@ -15,23 +15,18 @@
import os
+import synapse.rest.admin
from synapse.api.urls import ConsentURIBuilder
-from synapse.rest.client.v1 import admin, login, room
+from synapse.rest.client.v1 import login, room
from synapse.rest.consent import consent_resource
from tests import unittest
from tests.server import render
-try:
- from synapse.push.mailer import load_jinja2_templates
-except Exception:
- load_jinja2_templates = None
-
class ConsentResourceTestCase(unittest.HomeserverTestCase):
- skip = "No Jinja installed" if not load_jinja2_templates else None
servlets = [
- admin.register_servlets,
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
room.register_servlets,
login.register_servlets,
]
@@ -41,15 +36,18 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
config = self.default_config()
- config.user_consent_version = "1"
- config.public_baseurl = ""
- config.form_secret = "123abc"
+ config["public_baseurl"] = "aaaa"
+ config["form_secret"] = "123abc"
# Make some temporary templates...
temp_consent_path = self.mktemp()
os.mkdir(temp_consent_path)
os.mkdir(os.path.join(temp_consent_path, 'en'))
- config.user_consent_template_dir = os.path.abspath(temp_consent_path)
+
+ config["user_consent"] = {
+ "version": "1",
+ "template_dir": os.path.abspath(temp_consent_path),
+ }
with open(os.path.join(temp_consent_path, "en/1.html"), 'w') as f:
f.write("{{version}},{{has_consented}}")
diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py
index ca63b2e6ed..68949307d9 100644
--- a/tests/rest/client/test_identity.py
+++ b/tests/rest/client/test_identity.py
@@ -15,7 +15,8 @@
import json
-from synapse.rest.client.v1 import admin, login, room
+import synapse.rest.admin
+from synapse.rest.client.v1 import login, room
from tests import unittest
@@ -23,7 +24,7 @@ from tests import unittest
class IdentityTestCase(unittest.HomeserverTestCase):
servlets = [
- admin.register_servlets,
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
room.register_servlets,
login.register_servlets,
]
@@ -31,7 +32,7 @@ class IdentityTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
config = self.default_config()
- config.enable_3pid_lookup = False
+ config["enable_3pid_lookup"] = False
self.hs = self.setup_test_homeserver(config=config)
return self.hs
@@ -43,7 +44,7 @@ class IdentityTestCase(unittest.HomeserverTestCase):
tok = self.login("kermit", "monkey")
request, channel = self.make_request(
- b"POST", "/createRoom", b"{}", access_token=tok,
+ b"POST", "/createRoom", b"{}", access_token=tok
)
self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)
@@ -55,11 +56,9 @@ class IdentityTestCase(unittest.HomeserverTestCase):
"address": "test@example.com",
}
request_data = json.dumps(params)
- request_url = (
- "/rooms/%s/invite" % (room_id)
- ).encode('ascii')
+ request_url = ("/rooms/%s/invite" % (room_id)).encode('ascii')
request, channel = self.make_request(
- b"POST", request_url, request_data, access_token=tok,
+ b"POST", request_url, request_data, access_token=tok
)
self.render(request)
self.assertEquals(channel.result["code"], b"403", channel.result)
diff --git a/tests/rest/client/third_party_rules.py b/tests/rest/client/third_party_rules.py
new file mode 100644
index 0000000000..7167fc56b6
--- /dev/null
+++ b/tests/rest/client/third_party_rules.py
@@ -0,0 +1,79 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+
+from tests import unittest
+
+
+class ThirdPartyRulesTestModule(object):
+ def __init__(self, config):
+ pass
+
+ def check_event_allowed(self, event, context):
+ if event.type == "foo.bar.forbidden":
+ return False
+ else:
+ return True
+
+ @staticmethod
+ def parse_config(config):
+ return config
+
+
+class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+ config["third_party_event_rules"] = {
+ "module": "tests.rest.client.third_party_rules.ThirdPartyRulesTestModule",
+ "config": {},
+ }
+
+ self.hs = self.setup_test_homeserver(config=config)
+ return self.hs
+
+ def test_third_party_rules(self):
+ """Tests that a forbidden event is forbidden from being sent, but an allowed one
+ can be sent.
+ """
+ user_id = self.register_user("kermit", "monkey")
+ tok = self.login("kermit", "monkey")
+
+ room_id = self.helper.create_room_as(user_id, tok=tok)
+
+ request, channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/rooms/%s/send/foo.bar.allowed/1" % room_id,
+ {},
+ access_token=tok,
+ )
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ request, channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/rooms/%s/send/foo.bar.forbidden/1" % room_id,
+ {},
+ access_token=tok,
+ )
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"403", channel.result)
diff --git a/tests/rest/client/v1/test_directory.py b/tests/rest/client/v1/test_directory.py
new file mode 100644
index 0000000000..633b7dbda0
--- /dev/null
+++ b/tests/rest/client/v1/test_directory.py
@@ -0,0 +1,150 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+
+from synapse.rest import admin
+from synapse.rest.client.v1 import directory, login, room
+from synapse.types import RoomAlias
+from synapse.util.stringutils import random_string
+
+from tests import unittest
+
+
+class DirectoryTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ admin.register_servlets_for_client_rest_resource,
+ directory.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+ config["require_membership_for_aliases"] = True
+
+ self.hs = self.setup_test_homeserver(config=config)
+
+ return self.hs
+
+ def prepare(self, reactor, clock, homeserver):
+ self.room_owner = self.register_user("room_owner", "test")
+ self.room_owner_tok = self.login("room_owner", "test")
+
+ self.room_id = self.helper.create_room_as(
+ self.room_owner, tok=self.room_owner_tok
+ )
+
+ self.user = self.register_user("user", "test")
+ self.user_tok = self.login("user", "test")
+
+ def test_state_event_not_in_room(self):
+ self.ensure_user_left_room()
+ self.set_alias_via_state_event(403)
+
+ def test_directory_endpoint_not_in_room(self):
+ self.ensure_user_left_room()
+ self.set_alias_via_directory(403)
+
+ def test_state_event_in_room_too_long(self):
+ self.ensure_user_joined_room()
+ self.set_alias_via_state_event(400, alias_length=256)
+
+ def test_directory_in_room_too_long(self):
+ self.ensure_user_joined_room()
+ self.set_alias_via_directory(400, alias_length=256)
+
+ def test_state_event_in_room(self):
+ self.ensure_user_joined_room()
+ self.set_alias_via_state_event(200)
+
+ def test_directory_in_room(self):
+ self.ensure_user_joined_room()
+ self.set_alias_via_directory(200)
+
+ def test_room_creation_too_long(self):
+ url = "/_matrix/client/r0/createRoom"
+
+ # We use deliberately a localpart under the length threshold so
+ # that we can make sure that the check is done on the whole alias.
+ data = {"room_alias_name": random_string(256 - len(self.hs.hostname))}
+ request_data = json.dumps(data)
+ request, channel = self.make_request(
+ "POST", url, request_data, access_token=self.user_tok
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 400, channel.result)
+
+ def test_room_creation(self):
+ url = "/_matrix/client/r0/createRoom"
+
+ # Check with an alias of allowed length. There should already be
+ # a test that ensures it works in test_register.py, but let's be
+ # as cautious as possible here.
+ data = {"room_alias_name": random_string(5)}
+ request_data = json.dumps(data)
+ request, channel = self.make_request(
+ "POST", url, request_data, access_token=self.user_tok
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ def set_alias_via_state_event(self, expected_code, alias_length=5):
+ url = "/_matrix/client/r0/rooms/%s/state/m.room.aliases/%s" % (
+ self.room_id,
+ self.hs.hostname,
+ )
+
+ data = {"aliases": [self.random_alias(alias_length)]}
+ request_data = json.dumps(data)
+
+ request, channel = self.make_request(
+ "PUT", url, request_data, access_token=self.user_tok
+ )
+ self.render(request)
+ self.assertEqual(channel.code, expected_code, channel.result)
+
+ def set_alias_via_directory(self, expected_code, alias_length=5):
+ url = "/_matrix/client/r0/directory/room/%s" % self.random_alias(alias_length)
+ data = {"room_id": self.room_id}
+ request_data = json.dumps(data)
+
+ request, channel = self.make_request(
+ "PUT", url, request_data, access_token=self.user_tok
+ )
+ self.render(request)
+ self.assertEqual(channel.code, expected_code, channel.result)
+
+ def random_alias(self, length):
+ return RoomAlias(random_string(length), self.hs.hostname).to_string()
+
+ def ensure_user_left_room(self):
+ self.ensure_membership("leave")
+
+ def ensure_user_joined_room(self):
+ self.ensure_membership("join")
+
+ def ensure_membership(self, membership):
+ try:
+ if membership == "leave":
+ self.helper.leave(room=self.room_id, user=self.user, tok=self.user_tok)
+ if membership == "join":
+ self.helper.join(room=self.room_id, user=self.user, tok=self.user_tok)
+ except AssertionError:
+ # We don't care whether the leave request didn't return a 200 (e.g.
+ # if the user isn't already in the room), because we only want to
+ # make sure the user isn't in the room.
+ pass
diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py
index 36d8547275..f340b7e851 100644
--- a/tests/rest/client/v1/test_events.py
+++ b/tests/rest/client/v1/test_events.py
@@ -17,7 +17,8 @@
from mock import Mock, NonCallableMock
-from synapse.rest.client.v1 import admin, events, login, room
+import synapse.rest.admin
+from synapse.rest.client.v1 import events, login, room
from tests import unittest
@@ -28,16 +29,16 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
servlets = [
events.register_servlets,
room.register_servlets,
- admin.register_servlets,
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
login.register_servlets,
]
def make_homeserver(self, reactor, clock):
config = self.default_config()
- config.enable_registration_captcha = False
- config.enable_registration = True
- config.auto_join_rooms = []
+ config["enable_registration_captcha"] = False
+ config["enable_registration"] = True
+ config["auto_join_rooms"] = []
hs = self.setup_test_homeserver(
config=config, ratelimiter=NonCallableMock(spec_set=["can_do_action"])
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index 86312f1096..0397f91a9e 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -1,6 +1,7 @@
import json
-from synapse.rest.client.v1 import admin, login
+import synapse.rest.admin
+from synapse.rest.client.v1 import login
from tests import unittest
@@ -10,7 +11,7 @@ LOGIN_URL = b"/_matrix/client/r0/login"
class LoginRestServletTestCase(unittest.HomeserverTestCase):
servlets = [
- admin.register_servlets,
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
login.register_servlets,
]
@@ -36,10 +37,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
for i in range(0, 6):
params = {
"type": "m.login.password",
- "identifier": {
- "type": "m.id.user",
- "user": "kermit" + str(i),
- },
+ "identifier": {"type": "m.id.user", "user": "kermit" + str(i)},
"password": "monkey",
}
request_data = json.dumps(params)
@@ -56,14 +54,11 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# than 1min.
self.assertTrue(retry_after_ms < 6000)
- self.reactor.advance(retry_after_ms / 1000.)
+ self.reactor.advance(retry_after_ms / 1000.0)
params = {
"type": "m.login.password",
- "identifier": {
- "type": "m.id.user",
- "user": "kermit" + str(i),
- },
+ "identifier": {"type": "m.id.user", "user": "kermit" + str(i)},
"password": "monkey",
}
request_data = json.dumps(params)
@@ -81,10 +76,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
for i in range(0, 6):
params = {
"type": "m.login.password",
- "identifier": {
- "type": "m.id.user",
- "user": "kermit",
- },
+ "identifier": {"type": "m.id.user", "user": "kermit"},
"password": "monkey",
}
request_data = json.dumps(params)
@@ -101,14 +93,11 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# than 1min.
self.assertTrue(retry_after_ms < 6000)
- self.reactor.advance(retry_after_ms / 1000.)
+ self.reactor.advance(retry_after_ms / 1000.0)
params = {
"type": "m.login.password",
- "identifier": {
- "type": "m.id.user",
- "user": "kermit",
- },
+ "identifier": {"type": "m.id.user", "user": "kermit"},
"password": "monkey",
}
request_data = json.dumps(params)
@@ -126,10 +115,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
for i in range(0, 6):
params = {
"type": "m.login.password",
- "identifier": {
- "type": "m.id.user",
- "user": "kermit",
- },
+ "identifier": {"type": "m.id.user", "user": "kermit"},
"password": "notamonkey",
}
request_data = json.dumps(params)
@@ -146,14 +132,11 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# than 1min.
self.assertTrue(retry_after_ms < 6000)
- self.reactor.advance(retry_after_ms / 1000.)
+ self.reactor.advance(retry_after_ms / 1000.0)
params = {
"type": "m.login.password",
- "identifier": {
- "type": "m.id.user",
- "user": "kermit",
- },
+ "identifier": {"type": "m.id.user", "user": "kermit"},
"password": "notamonkey",
}
request_data = json.dumps(params)
diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py
index 1eab9c3bdb..72c7ed93cb 100644
--- a/tests/rest/client/v1/test_profile.py
+++ b/tests/rest/client/v1/test_profile.py
@@ -14,24 +14,30 @@
# limitations under the License.
"""Tests REST events for /profile paths."""
+import json
+
from mock import Mock
from twisted.internet import defer
import synapse.types
from synapse.api.errors import AuthError, SynapseError
-from synapse.rest.client.v1 import profile
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, profile, room
from tests import unittest
from ....utils import MockHttpResource, setup_test_homeserver
myid = "@1234ABCD:test"
-PATH_PREFIX = "/_matrix/client/api/v1"
+PATH_PREFIX = "/_matrix/client/r0"
+
+class MockHandlerProfileTestCase(unittest.TestCase):
+ """ Tests rest layer of profile management.
-class ProfileTestCase(unittest.TestCase):
- """ Tests profile management. """
+ Todo: move these into ProfileTestCase
+ """
@defer.inlineCallbacks
def setUp(self):
@@ -42,6 +48,7 @@ class ProfileTestCase(unittest.TestCase):
"set_displayname",
"get_avatar_url",
"set_avatar_url",
+ "check_profile_query_allowed",
]
)
@@ -155,3 +162,130 @@ class ProfileTestCase(unittest.TestCase):
self.assertEquals(mocked_set.call_args[0][0].localpart, "1234ABCD")
self.assertEquals(mocked_set.call_args[0][1].user.localpart, "1234ABCD")
self.assertEquals(mocked_set.call_args[0][2], "http://my.server/pic.gif")
+
+
+class ProfileTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ admin.register_servlets_for_client_rest_resource,
+ login.register_servlets,
+ profile.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ self.hs = self.setup_test_homeserver()
+ return self.hs
+
+ def prepare(self, reactor, clock, hs):
+ self.owner = self.register_user("owner", "pass")
+ self.owner_tok = self.login("owner", "pass")
+
+ def test_set_displayname(self):
+ request, channel = self.make_request(
+ "PUT",
+ "/profile/%s/displayname" % (self.owner, ),
+ content=json.dumps({"displayname": "test"}),
+ access_token=self.owner_tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ res = self.get_displayname()
+ self.assertEqual(res, "test")
+
+ def test_set_displayname_too_long(self):
+ """Attempts to set a stupid displayname should get a 400"""
+ request, channel = self.make_request(
+ "PUT",
+ "/profile/%s/displayname" % (self.owner, ),
+ content=json.dumps({"displayname": "test" * 100}),
+ access_token=self.owner_tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 400, channel.result)
+
+ res = self.get_displayname()
+ self.assertEqual(res, "owner")
+
+ def get_displayname(self):
+ request, channel = self.make_request(
+ "GET",
+ "/profile/%s/displayname" % (self.owner, ),
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+ return channel.json_body["displayname"]
+
+
+class ProfilesRestrictedTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ admin.register_servlets_for_client_rest_resource,
+ login.register_servlets,
+ profile.register_servlets,
+ room.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+
+ config = self.default_config()
+ config["require_auth_for_profile_requests"] = True
+ self.hs = self.setup_test_homeserver(config=config)
+
+ return self.hs
+
+ def prepare(self, reactor, clock, hs):
+ # User owning the requested profile.
+ self.owner = self.register_user("owner", "pass")
+ self.owner_tok = self.login("owner", "pass")
+ self.profile_url = "/profile/%s" % (self.owner)
+
+ # User requesting the profile.
+ self.requester = self.register_user("requester", "pass")
+ self.requester_tok = self.login("requester", "pass")
+
+ self.room_id = self.helper.create_room_as(self.owner, tok=self.owner_tok)
+
+ def test_no_auth(self):
+ self.try_fetch_profile(401)
+
+ def test_not_in_shared_room(self):
+ self.ensure_requester_left_room()
+
+ self.try_fetch_profile(403, access_token=self.requester_tok)
+
+ def test_in_shared_room(self):
+ self.ensure_requester_left_room()
+
+ self.helper.join(room=self.room_id, user=self.requester, tok=self.requester_tok)
+
+ self.try_fetch_profile(200, self.requester_tok)
+
+ def try_fetch_profile(self, expected_code, access_token=None):
+ self.request_profile(expected_code, access_token=access_token)
+
+ self.request_profile(
+ expected_code, url_suffix="/displayname", access_token=access_token
+ )
+
+ self.request_profile(
+ expected_code, url_suffix="/avatar_url", access_token=access_token
+ )
+
+ def request_profile(self, expected_code, url_suffix="", access_token=None):
+ request, channel = self.make_request(
+ "GET", self.profile_url + url_suffix, access_token=access_token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, expected_code, channel.result)
+
+ def ensure_requester_left_room(self):
+ try:
+ self.helper.leave(
+ room=self.room_id, user=self.requester, tok=self.requester_tok
+ )
+ except AssertionError:
+ # We don't care whether the leave request didn't return a 200 (e.g.
+ # if the user isn't already in the room), because we only want to
+ # make sure the user isn't in the room.
+ pass
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 015c144248..5f75ad7579 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -22,8 +23,9 @@ from six.moves.urllib import parse as urlparse
from twisted.internet import defer
+import synapse.rest.admin
from synapse.api.constants import Membership
-from synapse.rest.client.v1 import admin, login, room
+from synapse.rest.client.v1 import login, profile, room
from tests import unittest
@@ -803,7 +805,7 @@ class RoomMessageListTestCase(RoomBase):
class RoomSearchTestCase(unittest.HomeserverTestCase):
servlets = [
- admin.register_servlets,
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
room.register_servlets,
login.register_servlets,
]
@@ -903,3 +905,102 @@ class RoomSearchTestCase(unittest.HomeserverTestCase):
self.assertEqual(
context["profile_info"][self.other_user_id]["displayname"], "otheruser"
)
+
+
+class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+
+ self.url = b"/_matrix/client/r0/publicRooms"
+
+ config = self.default_config()
+ config["restrict_public_rooms_to_local_users"] = True
+ self.hs = self.setup_test_homeserver(config=config)
+
+ return self.hs
+
+ def test_restricted_no_auth(self):
+ request, channel = self.make_request("GET", self.url)
+ self.render(request)
+ self.assertEqual(channel.code, 401, channel.result)
+
+ def test_restricted_auth(self):
+ self.register_user("user", "pass")
+ tok = self.login("user", "pass")
+
+ request, channel = self.make_request("GET", self.url, access_token=tok)
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+
+
+class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ profile.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+ config["allow_per_room_profiles"] = False
+ self.hs = self.setup_test_homeserver(config=config)
+
+ return self.hs
+
+ def prepare(self, reactor, clock, homeserver):
+ self.user_id = self.register_user("test", "test")
+ self.tok = self.login("test", "test")
+
+ # Set a profile for the test user
+ self.displayname = "test user"
+ data = {
+ "displayname": self.displayname,
+ }
+ request_data = json.dumps(data)
+ request, channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/profile/%s/displayname" % (self.user_id,),
+ request_data,
+ access_token=self.tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
+
+ def test_per_room_profile_forbidden(self):
+ data = {
+ "membership": "join",
+ "displayname": "other test user"
+ }
+ request_data = json.dumps(data)
+ request, channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" % (
+ self.room_id, self.user_id,
+ ),
+ request_data,
+ access_token=self.tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+ event_id = channel.json_body["event_id"]
+
+ request, channel = self.make_request(
+ "GET",
+ "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id),
+ access_token=self.tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ res_displayname = channel.json_body["content"]["displayname"]
+ self.assertEqual(res_displayname, self.displayname, channel.result)
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index 05b0143c42..f7133fc12e 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -127,3 +127,20 @@ class RestHelper(object):
)
return channel.json_body
+
+ def send_state(self, room_id, event_type, body, tok, expect_code=200):
+ path = "/_matrix/client/r0/rooms/%s/state/%s" % (room_id, event_type)
+ if tok:
+ path = path + "?access_token=%s" % tok
+
+ request, channel = make_request(
+ self.hs.get_reactor(), "PUT", path, json.dumps(body).encode('utf8')
+ )
+ render(request, self.resource, self.hs.get_reactor())
+
+ assert int(channel.result["code"]) == expect_code, (
+ "Expected: %d, got: %d, resp: %r"
+ % (expect_code, int(channel.result["code"]), channel.result["body"])
+ )
+
+ return channel.json_body
diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py
new file mode 100644
index 0000000000..a60a4a3b87
--- /dev/null
+++ b/tests/rest/client/v2_alpha/test_account.py
@@ -0,0 +1,286 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015-2016 OpenMarket Ltd
+# Copyright 2017-2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import os
+import re
+from email.parser import Parser
+
+import pkg_resources
+
+import synapse.rest.admin
+from synapse.api.constants import LoginType
+from synapse.rest.client.v1 import login
+from synapse.rest.client.v2_alpha import account, register
+
+from tests import unittest
+
+
+class PasswordResetTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ account.register_servlets,
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ register.register_servlets,
+ login.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+
+ # Email config.
+ self.email_attempts = []
+
+ def sendmail(smtphost, from_addr, to_addrs, msg, **kwargs):
+ self.email_attempts.append(msg)
+ return
+
+ config["email"] = {
+ "enable_notifs": False,
+ "template_dir": os.path.abspath(
+ pkg_resources.resource_filename("synapse", "res/templates")
+ ),
+ "smtp_host": "127.0.0.1",
+ "smtp_port": 20,
+ "require_transport_security": False,
+ "smtp_user": None,
+ "smtp_pass": None,
+ "notif_from": "test@example.com",
+ }
+ config["public_baseurl"] = "https://example.com"
+
+ hs = self.setup_test_homeserver(config=config, sendmail=sendmail)
+ return hs
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+
+ def test_basic_password_reset(self):
+ """Test basic password reset flow
+ """
+ old_password = "monkey"
+ new_password = "kangeroo"
+
+ user_id = self.register_user("kermit", old_password)
+ self.login("kermit", old_password)
+
+ email = "test@example.com"
+
+ # Add a threepid
+ self.get_success(
+ self.store.user_add_threepid(
+ user_id=user_id,
+ medium="email",
+ address=email,
+ validated_at=0,
+ added_at=0,
+ )
+ )
+
+ client_secret = "foobar"
+ session_id = self._request_token(email, client_secret)
+
+ self.assertEquals(len(self.email_attempts), 1)
+ link = self._get_link_from_email()
+
+ self._validate_token(link)
+
+ self._reset_password(new_password, session_id, client_secret)
+
+ # Assert we can log in with the new password
+ self.login("kermit", new_password)
+
+ # Assert we can't log in with the old password
+ self.attempt_wrong_password_login("kermit", old_password)
+
+ def test_cant_reset_password_without_clicking_link(self):
+ """Test that we do actually need to click the link in the email
+ """
+ old_password = "monkey"
+ new_password = "kangeroo"
+
+ user_id = self.register_user("kermit", old_password)
+ self.login("kermit", old_password)
+
+ email = "test@example.com"
+
+ # Add a threepid
+ self.get_success(
+ self.store.user_add_threepid(
+ user_id=user_id,
+ medium="email",
+ address=email,
+ validated_at=0,
+ added_at=0,
+ )
+ )
+
+ client_secret = "foobar"
+ session_id = self._request_token(email, client_secret)
+
+ self.assertEquals(len(self.email_attempts), 1)
+
+ # Attempt to reset password without clicking the link
+ self._reset_password(
+ new_password, session_id, client_secret, expected_code=401,
+ )
+
+ # Assert we can log in with the old password
+ self.login("kermit", old_password)
+
+ # Assert we can't log in with the new password
+ self.attempt_wrong_password_login("kermit", new_password)
+
+ def test_no_valid_token(self):
+ """Test that we do actually need to request a token and can't just
+ make a session up.
+ """
+ old_password = "monkey"
+ new_password = "kangeroo"
+
+ user_id = self.register_user("kermit", old_password)
+ self.login("kermit", old_password)
+
+ email = "test@example.com"
+
+ # Add a threepid
+ self.get_success(
+ self.store.user_add_threepid(
+ user_id=user_id,
+ medium="email",
+ address=email,
+ validated_at=0,
+ added_at=0,
+ )
+ )
+
+ client_secret = "foobar"
+ session_id = "weasle"
+
+ # Attempt to reset password without even requesting an email
+ self._reset_password(
+ new_password, session_id, client_secret, expected_code=401,
+ )
+
+ # Assert we can log in with the old password
+ self.login("kermit", old_password)
+
+ # Assert we can't log in with the new password
+ self.attempt_wrong_password_login("kermit", new_password)
+
+ def _request_token(self, email, client_secret):
+ request, channel = self.make_request(
+ "POST",
+ b"account/password/email/requestToken",
+ {"client_secret": client_secret, "email": email, "send_attempt": 1},
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.result)
+
+ return channel.json_body["sid"]
+
+ def _validate_token(self, link):
+ # Remove the host
+ path = link.replace("https://example.com", "")
+
+ request, channel = self.make_request("GET", path, shorthand=False)
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.result)
+
+ def _get_link_from_email(self):
+ assert self.email_attempts, "No emails have been sent"
+
+ raw_msg = self.email_attempts[-1].decode("UTF-8")
+ mail = Parser().parsestr(raw_msg)
+
+ text = None
+ for part in mail.walk():
+ if part.get_content_type() == "text/plain":
+ text = part.get_payload(decode=True).decode("UTF-8")
+ break
+
+ if not text:
+ self.fail("Could not find text portion of email to parse")
+
+ match = re.search(r"https://example.com\S+", text)
+ assert match, "Could not find link in email"
+
+ return match.group(0)
+
+ def _reset_password(
+ self, new_password, session_id, client_secret, expected_code=200
+ ):
+ request, channel = self.make_request(
+ "POST",
+ b"account/password",
+ {
+ "new_password": new_password,
+ "auth": {
+ "type": LoginType.EMAIL_IDENTITY,
+ "threepid_creds": {
+ "client_secret": client_secret,
+ "sid": session_id,
+ },
+ },
+ },
+ )
+ self.render(request)
+ self.assertEquals(expected_code, channel.code, channel.result)
+
+
+class DeactivateTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ login.register_servlets,
+ account.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ hs = self.setup_test_homeserver()
+ return hs
+
+ def test_deactivate_account(self):
+ user_id = self.register_user("kermit", "test")
+ tok = self.login("kermit", "test")
+
+ request_data = json.dumps({
+ "auth": {
+ "type": "m.login.password",
+ "user": user_id,
+ "password": "test",
+ },
+ "erase": False,
+ })
+ request, channel = self.make_request(
+ "POST",
+ "account/deactivate",
+ request_data,
+ access_token=tok,
+ )
+ self.render(request)
+ self.assertEqual(request.code, 200)
+
+ store = self.hs.get_datastore()
+
+ # Check that the user has been marked as deactivated.
+ self.assertTrue(self.get_success(store.get_user_deactivated_status(user_id)))
+
+ # Check that this access token has been invalidated.
+ request, channel = self.make_request("GET", "account/whoami")
+ self.render(request)
+ self.assertEqual(request.code, 401)
diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py
index 7fa120a10f..b9ef46e8fb 100644
--- a/tests/rest/client/v2_alpha/test_auth.py
+++ b/tests/rest/client/v2_alpha/test_auth.py
@@ -16,8 +16,8 @@
from twisted.internet.defer import succeed
+import synapse.rest.admin
from synapse.api.constants import LoginType
-from synapse.rest.client.v1 import admin
from synapse.rest.client.v2_alpha import auth, register
from tests import unittest
@@ -27,7 +27,7 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
servlets = [
auth.register_servlets,
- admin.register_servlets,
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
register.register_servlets,
]
hijack_auth = False
@@ -36,9 +36,9 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
config = self.default_config()
- config.enable_registration_captcha = True
- config.recaptcha_public_key = "brokencake"
- config.registrations_require_3pid = []
+ config["enable_registration_captcha"] = True
+ config["recaptcha_public_key"] = "brokencake"
+ config["registrations_require_3pid"] = []
hs = self.setup_test_homeserver(config=config)
return hs
@@ -92,7 +92,14 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
self.assertEqual(len(self.recaptcha_attempts), 1)
self.assertEqual(self.recaptcha_attempts[0][0]["response"], "a")
- # Now we have fufilled the recaptcha fallback step, we can then send a
+ # also complete the dummy auth
+ request, channel = self.make_request(
+ "POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}}
+ )
+ self.render(request)
+
+ # Now we should have fufilled a complete auth flow, including
+ # the recaptcha fallback step, we can then send a
# request to the register API with the session in the authdict.
request, channel = self.make_request(
"POST", "register", {"auth": {"session": session}}
diff --git a/tests/rest/client/v2_alpha/test_capabilities.py b/tests/rest/client/v2_alpha/test_capabilities.py
index bbfc77e829..bce5b0cf4c 100644
--- a/tests/rest/client/v2_alpha/test_capabilities.py
+++ b/tests/rest/client/v2_alpha/test_capabilities.py
@@ -12,9 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-from synapse.api.room_versions import DEFAULT_ROOM_VERSION, KNOWN_ROOM_VERSIONS
-from synapse.rest.client.v1 import admin, login
+import synapse.rest.admin
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
+from synapse.rest.client.v1 import login
from synapse.rest.client.v2_alpha import capabilities
from tests import unittest
@@ -23,7 +23,7 @@ from tests import unittest
class CapabilitiesTestCase(unittest.HomeserverTestCase):
servlets = [
- admin.register_servlets,
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
capabilities.register_servlets,
login.register_servlets,
]
@@ -32,6 +32,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
self.url = b"/_matrix/client/r0/capabilities"
hs = self.setup_test_homeserver()
self.store = hs.get_datastore()
+ self.config = hs.config
return hs
def test_check_auth_required(self):
@@ -51,8 +52,10 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200)
for room_version in capabilities['m.room_versions']['available'].keys():
self.assertTrue(room_version in KNOWN_ROOM_VERSIONS, "" + room_version)
+
self.assertEqual(
- DEFAULT_ROOM_VERSION.identifier, capabilities['m.room_versions']['default']
+ self.config.default_room_version.identifier,
+ capabilities['m.room_versions']['default'],
)
def test_get_change_password_capabilities(self):
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index d3611ed21f..b35b215446 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -1,11 +1,32 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2017-2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import datetime
import json
+import os
+
+import pkg_resources
+import synapse.rest.admin
from synapse.api.constants import LoginType
from synapse.api.errors import Codes
from synapse.appservice import ApplicationService
-from synapse.rest.client.v1 import admin, login
-from synapse.rest.client.v2_alpha import register, sync
+from synapse.rest.client.v1 import login
+from synapse.rest.client.v2_alpha import account, account_validity, register, sync
from tests import unittest
@@ -32,11 +53,10 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
as_token = "i_am_an_app_service"
appservice = ApplicationService(
- as_token, self.hs.config.server_name,
+ as_token,
+ self.hs.config.server_name,
id="1234",
- namespaces={
- "users": [{"regex": r"@as_user.*", "exclusive": True}],
- },
+ namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
)
self.hs.get_datastore().services_cache.append(appservice)
@@ -48,10 +68,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)
- det_data = {
- "user_id": user_id,
- "home_server": self.hs.hostname,
- }
+ det_data = {"user_id": user_id, "home_server": self.hs.hostname}
self.assertDictContainsSubset(det_data, channel.json_body)
def test_POST_appservice_registration_invalid(self):
@@ -119,10 +136,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
self.render(request)
- det_data = {
- "home_server": self.hs.hostname,
- "device_id": "guest_device",
- }
+ det_data = {"home_server": self.hs.hostname, "device_id": "guest_device"}
self.assertEquals(channel.result["code"], b"200", channel.result)
self.assertDictContainsSubset(det_data, channel.json_body)
@@ -150,7 +164,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
else:
self.assertEquals(channel.result["code"], b"200", channel.result)
- self.reactor.advance(retry_after_ms / 1000.)
+ self.reactor.advance(retry_after_ms / 1000.0)
request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
self.render(request)
@@ -178,7 +192,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
else:
self.assertEquals(channel.result["code"], b"200", channel.result)
- self.reactor.advance(retry_after_ms / 1000.)
+ self.reactor.advance(retry_after_ms / 1000.0)
request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
self.render(request)
@@ -190,16 +204,20 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
servlets = [
register.register_servlets,
- admin.register_servlets,
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
login.register_servlets,
sync.register_servlets,
+ account_validity.register_servlets,
]
def make_homeserver(self, reactor, clock):
config = self.default_config()
- config.enable_registration = True
- config.account_validity.enabled = True
- config.account_validity.period = 604800000 # Time in ms for 1 week
+ # Test for account expiring after a week.
+ config["enable_registration"] = True
+ config["account_validity"] = {
+ "enabled": True,
+ "period": 604800000, # Time in ms for 1 week
+ }
self.hs = self.setup_test_homeserver(config=config)
return self.hs
@@ -210,21 +228,290 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
# The specific endpoint doesn't matter, all we need is an authenticated
# endpoint.
+ request, channel = self.make_request(b"GET", "/sync", access_token=tok)
+ self.render(request)
+
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ self.reactor.advance(datetime.timedelta(weeks=1).total_seconds())
+
+ request, channel = self.make_request(b"GET", "/sync", access_token=tok)
+ self.render(request)
+
+ self.assertEquals(channel.result["code"], b"403", channel.result)
+ self.assertEquals(
+ channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result
+ )
+
+ def test_manual_renewal(self):
+ user_id = self.register_user("kermit", "monkey")
+ tok = self.login("kermit", "monkey")
+
+ self.reactor.advance(datetime.timedelta(weeks=1).total_seconds())
+
+ # If we register the admin user at the beginning of the test, it will
+ # expire at the same time as the normal user and the renewal request
+ # will be denied.
+ self.register_user("admin", "adminpassword", admin=True)
+ admin_tok = self.login("admin", "adminpassword")
+
+ url = "/_matrix/client/unstable/admin/account_validity/validity"
+ params = {"user_id": user_id}
+ request_data = json.dumps(params)
request, channel = self.make_request(
- b"GET", "/sync", access_token=tok,
+ b"POST", url, request_data, access_token=admin_tok
)
self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+ # The specific endpoint doesn't matter, all we need is an authenticated
+ # endpoint.
+ request, channel = self.make_request(b"GET", "/sync", access_token=tok)
+ self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)
- self.reactor.advance(datetime.timedelta(weeks=1).total_seconds())
+ def test_manual_expire(self):
+ user_id = self.register_user("kermit", "monkey")
+ tok = self.login("kermit", "monkey")
+
+ self.register_user("admin", "adminpassword", admin=True)
+ admin_tok = self.login("admin", "adminpassword")
+ url = "/_matrix/client/unstable/admin/account_validity/validity"
+ params = {
+ "user_id": user_id,
+ "expiration_ts": 0,
+ "enable_renewal_emails": False,
+ }
+ request_data = json.dumps(params)
request, channel = self.make_request(
- b"GET", "/sync", access_token=tok,
+ b"POST", url, request_data, access_token=admin_tok
)
self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+ # The specific endpoint doesn't matter, all we need is an authenticated
+ # endpoint.
+ request, channel = self.make_request(b"GET", "/sync", access_token=tok)
+ self.render(request)
self.assertEquals(channel.result["code"], b"403", channel.result)
self.assertEquals(
- channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result,
+ channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result
)
+
+
+class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ register.register_servlets,
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ login.register_servlets,
+ sync.register_servlets,
+ account_validity.register_servlets,
+ account.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+
+ # Test for account expiring after a week and renewal emails being sent 2
+ # days before expiry.
+ config["enable_registration"] = True
+ config["account_validity"] = {
+ "enabled": True,
+ "period": 604800000, # Time in ms for 1 week
+ "renew_at": 172800000, # Time in ms for 2 days
+ "renew_by_email_enabled": True,
+ "renew_email_subject": "Renew your account",
+ }
+
+ # Email config.
+ self.email_attempts = []
+
+ def sendmail(*args, **kwargs):
+ self.email_attempts.append((args, kwargs))
+ return
+
+ config["email"] = {
+ "enable_notifs": True,
+ "template_dir": os.path.abspath(
+ pkg_resources.resource_filename('synapse', 'res/templates')
+ ),
+ "expiry_template_html": "notice_expiry.html",
+ "expiry_template_text": "notice_expiry.txt",
+ "notif_template_html": "notif_mail.html",
+ "notif_template_text": "notif_mail.txt",
+ "smtp_host": "127.0.0.1",
+ "smtp_port": 20,
+ "require_transport_security": False,
+ "smtp_user": None,
+ "smtp_pass": None,
+ "notif_from": "test@example.com",
+ }
+ config["public_baseurl"] = "aaa"
+
+ self.hs = self.setup_test_homeserver(config=config, sendmail=sendmail)
+
+ self.store = self.hs.get_datastore()
+
+ return self.hs
+
+ def test_renewal_email(self):
+ self.email_attempts = []
+
+ (user_id, tok) = self.create_user()
+
+ # Move 6 days forward. This should trigger a renewal email to be sent.
+ self.reactor.advance(datetime.timedelta(days=6).total_seconds())
+ self.assertEqual(len(self.email_attempts), 1)
+
+ # Retrieving the URL from the email is too much pain for now, so we
+ # retrieve the token from the DB.
+ renewal_token = self.get_success(self.store.get_renewal_token_for_user(user_id))
+ url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token
+ request, channel = self.make_request(b"GET", url)
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ # Move 3 days forward. If the renewal failed, every authed request with
+ # our access token should be denied from now, otherwise they should
+ # succeed.
+ self.reactor.advance(datetime.timedelta(days=3).total_seconds())
+ request, channel = self.make_request(b"GET", "/sync", access_token=tok)
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ def test_manual_email_send(self):
+ self.email_attempts = []
+
+ (user_id, tok) = self.create_user()
+ request, channel = self.make_request(
+ b"POST",
+ "/_matrix/client/unstable/account_validity/send_mail",
+ access_token=tok,
+ )
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ self.assertEqual(len(self.email_attempts), 1)
+
+ def test_deactivated_user(self):
+ self.email_attempts = []
+
+ (user_id, tok) = self.create_user()
+
+ request_data = json.dumps({
+ "auth": {
+ "type": "m.login.password",
+ "user": user_id,
+ "password": "monkey",
+ },
+ "erase": False,
+ })
+ request, channel = self.make_request(
+ "POST",
+ "account/deactivate",
+ request_data,
+ access_token=tok,
+ )
+ self.render(request)
+ self.assertEqual(request.code, 200)
+
+ self.reactor.advance(datetime.timedelta(days=8).total_seconds())
+
+ self.assertEqual(len(self.email_attempts), 0)
+
+ def create_user(self):
+ user_id = self.register_user("kermit", "monkey")
+ tok = self.login("kermit", "monkey")
+ # We need to manually add an email address otherwise the handler will do
+ # nothing.
+ now = self.hs.clock.time_msec()
+ self.get_success(
+ self.store.user_add_threepid(
+ user_id=user_id,
+ medium="email",
+ address="kermit@example.com",
+ validated_at=now,
+ added_at=now,
+ )
+ )
+ return (user_id, tok)
+
+ def test_manual_email_send_expired_account(self):
+ user_id = self.register_user("kermit", "monkey")
+ tok = self.login("kermit", "monkey")
+
+ # We need to manually add an email address otherwise the handler will do
+ # nothing.
+ now = self.hs.clock.time_msec()
+ self.get_success(
+ self.store.user_add_threepid(
+ user_id=user_id,
+ medium="email",
+ address="kermit@example.com",
+ validated_at=now,
+ added_at=now,
+ )
+ )
+
+ # Make the account expire.
+ self.reactor.advance(datetime.timedelta(days=8).total_seconds())
+
+ # Ignore all emails sent by the automatic background task and only focus on the
+ # ones sent manually.
+ self.email_attempts = []
+
+ # Test that we're still able to manually trigger a mail to be sent.
+ request, channel = self.make_request(
+ b"POST",
+ "/_matrix/client/unstable/account_validity/send_mail",
+ access_token=tok,
+ )
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ self.assertEqual(len(self.email_attempts), 1)
+
+
+class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ self.validity_period = 10
+ self.max_delta = self.validity_period * 10. / 100.
+
+ config = self.default_config()
+
+ config["enable_registration"] = True
+ config["account_validity"] = {
+ "enabled": False,
+ }
+
+ self.hs = self.setup_test_homeserver(config=config)
+ self.hs.config.account_validity.period = self.validity_period
+
+ self.store = self.hs.get_datastore()
+
+ return self.hs
+
+ def test_background_job(self):
+ """
+ Tests the same thing as test_background_job, except that it sets the
+ startup_job_max_delta parameter and checks that the expiration date is within the
+ allowed range.
+ """
+ user_id = self.register_user("kermit_delta", "user")
+
+ self.hs.config.account_validity.startup_job_max_delta = self.max_delta
+
+ now_ms = self.hs.clock.time_msec()
+ self.get_success(self.store._set_expiration_date_when_missing())
+
+ res = self.get_success(self.store.get_expiration_ts_for_user(user_id))
+
+ self.assertGreaterEqual(res, now_ms + self.validity_period - self.max_delta)
+ self.assertLessEqual(res, now_ms + self.validity_period)
diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py
new file mode 100644
index 0000000000..43b3049daa
--- /dev/null
+++ b/tests/rest/client/v2_alpha/test_relations.py
@@ -0,0 +1,564 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import itertools
+import json
+
+import six
+
+from synapse.api.constants import EventTypes, RelationTypes
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+from synapse.rest.client.v2_alpha import register, relations
+
+from tests import unittest
+
+
+class RelationsTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ relations.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ register.register_servlets,
+ admin.register_servlets_for_client_rest_resource,
+ ]
+ hijack_auth = False
+
+ def make_homeserver(self, reactor, clock):
+ # We need to enable msc1849 support for aggregations
+ config = self.default_config()
+ config["experimental_msc1849_support_enabled"] = True
+ return self.setup_test_homeserver(config=config)
+
+ def prepare(self, reactor, clock, hs):
+ self.user_id, self.user_token = self._create_user("alice")
+ self.user2_id, self.user2_token = self._create_user("bob")
+
+ self.room = self.helper.create_room_as(self.user_id, tok=self.user_token)
+ self.helper.join(self.room, user=self.user2_id, tok=self.user2_token)
+ res = self.helper.send(self.room, body="Hi!", tok=self.user_token)
+ self.parent_id = res["event_id"]
+
+ def test_send_relation(self):
+ """Tests that sending a relation using the new /send_relation works
+ creates the right shape of event.
+ """
+
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key=u"👍")
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ event_id = channel.json_body["event_id"]
+
+ request, channel = self.make_request(
+ "GET",
+ "/rooms/%s/event/%s" % (self.room, event_id),
+ access_token=self.user_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ self.assert_dict(
+ {
+ "type": "m.reaction",
+ "sender": self.user_id,
+ "content": {
+ "m.relates_to": {
+ "event_id": self.parent_id,
+ "key": u"👍",
+ "rel_type": RelationTypes.ANNOTATION,
+ }
+ },
+ },
+ channel.json_body,
+ )
+
+ def test_deny_membership(self):
+ """Test that we deny relations on membership events
+ """
+ channel = self._send_relation(RelationTypes.ANNOTATION, EventTypes.Member)
+ self.assertEquals(400, channel.code, channel.json_body)
+
+ def test_deny_double_react(self):
+ """Test that we deny relations on membership events
+ """
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
+ self.assertEquals(400, channel.code, channel.json_body)
+
+ def test_basic_paginate_relations(self):
+ """Tests that calling pagination API corectly the latest relations.
+ """
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction")
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction")
+ self.assertEquals(200, channel.code, channel.json_body)
+ annotation_id = channel.json_body["event_id"]
+
+ request, channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/rooms/%s/relations/%s?limit=1"
+ % (self.room, self.parent_id),
+ access_token=self.user_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ # We expect to get back a single pagination result, which is the full
+ # relation event we sent above.
+ self.assertEquals(len(channel.json_body["chunk"]), 1, channel.json_body)
+ self.assert_dict(
+ {"event_id": annotation_id, "sender": self.user_id, "type": "m.reaction"},
+ channel.json_body["chunk"][0],
+ )
+
+ # Make sure next_batch has something in it that looks like it could be a
+ # valid token.
+ self.assertIsInstance(
+ channel.json_body.get("next_batch"), six.string_types, channel.json_body
+ )
+
+ def test_repeated_paginate_relations(self):
+ """Test that if we paginate using a limit and tokens then we get the
+ expected events.
+ """
+
+ expected_event_ids = []
+ for _ in range(10):
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction")
+ self.assertEquals(200, channel.code, channel.json_body)
+ expected_event_ids.append(channel.json_body["event_id"])
+
+ prev_token = None
+ found_event_ids = []
+ for _ in range(20):
+ from_token = ""
+ if prev_token:
+ from_token = "&from=" + prev_token
+
+ request, channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/rooms/%s/relations/%s?limit=1%s"
+ % (self.room, self.parent_id, from_token),
+ access_token=self.user_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
+ next_batch = channel.json_body.get("next_batch")
+
+ self.assertNotEquals(prev_token, next_batch)
+ prev_token = next_batch
+
+ if not prev_token:
+ break
+
+ # We paginated backwards, so reverse
+ found_event_ids.reverse()
+ self.assertEquals(found_event_ids, expected_event_ids)
+
+ def test_aggregation_pagination_groups(self):
+ """Test that we can paginate annotation groups correctly.
+ """
+
+ # We need to create ten separate users to send each reaction.
+ access_tokens = [self.user_token, self.user2_token]
+ idx = 0
+ while len(access_tokens) < 10:
+ user_id, token = self._create_user("test" + str(idx))
+ idx += 1
+
+ self.helper.join(self.room, user=user_id, tok=token)
+ access_tokens.append(token)
+
+ idx = 0
+ sent_groups = {u"👍": 10, u"a": 7, u"b": 5, u"c": 3, u"d": 2, u"e": 1}
+ for key in itertools.chain.from_iterable(
+ itertools.repeat(key, num) for key, num in sent_groups.items()
+ ):
+ channel = self._send_relation(
+ RelationTypes.ANNOTATION,
+ "m.reaction",
+ key=key,
+ access_token=access_tokens[idx],
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ idx += 1
+ idx %= len(access_tokens)
+
+ prev_token = None
+ found_groups = {}
+ for _ in range(20):
+ from_token = ""
+ if prev_token:
+ from_token = "&from=" + prev_token
+
+ request, channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/rooms/%s/aggregations/%s?limit=1%s"
+ % (self.room, self.parent_id, from_token),
+ access_token=self.user_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
+
+ for groups in channel.json_body["chunk"]:
+ # We only expect reactions
+ self.assertEqual(groups["type"], "m.reaction", channel.json_body)
+
+ # We should only see each key once
+ self.assertNotIn(groups["key"], found_groups, channel.json_body)
+
+ found_groups[groups["key"]] = groups["count"]
+
+ next_batch = channel.json_body.get("next_batch")
+
+ self.assertNotEquals(prev_token, next_batch)
+ prev_token = next_batch
+
+ if not prev_token:
+ break
+
+ self.assertEquals(sent_groups, found_groups)
+
+ def test_aggregation_pagination_within_group(self):
+ """Test that we can paginate within an annotation group.
+ """
+
+ # We need to create ten separate users to send each reaction.
+ access_tokens = [self.user_token, self.user2_token]
+ idx = 0
+ while len(access_tokens) < 10:
+ user_id, token = self._create_user("test" + str(idx))
+ idx += 1
+
+ self.helper.join(self.room, user=user_id, tok=token)
+ access_tokens.append(token)
+
+ idx = 0
+ expected_event_ids = []
+ for _ in range(10):
+ channel = self._send_relation(
+ RelationTypes.ANNOTATION,
+ "m.reaction",
+ key=u"👍",
+ access_token=access_tokens[idx],
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ expected_event_ids.append(channel.json_body["event_id"])
+
+ idx += 1
+
+ # Also send a different type of reaction so that we test we don't see it
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ prev_token = None
+ found_event_ids = []
+ encoded_key = six.moves.urllib.parse.quote_plus(u"👍".encode("utf-8"))
+ for _ in range(20):
+ from_token = ""
+ if prev_token:
+ from_token = "&from=" + prev_token
+
+ request, channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/rooms/%s"
+ "/aggregations/%s/%s/m.reaction/%s?limit=1%s"
+ % (
+ self.room,
+ self.parent_id,
+ RelationTypes.ANNOTATION,
+ encoded_key,
+ from_token,
+ ),
+ access_token=self.user_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
+
+ found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
+
+ next_batch = channel.json_body.get("next_batch")
+
+ self.assertNotEquals(prev_token, next_batch)
+ prev_token = next_batch
+
+ if not prev_token:
+ break
+
+ # We paginated backwards, so reverse
+ found_event_ids.reverse()
+ self.assertEquals(found_event_ids, expected_event_ids)
+
+ def test_aggregation(self):
+ """Test that annotations get correctly aggregated.
+ """
+
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ channel = self._send_relation(
+ RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ request, channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/rooms/%s/aggregations/%s"
+ % (self.room, self.parent_id),
+ access_token=self.user_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ self.assertEquals(
+ channel.json_body,
+ {
+ "chunk": [
+ {"type": "m.reaction", "key": "a", "count": 2},
+ {"type": "m.reaction", "key": "b", "count": 1},
+ ]
+ },
+ )
+
+ def test_aggregation_redactions(self):
+ """Test that annotations get correctly aggregated after a redaction.
+ """
+
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
+ self.assertEquals(200, channel.code, channel.json_body)
+ to_redact_event_id = channel.json_body["event_id"]
+
+ channel = self._send_relation(
+ RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ # Now lets redact one of the 'a' reactions
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/rooms/%s/redact/%s" % (self.room, to_redact_event_id),
+ access_token=self.user_token,
+ content={},
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ request, channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/rooms/%s/aggregations/%s"
+ % (self.room, self.parent_id),
+ access_token=self.user_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ self.assertEquals(
+ channel.json_body,
+ {"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]},
+ )
+
+ def test_aggregation_must_be_annotation(self):
+ """Test that aggregations must be annotations.
+ """
+
+ request, channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/rooms/%s/aggregations/%s/%s?limit=1"
+ % (self.room, self.parent_id, RelationTypes.REPLACE),
+ access_token=self.user_token,
+ )
+ self.render(request)
+ self.assertEquals(400, channel.code, channel.json_body)
+
+ def test_aggregation_get_event(self):
+ """Test that annotations and references get correctly bundled when
+ getting the parent event.
+ """
+
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ channel = self._send_relation(
+ RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test")
+ self.assertEquals(200, channel.code, channel.json_body)
+ reply_1 = channel.json_body["event_id"]
+
+ channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test")
+ self.assertEquals(200, channel.code, channel.json_body)
+ reply_2 = channel.json_body["event_id"]
+
+ request, channel = self.make_request(
+ "GET",
+ "/rooms/%s/event/%s" % (self.room, self.parent_id),
+ access_token=self.user_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ self.assertEquals(
+ channel.json_body["unsigned"].get("m.relations"),
+ {
+ RelationTypes.ANNOTATION: {
+ "chunk": [
+ {"type": "m.reaction", "key": "a", "count": 2},
+ {"type": "m.reaction", "key": "b", "count": 1},
+ ]
+ },
+ RelationTypes.REFERENCE: {
+ "chunk": [{"event_id": reply_1}, {"event_id": reply_2}]
+ },
+ },
+ )
+
+ def test_edit(self):
+ """Test that a simple edit works.
+ """
+
+ new_body = {"msgtype": "m.text", "body": "I've been edited!"}
+ channel = self._send_relation(
+ RelationTypes.REPLACE,
+ "m.room.message",
+ content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ edit_event_id = channel.json_body["event_id"]
+
+ request, channel = self.make_request(
+ "GET",
+ "/rooms/%s/event/%s" % (self.room, self.parent_id),
+ access_token=self.user_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ self.assertEquals(channel.json_body["content"], new_body)
+
+ self.assertEquals(
+ channel.json_body["unsigned"].get("m.relations"),
+ {RelationTypes.REPLACE: {"event_id": edit_event_id}},
+ )
+
+ def test_multi_edit(self):
+ """Test that multiple edits, including attempts by people who
+ shouldn't be allowed, are correctly handled.
+ """
+
+ channel = self._send_relation(
+ RelationTypes.REPLACE,
+ "m.room.message",
+ content={
+ "msgtype": "m.text",
+ "body": "Wibble",
+ "m.new_content": {"msgtype": "m.text", "body": "First edit"},
+ },
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ new_body = {"msgtype": "m.text", "body": "I've been edited!"}
+ channel = self._send_relation(
+ RelationTypes.REPLACE,
+ "m.room.message",
+ content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ edit_event_id = channel.json_body["event_id"]
+
+ channel = self._send_relation(
+ RelationTypes.REPLACE,
+ "m.room.message.WRONG_TYPE",
+ content={
+ "msgtype": "m.text",
+ "body": "Wibble",
+ "m.new_content": {"msgtype": "m.text", "body": "Edit, but wrong type"},
+ },
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ request, channel = self.make_request(
+ "GET",
+ "/rooms/%s/event/%s" % (self.room, self.parent_id),
+ access_token=self.user_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ self.assertEquals(channel.json_body["content"], new_body)
+
+ self.assertEquals(
+ channel.json_body["unsigned"].get("m.relations"),
+ {RelationTypes.REPLACE: {"event_id": edit_event_id}},
+ )
+
+ def _send_relation(
+ self, relation_type, event_type, key=None, content={}, access_token=None
+ ):
+ """Helper function to send a relation pointing at `self.parent_id`
+
+ Args:
+ relation_type (str): One of `RelationTypes`
+ event_type (str): The type of the event to create
+ key (str|None): The aggregation key used for m.annotation relation
+ type.
+ content(dict|None): The content of the created event.
+ access_token (str|None): The access token used to send the relation,
+ defaults to `self.user_token`
+
+ Returns:
+ FakeChannel
+ """
+ if not access_token:
+ access_token = self.user_token
+
+ query = ""
+ if key:
+ query = "?key=" + six.moves.urllib.parse.quote_plus(key.encode("utf-8"))
+
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/client/unstable/rooms/%s/send_relation/%s/%s/%s%s"
+ % (self.room, self.parent_id, relation_type, event_type, query),
+ json.dumps(content).encode("utf-8"),
+ access_token=access_token,
+ )
+ self.render(request)
+ return channel
+
+ def _create_user(self, localpart):
+ user_id = self.register_user(localpart, "abc123")
+ access_token = self.login(localpart, "abc123")
+
+ return user_id, access_token
diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py
index 99b716f00a..71895094bd 100644
--- a/tests/rest/client/v2_alpha/test_sync.py
+++ b/tests/rest/client/v2_alpha/test_sync.py
@@ -15,7 +15,8 @@
from mock import Mock
-from synapse.rest.client.v1 import admin, login, room
+import synapse.rest.admin
+from synapse.rest.client.v1 import login, room
from synapse.rest.client.v2_alpha import sync
from tests import unittest
@@ -72,7 +73,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
class SyncTypingTests(unittest.HomeserverTestCase):
servlets = [
- admin.register_servlets,
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
room.register_servlets,
login.register_servlets,
sync.register_servlets,
diff --git a/tests/rest/media/v1/test_base.py b/tests/rest/media/v1/test_base.py
index af8f74eb42..00688a7325 100644
--- a/tests/rest/media/v1/test_base.py
+++ b/tests/rest/media/v1/test_base.py
@@ -26,20 +26,14 @@ class GetFileNameFromHeadersTests(unittest.TestCase):
b'inline; filename="aze%20rty"': u"aze%20rty",
b'inline; filename="aze\"rty"': u'aze"rty',
b'inline; filename="azer;ty"': u"azer;ty",
-
b"inline; filename*=utf-8''foo%C2%A3bar": u"foo£bar",
}
def tests(self):
for hdr, expected in self.TEST_CASES.items():
- res = get_filename_from_headers(
- {
- b'Content-Disposition': [hdr],
- },
- )
+ res = get_filename_from_headers({b'Content-Disposition': [hdr]})
self.assertEqual(
- res, expected,
- "expected output for %s to be %s but was %s" % (
- hdr, expected, res,
- )
+ res,
+ expected,
+ "expected output for %s to be %s but was %s" % (hdr, expected, res),
)
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index ad5e9a612f..1069a44145 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -25,13 +25,11 @@ from six.moves.urllib import parse
from twisted.internet import defer, reactor
from twisted.internet.defer import Deferred
-from synapse.config.repository import MediaStorageProviderConfig
from synapse.rest.media.v1._base import FileInfo
from synapse.rest.media.v1.filepath import MediaFilePaths
from synapse.rest.media.v1.media_storage import MediaStorage
from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend
from synapse.util.logcontext import make_deferred_yieldable
-from synapse.util.module_loader import load_module
from tests import unittest
@@ -120,12 +118,14 @@ class MediaRepoTests(unittest.HomeserverTestCase):
client.get_file = get_file
self.storage_path = self.mktemp()
+ self.media_store_path = self.mktemp()
os.mkdir(self.storage_path)
+ os.mkdir(self.media_store_path)
config = self.default_config()
- config.media_store_path = self.storage_path
- config.thumbnail_requirements = {}
- config.max_image_pixels = 2000000
+ config["media_store_path"] = self.media_store_path
+ config["thumbnail_requirements"] = {}
+ config["max_image_pixels"] = 2000000
provider_config = {
"module": "synapse.rest.media.v1.storage_provider.FileStorageProviderBackend",
@@ -134,12 +134,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
"store_remote": True,
"config": {"directory": self.storage_path},
}
-
- loaded = list(load_module(provider_config)) + [
- MediaStorageProviderConfig(False, False, False)
- ]
-
- config.media_storage_providers = [loaded]
+ config["media_storage_providers"] = [provider_config]
hs = self.setup_test_homeserver(config=config, http_client=client)
diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py
index 650ce95a6f..1ab0f7293a 100644
--- a/tests/rest/media/v1/test_url_preview.py
+++ b/tests/rest/media/v1/test_url_preview.py
@@ -16,7 +16,6 @@
import os
import attr
-from netaddr import IPSet
from twisted.internet._resolver import HostResolution
from twisted.internet.address import IPv4Address, IPv6Address
@@ -25,9 +24,6 @@ from twisted.python.failure import Failure
from twisted.test.proto_helpers import AccumulatingProtocol
from twisted.web._newclient import ResponseDone
-from synapse.config.repository import MediaStorageProviderConfig
-from synapse.util.module_loader import load_module
-
from tests import unittest
from tests.server import FakeTransport
@@ -67,23 +63,23 @@ class URLPreviewTests(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
- self.storage_path = self.mktemp()
- os.mkdir(self.storage_path)
-
config = self.default_config()
- config.url_preview_enabled = True
- config.max_spider_size = 9999999
- config.url_preview_ip_range_blacklist = IPSet(
- (
- "192.168.1.1",
- "1.0.0.0/8",
- "3fff:ffff:ffff:ffff:ffff:ffff:ffff:ffff",
- "2001:800::/21",
- )
+ config["url_preview_enabled"] = True
+ config["max_spider_size"] = 9999999
+ config["url_preview_ip_range_blacklist"] = (
+ "192.168.1.1",
+ "1.0.0.0/8",
+ "3fff:ffff:ffff:ffff:ffff:ffff:ffff:ffff",
+ "2001:800::/21",
)
- config.url_preview_ip_range_whitelist = IPSet(("1.1.1.1",))
- config.url_preview_url_blacklist = []
- config.media_store_path = self.storage_path
+ config["url_preview_ip_range_whitelist"] = ("1.1.1.1",)
+ config["url_preview_url_blacklist"] = []
+
+ self.storage_path = self.mktemp()
+ self.media_store_path = self.mktemp()
+ os.mkdir(self.storage_path)
+ os.mkdir(self.media_store_path)
+ config["media_store_path"] = self.media_store_path
provider_config = {
"module": "synapse.rest.media.v1.storage_provider.FileStorageProviderBackend",
@@ -93,11 +89,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
"config": {"directory": self.storage_path},
}
- loaded = list(load_module(provider_config)) + [
- MediaStorageProviderConfig(False, False, False)
- ]
-
- config.media_storage_providers = [loaded]
+ config["media_storage_providers"] = [provider_config]
hs = self.setup_test_homeserver(config=config)
@@ -297,12 +289,12 @@ class URLPreviewTests(unittest.HomeserverTestCase):
# No requests made.
self.assertEqual(len(self.reactor.tcpClients), 0)
- self.assertEqual(channel.code, 403)
+ self.assertEqual(channel.code, 502)
self.assertEqual(
channel.json_body,
{
'errcode': 'M_UNKNOWN',
- 'error': 'IP address blocked by IP blacklist entry',
+ 'error': 'DNS resolution failure during URL preview generation',
},
)
@@ -318,12 +310,12 @@ class URLPreviewTests(unittest.HomeserverTestCase):
request.render(self.preview_url)
self.pump()
- self.assertEqual(channel.code, 403)
+ self.assertEqual(channel.code, 502)
self.assertEqual(
channel.json_body,
{
'errcode': 'M_UNKNOWN',
- 'error': 'IP address blocked by IP blacklist entry',
+ 'error': 'DNS resolution failure during URL preview generation',
},
)
@@ -339,7 +331,6 @@ class URLPreviewTests(unittest.HomeserverTestCase):
# No requests made.
self.assertEqual(len(self.reactor.tcpClients), 0)
- self.assertEqual(channel.code, 403)
self.assertEqual(
channel.json_body,
{
@@ -347,6 +338,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
'error': 'IP address blocked by IP blacklist entry',
},
)
+ self.assertEqual(channel.code, 403)
def test_blacklisted_ip_range_direct(self):
"""
@@ -414,12 +406,12 @@ class URLPreviewTests(unittest.HomeserverTestCase):
)
request.render(self.preview_url)
self.pump()
- self.assertEqual(channel.code, 403)
+ self.assertEqual(channel.code, 502)
self.assertEqual(
channel.json_body,
{
'errcode': 'M_UNKNOWN',
- 'error': 'IP address blocked by IP blacklist entry',
+ 'error': 'DNS resolution failure during URL preview generation',
},
)
@@ -439,12 +431,12 @@ class URLPreviewTests(unittest.HomeserverTestCase):
# No requests made.
self.assertEqual(len(self.reactor.tcpClients), 0)
- self.assertEqual(channel.code, 403)
+ self.assertEqual(channel.code, 502)
self.assertEqual(
channel.json_body,
{
'errcode': 'M_UNKNOWN',
- 'error': 'IP address blocked by IP blacklist entry',
+ 'error': 'DNS resolution failure during URL preview generation',
},
)
@@ -460,11 +452,11 @@ class URLPreviewTests(unittest.HomeserverTestCase):
request.render(self.preview_url)
self.pump()
- self.assertEqual(channel.code, 403)
+ self.assertEqual(channel.code, 502)
self.assertEqual(
channel.json_body,
{
'errcode': 'M_UNKNOWN',
- 'error': 'IP address blocked by IP blacklist entry',
+ 'error': 'DNS resolution failure during URL preview generation',
},
)
diff --git a/tests/rest/test_well_known.py b/tests/rest/test_well_known.py
index 8d8f03e005..b090bb974c 100644
--- a/tests/rest/test_well_known.py
+++ b/tests/rest/test_well_known.py
@@ -31,27 +31,24 @@ class WellKnownTests(unittest.HomeserverTestCase):
self.hs.config.default_identity_server = "https://testis"
request, channel = self.make_request(
- "GET",
- "/.well-known/matrix/client",
- shorthand=False,
+ "GET", "/.well-known/matrix/client", shorthand=False
)
self.render(request)
self.assertEqual(request.code, 200)
self.assertEqual(
- channel.json_body, {
+ channel.json_body,
+ {
"m.homeserver": {"base_url": "https://tesths"},
"m.identity_server": {"base_url": "https://testis"},
- }
+ },
)
def test_well_known_no_public_baseurl(self):
self.hs.config.public_baseurl = None
request, channel = self.make_request(
- "GET",
- "/.well-known/matrix/client",
- shorthand=False,
+ "GET", "/.well-known/matrix/client", shorthand=False
)
self.render(request)
diff --git a/tests/server.py b/tests/server.py
index 8f89f4a83d..c15a47f2a4 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -182,7 +182,8 @@ def make_request(
if federation_auth_origin is not None:
req.requestHeaders.addRawHeader(
- b"Authorization", b"X-Matrix origin=%s,key=,sig=" % (federation_auth_origin,)
+ b"Authorization",
+ b"X-Matrix origin=%s,key=,sig=" % (federation_auth_origin,),
)
if content:
@@ -226,6 +227,8 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
"""
def __init__(self):
+ self.threadpool = ThreadPool(self)
+
self._udp = []
lookups = self.lookups = {}
@@ -233,7 +236,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
class FakeResolver(object):
def getHostByName(self, name, timeout=None):
if name not in lookups:
- return fail(DNSLookupError("OH NO: unknown %s" % (name, )))
+ return fail(DNSLookupError("OH NO: unknown %s" % (name,)))
return succeed(lookups[name])
self.nameResolver = SimpleResolverComplexifier(FakeResolver())
@@ -254,6 +257,37 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
self.callLater(0, d.callback, True)
return d
+ def getThreadPool(self):
+ return self.threadpool
+
+
+class ThreadPool:
+ """
+ Threadless thread pool.
+ """
+
+ def __init__(self, reactor):
+ self._reactor = reactor
+
+ def start(self):
+ pass
+
+ def stop(self):
+ pass
+
+ def callInThreadWithCallback(self, onResult, function, *args, **kwargs):
+ def _(res):
+ if isinstance(res, Failure):
+ onResult(False, res)
+ else:
+ onResult(True, res)
+
+ d = Deferred()
+ d.addCallback(lambda x: function(*args, **kwargs))
+ d.addBoth(_)
+ self._reactor.callLater(0, d.callback, True)
+ return d
+
def setup_test_homeserver(cleanup_func, *args, **kwargs):
"""
@@ -289,36 +323,10 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs):
**kwargs
)
- class ThreadPool:
- """
- Threadless thread pool.
- """
-
- def start(self):
- pass
-
- def stop(self):
- pass
-
- def callInThreadWithCallback(self, onResult, function, *args, **kwargs):
- def _(res):
- if isinstance(res, Failure):
- onResult(False, res)
- else:
- onResult(True, res)
-
- d = Deferred()
- d.addCallback(lambda x: function(*args, **kwargs))
- d.addBoth(_)
- clock._reactor.callLater(0, d.callback, True)
- return d
-
- clock.threadpool = ThreadPool()
-
if pool:
pool.runWithConnection = runWithConnection
pool.runInteraction = runInteraction
- pool.threadpool = ThreadPool()
+ pool.threadpool = ThreadPool(clock._reactor)
pool.running = True
return d
@@ -454,6 +462,6 @@ class FakeTransport(object):
logger.warning("Exception writing to protocol: %s", e)
return
- self.buffer = self.buffer[len(to_write):]
+ self.buffer = self.buffer[len(to_write) :]
if self.buffer and self.autoflush:
self._reactor.callLater(0.0, self.flush)
diff --git a/tests/server_notices/test_consent.py b/tests/server_notices/test_consent.py
index 95badc985e..872039c8f1 100644
--- a/tests/server_notices/test_consent.py
+++ b/tests/server_notices/test_consent.py
@@ -13,7 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.rest.client.v1 import admin, login, room
+import os
+
+import synapse.rest.admin
+from synapse.rest.client.v1 import login, room
from synapse.rest.client.v2_alpha import sync
from tests import unittest
@@ -23,27 +26,34 @@ class ConsentNoticesTests(unittest.HomeserverTestCase):
servlets = [
sync.register_servlets,
- admin.register_servlets,
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
login.register_servlets,
room.register_servlets,
]
def make_homeserver(self, reactor, clock):
+ tmpdir = self.mktemp()
+ os.mkdir(tmpdir)
self.consent_notice_message = "consent %(consent_uri)s"
config = self.default_config()
- config.user_consent_version = "1"
- config.user_consent_server_notice_content = {
- "msgtype": "m.text",
- "body": self.consent_notice_message,
+ config["user_consent"] = {
+ "version": "1",
+ "template_dir": tmpdir,
+ "server_notice_content": {
+ "msgtype": "m.text",
+ "body": self.consent_notice_message,
+ },
+ }
+ config["public_baseurl"] = "https://example.com/"
+ config["form_secret"] = "123abc"
+
+ config["server_notices"] = {
+ "system_mxid_localpart": "notices",
+ "system_mxid_display_name": "test display name",
+ "system_mxid_avatar_url": None,
+ "room_name": "Server Notices",
}
- config.public_baseurl = "https://example.com/"
- config.form_secret = "123abc"
-
- config.server_notices_mxid = "@notices:test"
- config.server_notices_mxid_display_name = "test display name"
- config.server_notices_mxid_avatar_url = None
- config.server_notices_room_name = "Server Notices"
hs = self.setup_test_homeserver(config=config)
diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index be73e718c2..739ee59ce4 100644
--- a/tests/server_notices/test_resource_limits_server_notices.py
+++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -27,10 +27,14 @@ from tests import unittest
class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
-
def make_homeserver(self, reactor, clock):
hs_config = self.default_config("test")
- hs_config.server_notices_mxid = "@server:test"
+ hs_config["server_notices"] = {
+ "system_mxid_localpart": "server",
+ "system_mxid_display_name": "test display name",
+ "system_mxid_avatar_url": None,
+ "room_name": "Server Notices",
+ }
hs = self.setup_test_homeserver(config=hs_config, expire_access_token=True)
return hs
@@ -80,7 +84,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
self._send_notice.assert_not_called()
# Test when mau limiting disabled
self.hs.config.hs_disabled = False
- self.hs.limit_usage_by_mau = False
+ self.hs.config.limit_usage_by_mau = False
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
self._send_notice.assert_not_called()
diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py
index f448b01326..9c5311d916 100644
--- a/tests/state/test_v2.py
+++ b/tests/state/test_v2.py
@@ -50,6 +50,7 @@ class FakeEvent(object):
refer to events. The event_id has node_id as localpart and example.com
as domain.
"""
+
def __init__(self, id, sender, type, state_key, content):
self.node_id = id
self.event_id = EventID(id, "example.com").to_string()
@@ -142,24 +143,14 @@ INITIAL_EVENTS = [
content=MEMBERSHIP_CONTENT_JOIN,
),
FakeEvent(
- id="START",
- sender=ZARA,
- type=EventTypes.Message,
- state_key=None,
- content={},
+ id="START", sender=ZARA, type=EventTypes.Message, state_key=None, content={}
),
FakeEvent(
- id="END",
- sender=ZARA,
- type=EventTypes.Message,
- state_key=None,
- content={},
+ id="END", sender=ZARA, type=EventTypes.Message, state_key=None, content={}
),
]
-INITIAL_EDGES = [
- "START", "IMZ", "IMC", "IMB", "IJR", "IPOWER", "IMA", "CREATE",
-]
+INITIAL_EDGES = ["START", "IMZ", "IMC", "IMB", "IJR", "IPOWER", "IMA", "CREATE"]
class StateTestCase(unittest.TestCase):
@@ -170,12 +161,7 @@ class StateTestCase(unittest.TestCase):
sender=ALICE,
type=EventTypes.PowerLevels,
state_key="",
- content={
- "users": {
- ALICE: 100,
- BOB: 50,
- }
- },
+ content={"users": {ALICE: 100, BOB: 50}},
),
FakeEvent(
id="MA",
@@ -196,19 +182,11 @@ class StateTestCase(unittest.TestCase):
sender=BOB,
type=EventTypes.PowerLevels,
state_key='',
- content={
- "users": {
- ALICE: 100,
- BOB: 50,
- },
- },
+ content={"users": {ALICE: 100, BOB: 50}},
),
]
- edges = [
- ["END", "MB", "MA", "PA", "START"],
- ["END", "PB", "PA"],
- ]
+ edges = [["END", "MB", "MA", "PA", "START"], ["END", "PB", "PA"]]
expected_state_ids = ["PA", "MA", "MB"]
@@ -232,10 +210,7 @@ class StateTestCase(unittest.TestCase):
),
]
- edges = [
- ["END", "JR", "START"],
- ["END", "ME", "START"],
- ]
+ edges = [["END", "JR", "START"], ["END", "ME", "START"]]
expected_state_ids = ["JR"]
@@ -248,45 +223,25 @@ class StateTestCase(unittest.TestCase):
sender=ALICE,
type=EventTypes.PowerLevels,
state_key="",
- content={
- "users": {
- ALICE: 100,
- BOB: 50,
- }
- },
+ content={"users": {ALICE: 100, BOB: 50}},
),
FakeEvent(
id="PB",
sender=BOB,
type=EventTypes.PowerLevels,
state_key='',
- content={
- "users": {
- ALICE: 100,
- BOB: 50,
- CHARLIE: 50,
- },
- },
+ content={"users": {ALICE: 100, BOB: 50, CHARLIE: 50}},
),
FakeEvent(
id="PC",
sender=CHARLIE,
type=EventTypes.PowerLevels,
state_key='',
- content={
- "users": {
- ALICE: 100,
- BOB: 50,
- CHARLIE: 0,
- },
- },
+ content={"users": {ALICE: 100, BOB: 50, CHARLIE: 0}},
),
]
- edges = [
- ["END", "PC", "PB", "PA", "START"],
- ["END", "PA"],
- ]
+ edges = [["END", "PC", "PB", "PA", "START"], ["END", "PA"]]
expected_state_ids = ["PC"]
@@ -295,68 +250,38 @@ class StateTestCase(unittest.TestCase):
def test_topic_basic(self):
events = [
FakeEvent(
- id="T1",
- sender=ALICE,
- type=EventTypes.Topic,
- state_key="",
- content={},
+ id="T1", sender=ALICE, type=EventTypes.Topic, state_key="", content={}
),
FakeEvent(
id="PA1",
sender=ALICE,
type=EventTypes.PowerLevels,
state_key='',
- content={
- "users": {
- ALICE: 100,
- BOB: 50,
- },
- },
+ content={"users": {ALICE: 100, BOB: 50}},
),
FakeEvent(
- id="T2",
- sender=ALICE,
- type=EventTypes.Topic,
- state_key="",
- content={},
+ id="T2", sender=ALICE, type=EventTypes.Topic, state_key="", content={}
),
FakeEvent(
id="PA2",
sender=ALICE,
type=EventTypes.PowerLevels,
state_key='',
- content={
- "users": {
- ALICE: 100,
- BOB: 0,
- },
- },
+ content={"users": {ALICE: 100, BOB: 0}},
),
FakeEvent(
id="PB",
sender=BOB,
type=EventTypes.PowerLevels,
state_key='',
- content={
- "users": {
- ALICE: 100,
- BOB: 50,
- },
- },
+ content={"users": {ALICE: 100, BOB: 50}},
),
FakeEvent(
- id="T3",
- sender=BOB,
- type=EventTypes.Topic,
- state_key="",
- content={},
+ id="T3", sender=BOB, type=EventTypes.Topic, state_key="", content={}
),
]
- edges = [
- ["END", "PA2", "T2", "PA1", "T1", "START"],
- ["END", "T3", "PB", "PA1"],
- ]
+ edges = [["END", "PA2", "T2", "PA1", "T1", "START"], ["END", "T3", "PB", "PA1"]]
expected_state_ids = ["PA2", "T2"]
@@ -365,30 +290,17 @@ class StateTestCase(unittest.TestCase):
def test_topic_reset(self):
events = [
FakeEvent(
- id="T1",
- sender=ALICE,
- type=EventTypes.Topic,
- state_key="",
- content={},
+ id="T1", sender=ALICE, type=EventTypes.Topic, state_key="", content={}
),
FakeEvent(
id="PA",
sender=ALICE,
type=EventTypes.PowerLevels,
state_key='',
- content={
- "users": {
- ALICE: 100,
- BOB: 50,
- },
- },
+ content={"users": {ALICE: 100, BOB: 50}},
),
FakeEvent(
- id="T2",
- sender=BOB,
- type=EventTypes.Topic,
- state_key="",
- content={},
+ id="T2", sender=BOB, type=EventTypes.Topic, state_key="", content={}
),
FakeEvent(
id="MB",
@@ -399,10 +311,7 @@ class StateTestCase(unittest.TestCase):
),
]
- edges = [
- ["END", "MB", "T2", "PA", "T1", "START"],
- ["END", "T1"],
- ]
+ edges = [["END", "MB", "T2", "PA", "T1", "START"], ["END", "T1"]]
expected_state_ids = ["T1", "MB", "PA"]
@@ -411,61 +320,34 @@ class StateTestCase(unittest.TestCase):
def test_topic(self):
events = [
FakeEvent(
- id="T1",
- sender=ALICE,
- type=EventTypes.Topic,
- state_key="",
- content={},
+ id="T1", sender=ALICE, type=EventTypes.Topic, state_key="", content={}
),
FakeEvent(
id="PA1",
sender=ALICE,
type=EventTypes.PowerLevels,
state_key='',
- content={
- "users": {
- ALICE: 100,
- BOB: 50,
- },
- },
+ content={"users": {ALICE: 100, BOB: 50}},
),
FakeEvent(
- id="T2",
- sender=ALICE,
- type=EventTypes.Topic,
- state_key="",
- content={},
+ id="T2", sender=ALICE, type=EventTypes.Topic, state_key="", content={}
),
FakeEvent(
id="PA2",
sender=ALICE,
type=EventTypes.PowerLevels,
state_key='',
- content={
- "users": {
- ALICE: 100,
- BOB: 0,
- },
- },
+ content={"users": {ALICE: 100, BOB: 0}},
),
FakeEvent(
id="PB",
sender=BOB,
type=EventTypes.PowerLevels,
state_key='',
- content={
- "users": {
- ALICE: 100,
- BOB: 50,
- },
- },
+ content={"users": {ALICE: 100, BOB: 50}},
),
FakeEvent(
- id="T3",
- sender=BOB,
- type=EventTypes.Topic,
- state_key="",
- content={},
+ id="T3", sender=BOB, type=EventTypes.Topic, state_key="", content={}
),
FakeEvent(
id="MZ1",
@@ -475,11 +357,7 @@ class StateTestCase(unittest.TestCase):
content={},
),
FakeEvent(
- id="T4",
- sender=ALICE,
- type=EventTypes.Topic,
- state_key="",
- content={},
+ id="T4", sender=ALICE, type=EventTypes.Topic, state_key="", content={}
),
]
@@ -587,13 +465,7 @@ class StateTestCase(unittest.TestCase):
class LexicographicalTestCase(unittest.TestCase):
def test_simple(self):
- graph = {
- "l": {"o"},
- "m": {"n", "o"},
- "n": {"o"},
- "o": set(),
- "p": {"o"},
- }
+ graph = {"l": {"o"}, "m": {"n", "o"}, "n": {"o"}, "o": set(), "p": {"o"}}
res = list(lexicographical_topological_sort(graph, key=lambda x: x))
@@ -680,7 +552,13 @@ class SimpleParamStateTestCase(unittest.TestCase):
self.expected_combined_state = {
(e.type, e.state_key): e.event_id
- for e in [create_event, alice_member, join_rules, bob_member, charlie_member]
+ for e in [
+ create_event,
+ alice_member,
+ join_rules,
+ bob_member,
+ charlie_member,
+ ]
}
def test_event_map_none(self):
@@ -720,11 +598,7 @@ class TestStateResolutionStore(object):
Deferred[dict[str, FrozenEvent]]: Dict from event_id to event.
"""
- return {
- eid: self.event_map[eid]
- for eid in event_ids
- if eid in self.event_map
- }
+ return {eid: self.event_map[eid] for eid in event_ids if eid in self.event_map}
def get_auth_chain(self, event_ids):
"""Gets the full auth chain for a set of events (including rejected
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 3f0083831b..25a6c89ef5 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -340,7 +340,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
other_events = [Mock(event_id="e5"), Mock(event_id="e6")]
# we aren't testing store._base stuff here, so mock this out
- self.store._get_events = Mock(return_value=events)
+ self.store.get_events_as_list = Mock(return_value=events)
yield self._insert_txn(self.as_list[1]["id"], 9, other_events)
yield self._insert_txn(service.id, 10, events)
diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index 5568a607c7..fbb9302694 100644
--- a/tests/storage/test_background_update.py
+++ b/tests/storage/test_background_update.py
@@ -9,9 +9,7 @@ from tests.utils import setup_test_homeserver
class BackgroundUpdateTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
- hs = yield setup_test_homeserver(
- self.addCleanup
- )
+ hs = yield setup_test_homeserver(self.addCleanup)
self.store = hs.get_datastore()
self.clock = hs.get_clock()
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index f18db8c384..c778de1f0c 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -56,10 +56,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
fake_engine = Mock(wraps=engine)
fake_engine.can_native_upsert = False
hs = TestHomeServer(
- "test",
- db_pool=self.db_pool,
- config=config,
- database_engine=fake_engine,
+ "test", db_pool=self.db_pool, config=config, database_engine=fake_engine
)
self.datastore = SQLBaseStore(None, hs)
diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
new file mode 100644
index 0000000000..f4c81ef77d
--- /dev/null
+++ b/tests/storage/test_cleanup_extrems.py
@@ -0,0 +1,224 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os.path
+
+from synapse.storage import prepare_database
+from synapse.types import Requester, UserID
+
+from tests.unittest import HomeserverTestCase
+
+
+class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
+ """
+ Test the background update to clean forward extremities table.
+ """
+
+ def prepare(self, reactor, clock, homeserver):
+ self.store = homeserver.get_datastore()
+ self.room_creator = homeserver.get_room_creation_handler()
+
+ # Create a test user and room
+ self.user = UserID("alice", "test")
+ self.requester = Requester(self.user, None, False, None, None)
+ info = self.get_success(self.room_creator.create_room(self.requester, {}))
+ self.room_id = info["room_id"]
+
+ def run_background_update(self):
+ """Re run the background update to clean up the extremities.
+ """
+ # Make sure we don't clash with in progress updates.
+ self.assertTrue(self.store._all_done, "Background updates are still ongoing")
+
+ schema_path = os.path.join(
+ prepare_database.dir_path,
+ "schema",
+ "delta",
+ "54",
+ "delete_forward_extremities.sql",
+ )
+
+ def run_delta_file(txn):
+ prepare_database.executescript(txn, schema_path)
+
+ self.get_success(
+ self.store.runInteraction("test_delete_forward_extremities", run_delta_file)
+ )
+
+ # Ugh, have to reset this flag
+ self.store._all_done = False
+
+ while not self.get_success(self.store.has_completed_background_updates()):
+ self.get_success(self.store.do_next_background_update(100), by=0.1)
+
+ def test_soft_failed_extremities_handled_correctly(self):
+ """Test that extremities are correctly calculated in the presence of
+ soft failed events.
+
+ Tests a graph like:
+
+ A <- SF1 <- SF2 <- B
+
+ Where SF* are soft failed.
+ """
+
+ # Create the room graph
+ event_id_1 = self.create_and_send_event(self.room_id, self.user)
+ event_id_2 = self.create_and_send_event(
+ self.room_id, self.user, True, [event_id_1]
+ )
+ event_id_3 = self.create_and_send_event(
+ self.room_id, self.user, True, [event_id_2]
+ )
+ event_id_4 = self.create_and_send_event(
+ self.room_id, self.user, False, [event_id_3]
+ )
+
+ # Check the latest events are as expected
+ latest_event_ids = self.get_success(
+ self.store.get_latest_event_ids_in_room(self.room_id)
+ )
+
+ self.assertEqual(latest_event_ids, [event_id_4])
+
+ def test_basic_cleanup(self):
+ """Test that extremities are correctly calculated in the presence of
+ soft failed events.
+
+ Tests a graph like:
+
+ A <- SF1 <- B
+
+ Where SF* are soft failed, and with extremities of A and B
+ """
+ # Create the room graph
+ event_id_a = self.create_and_send_event(self.room_id, self.user)
+ event_id_sf1 = self.create_and_send_event(
+ self.room_id, self.user, True, [event_id_a]
+ )
+ event_id_b = self.create_and_send_event(
+ self.room_id, self.user, False, [event_id_sf1]
+ )
+
+ # Add the new extremity and check the latest events are as expected
+ self.add_extremity(self.room_id, event_id_a)
+
+ latest_event_ids = self.get_success(
+ self.store.get_latest_event_ids_in_room(self.room_id)
+ )
+ self.assertEqual(set(latest_event_ids), set((event_id_a, event_id_b)))
+
+ # Run the background update and check it did the right thing
+ self.run_background_update()
+
+ latest_event_ids = self.get_success(
+ self.store.get_latest_event_ids_in_room(self.room_id)
+ )
+ self.assertEqual(latest_event_ids, [event_id_b])
+
+ def test_chain_of_fail_cleanup(self):
+ """Test that extremities are correctly calculated in the presence of
+ soft failed events.
+
+ Tests a graph like:
+
+ A <- SF1 <- SF2 <- B
+
+ Where SF* are soft failed, and with extremities of A and B
+ """
+ # Create the room graph
+ event_id_a = self.create_and_send_event(self.room_id, self.user)
+ event_id_sf1 = self.create_and_send_event(
+ self.room_id, self.user, True, [event_id_a]
+ )
+ event_id_sf2 = self.create_and_send_event(
+ self.room_id, self.user, True, [event_id_sf1]
+ )
+ event_id_b = self.create_and_send_event(
+ self.room_id, self.user, False, [event_id_sf2]
+ )
+
+ # Add the new extremity and check the latest events are as expected
+ self.add_extremity(self.room_id, event_id_a)
+
+ latest_event_ids = self.get_success(
+ self.store.get_latest_event_ids_in_room(self.room_id)
+ )
+ self.assertEqual(set(latest_event_ids), set((event_id_a, event_id_b)))
+
+ # Run the background update and check it did the right thing
+ self.run_background_update()
+
+ latest_event_ids = self.get_success(
+ self.store.get_latest_event_ids_in_room(self.room_id)
+ )
+ self.assertEqual(latest_event_ids, [event_id_b])
+
+ def test_forked_graph_cleanup(self):
+ r"""Test that extremities are correctly calculated in the presence of
+ soft failed events.
+
+ Tests a graph like, where time flows down the page:
+
+ A B
+ / \ /
+ / \ /
+ SF1 SF2
+ | |
+ SF3 |
+ / \ |
+ | \ |
+ C SF4
+
+ Where SF* are soft failed, and with them A, B and C marked as
+ extremities. This should resolve to B and C being marked as extremity.
+ """
+
+ # Create the room graph
+ event_id_a = self.create_and_send_event(self.room_id, self.user)
+ event_id_b = self.create_and_send_event(self.room_id, self.user)
+ event_id_sf1 = self.create_and_send_event(
+ self.room_id, self.user, True, [event_id_a]
+ )
+ event_id_sf2 = self.create_and_send_event(
+ self.room_id, self.user, True, [event_id_a, event_id_b]
+ )
+ event_id_sf3 = self.create_and_send_event(
+ self.room_id, self.user, True, [event_id_sf1]
+ )
+ self.create_and_send_event(
+ self.room_id, self.user, True, [event_id_sf2, event_id_sf3]
+ ) # SF4
+ event_id_c = self.create_and_send_event(
+ self.room_id, self.user, False, [event_id_sf3]
+ )
+
+ # Add the new extremity and check the latest events are as expected
+ self.add_extremity(self.room_id, event_id_a)
+
+ latest_event_ids = self.get_success(
+ self.store.get_latest_event_ids_in_room(self.room_id)
+ )
+ self.assertEqual(
+ set(latest_event_ids), set((event_id_a, event_id_b, event_id_c))
+ )
+
+ # Run the background update and check it did the right thing
+ self.run_background_update()
+
+ latest_event_ids = self.get_success(
+ self.store.get_latest_event_ids_in_room(self.room_id)
+ )
+ self.assertEqual(set(latest_event_ids), set([event_id_b, event_id_c]))
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index 858efe4992..b62eae7abc 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -18,8 +18,9 @@ from mock import Mock
from twisted.internet import defer
+import synapse.rest.admin
from synapse.http.site import XForwardedForRequest
-from synapse.rest.client.v1 import admin, login
+from synapse.rest.client.v1 import login
from tests import unittest
@@ -205,7 +206,10 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
class ClientIpAuthTestCase(unittest.HomeserverTestCase):
- servlets = [admin.register_servlets, login.register_servlets]
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ login.register_servlets,
+ ]
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver()
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index aef4dfaf57..6396ccddb5 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -72,6 +72,75 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
)
@defer.inlineCallbacks
+ def test_get_devices_by_remote(self):
+ device_ids = ["device_id1", "device_id2"]
+
+ # Add two device updates with a single stream_id
+ yield self.store.add_device_change_to_streams(
+ "user_id", device_ids, ["somehost"],
+ )
+
+ # Get all device updates ever meant for this remote
+ now_stream_id, device_updates = yield self.store.get_devices_by_remote(
+ "somehost", -1, limit=100,
+ )
+
+ # Check original device_ids are contained within these updates
+ self._check_devices_in_updates(device_ids, device_updates)
+
+ @defer.inlineCallbacks
+ def test_get_devices_by_remote_limited(self):
+ # Test breaking the update limit in 1, 101, and 1 device_id segments
+
+ # first add one device
+ device_ids1 = ["device_id0"]
+ yield self.store.add_device_change_to_streams(
+ "user_id", device_ids1, ["someotherhost"],
+ )
+
+ # then add 101
+ device_ids2 = ["device_id" + str(i + 1) for i in range(101)]
+ yield self.store.add_device_change_to_streams(
+ "user_id", device_ids2, ["someotherhost"],
+ )
+
+ # then one more
+ device_ids3 = ["newdevice"]
+ yield self.store.add_device_change_to_streams(
+ "user_id", device_ids3, ["someotherhost"],
+ )
+
+ #
+ # now read them back.
+ #
+
+ # first we should get a single update
+ now_stream_id, device_updates = yield self.store.get_devices_by_remote(
+ "someotherhost", -1, limit=100,
+ )
+ self._check_devices_in_updates(device_ids1, device_updates)
+
+ # Then we should get an empty list back as the 101 devices broke the limit
+ now_stream_id, device_updates = yield self.store.get_devices_by_remote(
+ "someotherhost", now_stream_id, limit=100,
+ )
+ self.assertEqual(len(device_updates), 0)
+
+ # The 101 devices should've been cleared, so we should now just get one device
+ # update
+ now_stream_id, device_updates = yield self.store.get_devices_by_remote(
+ "someotherhost", now_stream_id, limit=100,
+ )
+ self._check_devices_in_updates(device_ids3, device_updates)
+
+ def _check_devices_in_updates(self, expected_device_ids, device_updates):
+ """Check that an specific device ids exist in a list of device update EDUs"""
+ self.assertEqual(len(device_updates), len(expected_device_ids))
+
+ received_device_ids = {update["device_id"] for update in device_updates}
+ self.assertEqual(received_device_ids, set(expected_device_ids))
+
+ @defer.inlineCallbacks
def test_update_device(self):
yield self.store.store_device("user_id", "device_id", "display_name 1")
diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py
index 11fb8c0c19..cd2bcd4ca3 100644
--- a/tests/storage/test_end_to_end_keys.py
+++ b/tests/storage/test_end_to_end_keys.py
@@ -20,7 +20,6 @@ import tests.utils
class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
-
@defer.inlineCallbacks
def setUp(self):
hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py
new file mode 100644
index 0000000000..19f9ccf5e0
--- /dev/null
+++ b/tests/storage/test_event_metrics.py
@@ -0,0 +1,82 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from prometheus_client.exposition import generate_latest
+
+from synapse.metrics import REGISTRY
+from synapse.types import Requester, UserID
+
+from tests.unittest import HomeserverTestCase
+
+
+class ExtremStatisticsTestCase(HomeserverTestCase):
+ def test_exposed_to_prometheus(self):
+ """
+ Forward extremity counts are exposed via Prometheus.
+ """
+ room_creator = self.hs.get_room_creation_handler()
+
+ user = UserID("alice", "test")
+ requester = Requester(user, None, False, None, None)
+
+ # Real events, forward extremities
+ events = [(3, 2), (6, 2), (4, 6)]
+
+ for event_count, extrems in events:
+ info = self.get_success(room_creator.create_room(requester, {}))
+ room_id = info["room_id"]
+
+ last_event = None
+
+ # Make a real event chain
+ for i in range(event_count):
+ ev = self.create_and_send_event(room_id, user, False, last_event)
+ last_event = [ev]
+
+ # Sprinkle in some extremities
+ for i in range(extrems):
+ ev = self.create_and_send_event(room_id, user, False, last_event)
+
+ # Let it run for a while, then pull out the statistics from the
+ # Prometheus client registry
+ self.reactor.advance(60 * 60 * 1000)
+ self.pump(1)
+
+ items = set(
+ filter(
+ lambda x: b"synapse_forward_extremities_" in x,
+ generate_latest(REGISTRY).split(b"\n"),
+ )
+ )
+
+ expected = set([
+ b'synapse_forward_extremities_bucket{le="1.0"} 0.0',
+ b'synapse_forward_extremities_bucket{le="2.0"} 2.0',
+ b'synapse_forward_extremities_bucket{le="3.0"} 2.0',
+ b'synapse_forward_extremities_bucket{le="5.0"} 2.0',
+ b'synapse_forward_extremities_bucket{le="7.0"} 3.0',
+ b'synapse_forward_extremities_bucket{le="10.0"} 3.0',
+ b'synapse_forward_extremities_bucket{le="15.0"} 3.0',
+ b'synapse_forward_extremities_bucket{le="20.0"} 3.0',
+ b'synapse_forward_extremities_bucket{le="50.0"} 3.0',
+ b'synapse_forward_extremities_bucket{le="100.0"} 3.0',
+ b'synapse_forward_extremities_bucket{le="200.0"} 3.0',
+ b'synapse_forward_extremities_bucket{le="500.0"} 3.0',
+ b'synapse_forward_extremities_bucket{le="+Inf"} 3.0',
+ b'synapse_forward_extremities_count 3.0',
+ b'synapse_forward_extremities_sum 10.0',
+ ])
+
+ self.assertEqual(items, expected)
diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py
index 6bfaa00fe9..e07ff01201 100644
--- a/tests/storage/test_keys.py
+++ b/tests/storage/test_keys.py
@@ -17,6 +17,8 @@ import signedjson.key
from twisted.internet.defer import Deferred
+from synapse.storage.keys import FetchKeyResult
+
import tests.unittest
KEY_1 = signedjson.key.decode_verify_key_base64(
@@ -31,23 +33,34 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
def test_get_server_verify_keys(self):
store = self.hs.get_datastore()
- d = store.store_server_verify_key("server1", "from_server", 0, KEY_1)
- self.get_success(d)
- d = store.store_server_verify_key("server1", "from_server", 0, KEY_2)
+ key_id_1 = "ed25519:key1"
+ key_id_2 = "ed25519:KEY_ID_2"
+ d = store.store_server_verify_keys(
+ "from_server",
+ 10,
+ [
+ ("server1", key_id_1, FetchKeyResult(KEY_1, 100)),
+ ("server1", key_id_2, FetchKeyResult(KEY_2, 200)),
+ ],
+ )
self.get_success(d)
d = store.get_server_verify_keys(
- [
- ("server1", "ed25519:key1"),
- ("server1", "ed25519:key2"),
- ("server1", "ed25519:key3"),
- ]
+ [("server1", key_id_1), ("server1", key_id_2), ("server1", "ed25519:key3")]
)
res = self.get_success(d)
self.assertEqual(len(res.keys()), 3)
- self.assertEqual(res[("server1", "ed25519:key1")].version, "key1")
- self.assertEqual(res[("server1", "ed25519:key2")].version, "key2")
+ res1 = res[("server1", key_id_1)]
+ self.assertEqual(res1.verify_key, KEY_1)
+ self.assertEqual(res1.verify_key.version, "key1")
+ self.assertEqual(res1.valid_until_ts, 100)
+
+ res2 = res[("server1", key_id_2)]
+ self.assertEqual(res2.verify_key, KEY_2)
+ # version comes from the ID it was stored with
+ self.assertEqual(res2.verify_key.version, "KEY_ID_2")
+ self.assertEqual(res2.valid_until_ts, 200)
# non-existent result gives None
self.assertIsNone(res[("server1", "ed25519:key3")])
@@ -60,32 +73,51 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
key_id_1 = "ed25519:key1"
key_id_2 = "ed25519:key2"
- d = store.store_server_verify_key("srv1", "from_server", 0, KEY_1)
- self.get_success(d)
- d = store.store_server_verify_key("srv1", "from_server", 0, KEY_2)
+ d = store.store_server_verify_keys(
+ "from_server",
+ 0,
+ [
+ ("srv1", key_id_1, FetchKeyResult(KEY_1, 100)),
+ ("srv1", key_id_2, FetchKeyResult(KEY_2, 200)),
+ ],
+ )
self.get_success(d)
d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)])
res = self.get_success(d)
self.assertEqual(len(res.keys()), 2)
- self.assertEqual(res[("srv1", key_id_1)], KEY_1)
- self.assertEqual(res[("srv1", key_id_2)], KEY_2)
+
+ res1 = res[("srv1", key_id_1)]
+ self.assertEqual(res1.verify_key, KEY_1)
+ self.assertEqual(res1.valid_until_ts, 100)
+
+ res2 = res[("srv1", key_id_2)]
+ self.assertEqual(res2.verify_key, KEY_2)
+ self.assertEqual(res2.valid_until_ts, 200)
# we should be able to look up the same thing again without a db hit
res = store.get_server_verify_keys([("srv1", key_id_1)])
if isinstance(res, Deferred):
res = self.successResultOf(res)
self.assertEqual(len(res.keys()), 1)
- self.assertEqual(res[("srv1", key_id_1)], KEY_1)
+ self.assertEqual(res[("srv1", key_id_1)].verify_key, KEY_1)
new_key_2 = signedjson.key.get_verify_key(
signedjson.key.generate_signing_key("key2")
)
- d = store.store_server_verify_key("srv1", "from_server", 10, new_key_2)
+ d = store.store_server_verify_keys(
+ "from_server", 10, [("srv1", key_id_2, FetchKeyResult(new_key_2, 300))]
+ )
self.get_success(d)
d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)])
res = self.get_success(d)
self.assertEqual(len(res.keys()), 2)
- self.assertEqual(res[("srv1", key_id_1)], KEY_1)
- self.assertEqual(res[("srv1", key_id_2)], new_key_2)
+
+ res1 = res[("srv1", key_id_1)]
+ self.assertEqual(res1.verify_key, KEY_1)
+ self.assertEqual(res1.valid_until_ts, 100)
+
+ res2 = res[("srv1", key_id_2)]
+ self.assertEqual(res2.verify_key, new_key_2)
+ self.assertEqual(res2.valid_until_ts, 300)
diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
index d6569a82bb..f458c03054 100644
--- a/tests/storage/test_monthly_active_users.py
+++ b/tests/storage/test_monthly_active_users.py
@@ -56,8 +56,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.store.register(user_id=user1, token="123", password_hash=None)
self.store.register(user_id=user2, token="456", password_hash=None)
self.store.register(
- user_id=user3, token="789",
- password_hash=None, user_type=UserTypes.SUPPORT
+ user_id=user3, token="789", password_hash=None, user_type=UserTypes.SUPPORT
)
self.pump()
@@ -173,9 +172,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
def test_populate_monthly_users_should_update(self):
self.store.upsert_monthly_active_user = Mock()
- self.store.is_trial_user = Mock(
- return_value=defer.succeed(False)
- )
+ self.store.is_trial_user = Mock(return_value=defer.succeed(False))
self.store.user_last_seen_monthly_active = Mock(
return_value=defer.succeed(None)
@@ -187,13 +184,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
def test_populate_monthly_users_should_not_update(self):
self.store.upsert_monthly_active_user = Mock()
- self.store.is_trial_user = Mock(
- return_value=defer.succeed(False)
- )
+ self.store.is_trial_user = Mock(return_value=defer.succeed(False))
self.store.user_last_seen_monthly_active = Mock(
- return_value=defer.succeed(
- self.hs.get_clock().time_msec()
- )
+ return_value=defer.succeed(self.hs.get_clock().time_msec())
)
self.store.populate_monthly_active_users('user_id')
self.pump()
@@ -243,7 +236,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
user_id=support_user_id,
token="123",
password_hash=None,
- user_type=UserTypes.SUPPORT
+ user_type=UserTypes.SUPPORT,
)
self.store.upsert_monthly_active_user(support_user_id)
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index 0fc5019e9f..4823d44dec 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -60,7 +60,7 @@ class RedactionTestCase(unittest.TestCase):
"state_key": user.to_string(),
"room_id": room.to_string(),
"content": content,
- }
+ },
)
event, context = yield self.event_creation_handler.create_new_client_event(
@@ -83,7 +83,7 @@ class RedactionTestCase(unittest.TestCase):
"state_key": user.to_string(),
"room_id": room.to_string(),
"content": {"body": body, "msgtype": u"message"},
- }
+ },
)
event, context = yield self.event_creation_handler.create_new_client_event(
@@ -105,7 +105,7 @@ class RedactionTestCase(unittest.TestCase):
"room_id": room.to_string(),
"content": {"reason": reason},
"redacts": event_id,
- }
+ },
)
event, context = yield self.event_creation_handler.create_new_client_event(
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index cb3cc4d2e5..c0e0155bb4 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -116,7 +116,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
user_id=SUPPORT_USER,
token="456",
password_hash=None,
- user_type=UserTypes.SUPPORT
+ user_type=UserTypes.SUPPORT,
)
res = yield self.store.is_support_user(SUPPORT_USER)
self.assertTrue(res)
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 063387863e..73ed943f5a 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -58,7 +58,7 @@ class RoomMemberStoreTestCase(unittest.TestCase):
"state_key": user.to_string(),
"room_id": room.to_string(),
"content": {"membership": membership},
- }
+ },
)
event, context = yield self.event_creation_handler.create_new_client_event(
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 78e260a7fa..b6169436de 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -29,7 +29,6 @@ logger = logging.getLogger(__name__)
class StateStoreTestCase(tests.unittest.TestCase):
-
@defer.inlineCallbacks
def setUp(self):
hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
@@ -57,7 +56,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
"state_key": state_key,
"room_id": room.to_string(),
"content": content,
- }
+ },
)
event, context = yield self.event_creation_handler.create_new_client_event(
@@ -83,15 +82,14 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.room, self.u_alice, EventTypes.Name, '', {"name": "test room"}
)
- state_group_map = yield self.store.get_state_groups_ids(self.room, [e2.event_id])
+ state_group_map = yield self.store.get_state_groups_ids(
+ self.room, [e2.event_id]
+ )
self.assertEqual(len(state_group_map), 1)
state_map = list(state_group_map.values())[0]
self.assertDictEqual(
state_map,
- {
- (EventTypes.Create, ''): e1.event_id,
- (EventTypes.Name, ''): e2.event_id,
- },
+ {(EventTypes.Create, ''): e1.event_id, (EventTypes.Name, ''): e2.event_id},
)
@defer.inlineCallbacks
@@ -103,15 +101,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.room, self.u_alice, EventTypes.Name, '', {"name": "test room"}
)
- state_group_map = yield self.store.get_state_groups(
- self.room, [e2.event_id])
+ state_group_map = yield self.store.get_state_groups(self.room, [e2.event_id])
self.assertEqual(len(state_group_map), 1)
state_list = list(state_group_map.values())[0]
- self.assertEqual(
- {ev.event_id for ev in state_list},
- {e1.event_id, e2.event_id},
- )
+ self.assertEqual({ev.event_id for ev in state_list}, {e1.event_id, e2.event_id})
@defer.inlineCallbacks
def test_get_state_for_event(self):
@@ -147,9 +141,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
)
# check we get the full state as of the final event
- state = yield self.store.get_state_for_event(
- e5.event_id,
- )
+ state = yield self.store.get_state_for_event(e5.event_id)
self.assertIsNotNone(e4)
@@ -194,7 +186,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
state_filter=StateFilter(
types={EventTypes.Member: {self.u_alice.to_string()}},
include_others=True,
- )
+ ),
)
self.assertStateMapEqual(
@@ -208,9 +200,9 @@ class StateStoreTestCase(tests.unittest.TestCase):
# check that we can grab everything except members
state = yield self.store.get_state_for_event(
- e5.event_id, state_filter=StateFilter(
- types={EventTypes.Member: set()},
- include_others=True,
+ e5.event_id,
+ state_filter=StateFilter(
+ types={EventTypes.Member: set()}, include_others=True
),
)
@@ -229,10 +221,10 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters out members
# with types=[]
(state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
- self.store._state_group_cache, group,
+ self.store._state_group_cache,
+ group,
state_filter=StateFilter(
- types={EventTypes.Member: set()},
- include_others=True,
+ types={EventTypes.Member: set()}, include_others=True
),
)
@@ -249,8 +241,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.store._state_group_members_cache,
group,
state_filter=StateFilter(
- types={EventTypes.Member: set()},
- include_others=True,
+ types={EventTypes.Member: set()}, include_others=True
),
)
@@ -263,8 +254,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.store._state_group_cache,
group,
state_filter=StateFilter(
- types={EventTypes.Member: None},
- include_others=True,
+ types={EventTypes.Member: None}, include_others=True
),
)
@@ -281,8 +271,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.store._state_group_members_cache,
group,
state_filter=StateFilter(
- types={EventTypes.Member: None},
- include_others=True,
+ types={EventTypes.Member: None}, include_others=True
),
)
@@ -302,8 +291,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.store._state_group_cache,
group,
state_filter=StateFilter(
- types={EventTypes.Member: {e5.state_key}},
- include_others=True,
+ types={EventTypes.Member: {e5.state_key}}, include_others=True
),
)
@@ -320,8 +308,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.store._state_group_members_cache,
group,
state_filter=StateFilter(
- types={EventTypes.Member: {e5.state_key}},
- include_others=True,
+ types={EventTypes.Member: {e5.state_key}}, include_others=True
),
)
@@ -334,8 +321,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.store._state_group_members_cache,
group,
state_filter=StateFilter(
- types={EventTypes.Member: {e5.state_key}},
- include_others=False,
+ types={EventTypes.Member: {e5.state_key}}, include_others=False
),
)
@@ -384,10 +370,10 @@ class StateStoreTestCase(tests.unittest.TestCase):
# with types=[]
room_id = self.room.to_string()
(state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
- self.store._state_group_cache, group,
+ self.store._state_group_cache,
+ group,
state_filter=StateFilter(
- types={EventTypes.Member: set()},
- include_others=True,
+ types={EventTypes.Member: set()}, include_others=True
),
)
@@ -399,8 +385,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.store._state_group_members_cache,
group,
state_filter=StateFilter(
- types={EventTypes.Member: set()},
- include_others=True,
+ types={EventTypes.Member: set()}, include_others=True
),
)
@@ -413,8 +398,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.store._state_group_cache,
group,
state_filter=StateFilter(
- types={EventTypes.Member: None},
- include_others=True,
+ types={EventTypes.Member: None}, include_others=True
),
)
@@ -425,8 +409,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.store._state_group_members_cache,
group,
state_filter=StateFilter(
- types={EventTypes.Member: None},
- include_others=True,
+ types={EventTypes.Member: None}, include_others=True
),
)
@@ -445,8 +428,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.store._state_group_cache,
group,
state_filter=StateFilter(
- types={EventTypes.Member: {e5.state_key}},
- include_others=True,
+ types={EventTypes.Member: {e5.state_key}}, include_others=True
),
)
@@ -457,8 +439,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.store._state_group_members_cache,
group,
state_filter=StateFilter(
- types={EventTypes.Member: {e5.state_key}},
- include_others=True,
+ types={EventTypes.Member: {e5.state_key}}, include_others=True
),
)
@@ -471,8 +452,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.store._state_group_cache,
group,
state_filter=StateFilter(
- types={EventTypes.Member: {e5.state_key}},
- include_others=False,
+ types={EventTypes.Member: {e5.state_key}}, include_others=False
),
)
@@ -483,8 +463,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.store._state_group_members_cache,
group,
state_filter=StateFilter(
- types={EventTypes.Member: {e5.state_key}},
- include_others=False,
+ types={EventTypes.Member: {e5.state_key}}, include_others=False
),
)
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index fd3361404f..d7d244ce97 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -36,9 +36,7 @@ class UserDirectoryStoreTestCase(unittest.TestCase):
yield self.store.update_profile_in_user_dir(ALICE, "alice", None)
yield self.store.update_profile_in_user_dir(BOB, "bob", None)
yield self.store.update_profile_in_user_dir(BOBBY, "bobby", None)
- yield self.store.add_users_in_public_rooms(
- "!room:id", (ALICE, BOB)
- )
+ yield self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB))
@defer.inlineCallbacks
def test_search_user_dir(self):
diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py
index 4c8f87e958..8b2741d277 100644
--- a/tests/test_event_auth.py
+++ b/tests/test_event_auth.py
@@ -37,7 +37,9 @@ class EventAuthTestCase(unittest.TestCase):
# creator should be able to send state
event_auth.check(
- RoomVersions.V1.identifier, _random_state_event(creator), auth_events,
+ RoomVersions.V1.identifier,
+ _random_state_event(creator),
+ auth_events,
do_sig_check=False,
)
@@ -82,7 +84,9 @@ class EventAuthTestCase(unittest.TestCase):
# king should be able to send state
event_auth.check(
- RoomVersions.V1.identifier, _random_state_event(king), auth_events,
+ RoomVersions.V1.identifier,
+ _random_state_event(king),
+ auth_events,
do_sig_check=False,
)
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 1a5dc32c88..6a8339b561 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -1,4 +1,3 @@
-
from mock import Mock
from twisted.internet.defer import maybeDeferred, succeed
diff --git a/tests/test_mau.py b/tests/test_mau.py
index 00be1a8c21..1fbe0d51ff 100644
--- a/tests/test_mau.py
+++ b/tests/test_mau.py
@@ -33,9 +33,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
self.hs = self.setup_test_homeserver(
- "red",
- http_client=None,
- federation_client=Mock(),
+ "red", http_client=None, federation_client=Mock()
)
self.store = self.hs.get_datastore()
@@ -210,9 +208,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
return access_token
def do_sync_for_user(self, token):
- request, channel = self.make_request(
- "GET", "/sync", access_token=token
- )
+ request, channel = self.make_request("GET", "/sync", access_token=token)
self.render(request)
if channel.code != 200:
diff --git a/tests/test_metrics.py b/tests/test_metrics.py
index 0ff6d0e283..2edbae5c6d 100644
--- a/tests/test_metrics.py
+++ b/tests/test_metrics.py
@@ -44,9 +44,7 @@ def get_sample_labels_value(sample):
class TestMauLimit(unittest.TestCase):
def test_basic(self):
gauge = InFlightGauge(
- "test1", "",
- labels=["test_label"],
- sub_metrics=["foo", "bar"],
+ "test1", "", labels=["test_label"], sub_metrics=["foo", "bar"]
)
def handle1(metrics):
@@ -59,37 +57,49 @@ class TestMauLimit(unittest.TestCase):
gauge.register(("key1",), handle1)
- self.assert_dict({
- "test1_total": {("key1",): 1},
- "test1_foo": {("key1",): 2},
- "test1_bar": {("key1",): 5},
- }, self.get_metrics_from_gauge(gauge))
+ self.assert_dict(
+ {
+ "test1_total": {("key1",): 1},
+ "test1_foo": {("key1",): 2},
+ "test1_bar": {("key1",): 5},
+ },
+ self.get_metrics_from_gauge(gauge),
+ )
gauge.unregister(("key1",), handle1)
- self.assert_dict({
- "test1_total": {("key1",): 0},
- "test1_foo": {("key1",): 0},
- "test1_bar": {("key1",): 0},
- }, self.get_metrics_from_gauge(gauge))
+ self.assert_dict(
+ {
+ "test1_total": {("key1",): 0},
+ "test1_foo": {("key1",): 0},
+ "test1_bar": {("key1",): 0},
+ },
+ self.get_metrics_from_gauge(gauge),
+ )
gauge.register(("key1",), handle1)
gauge.register(("key2",), handle2)
- self.assert_dict({
- "test1_total": {("key1",): 1, ("key2",): 1},
- "test1_foo": {("key1",): 2, ("key2",): 3},
- "test1_bar": {("key1",): 5, ("key2",): 7},
- }, self.get_metrics_from_gauge(gauge))
+ self.assert_dict(
+ {
+ "test1_total": {("key1",): 1, ("key2",): 1},
+ "test1_foo": {("key1",): 2, ("key2",): 3},
+ "test1_bar": {("key1",): 5, ("key2",): 7},
+ },
+ self.get_metrics_from_gauge(gauge),
+ )
gauge.unregister(("key2",), handle2)
gauge.register(("key1",), handle2)
- self.assert_dict({
- "test1_total": {("key1",): 2, ("key2",): 0},
- "test1_foo": {("key1",): 5, ("key2",): 0},
- "test1_bar": {("key1",): 7, ("key2",): 0},
- }, self.get_metrics_from_gauge(gauge))
+ self.assert_dict(
+ {
+ "test1_total": {("key1",): 2, ("key2",): 0},
+ "test1_foo": {("key1",): 5, ("key2",): 0},
+ "test1_bar": {("key1",): 7, ("key2",): 0},
+ },
+ self.get_metrics_from_gauge(gauge),
+ )
def get_metrics_from_gauge(self, gauge):
results = {}
diff --git a/tests/test_state.py b/tests/test_state.py
index 5bcc6aaa18..6491a7105a 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -168,7 +168,7 @@ class StateTestCase(unittest.TestCase):
"get_state_resolution_handler",
]
)
- hs.config = default_config("tesths")
+ hs.config = default_config("tesths", True)
hs.get_datastore.return_value = self.store
hs.get_state_handler.return_value = None
hs.get_clock.return_value = MockClock()
diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py
index 0968e86a7b..52739fbabc 100644
--- a/tests/test_terms_auth.py
+++ b/tests/test_terms_auth.py
@@ -59,7 +59,7 @@ class TermsTestCase(unittest.HomeserverTestCase):
for flow in channel.json_body["flows"]:
self.assertIsInstance(flow["stages"], list)
self.assertTrue(len(flow["stages"]) > 0)
- self.assertEquals(flow["stages"][-1], "m.login.terms")
+ self.assertTrue("m.login.terms" in flow["stages"])
expected_params = {
"m.login.terms": {
@@ -69,10 +69,10 @@ class TermsTestCase(unittest.HomeserverTestCase):
"name": "My Cool Privacy Policy",
"url": "https://example.org/_matrix/consent?v=1.0",
},
- "version": "1.0"
- },
- },
- },
+ "version": "1.0",
+ }
+ }
+ }
}
self.assertIsInstance(channel.json_body["params"], dict)
self.assertDictContainsSubset(channel.json_body["params"], expected_params)
diff --git a/tests/test_types.py b/tests/test_types.py
index d314a7ff58..d83c36559f 100644
--- a/tests/test_types.py
+++ b/tests/test_types.py
@@ -94,8 +94,7 @@ class MapUsernameTestCase(unittest.TestCase):
def testSymbols(self):
self.assertEqual(
- map_username_to_mxid_localpart("test=$?_1234"),
- "test=3d=24=3f_1234",
+ map_username_to_mxid_localpart("test=$?_1234"), "test=3d=24=3f_1234"
)
def testLeadingUnderscore(self):
@@ -105,6 +104,5 @@ class MapUsernameTestCase(unittest.TestCase):
# this should work with either a unicode or a bytes
self.assertEqual(map_username_to_mxid_localpart(u'têst'), "t=c3=aast")
self.assertEqual(
- map_username_to_mxid_localpart(u'têst'.encode('utf-8')),
- "t=c3=aast",
+ map_username_to_mxid_localpart(u'têst'.encode('utf-8')), "t=c3=aast"
)
diff --git a/tests/test_utils/logging_setup.py b/tests/test_utils/logging_setup.py
index d0bc8e2112..fde0baee8e 100644
--- a/tests/test_utils/logging_setup.py
+++ b/tests/test_utils/logging_setup.py
@@ -22,6 +22,7 @@ from synapse.util.logcontext import LoggingContextFilter
class ToTwistedHandler(logging.Handler):
"""logging handler which sends the logs to the twisted log"""
+
tx_log = twisted.logger.Logger()
def emit(self, record):
@@ -41,7 +42,8 @@ def setup_logging():
root_logger = logging.getLogger()
log_format = (
- "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s - %(message)s"
+ "%(asctime)s - %(name)s - %(lineno)d - "
+ "%(levelname)s - %(request)s - %(message)s"
)
handler = ToTwistedHandler()
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index 3bdb500514..6a180ddc32 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -132,7 +132,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
"state_key": "",
"room_id": TEST_ROOM_ID,
"content": content,
- }
+ },
)
event, context = yield self.event_creation_handler.create_new_client_event(
@@ -153,7 +153,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
"state_key": user_id,
"room_id": TEST_ROOM_ID,
"content": content,
- }
+ },
)
event, context = yield self.event_creation_handler.create_new_client_event(
@@ -174,7 +174,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
"sender": user_id,
"room_id": TEST_ROOM_ID,
"content": content,
- }
+ },
)
event, context = yield self.event_creation_handler.create_new_client_event(
diff --git a/tests/unittest.py b/tests/unittest.py
index 8c65736a51..b6dc7932ce 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -27,10 +27,12 @@ import twisted.logger
from twisted.internet.defer import Deferred
from twisted.trial import unittest
+from synapse.api.constants import EventTypes
+from synapse.config.homeserver import HomeServerConfig
from synapse.http.server import JsonResource
from synapse.http.site import SynapseRequest
from synapse.server import HomeServer
-from synapse.types import UserID, create_requester
+from synapse.types import Requester, UserID, create_requester
from synapse.util.logcontext import LoggingContext
from tests.server import get_clock, make_request, render, setup_test_homeserver
@@ -84,9 +86,8 @@ class TestCase(unittest.TestCase):
# all future bets are off.
if LoggingContext.current_context() is not LoggingContext.sentinel:
self.fail(
- "Test starting with non-sentinel logging context %s" % (
- LoggingContext.current_context(),
- )
+ "Test starting with non-sentinel logging context %s"
+ % (LoggingContext.current_context(),)
)
old_level = logging.getLogger().level
@@ -181,10 +182,7 @@ class HomeserverTestCase(TestCase):
raise Exception("A homeserver wasn't returned, but %r" % (self.hs,))
# Register the resources
- self.resource = JsonResource(self.hs)
-
- for servlet in self.servlets:
- servlet(self.hs, self.resource)
+ self.resource = self.create_test_json_resource()
from tests.rest.client.v1.utils import RestHelper
@@ -230,9 +228,26 @@ class HomeserverTestCase(TestCase):
hs = self.setup_test_homeserver()
return hs
+ def create_test_json_resource(self):
+ """
+ Create a test JsonResource, with the relevant servlets registerd to it
+
+ The default implementation calls each function in `servlets` to do the
+ registration.
+
+ Returns:
+ JsonResource:
+ """
+ resource = JsonResource(self.hs)
+
+ for servlet in self.servlets:
+ servlet(self.hs, resource)
+
+ return resource
+
def default_config(self, name="test"):
"""
- Get a default HomeServer config object.
+ Get a default HomeServer config dict.
Args:
name (str): The homeserver name/domain.
@@ -286,7 +301,13 @@ class HomeserverTestCase(TestCase):
content = json.dumps(content).encode('utf8')
return make_request(
- self.reactor, method, path, content, access_token, request, shorthand,
+ self.reactor,
+ method,
+ path,
+ content,
+ access_token,
+ request,
+ shorthand,
federation_auth_origin,
)
@@ -316,7 +337,14 @@ class HomeserverTestCase(TestCase):
kwargs.update(self._hs_args)
if "config" not in kwargs:
config = self.default_config()
- kwargs["config"] = config
+ else:
+ config = kwargs["config"]
+
+ # Parse the config from a config dict into a HomeServerConfig
+ config_obj = HomeServerConfig()
+ config_obj.parse_config_dict(config)
+ kwargs["config"] = config_obj
+
hs = setup_test_homeserver(self.addCleanup, *args, **kwargs)
stor = hs.get_datastore()
@@ -414,3 +442,73 @@ class HomeserverTestCase(TestCase):
access_token = channel.json_body["access_token"]
return access_token
+
+ def create_and_send_event(
+ self, room_id, user, soft_failed=False, prev_event_ids=None
+ ):
+ """
+ Create and send an event.
+
+ Args:
+ soft_failed (bool): Whether to create a soft failed event or not
+ prev_event_ids (list[str]|None): Explicitly set the prev events,
+ or if None just use the default
+
+ Returns:
+ str: The new event's ID.
+ """
+ event_creator = self.hs.get_event_creation_handler()
+ secrets = self.hs.get_secrets()
+ requester = Requester(user, None, False, None, None)
+
+ prev_events_and_hashes = None
+ if prev_event_ids:
+ prev_events_and_hashes = [[p, {}, 0] for p in prev_event_ids]
+
+ event, context = self.get_success(
+ event_creator.create_event(
+ requester,
+ {
+ "type": EventTypes.Message,
+ "room_id": room_id,
+ "sender": user.to_string(),
+ "content": {"body": secrets.token_hex(), "msgtype": "m.text"},
+ },
+ prev_events_and_hashes=prev_events_and_hashes,
+ )
+ )
+
+ if soft_failed:
+ event.internal_metadata.soft_failed = True
+
+ self.get_success(
+ event_creator.send_nonmember_event(requester, event, context)
+ )
+
+ return event.event_id
+
+ def add_extremity(self, room_id, event_id):
+ """
+ Add the given event as an extremity to the room.
+ """
+ self.get_success(
+ self.hs.get_datastore()._simple_insert(
+ table="event_forward_extremities",
+ values={"room_id": room_id, "event_id": event_id},
+ desc="test_add_extremity",
+ )
+ )
+
+ self.hs.get_datastore().get_latest_event_ids_in_room.invalidate((room_id,))
+
+ def attempt_wrong_password_login(self, username, password):
+ """Attempts to login as the user with the given password, asserting
+ that the attempt *fails*.
+ """
+ body = {"type": "m.login.password", "user": username, "password": password}
+
+ request, channel = self.make_request(
+ "POST", "/_matrix/client/r0/login", json.dumps(body).encode('utf8')
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 403, channel.result)
diff --git a/tests/util/test_async_utils.py b/tests/util/test_async_utils.py
index 84dd71e47a..bf85d3b8ec 100644
--- a/tests/util/test_async_utils.py
+++ b/tests/util/test_async_utils.py
@@ -42,10 +42,10 @@ class TimeoutDeferredTest(TestCase):
self.assertNoResult(timing_out_d)
self.assertFalse(cancelled[0], "deferred was cancelled prematurely")
- self.clock.pump((1.0, ))
+ self.clock.pump((1.0,))
self.assertTrue(cancelled[0], "deferred was not cancelled by timeout")
- self.failureResultOf(timing_out_d, defer.TimeoutError, )
+ self.failureResultOf(timing_out_d, defer.TimeoutError)
def test_times_out_when_canceller_throws(self):
"""Test that we have successfully worked around
@@ -59,9 +59,9 @@ class TimeoutDeferredTest(TestCase):
self.assertNoResult(timing_out_d)
- self.clock.pump((1.0, ))
+ self.clock.pump((1.0,))
- self.failureResultOf(timing_out_d, defer.TimeoutError, )
+ self.failureResultOf(timing_out_d, defer.TimeoutError)
def test_logcontext_is_preserved_on_cancellation(self):
blocking_was_cancelled = [False]
@@ -80,10 +80,10 @@ class TimeoutDeferredTest(TestCase):
# the errbacks should be run in the test logcontext
def errback(res, deferred_name):
self.assertIs(
- LoggingContext.current_context(), context_one,
- "errback %s run in unexpected logcontext %s" % (
- deferred_name, LoggingContext.current_context(),
- )
+ LoggingContext.current_context(),
+ context_one,
+ "errback %s run in unexpected logcontext %s"
+ % (deferred_name, LoggingContext.current_context()),
)
return res
@@ -94,11 +94,10 @@ class TimeoutDeferredTest(TestCase):
self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel)
timing_out_d.addErrback(errback, "timingout")
- self.clock.pump((1.0, ))
+ self.clock.pump((1.0,))
self.assertTrue(
- blocking_was_cancelled[0],
- "non-completing deferred was not cancelled",
+ blocking_was_cancelled[0], "non-completing deferred was not cancelled"
)
- self.failureResultOf(timing_out_d, defer.TimeoutError, )
+ self.failureResultOf(timing_out_d, defer.TimeoutError)
self.assertIs(LoggingContext.current_context(), context_one)
diff --git a/tests/util/test_linearizer.py b/tests/util/test_linearizer.py
index 61a55b461b..ec7ba9719c 100644
--- a/tests/util/test_linearizer.py
+++ b/tests/util/test_linearizer.py
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
-# Copyright 2018 New Vector Ltd.
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tests/utils.py b/tests/utils.py
index cb75514851..f8c7ad2604 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -31,6 +31,7 @@ from synapse.api.constants import EventTypes
from synapse.api.errors import CodeMessageException, cs_error
from synapse.api.room_versions import RoomVersions
from synapse.config.homeserver import HomeServerConfig
+from synapse.config.server import DEFAULT_ROOM_VERSION
from synapse.federation.transport import server as federation_server
from synapse.http.server import HttpServer
from synapse.server import HomeServer
@@ -68,7 +69,9 @@ def setupdb():
# connect to postgres to create the base database.
db_conn = db_engine.module.connect(
- user=POSTGRES_USER, host=POSTGRES_HOST, password=POSTGRES_PASSWORD,
+ user=POSTGRES_USER,
+ host=POSTGRES_HOST,
+ password=POSTGRES_PASSWORD,
dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE,
)
db_conn.autocommit = True
@@ -94,7 +97,9 @@ def setupdb():
def _cleanup():
db_conn = db_engine.module.connect(
- user=POSTGRES_USER, host=POSTGRES_HOST, password=POSTGRES_PASSWORD,
+ user=POSTGRES_USER,
+ host=POSTGRES_HOST,
+ password=POSTGRES_PASSWORD,
dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE,
)
db_conn.autocommit = True
@@ -106,7 +111,7 @@ def setupdb():
atexit.register(_cleanup)
-def default_config(name):
+def default_config(name, parse=False):
"""
Create a reasonable test config.
"""
@@ -114,79 +119,73 @@ def default_config(name):
"server_name": name,
"media_store_path": "media",
"uploads_path": "uploads",
-
# the test signing key is just an arbitrary ed25519 key to keep the config
# parser happy
"signing_key": "ed25519 a_lPym qvioDNmfExFBRPgdTU+wtFYKq4JfwFRv7sYVgWvmgJg",
+ "event_cache_size": 1,
+ "enable_registration": True,
+ "enable_registration_captcha": False,
+ "macaroon_secret_key": "not even a little secret",
+ "expire_access_token": False,
+ "trusted_third_party_id_servers": [],
+ "room_invite_state_types": [],
+ "password_providers": [],
+ "worker_replication_url": "",
+ "worker_app": None,
+ "block_non_admin_invites": False,
+ "federation_domain_whitelist": None,
+ "filter_timeline_limit": 5000,
+ "user_directory_search_all_users": False,
+ "user_consent_server_notice_content": None,
+ "block_events_without_consent_error": None,
+ "user_consent_at_registration": False,
+ "user_consent_policy_name": "Privacy Policy",
+ "media_storage_providers": [],
+ "autocreate_auto_join_rooms": True,
+ "auto_join_rooms": [],
+ "limit_usage_by_mau": False,
+ "hs_disabled": False,
+ "hs_disabled_message": "",
+ "hs_disabled_limit_type": "",
+ "max_mau_value": 50,
+ "mau_trial_days": 0,
+ "mau_stats_only": False,
+ "mau_limits_reserved_threepids": [],
+ "admin_contact": None,
+ "rc_federation": {
+ "reject_limit": 10,
+ "sleep_limit": 10,
+ "sleep_delay": 10,
+ "concurrent": 10,
+ },
+ "rc_message": {"per_second": 10000, "burst_count": 10000},
+ "rc_registration": {"per_second": 10000, "burst_count": 10000},
+ "rc_login": {
+ "address": {"per_second": 10000, "burst_count": 10000},
+ "account": {"per_second": 10000, "burst_count": 10000},
+ "failed_attempts": {"per_second": 10000, "burst_count": 10000},
+ },
+ "saml2_enabled": False,
+ "public_baseurl": None,
+ "default_identity_server": None,
+ "key_refresh_interval": 24 * 60 * 60 * 1000,
+ "old_signing_keys": {},
+ "tls_fingerprints": [],
+ "use_frozen_dicts": False,
+ # We need a sane default_room_version, otherwise attempts to create
+ # rooms will fail.
+ "default_room_version": DEFAULT_ROOM_VERSION,
+ # disable user directory updates, because they get done in the
+ # background, which upsets the test runner.
+ "update_user_directory": False,
}
- config = HomeServerConfig()
- config.parse_config_dict(config_dict)
-
- # TODO: move this stuff into config_dict or get rid of it
- config.event_cache_size = 1
- config.enable_registration = True
- config.enable_registration_captcha = False
- config.macaroon_secret_key = "not even a little secret"
- config.expire_access_token = False
- config.trusted_third_party_id_servers = []
- config.room_invite_state_types = []
- config.password_providers = []
- config.worker_replication_url = ""
- config.worker_app = None
- config.email_enable_notifs = False
- config.block_non_admin_invites = False
- config.federation_domain_whitelist = None
- config.federation_rc_reject_limit = 10
- config.federation_rc_sleep_limit = 10
- config.federation_rc_sleep_delay = 100
- config.federation_rc_concurrent = 10
- config.filter_timeline_limit = 5000
- config.user_directory_search_all_users = False
- config.user_consent_server_notice_content = None
- config.block_events_without_consent_error = None
- config.user_consent_at_registration = False
- config.user_consent_policy_name = "Privacy Policy"
- config.media_storage_providers = []
- config.autocreate_auto_join_rooms = True
- config.auto_join_rooms = []
- config.limit_usage_by_mau = False
- config.hs_disabled = False
- config.hs_disabled_message = ""
- config.hs_disabled_limit_type = ""
- config.max_mau_value = 50
- config.mau_trial_days = 0
- config.mau_stats_only = False
- config.mau_limits_reserved_threepids = []
- config.admin_contact = None
- config.rc_messages_per_second = 10000
- config.rc_message_burst_count = 10000
- config.rc_registration.per_second = 10000
- config.rc_registration.burst_count = 10000
- config.rc_login_address.per_second = 10000
- config.rc_login_address.burst_count = 10000
- config.rc_login_account.per_second = 10000
- config.rc_login_account.burst_count = 10000
- config.rc_login_failed_attempts.per_second = 10000
- config.rc_login_failed_attempts.burst_count = 10000
- config.saml2_enabled = False
- config.public_baseurl = None
- config.default_identity_server = None
- config.key_refresh_interval = 24 * 60 * 60 * 1000
- config.old_signing_keys = {}
- config.tls_fingerprints = []
-
- config.use_frozen_dicts = False
-
- # we need a sane default_room_version, otherwise attempts to create rooms will
- # fail.
- config.default_room_version = "1"
-
- # disable user directory updates, because they get done in the
- # background, which upsets the test runner.
- config.update_user_directory = False
-
- return config
+ if parse:
+ config = HomeServerConfig()
+ config.parse_config_dict(config_dict)
+ return config
+
+ return config_dict
class TestHomeServer(HomeServer):
@@ -220,7 +219,7 @@ def setup_test_homeserver(
from twisted.internet import reactor
if config is None:
- config = default_config(name)
+ config = default_config(name, parse=True)
config.ldap_enabled = False
@@ -377,12 +376,7 @@ def register_federation_servlets(hs, resource):
resource=resource,
authenticator=federation_server.Authenticator(hs),
ratelimiter=FederationRateLimiter(
- hs.get_clock(),
- window_size=hs.config.federation_rc_window_size,
- sleep_limit=hs.config.federation_rc_sleep_limit,
- sleep_msec=hs.config.federation_rc_sleep_delay,
- reject_limit=hs.config.federation_rc_reject_limit,
- concurrent_requests=hs.config.federation_rc_concurrent,
+ hs.get_clock(), config=hs.config.rc_federation
),
)
|