diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 86032897f5..881d7089db 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -27,7 +27,7 @@ from typing import (
)
from synapse.api.constants import AccountDataTypes
-from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream
+from synapse.replication.tcp.streams import AccountDataStream
from synapse.storage._base import db_to_json
from synapse.storage.database import (
DatabasePool,
@@ -454,9 +454,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
def process_replication_position(
self, stream_name: str, instance_name: str, token: int
) -> None:
- if stream_name == TagAccountDataStream.NAME:
- self._account_data_id_gen.advance(instance_name, token)
- elif stream_name == AccountDataStream.NAME:
+ if stream_name == AccountDataStream.NAME:
self._account_data_id_gen.advance(instance_name, token)
super().process_replication_position(stream_name, instance_name, token)
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index e23c927e02..d5500cdd47 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -17,7 +17,8 @@
import logging
from typing import Any, Dict, Iterable, List, Tuple, cast
-from synapse.replication.tcp.streams import TagAccountDataStream
+from synapse.api.constants import AccountDataTypes
+from synapse.replication.tcp.streams import AccountDataStream
from synapse.storage._base import db_to_json
from synapse.storage.database import LoggingTransaction
from synapse.storage.databases.main.account_data import AccountDataWorkerStore
@@ -54,7 +55,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
async def get_all_updated_tags(
self, instance_name: str, last_id: int, current_id: int, limit: int
- ) -> Tuple[List[Tuple[int, Tuple[str, str, str]]], int, bool]:
+ ) -> Tuple[List[Tuple[int, str, str]], int, bool]:
"""Get updates for tags replication stream.
Args:
@@ -73,7 +74,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
The token returned can be used in a subsequent call to this
function to get further updatees.
- The updates are a list of 2-tuples of stream ID and the row data
+ The updates are a list of tuples of stream ID, user ID and room ID
"""
if last_id == current_id:
@@ -96,38 +97,13 @@ class TagsWorkerStore(AccountDataWorkerStore):
"get_all_updated_tags", get_all_updated_tags_txn
)
- def get_tag_content(
- txn: LoggingTransaction, tag_ids: List[Tuple[int, str, str]]
- ) -> List[Tuple[int, Tuple[str, str, str]]]:
- sql = "SELECT tag, content FROM room_tags WHERE user_id=? AND room_id=?"
- results = []
- for stream_id, user_id, room_id in tag_ids:
- txn.execute(sql, (user_id, room_id))
- tags = []
- for tag, content in txn:
- tags.append(json_encoder.encode(tag) + ":" + content)
- tag_json = "{" + ",".join(tags) + "}"
- results.append((stream_id, (user_id, room_id, tag_json)))
-
- return results
-
- batch_size = 50
- results = []
- for i in range(0, len(tag_ids), batch_size):
- tags = await self.db_pool.runInteraction(
- "get_all_updated_tag_content",
- get_tag_content,
- tag_ids[i : i + batch_size],
- )
- results.extend(tags)
-
limited = False
upto_token = current_id
- if len(results) >= limit:
- upto_token = results[-1][0]
+ if len(tag_ids) >= limit:
+ upto_token = tag_ids[-1][0]
limited = True
- return results, upto_token, limited
+ return tag_ids, upto_token, limited
async def get_updated_tags(
self, user_id: str, stream_id: int
@@ -299,20 +275,16 @@ class TagsWorkerStore(AccountDataWorkerStore):
token: int,
rows: Iterable[Any],
) -> None:
- if stream_name == TagAccountDataStream.NAME:
+ if stream_name == AccountDataStream.NAME:
for row in rows:
- self.get_tags_for_user.invalidate((row.user_id,))
- self._account_data_stream_cache.entity_has_changed(row.user_id, token)
+ if row.data_type == AccountDataTypes.TAG:
+ self.get_tags_for_user.invalidate((row.user_id,))
+ self._account_data_stream_cache.entity_has_changed(
+ row.user_id, token
+ )
super().process_replication_rows(stream_name, instance_name, token, rows)
- def process_replication_position(
- self, stream_name: str, instance_name: str, token: int
- ) -> None:
- if stream_name == TagAccountDataStream.NAME:
- self._account_data_id_gen.advance(instance_name, token)
- super().process_replication_position(stream_name, instance_name, token)
-
class TagsStore(TagsWorkerStore):
pass
|