summary refs log tree commit diff
path: root/synapse/storage/devices.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/devices.py')
-rw-r--r--synapse/storage/devices.py72
1 files changed, 67 insertions, 5 deletions
diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
index f7716302f0..7b4213f20b 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/devices.py
@@ -25,6 +25,7 @@ from twisted.internet import defer
 from synapse.api.errors import Codes, StoreError
 from synapse.logging.opentracing import (
     get_active_span_text_map,
+    set_tag,
     trace,
     whitelisted_homeserver,
 )
@@ -98,7 +99,7 @@ class DeviceWorkerStore(SQLBaseStore):
             destination, int(from_stream_id)
         )
         if not has_changed:
-            return (now_stream_id, [])
+            return now_stream_id, []
 
         # We retrieve n+1 devices from the list of outbound pokes where n is
         # our outbound device update limit. We then check if the very last
@@ -121,7 +122,7 @@ class DeviceWorkerStore(SQLBaseStore):
 
         # Return an empty list if there are no updates
         if not updates:
-            return (now_stream_id, [])
+            return now_stream_id, []
 
         # if we have exceeded the limit, we need to exclude any results with the
         # same stream_id as the last row.
@@ -171,13 +172,13 @@ class DeviceWorkerStore(SQLBaseStore):
         # skip that stream_id and return an empty list, and continue with the next
         # stream_id next time.
         if not query_map:
-            return (stream_id_cutoff, [])
+            return stream_id_cutoff, []
 
         results = yield self._get_device_update_edus_by_remote(
             destination, from_stream_id, query_map
         )
 
-        return (now_stream_id, results)
+        return now_stream_id, results
 
     def _get_devices_by_remote_txn(
         self, txn, destination, from_stream_id, now_stream_id, limit
@@ -322,9 +323,45 @@ class DeviceWorkerStore(SQLBaseStore):
         """
         txn.execute(sql, (destination, stream_id))
 
+    @defer.inlineCallbacks
+    def add_user_signature_change_to_streams(self, from_user_id, user_ids):
+        """Persist that a user has made new signatures
+
+        Args:
+            from_user_id (str): the user who made the signatures
+            user_ids (list[str]): the users who were signed
+        """
+
+        with self._device_list_id_gen.get_next() as stream_id:
+            yield self.runInteraction(
+                "add_user_sig_change_to_streams",
+                self._add_user_signature_change_txn,
+                from_user_id,
+                user_ids,
+                stream_id,
+            )
+        return stream_id
+
+    def _add_user_signature_change_txn(self, txn, from_user_id, user_ids, stream_id):
+        txn.call_after(
+            self._user_signature_stream_cache.entity_has_changed,
+            from_user_id,
+            stream_id,
+        )
+        self._simple_insert_txn(
+            txn,
+            "user_signature_stream",
+            values={
+                "stream_id": stream_id,
+                "from_user_id": from_user_id,
+                "user_ids": json.dumps(user_ids),
+            },
+        )
+
     def get_device_stream_token(self):
         return self._device_list_id_gen.get_current_token()
 
+    @trace
     @defer.inlineCallbacks
     def get_user_devices_from_cache(self, query_list):
         """Get the devices (and keys if any) for remote users from the cache.
@@ -356,7 +393,10 @@ class DeviceWorkerStore(SQLBaseStore):
             else:
                 results[user_id] = yield self._get_cached_devices_for_user(user_id)
 
-        return (user_ids_not_in_cache, results)
+        set_tag("in_cache", results)
+        set_tag("not_in_cache", user_ids_not_in_cache)
+
+        return user_ids_not_in_cache, results
 
     @cachedInlineCallbacks(num_args=2, tree=True)
     def _get_cached_user_device(self, user_id, device_id):
@@ -460,6 +500,28 @@ class DeviceWorkerStore(SQLBaseStore):
             "get_users_whose_devices_changed", _get_users_whose_devices_changed_txn
         )
 
+    @defer.inlineCallbacks
+    def get_users_whose_signatures_changed(self, user_id, from_key):
+        """Get the users who have new cross-signing signatures made by `user_id` since
+        `from_key`.
+
+        Args:
+            user_id (str): the user who made the signatures
+            from_key (str): The device lists stream token
+        """
+        from_key = int(from_key)
+        if self._user_signature_stream_cache.has_entity_changed(user_id, from_key):
+            sql = """
+                SELECT DISTINCT user_ids FROM user_signature_stream
+                WHERE from_user_id = ? AND stream_id > ?
+            """
+            rows = yield self._execute(
+                "get_users_whose_signatures_changed", None, sql, user_id, from_key
+            )
+            return set(user for row in rows for user in json.loads(row[0]))
+        else:
+            return set()
+
     def get_all_device_list_changes_for_remotes(self, from_key, to_key):
         """Return a list of `(stream_id, user_id, destination)` which is the
         combined list of changes to devices, and which destinations need to be