summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/appservice.py14
-rw-r--r--synapse/handlers/e2e_keys.py31
2 files changed, 29 insertions, 16 deletions
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 4ca2bc0420..6429545c98 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -841,8 +841,10 @@ class ApplicationServicesHandler:
         return True
 
     async def claim_e2e_one_time_keys(
-        self, query: Iterable[Tuple[str, str, str]]
-    ) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]:
+        self, query: Iterable[Tuple[str, str, str, int]]
+    ) -> Tuple[
+        Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str, int]]
+    ]:
         """Claim one time keys from application services.
 
         Users which are exclusively owned by an application service are sent a
@@ -863,18 +865,18 @@ class ApplicationServicesHandler:
         services = self.store.get_app_services()
 
         # Partition the users by appservice.
-        query_by_appservice: Dict[str, List[Tuple[str, str, str]]] = {}
+        query_by_appservice: Dict[str, List[Tuple[str, str, str, int]]] = {}
         missing = []
-        for user_id, device, algorithm in query:
+        for user_id, device, algorithm, count in query:
             if not self.store.get_if_app_services_interested_in_user(user_id):
-                missing.append((user_id, device, algorithm))
+                missing.append((user_id, device, algorithm, count))
                 continue
 
             # Find the associated appservice.
             for service in services:
                 if service.is_exclusive_user(user_id):
                     query_by_appservice.setdefault(service.id, []).append(
-                        (user_id, device, algorithm)
+                        (user_id, device, algorithm, count)
                     )
                     continue
 
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index d1ab95126c..24741b667b 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -564,7 +564,7 @@ class E2eKeysHandler:
 
     async def claim_local_one_time_keys(
         self,
-        local_query: List[Tuple[str, str, str]],
+        local_query: List[Tuple[str, str, str, int]],
         always_include_fallback_keys: bool,
     ) -> Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]]:
         """Claim one time keys for local users.
@@ -581,6 +581,12 @@ class E2eKeysHandler:
             An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes.
         """
 
+        # Cap the number of OTKs that can be claimed at once to avoid abuse.
+        local_query = [
+            (user_id, device_id, algorithm, min(count, 5))
+            for user_id, device_id, algorithm, count in local_query
+        ]
+
         otk_results, not_found = await self.store.claim_e2e_one_time_keys(local_query)
 
         # If the application services have not provided any keys via the C-S
@@ -607,7 +613,7 @@ class E2eKeysHandler:
             # from the appservice for that user ID / device ID. If it is found,
             # check if any of the keys match the requested algorithm & are a
             # fallback key.
-            for user_id, device_id, algorithm in local_query:
+            for user_id, device_id, algorithm, _count in local_query:
                 # Check if the appservice responded for this query.
                 as_result = appservice_results.get(user_id, {}).get(device_id, {})
                 found_otk = False
@@ -630,13 +636,17 @@ class E2eKeysHandler:
                         .get(device_id, {})
                         .keys()
                     )
+                    # Note that it doesn't make sense to request more than 1 fallback key
+                    # per (user_id, device_id, algorithm).
                     fallback_query.append((user_id, device_id, algorithm, mark_as_used))
 
         else:
             # All fallback keys get marked as used.
             fallback_query = [
+                # Note that it doesn't make sense to request more than 1 fallback key
+                # per (user_id, device_id, algorithm).
                 (user_id, device_id, algorithm, True)
-                for user_id, device_id, algorithm in not_found
+                for user_id, device_id, algorithm, count in not_found
             ]
 
         # For each user that does not have a one-time keys available, see if
@@ -650,18 +660,19 @@ class E2eKeysHandler:
     @trace
     async def claim_one_time_keys(
         self,
-        query: Dict[str, Dict[str, Dict[str, str]]],
+        query: Dict[str, Dict[str, Dict[str, int]]],
         timeout: Optional[int],
         always_include_fallback_keys: bool,
     ) -> JsonDict:
-        local_query: List[Tuple[str, str, str]] = []
-        remote_queries: Dict[str, Dict[str, Dict[str, str]]] = {}
+        local_query: List[Tuple[str, str, str, int]] = []
+        remote_queries: Dict[str, Dict[str, Dict[str, Dict[str, int]]]] = {}
 
-        for user_id, one_time_keys in query.get("one_time_keys", {}).items():
+        for user_id, one_time_keys in query.items():
             # we use UserID.from_string to catch invalid user ids
             if self.is_mine(UserID.from_string(user_id)):
-                for device_id, algorithm in one_time_keys.items():
-                    local_query.append((user_id, device_id, algorithm))
+                for device_id, algorithms in one_time_keys.items():
+                    for algorithm, count in algorithms.items():
+                        local_query.append((user_id, device_id, algorithm, count))
             else:
                 domain = get_domain_from_id(user_id)
                 remote_queries.setdefault(domain, {})[user_id] = one_time_keys
@@ -692,7 +703,7 @@ class E2eKeysHandler:
             device_keys = remote_queries[destination]
             try:
                 remote_result = await self.federation.claim_client_keys(
-                    destination, {"one_time_keys": device_keys}, timeout=timeout
+                    destination, device_keys, timeout=timeout
                 )
                 for user_id, keys in remote_result["one_time_keys"].items():
                     if user_id in device_keys: