summary refs log tree commit diff
path: root/synapse/replication/tcp/streams/_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication/tcp/streams/_base.py')
-rw-r--r--synapse/replication/tcp/streams/_base.py49
1 files changed, 24 insertions, 25 deletions
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index e01155ad59..fbf78da9c2 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -28,8 +28,8 @@ from typing import (
 
 import attr
 
+from synapse.api.constants import AccountDataTypes
 from synapse.replication.http.streams import ReplicationGetStreamUpdates
-from synapse.types import JsonDict
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -495,27 +495,6 @@ class ToDeviceStream(Stream):
         )
 
 
-class TagAccountDataStream(Stream):
-    """Someone added/removed a tag for a room"""
-
-    @attr.s(slots=True, frozen=True, auto_attribs=True)
-    class TagAccountDataStreamRow:
-        user_id: str
-        room_id: str
-        data: JsonDict
-
-    NAME = "tag_account_data"
-    ROW_TYPE = TagAccountDataStreamRow
-
-    def __init__(self, hs: "HomeServer"):
-        store = hs.get_datastores().main
-        super().__init__(
-            hs.get_instance_name(),
-            current_token_without_instance(store.get_max_account_data_stream_id),
-            store.get_all_updated_tags,
-        )
-
-
 class AccountDataStream(Stream):
     """Global or per room account data was changed"""
 
@@ -560,6 +539,19 @@ class AccountDataStream(Stream):
             to_token = room_results[-1][0]
             limited = True
 
+        tags, tag_to_token, tags_limited = await self.store.get_all_updated_tags(
+            instance_name,
+            from_token,
+            to_token,
+            limit,
+        )
+
+        # again, if the tag results hit the limit, limit the global results to
+        # the same stream token.
+        if tags_limited:
+            to_token = tag_to_token
+            limited = True
+
         # convert the global results to the right format, and limit them to the to_token
         # at the same time
         global_rows = (
@@ -568,11 +560,16 @@ class AccountDataStream(Stream):
             if stream_id <= to_token
         )
 
-        # we know that the room_results are already limited to `to_token` so no need
-        # for a check on `stream_id` here.
         room_rows = (
             (stream_id, (user_id, room_id, account_data_type))
             for stream_id, user_id, room_id, account_data_type in room_results
+            if stream_id <= to_token
+        )
+
+        tag_rows = (
+            (stream_id, (user_id, room_id, AccountDataTypes.TAG))
+            for stream_id, user_id, room_id in tags
+            if stream_id <= to_token
         )
 
         # We need to return a sorted list, so merge them together.
@@ -582,7 +579,9 @@ class AccountDataStream(Stream):
         # leading to a comparison between the data tuples. The comparison could
         # fail due to attempting to compare the `room_id` which results in a
         # `TypeError` from comparing a `str` vs `None`.
-        updates = list(heapq.merge(room_rows, global_rows, key=lambda row: row[0]))
+        updates = list(
+            heapq.merge(room_rows, global_rows, tag_rows, key=lambda row: row[0])
+        )
         return updates, to_token, limited