diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py
index f7a3542348..71f62036c0 100644
--- a/synapse/storage/data_stores/main/devices.py
+++ b/synapse/storage/data_stores/main/devices.py
@@ -37,6 +37,7 @@ from synapse.storage._base import (
make_in_list_sql_clause,
)
from synapse.storage.background_updates import BackgroundUpdateStore
+from synapse.types import get_verify_key_from_cross_signing_key
from synapse.util import batch_iter
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
@@ -90,13 +91,18 @@ class DeviceWorkerStore(SQLBaseStore):
@trace
@defer.inlineCallbacks
- def get_devices_by_remote(self, destination, from_stream_id, limit):
- """Get stream of updates to send to remote servers
+ def get_device_updates_by_remote(self, destination, from_stream_id, limit):
+ """Get a stream of device updates to send to the given remote server.
+ Args:
+ destination (str): The host the device updates are intended for
+ from_stream_id (int): The minimum stream_id to filter updates by, exclusive
+ limit (int): Maximum number of device updates to return
Returns:
- Deferred[tuple[int, list[dict]]]:
+ Deferred[tuple[int, list[tuple[string,dict]]]]:
current stream id (ie, the stream id of the last update included in the
- response), and the list of updates
+ response), and the list of updates, where each update is a pair of EDU
+ type and EDU contents
"""
now_stream_id = self._device_list_id_gen.get_current_token()
@@ -117,8 +123,8 @@ class DeviceWorkerStore(SQLBaseStore):
# stream_id; the rationale being that such a large device list update
# is likely an error.
updates = yield self.runInteraction(
- "get_devices_by_remote",
- self._get_devices_by_remote_txn,
+ "get_device_updates_by_remote",
+ self._get_device_updates_by_remote_txn,
destination,
from_stream_id,
now_stream_id,
@@ -129,6 +135,37 @@ class DeviceWorkerStore(SQLBaseStore):
if not updates:
return now_stream_id, []
+ # get the cross-signing keys of the users in the list, so that we can
+ # determine which of the device changes were cross-signing keys
+ users = set(r[0] for r in updates)
+ master_key_by_user = {}
+ self_signing_key_by_user = {}
+ for user in users:
+ cross_signing_key = yield self.get_e2e_cross_signing_key(user, "master")
+ if cross_signing_key:
+ key_id, verify_key = get_verify_key_from_cross_signing_key(
+ cross_signing_key
+ )
+ # verify_key is a VerifyKey from signedjson, which uses
+ # .version to denote the portion of the key ID after the
+ # algorithm and colon, which is the device ID
+ master_key_by_user[user] = {
+ "key_info": cross_signing_key,
+ "device_id": verify_key.version,
+ }
+
+ cross_signing_key = yield self.get_e2e_cross_signing_key(
+ user, "self_signing"
+ )
+ if cross_signing_key:
+ key_id, verify_key = get_verify_key_from_cross_signing_key(
+ cross_signing_key
+ )
+ self_signing_key_by_user[user] = {
+ "key_info": cross_signing_key,
+ "device_id": verify_key.version,
+ }
+
# if we have exceeded the limit, we need to exclude any results with the
# same stream_id as the last row.
if len(updates) > limit:
@@ -153,20 +190,33 @@ class DeviceWorkerStore(SQLBaseStore):
# context which created the Edu.
query_map = {}
- for update in updates:
- if stream_id_cutoff is not None and update[2] >= stream_id_cutoff:
+ cross_signing_keys_by_user = {}
+ for user_id, device_id, update_stream_id, update_context in updates:
+ if stream_id_cutoff is not None and update_stream_id >= stream_id_cutoff:
# Stop processing updates
break
- key = (update[0], update[1])
-
- update_context = update[3]
- update_stream_id = update[2]
+ if (
+ user_id in master_key_by_user
+ and device_id == master_key_by_user[user_id]["device_id"]
+ ):
+ result = cross_signing_keys_by_user.setdefault(user_id, {})
+ result["master_key"] = master_key_by_user[user_id]["key_info"]
+ elif (
+ user_id in self_signing_key_by_user
+ and device_id == self_signing_key_by_user[user_id]["device_id"]
+ ):
+ result = cross_signing_keys_by_user.setdefault(user_id, {})
+ result["self_signing_key"] = self_signing_key_by_user[user_id][
+ "key_info"
+ ]
+ else:
+ key = (user_id, device_id)
- previous_update_stream_id, _ = query_map.get(key, (0, None))
+ previous_update_stream_id, _ = query_map.get(key, (0, None))
- if update_stream_id > previous_update_stream_id:
- query_map[key] = (update_stream_id, update_context)
+ if update_stream_id > previous_update_stream_id:
+ query_map[key] = (update_stream_id, update_context)
# If we didn't find any updates with a stream_id lower than the cutoff, it
# means that there are more than limit updates all of which have the same
@@ -176,16 +226,22 @@ class DeviceWorkerStore(SQLBaseStore):
# devices, in which case E2E isn't going to work well anyway. We'll just
# skip that stream_id and return an empty list, and continue with the next
# stream_id next time.
- if not query_map:
+ if not query_map and not cross_signing_keys_by_user:
return stream_id_cutoff, []
results = yield self._get_device_update_edus_by_remote(
destination, from_stream_id, query_map
)
+ # add the updated cross-signing keys to the results list
+ for user_id, result in iteritems(cross_signing_keys_by_user):
+ result["user_id"] = user_id
+ # FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec
+ results.append(("org.matrix.signing_key_update", result))
+
return now_stream_id, results
- def _get_devices_by_remote_txn(
+ def _get_device_updates_by_remote_txn(
self, txn, destination, from_stream_id, now_stream_id, limit
):
"""Return device update information for a given remote destination
@@ -200,6 +256,7 @@ class DeviceWorkerStore(SQLBaseStore):
Returns:
List: List of device updates
"""
+ # get the list of device updates that need to be sent
sql = """
SELECT user_id, device_id, stream_id, opentracing_context FROM device_lists_outbound_pokes
WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
@@ -225,12 +282,16 @@ class DeviceWorkerStore(SQLBaseStore):
List[Dict]: List of objects representing an device update EDU
"""
- devices = yield self.runInteraction(
- "_get_e2e_device_keys_txn",
- self._get_e2e_device_keys_txn,
- query_map.keys(),
- include_all_devices=True,
- include_deleted_devices=True,
+ devices = (
+ yield self.runInteraction(
+ "_get_e2e_device_keys_txn",
+ self._get_e2e_device_keys_txn,
+ query_map.keys(),
+ include_all_devices=True,
+ include_deleted_devices=True,
+ )
+ if query_map
+ else {}
)
results = []
@@ -262,7 +323,7 @@ class DeviceWorkerStore(SQLBaseStore):
else:
result["deleted"] = True
- results.append(result)
+ results.append(("m.device_list_update", result))
return results
|