summary refs log tree commit diff
path: root/synapse/handlers/auth.py
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2020-04-30 13:47:49 -0400
committerGitHub <noreply@github.com>2020-04-30 13:47:49 -0400
commit627b0f5f2753e6910adb7a877541d50f5936b8a5 (patch)
tree6b4d53f91efa0675fc9df8af85cd163243976d55 /synapse/handlers/auth.py
parentApply federation check for /publicRooms with filter list (#7367) (diff)
downloadsynapse-627b0f5f2753e6910adb7a877541d50f5936b8a5.tar.xz
Persist user interactive authentication sessions (#7302)
By persisting the user interactive authentication sessions to the database, this fixes
situations where a user hits different works throughout their auth session and also
allows sessions to persist through restarts of Synapse.
Diffstat (limited to 'synapse/handlers/auth.py')
-rw-r--r--synapse/handlers/auth.py175
1 files changed, 61 insertions, 114 deletions
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index dbe165ce1e..7613e5b6ab 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -41,10 +41,10 @@ from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
 from synapse.http.server import finish_request
 from synapse.http.site import SynapseRequest
 from synapse.logging.context import defer_to_thread
+from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.module_api import ModuleApi
 from synapse.push.mailer import load_jinja2_templates
 from synapse.types import Requester, UserID
-from synapse.util.caches.expiringcache import ExpiringCache
 
 from ._base import BaseHandler
 
@@ -69,15 +69,6 @@ class AuthHandler(BaseHandler):
 
         self.bcrypt_rounds = hs.config.bcrypt_rounds
 
-        # This is not a cache per se, but a store of all current sessions that
-        # expire after N hours
-        self.sessions = ExpiringCache(
-            cache_name="register_sessions",
-            clock=hs.get_clock(),
-            expiry_ms=self.SESSION_EXPIRE_MS,
-            reset_expiry_on_get=True,
-        )
-
         account_handler = ModuleApi(hs, self)
         self.password_providers = [
             module(config=config, account_handler=account_handler)
@@ -119,6 +110,15 @@ class AuthHandler(BaseHandler):
 
         self._clock = self.hs.get_clock()
 
+        # Expire old UI auth sessions after a period of time.
+        if hs.config.worker_app is None:
+            self._clock.looping_call(
+                run_as_background_process,
+                5 * 60 * 1000,
+                "expire_old_sessions",
+                self._expire_old_sessions,
+            )
+
         # Load the SSO HTML templates.
 
         # The following template is shown to the user during a client login via SSO,
@@ -301,16 +301,21 @@ class AuthHandler(BaseHandler):
             if "session" in authdict:
                 sid = authdict["session"]
 
+        # Convert the URI and method to strings.
+        uri = request.uri.decode("utf-8")
+        method = request.uri.decode("utf-8")
+
         # If there's no session ID, create a new session.
         if not sid:
-            session = self._create_session(
-                clientdict, (request.uri, request.method, clientdict), description
+            session = await self.store.create_ui_auth_session(
+                clientdict, uri, method, description
             )
-            session_id = session["id"]
 
         else:
-            session = self._get_session_info(sid)
-            session_id = sid
+            try:
+                session = await self.store.get_ui_auth_session(sid)
+            except StoreError:
+                raise SynapseError(400, "Unknown session ID: %s" % (sid,))
 
             if not clientdict:
                 # This was designed to allow the client to omit the parameters
@@ -322,15 +327,15 @@ class AuthHandler(BaseHandler):
                 # on a homeserver.
                 # Revisit: Assuming the REST APIs do sensible validation, the data
                 # isn't arbitrary.
-                clientdict = session["clientdict"]
+                clientdict = session.clientdict
 
             # Ensure that the queried operation does not vary between stages of
             # the UI authentication session. This is done by generating a stable
             # comparator based on the URI, method, and body (minus the auth dict)
             # and storing it during the initial query. Subsequent queries ensure
             # that this comparator has not changed.
-            comparator = (request.uri, request.method, clientdict)
-            if session["ui_auth"] != comparator:
+            comparator = (uri, method, clientdict)
+            if (session.uri, session.method, session.clientdict) != comparator:
                 raise SynapseError(
                     403,
                     "Requested operation has changed during the UI authentication session.",
@@ -338,11 +343,9 @@ class AuthHandler(BaseHandler):
 
         if not authdict:
             raise InteractiveAuthIncompleteError(
-                self._auth_dict_for_flows(flows, session_id)
+                self._auth_dict_for_flows(flows, session.session_id)
             )
 
-        creds = session["creds"]
-
         # check auth type currently being presented
         errordict = {}  # type: Dict[str, Any]
         if "type" in authdict:
@@ -350,8 +353,9 @@ class AuthHandler(BaseHandler):
             try:
                 result = await self._check_auth_dict(authdict, clientip)
                 if result:
-                    creds[login_type] = result
-                    self._save_session(session)
+                    await self.store.mark_ui_auth_stage_complete(
+                        session.session_id, login_type, result
+                    )
             except LoginError as e:
                 if login_type == LoginType.EMAIL_IDENTITY:
                     # riot used to have a bug where it would request a new
@@ -367,6 +371,7 @@ class AuthHandler(BaseHandler):
                 # so that the client can have another go.
                 errordict = e.error_dict()
 
+        creds = await self.store.get_completed_ui_auth_stages(session.session_id)
         for f in flows:
             if len(set(f) - set(creds)) == 0:
                 # it's very useful to know what args are stored, but this can
@@ -380,9 +385,9 @@ class AuthHandler(BaseHandler):
                     list(clientdict),
                 )
 
-                return creds, clientdict, session_id
+                return creds, clientdict, session.session_id
 
-        ret = self._auth_dict_for_flows(flows, session_id)
+        ret = self._auth_dict_for_flows(flows, session.session_id)
         ret["completed"] = list(creds)
         ret.update(errordict)
         raise InteractiveAuthIncompleteError(ret)
@@ -399,13 +404,11 @@ class AuthHandler(BaseHandler):
         if "session" not in authdict:
             raise LoginError(400, "", Codes.MISSING_PARAM)
 
-        sess = self._get_session_info(authdict["session"])
-        creds = sess["creds"]
-
         result = await self.checkers[stagetype].check_auth(authdict, clientip)
         if result:
-            creds[stagetype] = result
-            self._save_session(sess)
+            await self.store.mark_ui_auth_stage_complete(
+                authdict["session"], stagetype, result
+            )
             return True
         return False
 
@@ -427,7 +430,7 @@ class AuthHandler(BaseHandler):
                 sid = authdict["session"]
         return sid
 
-    def set_session_data(self, session_id: str, key: str, value: Any) -> None:
+    async def set_session_data(self, session_id: str, key: str, value: Any) -> None:
         """
         Store a key-value pair into the sessions data associated with this
         request. This data is stored server-side and cannot be modified by
@@ -438,11 +441,12 @@ class AuthHandler(BaseHandler):
             key: The key to store the data under
             value: The data to store
         """
-        sess = self._get_session_info(session_id)
-        sess["serverdict"][key] = value
-        self._save_session(sess)
+        try:
+            await self.store.set_ui_auth_session_data(session_id, key, value)
+        except StoreError:
+            raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
 
-    def get_session_data(
+    async def get_session_data(
         self, session_id: str, key: str, default: Optional[Any] = None
     ) -> Any:
         """
@@ -453,8 +457,18 @@ class AuthHandler(BaseHandler):
             key: The key to store the data under
             default: Value to return if the key has not been set
         """
-        sess = self._get_session_info(session_id)
-        return sess["serverdict"].get(key, default)
+        try:
+            return await self.store.get_ui_auth_session_data(session_id, key, default)
+        except StoreError:
+            raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
+
+    async def _expire_old_sessions(self):
+        """
+        Invalidate any user interactive authentication sessions that have expired.
+        """
+        now = self._clock.time_msec()
+        expiration_time = now - self.SESSION_EXPIRE_MS
+        await self.store.delete_old_ui_auth_sessions(expiration_time)
 
     async def _check_auth_dict(
         self, authdict: Dict[str, Any], clientip: str
@@ -534,67 +548,6 @@ class AuthHandler(BaseHandler):
             "params": params,
         }
 
-    def _create_session(
-        self,
-        clientdict: Dict[str, Any],
-        ui_auth: Tuple[bytes, bytes, Dict[str, Any]],
-        description: str,
-    ) -> dict:
-        """
-        Creates a new user interactive authentication session.
-
-        The session can be used to track data across multiple requests, e.g. for
-        interactive authentication.
-
-        Each session has the following keys:
-
-            id:
-                A unique identifier for this session. Passed back to the client
-                and returned for each stage.
-            clientdict:
-                The dictionary from the client root level, not the 'auth' key.
-            ui_auth:
-                A tuple which is checked at each stage of the authentication to
-                ensure that the asked for operation has not changed.
-            creds:
-                A map, which maps each auth-type (str) to the relevant identity
-                authenticated by that auth-type (mostly str, but for captcha, bool).
-            serverdict:
-                A map of data that is stored server-side and cannot be modified
-                by the client.
-            description:
-                A string description of the operation that the current
-                authentication is authorising.
-    Returns:
-        The newly created session.
-        """
-        session_id = None
-        while session_id is None or session_id in self.sessions:
-            session_id = stringutils.random_string(24)
-
-        self.sessions[session_id] = {
-            "id": session_id,
-            "clientdict": clientdict,
-            "ui_auth": ui_auth,
-            "creds": {},
-            "serverdict": {},
-            "description": description,
-        }
-
-        return self.sessions[session_id]
-
-    def _get_session_info(self, session_id: str) -> dict:
-        """
-        Gets a session given a session ID.
-
-        The session can be used to track data across multiple requests, e.g. for
-        interactive authentication.
-        """
-        try:
-            return self.sessions[session_id]
-        except KeyError:
-            raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
-
     async def get_access_token_for_user_id(
         self, user_id: str, device_id: Optional[str], valid_until_ms: Optional[int]
     ):
@@ -994,13 +947,6 @@ class AuthHandler(BaseHandler):
         await self.store.user_delete_threepid(user_id, medium, address)
         return result
 
-    def _save_session(self, session: Dict[str, Any]) -> None:
-        """Update the last used time on the session to now and add it back to the session store."""
-        # TODO: Persistent storage
-        logger.debug("Saving session %s", session)
-        session["last_used"] = self.hs.get_clock().time_msec()
-        self.sessions[session["id"]] = session
-
     async def hash(self, password: str) -> str:
         """Computes a secure hash of password.
 
@@ -1052,7 +998,7 @@ class AuthHandler(BaseHandler):
         else:
             return False
 
-    def start_sso_ui_auth(self, redirect_url: str, session_id: str) -> str:
+    async def start_sso_ui_auth(self, redirect_url: str, session_id: str) -> str:
         """
         Get the HTML for the SSO redirect confirmation page.
 
@@ -1063,12 +1009,15 @@ class AuthHandler(BaseHandler):
         Returns:
             The HTML to render.
         """
-        session = self._get_session_info(session_id)
+        try:
+            session = await self.store.get_ui_auth_session(session_id)
+        except StoreError:
+            raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
         return self._sso_auth_confirm_template.render(
-            description=session["description"], redirect_url=redirect_url,
+            description=session.description, redirect_url=redirect_url,
         )
 
-    def complete_sso_ui_auth(
+    async def complete_sso_ui_auth(
         self, registered_user_id: str, session_id: str, request: SynapseRequest,
     ):
         """Having figured out a mxid for this user, complete the HTTP request
@@ -1080,13 +1029,11 @@ class AuthHandler(BaseHandler):
                 process.
         """
         # Mark the stage of the authentication as successful.
-        sess = self._get_session_info(session_id)
-        creds = sess["creds"]
-
         # Save the user who authenticated with SSO, this will be used to ensure
         # that the account be modified is also the person who logged in.
-        creds[LoginType.SSO] = registered_user_id
-        self._save_session(sess)
+        await self.store.mark_ui_auth_stage_complete(
+            session_id, LoginType.SSO, registered_user_id
+        )
 
         # Render the HTML and return.
         html_bytes = self._sso_auth_success_template.encode("utf-8")