summary refs log tree commit diff
path: root/synapse/storage/databases
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2023-01-13 14:57:43 +0000
committerGitHub <noreply@github.com>2023-01-13 14:57:43 +0000
commit73ff493dfba63541a09eaf08587eb8bbd3330967 (patch)
tree97691d93c6e16922f85f1555577ffec2aece2142 /synapse/storage/databases
parentUpdate misleading documentation ` user_directory.search_all_users ` (#14818) (diff)
downloadsynapse-73ff493dfba63541a09eaf08587eb8bbd3330967.tar.xz
Merge account data streams (#14826)
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r--synapse/storage/databases/main/account_data.py6
-rw-r--r--synapse/storage/databases/main/tags.py54
2 files changed, 15 insertions, 45 deletions
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 86032897f5..881d7089db 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -27,7 +27,7 @@ from typing import (
 )
 
 from synapse.api.constants import AccountDataTypes
-from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream
+from synapse.replication.tcp.streams import AccountDataStream
 from synapse.storage._base import db_to_json
 from synapse.storage.database import (
     DatabasePool,
@@ -454,9 +454,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
     def process_replication_position(
         self, stream_name: str, instance_name: str, token: int
     ) -> None:
-        if stream_name == TagAccountDataStream.NAME:
-            self._account_data_id_gen.advance(instance_name, token)
-        elif stream_name == AccountDataStream.NAME:
+        if stream_name == AccountDataStream.NAME:
             self._account_data_id_gen.advance(instance_name, token)
         super().process_replication_position(stream_name, instance_name, token)
 
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index e23c927e02..d5500cdd47 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -17,7 +17,8 @@
 import logging
 from typing import Any, Dict, Iterable, List, Tuple, cast
 
-from synapse.replication.tcp.streams import TagAccountDataStream
+from synapse.api.constants import AccountDataTypes
+from synapse.replication.tcp.streams import AccountDataStream
 from synapse.storage._base import db_to_json
 from synapse.storage.database import LoggingTransaction
 from synapse.storage.databases.main.account_data import AccountDataWorkerStore
@@ -54,7 +55,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
 
     async def get_all_updated_tags(
         self, instance_name: str, last_id: int, current_id: int, limit: int
-    ) -> Tuple[List[Tuple[int, Tuple[str, str, str]]], int, bool]:
+    ) -> Tuple[List[Tuple[int, str, str]], int, bool]:
         """Get updates for tags replication stream.
 
         Args:
@@ -73,7 +74,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
             The token returned can be used in a subsequent call to this
             function to get further updatees.
 
-            The updates are a list of 2-tuples of stream ID and the row data
+            The updates are a list of tuples of stream ID, user ID and room ID
         """
 
         if last_id == current_id:
@@ -96,38 +97,13 @@ class TagsWorkerStore(AccountDataWorkerStore):
             "get_all_updated_tags", get_all_updated_tags_txn
         )
 
-        def get_tag_content(
-            txn: LoggingTransaction, tag_ids: List[Tuple[int, str, str]]
-        ) -> List[Tuple[int, Tuple[str, str, str]]]:
-            sql = "SELECT tag, content FROM room_tags WHERE user_id=? AND room_id=?"
-            results = []
-            for stream_id, user_id, room_id in tag_ids:
-                txn.execute(sql, (user_id, room_id))
-                tags = []
-                for tag, content in txn:
-                    tags.append(json_encoder.encode(tag) + ":" + content)
-                tag_json = "{" + ",".join(tags) + "}"
-                results.append((stream_id, (user_id, room_id, tag_json)))
-
-            return results
-
-        batch_size = 50
-        results = []
-        for i in range(0, len(tag_ids), batch_size):
-            tags = await self.db_pool.runInteraction(
-                "get_all_updated_tag_content",
-                get_tag_content,
-                tag_ids[i : i + batch_size],
-            )
-            results.extend(tags)
-
         limited = False
         upto_token = current_id
-        if len(results) >= limit:
-            upto_token = results[-1][0]
+        if len(tag_ids) >= limit:
+            upto_token = tag_ids[-1][0]
             limited = True
 
-        return results, upto_token, limited
+        return tag_ids, upto_token, limited
 
     async def get_updated_tags(
         self, user_id: str, stream_id: int
@@ -299,20 +275,16 @@ class TagsWorkerStore(AccountDataWorkerStore):
         token: int,
         rows: Iterable[Any],
     ) -> None:
-        if stream_name == TagAccountDataStream.NAME:
+        if stream_name == AccountDataStream.NAME:
             for row in rows:
-                self.get_tags_for_user.invalidate((row.user_id,))
-                self._account_data_stream_cache.entity_has_changed(row.user_id, token)
+                if row.data_type == AccountDataTypes.TAG:
+                    self.get_tags_for_user.invalidate((row.user_id,))
+                    self._account_data_stream_cache.entity_has_changed(
+                        row.user_id, token
+                    )
 
         super().process_replication_rows(stream_name, instance_name, token, rows)
 
-    def process_replication_position(
-        self, stream_name: str, instance_name: str, token: int
-    ) -> None:
-        if stream_name == TagAccountDataStream.NAME:
-            self._account_data_id_gen.advance(instance_name, token)
-        super().process_replication_position(stream_name, instance_name, token)
-
 
 class TagsStore(TagsWorkerStore):
     pass