summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--tests/crypto/test_keyring.py229
-rw-r--r--tests/handlers/test_device.py3
-rw-r--r--tests/storage/test_client_ips.py5
-rw-r--r--tests/test_dns.py26
-rw-r--r--tests/utils.py1
5 files changed, 254 insertions, 10 deletions
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
new file mode 100644
index 0000000000..570312da84
--- /dev/null
+++ b/tests/crypto/test_keyring.py
@@ -0,0 +1,229 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 New Vector Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import time
+
+import signedjson.key
+import signedjson.sign
+from mock import Mock
+from synapse.api.errors import SynapseError
+from synapse.crypto import keyring
+from synapse.util import async, logcontext
+from synapse.util.logcontext import LoggingContext
+from tests import unittest, utils
+from twisted.internet import defer
+
+
+class MockPerspectiveServer(object):
+    def __init__(self):
+        self.server_name = "mock_server"
+        self.key = signedjson.key.generate_signing_key(0)
+
+    def get_verify_keys(self):
+        vk = signedjson.key.get_verify_key(self.key)
+        return {
+            "%s:%s" % (vk.alg, vk.version): vk,
+        }
+
+    def get_signed_key(self, server_name, verify_key):
+        key_id = "%s:%s" % (verify_key.alg, verify_key.version)
+        res = {
+            "server_name": server_name,
+            "old_verify_keys": {},
+            "valid_until_ts": time.time() * 1000 + 3600,
+            "verify_keys": {
+                key_id: {
+                    "key": signedjson.key.encode_verify_key_base64(verify_key)
+                }
+            }
+        }
+        signedjson.sign.sign_json(res, self.server_name, self.key)
+        return res
+
+
+class KeyringTestCase(unittest.TestCase):
+    @defer.inlineCallbacks
+    def setUp(self):
+        self.mock_perspective_server = MockPerspectiveServer()
+        self.http_client = Mock()
+        self.hs = yield utils.setup_test_homeserver(
+            handlers=None,
+            http_client=self.http_client,
+        )
+        self.hs.config.perspectives = {
+            self.mock_perspective_server.server_name:
+                self.mock_perspective_server.get_verify_keys()
+        }
+
+    def check_context(self, _, expected):
+        self.assertEquals(
+            getattr(LoggingContext.current_context(), "test_key", None),
+            expected
+        )
+
+    @defer.inlineCallbacks
+    def test_wait_for_previous_lookups(self):
+        sentinel_context = LoggingContext.current_context()
+
+        kr = keyring.Keyring(self.hs)
+
+        lookup_1_deferred = defer.Deferred()
+        lookup_2_deferred = defer.Deferred()
+
+        with LoggingContext("one") as context_one:
+            context_one.test_key = "one"
+
+            wait_1_deferred = kr.wait_for_previous_lookups(
+                ["server1"],
+                {"server1": lookup_1_deferred},
+            )
+
+            # there were no previous lookups, so the deferred should be ready
+            self.assertTrue(wait_1_deferred.called)
+            # ... so we should have preserved the LoggingContext.
+            self.assertIs(LoggingContext.current_context(), context_one)
+            wait_1_deferred.addBoth(self.check_context, "one")
+
+        with LoggingContext("two") as context_two:
+            context_two.test_key = "two"
+
+            # set off another wait. It should block because the first lookup
+            # hasn't yet completed.
+            wait_2_deferred = kr.wait_for_previous_lookups(
+                ["server1"],
+                {"server1": lookup_2_deferred},
+            )
+            self.assertFalse(wait_2_deferred.called)
+            # ... so we should have reset the LoggingContext.
+            self.assertIs(LoggingContext.current_context(), sentinel_context)
+            wait_2_deferred.addBoth(self.check_context, "two")
+
+            # let the first lookup complete (in the sentinel context)
+            lookup_1_deferred.callback(None)
+
+            # now the second wait should complete and restore our
+            # loggingcontext.
+            yield wait_2_deferred
+
+    @defer.inlineCallbacks
+    def test_verify_json_objects_for_server_awaits_previous_requests(self):
+        key1 = signedjson.key.generate_signing_key(1)
+
+        kr = keyring.Keyring(self.hs)
+        json1 = {}
+        signedjson.sign.sign_json(json1, "server10", key1)
+
+        persp_resp = {
+            "server_keys": [
+                self.mock_perspective_server.get_signed_key(
+                    "server10",
+                    signedjson.key.get_verify_key(key1)
+                ),
+            ]
+        }
+        persp_deferred = defer.Deferred()
+
+        @defer.inlineCallbacks
+        def get_perspectives(**kwargs):
+            self.assertEquals(
+                LoggingContext.current_context().test_key, "11",
+            )
+            with logcontext.PreserveLoggingContext():
+                yield persp_deferred
+            defer.returnValue(persp_resp)
+        self.http_client.post_json.side_effect = get_perspectives
+
+        with LoggingContext("11") as context_11:
+            context_11.test_key = "11"
+
+            # start off a first set of lookups
+            res_deferreds = kr.verify_json_objects_for_server(
+                [("server10", json1),
+                 ("server11", {})
+                 ]
+            )
+
+            # the unsigned json should be rejected pretty quickly
+            self.assertTrue(res_deferreds[1].called)
+            try:
+                yield res_deferreds[1]
+                self.assertFalse("unsigned json didn't cause a failure")
+            except SynapseError:
+                pass
+
+            self.assertFalse(res_deferreds[0].called)
+            res_deferreds[0].addBoth(self.check_context, None)
+
+            # wait a tick for it to send the request to the perspectives server
+            # (it first tries the datastore)
+            yield async.sleep(0.005)
+            self.http_client.post_json.assert_called_once()
+
+            self.assertIs(LoggingContext.current_context(), context_11)
+
+            context_12 = LoggingContext("12")
+            context_12.test_key = "12"
+            with logcontext.PreserveLoggingContext(context_12):
+                # a second request for a server with outstanding requests
+                # should block rather than start a second call
+                self.http_client.post_json.reset_mock()
+                self.http_client.post_json.return_value = defer.Deferred()
+
+                res_deferreds_2 = kr.verify_json_objects_for_server(
+                    [("server10", json1)],
+                )
+                yield async.sleep(0.005)
+                self.http_client.post_json.assert_not_called()
+                res_deferreds_2[0].addBoth(self.check_context, None)
+
+            # complete the first request
+            with logcontext.PreserveLoggingContext():
+                persp_deferred.callback(persp_resp)
+            self.assertIs(LoggingContext.current_context(), context_11)
+
+            with logcontext.PreserveLoggingContext():
+                yield res_deferreds[0]
+                yield res_deferreds_2[0]
+
+    @defer.inlineCallbacks
+    def test_verify_json_for_server(self):
+        kr = keyring.Keyring(self.hs)
+
+        key1 = signedjson.key.generate_signing_key(1)
+        yield self.hs.datastore.store_server_verify_key(
+            "server9", "", time.time() * 1000,
+            signedjson.key.get_verify_key(key1),
+        )
+        json1 = {}
+        signedjson.sign.sign_json(json1, "server9", key1)
+
+        sentinel_context = LoggingContext.current_context()
+
+        with LoggingContext("one") as context_one:
+            context_one.test_key = "one"
+
+            defer = kr.verify_json_for_server("server9", {})
+            try:
+                yield defer
+                self.fail("should fail on unsigned json")
+            except SynapseError:
+                pass
+            self.assertIs(LoggingContext.current_context(), context_one)
+
+            defer = kr.verify_json_for_server("server9", json1)
+            self.assertFalse(defer.called)
+            self.assertIs(LoggingContext.current_context(), sentinel_context)
+            yield defer
+
+            self.assertIs(LoggingContext.current_context(), context_one)
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index 2eaaa8253c..778ff2f6e9 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -19,7 +19,6 @@ import synapse.api.errors
 import synapse.handlers.device
 
 import synapse.storage
-from synapse import types
 from tests import unittest, utils
 
 user1 = "@boris:aaa"
@@ -179,6 +178,6 @@ class DeviceTestCase(unittest.TestCase):
 
         if ip is not None:
             yield self.store.insert_client_ip(
-                types.UserID.from_string(user_id),
+                user_id,
                 access_token, ip, "user_agent", device_id)
             self.clock.advance_time(1000)
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index 03df697575..bd6fda6cb1 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -15,9 +15,6 @@
 
 from twisted.internet import defer
 
-import synapse.server
-import synapse.storage
-import synapse.types
 import tests.unittest
 import tests.utils
 
@@ -39,7 +36,7 @@ class ClientIpStoreTestCase(tests.unittest.TestCase):
         self.clock.now = 12345678
         user_id = "@user:id"
         yield self.store.insert_client_ip(
-            synapse.types.UserID.from_string(user_id),
+            user_id,
             "access_token", "ip", "user_agent", "device_id",
         )
 
diff --git a/tests/test_dns.py b/tests/test_dns.py
index c394c57ee7..d08b0f4333 100644
--- a/tests/test_dns.py
+++ b/tests/test_dns.py
@@ -24,15 +24,17 @@ from synapse.http.endpoint import resolve_service
 from tests.utils import MockClock
 
 
+@unittest.DEBUG
 class DnsTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_resolve(self):
         dns_client_mock = Mock()
 
-        service_name = "test_service.examle.com"
+        service_name = "test_service.example.com"
         host_name = "example.com"
         ip_address = "127.0.0.1"
+        ip6_address = "::1"
 
         answer_srv = dns.RRHeader(
             type=dns.SRV,
@@ -48,8 +50,22 @@ class DnsTestCase(unittest.TestCase):
             )
         )
 
-        dns_client_mock.lookupService.return_value = ([answer_srv], None, None)
-        dns_client_mock.lookupAddress.return_value = ([answer_a], None, None)
+        answer_aaaa = dns.RRHeader(
+            type=dns.AAAA,
+            payload=dns.Record_AAAA(
+                address=ip6_address,
+            )
+        )
+
+        dns_client_mock.lookupService.return_value = defer.succeed(
+            ([answer_srv], None, None),
+        )
+        dns_client_mock.lookupAddress.return_value = defer.succeed(
+            ([answer_a], None, None),
+        )
+        dns_client_mock.lookupIPV6Address.return_value = defer.succeed(
+            ([answer_aaaa], None, None),
+        )
 
         cache = {}
 
@@ -59,10 +75,12 @@ class DnsTestCase(unittest.TestCase):
 
         dns_client_mock.lookupService.assert_called_once_with(service_name)
         dns_client_mock.lookupAddress.assert_called_once_with(host_name)
+        dns_client_mock.lookupIPV6Address.assert_called_once_with(host_name)
 
-        self.assertEquals(len(servers), 1)
+        self.assertEquals(len(servers), 2)
         self.assertEquals(servers, cache[service_name])
         self.assertEquals(servers[0].host, ip_address)
+        self.assertEquals(servers[1].host, ip6_address)
 
     @defer.inlineCallbacks
     def test_from_cache_expired_and_dns_fail(self):
diff --git a/tests/utils.py b/tests/utils.py
index 4f7e32b3ab..3c81a3e16d 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -56,6 +56,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
         config.worker_replication_url = ""
         config.worker_app = None
         config.email_enable_notifs = False
+        config.block_non_admin_invites = False
 
     config.use_frozen_dicts = True
     config.database_config = {"name": "sqlite3"}