summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/data_stores/main/devices.py88
-rw-r--r--synapse/storage/data_stores/main/end_to_end_keys.py134
2 files changed, 145 insertions, 77 deletions
diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py
index f7a3542348..6ac165068e 100644
--- a/synapse/storage/data_stores/main/devices.py
+++ b/synapse/storage/data_stores/main/devices.py
@@ -37,6 +37,7 @@ from synapse.storage._base import (
     make_in_list_sql_clause,
 )
 from synapse.storage.background_updates import BackgroundUpdateStore
+from synapse.types import get_verify_key_from_cross_signing_key
 from synapse.util import batch_iter
 from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
 
@@ -94,9 +95,10 @@ class DeviceWorkerStore(SQLBaseStore):
         """Get stream of updates to send to remote servers
 
         Returns:
-            Deferred[tuple[int, list[dict]]]:
+            Deferred[tuple[int, list[tuple[string,dict]]]]:
                 current stream id (ie, the stream id of the last update included in the
-                response), and the list of updates
+                response), and the list of updates, where each update is a pair of EDU
+                type and EDU contents
         """
         now_stream_id = self._device_list_id_gen.get_current_token()
 
@@ -129,6 +131,33 @@ class DeviceWorkerStore(SQLBaseStore):
         if not updates:
             return now_stream_id, []
 
+        # get the cross-signing keys of the users the list
+        users = set(r[0] for r in updates)
+        master_key_by_user = {}
+        self_signing_key_by_user = {}
+        for user in users:
+            cross_signing_key = yield self.get_e2e_cross_signing_key(user, "master")
+            if cross_signing_key:
+                key_id, verify_key = get_verify_key_from_cross_signing_key(
+                    cross_signing_key
+                )
+                master_key_by_user[user] = {
+                    "key_info": cross_signing_key,
+                    "pubkey": verify_key.version,
+                }
+
+            cross_signing_key = yield self.get_e2e_cross_signing_key(
+                user, "self_signing"
+            )
+            if cross_signing_key:
+                key_id, verify_key = get_verify_key_from_cross_signing_key(
+                    cross_signing_key
+                )
+                self_signing_key_by_user[user] = {
+                    "key_info": cross_signing_key,
+                    "pubkey": verify_key.version,
+                }
+
         # if we have exceeded the limit, we need to exclude any results with the
         # same stream_id as the last row.
         if len(updates) > limit:
@@ -158,6 +187,16 @@ class DeviceWorkerStore(SQLBaseStore):
                 # Stop processing updates
                 break
 
+            # skip over cross-signing keys
+            if (
+                update[0] in master_key_by_user
+                and update[1] == master_key_by_user[update[0]]["pubkey"]
+            ) or (
+                update[0] in master_key_by_user
+                and update[1] == self_signing_key_by_user[update[0]]["pubkey"]
+            ):
+                continue
+
             key = (update[0], update[1])
 
             update_context = update[3]
@@ -172,16 +211,40 @@ class DeviceWorkerStore(SQLBaseStore):
         # means that there are more than limit updates all of which have the same
         # steam_id.
 
+        # figure out which cross-signing keys were changed by intersecting the
+        # update list with the master/self-signing key by user maps
+        cross_signing_keys_by_user = {}
+        for user_id, device_id, stream, _opentracing_context in updates:
+            if device_id == master_key_by_user.get(user_id, {}).get("pubkey", None):
+                result = cross_signing_keys_by_user.setdefault(user_id, {})
+                result["master_key"] = master_key_by_user[user_id]["key_info"]
+            elif device_id == self_signing_key_by_user.get(user_id, {}).get(
+                "pubkey", None
+            ):
+                result = cross_signing_keys_by_user.setdefault(user_id, {})
+                result["self_signing_key"] = self_signing_key_by_user[user_id][
+                    "key_info"
+                ]
+
+        cross_signing_results = []
+
+        # add the updated cross-signing keys to the results list
+        for user_id, result in iteritems(cross_signing_keys_by_user):
+            result["user_id"] = user_id
+            # FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec
+            cross_signing_results.append(("org.matrix.signing_key_update", result))
+
         # That should only happen if a client is spamming the server with new
         # devices, in which case E2E isn't going to work well anyway. We'll just
         # skip that stream_id and return an empty list, and continue with the next
         # stream_id next time.
-        if not query_map:
+        if not query_map and not cross_signing_results:
             return stream_id_cutoff, []
 
         results = yield self._get_device_update_edus_by_remote(
             destination, from_stream_id, query_map
         )
+        results.extend(cross_signing_results)
 
         return now_stream_id, results
 
@@ -200,6 +263,7 @@ class DeviceWorkerStore(SQLBaseStore):
         Returns:
             List: List of device updates
         """
+        # get the list of device updates that need to be sent
         sql = """
             SELECT user_id, device_id, stream_id, opentracing_context FROM device_lists_outbound_pokes
             WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
@@ -225,12 +289,16 @@ class DeviceWorkerStore(SQLBaseStore):
             List[Dict]: List of objects representing an device update EDU
 
         """
-        devices = yield self.runInteraction(
-            "_get_e2e_device_keys_txn",
-            self._get_e2e_device_keys_txn,
-            query_map.keys(),
-            include_all_devices=True,
-            include_deleted_devices=True,
+        devices = (
+            yield self.runInteraction(
+                "_get_e2e_device_keys_txn",
+                self._get_e2e_device_keys_txn,
+                query_map.keys(),
+                include_all_devices=True,
+                include_deleted_devices=True,
+            )
+            if query_map
+            else {}
         )
 
         results = []
@@ -262,7 +330,7 @@ class DeviceWorkerStore(SQLBaseStore):
                 else:
                     result["deleted"] = True
 
-                results.append(result)
+                results.append(("m.device_list_update", result))
 
         return results
 
diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py
index f5c3ed9dc2..a0bc6f2d18 100644
--- a/synapse/storage/data_stores/main/end_to_end_keys.py
+++ b/synapse/storage/data_stores/main/end_to_end_keys.py
@@ -248,6 +248,73 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
 
         return self.runInteraction("count_e2e_one_time_keys", _count_e2e_one_time_keys)
 
+    def _get_e2e_cross_signing_key_txn(self, txn, user_id, key_type, from_user_id=None):
+        """Returns a user's cross-signing key.
+
+        Args:
+            txn (twisted.enterprise.adbapi.Connection): db connection
+            user_id (str): the user whose key is being requested
+            key_type (str): the type of key that is being set: either 'master'
+                for a master key, 'self_signing' for a self-signing key, or
+                'user_signing' for a user-signing key
+            from_user_id (str): if specified, signatures made by this user on
+                the key will be included in the result
+
+        Returns:
+            dict of the key data or None if not found
+        """
+        sql = (
+            "SELECT keydata "
+            "  FROM e2e_cross_signing_keys "
+            " WHERE user_id = ? AND keytype = ? ORDER BY stream_id DESC LIMIT 1"
+        )
+        txn.execute(sql, (user_id, key_type))
+        row = txn.fetchone()
+        if not row:
+            return None
+        key = json.loads(row[0])
+
+        device_id = None
+        for k in key["keys"].values():
+            device_id = k
+
+        if from_user_id is not None:
+            sql = (
+                "SELECT key_id, signature "
+                "  FROM e2e_cross_signing_signatures "
+                " WHERE user_id = ? "
+                "   AND target_user_id = ? "
+                "   AND target_device_id = ? "
+            )
+            txn.execute(sql, (from_user_id, user_id, device_id))
+            row = txn.fetchone()
+            if row:
+                key.setdefault("signatures", {}).setdefault(from_user_id, {})[
+                    row[0]
+                ] = row[1]
+
+        return key
+
+    def get_e2e_cross_signing_key(self, user_id, key_type, from_user_id=None):
+        """Returns a user's cross-signing key.
+
+        Args:
+            user_id (str): the user whose self-signing key is being requested
+            key_type (str): the type of cross-signing key to get
+            from_user_id (str): if specified, signatures made by this user on
+                the self-signing key will be included in the result
+
+        Returns:
+            dict of the key data or None if not found
+        """
+        return self.runInteraction(
+            "get_e2e_cross_signing_key",
+            self._get_e2e_cross_signing_key_txn,
+            user_id,
+            key_type,
+            from_user_id,
+        )
+
 
 class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
     def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):
@@ -426,73 +493,6 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
             key,
         )
 
-    def _get_e2e_cross_signing_key_txn(self, txn, user_id, key_type, from_user_id=None):
-        """Returns a user's cross-signing key.
-
-        Args:
-            txn (twisted.enterprise.adbapi.Connection): db connection
-            user_id (str): the user whose key is being requested
-            key_type (str): the type of key that is being set: either 'master'
-                for a master key, 'self_signing' for a self-signing key, or
-                'user_signing' for a user-signing key
-            from_user_id (str): if specified, signatures made by this user on
-                the key will be included in the result
-
-        Returns:
-            dict of the key data or None if not found
-        """
-        sql = (
-            "SELECT keydata "
-            "  FROM e2e_cross_signing_keys "
-            " WHERE user_id = ? AND keytype = ? ORDER BY stream_id DESC LIMIT 1"
-        )
-        txn.execute(sql, (user_id, key_type))
-        row = txn.fetchone()
-        if not row:
-            return None
-        key = json.loads(row[0])
-
-        device_id = None
-        for k in key["keys"].values():
-            device_id = k
-
-        if from_user_id is not None:
-            sql = (
-                "SELECT key_id, signature "
-                "  FROM e2e_cross_signing_signatures "
-                " WHERE user_id = ? "
-                "   AND target_user_id = ? "
-                "   AND target_device_id = ? "
-            )
-            txn.execute(sql, (from_user_id, user_id, device_id))
-            row = txn.fetchone()
-            if row:
-                key.setdefault("signatures", {}).setdefault(from_user_id, {})[
-                    row[0]
-                ] = row[1]
-
-        return key
-
-    def get_e2e_cross_signing_key(self, user_id, key_type, from_user_id=None):
-        """Returns a user's cross-signing key.
-
-        Args:
-            user_id (str): the user whose self-signing key is being requested
-            key_type (str): the type of cross-signing key to get
-            from_user_id (str): if specified, signatures made by this user on
-                the self-signing key will be included in the result
-
-        Returns:
-            dict of the key data or None if not found
-        """
-        return self.runInteraction(
-            "get_e2e_cross_signing_key",
-            self._get_e2e_cross_signing_key_txn,
-            user_id,
-            key_type,
-            from_user_id,
-        )
-
     def store_e2e_cross_signing_signatures(self, user_id, signatures):
         """Stores cross-signing signatures.