diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 4ab75a351e..0f320b3764 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -1072,7 +1072,7 @@ class SignatureListItem:
class SigningKeyEduUpdater(object):
- "Handles incoming signing key updates from federation and updates the DB"
+ """Handles incoming signing key updates from federation and updates the DB"""
def __init__(self, hs, e2e_keys_handler):
self.store = hs.get_datastore()
@@ -1111,7 +1111,6 @@ class SigningKeyEduUpdater(object):
self_signing_key = edu_content.pop("self_signing_key", None)
if get_domain_from_id(user_id) != origin:
- # TODO: Raise?
logger.warning("Got signing key update edu for %r from %r", user_id, origin)
return
@@ -1122,7 +1121,7 @@ class SigningKeyEduUpdater(object):
return
self._pending_updates.setdefault(user_id, []).append(
- (master_key, self_signing_key, edu_content)
+ (master_key, self_signing_key)
)
yield self._handle_signing_key_updates(user_id)
@@ -1147,22 +1146,21 @@ class SigningKeyEduUpdater(object):
logger.info("pending updates: %r", pending_updates)
- for master_key, self_signing_key, edu_content in pending_updates:
+ for master_key, self_signing_key in pending_updates:
if master_key:
yield self.store.set_e2e_cross_signing_key(
user_id, "master", master_key
)
- device_id = get_verify_key_from_cross_signing_key(master_key)[
- 1
- ].version
- device_ids.append(device_id)
+ _, verify_key = get_verify_key_from_cross_signing_key(master_key)
+ # verify_key is a VerifyKey from signedjson, which uses
+ # .version to denote the portion of the key ID after the
+ # algorithm and colon, which is the device ID
+ device_ids.append(verify_key.version)
if self_signing_key:
yield self.store.set_e2e_cross_signing_key(
user_id, "self_signing", self_signing_key
)
- device_id = get_verify_key_from_cross_signing_key(self_signing_key)[
- 1
- ].version
- device_ids.append(device_id)
+ _, verify_key = get_verify_key_from_cross_signing_key(self_signing_key)
+ device_ids.append(verify_key.version)
yield device_handler.notify_device_update(user_id, device_ids)
diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py
index 6ac165068e..0b12bc58c4 100644
--- a/synapse/storage/data_stores/main/devices.py
+++ b/synapse/storage/data_stores/main/devices.py
@@ -92,8 +92,12 @@ class DeviceWorkerStore(SQLBaseStore):
@trace
@defer.inlineCallbacks
def get_devices_by_remote(self, destination, from_stream_id, limit):
- """Get stream of updates to send to remote servers
+ """Get a stream of device updates to send to the given remote server.
+ Args:
+ destination (str): The host the device updates are intended for
+ from_stream_id (int): The minimum stream_id to filter updates by, exclusive
+ limit (int): Maximum number of device updates to return
Returns:
Deferred[tuple[int, list[tuple[string,dict]]]]:
current stream id (ie, the stream id of the last update included in the
@@ -131,7 +135,8 @@ class DeviceWorkerStore(SQLBaseStore):
if not updates:
return now_stream_id, []
- # get the cross-signing keys of the users the list
+ # get the cross-signing keys of the users in the list, so that we can
+ # determine which of the device changes were cross-signing keys
users = set(r[0] for r in updates)
master_key_by_user = {}
self_signing_key_by_user = {}
@@ -141,9 +146,12 @@ class DeviceWorkerStore(SQLBaseStore):
key_id, verify_key = get_verify_key_from_cross_signing_key(
cross_signing_key
)
+ # verify_key is a VerifyKey from signedjson, which uses
+ # .version to denote the portion of the key ID after the
+ # algorithm and colon, which is the device ID
master_key_by_user[user] = {
"key_info": cross_signing_key,
- "pubkey": verify_key.version,
+ "device_id": verify_key.version,
}
cross_signing_key = yield self.get_e2e_cross_signing_key(
@@ -155,7 +163,7 @@ class DeviceWorkerStore(SQLBaseStore):
)
self_signing_key_by_user[user] = {
"key_info": cross_signing_key,
- "pubkey": verify_key.version,
+ "device_id": verify_key.version,
}
# if we have exceeded the limit, we need to exclude any results with the
@@ -182,69 +190,54 @@ class DeviceWorkerStore(SQLBaseStore):
# context which created the Edu.
query_map = {}
- for update in updates:
- if stream_id_cutoff is not None and update[2] >= stream_id_cutoff:
+ cross_signing_keys_by_user = {}
+ for user_id, device_id, update_stream_id, update_context in updates:
+ if stream_id_cutoff is not None and update_stream_id >= stream_id_cutoff:
# Stop processing updates
break
- # skip over cross-signing keys
if (
- update[0] in master_key_by_user
- and update[1] == master_key_by_user[update[0]]["pubkey"]
- ) or (
- update[0] in master_key_by_user
- and update[1] == self_signing_key_by_user[update[0]]["pubkey"]
+ user_id in master_key_by_user
+ and device_id == master_key_by_user[user_id]["device_id"]
):
- continue
-
- key = (update[0], update[1])
-
- update_context = update[3]
- update_stream_id = update[2]
-
- previous_update_stream_id, _ = query_map.get(key, (0, None))
-
- if update_stream_id > previous_update_stream_id:
- query_map[key] = (update_stream_id, update_context)
-
- # If we didn't find any updates with a stream_id lower than the cutoff, it
- # means that there are more than limit updates all of which have the same
- # steam_id.
-
- # figure out which cross-signing keys were changed by intersecting the
- # update list with the master/self-signing key by user maps
- cross_signing_keys_by_user = {}
- for user_id, device_id, stream, _opentracing_context in updates:
- if device_id == master_key_by_user.get(user_id, {}).get("pubkey", None):
result = cross_signing_keys_by_user.setdefault(user_id, {})
result["master_key"] = master_key_by_user[user_id]["key_info"]
- elif device_id == self_signing_key_by_user.get(user_id, {}).get(
- "pubkey", None
+ elif (
+ user_id in master_key_by_user
+ and device_id == self_signing_key_by_user[user_id]["device_id"]
):
result = cross_signing_keys_by_user.setdefault(user_id, {})
result["self_signing_key"] = self_signing_key_by_user[user_id][
"key_info"
]
+ else:
+ key = (user_id, device_id)
- cross_signing_results = []
+ previous_update_stream_id, _ = query_map.get(key, (0, None))
- # add the updated cross-signing keys to the results list
- for user_id, result in iteritems(cross_signing_keys_by_user):
- result["user_id"] = user_id
- # FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec
- cross_signing_results.append(("org.matrix.signing_key_update", result))
+ if update_stream_id > previous_update_stream_id:
+ query_map[key] = (update_stream_id, update_context)
+
+ # If we didn't find any updates with a stream_id lower than the cutoff, it
+ # means that there are more than limit updates all of which have the same
+ # steam_id.
# That should only happen if a client is spamming the server with new
# devices, in which case E2E isn't going to work well anyway. We'll just
# skip that stream_id and return an empty list, and continue with the next
# stream_id next time.
- if not query_map and not cross_signing_results:
+ if not query_map and not cross_signing_keys_by_user:
return stream_id_cutoff, []
results = yield self._get_device_update_edus_by_remote(
destination, from_stream_id, query_map
)
- results.extend(cross_signing_results)
+
+ # add the updated cross-signing keys to the results list
+ for user_id, result in iteritems(cross_signing_keys_by_user):
+ result["user_id"] = user_id
+ # FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec
+ results.append(("org.matrix.signing_key_update", result))
return now_stream_id, results
|