diff --git a/changelog.d/16010.misc b/changelog.d/16010.misc
new file mode 100644
index 0000000000..1e1a148069
--- /dev/null
+++ b/changelog.d/16010.misc
@@ -0,0 +1 @@
+Update dehydrated devices implementation.
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index b7bf70a72d..5ae427d52c 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -385,6 +385,7 @@ class DeviceHandler(DeviceWorkerHandler):
self.federation_sender = hs.get_federation_sender()
self._account_data_handler = hs.get_account_data_handler()
self._storage_controllers = hs.get_storage_controllers()
+ self.db_pool = hs.get_datastores().main.db_pool
self.device_list_updater = DeviceListUpdater(hs, self)
@@ -656,15 +657,17 @@ class DeviceHandler(DeviceWorkerHandler):
device_id: Optional[str],
device_data: JsonDict,
initial_device_display_name: Optional[str] = None,
+ keys_for_device: Optional[JsonDict] = None,
) -> str:
- """Store a dehydrated device for a user. If the user had a previous
- dehydrated device, it is removed.
+ """Store a dehydrated device for a user, optionally storing the keys associated with
+ it as well. If the user had a previous dehydrated device, it is removed.
Args:
user_id: the user that we are storing the device for
device_id: device id supplied by client
device_data: the dehydrated device information
initial_device_display_name: The display name to use for the device
+ keys_for_device: keys for the dehydrated device
Returns:
device id of the dehydrated device
"""
@@ -673,11 +676,16 @@ class DeviceHandler(DeviceWorkerHandler):
device_id,
initial_device_display_name,
)
+
+ time_now = self.clock.time_msec()
+
old_device_id = await self.store.store_dehydrated_device(
- user_id, device_id, device_data
+ user_id, device_id, device_data, time_now, keys_for_device
)
+
if old_device_id is not None:
await self.delete_devices(user_id, [old_device_id])
+
return device_id
async def rehydrate_device(
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index 15e94a03cb..17ff8821d9 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -367,19 +367,6 @@ class DeviceMessageHandler:
errcode=Codes.INVALID_PARAM,
)
- # if we have a since token, delete any to-device messages before that token
- # (since we now know that the device has received them)
- deleted = await self.store.delete_messages_for_device(
- user_id, device_id, since_stream_id
- )
- logger.debug(
- "Deleted %d to-device messages up to %d for user_id %s device_id %s",
- deleted,
- since_stream_id,
- user_id,
- device_id,
- )
-
to_token = self.event_sources.get_current_token().to_device_key
messages, stream_id = await self.store.get_messages_for_device(
diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py
index 51f17f80da..925f037743 100644
--- a/synapse/rest/client/devices.py
+++ b/synapse/rest/client/devices.py
@@ -29,7 +29,6 @@ from synapse.http.servlet import (
parse_integer,
)
from synapse.http.site import SynapseRequest
-from synapse.replication.http.devices import ReplicationUploadKeysForUserRestServlet
from synapse.rest.client._base import client_patterns, interactive_auth_handler
from synapse.rest.client.models import AuthenticationData
from synapse.rest.models import RequestBodyModel
@@ -480,13 +479,6 @@ class DehydratedDeviceV2Servlet(RestServlet):
self.e2e_keys_handler = hs.get_e2e_keys_handler()
self.device_handler = handler
- if hs.config.worker.worker_app is None:
- # if main process
- self.key_uploader = self.e2e_keys_handler.upload_keys_for_user
- else:
- # then a worker
- self.key_uploader = ReplicationUploadKeysForUserRestServlet.make_client(hs)
-
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
@@ -549,18 +541,12 @@ class DehydratedDeviceV2Servlet(RestServlet):
"Device key(s) not found, these must be provided.",
)
- # TODO: Those two operations, creating a device and storing the
- # device's keys should be atomic.
device_id = await self.device_handler.store_dehydrated_device(
requester.user.to_string(),
submission.device_id,
submission.device_data.dict(),
submission.initial_device_display_name,
- )
-
- # TODO: Do we need to do something with the result here?
- await self.key_uploader(
- user_id=user_id, device_id=submission.device_id, keys=submission.dict()
+ device_info,
)
return 200, {"device_id": device_id}
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index d9df437e51..e4162f846b 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -28,6 +28,7 @@ from typing import (
cast,
)
+from canonicaljson import encode_canonical_json
from typing_extensions import Literal
from synapse.api.constants import EduTypes
@@ -1188,8 +1189,42 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
)
def _store_dehydrated_device_txn(
- self, txn: LoggingTransaction, user_id: str, device_id: str, device_data: str
+ self,
+ txn: LoggingTransaction,
+ user_id: str,
+ device_id: str,
+ device_data: str,
+ time: int,
+ keys: Optional[JsonDict] = None,
) -> Optional[str]:
+ # TODO: make keys non-optional once support for msc2697 is dropped
+ if keys:
+ device_keys = keys.get("device_keys", None)
+ if device_keys:
+ # Type ignore - this function is defined on EndToEndKeyStore which we do
+ # have access to due to hs.get_datastore() "magic"
+ self._set_e2e_device_keys_txn( # type: ignore[attr-defined]
+ txn, user_id, device_id, time, device_keys
+ )
+
+ one_time_keys = keys.get("one_time_keys", None)
+ if one_time_keys:
+ key_list = []
+ for key_id, key_obj in one_time_keys.items():
+ algorithm, key_id = key_id.split(":")
+ key_list.append(
+ (
+ algorithm,
+ key_id,
+ encode_canonical_json(key_obj).decode("ascii"),
+ )
+ )
+ self._add_e2e_one_time_keys_txn(txn, user_id, device_id, time, key_list)
+
+ fallback_keys = keys.get("fallback_keys", None)
+ if fallback_keys:
+ self._set_e2e_fallback_keys_txn(txn, user_id, device_id, fallback_keys)
+
old_device_id = self.db_pool.simple_select_one_onecol_txn(
txn,
table="dehydrated_devices",
@@ -1203,10 +1238,16 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
keyvalues={"user_id": user_id},
values={"device_id": device_id, "device_data": device_data},
)
+
return old_device_id
async def store_dehydrated_device(
- self, user_id: str, device_id: str, device_data: JsonDict
+ self,
+ user_id: str,
+ device_id: str,
+ device_data: JsonDict,
+ time_now: int,
+ keys: Optional[dict] = None,
) -> Optional[str]:
"""Store a dehydrated device for a user.
@@ -1214,15 +1255,21 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
user_id: the user that we are storing the device for
device_id: the ID of the dehydrated device
device_data: the dehydrated device information
+ time_now: current time at the request in milliseconds
+ keys: keys for the dehydrated device
+
Returns:
device id of the user's previous dehydrated device, if any
"""
+
return await self.db_pool.runInteraction(
"store_dehydrated_device_txn",
self._store_dehydrated_device_txn,
user_id,
device_id,
json_encoder.encode(device_data),
+ time_now,
+ keys,
)
async def remove_dehydrated_device(self, user_id: str, device_id: str) -> bool:
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 91ae9c457d..b49dea577c 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -522,36 +522,57 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
new_keys: keys to add - each a tuple of (algorithm, key_id, key json)
"""
- def _add_e2e_one_time_keys(txn: LoggingTransaction) -> None:
- set_tag("user_id", user_id)
- set_tag("device_id", device_id)
- set_tag("new_keys", str(new_keys))
- # We are protected from race between lookup and insertion due to
- # a unique constraint. If there is a race of two calls to
- # `add_e2e_one_time_keys` then they'll conflict and we will only
- # insert one set.
- self.db_pool.simple_insert_many_txn(
- txn,
- table="e2e_one_time_keys_json",
- keys=(
- "user_id",
- "device_id",
- "algorithm",
- "key_id",
- "ts_added_ms",
- "key_json",
- ),
- values=[
- (user_id, device_id, algorithm, key_id, time_now, json_bytes)
- for algorithm, key_id, json_bytes in new_keys
- ],
- )
- self._invalidate_cache_and_stream(
- txn, self.count_e2e_one_time_keys, (user_id, device_id)
- )
-
await self.db_pool.runInteraction(
- "add_e2e_one_time_keys_insert", _add_e2e_one_time_keys
+ "add_e2e_one_time_keys_insert",
+ self._add_e2e_one_time_keys_txn,
+ user_id,
+ device_id,
+ time_now,
+ new_keys,
+ )
+
+ def _add_e2e_one_time_keys_txn(
+ self,
+ txn: LoggingTransaction,
+ user_id: str,
+ device_id: str,
+ time_now: int,
+ new_keys: Iterable[Tuple[str, str, str]],
+ ) -> None:
+ """Insert some new one time keys for a device. Errors if any of the keys already exist.
+
+ Args:
+ user_id: id of user to get keys for
+ device_id: id of device to get keys for
+ time_now: insertion time to record (ms since epoch)
+ new_keys: keys to add - each a tuple of (algorithm, key_id, key json) - note
+ that the key JSON must be in canonical JSON form
+ """
+ set_tag("user_id", user_id)
+ set_tag("device_id", device_id)
+ set_tag("new_keys", str(new_keys))
+ # We are protected from race between lookup and insertion due to
+ # a unique constraint. If there is a race of two calls to
+ # `add_e2e_one_time_keys` then they'll conflict and we will only
+ # insert one set.
+ self.db_pool.simple_insert_many_txn(
+ txn,
+ table="e2e_one_time_keys_json",
+ keys=(
+ "user_id",
+ "device_id",
+ "algorithm",
+ "key_id",
+ "ts_added_ms",
+ "key_json",
+ ),
+ values=[
+ (user_id, device_id, algorithm, key_id, time_now, json_bytes)
+ for algorithm, key_id, json_bytes in new_keys
+ ],
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.count_e2e_one_time_keys, (user_id, device_id)
)
@cached(max_entries=10000)
@@ -723,6 +744,14 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
device_id: str,
fallback_keys: JsonDict,
) -> None:
+ """Set the user's e2e fallback keys.
+
+ Args:
+ user_id: the user whose keys are being set
+ device_id: the device whose keys are being set
+ fallback_keys: the keys to set. This is a map from key ID (which is
+ of the form "algorithm:id") to key data.
+ """
# fallback_keys will usually only have one item in it, so using a for
# loop (as opposed to calling simple_upsert_many_txn) won't be too bad
# FIXME: make sure that only one key per algorithm is uploaded
@@ -1304,42 +1333,69 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
) -> bool:
"""Stores device keys for a device. Returns whether there was a change
or the keys were already in the database.
+
+ Args:
+ user_id: user_id of the user to store keys for
+ device_id: device_id of the device to store keys for
+ time_now: time at the request to store the keys
+ device_keys: the keys to store
"""
- def _set_e2e_device_keys_txn(txn: LoggingTransaction) -> bool:
- set_tag("user_id", user_id)
- set_tag("device_id", device_id)
- set_tag("time_now", time_now)
- set_tag("device_keys", str(device_keys))
+ return await self.db_pool.runInteraction(
+ "set_e2e_device_keys",
+ self._set_e2e_device_keys_txn,
+ user_id,
+ device_id,
+ time_now,
+ device_keys,
+ )
- old_key_json = self.db_pool.simple_select_one_onecol_txn(
- txn,
- table="e2e_device_keys_json",
- keyvalues={"user_id": user_id, "device_id": device_id},
- retcol="key_json",
- allow_none=True,
- )
+ def _set_e2e_device_keys_txn(
+ self,
+ txn: LoggingTransaction,
+ user_id: str,
+ device_id: str,
+ time_now: int,
+ device_keys: JsonDict,
+ ) -> bool:
+ """Stores device keys for a device. Returns whether there was a change
+ or the keys were already in the database.
- # In py3 we need old_key_json to match new_key_json type. The DB
- # returns unicode while encode_canonical_json returns bytes.
- new_key_json = encode_canonical_json(device_keys).decode("utf-8")
+ Args:
+ user_id: user_id of the user to store keys for
+ device_id: device_id of the device to store keys for
+ time_now: time at the request to store the keys
+ device_keys: the keys to store
+ """
+ set_tag("user_id", user_id)
+ set_tag("device_id", device_id)
+ set_tag("time_now", time_now)
+ set_tag("device_keys", str(device_keys))
+
+ old_key_json = self.db_pool.simple_select_one_onecol_txn(
+ txn,
+ table="e2e_device_keys_json",
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ retcol="key_json",
+ allow_none=True,
+ )
- if old_key_json == new_key_json:
- log_kv({"Message": "Device key already stored."})
- return False
+ # In py3 we need old_key_json to match new_key_json type. The DB
+ # returns unicode while encode_canonical_json returns bytes.
+ new_key_json = encode_canonical_json(device_keys).decode("utf-8")
- self.db_pool.simple_upsert_txn(
- txn,
- table="e2e_device_keys_json",
- keyvalues={"user_id": user_id, "device_id": device_id},
- values={"ts_added_ms": time_now, "key_json": new_key_json},
- )
- log_kv({"message": "Device keys stored."})
- return True
+ if old_key_json == new_key_json:
+ log_kv({"Message": "Device key already stored."})
+ return False
- return await self.db_pool.runInteraction(
- "set_e2e_device_keys", _set_e2e_device_keys_txn
+ self.db_pool.simple_upsert_txn(
+ txn,
+ table="e2e_device_keys_json",
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ values={"ts_added_ms": time_now, "key_json": new_key_json},
)
+ log_kv({"message": "Device keys stored."})
+ return True
async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None:
def delete_e2e_keys_by_device_txn(txn: LoggingTransaction) -> None:
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index 647ee09279..e1e58fa6e6 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -566,15 +566,16 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(res["events"]), 1)
self.assertEqual(res["events"][0]["content"]["body"], "foo")
- # Fetch the message of the dehydrated device again, which should return nothing
- # and delete the old messages
+ # Fetch the message of the dehydrated device again, which should return
+ # the same message as it has not been deleted
res = self.get_success(
self.message_handler.get_events_for_dehydrated_device(
requester=requester,
device_id=stored_dehydrated_device_id,
- since_token=res["next_batch"],
+ since_token=None,
limit=10,
)
)
self.assertTrue(len(res["next_batch"]) > 1)
- self.assertEqual(len(res["events"]), 0)
+ self.assertEqual(len(res["events"]), 1)
+ self.assertEqual(res["events"][0]["content"]["body"], "foo")
diff --git a/tests/rest/client/test_devices.py b/tests/rest/client/test_devices.py
index 3cf29c10ea..60099f8c59 100644
--- a/tests/rest/client/test_devices.py
+++ b/tests/rest/client/test_devices.py
@@ -20,7 +20,7 @@ from synapse.api.errors import NotFoundError
from synapse.rest import admin, devices, room, sync
from synapse.rest.client import account, keys, login, register
from synapse.server import HomeServer
-from synapse.types import JsonDict, create_requester
+from synapse.types import JsonDict, UserID, create_requester
from synapse.util import Clock
from tests import unittest
@@ -282,6 +282,17 @@ class DehydratedDeviceTestCase(unittest.HomeserverTestCase):
"<user_id>": {"<algorithm>:<device_id>": "<signature_base64>"}
},
},
+ "fallback_keys": {
+ "alg1:device1": "f4llb4ckk3y",
+ "signed_<algorithm>:<device_id>": {
+ "fallback": "true",
+ "key": "f4llb4ckk3y",
+ "signatures": {
+ "<user_id>": {"<algorithm>:<device_id>": "<key_base64>"}
+ },
+ },
+ },
+ "one_time_keys": {"alg1:k1": "0net1m3k3y"},
}
channel = self.make_request(
"PUT",
@@ -312,6 +323,55 @@ class DehydratedDeviceTestCase(unittest.HomeserverTestCase):
}
self.assertEqual(device_data, expected_device_data)
+ # test that the keys are correctly uploaded
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/keys/query",
+ {
+ "device_keys": {
+ user: ["device1"],
+ },
+ },
+ token,
+ )
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(
+ channel.json_body["device_keys"][user][device_id]["keys"],
+ content["device_keys"]["keys"],
+ )
+ # first claim should return the onetime key we uploaded
+ res = self.get_success(
+ self.hs.get_e2e_keys_handler().claim_one_time_keys(
+ {user: {device_id: {"alg1": 1}}},
+ UserID.from_string(user),
+ timeout=None,
+ always_include_fallback_keys=False,
+ )
+ )
+ self.assertEqual(
+ res,
+ {
+ "failures": {},
+ "one_time_keys": {user: {device_id: {"alg1:k1": "0net1m3k3y"}}},
+ },
+ )
+ # second claim should return fallback key
+ res2 = self.get_success(
+ self.hs.get_e2e_keys_handler().claim_one_time_keys(
+ {user: {device_id: {"alg1": 1}}},
+ UserID.from_string(user),
+ timeout=None,
+ always_include_fallback_keys=False,
+ )
+ )
+ self.assertEqual(
+ res2,
+ {
+ "failures": {},
+ "one_time_keys": {user: {device_id: {"alg1:device1": "f4llb4ckk3y"}}},
+ },
+ )
+
# create another device for the user
(
new_device_id,
@@ -348,10 +408,21 @@ class DehydratedDeviceTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200)
expected_content = {"body": "test_message"}
self.assertEqual(channel.json_body["events"][0]["content"], expected_content)
+
+ # fetch messages again and make sure that the message was not deleted
+ channel = self.make_request(
+ "POST",
+ f"_matrix/client/unstable/org.matrix.msc3814.v1/dehydrated_device/{device_id}/events",
+ content={},
+ access_token=token,
+ shorthand=False,
+ )
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body["events"][0]["content"], expected_content)
next_batch_token = channel.json_body.get("next_batch")
- # fetch messages again and make sure that the message was deleted and we are returned an
- # empty array
+ # make sure fetching messages with next batch token works - there are no unfetched
+ # messages so we should receive an empty array
content = {"next_batch": next_batch_token}
channel = self.make_request(
"POST",
|