summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/handlers/device.py58
-rw-r--r--synapse/handlers/e2e_keys.py22
2 files changed, 62 insertions, 18 deletions
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 2cbb695bb1..230d170258 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -15,6 +15,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+from typing import Any, Dict, Optional
 
 from six import iteritems, itervalues
 
@@ -30,7 +31,11 @@ from synapse.api.errors import (
 )
 from synapse.logging.opentracing import log_kv, set_tag, trace
 from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.types import RoomStreamToken, get_domain_from_id
+from synapse.types import (
+    RoomStreamToken,
+    get_domain_from_id,
+    get_verify_key_from_cross_signing_key,
+)
 from synapse.util import stringutils
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches.expiringcache import ExpiringCache
@@ -795,6 +800,13 @@ class DeviceListUpdater(object):
         stream_id = result["stream_id"]
         devices = result["devices"]
 
+        # Get the master key and the self-signing key for this user if provided in the
+        # response (None if not in the response).
+        # The response will not contain the user signing key, as this key is only used by
+        # its owner, thus it doesn't make sense to send it over federation.
+        master_key = result.get("master_key")
+        self_signing_key = result.get("self_signing_key")
+
         # If the remote server has more than ~1000 devices for this user
         # we assume that something is going horribly wrong (e.g. a bot
         # that logs in and creates a new device every time it tries to
@@ -824,6 +836,13 @@ class DeviceListUpdater(object):
 
         yield self.store.update_remote_device_list_cache(user_id, devices, stream_id)
         device_ids = [device["device_id"] for device in devices]
+
+        # Handle cross-signing keys.
+        cross_signing_device_ids = yield self.process_cross_signing_key_update(
+            user_id, master_key, self_signing_key,
+        )
+        device_ids = device_ids + cross_signing_device_ids
+
         yield self.device_handler.notify_device_update(user_id, device_ids)
 
         # We clobber the seen updates since we've re-synced from a given
@@ -831,3 +850,40 @@ class DeviceListUpdater(object):
         self._seen_updates[user_id] = {stream_id}
 
         defer.returnValue(result)
+
+    @defer.inlineCallbacks
+    def process_cross_signing_key_update(
+        self,
+        user_id: str,
+        master_key: Optional[Dict[str, Any]],
+        self_signing_key: Optional[Dict[str, Any]],
+    ) -> list:
+        """Process the given new master and self-signing key for the given remote user.
+
+        Args:
+            user_id: The ID of the user these keys are for.
+            master_key: The dict of the cross-signing master key as returned by the
+                remote server.
+            self_signing_key: The dict of the cross-signing self-signing key as returned
+                by the remote server.
+
+        Return:
+            The device IDs for the given keys.
+        """
+        device_ids = []
+
+        if master_key:
+            yield self.store.set_e2e_cross_signing_key(user_id, "master", master_key)
+            _, verify_key = get_verify_key_from_cross_signing_key(master_key)
+            # verify_key is a VerifyKey from signedjson, which uses
+            # .version to denote the portion of the key ID after the
+            # algorithm and colon, which is the device ID
+            device_ids.append(verify_key.version)
+        if self_signing_key:
+            yield self.store.set_e2e_cross_signing_key(
+                user_id, "self_signing", self_signing_key
+            )
+            _, verify_key = get_verify_key_from_cross_signing_key(self_signing_key)
+            device_ids.append(verify_key.version)
+
+        return device_ids
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 8f1bc0323c..774a252619 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -1291,6 +1291,7 @@ class SigningKeyEduUpdater(object):
         """
 
         device_handler = self.e2e_keys_handler.device_handler
+        device_list_updater = device_handler.device_list_updater
 
         with (yield self._remote_edu_linearizer.queue(user_id)):
             pending_updates = self._pending_updates.pop(user_id, [])
@@ -1303,22 +1304,9 @@ class SigningKeyEduUpdater(object):
             logger.info("pending updates: %r", pending_updates)
 
             for master_key, self_signing_key in pending_updates:
-                if master_key:
-                    yield self.store.set_e2e_cross_signing_key(
-                        user_id, "master", master_key
-                    )
-                    _, verify_key = get_verify_key_from_cross_signing_key(master_key)
-                    # verify_key is a VerifyKey from signedjson, which uses
-                    # .version to denote the portion of the key ID after the
-                    # algorithm and colon, which is the device ID
-                    device_ids.append(verify_key.version)
-                if self_signing_key:
-                    yield self.store.set_e2e_cross_signing_key(
-                        user_id, "self_signing", self_signing_key
-                    )
-                    _, verify_key = get_verify_key_from_cross_signing_key(
-                        self_signing_key
-                    )
-                    device_ids.append(verify_key.version)
+                new_device_ids = yield device_list_updater.process_cross_signing_key_update(
+                    user_id, master_key, self_signing_key,
+                )
+                device_ids = device_ids + new_device_ids
 
             yield device_handler.notify_device_update(user_id, device_ids)