diff --git a/changelog.d/7302.bugfix b/changelog.d/7302.bugfix
new file mode 100644
index 0000000000..820646d1f9
--- /dev/null
+++ b/changelog.d/7302.bugfix
@@ -0,0 +1 @@
+Persist user interactive authentication sessions across workers and Synapse restarts.
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index d125327f08..0ace7b787d 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -127,6 +127,7 @@ from synapse.storage.data_stores.main.monthly_active_users import (
MonthlyActiveUsersWorkerStore,
)
from synapse.storage.data_stores.main.presence import UserPresenceState
+from synapse.storage.data_stores.main.ui_auth import UIAuthWorkerStore
from synapse.storage.data_stores.main.user_directory import UserDirectoryStore
from synapse.types import ReadReceipt
from synapse.util.async_helpers import Linearizer
@@ -439,6 +440,7 @@ class GenericWorkerSlavedStore(
# FIXME(#3714): We need to add UserDirectoryStore as we write directly
# rather than going via the correct worker.
UserDirectoryStore,
+ UIAuthWorkerStore,
SlavedDeviceInboxStore,
SlavedDeviceStore,
SlavedReceiptsStore,
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index dbe165ce1e..7613e5b6ab 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -41,10 +41,10 @@ from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
from synapse.http.server import finish_request
from synapse.http.site import SynapseRequest
from synapse.logging.context import defer_to_thread
+from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.module_api import ModuleApi
from synapse.push.mailer import load_jinja2_templates
from synapse.types import Requester, UserID
-from synapse.util.caches.expiringcache import ExpiringCache
from ._base import BaseHandler
@@ -69,15 +69,6 @@ class AuthHandler(BaseHandler):
self.bcrypt_rounds = hs.config.bcrypt_rounds
- # This is not a cache per se, but a store of all current sessions that
- # expire after N hours
- self.sessions = ExpiringCache(
- cache_name="register_sessions",
- clock=hs.get_clock(),
- expiry_ms=self.SESSION_EXPIRE_MS,
- reset_expiry_on_get=True,
- )
-
account_handler = ModuleApi(hs, self)
self.password_providers = [
module(config=config, account_handler=account_handler)
@@ -119,6 +110,15 @@ class AuthHandler(BaseHandler):
self._clock = self.hs.get_clock()
+ # Expire old UI auth sessions after a period of time.
+ if hs.config.worker_app is None:
+ self._clock.looping_call(
+ run_as_background_process,
+ 5 * 60 * 1000,
+ "expire_old_sessions",
+ self._expire_old_sessions,
+ )
+
# Load the SSO HTML templates.
# The following template is shown to the user during a client login via SSO,
@@ -301,16 +301,21 @@ class AuthHandler(BaseHandler):
if "session" in authdict:
sid = authdict["session"]
+ # Convert the URI and method to strings.
+ uri = request.uri.decode("utf-8")
+ method = request.uri.decode("utf-8")
+
# If there's no session ID, create a new session.
if not sid:
- session = self._create_session(
- clientdict, (request.uri, request.method, clientdict), description
+ session = await self.store.create_ui_auth_session(
+ clientdict, uri, method, description
)
- session_id = session["id"]
else:
- session = self._get_session_info(sid)
- session_id = sid
+ try:
+ session = await self.store.get_ui_auth_session(sid)
+ except StoreError:
+ raise SynapseError(400, "Unknown session ID: %s" % (sid,))
if not clientdict:
# This was designed to allow the client to omit the parameters
@@ -322,15 +327,15 @@ class AuthHandler(BaseHandler):
# on a homeserver.
# Revisit: Assuming the REST APIs do sensible validation, the data
# isn't arbitrary.
- clientdict = session["clientdict"]
+ clientdict = session.clientdict
# Ensure that the queried operation does not vary between stages of
# the UI authentication session. This is done by generating a stable
# comparator based on the URI, method, and body (minus the auth dict)
# and storing it during the initial query. Subsequent queries ensure
# that this comparator has not changed.
- comparator = (request.uri, request.method, clientdict)
- if session["ui_auth"] != comparator:
+ comparator = (uri, method, clientdict)
+ if (session.uri, session.method, session.clientdict) != comparator:
raise SynapseError(
403,
"Requested operation has changed during the UI authentication session.",
@@ -338,11 +343,9 @@ class AuthHandler(BaseHandler):
if not authdict:
raise InteractiveAuthIncompleteError(
- self._auth_dict_for_flows(flows, session_id)
+ self._auth_dict_for_flows(flows, session.session_id)
)
- creds = session["creds"]
-
# check auth type currently being presented
errordict = {} # type: Dict[str, Any]
if "type" in authdict:
@@ -350,8 +353,9 @@ class AuthHandler(BaseHandler):
try:
result = await self._check_auth_dict(authdict, clientip)
if result:
- creds[login_type] = result
- self._save_session(session)
+ await self.store.mark_ui_auth_stage_complete(
+ session.session_id, login_type, result
+ )
except LoginError as e:
if login_type == LoginType.EMAIL_IDENTITY:
# riot used to have a bug where it would request a new
@@ -367,6 +371,7 @@ class AuthHandler(BaseHandler):
# so that the client can have another go.
errordict = e.error_dict()
+ creds = await self.store.get_completed_ui_auth_stages(session.session_id)
for f in flows:
if len(set(f) - set(creds)) == 0:
# it's very useful to know what args are stored, but this can
@@ -380,9 +385,9 @@ class AuthHandler(BaseHandler):
list(clientdict),
)
- return creds, clientdict, session_id
+ return creds, clientdict, session.session_id
- ret = self._auth_dict_for_flows(flows, session_id)
+ ret = self._auth_dict_for_flows(flows, session.session_id)
ret["completed"] = list(creds)
ret.update(errordict)
raise InteractiveAuthIncompleteError(ret)
@@ -399,13 +404,11 @@ class AuthHandler(BaseHandler):
if "session" not in authdict:
raise LoginError(400, "", Codes.MISSING_PARAM)
- sess = self._get_session_info(authdict["session"])
- creds = sess["creds"]
-
result = await self.checkers[stagetype].check_auth(authdict, clientip)
if result:
- creds[stagetype] = result
- self._save_session(sess)
+ await self.store.mark_ui_auth_stage_complete(
+ authdict["session"], stagetype, result
+ )
return True
return False
@@ -427,7 +430,7 @@ class AuthHandler(BaseHandler):
sid = authdict["session"]
return sid
- def set_session_data(self, session_id: str, key: str, value: Any) -> None:
+ async def set_session_data(self, session_id: str, key: str, value: Any) -> None:
"""
Store a key-value pair into the sessions data associated with this
request. This data is stored server-side and cannot be modified by
@@ -438,11 +441,12 @@ class AuthHandler(BaseHandler):
key: The key to store the data under
value: The data to store
"""
- sess = self._get_session_info(session_id)
- sess["serverdict"][key] = value
- self._save_session(sess)
+ try:
+ await self.store.set_ui_auth_session_data(session_id, key, value)
+ except StoreError:
+ raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
- def get_session_data(
+ async def get_session_data(
self, session_id: str, key: str, default: Optional[Any] = None
) -> Any:
"""
@@ -453,8 +457,18 @@ class AuthHandler(BaseHandler):
key: The key to store the data under
default: Value to return if the key has not been set
"""
- sess = self._get_session_info(session_id)
- return sess["serverdict"].get(key, default)
+ try:
+ return await self.store.get_ui_auth_session_data(session_id, key, default)
+ except StoreError:
+ raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
+
+ async def _expire_old_sessions(self):
+ """
+ Invalidate any user interactive authentication sessions that have expired.
+ """
+ now = self._clock.time_msec()
+ expiration_time = now - self.SESSION_EXPIRE_MS
+ await self.store.delete_old_ui_auth_sessions(expiration_time)
async def _check_auth_dict(
self, authdict: Dict[str, Any], clientip: str
@@ -534,67 +548,6 @@ class AuthHandler(BaseHandler):
"params": params,
}
- def _create_session(
- self,
- clientdict: Dict[str, Any],
- ui_auth: Tuple[bytes, bytes, Dict[str, Any]],
- description: str,
- ) -> dict:
- """
- Creates a new user interactive authentication session.
-
- The session can be used to track data across multiple requests, e.g. for
- interactive authentication.
-
- Each session has the following keys:
-
- id:
- A unique identifier for this session. Passed back to the client
- and returned for each stage.
- clientdict:
- The dictionary from the client root level, not the 'auth' key.
- ui_auth:
- A tuple which is checked at each stage of the authentication to
- ensure that the asked for operation has not changed.
- creds:
- A map, which maps each auth-type (str) to the relevant identity
- authenticated by that auth-type (mostly str, but for captcha, bool).
- serverdict:
- A map of data that is stored server-side and cannot be modified
- by the client.
- description:
- A string description of the operation that the current
- authentication is authorising.
- Returns:
- The newly created session.
- """
- session_id = None
- while session_id is None or session_id in self.sessions:
- session_id = stringutils.random_string(24)
-
- self.sessions[session_id] = {
- "id": session_id,
- "clientdict": clientdict,
- "ui_auth": ui_auth,
- "creds": {},
- "serverdict": {},
- "description": description,
- }
-
- return self.sessions[session_id]
-
- def _get_session_info(self, session_id: str) -> dict:
- """
- Gets a session given a session ID.
-
- The session can be used to track data across multiple requests, e.g. for
- interactive authentication.
- """
- try:
- return self.sessions[session_id]
- except KeyError:
- raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
-
async def get_access_token_for_user_id(
self, user_id: str, device_id: Optional[str], valid_until_ms: Optional[int]
):
@@ -994,13 +947,6 @@ class AuthHandler(BaseHandler):
await self.store.user_delete_threepid(user_id, medium, address)
return result
- def _save_session(self, session: Dict[str, Any]) -> None:
- """Update the last used time on the session to now and add it back to the session store."""
- # TODO: Persistent storage
- logger.debug("Saving session %s", session)
- session["last_used"] = self.hs.get_clock().time_msec()
- self.sessions[session["id"]] = session
-
async def hash(self, password: str) -> str:
"""Computes a secure hash of password.
@@ -1052,7 +998,7 @@ class AuthHandler(BaseHandler):
else:
return False
- def start_sso_ui_auth(self, redirect_url: str, session_id: str) -> str:
+ async def start_sso_ui_auth(self, redirect_url: str, session_id: str) -> str:
"""
Get the HTML for the SSO redirect confirmation page.
@@ -1063,12 +1009,15 @@ class AuthHandler(BaseHandler):
Returns:
The HTML to render.
"""
- session = self._get_session_info(session_id)
+ try:
+ session = await self.store.get_ui_auth_session(session_id)
+ except StoreError:
+ raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
return self._sso_auth_confirm_template.render(
- description=session["description"], redirect_url=redirect_url,
+ description=session.description, redirect_url=redirect_url,
)
- def complete_sso_ui_auth(
+ async def complete_sso_ui_auth(
self, registered_user_id: str, session_id: str, request: SynapseRequest,
):
"""Having figured out a mxid for this user, complete the HTTP request
@@ -1080,13 +1029,11 @@ class AuthHandler(BaseHandler):
process.
"""
# Mark the stage of the authentication as successful.
- sess = self._get_session_info(session_id)
- creds = sess["creds"]
-
# Save the user who authenticated with SSO, this will be used to ensure
# that the account be modified is also the person who logged in.
- creds[LoginType.SSO] = registered_user_id
- self._save_session(sess)
+ await self.store.mark_ui_auth_stage_complete(
+ session_id, LoginType.SSO, registered_user_id
+ )
# Render the HTML and return.
html_bytes = self._sso_auth_success_template.encode("utf-8")
diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py
index 5cb3f9d133..64aaa1335c 100644
--- a/synapse/handlers/cas_handler.py
+++ b/synapse/handlers/cas_handler.py
@@ -206,7 +206,7 @@ class CasHandler:
registered_user_id = await self._auth_handler.check_user_exists(user_id)
if session:
- self._auth_handler.complete_sso_ui_auth(
+ await self._auth_handler.complete_sso_ui_auth(
registered_user_id, session, request,
)
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index 7c9454b504..96f2dd36ad 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -149,7 +149,7 @@ class SamlHandler:
# Complete the interactive auth session or the login.
if current_session and current_session.ui_auth_session_id:
- self._auth_handler.complete_sso_ui_auth(
+ await self._auth_handler.complete_sso_ui_auth(
user_id, current_session.ui_auth_session_id, request
)
diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py
index 11599f5005..24dd3d3e96 100644
--- a/synapse/rest/client/v2_alpha/auth.py
+++ b/synapse/rest/client/v2_alpha/auth.py
@@ -140,7 +140,7 @@ class AuthRestServlet(RestServlet):
self._cas_server_url = hs.config.cas_server_url
self._cas_service_url = hs.config.cas_service_url
- def on_GET(self, request, stagetype):
+ async def on_GET(self, request, stagetype):
session = parse_string(request, "session")
if not session:
raise SynapseError(400, "No session supplied")
@@ -180,7 +180,7 @@ class AuthRestServlet(RestServlet):
else:
raise SynapseError(400, "Homeserver not configured for SSO.")
- html = self.auth_handler.start_sso_ui_auth(sso_redirect_url, session)
+ html = await self.auth_handler.start_sso_ui_auth(sso_redirect_url, session)
else:
raise SynapseError(404, "Unknown auth stage type")
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index d1b5c49989..af08cc6cce 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -499,7 +499,7 @@ class RegisterRestServlet(RestServlet):
# registered a user for this session, so we could just return the
# user here. We carry on and go through the auth checks though,
# for paranoia.
- registered_user_id = self.auth_handler.get_session_data(
+ registered_user_id = await self.auth_handler.get_session_data(
session_id, "registered_user_id", None
)
@@ -598,7 +598,7 @@ class RegisterRestServlet(RestServlet):
# remember that we've now registered that user account, and with
# what user ID (since the user may not have specified)
- self.auth_handler.set_session_data(
+ await self.auth_handler.set_session_data(
session_id, "registered_user_id", registered_user_id
)
diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py
index bd7c3a00ea..ceba10882c 100644
--- a/synapse/storage/data_stores/main/__init__.py
+++ b/synapse/storage/data_stores/main/__init__.py
@@ -66,6 +66,7 @@ from .stats import StatsStore
from .stream import StreamStore
from .tags import TagsStore
from .transactions import TransactionStore
+from .ui_auth import UIAuthStore
from .user_directory import UserDirectoryStore
from .user_erasure_store import UserErasureStore
@@ -112,6 +113,7 @@ class DataStore(
StatsStore,
RelationsStore,
CacheInvalidationStore,
+ UIAuthStore,
):
def __init__(self, database: Database, db_conn, hs):
self.hs = hs
diff --git a/synapse/storage/data_stores/main/schema/delta/58/03persist_ui_auth.sql b/synapse/storage/data_stores/main/schema/delta/58/03persist_ui_auth.sql
new file mode 100644
index 0000000000..dcb593fc2d
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/58/03persist_ui_auth.sql
@@ -0,0 +1,36 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS ui_auth_sessions(
+ session_id TEXT NOT NULL, -- The session ID passed to the client.
+ creation_time BIGINT NOT NULL, -- The time this session was created (epoch time in milliseconds).
+ serverdict TEXT NOT NULL, -- A JSON dictionary of arbitrary data added by Synapse.
+ clientdict TEXT NOT NULL, -- A JSON dictionary of arbitrary data from the client.
+ uri TEXT NOT NULL, -- The URI the UI authentication session is using.
+ method TEXT NOT NULL, -- The HTTP method the UI authentication session is using.
+ -- The clientdict, uri, and method make up an tuple that must be immutable
+ -- throughout the lifetime of the UI Auth session.
+ description TEXT NOT NULL, -- A human readable description of the operation which caused the UI Auth flow to occur.
+ UNIQUE (session_id)
+);
+
+CREATE TABLE IF NOT EXISTS ui_auth_sessions_credentials(
+ session_id TEXT NOT NULL, -- The corresponding UI Auth session.
+ stage_type TEXT NOT NULL, -- The stage type.
+ result TEXT NOT NULL, -- The result of the stage verification, stored as JSON.
+ UNIQUE (session_id, stage_type),
+ FOREIGN KEY (session_id)
+ REFERENCES ui_auth_sessions (session_id)
+);
diff --git a/synapse/storage/data_stores/main/ui_auth.py b/synapse/storage/data_stores/main/ui_auth.py
new file mode 100644
index 0000000000..c8eebc9378
--- /dev/null
+++ b/synapse/storage/data_stores/main/ui_auth.py
@@ -0,0 +1,279 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import json
+from typing import Any, Dict, Optional, Union
+
+import attr
+
+import synapse.util.stringutils as stringutils
+from synapse.api.errors import StoreError
+from synapse.storage._base import SQLBaseStore
+from synapse.types import JsonDict
+
+
+@attr.s
+class UIAuthSessionData:
+ session_id = attr.ib(type=str)
+ # The dictionary from the client root level, not the 'auth' key.
+ clientdict = attr.ib(type=JsonDict)
+ # The URI and method the session was intiatied with. These are checked at
+ # each stage of the authentication to ensure that the asked for operation
+ # has not changed.
+ uri = attr.ib(type=str)
+ method = attr.ib(type=str)
+ # A string description of the operation that the current authentication is
+ # authorising.
+ description = attr.ib(type=str)
+
+
+class UIAuthWorkerStore(SQLBaseStore):
+ """
+ Manage user interactive authentication sessions.
+ """
+
+ async def create_ui_auth_session(
+ self, clientdict: JsonDict, uri: str, method: str, description: str,
+ ) -> UIAuthSessionData:
+ """
+ Creates a new user interactive authentication session.
+
+ The session can be used to track the stages necessary to authenticate a
+ user across multiple HTTP requests.
+
+ Args:
+ clientdict:
+ The dictionary from the client root level, not the 'auth' key.
+ uri:
+ The URI this session was initiated with, this is checked at each
+ stage of the authentication to ensure that the asked for
+ operation has not changed.
+ method:
+ The method this session was initiated with, this is checked at each
+ stage of the authentication to ensure that the asked for
+ operation has not changed.
+ description:
+ A string description of the operation that the current
+ authentication is authorising.
+ Returns:
+ The newly created session.
+ Raises:
+ StoreError if a unique session ID cannot be generated.
+ """
+ # The clientdict gets stored as JSON.
+ clientdict_json = json.dumps(clientdict)
+
+ # autogen a session ID and try to create it. We may clash, so just
+ # try a few times till one goes through, giving up eventually.
+ attempts = 0
+ while attempts < 5:
+ session_id = stringutils.random_string(24)
+
+ try:
+ await self.db.simple_insert(
+ table="ui_auth_sessions",
+ values={
+ "session_id": session_id,
+ "clientdict": clientdict_json,
+ "uri": uri,
+ "method": method,
+ "description": description,
+ "serverdict": "{}",
+ "creation_time": self.hs.get_clock().time_msec(),
+ },
+ desc="create_ui_auth_session",
+ )
+ return UIAuthSessionData(
+ session_id, clientdict, uri, method, description
+ )
+ except self.db.engine.module.IntegrityError:
+ attempts += 1
+ raise StoreError(500, "Couldn't generate a session ID.")
+
+ async def get_ui_auth_session(self, session_id: str) -> UIAuthSessionData:
+ """Retrieve a UI auth session.
+
+ Args:
+ session_id: The ID of the session.
+ Returns:
+ A dict containing the device information.
+ Raises:
+ StoreError if the session is not found.
+ """
+ result = await self.db.simple_select_one(
+ table="ui_auth_sessions",
+ keyvalues={"session_id": session_id},
+ retcols=("clientdict", "uri", "method", "description"),
+ desc="get_ui_auth_session",
+ )
+
+ result["clientdict"] = json.loads(result["clientdict"])
+
+ return UIAuthSessionData(session_id, **result)
+
+ async def mark_ui_auth_stage_complete(
+ self, session_id: str, stage_type: str, result: Union[str, bool, JsonDict],
+ ):
+ """
+ Mark a session stage as completed.
+
+ Args:
+ session_id: The ID of the corresponding session.
+ stage_type: The completed stage type.
+ result: The result of the stage verification.
+ Raises:
+ StoreError if the session cannot be found.
+ """
+ # Add (or update) the results of the current stage to the database.
+ #
+ # Note that we need to allow for the same stage to complete multiple
+ # times here so that registration is idempotent.
+ try:
+ await self.db.simple_upsert(
+ table="ui_auth_sessions_credentials",
+ keyvalues={"session_id": session_id, "stage_type": stage_type},
+ values={"result": json.dumps(result)},
+ desc="mark_ui_auth_stage_complete",
+ )
+ except self.db.engine.module.IntegrityError:
+ raise StoreError(400, "Unknown session ID: %s" % (session_id,))
+
+ async def get_completed_ui_auth_stages(
+ self, session_id: str
+ ) -> Dict[str, Union[str, bool, JsonDict]]:
+ """
+ Retrieve the completed stages of a UI authentication session.
+
+ Args:
+ session_id: The ID of the session.
+ Returns:
+ The completed stages mapped to the result of the verification of
+ that auth-type.
+ """
+ results = {}
+ for row in await self.db.simple_select_list(
+ table="ui_auth_sessions_credentials",
+ keyvalues={"session_id": session_id},
+ retcols=("stage_type", "result"),
+ desc="get_completed_ui_auth_stages",
+ ):
+ results[row["stage_type"]] = json.loads(row["result"])
+
+ return results
+
+ async def set_ui_auth_session_data(self, session_id: str, key: str, value: Any):
+ """
+ Store a key-value pair into the sessions data associated with this
+ request. This data is stored server-side and cannot be modified by
+ the client.
+
+ Args:
+ session_id: The ID of this session as returned from check_auth
+ key: The key to store the data under
+ value: The data to store
+ Raises:
+ StoreError if the session cannot be found.
+ """
+ await self.db.runInteraction(
+ "set_ui_auth_session_data",
+ self._set_ui_auth_session_data_txn,
+ session_id,
+ key,
+ value,
+ )
+
+ def _set_ui_auth_session_data_txn(self, txn, session_id: str, key: str, value: Any):
+ # Get the current value.
+ result = self.db.simple_select_one_txn(
+ txn,
+ table="ui_auth_sessions",
+ keyvalues={"session_id": session_id},
+ retcols=("serverdict",),
+ )
+
+ # Update it and add it back to the database.
+ serverdict = json.loads(result["serverdict"])
+ serverdict[key] = value
+
+ self.db.simple_update_one_txn(
+ txn,
+ table="ui_auth_sessions",
+ keyvalues={"session_id": session_id},
+ updatevalues={"serverdict": json.dumps(serverdict)},
+ )
+
+ async def get_ui_auth_session_data(
+ self, session_id: str, key: str, default: Optional[Any] = None
+ ) -> Any:
+ """
+ Retrieve data stored with set_session_data
+
+ Args:
+ session_id: The ID of this session as returned from check_auth
+ key: The key to store the data under
+ default: Value to return if the key has not been set
+ Raises:
+ StoreError if the session cannot be found.
+ """
+ result = await self.db.simple_select_one(
+ table="ui_auth_sessions",
+ keyvalues={"session_id": session_id},
+ retcols=("serverdict",),
+ desc="get_ui_auth_session_data",
+ )
+
+ serverdict = json.loads(result["serverdict"])
+
+ return serverdict.get(key, default)
+
+
+class UIAuthStore(UIAuthWorkerStore):
+ def delete_old_ui_auth_sessions(self, expiration_time: int):
+ """
+ Remove sessions which were last used earlier than the expiration time.
+
+ Args:
+ expiration_time: The latest time that is still considered valid.
+ This is an epoch time in milliseconds.
+
+ """
+ return self.db.runInteraction(
+ "delete_old_ui_auth_sessions",
+ self._delete_old_ui_auth_sessions_txn,
+ expiration_time,
+ )
+
+ def _delete_old_ui_auth_sessions_txn(self, txn, expiration_time: int):
+ # Get the expired sessions.
+ sql = "SELECT session_id FROM ui_auth_sessions WHERE creation_time <= ?"
+ txn.execute(sql, [expiration_time])
+ session_ids = [r[0] for r in txn.fetchall()]
+
+ # Delete the corresponding completed credentials.
+ self.db.simple_delete_many_txn(
+ txn,
+ table="ui_auth_sessions_credentials",
+ column="session_id",
+ iterable=session_ids,
+ keyvalues={},
+ )
+
+ # Finally, delete the sessions.
+ self.db.simple_delete_many_txn(
+ txn,
+ table="ui_auth_sessions",
+ column="session_id",
+ iterable=session_ids,
+ keyvalues={},
+ )
diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py
index 3bc2e8b986..215a949442 100644
--- a/synapse/storage/engines/sqlite.py
+++ b/synapse/storage/engines/sqlite.py
@@ -85,6 +85,7 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]):
prepare_database(db_conn, self, config=None)
db_conn.create_function("rank", 1, _rank)
+ db_conn.execute("PRAGMA foreign_keys = ON;")
def is_deadlock(self, error):
return False
diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py
index 624bf5ada2..587be7b2e7 100644
--- a/tests/rest/client/v2_alpha/test_auth.py
+++ b/tests/rest/client/v2_alpha/test_auth.py
@@ -181,3 +181,43 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
)
self.render(request)
self.assertEqual(channel.code, 403)
+
+ def test_complete_operation_unknown_session(self):
+ """
+ Attempting to mark an invalid session as complete should error.
+ """
+
+ # Make the initial request to register. (Later on a different password
+ # will be used.)
+ request, channel = self.make_request(
+ "POST",
+ "register",
+ {"username": "user", "type": "m.login.password", "password": "bar"},
+ )
+ self.render(request)
+
+ # Returns a 401 as per the spec
+ self.assertEqual(request.code, 401)
+ # Grab the session
+ session = channel.json_body["session"]
+ # Assert our configured public key is being given
+ self.assertEqual(
+ channel.json_body["params"]["m.login.recaptcha"]["public_key"], "brokencake"
+ )
+
+ request, channel = self.make_request(
+ "GET", "auth/m.login.recaptcha/fallback/web?session=" + session
+ )
+ self.render(request)
+ self.assertEqual(request.code, 200)
+
+ # Attempt to complete an unknown session, which should return an error.
+ unknown_session = session + "unknown"
+ request, channel = self.make_request(
+ "POST",
+ "auth/m.login.recaptcha/fallback/web?session="
+ + unknown_session
+ + "&g-recaptcha-response=a",
+ )
+ self.render(request)
+ self.assertEqual(request.code, 400)
diff --git a/tests/utils.py b/tests/utils.py
index 037cb134f0..f9be62b499 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -512,8 +512,8 @@ class MockClock(object):
return t
- def looping_call(self, function, interval):
- self.loopers.append([function, interval / 1000.0, self.now])
+ def looping_call(self, function, interval, *args, **kwargs):
+ self.loopers.append([function, interval / 1000.0, self.now, args, kwargs])
def cancel_call_later(self, timer, ignore_errs=False):
if timer[2]:
@@ -543,9 +543,9 @@ class MockClock(object):
self.timers.append(t)
for looped in self.loopers:
- func, interval, last = looped
+ func, interval, last, args, kwargs = looped
if last + interval < self.now:
- func()
+ func(*args, **kwargs)
looped[2] = self.now
def advance_time_msec(self, ms):
diff --git a/tox.ini b/tox.ini
index 2630857436..eccc44e436 100644
--- a/tox.ini
+++ b/tox.ini
@@ -200,8 +200,9 @@ commands = mypy \
synapse/replication \
synapse/rest \
synapse/spam_checker_api \
- synapse/storage/engines \
+ synapse/storage/data_stores/main/ui_auth.py \
synapse/storage/database.py \
+ synapse/storage/engines \
synapse/streams \
synapse/util/caches/stream_change_cache.py \
tests/replication/tcp/streams \
|