diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
index b73401bc26..ed372e2fc4 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/devices.py
@@ -302,6 +302,41 @@ class DeviceWorkerStore(SQLBaseStore):
"""
txn.execute(sql, (destination, stream_id))
+ @defer.inlineCallbacks
+ def add_user_signature_change_to_streams(self, from_user_id, user_ids):
+ """Persist that a user has made new signatures
+
+ Args:
+ from_user_id (str): the user who made the signatures
+ user_ids (list[str]): the users who were signed
+ """
+
+ with self._device_list_id_gen.get_next() as stream_id:
+ yield self.runInteraction(
+ "add_user_sig_change_to_streams",
+ self._add_user_signature_change_txn,
+ from_user_id,
+ user_ids,
+ stream_id,
+ )
+ defer.returnValue(stream_id)
+
+ def _add_user_signature_change_txn(self, txn, from_user_id, user_ids, stream_id):
+ txn.call_after(
+ self._user_signature_stream_cache.entity_has_changed,
+ from_user_id,
+ stream_id,
+ )
+ self._simple_insert_txn(
+ txn,
+ "user_signature_stream",
+ values={
+ "stream_id": stream_id,
+ "from_user_id": from_user_id,
+ "user_ids": json.dumps(user_ids),
+ },
+ )
+
def get_device_stream_token(self):
return self._device_list_id_gen.get_current_token()
@@ -440,6 +475,28 @@ class DeviceWorkerStore(SQLBaseStore):
"get_users_whose_devices_changed", _get_users_whose_devices_changed_txn
)
+ @defer.inlineCallbacks
+ def get_users_whose_signatures_changed(self, user_id, from_key):
+ """Get the users who have new cross-signing signatures made by `user_id` since
+ `from_key`.
+
+ Args:
+ user_id (str): the user who made the signatures
+ from_key (str): The device lists stream token
+ """
+ from_key = int(from_key)
+ if self._user_signature_stream_cache.has_entity_changed(user_id, from_key):
+ sql = """
+ SELECT DISTINCT user_ids FROM user_signature_stream
+ WHERE from_user_id = ? AND stream_id > ?
+ """
+ rows = yield self._execute(
+ "get_users_whose_signatures_changed", None, sql, user_id, from_key
+ )
+ defer.returnValue(set(user for row in rows for user in json.loads(row[0])))
+ else:
+ defer.returnValue(set())
+
def get_all_device_list_changes_for_remotes(self, from_key, to_key):
"""Return a list of `(stream_id, user_id, destination)` which is the
combined list of changes to devices, and which destinations need to be
|