summary refs log tree commit diff
path: root/synapse/storage/data_stores
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2019-10-22 10:41:18 +0100
committerErik Johnston <erik@matrix.org>2019-10-22 10:41:18 +0100
commitbb6264be0b13314123b4fd047fe180616e8f0efa (patch)
treee06f393f38b4aea65101fbcdc0637e719b285727 /synapse/storage/data_stores
parentFix packaging (diff)
parentDelete format_tap.py (#6219) (diff)
downloadsynapse-bb6264be0b13314123b4fd047fe180616e8f0efa.tar.xz
Merge branch 'develop' of github.com:matrix-org/synapse into erikj/refactor_stores
Diffstat (limited to 'synapse/storage/data_stores')
-rw-r--r--synapse/storage/data_stores/main/__init__.py6
-rw-r--r--synapse/storage/data_stores/main/devices.py95
-rw-r--r--synapse/storage/data_stores/main/end_to_end_keys.py172
3 files changed, 262 insertions, 11 deletions
diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py
index d29135588f..b185ba0b3e 100644
--- a/synapse/storage/data_stores/main/__init__.py
+++ b/synapse/storage/data_stores/main/__init__.py
@@ -141,6 +141,9 @@ class DataStore(
         self._device_list_id_gen = StreamIdGenerator(
             db_conn, "device_lists_stream", "stream_id"
         )
+        self._cross_signing_id_gen = StreamIdGenerator(
+            db_conn, "e2e_cross_signing_keys", "stream_id"
+        )
 
         self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
         self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
@@ -212,6 +215,9 @@ class DataStore(
         self._device_list_stream_cache = StreamChangeCache(
             "DeviceListStreamChangeCache", device_list_max
         )
+        self._user_signature_stream_cache = StreamChangeCache(
+            "UserSignatureStreamChangeCache", device_list_max
+        )
         self._device_list_federation_stream_cache = StreamChangeCache(
             "DeviceListFederationStreamChangeCache", device_list_max
         )
diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py
index ac5239e50a..f7a3542348 100644
--- a/synapse/storage/data_stores/main/devices.py
+++ b/synapse/storage/data_stores/main/devices.py
@@ -1,5 +1,7 @@
 # -*- coding: utf-8 -*-
 # Copyright 2016 OpenMarket Ltd
+# Copyright 2019 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -20,7 +22,7 @@ from canonicaljson import json
 
 from twisted.internet import defer
 
-from synapse.api.errors import StoreError
+from synapse.api.errors import Codes, StoreError
 from synapse.logging.opentracing import (
     get_active_span_text_map,
     set_tag,
@@ -47,7 +49,8 @@ DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = (
 
 class DeviceWorkerStore(SQLBaseStore):
     def get_device(self, user_id, device_id):
-        """Retrieve a device.
+        """Retrieve a device. Only returns devices that are not marked as
+        hidden.
 
         Args:
             user_id (str): The ID of the user which owns the device
@@ -59,14 +62,15 @@ class DeviceWorkerStore(SQLBaseStore):
         """
         return self._simple_select_one(
             table="devices",
-            keyvalues={"user_id": user_id, "device_id": device_id},
+            keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
             retcols=("user_id", "device_id", "display_name"),
             desc="get_device",
         )
 
     @defer.inlineCallbacks
     def get_devices_by_user(self, user_id):
-        """Retrieve all of a user's registered devices.
+        """Retrieve all of a user's registered devices. Only returns devices
+        that are not marked as hidden.
 
         Args:
             user_id (str):
@@ -77,7 +81,7 @@ class DeviceWorkerStore(SQLBaseStore):
         """
         devices = yield self._simple_select_list(
             table="devices",
-            keyvalues={"user_id": user_id},
+            keyvalues={"user_id": user_id, "hidden": False},
             retcols=("user_id", "device_id", "display_name"),
             desc="get_devices_by_user",
         )
@@ -324,6 +328,41 @@ class DeviceWorkerStore(SQLBaseStore):
         """
         txn.execute(sql, (destination, stream_id))
 
+    @defer.inlineCallbacks
+    def add_user_signature_change_to_streams(self, from_user_id, user_ids):
+        """Persist that a user has made new signatures
+
+        Args:
+            from_user_id (str): the user who made the signatures
+            user_ids (list[str]): the users who were signed
+        """
+
+        with self._device_list_id_gen.get_next() as stream_id:
+            yield self.runInteraction(
+                "add_user_sig_change_to_streams",
+                self._add_user_signature_change_txn,
+                from_user_id,
+                user_ids,
+                stream_id,
+            )
+        return stream_id
+
+    def _add_user_signature_change_txn(self, txn, from_user_id, user_ids, stream_id):
+        txn.call_after(
+            self._user_signature_stream_cache.entity_has_changed,
+            from_user_id,
+            stream_id,
+        )
+        self._simple_insert_txn(
+            txn,
+            "user_signature_stream",
+            values={
+                "stream_id": stream_id,
+                "from_user_id": from_user_id,
+                "user_ids": json.dumps(user_ids),
+            },
+        )
+
     def get_device_stream_token(self):
         return self._device_list_id_gen.get_current_token()
 
@@ -469,6 +508,28 @@ class DeviceWorkerStore(SQLBaseStore):
             "get_users_whose_devices_changed", _get_users_whose_devices_changed_txn
         )
 
+    @defer.inlineCallbacks
+    def get_users_whose_signatures_changed(self, user_id, from_key):
+        """Get the users who have new cross-signing signatures made by `user_id` since
+        `from_key`.
+
+        Args:
+            user_id (str): the user who made the signatures
+            from_key (str): The device lists stream token
+        """
+        from_key = int(from_key)
+        if self._user_signature_stream_cache.has_entity_changed(user_id, from_key):
+            sql = """
+                SELECT DISTINCT user_ids FROM user_signature_stream
+                WHERE from_user_id = ? AND stream_id > ?
+            """
+            rows = yield self._execute(
+                "get_users_whose_signatures_changed", None, sql, user_id, from_key
+            )
+            return set(user for row in rows for user in json.loads(row[0]))
+        else:
+            return set()
+
     def get_all_device_list_changes_for_remotes(self, from_key, to_key):
         """Return a list of `(stream_id, user_id, destination)` which is the
         combined list of changes to devices, and which destinations need to be
@@ -592,6 +653,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         Returns:
             defer.Deferred: boolean whether the device was inserted or an
                 existing device existed with that ID.
+        Raises:
+            StoreError: if the device is already in use
         """
         key = (user_id, device_id)
         if self.device_id_exists_cache.get(key, None):
@@ -604,12 +667,25 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
                     "user_id": user_id,
                     "device_id": device_id,
                     "display_name": initial_device_display_name,
+                    "hidden": False,
                 },
                 desc="store_device",
                 or_ignore=True,
             )
+            if not inserted:
+                # if the device already exists, check if it's a real device, or
+                # if the device ID is reserved by something else
+                hidden = yield self._simple_select_one_onecol(
+                    "devices",
+                    keyvalues={"user_id": user_id, "device_id": device_id},
+                    retcol="hidden",
+                )
+                if hidden:
+                    raise StoreError(400, "The device ID is in use", Codes.FORBIDDEN)
             self.device_id_exists_cache.prefill(key, True)
             return inserted
+        except StoreError:
+            raise
         except Exception as e:
             logger.error(
                 "store_device with device_id=%s(%r) user_id=%s(%r)"
@@ -636,7 +712,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         """
         yield self._simple_delete_one(
             table="devices",
-            keyvalues={"user_id": user_id, "device_id": device_id},
+            keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
             desc="delete_device",
         )
 
@@ -656,14 +732,15 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             table="devices",
             column="device_id",
             iterable=device_ids,
-            keyvalues={"user_id": user_id},
+            keyvalues={"user_id": user_id, "hidden": False},
             desc="delete_devices",
         )
         for device_id in device_ids:
             self.device_id_exists_cache.invalidate((user_id, device_id))
 
     def update_device(self, user_id, device_id, new_display_name=None):
-        """Update a device.
+        """Update a device. Only updates the device if it is not marked as
+        hidden.
 
         Args:
             user_id (str): The ID of the user which owns the device
@@ -682,7 +759,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             return defer.succeed(None)
         return self._simple_update_one(
             table="devices",
-            keyvalues={"user_id": user_id, "device_id": device_id},
+            keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
             updatevalues=updates,
             desc="update_device",
         )
diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py
index 7e44f41046..5e4c86025e 100644
--- a/synapse/storage/data_stores/main/end_to_end_keys.py
+++ b/synapse/storage/data_stores/main/end_to_end_keys.py
@@ -1,5 +1,7 @@
 # -*- coding: utf-8 -*-
 # Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2019 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -14,7 +16,7 @@
 # limitations under the License.
 from six import iteritems
 
-from canonicaljson import encode_canonical_json
+from canonicaljson import encode_canonical_json, json
 
 from twisted.internet import defer
 
@@ -101,7 +103,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
             "    k.key_json"
             " FROM devices d"
             "    %s JOIN e2e_device_keys_json k USING (user_id, device_id)"
-            " WHERE %s"
+            " WHERE %s AND NOT d.hidden"
         ) % (
             "LEFT" if include_all_devices else "INNER",
             " OR ".join("(" + q + ")" for q in query_clauses),
@@ -320,3 +322,169 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
         return self.runInteraction(
             "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
         )
+
+    def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key):
+        """Set a user's cross-signing key.
+
+        Args:
+            txn (twisted.enterprise.adbapi.Connection): db connection
+            user_id (str): the user to set the signing key for
+            key_type (str): the type of key that is being set: either 'master'
+                for a master key, 'self_signing' for a self-signing key, or
+                'user_signing' for a user-signing key
+            key (dict): the key data
+        """
+        # the cross-signing keys need to occupy the same namespace as devices,
+        # since signatures are identified by device ID.  So add an entry to the
+        # device table to make sure that we don't have a collision with device
+        # IDs
+
+        # the 'key' dict will look something like:
+        # {
+        #   "user_id": "@alice:example.com",
+        #   "usage": ["self_signing"],
+        #   "keys": {
+        #     "ed25519:base64+self+signing+public+key": "base64+self+signing+public+key",
+        #   },
+        #   "signatures": {
+        #     "@alice:example.com": {
+        #       "ed25519:base64+master+public+key": "base64+signature"
+        #     }
+        #   }
+        # }
+        # The "keys" property must only have one entry, which will be the public
+        # key, so we just grab the first value in there
+        pubkey = next(iter(key["keys"].values()))
+        self._simple_insert_txn(
+            txn,
+            "devices",
+            values={
+                "user_id": user_id,
+                "device_id": pubkey,
+                "display_name": key_type + " signing key",
+                "hidden": True,
+            },
+        )
+
+        # and finally, store the key itself
+        with self._cross_signing_id_gen.get_next() as stream_id:
+            self._simple_insert_txn(
+                txn,
+                "e2e_cross_signing_keys",
+                values={
+                    "user_id": user_id,
+                    "keytype": key_type,
+                    "keydata": json.dumps(key),
+                    "stream_id": stream_id,
+                },
+            )
+
+    def set_e2e_cross_signing_key(self, user_id, key_type, key):
+        """Set a user's cross-signing key.
+
+        Args:
+            user_id (str): the user to set the user-signing key for
+            key_type (str): the type of cross-signing key to set
+            key (dict): the key data
+        """
+        return self.runInteraction(
+            "add_e2e_cross_signing_key",
+            self._set_e2e_cross_signing_key_txn,
+            user_id,
+            key_type,
+            key,
+        )
+
+    def _get_e2e_cross_signing_key_txn(self, txn, user_id, key_type, from_user_id=None):
+        """Returns a user's cross-signing key.
+
+        Args:
+            txn (twisted.enterprise.adbapi.Connection): db connection
+            user_id (str): the user whose key is being requested
+            key_type (str): the type of key that is being set: either 'master'
+                for a master key, 'self_signing' for a self-signing key, or
+                'user_signing' for a user-signing key
+            from_user_id (str): if specified, signatures made by this user on
+                the key will be included in the result
+
+        Returns:
+            dict of the key data or None if not found
+        """
+        sql = (
+            "SELECT keydata "
+            "  FROM e2e_cross_signing_keys "
+            " WHERE user_id = ? AND keytype = ? ORDER BY stream_id DESC LIMIT 1"
+        )
+        txn.execute(sql, (user_id, key_type))
+        row = txn.fetchone()
+        if not row:
+            return None
+        key = json.loads(row[0])
+
+        device_id = None
+        for k in key["keys"].values():
+            device_id = k
+
+        if from_user_id is not None:
+            sql = (
+                "SELECT key_id, signature "
+                "  FROM e2e_cross_signing_signatures "
+                " WHERE user_id = ? "
+                "   AND target_user_id = ? "
+                "   AND target_device_id = ? "
+            )
+            txn.execute(sql, (from_user_id, user_id, device_id))
+            row = txn.fetchone()
+            if row:
+                key.setdefault("signatures", {}).setdefault(from_user_id, {})[
+                    row[0]
+                ] = row[1]
+
+        return key
+
+    def get_e2e_cross_signing_key(self, user_id, key_type, from_user_id=None):
+        """Returns a user's cross-signing key.
+
+        Args:
+            user_id (str): the user whose self-signing key is being requested
+            key_type (str): the type of cross-signing key to get
+            from_user_id (str): if specified, signatures made by this user on
+                the self-signing key will be included in the result
+
+        Returns:
+            dict of the key data or None if not found
+        """
+        return self.runInteraction(
+            "get_e2e_cross_signing_key",
+            self._get_e2e_cross_signing_key_txn,
+            user_id,
+            key_type,
+            from_user_id,
+        )
+
+    def store_e2e_cross_signing_signatures(self, user_id, signatures):
+        """Stores cross-signing signatures.
+
+        Args:
+            user_id (str): the user who made the signatures
+            signatures (iterable[(str, str, str, str)]): signatures to add - each
+                a tuple of (key_id, target_user_id, target_device_id, signature),
+                where key_id is the ID of the key (including the signature
+                algorithm) that made the signature, target_user_id and
+                target_device_id indicate the device being signed, and signature
+                is the signature of the device
+        """
+        return self._simple_insert_many(
+            "e2e_cross_signing_signatures",
+            [
+                {
+                    "user_id": user_id,
+                    "key_id": key_id,
+                    "target_user_id": target_user_id,
+                    "target_device_id": target_device_id,
+                    "signature": signature,
+                }
+                for (key_id, target_user_id, target_device_id, signature) in signatures
+            ],
+            "add_e2e_signing_key",
+        )