diff options
Diffstat (limited to '')
-rw-r--r-- | tests/api/test_auth.py | 18 | ||||
-rw-r--r-- | tests/handlers/test_typing.py | 3 | ||||
-rw-r--r-- | tests/rest/client/v1/test_register.py | 2 | ||||
-rw-r--r-- | tests/rest/client/v2_alpha/test_register.py | 2 | ||||
-rw-r--r-- | tests/utils.py | 18 |
5 files changed, 29 insertions, 14 deletions
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index e91723ca3d..2cf262bb46 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -20,7 +20,7 @@ from mock import Mock from synapse.api.auth import Auth from synapse.api.errors import AuthError from synapse.types import UserID -from tests.utils import setup_test_homeserver +from tests.utils import setup_test_homeserver, mock_getRawHeaders import pymacaroons @@ -51,7 +51,7 @@ class AuthTestCase(unittest.TestCase): request = Mock(args={}) request.args["access_token"] = [self.test_token] - request.requestHeaders.getRawHeaders = Mock(return_value=[""]) + request.requestHeaders.getRawHeaders = mock_getRawHeaders() requester = yield self.auth.get_user_by_req(request) self.assertEquals(requester.user.to_string(), self.test_user) @@ -61,7 +61,7 @@ class AuthTestCase(unittest.TestCase): request = Mock(args={}) request.args["access_token"] = [self.test_token] - request.requestHeaders.getRawHeaders = Mock(return_value=[""]) + request.requestHeaders.getRawHeaders = mock_getRawHeaders() d = self.auth.get_user_by_req(request) self.failureResultOf(d, AuthError) @@ -74,7 +74,7 @@ class AuthTestCase(unittest.TestCase): self.store.get_user_by_access_token = Mock(return_value=user_info) request = Mock(args={}) - request.requestHeaders.getRawHeaders = Mock(return_value=[""]) + request.requestHeaders.getRawHeaders = mock_getRawHeaders() d = self.auth.get_user_by_req(request) self.failureResultOf(d, AuthError) @@ -86,7 +86,7 @@ class AuthTestCase(unittest.TestCase): request = Mock(args={}) request.args["access_token"] = [self.test_token] - request.requestHeaders.getRawHeaders = Mock(return_value=[""]) + request.requestHeaders.getRawHeaders = mock_getRawHeaders() requester = yield self.auth.get_user_by_req(request) self.assertEquals(requester.user.to_string(), self.test_user) @@ -96,7 +96,7 @@ class AuthTestCase(unittest.TestCase): request = Mock(args={}) request.args["access_token"] = [self.test_token] - request.requestHeaders.getRawHeaders = Mock(return_value=[""]) + request.requestHeaders.getRawHeaders = mock_getRawHeaders() d = self.auth.get_user_by_req(request) self.failureResultOf(d, AuthError) @@ -106,7 +106,7 @@ class AuthTestCase(unittest.TestCase): self.store.get_user_by_access_token = Mock(return_value=None) request = Mock(args={}) - request.requestHeaders.getRawHeaders = Mock(return_value=[""]) + request.requestHeaders.getRawHeaders = mock_getRawHeaders() d = self.auth.get_user_by_req(request) self.failureResultOf(d, AuthError) @@ -121,7 +121,7 @@ class AuthTestCase(unittest.TestCase): request = Mock(args={}) request.args["access_token"] = [self.test_token] request.args["user_id"] = [masquerading_user_id] - request.requestHeaders.getRawHeaders = Mock(return_value=[""]) + request.requestHeaders.getRawHeaders = mock_getRawHeaders() requester = yield self.auth.get_user_by_req(request) self.assertEquals(requester.user.to_string(), masquerading_user_id) @@ -135,7 +135,7 @@ class AuthTestCase(unittest.TestCase): request = Mock(args={}) request.args["access_token"] = [self.test_token] request.args["user_id"] = [masquerading_user_id] - request.requestHeaders.getRawHeaders = Mock(return_value=[""]) + request.requestHeaders.getRawHeaders = mock_getRawHeaders() d = self.auth.get_user_by_req(request) self.failureResultOf(d, AuthError) diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index c3108f5181..c718d1f98f 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -219,7 +219,8 @@ class TypingNotificationsTestCase(unittest.TestCase): "user_id": self.u_onion.to_string(), "typing": True, } - ) + ), + federation_auth=True, ) self.on_new_event.assert_has_calls([ diff --git a/tests/rest/client/v1/test_register.py b/tests/rest/client/v1/test_register.py index 44ba9ff58f..a6a4e2ffe0 100644 --- a/tests/rest/client/v1/test_register.py +++ b/tests/rest/client/v1/test_register.py @@ -17,6 +17,7 @@ from synapse.rest.client.v1.register import CreateUserRestServlet from twisted.internet import defer from mock import Mock from tests import unittest +from tests.utils import mock_getRawHeaders import json @@ -30,6 +31,7 @@ class CreateUserServletTestCase(unittest.TestCase): path='/_matrix/client/api/v1/createUser' ) self.request.args = {} + self.request.requestHeaders.getRawHeaders = mock_getRawHeaders() self.registration_handler = Mock() diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index e9cb416e4b..b4a787c436 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -3,6 +3,7 @@ from synapse.api.errors import SynapseError from twisted.internet import defer from mock import Mock from tests import unittest +from tests.utils import mock_getRawHeaders import json @@ -16,6 +17,7 @@ class RegisterRestServletTestCase(unittest.TestCase): path='/_matrix/api/v2_alpha/register' ) self.request.args = {} + self.request.requestHeaders.getRawHeaders = mock_getRawHeaders() self.appservice = None self.auth = Mock(get_appservice_by_req=Mock( diff --git a/tests/utils.py b/tests/utils.py index a91d167f90..5929f1c729 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -116,6 +116,15 @@ def get_mock_call_args(pattern_func, mock_func): return getcallargs(pattern_func, *invoked_args, **invoked_kargs) +def mock_getRawHeaders(headers=None): + headers = headers if headers is not None else {} + + def getRawHeaders(name, default=None): + return headers.get(name, default) + + return getRawHeaders + + # This is a mock /resource/ not an entire server class MockHttpResource(HttpServer): @@ -128,7 +137,7 @@ class MockHttpResource(HttpServer): @patch('twisted.web.http.Request') @defer.inlineCallbacks - def trigger(self, http_method, path, content, mock_request): + def trigger(self, http_method, path, content, mock_request, federation_auth=False): """ Fire an HTTP event. Args: @@ -156,9 +165,10 @@ class MockHttpResource(HttpServer): mock_request.getClientIP.return_value = "-" - mock_request.requestHeaders.getRawHeaders.return_value = [ - "X-Matrix origin=test,key=,sig=" - ] + headers = {} + if federation_auth: + headers["Authorization"] = ["X-Matrix origin=test,key=,sig="] + mock_request.requestHeaders.getRawHeaders = mock_getRawHeaders(headers) # return the right path if the event requires it mock_request.path = path |