summary refs log tree commit diff
path: root/synapse/replication
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication')
-rw-r--r--synapse/replication/tcp/client.py3
-rw-r--r--synapse/replication/tcp/handler.py3
-rw-r--r--synapse/replication/tcp/streams/__init__.py3
-rw-r--r--synapse/replication/tcp/streams/_base.py49
4 files changed, 26 insertions, 32 deletions
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index b5e40da533..7263bb2796 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -33,7 +33,6 @@ from synapse.replication.tcp.streams import (
     PushersStream,
     PushRulesStream,
     ReceiptsStream,
-    TagAccountDataStream,
     ToDeviceStream,
     TypingStream,
     UnPartialStatedEventStream,
@@ -168,7 +167,7 @@ class ReplicationDataHandler:
             self.notifier.on_new_event(
                 StreamKeyType.PUSH_RULES, token, users=[row.user_id for row in rows]
             )
-        elif stream_name in (AccountDataStream.NAME, TagAccountDataStream.NAME):
+        elif stream_name in AccountDataStream.NAME:
             self.notifier.on_new_event(
                 StreamKeyType.ACCOUNT_DATA, token, users=[row.user_id for row in rows]
             )
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 0f166d16aa..d03a53d764 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -58,7 +58,6 @@ from synapse.replication.tcp.streams import (
     PresenceStream,
     ReceiptsStream,
     Stream,
-    TagAccountDataStream,
     ToDeviceStream,
     TypingStream,
 )
@@ -145,7 +144,7 @@ class ReplicationCommandHandler:
 
                 continue
 
-            if isinstance(stream, (AccountDataStream, TagAccountDataStream)):
+            if isinstance(stream, AccountDataStream):
                 # Only add AccountDataStream and TagAccountDataStream as a source on the
                 # instance in charge of account_data persistence.
                 if hs.get_instance_name() in hs.config.worker.writers.account_data:
diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py
index 110f10aab9..a7eadfa3c9 100644
--- a/synapse/replication/tcp/streams/__init__.py
+++ b/synapse/replication/tcp/streams/__init__.py
@@ -35,7 +35,6 @@ from synapse.replication.tcp.streams._base import (
     PushRulesStream,
     ReceiptsStream,
     Stream,
-    TagAccountDataStream,
     ToDeviceStream,
     TypingStream,
     UserSignatureStream,
@@ -62,7 +61,6 @@ STREAMS_MAP = {
         DeviceListsStream,
         ToDeviceStream,
         FederationStream,
-        TagAccountDataStream,
         AccountDataStream,
         UserSignatureStream,
         UnPartialStatedRoomStream,
@@ -83,7 +81,6 @@ __all__ = [
     "CachesStream",
     "DeviceListsStream",
     "ToDeviceStream",
-    "TagAccountDataStream",
     "AccountDataStream",
     "UserSignatureStream",
     "UnPartialStatedRoomStream",
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index e01155ad59..fbf78da9c2 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -28,8 +28,8 @@ from typing import (
 
 import attr
 
+from synapse.api.constants import AccountDataTypes
 from synapse.replication.http.streams import ReplicationGetStreamUpdates
-from synapse.types import JsonDict
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -495,27 +495,6 @@ class ToDeviceStream(Stream):
         )
 
 
-class TagAccountDataStream(Stream):
-    """Someone added/removed a tag for a room"""
-
-    @attr.s(slots=True, frozen=True, auto_attribs=True)
-    class TagAccountDataStreamRow:
-        user_id: str
-        room_id: str
-        data: JsonDict
-
-    NAME = "tag_account_data"
-    ROW_TYPE = TagAccountDataStreamRow
-
-    def __init__(self, hs: "HomeServer"):
-        store = hs.get_datastores().main
-        super().__init__(
-            hs.get_instance_name(),
-            current_token_without_instance(store.get_max_account_data_stream_id),
-            store.get_all_updated_tags,
-        )
-
-
 class AccountDataStream(Stream):
     """Global or per room account data was changed"""
 
@@ -560,6 +539,19 @@ class AccountDataStream(Stream):
             to_token = room_results[-1][0]
             limited = True
 
+        tags, tag_to_token, tags_limited = await self.store.get_all_updated_tags(
+            instance_name,
+            from_token,
+            to_token,
+            limit,
+        )
+
+        # again, if the tag results hit the limit, limit the global results to
+        # the same stream token.
+        if tags_limited:
+            to_token = tag_to_token
+            limited = True
+
         # convert the global results to the right format, and limit them to the to_token
         # at the same time
         global_rows = (
@@ -568,11 +560,16 @@ class AccountDataStream(Stream):
             if stream_id <= to_token
         )
 
-        # we know that the room_results are already limited to `to_token` so no need
-        # for a check on `stream_id` here.
         room_rows = (
             (stream_id, (user_id, room_id, account_data_type))
             for stream_id, user_id, room_id, account_data_type in room_results
+            if stream_id <= to_token
+        )
+
+        tag_rows = (
+            (stream_id, (user_id, room_id, AccountDataTypes.TAG))
+            for stream_id, user_id, room_id in tags
+            if stream_id <= to_token
         )
 
         # We need to return a sorted list, so merge them together.
@@ -582,7 +579,9 @@ class AccountDataStream(Stream):
         # leading to a comparison between the data tuples. The comparison could
         # fail due to attempting to compare the `room_id` which results in a
         # `TypeError` from comparing a `str` vs `None`.
-        updates = list(heapq.merge(room_rows, global_rows, key=lambda row: row[0]))
+        updates = list(
+            heapq.merge(room_rows, global_rows, tag_rows, key=lambda row: row[0])
+        )
         return updates, to_token, limited