summary refs log tree commit diff
path: root/synapse/handlers/auth.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/auth.py')
-rw-r--r--synapse/handlers/auth.py122
1 files changed, 10 insertions, 112 deletions
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index fbafbbee6b..3d83236b0c 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -37,9 +37,7 @@ from typing import (
 
 import attr
 import bcrypt
-import pymacaroons
 import unpaddedbase64
-from pymacaroons.exceptions import MacaroonVerificationFailedException
 
 from twisted.internet.defer import CancelledError
 from twisted.web.server import Request
@@ -69,7 +67,7 @@ from synapse.storage.roommember import ProfileInfo
 from synapse.types import JsonDict, Requester, UserID
 from synapse.util import stringutils as stringutils
 from synapse.util.async_helpers import delay_cancellation, maybe_awaitable
-from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
+from synapse.util.macaroons import LoginTokenAttributes
 from synapse.util.msisdn import phone_number_to_msisdn
 from synapse.util.stringutils import base62_encode
 from synapse.util.threepids import canonicalise_email
@@ -81,6 +79,8 @@ if TYPE_CHECKING:
 
 logger = logging.getLogger(__name__)
 
+INVALID_USERNAME_OR_PASSWORD = "Invalid username or password"
+
 
 def convert_client_dict_legacy_fields_to_identifier(
     submission: JsonDict,
@@ -178,25 +178,13 @@ class SsoLoginExtraAttributes:
     extra_attributes: JsonDict
 
 
-@attr.s(slots=True, frozen=True, auto_attribs=True)
-class LoginTokenAttributes:
-    """Data we store in a short-term login token"""
-
-    user_id: str
-
-    auth_provider_id: str
-    """The SSO Identity Provider that the user authenticated with, to get this token."""
-
-    auth_provider_session_id: Optional[str]
-    """The session ID advertised by the SSO Identity Provider."""
-
-
 class AuthHandler:
     SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
 
     def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastores().main
         self.auth = hs.get_auth()
+        self.auth_blocking = hs.get_auth_blocking()
         self.clock = hs.get_clock()
         self.checkers: Dict[str, UserInteractiveAuthChecker] = {}
         for auth_checker_class in INTERACTIVE_AUTH_CHECKERS:
@@ -983,7 +971,7 @@ class AuthHandler:
             not is_appservice_ghost
             or self.hs.config.appservice.track_appservice_user_ips
         ):
-            await self.auth.check_auth_blocking(user_id)
+            await self.auth_blocking.check_auth_blocking(user_id)
 
         access_token = self.generate_access_token(target_user_id_obj)
         await self.store.add_access_token_to_user(
@@ -1215,7 +1203,9 @@ class AuthHandler:
                     await self._failed_login_attempts_ratelimiter.can_do_action(
                         None, (medium, address)
                     )
-                raise LoginError(403, "", errcode=Codes.FORBIDDEN)
+                raise LoginError(
+                    403, msg=INVALID_USERNAME_OR_PASSWORD, errcode=Codes.FORBIDDEN
+                )
 
             identifier_dict = {"type": "m.id.user", "user": user_id}
 
@@ -1341,7 +1331,7 @@ class AuthHandler:
 
         # We raise a 403 here, but note that if we're doing user-interactive
         # login, it turns all LoginErrors into a 401 anyway.
-        raise LoginError(403, "Invalid password", errcode=Codes.FORBIDDEN)
+        raise LoginError(403, msg=INVALID_USERNAME_OR_PASSWORD, errcode=Codes.FORBIDDEN)
 
     async def check_password_provider_3pid(
         self, medium: str, address: str, password: str
@@ -1435,7 +1425,7 @@ class AuthHandler:
         except Exception:
             raise AuthError(403, "Invalid login token", errcode=Codes.FORBIDDEN)
 
-        await self.auth.check_auth_blocking(res.user_id)
+        await self.auth_blocking.check_auth_blocking(res.user_id)
         return res
 
     async def delete_access_token(self, access_token: str) -> None:
@@ -1826,98 +1816,6 @@ class AuthHandler:
         return urllib.parse.urlunparse(url_parts)
 
 
-@attr.s(slots=True, auto_attribs=True)
-class MacaroonGenerator:
-    hs: "HomeServer"
-
-    def generate_guest_access_token(self, user_id: str) -> str:
-        macaroon = self._generate_base_macaroon(user_id)
-        macaroon.add_first_party_caveat("type = access")
-        # Include a nonce, to make sure that each login gets a different
-        # access token.
-        macaroon.add_first_party_caveat(
-            "nonce = %s" % (stringutils.random_string_with_symbols(16),)
-        )
-        macaroon.add_first_party_caveat("guest = true")
-        return macaroon.serialize()
-
-    def generate_short_term_login_token(
-        self,
-        user_id: str,
-        auth_provider_id: str,
-        auth_provider_session_id: Optional[str] = None,
-        duration_in_ms: int = (2 * 60 * 1000),
-    ) -> str:
-        macaroon = self._generate_base_macaroon(user_id)
-        macaroon.add_first_party_caveat("type = login")
-        now = self.hs.get_clock().time_msec()
-        expiry = now + duration_in_ms
-        macaroon.add_first_party_caveat("time < %d" % (expiry,))
-        macaroon.add_first_party_caveat("auth_provider_id = %s" % (auth_provider_id,))
-        if auth_provider_session_id is not None:
-            macaroon.add_first_party_caveat(
-                "auth_provider_session_id = %s" % (auth_provider_session_id,)
-            )
-        return macaroon.serialize()
-
-    def verify_short_term_login_token(self, token: str) -> LoginTokenAttributes:
-        """Verify a short-term-login macaroon
-
-        Checks that the given token is a valid, unexpired short-term-login token
-        minted by this server.
-
-        Args:
-            token: the login token to verify
-
-        Returns:
-            the user_id that this token is valid for
-
-        Raises:
-            MacaroonVerificationFailedException if the verification failed
-        """
-        macaroon = pymacaroons.Macaroon.deserialize(token)
-        user_id = get_value_from_macaroon(macaroon, "user_id")
-        auth_provider_id = get_value_from_macaroon(macaroon, "auth_provider_id")
-
-        auth_provider_session_id: Optional[str] = None
-        try:
-            auth_provider_session_id = get_value_from_macaroon(
-                macaroon, "auth_provider_session_id"
-            )
-        except MacaroonVerificationFailedException:
-            pass
-
-        v = pymacaroons.Verifier()
-        v.satisfy_exact("gen = 1")
-        v.satisfy_exact("type = login")
-        v.satisfy_general(lambda c: c.startswith("user_id = "))
-        v.satisfy_general(lambda c: c.startswith("auth_provider_id = "))
-        v.satisfy_general(lambda c: c.startswith("auth_provider_session_id = "))
-        satisfy_expiry(v, self.hs.get_clock().time_msec)
-        v.verify(macaroon, self.hs.config.key.macaroon_secret_key)
-
-        return LoginTokenAttributes(
-            user_id=user_id,
-            auth_provider_id=auth_provider_id,
-            auth_provider_session_id=auth_provider_session_id,
-        )
-
-    def generate_delete_pusher_token(self, user_id: str) -> str:
-        macaroon = self._generate_base_macaroon(user_id)
-        macaroon.add_first_party_caveat("type = delete_pusher")
-        return macaroon.serialize()
-
-    def _generate_base_macaroon(self, user_id: str) -> pymacaroons.Macaroon:
-        macaroon = pymacaroons.Macaroon(
-            location=self.hs.config.server.server_name,
-            identifier="key",
-            key=self.hs.config.key.macaroon_secret_key,
-        )
-        macaroon.add_first_party_caveat("gen = 1")
-        macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
-        return macaroon
-
-
 def load_legacy_password_auth_providers(hs: "HomeServer") -> None:
     module_api = hs.get_module_api()
     for module, config in hs.config.authproviders.password_providers: