diff --git a/changelog.d/16565.feature b/changelog.d/16565.feature
new file mode 100644
index 0000000000..c807945fa8
--- /dev/null
+++ b/changelog.d/16565.feature
@@ -0,0 +1 @@
+Improve the performance of claiming encryption keys.
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 91c5fe007d..d340d4aebe 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -753,6 +753,16 @@ class E2eKeysHandler:
async def upload_keys_for_user(
self, user_id: str, device_id: str, keys: JsonDict
) -> JsonDict:
+ """
+ Args:
+ user_id: user whose keys are being uploaded.
+ device_id: device whose keys are being uploaded.
+ keys: the body of a /keys/upload request.
+
+ Returns a dictionary with one field:
+ "one_time_keys": A mapping from algorithm to number of keys for that
+ algorithm, including those previously persisted.
+ """
# This can only be called from the main process.
assert isinstance(self.device_handler, DeviceHandler)
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 08385d312f..4f96ac25c7 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -1111,7 +1111,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
...
async def claim_e2e_one_time_keys(
- self, query_list: Iterable[Tuple[str, str, str, int]]
+ self, query_list: Collection[Tuple[str, str, str, int]]
) -> Tuple[
Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str, int]]
]:
@@ -1121,131 +1121,63 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
query_list: An iterable of tuples of (user ID, device ID, algorithm).
Returns:
- A tuple pf:
+ A tuple (results, missing) of:
A map of user ID -> a map device ID -> a map of key ID -> JSON.
- A copy of the input which has not been fulfilled.
+ A copy of the input which has not been fulfilled. The returned counts
+ may be less than the input counts. In this case, the returned counts
+ are the number of claims that were not fulfilled.
"""
-
- @trace
- def _claim_e2e_one_time_key_simple(
- txn: LoggingTransaction,
- user_id: str,
- device_id: str,
- algorithm: str,
- count: int,
- ) -> List[Tuple[str, str]]:
- """Claim OTK for device for DBs that don't support RETURNING.
-
- Returns:
- A tuple of key name (algorithm + key ID) and key JSON, if an
- OTK was found.
- """
-
- sql = """
- SELECT key_id, key_json FROM e2e_one_time_keys_json
- WHERE user_id = ? AND device_id = ? AND algorithm = ?
- LIMIT ?
- """
-
- txn.execute(sql, (user_id, device_id, algorithm, count))
- otk_rows = list(txn)
- if not otk_rows:
- return []
-
- self.db_pool.simple_delete_many_txn(
- txn,
- table="e2e_one_time_keys_json",
- column="key_id",
- values=[otk_row[0] for otk_row in otk_rows],
- keyvalues={
- "user_id": user_id,
- "device_id": device_id,
- "algorithm": algorithm,
- },
- )
- self._invalidate_cache_and_stream(
- txn, self.count_e2e_one_time_keys, (user_id, device_id)
- )
-
- return [
- (f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows
- ]
-
- @trace
- def _claim_e2e_one_time_key_returning(
- txn: LoggingTransaction,
- user_id: str,
- device_id: str,
- algorithm: str,
- count: int,
- ) -> List[Tuple[str, str]]:
- """Claim OTK for device for DBs that support RETURNING.
-
- Returns:
- A tuple of key name (algorithm + key ID) and key JSON, if an
- OTK was found.
- """
-
- # We can use RETURNING to do the fetch and DELETE in once step.
- sql = """
- DELETE FROM e2e_one_time_keys_json
- WHERE user_id = ? AND device_id = ? AND algorithm = ?
- AND key_id IN (
- SELECT key_id FROM e2e_one_time_keys_json
- WHERE user_id = ? AND device_id = ? AND algorithm = ?
- LIMIT ?
- )
- RETURNING key_id, key_json
- """
-
- txn.execute(
- sql,
- (user_id, device_id, algorithm, user_id, device_id, algorithm, count),
- )
- otk_rows = list(txn)
- if not otk_rows:
- return []
-
- self._invalidate_cache_and_stream(
- txn, self.count_e2e_one_time_keys, (user_id, device_id)
- )
-
- return [
- (f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows
- ]
-
results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
missing: List[Tuple[str, str, str, int]] = []
- for user_id, device_id, algorithm, count in query_list:
- if self.database_engine.supports_returning:
- # If we support RETURNING clause we can use a single query that
- # allows us to use autocommit mode.
- _claim_e2e_one_time_key = _claim_e2e_one_time_key_returning
- db_autocommit = True
- else:
- _claim_e2e_one_time_key = _claim_e2e_one_time_key_simple
- db_autocommit = False
+ if isinstance(self.database_engine, PostgresEngine):
+ # If we can use execute_values we can use a single batch query
+ # in autocommit mode.
+ unfulfilled_claim_counts: Dict[Tuple[str, str, str], int] = {}
+ for user_id, device_id, algorithm, count in query_list:
+ unfulfilled_claim_counts[user_id, device_id, algorithm] = count
- claim_rows = await self.db_pool.runInteraction(
+ bulk_claims = await self.db_pool.runInteraction(
"claim_e2e_one_time_keys",
- _claim_e2e_one_time_key,
- user_id,
- device_id,
- algorithm,
- count,
- db_autocommit=db_autocommit,
+ self._claim_e2e_one_time_keys_bulk,
+ query_list,
+ db_autocommit=True,
)
- if claim_rows:
+
+ for user_id, device_id, algorithm, key_id, key_json in bulk_claims:
device_results = results.setdefault(user_id, {}).setdefault(
device_id, {}
)
- for claim_row in claim_rows:
- device_results[claim_row[0]] = json_decoder.decode(claim_row[1])
+ device_results[f"{algorithm}:{key_id}"] = json_decoder.decode(key_json)
+ unfulfilled_claim_counts[(user_id, device_id, algorithm)] -= 1
+
# Did we get enough OTKs?
- count -= len(claim_rows)
- if count:
- missing.append((user_id, device_id, algorithm, count))
+ missing = [
+ (user, device, alg, count)
+ for (user, device, alg), count in unfulfilled_claim_counts.items()
+ if count > 0
+ ]
+ else:
+ for user_id, device_id, algorithm, count in query_list:
+ claim_rows = await self.db_pool.runInteraction(
+ "claim_e2e_one_time_keys",
+ self._claim_e2e_one_time_key_simple,
+ user_id,
+ device_id,
+ algorithm,
+ count,
+ db_autocommit=False,
+ )
+ if claim_rows:
+ device_results = results.setdefault(user_id, {}).setdefault(
+ device_id, {}
+ )
+ for claim_row in claim_rows:
+ device_results[claim_row[0]] = json_decoder.decode(claim_row[1])
+ # Did we get enough OTKs?
+ count -= len(claim_rows)
+ if count:
+ missing.append((user_id, device_id, algorithm, count))
return results, missing
@@ -1362,6 +1294,99 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
return results
+ @trace
+ def _claim_e2e_one_time_key_simple(
+ self,
+ txn: LoggingTransaction,
+ user_id: str,
+ device_id: str,
+ algorithm: str,
+ count: int,
+ ) -> List[Tuple[str, str]]:
+ """Claim OTK for device for DBs that don't support RETURNING.
+
+ Returns:
+ A tuple of key name (algorithm + key ID) and key JSON, if an
+ OTK was found.
+ """
+
+ sql = """
+ SELECT key_id, key_json FROM e2e_one_time_keys_json
+ WHERE user_id = ? AND device_id = ? AND algorithm = ?
+ LIMIT ?
+ """
+
+ txn.execute(sql, (user_id, device_id, algorithm, count))
+ otk_rows = list(txn)
+ if not otk_rows:
+ return []
+
+ self.db_pool.simple_delete_many_txn(
+ txn,
+ table="e2e_one_time_keys_json",
+ column="key_id",
+ values=[otk_row[0] for otk_row in otk_rows],
+ keyvalues={
+ "user_id": user_id,
+ "device_id": device_id,
+ "algorithm": algorithm,
+ },
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.count_e2e_one_time_keys, (user_id, device_id)
+ )
+
+ return [(f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows]
+
+ @trace
+ def _claim_e2e_one_time_keys_bulk(
+ self,
+ txn: LoggingTransaction,
+ query_list: Iterable[Tuple[str, str, str, int]],
+ ) -> List[Tuple[str, str, str, str, str]]:
+ """Bulk claim OTKs, for DBs that support DELETE FROM... RETURNING.
+
+ Args:
+ query_list: Collection of tuples (user_id, device_id, algorithm, count)
+ as passed to claim_e2e_one_time_keys.
+
+ Returns:
+ A list of tuples (user_id, device_id, algorithm, key_id, key_json)
+ for each OTK claimed.
+ """
+ sql = """
+ WITH claims(user_id, device_id, algorithm, claim_count) AS (
+ VALUES ?
+ ), ranked_keys AS (
+ SELECT
+ user_id, device_id, algorithm, key_id, claim_count,
+ ROW_NUMBER() OVER (PARTITION BY (user_id, device_id, algorithm)) AS r
+ FROM e2e_one_time_keys_json
+ JOIN claims USING (user_id, device_id, algorithm)
+ )
+ DELETE FROM e2e_one_time_keys_json k
+ WHERE (user_id, device_id, algorithm, key_id) IN (
+ SELECT user_id, device_id, algorithm, key_id
+ FROM ranked_keys
+ WHERE r <= claim_count
+ )
+ RETURNING user_id, device_id, algorithm, key_id, key_json;
+ """
+ otk_rows = cast(
+ List[Tuple[str, str, str, str, str]], txn.execute_values(sql, query_list)
+ )
+
+ seen_user_device: Set[Tuple[str, str]] = set()
+ for user_id, device_id, _, _, _ in otk_rows:
+ if (user_id, device_id) in seen_user_device:
+ continue
+ seen_user_device.add((user_id, device_id))
+ self._invalidate_cache_and_stream(
+ txn, self.count_e2e_one_time_keys, (user_id, device_id)
+ )
+
+ return otk_rows
+
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
def __init__(
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 24e405f429..90b4da9ad5 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -174,6 +174,164 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
},
)
+ def test_claim_one_time_key_bulk(self) -> None:
+ """Like test_claim_one_time_key but claims multiple keys in one handler call."""
+ # Apologies to the reader. This test is a little too verbose. It is particularly
+ # tricky to make assertions neatly with all these nested dictionaries in play.
+
+ # Three users with two devices each. Each device uses two algorithms.
+ # Each algorithm is invoked with two keys.
+ alice = f"@alice:{self.hs.hostname}"
+ brian = f"@brian:{self.hs.hostname}"
+ chris = f"@chris:{self.hs.hostname}"
+ one_time_keys = {
+ alice: {
+ "alice_dev_1": {
+ "alg1:k1": {"dummy_id": 1},
+ "alg1:k2": {"dummy_id": 2},
+ "alg2:k3": {"dummy_id": 3},
+ "alg2:k4": {"dummy_id": 4},
+ },
+ "alice_dev_2": {
+ "alg1:k5": {"dummy_id": 5},
+ "alg1:k6": {"dummy_id": 6},
+ "alg2:k7": {"dummy_id": 7},
+ "alg2:k8": {"dummy_id": 8},
+ },
+ },
+ brian: {
+ "brian_dev_1": {
+ "alg1:k9": {"dummy_id": 9},
+ "alg1:k10": {"dummy_id": 10},
+ "alg2:k11": {"dummy_id": 11},
+ "alg2:k12": {"dummy_id": 12},
+ },
+ "brian_dev_2": {
+ "alg1:k13": {"dummy_id": 13},
+ "alg1:k14": {"dummy_id": 14},
+ "alg2:k15": {"dummy_id": 15},
+ "alg2:k16": {"dummy_id": 16},
+ },
+ },
+ chris: {
+ "chris_dev_1": {
+ "alg1:k17": {"dummy_id": 17},
+ "alg1:k18": {"dummy_id": 18},
+ "alg2:k19": {"dummy_id": 19},
+ "alg2:k20": {"dummy_id": 20},
+ },
+ "chris_dev_2": {
+ "alg1:k21": {"dummy_id": 21},
+ "alg1:k22": {"dummy_id": 22},
+ "alg2:k23": {"dummy_id": 23},
+ "alg2:k24": {"dummy_id": 24},
+ },
+ },
+ }
+ for user_id, devices in one_time_keys.items():
+ for device_id, keys_dict in devices.items():
+ counts = self.get_success(
+ self.handler.upload_keys_for_user(
+ user_id,
+ device_id,
+ {"one_time_keys": keys_dict},
+ )
+ )
+ # The upload should report 2 keys per algorithm.
+ expected_counts = {
+ "one_time_key_counts": {
+ # See count_e2e_one_time_keys for why this is hardcoded.
+ "signed_curve25519": 0,
+ "alg1": 2,
+ "alg2": 2,
+ },
+ }
+ self.assertEqual(counts, expected_counts)
+
+ # Claim a variety of keys.
+ # Raw format, easier to make test assertions about.
+ claims_to_make = {
+ (alice, "alice_dev_1", "alg1"): 1,
+ (alice, "alice_dev_1", "alg2"): 2,
+ (alice, "alice_dev_2", "alg2"): 1,
+ (brian, "brian_dev_1", "alg1"): 2,
+ (brian, "brian_dev_2", "alg2"): 9001,
+ (chris, "chris_dev_2", "alg2"): 1,
+ }
+ # Convert to the format the handler wants.
+ query: Dict[str, Dict[str, Dict[str, int]]] = {}
+ for (user_id, device_id, algorithm), count in claims_to_make.items():
+ query.setdefault(user_id, {}).setdefault(device_id, {})[algorithm] = count
+ claim_res = self.get_success(
+ self.handler.claim_one_time_keys(
+ query,
+ self.requester,
+ timeout=None,
+ always_include_fallback_keys=False,
+ )
+ )
+
+ # No failures, please!
+ self.assertEqual(claim_res["failures"], {})
+
+ # Check that we get exactly the (user, device, algorithm)s we asked for.
+ got_otks = claim_res["one_time_keys"]
+ claimed_user_device_algorithms = {
+ (user_id, device_id, alg_key_id.split(":")[0])
+ for user_id, devices in got_otks.items()
+ for device_id, key_dict in devices.items()
+ for alg_key_id in key_dict
+ }
+ self.assertEqual(claimed_user_device_algorithms, set(claims_to_make))
+
+ # Now check the keys we got are what we expected.
+ def assertExactlyOneOtk(
+ user_id: str, device_id: str, *alg_key_pairs: str
+ ) -> None:
+ key_dict = got_otks[user_id][device_id]
+ found = 0
+ for alg_key in alg_key_pairs:
+ if alg_key in key_dict:
+ expected_key_json = one_time_keys[user_id][device_id][alg_key]
+ self.assertEqual(key_dict[alg_key], expected_key_json)
+ found += 1
+ self.assertEqual(found, 1)
+
+ def assertAllOtks(user_id: str, device_id: str, *alg_key_pairs: str) -> None:
+ key_dict = got_otks[user_id][device_id]
+ for alg_key in alg_key_pairs:
+ expected_key_json = one_time_keys[user_id][device_id][alg_key]
+ self.assertEqual(key_dict[alg_key], expected_key_json)
+
+ # Expect a single arbitrary key to be returned.
+ assertExactlyOneOtk(alice, "alice_dev_1", "alg1:k1", "alg1:k2")
+ assertExactlyOneOtk(alice, "alice_dev_2", "alg2:k7", "alg2:k8")
+ assertExactlyOneOtk(chris, "chris_dev_2", "alg2:k23", "alg2:k24")
+
+ assertAllOtks(alice, "alice_dev_1", "alg2:k3", "alg2:k4")
+ assertAllOtks(brian, "brian_dev_1", "alg1:k9", "alg1:k10")
+ assertAllOtks(brian, "brian_dev_2", "alg2:k15", "alg2:k16")
+
+ # Now check the unused key counts.
+ for user_id, devices in one_time_keys.items():
+ for device_id in devices:
+ counts_by_alg = self.get_success(
+ self.store.count_e2e_one_time_keys(user_id, device_id)
+ )
+ # Somewhat fiddley to compute the expected count dict.
+ expected_counts_by_alg = {
+ "signed_curve25519": 0,
+ }
+ for alg in ["alg1", "alg2"]:
+ claim_count = claims_to_make.get((user_id, device_id, alg), 0)
+ remaining_count = max(0, 2 - claim_count)
+ if remaining_count > 0:
+ expected_counts_by_alg[alg] = remaining_count
+
+ self.assertEqual(
+ counts_by_alg, expected_counts_by_alg, f"{user_id}:{device_id}"
+ )
+
def test_fallback_key(self) -> None:
local_user = "@boris:" + self.hs.hostname
device_id = "xyz"
|