summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--changelog.d/5656.misc1
-rw-r--r--synapse/api/auth.py127
-rw-r--r--synapse/api/errors.py33
-rw-r--r--synapse/rest/client/v1/directory.py10
-rw-r--r--synapse/rest/client/v1/room.py9
-rw-r--r--tests/api/test_auth.py31
6 files changed, 111 insertions, 100 deletions
diff --git a/changelog.d/5656.misc b/changelog.d/5656.misc
new file mode 100644
index 0000000000..a8de20a7d0
--- /dev/null
+++ b/changelog.d/5656.misc
@@ -0,0 +1 @@
+Clean up exception handling around client access tokens.
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 86f145649c..afc6400948 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -25,7 +25,13 @@ from twisted.internet import defer
 import synapse.types
 from synapse import event_auth
 from synapse.api.constants import EventTypes, JoinRules, Membership
-from synapse.api.errors import AuthError, Codes, ResourceLimitError
+from synapse.api.errors import (
+    AuthError,
+    Codes,
+    InvalidClientTokenError,
+    MissingClientTokenError,
+    ResourceLimitError,
+)
 from synapse.config.server import is_threepid_reserved
 from synapse.types import UserID
 from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache
@@ -63,7 +69,6 @@ class Auth(object):
         self.clock = hs.get_clock()
         self.store = hs.get_datastore()
         self.state = hs.get_state_handler()
-        self.TOKEN_NOT_FOUND_HTTP_STATUS = 401
 
         self.token_cache = LruCache(CACHE_SIZE_FACTOR * 10000)
         register_cache("cache", "token_cache", self.token_cache)
@@ -189,18 +194,17 @@ class Auth(object):
         Returns:
             defer.Deferred: resolves to a ``synapse.types.Requester`` object
         Raises:
-            AuthError if no user by that token exists or the token is invalid.
+            InvalidClientCredentialsError if no user by that token exists or the token
+                is invalid.
+            AuthError if access is denied for the user in the access token
         """
-        # Can optionally look elsewhere in the request (e.g. headers)
         try:
             ip_addr = self.hs.get_ip_from_request(request)
             user_agent = request.requestHeaders.getRawHeaders(
                 b"User-Agent", default=[b""]
             )[0].decode("ascii", "surrogateescape")
 
-            access_token = self.get_access_token_from_request(
-                request, self.TOKEN_NOT_FOUND_HTTP_STATUS
-            )
+            access_token = self.get_access_token_from_request(request)
 
             user_id, app_service = yield self._get_appservice_user_id(request)
             if user_id:
@@ -264,18 +268,12 @@ class Auth(object):
                 )
             )
         except KeyError:
-            raise AuthError(
-                self.TOKEN_NOT_FOUND_HTTP_STATUS,
-                "Missing access token.",
-                errcode=Codes.MISSING_TOKEN,
-            )
+            raise MissingClientTokenError()
 
     @defer.inlineCallbacks
     def _get_appservice_user_id(self, request):
         app_service = self.store.get_app_service_by_token(
-            self.get_access_token_from_request(
-                request, self.TOKEN_NOT_FOUND_HTTP_STATUS
-            )
+            self.get_access_token_from_request(request)
         )
         if app_service is None:
             defer.returnValue((None, None))
@@ -313,7 +311,8 @@ class Auth(object):
                `token_id` (int|None): access token id. May be None if guest
                `device_id` (str|None): device corresponding to access token
         Raises:
-            AuthError if no user by that token exists or the token is invalid.
+            InvalidClientCredentialsError if no user by that token exists or the token
+                is invalid.
         """
 
         if rights == "access":
@@ -331,11 +330,7 @@ class Auth(object):
                 if not guest:
                     # non-guest access tokens must be in the database
                     logger.warning("Unrecognised access token - not in store.")
-                    raise AuthError(
-                        self.TOKEN_NOT_FOUND_HTTP_STATUS,
-                        "Unrecognised access token.",
-                        errcode=Codes.UNKNOWN_TOKEN,
-                    )
+                    raise InvalidClientTokenError()
 
                 # Guest access tokens are not stored in the database (there can
                 # only be one access token per guest, anyway).
@@ -350,16 +345,10 @@ class Auth(object):
                 # guest tokens.
                 stored_user = yield self.store.get_user_by_id(user_id)
                 if not stored_user:
-                    raise AuthError(
-                        self.TOKEN_NOT_FOUND_HTTP_STATUS,
-                        "Unknown user_id %s" % user_id,
-                        errcode=Codes.UNKNOWN_TOKEN,
-                    )
+                    raise InvalidClientTokenError("Unknown user_id %s" % user_id)
                 if not stored_user["is_guest"]:
-                    raise AuthError(
-                        self.TOKEN_NOT_FOUND_HTTP_STATUS,
-                        "Guest access token used for regular user",
-                        errcode=Codes.UNKNOWN_TOKEN,
+                    raise InvalidClientTokenError(
+                        "Guest access token used for regular user"
                     )
                 ret = {
                     "user": user,
@@ -386,11 +375,7 @@ class Auth(object):
             ValueError,
         ) as e:
             logger.warning("Invalid macaroon in auth: %s %s", type(e), e)
-            raise AuthError(
-                self.TOKEN_NOT_FOUND_HTTP_STATUS,
-                "Invalid macaroon passed.",
-                errcode=Codes.UNKNOWN_TOKEN,
-            )
+            raise InvalidClientTokenError("Invalid macaroon passed.")
 
     def _parse_and_validate_macaroon(self, token, rights="access"):
         """Takes a macaroon and tries to parse and validate it. This is cached
@@ -430,11 +415,7 @@ class Auth(object):
                 macaroon, rights, self.hs.config.expire_access_token, user_id=user_id
             )
         except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
-            raise AuthError(
-                self.TOKEN_NOT_FOUND_HTTP_STATUS,
-                "Invalid macaroon passed.",
-                errcode=Codes.UNKNOWN_TOKEN,
-            )
+            raise InvalidClientTokenError("Invalid macaroon passed.")
 
         if not has_expiry and rights == "access":
             self.token_cache[token] = (user_id, guest)
@@ -453,17 +434,14 @@ class Auth(object):
             (str) user id
 
         Raises:
-            AuthError if there is no user_id caveat in the macaroon
+            InvalidClientCredentialsError if there is no user_id caveat in the
+                macaroon
         """
         user_prefix = "user_id = "
         for caveat in macaroon.caveats:
             if caveat.caveat_id.startswith(user_prefix):
                 return caveat.caveat_id[len(user_prefix) :]
-        raise AuthError(
-            self.TOKEN_NOT_FOUND_HTTP_STATUS,
-            "No user caveat in macaroon",
-            errcode=Codes.UNKNOWN_TOKEN,
-        )
+        raise InvalidClientTokenError("No user caveat in macaroon")
 
     def validate_macaroon(self, macaroon, type_string, verify_expiry, user_id):
         """
@@ -531,22 +509,13 @@ class Auth(object):
         defer.returnValue(user_info)
 
     def get_appservice_by_req(self, request):
-        try:
-            token = self.get_access_token_from_request(
-                request, self.TOKEN_NOT_FOUND_HTTP_STATUS
-            )
-            service = self.store.get_app_service_by_token(token)
-            if not service:
-                logger.warn("Unrecognised appservice access token.")
-                raise AuthError(
-                    self.TOKEN_NOT_FOUND_HTTP_STATUS,
-                    "Unrecognised access token.",
-                    errcode=Codes.UNKNOWN_TOKEN,
-                )
-            request.authenticated_entity = service.sender
-            return defer.succeed(service)
-        except KeyError:
-            raise AuthError(self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.")
+        token = self.get_access_token_from_request(request)
+        service = self.store.get_app_service_by_token(token)
+        if not service:
+            logger.warn("Unrecognised appservice access token.")
+            raise InvalidClientTokenError()
+        request.authenticated_entity = service.sender
+        return defer.succeed(service)
 
     def is_server_admin(self, user):
         """ Check if the given user is a local server admin.
@@ -692,20 +661,16 @@ class Auth(object):
         return bool(query_params) or bool(auth_headers)
 
     @staticmethod
-    def get_access_token_from_request(request, token_not_found_http_status=401):
+    def get_access_token_from_request(request):
         """Extracts the access_token from the request.
 
         Args:
             request: The http request.
-            token_not_found_http_status(int): The HTTP status code to set in the
-                AuthError if the token isn't found. This is used in some of the
-                legacy APIs to change the status code to 403 from the default of
-                401 since some of the old clients depended on auth errors returning
-                403.
         Returns:
             unicode: The access_token
         Raises:
-            AuthError: If there isn't an access_token in the request.
+            MissingClientTokenError: If there isn't a single access_token in the
+                request
         """
 
         auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
@@ -714,34 +679,20 @@ class Auth(object):
             # Try the get the access_token from a "Authorization: Bearer"
             # header
             if query_params is not None:
-                raise AuthError(
-                    token_not_found_http_status,
-                    "Mixing Authorization headers and access_token query parameters.",
-                    errcode=Codes.MISSING_TOKEN,
+                raise MissingClientTokenError(
+                    "Mixing Authorization headers and access_token query parameters."
                 )
             if len(auth_headers) > 1:
-                raise AuthError(
-                    token_not_found_http_status,
-                    "Too many Authorization headers.",
-                    errcode=Codes.MISSING_TOKEN,
-                )
+                raise MissingClientTokenError("Too many Authorization headers.")
             parts = auth_headers[0].split(b" ")
             if parts[0] == b"Bearer" and len(parts) == 2:
                 return parts[1].decode("ascii")
             else:
-                raise AuthError(
-                    token_not_found_http_status,
-                    "Invalid Authorization header.",
-                    errcode=Codes.MISSING_TOKEN,
-                )
+                raise MissingClientTokenError("Invalid Authorization header.")
         else:
             # Try to get the access_token from the query params.
             if not query_params:
-                raise AuthError(
-                    token_not_found_http_status,
-                    "Missing access token.",
-                    errcode=Codes.MISSING_TOKEN,
-                )
+                raise MissingClientTokenError()
 
             return query_params[0].decode("ascii")
 
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index 28b5c2af9b..41fd04cd54 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -210,7 +210,9 @@ class NotFoundError(SynapseError):
 
 
 class AuthError(SynapseError):
-    """An error raised when there was a problem authorising an event."""
+    """An error raised when there was a problem authorising an event, and at various
+    other poorly-defined times.
+    """
 
     def __init__(self, *args, **kwargs):
         if "errcode" not in kwargs:
@@ -218,6 +220,35 @@ class AuthError(SynapseError):
         super(AuthError, self).__init__(*args, **kwargs)
 
 
+class InvalidClientCredentialsError(SynapseError):
+    """An error raised when there was a problem with the authorisation credentials
+    in a client request.
+
+    https://matrix.org/docs/spec/client_server/r0.5.0#using-access-tokens:
+
+    When credentials are required but missing or invalid, the HTTP call will
+    return with a status of 401 and the error code, M_MISSING_TOKEN or
+    M_UNKNOWN_TOKEN respectively.
+    """
+
+    def __init__(self, msg, errcode):
+        super().__init__(code=401, msg=msg, errcode=errcode)
+
+
+class MissingClientTokenError(InvalidClientCredentialsError):
+    """Raised when we couldn't find the access token in a request"""
+
+    def __init__(self, msg="Missing access token"):
+        super().__init__(msg=msg, errcode="M_MISSING_TOKEN")
+
+
+class InvalidClientTokenError(InvalidClientCredentialsError):
+    """Raised when we didn't understand the access token in a request"""
+
+    def __init__(self, msg="Unrecognised access token"):
+        super().__init__(msg=msg, errcode="M_UNKNOWN_TOKEN")
+
+
 class ResourceLimitError(SynapseError):
     """
     Any error raised when there is a problem with resource usage.
diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py
index dd0d38ea5c..57542c2b4b 100644
--- a/synapse/rest/client/v1/directory.py
+++ b/synapse/rest/client/v1/directory.py
@@ -18,7 +18,13 @@ import logging
 
 from twisted.internet import defer
 
-from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
+from synapse.api.errors import (
+    AuthError,
+    Codes,
+    InvalidClientCredentialsError,
+    NotFoundError,
+    SynapseError,
+)
 from synapse.http.servlet import RestServlet, parse_json_object_from_request
 from synapse.rest.client.v2_alpha._base import client_patterns
 from synapse.types import RoomAlias
@@ -97,7 +103,7 @@ class ClientDirectoryServer(RestServlet):
                 room_alias.to_string(),
             )
             defer.returnValue((200, {}))
-        except AuthError:
+        except InvalidClientCredentialsError:
             # fallback to default user behaviour if they aren't an AS
             pass
 
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index cca7e45ddb..7709c2d705 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -24,7 +24,12 @@ from canonicaljson import json
 from twisted.internet import defer
 
 from synapse.api.constants import EventTypes, Membership
-from synapse.api.errors import AuthError, Codes, SynapseError
+from synapse.api.errors import (
+    AuthError,
+    Codes,
+    InvalidClientCredentialsError,
+    SynapseError,
+)
 from synapse.api.filtering import Filter
 from synapse.events.utils import format_event_for_client_v2
 from synapse.http.servlet import (
@@ -307,7 +312,7 @@ class PublicRoomListRestServlet(TransactionRestServlet):
 
         try:
             yield self.auth.get_user_by_req(request, allow_guest=True)
-        except AuthError as e:
+        except InvalidClientCredentialsError as e:
             # Option to allow servers to require auth when accessing
             # /publicRooms via CS API. This is especially helpful in private
             # federations.
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index ddf2b578b3..ee92ceeb60 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -21,7 +21,14 @@ from twisted.internet import defer
 
 import synapse.handlers.auth
 from synapse.api.auth import Auth
-from synapse.api.errors import AuthError, Codes, ResourceLimitError
+from synapse.api.errors import (
+    AuthError,
+    Codes,
+    InvalidClientCredentialsError,
+    InvalidClientTokenError,
+    MissingClientTokenError,
+    ResourceLimitError,
+)
 from synapse.types import UserID
 
 from tests import unittest
@@ -70,7 +77,9 @@ class AuthTestCase(unittest.TestCase):
         request.args[b"access_token"] = [self.test_token]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
         d = self.auth.get_user_by_req(request)
-        self.failureResultOf(d, AuthError)
+        f = self.failureResultOf(d, InvalidClientTokenError).value
+        self.assertEqual(f.code, 401)
+        self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
 
     def test_get_user_by_req_user_missing_token(self):
         user_info = {"name": self.test_user, "token_id": "ditto"}
@@ -79,7 +88,9 @@ class AuthTestCase(unittest.TestCase):
         request = Mock(args={})
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
         d = self.auth.get_user_by_req(request)
-        self.failureResultOf(d, AuthError)
+        f = self.failureResultOf(d, MissingClientTokenError).value
+        self.assertEqual(f.code, 401)
+        self.assertEqual(f.errcode, "M_MISSING_TOKEN")
 
     @defer.inlineCallbacks
     def test_get_user_by_req_appservice_valid_token(self):
@@ -133,7 +144,9 @@ class AuthTestCase(unittest.TestCase):
         request.args[b"access_token"] = [self.test_token]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
         d = self.auth.get_user_by_req(request)
-        self.failureResultOf(d, AuthError)
+        f = self.failureResultOf(d, InvalidClientTokenError).value
+        self.assertEqual(f.code, 401)
+        self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
 
     def test_get_user_by_req_appservice_bad_token(self):
         self.store.get_app_service_by_token = Mock(return_value=None)
@@ -143,7 +156,9 @@ class AuthTestCase(unittest.TestCase):
         request.args[b"access_token"] = [self.test_token]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
         d = self.auth.get_user_by_req(request)
-        self.failureResultOf(d, AuthError)
+        f = self.failureResultOf(d, InvalidClientTokenError).value
+        self.assertEqual(f.code, 401)
+        self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
 
     def test_get_user_by_req_appservice_missing_token(self):
         app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
@@ -153,7 +168,9 @@ class AuthTestCase(unittest.TestCase):
         request = Mock(args={})
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
         d = self.auth.get_user_by_req(request)
-        self.failureResultOf(d, AuthError)
+        f = self.failureResultOf(d, MissingClientTokenError).value
+        self.assertEqual(f.code, 401)
+        self.assertEqual(f.errcode, "M_MISSING_TOKEN")
 
     @defer.inlineCallbacks
     def test_get_user_by_req_appservice_valid_token_valid_user_id(self):
@@ -280,7 +297,7 @@ class AuthTestCase(unittest.TestCase):
         request.args[b"access_token"] = [guest_tok.encode("ascii")]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
 
-        with self.assertRaises(AuthError) as cm:
+        with self.assertRaises(InvalidClientCredentialsError) as cm:
             yield self.auth.get_user_by_req(request, allow_guest=True)
 
         self.assertEqual(401, cm.exception.code)