summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/crypto/test_keyring.py133
-rw-r--r--tests/rest/client/test_identity.py65
-rw-r--r--tests/rest/client/v2_alpha/test_register.py51
-rw-r--r--tests/storage/test_keys.py83
-rw-r--r--tests/test_state.py4
-rw-r--r--tests/unittest.py2
-rw-r--r--tests/util/test_linearizer.py2
7 files changed, 314 insertions, 26 deletions
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index c30a1a69e7..f5bd7a1aa1 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2017 New Vector Ltd.
+# Copyright 2017 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -16,6 +16,7 @@ import time
 
 from mock import Mock
 
+import canonicaljson
 import signedjson.key
 import signedjson.sign
 
@@ -23,6 +24,7 @@ from twisted.internet import defer
 
 from synapse.api.errors import SynapseError
 from synapse.crypto import keyring
+from synapse.crypto.keyring import KeyLookupError
 from synapse.util import logcontext
 from synapse.util.logcontext import LoggingContext
 
@@ -48,6 +50,9 @@ class MockPerspectiveServer(object):
                 key_id: {"key": signedjson.key.encode_verify_key_base64(verify_key)}
             },
         }
+        return self.get_signed_response(res)
+
+    def get_signed_response(self, res):
         signedjson.sign.sign_json(res, self.server_name, self.key)
         return res
 
@@ -202,6 +207,132 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         self.assertFalse(d.called)
         self.get_success(d)
 
+    def test_get_keys_from_server(self):
+        # arbitrarily advance the clock a bit
+        self.reactor.advance(100)
+
+        SERVER_NAME = "server2"
+        kr = keyring.Keyring(self.hs)
+        testkey = signedjson.key.generate_signing_key("ver1")
+        testverifykey = signedjson.key.get_verify_key(testkey)
+        testverifykey_id = "ed25519:ver1"
+        VALID_UNTIL_TS = 1000
+
+        # valid response
+        response = {
+            "server_name": SERVER_NAME,
+            "old_verify_keys": {},
+            "valid_until_ts": VALID_UNTIL_TS,
+            "verify_keys": {
+                testverifykey_id: {
+                    "key": signedjson.key.encode_verify_key_base64(testverifykey)
+                }
+            },
+        }
+        signedjson.sign.sign_json(response, SERVER_NAME, testkey)
+
+        def get_json(destination, path, **kwargs):
+            self.assertEqual(destination, SERVER_NAME)
+            self.assertEqual(path, "/_matrix/key/v2/server/key1")
+            return response
+
+        self.http_client.get_json.side_effect = get_json
+
+        server_name_and_key_ids = [(SERVER_NAME, ("key1",))]
+        keys = self.get_success(kr.get_keys_from_server(server_name_and_key_ids))
+        k = keys[SERVER_NAME][testverifykey_id]
+        self.assertEqual(k, testverifykey)
+        self.assertEqual(k.alg, "ed25519")
+        self.assertEqual(k.version, "ver1")
+
+        # check that the perspectives store is correctly updated
+        lookup_triplet = (SERVER_NAME, testverifykey_id, None)
+        key_json = self.get_success(
+            self.hs.get_datastore().get_server_keys_json([lookup_triplet])
+        )
+        res = key_json[lookup_triplet]
+        self.assertEqual(len(res), 1)
+        res = res[0]
+        self.assertEqual(res["key_id"], testverifykey_id)
+        self.assertEqual(res["from_server"], SERVER_NAME)
+        self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
+        self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS)
+
+        # we expect it to be encoded as canonical json *before* it hits the db
+        self.assertEqual(
+            bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
+        )
+
+        # change the server name: it should cause a rejection
+        response["server_name"] = "OTHER_SERVER"
+        self.get_failure(
+            kr.get_keys_from_server(server_name_and_key_ids), KeyLookupError
+        )
+
+    def test_get_keys_from_perspectives(self):
+        # arbitrarily advance the clock a bit
+        self.reactor.advance(100)
+
+        SERVER_NAME = "server2"
+        kr = keyring.Keyring(self.hs)
+        testkey = signedjson.key.generate_signing_key("ver1")
+        testverifykey = signedjson.key.get_verify_key(testkey)
+        testverifykey_id = "ed25519:ver1"
+        VALID_UNTIL_TS = 200 * 1000
+
+        # valid response
+        response = {
+            "server_name": SERVER_NAME,
+            "old_verify_keys": {},
+            "valid_until_ts": VALID_UNTIL_TS,
+            "verify_keys": {
+                testverifykey_id: {
+                    "key": signedjson.key.encode_verify_key_base64(testverifykey)
+                }
+            },
+        }
+
+        persp_resp = {
+            "server_keys": [self.mock_perspective_server.get_signed_response(response)]
+        }
+
+        def post_json(destination, path, data, **kwargs):
+            self.assertEqual(destination, self.mock_perspective_server.server_name)
+            self.assertEqual(path, "/_matrix/key/v2/query")
+
+            # check that the request is for the expected key
+            q = data["server_keys"]
+            self.assertEqual(list(q[SERVER_NAME].keys()), ["key1"])
+            return persp_resp
+
+        self.http_client.post_json.side_effect = post_json
+
+        server_name_and_key_ids = [(SERVER_NAME, ("key1",))]
+        keys = self.get_success(kr.get_keys_from_perspectives(server_name_and_key_ids))
+        self.assertIn(SERVER_NAME, keys)
+        k = keys[SERVER_NAME][testverifykey_id]
+        self.assertEqual(k, testverifykey)
+        self.assertEqual(k.alg, "ed25519")
+        self.assertEqual(k.version, "ver1")
+
+        # check that the perspectives store is correctly updated
+        lookup_triplet = (SERVER_NAME, testverifykey_id, None)
+        key_json = self.get_success(
+            self.hs.get_datastore().get_server_keys_json([lookup_triplet])
+        )
+        res = key_json[lookup_triplet]
+        self.assertEqual(len(res), 1)
+        res = res[0]
+        self.assertEqual(res["key_id"], testverifykey_id)
+        self.assertEqual(res["from_server"], self.mock_perspective_server.server_name)
+        self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
+        self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS)
+
+        self.assertEqual(
+            bytes(res["key_json"]),
+            canonicaljson.encode_canonical_json(persp_resp["server_keys"][0]),
+        )
+
 
 @defer.inlineCallbacks
 def run_in_context(f, *args, **kwargs):
diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py
new file mode 100644
index 0000000000..ca63b2e6ed
--- /dev/null
+++ b/tests/rest/client/test_identity.py
@@ -0,0 +1,65 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+
+from synapse.rest.client.v1 import admin, login, room
+
+from tests import unittest
+
+
+class IdentityTestCase(unittest.HomeserverTestCase):
+
+    servlets = [
+        admin.register_servlets,
+        room.register_servlets,
+        login.register_servlets,
+    ]
+
+    def make_homeserver(self, reactor, clock):
+
+        config = self.default_config()
+        config.enable_3pid_lookup = False
+        self.hs = self.setup_test_homeserver(config=config)
+
+        return self.hs
+
+    def test_3pid_lookup_disabled(self):
+        self.hs.config.enable_3pid_lookup = False
+
+        self.register_user("kermit", "monkey")
+        tok = self.login("kermit", "monkey")
+
+        request, channel = self.make_request(
+            b"POST", "/createRoom", b"{}", access_token=tok,
+        )
+        self.render(request)
+        self.assertEquals(channel.result["code"], b"200", channel.result)
+        room_id = channel.json_body["room_id"]
+
+        params = {
+            "id_server": "testis",
+            "medium": "email",
+            "address": "test@example.com",
+        }
+        request_data = json.dumps(params)
+        request_url = (
+            "/rooms/%s/invite" % (room_id)
+        ).encode('ascii')
+        request, channel = self.make_request(
+            b"POST", request_url, request_data, access_token=tok,
+        )
+        self.render(request)
+        self.assertEquals(channel.result["code"], b"403", channel.result)
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index a45e6e5e1f..d3611ed21f 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -1,15 +1,18 @@
+import datetime
 import json
 
 from synapse.api.constants import LoginType
+from synapse.api.errors import Codes
 from synapse.appservice import ApplicationService
-from synapse.rest.client.v2_alpha.register import register_servlets
+from synapse.rest.client.v1 import admin, login
+from synapse.rest.client.v2_alpha import register, sync
 
 from tests import unittest
 
 
 class RegisterRestServletTestCase(unittest.HomeserverTestCase):
 
-    servlets = [register_servlets]
+    servlets = [register.register_servlets]
 
     def make_homeserver(self, reactor, clock):
 
@@ -181,3 +184,47 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
         self.render(request)
 
         self.assertEquals(channel.result["code"], b"200", channel.result)
+
+
+class AccountValidityTestCase(unittest.HomeserverTestCase):
+
+    servlets = [
+        register.register_servlets,
+        admin.register_servlets,
+        login.register_servlets,
+        sync.register_servlets,
+    ]
+
+    def make_homeserver(self, reactor, clock):
+        config = self.default_config()
+        config.enable_registration = True
+        config.account_validity.enabled = True
+        config.account_validity.period = 604800000  # Time in ms for 1 week
+        self.hs = self.setup_test_homeserver(config=config)
+
+        return self.hs
+
+    def test_validity_period(self):
+        self.register_user("kermit", "monkey")
+        tok = self.login("kermit", "monkey")
+
+        # The specific endpoint doesn't matter, all we need is an authenticated
+        # endpoint.
+        request, channel = self.make_request(
+            b"GET", "/sync", access_token=tok,
+        )
+        self.render(request)
+
+        self.assertEquals(channel.result["code"], b"200", channel.result)
+
+        self.reactor.advance(datetime.timedelta(weeks=1).total_seconds())
+
+        request, channel = self.make_request(
+            b"GET", "/sync", access_token=tok,
+        )
+        self.render(request)
+
+        self.assertEquals(channel.result["code"], b"403", channel.result)
+        self.assertEquals(
+            channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result,
+        )
diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py
index 0d2dc9f325..6bfaa00fe9 100644
--- a/tests/storage/test_keys.py
+++ b/tests/storage/test_keys.py
@@ -15,34 +15,77 @@
 
 import signedjson.key
 
-from twisted.internet import defer
+from twisted.internet.defer import Deferred
 
 import tests.unittest
-import tests.utils
 
+KEY_1 = signedjson.key.decode_verify_key_base64(
+    "ed25519", "key1", "fP5l4JzpZPq/zdbBg5xx6lQGAAOM9/3w94cqiJ5jPrw"
+)
+KEY_2 = signedjson.key.decode_verify_key_base64(
+    "ed25519", "key2", "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw"
+)
 
-class KeyStoreTestCase(tests.unittest.TestCase):
 
-    @defer.inlineCallbacks
-    def setUp(self):
-        hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
-        self.store = hs.get_datastore()
-
-    @defer.inlineCallbacks
+class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
     def test_get_server_verify_keys(self):
-        key1 = signedjson.key.decode_verify_key_base64(
-            "ed25519", "key1", "fP5l4JzpZPq/zdbBg5xx6lQGAAOM9/3w94cqiJ5jPrw"
-        )
-        key2 = signedjson.key.decode_verify_key_base64(
-            "ed25519", "key2", "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw"
+        store = self.hs.get_datastore()
+
+        d = store.store_server_verify_key("server1", "from_server", 0, KEY_1)
+        self.get_success(d)
+        d = store.store_server_verify_key("server1", "from_server", 0, KEY_2)
+        self.get_success(d)
+
+        d = store.get_server_verify_keys(
+            [
+                ("server1", "ed25519:key1"),
+                ("server1", "ed25519:key2"),
+                ("server1", "ed25519:key3"),
+            ]
         )
-        yield self.store.store_server_verify_key("server1", "from_server", 0, key1)
-        yield self.store.store_server_verify_key("server1", "from_server", 0, key2)
+        res = self.get_success(d)
+
+        self.assertEqual(len(res.keys()), 3)
+        self.assertEqual(res[("server1", "ed25519:key1")].version, "key1")
+        self.assertEqual(res[("server1", "ed25519:key2")].version, "key2")
+
+        # non-existent result gives None
+        self.assertIsNone(res[("server1", "ed25519:key3")])
+
+    def test_cache(self):
+        """Check that updates correctly invalidate the cache."""
+
+        store = self.hs.get_datastore()
+
+        key_id_1 = "ed25519:key1"
+        key_id_2 = "ed25519:key2"
+
+        d = store.store_server_verify_key("srv1", "from_server", 0, KEY_1)
+        self.get_success(d)
+        d = store.store_server_verify_key("srv1", "from_server", 0, KEY_2)
+        self.get_success(d)
+
+        d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)])
+        res = self.get_success(d)
+        self.assertEqual(len(res.keys()), 2)
+        self.assertEqual(res[("srv1", key_id_1)], KEY_1)
+        self.assertEqual(res[("srv1", key_id_2)], KEY_2)
+
+        # we should be able to look up the same thing again without a db hit
+        res = store.get_server_verify_keys([("srv1", key_id_1)])
+        if isinstance(res, Deferred):
+            res = self.successResultOf(res)
+        self.assertEqual(len(res.keys()), 1)
+        self.assertEqual(res[("srv1", key_id_1)], KEY_1)
 
-        res = yield self.store.get_server_verify_keys(
-            "server1", ["ed25519:key1", "ed25519:key2", "ed25519:key3"]
+        new_key_2 = signedjson.key.get_verify_key(
+            signedjson.key.generate_signing_key("key2")
         )
+        d = store.store_server_verify_key("srv1", "from_server", 10, new_key_2)
+        self.get_success(d)
 
+        d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)])
+        res = self.get_success(d)
         self.assertEqual(len(res.keys()), 2)
-        self.assertEqual(res["ed25519:key1"].version, "key1")
-        self.assertEqual(res["ed25519:key2"].version, "key2")
+        self.assertEqual(res[("srv1", key_id_1)], KEY_1)
+        self.assertEqual(res[("srv1", key_id_2)], new_key_2)
diff --git a/tests/test_state.py b/tests/test_state.py
index 03e4810c2e..5bcc6aaa18 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -25,7 +25,7 @@ from synapse.state import StateHandler, StateResolutionHandler
 
 from tests import unittest
 
-from .utils import MockClock
+from .utils import MockClock, default_config
 
 _next_event_id = 1000
 
@@ -160,6 +160,7 @@ class StateTestCase(unittest.TestCase):
         self.store = StateGroupStore()
         hs = Mock(
             spec_set=[
+                "config",
                 "get_datastore",
                 "get_auth",
                 "get_state_handler",
@@ -167,6 +168,7 @@ class StateTestCase(unittest.TestCase):
                 "get_state_resolution_handler",
             ]
         )
+        hs.config = default_config("tesths")
         hs.get_datastore.return_value = self.store
         hs.get_state_handler.return_value = None
         hs.get_clock.return_value = MockClock()
diff --git a/tests/unittest.py b/tests/unittest.py
index 27403de908..8c65736a51 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -410,7 +410,7 @@ class HomeserverTestCase(TestCase):
             "POST", "/_matrix/client/r0/login", json.dumps(body).encode('utf8')
         )
         self.render(request)
-        self.assertEqual(channel.code, 200)
+        self.assertEqual(channel.code, 200, channel.result)
 
         access_token = channel.json_body["access_token"]
         return access_token
diff --git a/tests/util/test_linearizer.py b/tests/util/test_linearizer.py
index 61a55b461b..ec7ba9719c 100644
--- a/tests/util/test_linearizer.py
+++ b/tests/util/test_linearizer.py
@@ -1,6 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2016 OpenMarket Ltd
-# Copyright 2018 New Vector Ltd.
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.