summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/data_stores/main/devices.py79
1 files changed, 36 insertions, 43 deletions
diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py
index 6ac165068e..0b12bc58c4 100644
--- a/synapse/storage/data_stores/main/devices.py
+++ b/synapse/storage/data_stores/main/devices.py
@@ -92,8 +92,12 @@ class DeviceWorkerStore(SQLBaseStore):
     @trace
     @defer.inlineCallbacks
     def get_devices_by_remote(self, destination, from_stream_id, limit):
-        """Get stream of updates to send to remote servers
+        """Get a stream of device updates to send to the given remote server.
 
+        Args:
+            destination (str): The host the device updates are intended for
+            from_stream_id (int): The minimum stream_id to filter updates by, exclusive
+            limit (int): Maximum number of device updates to return
         Returns:
             Deferred[tuple[int, list[tuple[string,dict]]]]:
                 current stream id (ie, the stream id of the last update included in the
@@ -131,7 +135,8 @@ class DeviceWorkerStore(SQLBaseStore):
         if not updates:
             return now_stream_id, []
 
-        # get the cross-signing keys of the users the list
+        # get the cross-signing keys of the users in the list, so that we can
+        # determine which of the device changes were cross-signing keys
         users = set(r[0] for r in updates)
         master_key_by_user = {}
         self_signing_key_by_user = {}
@@ -141,9 +146,12 @@ class DeviceWorkerStore(SQLBaseStore):
                 key_id, verify_key = get_verify_key_from_cross_signing_key(
                     cross_signing_key
                 )
+                # verify_key is a VerifyKey from signedjson, which uses
+                # .version to denote the portion of the key ID after the
+                # algorithm and colon, which is the device ID
                 master_key_by_user[user] = {
                     "key_info": cross_signing_key,
-                    "pubkey": verify_key.version,
+                    "device_id": verify_key.version,
                 }
 
             cross_signing_key = yield self.get_e2e_cross_signing_key(
@@ -155,7 +163,7 @@ class DeviceWorkerStore(SQLBaseStore):
                 )
                 self_signing_key_by_user[user] = {
                     "key_info": cross_signing_key,
-                    "pubkey": verify_key.version,
+                    "device_id": verify_key.version,
                 }
 
         # if we have exceeded the limit, we need to exclude any results with the
@@ -182,69 +190,54 @@ class DeviceWorkerStore(SQLBaseStore):
         # context which created the Edu.
 
         query_map = {}
-        for update in updates:
-            if stream_id_cutoff is not None and update[2] >= stream_id_cutoff:
+        cross_signing_keys_by_user = {}
+        for user_id, device_id, update_stream_id, update_context in updates:
+            if stream_id_cutoff is not None and update_stream_id >= stream_id_cutoff:
                 # Stop processing updates
                 break
 
-            # skip over cross-signing keys
             if (
-                update[0] in master_key_by_user
-                and update[1] == master_key_by_user[update[0]]["pubkey"]
-            ) or (
-                update[0] in master_key_by_user
-                and update[1] == self_signing_key_by_user[update[0]]["pubkey"]
+                user_id in master_key_by_user
+                and device_id == master_key_by_user[user_id]["device_id"]
             ):
-                continue
-
-            key = (update[0], update[1])
-
-            update_context = update[3]
-            update_stream_id = update[2]
-
-            previous_update_stream_id, _ = query_map.get(key, (0, None))
-
-            if update_stream_id > previous_update_stream_id:
-                query_map[key] = (update_stream_id, update_context)
-
-        # If we didn't find any updates with a stream_id lower than the cutoff, it
-        # means that there are more than limit updates all of which have the same
-        # steam_id.
-
-        # figure out which cross-signing keys were changed by intersecting the
-        # update list with the master/self-signing key by user maps
-        cross_signing_keys_by_user = {}
-        for user_id, device_id, stream, _opentracing_context in updates:
-            if device_id == master_key_by_user.get(user_id, {}).get("pubkey", None):
                 result = cross_signing_keys_by_user.setdefault(user_id, {})
                 result["master_key"] = master_key_by_user[user_id]["key_info"]
-            elif device_id == self_signing_key_by_user.get(user_id, {}).get(
-                "pubkey", None
+            elif (
+                user_id in master_key_by_user
+                and device_id == self_signing_key_by_user[user_id]["device_id"]
             ):
                 result = cross_signing_keys_by_user.setdefault(user_id, {})
                 result["self_signing_key"] = self_signing_key_by_user[user_id][
                     "key_info"
                 ]
+            else:
+                key = (user_id, device_id)
 
-        cross_signing_results = []
+                previous_update_stream_id, _ = query_map.get(key, (0, None))
 
-        # add the updated cross-signing keys to the results list
-        for user_id, result in iteritems(cross_signing_keys_by_user):
-            result["user_id"] = user_id
-            # FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec
-            cross_signing_results.append(("org.matrix.signing_key_update", result))
+                if update_stream_id > previous_update_stream_id:
+                    query_map[key] = (update_stream_id, update_context)
+
+        # If we didn't find any updates with a stream_id lower than the cutoff, it
+        # means that there are more than limit updates all of which have the same
+        # steam_id.
 
         # That should only happen if a client is spamming the server with new
         # devices, in which case E2E isn't going to work well anyway. We'll just
         # skip that stream_id and return an empty list, and continue with the next
         # stream_id next time.
-        if not query_map and not cross_signing_results:
+        if not query_map and not cross_signing_keys_by_user:
             return stream_id_cutoff, []
 
         results = yield self._get_device_update_edus_by_remote(
             destination, from_stream_id, query_map
         )
-        results.extend(cross_signing_results)
+
+        # add the updated cross-signing keys to the results list
+        for user_id, result in iteritems(cross_signing_keys_by_user):
+            result["user_id"] = user_id
+            # FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec
+            results.append(("org.matrix.signing_key_update", result))
 
         return now_stream_id, results