summary refs log tree commit diff
path: root/synapse/handlers/e2e_keys.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/e2e_keys.py')
-rw-r--r--synapse/handlers/e2e_keys.py78
1 files changed, 45 insertions, 33 deletions
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 1ece54ccfc..4f40e9ffd6 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -53,6 +53,9 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
+ONE_TIME_KEY_UPLOAD = "one_time_key_upload_lock"
+
+
 class E2eKeysHandler:
     def __init__(self, hs: "HomeServer"):
         self.config = hs.config
@@ -62,6 +65,7 @@ class E2eKeysHandler:
         self._appservice_handler = hs.get_application_service_handler()
         self.is_mine = hs.is_mine
         self.clock = hs.get_clock()
+        self._worker_lock_handler = hs.get_worker_locks_handler()
 
         federation_registry = hs.get_federation_registry()
 
@@ -855,45 +859,53 @@ class E2eKeysHandler:
     async def _upload_one_time_keys_for_user(
         self, user_id: str, device_id: str, time_now: int, one_time_keys: JsonDict
     ) -> None:
-        logger.info(
-            "Adding one_time_keys %r for device %r for user %r at %d",
-            one_time_keys.keys(),
-            device_id,
-            user_id,
-            time_now,
-        )
+        # We take out a lock so that we don't have to worry about a client
+        # sending duplicate requests.
+        lock_key = f"{user_id}_{device_id}"
+        async with self._worker_lock_handler.acquire_lock(
+            ONE_TIME_KEY_UPLOAD, lock_key
+        ):
+            logger.info(
+                "Adding one_time_keys %r for device %r for user %r at %d",
+                one_time_keys.keys(),
+                device_id,
+                user_id,
+                time_now,
+            )
 
-        # make a list of (alg, id, key) tuples
-        key_list = []
-        for key_id, key_obj in one_time_keys.items():
-            algorithm, key_id = key_id.split(":")
-            key_list.append((algorithm, key_id, key_obj))
+            # make a list of (alg, id, key) tuples
+            key_list = []
+            for key_id, key_obj in one_time_keys.items():
+                algorithm, key_id = key_id.split(":")
+                key_list.append((algorithm, key_id, key_obj))
 
-        # First we check if we have already persisted any of the keys.
-        existing_key_map = await self.store.get_e2e_one_time_keys(
-            user_id, device_id, [k_id for _, k_id, _ in key_list]
-        )
+            # First we check if we have already persisted any of the keys.
+            existing_key_map = await self.store.get_e2e_one_time_keys(
+                user_id, device_id, [k_id for _, k_id, _ in key_list]
+            )
 
-        new_keys = []  # Keys that we need to insert. (alg, id, json) tuples.
-        for algorithm, key_id, key in key_list:
-            ex_json = existing_key_map.get((algorithm, key_id), None)
-            if ex_json:
-                if not _one_time_keys_match(ex_json, key):
-                    raise SynapseError(
-                        400,
-                        (
-                            "One time key %s:%s already exists. "
-                            "Old key: %s; new key: %r"
+            new_keys = []  # Keys that we need to insert. (alg, id, json) tuples.
+            for algorithm, key_id, key in key_list:
+                ex_json = existing_key_map.get((algorithm, key_id), None)
+                if ex_json:
+                    if not _one_time_keys_match(ex_json, key):
+                        raise SynapseError(
+                            400,
+                            (
+                                "One time key %s:%s already exists. "
+                                "Old key: %s; new key: %r"
+                            )
+                            % (algorithm, key_id, ex_json, key),
                         )
-                        % (algorithm, key_id, ex_json, key),
+                else:
+                    new_keys.append(
+                        (algorithm, key_id, encode_canonical_json(key).decode("ascii"))
                     )
-            else:
-                new_keys.append(
-                    (algorithm, key_id, encode_canonical_json(key).decode("ascii"))
-                )
 
-        log_kv({"message": "Inserting new one_time_keys.", "keys": new_keys})
-        await self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys)
+            log_kv({"message": "Inserting new one_time_keys.", "keys": new_keys})
+            await self.store.add_e2e_one_time_keys(
+                user_id, device_id, time_now, new_keys
+            )
 
     async def upload_signing_keys_for_user(
         self, user_id: str, keys: JsonDict