summary refs log tree commit diff
path: root/tests/rest
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rest')
-rw-r--r--tests/rest/client/test_transactions.py69
-rw-r--r--tests/rest/client/v1/test_profile.py13
-rw-r--r--tests/rest/client/v1/test_register.py78
-rw-r--r--tests/rest/client/v1/test_rooms.py8
-rw-r--r--tests/rest/client/v1/test_typing.py5
-rw-r--r--tests/rest/client/v2_alpha/test_filter.py127
-rw-r--r--tests/rest/client/v2_alpha/test_register.py32
7 files changed, 268 insertions, 64 deletions
diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py
new file mode 100644
index 0000000000..d7cea30260
--- /dev/null
+++ b/tests/rest/client/test_transactions.py
@@ -0,0 +1,69 @@
+from synapse.rest.client.transactions import HttpTransactionCache
+from synapse.rest.client.transactions import CLEANUP_PERIOD_MS
+from twisted.internet import defer
+from mock import Mock, call
+from tests import unittest
+from tests.utils import MockClock
+
+
+class HttpTransactionCacheTestCase(unittest.TestCase):
+
+    def setUp(self):
+        self.clock = MockClock()
+        self.cache = HttpTransactionCache(self.clock)
+
+        self.mock_http_response = (200, "GOOD JOB!")
+        self.mock_key = "foo"
+
+    @defer.inlineCallbacks
+    def test_executes_given_function(self):
+        cb = Mock(
+            return_value=defer.succeed(self.mock_http_response)
+        )
+        res = yield self.cache.fetch_or_execute(
+            self.mock_key, cb, "some_arg", keyword="arg"
+        )
+        cb.assert_called_once_with("some_arg", keyword="arg")
+        self.assertEqual(res, self.mock_http_response)
+
+    @defer.inlineCallbacks
+    def test_deduplicates_based_on_key(self):
+        cb = Mock(
+            return_value=defer.succeed(self.mock_http_response)
+        )
+        for i in range(3):  # invoke multiple times
+            res = yield self.cache.fetch_or_execute(
+                self.mock_key, cb, "some_arg", keyword="arg", changing_args=i
+            )
+            self.assertEqual(res, self.mock_http_response)
+        # expect only a single call to do the work
+        cb.assert_called_once_with("some_arg", keyword="arg", changing_args=0)
+
+    @defer.inlineCallbacks
+    def test_cleans_up(self):
+        cb = Mock(
+            return_value=defer.succeed(self.mock_http_response)
+        )
+        yield self.cache.fetch_or_execute(
+            self.mock_key, cb, "an arg"
+        )
+        # should NOT have cleaned up yet
+        self.clock.advance_time_msec(CLEANUP_PERIOD_MS / 2)
+
+        yield self.cache.fetch_or_execute(
+            self.mock_key, cb, "an arg"
+        )
+        # still using cache
+        cb.assert_called_once_with("an arg")
+
+        self.clock.advance_time_msec(CLEANUP_PERIOD_MS)
+
+        yield self.cache.fetch_or_execute(
+            self.mock_key, cb, "an arg"
+        )
+        # no longer using cache
+        self.assertEqual(cb.call_count, 2)
+        self.assertEqual(
+            cb.call_args_list,
+            [call("an arg",), call("an arg",)]
+        )
diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py
index af02fce8fb..1e95e97538 100644
--- a/tests/rest/client/v1/test_profile.py
+++ b/tests/rest/client/v1/test_profile.py
@@ -14,17 +14,14 @@
 # limitations under the License.
 
 """Tests REST events for /profile paths."""
-from tests import unittest
-from twisted.internet import defer
-
 from mock import Mock
+from twisted.internet import defer
 
-from ....utils import MockHttpResource, setup_test_homeserver
-
+import synapse.types
 from synapse.api.errors import SynapseError, AuthError
-from synapse.types import Requester, UserID
-
 from synapse.rest.client.v1 import profile
+from tests import unittest
+from ....utils import MockHttpResource, setup_test_homeserver
 
 myid = "@1234ABCD:test"
 PATH_PREFIX = "/_matrix/client/api/v1"
@@ -52,7 +49,7 @@ class ProfileTestCase(unittest.TestCase):
         )
 
         def _get_user_by_req(request=None, allow_guest=False):
-            return Requester(UserID.from_string(myid), "", False)
+            return synapse.types.create_requester(myid)
 
         hs.get_v1auth().get_user_by_req = _get_user_by_req
 
diff --git a/tests/rest/client/v1/test_register.py b/tests/rest/client/v1/test_register.py
new file mode 100644
index 0000000000..a6a4e2ffe0
--- /dev/null
+++ b/tests/rest/client/v1/test_register.py
@@ -0,0 +1,78 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015, 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.rest.client.v1.register import CreateUserRestServlet
+from twisted.internet import defer
+from mock import Mock
+from tests import unittest
+from tests.utils import mock_getRawHeaders
+import json
+
+
+class CreateUserServletTestCase(unittest.TestCase):
+
+    def setUp(self):
+        # do the dance to hook up request data to self.request_data
+        self.request_data = ""
+        self.request = Mock(
+            content=Mock(read=Mock(side_effect=lambda: self.request_data)),
+            path='/_matrix/client/api/v1/createUser'
+        )
+        self.request.args = {}
+        self.request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+
+        self.registration_handler = Mock()
+
+        self.appservice = Mock(sender="@as:test")
+        self.datastore = Mock(
+            get_app_service_by_token=Mock(return_value=self.appservice)
+        )
+
+        # do the dance to hook things up to the hs global
+        handlers = Mock(
+            registration_handler=self.registration_handler,
+        )
+        self.hs = Mock()
+        self.hs.hostname = "superbig~testing~thing.com"
+        self.hs.get_datastore = Mock(return_value=self.datastore)
+        self.hs.get_handlers = Mock(return_value=handlers)
+        self.servlet = CreateUserRestServlet(self.hs)
+
+    @defer.inlineCallbacks
+    def test_POST_createuser_with_valid_user(self):
+        user_id = "@someone:interesting"
+        token = "my token"
+        self.request.args = {
+            "access_token": "i_am_an_app_service"
+        }
+        self.request_data = json.dumps({
+            "localpart": "someone",
+            "displayname": "someone interesting",
+            "duration_seconds": 200
+        })
+
+        self.registration_handler.get_or_create_user = Mock(
+            return_value=(user_id, token)
+        )
+
+        (code, result) = yield self.servlet.on_POST(self.request)
+        self.assertEquals(code, 200)
+
+        det_data = {
+            "user_id": user_id,
+            "access_token": token,
+            "home_server": self.hs.hostname
+        }
+        self.assertDictContainsSubset(det_data, result)
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 4ab8b35e6b..4fe99ebc0b 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -259,8 +259,8 @@ class RoomPermissionsTestCase(RestTestCase):
         # set [invite/join/left] of self, set [invite/join/left] of other,
         # expect all 404s because room doesn't exist on any server
         for usr in [self.user_id, self.rmcreator_id]:
-            yield self.join(room=room, user=usr, expect_code=404)
-            yield self.leave(room=room, user=usr, expect_code=404)
+            yield self.join(room=room, user=usr, expect_code=403)
+            yield self.leave(room=room, user=usr, expect_code=403)
 
     @defer.inlineCallbacks
     def test_membership_private_room_perms(self):
@@ -1032,7 +1032,7 @@ class RoomMessageListTestCase(RestTestCase):
 
     @defer.inlineCallbacks
     def test_topo_token_is_accepted(self):
-        token = "t1-0_0_0_0_0_0"
+        token = "t1-0_0_0_0_0_0_0"
         (code, response) = yield self.mock_resource.trigger_get(
             "/rooms/%s/messages?access_token=x&from=%s" %
             (self.room_id, token))
@@ -1044,7 +1044,7 @@ class RoomMessageListTestCase(RestTestCase):
 
     @defer.inlineCallbacks
     def test_stream_token_is_accepted_for_fwd_pagianation(self):
-        token = "s0_0_0_0_0_0"
+        token = "s0_0_0_0_0_0_0"
         (code, response) = yield self.mock_resource.trigger_get(
             "/rooms/%s/messages?access_token=x&from=%s" %
             (self.room_id, token))
diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py
index d0037a53ef..a269e6f56e 100644
--- a/tests/rest/client/v1/test_typing.py
+++ b/tests/rest/client/v1/test_typing.py
@@ -105,9 +105,6 @@ class RoomTypingTestCase(RestTestCase):
         # Need another user to make notifications actually work
         yield self.join(self.room_id, user="@jim:red")
 
-    def tearDown(self):
-        self.hs.get_handlers().typing_notification_handler.tearDown()
-
     @defer.inlineCallbacks
     def test_set_typing(self):
         (code, _) = yield self.mock_resource.trigger(
@@ -147,7 +144,7 @@ class RoomTypingTestCase(RestTestCase):
 
         self.assertEquals(self.event_source.get_current_key(), 1)
 
-        self.clock.advance_time(31)
+        self.clock.advance_time(36)
 
         self.assertEquals(self.event_source.get_current_key(), 2)
 
diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/v2_alpha/test_filter.py
index d1442aafac..3d27d03cbf 100644
--- a/tests/rest/client/v2_alpha/test_filter.py
+++ b/tests/rest/client/v2_alpha/test_filter.py
@@ -15,78 +15,125 @@
 
 from twisted.internet import defer
 
-from . import V2AlphaRestTestCase
+from tests import unittest
 
 from synapse.rest.client.v2_alpha import filter
 
-from synapse.api.errors import StoreError
+from synapse.api.errors import Codes
 
+import synapse.types
+
+from synapse.types import UserID
+
+from ....utils import MockHttpResource, setup_test_homeserver
+
+PATH_PREFIX = "/_matrix/client/v2_alpha"
+
+
+class FilterTestCase(unittest.TestCase):
 
-class FilterTestCase(V2AlphaRestTestCase):
     USER_ID = "@apple:test"
+    EXAMPLE_FILTER = {"type": ["m.*"]}
+    EXAMPLE_FILTER_JSON = '{"type": ["m.*"]}'
     TO_REGISTER = [filter]
 
-    def make_datastore_mock(self):
-        datastore = super(FilterTestCase, self).make_datastore_mock()
+    @defer.inlineCallbacks
+    def setUp(self):
+        self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
 
-        self._user_filters = {}
+        self.hs = yield setup_test_homeserver(
+            http_client=None,
+            resource_for_client=self.mock_resource,
+            resource_for_federation=self.mock_resource,
+        )
 
-        def add_user_filter(user_localpart, definition):
-            filters = self._user_filters.setdefault(user_localpart, [])
-            filter_id = len(filters)
-            filters.append(definition)
-            return defer.succeed(filter_id)
-        datastore.add_user_filter = add_user_filter
+        self.auth = self.hs.get_auth()
 
-        def get_user_filter(user_localpart, filter_id):
-            if user_localpart not in self._user_filters:
-                raise StoreError(404, "No user")
-            filters = self._user_filters[user_localpart]
-            if filter_id >= len(filters):
-                raise StoreError(404, "No filter")
-            return defer.succeed(filters[filter_id])
-        datastore.get_user_filter = get_user_filter
+        def get_user_by_access_token(token=None, allow_guest=False):
+            return {
+                "user": UserID.from_string(self.USER_ID),
+                "token_id": 1,
+                "is_guest": False,
+            }
 
-        return datastore
+        def get_user_by_req(request, allow_guest=False, rights="access"):
+            return synapse.types.create_requester(
+                UserID.from_string(self.USER_ID), 1, False, None)
+
+        self.auth.get_user_by_access_token = get_user_by_access_token
+        self.auth.get_user_by_req = get_user_by_req
+
+        self.store = self.hs.get_datastore()
+        self.filtering = self.hs.get_filtering()
+
+        for r in self.TO_REGISTER:
+            r.register_servlets(self.hs, self.mock_resource)
 
     @defer.inlineCallbacks
     def test_add_filter(self):
         (code, response) = yield self.mock_resource.trigger(
-            "POST", "/user/%s/filter" % (self.USER_ID), '{"type": ["m.*"]}'
+            "POST", "/user/%s/filter" % (self.USER_ID), self.EXAMPLE_FILTER_JSON
         )
         self.assertEquals(200, code)
         self.assertEquals({"filter_id": "0"}, response)
+        filter = yield self.store.get_user_filter(
+            user_localpart='apple',
+            filter_id=0,
+        )
+        self.assertEquals(filter, self.EXAMPLE_FILTER)
 
-        self.assertIn("apple", self._user_filters)
-        self.assertEquals(len(self._user_filters["apple"]), 1)
-        self.assertEquals({"type": ["m.*"]}, self._user_filters["apple"][0])
+    @defer.inlineCallbacks
+    def test_add_filter_for_other_user(self):
+        (code, response) = yield self.mock_resource.trigger(
+            "POST", "/user/%s/filter" % ('@watermelon:test'), self.EXAMPLE_FILTER_JSON
+        )
+        self.assertEquals(403, code)
+        self.assertEquals(response['errcode'], Codes.FORBIDDEN)
 
     @defer.inlineCallbacks
-    def test_get_filter(self):
-        self._user_filters["apple"] = [
-            {"type": ["m.*"]}
-        ]
+    def test_add_filter_non_local_user(self):
+        _is_mine = self.hs.is_mine
+        self.hs.is_mine = lambda target_user: False
+        (code, response) = yield self.mock_resource.trigger(
+            "POST", "/user/%s/filter" % (self.USER_ID), self.EXAMPLE_FILTER_JSON
+        )
+        self.hs.is_mine = _is_mine
+        self.assertEquals(403, code)
+        self.assertEquals(response['errcode'], Codes.FORBIDDEN)
 
+    @defer.inlineCallbacks
+    def test_get_filter(self):
+        filter_id = yield self.filtering.add_user_filter(
+            user_localpart='apple',
+            user_filter=self.EXAMPLE_FILTER
+        )
         (code, response) = yield self.mock_resource.trigger_get(
-            "/user/%s/filter/0" % (self.USER_ID)
+            "/user/%s/filter/%s" % (self.USER_ID, filter_id)
         )
         self.assertEquals(200, code)
-        self.assertEquals({"type": ["m.*"]}, response)
+        self.assertEquals(self.EXAMPLE_FILTER, response)
 
     @defer.inlineCallbacks
-    def test_get_filter_no_id(self):
-        self._user_filters["apple"] = [
-            {"type": ["m.*"]}
-        ]
+    def test_get_filter_non_existant(self):
+        (code, response) = yield self.mock_resource.trigger_get(
+            "/user/%s/filter/12382148321" % (self.USER_ID)
+        )
+        self.assertEquals(400, code)
+        self.assertEquals(response['errcode'], Codes.NOT_FOUND)
 
+    # Currently invalid params do not have an appropriate errcode
+    # in errors.py
+    @defer.inlineCallbacks
+    def test_get_filter_invalid_id(self):
         (code, response) = yield self.mock_resource.trigger_get(
-            "/user/%s/filter/2" % (self.USER_ID)
+            "/user/%s/filter/foobar" % (self.USER_ID)
         )
-        self.assertEquals(404, code)
+        self.assertEquals(400, code)
 
+    # No ID also returns an invalid_id error
     @defer.inlineCallbacks
-    def test_get_filter_no_user(self):
+    def test_get_filter_no_id(self):
         (code, response) = yield self.mock_resource.trigger_get(
-            "/user/%s/filter/0" % (self.USER_ID)
+            "/user/%s/filter/" % (self.USER_ID)
         )
-        self.assertEquals(404, code)
+        self.assertEquals(400, code)
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index affd42c015..b6173ab2ee 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -3,6 +3,7 @@ from synapse.api.errors import SynapseError
 from twisted.internet import defer
 from mock import Mock
 from tests import unittest
+from tests.utils import mock_getRawHeaders
 import json
 
 
@@ -16,10 +17,11 @@ class RegisterRestServletTestCase(unittest.TestCase):
             path='/_matrix/api/v2_alpha/register'
         )
         self.request.args = {}
+        self.request.requestHeaders.getRawHeaders = mock_getRawHeaders()
 
         self.appservice = None
         self.auth = Mock(get_appservice_by_req=Mock(
-            side_effect=lambda x: defer.succeed(self.appservice))
+            side_effect=lambda x: self.appservice)
         )
 
         self.auth_result = (False, None, None, None)
@@ -30,10 +32,10 @@ class RegisterRestServletTestCase(unittest.TestCase):
         self.registration_handler = Mock()
         self.identity_handler = Mock()
         self.login_handler = Mock()
+        self.device_handler = Mock()
 
         # do the dance to hook it up to the hs global
         self.handlers = Mock(
-            auth_handler=self.auth_handler,
             registration_handler=self.registration_handler,
             identity_handler=self.identity_handler,
             login_handler=self.login_handler
@@ -42,6 +44,8 @@ class RegisterRestServletTestCase(unittest.TestCase):
         self.hs.hostname = "superbig~testing~thing.com"
         self.hs.get_auth = Mock(return_value=self.auth)
         self.hs.get_handlers = Mock(return_value=self.handlers)
+        self.hs.get_auth_handler = Mock(return_value=self.auth_handler)
+        self.hs.get_device_handler = Mock(return_value=self.device_handler)
         self.hs.config.enable_registration = True
 
         # init the thing we're testing
@@ -61,8 +65,12 @@ class RegisterRestServletTestCase(unittest.TestCase):
             "id": "1234"
         }
         self.registration_handler.appservice_register = Mock(
-            return_value=(user_id, token)
+            return_value=user_id
         )
+        self.auth_handler.get_access_token_for_user_id = Mock(
+            return_value=token
+        )
+
         (code, result) = yield self.servlet.on_POST(self.request)
         self.assertEquals(code, 200)
         det_data = {
@@ -71,7 +79,6 @@ class RegisterRestServletTestCase(unittest.TestCase):
             "home_server": self.hs.hostname
         }
         self.assertDictContainsSubset(det_data, result)
-        self.assertIn("refresh_token", result)
 
     @defer.inlineCallbacks
     def test_POST_appservice_registration_invalid(self):
@@ -105,26 +112,35 @@ class RegisterRestServletTestCase(unittest.TestCase):
     def test_POST_user_valid(self):
         user_id = "@kermit:muppet"
         token = "kermits_access_token"
+        device_id = "frogfone"
         self.request_data = json.dumps({
             "username": "kermit",
-            "password": "monkey"
+            "password": "monkey",
+            "device_id": device_id,
         })
         self.registration_handler.check_username = Mock(return_value=True)
         self.auth_result = (True, None, {
             "username": "kermit",
             "password": "monkey"
         }, None)
-        self.registration_handler.register = Mock(return_value=(user_id, token))
+        self.registration_handler.register = Mock(return_value=(user_id, None))
+        self.auth_handler.get_access_token_for_user_id = Mock(
+            return_value=token
+        )
+        self.device_handler.check_device_registered = \
+            Mock(return_value=device_id)
 
         (code, result) = yield self.servlet.on_POST(self.request)
         self.assertEquals(code, 200)
         det_data = {
             "user_id": user_id,
             "access_token": token,
-            "home_server": self.hs.hostname
+            "home_server": self.hs.hostname,
+            "device_id": device_id,
         }
         self.assertDictContainsSubset(det_data, result)
-        self.assertIn("refresh_token", result)
+        self.auth_handler.get_login_tuple_for_user_id(
+            user_id, device_id=device_id, initial_device_display_name=None)
 
     def test_POST_disabled_registration(self):
         self.hs.config.enable_registration = False