summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/handlers/test_typing.py20
-rw-r--r--tests/rest/client/v1/test_events.py12
-rw-r--r--tests/rest/client/v1/test_presence.py4
-rw-r--r--tests/rest/client/v2_alpha/test_register.py135
-rw-r--r--tests/storage/test__base.py23
-rw-r--r--tests/test_distributor.py4
-rw-r--r--tests/test_state.py2
-rw-r--r--tests/util/test_dict_cache.py101
-rw-r--r--tests/util/test_lrucache.py4
9 files changed, 276 insertions, 29 deletions
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 7ccbe2ea9c..41bb08b7ca 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -66,8 +66,8 @@ class TypingNotificationsTestCase(unittest.TestCase):
 
         self.mock_federation_resource = MockHttpResource()
 
-        mock_notifier = Mock(spec=["on_new_user_event"])
-        self.on_new_user_event = mock_notifier.on_new_user_event
+        mock_notifier = Mock(spec=["on_new_event"])
+        self.on_new_event = mock_notifier.on_new_event
 
         self.auth = Mock(spec=[])
 
@@ -182,7 +182,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
             timeout=20000,
         )
 
-        self.on_new_user_event.assert_has_calls([
+        self.on_new_event.assert_has_calls([
             call('typing_key', 1, rooms=[self.room_id]),
         ])
 
@@ -245,7 +245,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
             )
         )
 
-        self.on_new_user_event.assert_has_calls([
+        self.on_new_event.assert_has_calls([
             call('typing_key', 1, rooms=[self.room_id]),
         ])
 
@@ -299,7 +299,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
             room_id=self.room_id,
         )
 
-        self.on_new_user_event.assert_has_calls([
+        self.on_new_event.assert_has_calls([
             call('typing_key', 1, rooms=[self.room_id]),
         ])
 
@@ -331,10 +331,10 @@ class TypingNotificationsTestCase(unittest.TestCase):
             timeout=10000,
         )
 
-        self.on_new_user_event.assert_has_calls([
+        self.on_new_event.assert_has_calls([
             call('typing_key', 1, rooms=[self.room_id]),
         ])
-        self.on_new_user_event.reset_mock()
+        self.on_new_event.reset_mock()
 
         self.assertEquals(self.event_source.get_current_key(), 1)
         events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None)
@@ -351,7 +351,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
 
         self.clock.advance_time(11)
 
-        self.on_new_user_event.assert_has_calls([
+        self.on_new_event.assert_has_calls([
             call('typing_key', 2, rooms=[self.room_id]),
         ])
 
@@ -377,10 +377,10 @@ class TypingNotificationsTestCase(unittest.TestCase):
             timeout=10000,
         )
 
-        self.on_new_user_event.assert_has_calls([
+        self.on_new_event.assert_has_calls([
             call('typing_key', 3, rooms=[self.room_id]),
         ])
-        self.on_new_user_event.reset_mock()
+        self.on_new_event.reset_mock()
 
         self.assertEquals(self.event_source.get_current_key(), 3)
         events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None)
diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py
index 445272e323..ac3b0b58ac 100644
--- a/tests/rest/client/v1/test_events.py
+++ b/tests/rest/client/v1/test_events.py
@@ -183,7 +183,17 @@ class EventStreamPermissionsTestCase(RestTestCase):
         )
         self.assertEquals(200, code, msg=str(response))
 
-        self.assertEquals(0, len(response["chunk"]))
+        # We may get a presence event for ourselves down
+        self.assertEquals(
+            0,
+            len([
+                c for c in response["chunk"]
+                if not (
+                    c.get("type") == "m.presence"
+                    and c["content"].get("user_id") == self.user_id
+                )
+            ])
+        )
 
         # joined room (expect all content for room)
         yield self.join(room=room_id, user=self.user_id, tok=self.token)
diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py
index 4b32c7a203..089a71568c 100644
--- a/tests/rest/client/v1/test_presence.py
+++ b/tests/rest/client/v1/test_presence.py
@@ -357,7 +357,7 @@ class PresenceEventStreamTestCase(unittest.TestCase):
         # all be ours
 
         # I'll already get my own presence state change
-        self.assertEquals({"start": "0_1_0", "end": "0_1_0", "chunk": []},
+        self.assertEquals({"start": "0_1_0_0", "end": "0_1_0_0", "chunk": []},
             response
         )
 
@@ -376,7 +376,7 @@ class PresenceEventStreamTestCase(unittest.TestCase):
                 "/events?from=s0_1_0&timeout=0", None)
 
         self.assertEquals(200, code)
-        self.assertEquals({"start": "s0_1_0", "end": "s0_2_0", "chunk": [
+        self.assertEquals({"start": "s0_1_0_0", "end": "s0_2_0_0", "chunk": [
             {"type": "m.presence",
              "content": {
                  "user_id": "@banana:test",
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
new file mode 100644
index 0000000000..f9a2b22485
--- /dev/null
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -0,0 +1,135 @@
+from synapse.rest.client.v2_alpha.register import RegisterRestServlet
+from synapse.api.errors import SynapseError
+from twisted.internet import defer
+from mock import Mock, MagicMock
+from tests import unittest
+import json
+
+
+class RegisterRestServletTestCase(unittest.TestCase):
+
+    def setUp(self):
+        # do the dance to hook up request data to self.request_data
+        self.request_data = ""
+        self.request = Mock(
+            content=Mock(read=Mock(side_effect=lambda: self.request_data)),
+            path='/_matrix/api/v2_alpha/register'
+        )
+        self.request.args = {}
+
+        self.appservice = None
+        self.auth = Mock(get_appservice_by_req=Mock(
+            side_effect=lambda x: defer.succeed(self.appservice))
+        )
+
+        self.auth_result = (False, None, None)
+        self.auth_handler = Mock(
+            check_auth=Mock(side_effect=lambda x,y,z: self.auth_result)
+        )
+        self.registration_handler = Mock()
+        self.identity_handler = Mock()
+        self.login_handler = Mock()
+
+        # do the dance to hook it up to the hs global
+        self.handlers = Mock(
+            auth_handler=self.auth_handler,
+            registration_handler=self.registration_handler,
+            identity_handler=self.identity_handler,
+            login_handler=self.login_handler
+        )
+        self.hs = Mock()
+        self.hs.hostname = "superbig~testing~thing.com"
+        self.hs.get_auth = Mock(return_value=self.auth)
+        self.hs.get_handlers = Mock(return_value=self.handlers)
+        self.hs.config.disable_registration = False
+
+        # init the thing we're testing
+        self.servlet = RegisterRestServlet(self.hs)
+
+    @defer.inlineCallbacks
+    def test_POST_appservice_registration_valid(self):
+        user_id = "@kermit:muppet"
+        token = "kermits_access_token"
+        self.request.args = {
+            "access_token": "i_am_an_app_service"
+        }
+        self.request_data = json.dumps({
+            "username": "kermit"
+        })
+        self.appservice = {
+            "id": "1234"
+        }
+        self.registration_handler.appservice_register = Mock(
+            return_value=(user_id, token)
+        )
+        result = yield self.servlet.on_POST(self.request)
+        self.assertEquals(result, (200, {
+            "user_id": user_id,
+            "access_token": token,
+            "home_server": self.hs.hostname
+        }))
+
+    @defer.inlineCallbacks
+    def test_POST_appservice_registration_invalid(self):
+        self.request.args = {
+            "access_token": "i_am_an_app_service"
+        }
+        self.request_data = json.dumps({
+            "username": "kermit"
+        })
+        self.appservice = None  # no application service exists
+        result = yield self.servlet.on_POST(self.request)
+        self.assertEquals(result, (401, None))
+
+    def test_POST_bad_password(self):
+        self.request_data = json.dumps({
+            "username": "kermit",
+            "password": 666
+        })
+        d = self.servlet.on_POST(self.request)
+        return self.assertFailure(d, SynapseError)
+
+    def test_POST_bad_username(self):
+        self.request_data = json.dumps({
+            "username": 777,
+            "password": "monkey"
+        })
+        d = self.servlet.on_POST(self.request)
+        return self.assertFailure(d, SynapseError)
+
+    @defer.inlineCallbacks
+    def test_POST_user_valid(self):
+        user_id = "@kermit:muppet"
+        token = "kermits_access_token"
+        self.request_data = json.dumps({
+            "username": "kermit",
+            "password": "monkey"
+        })
+        self.registration_handler.check_username = Mock(return_value=True)
+        self.auth_result = (True, None, {
+            "username": "kermit",
+            "password": "monkey"
+        })
+        self.registration_handler.register = Mock(return_value=(user_id, token))
+
+        result = yield self.servlet.on_POST(self.request)
+        self.assertEquals(result, (200, {
+            "user_id": user_id,
+            "access_token": token,
+            "home_server": self.hs.hostname
+        }))
+
+    def test_POST_disabled_registration(self):
+        self.hs.config.disable_registration = True
+        self.request_data = json.dumps({
+            "username": "kermit",
+            "password": "monkey"
+        })
+        self.registration_handler.check_username = Mock(return_value=True)
+        self.auth_result = (True, None, {
+            "username": "kermit",
+            "password": "monkey"
+        })
+        self.registration_handler.register = Mock(return_value=("@user:id", "t"))
+        d = self.servlet.on_POST(self.request)
+        return self.assertFailure(d, SynapseError)
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index 8c3d2952bd..e72cace8ff 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -17,7 +17,9 @@
 from tests import unittest
 from twisted.internet import defer
 
-from synapse.storage._base import Cache, cached
+from synapse.util.async import ObservableDeferred
+
+from synapse.util.caches.descriptors import Cache, cached
 
 
 class CacheTestCase(unittest.TestCase):
@@ -40,12 +42,12 @@ class CacheTestCase(unittest.TestCase):
         self.assertEquals(self.cache.get("foo"), 123)
 
     def test_invalidate(self):
-        self.cache.prefill("foo", 123)
-        self.cache.invalidate("foo")
+        self.cache.prefill(("foo",), 123)
+        self.cache.invalidate(("foo",))
 
         failed = False
         try:
-            self.cache.get("foo")
+            self.cache.get(("foo",))
         except KeyError:
             failed = True
 
@@ -139,7 +141,7 @@ class CacheDecoratorTestCase(unittest.TestCase):
 
         self.assertEquals(callcount[0], 1)
 
-        a.func.invalidate("foo")
+        a.func.invalidate(("foo",))
 
         yield a.func("foo")
 
@@ -151,7 +153,7 @@ class CacheDecoratorTestCase(unittest.TestCase):
             def func(self, key):
                 return key
 
-        A().func.invalidate("what")
+        A().func.invalidate(("what",))
 
     @defer.inlineCallbacks
     def test_max_entries(self):
@@ -178,19 +180,20 @@ class CacheDecoratorTestCase(unittest.TestCase):
         self.assertTrue(callcount[0] >= 14,
             msg="Expected callcount >= 14, got %d" % (callcount[0]))
 
-    @defer.inlineCallbacks
     def test_prefill(self):
         callcount = [0]
 
+        d = defer.succeed(123)
+
         class A(object):
             @cached()
             def func(self, key):
                 callcount[0] += 1
-                return key
+                return d
 
         a = A()
 
-        a.func.prefill("foo", 123)
+        a.func.prefill(("foo",), ObservableDeferred(d))
 
-        self.assertEquals((yield a.func("foo")), 123)
+        self.assertEquals(a.func("foo").result, d.result)
         self.assertEquals(callcount[0], 0)
diff --git a/tests/test_distributor.py b/tests/test_distributor.py
index 6a0095d850..8ed48cfb70 100644
--- a/tests/test_distributor.py
+++ b/tests/test_distributor.py
@@ -73,8 +73,8 @@ class DistributorTestCase(unittest.TestCase):
             yield d
             self.assertTrue(d.called)
 
-            observers[0].assert_called_once("Go")
-            observers[1].assert_called_once("Go")
+            observers[0].assert_called_once_with("Go")
+            observers[1].assert_called_once_with("Go")
 
             self.assertEquals(mock_logger.warning.call_count, 1)
             self.assertIsInstance(mock_logger.warning.call_args[0][0],
diff --git a/tests/test_state.py b/tests/test_state.py
index fea25f7021..5845358754 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -69,7 +69,7 @@ class StateGroupStore(object):
 
         self._next_group = 1
 
-    def get_state_groups(self, event_ids):
+    def get_state_groups(self, room_id, event_ids):
         groups = {}
         for event_id in event_ids:
             group = self._event_to_state_group.get(event_id)
diff --git a/tests/util/test_dict_cache.py b/tests/util/test_dict_cache.py
new file mode 100644
index 0000000000..54ff26cd97
--- /dev/null
+++ b/tests/util/test_dict_cache.py
@@ -0,0 +1,101 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from twisted.internet import defer
+from tests import unittest
+
+from synapse.util.caches.dictionary_cache import DictionaryCache
+
+
+class DictCacheTestCase(unittest.TestCase):
+
+    def setUp(self):
+        self.cache = DictionaryCache("foobar")
+
+    def test_simple_cache_hit_full(self):
+        key = "test_simple_cache_hit_full"
+
+        v = self.cache.get(key)
+        self.assertEqual((False, {}), v)
+
+        seq = self.cache.sequence
+        test_value = {"test": "test_simple_cache_hit_full"}
+        self.cache.update(seq, key, test_value, full=True)
+
+        c = self.cache.get(key)
+        self.assertEqual(test_value, c.value)
+
+    def test_simple_cache_hit_partial(self):
+        key = "test_simple_cache_hit_partial"
+
+        seq = self.cache.sequence
+        test_value = {
+            "test": "test_simple_cache_hit_partial"
+        }
+        self.cache.update(seq, key, test_value, full=True)
+
+        c = self.cache.get(key, ["test"])
+        self.assertEqual(test_value, c.value)
+
+    def test_simple_cache_miss_partial(self):
+        key = "test_simple_cache_miss_partial"
+
+        seq = self.cache.sequence
+        test_value = {
+            "test": "test_simple_cache_miss_partial"
+        }
+        self.cache.update(seq, key, test_value, full=True)
+
+        c = self.cache.get(key, ["test2"])
+        self.assertEqual({}, c.value)
+
+    def test_simple_cache_hit_miss_partial(self):
+        key = "test_simple_cache_hit_miss_partial"
+
+        seq = self.cache.sequence
+        test_value = {
+            "test": "test_simple_cache_hit_miss_partial",
+            "test2": "test_simple_cache_hit_miss_partial2",
+            "test3": "test_simple_cache_hit_miss_partial3",
+        }
+        self.cache.update(seq, key, test_value, full=True)
+
+        c = self.cache.get(key, ["test2"])
+        self.assertEqual({"test2": "test_simple_cache_hit_miss_partial2"}, c.value)
+
+    def test_multi_insert(self):
+        key = "test_simple_cache_hit_miss_partial"
+
+        seq = self.cache.sequence
+        test_value_1 = {
+            "test": "test_simple_cache_hit_miss_partial",
+        }
+        self.cache.update(seq, key, test_value_1, full=False)
+
+        seq = self.cache.sequence
+        test_value_2 = {
+            "test2": "test_simple_cache_hit_miss_partial2",
+        }
+        self.cache.update(seq, key, test_value_2, full=False)
+
+        c = self.cache.get(key)
+        self.assertEqual(
+            {
+                "test": "test_simple_cache_hit_miss_partial",
+                "test2": "test_simple_cache_hit_miss_partial2",
+            },
+            c.value
+        )
diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py
index ab934bf928..fc5a904323 100644
--- a/tests/util/test_lrucache.py
+++ b/tests/util/test_lrucache.py
@@ -16,7 +16,7 @@
 
 from .. import unittest
 
-from synapse.util.lrucache import LruCache
+from synapse.util.caches.lrucache import LruCache
 
 class LruCacheTestCase(unittest.TestCase):
 
@@ -52,5 +52,3 @@ class LruCacheTestCase(unittest.TestCase):
         cache["key"] = 1
         self.assertEquals(cache.pop("key"), 1)
         self.assertEquals(cache.pop("key"), None)
-
-