diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 6a5e7171da..6432d32d83 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -249,6 +249,7 @@ class RoomEncryptionAlgorithms:
class AccountDataTypes:
DIRECT: Final = "m.direct"
IGNORED_USER_LIST: Final = "m.ignored_user_list"
+ TAG: Final = "m.tag"
class HistoryVisibility:
diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py
index aba7315cf7..834006356a 100644
--- a/synapse/handlers/account_data.py
+++ b/synapse/handlers/account_data.py
@@ -16,6 +16,7 @@ import logging
import random
from typing import TYPE_CHECKING, Awaitable, Callable, Collection, List, Optional, Tuple
+from synapse.api.constants import AccountDataTypes
from synapse.replication.http.account_data import (
ReplicationAddRoomAccountDataRestServlet,
ReplicationAddTagRestServlet,
@@ -335,7 +336,11 @@ class AccountDataEventSource(EventSource[int, JsonDict]):
for room_id, room_tags in tags.items():
results.append(
- {"type": "m.tag", "content": {"tags": room_tags}, "room_id": room_id}
+ {
+ "type": AccountDataTypes.TAG,
+ "content": {"tags": room_tags},
+ "room_id": room_id,
+ }
)
(
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 9c335e6863..8c2260ad7d 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -15,7 +15,7 @@
import logging
from typing import TYPE_CHECKING, List, Optional, Tuple, cast
-from synapse.api.constants import EduTypes, EventTypes, Membership
+from synapse.api.constants import AccountDataTypes, EduTypes, EventTypes, Membership
from synapse.api.errors import SynapseError
from synapse.events import EventBase
from synapse.events.utils import SerializeEventConfig
@@ -239,7 +239,7 @@ class InitialSyncHandler:
tags = tags_by_room.get(event.room_id)
if tags:
account_data_events.append(
- {"type": "m.tag", "content": {"tags": tags}}
+ {"type": AccountDataTypes.TAG, "content": {"tags": tags}}
)
account_data = account_data_by_room.get(event.room_id, {})
@@ -326,7 +326,9 @@ class InitialSyncHandler:
account_data_events = []
tags = await self.store.get_tags_for_room(user_id, room_id)
if tags:
- account_data_events.append({"type": "m.tag", "content": {"tags": tags}})
+ account_data_events.append(
+ {"type": AccountDataTypes.TAG, "content": {"tags": tags}}
+ )
account_data = await self.store.get_account_data_for_room(user_id, room_id)
for account_data_type, content in account_data.items():
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 20ee2f203a..78d488f2b1 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -31,7 +31,12 @@ from typing import (
import attr
from prometheus_client import Counter
-from synapse.api.constants import EventContentFields, EventTypes, Membership
+from synapse.api.constants import (
+ AccountDataTypes,
+ EventContentFields,
+ EventTypes,
+ Membership,
+)
from synapse.api.filtering import FilterCollection
from synapse.api.presence import UserPresenceState
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
@@ -2331,7 +2336,9 @@ class SyncHandler:
account_data_events = []
if tags is not None:
- account_data_events.append({"type": "m.tag", "content": {"tags": tags}})
+ account_data_events.append(
+ {"type": AccountDataTypes.TAG, "content": {"tags": tags}}
+ )
for account_data_type, content in account_data.items():
account_data_events.append(
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index b5e40da533..7263bb2796 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -33,7 +33,6 @@ from synapse.replication.tcp.streams import (
PushersStream,
PushRulesStream,
ReceiptsStream,
- TagAccountDataStream,
ToDeviceStream,
TypingStream,
UnPartialStatedEventStream,
@@ -168,7 +167,7 @@ class ReplicationDataHandler:
self.notifier.on_new_event(
StreamKeyType.PUSH_RULES, token, users=[row.user_id for row in rows]
)
- elif stream_name in (AccountDataStream.NAME, TagAccountDataStream.NAME):
+ elif stream_name in AccountDataStream.NAME:
self.notifier.on_new_event(
StreamKeyType.ACCOUNT_DATA, token, users=[row.user_id for row in rows]
)
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 0f166d16aa..d03a53d764 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -58,7 +58,6 @@ from synapse.replication.tcp.streams import (
PresenceStream,
ReceiptsStream,
Stream,
- TagAccountDataStream,
ToDeviceStream,
TypingStream,
)
@@ -145,7 +144,7 @@ class ReplicationCommandHandler:
continue
- if isinstance(stream, (AccountDataStream, TagAccountDataStream)):
+ if isinstance(stream, AccountDataStream):
# Only add AccountDataStream and TagAccountDataStream as a source on the
# instance in charge of account_data persistence.
if hs.get_instance_name() in hs.config.worker.writers.account_data:
diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py
index 110f10aab9..a7eadfa3c9 100644
--- a/synapse/replication/tcp/streams/__init__.py
+++ b/synapse/replication/tcp/streams/__init__.py
@@ -35,7 +35,6 @@ from synapse.replication.tcp.streams._base import (
PushRulesStream,
ReceiptsStream,
Stream,
- TagAccountDataStream,
ToDeviceStream,
TypingStream,
UserSignatureStream,
@@ -62,7 +61,6 @@ STREAMS_MAP = {
DeviceListsStream,
ToDeviceStream,
FederationStream,
- TagAccountDataStream,
AccountDataStream,
UserSignatureStream,
UnPartialStatedRoomStream,
@@ -83,7 +81,6 @@ __all__ = [
"CachesStream",
"DeviceListsStream",
"ToDeviceStream",
- "TagAccountDataStream",
"AccountDataStream",
"UserSignatureStream",
"UnPartialStatedRoomStream",
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index e01155ad59..fbf78da9c2 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -28,8 +28,8 @@ from typing import (
import attr
+from synapse.api.constants import AccountDataTypes
from synapse.replication.http.streams import ReplicationGetStreamUpdates
-from synapse.types import JsonDict
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -495,27 +495,6 @@ class ToDeviceStream(Stream):
)
-class TagAccountDataStream(Stream):
- """Someone added/removed a tag for a room"""
-
- @attr.s(slots=True, frozen=True, auto_attribs=True)
- class TagAccountDataStreamRow:
- user_id: str
- room_id: str
- data: JsonDict
-
- NAME = "tag_account_data"
- ROW_TYPE = TagAccountDataStreamRow
-
- def __init__(self, hs: "HomeServer"):
- store = hs.get_datastores().main
- super().__init__(
- hs.get_instance_name(),
- current_token_without_instance(store.get_max_account_data_stream_id),
- store.get_all_updated_tags,
- )
-
-
class AccountDataStream(Stream):
"""Global or per room account data was changed"""
@@ -560,6 +539,19 @@ class AccountDataStream(Stream):
to_token = room_results[-1][0]
limited = True
+ tags, tag_to_token, tags_limited = await self.store.get_all_updated_tags(
+ instance_name,
+ from_token,
+ to_token,
+ limit,
+ )
+
+ # again, if the tag results hit the limit, limit the global results to
+ # the same stream token.
+ if tags_limited:
+ to_token = tag_to_token
+ limited = True
+
# convert the global results to the right format, and limit them to the to_token
# at the same time
global_rows = (
@@ -568,11 +560,16 @@ class AccountDataStream(Stream):
if stream_id <= to_token
)
- # we know that the room_results are already limited to `to_token` so no need
- # for a check on `stream_id` here.
room_rows = (
(stream_id, (user_id, room_id, account_data_type))
for stream_id, user_id, room_id, account_data_type in room_results
+ if stream_id <= to_token
+ )
+
+ tag_rows = (
+ (stream_id, (user_id, room_id, AccountDataTypes.TAG))
+ for stream_id, user_id, room_id in tags
+ if stream_id <= to_token
)
# We need to return a sorted list, so merge them together.
@@ -582,7 +579,9 @@ class AccountDataStream(Stream):
# leading to a comparison between the data tuples. The comparison could
# fail due to attempting to compare the `room_id` which results in a
# `TypeError` from comparing a `str` vs `None`.
- updates = list(heapq.merge(room_rows, global_rows, key=lambda row: row[0]))
+ updates = list(
+ heapq.merge(room_rows, global_rows, tag_rows, key=lambda row: row[0])
+ )
return updates, to_token, limited
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 86032897f5..881d7089db 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -27,7 +27,7 @@ from typing import (
)
from synapse.api.constants import AccountDataTypes
-from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream
+from synapse.replication.tcp.streams import AccountDataStream
from synapse.storage._base import db_to_json
from synapse.storage.database import (
DatabasePool,
@@ -454,9 +454,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
def process_replication_position(
self, stream_name: str, instance_name: str, token: int
) -> None:
- if stream_name == TagAccountDataStream.NAME:
- self._account_data_id_gen.advance(instance_name, token)
- elif stream_name == AccountDataStream.NAME:
+ if stream_name == AccountDataStream.NAME:
self._account_data_id_gen.advance(instance_name, token)
super().process_replication_position(stream_name, instance_name, token)
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index e23c927e02..d5500cdd47 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -17,7 +17,8 @@
import logging
from typing import Any, Dict, Iterable, List, Tuple, cast
-from synapse.replication.tcp.streams import TagAccountDataStream
+from synapse.api.constants import AccountDataTypes
+from synapse.replication.tcp.streams import AccountDataStream
from synapse.storage._base import db_to_json
from synapse.storage.database import LoggingTransaction
from synapse.storage.databases.main.account_data import AccountDataWorkerStore
@@ -54,7 +55,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
async def get_all_updated_tags(
self, instance_name: str, last_id: int, current_id: int, limit: int
- ) -> Tuple[List[Tuple[int, Tuple[str, str, str]]], int, bool]:
+ ) -> Tuple[List[Tuple[int, str, str]], int, bool]:
"""Get updates for tags replication stream.
Args:
@@ -73,7 +74,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
The token returned can be used in a subsequent call to this
function to get further updatees.
- The updates are a list of 2-tuples of stream ID and the row data
+ The updates are a list of tuples of stream ID, user ID and room ID
"""
if last_id == current_id:
@@ -96,38 +97,13 @@ class TagsWorkerStore(AccountDataWorkerStore):
"get_all_updated_tags", get_all_updated_tags_txn
)
- def get_tag_content(
- txn: LoggingTransaction, tag_ids: List[Tuple[int, str, str]]
- ) -> List[Tuple[int, Tuple[str, str, str]]]:
- sql = "SELECT tag, content FROM room_tags WHERE user_id=? AND room_id=?"
- results = []
- for stream_id, user_id, room_id in tag_ids:
- txn.execute(sql, (user_id, room_id))
- tags = []
- for tag, content in txn:
- tags.append(json_encoder.encode(tag) + ":" + content)
- tag_json = "{" + ",".join(tags) + "}"
- results.append((stream_id, (user_id, room_id, tag_json)))
-
- return results
-
- batch_size = 50
- results = []
- for i in range(0, len(tag_ids), batch_size):
- tags = await self.db_pool.runInteraction(
- "get_all_updated_tag_content",
- get_tag_content,
- tag_ids[i : i + batch_size],
- )
- results.extend(tags)
-
limited = False
upto_token = current_id
- if len(results) >= limit:
- upto_token = results[-1][0]
+ if len(tag_ids) >= limit:
+ upto_token = tag_ids[-1][0]
limited = True
- return results, upto_token, limited
+ return tag_ids, upto_token, limited
async def get_updated_tags(
self, user_id: str, stream_id: int
@@ -299,20 +275,16 @@ class TagsWorkerStore(AccountDataWorkerStore):
token: int,
rows: Iterable[Any],
) -> None:
- if stream_name == TagAccountDataStream.NAME:
+ if stream_name == AccountDataStream.NAME:
for row in rows:
- self.get_tags_for_user.invalidate((row.user_id,))
- self._account_data_stream_cache.entity_has_changed(row.user_id, token)
+ if row.data_type == AccountDataTypes.TAG:
+ self.get_tags_for_user.invalidate((row.user_id,))
+ self._account_data_stream_cache.entity_has_changed(
+ row.user_id, token
+ )
super().process_replication_rows(stream_name, instance_name, token, rows)
- def process_replication_position(
- self, stream_name: str, instance_name: str, token: int
- ) -> None:
- if stream_name == TagAccountDataStream.NAME:
- self._account_data_id_gen.advance(instance_name, token)
- super().process_replication_position(stream_name, instance_name, token)
-
class TagsStore(TagsWorkerStore):
pass
|