diff options
Diffstat (limited to 'synapse/replication')
-rw-r--r-- | synapse/replication/tcp/client.py | 3 | ||||
-rw-r--r-- | synapse/replication/tcp/handler.py | 3 | ||||
-rw-r--r-- | synapse/replication/tcp/streams/__init__.py | 3 | ||||
-rw-r--r-- | synapse/replication/tcp/streams/_base.py | 49 |
4 files changed, 26 insertions, 32 deletions
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index b5e40da533..7263bb2796 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -33,7 +33,6 @@ from synapse.replication.tcp.streams import ( PushersStream, PushRulesStream, ReceiptsStream, - TagAccountDataStream, ToDeviceStream, TypingStream, UnPartialStatedEventStream, @@ -168,7 +167,7 @@ class ReplicationDataHandler: self.notifier.on_new_event( StreamKeyType.PUSH_RULES, token, users=[row.user_id for row in rows] ) - elif stream_name in (AccountDataStream.NAME, TagAccountDataStream.NAME): + elif stream_name in AccountDataStream.NAME: self.notifier.on_new_event( StreamKeyType.ACCOUNT_DATA, token, users=[row.user_id for row in rows] ) diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 0f166d16aa..d03a53d764 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -58,7 +58,6 @@ from synapse.replication.tcp.streams import ( PresenceStream, ReceiptsStream, Stream, - TagAccountDataStream, ToDeviceStream, TypingStream, ) @@ -145,7 +144,7 @@ class ReplicationCommandHandler: continue - if isinstance(stream, (AccountDataStream, TagAccountDataStream)): + if isinstance(stream, AccountDataStream): # Only add AccountDataStream and TagAccountDataStream as a source on the # instance in charge of account_data persistence. if hs.get_instance_name() in hs.config.worker.writers.account_data: diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py index 110f10aab9..a7eadfa3c9 100644 --- a/synapse/replication/tcp/streams/__init__.py +++ b/synapse/replication/tcp/streams/__init__.py @@ -35,7 +35,6 @@ from synapse.replication.tcp.streams._base import ( PushRulesStream, ReceiptsStream, Stream, - TagAccountDataStream, ToDeviceStream, TypingStream, UserSignatureStream, @@ -62,7 +61,6 @@ STREAMS_MAP = { DeviceListsStream, ToDeviceStream, FederationStream, - TagAccountDataStream, AccountDataStream, UserSignatureStream, UnPartialStatedRoomStream, @@ -83,7 +81,6 @@ __all__ = [ "CachesStream", "DeviceListsStream", "ToDeviceStream", - "TagAccountDataStream", "AccountDataStream", "UserSignatureStream", "UnPartialStatedRoomStream", diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index e01155ad59..fbf78da9c2 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -28,8 +28,8 @@ from typing import ( import attr +from synapse.api.constants import AccountDataTypes from synapse.replication.http.streams import ReplicationGetStreamUpdates -from synapse.types import JsonDict if TYPE_CHECKING: from synapse.server import HomeServer @@ -495,27 +495,6 @@ class ToDeviceStream(Stream): ) -class TagAccountDataStream(Stream): - """Someone added/removed a tag for a room""" - - @attr.s(slots=True, frozen=True, auto_attribs=True) - class TagAccountDataStreamRow: - user_id: str - room_id: str - data: JsonDict - - NAME = "tag_account_data" - ROW_TYPE = TagAccountDataStreamRow - - def __init__(self, hs: "HomeServer"): - store = hs.get_datastores().main - super().__init__( - hs.get_instance_name(), - current_token_without_instance(store.get_max_account_data_stream_id), - store.get_all_updated_tags, - ) - - class AccountDataStream(Stream): """Global or per room account data was changed""" @@ -560,6 +539,19 @@ class AccountDataStream(Stream): to_token = room_results[-1][0] limited = True + tags, tag_to_token, tags_limited = await self.store.get_all_updated_tags( + instance_name, + from_token, + to_token, + limit, + ) + + # again, if the tag results hit the limit, limit the global results to + # the same stream token. + if tags_limited: + to_token = tag_to_token + limited = True + # convert the global results to the right format, and limit them to the to_token # at the same time global_rows = ( @@ -568,11 +560,16 @@ class AccountDataStream(Stream): if stream_id <= to_token ) - # we know that the room_results are already limited to `to_token` so no need - # for a check on `stream_id` here. room_rows = ( (stream_id, (user_id, room_id, account_data_type)) for stream_id, user_id, room_id, account_data_type in room_results + if stream_id <= to_token + ) + + tag_rows = ( + (stream_id, (user_id, room_id, AccountDataTypes.TAG)) + for stream_id, user_id, room_id in tags + if stream_id <= to_token ) # We need to return a sorted list, so merge them together. @@ -582,7 +579,9 @@ class AccountDataStream(Stream): # leading to a comparison between the data tuples. The comparison could # fail due to attempting to compare the `room_id` which results in a # `TypeError` from comparing a `str` vs `None`. - updates = list(heapq.merge(room_rows, global_rows, key=lambda row: row[0])) + updates = list( + heapq.merge(room_rows, global_rows, tag_rows, key=lambda row: row[0]) + ) return updates, to_token, limited |