diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/crypto/test_keyring.py | 229 | ||||
-rw-r--r-- | tests/handlers/test_device.py | 3 | ||||
-rw-r--r-- | tests/storage/test_client_ips.py | 5 | ||||
-rw-r--r-- | tests/test_dns.py | 26 | ||||
-rw-r--r-- | tests/utils.py | 1 |
5 files changed, 254 insertions, 10 deletions
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py new file mode 100644 index 0000000000..570312da84 --- /dev/null +++ b/tests/crypto/test_keyring.py @@ -0,0 +1,229 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 New Vector Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import time + +import signedjson.key +import signedjson.sign +from mock import Mock +from synapse.api.errors import SynapseError +from synapse.crypto import keyring +from synapse.util import async, logcontext +from synapse.util.logcontext import LoggingContext +from tests import unittest, utils +from twisted.internet import defer + + +class MockPerspectiveServer(object): + def __init__(self): + self.server_name = "mock_server" + self.key = signedjson.key.generate_signing_key(0) + + def get_verify_keys(self): + vk = signedjson.key.get_verify_key(self.key) + return { + "%s:%s" % (vk.alg, vk.version): vk, + } + + def get_signed_key(self, server_name, verify_key): + key_id = "%s:%s" % (verify_key.alg, verify_key.version) + res = { + "server_name": server_name, + "old_verify_keys": {}, + "valid_until_ts": time.time() * 1000 + 3600, + "verify_keys": { + key_id: { + "key": signedjson.key.encode_verify_key_base64(verify_key) + } + } + } + signedjson.sign.sign_json(res, self.server_name, self.key) + return res + + +class KeyringTestCase(unittest.TestCase): + @defer.inlineCallbacks + def setUp(self): + self.mock_perspective_server = MockPerspectiveServer() + self.http_client = Mock() + self.hs = yield utils.setup_test_homeserver( + handlers=None, + http_client=self.http_client, + ) + self.hs.config.perspectives = { + self.mock_perspective_server.server_name: + self.mock_perspective_server.get_verify_keys() + } + + def check_context(self, _, expected): + self.assertEquals( + getattr(LoggingContext.current_context(), "test_key", None), + expected + ) + + @defer.inlineCallbacks + def test_wait_for_previous_lookups(self): + sentinel_context = LoggingContext.current_context() + + kr = keyring.Keyring(self.hs) + + lookup_1_deferred = defer.Deferred() + lookup_2_deferred = defer.Deferred() + + with LoggingContext("one") as context_one: + context_one.test_key = "one" + + wait_1_deferred = kr.wait_for_previous_lookups( + ["server1"], + {"server1": lookup_1_deferred}, + ) + + # there were no previous lookups, so the deferred should be ready + self.assertTrue(wait_1_deferred.called) + # ... so we should have preserved the LoggingContext. + self.assertIs(LoggingContext.current_context(), context_one) + wait_1_deferred.addBoth(self.check_context, "one") + + with LoggingContext("two") as context_two: + context_two.test_key = "two" + + # set off another wait. It should block because the first lookup + # hasn't yet completed. + wait_2_deferred = kr.wait_for_previous_lookups( + ["server1"], + {"server1": lookup_2_deferred}, + ) + self.assertFalse(wait_2_deferred.called) + # ... so we should have reset the LoggingContext. + self.assertIs(LoggingContext.current_context(), sentinel_context) + wait_2_deferred.addBoth(self.check_context, "two") + + # let the first lookup complete (in the sentinel context) + lookup_1_deferred.callback(None) + + # now the second wait should complete and restore our + # loggingcontext. + yield wait_2_deferred + + @defer.inlineCallbacks + def test_verify_json_objects_for_server_awaits_previous_requests(self): + key1 = signedjson.key.generate_signing_key(1) + + kr = keyring.Keyring(self.hs) + json1 = {} + signedjson.sign.sign_json(json1, "server10", key1) + + persp_resp = { + "server_keys": [ + self.mock_perspective_server.get_signed_key( + "server10", + signedjson.key.get_verify_key(key1) + ), + ] + } + persp_deferred = defer.Deferred() + + @defer.inlineCallbacks + def get_perspectives(**kwargs): + self.assertEquals( + LoggingContext.current_context().test_key, "11", + ) + with logcontext.PreserveLoggingContext(): + yield persp_deferred + defer.returnValue(persp_resp) + self.http_client.post_json.side_effect = get_perspectives + + with LoggingContext("11") as context_11: + context_11.test_key = "11" + + # start off a first set of lookups + res_deferreds = kr.verify_json_objects_for_server( + [("server10", json1), + ("server11", {}) + ] + ) + + # the unsigned json should be rejected pretty quickly + self.assertTrue(res_deferreds[1].called) + try: + yield res_deferreds[1] + self.assertFalse("unsigned json didn't cause a failure") + except SynapseError: + pass + + self.assertFalse(res_deferreds[0].called) + res_deferreds[0].addBoth(self.check_context, None) + + # wait a tick for it to send the request to the perspectives server + # (it first tries the datastore) + yield async.sleep(0.005) + self.http_client.post_json.assert_called_once() + + self.assertIs(LoggingContext.current_context(), context_11) + + context_12 = LoggingContext("12") + context_12.test_key = "12" + with logcontext.PreserveLoggingContext(context_12): + # a second request for a server with outstanding requests + # should block rather than start a second call + self.http_client.post_json.reset_mock() + self.http_client.post_json.return_value = defer.Deferred() + + res_deferreds_2 = kr.verify_json_objects_for_server( + [("server10", json1)], + ) + yield async.sleep(0.005) + self.http_client.post_json.assert_not_called() + res_deferreds_2[0].addBoth(self.check_context, None) + + # complete the first request + with logcontext.PreserveLoggingContext(): + persp_deferred.callback(persp_resp) + self.assertIs(LoggingContext.current_context(), context_11) + + with logcontext.PreserveLoggingContext(): + yield res_deferreds[0] + yield res_deferreds_2[0] + + @defer.inlineCallbacks + def test_verify_json_for_server(self): + kr = keyring.Keyring(self.hs) + + key1 = signedjson.key.generate_signing_key(1) + yield self.hs.datastore.store_server_verify_key( + "server9", "", time.time() * 1000, + signedjson.key.get_verify_key(key1), + ) + json1 = {} + signedjson.sign.sign_json(json1, "server9", key1) + + sentinel_context = LoggingContext.current_context() + + with LoggingContext("one") as context_one: + context_one.test_key = "one" + + defer = kr.verify_json_for_server("server9", {}) + try: + yield defer + self.fail("should fail on unsigned json") + except SynapseError: + pass + self.assertIs(LoggingContext.current_context(), context_one) + + defer = kr.verify_json_for_server("server9", json1) + self.assertFalse(defer.called) + self.assertIs(LoggingContext.current_context(), sentinel_context) + yield defer + + self.assertIs(LoggingContext.current_context(), context_one) diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index 2eaaa8253c..778ff2f6e9 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -19,7 +19,6 @@ import synapse.api.errors import synapse.handlers.device import synapse.storage -from synapse import types from tests import unittest, utils user1 = "@boris:aaa" @@ -179,6 +178,6 @@ class DeviceTestCase(unittest.TestCase): if ip is not None: yield self.store.insert_client_ip( - types.UserID.from_string(user_id), + user_id, access_token, ip, "user_agent", device_id) self.clock.advance_time(1000) diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index 03df697575..bd6fda6cb1 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -15,9 +15,6 @@ from twisted.internet import defer -import synapse.server -import synapse.storage -import synapse.types import tests.unittest import tests.utils @@ -39,7 +36,7 @@ class ClientIpStoreTestCase(tests.unittest.TestCase): self.clock.now = 12345678 user_id = "@user:id" yield self.store.insert_client_ip( - synapse.types.UserID.from_string(user_id), + user_id, "access_token", "ip", "user_agent", "device_id", ) diff --git a/tests/test_dns.py b/tests/test_dns.py index c394c57ee7..d08b0f4333 100644 --- a/tests/test_dns.py +++ b/tests/test_dns.py @@ -24,15 +24,17 @@ from synapse.http.endpoint import resolve_service from tests.utils import MockClock +@unittest.DEBUG class DnsTestCase(unittest.TestCase): @defer.inlineCallbacks def test_resolve(self): dns_client_mock = Mock() - service_name = "test_service.examle.com" + service_name = "test_service.example.com" host_name = "example.com" ip_address = "127.0.0.1" + ip6_address = "::1" answer_srv = dns.RRHeader( type=dns.SRV, @@ -48,8 +50,22 @@ class DnsTestCase(unittest.TestCase): ) ) - dns_client_mock.lookupService.return_value = ([answer_srv], None, None) - dns_client_mock.lookupAddress.return_value = ([answer_a], None, None) + answer_aaaa = dns.RRHeader( + type=dns.AAAA, + payload=dns.Record_AAAA( + address=ip6_address, + ) + ) + + dns_client_mock.lookupService.return_value = defer.succeed( + ([answer_srv], None, None), + ) + dns_client_mock.lookupAddress.return_value = defer.succeed( + ([answer_a], None, None), + ) + dns_client_mock.lookupIPV6Address.return_value = defer.succeed( + ([answer_aaaa], None, None), + ) cache = {} @@ -59,10 +75,12 @@ class DnsTestCase(unittest.TestCase): dns_client_mock.lookupService.assert_called_once_with(service_name) dns_client_mock.lookupAddress.assert_called_once_with(host_name) + dns_client_mock.lookupIPV6Address.assert_called_once_with(host_name) - self.assertEquals(len(servers), 1) + self.assertEquals(len(servers), 2) self.assertEquals(servers, cache[service_name]) self.assertEquals(servers[0].host, ip_address) + self.assertEquals(servers[1].host, ip6_address) @defer.inlineCallbacks def test_from_cache_expired_and_dns_fail(self): diff --git a/tests/utils.py b/tests/utils.py index 4f7e32b3ab..3c81a3e16d 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -56,6 +56,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs): config.worker_replication_url = "" config.worker_app = None config.email_enable_notifs = False + config.block_non_admin_invites = False config.use_frozen_dicts = True config.database_config = {"name": "sqlite3"} |