summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/8765.misc1
-rw-r--r--synapse/handlers/oidc_handler.py92
-rw-r--r--synapse/handlers/saml_handler.py77
-rw-r--r--synapse/handlers/sso.py90
-rw-r--r--synapse/server.py5
-rw-r--r--tests/handlers/test_oidc.py14
6 files changed, 159 insertions, 120 deletions
diff --git a/changelog.d/8765.misc b/changelog.d/8765.misc
new file mode 100644
index 0000000000..053f9acc9c
--- /dev/null
+++ b/changelog.d/8765.misc
@@ -0,0 +1 @@
+Consolidate logic between the OpenID Connect and SAML code.
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index 331d4e7e96..be8562d47b 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -34,7 +34,8 @@ from typing_extensions import TypedDict
 from twisted.web.client import readBody
 
 from synapse.config import ConfigError
-from synapse.http.server import respond_with_html
+from synapse.handlers._base import BaseHandler
+from synapse.handlers.sso import MappingException
 from synapse.http.site import SynapseRequest
 from synapse.logging.context import make_deferred_yieldable
 from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
@@ -83,17 +84,12 @@ class OidcError(Exception):
         return self.error
 
 
-class MappingException(Exception):
-    """Used to catch errors when mapping the UserInfo object
-    """
-
-
-class OidcHandler:
+class OidcHandler(BaseHandler):
     """Handles requests related to the OpenID Connect login flow.
     """
 
     def __init__(self, hs: "HomeServer"):
-        self.hs = hs
+        super().__init__(hs)
         self._callback_url = hs.config.oidc_callback_url  # type: str
         self._scopes = hs.config.oidc_scopes  # type: List[str]
         self._user_profile_method = hs.config.oidc_user_profile_method  # type: str
@@ -120,36 +116,13 @@ class OidcHandler:
         self._http_client = hs.get_proxied_http_client()
         self._auth_handler = hs.get_auth_handler()
         self._registration_handler = hs.get_registration_handler()
-        self._datastore = hs.get_datastore()
-        self._clock = hs.get_clock()
-        self._hostname = hs.hostname  # type: str
         self._server_name = hs.config.server_name  # type: str
         self._macaroon_secret_key = hs.config.macaroon_secret_key
-        self._error_template = hs.config.sso_error_template
 
         # identifier for the external_ids table
         self._auth_provider_id = "oidc"
 
-    def _render_error(
-        self, request, error: str, error_description: Optional[str] = None
-    ) -> None:
-        """Render the error template and respond to the request with it.
-
-        This is used to show errors to the user. The template of this page can
-        be found under `synapse/res/templates/sso_error.html`.
-
-        Args:
-            request: The incoming request from the browser.
-                We'll respond with an HTML page describing the error.
-            error: A technical identifier for this error. Those include
-                well-known OAuth2/OIDC error types like invalid_request or
-                access_denied.
-            error_description: A human-readable description of the error.
-        """
-        html = self._error_template.render(
-            error=error, error_description=error_description
-        )
-        respond_with_html(request, 400, html)
+        self._sso_handler = hs.get_sso_handler()
 
     def _validate_metadata(self):
         """Verifies the provider metadata.
@@ -571,7 +544,7 @@ class OidcHandler:
 
         Since we might want to display OIDC-related errors in a user-friendly
         way, we don't raise SynapseError from here. Instead, we call
-        ``self._render_error`` which displays an HTML page for the error.
+        ``self._sso_handler.render_error`` which displays an HTML page for the error.
 
         Most of the OpenID Connect logic happens here:
 
@@ -609,7 +582,7 @@ class OidcHandler:
             if error != "access_denied":
                 logger.error("Error from the OIDC provider: %s %s", error, description)
 
-            self._render_error(request, error, description)
+            self._sso_handler.render_error(request, error, description)
             return
 
         # otherwise, it is presumably a successful response. see:
@@ -619,7 +592,9 @@ class OidcHandler:
         session = request.getCookie(SESSION_COOKIE_NAME)  # type: Optional[bytes]
         if session is None:
             logger.info("No session cookie found")
-            self._render_error(request, "missing_session", "No session cookie found")
+            self._sso_handler.render_error(
+                request, "missing_session", "No session cookie found"
+            )
             return
 
         # Remove the cookie. There is a good chance that if the callback failed
@@ -637,7 +612,9 @@ class OidcHandler:
         # Check for the state query parameter
         if b"state" not in request.args:
             logger.info("State parameter is missing")
-            self._render_error(request, "invalid_request", "State parameter is missing")
+            self._sso_handler.render_error(
+                request, "invalid_request", "State parameter is missing"
+            )
             return
 
         state = request.args[b"state"][0].decode()
@@ -651,17 +628,19 @@ class OidcHandler:
             ) = self._verify_oidc_session_token(session, state)
         except MacaroonDeserializationException as e:
             logger.exception("Invalid session")
-            self._render_error(request, "invalid_session", str(e))
+            self._sso_handler.render_error(request, "invalid_session", str(e))
             return
         except MacaroonInvalidSignatureException as e:
             logger.exception("Could not verify session")
-            self._render_error(request, "mismatching_session", str(e))
+            self._sso_handler.render_error(request, "mismatching_session", str(e))
             return
 
         # Exchange the code with the provider
         if b"code" not in request.args:
             logger.info("Code parameter is missing")
-            self._render_error(request, "invalid_request", "Code parameter is missing")
+            self._sso_handler.render_error(
+                request, "invalid_request", "Code parameter is missing"
+            )
             return
 
         logger.debug("Exchanging code")
@@ -670,7 +649,7 @@ class OidcHandler:
             token = await self._exchange_code(code)
         except OidcError as e:
             logger.exception("Could not exchange code")
-            self._render_error(request, e.error, e.error_description)
+            self._sso_handler.render_error(request, e.error, e.error_description)
             return
 
         logger.debug("Successfully obtained OAuth2 access token")
@@ -683,7 +662,7 @@ class OidcHandler:
                 userinfo = await self._fetch_userinfo(token)
             except Exception as e:
                 logger.exception("Could not fetch userinfo")
-                self._render_error(request, "fetch_error", str(e))
+                self._sso_handler.render_error(request, "fetch_error", str(e))
                 return
         else:
             logger.debug("Extracting userinfo from id_token")
@@ -691,7 +670,7 @@ class OidcHandler:
                 userinfo = await self._parse_id_token(token, nonce=nonce)
             except Exception as e:
                 logger.exception("Invalid id_token")
-                self._render_error(request, "invalid_token", str(e))
+                self._sso_handler.render_error(request, "invalid_token", str(e))
                 return
 
         # Pull out the user-agent and IP from the request.
@@ -705,7 +684,7 @@ class OidcHandler:
             )
         except MappingException as e:
             logger.exception("Could not map user")
-            self._render_error(request, "mapping_error", str(e))
+            self._sso_handler.render_error(request, "mapping_error", str(e))
             return
 
         # Mapping providers might not have get_extra_attributes: only call this
@@ -770,7 +749,7 @@ class OidcHandler:
             macaroon.add_first_party_caveat(
                 "ui_auth_session_id = %s" % (ui_auth_session_id,)
             )
-        now = self._clock.time_msec()
+        now = self.clock.time_msec()
         expiry = now + duration_in_ms
         macaroon.add_first_party_caveat("time < %d" % (expiry,))
 
@@ -845,7 +824,7 @@ class OidcHandler:
         if not caveat.startswith(prefix):
             return False
         expiry = int(caveat[len(prefix) :])
-        now = self._clock.time_msec()
+        now = self.clock.time_msec()
         return now < expiry
 
     async def _map_userinfo_to_user(
@@ -885,20 +864,14 @@ class OidcHandler:
         # to be strings.
         remote_user_id = str(remote_user_id)
 
-        logger.info(
-            "Looking for existing mapping for user %s:%s",
-            self._auth_provider_id,
-            remote_user_id,
-        )
-
-        registered_user_id = await self._datastore.get_user_by_external_id(
+        # first of all, check if we already have a mapping for this user
+        previously_registered_user_id = await self._sso_handler.get_sso_user_by_remote_user_id(
             self._auth_provider_id, remote_user_id,
         )
+        if previously_registered_user_id:
+            return previously_registered_user_id
 
-        if registered_user_id is not None:
-            logger.info("Found existing mapping %s", registered_user_id)
-            return registered_user_id
-
+        # Otherwise, generate a new user.
         try:
             attributes = await self._user_mapping_provider.map_user_attributes(
                 userinfo, token
@@ -917,8 +890,8 @@ class OidcHandler:
 
         localpart = map_username_to_mxid_localpart(attributes["localpart"])
 
-        user_id = UserID(localpart, self._hostname).to_string()
-        users = await self._datastore.get_users_by_id_case_insensitive(user_id)
+        user_id = UserID(localpart, self.server_name).to_string()
+        users = await self.store.get_users_by_id_case_insensitive(user_id)
         if users:
             if self._allow_existing_users:
                 if len(users) == 1:
@@ -942,7 +915,8 @@ class OidcHandler:
                 default_display_name=attributes["display_name"],
                 user_agent_ips=(user_agent, ip_address),
             )
-        await self._datastore.record_user_external_id(
+
+        await self.store.record_user_external_id(
             self._auth_provider_id, remote_user_id, registered_user_id,
         )
         return registered_user_id
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index fd6c5e9ea8..aee772239a 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -24,7 +24,8 @@ from saml2.client import Saml2Client
 from synapse.api.errors import SynapseError
 from synapse.config import ConfigError
 from synapse.config.saml2_config import SamlAttributeRequirement
-from synapse.http.server import respond_with_html
+from synapse.handlers._base import BaseHandler
+from synapse.handlers.sso import MappingException
 from synapse.http.servlet import parse_string
 from synapse.http.site import SynapseRequest
 from synapse.module_api import ModuleApi
@@ -42,10 +43,6 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
-class MappingException(Exception):
-    """Used to catch errors when mapping the SAML2 response to a user."""
-
-
 @attr.s(slots=True)
 class Saml2SessionData:
     """Data we track about SAML2 sessions"""
@@ -57,17 +54,13 @@ class Saml2SessionData:
     ui_auth_session_id = attr.ib(type=Optional[str], default=None)
 
 
-class SamlHandler:
+class SamlHandler(BaseHandler):
     def __init__(self, hs: "synapse.server.HomeServer"):
-        self.hs = hs
+        super().__init__(hs)
         self._saml_client = Saml2Client(hs.config.saml2_sp_config)
-        self._auth = hs.get_auth()
         self._auth_handler = hs.get_auth_handler()
         self._registration_handler = hs.get_registration_handler()
 
-        self._clock = hs.get_clock()
-        self._datastore = hs.get_datastore()
-        self._hostname = hs.hostname
         self._saml2_session_lifetime = hs.config.saml2_session_lifetime
         self._grandfathered_mxid_source_attribute = (
             hs.config.saml2_grandfathered_mxid_source_attribute
@@ -88,26 +81,9 @@ class SamlHandler:
         self._outstanding_requests_dict = {}  # type: Dict[str, Saml2SessionData]
 
         # a lock on the mappings
-        self._mapping_lock = Linearizer(name="saml_mapping", clock=self._clock)
-
-    def _render_error(
-        self, request, error: str, error_description: Optional[str] = None
-    ) -> None:
-        """Render the error template and respond to the request with it.
+        self._mapping_lock = Linearizer(name="saml_mapping", clock=self.clock)
 
-        This is used to show errors to the user. The template of this page can
-        be found under `synapse/res/templates/sso_error.html`.
-
-        Args:
-            request: The incoming request from the browser.
-                We'll respond with an HTML page describing the error.
-            error: A technical identifier for this error.
-            error_description: A human-readable description of the error.
-        """
-        html = self._error_template.render(
-            error=error, error_description=error_description
-        )
-        respond_with_html(request, 400, html)
+        self._sso_handler = hs.get_sso_handler()
 
     def handle_redirect_request(
         self, client_redirect_url: bytes, ui_auth_session_id: Optional[str] = None
@@ -130,7 +106,7 @@ class SamlHandler:
         # Since SAML sessions timeout it is useful to log when they were created.
         logger.info("Initiating a new SAML session: %s" % (reqid,))
 
-        now = self._clock.time_msec()
+        now = self.clock.time_msec()
         self._outstanding_requests_dict[reqid] = Saml2SessionData(
             creation_time=now, ui_auth_session_id=ui_auth_session_id,
         )
@@ -171,12 +147,12 @@ class SamlHandler:
             # in the (user-visible) exception message, so let's log the exception here
             # so we can track down the session IDs later.
             logger.warning(str(e))
-            self._render_error(
+            self._sso_handler.render_error(
                 request, "unsolicited_response", "Unexpected SAML2 login."
             )
             return
         except Exception as e:
-            self._render_error(
+            self._sso_handler.render_error(
                 request,
                 "invalid_response",
                 "Unable to parse SAML2 response: %s." % (e,),
@@ -184,7 +160,7 @@ class SamlHandler:
             return
 
         if saml2_auth.not_signed:
-            self._render_error(
+            self._sso_handler.render_error(
                 request, "unsigned_respond", "SAML2 response was not signed."
             )
             return
@@ -210,7 +186,7 @@ class SamlHandler:
         # attributes.
         for requirement in self._saml2_attribute_requirements:
             if not _check_attribute_requirement(saml2_auth.ava, requirement):
-                self._render_error(
+                self._sso_handler.render_error(
                     request, "unauthorised", "You are not authorised to log in here."
                 )
                 return
@@ -226,7 +202,7 @@ class SamlHandler:
             )
         except MappingException as e:
             logger.exception("Could not map user")
-            self._render_error(request, "mapping_error", str(e))
+            self._sso_handler.render_error(request, "mapping_error", str(e))
             return
 
         # Complete the interactive auth session or the login.
@@ -274,17 +250,11 @@ class SamlHandler:
 
         with (await self._mapping_lock.queue(self._auth_provider_id)):
             # first of all, check if we already have a mapping for this user
-            logger.info(
-                "Looking for existing mapping for user %s:%s",
-                self._auth_provider_id,
-                remote_user_id,
+            previously_registered_user_id = await self._sso_handler.get_sso_user_by_remote_user_id(
+                self._auth_provider_id, remote_user_id,
             )
-            registered_user_id = await self._datastore.get_user_by_external_id(
-                self._auth_provider_id, remote_user_id
-            )
-            if registered_user_id is not None:
-                logger.info("Found existing mapping %s", registered_user_id)
-                return registered_user_id
+            if previously_registered_user_id:
+                return previously_registered_user_id
 
             # backwards-compatibility hack: see if there is an existing user with a
             # suitable mapping from the uid
@@ -294,7 +264,7 @@ class SamlHandler:
             ):
                 attrval = saml2_auth.ava[self._grandfathered_mxid_source_attribute][0]
                 user_id = UserID(
-                    map_username_to_mxid_localpart(attrval), self._hostname
+                    map_username_to_mxid_localpart(attrval), self.server_name
                 ).to_string()
                 logger.info(
                     "Looking for existing account based on mapped %s %s",
@@ -302,11 +272,11 @@ class SamlHandler:
                     user_id,
                 )
 
-                users = await self._datastore.get_users_by_id_case_insensitive(user_id)
+                users = await self.store.get_users_by_id_case_insensitive(user_id)
                 if users:
                     registered_user_id = list(users.keys())[0]
                     logger.info("Grandfathering mapping to %s", registered_user_id)
-                    await self._datastore.record_user_external_id(
+                    await self.store.record_user_external_id(
                         self._auth_provider_id, remote_user_id, registered_user_id
                     )
                     return registered_user_id
@@ -335,8 +305,8 @@ class SamlHandler:
                 emails = attribute_dict.get("emails", [])
 
                 # Check if this mxid already exists
-                if not await self._datastore.get_users_by_id_case_insensitive(
-                    UserID(localpart, self._hostname).to_string()
+                if not await self.store.get_users_by_id_case_insensitive(
+                    UserID(localpart, self.server_name).to_string()
                 ):
                     # This mxid is free
                     break
@@ -348,7 +318,6 @@ class SamlHandler:
                 )
 
             logger.info("Mapped SAML user to local part %s", localpart)
-
             registered_user_id = await self._registration_handler.register_user(
                 localpart=localpart,
                 default_display_name=displayname,
@@ -356,13 +325,13 @@ class SamlHandler:
                 user_agent_ips=(user_agent, ip_address),
             )
 
-            await self._datastore.record_user_external_id(
+            await self.store.record_user_external_id(
                 self._auth_provider_id, remote_user_id, registered_user_id
             )
             return registered_user_id
 
     def expire_sessions(self):
-        expire_before = self._clock.time_msec() - self._saml2_session_lifetime
+        expire_before = self.clock.time_msec() - self._saml2_session_lifetime
         to_expire = set()
         for reqid, data in self._outstanding_requests_dict.items():
             if data.creation_time < expire_before:
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
new file mode 100644
index 0000000000..9cb1866a71
--- /dev/null
+++ b/synapse/handlers/sso.py
@@ -0,0 +1,90 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+from typing import TYPE_CHECKING, Optional
+
+from synapse.handlers._base import BaseHandler
+from synapse.http.server import respond_with_html
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class MappingException(Exception):
+    """Used to catch errors when mapping the UserInfo object
+    """
+
+
+class SsoHandler(BaseHandler):
+    def __init__(self, hs: "HomeServer"):
+        super().__init__(hs)
+        self._error_template = hs.config.sso_error_template
+
+    def render_error(
+        self, request, error: str, error_description: Optional[str] = None
+    ) -> None:
+        """Renders the error template and responds with it.
+
+        This is used to show errors to the user. The template of this page can
+        be found under `synapse/res/templates/sso_error.html`.
+
+        Args:
+            request: The incoming request from the browser.
+                We'll respond with an HTML page describing the error.
+            error: A technical identifier for this error.
+            error_description: A human-readable description of the error.
+        """
+        html = self._error_template.render(
+            error=error, error_description=error_description
+        )
+        respond_with_html(request, 400, html)
+
+    async def get_sso_user_by_remote_user_id(
+        self, auth_provider_id: str, remote_user_id: str
+    ) -> Optional[str]:
+        """
+        Maps the user ID of a remote IdP to a mxid for a previously seen user.
+
+        If the user has not been seen yet, this will return None.
+
+        Args:
+            auth_provider_id: A unique identifier for this SSO provider, e.g.
+                "oidc" or "saml".
+            remote_user_id: The user ID according to the remote IdP. This might
+                be an e-mail address, a GUID, or some other form. It must be
+                unique and immutable.
+
+        Returns:
+            The mxid of a previously seen user.
+        """
+        # Check if we already have a mapping for this user.
+        logger.info(
+            "Looking for existing mapping for user %s:%s",
+            auth_provider_id,
+            remote_user_id,
+        )
+        previously_registered_user_id = await self.store.get_user_by_external_id(
+            auth_provider_id, remote_user_id,
+        )
+
+        # A match was found, return the user ID.
+        if previously_registered_user_id is not None:
+            logger.info("Found existing mapping %s", previously_registered_user_id)
+            return previously_registered_user_id
+
+        # No match.
+        return None
diff --git a/synapse/server.py b/synapse/server.py
index 21a232bbd9..12a783de17 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -89,6 +89,7 @@ from synapse.handlers.room_member import RoomMemberMasterHandler
 from synapse.handlers.room_member_worker import RoomMemberWorkerHandler
 from synapse.handlers.search import SearchHandler
 from synapse.handlers.set_password import SetPasswordHandler
+from synapse.handlers.sso import SsoHandler
 from synapse.handlers.stats import StatsHandler
 from synapse.handlers.sync import SyncHandler
 from synapse.handlers.typing import FollowerTypingHandler, TypingWriterHandler
@@ -391,6 +392,10 @@ class HomeServer(metaclass=abc.ABCMeta):
             return FollowerTypingHandler(self)
 
     @cache_in_self
+    def get_sso_handler(self) -> SsoHandler:
+        return SsoHandler(self)
+
+    @cache_in_self
     def get_sync_handler(self) -> SyncHandler:
         return SyncHandler(self)
 
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index 0d51705849..630e6da808 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -154,6 +154,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
         )
 
         self.handler = OidcHandler(hs)
+        # Mock the render error method.
+        self.render_error = Mock(return_value=None)
+        self.handler._sso_handler.render_error = self.render_error
 
         return hs
 
@@ -161,12 +164,12 @@ class OidcHandlerTestCase(HomeserverTestCase):
         return patch.dict(self.handler._provider_metadata, values)
 
     def assertRenderedError(self, error, error_description=None):
-        args = self.handler._render_error.call_args[0]
+        args = self.render_error.call_args[0]
         self.assertEqual(args[1], error)
         if error_description is not None:
             self.assertEqual(args[2], error_description)
         # Reset the render_error mock
-        self.handler._render_error.reset_mock()
+        self.render_error.reset_mock()
 
     def test_config(self):
         """Basic config correctly sets up the callback URL and client auth correctly."""
@@ -356,7 +359,6 @@ class OidcHandlerTestCase(HomeserverTestCase):
 
     def test_callback_error(self):
         """Errors from the provider returned in the callback are displayed."""
-        self.handler._render_error = Mock()
         request = Mock(args={})
         request.args[b"error"] = [b"invalid_client"]
         self.get_success(self.handler.handle_oidc_callback(request))
@@ -387,7 +389,6 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "preferred_username": "bar",
         }
         user_id = "@foo:domain.org"
-        self.handler._render_error = Mock(return_value=None)
         self.handler._exchange_code = simple_async_mock(return_value=token)
         self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
         self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
@@ -435,7 +436,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             userinfo, token, user_agent, ip_address
         )
         self.handler._fetch_userinfo.assert_not_called()
-        self.handler._render_error.assert_not_called()
+        self.render_error.assert_not_called()
 
         # Handle mapping errors
         self.handler._map_userinfo_to_user = simple_async_mock(
@@ -469,7 +470,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             userinfo, token, user_agent, ip_address
         )
         self.handler._fetch_userinfo.assert_called_once_with(token)
-        self.handler._render_error.assert_not_called()
+        self.render_error.assert_not_called()
 
         # Handle userinfo fetching error
         self.handler._fetch_userinfo = simple_async_mock(raises=Exception())
@@ -485,7 +486,6 @@ class OidcHandlerTestCase(HomeserverTestCase):
 
     def test_callback_session(self):
         """The callback verifies the session presence and validity"""
-        self.handler._render_error = Mock(return_value=None)
         request = Mock(spec=["args", "getCookie", "addCookie"])
 
         # Missing cookie