summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/handlers/test_appservice.py43
-rw-r--r--tests/handlers/test_federation.py4
-rw-r--r--tests/handlers/test_room.py2
-rw-r--r--tests/handlers/test_typing.py20
-rw-r--r--tests/rest/client/v1/test_events.py12
-rw-r--r--tests/rest/client/v1/test_presence.py4
-rw-r--r--tests/rest/client/v2_alpha/test_register.py134
-rw-r--r--tests/storage/test__base.py95
-rw-r--r--tests/storage/test_registration.py2
-rw-r--r--tests/test_distributor.py4
-rw-r--r--tests/utils.py2
11 files changed, 265 insertions, 57 deletions
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index 06cb1dd4cf..9e95d1e532 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -58,6 +58,49 @@ class AppServiceHandlerTestCase(unittest.TestCase):
         )
 
     @defer.inlineCallbacks
+    def test_query_user_exists_unknown_user(self):
+        user_id = "@someone:anywhere"
+        services = [self._mkservice(is_interested=True)]
+        services[0].is_interested_in_user = Mock(return_value=True)
+        self.mock_store.get_app_services = Mock(return_value=services)
+        self.mock_store.get_user_by_id = Mock(return_value=None)
+
+        event = Mock(
+            sender=user_id,
+            type="m.room.message",
+            room_id="!foo:bar"
+        )
+        self.mock_as_api.push = Mock()
+        self.mock_as_api.query_user = Mock()
+        yield self.handler.notify_interested_services(event)
+        self.mock_as_api.query_user.assert_called_once_with(
+            services[0], user_id
+        )
+
+    @defer.inlineCallbacks
+    def test_query_user_exists_known_user(self):
+        user_id = "@someone:anywhere"
+        services = [self._mkservice(is_interested=True)]
+        services[0].is_interested_in_user = Mock(return_value=True)
+        self.mock_store.get_app_services = Mock(return_value=services)
+        self.mock_store.get_user_by_id = Mock(return_value={
+            "name": user_id
+        })
+
+        event = Mock(
+            sender=user_id,
+            type="m.room.message",
+            room_id="!foo:bar"
+        )
+        self.mock_as_api.push = Mock()
+        self.mock_as_api.query_user = Mock()
+        yield self.handler.notify_interested_services(event)
+        self.assertFalse(
+            self.mock_as_api.query_user.called,
+            "query_user called when it shouldn't have been."
+        )
+
+    @defer.inlineCallbacks
     def test_query_room_alias_exists(self):
         room_alias_str = "#foo:bar"
         room_alias = Mock()
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index f3821242bc..d392c23015 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -100,7 +100,7 @@ class FederationTestCase(unittest.TestCase):
             return defer.succeed({})
         self.datastore.have_events.side_effect = have_events
 
-        def annotate(ev, old_state=None):
+        def annotate(ev, old_state=None, outlier=False):
             context = Mock()
             context.current_state = {}
             context.auth_events = {}
@@ -120,7 +120,7 @@ class FederationTestCase(unittest.TestCase):
         )
 
         self.state_handler.compute_event_context.assert_called_once_with(
-            ANY, old_state=None,
+            ANY, old_state=None, outlier=False
         )
 
         self.auth.check.assert_called_once_with(ANY, auth_events={})
diff --git a/tests/handlers/test_room.py b/tests/handlers/test_room.py
index a2d7635995..2a7553f982 100644
--- a/tests/handlers/test_room.py
+++ b/tests/handlers/test_room.py
@@ -42,6 +42,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
                 "get_room",
                 "store_room",
                 "get_latest_events_in_room",
+                "add_event_hashes",
             ]),
             resource_for_federation=NonCallableMock(),
             http_client=NonCallableMock(spec_set=[]),
@@ -88,6 +89,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
         self.ratelimiter.send_message.return_value = (True, 0)
 
         self.datastore.persist_event.return_value = (1,1)
+        self.datastore.add_event_hashes.return_value = []
 
     @defer.inlineCallbacks
     def test_invite(self):
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 7ccbe2ea9c..41bb08b7ca 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -66,8 +66,8 @@ class TypingNotificationsTestCase(unittest.TestCase):
 
         self.mock_federation_resource = MockHttpResource()
 
-        mock_notifier = Mock(spec=["on_new_user_event"])
-        self.on_new_user_event = mock_notifier.on_new_user_event
+        mock_notifier = Mock(spec=["on_new_event"])
+        self.on_new_event = mock_notifier.on_new_event
 
         self.auth = Mock(spec=[])
 
@@ -182,7 +182,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
             timeout=20000,
         )
 
-        self.on_new_user_event.assert_has_calls([
+        self.on_new_event.assert_has_calls([
             call('typing_key', 1, rooms=[self.room_id]),
         ])
 
@@ -245,7 +245,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
             )
         )
 
-        self.on_new_user_event.assert_has_calls([
+        self.on_new_event.assert_has_calls([
             call('typing_key', 1, rooms=[self.room_id]),
         ])
 
@@ -299,7 +299,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
             room_id=self.room_id,
         )
 
-        self.on_new_user_event.assert_has_calls([
+        self.on_new_event.assert_has_calls([
             call('typing_key', 1, rooms=[self.room_id]),
         ])
 
@@ -331,10 +331,10 @@ class TypingNotificationsTestCase(unittest.TestCase):
             timeout=10000,
         )
 
-        self.on_new_user_event.assert_has_calls([
+        self.on_new_event.assert_has_calls([
             call('typing_key', 1, rooms=[self.room_id]),
         ])
-        self.on_new_user_event.reset_mock()
+        self.on_new_event.reset_mock()
 
         self.assertEquals(self.event_source.get_current_key(), 1)
         events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None)
@@ -351,7 +351,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
 
         self.clock.advance_time(11)
 
-        self.on_new_user_event.assert_has_calls([
+        self.on_new_event.assert_has_calls([
             call('typing_key', 2, rooms=[self.room_id]),
         ])
 
@@ -377,10 +377,10 @@ class TypingNotificationsTestCase(unittest.TestCase):
             timeout=10000,
         )
 
-        self.on_new_user_event.assert_has_calls([
+        self.on_new_event.assert_has_calls([
             call('typing_key', 3, rooms=[self.room_id]),
         ])
-        self.on_new_user_event.reset_mock()
+        self.on_new_event.reset_mock()
 
         self.assertEquals(self.event_source.get_current_key(), 3)
         events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None)
diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py
index 445272e323..ac3b0b58ac 100644
--- a/tests/rest/client/v1/test_events.py
+++ b/tests/rest/client/v1/test_events.py
@@ -183,7 +183,17 @@ class EventStreamPermissionsTestCase(RestTestCase):
         )
         self.assertEquals(200, code, msg=str(response))
 
-        self.assertEquals(0, len(response["chunk"]))
+        # We may get a presence event for ourselves down
+        self.assertEquals(
+            0,
+            len([
+                c for c in response["chunk"]
+                if not (
+                    c.get("type") == "m.presence"
+                    and c["content"].get("user_id") == self.user_id
+                )
+            ])
+        )
 
         # joined room (expect all content for room)
         yield self.join(room=room_id, user=self.user_id, tok=self.token)
diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py
index 4b32c7a203..089a71568c 100644
--- a/tests/rest/client/v1/test_presence.py
+++ b/tests/rest/client/v1/test_presence.py
@@ -357,7 +357,7 @@ class PresenceEventStreamTestCase(unittest.TestCase):
         # all be ours
 
         # I'll already get my own presence state change
-        self.assertEquals({"start": "0_1_0", "end": "0_1_0", "chunk": []},
+        self.assertEquals({"start": "0_1_0_0", "end": "0_1_0_0", "chunk": []},
             response
         )
 
@@ -376,7 +376,7 @@ class PresenceEventStreamTestCase(unittest.TestCase):
                 "/events?from=s0_1_0&timeout=0", None)
 
         self.assertEquals(200, code)
-        self.assertEquals({"start": "s0_1_0", "end": "s0_2_0", "chunk": [
+        self.assertEquals({"start": "s0_1_0_0", "end": "s0_2_0_0", "chunk": [
             {"type": "m.presence",
              "content": {
                  "user_id": "@banana:test",
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
new file mode 100644
index 0000000000..66fd25964d
--- /dev/null
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -0,0 +1,134 @@
+from synapse.rest.client.v2_alpha.register import RegisterRestServlet
+from synapse.api.errors import SynapseError
+from twisted.internet import defer
+from mock import Mock, MagicMock
+from tests import unittest
+import json
+
+
+class RegisterRestServletTestCase(unittest.TestCase):
+
+    def setUp(self):
+        # do the dance to hook up request data to self.request_data
+        self.request_data = ""
+        self.request = Mock(
+            content=Mock(read=Mock(side_effect=lambda: self.request_data)),
+        )
+        self.request.args = {}
+
+        self.appservice = None
+        self.auth = Mock(get_appservice_by_req=Mock(
+            side_effect=lambda x: defer.succeed(self.appservice))
+        )
+
+        self.auth_result = (False, None, None)
+        self.auth_handler = Mock(
+            check_auth=Mock(side_effect=lambda x,y,z: self.auth_result)
+        )
+        self.registration_handler = Mock()
+        self.identity_handler = Mock()
+        self.login_handler = Mock()
+
+        # do the dance to hook it up to the hs global
+        self.handlers = Mock(
+            auth_handler=self.auth_handler,
+            registration_handler=self.registration_handler,
+            identity_handler=self.identity_handler,
+            login_handler=self.login_handler
+        )
+        self.hs = Mock()
+        self.hs.hostname = "superbig~testing~thing.com"
+        self.hs.get_auth = Mock(return_value=self.auth)
+        self.hs.get_handlers = Mock(return_value=self.handlers)
+        self.hs.config.disable_registration = False
+
+        # init the thing we're testing
+        self.servlet = RegisterRestServlet(self.hs)
+
+    @defer.inlineCallbacks
+    def test_POST_appservice_registration_valid(self):
+        user_id = "@kermit:muppet"
+        token = "kermits_access_token"
+        self.request.args = {
+            "access_token": "i_am_an_app_service"
+        }
+        self.request_data = json.dumps({
+            "username": "kermit"
+        })
+        self.appservice = {
+            "id": "1234"
+        }
+        self.registration_handler.appservice_register = Mock(
+            return_value=(user_id, token)
+        )
+        result = yield self.servlet.on_POST(self.request)
+        self.assertEquals(result, (200, {
+            "user_id": user_id,
+            "access_token": token,
+            "home_server": self.hs.hostname
+        }))
+
+    @defer.inlineCallbacks
+    def test_POST_appservice_registration_invalid(self):
+        self.request.args = {
+            "access_token": "i_am_an_app_service"
+        }
+        self.request_data = json.dumps({
+            "username": "kermit"
+        })
+        self.appservice = None  # no application service exists
+        result = yield self.servlet.on_POST(self.request)
+        self.assertEquals(result, (401, None))
+
+    def test_POST_bad_password(self):
+        self.request_data = json.dumps({
+            "username": "kermit",
+            "password": 666
+        })
+        d = self.servlet.on_POST(self.request)
+        return self.assertFailure(d, SynapseError)
+
+    def test_POST_bad_username(self):
+        self.request_data = json.dumps({
+            "username": 777,
+            "password": "monkey"
+        })
+        d = self.servlet.on_POST(self.request)
+        return self.assertFailure(d, SynapseError)
+
+    @defer.inlineCallbacks
+    def test_POST_user_valid(self):
+        user_id = "@kermit:muppet"
+        token = "kermits_access_token"
+        self.request_data = json.dumps({
+            "username": "kermit",
+            "password": "monkey"
+        })
+        self.registration_handler.check_username = Mock(return_value=True)
+        self.auth_result = (True, None, {
+            "username": "kermit",
+            "password": "monkey"
+        })
+        self.registration_handler.register = Mock(return_value=(user_id, token))
+
+        result = yield self.servlet.on_POST(self.request)
+        self.assertEquals(result, (200, {
+            "user_id": user_id,
+            "access_token": token,
+            "home_server": self.hs.hostname
+        }))
+
+    def test_POST_disabled_registration(self):
+        self.hs.config.disable_registration = True
+        self.request_data = json.dumps({
+            "username": "kermit",
+            "password": "monkey"
+        })
+        self.registration_handler.check_username = Mock(return_value=True)
+        self.auth_result = (True, None, {
+            "username": "kermit",
+            "password": "monkey"
+        })
+        self.registration_handler.register = Mock(return_value=("@user:id", "t"))
+        d = self.servlet.on_POST(self.request)
+        return self.assertFailure(d, SynapseError)
\ No newline at end of file
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index 96caf8c4c1..abee2f631d 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -17,6 +17,8 @@
 from tests import unittest
 from twisted.internet import defer
 
+from synapse.util.async import ObservableDeferred
+
 from synapse.storage._base import Cache, cached
 
 
@@ -40,12 +42,12 @@ class CacheTestCase(unittest.TestCase):
         self.assertEquals(self.cache.get("foo"), 123)
 
     def test_invalidate(self):
-        self.cache.prefill("foo", 123)
-        self.cache.invalidate("foo")
+        self.cache.prefill(("foo",), 123)
+        self.cache.invalidate(("foo",))
 
         failed = False
         try:
-            self.cache.get("foo")
+            self.cache.get(("foo",))
         except KeyError:
             failed = True
 
@@ -96,87 +98,102 @@ class CacheDecoratorTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_passthrough(self):
-        @cached()
-        def func(self, key):
-            return key
+        class A(object):
+            @cached()
+            def func(self, key):
+                return key
 
-        self.assertEquals((yield func(self, "foo")), "foo")
-        self.assertEquals((yield func(self, "bar")), "bar")
+        a = A()
+
+        self.assertEquals((yield a.func("foo")), "foo")
+        self.assertEquals((yield a.func("bar")), "bar")
 
     @defer.inlineCallbacks
     def test_hit(self):
         callcount = [0]
 
-        @cached()
-        def func(self, key):
-            callcount[0] += 1
-            return key
+        class A(object):
+            @cached()
+            def func(self, key):
+                callcount[0] += 1
+                return key
 
-        yield func(self, "foo")
+        a = A()
+        yield a.func("foo")
 
         self.assertEquals(callcount[0], 1)
 
-        self.assertEquals((yield func(self, "foo")), "foo")
+        self.assertEquals((yield a.func("foo")), "foo")
         self.assertEquals(callcount[0], 1)
 
     @defer.inlineCallbacks
     def test_invalidate(self):
         callcount = [0]
 
-        @cached()
-        def func(self, key):
-            callcount[0] += 1
-            return key
+        class A(object):
+            @cached()
+            def func(self, key):
+                callcount[0] += 1
+                return key
 
-        yield func(self, "foo")
+        a = A()
+        yield a.func("foo")
 
         self.assertEquals(callcount[0], 1)
 
-        func.invalidate("foo")
+        a.func.invalidate(("foo",))
 
-        yield func(self, "foo")
+        yield a.func("foo")
 
         self.assertEquals(callcount[0], 2)
 
     def test_invalidate_missing(self):
-        @cached()
-        def func(self, key):
-            return key
+        class A(object):
+            @cached()
+            def func(self, key):
+                return key
 
-        func.invalidate("what")
+        A().func.invalidate(("what",))
 
     @defer.inlineCallbacks
     def test_max_entries(self):
         callcount = [0]
 
-        @cached(max_entries=10)
-        def func(self, key):
-            callcount[0] += 1
-            return key
+        class A(object):
+            @cached(max_entries=10)
+            def func(self, key):
+                callcount[0] += 1
+                return key
 
-        for k in range(0,12):
-            yield func(self, k)
+        a = A()
+
+        for k in range(0, 12):
+            yield a.func(k)
 
         self.assertEquals(callcount[0], 12)
 
         # There must have been at least 2 evictions, meaning if we calculate
         # all 12 values again, we must get called at least 2 more times
         for k in range(0,12):
-            yield func(self, k)
+            yield a.func(k)
 
         self.assertTrue(callcount[0] >= 14,
             msg="Expected callcount >= 14, got %d" % (callcount[0]))
 
-    @defer.inlineCallbacks
     def test_prefill(self):
         callcount = [0]
 
-        @cached()
-        def func(self, key):
-            callcount[0] += 1
-            return key
+        d = defer.succeed(123)
+
+        class A(object):
+            @cached()
+            def func(self, key):
+                callcount[0] += 1
+                return d
+
+        a = A()
 
-        func.prefill("foo", 123)
+        a.func.prefill(("foo",), ObservableDeferred(d))
 
-        self.assertEquals((yield func(self, "foo")), 123)
+        self.assertEquals(a.func("foo").result, d.result)
         self.assertEquals(callcount[0], 0)
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 78f6004204..2702291178 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -46,7 +46,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
             (yield self.store.get_user_by_id(self.user_id))
         )
 
-        result = yield self.store.get_user_by_token(self.tokens[1])
+        result = yield self.store.get_user_by_token(self.tokens[0])
 
         self.assertDictContainsSubset(
             {
diff --git a/tests/test_distributor.py b/tests/test_distributor.py
index 6a0095d850..8ed48cfb70 100644
--- a/tests/test_distributor.py
+++ b/tests/test_distributor.py
@@ -73,8 +73,8 @@ class DistributorTestCase(unittest.TestCase):
             yield d
             self.assertTrue(d.called)
 
-            observers[0].assert_called_once("Go")
-            observers[1].assert_called_once("Go")
+            observers[0].assert_called_once_with("Go")
+            observers[1].assert_called_once_with("Go")
 
             self.assertEquals(mock_logger.warning.call_count, 1)
             self.assertIsInstance(mock_logger.warning.call_args[0][0],
diff --git a/tests/utils.py b/tests/utils.py
index 3b5c335911..eb035cf48f 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -114,6 +114,8 @@ class MockHttpResource(HttpServer):
         mock_request.method = http_method
         mock_request.uri = path
 
+        mock_request.getClientIP.return_value = "-"
+
         mock_request.requestHeaders.getRawHeaders.return_value=[
             "X-Matrix origin=test,key=,sig="
         ]