From a7bdf98d01d2225a479753a85ba81adf02b16a32 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 5 Aug 2020 21:38:57 +0100 Subject: Rename database classes to make some sense (#8033) --- synapse/storage/databases/main/cache.py | 307 ++++++++++++++++++++++++++++++++ 1 file changed, 307 insertions(+) create mode 100644 synapse/storage/databases/main/cache.py (limited to 'synapse/storage/databases/main/cache.py') diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py new file mode 100644 index 0000000000..683afde52b --- /dev/null +++ b/synapse/storage/databases/main/cache.py @@ -0,0 +1,307 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import itertools +import logging +from typing import Any, Iterable, List, Optional, Tuple + +from synapse.api.constants import EventTypes +from synapse.replication.tcp.streams import BackfillStream, CachesStream +from synapse.replication.tcp.streams.events import ( + EventsStream, + EventsStreamCurrentStateRow, + EventsStreamEventRow, +) +from synapse.storage._base import SQLBaseStore +from synapse.storage.database import DatabasePool +from synapse.storage.engines import PostgresEngine +from synapse.util.iterutils import batch_iter + +logger = logging.getLogger(__name__) + + +# This is a special cache name we use to batch multiple invalidations of caches +# based on the current state when notifying workers over replication. +CURRENT_STATE_CACHE_NAME = "cs_cache_fake" + + +class CacheInvalidationWorkerStore(SQLBaseStore): + def __init__(self, database: DatabasePool, db_conn, hs): + super().__init__(database, db_conn, hs) + + self._instance_name = hs.get_instance_name() + + async def get_all_updated_caches( + self, instance_name: str, last_id: int, current_id: int, limit: int + ) -> Tuple[List[Tuple[int, tuple]], int, bool]: + """Get updates for caches replication stream. + + Args: + instance_name: The writer we want to fetch updates from. Unused + here since there is only ever one writer. + last_id: The token to fetch updates from. Exclusive. + current_id: The token to fetch updates up to. Inclusive. + limit: The requested limit for the number of rows to return. The + function may return more or fewer rows. + + Returns: + A tuple consisting of: the updates, a token to use to fetch + subsequent updates, and whether we returned fewer rows than exists + between the requested tokens due to the limit. + + The token returned can be used in a subsequent call to this + function to get further updatees. + + The updates are a list of 2-tuples of stream ID and the row data + """ + + if last_id == current_id: + return [], current_id, False + + def get_all_updated_caches_txn(txn): + # We purposefully don't bound by the current token, as we want to + # send across cache invalidations as quickly as possible. Cache + # invalidations are idempotent, so duplicates are fine. + sql = """ + SELECT stream_id, cache_func, keys, invalidation_ts + FROM cache_invalidation_stream_by_instance + WHERE stream_id > ? AND instance_name = ? + ORDER BY stream_id ASC + LIMIT ? + """ + txn.execute(sql, (last_id, instance_name, limit)) + updates = [(row[0], row[1:]) for row in txn] + limited = False + upto_token = current_id + if len(updates) >= limit: + upto_token = updates[-1][0] + limited = True + + return updates, upto_token, limited + + return await self.db_pool.runInteraction( + "get_all_updated_caches", get_all_updated_caches_txn + ) + + def process_replication_rows(self, stream_name, instance_name, token, rows): + if stream_name == EventsStream.NAME: + for row in rows: + self._process_event_stream_row(token, row) + elif stream_name == BackfillStream.NAME: + for row in rows: + self._invalidate_caches_for_event( + -token, + row.event_id, + row.room_id, + row.type, + row.state_key, + row.redacts, + row.relates_to, + backfilled=True, + ) + elif stream_name == CachesStream.NAME: + if self._cache_id_gen: + self._cache_id_gen.advance(instance_name, token) + + for row in rows: + if row.cache_func == CURRENT_STATE_CACHE_NAME: + if row.keys is None: + raise Exception( + "Can't send an 'invalidate all' for current state cache" + ) + + room_id = row.keys[0] + members_changed = set(row.keys[1:]) + self._invalidate_state_caches(room_id, members_changed) + else: + self._attempt_to_invalidate_cache(row.cache_func, row.keys) + + super().process_replication_rows(stream_name, instance_name, token, rows) + + def _process_event_stream_row(self, token, row): + data = row.data + + if row.type == EventsStreamEventRow.TypeId: + self._invalidate_caches_for_event( + token, + data.event_id, + data.room_id, + data.type, + data.state_key, + data.redacts, + data.relates_to, + backfilled=False, + ) + elif row.type == EventsStreamCurrentStateRow.TypeId: + self._curr_state_delta_stream_cache.entity_has_changed( + row.data.room_id, token + ) + + if data.type == EventTypes.Member: + self.get_rooms_for_user_with_stream_ordering.invalidate( + (data.state_key,) + ) + else: + raise Exception("Unknown events stream row type %s" % (row.type,)) + + def _invalidate_caches_for_event( + self, + stream_ordering, + event_id, + room_id, + etype, + state_key, + redacts, + relates_to, + backfilled, + ): + self._invalidate_get_event_cache(event_id) + + self.get_latest_event_ids_in_room.invalidate((room_id,)) + + self.get_unread_message_count_for_user.invalidate_many((room_id,)) + self.get_unread_event_push_actions_by_room_for_user.invalidate_many((room_id,)) + + if not backfilled: + self._events_stream_cache.entity_has_changed(room_id, stream_ordering) + + if redacts: + self._invalidate_get_event_cache(redacts) + + if etype == EventTypes.Member: + self._membership_stream_cache.entity_has_changed(state_key, stream_ordering) + self.get_invited_rooms_for_local_user.invalidate((state_key,)) + + if relates_to: + self.get_relations_for_event.invalidate_many((relates_to,)) + self.get_aggregation_groups_for_event.invalidate_many((relates_to,)) + self.get_applicable_edit.invalidate((relates_to,)) + + async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]): + """Invalidates the cache and adds it to the cache stream so slaves + will know to invalidate their caches. + + This should only be used to invalidate caches where slaves won't + otherwise know from other replication streams that the cache should + be invalidated. + """ + cache_func = getattr(self, cache_name, None) + if not cache_func: + return + + cache_func.invalidate(keys) + await self.db_pool.runInteraction( + "invalidate_cache_and_stream", + self._send_invalidation_to_replication, + cache_func.__name__, + keys, + ) + + def _invalidate_cache_and_stream(self, txn, cache_func, keys): + """Invalidates the cache and adds it to the cache stream so slaves + will know to invalidate their caches. + + This should only be used to invalidate caches where slaves won't + otherwise know from other replication streams that the cache should + be invalidated. + """ + txn.call_after(cache_func.invalidate, keys) + self._send_invalidation_to_replication(txn, cache_func.__name__, keys) + + def _invalidate_all_cache_and_stream(self, txn, cache_func): + """Invalidates the entire cache and adds it to the cache stream so slaves + will know to invalidate their caches. + """ + + txn.call_after(cache_func.invalidate_all) + self._send_invalidation_to_replication(txn, cache_func.__name__, None) + + def _invalidate_state_caches_and_stream(self, txn, room_id, members_changed): + """Special case invalidation of caches based on current state. + + We special case this so that we can batch the cache invalidations into a + single replication poke. + + Args: + txn + room_id (str): Room where state changed + members_changed (iterable[str]): The user_ids of members that have changed + """ + txn.call_after(self._invalidate_state_caches, room_id, members_changed) + + if members_changed: + # We need to be careful that the size of the `members_changed` list + # isn't so large that it causes problems sending over replication, so we + # send them in chunks. + # Max line length is 16K, and max user ID length is 255, so 50 should + # be safe. + for chunk in batch_iter(members_changed, 50): + keys = itertools.chain([room_id], chunk) + self._send_invalidation_to_replication( + txn, CURRENT_STATE_CACHE_NAME, keys + ) + else: + # if no members changed, we still need to invalidate the other caches. + self._send_invalidation_to_replication( + txn, CURRENT_STATE_CACHE_NAME, [room_id] + ) + + def _send_invalidation_to_replication( + self, txn, cache_name: str, keys: Optional[Iterable[Any]] + ): + """Notifies replication that given cache has been invalidated. + + Note that this does *not* invalidate the cache locally. + + Args: + txn + cache_name + keys: Entry to invalidate. If None will invalidate all. + """ + + if cache_name == CURRENT_STATE_CACHE_NAME and keys is None: + raise Exception( + "Can't stream invalidate all with magic current state cache" + ) + + if isinstance(self.database_engine, PostgresEngine): + # get_next() returns a context manager which is designed to wrap + # the transaction. However, we want to only get an ID when we want + # to use it, here, so we need to call __enter__ manually, and have + # __exit__ called after the transaction finishes. + stream_id = self._cache_id_gen.get_next_txn(txn) + txn.call_after(self.hs.get_notifier().on_new_replication_data) + + if keys is not None: + keys = list(keys) + + self.db_pool.simple_insert_txn( + txn, + table="cache_invalidation_stream_by_instance", + values={ + "stream_id": stream_id, + "instance_name": self._instance_name, + "cache_func": cache_name, + "keys": keys, + "invalidation_ts": self.clock.time_msec(), + }, + ) + + def get_cache_stream_token(self, instance_name): + if self._cache_id_gen: + return self._cache_id_gen.get_current_token(instance_name) + else: + return 0 -- cgit 1.5.1 From 2ffd6783c7af12e3c29e1a44dee4a9deeb83890b Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Thu, 6 Aug 2020 17:15:35 +0100 Subject: Revert #7736 (#8039) --- changelog.d/7736.feature | 1 - changelog.d/8039.misc | 1 + scripts/synapse_port_db | 2 +- synapse/handlers/sync.py | 6 - synapse/push/push_tools.py | 17 ++- synapse/rest/client/v2_alpha/sync.py | 1 - synapse/storage/databases/main/cache.py | 1 - synapse/storage/databases/main/events.py | 48 +------ synapse/storage/databases/main/events_worker.py | 86 +---------- .../main/schema/delta/58/12unread_messages.sql | 18 --- tests/rest/client/v1/utils.py | 20 --- tests/rest/client/v2_alpha/test_sync.py | 157 +-------------------- 12 files changed, 19 insertions(+), 339 deletions(-) delete mode 100644 changelog.d/7736.feature create mode 100644 changelog.d/8039.misc delete mode 100644 synapse/storage/databases/main/schema/delta/58/12unread_messages.sql (limited to 'synapse/storage/databases/main/cache.py') diff --git a/changelog.d/7736.feature b/changelog.d/7736.feature deleted file mode 100644 index feb02be234..0000000000 --- a/changelog.d/7736.feature +++ /dev/null @@ -1 +0,0 @@ -Add unread messages count to sync responses, as specified in [MSC2654](https://github.com/matrix-org/matrix-doc/pull/2654). diff --git a/changelog.d/8039.misc b/changelog.d/8039.misc new file mode 100644 index 0000000000..599933c80e --- /dev/null +++ b/changelog.d/8039.misc @@ -0,0 +1 @@ +Revert MSC2654 implementation because of perf issues. Please delete this line when processing the 1.19 changelog. diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index ae5e1810fc..a34bdf1830 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -67,7 +67,7 @@ logger = logging.getLogger("synapse_port_db") BOOLEAN_COLUMNS = { - "events": ["processed", "outlier", "contains_url", "count_as_unread"], + "events": ["processed", "outlier", "contains_url"], "rooms": ["is_public"], "event_edges": ["is_state"], "presence_list": ["accepted"], diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 5a19bac929..c42dac18f5 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -103,7 +103,6 @@ class JoinedSyncResult: account_data = attr.ib(type=List[JsonDict]) unread_notifications = attr.ib(type=JsonDict) summary = attr.ib(type=Optional[JsonDict]) - unread_count = attr.ib(type=int) def __nonzero__(self) -> bool: """Make the result appear empty if there are no updates. This is used @@ -1887,10 +1886,6 @@ class SyncHandler(object): if room_builder.rtype == "joined": unread_notifications = {} # type: Dict[str, str] - - unread_count = await self.store.get_unread_message_count_for_user( - room_id, sync_config.user.to_string(), - ) room_sync = JoinedSyncResult( room_id=room_id, timeline=batch, @@ -1899,7 +1894,6 @@ class SyncHandler(object): account_data=account_data_events, unread_notifications=unread_notifications, summary=summary, - unread_count=unread_count, ) if room_sync or always_include: diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index bc8f71916b..d0145666bf 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -21,13 +21,22 @@ async def get_badge_count(store, user_id): invites = await store.get_invited_rooms_for_local_user(user_id) joins = await store.get_rooms_for_user(user_id) + my_receipts_by_room = await store.get_receipts_for_user(user_id, "m.read") + badge = len(invites) for room_id in joins: - unread_count = await store.get_unread_message_count_for_user(room_id, user_id) - # return one badge count per conversation, as count per - # message is so noisy as to be almost useless - badge += 1 if unread_count else 0 + if room_id in my_receipts_by_room: + last_unread_event_id = my_receipts_by_room[room_id] + + notifs = await ( + store.get_unread_event_push_actions_by_room_for_user( + room_id, user_id, last_unread_event_id + ) + ) + # return one badge count per conversation, as count per + # message is so noisy as to be almost useless + badge += 1 if notifs["notify_count"] else 0 return badge diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index 3f5bf75e59..a5c24fbd63 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -426,7 +426,6 @@ class SyncRestServlet(RestServlet): result["ephemeral"] = {"events": ephemeral_events} result["unread_notifications"] = room.unread_notifications result["summary"] = room.summary - result["org.matrix.msc2654.unread_count"] = room.unread_count return result diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index 683afde52b..10de446065 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -172,7 +172,6 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self.get_latest_event_ids_in_room.invalidate((room_id,)) - self.get_unread_message_count_for_user.invalidate_many((room_id,)) self.get_unread_event_push_actions_by_room_for_user.invalidate_many((room_id,)) if not backfilled: diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 4d8a24ce4b..1a68bf32cb 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -53,47 +53,6 @@ event_counter = Counter( ["type", "origin_type", "origin_entity"], ) -STATE_EVENT_TYPES_TO_MARK_UNREAD = { - EventTypes.Topic, - EventTypes.Name, - EventTypes.RoomAvatar, - EventTypes.Tombstone, -} - - -def should_count_as_unread(event: EventBase, context: EventContext) -> bool: - # Exclude rejected and soft-failed events. - if context.rejected or event.internal_metadata.is_soft_failed(): - return False - - # Exclude notices. - if ( - not event.is_state() - and event.type == EventTypes.Message - and event.content.get("msgtype") == "m.notice" - ): - return False - - # Exclude edits. - relates_to = event.content.get("m.relates_to", {}) - if relates_to.get("rel_type") == RelationTypes.REPLACE: - return False - - # Mark events that have a non-empty string body as unread. - body = event.content.get("body") - if isinstance(body, str) and body: - return True - - # Mark some state events as unread. - if event.is_state() and event.type in STATE_EVENT_TYPES_TO_MARK_UNREAD: - return True - - # Mark encrypted events as unread. - if not event.is_state() and event.type == EventTypes.Encrypted: - return True - - return False - def encode_json(json_object): """ @@ -239,10 +198,6 @@ class PersistEventsStore: event_counter.labels(event.type, origin_type, origin_entity).inc() - self.store.get_unread_message_count_for_user.invalidate_many( - (event.room_id,), - ) - for room_id, new_state in current_state_for_room.items(): self.store.get_current_state_ids.prefill((room_id,), new_state) @@ -864,9 +819,8 @@ class PersistEventsStore: "contains_url": ( "url" in event.content and isinstance(event.content["url"], str) ), - "count_as_unread": should_count_as_unread(event, context), } - for event, context in events_and_contexts + for event, _ in events_and_contexts ], ) diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index a7b7393f6e..755b7a2a85 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -41,15 +41,9 @@ from synapse.replication.tcp.streams import BackfillStream from synapse.replication.tcp.streams.events import EventsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool -from synapse.storage.types import Cursor from synapse.storage.util.id_generators import StreamIdGenerator from synapse.types import get_domain_from_id -from synapse.util.caches.descriptors import ( - Cache, - _CacheContext, - cached, - cachedInlineCallbacks, -) +from synapse.util.caches.descriptors import Cache, cached, cachedInlineCallbacks from synapse.util.iterutils import batch_iter from synapse.util.metrics import Measure @@ -1364,84 +1358,6 @@ class EventsWorkerStore(SQLBaseStore): desc="get_next_event_to_expire", func=get_next_event_to_expire_txn ) - @cached(tree=True, cache_context=True) - async def get_unread_message_count_for_user( - self, room_id: str, user_id: str, cache_context: _CacheContext, - ) -> int: - """Retrieve the count of unread messages for the given room and user. - - Args: - room_id: The ID of the room to count unread messages in. - user_id: The ID of the user to count unread messages for. - - Returns: - The number of unread messages for the given user in the given room. - """ - with Measure(self._clock, "get_unread_message_count_for_user"): - last_read_event_id = await self.get_last_receipt_event_id_for_user( - user_id=user_id, - room_id=room_id, - receipt_type="m.read", - on_invalidate=cache_context.invalidate, - ) - - return await self.db_pool.runInteraction( - "get_unread_message_count_for_user", - self._get_unread_message_count_for_user_txn, - user_id, - room_id, - last_read_event_id, - ) - - def _get_unread_message_count_for_user_txn( - self, - txn: Cursor, - user_id: str, - room_id: str, - last_read_event_id: Optional[str], - ) -> int: - if last_read_event_id: - # Get the stream ordering for the last read event. - stream_ordering = self.db_pool.simple_select_one_onecol_txn( - txn=txn, - table="events", - keyvalues={"room_id": room_id, "event_id": last_read_event_id}, - retcol="stream_ordering", - ) - else: - # If there's no read receipt for that room, it probably means the user hasn't - # opened it yet, in which case use the stream ID of their join event. - # We can't just set it to 0 otherwise messages from other local users from - # before this user joined will be counted as well. - txn.execute( - """ - SELECT stream_ordering FROM local_current_membership - LEFT JOIN events USING (event_id, room_id) - WHERE membership = 'join' - AND user_id = ? - AND room_id = ? - """, - (user_id, room_id), - ) - row = txn.fetchone() - - if row is None: - return 0 - - stream_ordering = row[0] - - # Count the messages that qualify as unread after the stream ordering we've just - # retrieved. - sql = """ - SELECT COUNT(*) FROM events - WHERE sender != ? AND room_id = ? AND stream_ordering > ? AND count_as_unread - """ - - txn.execute(sql, (user_id, room_id, stream_ordering)) - row = txn.fetchone() - - return row[0] if row else 0 - AllNewEventsResult = namedtuple( "AllNewEventsResult", diff --git a/synapse/storage/databases/main/schema/delta/58/12unread_messages.sql b/synapse/storage/databases/main/schema/delta/58/12unread_messages.sql deleted file mode 100644 index 531b532c73..0000000000 --- a/synapse/storage/databases/main/schema/delta/58/12unread_messages.sql +++ /dev/null @@ -1,18 +0,0 @@ -/* Copyright 2020 The Matrix.org Foundation C.I.C - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- Store a boolean value in the events table for whether the event should be counted in --- the unread_count property of sync responses. -ALTER TABLE events ADD COLUMN count_as_unread BOOLEAN; diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 51941f99f9..8933b560d2 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -165,26 +165,6 @@ class RestHelper(object): return channel.json_body - def redact(self, room_id, event_id, txn_id=None, tok=None, expect_code=200): - if txn_id is None: - txn_id = "m%s" % (str(time.time())) - - path = "/_matrix/client/r0/rooms/%s/redact/%s/%s" % (room_id, event_id, txn_id) - if tok: - path = path + "?access_token=%s" % tok - - request, channel = make_request( - self.hs.get_reactor(), "PUT", path, json.dumps({}).encode("utf8") - ) - render(request, self.resource, self.hs.get_reactor()) - - assert int(channel.result["code"]) == expect_code, ( - "Expected: %d, got: %d, resp: %r" - % (expect_code, int(channel.result["code"]), channel.result["body"]) - ) - - return channel.json_body - def _read_write_state( self, room_id: str, diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py index a31e44c97e..fa3a3ec1bd 100644 --- a/tests/rest/client/v2_alpha/test_sync.py +++ b/tests/rest/client/v2_alpha/test_sync.py @@ -16,9 +16,9 @@ import json import synapse.rest.admin -from synapse.api.constants import EventContentFields, EventTypes, RelationTypes +from synapse.api.constants import EventContentFields, EventTypes from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import read_marker, sync +from synapse.rest.client.v2_alpha import sync from tests import unittest from tests.server import TimedOutException @@ -324,156 +324,3 @@ class SyncTypingTests(unittest.HomeserverTestCase): "GET", sync_url % (access_token, next_batch) ) self.assertRaises(TimedOutException, self.render, request) - - -class UnreadMessagesTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets, - login.register_servlets, - read_marker.register_servlets, - room.register_servlets, - sync.register_servlets, - ] - - def prepare(self, reactor, clock, hs): - self.url = "/sync?since=%s" - self.next_batch = "s0" - - # Register the first user (used to check the unread counts). - self.user_id = self.register_user("kermit", "monkey") - self.tok = self.login("kermit", "monkey") - - # Create the room we'll check unread counts for. - self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) - - # Register the second user (used to send events to the room). - self.user2 = self.register_user("kermit2", "monkey") - self.tok2 = self.login("kermit2", "monkey") - - # Change the power levels of the room so that the second user can send state - # events. - self.helper.send_state( - self.room_id, - EventTypes.PowerLevels, - { - "users": {self.user_id: 100, self.user2: 100}, - "users_default": 0, - "events": { - "m.room.name": 50, - "m.room.power_levels": 100, - "m.room.history_visibility": 100, - "m.room.canonical_alias": 50, - "m.room.avatar": 50, - "m.room.tombstone": 100, - "m.room.server_acl": 100, - "m.room.encryption": 100, - }, - "events_default": 0, - "state_default": 50, - "ban": 50, - "kick": 50, - "redact": 50, - "invite": 0, - }, - tok=self.tok, - ) - - def test_unread_counts(self): - """Tests that /sync returns the right value for the unread count (MSC2654).""" - - # Check that our own messages don't increase the unread count. - self.helper.send(self.room_id, "hello", tok=self.tok) - self._check_unread_count(0) - - # Join the new user and check that this doesn't increase the unread count. - self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2) - self._check_unread_count(0) - - # Check that the new user sending a message increases our unread count. - res = self.helper.send(self.room_id, "hello", tok=self.tok2) - self._check_unread_count(1) - - # Send a read receipt to tell the server we've read the latest event. - body = json.dumps({"m.read": res["event_id"]}).encode("utf8") - request, channel = self.make_request( - "POST", - "/rooms/%s/read_markers" % self.room_id, - body, - access_token=self.tok, - ) - self.render(request) - self.assertEqual(channel.code, 200, channel.json_body) - - # Check that the unread counter is back to 0. - self._check_unread_count(0) - - # Check that room name changes increase the unread counter. - self.helper.send_state( - self.room_id, "m.room.name", {"name": "my super room"}, tok=self.tok2, - ) - self._check_unread_count(1) - - # Check that room topic changes increase the unread counter. - self.helper.send_state( - self.room_id, "m.room.topic", {"topic": "welcome!!!"}, tok=self.tok2, - ) - self._check_unread_count(2) - - # Check that encrypted messages increase the unread counter. - self.helper.send_event(self.room_id, EventTypes.Encrypted, {}, tok=self.tok2) - self._check_unread_count(3) - - # Check that custom events with a body increase the unread counter. - self.helper.send_event( - self.room_id, "org.matrix.custom_type", {"body": "hello"}, tok=self.tok2, - ) - self._check_unread_count(4) - - # Check that edits don't increase the unread counter. - self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={ - "body": "hello", - "msgtype": "m.text", - "m.relates_to": {"rel_type": RelationTypes.REPLACE}, - }, - tok=self.tok2, - ) - self._check_unread_count(4) - - # Check that notices don't increase the unread counter. - self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={"body": "hello", "msgtype": "m.notice"}, - tok=self.tok2, - ) - self._check_unread_count(4) - - # Check that tombstone events changes increase the unread counter. - self.helper.send_state( - self.room_id, - EventTypes.Tombstone, - {"replacement_room": "!someroom:test"}, - tok=self.tok2, - ) - self._check_unread_count(5) - - def _check_unread_count(self, expected_count: True): - """Syncs and compares the unread count with the expected value.""" - - request, channel = self.make_request( - "GET", self.url % self.next_batch, access_token=self.tok, - ) - self.render(request) - - self.assertEqual(channel.code, 200, channel.json_body) - - room_entry = channel.json_body["rooms"]["join"][self.room_id] - self.assertEqual( - room_entry["org.matrix.msc2654.unread_count"], expected_count, room_entry, - ) - - # Store the next batch for the next request. - self.next_batch = channel.json_body["next_batch"] -- cgit 1.5.1 From 76d21d14a042756b0c8a8f520dfd9ea09cf092c7 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 19 Aug 2020 10:39:31 +0100 Subject: Separate `get_current_token` into two. (#8113) The function is used for two purposes: 1) for subscribers of streams to get a token they can use to get further updates with, and 2) for replication to track position of the writers of the stream. For streams with a single writer the two scenarios produce the same result, however the situation becomes complicated for streams with multiple writers. The current `MultiWriterIdGenerator` does not correctly handle the first case (which is not an issue as its only used for the `caches` stream which nothing subscribes to outside of replication). --- changelog.d/8113.misc | 1 + .../slave/storage/_slaved_id_tracker.py | 8 +++++ synapse/replication/tcp/streams/_base.py | 2 +- synapse/storage/databases/main/cache.py | 4 +-- synapse/storage/util/id_generators.py | 36 ++++++++++++++++------ tests/storage/test_id_generators.py | 16 +++++----- 6 files changed, 47 insertions(+), 20 deletions(-) create mode 100644 changelog.d/8113.misc (limited to 'synapse/storage/databases/main/cache.py') diff --git a/changelog.d/8113.misc b/changelog.d/8113.misc new file mode 100644 index 0000000000..00bec4f8ef --- /dev/null +++ b/changelog.d/8113.misc @@ -0,0 +1 @@ +Separate `get_current_token` into two since there are two different use cases for it. diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py index 9d1d173b2f..d43eaf3a29 100644 --- a/synapse/replication/slave/storage/_slaved_id_tracker.py +++ b/synapse/replication/slave/storage/_slaved_id_tracker.py @@ -33,3 +33,11 @@ class SlavedIdTracker(object): int """ return self._current + + def get_current_token_for_writer(self, instance_name: str) -> int: + """Returns the position of the given writer. + + For streams with single writers this is equivalent to + `get_current_token`. + """ + return self.get_current_token() diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 7a42de3f7d..1e92d52165 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -405,7 +405,7 @@ class CachesStream(Stream): store = hs.get_datastore() super().__init__( hs.get_instance_name(), - store.get_cache_stream_token, + store.get_cache_stream_token_for_writer, store.get_all_updated_caches, ) diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index 10de446065..1e7637a6f5 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -299,8 +299,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore): }, ) - def get_cache_stream_token(self, instance_name): + def get_cache_stream_token_for_writer(self, instance_name: str) -> int: if self._cache_id_gen: - return self._cache_id_gen.get_current_token(instance_name) + return self._cache_id_gen.get_current_token_for_writer(instance_name) else: return 0 diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index e2ddd01290..8276a755e5 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -158,6 +158,14 @@ class StreamIdGenerator(object): return self._current + def get_current_token_for_writer(self, instance_name: str) -> int: + """Returns the position of the given writer. + + For streams with single writers this is equivalent to + `get_current_token`. + """ + return self.get_current_token() + class ChainedIdGenerator(object): """Used to generate new stream ids where the stream must be kept in sync @@ -216,6 +224,14 @@ class ChainedIdGenerator(object): "Attempted to advance token on source for table %r", self._table ) + def get_current_token_for_writer(self, instance_name: str) -> Tuple[int, int]: + """Returns the position of the given writer. + + For streams with single writers this is equivalent to + `get_current_token`. + """ + return self.get_current_token() + class MultiWriterIdGenerator: """An ID generator that tracks a stream that can have multiple writers. @@ -298,7 +314,7 @@ class MultiWriterIdGenerator: # Assert the fetched ID is actually greater than what we currently # believe the ID to be. If not, then the sequence and table have got # out of sync somehow. - assert self.get_current_token() < next_id + assert self.get_current_token_for_writer(self._instance_name) < next_id with self._lock: self._unfinished_ids.add(next_id) @@ -344,16 +360,18 @@ class MultiWriterIdGenerator: curr = self._current_positions.get(self._instance_name, 0) self._current_positions[self._instance_name] = max(curr, next_id) - def get_current_token(self, instance_name: str = None) -> int: - """Gets the current position of a named writer (defaults to current - instance). - - Returns 0 if we don't have a position for the named writer (likely due - to it being a new writer). + def get_current_token(self) -> int: + """Returns the maximum stream id such that all stream ids less than or + equal to it have been successfully persisted. """ - if instance_name is None: - instance_name = self._instance_name + # Currently we don't support this operation, as it's not obvious how to + # condense the stream positions of multiple writers into a single int. + raise NotImplementedError() + + def get_current_token_for_writer(self, instance_name: str) -> int: + """Returns the position of the given writer. + """ with self._lock: return self._current_positions.get(instance_name, 0) diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py index e845410dae..7a05194653 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py @@ -88,7 +88,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): id_gen = self._create_id_generator() self.assertEqual(id_gen.get_positions(), {"master": 7}) - self.assertEqual(id_gen.get_current_token("master"), 7) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) # Try allocating a new ID gen and check that we only see position # advanced after we leave the context manager. @@ -98,12 +98,12 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self.assertEqual(stream_id, 8) self.assertEqual(id_gen.get_positions(), {"master": 7}) - self.assertEqual(id_gen.get_current_token("master"), 7) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) self.get_success(_get_next_async()) self.assertEqual(id_gen.get_positions(), {"master": 8}) - self.assertEqual(id_gen.get_current_token("master"), 8) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 8) def test_multi_instance(self): """Test that reads and writes from multiple processes are handled @@ -116,8 +116,8 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): second_id_gen = self._create_id_generator("second") self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7}) - self.assertEqual(first_id_gen.get_current_token("first"), 3) - self.assertEqual(first_id_gen.get_current_token("second"), 7) + self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3) + self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7) # Try allocating a new ID gen and check that we only see position # advanced after we leave the context manager. @@ -166,7 +166,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): id_gen = self._create_id_generator() self.assertEqual(id_gen.get_positions(), {"master": 7}) - self.assertEqual(id_gen.get_current_token("master"), 7) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) # Try allocating a new ID gen and check that we only see position # advanced after we leave the context manager. @@ -176,9 +176,9 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self.assertEqual(stream_id, 8) self.assertEqual(id_gen.get_positions(), {"master": 7}) - self.assertEqual(id_gen.get_current_token("master"), 7) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) self.get_success(self.db_pool.runInteraction("test", _get_next_txn)) self.assertEqual(id_gen.get_positions(), {"master": 8}) - self.assertEqual(id_gen.get_current_token("master"), 8) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 8) -- cgit 1.5.1