diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index e01155ad59..fbf78da9c2 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -28,8 +28,8 @@ from typing import (
import attr
+from synapse.api.constants import AccountDataTypes
from synapse.replication.http.streams import ReplicationGetStreamUpdates
-from synapse.types import JsonDict
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -495,27 +495,6 @@ class ToDeviceStream(Stream):
)
-class TagAccountDataStream(Stream):
- """Someone added/removed a tag for a room"""
-
- @attr.s(slots=True, frozen=True, auto_attribs=True)
- class TagAccountDataStreamRow:
- user_id: str
- room_id: str
- data: JsonDict
-
- NAME = "tag_account_data"
- ROW_TYPE = TagAccountDataStreamRow
-
- def __init__(self, hs: "HomeServer"):
- store = hs.get_datastores().main
- super().__init__(
- hs.get_instance_name(),
- current_token_without_instance(store.get_max_account_data_stream_id),
- store.get_all_updated_tags,
- )
-
-
class AccountDataStream(Stream):
"""Global or per room account data was changed"""
@@ -560,6 +539,19 @@ class AccountDataStream(Stream):
to_token = room_results[-1][0]
limited = True
+ tags, tag_to_token, tags_limited = await self.store.get_all_updated_tags(
+ instance_name,
+ from_token,
+ to_token,
+ limit,
+ )
+
+ # again, if the tag results hit the limit, limit the global results to
+ # the same stream token.
+ if tags_limited:
+ to_token = tag_to_token
+ limited = True
+
# convert the global results to the right format, and limit them to the to_token
# at the same time
global_rows = (
@@ -568,11 +560,16 @@ class AccountDataStream(Stream):
if stream_id <= to_token
)
- # we know that the room_results are already limited to `to_token` so no need
- # for a check on `stream_id` here.
room_rows = (
(stream_id, (user_id, room_id, account_data_type))
for stream_id, user_id, room_id, account_data_type in room_results
+ if stream_id <= to_token
+ )
+
+ tag_rows = (
+ (stream_id, (user_id, room_id, AccountDataTypes.TAG))
+ for stream_id, user_id, room_id in tags
+ if stream_id <= to_token
)
# We need to return a sorted list, so merge them together.
@@ -582,7 +579,9 @@ class AccountDataStream(Stream):
# leading to a comparison between the data tuples. The comparison could
# fail due to attempting to compare the `room_id` which results in a
# `TypeError` from comparing a `str` vs `None`.
- updates = list(heapq.merge(room_rows, global_rows, key=lambda row: row[0]))
+ updates = list(
+ heapq.merge(room_rows, global_rows, tag_rows, key=lambda row: row[0])
+ )
return updates, to_token, limited
|