diff --git a/changelog.d/6254.bugfix b/changelog.d/6254.bugfix
new file mode 100644
index 0000000000..3181484b88
--- /dev/null
+++ b/changelog.d/6254.bugfix
@@ -0,0 +1 @@
+Make notification of cross-signing signatures work with workers.
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index 61557665a7..de50748c30 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -15,6 +15,7 @@
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
+from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream
from synapse.storage.data_stores.main.devices import DeviceWorkerStore
from synapse.storage.data_stores.main.end_to_end_keys import EndToEndKeyWorkerStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -42,14 +43,22 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
def stream_positions(self):
result = super(SlavedDeviceStore, self).stream_positions()
- result["device_lists"] = self._device_list_id_gen.get_current_token()
+ # The user signature stream uses the same stream ID generator as the
+ # device list stream, so set them both to the device list ID
+ # generator's current token.
+ current_token = self._device_list_id_gen.get_current_token()
+ result[DeviceListsStream.NAME] = current_token
+ result[UserSignatureStream.NAME] = current_token
return result
def process_replication_rows(self, stream_name, token, rows):
- if stream_name == "device_lists":
+ if stream_name == DeviceListsStream.NAME:
self._device_list_id_gen.advance(token)
for row in rows:
self._invalidate_caches_for_devices(token, row.user_id, row.destination)
+ elif stream_name == UserSignatureStream.NAME:
+ for row in rows:
+ self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
return super(SlavedDeviceStore, self).process_replication_rows(
stream_name, token, rows
)
diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py
index 634f636dc9..5f52264e84 100644
--- a/synapse/replication/tcp/streams/__init__.py
+++ b/synapse/replication/tcp/streams/__init__.py
@@ -45,5 +45,6 @@ STREAMS_MAP = {
_base.TagAccountDataStream,
_base.AccountDataStream,
_base.GroupServerStream,
+ _base.UserSignatureStream,
)
}
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index f03111c259..9e45429d49 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -95,6 +95,7 @@ GroupsStreamRow = namedtuple(
"GroupsStreamRow",
("group_id", "user_id", "type", "content"), # str # str # str # dict
)
+UserSignatureStreamRow = namedtuple("UserSignatureStreamRow", ("user_id")) # str
class Stream(object):
@@ -438,3 +439,20 @@ class GroupServerStream(Stream):
self.update_function = store.get_all_groups_changes
super(GroupServerStream, self).__init__(hs)
+
+
+class UserSignatureStream(Stream):
+ """A user has signed their own device with their user-signing key
+ """
+
+ NAME = "user_signature"
+ _LIMITED = False
+ ROW_TYPE = UserSignatureStreamRow
+
+ def __init__(self, hs):
+ store = hs.get_datastore()
+
+ self.current_token = store.get_device_stream_token
+ self.update_function = store.get_all_user_signature_changes_for_remotes
+
+ super(UserSignatureStream, self).__init__(hs)
diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py
index 60ae01d972..10c940df1e 100644
--- a/synapse/storage/data_stores/main/__init__.py
+++ b/synapse/storage/data_stores/main/__init__.py
@@ -139,7 +139,10 @@ class DataStore(
db_conn, "public_room_list_stream", "stream_id"
)
self._device_list_id_gen = StreamIdGenerator(
- db_conn, "device_lists_stream", "stream_id"
+ db_conn,
+ "device_lists_stream",
+ "stream_id",
+ extra_tables=[("user_signature_stream", "stream_id")],
)
self._cross_signing_id_gen = StreamIdGenerator(
db_conn, "e2e_cross_signing_keys", "stream_id"
diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py
index a0bc6f2d18..073412a78d 100644
--- a/synapse/storage/data_stores/main/end_to_end_keys.py
+++ b/synapse/storage/data_stores/main/end_to_end_keys.py
@@ -315,6 +315,30 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
from_user_id,
)
+ def get_all_user_signature_changes_for_remotes(self, from_key, to_key):
+ """Return a list of changes from the user signature stream to notify remotes.
+ Note that the user signature stream represents when a user signs their
+ device with their user-signing key, which is not published to other
+ users or servers, so no `destination` is needed in the returned
+ list. However, this is needed to poke workers.
+
+ Args:
+ from_key (int): the stream ID to start at (exclusive)
+ to_key (int): the stream ID to end at (inclusive)
+
+ Returns:
+ Deferred[list[(int,str)]] a list of `(stream_id, user_id)`
+ """
+ sql = """
+ SELECT MAX(stream_id) AS stream_id, from_user_id AS user_id
+ FROM user_signature_stream
+ WHERE ? < stream_id AND stream_id <= ?
+ GROUP BY user_id
+ """
+ return self._execute(
+ "get_all_user_signature_changes_for_remotes", None, sql, from_key, to_key
+ )
+
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):
|