summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/api/auth.py46
-rw-r--r--tests/api/test_auth.py18
-rw-r--r--tests/handlers/test_typing.py3
-rw-r--r--tests/rest/client/v1/test_register.py2
-rw-r--r--tests/rest/client/v2_alpha/test_register.py2
-rw-r--r--tests/utils.py18
6 files changed, 66 insertions, 23 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index b6a151a7ec..69b3392735 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -1170,7 +1170,8 @@ def has_access_token(request):
         bool: False if no access_token was given, True otherwise.
     """
     query_params = request.args.get("access_token")
-    return bool(query_params)
+    auth_headers = request.requestHeaders.getRawHeaders("Authorization")
+    return bool(query_params) or bool(auth_headers)
 
 
 def get_access_token_from_request(request, token_not_found_http_status=401):
@@ -1188,13 +1189,40 @@ def get_access_token_from_request(request, token_not_found_http_status=401):
     Raises:
         AuthError: If there isn't an access_token in the request.
     """
+
+    auth_headers = request.requestHeaders.getRawHeaders("Authorization")
     query_params = request.args.get("access_token")
-    # Try to get the access_token from the query params.
-    if not query_params:
-        raise AuthError(
-            token_not_found_http_status,
-            "Missing access token.",
-            errcode=Codes.MISSING_TOKEN
-        )
+    if auth_headers:
+        # Try the get the access_token from a "Authorization: Bearer"
+        # header
+        if query_params is not None:
+            raise AuthError(
+                token_not_found_http_status,
+                "Mixing Authorization headers and access_token query parameters.",
+                errcode=Codes.MISSING_TOKEN,
+            )
+        if len(auth_headers) > 1:
+            raise AuthError(
+                token_not_found_http_status,
+                "Too many Authorization headers.",
+                errcode=Codes.MISSING_TOKEN,
+            )
+        parts = auth_headers[0].split(" ")
+        if parts[0] == "Bearer" and len(parts) == 2:
+            return parts[1]
+        else:
+            raise AuthError(
+                token_not_found_http_status,
+                "Invalid Authorization header.",
+                errcode=Codes.MISSING_TOKEN,
+            )
+    else:
+        # Try to get the access_token from the query params.
+        if not query_params:
+            raise AuthError(
+                token_not_found_http_status,
+                "Missing access token.",
+                errcode=Codes.MISSING_TOKEN
+            )
 
-    return query_params[0]
+        return query_params[0]
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index e91723ca3d..2cf262bb46 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -20,7 +20,7 @@ from mock import Mock
 from synapse.api.auth import Auth
 from synapse.api.errors import AuthError
 from synapse.types import UserID
-from tests.utils import setup_test_homeserver
+from tests.utils import setup_test_homeserver, mock_getRawHeaders
 
 import pymacaroons
 
@@ -51,7 +51,7 @@ class AuthTestCase(unittest.TestCase):
 
         request = Mock(args={})
         request.args["access_token"] = [self.test_token]
-        request.requestHeaders.getRawHeaders = Mock(return_value=[""])
+        request.requestHeaders.getRawHeaders = mock_getRawHeaders()
         requester = yield self.auth.get_user_by_req(request)
         self.assertEquals(requester.user.to_string(), self.test_user)
 
@@ -61,7 +61,7 @@ class AuthTestCase(unittest.TestCase):
 
         request = Mock(args={})
         request.args["access_token"] = [self.test_token]
-        request.requestHeaders.getRawHeaders = Mock(return_value=[""])
+        request.requestHeaders.getRawHeaders = mock_getRawHeaders()
         d = self.auth.get_user_by_req(request)
         self.failureResultOf(d, AuthError)
 
@@ -74,7 +74,7 @@ class AuthTestCase(unittest.TestCase):
         self.store.get_user_by_access_token = Mock(return_value=user_info)
 
         request = Mock(args={})
-        request.requestHeaders.getRawHeaders = Mock(return_value=[""])
+        request.requestHeaders.getRawHeaders = mock_getRawHeaders()
         d = self.auth.get_user_by_req(request)
         self.failureResultOf(d, AuthError)
 
@@ -86,7 +86,7 @@ class AuthTestCase(unittest.TestCase):
 
         request = Mock(args={})
         request.args["access_token"] = [self.test_token]
-        request.requestHeaders.getRawHeaders = Mock(return_value=[""])
+        request.requestHeaders.getRawHeaders = mock_getRawHeaders()
         requester = yield self.auth.get_user_by_req(request)
         self.assertEquals(requester.user.to_string(), self.test_user)
 
@@ -96,7 +96,7 @@ class AuthTestCase(unittest.TestCase):
 
         request = Mock(args={})
         request.args["access_token"] = [self.test_token]
-        request.requestHeaders.getRawHeaders = Mock(return_value=[""])
+        request.requestHeaders.getRawHeaders = mock_getRawHeaders()
         d = self.auth.get_user_by_req(request)
         self.failureResultOf(d, AuthError)
 
@@ -106,7 +106,7 @@ class AuthTestCase(unittest.TestCase):
         self.store.get_user_by_access_token = Mock(return_value=None)
 
         request = Mock(args={})
-        request.requestHeaders.getRawHeaders = Mock(return_value=[""])
+        request.requestHeaders.getRawHeaders = mock_getRawHeaders()
         d = self.auth.get_user_by_req(request)
         self.failureResultOf(d, AuthError)
 
@@ -121,7 +121,7 @@ class AuthTestCase(unittest.TestCase):
         request = Mock(args={})
         request.args["access_token"] = [self.test_token]
         request.args["user_id"] = [masquerading_user_id]
-        request.requestHeaders.getRawHeaders = Mock(return_value=[""])
+        request.requestHeaders.getRawHeaders = mock_getRawHeaders()
         requester = yield self.auth.get_user_by_req(request)
         self.assertEquals(requester.user.to_string(), masquerading_user_id)
 
@@ -135,7 +135,7 @@ class AuthTestCase(unittest.TestCase):
         request = Mock(args={})
         request.args["access_token"] = [self.test_token]
         request.args["user_id"] = [masquerading_user_id]
-        request.requestHeaders.getRawHeaders = Mock(return_value=[""])
+        request.requestHeaders.getRawHeaders = mock_getRawHeaders()
         d = self.auth.get_user_by_req(request)
         self.failureResultOf(d, AuthError)
 
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index c3108f5181..c718d1f98f 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -219,7 +219,8 @@ class TypingNotificationsTestCase(unittest.TestCase):
                     "user_id": self.u_onion.to_string(),
                     "typing": True,
                 }
-            )
+            ),
+            federation_auth=True,
         )
 
         self.on_new_event.assert_has_calls([
diff --git a/tests/rest/client/v1/test_register.py b/tests/rest/client/v1/test_register.py
index 44ba9ff58f..a6a4e2ffe0 100644
--- a/tests/rest/client/v1/test_register.py
+++ b/tests/rest/client/v1/test_register.py
@@ -17,6 +17,7 @@ from synapse.rest.client.v1.register import CreateUserRestServlet
 from twisted.internet import defer
 from mock import Mock
 from tests import unittest
+from tests.utils import mock_getRawHeaders
 import json
 
 
@@ -30,6 +31,7 @@ class CreateUserServletTestCase(unittest.TestCase):
             path='/_matrix/client/api/v1/createUser'
         )
         self.request.args = {}
+        self.request.requestHeaders.getRawHeaders = mock_getRawHeaders()
 
         self.registration_handler = Mock()
 
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index e9cb416e4b..b4a787c436 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -3,6 +3,7 @@ from synapse.api.errors import SynapseError
 from twisted.internet import defer
 from mock import Mock
 from tests import unittest
+from tests.utils import mock_getRawHeaders
 import json
 
 
@@ -16,6 +17,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
             path='/_matrix/api/v2_alpha/register'
         )
         self.request.args = {}
+        self.request.requestHeaders.getRawHeaders = mock_getRawHeaders()
 
         self.appservice = None
         self.auth = Mock(get_appservice_by_req=Mock(
diff --git a/tests/utils.py b/tests/utils.py
index a91d167f90..5929f1c729 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -116,6 +116,15 @@ def get_mock_call_args(pattern_func, mock_func):
     return getcallargs(pattern_func, *invoked_args, **invoked_kargs)
 
 
+def mock_getRawHeaders(headers=None):
+    headers = headers if headers is not None else {}
+
+    def getRawHeaders(name, default=None):
+        return headers.get(name, default)
+
+    return getRawHeaders
+
+
 # This is a mock /resource/ not an entire server
 class MockHttpResource(HttpServer):
 
@@ -128,7 +137,7 @@ class MockHttpResource(HttpServer):
 
     @patch('twisted.web.http.Request')
     @defer.inlineCallbacks
-    def trigger(self, http_method, path, content, mock_request):
+    def trigger(self, http_method, path, content, mock_request, federation_auth=False):
         """ Fire an HTTP event.
 
         Args:
@@ -156,9 +165,10 @@ class MockHttpResource(HttpServer):
 
         mock_request.getClientIP.return_value = "-"
 
-        mock_request.requestHeaders.getRawHeaders.return_value = [
-            "X-Matrix origin=test,key=,sig="
-        ]
+        headers = {}
+        if federation_auth:
+            headers["Authorization"] = ["X-Matrix origin=test,key=,sig="]
+        mock_request.requestHeaders.getRawHeaders = mock_getRawHeaders(headers)
 
         # return the right path if the event requires it
         mock_request.path = path