summary refs log tree commit diff
diff options
context:
space:
mode:
authorHubert Chathi <hubert@uhoreg.ca>2019-05-22 16:42:00 -0400
committerHubert Chathi <hubert@uhoreg.ca>2019-10-22 19:04:35 -0400
commit8d3542a64e2689a00ed87f9bd58fe3e1d3b10ed8 (patch)
treed96814f4ae9fa1f54defd9336889c3a3a36d5a63
parentMerge pull request #5726 from matrix-org/uhoreg/e2e_cross-signing2-part2 (diff)
downloadsynapse-8d3542a64e2689a00ed87f9bd58fe3e1d3b10ed8.tar.xz
implement federation parts of cross-signing
-rw-r--r--synapse/federation/sender/per_destination_queue.py4
-rw-r--r--synapse/handlers/device.py13
-rw-r--r--synapse/handlers/e2e_keys.py116
-rw-r--r--synapse/storage/devices.py56
4 files changed, 179 insertions, 10 deletions
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index fad980b893..0486af2dbf 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -366,10 +366,10 @@ class PerDestinationQueue(object):
             Edu(
                 origin=self._server_name,
                 destination=self._destination,
-                edu_type="m.device_list_update",
+                edu_type=edu_type,
                 content=content,
             )
-            for content in results
+            for (edu_type, content) in results
         ]
 
         assert len(edus) <= limit, "get_devices_by_remote returned too many EDUs"
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 5f23ee4488..cd6eb52316 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -458,7 +458,18 @@ class DeviceHandler(DeviceWorkerHandler):
     @defer.inlineCallbacks
     def on_federation_query_user_devices(self, user_id):
         stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id)
-        return {"user_id": user_id, "stream_id": stream_id, "devices": devices}
+        master_key = yield self.store.get_e2e_cross_signing_key(user_id, "master")
+        self_signing_key = yield self.store.get_e2e_cross_signing_key(
+            user_id, "self_signing"
+        )
+
+        return {
+            "user_id": user_id,
+            "stream_id": stream_id,
+            "devices": devices,
+            "master_key": master_key,
+            "self_signing_key": self_signing_key
+        }
 
     @defer.inlineCallbacks
     def user_left_room(self, user, room_id):
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 5ea54f60be..849ee04f93 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -36,6 +36,8 @@ from synapse.types import (
     get_verify_key_from_cross_signing_key,
 )
 from synapse.util import unwrapFirstError
+from synapse.util.async_helpers import Linearizer
+from synapse.util.caches.expiringcache import ExpiringCache
 from synapse.util.retryutils import NotRetryingDestination
 
 logger = logging.getLogger(__name__)
@@ -49,10 +51,17 @@ class E2eKeysHandler(object):
         self.is_mine = hs.is_mine
         self.clock = hs.get_clock()
 
+        self._edu_updater = SigningKeyEduUpdater(hs, self)
+
+        federation_registry = hs.get_federation_registry()
+
+        federation_registry.register_edu_handler(
+            "m.signing_key_update", self._edu_updater.incoming_signing_key_update,
+        )
         # doesn't really work as part of the generic query API, because the
         # query request requires an object POST, but we abuse the
         # "query handler" interface.
-        hs.get_federation_registry().register_query_handler(
+        federation_registry.register_query_handler(
             "client_keys", self.on_federation_query_client_keys
         )
 
@@ -343,7 +352,15 @@ class E2eKeysHandler(object):
         """
         device_keys_query = query_body.get("device_keys", {})
         res = yield self.query_local_devices(device_keys_query)
-        return {"device_keys": res}
+        ret = {"device_keys": res}
+
+        # add in the cross-signing keys
+        cross_signing_keys = yield self.query_cross_signing_keys(device_keys_query)
+
+        for key, value in iteritems(cross_signing_keys):
+            ret[key + "_keys"] = value
+
+        return ret
 
     @trace
     @defer.inlineCallbacks
@@ -1047,3 +1064,98 @@ class SignatureListItem:
     target_user_id = attr.ib()
     target_device_id = attr.ib()
     signature = attr.ib()
+
+
+class SigningKeyEduUpdater(object):
+    "Handles incoming signing key updates from federation and updates the DB"
+
+    def __init__(self, hs, e2e_keys_handler):
+        self.store = hs.get_datastore()
+        self.federation = hs.get_federation_client()
+        self.clock = hs.get_clock()
+        self.e2e_keys_handler = e2e_keys_handler
+
+        self._remote_edu_linearizer = Linearizer(name="remote_signing_key")
+
+        # user_id -> list of updates waiting to be handled.
+        self._pending_updates = {}
+
+        # Recently seen stream ids. We don't bother keeping these in the DB,
+        # but they're useful to have them about to reduce the number of spurious
+        # resyncs.
+        self._seen_updates = ExpiringCache(
+            cache_name="signing_key_update_edu",
+            clock=self.clock,
+            max_len=10000,
+            expiry_ms=30 * 60 * 1000,
+            iterable=True,
+        )
+
+    @defer.inlineCallbacks
+    def incoming_signing_key_update(self, origin, edu_content):
+        """Called on incoming signing key update from federation. Responsible for
+        parsing the EDU and adding to pending updates list.
+
+        Args:
+            origin (string): the server that sent the EDU
+            edu_content (dict): the contents of the EDU
+        """
+
+        user_id = edu_content.pop("user_id")
+        master_key = edu_content.pop("master_key", None)
+        self_signing_key = edu_content.pop("self_signing_key", None)
+
+        if get_domain_from_id(user_id) != origin:
+            # TODO: Raise?
+            logger.warning("Got signing key update edu for %r from %r", user_id, origin)
+            return
+
+        room_ids = yield self.store.get_rooms_for_user(user_id)
+        if not room_ids:
+            # We don't share any rooms with this user. Ignore update, as we
+            # probably won't get any further updates.
+            return
+
+        self._pending_updates.setdefault(user_id, []).append(
+            (master_key, self_signing_key, edu_content)
+        )
+
+        yield self._handle_signing_key_updates(user_id)
+
+    @defer.inlineCallbacks
+    def _handle_signing_key_updates(self, user_id):
+        """Actually handle pending updates.
+
+        Args:
+            user_id (string): the user whose updates we are processing
+        """
+
+        device_handler = self.e2e_keys_handler.device_handler
+
+        with (yield self._remote_edu_linearizer.queue(user_id)):
+            pending_updates = self._pending_updates.pop(user_id, [])
+            if not pending_updates:
+                # This can happen since we batch updates
+                return
+
+            device_ids = []
+
+            logger.info("pending updates: %r", pending_updates)
+
+            for master_key, self_signing_key, edu_content in pending_updates:
+                if master_key:
+                    yield self.store.set_e2e_cross_signing_key(
+                        user_id, "master", master_key
+                    )
+                    device_id = \
+                        get_verify_key_from_cross_signing_key(master_key)[1].version
+                    device_ids.append(device_id)
+                if self_signing_key:
+                    yield self.store.set_e2e_cross_signing_key(
+                        user_id, "self_signing", self_signing_key
+                    )
+                    device_id = \
+                        get_verify_key_from_cross_signing_key(self_signing_key)[1].version
+                    device_ids.append(device_id)
+
+            yield device_handler.notify_device_update(user_id, device_ids)
diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
index f7a3542348..182e95fa21 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/devices.py
@@ -94,9 +94,10 @@ class DeviceWorkerStore(SQLBaseStore):
         """Get stream of updates to send to remote servers
 
         Returns:
-            Deferred[tuple[int, list[dict]]]:
+            Deferred[tuple[int, list[tuple[string,dict]]]]:
                 current stream id (ie, the stream id of the last update included in the
-                response), and the list of updates
+                response), and the list of updates, where each update is a pair of EDU
+                type and EDU contents
         """
         now_stream_id = self._device_list_id_gen.get_current_token()
 
@@ -129,6 +130,25 @@ class DeviceWorkerStore(SQLBaseStore):
         if not updates:
             return now_stream_id, []
 
+        # get the cross-signing keys of the users the list
+        users = set(r[0] for r in updates)
+        master_key_by_user = {}
+        self_signing_key_by_user = {}
+        for user in users:
+            cross_signing_key = yield self.get_e2e_cross_signing_key(user, "master")
+            key_id, verify_key = get_verify_key_from_cross_signing_key(cross_signing_key)
+            master_key_by_user[user] = {
+                "key_info": cross_signing_key,
+                "pubkey": verify_key.version
+            }
+
+            cross_signing_key = yield self.get_e2e_cross_signing_key(user, "self_signing")
+            key_id, verify_key = get_verify_key_from_cross_signing_key(cross_signing_key)
+            self_signing_key_by_user[user] = {
+                "key_info": cross_signing_key,
+                "pubkey": verify_key.version
+            }
+
         # if we have exceeded the limit, we need to exclude any results with the
         # same stream_id as the last row.
         if len(updates) > limit:
@@ -158,6 +178,10 @@ class DeviceWorkerStore(SQLBaseStore):
                 # Stop processing updates
                 break
 
+            if update[1] == master_key_by_user[update[0]]["pubkey"] or \
+                    update[1] == self_signing_key_by_user[update[0]]["pubkey"]:
+                continue
+
             key = (update[0], update[1])
 
             update_context = update[3]
@@ -172,16 +196,37 @@ class DeviceWorkerStore(SQLBaseStore):
         # means that there are more than limit updates all of which have the same
         # steam_id.
 
+        # figure out which cross-signing keys were changed by intersecting the
+        # update list with the master/self-signing key by user maps
+        cross_signing_keys_by_user = {}
+        for user_id, device_id, stream in updates:
+            if device_id == master_key_by_user[user_id]["pubkey"]:
+                result = cross_signing_keys_by_user.setdefault(user_id, {})
+                result["master_key"] = \
+                    master_key_by_user[user_id]["key_info"]
+            elif device_id == self_signing_key_by_user[user_id]["pubkey"]:
+                result = cross_signing_keys_by_user.setdefault(user_id, {})
+                result["self_signing_key"] = \
+                    self_signing_key_by_user[user_id]["key_info"]
+
+        cross_signing_results = []
+
+        # add the updated cross-signing keys to the results list
+        for user_id, result in iteritems(cross_signing_keys_by_user):
+            result["user_id"] = user_id
+            cross_signing_results.append(("m.signing_key_update", result))
+
         # That should only happen if a client is spamming the server with new
         # devices, in which case E2E isn't going to work well anyway. We'll just
         # skip that stream_id and return an empty list, and continue with the next
         # stream_id next time.
-        if not query_map:
+        if not query_map and not cross_signing_results:
             return stream_id_cutoff, []
 
         results = yield self._get_device_update_edus_by_remote(
             destination, from_stream_id, query_map
         )
+        results.extend(cross_signing_results)
 
         return now_stream_id, results
 
@@ -200,6 +245,7 @@ class DeviceWorkerStore(SQLBaseStore):
         Returns:
             List: List of device updates
         """
+        # get the list of device updates that need to be sent
         sql = """
             SELECT user_id, device_id, stream_id, opentracing_context FROM device_lists_outbound_pokes
             WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
@@ -231,7 +277,7 @@ class DeviceWorkerStore(SQLBaseStore):
             query_map.keys(),
             include_all_devices=True,
             include_deleted_devices=True,
-        )
+        ) if query_map else {}
 
         results = []
         for user_id, user_devices in iteritems(devices):
@@ -262,7 +308,7 @@ class DeviceWorkerStore(SQLBaseStore):
                 else:
                     result["deleted"] = True
 
-                results.append(result)
+                results.append(("m.device_list_update", result))
 
         return results