diff options
Diffstat (limited to 'synapse')
-rw-r--r-- | synapse/api/constants.py | 1 | ||||
-rw-r--r-- | synapse/handlers/account_data.py | 7 | ||||
-rw-r--r-- | synapse/handlers/initial_sync.py | 8 | ||||
-rw-r--r-- | synapse/handlers/sync.py | 11 | ||||
-rw-r--r-- | synapse/replication/tcp/client.py | 3 | ||||
-rw-r--r-- | synapse/replication/tcp/handler.py | 3 | ||||
-rw-r--r-- | synapse/replication/tcp/streams/__init__.py | 3 | ||||
-rw-r--r-- | synapse/replication/tcp/streams/_base.py | 49 | ||||
-rw-r--r-- | synapse/storage/databases/main/account_data.py | 6 | ||||
-rw-r--r-- | synapse/storage/databases/main/tags.py | 54 |
10 files changed, 62 insertions, 83 deletions
diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 6a5e7171da..6432d32d83 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -249,6 +249,7 @@ class RoomEncryptionAlgorithms: class AccountDataTypes: DIRECT: Final = "m.direct" IGNORED_USER_LIST: Final = "m.ignored_user_list" + TAG: Final = "m.tag" class HistoryVisibility: diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py index aba7315cf7..834006356a 100644 --- a/synapse/handlers/account_data.py +++ b/synapse/handlers/account_data.py @@ -16,6 +16,7 @@ import logging import random from typing import TYPE_CHECKING, Awaitable, Callable, Collection, List, Optional, Tuple +from synapse.api.constants import AccountDataTypes from synapse.replication.http.account_data import ( ReplicationAddRoomAccountDataRestServlet, ReplicationAddTagRestServlet, @@ -335,7 +336,11 @@ class AccountDataEventSource(EventSource[int, JsonDict]): for room_id, room_tags in tags.items(): results.append( - {"type": "m.tag", "content": {"tags": room_tags}, "room_id": room_id} + { + "type": AccountDataTypes.TAG, + "content": {"tags": room_tags}, + "room_id": room_id, + } ) ( diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 9c335e6863..8c2260ad7d 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -15,7 +15,7 @@ import logging from typing import TYPE_CHECKING, List, Optional, Tuple, cast -from synapse.api.constants import EduTypes, EventTypes, Membership +from synapse.api.constants import AccountDataTypes, EduTypes, EventTypes, Membership from synapse.api.errors import SynapseError from synapse.events import EventBase from synapse.events.utils import SerializeEventConfig @@ -239,7 +239,7 @@ class InitialSyncHandler: tags = tags_by_room.get(event.room_id) if tags: account_data_events.append( - {"type": "m.tag", "content": {"tags": tags}} + {"type": AccountDataTypes.TAG, "content": {"tags": tags}} ) account_data = account_data_by_room.get(event.room_id, {}) @@ -326,7 +326,9 @@ class InitialSyncHandler: account_data_events = [] tags = await self.store.get_tags_for_room(user_id, room_id) if tags: - account_data_events.append({"type": "m.tag", "content": {"tags": tags}}) + account_data_events.append( + {"type": AccountDataTypes.TAG, "content": {"tags": tags}} + ) account_data = await self.store.get_account_data_for_room(user_id, room_id) for account_data_type, content in account_data.items(): diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 20ee2f203a..78d488f2b1 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -31,7 +31,12 @@ from typing import ( import attr from prometheus_client import Counter -from synapse.api.constants import EventContentFields, EventTypes, Membership +from synapse.api.constants import ( + AccountDataTypes, + EventContentFields, + EventTypes, + Membership, +) from synapse.api.filtering import FilterCollection from synapse.api.presence import UserPresenceState from synapse.api.room_versions import KNOWN_ROOM_VERSIONS @@ -2331,7 +2336,9 @@ class SyncHandler: account_data_events = [] if tags is not None: - account_data_events.append({"type": "m.tag", "content": {"tags": tags}}) + account_data_events.append( + {"type": AccountDataTypes.TAG, "content": {"tags": tags}} + ) for account_data_type, content in account_data.items(): account_data_events.append( diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index b5e40da533..7263bb2796 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -33,7 +33,6 @@ from synapse.replication.tcp.streams import ( PushersStream, PushRulesStream, ReceiptsStream, - TagAccountDataStream, ToDeviceStream, TypingStream, UnPartialStatedEventStream, @@ -168,7 +167,7 @@ class ReplicationDataHandler: self.notifier.on_new_event( StreamKeyType.PUSH_RULES, token, users=[row.user_id for row in rows] ) - elif stream_name in (AccountDataStream.NAME, TagAccountDataStream.NAME): + elif stream_name in AccountDataStream.NAME: self.notifier.on_new_event( StreamKeyType.ACCOUNT_DATA, token, users=[row.user_id for row in rows] ) diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 0f166d16aa..d03a53d764 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -58,7 +58,6 @@ from synapse.replication.tcp.streams import ( PresenceStream, ReceiptsStream, Stream, - TagAccountDataStream, ToDeviceStream, TypingStream, ) @@ -145,7 +144,7 @@ class ReplicationCommandHandler: continue - if isinstance(stream, (AccountDataStream, TagAccountDataStream)): + if isinstance(stream, AccountDataStream): # Only add AccountDataStream and TagAccountDataStream as a source on the # instance in charge of account_data persistence. if hs.get_instance_name() in hs.config.worker.writers.account_data: diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py index 110f10aab9..a7eadfa3c9 100644 --- a/synapse/replication/tcp/streams/__init__.py +++ b/synapse/replication/tcp/streams/__init__.py @@ -35,7 +35,6 @@ from synapse.replication.tcp.streams._base import ( PushRulesStream, ReceiptsStream, Stream, - TagAccountDataStream, ToDeviceStream, TypingStream, UserSignatureStream, @@ -62,7 +61,6 @@ STREAMS_MAP = { DeviceListsStream, ToDeviceStream, FederationStream, - TagAccountDataStream, AccountDataStream, UserSignatureStream, UnPartialStatedRoomStream, @@ -83,7 +81,6 @@ __all__ = [ "CachesStream", "DeviceListsStream", "ToDeviceStream", - "TagAccountDataStream", "AccountDataStream", "UserSignatureStream", "UnPartialStatedRoomStream", diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index e01155ad59..fbf78da9c2 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -28,8 +28,8 @@ from typing import ( import attr +from synapse.api.constants import AccountDataTypes from synapse.replication.http.streams import ReplicationGetStreamUpdates -from synapse.types import JsonDict if TYPE_CHECKING: from synapse.server import HomeServer @@ -495,27 +495,6 @@ class ToDeviceStream(Stream): ) -class TagAccountDataStream(Stream): - """Someone added/removed a tag for a room""" - - @attr.s(slots=True, frozen=True, auto_attribs=True) - class TagAccountDataStreamRow: - user_id: str - room_id: str - data: JsonDict - - NAME = "tag_account_data" - ROW_TYPE = TagAccountDataStreamRow - - def __init__(self, hs: "HomeServer"): - store = hs.get_datastores().main - super().__init__( - hs.get_instance_name(), - current_token_without_instance(store.get_max_account_data_stream_id), - store.get_all_updated_tags, - ) - - class AccountDataStream(Stream): """Global or per room account data was changed""" @@ -560,6 +539,19 @@ class AccountDataStream(Stream): to_token = room_results[-1][0] limited = True + tags, tag_to_token, tags_limited = await self.store.get_all_updated_tags( + instance_name, + from_token, + to_token, + limit, + ) + + # again, if the tag results hit the limit, limit the global results to + # the same stream token. + if tags_limited: + to_token = tag_to_token + limited = True + # convert the global results to the right format, and limit them to the to_token # at the same time global_rows = ( @@ -568,11 +560,16 @@ class AccountDataStream(Stream): if stream_id <= to_token ) - # we know that the room_results are already limited to `to_token` so no need - # for a check on `stream_id` here. room_rows = ( (stream_id, (user_id, room_id, account_data_type)) for stream_id, user_id, room_id, account_data_type in room_results + if stream_id <= to_token + ) + + tag_rows = ( + (stream_id, (user_id, room_id, AccountDataTypes.TAG)) + for stream_id, user_id, room_id in tags + if stream_id <= to_token ) # We need to return a sorted list, so merge them together. @@ -582,7 +579,9 @@ class AccountDataStream(Stream): # leading to a comparison between the data tuples. The comparison could # fail due to attempting to compare the `room_id` which results in a # `TypeError` from comparing a `str` vs `None`. - updates = list(heapq.merge(room_rows, global_rows, key=lambda row: row[0])) + updates = list( + heapq.merge(room_rows, global_rows, tag_rows, key=lambda row: row[0]) + ) return updates, to_token, limited diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index 86032897f5..881d7089db 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -27,7 +27,7 @@ from typing import ( ) from synapse.api.constants import AccountDataTypes -from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream +from synapse.replication.tcp.streams import AccountDataStream from synapse.storage._base import db_to_json from synapse.storage.database import ( DatabasePool, @@ -454,9 +454,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) def process_replication_position( self, stream_name: str, instance_name: str, token: int ) -> None: - if stream_name == TagAccountDataStream.NAME: - self._account_data_id_gen.advance(instance_name, token) - elif stream_name == AccountDataStream.NAME: + if stream_name == AccountDataStream.NAME: self._account_data_id_gen.advance(instance_name, token) super().process_replication_position(stream_name, instance_name, token) diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py index e23c927e02..d5500cdd47 100644 --- a/synapse/storage/databases/main/tags.py +++ b/synapse/storage/databases/main/tags.py @@ -17,7 +17,8 @@ import logging from typing import Any, Dict, Iterable, List, Tuple, cast -from synapse.replication.tcp.streams import TagAccountDataStream +from synapse.api.constants import AccountDataTypes +from synapse.replication.tcp.streams import AccountDataStream from synapse.storage._base import db_to_json from synapse.storage.database import LoggingTransaction from synapse.storage.databases.main.account_data import AccountDataWorkerStore @@ -54,7 +55,7 @@ class TagsWorkerStore(AccountDataWorkerStore): async def get_all_updated_tags( self, instance_name: str, last_id: int, current_id: int, limit: int - ) -> Tuple[List[Tuple[int, Tuple[str, str, str]]], int, bool]: + ) -> Tuple[List[Tuple[int, str, str]], int, bool]: """Get updates for tags replication stream. Args: @@ -73,7 +74,7 @@ class TagsWorkerStore(AccountDataWorkerStore): The token returned can be used in a subsequent call to this function to get further updatees. - The updates are a list of 2-tuples of stream ID and the row data + The updates are a list of tuples of stream ID, user ID and room ID """ if last_id == current_id: @@ -96,38 +97,13 @@ class TagsWorkerStore(AccountDataWorkerStore): "get_all_updated_tags", get_all_updated_tags_txn ) - def get_tag_content( - txn: LoggingTransaction, tag_ids: List[Tuple[int, str, str]] - ) -> List[Tuple[int, Tuple[str, str, str]]]: - sql = "SELECT tag, content FROM room_tags WHERE user_id=? AND room_id=?" - results = [] - for stream_id, user_id, room_id in tag_ids: - txn.execute(sql, (user_id, room_id)) - tags = [] - for tag, content in txn: - tags.append(json_encoder.encode(tag) + ":" + content) - tag_json = "{" + ",".join(tags) + "}" - results.append((stream_id, (user_id, room_id, tag_json))) - - return results - - batch_size = 50 - results = [] - for i in range(0, len(tag_ids), batch_size): - tags = await self.db_pool.runInteraction( - "get_all_updated_tag_content", - get_tag_content, - tag_ids[i : i + batch_size], - ) - results.extend(tags) - limited = False upto_token = current_id - if len(results) >= limit: - upto_token = results[-1][0] + if len(tag_ids) >= limit: + upto_token = tag_ids[-1][0] limited = True - return results, upto_token, limited + return tag_ids, upto_token, limited async def get_updated_tags( self, user_id: str, stream_id: int @@ -299,20 +275,16 @@ class TagsWorkerStore(AccountDataWorkerStore): token: int, rows: Iterable[Any], ) -> None: - if stream_name == TagAccountDataStream.NAME: + if stream_name == AccountDataStream.NAME: for row in rows: - self.get_tags_for_user.invalidate((row.user_id,)) - self._account_data_stream_cache.entity_has_changed(row.user_id, token) + if row.data_type == AccountDataTypes.TAG: + self.get_tags_for_user.invalidate((row.user_id,)) + self._account_data_stream_cache.entity_has_changed( + row.user_id, token + ) super().process_replication_rows(stream_name, instance_name, token, rows) - def process_replication_position( - self, stream_name: str, instance_name: str, token: int - ) -> None: - if stream_name == TagAccountDataStream.NAME: - self._account_data_id_gen.advance(instance_name, token) - super().process_replication_position(stream_name, instance_name, token) - class TagsStore(TagsWorkerStore): pass |