summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/api/auth.py45
-rw-r--r--synapse/handlers/auth.py17
-rw-r--r--synapse/server.pyi4
-rw-r--r--tests/handlers/test_auth.py52
4 files changed, 87 insertions, 31 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 59db76debc..0db26fcfd7 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -675,27 +675,18 @@ class Auth(object):
         try:
             macaroon = pymacaroons.Macaroon.deserialize(macaroon_str)
 
-            user_prefix = "user_id = "
-            user = None
-            user_id = None
-            guest = False
-            for caveat in macaroon.caveats:
-                if caveat.caveat_id.startswith(user_prefix):
-                    user_id = caveat.caveat_id[len(user_prefix):]
-                    user = UserID.from_string(user_id)
-                elif caveat.caveat_id == "guest = true":
-                    guest = True
+            user_id = self.get_user_id_from_macaroon(macaroon)
+            user = UserID.from_string(user_id)
 
             self.validate_macaroon(
                 macaroon, rights, self.hs.config.expire_access_token,
                 user_id=user_id,
             )
 
-            if user is None:
-                raise AuthError(
-                    self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon",
-                    errcode=Codes.UNKNOWN_TOKEN
-                )
+            guest = False
+            for caveat in macaroon.caveats:
+                if caveat.caveat_id == "guest = true":
+                    guest = True
 
             if guest:
                 ret = {
@@ -743,6 +734,29 @@ class Auth(object):
                 errcode=Codes.UNKNOWN_TOKEN
             )
 
+    def get_user_id_from_macaroon(self, macaroon):
+        """Retrieve the user_id given by the caveats on the macaroon.
+
+        Does *not* validate the macaroon.
+
+        Args:
+            macaroon (pymacaroons.Macaroon): The macaroon to validate
+
+        Returns:
+            (str) user id
+
+        Raises:
+            AuthError if there is no user_id caveat in the macaroon
+        """
+        user_prefix = "user_id = "
+        for caveat in macaroon.caveats:
+            if caveat.caveat_id.startswith(user_prefix):
+                return caveat.caveat_id[len(user_prefix):]
+        raise AuthError(
+            self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon",
+            errcode=Codes.UNKNOWN_TOKEN
+        )
+
     def validate_macaroon(self, macaroon, type_string, verify_expiry, user_id):
         """
         validate that a Macaroon is understood by and was signed by this server.
@@ -754,6 +768,7 @@ class Auth(object):
             verify_expiry(bool): Whether to verify whether the macaroon has expired.
                 This should really always be True, but no clients currently implement
                 token refresh, so we can't enforce expiry yet.
+            user_id (str): The user_id required
         """
         v = pymacaroons.Verifier()
         v.satisfy_exact("gen = 1")
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 2e138f328f..1d3641b7a7 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -720,10 +720,11 @@ class AuthHandler(BaseHandler):
 
     def validate_short_term_login_token_and_get_user_id(self, login_token):
         try:
-            macaroon = pymacaroons.Macaroon.deserialize(login_token)
             auth_api = self.hs.get_auth()
-            auth_api.validate_macaroon(macaroon, "login", True)
-            return self.get_user_from_macaroon(macaroon)
+            macaroon = pymacaroons.Macaroon.deserialize(login_token)
+            user_id = auth_api.get_user_id_from_macaroon(macaroon)
+            auth_api.validate_macaroon(macaroon, "login", True, user_id)
+            return user_id
         except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
             raise AuthError(401, "Invalid token", errcode=Codes.UNKNOWN_TOKEN)
 
@@ -736,16 +737,6 @@ class AuthHandler(BaseHandler):
         macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
         return macaroon
 
-    def get_user_from_macaroon(self, macaroon):
-        user_prefix = "user_id = "
-        for caveat in macaroon.caveats:
-            if caveat.caveat_id.startswith(user_prefix):
-                return caveat.caveat_id[len(user_prefix):]
-        raise AuthError(
-            self.INVALID_TOKEN_HTTP_STATUS, "No user_id found in token",
-            errcode=Codes.UNKNOWN_TOKEN
-        )
-
     @defer.inlineCallbacks
     def set_password(self, user_id, newpassword, requester=None):
         password_hash = self.hash(newpassword)
diff --git a/synapse/server.pyi b/synapse/server.pyi
index c0aa868c4f..9570df5537 100644
--- a/synapse/server.pyi
+++ b/synapse/server.pyi
@@ -1,3 +1,4 @@
+import synapse.api.auth
 import synapse.handlers
 import synapse.handlers.auth
 import synapse.handlers.device
@@ -6,6 +7,9 @@ import synapse.storage
 import synapse.state
 
 class HomeServer(object):
+    def get_auth(self) -> synapse.api.auth.Auth:
+        pass
+
     def get_auth_handler(self) -> synapse.handlers.auth.AuthHandler:
         pass
 
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index 21077cbe9a..4a8cd19acf 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -14,11 +14,13 @@
 # limitations under the License.
 
 import pymacaroons
+from twisted.internet import defer
 
+import synapse
+import synapse.api.errors
 from synapse.handlers.auth import AuthHandler
 from tests import unittest
 from tests.utils import setup_test_homeserver
-from twisted.internet import defer
 
 
 class AuthHandlers(object):
@@ -31,11 +33,12 @@ class AuthTestCase(unittest.TestCase):
     def setUp(self):
         self.hs = yield setup_test_homeserver(handlers=None)
         self.hs.handlers = AuthHandlers(self.hs)
+        self.auth_handler = self.hs.handlers.auth_handler
 
     def test_token_is_a_macaroon(self):
         self.hs.config.macaroon_secret_key = "this key is a huge secret"
 
-        token = self.hs.handlers.auth_handler.generate_access_token("some_user")
+        token = self.auth_handler.generate_access_token("some_user")
         # Check that we can parse the thing with pymacaroons
         macaroon = pymacaroons.Macaroon.deserialize(token)
         # The most basic of sanity checks
@@ -46,7 +49,7 @@ class AuthTestCase(unittest.TestCase):
         self.hs.config.macaroon_secret_key = "this key is a massive secret"
         self.hs.clock.now = 5000
 
-        token = self.hs.handlers.auth_handler.generate_access_token("a_user")
+        token = self.auth_handler.generate_access_token("a_user")
         macaroon = pymacaroons.Macaroon.deserialize(token)
 
         def verify_gen(caveat):
@@ -67,3 +70,46 @@ class AuthTestCase(unittest.TestCase):
         v.satisfy_general(verify_type)
         v.satisfy_general(verify_expiry)
         v.verify(macaroon, self.hs.config.macaroon_secret_key)
+
+    def test_short_term_login_token_gives_user_id(self):
+        self.hs.clock.now = 1000
+
+        token = self.auth_handler.generate_short_term_login_token(
+            "a_user", 5000
+        )
+
+        self.assertEqual(
+            "a_user",
+            self.auth_handler.validate_short_term_login_token_and_get_user_id(
+                token
+            )
+        )
+
+        # when we advance the clock, the token should be rejected
+        self.hs.clock.now = 6000
+        with self.assertRaises(synapse.api.errors.AuthError):
+            self.auth_handler.validate_short_term_login_token_and_get_user_id(
+                token
+            )
+
+    def test_short_term_login_token_cannot_replace_user_id(self):
+        token = self.auth_handler.generate_short_term_login_token(
+            "a_user", 5000
+        )
+        macaroon = pymacaroons.Macaroon.deserialize(token)
+
+        self.assertEqual(
+            "a_user",
+            self.auth_handler.validate_short_term_login_token_and_get_user_id(
+                macaroon.serialize()
+            )
+        )
+
+        # add another "user_id" caveat, which might allow us to override the
+        # user_id.
+        macaroon.add_first_party_caveat("user_id = b_user")
+
+        with self.assertRaises(synapse.api.errors.AuthError):
+            self.auth_handler.validate_short_term_login_token_and_get_user_id(
+                macaroon.serialize()
+            )