summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/api/test_auth.py18
-rw-r--r--tests/handlers/test_typing.py3
-rw-r--r--tests/rest/client/v1/test_register.py2
-rw-r--r--tests/rest/client/v2_alpha/test_filter.py127
-rw-r--r--tests/rest/client/v2_alpha/test_register.py2
-rw-r--r--tests/utils.py20
6 files changed, 117 insertions, 55 deletions
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index e91723ca3d..2cf262bb46 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -20,7 +20,7 @@ from mock import Mock
 from synapse.api.auth import Auth
 from synapse.api.errors import AuthError
 from synapse.types import UserID
-from tests.utils import setup_test_homeserver
+from tests.utils import setup_test_homeserver, mock_getRawHeaders
 
 import pymacaroons
 
@@ -51,7 +51,7 @@ class AuthTestCase(unittest.TestCase):
 
         request = Mock(args={})
         request.args["access_token"] = [self.test_token]
-        request.requestHeaders.getRawHeaders = Mock(return_value=[""])
+        request.requestHeaders.getRawHeaders = mock_getRawHeaders()
         requester = yield self.auth.get_user_by_req(request)
         self.assertEquals(requester.user.to_string(), self.test_user)
 
@@ -61,7 +61,7 @@ class AuthTestCase(unittest.TestCase):
 
         request = Mock(args={})
         request.args["access_token"] = [self.test_token]
-        request.requestHeaders.getRawHeaders = Mock(return_value=[""])
+        request.requestHeaders.getRawHeaders = mock_getRawHeaders()
         d = self.auth.get_user_by_req(request)
         self.failureResultOf(d, AuthError)
 
@@ -74,7 +74,7 @@ class AuthTestCase(unittest.TestCase):
         self.store.get_user_by_access_token = Mock(return_value=user_info)
 
         request = Mock(args={})
-        request.requestHeaders.getRawHeaders = Mock(return_value=[""])
+        request.requestHeaders.getRawHeaders = mock_getRawHeaders()
         d = self.auth.get_user_by_req(request)
         self.failureResultOf(d, AuthError)
 
@@ -86,7 +86,7 @@ class AuthTestCase(unittest.TestCase):
 
         request = Mock(args={})
         request.args["access_token"] = [self.test_token]
-        request.requestHeaders.getRawHeaders = Mock(return_value=[""])
+        request.requestHeaders.getRawHeaders = mock_getRawHeaders()
         requester = yield self.auth.get_user_by_req(request)
         self.assertEquals(requester.user.to_string(), self.test_user)
 
@@ -96,7 +96,7 @@ class AuthTestCase(unittest.TestCase):
 
         request = Mock(args={})
         request.args["access_token"] = [self.test_token]
-        request.requestHeaders.getRawHeaders = Mock(return_value=[""])
+        request.requestHeaders.getRawHeaders = mock_getRawHeaders()
         d = self.auth.get_user_by_req(request)
         self.failureResultOf(d, AuthError)
 
@@ -106,7 +106,7 @@ class AuthTestCase(unittest.TestCase):
         self.store.get_user_by_access_token = Mock(return_value=None)
 
         request = Mock(args={})
-        request.requestHeaders.getRawHeaders = Mock(return_value=[""])
+        request.requestHeaders.getRawHeaders = mock_getRawHeaders()
         d = self.auth.get_user_by_req(request)
         self.failureResultOf(d, AuthError)
 
@@ -121,7 +121,7 @@ class AuthTestCase(unittest.TestCase):
         request = Mock(args={})
         request.args["access_token"] = [self.test_token]
         request.args["user_id"] = [masquerading_user_id]
-        request.requestHeaders.getRawHeaders = Mock(return_value=[""])
+        request.requestHeaders.getRawHeaders = mock_getRawHeaders()
         requester = yield self.auth.get_user_by_req(request)
         self.assertEquals(requester.user.to_string(), masquerading_user_id)
 
@@ -135,7 +135,7 @@ class AuthTestCase(unittest.TestCase):
         request = Mock(args={})
         request.args["access_token"] = [self.test_token]
         request.args["user_id"] = [masquerading_user_id]
-        request.requestHeaders.getRawHeaders = Mock(return_value=[""])
+        request.requestHeaders.getRawHeaders = mock_getRawHeaders()
         d = self.auth.get_user_by_req(request)
         self.failureResultOf(d, AuthError)
 
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index c3108f5181..c718d1f98f 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -219,7 +219,8 @@ class TypingNotificationsTestCase(unittest.TestCase):
                     "user_id": self.u_onion.to_string(),
                     "typing": True,
                 }
-            )
+            ),
+            federation_auth=True,
         )
 
         self.on_new_event.assert_has_calls([
diff --git a/tests/rest/client/v1/test_register.py b/tests/rest/client/v1/test_register.py
index 44ba9ff58f..a6a4e2ffe0 100644
--- a/tests/rest/client/v1/test_register.py
+++ b/tests/rest/client/v1/test_register.py
@@ -17,6 +17,7 @@ from synapse.rest.client.v1.register import CreateUserRestServlet
 from twisted.internet import defer
 from mock import Mock
 from tests import unittest
+from tests.utils import mock_getRawHeaders
 import json
 
 
@@ -30,6 +31,7 @@ class CreateUserServletTestCase(unittest.TestCase):
             path='/_matrix/client/api/v1/createUser'
         )
         self.request.args = {}
+        self.request.requestHeaders.getRawHeaders = mock_getRawHeaders()
 
         self.registration_handler = Mock()
 
diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/v2_alpha/test_filter.py
index d1442aafac..3d27d03cbf 100644
--- a/tests/rest/client/v2_alpha/test_filter.py
+++ b/tests/rest/client/v2_alpha/test_filter.py
@@ -15,78 +15,125 @@
 
 from twisted.internet import defer
 
-from . import V2AlphaRestTestCase
+from tests import unittest
 
 from synapse.rest.client.v2_alpha import filter
 
-from synapse.api.errors import StoreError
+from synapse.api.errors import Codes
 
+import synapse.types
+
+from synapse.types import UserID
+
+from ....utils import MockHttpResource, setup_test_homeserver
+
+PATH_PREFIX = "/_matrix/client/v2_alpha"
+
+
+class FilterTestCase(unittest.TestCase):
 
-class FilterTestCase(V2AlphaRestTestCase):
     USER_ID = "@apple:test"
+    EXAMPLE_FILTER = {"type": ["m.*"]}
+    EXAMPLE_FILTER_JSON = '{"type": ["m.*"]}'
     TO_REGISTER = [filter]
 
-    def make_datastore_mock(self):
-        datastore = super(FilterTestCase, self).make_datastore_mock()
+    @defer.inlineCallbacks
+    def setUp(self):
+        self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
 
-        self._user_filters = {}
+        self.hs = yield setup_test_homeserver(
+            http_client=None,
+            resource_for_client=self.mock_resource,
+            resource_for_federation=self.mock_resource,
+        )
 
-        def add_user_filter(user_localpart, definition):
-            filters = self._user_filters.setdefault(user_localpart, [])
-            filter_id = len(filters)
-            filters.append(definition)
-            return defer.succeed(filter_id)
-        datastore.add_user_filter = add_user_filter
+        self.auth = self.hs.get_auth()
 
-        def get_user_filter(user_localpart, filter_id):
-            if user_localpart not in self._user_filters:
-                raise StoreError(404, "No user")
-            filters = self._user_filters[user_localpart]
-            if filter_id >= len(filters):
-                raise StoreError(404, "No filter")
-            return defer.succeed(filters[filter_id])
-        datastore.get_user_filter = get_user_filter
+        def get_user_by_access_token(token=None, allow_guest=False):
+            return {
+                "user": UserID.from_string(self.USER_ID),
+                "token_id": 1,
+                "is_guest": False,
+            }
 
-        return datastore
+        def get_user_by_req(request, allow_guest=False, rights="access"):
+            return synapse.types.create_requester(
+                UserID.from_string(self.USER_ID), 1, False, None)
+
+        self.auth.get_user_by_access_token = get_user_by_access_token
+        self.auth.get_user_by_req = get_user_by_req
+
+        self.store = self.hs.get_datastore()
+        self.filtering = self.hs.get_filtering()
+
+        for r in self.TO_REGISTER:
+            r.register_servlets(self.hs, self.mock_resource)
 
     @defer.inlineCallbacks
     def test_add_filter(self):
         (code, response) = yield self.mock_resource.trigger(
-            "POST", "/user/%s/filter" % (self.USER_ID), '{"type": ["m.*"]}'
+            "POST", "/user/%s/filter" % (self.USER_ID), self.EXAMPLE_FILTER_JSON
         )
         self.assertEquals(200, code)
         self.assertEquals({"filter_id": "0"}, response)
+        filter = yield self.store.get_user_filter(
+            user_localpart='apple',
+            filter_id=0,
+        )
+        self.assertEquals(filter, self.EXAMPLE_FILTER)
 
-        self.assertIn("apple", self._user_filters)
-        self.assertEquals(len(self._user_filters["apple"]), 1)
-        self.assertEquals({"type": ["m.*"]}, self._user_filters["apple"][0])
+    @defer.inlineCallbacks
+    def test_add_filter_for_other_user(self):
+        (code, response) = yield self.mock_resource.trigger(
+            "POST", "/user/%s/filter" % ('@watermelon:test'), self.EXAMPLE_FILTER_JSON
+        )
+        self.assertEquals(403, code)
+        self.assertEquals(response['errcode'], Codes.FORBIDDEN)
 
     @defer.inlineCallbacks
-    def test_get_filter(self):
-        self._user_filters["apple"] = [
-            {"type": ["m.*"]}
-        ]
+    def test_add_filter_non_local_user(self):
+        _is_mine = self.hs.is_mine
+        self.hs.is_mine = lambda target_user: False
+        (code, response) = yield self.mock_resource.trigger(
+            "POST", "/user/%s/filter" % (self.USER_ID), self.EXAMPLE_FILTER_JSON
+        )
+        self.hs.is_mine = _is_mine
+        self.assertEquals(403, code)
+        self.assertEquals(response['errcode'], Codes.FORBIDDEN)
 
+    @defer.inlineCallbacks
+    def test_get_filter(self):
+        filter_id = yield self.filtering.add_user_filter(
+            user_localpart='apple',
+            user_filter=self.EXAMPLE_FILTER
+        )
         (code, response) = yield self.mock_resource.trigger_get(
-            "/user/%s/filter/0" % (self.USER_ID)
+            "/user/%s/filter/%s" % (self.USER_ID, filter_id)
         )
         self.assertEquals(200, code)
-        self.assertEquals({"type": ["m.*"]}, response)
+        self.assertEquals(self.EXAMPLE_FILTER, response)
 
     @defer.inlineCallbacks
-    def test_get_filter_no_id(self):
-        self._user_filters["apple"] = [
-            {"type": ["m.*"]}
-        ]
+    def test_get_filter_non_existant(self):
+        (code, response) = yield self.mock_resource.trigger_get(
+            "/user/%s/filter/12382148321" % (self.USER_ID)
+        )
+        self.assertEquals(400, code)
+        self.assertEquals(response['errcode'], Codes.NOT_FOUND)
 
+    # Currently invalid params do not have an appropriate errcode
+    # in errors.py
+    @defer.inlineCallbacks
+    def test_get_filter_invalid_id(self):
         (code, response) = yield self.mock_resource.trigger_get(
-            "/user/%s/filter/2" % (self.USER_ID)
+            "/user/%s/filter/foobar" % (self.USER_ID)
         )
-        self.assertEquals(404, code)
+        self.assertEquals(400, code)
 
+    # No ID also returns an invalid_id error
     @defer.inlineCallbacks
-    def test_get_filter_no_user(self):
+    def test_get_filter_no_id(self):
         (code, response) = yield self.mock_resource.trigger_get(
-            "/user/%s/filter/0" % (self.USER_ID)
+            "/user/%s/filter/" % (self.USER_ID)
         )
-        self.assertEquals(404, code)
+        self.assertEquals(400, code)
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index e9cb416e4b..b4a787c436 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -3,6 +3,7 @@ from synapse.api.errors import SynapseError
 from twisted.internet import defer
 from mock import Mock
 from tests import unittest
+from tests.utils import mock_getRawHeaders
 import json
 
 
@@ -16,6 +17,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
             path='/_matrix/api/v2_alpha/register'
         )
         self.request.args = {}
+        self.request.requestHeaders.getRawHeaders = mock_getRawHeaders()
 
         self.appservice = None
         self.auth = Mock(get_appservice_by_req=Mock(
diff --git a/tests/utils.py b/tests/utils.py
index f74526b6a7..5929f1c729 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -116,6 +116,15 @@ def get_mock_call_args(pattern_func, mock_func):
     return getcallargs(pattern_func, *invoked_args, **invoked_kargs)
 
 
+def mock_getRawHeaders(headers=None):
+    headers = headers if headers is not None else {}
+
+    def getRawHeaders(name, default=None):
+        return headers.get(name, default)
+
+    return getRawHeaders
+
+
 # This is a mock /resource/ not an entire server
 class MockHttpResource(HttpServer):
 
@@ -128,7 +137,7 @@ class MockHttpResource(HttpServer):
 
     @patch('twisted.web.http.Request')
     @defer.inlineCallbacks
-    def trigger(self, http_method, path, content, mock_request):
+    def trigger(self, http_method, path, content, mock_request, federation_auth=False):
         """ Fire an HTTP event.
 
         Args:
@@ -156,9 +165,10 @@ class MockHttpResource(HttpServer):
 
         mock_request.getClientIP.return_value = "-"
 
-        mock_request.requestHeaders.getRawHeaders.return_value = [
-            "X-Matrix origin=test,key=,sig="
-        ]
+        headers = {}
+        if federation_auth:
+            headers["Authorization"] = ["X-Matrix origin=test,key=,sig="]
+        mock_request.requestHeaders.getRawHeaders = mock_getRawHeaders(headers)
 
         # return the right path if the event requires it
         mock_request.path = path
@@ -189,7 +199,7 @@ class MockHttpResource(HttpServer):
                     )
                     defer.returnValue((code, response))
                 except CodeMessageException as e:
-                    defer.returnValue((e.code, cs_error(e.msg)))
+                    defer.returnValue((e.code, cs_error(e.msg, code=e.errcode)))
 
         raise KeyError("No event can handle %s" % path)