diff options
author | Amber H. Brown <hawkowl@atleastfornow.net> | 2019-06-18 21:20:13 +1000 |
---|---|---|
committer | Amber H. Brown <hawkowl@atleastfornow.net> | 2019-06-18 21:20:13 +1000 |
commit | e60aab14b4203b91bad23417b1664289b4b558b6 (patch) | |
tree | 1ea1d01dd14ad85d5eeda1c460282f7c615dd55c /tests | |
parent | Merge remote-tracking branch 'origin/master' into shhs (diff) | |
parent | Fix seven contrib files with Python syntax errors (#5446) (diff) | |
download | synapse-e60aab14b4203b91bad23417b1664289b4b558b6.tar.xz |
Merge remote-tracking branch 'origin/develop' into shhs
Diffstat (limited to 'tests')
25 files changed, 1597 insertions, 325 deletions
diff --git a/tests/__init__.py b/tests/__init__.py index d3181f9403..f7fc502f01 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -21,4 +21,4 @@ import tests.patch_inline_callbacks # attempt to do the patch before we load any synapse code tests.patch_inline_callbacks.do_patch() -util.DEFAULT_TIMEOUT_DURATION = 10 +util.DEFAULT_TIMEOUT_DURATION = 20 diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 3c79d4afe7..5a355f00cc 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -19,12 +19,18 @@ from mock import Mock import canonicaljson import signedjson.key import signedjson.sign +from signedjson.key import encode_verify_key_base64, get_verify_key from twisted.internet import defer from synapse.api.errors import SynapseError from synapse.crypto import keyring -from synapse.crypto.keyring import KeyLookupError +from synapse.crypto.keyring import ( + PerspectivesKeyFetcher, + ServerKeyFetcher, + StoreKeyFetcher, +) +from synapse.storage.keys import FetchKeyResult from synapse.util import logcontext from synapse.util.logcontext import LoggingContext @@ -38,7 +44,7 @@ class MockPerspectiveServer(object): def get_verify_keys(self): vk = signedjson.key.get_verify_key(self.key) - return {"%s:%s" % (vk.alg, vk.version): vk} + return {"%s:%s" % (vk.alg, vk.version): encode_verify_key_base64(vk)} def get_signed_key(self, server_name, verify_key): key_id = "%s:%s" % (verify_key.alg, verify_key.version) @@ -46,25 +52,31 @@ class MockPerspectiveServer(object): "server_name": server_name, "old_verify_keys": {}, "valid_until_ts": time.time() * 1000 + 3600, - "verify_keys": { - key_id: {"key": signedjson.key.encode_verify_key_base64(verify_key)} - }, + "verify_keys": {key_id: {"key": encode_verify_key_base64(verify_key)}}, } - return self.get_signed_response(res) + self.sign_response(res) + return res - def get_signed_response(self, res): + def sign_response(self, res): signedjson.sign.sign_json(res, self.server_name, self.key) - return res class KeyringTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): self.mock_perspective_server = MockPerspectiveServer() self.http_client = Mock() - hs = self.setup_test_homeserver(handlers=None, http_client=self.http_client) - keys = self.mock_perspective_server.get_verify_keys() - hs.config.perspectives = {self.mock_perspective_server.server_name: keys} - return hs + + config = self.default_config() + config["trusted_key_servers"] = [ + { + "server_name": self.mock_perspective_server.server_name, + "verify_keys": self.mock_perspective_server.get_verify_keys(), + } + ] + + return self.setup_test_homeserver( + handlers=None, http_client=self.http_client, config=config + ) def check_context(self, _, expected): self.assertEquals( @@ -80,7 +92,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): # we run the lookup in a logcontext so that the patched inlineCallbacks can check # it is doing the right thing with logcontexts. wait_1_deferred = run_in_context( - kr.wait_for_previous_lookups, ["server1"], {"server1": lookup_1_deferred} + kr.wait_for_previous_lookups, {"server1": lookup_1_deferred} ) # there were no previous lookups, so the deferred should be ready @@ -89,7 +101,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): # set off another wait. It should block because the first lookup # hasn't yet completed. wait_2_deferred = run_in_context( - kr.wait_for_previous_lookups, ["server1"], {"server1": lookup_2_deferred} + kr.wait_for_previous_lookups, {"server1": lookup_2_deferred} ) self.assertFalse(wait_2_deferred.called) @@ -132,7 +144,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): context_11.request = "11" res_deferreds = kr.verify_json_objects_for_server( - [("server10", json1), ("server11", {})] + [("server10", json1, 0, "test10"), ("server11", {}, 0, "test11")] ) # the unsigned json should be rejected pretty quickly @@ -169,7 +181,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): self.http_client.post_json.return_value = defer.Deferred() res_deferreds_2 = kr.verify_json_objects_for_server( - [("server10", json1)] + [("server10", json1, 0, "test")] ) res_deferreds_2[0].addBoth(self.check_context, None) yield logcontext.make_deferred_yieldable(res_deferreds_2[0]) @@ -192,31 +204,169 @@ class KeyringTestCase(unittest.HomeserverTestCase): kr = keyring.Keyring(self.hs) key1 = signedjson.key.generate_signing_key(1) - r = self.hs.datastore.store_server_verify_key( - "server9", "", time.time() * 1000, signedjson.key.get_verify_key(key1) + r = self.hs.datastore.store_server_verify_keys( + "server9", + time.time() * 1000, + [("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), 1000))], + ) + self.get_success(r) + + json1 = {} + signedjson.sign.sign_json(json1, "server9", key1) + + # should fail immediately on an unsigned object + d = _verify_json_for_server(kr, "server9", {}, 0, "test unsigned") + self.failureResultOf(d, SynapseError) + + # should suceed on a signed object + d = _verify_json_for_server(kr, "server9", json1, 500, "test signed") + # self.assertFalse(d.called) + self.get_success(d) + + def test_verify_json_for_server_with_null_valid_until_ms(self): + """Tests that we correctly handle key requests for keys we've stored + with a null `ts_valid_until_ms` + """ + mock_fetcher = keyring.KeyFetcher() + mock_fetcher.get_keys = Mock(return_value=defer.succeed({})) + + kr = keyring.Keyring( + self.hs, key_fetchers=(StoreKeyFetcher(self.hs), mock_fetcher) + ) + + key1 = signedjson.key.generate_signing_key(1) + r = self.hs.datastore.store_server_verify_keys( + "server9", + time.time() * 1000, + [("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), None))], ) self.get_success(r) + json1 = {} signedjson.sign.sign_json(json1, "server9", key1) # should fail immediately on an unsigned object - d = _verify_json_for_server(kr, "server9", {}) + d = _verify_json_for_server(kr, "server9", {}, 0, "test unsigned") self.failureResultOf(d, SynapseError) - d = _verify_json_for_server(kr, "server9", json1) - self.assertFalse(d.called) + # should fail on a signed object with a non-zero minimum_valid_until_ms, + # as it tries to refetch the keys and fails. + d = _verify_json_for_server( + kr, "server9", json1, 500, "test signed non-zero min" + ) + self.get_failure(d, SynapseError) + + # We expect the keyring tried to refetch the key once. + mock_fetcher.get_keys.assert_called_once_with( + {"server9": {get_key_id(key1): 500}} + ) + + # should succeed on a signed object with a 0 minimum_valid_until_ms + d = _verify_json_for_server( + kr, "server9", json1, 0, "test signed with zero min" + ) self.get_success(d) + def test_verify_json_dedupes_key_requests(self): + """Two requests for the same key should be deduped.""" + key1 = signedjson.key.generate_signing_key(1) + + def get_keys(keys_to_fetch): + # there should only be one request object (with the max validity) + self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}}) + + return defer.succeed( + { + "server1": { + get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200) + } + } + ) + + mock_fetcher = keyring.KeyFetcher() + mock_fetcher.get_keys = Mock(side_effect=get_keys) + kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher,)) + + json1 = {} + signedjson.sign.sign_json(json1, "server1", key1) + + # the first request should succeed; the second should fail because the key + # has expired + results = kr.verify_json_objects_for_server( + [("server1", json1, 500, "test1"), ("server1", json1, 1500, "test2")] + ) + self.assertEqual(len(results), 2) + self.get_success(results[0]) + e = self.get_failure(results[1], SynapseError).value + self.assertEqual(e.errcode, "M_UNAUTHORIZED") + self.assertEqual(e.code, 401) + + # there should have been a single call to the fetcher + mock_fetcher.get_keys.assert_called_once() + + def test_verify_json_falls_back_to_other_fetchers(self): + """If the first fetcher cannot provide a recent enough key, we fall back""" + key1 = signedjson.key.generate_signing_key(1) + + def get_keys1(keys_to_fetch): + self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}}) + return defer.succeed( + { + "server1": { + get_key_id(key1): FetchKeyResult(get_verify_key(key1), 800) + } + } + ) + + def get_keys2(keys_to_fetch): + self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}}) + return defer.succeed( + { + "server1": { + get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200) + } + } + ) + + mock_fetcher1 = keyring.KeyFetcher() + mock_fetcher1.get_keys = Mock(side_effect=get_keys1) + mock_fetcher2 = keyring.KeyFetcher() + mock_fetcher2.get_keys = Mock(side_effect=get_keys2) + kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher1, mock_fetcher2)) + + json1 = {} + signedjson.sign.sign_json(json1, "server1", key1) + + results = kr.verify_json_objects_for_server( + [("server1", json1, 1200, "test1"), ("server1", json1, 1500, "test2")] + ) + self.assertEqual(len(results), 2) + self.get_success(results[0]) + e = self.get_failure(results[1], SynapseError).value + self.assertEqual(e.errcode, "M_UNAUTHORIZED") + self.assertEqual(e.code, 401) + + # there should have been a single call to each fetcher + mock_fetcher1.get_keys.assert_called_once() + mock_fetcher2.get_keys.assert_called_once() + + +class ServerKeyFetcherTestCase(unittest.HomeserverTestCase): + def make_homeserver(self, reactor, clock): + self.http_client = Mock() + hs = self.setup_test_homeserver(handlers=None, http_client=self.http_client) + return hs + def test_get_keys_from_server(self): # arbitrarily advance the clock a bit self.reactor.advance(100) SERVER_NAME = "server2" - kr = keyring.Keyring(self.hs) + fetcher = ServerKeyFetcher(self.hs) testkey = signedjson.key.generate_signing_key("ver1") testverifykey = signedjson.key.get_verify_key(testkey) testverifykey_id = "ed25519:ver1" - VALID_UNTIL_TS = 1000 + VALID_UNTIL_TS = 200 * 1000 # valid response response = { @@ -238,12 +388,13 @@ class KeyringTestCase(unittest.HomeserverTestCase): self.http_client.get_json.side_effect = get_json - server_name_and_key_ids = [(SERVER_NAME, ("key1",))] - keys = self.get_success(kr.get_keys_from_server(server_name_and_key_ids)) + keys_to_fetch = {SERVER_NAME: {"key1": 0}} + keys = self.get_success(fetcher.get_keys(keys_to_fetch)) k = keys[SERVER_NAME][testverifykey_id] - self.assertEqual(k, testverifykey) - self.assertEqual(k.alg, "ed25519") - self.assertEqual(k.version, "ver1") + self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS) + self.assertEqual(k.verify_key, testverifykey) + self.assertEqual(k.verify_key.alg, "ed25519") + self.assertEqual(k.verify_key.version, "ver1") # check that the perspectives store is correctly updated lookup_triplet = (SERVER_NAME, testverifykey_id, None) @@ -263,18 +414,37 @@ class KeyringTestCase(unittest.HomeserverTestCase): bytes(res["key_json"]), canonicaljson.encode_canonical_json(response) ) - # change the server name: it should cause a rejection + # change the server name: the result should be ignored response["server_name"] = "OTHER_SERVER" - self.get_failure( - kr.get_keys_from_server(server_name_and_key_ids), KeyLookupError + + keys = self.get_success(fetcher.get_keys(keys_to_fetch)) + self.assertEqual(keys, {}) + + +class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): + def make_homeserver(self, reactor, clock): + self.mock_perspective_server = MockPerspectiveServer() + self.http_client = Mock() + + config = self.default_config() + config["trusted_key_servers"] = [ + { + "server_name": self.mock_perspective_server.server_name, + "verify_keys": self.mock_perspective_server.get_verify_keys(), + } + ] + + return self.setup_test_homeserver( + handlers=None, http_client=self.http_client, config=config ) def test_get_keys_from_perspectives(self): # arbitrarily advance the clock a bit self.reactor.advance(100) + fetcher = PerspectivesKeyFetcher(self.hs) + SERVER_NAME = "server2" - kr = keyring.Keyring(self.hs) testkey = signedjson.key.generate_signing_key("ver1") testverifykey = signedjson.key.get_verify_key(testkey) testverifykey_id = "ed25519:ver1" @@ -292,9 +462,10 @@ class KeyringTestCase(unittest.HomeserverTestCase): }, } - persp_resp = { - "server_keys": [self.mock_perspective_server.get_signed_response(response)] - } + # the response must be signed by both the origin server and the perspectives + # server. + signedjson.sign.sign_json(response, SERVER_NAME, testkey) + self.mock_perspective_server.sign_response(response) def post_json(destination, path, data, **kwargs): self.assertEqual(destination, self.mock_perspective_server.server_name) @@ -303,17 +474,18 @@ class KeyringTestCase(unittest.HomeserverTestCase): # check that the request is for the expected key q = data["server_keys"] self.assertEqual(list(q[SERVER_NAME].keys()), ["key1"]) - return persp_resp + return {"server_keys": [response]} self.http_client.post_json.side_effect = post_json - server_name_and_key_ids = [(SERVER_NAME, ("key1",))] - keys = self.get_success(kr.get_keys_from_perspectives(server_name_and_key_ids)) + keys_to_fetch = {SERVER_NAME: {"key1": 0}} + keys = self.get_success(fetcher.get_keys(keys_to_fetch)) self.assertIn(SERVER_NAME, keys) k = keys[SERVER_NAME][testverifykey_id] - self.assertEqual(k, testverifykey) - self.assertEqual(k.alg, "ed25519") - self.assertEqual(k.version, "ver1") + self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS) + self.assertEqual(k.verify_key, testverifykey) + self.assertEqual(k.verify_key.alg, "ed25519") + self.assertEqual(k.verify_key.version, "ver1") # check that the perspectives store is correctly updated lookup_triplet = (SERVER_NAME, testverifykey_id, None) @@ -329,26 +501,96 @@ class KeyringTestCase(unittest.HomeserverTestCase): self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS) self.assertEqual( - bytes(res["key_json"]), - canonicaljson.encode_canonical_json(persp_resp["server_keys"][0]), + bytes(res["key_json"]), canonicaljson.encode_canonical_json(response) ) + def test_invalid_perspectives_responses(self): + """Check that invalid responses from the perspectives server are rejected""" + # arbitrarily advance the clock a bit + self.reactor.advance(100) + + SERVER_NAME = "server2" + testkey = signedjson.key.generate_signing_key("ver1") + testverifykey = signedjson.key.get_verify_key(testkey) + testverifykey_id = "ed25519:ver1" + VALID_UNTIL_TS = 200 * 1000 + + def build_response(): + # valid response + response = { + "server_name": SERVER_NAME, + "old_verify_keys": {}, + "valid_until_ts": VALID_UNTIL_TS, + "verify_keys": { + testverifykey_id: { + "key": signedjson.key.encode_verify_key_base64(testverifykey) + } + }, + } + + # the response must be signed by both the origin server and the perspectives + # server. + signedjson.sign.sign_json(response, SERVER_NAME, testkey) + self.mock_perspective_server.sign_response(response) + return response + + def get_key_from_perspectives(response): + fetcher = PerspectivesKeyFetcher(self.hs) + keys_to_fetch = {SERVER_NAME: {"key1": 0}} + + def post_json(destination, path, data, **kwargs): + self.assertEqual(destination, self.mock_perspective_server.server_name) + self.assertEqual(path, "/_matrix/key/v2/query") + return {"server_keys": [response]} + + self.http_client.post_json.side_effect = post_json + + return self.get_success(fetcher.get_keys(keys_to_fetch)) + + # start with a valid response so we can check we are testing the right thing + response = build_response() + keys = get_key_from_perspectives(response) + k = keys[SERVER_NAME][testverifykey_id] + self.assertEqual(k.verify_key, testverifykey) + + # remove the perspectives server's signature + response = build_response() + del response["signatures"][self.mock_perspective_server.server_name] + self.http_client.post_json.return_value = {"server_keys": [response]} + keys = get_key_from_perspectives(response) + self.assertEqual(keys, {}, "Expected empty dict with missing persp server sig") + + # remove the origin server's signature + response = build_response() + del response["signatures"][SERVER_NAME] + self.http_client.post_json.return_value = {"server_keys": [response]} + keys = get_key_from_perspectives(response) + self.assertEqual(keys, {}, "Expected empty dict with missing origin server sig") + + +def get_key_id(key): + """Get the matrix ID tag for a given SigningKey or VerifyKey""" + return "%s:%s" % (key.alg, key.version) + @defer.inlineCallbacks def run_in_context(f, *args, **kwargs): - with LoggingContext("testctx"): + with LoggingContext("testctx") as ctx: + # we set the "request" prop to make it easier to follow what's going on in the + # logs. + ctx.request = "testctx" rv = yield f(*args, **kwargs) defer.returnValue(rv) -def _verify_json_for_server(keyring, server_name, json_object): +def _verify_json_for_server(kr, *args): """thin wrapper around verify_json_for_server which makes sure it is wrapped with the patched defer.inlineCallbacks. """ @defer.inlineCallbacks def v(): - rv1 = yield keyring.verify_json_for_server(server_name, json_object) + rv1 = yield kr.verify_json_for_server(*args) defer.returnValue(rv1) return run_in_context(v) diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py index c499d9d9d6..3072fdaf9d 100644 --- a/tests/federation/test_complexity.py +++ b/tests/federation/test_complexity.py @@ -13,11 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from mock import Mock - from twisted.internet import defer -from synapse.api.errors import Codes, SynapseError from synapse.config.ratelimiting import FederationRateLimitConfig from synapse.federation.transport import server from synapse.rest import admin diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py index 249aba3d59..2710c991cf 100644 --- a/tests/handlers/test_stats.py +++ b/tests/handlers/test_stats.py @@ -204,7 +204,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): "a2": {"membership": "not a real thing"}, } - def get_event(event_id): + def get_event(event_id, allow_none=True): m = Mock() m.content = events[event_id] d = defer.Deferred() @@ -224,7 +224,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): "room_id": "room", "event_id": "a1", "prev_event_id": "a2", - "stream_id": "bleb", + "stream_id": 60, } ] @@ -241,7 +241,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): "room_id": "room", "event_id": "a2", "prev_event_id": "a1", - "stream_id": "bleb", + "stream_id": 100, } ] @@ -249,3 +249,59 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.assertEqual( f.value.args[0], "'not a real thing' is not a valid membership" ) + + def test_redacted_prev_event(self): + """ + If the prev_event does not exist, then it is assumed to be a LEAVE. + """ + u1 = self.register_user("u1", "pass") + u1_token = self.login("u1", "pass") + + room_1 = self.helper.create_room_as(u1, tok=u1_token) + + # Do the initial population of the user directory via the background update + self._add_background_updates() + + while not self.get_success(self.store.has_completed_background_updates()): + self.get_success(self.store.do_next_background_update(100), by=0.1) + + events = { + "a1": None, + "a2": {"membership": Membership.JOIN}, + } + + def get_event(event_id, allow_none=True): + if events.get(event_id): + m = Mock() + m.content = events[event_id] + else: + m = None + d = defer.Deferred() + self.reactor.callLater(0.0, d.callback, m) + return d + + def get_received_ts(event_id): + return defer.succeed(1) + + self.store.get_received_ts = get_received_ts + self.store.get_event = get_event + + deltas = [ + { + "type": EventTypes.Member, + "state_key": "some_user:test", + "room_id": room_1, + "event_id": "a2", + "prev_event_id": "a1", + "stream_id": 100, + } + ] + + # Handle our fake deltas, which has a user going from LEAVE -> JOIN. + self.get_success(self.handler._handle_deltas(deltas)) + + # One delta, with two joined members -- the room creator, and our fake + # user. + r = self.get_success(self.store.get_deltas_for_room(room_1, 0)) + self.assertEqual(len(r), 1) + self.assertEqual(r[0]["joined_members"], 2) diff --git a/tests/http/__init__.py b/tests/http/__init__.py index 851fc0eb33..2d5dba6464 100644 --- a/tests/http/__init__.py +++ b/tests/http/__init__.py @@ -13,28 +13,122 @@ # See the License for the specific language governing permissions and # limitations under the License. import os.path +import subprocess + +from zope.interface import implementer from OpenSSL import SSL +from OpenSSL.SSL import Connection +from twisted.internet.interfaces import IOpenSSLServerConnectionCreator + + +def get_test_ca_cert_file(): + """Get the path to the test CA cert + + The keypair is generated with: + + openssl genrsa -out ca.key 2048 + openssl req -new -x509 -key ca.key -days 3650 -out ca.crt \ + -subj '/CN=synapse test CA' + """ + return os.path.join(os.path.dirname(__file__), "ca.crt") + + +def get_test_key_file(): + """get the path to the test key + + The key file is made with: + + openssl genrsa -out server.key 2048 + """ + return os.path.join(os.path.dirname(__file__), "server.key") + + +cert_file_count = 0 + +CONFIG_TEMPLATE = b"""\ +[default] +basicConstraints = CA:FALSE +keyUsage=nonRepudiation, digitalSignature, keyEncipherment +subjectAltName = %(sanentries)s +""" + + +def create_test_cert_file(sanlist): + """build an x509 certificate file + + Args: + sanlist: list[bytes]: a list of subjectAltName values for the cert + + Returns: + str: the path to the file + """ + global cert_file_count + csr_filename = "server.csr" + cnf_filename = "server.%i.cnf" % (cert_file_count,) + cert_filename = "server.%i.crt" % (cert_file_count,) + cert_file_count += 1 + + # first build a CSR + subprocess.check_call( + [ + "openssl", + "req", + "-new", + "-key", + get_test_key_file(), + "-subj", + "/", + "-out", + csr_filename, + ] + ) + # now a config file describing the right SAN entries + sanentries = b",".join(sanlist) + with open(cnf_filename, "wb") as f: + f.write(CONFIG_TEMPLATE % {b"sanentries": sanentries}) -def get_test_cert_file(): - """get the path to the test cert""" + # finally the cert + ca_key_filename = os.path.join(os.path.dirname(__file__), "ca.key") + ca_cert_filename = get_test_ca_cert_file() + subprocess.check_call( + [ + "openssl", + "x509", + "-req", + "-in", + csr_filename, + "-CA", + ca_cert_filename, + "-CAkey", + ca_key_filename, + "-set_serial", + "1", + "-extfile", + cnf_filename, + "-out", + cert_filename, + ] + ) - # the cert file itself is made with: - # - # openssl req -x509 -newkey rsa:4096 -keyout server.pem -out server.pem -days 36500 \ - # -nodes -subj '/CN=testserv' - return os.path.join(os.path.dirname(__file__), 'server.pem') + return cert_filename -class ServerTLSContext(object): - """A TLS Context which presents our test cert.""" +@implementer(IOpenSSLServerConnectionCreator) +class TestServerTLSConnectionFactory(object): + """An SSL connection creator which returns connections which present a certificate + signed by our test CA.""" - def __init__(self): - self.filename = get_test_cert_file() + def __init__(self, sanlist): + """ + Args: + sanlist: list[bytes]: a list of subjectAltName values for the cert + """ + self._cert_file = create_test_cert_file(sanlist) - def getContext(self): + def serverConnectionForTLS(self, tlsProtocol): ctx = SSL.Context(SSL.TLSv1_METHOD) - ctx.use_certificate_file(self.filename) - ctx.use_privatekey_file(self.filename) - return ctx + ctx.use_certificate_file(self._cert_file) + ctx.use_privatekey_file(get_test_key_file()) + return Connection(ctx, None) diff --git a/tests/http/ca.crt b/tests/http/ca.crt new file mode 100644 index 0000000000..730f81e99c --- /dev/null +++ b/tests/http/ca.crt @@ -0,0 +1,19 @@ +-----BEGIN CERTIFICATE----- +MIIDCjCCAfKgAwIBAgIJAPwHIHgH/jtjMA0GCSqGSIb3DQEBCwUAMBoxGDAWBgNV +BAMMD3N5bmFwc2UgdGVzdCBDQTAeFw0xOTA2MTAxMTI2NDdaFw0yOTA2MDcxMTI2 +NDdaMBoxGDAWBgNVBAMMD3N5bmFwc2UgdGVzdCBDQTCCASIwDQYJKoZIhvcNAQEB +BQADggEPADCCAQoCggEBAOZOXCKuylf9jHzJXpU2nS+XEKrnGPgs2SAhQKrzBxg3 +/d8KT2Zsfsj1i3G7oGu7B0ZKO6qG5AxOPCmSMf9/aiSHFilfSh+r8rCpJyWMev2c +/w/xmhoFHgn+H90NnqlXvWb5y1YZCE3gWaituQSaa93GPKacRqXCgIrzjPUuhfeT +uwFQt4iyUhMNBYEy3aw4IuIHdyBqi4noUhR2ZeuflLJ6PswdJ8mEiAvxCbBGPerq +idhWcZwlo0fKu4u1uu5B8TnTsMg2fJgL6c5olBG90Urt22gA6anfP5W/U1ZdVhmB +T3Rv5SJMkGyMGE6sEUetLFyb2GJpgGD7ePkUCZr+IMMCAwEAAaNTMFEwHQYDVR0O +BBYEFLg7nTCYsvQXWTyS6upLc0YTlIwRMB8GA1UdIwQYMBaAFLg7nTCYsvQXWTyS +6upLc0YTlIwRMA8GA1UdEwEB/wQFMAMBAf8wDQYJKoZIhvcNAQELBQADggEBADqx +GX4Ul5OGQlcG+xTt4u3vMCeqGo8mh1AnJ7zQbyRmwjJiNxJVX+/EcqFSTsmkBNoe +xdYITI7Z6dyoiKw99yCZDE7gALcyACEU7r0XY7VY/hebAaX6uLaw1sZKKAIC04lD +KgCu82tG85n60Qyud5SiZZF0q1XVq7lbvOYVdzVZ7k8Vssy5p9XnaLJLMggYeOiX +psHIQjvYGnTTEBZZHzWOrc0WGThd69wxTOOkAbCsoTPEwZL8BGUsdtLWtvhp452O +npvaUBzKg39R5X3KTdhB68XptiQfzbQkd3FtrwNuYPUywlsg55Bxkv85n57+xDO3 +D9YkgUqEp0RGUXQgCsQ= +-----END CERTIFICATE----- diff --git a/tests/http/ca.key b/tests/http/ca.key new file mode 100644 index 0000000000..5c99cae186 --- /dev/null +++ b/tests/http/ca.key @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEpgIBAAKCAQEA5k5cIq7KV/2MfMlelTadL5cQqucY+CzZICFAqvMHGDf93wpP +Zmx+yPWLcbuga7sHRko7qobkDE48KZIx/39qJIcWKV9KH6vysKknJYx6/Zz/D/Ga +GgUeCf4f3Q2eqVe9ZvnLVhkITeBZqK25BJpr3cY8ppxGpcKAivOM9S6F95O7AVC3 +iLJSEw0FgTLdrDgi4gd3IGqLiehSFHZl65+Usno+zB0nyYSIC/EJsEY96uqJ2FZx +nCWjR8q7i7W67kHxOdOwyDZ8mAvpzmiUEb3RSu3baADpqd8/lb9TVl1WGYFPdG/l +IkyQbIwYTqwRR60sXJvYYmmAYPt4+RQJmv4gwwIDAQABAoIBAQCFuFG+wYYy+MCt +Y65LLN6vVyMSWAQjdMbM5QHLQDiKU1hQPIhFjBFBVXCVpL9MTde3dDqYlKGsk3BT +ItNs6eoTM2wmsXE0Wn4bHNvh7WMsBhACjeFP4lDCtI6DpvjMkmkidT8eyoIL1Yu5 +aMTYa2Dd79AfXPWYIQrJowfhBBY83KuW5fmYnKKDVLqkT9nf2dgmmQz85RgtNiZC +zFkIsNmPqH1zRbcw0wORfOBrLFvsMc4Tt8EY5Wz3NnH8Zfgf8Q3MgARH1yspz3Vp +B+EYHbsK17xZ+P59KPiX3yefvyYWEUjFF7ymVsVnDxLugYl4pXwWUpm19GxeDvFk +cgBUD5OBAoGBAP7lBdCp6lx6fYtxdxUm3n4MMQmYcac4qZdeBIrvpFMnvOBBuixl +eavcfFmFdwgAr8HyVYiu9ynac504IYvmtYlcpUmiRBbmMHbvLQEYHl7FYFKNz9ej +2ue4oJE3RsPdLsD3xIlc+xN8oT1j0knyorwsHdj0Sv77eZzZS9XZZfJzAoGBAOdO +CibYmoNqK/mqDHkp6PgsnbQGD5/CvPF/BLUWV1QpHxLzUQQeoBOQW5FatHe1H5zi +mbq3emBefVmsCLrRIJ4GQu4vsTMfjcpGLwviWmaK6pHbGPt8IYeEQ2MNyv59EtA2 +pQy4dX7/Oe6NLAR1UEQjXmCuXf+rxnxF3VJd1nRxAoGBANb9eusl9fusgSnVOTjJ +AQ7V36KVRv9hZoG6liBNwo80zDVmms4JhRd1MBkd3mkMkzIF4SkZUnWlwLBSANGM +dX/3eZ5i1AVwgF5Am/f5TNxopDbdT/o1RVT/P8dcFT7s1xuBn+6wU0F7dFBgWqVu +lt4aY85zNrJcj5XBHhqwdDGLAoGBAIksPNUAy9F3m5C6ih8o/aKAQx5KIeXrBUZq +v43tK+kbYfRJHBjHWMOBbuxq0G/VmGPf9q9GtGqGXuxZG+w+rYtJx1OeMQZShjIZ +ITl5CYeahrXtK4mo+fF2PMh3m5UE861LWuKKWhPwpJiWXC5grDNcjlHj1pcTdeip +PjHkuJPhAoGBAIh35DptqqdicOd3dr/+/m2YQywY8aSpMrR0bC06aAkscD7oq4tt +s/jwl0UlHIrEm/aMN7OnGIbpfkVdExfGKYaa5NRlgOwQpShwLufIo/c8fErd2zb8 +K3ptlwBxMrayMXpS3DP78r83Z0B8/FSK2guelzdRJ3ftipZ9io1Gss1C +-----END RSA PRIVATE KEY----- diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index ed0ca079d9..ecce473b01 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -17,16 +17,19 @@ import logging from mock import Mock import treq +from service_identity import VerificationError from zope.interface import implementer from twisted.internet import defer from twisted.internet._sslverify import ClientTLSOptions, OpenSSLCertificateOptions from twisted.internet.protocol import Factory from twisted.protocols.tls import TLSMemoryBIOFactory +from twisted.web._newclient import ResponseNeverReceived from twisted.web.http import HTTPChannel from twisted.web.http_headers import Headers from twisted.web.iweb import IPolicyForHTTPS +from synapse.config.homeserver import HomeServerConfig from synapse.crypto.context_factory import ClientTLSOptionsFactory from synapse.http.federation.matrix_federation_agent import ( MatrixFederationAgent, @@ -36,13 +39,29 @@ from synapse.http.federation.srv_resolver import Server from synapse.util.caches.ttlcache import TTLCache from synapse.util.logcontext import LoggingContext -from tests.http import ServerTLSContext +from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file from tests.server import FakeTransport, ThreadedMemoryReactorClock from tests.unittest import TestCase from tests.utils import default_config logger = logging.getLogger(__name__) +test_server_connection_factory = None + + +def get_connection_factory(): + # this needs to happen once, but not until we are ready to run the first test + global test_server_connection_factory + if test_server_connection_factory is None: + test_server_connection_factory = TestServerTLSConnectionFactory(sanlist=[ + b'DNS:testserv', + b'DNS:target-server', + b'DNS:xn--bcher-kva.com', + b'IP:1.2.3.4', + b'IP:::1', + ]) + return test_server_connection_factory + class MatrixFederationAgentTests(TestCase): def setUp(self): @@ -52,11 +71,16 @@ class MatrixFederationAgentTests(TestCase): self.well_known_cache = TTLCache("test_cache", timer=self.reactor.seconds) + config_dict = default_config("test", parse=False) + config_dict["federation_custom_ca_list"] = [get_test_ca_cert_file()] + # config_dict["trusted_key_servers"] = [] + + self._config = config = HomeServerConfig() + config.parse_config_dict(config_dict) + self.agent = MatrixFederationAgent( reactor=self.reactor, - tls_client_options_factory=ClientTLSOptionsFactory( - default_config("test", parse=True) - ), + tls_client_options_factory=ClientTLSOptionsFactory(config), _well_known_tls_policy=TrustingTLSPolicyForHTTPS(), _srv_resolver=self.mock_resolver, _well_known_cache=self.well_known_cache, @@ -70,7 +94,7 @@ class MatrixFederationAgentTests(TestCase): """ # build the test server - server_tls_protocol = _build_test_server() + server_tls_protocol = _build_test_server(get_connection_factory()) # now, tell the client protocol factory to build the client protocol (it will be a # _WrappingProtocol, around a TLSMemoryBIOProtocol, around an @@ -321,6 +345,88 @@ class MatrixFederationAgentTests(TestCase): self.reactor.pump((0.1,)) self.successResultOf(test_d) + def test_get_hostname_bad_cert(self): + """ + Test the behaviour when the certificate on the server doesn't match the hostname + """ + self.mock_resolver.resolve_service.side_effect = lambda _: [] + self.reactor.lookups["testserv1"] = "1.2.3.4" + + test_d = self._make_get_request(b"matrix://testserv1/foo/bar") + + # Nothing happened yet + self.assertNoResult(test_d) + + # No SRV record lookup yet + self.mock_resolver.resolve_service.assert_not_called() + + # there should be an attempt to connect on port 443 for the .well-known + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients[0] + self.assertEqual(host, '1.2.3.4') + self.assertEqual(port, 443) + + # fonx the connection + client_factory.clientConnectionFailed(None, Exception("nope")) + + # attemptdelay on the hostnameendpoint is 0.3, so takes that long before the + # .well-known request fails. + self.reactor.pump((0.4,)) + + # now there should be a SRV lookup + self.mock_resolver.resolve_service.assert_called_once_with( + b"_matrix._tcp.testserv1" + ) + + # we should fall back to a direct connection + self.assertEqual(len(clients), 2) + (host, port, client_factory, _timeout, _bindAddress) = clients[1] + self.assertEqual(host, '1.2.3.4') + self.assertEqual(port, 8448) + + # make a test server, and wire up the client + http_server = self._make_connection(client_factory, expected_sni=b'testserv1') + + # there should be no requests + self.assertEqual(len(http_server.requests), 0) + + # ... and the request should have failed + e = self.failureResultOf(test_d, ResponseNeverReceived) + failure_reason = e.value.reasons[0] + self.assertIsInstance(failure_reason.value, VerificationError) + + def test_get_ip_address_bad_cert(self): + """ + Test the behaviour when the server name contains an explicit IP, but + the server cert doesn't cover it + """ + # there will be a getaddrinfo on the IP + self.reactor.lookups["1.2.3.5"] = "1.2.3.5" + + test_d = self._make_get_request(b"matrix://1.2.3.5/foo/bar") + + # Nothing happened yet + self.assertNoResult(test_d) + + # Make sure treq is trying to connect + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients[0] + self.assertEqual(host, '1.2.3.5') + self.assertEqual(port, 8448) + + # make a test server, and wire up the client + http_server = self._make_connection(client_factory, expected_sni=None) + + # there should be no requests + self.assertEqual(len(http_server.requests), 0) + + # ... and the request should have failed + e = self.failureResultOf(test_d, ResponseNeverReceived) + failure_reason = e.value.reasons[0] + self.assertIsInstance(failure_reason.value, VerificationError) + def test_get_no_srv_no_well_known(self): """ Test the behaviour when the server name has no port, no SRV, and no well-known @@ -578,6 +684,49 @@ class MatrixFederationAgentTests(TestCase): self.reactor.pump((0.1,)) self.successResultOf(test_d) + def test_get_well_known_unsigned_cert(self): + """Test the behaviour when the .well-known server presents a cert + not signed by a CA + """ + + # we use the same test server as the other tests, but use an agent + # with _well_known_tls_policy left to the default, which will not + # trust it (since the presented cert is signed by a test CA) + + self.mock_resolver.resolve_service.side_effect = lambda _: [] + self.reactor.lookups["testserv"] = "1.2.3.4" + + agent = MatrixFederationAgent( + reactor=self.reactor, + tls_client_options_factory=ClientTLSOptionsFactory(self._config), + _srv_resolver=self.mock_resolver, + _well_known_cache=self.well_known_cache, + ) + + test_d = agent.request(b"GET", b"matrix://testserv/foo/bar") + + # Nothing happened yet + self.assertNoResult(test_d) + + # there should be an attempt to connect on port 443 for the .well-known + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients[0] + self.assertEqual(host, '1.2.3.4') + self.assertEqual(port, 443) + + http_proto = self._make_connection( + client_factory, expected_sni=b"testserv", + ) + + # there should be no requests + self.assertEqual(len(http_proto.requests), 0) + + # and there should be a SRV lookup instead + self.mock_resolver.resolve_service.assert_called_once_with( + b"_matrix._tcp.testserv" + ) + def test_get_hostname_srv(self): """ Test the behaviour when there is a single SRV record @@ -911,11 +1060,17 @@ def _check_logcontext(context): raise AssertionError("Expected logcontext %s but was %s" % (context, current)) -def _build_test_server(): +def _build_test_server(connection_creator): """Construct a test server This builds an HTTP channel, wrapped with a TLSMemoryBIOProtocol + Args: + connection_creator (IOpenSSLServerConnectionCreator): thing to build + SSL connections + sanlist (list[bytes]): list of the SAN entries for the cert returned + by the server + Returns: TLSMemoryBIOProtocol """ @@ -924,7 +1079,7 @@ def _build_test_server(): server_factory.log = _log_request server_tls_factory = TLSMemoryBIOFactory( - ServerTLSContext(), isClient=False, wrappedFactory=server_factory + connection_creator, isClient=False, wrappedFactory=server_factory ) return server_tls_factory.buildProtocol(None) @@ -937,7 +1092,8 @@ def _log_request(request): @implementer(IPolicyForHTTPS) class TrustingTLSPolicyForHTTPS(object): - """An IPolicyForHTTPS which doesn't do any certificate verification""" + """An IPolicyForHTTPS which checks that the certificate belongs to the + right server, but doesn't check the certificate chain.""" def creatorForNetloc(self, hostname, port): certificateOptions = OpenSSLCertificateOptions() diff --git a/tests/http/server.key b/tests/http/server.key new file mode 100644 index 0000000000..c53ee02b21 --- /dev/null +++ b/tests/http/server.key @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEpAIBAAKCAQEAvUAWLOE6TEp3FYSfEnJMwYtJg3KIW5BjiAOOvFVOVQfJ5eEa +vzyJ1Z+8DUgLznFnUkAeD9GjPvP7awl3NPJKLQSMkV5Tp+ea4YyV+Aa4R7flROEa +zCGvmleydZw0VqN1atVZ0ikEoglM/APJQd70ec7KSR3QoxaV2/VNCHmyAPdP+0WI +llV54VXX1CZrWSHaCSn1gzo3WjnGbxTOCQE5Z4k5hqJAwLWWhxDv+FX/jD38Sq3H +gMFNpXJv6FYwwaKU8awghHdSY/qlBPE/1rU83vIBFJ3jW6I1WnQDfCQ69of5vshK +N4v4hok56ScwdUnk8lw6xvJx1Uav/XQB9qGh4QIDAQABAoIBAQCHLO5p8hotAgdb +JFZm26N9nxrMPBOvq0ucjEX4ucnwrFaGzynGrNwa7TRqHCrqs0/EjS2ryOacgbL0 +eldeRy26SASLlN+WD7UuI7e+6DXabDzj3RHB+tGuIbPDk+ZCeBDXVTsKBOhdQN1v +KNkpJrJjCtSsMxKiWvCBow353srJKqCDZcF5NIBYBeDBPMoMbfYn5dJ9JhEf+2h4 +0iwpnWDX1Vqf46pCRa0hwEyMXycGeV2CnfJSyV7z52ZHQrvkz8QspSnPpnlCnbOE +UAvc8kZ5e8oZE7W+JfkK38vHbEGM1FCrBmrC/46uUGMRpZfDferGs91RwQVq/F0n +JN9hLzsBAoGBAPh2pm9Xt7a4fWSkX0cDgjI7PT2BvLUjbRwKLV+459uDa7+qRoGE +sSwb2QBqmQ1kbr9JyTS+Ld8dyUTsGHZK+YbTieAxI3FBdKsuFtcYJO/REN0vik+6 +fMaBHPvDHSU2ioq7spZ4JBFskzqs38FvZ0lX7aa3fguMk8GMLnofQ8QxAoGBAML9 +o5sJLN9Tk9bv2aFgnERgfRfNjjV4Wd99TsktnCD04D1GrP2eDSLfpwFlCnguck6b +jxikqcolsNhZH4dgYHqRNj+IljSdl+sYZiygO6Ld0XU+dEFO86N3E9NzZhKcQ1at +85VdwNPCS7JM2fIxEvS9xfbVnsmK6/37ZZ5iI7yxAoGBALw2vRtJGmy60pojfd1A +hibhAyINnlKlFGkSOI7zdgeuRTf6l9BTIRclvTt4hJpFgzM6hMWEbyE94hJoupsZ +bm443o/LCWsox2VI05p6urhD6f9znNWKkiyY78izY+elqksvpjgfqEresaTYAeP5 +LQe9KNSK2VuMUP1j4G04M9BxAoGAWe8ITZJuytZOgrz/YIohqPvj1l2tcIYA1a6C +7xEFSMIIxtpZIWSLZIFJEsCakpHBkPX4iwIveZfmt/JrM1JFTWK6ZZVGyh/BmOIZ +Bg4lU1oBqJTUo+aZQtTCJS29b2n5OPpkNYkXTdP4e9UsVKNDvfPlYZJneUeEzxDr +bqCPIRECgYA544KMwrWxDQZg1dsKWgdVVKx80wEFZAiQr9+0KF6ch6Iu7lwGJHFY +iI6O85paX41qeC/Fo+feIWJVJU2GvG6eBsbO4bmq+KSg4NkABJSYxodgBp9ftNeD +jo1tfw+gudlNe5jXHu7oSX93tqGjR4Cnlgan/KtfkB96yHOumGmOhQ== +-----END RSA PRIVATE KEY----- diff --git a/tests/http/server.pem b/tests/http/server.pem deleted file mode 100644 index 0584cf1a80..0000000000 --- a/tests/http/server.pem +++ /dev/null @@ -1,81 +0,0 @@ ------BEGIN PRIVATE KEY----- -MIIJQgIBADANBgkqhkiG9w0BAQEFAASCCSwwggkoAgEAAoICAQCgF43/3lAgJ+p0 -x7Rn8UcL8a4fctvdkikvZrCngw96LkB34Evfq8YGWlOVjU+f9naUJLAKMatmAfEN -r+rMX4VOXmpTwuu6iLtqwreUrRFMESyrmvQxa15p+y85gkY0CFmXMblv6ORbxHTG -ncBGwST4WK4Poewcgt6jcISFCESTUKu1zc3cw1ANIDRyDLB5K44KwIe36dcKckyN -Kdtv4BJ+3fcIZIkPJH62zqCypgFF1oiFt40uJzClxgHdJZlKYpgkfnDTckw4Y/Mx -9k8BbE310KAzUNMV9H7I1eEolzrNr66FQj1eN64X/dqO8lTbwCqAd4diCT4sIUk0 -0SVsAUjNd3g8j651hx+Qb1t8fuOjrny8dmeMxtUgIBHoQcpcj76R55Fs7KZ9uar0 -8OFTyGIze51W1jG2K/7/5M1zxIqrA+7lsXu5OR81s7I+Ng/UUAhiHA/z+42/aiNa -qEuk6tqj3rHfLctnCbtZ+JrRNqSSwEi8F0lMA021ivEd2eJV+284OyJjhXOmKHrX -QADHrmS7Sh4syTZvRNm9n+qWID0KdDr2Sji/KnS3Enp44HDQ4xriT6/xhwEGsyuX -oH5aAkdLznulbWkHBbyx1SUQSTLpOqzaioF9m1vRrLsFvrkrY3D253mPJ5eU9HM/ -dilduFcUgj4rz+6cdXUAh+KK/v95zwIDAQABAoICAFG5tJPaOa0ws0/KYx5s3YgL -aIhFalhCNSQtmCDrlwsYcXDA3/rfBchYdDL0YKGYgBBAal3J3WXFt/j0xThvyu2m -5UC9UPl4s7RckrsjXqEmY1d3UxGnbhtMT19cUdpeKN42VCP9EBaIw9Rg07dLAkSF -gNYaIx6q8F0fI4eGIPvTQtUcqur4CfWpaxyNvckdovV6M85/YXfDwbCOnacPDGIX -jfSK3i0MxGMuOHr6o8uzKR6aBUh6WStHWcw7VXXTvzdiFNbckmx3Gb93rf1b/LBw -QFfx+tBKcC62gKroCOzXso/0sL9YTVeSD/DJZOiJwSiz3Dj/3u1IUMbVvfTU8wSi -CYS7Z+jHxwSOCSSNTXm1wO/MtDsNKbI1+R0cohr/J9pOMQvrVh1+2zSDOFvXAQ1S -yvjn+uqdmijRoV2VEGVHd+34C+ci7eJGAhL/f92PohuuFR2shUETgGWzpACZSJwg -j1d90Hs81hj07vWRb+xCeDh00vimQngz9AD8vYvv/S4mqRGQ6TZdfjLoUwSTg0JD -6sQgRXX026gQhLhn687vLKZfHwzQPZkpQdxOR0dTZ/ho/RyGGRJXH4kN4cA2tPr+ -AKYQ29YXGlEzGG7OqikaZcprNWG6UFgEpuXyBxCgp9r4ladZo3J+1Rhgus8ZYatd -uO98q3WEBmP6CZ2n32mBAoIBAQDS/c/ybFTos0YpGHakwdmSfj5OOQJto2y8ywfG -qDHwO0ebcpNnS1+MA+7XbKUQb/3Iq7iJljkkzJG2DIJ6rpKynYts1ViYpM7M/t0T -W3V1gvUcUL62iqkgws4pnpWmubFkqV31cPSHcfIIclnzeQ1aOEGsGHNAvhty0ciC -DnkJACbqApvopFLOR5f6UFTtKExE+hDH0WqgpsCAKJ1L4g6pBzZatI32/CN9JEVU -tDbxLV75hHlFFjUrG7nT1rPyr/gI8Ceh9/2xeXPfjJUR0PrG3U1nwLqUCZkvFzO6 -XpN2+A+/v4v5xqMjKDKDFy1oq6SCMomwv/viw6wl/84TMbolAoIBAQDCPiMecnR8 -REik6tqVzQO/uSe9ZHjz6J15t5xdwaI6HpSwLlIkQPkLTjyXtFpemK5DOYRxrJvQ -remfrZrN2qtLlb/DKpuGPWRsPOvWCrSuNEp48ivUehtclljrzxAFfy0sM+fWeJ48 -nTnR+td9KNhjNtZixzWdAy/mE+jdaMsXVnk66L73Uz+2WsnvVMW2R6cpCR0F2eP/ -B4zDWRqlT2w47sePAB81mFYSQLvPC6Xcgg1OqMubfiizJI49c8DO6Jt+FFYdsxhd -kG52Eqa/Net6rN3ueiS6yXL5TU3Y6g96bPA2KyNCypucGcddcBfqaiVx/o4AH6yT -NrdsrYtyvk/jAoIBAQDHUwKVeeRJJbvdbQAArCV4MI155n+1xhMe1AuXkCQFWGtQ -nlBE4D72jmyf1UKnIbW2Uwv15xY6/ouVWYIWlj9+QDmMaozVP7Uiko+WDuwLRNl8 -k4dn+dzHV2HejbPBG2JLv3lFOx23q1zEwArcaXrExaq9Ayg2fKJ/uVHcFAIiD6Oz -pR1XDY4w1A/uaN+iYFSVQUyDCQLbnEz1hej73CaPZoHh9Pq83vxD5/UbjVjuRTeZ -L55FNzKpc/r89rNvTPBcuUwnxplDhYKDKVNWzn9rSXwrzTY2Tk8J3rh+k4RqevSd -6D47jH1n5Dy7/TRn0ueKHGZZtTUnyEUkbOJo3ayFAoIBAHKDyZaQqaX9Z8p6fwWj -yVsFoK0ih8BcWkLBAdmwZ6DWGJjJpjmjaG/G3ygc9s4gO1R8m12dAnuDnGE8KzDD -gwtbrKM2Alyg4wyA2hTlWOH/CAzH0RlCJ9Fs/d1/xJVJBeuyajLiB3/6vXTS6qnq -I7BSSxAPG8eGcn21LSsjNeB7ZZtaTgNnu/8ZBUYo9yrgkWc67TZe3/ChldYxOOlO -qqHh/BqNWtjxB4VZTp/g4RbgQVInZ2ozdXEv0v/dt0UEk29ANAjsZif7F3RayJ2f -/0TilzCaJ/9K9pKNhaClVRy7Dt8QjYg6BIWCGSw4ApF7pLnQ9gySn95mersCkVzD -YDsCggEAb0E/TORjQhKfNQvahyLfQFm151e+HIoqBqa4WFyfFxe/IJUaLH/JSSFw -VohbQqPdCmaAeuQ8ERL564DdkcY5BgKcax79fLLCOYP5bT11aQx6uFpfl2Dcm6Z9 -QdCRI4jzPftsd5fxLNH1XtGyC4t6vTic4Pji2O71WgWzx0j5v4aeDY4sZQeFxqCV -/q7Ee8hem1Rn5RFHu14FV45RS4LAWl6wvf5pQtneSKzx8YL0GZIRRytOzdEfnGKr -FeUlAj5uL+5/p0ZEgM7gPsEBwdm8scF79qSUn8UWSoXNeIauF9D4BDg8RZcFFxka -KILVFsq3cQC+bEnoM4eVbjEQkGs1RQ== ------END PRIVATE KEY----- ------BEGIN CERTIFICATE----- -MIIE/jCCAuagAwIBAgIJANFtVaGvJWZlMA0GCSqGSIb3DQEBCwUAMBMxETAPBgNV -BAMMCHRlc3RzZXJ2MCAXDTE5MDEyNzIyMDIzNloYDzIxMTkwMTAzMjIwMjM2WjAT -MREwDwYDVQQDDAh0ZXN0c2VydjCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoC -ggIBAKAXjf/eUCAn6nTHtGfxRwvxrh9y292SKS9msKeDD3ouQHfgS9+rxgZaU5WN -T5/2dpQksAoxq2YB8Q2v6sxfhU5ealPC67qIu2rCt5StEUwRLKua9DFrXmn7LzmC -RjQIWZcxuW/o5FvEdMadwEbBJPhYrg+h7ByC3qNwhIUIRJNQq7XNzdzDUA0gNHIM -sHkrjgrAh7fp1wpyTI0p22/gEn7d9whkiQ8kfrbOoLKmAUXWiIW3jS4nMKXGAd0l -mUpimCR+cNNyTDhj8zH2TwFsTfXQoDNQ0xX0fsjV4SiXOs2vroVCPV43rhf92o7y -VNvAKoB3h2IJPiwhSTTRJWwBSM13eDyPrnWHH5BvW3x+46OufLx2Z4zG1SAgEehB -ylyPvpHnkWzspn25qvTw4VPIYjN7nVbWMbYr/v/kzXPEiqsD7uWxe7k5HzWzsj42 -D9RQCGIcD/P7jb9qI1qoS6Tq2qPesd8ty2cJu1n4mtE2pJLASLwXSUwDTbWK8R3Z -4lX7bzg7ImOFc6YoetdAAMeuZLtKHizJNm9E2b2f6pYgPQp0OvZKOL8qdLcSenjg -cNDjGuJPr/GHAQazK5egfloCR0vOe6VtaQcFvLHVJRBJMuk6rNqKgX2bW9GsuwW+ -uStjcPbneY8nl5T0cz92KV24VxSCPivP7px1dQCH4or+/3nPAgMBAAGjUzBRMB0G -A1UdDgQWBBQcQZpzLzTk5KdS/Iz7sGCV7gTd/zAfBgNVHSMEGDAWgBQcQZpzLzTk -5KdS/Iz7sGCV7gTd/zAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUAA4IC -AQAr/Pgha57jqYsDDX1LyRrVdqoVBpLBeB7x/p9dKYm7S6tBTDFNMZ0SZyQP8VEG -7UoC9/OQ9nCdEMoR7ZKpQsmipwcIqpXHS6l4YOkf5EEq5jpMgvlEesHmBJJeJew/ -FEPDl1bl8d0tSrmWaL3qepmwzA+2lwAAouWk2n+rLiP8CZ3jZeoTXFqYYrUlEqO9 -fHMvuWqTV4KCSyNY+GWCrnHetulgKHlg+W2J1mZnrCKcBhWf9C2DesTJO+JldIeM -ornTFquSt21hZi+k3aySuMn2N3MWiNL8XsZVsAnPSs0zA+2fxjJkShls8Gc7cCvd -a6XrNC+PY6pONguo7rEU4HiwbvnawSTngFFglmH/ImdA/HkaAekW6o82aI8/UxFx -V9fFMO3iKDQdOrg77hI1bx9RlzKNZZinE2/Pu26fWd5d2zqDWCjl8ykGQRAfXgYN -H3BjgyXLl+ao5/pOUYYtzm3ruTXTgRcy5hhL6hVTYhSrf9vYh4LNIeXNKnZ78tyG -TX77/kU2qXhBGCFEUUMqUNV/+ITir2lmoxVjknt19M07aGr8C7SgYt6Rs+qDpMiy -JurgvRh8LpVq4pHx1efxzxCFmo58DMrG40I0+CF3y/niNpOb1gp2wAqByRiORkds -f0ytW6qZ0TpHbD6gOtQLYDnhx3ISuX+QYSekVwQUpffeWQ== ------END CERTIFICATE----- diff --git a/tests/push/test_email.py b/tests/push/test_email.py index 9cdde1a9bd..72760a0733 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py @@ -15,6 +15,7 @@ import os +import attr import pkg_resources from twisted.internet.defer import Deferred @@ -24,15 +25,16 @@ from synapse.rest.client.v1 import login, room from tests.unittest import HomeserverTestCase -try: - from synapse.push.mailer import load_jinja2_templates -except Exception: - load_jinja2_templates = None + +@attr.s +class _User(object): + "Helper wrapper for user ID and access token" + id = attr.ib() + token = attr.ib() class EmailPusherTests(HomeserverTestCase): - skip = "No Jinja installed" if not load_jinja2_templates else None servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, room.register_servlets, @@ -77,25 +79,32 @@ class EmailPusherTests(HomeserverTestCase): return hs - def test_sends_email(self): - + def prepare(self, reactor, clock, hs): # Register the user who gets notified - user_id = self.register_user("user", "pass") - access_token = self.login("user", "pass") - - # Register the user who sends the message - other_user_id = self.register_user("otheruser", "pass") - other_access_token = self.login("otheruser", "pass") + self.user_id = self.register_user("user", "pass") + self.access_token = self.login("user", "pass") + + # Register other users + self.others = [ + _User( + id=self.register_user("otheruser1", "pass"), + token=self.login("otheruser1", "pass"), + ), + _User( + id=self.register_user("otheruser2", "pass"), + token=self.login("otheruser2", "pass"), + ), + ] # Register the pusher user_tuple = self.get_success( - self.hs.get_datastore().get_user_by_access_token(access_token) + self.hs.get_datastore().get_user_by_access_token(self.access_token) ) token_id = user_tuple["token_id"] - self.get_success( + self.pusher = self.get_success( self.hs.get_pusherpool().add_pusher( - user_id=user_id, + user_id=self.user_id, access_token=token_id, kind="email", app_id="m.email", @@ -107,22 +116,54 @@ class EmailPusherTests(HomeserverTestCase): ) ) - # Create a room - room = self.helper.create_room_as(user_id, tok=access_token) + def test_simple_sends_email(self): + # Create a simple room with two users + room = self.helper.create_room_as(self.user_id, tok=self.access_token) + self.helper.invite( + room=room, src=self.user_id, tok=self.access_token, targ=self.others[0].id, + ) + self.helper.join(room=room, user=self.others[0].id, tok=self.others[0].token) - # Invite the other person - self.helper.invite(room=room, src=user_id, tok=access_token, targ=other_user_id) + # The other user sends some messages + self.helper.send(room, body="Hi!", tok=self.others[0].token) + self.helper.send(room, body="There!", tok=self.others[0].token) - # The other user joins - self.helper.join(room=room, user=other_user_id, tok=other_access_token) + # We should get emailed about that message + self._check_for_mail() - # The other user sends some messages - self.helper.send(room, body="Hi!", tok=other_access_token) - self.helper.send(room, body="There!", tok=other_access_token) + def test_multiple_members_email(self): + # We want to test multiple notifications, so we pause processing of push + # while we send messages. + self.pusher._pause_processing() + + # Create a simple room with multiple other users + room = self.helper.create_room_as(self.user_id, tok=self.access_token) + + for other in self.others: + self.helper.invite( + room=room, src=self.user_id, tok=self.access_token, targ=other.id, + ) + self.helper.join(room=room, user=other.id, tok=other.token) + + # The other users send some messages + self.helper.send(room, body="Hi!", tok=self.others[0].token) + self.helper.send(room, body="There!", tok=self.others[1].token) + self.helper.send(room, body="There!", tok=self.others[1].token) + + # Nothing should have happened yet, as we're paused. + assert not self.email_attempts + + self.pusher._resume_processing() + + # We should get emailed about those messages + self._check_for_mail() + + def _check_for_mail(self): + "Check that the user receives an email notification" # Get the stream ordering before it gets sent pushers = self.get_success( - self.hs.get_datastore().get_pushers_by(dict(user_name=user_id)) + self.hs.get_datastore().get_pushers_by(dict(user_name=self.user_id)) ) self.assertEqual(len(pushers), 1) last_stream_ordering = pushers[0]["last_stream_ordering"] @@ -132,7 +173,7 @@ class EmailPusherTests(HomeserverTestCase): # It hasn't succeeded yet, so the stream ordering shouldn't have moved pushers = self.get_success( - self.hs.get_datastore().get_pushers_by(dict(user_name=user_id)) + self.hs.get_datastore().get_pushers_by(dict(user_name=self.user_id)) ) self.assertEqual(len(pushers), 1) self.assertEqual(last_stream_ordering, pushers[0]["last_stream_ordering"]) @@ -149,7 +190,7 @@ class EmailPusherTests(HomeserverTestCase): # The stream ordering has increased pushers = self.get_success( - self.hs.get_datastore().get_pushers_by(dict(user_name=user_id)) + self.hs.get_datastore().get_pushers_by(dict(user_name=self.user_id)) ) self.assertEqual(len(pushers), 1) self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering) diff --git a/tests/push/test_http.py b/tests/push/test_http.py index aba618b2be..22c3f73ef3 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py @@ -23,15 +23,9 @@ from synapse.util.logcontext import make_deferred_yieldable from tests.unittest import HomeserverTestCase -try: - from synapse.push.mailer import load_jinja2_templates -except Exception: - load_jinja2_templates = None - class HTTPPusherTests(HomeserverTestCase): - skip = "No Jinja installed" if not load_jinja2_templates else None servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, room.register_servlets, diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index ee5f09041f..e5fc2fcd15 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -408,7 +408,6 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase): users_in_room = self.get_success(self.store.get_users_in_room(room_id)) self.assertEqual([], users_in_room) - @unittest.DEBUG def test_shutdown_room_block_peek(self): """Test that a world_readable room can no longer be peeked into after it has been shut down. diff --git a/tests/rest/client/test_consent.py b/tests/rest/client/test_consent.py index 88f8f1abdc..efc5a99db3 100644 --- a/tests/rest/client/test_consent.py +++ b/tests/rest/client/test_consent.py @@ -23,14 +23,8 @@ from synapse.rest.consent import consent_resource from tests import unittest from tests.server import render -try: - from synapse.push.mailer import load_jinja2_templates -except Exception: - load_jinja2_templates = None - class ConsentResourceTestCase(unittest.HomeserverTestCase): - skip = "No Jinja installed" if not load_jinja2_templates else None servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, room.register_servlets, diff --git a/tests/rest/client/third_party_rules.py b/tests/rest/client/third_party_rules.py new file mode 100644 index 0000000000..7167fc56b6 --- /dev/null +++ b/tests/rest/client/third_party_rules.py @@ -0,0 +1,79 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.rest import admin +from synapse.rest.client.v1 import login, room + +from tests import unittest + + +class ThirdPartyRulesTestModule(object): + def __init__(self, config): + pass + + def check_event_allowed(self, event, context): + if event.type == "foo.bar.forbidden": + return False + else: + return True + + @staticmethod + def parse_config(config): + return config + + +class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + config["third_party_event_rules"] = { + "module": "tests.rest.client.third_party_rules.ThirdPartyRulesTestModule", + "config": {}, + } + + self.hs = self.setup_test_homeserver(config=config) + return self.hs + + def test_third_party_rules(self): + """Tests that a forbidden event is forbidden from being sent, but an allowed one + can be sent. + """ + user_id = self.register_user("kermit", "monkey") + tok = self.login("kermit", "monkey") + + room_id = self.helper.create_room_as(user_id, tok=tok) + + request, channel = self.make_request( + "PUT", + "/_matrix/client/r0/rooms/%s/send/foo.bar.allowed/1" % room_id, + {}, + access_token=tok, + ) + self.render(request) + self.assertEquals(channel.result["code"], b"200", channel.result) + + request, channel = self.make_request( + "PUT", + "/_matrix/client/r0/rooms/%s/send/foo.bar.forbidden/1" % room_id, + {}, + access_token=tok, + ) + self.render(request) + self.assertEquals(channel.result["code"], b"403", channel.result) diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py index 769c37ce52..72c7ed93cb 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py @@ -14,6 +14,8 @@ # limitations under the License. """Tests REST events for /profile paths.""" +import json + from mock import Mock from twisted.internet import defer @@ -28,11 +30,14 @@ from tests import unittest from ....utils import MockHttpResource, setup_test_homeserver myid = "@1234ABCD:test" -PATH_PREFIX = "/_matrix/client/api/v1" +PATH_PREFIX = "/_matrix/client/r0" + +class MockHandlerProfileTestCase(unittest.TestCase): + """ Tests rest layer of profile management. -class ProfileTestCase(unittest.TestCase): - """ Tests profile management. """ + Todo: move these into ProfileTestCase + """ @defer.inlineCallbacks def setUp(self): @@ -159,6 +164,59 @@ class ProfileTestCase(unittest.TestCase): self.assertEquals(mocked_set.call_args[0][2], "http://my.server/pic.gif") +class ProfileTestCase(unittest.HomeserverTestCase): + + servlets = [ + admin.register_servlets_for_client_rest_resource, + login.register_servlets, + profile.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + self.hs = self.setup_test_homeserver() + return self.hs + + def prepare(self, reactor, clock, hs): + self.owner = self.register_user("owner", "pass") + self.owner_tok = self.login("owner", "pass") + + def test_set_displayname(self): + request, channel = self.make_request( + "PUT", + "/profile/%s/displayname" % (self.owner, ), + content=json.dumps({"displayname": "test"}), + access_token=self.owner_tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + res = self.get_displayname() + self.assertEqual(res, "test") + + def test_set_displayname_too_long(self): + """Attempts to set a stupid displayname should get a 400""" + request, channel = self.make_request( + "PUT", + "/profile/%s/displayname" % (self.owner, ), + content=json.dumps({"displayname": "test" * 100}), + access_token=self.owner_tok, + ) + self.render(request) + self.assertEqual(channel.code, 400, channel.result) + + res = self.get_displayname() + self.assertEqual(res, "owner") + + def get_displayname(self): + request, channel = self.make_request( + "GET", + "/profile/%s/displayname" % (self.owner, ), + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + return channel.json_body["displayname"] + + class ProfilesRestrictedTestCase(unittest.HomeserverTestCase): servlets = [ diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py new file mode 100644 index 0000000000..a60a4a3b87 --- /dev/null +++ b/tests/rest/client/v2_alpha/test_account.py @@ -0,0 +1,286 @@ +# -*- coding: utf-8 -*- +# Copyright 2015-2016 OpenMarket Ltd +# Copyright 2017-2018 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import re +from email.parser import Parser + +import pkg_resources + +import synapse.rest.admin +from synapse.api.constants import LoginType +from synapse.rest.client.v1 import login +from synapse.rest.client.v2_alpha import account, register + +from tests import unittest + + +class PasswordResetTestCase(unittest.HomeserverTestCase): + + servlets = [ + account.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, + register.register_servlets, + login.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + + # Email config. + self.email_attempts = [] + + def sendmail(smtphost, from_addr, to_addrs, msg, **kwargs): + self.email_attempts.append(msg) + return + + config["email"] = { + "enable_notifs": False, + "template_dir": os.path.abspath( + pkg_resources.resource_filename("synapse", "res/templates") + ), + "smtp_host": "127.0.0.1", + "smtp_port": 20, + "require_transport_security": False, + "smtp_user": None, + "smtp_pass": None, + "notif_from": "test@example.com", + } + config["public_baseurl"] = "https://example.com" + + hs = self.setup_test_homeserver(config=config, sendmail=sendmail) + return hs + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + + def test_basic_password_reset(self): + """Test basic password reset flow + """ + old_password = "monkey" + new_password = "kangeroo" + + user_id = self.register_user("kermit", old_password) + self.login("kermit", old_password) + + email = "test@example.com" + + # Add a threepid + self.get_success( + self.store.user_add_threepid( + user_id=user_id, + medium="email", + address=email, + validated_at=0, + added_at=0, + ) + ) + + client_secret = "foobar" + session_id = self._request_token(email, client_secret) + + self.assertEquals(len(self.email_attempts), 1) + link = self._get_link_from_email() + + self._validate_token(link) + + self._reset_password(new_password, session_id, client_secret) + + # Assert we can log in with the new password + self.login("kermit", new_password) + + # Assert we can't log in with the old password + self.attempt_wrong_password_login("kermit", old_password) + + def test_cant_reset_password_without_clicking_link(self): + """Test that we do actually need to click the link in the email + """ + old_password = "monkey" + new_password = "kangeroo" + + user_id = self.register_user("kermit", old_password) + self.login("kermit", old_password) + + email = "test@example.com" + + # Add a threepid + self.get_success( + self.store.user_add_threepid( + user_id=user_id, + medium="email", + address=email, + validated_at=0, + added_at=0, + ) + ) + + client_secret = "foobar" + session_id = self._request_token(email, client_secret) + + self.assertEquals(len(self.email_attempts), 1) + + # Attempt to reset password without clicking the link + self._reset_password( + new_password, session_id, client_secret, expected_code=401, + ) + + # Assert we can log in with the old password + self.login("kermit", old_password) + + # Assert we can't log in with the new password + self.attempt_wrong_password_login("kermit", new_password) + + def test_no_valid_token(self): + """Test that we do actually need to request a token and can't just + make a session up. + """ + old_password = "monkey" + new_password = "kangeroo" + + user_id = self.register_user("kermit", old_password) + self.login("kermit", old_password) + + email = "test@example.com" + + # Add a threepid + self.get_success( + self.store.user_add_threepid( + user_id=user_id, + medium="email", + address=email, + validated_at=0, + added_at=0, + ) + ) + + client_secret = "foobar" + session_id = "weasle" + + # Attempt to reset password without even requesting an email + self._reset_password( + new_password, session_id, client_secret, expected_code=401, + ) + + # Assert we can log in with the old password + self.login("kermit", old_password) + + # Assert we can't log in with the new password + self.attempt_wrong_password_login("kermit", new_password) + + def _request_token(self, email, client_secret): + request, channel = self.make_request( + "POST", + b"account/password/email/requestToken", + {"client_secret": client_secret, "email": email, "send_attempt": 1}, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.result) + + return channel.json_body["sid"] + + def _validate_token(self, link): + # Remove the host + path = link.replace("https://example.com", "") + + request, channel = self.make_request("GET", path, shorthand=False) + self.render(request) + self.assertEquals(200, channel.code, channel.result) + + def _get_link_from_email(self): + assert self.email_attempts, "No emails have been sent" + + raw_msg = self.email_attempts[-1].decode("UTF-8") + mail = Parser().parsestr(raw_msg) + + text = None + for part in mail.walk(): + if part.get_content_type() == "text/plain": + text = part.get_payload(decode=True).decode("UTF-8") + break + + if not text: + self.fail("Could not find text portion of email to parse") + + match = re.search(r"https://example.com\S+", text) + assert match, "Could not find link in email" + + return match.group(0) + + def _reset_password( + self, new_password, session_id, client_secret, expected_code=200 + ): + request, channel = self.make_request( + "POST", + b"account/password", + { + "new_password": new_password, + "auth": { + "type": LoginType.EMAIL_IDENTITY, + "threepid_creds": { + "client_secret": client_secret, + "sid": session_id, + }, + }, + }, + ) + self.render(request) + self.assertEquals(expected_code, channel.code, channel.result) + + +class DeactivateTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + account.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + hs = self.setup_test_homeserver() + return hs + + def test_deactivate_account(self): + user_id = self.register_user("kermit", "test") + tok = self.login("kermit", "test") + + request_data = json.dumps({ + "auth": { + "type": "m.login.password", + "user": user_id, + "password": "test", + }, + "erase": False, + }) + request, channel = self.make_request( + "POST", + "account/deactivate", + request_data, + access_token=tok, + ) + self.render(request) + self.assertEqual(request.code, 200) + + store = self.hs.get_datastore() + + # Check that the user has been marked as deactivated. + self.assertTrue(self.get_success(store.get_user_deactivated_status(user_id))) + + # Check that this access token has been invalidated. + request, channel = self.make_request("GET", "account/whoami") + self.render(request) + self.assertEqual(request.code, 401) diff --git a/tests/rest/client/v2_alpha/test_capabilities.py b/tests/rest/client/v2_alpha/test_capabilities.py index f3ef977404..bce5b0cf4c 100644 --- a/tests/rest/client/v2_alpha/test_capabilities.py +++ b/tests/rest/client/v2_alpha/test_capabilities.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import synapse.rest.admin -from synapse.api.room_versions import DEFAULT_ROOM_VERSION, KNOWN_ROOM_VERSIONS +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.rest.client.v1 import login from synapse.rest.client.v2_alpha import capabilities @@ -32,6 +32,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): self.url = b"/_matrix/client/r0/capabilities" hs = self.setup_test_homeserver() self.store = hs.get_datastore() + self.config = hs.config return hs def test_check_auth_required(self): @@ -51,8 +52,10 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200) for room_version in capabilities['m.room_versions']['available'].keys(): self.assertTrue(room_version in KNOWN_ROOM_VERSIONS, "" + room_version) + self.assertEqual( - DEFAULT_ROOM_VERSION.identifier, capabilities['m.room_versions']['default'] + self.config.default_room_version.identifier, + capabilities['m.room_versions']['default'], ) def test_get_change_password_capabilities(self): diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index d4a1d4d50c..b35b215446 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -26,15 +26,10 @@ from synapse.api.constants import LoginType from synapse.api.errors import Codes from synapse.appservice import ApplicationService from synapse.rest.client.v1 import login -from synapse.rest.client.v2_alpha import account_validity, register, sync +from synapse.rest.client.v2_alpha import account, account_validity, register, sync from tests import unittest -try: - from synapse.push.mailer import load_jinja2_templates -except ImportError: - load_jinja2_templates = None - class RegisterRestServletTestCase(unittest.HomeserverTestCase): @@ -307,13 +302,13 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): - skip = "No Jinja installed" if not load_jinja2_templates else None servlets = [ register.register_servlets, synapse.rest.admin.register_servlets_for_client_rest_resource, login.register_servlets, sync.register_servlets, account_validity.register_servlets, + account.register_servlets, ] def make_homeserver(self, reactor, clock): @@ -364,20 +359,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): def test_renewal_email(self): self.email_attempts = [] - user_id = self.register_user("kermit", "monkey") - tok = self.login("kermit", "monkey") - # We need to manually add an email address otherwise the handler will do - # nothing. - now = self.hs.clock.time_msec() - self.get_success( - self.store.user_add_threepid( - user_id=user_id, - medium="email", - address="kermit@example.com", - validated_at=now, - added_at=now, - ) - ) + (user_id, tok) = self.create_user() # Move 6 days forward. This should trigger a renewal email to be sent. self.reactor.advance(datetime.timedelta(days=6).total_seconds()) @@ -402,8 +384,64 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): def test_manual_email_send(self): self.email_attempts = [] + (user_id, tok) = self.create_user() + request, channel = self.make_request( + b"POST", + "/_matrix/client/unstable/account_validity/send_mail", + access_token=tok, + ) + self.render(request) + self.assertEquals(channel.result["code"], b"200", channel.result) + + self.assertEqual(len(self.email_attempts), 1) + + def test_deactivated_user(self): + self.email_attempts = [] + + (user_id, tok) = self.create_user() + + request_data = json.dumps({ + "auth": { + "type": "m.login.password", + "user": user_id, + "password": "monkey", + }, + "erase": False, + }) + request, channel = self.make_request( + "POST", + "account/deactivate", + request_data, + access_token=tok, + ) + self.render(request) + self.assertEqual(request.code, 200) + + self.reactor.advance(datetime.timedelta(days=8).total_seconds()) + + self.assertEqual(len(self.email_attempts), 0) + + def create_user(self): + user_id = self.register_user("kermit", "monkey") + tok = self.login("kermit", "monkey") + # We need to manually add an email address otherwise the handler will do + # nothing. + now = self.hs.clock.time_msec() + self.get_success( + self.store.user_add_threepid( + user_id=user_id, + medium="email", + address="kermit@example.com", + validated_at=now, + added_at=now, + ) + ) + return (user_id, tok) + + def test_manual_email_send_expired_account(self): user_id = self.register_user("kermit", "monkey") tok = self.login("kermit", "monkey") + # We need to manually add an email address otherwise the handler will do # nothing. now = self.hs.clock.time_msec() @@ -417,6 +455,14 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): ) ) + # Make the account expire. + self.reactor.advance(datetime.timedelta(days=8).total_seconds()) + + # Ignore all emails sent by the automatic background task and only focus on the + # ones sent manually. + self.email_attempts = [] + + # Test that we're still able to manually trigger a mail to be sent. request, channel = self.make_request( b"POST", "/_matrix/client/unstable/account_validity/send_mail", @@ -436,6 +482,7 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): self.validity_period = 10 + self.max_delta = self.validity_period * 10. / 100. config = self.default_config() @@ -453,14 +500,18 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase): def test_background_job(self): """ - Tests whether the account validity startup background job does the right thing, - which is sticking an expiration date to every account that doesn't already have - one. + Tests the same thing as test_background_job, except that it sets the + startup_job_max_delta parameter and checks that the expiration date is within the + allowed range. """ - user_id = self.register_user("kermit", "user") + user_id = self.register_user("kermit_delta", "user") + + self.hs.config.account_validity.startup_job_max_delta = self.max_delta now_ms = self.hs.clock.time_msec() self.get_success(self.store._set_expiration_date_when_missing()) res = self.get_success(self.store.get_expiration_ts_for_user(user_id)) - self.assertEqual(res, now_ms + self.validity_period) + + self.assertGreaterEqual(res, now_ms + self.validity_period - self.max_delta) + self.assertLessEqual(res, now_ms + self.validity_period) diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py index 6dda66ecd3..f4c81ef77d 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py @@ -15,7 +15,6 @@ import os.path -from synapse.api.constants import EventTypes from synapse.storage import prepare_database from synapse.types import Requester, UserID @@ -23,12 +22,12 @@ from tests.unittest import HomeserverTestCase class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): - """Test the background update to clean forward extremities table. + """ + Test the background update to clean forward extremities table. """ def prepare(self, reactor, clock, homeserver): self.store = homeserver.get_datastore() - self.event_creator = homeserver.get_event_creation_handler() self.room_creator = homeserver.get_room_creation_handler() # Create a test user and room @@ -37,56 +36,6 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): info = self.get_success(self.room_creator.create_room(self.requester, {})) self.room_id = info["room_id"] - def create_and_send_event(self, soft_failed=False, prev_event_ids=None): - """Create and send an event. - - Args: - soft_failed (bool): Whether to create a soft failed event or not - prev_event_ids (list[str]|None): Explicitly set the prev events, - or if None just use the default - - Returns: - str: The new event's ID. - """ - prev_events_and_hashes = None - if prev_event_ids: - prev_events_and_hashes = [[p, {}, 0] for p in prev_event_ids] - - event, context = self.get_success( - self.event_creator.create_event( - self.requester, - { - "type": EventTypes.Message, - "room_id": self.room_id, - "sender": self.user.to_string(), - "content": {"body": "", "msgtype": "m.text"}, - }, - prev_events_and_hashes=prev_events_and_hashes, - ) - ) - - if soft_failed: - event.internal_metadata.soft_failed = True - - self.get_success( - self.event_creator.send_nonmember_event(self.requester, event, context) - ) - - return event.event_id - - def add_extremity(self, event_id): - """Add the given event as an extremity to the room. - """ - self.get_success( - self.store._simple_insert( - table="event_forward_extremities", - values={"room_id": self.room_id, "event_id": event_id}, - desc="test_add_extremity", - ) - ) - - self.store.get_latest_event_ids_in_room.invalidate((self.room_id,)) - def run_background_update(self): """Re run the background update to clean up the extremities. """ @@ -126,10 +75,16 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): """ # Create the room graph - event_id_1 = self.create_and_send_event() - event_id_2 = self.create_and_send_event(True, [event_id_1]) - event_id_3 = self.create_and_send_event(True, [event_id_2]) - event_id_4 = self.create_and_send_event(False, [event_id_3]) + event_id_1 = self.create_and_send_event(self.room_id, self.user) + event_id_2 = self.create_and_send_event( + self.room_id, self.user, True, [event_id_1] + ) + event_id_3 = self.create_and_send_event( + self.room_id, self.user, True, [event_id_2] + ) + event_id_4 = self.create_and_send_event( + self.room_id, self.user, False, [event_id_3] + ) # Check the latest events are as expected latest_event_ids = self.get_success( @@ -149,12 +104,16 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): Where SF* are soft failed, and with extremities of A and B """ # Create the room graph - event_id_a = self.create_and_send_event() - event_id_sf1 = self.create_and_send_event(True, [event_id_a]) - event_id_b = self.create_and_send_event(False, [event_id_sf1]) + event_id_a = self.create_and_send_event(self.room_id, self.user) + event_id_sf1 = self.create_and_send_event( + self.room_id, self.user, True, [event_id_a] + ) + event_id_b = self.create_and_send_event( + self.room_id, self.user, False, [event_id_sf1] + ) # Add the new extremity and check the latest events are as expected - self.add_extremity(event_id_a) + self.add_extremity(self.room_id, event_id_a) latest_event_ids = self.get_success( self.store.get_latest_event_ids_in_room(self.room_id) @@ -180,13 +139,19 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): Where SF* are soft failed, and with extremities of A and B """ # Create the room graph - event_id_a = self.create_and_send_event() - event_id_sf1 = self.create_and_send_event(True, [event_id_a]) - event_id_sf2 = self.create_and_send_event(True, [event_id_sf1]) - event_id_b = self.create_and_send_event(False, [event_id_sf2]) + event_id_a = self.create_and_send_event(self.room_id, self.user) + event_id_sf1 = self.create_and_send_event( + self.room_id, self.user, True, [event_id_a] + ) + event_id_sf2 = self.create_and_send_event( + self.room_id, self.user, True, [event_id_sf1] + ) + event_id_b = self.create_and_send_event( + self.room_id, self.user, False, [event_id_sf2] + ) # Add the new extremity and check the latest events are as expected - self.add_extremity(event_id_a) + self.add_extremity(self.room_id, event_id_a) latest_event_ids = self.get_success( self.store.get_latest_event_ids_in_room(self.room_id) @@ -220,17 +185,28 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): Where SF* are soft failed, and with them A, B and C marked as extremities. This should resolve to B and C being marked as extremity. """ + # Create the room graph - event_id_a = self.create_and_send_event() - event_id_b = self.create_and_send_event() - event_id_sf1 = self.create_and_send_event(True, [event_id_a]) - event_id_sf2 = self.create_and_send_event(True, [event_id_a, event_id_b]) - event_id_sf3 = self.create_and_send_event(True, [event_id_sf1]) - self.create_and_send_event(True, [event_id_sf2, event_id_sf3]) # SF4 - event_id_c = self.create_and_send_event(False, [event_id_sf3]) + event_id_a = self.create_and_send_event(self.room_id, self.user) + event_id_b = self.create_and_send_event(self.room_id, self.user) + event_id_sf1 = self.create_and_send_event( + self.room_id, self.user, True, [event_id_a] + ) + event_id_sf2 = self.create_and_send_event( + self.room_id, self.user, True, [event_id_a, event_id_b] + ) + event_id_sf3 = self.create_and_send_event( + self.room_id, self.user, True, [event_id_sf1] + ) + self.create_and_send_event( + self.room_id, self.user, True, [event_id_sf2, event_id_sf3] + ) # SF4 + event_id_c = self.create_and_send_event( + self.room_id, self.user, False, [event_id_sf3] + ) # Add the new extremity and check the latest events are as expected - self.add_extremity(event_id_a) + self.add_extremity(self.room_id, event_id_a) latest_event_ids = self.get_success( self.store.get_latest_event_ids_in_room(self.room_id) diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index aef4dfaf57..6396ccddb5 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -72,6 +72,75 @@ class DeviceStoreTestCase(tests.unittest.TestCase): ) @defer.inlineCallbacks + def test_get_devices_by_remote(self): + device_ids = ["device_id1", "device_id2"] + + # Add two device updates with a single stream_id + yield self.store.add_device_change_to_streams( + "user_id", device_ids, ["somehost"], + ) + + # Get all device updates ever meant for this remote + now_stream_id, device_updates = yield self.store.get_devices_by_remote( + "somehost", -1, limit=100, + ) + + # Check original device_ids are contained within these updates + self._check_devices_in_updates(device_ids, device_updates) + + @defer.inlineCallbacks + def test_get_devices_by_remote_limited(self): + # Test breaking the update limit in 1, 101, and 1 device_id segments + + # first add one device + device_ids1 = ["device_id0"] + yield self.store.add_device_change_to_streams( + "user_id", device_ids1, ["someotherhost"], + ) + + # then add 101 + device_ids2 = ["device_id" + str(i + 1) for i in range(101)] + yield self.store.add_device_change_to_streams( + "user_id", device_ids2, ["someotherhost"], + ) + + # then one more + device_ids3 = ["newdevice"] + yield self.store.add_device_change_to_streams( + "user_id", device_ids3, ["someotherhost"], + ) + + # + # now read them back. + # + + # first we should get a single update + now_stream_id, device_updates = yield self.store.get_devices_by_remote( + "someotherhost", -1, limit=100, + ) + self._check_devices_in_updates(device_ids1, device_updates) + + # Then we should get an empty list back as the 101 devices broke the limit + now_stream_id, device_updates = yield self.store.get_devices_by_remote( + "someotherhost", now_stream_id, limit=100, + ) + self.assertEqual(len(device_updates), 0) + + # The 101 devices should've been cleared, so we should now just get one device + # update + now_stream_id, device_updates = yield self.store.get_devices_by_remote( + "someotherhost", now_stream_id, limit=100, + ) + self._check_devices_in_updates(device_ids3, device_updates) + + def _check_devices_in_updates(self, expected_device_ids, device_updates): + """Check that an specific device ids exist in a list of device update EDUs""" + self.assertEqual(len(device_updates), len(expected_device_ids)) + + received_device_ids = {update["device_id"] for update in device_updates} + self.assertEqual(received_device_ids, set(expected_device_ids)) + + @defer.inlineCallbacks def test_update_device(self): yield self.store.store_device("user_id", "device_id", "display_name 1") diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py new file mode 100644 index 0000000000..19f9ccf5e0 --- /dev/null +++ b/tests/storage/test_event_metrics.py @@ -0,0 +1,82 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from prometheus_client.exposition import generate_latest + +from synapse.metrics import REGISTRY +from synapse.types import Requester, UserID + +from tests.unittest import HomeserverTestCase + + +class ExtremStatisticsTestCase(HomeserverTestCase): + def test_exposed_to_prometheus(self): + """ + Forward extremity counts are exposed via Prometheus. + """ + room_creator = self.hs.get_room_creation_handler() + + user = UserID("alice", "test") + requester = Requester(user, None, False, None, None) + + # Real events, forward extremities + events = [(3, 2), (6, 2), (4, 6)] + + for event_count, extrems in events: + info = self.get_success(room_creator.create_room(requester, {})) + room_id = info["room_id"] + + last_event = None + + # Make a real event chain + for i in range(event_count): + ev = self.create_and_send_event(room_id, user, False, last_event) + last_event = [ev] + + # Sprinkle in some extremities + for i in range(extrems): + ev = self.create_and_send_event(room_id, user, False, last_event) + + # Let it run for a while, then pull out the statistics from the + # Prometheus client registry + self.reactor.advance(60 * 60 * 1000) + self.pump(1) + + items = set( + filter( + lambda x: b"synapse_forward_extremities_" in x, + generate_latest(REGISTRY).split(b"\n"), + ) + ) + + expected = set([ + b'synapse_forward_extremities_bucket{le="1.0"} 0.0', + b'synapse_forward_extremities_bucket{le="2.0"} 2.0', + b'synapse_forward_extremities_bucket{le="3.0"} 2.0', + b'synapse_forward_extremities_bucket{le="5.0"} 2.0', + b'synapse_forward_extremities_bucket{le="7.0"} 3.0', + b'synapse_forward_extremities_bucket{le="10.0"} 3.0', + b'synapse_forward_extremities_bucket{le="15.0"} 3.0', + b'synapse_forward_extremities_bucket{le="20.0"} 3.0', + b'synapse_forward_extremities_bucket{le="50.0"} 3.0', + b'synapse_forward_extremities_bucket{le="100.0"} 3.0', + b'synapse_forward_extremities_bucket{le="200.0"} 3.0', + b'synapse_forward_extremities_bucket{le="500.0"} 3.0', + b'synapse_forward_extremities_bucket{le="+Inf"} 3.0', + b'synapse_forward_extremities_count 3.0', + b'synapse_forward_extremities_sum 10.0', + ]) + + self.assertEqual(items, expected) diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py index 6bfaa00fe9..e07ff01201 100644 --- a/tests/storage/test_keys.py +++ b/tests/storage/test_keys.py @@ -17,6 +17,8 @@ import signedjson.key from twisted.internet.defer import Deferred +from synapse.storage.keys import FetchKeyResult + import tests.unittest KEY_1 = signedjson.key.decode_verify_key_base64( @@ -31,23 +33,34 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase): def test_get_server_verify_keys(self): store = self.hs.get_datastore() - d = store.store_server_verify_key("server1", "from_server", 0, KEY_1) - self.get_success(d) - d = store.store_server_verify_key("server1", "from_server", 0, KEY_2) + key_id_1 = "ed25519:key1" + key_id_2 = "ed25519:KEY_ID_2" + d = store.store_server_verify_keys( + "from_server", + 10, + [ + ("server1", key_id_1, FetchKeyResult(KEY_1, 100)), + ("server1", key_id_2, FetchKeyResult(KEY_2, 200)), + ], + ) self.get_success(d) d = store.get_server_verify_keys( - [ - ("server1", "ed25519:key1"), - ("server1", "ed25519:key2"), - ("server1", "ed25519:key3"), - ] + [("server1", key_id_1), ("server1", key_id_2), ("server1", "ed25519:key3")] ) res = self.get_success(d) self.assertEqual(len(res.keys()), 3) - self.assertEqual(res[("server1", "ed25519:key1")].version, "key1") - self.assertEqual(res[("server1", "ed25519:key2")].version, "key2") + res1 = res[("server1", key_id_1)] + self.assertEqual(res1.verify_key, KEY_1) + self.assertEqual(res1.verify_key.version, "key1") + self.assertEqual(res1.valid_until_ts, 100) + + res2 = res[("server1", key_id_2)] + self.assertEqual(res2.verify_key, KEY_2) + # version comes from the ID it was stored with + self.assertEqual(res2.verify_key.version, "KEY_ID_2") + self.assertEqual(res2.valid_until_ts, 200) # non-existent result gives None self.assertIsNone(res[("server1", "ed25519:key3")]) @@ -60,32 +73,51 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase): key_id_1 = "ed25519:key1" key_id_2 = "ed25519:key2" - d = store.store_server_verify_key("srv1", "from_server", 0, KEY_1) - self.get_success(d) - d = store.store_server_verify_key("srv1", "from_server", 0, KEY_2) + d = store.store_server_verify_keys( + "from_server", + 0, + [ + ("srv1", key_id_1, FetchKeyResult(KEY_1, 100)), + ("srv1", key_id_2, FetchKeyResult(KEY_2, 200)), + ], + ) self.get_success(d) d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)]) res = self.get_success(d) self.assertEqual(len(res.keys()), 2) - self.assertEqual(res[("srv1", key_id_1)], KEY_1) - self.assertEqual(res[("srv1", key_id_2)], KEY_2) + + res1 = res[("srv1", key_id_1)] + self.assertEqual(res1.verify_key, KEY_1) + self.assertEqual(res1.valid_until_ts, 100) + + res2 = res[("srv1", key_id_2)] + self.assertEqual(res2.verify_key, KEY_2) + self.assertEqual(res2.valid_until_ts, 200) # we should be able to look up the same thing again without a db hit res = store.get_server_verify_keys([("srv1", key_id_1)]) if isinstance(res, Deferred): res = self.successResultOf(res) self.assertEqual(len(res.keys()), 1) - self.assertEqual(res[("srv1", key_id_1)], KEY_1) + self.assertEqual(res[("srv1", key_id_1)].verify_key, KEY_1) new_key_2 = signedjson.key.get_verify_key( signedjson.key.generate_signing_key("key2") ) - d = store.store_server_verify_key("srv1", "from_server", 10, new_key_2) + d = store.store_server_verify_keys( + "from_server", 10, [("srv1", key_id_2, FetchKeyResult(new_key_2, 300))] + ) self.get_success(d) d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)]) res = self.get_success(d) self.assertEqual(len(res.keys()), 2) - self.assertEqual(res[("srv1", key_id_1)], KEY_1) - self.assertEqual(res[("srv1", key_id_2)], new_key_2) + + res1 = res[("srv1", key_id_1)] + self.assertEqual(res1.verify_key, KEY_1) + self.assertEqual(res1.valid_until_ts, 100) + + res2 = res[("srv1", key_id_2)] + self.assertEqual(res2.verify_key, new_key_2) + self.assertEqual(res2.valid_until_ts, 300) diff --git a/tests/unittest.py b/tests/unittest.py index ea2b7f272e..1f8b1ad64c 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -25,11 +25,12 @@ from canonicaljson import json from twisted.internet.defer import Deferred from twisted.trial import unittest +from synapse.api.constants import EventTypes from synapse.config.homeserver import HomeServerConfig from synapse.http.server import JsonResource from synapse.http.site import SynapseRequest from synapse.server import HomeServer -from synapse.types import UserID, create_requester +from synapse.types import Requester, UserID, create_requester from synapse.util.logcontext import LoggingContext from tests.server import get_clock, make_request, render, setup_test_homeserver @@ -435,3 +436,73 @@ class HomeserverTestCase(TestCase): access_token = channel.json_body["access_token"] return access_token + + def create_and_send_event( + self, room_id, user, soft_failed=False, prev_event_ids=None + ): + """ + Create and send an event. + + Args: + soft_failed (bool): Whether to create a soft failed event or not + prev_event_ids (list[str]|None): Explicitly set the prev events, + or if None just use the default + + Returns: + str: The new event's ID. + """ + event_creator = self.hs.get_event_creation_handler() + secrets = self.hs.get_secrets() + requester = Requester(user, None, False, None, None) + + prev_events_and_hashes = None + if prev_event_ids: + prev_events_and_hashes = [[p, {}, 0] for p in prev_event_ids] + + event, context = self.get_success( + event_creator.create_event( + requester, + { + "type": EventTypes.Message, + "room_id": room_id, + "sender": user.to_string(), + "content": {"body": secrets.token_hex(), "msgtype": "m.text"}, + }, + prev_events_and_hashes=prev_events_and_hashes, + ) + ) + + if soft_failed: + event.internal_metadata.soft_failed = True + + self.get_success( + event_creator.send_nonmember_event(requester, event, context) + ) + + return event.event_id + + def add_extremity(self, room_id, event_id): + """ + Add the given event as an extremity to the room. + """ + self.get_success( + self.hs.get_datastore()._simple_insert( + table="event_forward_extremities", + values={"room_id": room_id, "event_id": event_id}, + desc="test_add_extremity", + ) + ) + + self.hs.get_datastore().get_latest_event_ids_in_room.invalidate((room_id,)) + + def attempt_wrong_password_login(self, username, password): + """Attempts to login as the user with the given password, asserting + that the attempt *fails*. + """ + body = {"type": "m.login.password", "user": username, "password": password} + + request, channel = self.make_request( + "POST", "/_matrix/client/r0/login", json.dumps(body).encode('utf8') + ) + self.render(request) + self.assertEqual(channel.code, 403, channel.result) diff --git a/tests/utils.py b/tests/utils.py index 200c1ceabe..f8c7ad2604 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -31,6 +31,7 @@ from synapse.api.constants import EventTypes from synapse.api.errors import CodeMessageException, cs_error from synapse.api.room_versions import RoomVersions from synapse.config.homeserver import HomeServerConfig +from synapse.config.server import DEFAULT_ROOM_VERSION from synapse.federation.transport import server as federation_server from synapse.http.server import HttpServer from synapse.server import HomeServer @@ -131,7 +132,6 @@ def default_config(name, parse=False): "password_providers": [], "worker_replication_url": "", "worker_app": None, - "email_enable_notifs": False, "block_non_admin_invites": False, "federation_domain_whitelist": None, "filter_timeline_limit": 5000, @@ -174,7 +174,7 @@ def default_config(name, parse=False): "use_frozen_dicts": False, # We need a sane default_room_version, otherwise attempts to create # rooms will fail. - "default_room_version": "1", + "default_room_version": DEFAULT_ROOM_VERSION, # disable user directory updates, because they get done in the # background, which upsets the test runner. "update_user_directory": False, |