summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/5727.feature1
-rw-r--r--synapse/federation/sender/per_destination_queue.py8
-rw-r--r--synapse/handlers/device.py13
-rw-r--r--synapse/handlers/e2e_keys.py137
-rw-r--r--synapse/storage/data_stores/main/devices.py109
-rw-r--r--tests/handlers/test_typing.py4
-rw-r--r--tests/storage/test_devices.py16
7 files changed, 241 insertions, 47 deletions
diff --git a/changelog.d/5727.feature b/changelog.d/5727.feature
new file mode 100644
index 0000000000..819bebf2d7
--- /dev/null
+++ b/changelog.d/5727.feature
@@ -0,0 +1 @@
+Add federation support for cross-signing.
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index b754a09d7a..a5b36b1827 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -360,20 +360,20 @@ class PerDestinationQueue(object):
         last_device_list = self._last_device_list_stream_id
 
         # Retrieve list of new device updates to send to the destination
-        now_stream_id, results = yield self._store.get_devices_by_remote(
+        now_stream_id, results = yield self._store.get_device_updates_by_remote(
             self._destination, last_device_list, limit=limit
         )
         edus = [
             Edu(
                 origin=self._server_name,
                 destination=self._destination,
-                edu_type="m.device_list_update",
+                edu_type=edu_type,
                 content=content,
             )
-            for content in results
+            for (edu_type, content) in results
         ]
 
-        assert len(edus) <= limit, "get_devices_by_remote returned too many EDUs"
+        assert len(edus) <= limit, "get_device_updates_by_remote returned too many EDUs"
 
         return (edus, now_stream_id)
 
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 16b4617f68..26ef5e150c 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -459,7 +459,18 @@ class DeviceHandler(DeviceWorkerHandler):
     @defer.inlineCallbacks
     def on_federation_query_user_devices(self, user_id):
         stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id)
-        return {"user_id": user_id, "stream_id": stream_id, "devices": devices}
+        master_key = yield self.store.get_e2e_cross_signing_key(user_id, "master")
+        self_signing_key = yield self.store.get_e2e_cross_signing_key(
+            user_id, "self_signing"
+        )
+
+        return {
+            "user_id": user_id,
+            "stream_id": stream_id,
+            "devices": devices,
+            "master_key": master_key,
+            "self_signing_key": self_signing_key,
+        }
 
     @defer.inlineCallbacks
     def user_left_room(self, user, room_id):
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 0449034a4e..f09a0b73c8 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -36,6 +36,8 @@ from synapse.types import (
     get_verify_key_from_cross_signing_key,
 )
 from synapse.util import unwrapFirstError
+from synapse.util.async_helpers import Linearizer
+from synapse.util.caches.expiringcache import ExpiringCache
 from synapse.util.retryutils import NotRetryingDestination
 
 logger = logging.getLogger(__name__)
@@ -49,10 +51,19 @@ class E2eKeysHandler(object):
         self.is_mine = hs.is_mine
         self.clock = hs.get_clock()
 
+        self._edu_updater = SigningKeyEduUpdater(hs, self)
+
+        federation_registry = hs.get_federation_registry()
+
+        # FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec
+        federation_registry.register_edu_handler(
+            "org.matrix.signing_key_update",
+            self._edu_updater.incoming_signing_key_update,
+        )
         # doesn't really work as part of the generic query API, because the
         # query request requires an object POST, but we abuse the
         # "query handler" interface.
-        hs.get_federation_registry().register_query_handler(
+        federation_registry.register_query_handler(
             "client_keys", self.on_federation_query_client_keys
         )
 
@@ -208,13 +219,15 @@ class E2eKeysHandler(object):
                     if user_id in destination_query:
                         results[user_id] = keys
 
-                for user_id, key in remote_result["master_keys"].items():
-                    if user_id in destination_query:
-                        cross_signing_keys["master_keys"][user_id] = key
+                if "master_keys" in remote_result:
+                    for user_id, key in remote_result["master_keys"].items():
+                        if user_id in destination_query:
+                            cross_signing_keys["master_keys"][user_id] = key
 
-                for user_id, key in remote_result["self_signing_keys"].items():
-                    if user_id in destination_query:
-                        cross_signing_keys["self_signing_keys"][user_id] = key
+                if "self_signing_keys" in remote_result:
+                    for user_id, key in remote_result["self_signing_keys"].items():
+                        if user_id in destination_query:
+                            cross_signing_keys["self_signing_keys"][user_id] = key
 
             except Exception as e:
                 failure = _exception_to_failure(e)
@@ -252,7 +265,7 @@ class E2eKeysHandler(object):
 
         Returns:
             defer.Deferred[dict[str, dict[str, dict]]]: map from
-                (master|self_signing|user_signing) -> user_id -> key
+                (master_keys|self_signing_keys|user_signing_keys) -> user_id -> key
         """
         master_keys = {}
         self_signing_keys = {}
@@ -344,7 +357,16 @@ class E2eKeysHandler(object):
         """
         device_keys_query = query_body.get("device_keys", {})
         res = yield self.query_local_devices(device_keys_query)
-        return {"device_keys": res}
+        ret = {"device_keys": res}
+
+        # add in the cross-signing keys
+        cross_signing_keys = yield self.get_cross_signing_keys_from_cache(
+            device_keys_query, None
+        )
+
+        ret.update(cross_signing_keys)
+
+        return ret
 
     @trace
     @defer.inlineCallbacks
@@ -1058,3 +1080,100 @@ class SignatureListItem:
     target_user_id = attr.ib()
     target_device_id = attr.ib()
     signature = attr.ib()
+
+
+class SigningKeyEduUpdater(object):
+    """Handles incoming signing key updates from federation and updates the DB"""
+
+    def __init__(self, hs, e2e_keys_handler):
+        self.store = hs.get_datastore()
+        self.federation = hs.get_federation_client()
+        self.clock = hs.get_clock()
+        self.e2e_keys_handler = e2e_keys_handler
+
+        self._remote_edu_linearizer = Linearizer(name="remote_signing_key")
+
+        # user_id -> list of updates waiting to be handled.
+        self._pending_updates = {}
+
+        # Recently seen stream ids. We don't bother keeping these in the DB,
+        # but they're useful to have them about to reduce the number of spurious
+        # resyncs.
+        self._seen_updates = ExpiringCache(
+            cache_name="signing_key_update_edu",
+            clock=self.clock,
+            max_len=10000,
+            expiry_ms=30 * 60 * 1000,
+            iterable=True,
+        )
+
+    @defer.inlineCallbacks
+    def incoming_signing_key_update(self, origin, edu_content):
+        """Called on incoming signing key update from federation. Responsible for
+        parsing the EDU and adding to pending updates list.
+
+        Args:
+            origin (string): the server that sent the EDU
+            edu_content (dict): the contents of the EDU
+        """
+
+        user_id = edu_content.pop("user_id")
+        master_key = edu_content.pop("master_key", None)
+        self_signing_key = edu_content.pop("self_signing_key", None)
+
+        if get_domain_from_id(user_id) != origin:
+            logger.warning("Got signing key update edu for %r from %r", user_id, origin)
+            return
+
+        room_ids = yield self.store.get_rooms_for_user(user_id)
+        if not room_ids:
+            # We don't share any rooms with this user. Ignore update, as we
+            # probably won't get any further updates.
+            return
+
+        self._pending_updates.setdefault(user_id, []).append(
+            (master_key, self_signing_key)
+        )
+
+        yield self._handle_signing_key_updates(user_id)
+
+    @defer.inlineCallbacks
+    def _handle_signing_key_updates(self, user_id):
+        """Actually handle pending updates.
+
+        Args:
+            user_id (string): the user whose updates we are processing
+        """
+
+        device_handler = self.e2e_keys_handler.device_handler
+
+        with (yield self._remote_edu_linearizer.queue(user_id)):
+            pending_updates = self._pending_updates.pop(user_id, [])
+            if not pending_updates:
+                # This can happen since we batch updates
+                return
+
+            device_ids = []
+
+            logger.info("pending updates: %r", pending_updates)
+
+            for master_key, self_signing_key in pending_updates:
+                if master_key:
+                    yield self.store.set_e2e_cross_signing_key(
+                        user_id, "master", master_key
+                    )
+                    _, verify_key = get_verify_key_from_cross_signing_key(master_key)
+                    # verify_key is a VerifyKey from signedjson, which uses
+                    # .version to denote the portion of the key ID after the
+                    # algorithm and colon, which is the device ID
+                    device_ids.append(verify_key.version)
+                if self_signing_key:
+                    yield self.store.set_e2e_cross_signing_key(
+                        user_id, "self_signing", self_signing_key
+                    )
+                    _, verify_key = get_verify_key_from_cross_signing_key(
+                        self_signing_key
+                    )
+                    device_ids.append(verify_key.version)
+
+            yield device_handler.notify_device_update(user_id, device_ids)
diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py
index f7a3542348..71f62036c0 100644
--- a/synapse/storage/data_stores/main/devices.py
+++ b/synapse/storage/data_stores/main/devices.py
@@ -37,6 +37,7 @@ from synapse.storage._base import (
     make_in_list_sql_clause,
 )
 from synapse.storage.background_updates import BackgroundUpdateStore
+from synapse.types import get_verify_key_from_cross_signing_key
 from synapse.util import batch_iter
 from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
 
@@ -90,13 +91,18 @@ class DeviceWorkerStore(SQLBaseStore):
 
     @trace
     @defer.inlineCallbacks
-    def get_devices_by_remote(self, destination, from_stream_id, limit):
-        """Get stream of updates to send to remote servers
+    def get_device_updates_by_remote(self, destination, from_stream_id, limit):
+        """Get a stream of device updates to send to the given remote server.
 
+        Args:
+            destination (str): The host the device updates are intended for
+            from_stream_id (int): The minimum stream_id to filter updates by, exclusive
+            limit (int): Maximum number of device updates to return
         Returns:
-            Deferred[tuple[int, list[dict]]]:
+            Deferred[tuple[int, list[tuple[string,dict]]]]:
                 current stream id (ie, the stream id of the last update included in the
-                response), and the list of updates
+                response), and the list of updates, where each update is a pair of EDU
+                type and EDU contents
         """
         now_stream_id = self._device_list_id_gen.get_current_token()
 
@@ -117,8 +123,8 @@ class DeviceWorkerStore(SQLBaseStore):
         # stream_id; the rationale being that such a large device list update
         # is likely an error.
         updates = yield self.runInteraction(
-            "get_devices_by_remote",
-            self._get_devices_by_remote_txn,
+            "get_device_updates_by_remote",
+            self._get_device_updates_by_remote_txn,
             destination,
             from_stream_id,
             now_stream_id,
@@ -129,6 +135,37 @@ class DeviceWorkerStore(SQLBaseStore):
         if not updates:
             return now_stream_id, []
 
+        # get the cross-signing keys of the users in the list, so that we can
+        # determine which of the device changes were cross-signing keys
+        users = set(r[0] for r in updates)
+        master_key_by_user = {}
+        self_signing_key_by_user = {}
+        for user in users:
+            cross_signing_key = yield self.get_e2e_cross_signing_key(user, "master")
+            if cross_signing_key:
+                key_id, verify_key = get_verify_key_from_cross_signing_key(
+                    cross_signing_key
+                )
+                # verify_key is a VerifyKey from signedjson, which uses
+                # .version to denote the portion of the key ID after the
+                # algorithm and colon, which is the device ID
+                master_key_by_user[user] = {
+                    "key_info": cross_signing_key,
+                    "device_id": verify_key.version,
+                }
+
+            cross_signing_key = yield self.get_e2e_cross_signing_key(
+                user, "self_signing"
+            )
+            if cross_signing_key:
+                key_id, verify_key = get_verify_key_from_cross_signing_key(
+                    cross_signing_key
+                )
+                self_signing_key_by_user[user] = {
+                    "key_info": cross_signing_key,
+                    "device_id": verify_key.version,
+                }
+
         # if we have exceeded the limit, we need to exclude any results with the
         # same stream_id as the last row.
         if len(updates) > limit:
@@ -153,20 +190,33 @@ class DeviceWorkerStore(SQLBaseStore):
         # context which created the Edu.
 
         query_map = {}
-        for update in updates:
-            if stream_id_cutoff is not None and update[2] >= stream_id_cutoff:
+        cross_signing_keys_by_user = {}
+        for user_id, device_id, update_stream_id, update_context in updates:
+            if stream_id_cutoff is not None and update_stream_id >= stream_id_cutoff:
                 # Stop processing updates
                 break
 
-            key = (update[0], update[1])
-
-            update_context = update[3]
-            update_stream_id = update[2]
+            if (
+                user_id in master_key_by_user
+                and device_id == master_key_by_user[user_id]["device_id"]
+            ):
+                result = cross_signing_keys_by_user.setdefault(user_id, {})
+                result["master_key"] = master_key_by_user[user_id]["key_info"]
+            elif (
+                user_id in self_signing_key_by_user
+                and device_id == self_signing_key_by_user[user_id]["device_id"]
+            ):
+                result = cross_signing_keys_by_user.setdefault(user_id, {})
+                result["self_signing_key"] = self_signing_key_by_user[user_id][
+                    "key_info"
+                ]
+            else:
+                key = (user_id, device_id)
 
-            previous_update_stream_id, _ = query_map.get(key, (0, None))
+                previous_update_stream_id, _ = query_map.get(key, (0, None))
 
-            if update_stream_id > previous_update_stream_id:
-                query_map[key] = (update_stream_id, update_context)
+                if update_stream_id > previous_update_stream_id:
+                    query_map[key] = (update_stream_id, update_context)
 
         # If we didn't find any updates with a stream_id lower than the cutoff, it
         # means that there are more than limit updates all of which have the same
@@ -176,16 +226,22 @@ class DeviceWorkerStore(SQLBaseStore):
         # devices, in which case E2E isn't going to work well anyway. We'll just
         # skip that stream_id and return an empty list, and continue with the next
         # stream_id next time.
-        if not query_map:
+        if not query_map and not cross_signing_keys_by_user:
             return stream_id_cutoff, []
 
         results = yield self._get_device_update_edus_by_remote(
             destination, from_stream_id, query_map
         )
 
+        # add the updated cross-signing keys to the results list
+        for user_id, result in iteritems(cross_signing_keys_by_user):
+            result["user_id"] = user_id
+            # FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec
+            results.append(("org.matrix.signing_key_update", result))
+
         return now_stream_id, results
 
-    def _get_devices_by_remote_txn(
+    def _get_device_updates_by_remote_txn(
         self, txn, destination, from_stream_id, now_stream_id, limit
     ):
         """Return device update information for a given remote destination
@@ -200,6 +256,7 @@ class DeviceWorkerStore(SQLBaseStore):
         Returns:
             List: List of device updates
         """
+        # get the list of device updates that need to be sent
         sql = """
             SELECT user_id, device_id, stream_id, opentracing_context FROM device_lists_outbound_pokes
             WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
@@ -225,12 +282,16 @@ class DeviceWorkerStore(SQLBaseStore):
             List[Dict]: List of objects representing an device update EDU
 
         """
-        devices = yield self.runInteraction(
-            "_get_e2e_device_keys_txn",
-            self._get_e2e_device_keys_txn,
-            query_map.keys(),
-            include_all_devices=True,
-            include_deleted_devices=True,
+        devices = (
+            yield self.runInteraction(
+                "_get_e2e_device_keys_txn",
+                self._get_e2e_device_keys_txn,
+                query_map.keys(),
+                include_all_devices=True,
+                include_deleted_devices=True,
+            )
+            if query_map
+            else {}
         )
 
         results = []
@@ -262,7 +323,7 @@ class DeviceWorkerStore(SQLBaseStore):
                 else:
                     result["deleted"] = True
 
-                results.append(result)
+                results.append(("m.device_list_update", result))
 
         return results
 
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index f360c8e965..5ec568f4e6 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -73,7 +73,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
                         "get_received_txn_response",
                         "set_received_txn_response",
                         "get_destination_retry_timings",
-                        "get_devices_by_remote",
+                        "get_device_updates_by_remote",
                         # Bits that user_directory needs
                         "get_user_directory_stream_pos",
                         "get_current_state_deltas",
@@ -109,7 +109,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
             retry_timings_res
         )
 
-        self.datastore.get_devices_by_remote.return_value = (0, [])
+        self.datastore.get_device_updates_by_remote.return_value = (0, [])
 
         def get_received_txn_response(*args):
             return defer.succeed(None)
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index 3cc18f9f1c..6f8d990959 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -72,7 +72,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
         )
 
     @defer.inlineCallbacks
-    def test_get_devices_by_remote(self):
+    def test_get_device_updates_by_remote(self):
         device_ids = ["device_id1", "device_id2"]
 
         # Add two device updates with a single stream_id
@@ -81,7 +81,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
         )
 
         # Get all device updates ever meant for this remote
-        now_stream_id, device_updates = yield self.store.get_devices_by_remote(
+        now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
             "somehost", -1, limit=100
         )
 
@@ -89,7 +89,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
         self._check_devices_in_updates(device_ids, device_updates)
 
     @defer.inlineCallbacks
-    def test_get_devices_by_remote_limited(self):
+    def test_get_device_updates_by_remote_limited(self):
         # Test breaking the update limit in 1, 101, and 1 device_id segments
 
         # first add one device
@@ -115,20 +115,20 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
         #
 
         # first we should get a single update
-        now_stream_id, device_updates = yield self.store.get_devices_by_remote(
+        now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
             "someotherhost", -1, limit=100
         )
         self._check_devices_in_updates(device_ids1, device_updates)
 
         # Then we should get an empty list back as the 101 devices broke the limit
-        now_stream_id, device_updates = yield self.store.get_devices_by_remote(
+        now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
             "someotherhost", now_stream_id, limit=100
         )
         self.assertEqual(len(device_updates), 0)
 
         # The 101 devices should've been cleared, so we should now just get one device
         # update
-        now_stream_id, device_updates = yield self.store.get_devices_by_remote(
+        now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
             "someotherhost", now_stream_id, limit=100
         )
         self._check_devices_in_updates(device_ids3, device_updates)
@@ -137,7 +137,9 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
         """Check that an specific device ids exist in a list of device update EDUs"""
         self.assertEqual(len(device_updates), len(expected_device_ids))
 
-        received_device_ids = {update["device_id"] for update in device_updates}
+        received_device_ids = {
+            update["device_id"] for edu_type, update in device_updates
+        }
         self.assertEqual(received_device_ids, set(expected_device_ids))
 
     @defer.inlineCallbacks