summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/api/constants.py1
-rw-r--r--synapse/handlers/account_data.py7
-rw-r--r--synapse/handlers/initial_sync.py8
-rw-r--r--synapse/handlers/sync.py11
-rw-r--r--synapse/replication/tcp/client.py3
-rw-r--r--synapse/replication/tcp/handler.py3
-rw-r--r--synapse/replication/tcp/streams/__init__.py3
-rw-r--r--synapse/replication/tcp/streams/_base.py49
-rw-r--r--synapse/storage/databases/main/account_data.py6
-rw-r--r--synapse/storage/databases/main/tags.py54
10 files changed, 62 insertions, 83 deletions
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 6a5e7171da..6432d32d83 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -249,6 +249,7 @@ class RoomEncryptionAlgorithms:
 class AccountDataTypes:
     DIRECT: Final = "m.direct"
     IGNORED_USER_LIST: Final = "m.ignored_user_list"
+    TAG: Final = "m.tag"
 
 
 class HistoryVisibility:
diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py
index aba7315cf7..834006356a 100644
--- a/synapse/handlers/account_data.py
+++ b/synapse/handlers/account_data.py
@@ -16,6 +16,7 @@ import logging
 import random
 from typing import TYPE_CHECKING, Awaitable, Callable, Collection, List, Optional, Tuple
 
+from synapse.api.constants import AccountDataTypes
 from synapse.replication.http.account_data import (
     ReplicationAddRoomAccountDataRestServlet,
     ReplicationAddTagRestServlet,
@@ -335,7 +336,11 @@ class AccountDataEventSource(EventSource[int, JsonDict]):
 
         for room_id, room_tags in tags.items():
             results.append(
-                {"type": "m.tag", "content": {"tags": room_tags}, "room_id": room_id}
+                {
+                    "type": AccountDataTypes.TAG,
+                    "content": {"tags": room_tags},
+                    "room_id": room_id,
+                }
             )
 
         (
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 9c335e6863..8c2260ad7d 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -15,7 +15,7 @@
 import logging
 from typing import TYPE_CHECKING, List, Optional, Tuple, cast
 
-from synapse.api.constants import EduTypes, EventTypes, Membership
+from synapse.api.constants import AccountDataTypes, EduTypes, EventTypes, Membership
 from synapse.api.errors import SynapseError
 from synapse.events import EventBase
 from synapse.events.utils import SerializeEventConfig
@@ -239,7 +239,7 @@ class InitialSyncHandler:
                 tags = tags_by_room.get(event.room_id)
                 if tags:
                     account_data_events.append(
-                        {"type": "m.tag", "content": {"tags": tags}}
+                        {"type": AccountDataTypes.TAG, "content": {"tags": tags}}
                     )
 
                 account_data = account_data_by_room.get(event.room_id, {})
@@ -326,7 +326,9 @@ class InitialSyncHandler:
         account_data_events = []
         tags = await self.store.get_tags_for_room(user_id, room_id)
         if tags:
-            account_data_events.append({"type": "m.tag", "content": {"tags": tags}})
+            account_data_events.append(
+                {"type": AccountDataTypes.TAG, "content": {"tags": tags}}
+            )
 
         account_data = await self.store.get_account_data_for_room(user_id, room_id)
         for account_data_type, content in account_data.items():
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 20ee2f203a..78d488f2b1 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -31,7 +31,12 @@ from typing import (
 import attr
 from prometheus_client import Counter
 
-from synapse.api.constants import EventContentFields, EventTypes, Membership
+from synapse.api.constants import (
+    AccountDataTypes,
+    EventContentFields,
+    EventTypes,
+    Membership,
+)
 from synapse.api.filtering import FilterCollection
 from synapse.api.presence import UserPresenceState
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
@@ -2331,7 +2336,9 @@ class SyncHandler:
 
             account_data_events = []
             if tags is not None:
-                account_data_events.append({"type": "m.tag", "content": {"tags": tags}})
+                account_data_events.append(
+                    {"type": AccountDataTypes.TAG, "content": {"tags": tags}}
+                )
 
             for account_data_type, content in account_data.items():
                 account_data_events.append(
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index b5e40da533..7263bb2796 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -33,7 +33,6 @@ from synapse.replication.tcp.streams import (
     PushersStream,
     PushRulesStream,
     ReceiptsStream,
-    TagAccountDataStream,
     ToDeviceStream,
     TypingStream,
     UnPartialStatedEventStream,
@@ -168,7 +167,7 @@ class ReplicationDataHandler:
             self.notifier.on_new_event(
                 StreamKeyType.PUSH_RULES, token, users=[row.user_id for row in rows]
             )
-        elif stream_name in (AccountDataStream.NAME, TagAccountDataStream.NAME):
+        elif stream_name in AccountDataStream.NAME:
             self.notifier.on_new_event(
                 StreamKeyType.ACCOUNT_DATA, token, users=[row.user_id for row in rows]
             )
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 0f166d16aa..d03a53d764 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -58,7 +58,6 @@ from synapse.replication.tcp.streams import (
     PresenceStream,
     ReceiptsStream,
     Stream,
-    TagAccountDataStream,
     ToDeviceStream,
     TypingStream,
 )
@@ -145,7 +144,7 @@ class ReplicationCommandHandler:
 
                 continue
 
-            if isinstance(stream, (AccountDataStream, TagAccountDataStream)):
+            if isinstance(stream, AccountDataStream):
                 # Only add AccountDataStream and TagAccountDataStream as a source on the
                 # instance in charge of account_data persistence.
                 if hs.get_instance_name() in hs.config.worker.writers.account_data:
diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py
index 110f10aab9..a7eadfa3c9 100644
--- a/synapse/replication/tcp/streams/__init__.py
+++ b/synapse/replication/tcp/streams/__init__.py
@@ -35,7 +35,6 @@ from synapse.replication.tcp.streams._base import (
     PushRulesStream,
     ReceiptsStream,
     Stream,
-    TagAccountDataStream,
     ToDeviceStream,
     TypingStream,
     UserSignatureStream,
@@ -62,7 +61,6 @@ STREAMS_MAP = {
         DeviceListsStream,
         ToDeviceStream,
         FederationStream,
-        TagAccountDataStream,
         AccountDataStream,
         UserSignatureStream,
         UnPartialStatedRoomStream,
@@ -83,7 +81,6 @@ __all__ = [
     "CachesStream",
     "DeviceListsStream",
     "ToDeviceStream",
-    "TagAccountDataStream",
     "AccountDataStream",
     "UserSignatureStream",
     "UnPartialStatedRoomStream",
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index e01155ad59..fbf78da9c2 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -28,8 +28,8 @@ from typing import (
 
 import attr
 
+from synapse.api.constants import AccountDataTypes
 from synapse.replication.http.streams import ReplicationGetStreamUpdates
-from synapse.types import JsonDict
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -495,27 +495,6 @@ class ToDeviceStream(Stream):
         )
 
 
-class TagAccountDataStream(Stream):
-    """Someone added/removed a tag for a room"""
-
-    @attr.s(slots=True, frozen=True, auto_attribs=True)
-    class TagAccountDataStreamRow:
-        user_id: str
-        room_id: str
-        data: JsonDict
-
-    NAME = "tag_account_data"
-    ROW_TYPE = TagAccountDataStreamRow
-
-    def __init__(self, hs: "HomeServer"):
-        store = hs.get_datastores().main
-        super().__init__(
-            hs.get_instance_name(),
-            current_token_without_instance(store.get_max_account_data_stream_id),
-            store.get_all_updated_tags,
-        )
-
-
 class AccountDataStream(Stream):
     """Global or per room account data was changed"""
 
@@ -560,6 +539,19 @@ class AccountDataStream(Stream):
             to_token = room_results[-1][0]
             limited = True
 
+        tags, tag_to_token, tags_limited = await self.store.get_all_updated_tags(
+            instance_name,
+            from_token,
+            to_token,
+            limit,
+        )
+
+        # again, if the tag results hit the limit, limit the global results to
+        # the same stream token.
+        if tags_limited:
+            to_token = tag_to_token
+            limited = True
+
         # convert the global results to the right format, and limit them to the to_token
         # at the same time
         global_rows = (
@@ -568,11 +560,16 @@ class AccountDataStream(Stream):
             if stream_id <= to_token
         )
 
-        # we know that the room_results are already limited to `to_token` so no need
-        # for a check on `stream_id` here.
         room_rows = (
             (stream_id, (user_id, room_id, account_data_type))
             for stream_id, user_id, room_id, account_data_type in room_results
+            if stream_id <= to_token
+        )
+
+        tag_rows = (
+            (stream_id, (user_id, room_id, AccountDataTypes.TAG))
+            for stream_id, user_id, room_id in tags
+            if stream_id <= to_token
         )
 
         # We need to return a sorted list, so merge them together.
@@ -582,7 +579,9 @@ class AccountDataStream(Stream):
         # leading to a comparison between the data tuples. The comparison could
         # fail due to attempting to compare the `room_id` which results in a
         # `TypeError` from comparing a `str` vs `None`.
-        updates = list(heapq.merge(room_rows, global_rows, key=lambda row: row[0]))
+        updates = list(
+            heapq.merge(room_rows, global_rows, tag_rows, key=lambda row: row[0])
+        )
         return updates, to_token, limited
 
 
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 86032897f5..881d7089db 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -27,7 +27,7 @@ from typing import (
 )
 
 from synapse.api.constants import AccountDataTypes
-from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream
+from synapse.replication.tcp.streams import AccountDataStream
 from synapse.storage._base import db_to_json
 from synapse.storage.database import (
     DatabasePool,
@@ -454,9 +454,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
     def process_replication_position(
         self, stream_name: str, instance_name: str, token: int
     ) -> None:
-        if stream_name == TagAccountDataStream.NAME:
-            self._account_data_id_gen.advance(instance_name, token)
-        elif stream_name == AccountDataStream.NAME:
+        if stream_name == AccountDataStream.NAME:
             self._account_data_id_gen.advance(instance_name, token)
         super().process_replication_position(stream_name, instance_name, token)
 
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index e23c927e02..d5500cdd47 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -17,7 +17,8 @@
 import logging
 from typing import Any, Dict, Iterable, List, Tuple, cast
 
-from synapse.replication.tcp.streams import TagAccountDataStream
+from synapse.api.constants import AccountDataTypes
+from synapse.replication.tcp.streams import AccountDataStream
 from synapse.storage._base import db_to_json
 from synapse.storage.database import LoggingTransaction
 from synapse.storage.databases.main.account_data import AccountDataWorkerStore
@@ -54,7 +55,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
 
     async def get_all_updated_tags(
         self, instance_name: str, last_id: int, current_id: int, limit: int
-    ) -> Tuple[List[Tuple[int, Tuple[str, str, str]]], int, bool]:
+    ) -> Tuple[List[Tuple[int, str, str]], int, bool]:
         """Get updates for tags replication stream.
 
         Args:
@@ -73,7 +74,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
             The token returned can be used in a subsequent call to this
             function to get further updatees.
 
-            The updates are a list of 2-tuples of stream ID and the row data
+            The updates are a list of tuples of stream ID, user ID and room ID
         """
 
         if last_id == current_id:
@@ -96,38 +97,13 @@ class TagsWorkerStore(AccountDataWorkerStore):
             "get_all_updated_tags", get_all_updated_tags_txn
         )
 
-        def get_tag_content(
-            txn: LoggingTransaction, tag_ids: List[Tuple[int, str, str]]
-        ) -> List[Tuple[int, Tuple[str, str, str]]]:
-            sql = "SELECT tag, content FROM room_tags WHERE user_id=? AND room_id=?"
-            results = []
-            for stream_id, user_id, room_id in tag_ids:
-                txn.execute(sql, (user_id, room_id))
-                tags = []
-                for tag, content in txn:
-                    tags.append(json_encoder.encode(tag) + ":" + content)
-                tag_json = "{" + ",".join(tags) + "}"
-                results.append((stream_id, (user_id, room_id, tag_json)))
-
-            return results
-
-        batch_size = 50
-        results = []
-        for i in range(0, len(tag_ids), batch_size):
-            tags = await self.db_pool.runInteraction(
-                "get_all_updated_tag_content",
-                get_tag_content,
-                tag_ids[i : i + batch_size],
-            )
-            results.extend(tags)
-
         limited = False
         upto_token = current_id
-        if len(results) >= limit:
-            upto_token = results[-1][0]
+        if len(tag_ids) >= limit:
+            upto_token = tag_ids[-1][0]
             limited = True
 
-        return results, upto_token, limited
+        return tag_ids, upto_token, limited
 
     async def get_updated_tags(
         self, user_id: str, stream_id: int
@@ -299,20 +275,16 @@ class TagsWorkerStore(AccountDataWorkerStore):
         token: int,
         rows: Iterable[Any],
     ) -> None:
-        if stream_name == TagAccountDataStream.NAME:
+        if stream_name == AccountDataStream.NAME:
             for row in rows:
-                self.get_tags_for_user.invalidate((row.user_id,))
-                self._account_data_stream_cache.entity_has_changed(row.user_id, token)
+                if row.data_type == AccountDataTypes.TAG:
+                    self.get_tags_for_user.invalidate((row.user_id,))
+                    self._account_data_stream_cache.entity_has_changed(
+                        row.user_id, token
+                    )
 
         super().process_replication_rows(stream_name, instance_name, token, rows)
 
-    def process_replication_position(
-        self, stream_name: str, instance_name: str, token: int
-    ) -> None:
-        if stream_name == TagAccountDataStream.NAME:
-            self._account_data_id_gen.advance(instance_name, token)
-        super().process_replication_position(stream_name, instance_name, token)
-
 
 class TagsStore(TagsWorkerStore):
     pass