summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/16570.feature1
-rw-r--r--synapse/handlers/e2e_keys.py14
-rw-r--r--synapse/storage/database.py10
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py60
-rw-r--r--tests/handlers/test_e2e_keys.py77
5 files changed, 162 insertions, 0 deletions
diff --git a/changelog.d/16570.feature b/changelog.d/16570.feature
new file mode 100644
index 0000000000..c807945fa8
--- /dev/null
+++ b/changelog.d/16570.feature
@@ -0,0 +1 @@
+Improve the performance of claiming encryption keys.
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 8c6432035d..91c5fe007d 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -659,6 +659,20 @@ class E2eKeysHandler:
         timeout: Optional[int],
         always_include_fallback_keys: bool,
     ) -> JsonDict:
+        """
+        Args:
+            query: A chain of maps from (user_id, device_id, algorithm) to the requested
+                number of keys to claim.
+            user: The user who is claiming these keys.
+            timeout: How long to wait for any federation key claim requests before
+                giving up.
+            always_include_fallback_keys: always include a fallback key for local users'
+                devices, even if we managed to claim a one-time-key.
+
+        Returns: a heterogeneous dict with two keys:
+            one_time_keys: chain of maps user ID -> device ID -> key ID -> key.
+            failures: map from remote destination to a JsonDict describing the error.
+        """
         local_query: List[Tuple[str, str, str, int]] = []
         remote_queries: Dict[str, Dict[str, Dict[str, Dict[str, int]]]] = {}
 
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index b1ece63845..a4e7048368 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -420,6 +420,16 @@ class LoggingTransaction:
         self._do_execute(self.txn.execute, sql, parameters)
 
     def executemany(self, sql: str, *args: Any) -> None:
+        """Repeatedly execute the same piece of SQL with different parameters.
+
+        See https://peps.python.org/pep-0249/#executemany. Note in particular that
+
+        > Use of this method for an operation which produces one or more result sets
+        > constitutes undefined behavior
+
+        so you can't use this for e.g. a SELECT, an UPDATE ... RETURNING, or a
+        DELETE FROM... RETURNING.
+        """
         # TODO: we should add a type for *args here. Looking at Cursor.executemany
         # and DBAPI2 it ought to be Sequence[_Parameter], but we pass in
         # Iterable[Iterable[Any]] in execute_batch and execute_values above, which mypy
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index f70f95eeba..08385d312f 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -24,6 +24,7 @@ from typing import (
     Mapping,
     Optional,
     Sequence,
+    Set,
     Tuple,
     Union,
     cast,
@@ -1260,6 +1261,65 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
         Returns:
             A map of user ID -> a map device ID -> a map of key ID -> JSON.
         """
+        if isinstance(self.database_engine, PostgresEngine):
+            return await self.db_pool.runInteraction(
+                "_claim_e2e_fallback_keys_bulk",
+                self._claim_e2e_fallback_keys_bulk_txn,
+                query_list,
+                db_autocommit=True,
+            )
+            # Use an UPDATE FROM... RETURNING combined with a VALUES block to do
+            # everything in one query. Note: this is also supported in SQLite 3.33.0,
+            # (see https://www.sqlite.org/lang_update.html#update_from), but we do not
+            # have an equivalent of psycopg2's execute_values to do this in one query.
+        else:
+            return await self._claim_e2e_fallback_keys_simple(query_list)
+
+    def _claim_e2e_fallback_keys_bulk_txn(
+        self,
+        txn: LoggingTransaction,
+        query_list: Iterable[Tuple[str, str, str, bool]],
+    ) -> Dict[str, Dict[str, Dict[str, JsonDict]]]:
+        """Efficient implementation of claim_e2e_fallback_keys for Postgres.
+
+        Safe to autocommit: this is a single query.
+        """
+        results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
+
+        sql = """
+            WITH claims(user_id, device_id, algorithm, mark_as_used) AS (
+                VALUES ?
+            )
+            UPDATE e2e_fallback_keys_json k
+            SET used = used OR mark_as_used
+            FROM claims
+            WHERE (k.user_id, k.device_id, k.algorithm) = (claims.user_id, claims.device_id, claims.algorithm)
+            RETURNING k.user_id, k.device_id, k.algorithm, k.key_id, k.key_json;
+        """
+        claimed_keys = cast(
+            List[Tuple[str, str, str, str, str]],
+            txn.execute_values(sql, query_list),
+        )
+
+        seen_user_device: Set[Tuple[str, str]] = set()
+        for user_id, device_id, algorithm, key_id, key_json in claimed_keys:
+            device_results = results.setdefault(user_id, {}).setdefault(device_id, {})
+            device_results[f"{algorithm}:{key_id}"] = json_decoder.decode(key_json)
+
+            if (user_id, device_id) in seen_user_device:
+                continue
+            seen_user_device.add((user_id, device_id))
+            self._invalidate_cache_and_stream(
+                txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id)
+            )
+
+        return results
+
+    async def _claim_e2e_fallback_keys_simple(
+        self,
+        query_list: Iterable[Tuple[str, str, str, bool]],
+    ) -> Dict[str, Dict[str, Dict[str, JsonDict]]]:
+        """Naive, inefficient implementation of claim_e2e_fallback_keys for SQLite."""
         results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
         for user_id, device_id, algorithm, mark_as_used in query_list:
             row = await self.db_pool.simple_select_one(
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index c5556f2844..24e405f429 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -322,6 +322,83 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
             {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key3}}},
         )
 
+    def test_fallback_key_bulk(self) -> None:
+        """Like test_fallback_key, but claims multiple keys in one handler call."""
+        alice = f"@alice:{self.hs.hostname}"
+        brian = f"@brian:{self.hs.hostname}"
+        chris = f"@chris:{self.hs.hostname}"
+
+        # Have three users upload fallback keys for two devices.
+        fallback_keys = {
+            alice: {
+                "alice_dev_1": {"alg1:k1": "fallback_key1"},
+                "alice_dev_2": {"alg2:k2": "fallback_key2"},
+            },
+            brian: {
+                "brian_dev_1": {"alg1:k3": "fallback_key3"},
+                "brian_dev_2": {"alg2:k4": "fallback_key4"},
+            },
+            chris: {
+                "chris_dev_1": {"alg1:k5": "fallback_key5"},
+                "chris_dev_2": {"alg2:k6": "fallback_key6"},
+            },
+        }
+
+        for user_id, devices in fallback_keys.items():
+            for device_id, key_dict in devices.items():
+                self.get_success(
+                    self.handler.upload_keys_for_user(
+                        user_id,
+                        device_id,
+                        {"fallback_keys": key_dict},
+                    )
+                )
+
+        # Each device should have an unused fallback key.
+        for user_id, devices in fallback_keys.items():
+            for device_id in devices:
+                fallback_res = self.get_success(
+                    self.store.get_e2e_unused_fallback_key_types(user_id, device_id)
+                )
+                expected_algorithm_name = f"alg{device_id[-1]}"
+                self.assertEqual(fallback_res, [expected_algorithm_name])
+
+        # Claim the fallback key for one device per user.
+        claim_res = self.get_success(
+            self.handler.claim_one_time_keys(
+                {
+                    alice: {"alice_dev_1": {"alg1": 1}},
+                    brian: {"brian_dev_2": {"alg2": 1}},
+                    chris: {"chris_dev_2": {"alg2": 1}},
+                },
+                self.requester,
+                timeout=None,
+                always_include_fallback_keys=False,
+            )
+        )
+        expected_claims = {
+            alice: {"alice_dev_1": {"alg1:k1": "fallback_key1"}},
+            brian: {"brian_dev_2": {"alg2:k4": "fallback_key4"}},
+            chris: {"chris_dev_2": {"alg2:k6": "fallback_key6"}},
+        }
+        self.assertEqual(
+            claim_res,
+            {"failures": {}, "one_time_keys": expected_claims},
+        )
+
+        for user_id, devices in fallback_keys.items():
+            for device_id in devices:
+                fallback_res = self.get_success(
+                    self.store.get_e2e_unused_fallback_key_types(user_id, device_id)
+                )
+                # Claimed fallback keys should no longer show up as unused.
+                # Unclaimed fallback keys should still be unused.
+                if device_id in expected_claims[user_id]:
+                    self.assertEqual(fallback_res, [])
+                else:
+                    expected_algorithm_name = f"alg{device_id[-1]}"
+                    self.assertEqual(fallback_res, [expected_algorithm_name])
+
     def test_fallback_key_always_returned(self) -> None:
         local_user = "@boris:" + self.hs.hostname
         device_id = "xyz"