summary refs log tree commit diff
path: root/synapse/storage/databases
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py188
1 files changed, 126 insertions, 62 deletions
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 1edc96042b..1f0a39eac4 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -755,81 +755,145 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
         """
 
         @trace
-        def _claim_e2e_one_time_keys(txn):
-            sql = (
-                "SELECT key_id, key_json FROM e2e_one_time_keys_json"
-                " WHERE user_id = ? AND device_id = ? AND algorithm = ?"
-                " LIMIT 1"
+        def _claim_e2e_one_time_key_simple(
+            txn, user_id: str, device_id: str, algorithm: str
+        ) -> Optional[Tuple[str, str]]:
+            """Claim OTK for device for DBs that don't support RETURNING.
+
+            Returns:
+                A tuple of key name (algorithm + key ID) and key JSON, if an
+                OTK was found.
+            """
+
+            sql = """
+                SELECT key_id, key_json FROM e2e_one_time_keys_json
+                WHERE user_id = ? AND device_id = ? AND algorithm = ?
+                LIMIT 1
+            """
+
+            txn.execute(sql, (user_id, device_id, algorithm))
+            otk_row = txn.fetchone()
+            if otk_row is None:
+                return None
+
+            key_id, key_json = otk_row
+
+            self.db_pool.simple_delete_one_txn(
+                txn,
+                table="e2e_one_time_keys_json",
+                keyvalues={
+                    "user_id": user_id,
+                    "device_id": device_id,
+                    "algorithm": algorithm,
+                    "key_id": key_id,
+                },
             )
-            fallback_sql = (
-                "SELECT key_id, key_json, used FROM e2e_fallback_keys_json"
-                " WHERE user_id = ? AND device_id = ? AND algorithm = ?"
-                " LIMIT 1"
+            self._invalidate_cache_and_stream(
+                txn, self.count_e2e_one_time_keys, (user_id, device_id)
             )
-            result = {}
-            delete = []
-            used_fallbacks = []
-            for user_id, device_id, algorithm in query_list:
-                user_result = result.setdefault(user_id, {})
-                device_result = user_result.setdefault(device_id, {})
-                txn.execute(sql, (user_id, device_id, algorithm))
-                otk_row = txn.fetchone()
-                if otk_row is not None:
-                    key_id, key_json = otk_row
-                    device_result[algorithm + ":" + key_id] = key_json
-                    delete.append((user_id, device_id, algorithm, key_id))
-                else:
-                    # no one-time key available, so see if there's a fallback
-                    # key
-                    txn.execute(fallback_sql, (user_id, device_id, algorithm))
-                    fallback_row = txn.fetchone()
-                    if fallback_row is not None:
-                        key_id, key_json, used = fallback_row
-                        device_result[algorithm + ":" + key_id] = key_json
-                        if not used:
-                            used_fallbacks.append(
-                                (user_id, device_id, algorithm, key_id)
-                            )
-
-            # drop any one-time keys that were claimed
-            sql = (
-                "DELETE FROM e2e_one_time_keys_json"
-                " WHERE user_id = ? AND device_id = ? AND algorithm = ?"
-                " AND key_id = ?"
+
+            return f"{algorithm}:{key_id}", key_json
+
+        @trace
+        def _claim_e2e_one_time_key_returning(
+            txn, user_id: str, device_id: str, algorithm: str
+        ) -> Optional[Tuple[str, str]]:
+            """Claim OTK for device for DBs that support RETURNING.
+
+            Returns:
+                A tuple of key name (algorithm + key ID) and key JSON, if an
+                OTK was found.
+            """
+
+            # We can use RETURNING to do the fetch and DELETE in once step.
+            sql = """
+                DELETE FROM e2e_one_time_keys_json
+                WHERE user_id = ? AND device_id = ? AND algorithm = ?
+                    AND key_id IN (
+                        SELECT key_id FROM e2e_one_time_keys_json
+                        WHERE user_id = ? AND device_id = ? AND algorithm = ?
+                        LIMIT 1
+                    )
+                RETURNING key_id, key_json
+            """
+
+            txn.execute(
+                sql, (user_id, device_id, algorithm, user_id, device_id, algorithm)
             )
-            for user_id, device_id, algorithm, key_id in delete:
-                log_kv(
-                    {
-                        "message": "Executing claim e2e_one_time_keys transaction on database."
-                    }
-                )
-                txn.execute(sql, (user_id, device_id, algorithm, key_id))
-                log_kv({"message": "finished executing and invalidating cache"})
-                self._invalidate_cache_and_stream(
-                    txn, self.count_e2e_one_time_keys, (user_id, device_id)
+            otk_row = txn.fetchone()
+            if otk_row is None:
+                return None
+
+            key_id, key_json = otk_row
+            return f"{algorithm}:{key_id}", key_json
+
+        results = {}
+        for user_id, device_id, algorithm in query_list:
+            if self.database_engine.supports_returning:
+                # If we support RETURNING clause we can use a single query that
+                # allows us to use autocommit mode.
+                _claim_e2e_one_time_key = _claim_e2e_one_time_key_returning
+                db_autocommit = True
+            else:
+                _claim_e2e_one_time_key = _claim_e2e_one_time_key_simple
+                db_autocommit = False
+
+            row = await self.db_pool.runInteraction(
+                "claim_e2e_one_time_keys",
+                _claim_e2e_one_time_key,
+                user_id,
+                device_id,
+                algorithm,
+                db_autocommit=db_autocommit,
+            )
+            if row:
+                device_results = results.setdefault(user_id, {}).setdefault(
+                    device_id, {}
                 )
-            # mark fallback keys as used
-            for user_id, device_id, algorithm, key_id in used_fallbacks:
-                self.db_pool.simple_update_txn(
-                    txn,
-                    "e2e_fallback_keys_json",
-                    {
+                device_results[row[0]] = row[1]
+                continue
+
+            # No one-time key available, so see if there's a fallback
+            # key
+            row = await self.db_pool.simple_select_one(
+                table="e2e_fallback_keys_json",
+                keyvalues={
+                    "user_id": user_id,
+                    "device_id": device_id,
+                    "algorithm": algorithm,
+                },
+                retcols=("key_id", "key_json", "used"),
+                desc="_get_fallback_key",
+                allow_none=True,
+            )
+            if row is None:
+                continue
+
+            key_id = row["key_id"]
+            key_json = row["key_json"]
+            used = row["used"]
+
+            # Mark fallback key as used if not already.
+            if not used:
+                await self.db_pool.simple_update_one(
+                    table="e2e_fallback_keys_json",
+                    keyvalues={
                         "user_id": user_id,
                         "device_id": device_id,
                         "algorithm": algorithm,
                         "key_id": key_id,
                     },
-                    {"used": True},
+                    updatevalues={"used": True},
+                    desc="_get_fallback_key_set_used",
                 )
-                self._invalidate_cache_and_stream(
-                    txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id)
+                await self.invalidate_cache_and_stream(
+                    "get_e2e_unused_fallback_key_types", (user_id, device_id)
                 )
 
-            return result
+            device_results = results.setdefault(user_id, {}).setdefault(device_id, {})
+            device_results[f"{algorithm}:{key_id}"] = key_json
 
-        return await self.db_pool.runInteraction(
-            "claim_e2e_one_time_keys", _claim_e2e_one_time_keys
-        )
+        return results
 
 
 class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):