summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/auth.py16
-rw-r--r--synapse/handlers/cas_handler.py40
-rw-r--r--synapse/handlers/oidc_handler.py40
-rw-r--r--synapse/handlers/room.py21
-rw-r--r--synapse/handlers/saml_handler.py26
-rw-r--r--synapse/handlers/sso.py90
6 files changed, 157 insertions, 76 deletions
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index a19c556437..648fe91f53 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -1472,10 +1472,22 @@ class AuthHandler(BaseHandler):
         # Remove the query parameters from the redirect URL to get a shorter version of
         # it. This is only to display a human-readable URL in the template, but not the
         # URL we redirect users to.
-        redirect_url_no_params = client_redirect_url.split("?")[0]
+        url_parts = urllib.parse.urlsplit(client_redirect_url)
+
+        if url_parts.scheme == "https":
+            # for an https uri, just show the netloc (ie, the hostname. Specifically,
+            # the bit between "//" and "/"; this includes any potential
+            # "username:password@" prefix.)
+            display_url = url_parts.netloc
+        else:
+            # for other uris, strip the query-params (including the login token) and
+            # fragment.
+            display_url = urllib.parse.urlunsplit(
+                (url_parts.scheme, url_parts.netloc, url_parts.path, "", "")
+            )
 
         html = self._sso_redirect_confirm_template.render(
-            display_url=redirect_url_no_params,
+            display_url=display_url,
             redirect_url=redirect_url,
             server_name=self._server_name,
             new_user=new_user,
diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py
index bd35d1fb87..81ed44ac87 100644
--- a/synapse/handlers/cas_handler.py
+++ b/synapse/handlers/cas_handler.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 import logging
 import urllib.parse
-from typing import TYPE_CHECKING, Dict, Optional
+from typing import TYPE_CHECKING, Dict, List, Optional
 from xml.etree import ElementTree as ET
 
 import attr
@@ -49,7 +49,7 @@ class CasError(Exception):
 @attr.s(slots=True, frozen=True)
 class CasResponse:
     username = attr.ib(type=str)
-    attributes = attr.ib(type=Dict[str, Optional[str]])
+    attributes = attr.ib(type=Dict[str, List[Optional[str]]])
 
 
 class CasHandler:
@@ -169,7 +169,7 @@ class CasHandler:
 
         # Iterate through the nodes and pull out the user and any extra attributes.
         user = None
-        attributes = {}
+        attributes = {}  # type: Dict[str, List[Optional[str]]]
         for child in root[0]:
             if child.tag.endswith("user"):
                 user = child.text
@@ -182,7 +182,7 @@ class CasHandler:
                     tag = attribute.tag
                     if "}" in tag:
                         tag = tag.split("}")[1]
-                    attributes[tag] = attribute.text
+                    attributes.setdefault(tag, []).append(attribute.text)
 
         # Ensure a user was found.
         if user is None:
@@ -303,29 +303,10 @@ class CasHandler:
 
         # Ensure that the attributes of the logged in user meet the required
         # attributes.
-        for required_attribute, required_value in self._cas_required_attributes.items():
-            # If required attribute was not in CAS Response - Forbidden
-            if required_attribute not in cas_response.attributes:
-                self._sso_handler.render_error(
-                    request,
-                    "unauthorised",
-                    "You are not authorised to log in here.",
-                    401,
-                )
-                return
-
-            # Also need to check value
-            if required_value is not None:
-                actual_value = cas_response.attributes[required_attribute]
-                # If required attribute value does not match expected - Forbidden
-                if required_value != actual_value:
-                    self._sso_handler.render_error(
-                        request,
-                        "unauthorised",
-                        "You are not authorised to log in here.",
-                        401,
-                    )
-                    return
+        if not self._sso_handler.check_required_attributes(
+            request, cas_response.attributes, self._cas_required_attributes
+        ):
+            return
 
         # Call the mapper to register/login the user
 
@@ -372,9 +353,10 @@ class CasHandler:
             if failures:
                 raise RuntimeError("CAS is not expected to de-duplicate Matrix IDs")
 
+            # Arbitrarily use the first attribute found.
             display_name = cas_response.attributes.get(
-                self._cas_displayname_attribute, None
-            )
+                self._cas_displayname_attribute, [None]
+            )[0]
 
             return UserAttributes(localpart=localpart, display_name=display_name)
 
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index 71008ec50d..3adc75fa4a 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -123,7 +123,6 @@ class OidcHandler:
         Args:
             request: the incoming request from the browser.
         """
-
         # The provider might redirect with an error.
         # In that case, just display it as-is.
         if b"error" in request.args:
@@ -137,8 +136,12 @@ class OidcHandler:
             # either the provider misbehaving or Synapse being misconfigured.
             # The only exception of that is "access_denied", where the user
             # probably cancelled the login flow. In other cases, log those errors.
-            if error != "access_denied":
-                logger.error("Error from the OIDC provider: %s %s", error, description)
+            logger.log(
+                logging.INFO if error == "access_denied" else logging.ERROR,
+                "Received OIDC callback with error: %s %s",
+                error,
+                description,
+            )
 
             self._sso_handler.render_error(request, error, description)
             return
@@ -149,7 +152,7 @@ class OidcHandler:
         # Fetch the session cookie
         session = request.getCookie(SESSION_COOKIE_NAME)  # type: Optional[bytes]
         if session is None:
-            logger.info("No session cookie found")
+            logger.info("Received OIDC callback, with no session cookie")
             self._sso_handler.render_error(
                 request, "missing_session", "No session cookie found"
             )
@@ -169,7 +172,7 @@ class OidcHandler:
 
         # Check for the state query parameter
         if b"state" not in request.args:
-            logger.info("State parameter is missing")
+            logger.info("Received OIDC callback, with no state parameter")
             self._sso_handler.render_error(
                 request, "invalid_request", "State parameter is missing"
             )
@@ -183,14 +186,16 @@ class OidcHandler:
                 session, state
             )
         except (MacaroonDeserializationException, ValueError) as e:
-            logger.exception("Invalid session")
+            logger.exception("Invalid session for OIDC callback")
             self._sso_handler.render_error(request, "invalid_session", str(e))
             return
         except MacaroonInvalidSignatureException as e:
-            logger.exception("Could not verify session")
+            logger.exception("Could not verify session for OIDC callback")
             self._sso_handler.render_error(request, "mismatching_session", str(e))
             return
 
+        logger.info("Received OIDC callback for IdP %s", session_data.idp_id)
+
         oidc_provider = self._providers.get(session_data.idp_id)
         if not oidc_provider:
             logger.error("OIDC session uses unknown IdP %r", oidc_provider)
@@ -565,6 +570,7 @@ class OidcProvider:
         Returns:
             UserInfo: an object representing the user.
         """
+        logger.debug("Using the OAuth2 access_token to request userinfo")
         metadata = await self.load_metadata()
 
         resp = await self._http_client.get_json(
@@ -572,6 +578,8 @@ class OidcProvider:
             headers={"Authorization": ["Bearer {}".format(token["access_token"])]},
         )
 
+        logger.debug("Retrieved user info from userinfo endpoint: %r", resp)
+
         return UserInfo(resp)
 
     async def _parse_id_token(self, token: Token, nonce: str) -> UserInfo:
@@ -600,17 +608,19 @@ class OidcProvider:
             claims_cls = ImplicitIDToken
 
         alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"])
-
         jwt = JsonWebToken(alg_values)
 
         claim_options = {"iss": {"values": [metadata["issuer"]]}}
 
+        id_token = token["id_token"]
+        logger.debug("Attempting to decode JWT id_token %r", id_token)
+
         # Try to decode the keys in cache first, then retry by forcing the keys
         # to be reloaded
         jwk_set = await self.load_jwks()
         try:
             claims = jwt.decode(
-                token["id_token"],
+                id_token,
                 key=jwk_set,
                 claims_cls=claims_cls,
                 claims_options=claim_options,
@@ -620,13 +630,15 @@ class OidcProvider:
             logger.info("Reloading JWKS after decode error")
             jwk_set = await self.load_jwks(force=True)  # try reloading the jwks
             claims = jwt.decode(
-                token["id_token"],
+                id_token,
                 key=jwk_set,
                 claims_cls=claims_cls,
                 claims_options=claim_options,
                 claims_params=claims_params,
             )
 
+        logger.debug("Decoded id_token JWT %r; validating", claims)
+
         claims.validate(leeway=120)  # allows 2 min of clock skew
         return UserInfo(claims)
 
@@ -726,19 +738,18 @@ class OidcProvider:
         """
         # Exchange the code with the provider
         try:
-            logger.debug("Exchanging code")
+            logger.debug("Exchanging OAuth2 code for a token")
             token = await self._exchange_code(code)
         except OidcError as e:
-            logger.exception("Could not exchange code")
+            logger.exception("Could not exchange OAuth2 code")
             self._sso_handler.render_error(request, e.error, e.error_description)
             return
 
-        logger.debug("Successfully obtained OAuth2 access token")
+        logger.debug("Successfully obtained OAuth2 token data: %r", token)
 
         # Now that we have a token, get the userinfo, either by decoding the
         # `id_token` or by fetching the `userinfo_endpoint`.
         if self._uses_userinfo:
-            logger.debug("Fetching userinfo")
             try:
                 userinfo = await self._fetch_userinfo(token)
             except Exception as e:
@@ -746,7 +757,6 @@ class OidcProvider:
                 self._sso_handler.render_error(request, "fetch_error", str(e))
                 return
         else:
-            logger.debug("Extracting userinfo from id_token")
             try:
                 userinfo = await self._parse_id_token(token, nonce=session_data.nonce)
             except Exception as e:
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 07b2187eb1..1336a23a3a 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -38,6 +38,7 @@ from synapse.api.filtering import Filter
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
 from synapse.events import EventBase
 from synapse.events.utils import copy_power_levels_contents
+from synapse.rest.admin._base import assert_user_is_admin
 from synapse.storage.state import StateFilter
 from synapse.types import (
     JsonDict,
@@ -1004,41 +1005,51 @@ class RoomCreationHandler(BaseHandler):
 class RoomContextHandler:
     def __init__(self, hs: "HomeServer"):
         self.hs = hs
+        self.auth = hs.get_auth()
         self.store = hs.get_datastore()
         self.storage = hs.get_storage()
         self.state_store = self.storage.state
 
     async def get_event_context(
         self,
-        user: UserID,
+        requester: Requester,
         room_id: str,
         event_id: str,
         limit: int,
         event_filter: Optional[Filter],
+        use_admin_priviledge: bool = False,
     ) -> Optional[JsonDict]:
         """Retrieves events, pagination tokens and state around a given event
         in a room.
 
         Args:
-            user
+            requester
             room_id
             event_id
             limit: The maximum number of events to return in total
                 (excluding state).
             event_filter: the filter to apply to the events returned
                 (excluding the target event_id)
-
+            use_admin_priviledge: if `True`, return all events, regardless
+                of whether `user` has access to them. To be used **ONLY**
+                from the admin API.
         Returns:
             dict, or None if the event isn't found
         """
+        user = requester.user
+        if use_admin_priviledge:
+            await assert_user_is_admin(self.auth, requester.user)
+
         before_limit = math.floor(limit / 2.0)
         after_limit = limit - before_limit
 
         users = await self.store.get_users_in_room(room_id)
         is_peeking = user.to_string() not in users
 
-        def filter_evts(events):
-            return filter_events_for_client(
+        async def filter_evts(events):
+            if use_admin_priviledge:
+                return events
+            return await filter_events_for_client(
                 self.storage, user.to_string(), events, is_peeking=is_peeking
             )
 
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index e88fd59749..78f130e152 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -23,7 +23,6 @@ from saml2.client import Saml2Client
 
 from synapse.api.errors import SynapseError
 from synapse.config import ConfigError
-from synapse.config.saml2_config import SamlAttributeRequirement
 from synapse.handlers._base import BaseHandler
 from synapse.handlers.sso import MappingException, UserAttributes
 from synapse.http.servlet import parse_string
@@ -239,12 +238,10 @@ class SamlHandler(BaseHandler):
 
         # Ensure that the attributes of the logged in user meet the required
         # attributes.
-        for requirement in self._saml2_attribute_requirements:
-            if not _check_attribute_requirement(saml2_auth.ava, requirement):
-                self._sso_handler.render_error(
-                    request, "unauthorised", "You are not authorised to log in here."
-                )
-                return
+        if not self._sso_handler.check_required_attributes(
+            request, saml2_auth.ava, self._saml2_attribute_requirements
+        ):
+            return
 
         # Call the mapper to register/login the user
         try:
@@ -373,21 +370,6 @@ class SamlHandler(BaseHandler):
             del self._outstanding_requests_dict[reqid]
 
 
-def _check_attribute_requirement(ava: dict, req: SamlAttributeRequirement) -> bool:
-    values = ava.get(req.attribute, [])
-    for v in values:
-        if v == req.value:
-            return True
-
-    logger.info(
-        "SAML2 attribute %s did not match required value '%s' (was '%s')",
-        req.attribute,
-        req.value,
-        values,
-    )
-    return False
-
-
 DOT_REPLACE_PATTERN = re.compile(
     ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),))
 )
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index b450668f1c..a63fd52485 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -16,10 +16,12 @@ import abc
 import logging
 from typing import (
     TYPE_CHECKING,
+    Any,
     Awaitable,
     Callable,
     Dict,
     Iterable,
+    List,
     Mapping,
     Optional,
     Set,
@@ -34,6 +36,7 @@ from twisted.web.iweb import IRequest
 
 from synapse.api.constants import LoginType
 from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError
+from synapse.config.sso import SsoAttributeRequirement
 from synapse.handlers.ui_auth import UIAuthSessionDataConstants
 from synapse.http import get_request_user_agent
 from synapse.http.server import respond_with_html, respond_with_redirect
@@ -742,7 +745,11 @@ class SsoHandler:
             use_display_name: whether the user wants to use the suggested display name
             emails_to_use: emails that the user would like to use
         """
-        session = self.get_mapping_session(session_id)
+        try:
+            session = self.get_mapping_session(session_id)
+        except SynapseError as e:
+            self.render_error(request, "bad_session", e.msg, code=e.code)
+            return
 
         # update the session with the user's choices
         session.chosen_localpart = localpart
@@ -793,7 +800,12 @@ class SsoHandler:
             session_id,
             terms_version,
         )
-        session = self.get_mapping_session(session_id)
+        try:
+            session = self.get_mapping_session(session_id)
+        except SynapseError as e:
+            self.render_error(request, "bad_session", e.msg, code=e.code)
+            return
+
         session.terms_accepted_version = terms_version
 
         # we're done; now we can register the user
@@ -808,7 +820,11 @@ class SsoHandler:
             request: HTTP request
             session_id: ID of the username mapping session, extracted from a cookie
         """
-        session = self.get_mapping_session(session_id)
+        try:
+            session = self.get_mapping_session(session_id)
+        except SynapseError as e:
+            self.render_error(request, "bad_session", e.msg, code=e.code)
+            return
 
         logger.info(
             "[session %s] Registering localpart %s",
@@ -880,6 +896,41 @@ class SsoHandler:
             logger.info("Expiring mapping session %s", session_id)
             del self._username_mapping_sessions[session_id]
 
+    def check_required_attributes(
+        self,
+        request: SynapseRequest,
+        attributes: Mapping[str, List[Any]],
+        attribute_requirements: Iterable[SsoAttributeRequirement],
+    ) -> bool:
+        """
+        Confirm that the required attributes were present in the SSO response.
+
+        If all requirements are met, this will return True.
+
+        If any requirement is not met, then the request will be finalized by
+        showing an error page to the user and False will be returned.
+
+        Args:
+            request: The request to (potentially) respond to.
+            attributes: The attributes from the SSO IdP.
+            attribute_requirements: The requirements that attributes must meet.
+
+        Returns:
+            True if all requirements are met, False if any attribute fails to
+            meet the requirement.
+
+        """
+        # Ensure that the attributes of the logged in user meet the required
+        # attributes.
+        for requirement in attribute_requirements:
+            if not _check_attribute_requirement(attributes, requirement):
+                self.render_error(
+                    request, "unauthorised", "You are not authorised to log in here."
+                )
+                return False
+
+        return True
+
 
 def get_username_mapping_session_cookie_from_request(request: IRequest) -> str:
     """Extract the session ID from the cookie
@@ -890,3 +941,36 @@ def get_username_mapping_session_cookie_from_request(request: IRequest) -> str:
     if not session_id:
         raise SynapseError(code=400, msg="missing session_id")
     return session_id.decode("ascii", errors="replace")
+
+
+def _check_attribute_requirement(
+    attributes: Mapping[str, List[Any]], req: SsoAttributeRequirement
+) -> bool:
+    """Check if SSO attributes meet the proper requirements.
+
+    Args:
+        attributes: A mapping of attributes to an iterable of one or more values.
+        requirement: The configured requirement to check.
+
+    Returns:
+        True if the required attribute was found and had a proper value.
+    """
+    if req.attribute not in attributes:
+        logger.info("SSO attribute missing: %s", req.attribute)
+        return False
+
+    # If the requirement is None, the attribute existing is enough.
+    if req.value is None:
+        return True
+
+    values = attributes[req.attribute]
+    if req.value in values:
+        return True
+
+    logger.info(
+        "SSO attribute %s did not match required value '%s' (was '%s')",
+        req.attribute,
+        req.value,
+        values,
+    )
+    return False