diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index dbe165ce1e..7613e5b6ab 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -41,10 +41,10 @@ from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
from synapse.http.server import finish_request
from synapse.http.site import SynapseRequest
from synapse.logging.context import defer_to_thread
+from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.module_api import ModuleApi
from synapse.push.mailer import load_jinja2_templates
from synapse.types import Requester, UserID
-from synapse.util.caches.expiringcache import ExpiringCache
from ._base import BaseHandler
@@ -69,15 +69,6 @@ class AuthHandler(BaseHandler):
self.bcrypt_rounds = hs.config.bcrypt_rounds
- # This is not a cache per se, but a store of all current sessions that
- # expire after N hours
- self.sessions = ExpiringCache(
- cache_name="register_sessions",
- clock=hs.get_clock(),
- expiry_ms=self.SESSION_EXPIRE_MS,
- reset_expiry_on_get=True,
- )
-
account_handler = ModuleApi(hs, self)
self.password_providers = [
module(config=config, account_handler=account_handler)
@@ -119,6 +110,15 @@ class AuthHandler(BaseHandler):
self._clock = self.hs.get_clock()
+ # Expire old UI auth sessions after a period of time.
+ if hs.config.worker_app is None:
+ self._clock.looping_call(
+ run_as_background_process,
+ 5 * 60 * 1000,
+ "expire_old_sessions",
+ self._expire_old_sessions,
+ )
+
# Load the SSO HTML templates.
# The following template is shown to the user during a client login via SSO,
@@ -301,16 +301,21 @@ class AuthHandler(BaseHandler):
if "session" in authdict:
sid = authdict["session"]
+ # Convert the URI and method to strings.
+ uri = request.uri.decode("utf-8")
+ method = request.uri.decode("utf-8")
+
# If there's no session ID, create a new session.
if not sid:
- session = self._create_session(
- clientdict, (request.uri, request.method, clientdict), description
+ session = await self.store.create_ui_auth_session(
+ clientdict, uri, method, description
)
- session_id = session["id"]
else:
- session = self._get_session_info(sid)
- session_id = sid
+ try:
+ session = await self.store.get_ui_auth_session(sid)
+ except StoreError:
+ raise SynapseError(400, "Unknown session ID: %s" % (sid,))
if not clientdict:
# This was designed to allow the client to omit the parameters
@@ -322,15 +327,15 @@ class AuthHandler(BaseHandler):
# on a homeserver.
# Revisit: Assuming the REST APIs do sensible validation, the data
# isn't arbitrary.
- clientdict = session["clientdict"]
+ clientdict = session.clientdict
# Ensure that the queried operation does not vary between stages of
# the UI authentication session. This is done by generating a stable
# comparator based on the URI, method, and body (minus the auth dict)
# and storing it during the initial query. Subsequent queries ensure
# that this comparator has not changed.
- comparator = (request.uri, request.method, clientdict)
- if session["ui_auth"] != comparator:
+ comparator = (uri, method, clientdict)
+ if (session.uri, session.method, session.clientdict) != comparator:
raise SynapseError(
403,
"Requested operation has changed during the UI authentication session.",
@@ -338,11 +343,9 @@ class AuthHandler(BaseHandler):
if not authdict:
raise InteractiveAuthIncompleteError(
- self._auth_dict_for_flows(flows, session_id)
+ self._auth_dict_for_flows(flows, session.session_id)
)
- creds = session["creds"]
-
# check auth type currently being presented
errordict = {} # type: Dict[str, Any]
if "type" in authdict:
@@ -350,8 +353,9 @@ class AuthHandler(BaseHandler):
try:
result = await self._check_auth_dict(authdict, clientip)
if result:
- creds[login_type] = result
- self._save_session(session)
+ await self.store.mark_ui_auth_stage_complete(
+ session.session_id, login_type, result
+ )
except LoginError as e:
if login_type == LoginType.EMAIL_IDENTITY:
# riot used to have a bug where it would request a new
@@ -367,6 +371,7 @@ class AuthHandler(BaseHandler):
# so that the client can have another go.
errordict = e.error_dict()
+ creds = await self.store.get_completed_ui_auth_stages(session.session_id)
for f in flows:
if len(set(f) - set(creds)) == 0:
# it's very useful to know what args are stored, but this can
@@ -380,9 +385,9 @@ class AuthHandler(BaseHandler):
list(clientdict),
)
- return creds, clientdict, session_id
+ return creds, clientdict, session.session_id
- ret = self._auth_dict_for_flows(flows, session_id)
+ ret = self._auth_dict_for_flows(flows, session.session_id)
ret["completed"] = list(creds)
ret.update(errordict)
raise InteractiveAuthIncompleteError(ret)
@@ -399,13 +404,11 @@ class AuthHandler(BaseHandler):
if "session" not in authdict:
raise LoginError(400, "", Codes.MISSING_PARAM)
- sess = self._get_session_info(authdict["session"])
- creds = sess["creds"]
-
result = await self.checkers[stagetype].check_auth(authdict, clientip)
if result:
- creds[stagetype] = result
- self._save_session(sess)
+ await self.store.mark_ui_auth_stage_complete(
+ authdict["session"], stagetype, result
+ )
return True
return False
@@ -427,7 +430,7 @@ class AuthHandler(BaseHandler):
sid = authdict["session"]
return sid
- def set_session_data(self, session_id: str, key: str, value: Any) -> None:
+ async def set_session_data(self, session_id: str, key: str, value: Any) -> None:
"""
Store a key-value pair into the sessions data associated with this
request. This data is stored server-side and cannot be modified by
@@ -438,11 +441,12 @@ class AuthHandler(BaseHandler):
key: The key to store the data under
value: The data to store
"""
- sess = self._get_session_info(session_id)
- sess["serverdict"][key] = value
- self._save_session(sess)
+ try:
+ await self.store.set_ui_auth_session_data(session_id, key, value)
+ except StoreError:
+ raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
- def get_session_data(
+ async def get_session_data(
self, session_id: str, key: str, default: Optional[Any] = None
) -> Any:
"""
@@ -453,8 +457,18 @@ class AuthHandler(BaseHandler):
key: The key to store the data under
default: Value to return if the key has not been set
"""
- sess = self._get_session_info(session_id)
- return sess["serverdict"].get(key, default)
+ try:
+ return await self.store.get_ui_auth_session_data(session_id, key, default)
+ except StoreError:
+ raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
+
+ async def _expire_old_sessions(self):
+ """
+ Invalidate any user interactive authentication sessions that have expired.
+ """
+ now = self._clock.time_msec()
+ expiration_time = now - self.SESSION_EXPIRE_MS
+ await self.store.delete_old_ui_auth_sessions(expiration_time)
async def _check_auth_dict(
self, authdict: Dict[str, Any], clientip: str
@@ -534,67 +548,6 @@ class AuthHandler(BaseHandler):
"params": params,
}
- def _create_session(
- self,
- clientdict: Dict[str, Any],
- ui_auth: Tuple[bytes, bytes, Dict[str, Any]],
- description: str,
- ) -> dict:
- """
- Creates a new user interactive authentication session.
-
- The session can be used to track data across multiple requests, e.g. for
- interactive authentication.
-
- Each session has the following keys:
-
- id:
- A unique identifier for this session. Passed back to the client
- and returned for each stage.
- clientdict:
- The dictionary from the client root level, not the 'auth' key.
- ui_auth:
- A tuple which is checked at each stage of the authentication to
- ensure that the asked for operation has not changed.
- creds:
- A map, which maps each auth-type (str) to the relevant identity
- authenticated by that auth-type (mostly str, but for captcha, bool).
- serverdict:
- A map of data that is stored server-side and cannot be modified
- by the client.
- description:
- A string description of the operation that the current
- authentication is authorising.
- Returns:
- The newly created session.
- """
- session_id = None
- while session_id is None or session_id in self.sessions:
- session_id = stringutils.random_string(24)
-
- self.sessions[session_id] = {
- "id": session_id,
- "clientdict": clientdict,
- "ui_auth": ui_auth,
- "creds": {},
- "serverdict": {},
- "description": description,
- }
-
- return self.sessions[session_id]
-
- def _get_session_info(self, session_id: str) -> dict:
- """
- Gets a session given a session ID.
-
- The session can be used to track data across multiple requests, e.g. for
- interactive authentication.
- """
- try:
- return self.sessions[session_id]
- except KeyError:
- raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
-
async def get_access_token_for_user_id(
self, user_id: str, device_id: Optional[str], valid_until_ms: Optional[int]
):
@@ -994,13 +947,6 @@ class AuthHandler(BaseHandler):
await self.store.user_delete_threepid(user_id, medium, address)
return result
- def _save_session(self, session: Dict[str, Any]) -> None:
- """Update the last used time on the session to now and add it back to the session store."""
- # TODO: Persistent storage
- logger.debug("Saving session %s", session)
- session["last_used"] = self.hs.get_clock().time_msec()
- self.sessions[session["id"]] = session
-
async def hash(self, password: str) -> str:
"""Computes a secure hash of password.
@@ -1052,7 +998,7 @@ class AuthHandler(BaseHandler):
else:
return False
- def start_sso_ui_auth(self, redirect_url: str, session_id: str) -> str:
+ async def start_sso_ui_auth(self, redirect_url: str, session_id: str) -> str:
"""
Get the HTML for the SSO redirect confirmation page.
@@ -1063,12 +1009,15 @@ class AuthHandler(BaseHandler):
Returns:
The HTML to render.
"""
- session = self._get_session_info(session_id)
+ try:
+ session = await self.store.get_ui_auth_session(session_id)
+ except StoreError:
+ raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
return self._sso_auth_confirm_template.render(
- description=session["description"], redirect_url=redirect_url,
+ description=session.description, redirect_url=redirect_url,
)
- def complete_sso_ui_auth(
+ async def complete_sso_ui_auth(
self, registered_user_id: str, session_id: str, request: SynapseRequest,
):
"""Having figured out a mxid for this user, complete the HTTP request
@@ -1080,13 +1029,11 @@ class AuthHandler(BaseHandler):
process.
"""
# Mark the stage of the authentication as successful.
- sess = self._get_session_info(session_id)
- creds = sess["creds"]
-
# Save the user who authenticated with SSO, this will be used to ensure
# that the account be modified is also the person who logged in.
- creds[LoginType.SSO] = registered_user_id
- self._save_session(sess)
+ await self.store.mark_ui_auth_stage_complete(
+ session_id, LoginType.SSO, registered_user_id
+ )
# Render the HTML and return.
html_bytes = self._sso_auth_success_template.encode("utf-8")
diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py
index 5cb3f9d133..64aaa1335c 100644
--- a/synapse/handlers/cas_handler.py
+++ b/synapse/handlers/cas_handler.py
@@ -206,7 +206,7 @@ class CasHandler:
registered_user_id = await self._auth_handler.check_user_exists(user_id)
if session:
- self._auth_handler.complete_sso_ui_auth(
+ await self._auth_handler.complete_sso_ui_auth(
registered_user_id, session, request,
)
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index 7c9454b504..96f2dd36ad 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -149,7 +149,7 @@ class SamlHandler:
# Complete the interactive auth session or the login.
if current_session and current_session.ui_auth_session_id:
- self._auth_handler.complete_sso_ui_auth(
+ await self._auth_handler.complete_sso_ui_auth(
user_id, current_session.ui_auth_session_id, request
)
|