diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 89e62b0e36..968cf6f174 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -39,6 +39,7 @@ from synapse.logging import opentracing as opentracing
from synapse.storage.databases.main.registration import TokenLookupResult
from synapse.types import StateMap, UserID
from synapse.util.caches.lrucache import LruCache
+from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
from synapse.util.metrics import Measure
logger = logging.getLogger(__name__)
@@ -408,7 +409,7 @@ class Auth:
raise _InvalidMacaroonException()
try:
- user_id = self.get_user_id_from_macaroon(macaroon)
+ user_id = get_value_from_macaroon(macaroon, "user_id")
guest = False
for caveat in macaroon.caveats:
@@ -416,7 +417,12 @@ class Auth:
guest = True
self.validate_macaroon(macaroon, rights, user_id=user_id)
- except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
+ except (
+ pymacaroons.exceptions.MacaroonException,
+ KeyError,
+ TypeError,
+ ValueError,
+ ):
raise InvalidClientTokenError("Invalid macaroon passed.")
if rights == "access":
@@ -424,27 +430,6 @@ class Auth:
return user_id, guest
- def get_user_id_from_macaroon(self, macaroon):
- """Retrieve the user_id given by the caveats on the macaroon.
-
- Does *not* validate the macaroon.
-
- Args:
- macaroon (pymacaroons.Macaroon): The macaroon to validate
-
- Returns:
- (str) user id
-
- Raises:
- InvalidClientCredentialsError if there is no user_id caveat in the
- macaroon
- """
- user_prefix = "user_id = "
- for caveat in macaroon.caveats:
- if caveat.caveat_id.startswith(user_prefix):
- return caveat.caveat_id[len(user_prefix) :]
- raise InvalidClientTokenError("No user caveat in macaroon")
-
def validate_macaroon(self, macaroon, type_string, user_id):
"""
validate that a Macaroon is understood by and was signed by this server.
@@ -465,21 +450,13 @@ class Auth:
v.satisfy_exact("type = " + type_string)
v.satisfy_exact("user_id = %s" % user_id)
v.satisfy_exact("guest = true")
- v.satisfy_general(self._verify_expiry)
+ satisfy_expiry(v, self.clock.time_msec)
# access_tokens include a nonce for uniqueness: any value is acceptable
v.satisfy_general(lambda c: c.startswith("nonce = "))
v.verify(macaroon, self._macaroon_secret_key)
- def _verify_expiry(self, caveat):
- prefix = "time < "
- if not caveat.startswith(prefix):
- return False
- expiry = int(caveat[len(prefix) :])
- now = self.hs.get_clock().time_msec()
- return now < expiry
-
def get_appservice_by_req(self, request: SynapseRequest) -> ApplicationService:
token = self.get_access_token_from_request(request)
service = self.store.get_app_service_by_token(token)
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 3978e41518..bec0c615d4 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -65,6 +65,7 @@ from synapse.storage.roommember import ProfileInfo
from synapse.types import JsonDict, Requester, UserID
from synapse.util import stringutils as stringutils
from synapse.util.async_helpers import maybe_awaitable
+from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.threepids import canonicalise_email
@@ -170,6 +171,16 @@ class SsoLoginExtraAttributes:
extra_attributes = attr.ib(type=JsonDict)
+@attr.s(slots=True, frozen=True)
+class LoginTokenAttributes:
+ """Data we store in a short-term login token"""
+
+ user_id = attr.ib(type=str)
+
+ # the SSO Identity Provider that the user authenticated with, to get this token
+ auth_provider_id = attr.ib(type=str)
+
+
class AuthHandler(BaseHandler):
SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
@@ -1164,18 +1175,16 @@ class AuthHandler(BaseHandler):
return None
return user_id
- async def validate_short_term_login_token_and_get_user_id(self, login_token: str):
- auth_api = self.hs.get_auth()
- user_id = None
+ async def validate_short_term_login_token(
+ self, login_token: str
+ ) -> LoginTokenAttributes:
try:
- macaroon = pymacaroons.Macaroon.deserialize(login_token)
- user_id = auth_api.get_user_id_from_macaroon(macaroon)
- auth_api.validate_macaroon(macaroon, "login", user_id)
+ res = self.macaroon_gen.verify_short_term_login_token(login_token)
except Exception:
raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
- await self.auth.check_auth_blocking(user_id)
- return user_id
+ await self.auth.check_auth_blocking(res.user_id)
+ return res
async def delete_access_token(self, access_token: str):
"""Invalidate a single access token
@@ -1397,6 +1406,7 @@ class AuthHandler(BaseHandler):
async def complete_sso_login(
self,
registered_user_id: str,
+ auth_provider_id: str,
request: Request,
client_redirect_url: str,
extra_attributes: Optional[JsonDict] = None,
@@ -1406,6 +1416,9 @@ class AuthHandler(BaseHandler):
Args:
registered_user_id: The registered user ID to complete SSO login for.
+ auth_provider_id: The id of the SSO Identity provider that was used for
+ login. This will be stored in the login token for future tracking in
+ prometheus metrics.
request: The request to complete.
client_redirect_url: The URL to which to redirect the user at the end of the
process.
@@ -1427,6 +1440,7 @@ class AuthHandler(BaseHandler):
self._complete_sso_login(
registered_user_id,
+ auth_provider_id,
request,
client_redirect_url,
extra_attributes,
@@ -1437,6 +1451,7 @@ class AuthHandler(BaseHandler):
def _complete_sso_login(
self,
registered_user_id: str,
+ auth_provider_id: str,
request: Request,
client_redirect_url: str,
extra_attributes: Optional[JsonDict] = None,
@@ -1463,7 +1478,7 @@ class AuthHandler(BaseHandler):
# Create a login token
login_token = self.macaroon_gen.generate_short_term_login_token(
- registered_user_id
+ registered_user_id, auth_provider_id=auth_provider_id
)
# Append the login token to the original redirect URL (i.e. with its query
@@ -1569,15 +1584,48 @@ class MacaroonGenerator:
return macaroon.serialize()
def generate_short_term_login_token(
- self, user_id: str, duration_in_ms: int = (2 * 60 * 1000)
+ self,
+ user_id: str,
+ auth_provider_id: str,
+ duration_in_ms: int = (2 * 60 * 1000),
) -> str:
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = login")
now = self.hs.get_clock().time_msec()
expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,))
+ macaroon.add_first_party_caveat("auth_provider_id = %s" % (auth_provider_id,))
return macaroon.serialize()
+ def verify_short_term_login_token(self, token: str) -> LoginTokenAttributes:
+ """Verify a short-term-login macaroon
+
+ Checks that the given token is a valid, unexpired short-term-login token
+ minted by this server.
+
+ Args:
+ token: the login token to verify
+
+ Returns:
+ the user_id that this token is valid for
+
+ Raises:
+ MacaroonVerificationFailedException if the verification failed
+ """
+ macaroon = pymacaroons.Macaroon.deserialize(token)
+ user_id = get_value_from_macaroon(macaroon, "user_id")
+ auth_provider_id = get_value_from_macaroon(macaroon, "auth_provider_id")
+
+ v = pymacaroons.Verifier()
+ v.satisfy_exact("gen = 1")
+ v.satisfy_exact("type = login")
+ v.satisfy_general(lambda c: c.startswith("user_id = "))
+ v.satisfy_general(lambda c: c.startswith("auth_provider_id = "))
+ satisfy_expiry(v, self.hs.get_clock().time_msec)
+ v.verify(macaroon, self.hs.config.key.macaroon_secret_key)
+
+ return LoginTokenAttributes(user_id=user_id, auth_provider_id=auth_provider_id)
+
def generate_delete_pusher_token(self, user_id: str) -> str:
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = delete_pusher")
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index 07db1e31e4..b4a74390cc 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -42,6 +42,7 @@ from synapse.logging.context import make_deferred_yieldable
from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
from synapse.util import json_decoder
from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
+from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -211,7 +212,7 @@ class OidcHandler:
session_data = self._token_generator.verify_oidc_session_token(
session, state
)
- except (MacaroonDeserializationException, ValueError) as e:
+ except (MacaroonDeserializationException, KeyError) as e:
logger.exception("Invalid session for OIDC callback")
self._sso_handler.render_error(request, "invalid_session", str(e))
return
@@ -745,7 +746,7 @@ class OidcProvider:
idp_id=self.idp_id,
nonce=nonce,
client_redirect_url=client_redirect_url.decode(),
- ui_auth_session_id=ui_auth_session_id,
+ ui_auth_session_id=ui_auth_session_id or "",
),
)
@@ -1020,10 +1021,9 @@ class OidcSessionTokenGenerator:
macaroon.add_first_party_caveat(
"client_redirect_url = %s" % (session_data.client_redirect_url,)
)
- if session_data.ui_auth_session_id:
- macaroon.add_first_party_caveat(
- "ui_auth_session_id = %s" % (session_data.ui_auth_session_id,)
- )
+ macaroon.add_first_party_caveat(
+ "ui_auth_session_id = %s" % (session_data.ui_auth_session_id,)
+ )
now = self._clock.time_msec()
expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,))
@@ -1046,7 +1046,7 @@ class OidcSessionTokenGenerator:
The data extracted from the session cookie
Raises:
- ValueError if an expected caveat is missing from the macaroon.
+ KeyError if an expected caveat is missing from the macaroon.
"""
macaroon = pymacaroons.Macaroon.deserialize(session)
@@ -1057,26 +1057,16 @@ class OidcSessionTokenGenerator:
v.satisfy_general(lambda c: c.startswith("nonce = "))
v.satisfy_general(lambda c: c.startswith("idp_id = "))
v.satisfy_general(lambda c: c.startswith("client_redirect_url = "))
- # Sometimes there's a UI auth session ID, it seems to be OK to attempt
- # to always satisfy this.
v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = "))
- v.satisfy_general(self._verify_expiry)
+ satisfy_expiry(v, self._clock.time_msec)
v.verify(macaroon, self._macaroon_secret_key)
# Extract the session data from the token.
- nonce = self._get_value_from_macaroon(macaroon, "nonce")
- idp_id = self._get_value_from_macaroon(macaroon, "idp_id")
- client_redirect_url = self._get_value_from_macaroon(
- macaroon, "client_redirect_url"
- )
- try:
- ui_auth_session_id = self._get_value_from_macaroon(
- macaroon, "ui_auth_session_id"
- ) # type: Optional[str]
- except ValueError:
- ui_auth_session_id = None
-
+ nonce = get_value_from_macaroon(macaroon, "nonce")
+ idp_id = get_value_from_macaroon(macaroon, "idp_id")
+ client_redirect_url = get_value_from_macaroon(macaroon, "client_redirect_url")
+ ui_auth_session_id = get_value_from_macaroon(macaroon, "ui_auth_session_id")
return OidcSessionData(
nonce=nonce,
idp_id=idp_id,
@@ -1084,33 +1074,6 @@ class OidcSessionTokenGenerator:
ui_auth_session_id=ui_auth_session_id,
)
- def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str:
- """Extracts a caveat value from a macaroon token.
-
- Args:
- macaroon: the token
- key: the key of the caveat to extract
-
- Returns:
- The extracted value
-
- Raises:
- ValueError: if the caveat was not in the macaroon
- """
- prefix = key + " = "
- for caveat in macaroon.caveats:
- if caveat.caveat_id.startswith(prefix):
- return caveat.caveat_id[len(prefix) :]
- raise ValueError("No %s caveat in macaroon" % (key,))
-
- def _verify_expiry(self, caveat: str) -> bool:
- prefix = "time < "
- if not caveat.startswith(prefix):
- return False
- expiry = int(caveat[len(prefix) :])
- now = self._clock.time_msec()
- return now < expiry
-
@attr.s(frozen=True, slots=True)
class OidcSessionData:
@@ -1125,8 +1088,8 @@ class OidcSessionData:
# The URL the client gave when it initiated the flow. ("" if this is a UI Auth)
client_redirect_url = attr.ib(type=str)
- # The session ID of the ongoing UI Auth (None if this is a login)
- ui_auth_session_id = attr.ib(type=Optional[str], default=None)
+ # The session ID of the ongoing UI Auth ("" if this is a login)
+ ui_auth_session_id = attr.ib(type=str)
UserAttributeDict = TypedDict(
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 80e28bdcbe..8a22dab54a 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -456,6 +456,7 @@ class SsoHandler:
await self._auth_handler.complete_sso_login(
user_id,
+ auth_provider_id,
request,
client_redirect_url,
extra_login_attributes,
@@ -886,6 +887,7 @@ class SsoHandler:
await self._auth_handler.complete_sso_login(
user_id,
+ session.auth_provider_id,
request,
session.client_redirect_url,
session.extra_login_attributes,
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index db2d400b7e..781e02fbbb 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -203,11 +203,26 @@ class ModuleApi:
)
def generate_short_term_login_token(
- self, user_id: str, duration_in_ms: int = (2 * 60 * 1000)
+ self,
+ user_id: str,
+ duration_in_ms: int = (2 * 60 * 1000),
+ auth_provider_id: str = "",
) -> str:
- """Generate a login token suitable for m.login.token authentication"""
+ """Generate a login token suitable for m.login.token authentication
+
+ Args:
+ user_id: gives the ID of the user that the token is for
+
+ duration_in_ms: the time that the token will be valid for
+
+ auth_provider_id: the ID of the SSO IdP that the user used to authenticate
+ to get this token, if any. This is encoded in the token so that
+ /login can report stats on number of successful logins by IdP.
+ """
return self._hs.get_macaroon_generator().generate_short_term_login_token(
- user_id, duration_in_ms
+ user_id,
+ auth_provider_id,
+ duration_in_ms,
)
@defer.inlineCallbacks
@@ -276,6 +291,7 @@ class ModuleApi:
"""
self._auth_handler._complete_sso_login(
registered_user_id,
+ "<unknown>",
request,
client_redirect_url,
)
@@ -286,6 +302,7 @@ class ModuleApi:
request: SynapseRequest,
client_redirect_url: str,
new_user: bool = False,
+ auth_provider_id: str = "<unknown>",
):
"""Complete a SSO login by redirecting the user to a page to confirm whether they
want their access token sent to `client_redirect_url`, or redirect them to that
@@ -299,9 +316,15 @@ class ModuleApi:
redirect them directly if whitelisted).
new_user: set to true to use wording for the consent appropriate to a user
who has just registered.
+ auth_provider_id: the ID of the SSO IdP which was used to log in. This
+ is used to track counts of sucessful logins by IdP.
"""
await self._auth_handler.complete_sso_login(
- registered_user_id, request, client_redirect_url, new_user=new_user
+ registered_user_id,
+ auth_provider_id,
+ request,
+ client_redirect_url,
+ new_user=new_user,
)
@defer.inlineCallbacks
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 925edfc402..1ec3a47ffb 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -283,12 +283,10 @@ class LoginRestServlet(RestServlet):
"""
token = login_submission["token"]
auth_handler = self.auth_handler
- user_id = await auth_handler.validate_short_term_login_token_and_get_user_id(
- token
- )
+ res = await auth_handler.validate_short_term_login_token(token)
return await self._complete_login(
- user_id, login_submission, self.auth_handler._sso_login_callback
+ res.user_id, login_submission, self.auth_handler._sso_login_callback
)
async def _do_jwt_login(self, login_submission: JsonDict) -> Dict[str, str]:
diff --git a/synapse/util/macaroons.py b/synapse/util/macaroons.py
new file mode 100644
index 0000000000..12cdd53327
--- /dev/null
+++ b/synapse/util/macaroons.py
@@ -0,0 +1,89 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Quentin Gliech
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Utilities for manipulating macaroons"""
+
+from typing import Callable, Optional
+
+import pymacaroons
+from pymacaroons.exceptions import MacaroonVerificationFailedException
+
+
+def get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str:
+ """Extracts a caveat value from a macaroon token.
+
+ Checks that there is exactly one caveat of the form "key = <val>" in the macaroon,
+ and returns the extracted value.
+
+ Args:
+ macaroon: the token
+ key: the key of the caveat to extract
+
+ Returns:
+ The extracted value
+
+ Raises:
+ MacaroonVerificationFailedException: if there are conflicting values for the
+ caveat in the macaroon, or if the caveat was not found in the macaroon.
+ """
+ prefix = key + " = "
+ result = None # type: Optional[str]
+ for caveat in macaroon.caveats:
+ if not caveat.caveat_id.startswith(prefix):
+ continue
+
+ val = caveat.caveat_id[len(prefix) :]
+
+ if result is None:
+ # first time we found this caveat: record the value
+ result = val
+ elif val != result:
+ # on subsequent occurrences, raise if the value is different.
+ raise MacaroonVerificationFailedException(
+ "Conflicting values for caveat " + key
+ )
+
+ if result is not None:
+ return result
+
+ # If the caveat is not there, we raise a MacaroonVerificationFailedException.
+ # Note that it is insecure to generate a macaroon without all the caveats you
+ # might need (because there is nothing stopping people from adding extra caveats),
+ # so if the caveat isn't there, something odd must be going on.
+ raise MacaroonVerificationFailedException("No %s caveat in macaroon" % (key,))
+
+
+def satisfy_expiry(v: pymacaroons.Verifier, get_time_ms: Callable[[], int]) -> None:
+ """Make a macaroon verifier which accepts 'time' caveats
+
+ Builds a caveat verifier which will accept unexpired 'time' caveats, and adds it to
+ the given macaroon verifier.
+
+ Args:
+ v: the macaroon verifier
+ get_time_ms: a callable which will return the timestamp after which the caveat
+ should be considered expired. Normally the current time.
+ """
+
+ def verify_expiry_caveat(caveat: str):
+ time_msec = get_time_ms()
+ prefix = "time < "
+ if not caveat.startswith(prefix):
+ return False
+ expiry = int(caveat[len(prefix) :])
+ return time_msec < expiry
+
+ v.satisfy_general(verify_expiry_caveat)
|