summary refs log tree commit diff
path: root/synapse/storage/devices.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/devices.py')
-rw-r--r--synapse/storage/devices.py95
1 files changed, 86 insertions, 9 deletions
diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
index ac5239e50a..f7a3542348 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/devices.py
@@ -1,5 +1,7 @@
 # -*- coding: utf-8 -*-
 # Copyright 2016 OpenMarket Ltd
+# Copyright 2019 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -20,7 +22,7 @@ from canonicaljson import json
 
 from twisted.internet import defer
 
-from synapse.api.errors import StoreError
+from synapse.api.errors import Codes, StoreError
 from synapse.logging.opentracing import (
     get_active_span_text_map,
     set_tag,
@@ -47,7 +49,8 @@ DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = (
 
 class DeviceWorkerStore(SQLBaseStore):
     def get_device(self, user_id, device_id):
-        """Retrieve a device.
+        """Retrieve a device. Only returns devices that are not marked as
+        hidden.
 
         Args:
             user_id (str): The ID of the user which owns the device
@@ -59,14 +62,15 @@ class DeviceWorkerStore(SQLBaseStore):
         """
         return self._simple_select_one(
             table="devices",
-            keyvalues={"user_id": user_id, "device_id": device_id},
+            keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
             retcols=("user_id", "device_id", "display_name"),
             desc="get_device",
         )
 
     @defer.inlineCallbacks
     def get_devices_by_user(self, user_id):
-        """Retrieve all of a user's registered devices.
+        """Retrieve all of a user's registered devices. Only returns devices
+        that are not marked as hidden.
 
         Args:
             user_id (str):
@@ -77,7 +81,7 @@ class DeviceWorkerStore(SQLBaseStore):
         """
         devices = yield self._simple_select_list(
             table="devices",
-            keyvalues={"user_id": user_id},
+            keyvalues={"user_id": user_id, "hidden": False},
             retcols=("user_id", "device_id", "display_name"),
             desc="get_devices_by_user",
         )
@@ -324,6 +328,41 @@ class DeviceWorkerStore(SQLBaseStore):
         """
         txn.execute(sql, (destination, stream_id))
 
+    @defer.inlineCallbacks
+    def add_user_signature_change_to_streams(self, from_user_id, user_ids):
+        """Persist that a user has made new signatures
+
+        Args:
+            from_user_id (str): the user who made the signatures
+            user_ids (list[str]): the users who were signed
+        """
+
+        with self._device_list_id_gen.get_next() as stream_id:
+            yield self.runInteraction(
+                "add_user_sig_change_to_streams",
+                self._add_user_signature_change_txn,
+                from_user_id,
+                user_ids,
+                stream_id,
+            )
+        return stream_id
+
+    def _add_user_signature_change_txn(self, txn, from_user_id, user_ids, stream_id):
+        txn.call_after(
+            self._user_signature_stream_cache.entity_has_changed,
+            from_user_id,
+            stream_id,
+        )
+        self._simple_insert_txn(
+            txn,
+            "user_signature_stream",
+            values={
+                "stream_id": stream_id,
+                "from_user_id": from_user_id,
+                "user_ids": json.dumps(user_ids),
+            },
+        )
+
     def get_device_stream_token(self):
         return self._device_list_id_gen.get_current_token()
 
@@ -469,6 +508,28 @@ class DeviceWorkerStore(SQLBaseStore):
             "get_users_whose_devices_changed", _get_users_whose_devices_changed_txn
         )
 
+    @defer.inlineCallbacks
+    def get_users_whose_signatures_changed(self, user_id, from_key):
+        """Get the users who have new cross-signing signatures made by `user_id` since
+        `from_key`.
+
+        Args:
+            user_id (str): the user who made the signatures
+            from_key (str): The device lists stream token
+        """
+        from_key = int(from_key)
+        if self._user_signature_stream_cache.has_entity_changed(user_id, from_key):
+            sql = """
+                SELECT DISTINCT user_ids FROM user_signature_stream
+                WHERE from_user_id = ? AND stream_id > ?
+            """
+            rows = yield self._execute(
+                "get_users_whose_signatures_changed", None, sql, user_id, from_key
+            )
+            return set(user for row in rows for user in json.loads(row[0]))
+        else:
+            return set()
+
     def get_all_device_list_changes_for_remotes(self, from_key, to_key):
         """Return a list of `(stream_id, user_id, destination)` which is the
         combined list of changes to devices, and which destinations need to be
@@ -592,6 +653,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         Returns:
             defer.Deferred: boolean whether the device was inserted or an
                 existing device existed with that ID.
+        Raises:
+            StoreError: if the device is already in use
         """
         key = (user_id, device_id)
         if self.device_id_exists_cache.get(key, None):
@@ -604,12 +667,25 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
                     "user_id": user_id,
                     "device_id": device_id,
                     "display_name": initial_device_display_name,
+                    "hidden": False,
                 },
                 desc="store_device",
                 or_ignore=True,
             )
+            if not inserted:
+                # if the device already exists, check if it's a real device, or
+                # if the device ID is reserved by something else
+                hidden = yield self._simple_select_one_onecol(
+                    "devices",
+                    keyvalues={"user_id": user_id, "device_id": device_id},
+                    retcol="hidden",
+                )
+                if hidden:
+                    raise StoreError(400, "The device ID is in use", Codes.FORBIDDEN)
             self.device_id_exists_cache.prefill(key, True)
             return inserted
+        except StoreError:
+            raise
         except Exception as e:
             logger.error(
                 "store_device with device_id=%s(%r) user_id=%s(%r)"
@@ -636,7 +712,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         """
         yield self._simple_delete_one(
             table="devices",
-            keyvalues={"user_id": user_id, "device_id": device_id},
+            keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
             desc="delete_device",
         )
 
@@ -656,14 +732,15 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             table="devices",
             column="device_id",
             iterable=device_ids,
-            keyvalues={"user_id": user_id},
+            keyvalues={"user_id": user_id, "hidden": False},
             desc="delete_devices",
         )
         for device_id in device_ids:
             self.device_id_exists_cache.invalidate((user_id, device_id))
 
     def update_device(self, user_id, device_id, new_display_name=None):
-        """Update a device.
+        """Update a device. Only updates the device if it is not marked as
+        hidden.
 
         Args:
             user_id (str): The ID of the user which owns the device
@@ -682,7 +759,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             return defer.succeed(None)
         return self._simple_update_one(
             table="devices",
-            keyvalues={"user_id": user_id, "device_id": device_id},
+            keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
             updatevalues=updates,
             desc="update_device",
         )