diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index b5e40da533..322d695bc7 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -33,7 +33,6 @@ from synapse.replication.tcp.streams import (
PushersStream,
PushRulesStream,
ReceiptsStream,
- TagAccountDataStream,
ToDeviceStream,
TypingStream,
UnPartialStatedEventStream,
@@ -168,7 +167,7 @@ class ReplicationDataHandler:
self.notifier.on_new_event(
StreamKeyType.PUSH_RULES, token, users=[row.user_id for row in rows]
)
- elif stream_name in (AccountDataStream.NAME, TagAccountDataStream.NAME):
+ elif stream_name in AccountDataStream.NAME:
self.notifier.on_new_event(
StreamKeyType.ACCOUNT_DATA, token, users=[row.user_id for row in rows]
)
@@ -188,7 +187,7 @@ class ReplicationDataHandler:
elif stream_name == DeviceListsStream.NAME:
all_room_ids: Set[str] = set()
for row in rows:
- if row.entity.startswith("@"):
+ if row.entity.startswith("@") and not row.is_signature:
room_ids = await self.store.get_rooms_for_user(row.entity)
all_room_ids.update(room_ids)
self.notifier.on_new_event(
@@ -326,7 +325,7 @@ class ReplicationDataHandler:
# anyway in that case we don't need to wait.
return
- current_position = self._streams[stream_name].current_token(self._instance_name)
+ current_position = self._streams[stream_name].current_token(instance_name)
if position <= current_position:
# We're already past the position
return
@@ -423,7 +422,11 @@ class FederationSenderHandler:
# The entities are either user IDs (starting with '@') whose devices
# have changed, or remote servers that we need to tell about
# changes.
- hosts = {row.entity for row in rows if not row.entity.startswith("@")}
+ hosts = {
+ row.entity
+ for row in rows
+ if not row.entity.startswith("@") and not row.is_signature
+ }
for host in hosts:
self.federation_sender.send_device_messages(host, immediate=False)
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 0f166d16aa..d03a53d764 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -58,7 +58,6 @@ from synapse.replication.tcp.streams import (
PresenceStream,
ReceiptsStream,
Stream,
- TagAccountDataStream,
ToDeviceStream,
TypingStream,
)
@@ -145,7 +144,7 @@ class ReplicationCommandHandler:
continue
- if isinstance(stream, (AccountDataStream, TagAccountDataStream)):
+ if isinstance(stream, AccountDataStream):
# Only add AccountDataStream and TagAccountDataStream as a source on the
# instance in charge of account_data persistence.
if hs.get_instance_name() in hs.config.worker.writers.account_data:
diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py
index 110f10aab9..9c67f661a3 100644
--- a/synapse/replication/tcp/streams/__init__.py
+++ b/synapse/replication/tcp/streams/__init__.py
@@ -35,10 +35,8 @@ from synapse.replication.tcp.streams._base import (
PushRulesStream,
ReceiptsStream,
Stream,
- TagAccountDataStream,
ToDeviceStream,
TypingStream,
- UserSignatureStream,
)
from synapse.replication.tcp.streams.events import EventsStream
from synapse.replication.tcp.streams.federation import FederationStream
@@ -62,9 +60,7 @@ STREAMS_MAP = {
DeviceListsStream,
ToDeviceStream,
FederationStream,
- TagAccountDataStream,
AccountDataStream,
- UserSignatureStream,
UnPartialStatedRoomStream,
UnPartialStatedEventStream,
)
@@ -83,9 +79,7 @@ __all__ = [
"CachesStream",
"DeviceListsStream",
"ToDeviceStream",
- "TagAccountDataStream",
"AccountDataStream",
- "UserSignatureStream",
"UnPartialStatedRoomStream",
"UnPartialStatedEventStream",
]
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index e01155ad59..a4bdb48c0c 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -28,8 +28,8 @@ from typing import (
import attr
+from synapse.api.constants import AccountDataTypes
from synapse.replication.http.streams import ReplicationGetStreamUpdates
-from synapse.types import JsonDict
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -463,18 +463,67 @@ class DeviceListsStream(Stream):
@attr.s(slots=True, frozen=True, auto_attribs=True)
class DeviceListsStreamRow:
entity: str
+ # Indicates that a user has signed their own device with their user-signing key
+ is_signature: bool
NAME = "device_lists"
ROW_TYPE = DeviceListsStreamRow
def __init__(self, hs: "HomeServer"):
- store = hs.get_datastores().main
+ self.store = hs.get_datastores().main
super().__init__(
hs.get_instance_name(),
- current_token_without_instance(store.get_device_stream_token),
- store.get_all_device_list_changes_for_remotes,
+ current_token_without_instance(self.store.get_device_stream_token),
+ self._update_function,
)
+ async def _update_function(
+ self,
+ instance_name: str,
+ from_token: Token,
+ current_token: Token,
+ target_row_count: int,
+ ) -> StreamUpdateResult:
+ (
+ device_updates,
+ devices_to_token,
+ devices_limited,
+ ) = await self.store.get_all_device_list_changes_for_remotes(
+ instance_name, from_token, current_token, target_row_count
+ )
+
+ (
+ signatures_updates,
+ signatures_to_token,
+ signatures_limited,
+ ) = await self.store.get_all_user_signature_changes_for_remotes(
+ instance_name, from_token, current_token, target_row_count
+ )
+
+ upper_limit_token = current_token
+ if devices_limited:
+ upper_limit_token = min(upper_limit_token, devices_to_token)
+ if signatures_limited:
+ upper_limit_token = min(upper_limit_token, signatures_to_token)
+
+ device_updates = [
+ (stream_id, (entity, False))
+ for stream_id, (entity,) in device_updates
+ if stream_id <= upper_limit_token
+ ]
+
+ signatures_updates = [
+ (stream_id, (entity, True))
+ for stream_id, (entity,) in signatures_updates
+ if stream_id <= upper_limit_token
+ ]
+
+ updates = list(
+ heapq.merge(device_updates, signatures_updates, key=lambda row: row[0])
+ )
+
+ return updates, upper_limit_token, devices_limited or signatures_limited
+
class ToDeviceStream(Stream):
"""New to_device messages for a client"""
@@ -495,27 +544,6 @@ class ToDeviceStream(Stream):
)
-class TagAccountDataStream(Stream):
- """Someone added/removed a tag for a room"""
-
- @attr.s(slots=True, frozen=True, auto_attribs=True)
- class TagAccountDataStreamRow:
- user_id: str
- room_id: str
- data: JsonDict
-
- NAME = "tag_account_data"
- ROW_TYPE = TagAccountDataStreamRow
-
- def __init__(self, hs: "HomeServer"):
- store = hs.get_datastores().main
- super().__init__(
- hs.get_instance_name(),
- current_token_without_instance(store.get_max_account_data_stream_id),
- store.get_all_updated_tags,
- )
-
-
class AccountDataStream(Stream):
"""Global or per room account data was changed"""
@@ -560,6 +588,19 @@ class AccountDataStream(Stream):
to_token = room_results[-1][0]
limited = True
+ tags, tag_to_token, tags_limited = await self.store.get_all_updated_tags(
+ instance_name,
+ from_token,
+ to_token,
+ limit,
+ )
+
+ # again, if the tag results hit the limit, limit the global results to
+ # the same stream token.
+ if tags_limited:
+ to_token = tag_to_token
+ limited = True
+
# convert the global results to the right format, and limit them to the to_token
# at the same time
global_rows = (
@@ -568,11 +609,16 @@ class AccountDataStream(Stream):
if stream_id <= to_token
)
- # we know that the room_results are already limited to `to_token` so no need
- # for a check on `stream_id` here.
room_rows = (
(stream_id, (user_id, room_id, account_data_type))
for stream_id, user_id, room_id, account_data_type in room_results
+ if stream_id <= to_token
+ )
+
+ tag_rows = (
+ (stream_id, (user_id, room_id, AccountDataTypes.TAG))
+ for stream_id, user_id, room_id in tags
+ if stream_id <= to_token
)
# We need to return a sorted list, so merge them together.
@@ -582,24 +628,7 @@ class AccountDataStream(Stream):
# leading to a comparison between the data tuples. The comparison could
# fail due to attempting to compare the `room_id` which results in a
# `TypeError` from comparing a `str` vs `None`.
- updates = list(heapq.merge(room_rows, global_rows, key=lambda row: row[0]))
- return updates, to_token, limited
-
-
-class UserSignatureStream(Stream):
- """A user has signed their own device with their user-signing key"""
-
- @attr.s(slots=True, frozen=True, auto_attribs=True)
- class UserSignatureStreamRow:
- user_id: str
-
- NAME = "user_signature"
- ROW_TYPE = UserSignatureStreamRow
-
- def __init__(self, hs: "HomeServer"):
- store = hs.get_datastores().main
- super().__init__(
- hs.get_instance_name(),
- current_token_without_instance(store.get_device_stream_token),
- store.get_all_user_signature_changes_for_remotes,
+ updates = list(
+ heapq.merge(room_rows, global_rows, tag_rows, key=lambda row: row[0])
)
+ return updates, to_token, limited
|