diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index 86ddb1bb28..024098e9cb 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -442,8 +442,10 @@ class ApplicationServiceApi(SimpleHttpClient):
return False
async def claim_client_keys(
- self, service: "ApplicationService", query: List[Tuple[str, str, str]]
- ) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]:
+ self, service: "ApplicationService", query: List[Tuple[str, str, str, int]]
+ ) -> Tuple[
+ Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str, int]]
+ ]:
"""Claim one time keys from an application service.
Note that any error (including a timeout) is treated as the application
@@ -469,8 +471,10 @@ class ApplicationServiceApi(SimpleHttpClient):
# Create the expected payload shape.
body: Dict[str, Dict[str, List[str]]] = {}
- for user_id, device, algorithm in query:
- body.setdefault(user_id, {}).setdefault(device, []).append(algorithm)
+ for user_id, device, algorithm, count in query:
+ body.setdefault(user_id, {}).setdefault(device, []).extend(
+ [algorithm] * count
+ )
uri = f"{service.url}/_matrix/app/unstable/org.matrix.msc3983/keys/claim"
try:
@@ -493,11 +497,20 @@ class ApplicationServiceApi(SimpleHttpClient):
# or if some are still missing.
#
# TODO This places a lot of faith in the response shape being correct.
- missing = [
- (user_id, device, algorithm)
- for user_id, device, algorithm in query
- if algorithm not in response.get(user_id, {}).get(device, [])
- ]
+ missing = []
+ for user_id, device, algorithm, count in query:
+ # Count the number of keys in the response for this algorithm by
+ # checking which key IDs start with the algorithm. This uses that
+ # True == 1 in Python to generate a count.
+ response_count = sum(
+ key_id.startswith(f"{algorithm}:")
+ for key_id in response.get(user_id, {}).get(device, {})
+ )
+ count -= response_count
+ # If the appservice responds with fewer keys than requested, then
+ # consider the request unfulfilled.
+ if count > 0:
+ missing.append((user_id, device, algorithm, count))
return response, missing
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index ba34573d46..0b2d1a78f7 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -235,7 +235,10 @@ class FederationClient(FederationBase):
)
async def claim_client_keys(
- self, destination: str, content: JsonDict, timeout: Optional[int]
+ self,
+ destination: str,
+ query: Dict[str, Dict[str, Dict[str, int]]],
+ timeout: Optional[int],
) -> JsonDict:
"""Claims one-time keys for a device hosted on a remote server.
@@ -247,6 +250,50 @@ class FederationClient(FederationBase):
The JSON object from the response
"""
sent_queries_counter.labels("client_one_time_keys").inc()
+
+ # Convert the query with counts into a stable and unstable query and check
+ # if attempting to claim more than 1 OTK.
+ content: Dict[str, Dict[str, str]] = {}
+ unstable_content: Dict[str, Dict[str, List[str]]] = {}
+ use_unstable = False
+ for user_id, one_time_keys in query.items():
+ for device_id, algorithms in one_time_keys.items():
+ if any(count > 1 for count in algorithms.values()):
+ use_unstable = True
+ if algorithms:
+ # For the stable query, choose only the first algorithm.
+ content.setdefault(user_id, {})[device_id] = next(iter(algorithms))
+ # For the unstable query, repeat each algorithm by count, then
+ # splat those into chain to get a flattened list of all algorithms.
+ #
+ # Converts from {"algo1": 2, "algo2": 2} to ["algo1", "algo1", "algo2"].
+ unstable_content.setdefault(user_id, {})[device_id] = list(
+ itertools.chain(
+ *(
+ itertools.repeat(algorithm, count)
+ for algorithm, count in algorithms.items()
+ )
+ )
+ )
+
+ if use_unstable:
+ try:
+ return await self.transport_layer.claim_client_keys_unstable(
+ destination, unstable_content, timeout
+ )
+ except HttpResponseException as e:
+ # If an error is received that is due to an unrecognised endpoint,
+ # fallback to the v1 endpoint. Otherwise, consider it a legitimate error
+ # and raise.
+ if not is_unknown_endpoint(e):
+ raise
+
+ logger.debug(
+ "Couldn't claim client keys with the unstable API, falling back to the v1 API"
+ )
+ else:
+ logger.debug("Skipping unstable claim client keys API")
+
return await self.transport_layer.claim_client_keys(
destination, content, timeout
)
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index c618f3d7a6..ca43c7bfc0 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -1005,13 +1005,8 @@ class FederationServer(FederationBase):
@trace
async def on_claim_client_keys(
- self, origin: str, content: JsonDict, always_include_fallback_keys: bool
+ self, query: List[Tuple[str, str, str, int]], always_include_fallback_keys: bool
) -> Dict[str, Any]:
- query = []
- for user_id, device_keys in content.get("one_time_keys", {}).items():
- for device_id, algorithm in device_keys.items():
- query.append((user_id, device_id, algorithm))
-
log_kv({"message": "Claiming one time keys.", "user, device pairs": query})
results = await self._e2e_keys_handler.claim_local_one_time_keys(
query, always_include_fallback_keys=always_include_fallback_keys
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index bedbd23ded..bc70b94f68 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -650,10 +650,10 @@ class TransportLayerClient:
Response:
{
- "device_keys": {
+ "one_time_keys": {
"<user_id>": {
"<device_id>": {
- "<algorithm>:<key_id>": "<key_base64>"
+ "<algorithm>:<key_id>": <OTK JSON>
}
}
}
@@ -669,7 +669,50 @@ class TransportLayerClient:
path = _create_v1_path("/user/keys/claim")
return await self.client.post_json(
- destination=destination, path=path, data=query_content, timeout=timeout
+ destination=destination,
+ path=path,
+ data={"one_time_keys": query_content},
+ timeout=timeout,
+ )
+
+ async def claim_client_keys_unstable(
+ self, destination: str, query_content: JsonDict, timeout: Optional[int]
+ ) -> JsonDict:
+ """Claim one-time keys for a list of devices hosted on a remote server.
+
+ Request:
+ {
+ "one_time_keys": {
+ "<user_id>": {
+ "<device_id>": {"<algorithm>": <count>}
+ }
+ }
+ }
+
+ Response:
+ {
+ "one_time_keys": {
+ "<user_id>": {
+ "<device_id>": {
+ "<algorithm>:<key_id>": <OTK JSON>
+ }
+ }
+ }
+ }
+
+ Args:
+ destination: The server to query.
+ query_content: The user ids to query.
+ Returns:
+ A dict containing the one-time keys.
+ """
+ path = _create_path(FEDERATION_UNSTABLE_PREFIX, "/user/keys/claim")
+
+ return await self.client.post_json(
+ destination=destination,
+ path=path,
+ data={"one_time_keys": query_content},
+ timeout=timeout,
)
async def get_missing_events(
diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py
index e2340d70d5..36b0362504 100644
--- a/synapse/federation/transport/server/federation.py
+++ b/synapse/federation/transport/server/federation.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from collections import Counter
from typing import (
TYPE_CHECKING,
Dict,
@@ -577,16 +578,23 @@ class FederationClientKeysClaimServlet(BaseFederationServerServlet):
async def on_POST(
self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
) -> Tuple[int, JsonDict]:
+ # Generate a count for each algorithm, which is hard-coded to 1.
+ key_query: List[Tuple[str, str, str, int]] = []
+ for user_id, device_keys in content.get("one_time_keys", {}).items():
+ for device_id, algorithm in device_keys.items():
+ key_query.append((user_id, device_id, algorithm, 1))
+
response = await self.handler.on_claim_client_keys(
- origin, content, always_include_fallback_keys=False
+ key_query, always_include_fallback_keys=False
)
return 200, response
class FederationUnstableClientKeysClaimServlet(BaseFederationServerServlet):
"""
- Identical to the stable endpoint (FederationClientKeysClaimServlet) except it
- always includes fallback keys in the response.
+ Identical to the stable endpoint (FederationClientKeysClaimServlet) except
+ it allows for querying for multiple OTKs at once and always includes fallback
+ keys in the response.
"""
PREFIX = FEDERATION_UNSTABLE_PREFIX
@@ -596,8 +604,16 @@ class FederationUnstableClientKeysClaimServlet(BaseFederationServerServlet):
async def on_POST(
self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
) -> Tuple[int, JsonDict]:
+ # Generate a count for each algorithm.
+ key_query: List[Tuple[str, str, str, int]] = []
+ for user_id, device_keys in content.get("one_time_keys", {}).items():
+ for device_id, algorithms in device_keys.items():
+ counts = Counter(algorithms)
+ for algorithm, count in counts.items():
+ key_query.append((user_id, device_id, algorithm, count))
+
response = await self.handler.on_claim_client_keys(
- origin, content, always_include_fallback_keys=True
+ key_query, always_include_fallback_keys=True
)
return 200, response
@@ -805,6 +821,7 @@ FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = (
FederationClientKeysQueryServlet,
FederationUserDevicesQueryServlet,
FederationClientKeysClaimServlet,
+ FederationUnstableClientKeysClaimServlet,
FederationThirdPartyInviteExchangeServlet,
On3pidBindServlet,
FederationVersionServlet,
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 4ca2bc0420..6429545c98 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -841,8 +841,10 @@ class ApplicationServicesHandler:
return True
async def claim_e2e_one_time_keys(
- self, query: Iterable[Tuple[str, str, str]]
- ) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]:
+ self, query: Iterable[Tuple[str, str, str, int]]
+ ) -> Tuple[
+ Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str, int]]
+ ]:
"""Claim one time keys from application services.
Users which are exclusively owned by an application service are sent a
@@ -863,18 +865,18 @@ class ApplicationServicesHandler:
services = self.store.get_app_services()
# Partition the users by appservice.
- query_by_appservice: Dict[str, List[Tuple[str, str, str]]] = {}
+ query_by_appservice: Dict[str, List[Tuple[str, str, str, int]]] = {}
missing = []
- for user_id, device, algorithm in query:
+ for user_id, device, algorithm, count in query:
if not self.store.get_if_app_services_interested_in_user(user_id):
- missing.append((user_id, device, algorithm))
+ missing.append((user_id, device, algorithm, count))
continue
# Find the associated appservice.
for service in services:
if service.is_exclusive_user(user_id):
query_by_appservice.setdefault(service.id, []).append(
- (user_id, device, algorithm)
+ (user_id, device, algorithm, count)
)
continue
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index d1ab95126c..24741b667b 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -564,7 +564,7 @@ class E2eKeysHandler:
async def claim_local_one_time_keys(
self,
- local_query: List[Tuple[str, str, str]],
+ local_query: List[Tuple[str, str, str, int]],
always_include_fallback_keys: bool,
) -> Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]]:
"""Claim one time keys for local users.
@@ -581,6 +581,12 @@ class E2eKeysHandler:
An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes.
"""
+ # Cap the number of OTKs that can be claimed at once to avoid abuse.
+ local_query = [
+ (user_id, device_id, algorithm, min(count, 5))
+ for user_id, device_id, algorithm, count in local_query
+ ]
+
otk_results, not_found = await self.store.claim_e2e_one_time_keys(local_query)
# If the application services have not provided any keys via the C-S
@@ -607,7 +613,7 @@ class E2eKeysHandler:
# from the appservice for that user ID / device ID. If it is found,
# check if any of the keys match the requested algorithm & are a
# fallback key.
- for user_id, device_id, algorithm in local_query:
+ for user_id, device_id, algorithm, _count in local_query:
# Check if the appservice responded for this query.
as_result = appservice_results.get(user_id, {}).get(device_id, {})
found_otk = False
@@ -630,13 +636,17 @@ class E2eKeysHandler:
.get(device_id, {})
.keys()
)
+ # Note that it doesn't make sense to request more than 1 fallback key
+ # per (user_id, device_id, algorithm).
fallback_query.append((user_id, device_id, algorithm, mark_as_used))
else:
# All fallback keys get marked as used.
fallback_query = [
+ # Note that it doesn't make sense to request more than 1 fallback key
+ # per (user_id, device_id, algorithm).
(user_id, device_id, algorithm, True)
- for user_id, device_id, algorithm in not_found
+ for user_id, device_id, algorithm, count in not_found
]
# For each user that does not have a one-time keys available, see if
@@ -650,18 +660,19 @@ class E2eKeysHandler:
@trace
async def claim_one_time_keys(
self,
- query: Dict[str, Dict[str, Dict[str, str]]],
+ query: Dict[str, Dict[str, Dict[str, int]]],
timeout: Optional[int],
always_include_fallback_keys: bool,
) -> JsonDict:
- local_query: List[Tuple[str, str, str]] = []
- remote_queries: Dict[str, Dict[str, Dict[str, str]]] = {}
+ local_query: List[Tuple[str, str, str, int]] = []
+ remote_queries: Dict[str, Dict[str, Dict[str, Dict[str, int]]]] = {}
- for user_id, one_time_keys in query.get("one_time_keys", {}).items():
+ for user_id, one_time_keys in query.items():
# we use UserID.from_string to catch invalid user ids
if self.is_mine(UserID.from_string(user_id)):
- for device_id, algorithm in one_time_keys.items():
- local_query.append((user_id, device_id, algorithm))
+ for device_id, algorithms in one_time_keys.items():
+ for algorithm, count in algorithms.items():
+ local_query.append((user_id, device_id, algorithm, count))
else:
domain = get_domain_from_id(user_id)
remote_queries.setdefault(domain, {})[user_id] = one_time_keys
@@ -692,7 +703,7 @@ class E2eKeysHandler:
device_keys = remote_queries[destination]
try:
remote_result = await self.federation.claim_client_keys(
- destination, {"one_time_keys": device_keys}, timeout=timeout
+ destination, device_keys, timeout=timeout
)
for user_id, keys in remote_result["one_time_keys"].items():
if user_id in device_keys:
diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py
index 2a25094109..9bbab5e624 100644
--- a/synapse/rest/client/keys.py
+++ b/synapse/rest/client/keys.py
@@ -16,7 +16,8 @@
import logging
import re
-from typing import TYPE_CHECKING, Any, Optional, Tuple
+from collections import Counter
+from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
from synapse.api.errors import InvalidAPICallError, SynapseError
from synapse.http.server import HttpServer
@@ -289,16 +290,40 @@ class OneTimeKeyServlet(RestServlet):
await self.auth.get_user_by_req(request, allow_guest=True)
timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request)
+
+ # Generate a count for each algorithm, which is hard-coded to 1.
+ query: Dict[str, Dict[str, Dict[str, int]]] = {}
+ for user_id, one_time_keys in body.get("one_time_keys", {}).items():
+ for device_id, algorithm in one_time_keys.items():
+ query.setdefault(user_id, {})[device_id] = {algorithm: 1}
+
result = await self.e2e_keys_handler.claim_one_time_keys(
- body, timeout, always_include_fallback_keys=False
+ query, timeout, always_include_fallback_keys=False
)
return 200, result
class UnstableOneTimeKeyServlet(RestServlet):
"""
- Identical to the stable endpoint (OneTimeKeyServlet) except it always includes
- fallback keys in the response.
+ Identical to the stable endpoint (OneTimeKeyServlet) except it allows for
+ querying for multiple OTKs at once and always includes fallback keys in the
+ response.
+
+ POST /keys/claim HTTP/1.1
+ {
+ "one_time_keys": {
+ "<user_id>": {
+ "<device_id>": ["<algorithm>", ...]
+ } } }
+
+ HTTP/1.1 200 OK
+ {
+ "one_time_keys": {
+ "<user_id>": {
+ "<device_id>": {
+ "<algorithm>:<key_id>": "<key_base64>"
+ } } } }
+
"""
PATTERNS = [re.compile(r"^/_matrix/client/unstable/org.matrix.msc3983/keys/claim$")]
@@ -313,8 +338,15 @@ class UnstableOneTimeKeyServlet(RestServlet):
await self.auth.get_user_by_req(request, allow_guest=True)
timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request)
+
+ # Generate a count for each algorithm.
+ query: Dict[str, Dict[str, Dict[str, int]]] = {}
+ for user_id, one_time_keys in body.get("one_time_keys", {}).items():
+ for device_id, algorithms in one_time_keys.items():
+ query.setdefault(user_id, {})[device_id] = Counter(algorithms)
+
result = await self.e2e_keys_handler.claim_one_time_keys(
- body, timeout, always_include_fallback_keys=True
+ query, timeout, always_include_fallback_keys=True
)
return 200, result
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 1a4ae55304..4bc391f213 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -1027,8 +1027,10 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
...
async def claim_e2e_one_time_keys(
- self, query_list: Iterable[Tuple[str, str, str]]
- ) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]:
+ self, query_list: Iterable[Tuple[str, str, str, int]]
+ ) -> Tuple[
+ Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str, int]]
+ ]:
"""Take a list of one time keys out of the database.
Args:
@@ -1043,8 +1045,12 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
@trace
def _claim_e2e_one_time_key_simple(
- txn: LoggingTransaction, user_id: str, device_id: str, algorithm: str
- ) -> Optional[Tuple[str, str]]:
+ txn: LoggingTransaction,
+ user_id: str,
+ device_id: str,
+ algorithm: str,
+ count: int,
+ ) -> List[Tuple[str, str]]:
"""Claim OTK for device for DBs that don't support RETURNING.
Returns:
@@ -1055,36 +1061,41 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
sql = """
SELECT key_id, key_json FROM e2e_one_time_keys_json
WHERE user_id = ? AND device_id = ? AND algorithm = ?
- LIMIT 1
+ LIMIT ?
"""
- txn.execute(sql, (user_id, device_id, algorithm))
- otk_row = txn.fetchone()
- if otk_row is None:
- return None
+ txn.execute(sql, (user_id, device_id, algorithm, count))
+ otk_rows = list(txn)
+ if not otk_rows:
+ return []
- key_id, key_json = otk_row
-
- self.db_pool.simple_delete_one_txn(
+ self.db_pool.simple_delete_many_txn(
txn,
table="e2e_one_time_keys_json",
+ column="key_id",
+ values=[otk_row[0] for otk_row in otk_rows],
keyvalues={
"user_id": user_id,
"device_id": device_id,
"algorithm": algorithm,
- "key_id": key_id,
},
)
self._invalidate_cache_and_stream(
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)
- return f"{algorithm}:{key_id}", key_json
+ return [
+ (f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows
+ ]
@trace
def _claim_e2e_one_time_key_returning(
- txn: LoggingTransaction, user_id: str, device_id: str, algorithm: str
- ) -> Optional[Tuple[str, str]]:
+ txn: LoggingTransaction,
+ user_id: str,
+ device_id: str,
+ algorithm: str,
+ count: int,
+ ) -> List[Tuple[str, str]]:
"""Claim OTK for device for DBs that support RETURNING.
Returns:
@@ -1099,28 +1110,30 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
AND key_id IN (
SELECT key_id FROM e2e_one_time_keys_json
WHERE user_id = ? AND device_id = ? AND algorithm = ?
- LIMIT 1
+ LIMIT ?
)
RETURNING key_id, key_json
"""
txn.execute(
- sql, (user_id, device_id, algorithm, user_id, device_id, algorithm)
+ sql,
+ (user_id, device_id, algorithm, user_id, device_id, algorithm, count),
)
- otk_row = txn.fetchone()
- if otk_row is None:
- return None
+ otk_rows = list(txn)
+ if not otk_rows:
+ return []
self._invalidate_cache_and_stream(
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)
- key_id, key_json = otk_row
- return f"{algorithm}:{key_id}", key_json
+ return [
+ (f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows
+ ]
results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
- missing: List[Tuple[str, str, str]] = []
- for user_id, device_id, algorithm in query_list:
+ missing: List[Tuple[str, str, str, int]] = []
+ for user_id, device_id, algorithm, count in query_list:
if self.database_engine.supports_returning:
# If we support RETURNING clause we can use a single query that
# allows us to use autocommit mode.
@@ -1130,21 +1143,25 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
_claim_e2e_one_time_key = _claim_e2e_one_time_key_simple
db_autocommit = False
- claim_row = await self.db_pool.runInteraction(
+ claim_rows = await self.db_pool.runInteraction(
"claim_e2e_one_time_keys",
_claim_e2e_one_time_key,
user_id,
device_id,
algorithm,
+ count,
db_autocommit=db_autocommit,
)
- if claim_row:
+ if claim_rows:
device_results = results.setdefault(user_id, {}).setdefault(
device_id, {}
)
- device_results[claim_row[0]] = json_decoder.decode(claim_row[1])
- else:
- missing.append((user_id, device_id, algorithm))
+ for claim_row in claim_rows:
+ device_results[claim_row[0]] = json_decoder.decode(claim_row[1])
+ # Did we get enough OTKs?
+ count -= len(claim_rows)
+ if count:
+ missing.append((user_id, device_id, algorithm, count))
return results, missing
|