summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/replication/slave/storage/test_events.py8
-rw-r--r--tests/rest/client/v1/test_admin.py305
-rw-r--r--tests/test_federation.py3
-rw-r--r--tests/test_state.py47
-rw-r--r--tests/util/test_limiter.py70
-rw-r--r--tests/util/test_linearizer.py78
-rw-r--r--tests/utils.py3
7 files changed, 427 insertions, 87 deletions
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index cea01d93eb..f5b47f5ec0 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -222,9 +222,11 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
             state_ids = {
                 key: e.event_id for key, e in state.items()
             }
-            context = EventContext()
-            context.current_state_ids = state_ids
-            context.prev_state_ids = state_ids
+            context = EventContext.with_state(
+                state_group=None,
+                current_state_ids=state_ids,
+                prev_state_ids=state_ids
+            )
         else:
             state_handler = self.hs.get_state_handler()
             context = yield state_handler.compute_event_context(event)
diff --git a/tests/rest/client/v1/test_admin.py b/tests/rest/client/v1/test_admin.py
new file mode 100644
index 0000000000..8c90145601
--- /dev/null
+++ b/tests/rest/client/v1/test_admin.py
@@ -0,0 +1,305 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import hashlib
+import hmac
+import json
+
+from mock import Mock
+
+from synapse.http.server import JsonResource
+from synapse.rest.client.v1.admin import register_servlets
+from synapse.util import Clock
+
+from tests import unittest
+from tests.server import (
+    ThreadedMemoryReactorClock,
+    make_request,
+    render,
+    setup_test_homeserver,
+)
+
+
+class UserRegisterTestCase(unittest.TestCase):
+    def setUp(self):
+
+        self.clock = ThreadedMemoryReactorClock()
+        self.hs_clock = Clock(self.clock)
+        self.url = "/_matrix/client/r0/admin/register"
+
+        self.registration_handler = Mock()
+        self.identity_handler = Mock()
+        self.login_handler = Mock()
+        self.device_handler = Mock()
+        self.device_handler.check_device_registered = Mock(return_value="FAKE")
+
+        self.datastore = Mock(return_value=Mock())
+        self.datastore.get_current_state_deltas = Mock(return_value=[])
+
+        self.secrets = Mock()
+
+        self.hs = setup_test_homeserver(
+            http_client=None, clock=self.hs_clock, reactor=self.clock
+        )
+
+        self.hs.config.registration_shared_secret = u"shared"
+
+        self.hs.get_media_repository = Mock()
+        self.hs.get_deactivate_account_handler = Mock()
+
+        self.resource = JsonResource(self.hs)
+        register_servlets(self.hs, self.resource)
+
+    def test_disabled(self):
+        """
+        If there is no shared secret, registration through this method will be
+        prevented.
+        """
+        self.hs.config.registration_shared_secret = None
+
+        request, channel = make_request("POST", self.url, b'{}')
+        render(request, self.resource, self.clock)
+
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(
+            'Shared secret registration is not enabled', channel.json_body["error"]
+        )
+
+    def test_get_nonce(self):
+        """
+        Calling GET on the endpoint will return a randomised nonce, using the
+        homeserver's secrets provider.
+        """
+        secrets = Mock()
+        secrets.token_hex = Mock(return_value="abcd")
+
+        self.hs.get_secrets = Mock(return_value=secrets)
+
+        request, channel = make_request("GET", self.url)
+        render(request, self.resource, self.clock)
+
+        self.assertEqual(channel.json_body, {"nonce": "abcd"})
+
+    def test_expired_nonce(self):
+        """
+        Calling GET on the endpoint will return a randomised nonce, which will
+        only last for SALT_TIMEOUT (60s).
+        """
+        request, channel = make_request("GET", self.url)
+        render(request, self.resource, self.clock)
+        nonce = channel.json_body["nonce"]
+
+        # 59 seconds
+        self.clock.advance(59)
+
+        body = json.dumps({"nonce": nonce})
+        request, channel = make_request("POST", self.url, body.encode('utf8'))
+        render(request, self.resource, self.clock)
+
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual('username must be specified', channel.json_body["error"])
+
+        # 61 seconds
+        self.clock.advance(2)
+
+        request, channel = make_request("POST", self.url, body.encode('utf8'))
+        render(request, self.resource, self.clock)
+
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual('unrecognised nonce', channel.json_body["error"])
+
+    def test_register_incorrect_nonce(self):
+        """
+        Only the provided nonce can be used, as it's checked in the MAC.
+        """
+        request, channel = make_request("GET", self.url)
+        render(request, self.resource, self.clock)
+        nonce = channel.json_body["nonce"]
+
+        want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
+        want_mac.update(b"notthenonce\x00bob\x00abc123\x00admin")
+        want_mac = want_mac.hexdigest()
+
+        body = json.dumps(
+            {
+                "nonce": nonce,
+                "username": "bob",
+                "password": "abc123",
+                "admin": True,
+                "mac": want_mac,
+            }
+        ).encode('utf8')
+        request, channel = make_request("POST", self.url, body.encode('utf8'))
+        render(request, self.resource, self.clock)
+
+        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual("HMAC incorrect", channel.json_body["error"])
+
+    def test_register_correct_nonce(self):
+        """
+        When the correct nonce is provided, and the right key is provided, the
+        user is registered.
+        """
+        request, channel = make_request("GET", self.url)
+        render(request, self.resource, self.clock)
+        nonce = channel.json_body["nonce"]
+
+        want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
+        want_mac.update(nonce.encode('ascii') + b"\x00bob\x00abc123\x00admin")
+        want_mac = want_mac.hexdigest()
+
+        body = json.dumps(
+            {
+                "nonce": nonce,
+                "username": "bob",
+                "password": "abc123",
+                "admin": True,
+                "mac": want_mac,
+            }
+        ).encode('utf8')
+        request, channel = make_request("POST", self.url, body.encode('utf8'))
+        render(request, self.resource, self.clock)
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual("@bob:test", channel.json_body["user_id"])
+
+    def test_nonce_reuse(self):
+        """
+        A valid unrecognised nonce.
+        """
+        request, channel = make_request("GET", self.url)
+        render(request, self.resource, self.clock)
+        nonce = channel.json_body["nonce"]
+
+        want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
+        want_mac.update(nonce.encode('ascii') + b"\x00bob\x00abc123\x00admin")
+        want_mac = want_mac.hexdigest()
+
+        body = json.dumps(
+            {
+                "nonce": nonce,
+                "username": "bob",
+                "password": "abc123",
+                "admin": True,
+                "mac": want_mac,
+            }
+        ).encode('utf8')
+        request, channel = make_request("POST", self.url, body.encode('utf8'))
+        render(request, self.resource, self.clock)
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual("@bob:test", channel.json_body["user_id"])
+
+        # Now, try and reuse it
+        request, channel = make_request("POST", self.url, body.encode('utf8'))
+        render(request, self.resource, self.clock)
+
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual('unrecognised nonce', channel.json_body["error"])
+
+    def test_missing_parts(self):
+        """
+        Synapse will complain if you don't give nonce, username, password, and
+        mac.  Admin is optional.  Additional checks are done for length and
+        type.
+        """
+        def nonce():
+            request, channel = make_request("GET", self.url)
+            render(request, self.resource, self.clock)
+            return channel.json_body["nonce"]
+
+        #
+        # Nonce check
+        #
+
+        # Must be present
+        body = json.dumps({})
+        request, channel = make_request("POST", self.url, body.encode('utf8'))
+        render(request, self.resource, self.clock)
+
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual('nonce must be specified', channel.json_body["error"])
+
+        #
+        # Username checks
+        #
+
+        # Must be present
+        body = json.dumps({"nonce": nonce()})
+        request, channel = make_request("POST", self.url, body.encode('utf8'))
+        render(request, self.resource, self.clock)
+
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual('username must be specified', channel.json_body["error"])
+
+        # Must be a string
+        body = json.dumps({"nonce": nonce(), "username": 1234})
+        request, channel = make_request("POST", self.url, body.encode('utf8'))
+        render(request, self.resource, self.clock)
+
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual('Invalid username', channel.json_body["error"])
+
+        # Must not have null bytes
+        body = json.dumps({"nonce": nonce(), "username": b"abcd\x00"})
+        request, channel = make_request("POST", self.url, body.encode('utf8'))
+        render(request, self.resource, self.clock)
+
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual('Invalid username', channel.json_body["error"])
+
+        # Must not have null bytes
+        body = json.dumps({"nonce": nonce(), "username": "a" * 1000})
+        request, channel = make_request("POST", self.url, body.encode('utf8'))
+        render(request, self.resource, self.clock)
+
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual('Invalid username', channel.json_body["error"])
+
+        #
+        # Username checks
+        #
+
+        # Must be present
+        body = json.dumps({"nonce": nonce(), "username": "a"})
+        request, channel = make_request("POST", self.url, body.encode('utf8'))
+        render(request, self.resource, self.clock)
+
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual('password must be specified', channel.json_body["error"])
+
+        # Must be a string
+        body = json.dumps({"nonce": nonce(), "username": "a", "password": 1234})
+        request, channel = make_request("POST", self.url, body.encode('utf8'))
+        render(request, self.resource, self.clock)
+
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual('Invalid password', channel.json_body["error"])
+
+        # Must not have null bytes
+        body = json.dumps({"nonce": nonce(), "username": "a", "password": b"abcd\x00"})
+        request, channel = make_request("POST", self.url, body.encode('utf8'))
+        render(request, self.resource, self.clock)
+
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual('Invalid password', channel.json_body["error"])
+
+        # Super long
+        body = json.dumps({"nonce": nonce(), "username": "a", "password": "A" * 1000})
+        request, channel = make_request("POST", self.url, body.encode('utf8'))
+        render(request, self.resource, self.clock)
+
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual('Invalid password', channel.json_body["error"])
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 159a136971..f40ff29b52 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -137,7 +137,6 @@ class MessageAcceptTests(unittest.TestCase):
         )
         self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv")
 
-    @unittest.DEBUG
     def test_cant_hide_past_history(self):
         """
         If you send a message, you must be able to provide the direct
@@ -178,7 +177,7 @@ class MessageAcceptTests(unittest.TestCase):
                             for x, y in d.items()
                             if x == ("m.room.member", "@us:test")
                         ],
-                        "auth_chain_ids": d.values(),
+                        "auth_chain_ids": list(d.values()),
                     }
                 )
 
diff --git a/tests/test_state.py b/tests/test_state.py
index c0f2d1152d..429a18cbf7 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -204,7 +204,8 @@ class StateTestCase(unittest.TestCase):
             self.store.register_event_context(event, context)
             context_store[event.event_id] = context
 
-        self.assertEqual(2, len(context_store["D"].prev_state_ids))
+        prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store)
+        self.assertEqual(2, len(prev_state_ids))
 
     @defer.inlineCallbacks
     def test_branch_basic_conflict(self):
@@ -255,9 +256,11 @@ class StateTestCase(unittest.TestCase):
             self.store.register_event_context(event, context)
             context_store[event.event_id] = context
 
+        prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store)
+
         self.assertSetEqual(
             {"START", "A", "C"},
-            {e_id for e_id in context_store["D"].prev_state_ids.values()}
+            {e_id for e_id in prev_state_ids.values()}
         )
 
     @defer.inlineCallbacks
@@ -318,9 +321,11 @@ class StateTestCase(unittest.TestCase):
             self.store.register_event_context(event, context)
             context_store[event.event_id] = context
 
+        prev_state_ids = yield context_store["E"].get_prev_state_ids(self.store)
+
         self.assertSetEqual(
             {"START", "A", "B", "C"},
-            {e for e in context_store["E"].prev_state_ids.values()}
+            {e for e in prev_state_ids.values()}
         )
 
     @defer.inlineCallbacks
@@ -398,9 +403,11 @@ class StateTestCase(unittest.TestCase):
             self.store.register_event_context(event, context)
             context_store[event.event_id] = context
 
+        prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store)
+
         self.assertSetEqual(
             {"A1", "A2", "A3", "A5", "B"},
-            {e for e in context_store["D"].prev_state_ids.values()}
+            {e for e in prev_state_ids.values()}
         )
 
     def _add_depths(self, nodes, edges):
@@ -429,8 +436,10 @@ class StateTestCase(unittest.TestCase):
             event, old_state=old_state
         )
 
+        current_state_ids = yield context.get_current_state_ids(self.store)
+
         self.assertEqual(
-            set(e.event_id for e in old_state), set(context.current_state_ids.values())
+            set(e.event_id for e in old_state), set(current_state_ids.values())
         )
 
         self.assertIsNotNone(context.state_group)
@@ -449,8 +458,10 @@ class StateTestCase(unittest.TestCase):
             event, old_state=old_state
         )
 
+        prev_state_ids = yield context.get_prev_state_ids(self.store)
+
         self.assertEqual(
-            set(e.event_id for e in old_state), set(context.prev_state_ids.values())
+            set(e.event_id for e in old_state), set(prev_state_ids.values())
         )
 
     @defer.inlineCallbacks
@@ -475,9 +486,11 @@ class StateTestCase(unittest.TestCase):
 
         context = yield self.state.compute_event_context(event)
 
+        current_state_ids = yield context.get_current_state_ids(self.store)
+
         self.assertEqual(
             set([e.event_id for e in old_state]),
-            set(context.current_state_ids.values())
+            set(current_state_ids.values())
         )
 
         self.assertEqual(group_name, context.state_group)
@@ -504,9 +517,11 @@ class StateTestCase(unittest.TestCase):
 
         context = yield self.state.compute_event_context(event)
 
+        prev_state_ids = yield context.get_prev_state_ids(self.store)
+
         self.assertEqual(
             set([e.event_id for e in old_state]),
-            set(context.prev_state_ids.values())
+            set(prev_state_ids.values())
         )
 
         self.assertIsNotNone(context.state_group)
@@ -545,7 +560,9 @@ class StateTestCase(unittest.TestCase):
             event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
         )
 
-        self.assertEqual(len(context.current_state_ids), 6)
+        current_state_ids = yield context.get_current_state_ids(self.store)
+
+        self.assertEqual(len(current_state_ids), 6)
 
         self.assertIsNotNone(context.state_group)
 
@@ -585,7 +602,9 @@ class StateTestCase(unittest.TestCase):
             event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
         )
 
-        self.assertEqual(len(context.current_state_ids), 6)
+        current_state_ids = yield context.get_current_state_ids(self.store)
+
+        self.assertEqual(len(current_state_ids), 6)
 
         self.assertIsNotNone(context.state_group)
 
@@ -642,8 +661,10 @@ class StateTestCase(unittest.TestCase):
             event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
         )
 
+        current_state_ids = yield context.get_current_state_ids(self.store)
+
         self.assertEqual(
-            old_state_2[3].event_id, context.current_state_ids[("test1", "1")]
+            old_state_2[3].event_id, current_state_ids[("test1", "1")]
         )
 
         # Reverse the depth to make sure we are actually using the depths
@@ -670,8 +691,10 @@ class StateTestCase(unittest.TestCase):
             event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
         )
 
+        current_state_ids = yield context.get_current_state_ids(self.store)
+
         self.assertEqual(
-            old_state_1[3].event_id, context.current_state_ids[("test1", "1")]
+            old_state_1[3].event_id, current_state_ids[("test1", "1")]
         )
 
     def _get_context(self, event, prev_event_id_1, old_state_1, prev_event_id_2,
diff --git a/tests/util/test_limiter.py b/tests/util/test_limiter.py
deleted file mode 100644
index a5a767b1ff..0000000000
--- a/tests/util/test_limiter.py
+++ /dev/null
@@ -1,70 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-from twisted.internet import defer
-
-from synapse.util.async import Limiter
-
-from tests import unittest
-
-
-class LimiterTestCase(unittest.TestCase):
-
-    @defer.inlineCallbacks
-    def test_limiter(self):
-        limiter = Limiter(3)
-
-        key = object()
-
-        d1 = limiter.queue(key)
-        cm1 = yield d1
-
-        d2 = limiter.queue(key)
-        cm2 = yield d2
-
-        d3 = limiter.queue(key)
-        cm3 = yield d3
-
-        d4 = limiter.queue(key)
-        self.assertFalse(d4.called)
-
-        d5 = limiter.queue(key)
-        self.assertFalse(d5.called)
-
-        with cm1:
-            self.assertFalse(d4.called)
-            self.assertFalse(d5.called)
-
-        self.assertTrue(d4.called)
-        self.assertFalse(d5.called)
-
-        with cm3:
-            self.assertFalse(d5.called)
-
-        self.assertTrue(d5.called)
-
-        with cm2:
-            pass
-
-        with (yield d4):
-            pass
-
-        with (yield d5):
-            pass
-
-        d6 = limiter.queue(key)
-        with (yield d6):
-            pass
diff --git a/tests/util/test_linearizer.py b/tests/util/test_linearizer.py
index c95907b32c..4729bd5a0a 100644
--- a/tests/util/test_linearizer.py
+++ b/tests/util/test_linearizer.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -16,6 +17,7 @@
 from six.moves import range
 
 from twisted.internet import defer, reactor
+from twisted.internet.defer import CancelledError
 
 from synapse.util import Clock, logcontext
 from synapse.util.async import Linearizer
@@ -65,3 +67,79 @@ class LinearizerTestCase(unittest.TestCase):
             func(i)
 
         return func(1000)
+
+    @defer.inlineCallbacks
+    def test_multiple_entries(self):
+        limiter = Linearizer(max_count=3)
+
+        key = object()
+
+        d1 = limiter.queue(key)
+        cm1 = yield d1
+
+        d2 = limiter.queue(key)
+        cm2 = yield d2
+
+        d3 = limiter.queue(key)
+        cm3 = yield d3
+
+        d4 = limiter.queue(key)
+        self.assertFalse(d4.called)
+
+        d5 = limiter.queue(key)
+        self.assertFalse(d5.called)
+
+        with cm1:
+            self.assertFalse(d4.called)
+            self.assertFalse(d5.called)
+
+        cm4 = yield d4
+        self.assertFalse(d5.called)
+
+        with cm3:
+            self.assertFalse(d5.called)
+
+        cm5 = yield d5
+
+        with cm2:
+            pass
+
+        with cm4:
+            pass
+
+        with cm5:
+            pass
+
+        d6 = limiter.queue(key)
+        with (yield d6):
+            pass
+
+    @defer.inlineCallbacks
+    def test_cancellation(self):
+        linearizer = Linearizer()
+
+        key = object()
+
+        d1 = linearizer.queue(key)
+        cm1 = yield d1
+
+        d2 = linearizer.queue(key)
+        self.assertFalse(d2.called)
+
+        d3 = linearizer.queue(key)
+        self.assertFalse(d3.called)
+
+        d2.cancel()
+
+        with cm1:
+            pass
+
+        self.assertTrue(d2.called)
+        try:
+            yield d2
+            self.fail("Expected d2 to raise CancelledError")
+        except CancelledError:
+            pass
+
+        with (yield d3):
+            pass
diff --git a/tests/utils.py b/tests/utils.py
index e488238bb3..c3dbff8507 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -71,6 +71,8 @@ def setup_test_homeserver(name="test", datastore=None, config=None, reactor=None
         config.user_directory_search_all_users = False
         config.user_consent_server_notice_content = None
         config.block_events_without_consent_error = None
+        config.media_storage_providers = []
+        config.auto_join_rooms = []
 
         # disable user directory updates, because they get done in the
         # background, which upsets the test runner.
@@ -136,6 +138,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, reactor=None
             database_engine=db_engine,
             room_list_handler=object(),
             tls_server_context_factory=Mock(),
+            reactor=reactor,
             **kargs
         )