From 41320a0554716aaf7cec6172da98e002c48344c5 Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Thu, 4 Aug 2022 15:49:55 +0100 Subject: Optimise async get event lookups (#13435) Still maintains local in memory lookup optimisation, but does any external lookup as part of the deferred that prevents duplicate lookups for the same event at once. This makes the assumption that fetching from an external cache is a non-zero load operation. --- synapse/storage/databases/main/events_worker.py | 75 ++++++++++++++++++++++--- synapse/storage/databases/main/roommember.py | 2 +- 2 files changed, 69 insertions(+), 8 deletions(-) (limited to 'synapse/storage/databases') diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 29c99c6357..e9ff6cfb34 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -600,7 +600,11 @@ class EventsWorkerStore(SQLBaseStore): Returns: map from event id to result """ - event_entry_map = await self._get_events_from_cache( + # Shortcut: check if we have any events in the *in memory* cache - this function + # may be called repeatedly for the same event so at this point we cannot reach + # out to any external cache for performance reasons. The external cache is + # checked later on in the `get_missing_events_from_cache_or_db` function below. + event_entry_map = self._get_events_from_local_cache( event_ids, ) @@ -632,7 +636,9 @@ class EventsWorkerStore(SQLBaseStore): if missing_events_ids: - async def get_missing_events_from_db() -> Dict[str, EventCacheEntry]: + async def get_missing_events_from_cache_or_db() -> Dict[ + str, EventCacheEntry + ]: """Fetches the events in `missing_event_ids` from the database. Also creates entries in `self._current_event_fetches` to allow @@ -657,10 +663,18 @@ class EventsWorkerStore(SQLBaseStore): # the events have been redacted, and if so pulling the redaction event # out of the database to check it. # + missing_events = {} try: - missing_events = await self._get_events_from_db( + # Try to fetch from any external cache. We already checked the + # in-memory cache above. + missing_events = await self._get_events_from_external_cache( missing_events_ids, ) + # Now actually fetch any remaining events from the DB + db_missing_events = await self._get_events_from_db( + missing_events_ids - missing_events.keys(), + ) + missing_events.update(db_missing_events) except Exception as e: with PreserveLoggingContext(): fetching_deferred.errback(e) @@ -679,7 +693,7 @@ class EventsWorkerStore(SQLBaseStore): # cancellations, since multiple `_get_events_from_cache_or_db` calls can # reuse the same fetch. missing_events: Dict[str, EventCacheEntry] = await delay_cancellation( - get_missing_events_from_db() + get_missing_events_from_cache_or_db() ) event_entry_map.update(missing_events) @@ -754,7 +768,54 @@ class EventsWorkerStore(SQLBaseStore): async def _get_events_from_cache( self, events: Iterable[str], update_metrics: bool = True ) -> Dict[str, EventCacheEntry]: - """Fetch events from the caches. + """Fetch events from the caches, both in memory and any external. + + May return rejected events. + + Args: + events: list of event_ids to fetch + update_metrics: Whether to update the cache hit ratio metrics + """ + event_map = self._get_events_from_local_cache( + events, update_metrics=update_metrics + ) + + missing_event_ids = (e for e in events if e not in event_map) + event_map.update( + await self._get_events_from_external_cache( + events=missing_event_ids, + update_metrics=update_metrics, + ) + ) + + return event_map + + async def _get_events_from_external_cache( + self, events: Iterable[str], update_metrics: bool = True + ) -> Dict[str, EventCacheEntry]: + """Fetch events from any configured external cache. + + May return rejected events. + + Args: + events: list of event_ids to fetch + update_metrics: Whether to update the cache hit ratio metrics + """ + event_map = {} + + for event_id in events: + ret = await self._get_event_cache.get_external( + (event_id,), None, update_metrics=update_metrics + ) + if ret: + event_map[event_id] = ret + + return event_map + + def _get_events_from_local_cache( + self, events: Iterable[str], update_metrics: bool = True + ) -> Dict[str, EventCacheEntry]: + """Fetch events from the local, in memory, caches. May return rejected events. @@ -766,7 +827,7 @@ class EventsWorkerStore(SQLBaseStore): for event_id in events: # First check if it's in the event cache - ret = await self._get_event_cache.get( + ret = self._get_event_cache.get_local( (event_id,), None, update_metrics=update_metrics ) if ret: @@ -788,7 +849,7 @@ class EventsWorkerStore(SQLBaseStore): # We add the entry back into the cache as we want to keep # recently queried events in the cache. - await self._get_event_cache.set((event_id,), cache_entry) + self._get_event_cache.set_local((event_id,), cache_entry) return event_map diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index e2cccc688c..93ff4816c8 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -896,7 +896,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): # We don't update the event cache hit ratio as it completely throws off # the hit ratio counts. After all, we don't populate the cache if we # miss it here - event_map = await self._get_events_from_cache( + event_map = self._get_events_from_local_cache( member_event_ids, update_metrics=False ) -- cgit 1.5.1 From 96d92156d0f820224f68092e72d6089dceef715a Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 4 Aug 2022 17:45:01 +0100 Subject: Update type of `EventContext.rejected` (#13460) --- changelog.d/13460.misc | 1 + synapse/events/snapshot.py | 7 +++---- synapse/storage/databases/main/events.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) create mode 100644 changelog.d/13460.misc (limited to 'synapse/storage/databases') diff --git a/changelog.d/13460.misc b/changelog.d/13460.misc new file mode 100644 index 0000000000..f9e9de219d --- /dev/null +++ b/changelog.d/13460.misc @@ -0,0 +1 @@ +Update type of `EventContext.rejected`. diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index b700cbbfa1..d3c8083e4a 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -11,11 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, List, Optional, Tuple import attr from frozendict import frozendict -from typing_extensions import Literal from synapse.appservice import ApplicationService from synapse.events import EventBase @@ -33,7 +32,7 @@ class EventContext: Holds information relevant to persisting an event Attributes: - rejected: A rejection reason if the event was rejected, else False + rejected: A rejection reason if the event was rejected, else None _state_group: The ID of the state group for this event. Note that state events are persisted with a state group which includes the new event, so this is @@ -85,7 +84,7 @@ class EventContext: """ _storage: "StorageControllers" - rejected: Union[Literal[False], str] = False + rejected: Optional[str] = None _state_group: Optional[int] = None state_group_before_event: Optional[int] = None _state_delta_due_to_event: Optional[StateMap[str]] = None diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 1f600f1190..5560b38a48 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1490,7 +1490,7 @@ class PersistEventsStore: event.sender, "url" in event.content and isinstance(event.content["url"], str), event.get_state_key(), - context.rejected or None, + context.rejected, ) for event, context in events_and_contexts ), -- cgit 1.5.1 From ec24813220f9d54108924dc04aecd24555277b99 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 4 Aug 2022 15:24:44 -0400 Subject: Improve comments (& avoid a duplicate query) in push actions processing. (#13455) * Adds docstrings and inline comments. * Formats SQL queries using triple quoted strings. * Minor formatting changes. * Avoid fetching `event_push_summary_stream_ordering` multiple times in the same transactions. --- changelog.d/13455.misc | 1 + .../storage/databases/main/event_push_actions.py | 282 ++++++++++++--------- 2 files changed, 159 insertions(+), 124 deletions(-) create mode 100644 changelog.d/13455.misc (limited to 'synapse/storage/databases') diff --git a/changelog.d/13455.misc b/changelog.d/13455.misc new file mode 100644 index 0000000000..17462c56f3 --- /dev/null +++ b/changelog.d/13455.misc @@ -0,0 +1 @@ +Add some comments about how event push actions are stored. diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index dd2627037c..5ddddb1cf3 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -265,7 +265,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas counts.notify_count += row[1] counts.unread_count += row[2] - # Next we need to count highlights, which aren't summarized + # Next we need to count highlights, which aren't summarised sql = """ SELECT COUNT(*) FROM event_push_actions WHERE user_id = ? @@ -280,7 +280,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas # Finally we need to count push actions that aren't included in the # summary returned above, e.g. recent events that haven't been - # summarized yet, or the summary is empty due to a recent read receipt. + # summarised yet, or the summary is empty due to a recent read receipt. stream_ordering = max(stream_ordering, summary_stream_ordering) notify_count, unread_count = self._get_notif_unread_count_for_user_room( txn, room_id, user_id, stream_ordering @@ -304,6 +304,17 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas Does not consult `event_push_summary` table, which may include push actions that have been deleted from `event_push_actions` table. + + Args: + txn: The database transaction. + room_id: The room ID to get unread counts for. + user_id: The user ID to get unread counts for. + stream_ordering: The (exclusive) minimum stream ordering to consider. + max_stream_ordering: The (inclusive) maximum stream ordering to consider. + If this is not given, then no maximum is applied. + + Return: + A tuple of the notif count and unread count in the given range. """ # If there have been no events in the room since the stream ordering, @@ -383,27 +394,27 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas ) -> List[Tuple[str, str, int, str, bool]]: # find rooms that have a read receipt in them and return the next # push actions - sql = ( - "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," - " ep.highlight " - " FROM (" - " SELECT room_id," - " MAX(stream_ordering) as stream_ordering" - " FROM events" - " INNER JOIN receipts_linearized USING (room_id, event_id)" - " WHERE receipt_type = 'm.read' AND user_id = ?" - " GROUP BY room_id" - ") AS rl," - " event_push_actions AS ep" - " WHERE" - " ep.room_id = rl.room_id" - " AND ep.stream_ordering > rl.stream_ordering" - " AND ep.user_id = ?" - " AND ep.stream_ordering > ?" - " AND ep.stream_ordering <= ?" - " AND ep.notif = 1" - " ORDER BY ep.stream_ordering ASC LIMIT ?" - ) + sql = """ + SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions, + ep.highlight + FROM ( + SELECT room_id, + MAX(stream_ordering) as stream_ordering + FROM events + INNER JOIN receipts_linearized USING (room_id, event_id) + WHERE receipt_type = 'm.read' AND user_id = ? + GROUP BY room_id + ) AS rl, + event_push_actions AS ep + WHERE + ep.room_id = rl.room_id + AND ep.stream_ordering > rl.stream_ordering + AND ep.user_id = ? + AND ep.stream_ordering > ? + AND ep.stream_ordering <= ? + AND ep.notif = 1 + ORDER BY ep.stream_ordering ASC LIMIT ? + """ args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] txn.execute(sql, args) return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall()) @@ -418,23 +429,23 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas def get_no_receipt( txn: LoggingTransaction, ) -> List[Tuple[str, str, int, str, bool]]: - sql = ( - "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," - " ep.highlight " - " FROM event_push_actions AS ep" - " INNER JOIN events AS e USING (room_id, event_id)" - " WHERE" - " ep.room_id NOT IN (" - " SELECT room_id FROM receipts_linearized" - " WHERE receipt_type = 'm.read' AND user_id = ?" - " GROUP BY room_id" - " )" - " AND ep.user_id = ?" - " AND ep.stream_ordering > ?" - " AND ep.stream_ordering <= ?" - " AND ep.notif = 1" - " ORDER BY ep.stream_ordering ASC LIMIT ?" - ) + sql = """ + SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions, + ep.highlight + FROM event_push_actions AS ep + INNER JOIN events AS e USING (room_id, event_id) + WHERE + ep.room_id NOT IN ( + SELECT room_id FROM receipts_linearized + WHERE receipt_type = 'm.read' AND user_id = ? + GROUP BY room_id + ) + AND ep.user_id = ? + AND ep.stream_ordering > ? + AND ep.stream_ordering <= ? + AND ep.notif = 1 + ORDER BY ep.stream_ordering ASC LIMIT ? + """ args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] txn.execute(sql, args) return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall()) @@ -490,28 +501,28 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas def get_after_receipt( txn: LoggingTransaction, ) -> List[Tuple[str, str, int, str, bool, int]]: - sql = ( - "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," - " ep.highlight, e.received_ts" - " FROM (" - " SELECT room_id," - " MAX(stream_ordering) as stream_ordering" - " FROM events" - " INNER JOIN receipts_linearized USING (room_id, event_id)" - " WHERE receipt_type = 'm.read' AND user_id = ?" - " GROUP BY room_id" - ") AS rl," - " event_push_actions AS ep" - " INNER JOIN events AS e USING (room_id, event_id)" - " WHERE" - " ep.room_id = rl.room_id" - " AND ep.stream_ordering > rl.stream_ordering" - " AND ep.user_id = ?" - " AND ep.stream_ordering > ?" - " AND ep.stream_ordering <= ?" - " AND ep.notif = 1" - " ORDER BY ep.stream_ordering DESC LIMIT ?" - ) + sql = """ + SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions, + ep.highlight, e.received_ts + FROM ( + SELECT room_id, + MAX(stream_ordering) as stream_ordering + FROM events + INNER JOIN receipts_linearized USING (room_id, event_id) + WHERE receipt_type = 'm.read' AND user_id = ? + GROUP BY room_id + ) AS rl, + event_push_actions AS ep + INNER JOIN events AS e USING (room_id, event_id) + WHERE + ep.room_id = rl.room_id + AND ep.stream_ordering > rl.stream_ordering + AND ep.user_id = ? + AND ep.stream_ordering > ? + AND ep.stream_ordering <= ? + AND ep.notif = 1 + ORDER BY ep.stream_ordering DESC LIMIT ? + """ args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] txn.execute(sql, args) return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall()) @@ -526,23 +537,23 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas def get_no_receipt( txn: LoggingTransaction, ) -> List[Tuple[str, str, int, str, bool, int]]: - sql = ( - "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," - " ep.highlight, e.received_ts" - " FROM event_push_actions AS ep" - " INNER JOIN events AS e USING (room_id, event_id)" - " WHERE" - " ep.room_id NOT IN (" - " SELECT room_id FROM receipts_linearized" - " WHERE receipt_type = 'm.read' AND user_id = ?" - " GROUP BY room_id" - " )" - " AND ep.user_id = ?" - " AND ep.stream_ordering > ?" - " AND ep.stream_ordering <= ?" - " AND ep.notif = 1" - " ORDER BY ep.stream_ordering DESC LIMIT ?" - ) + sql = """ + SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions, + ep.highlight, e.received_ts + FROM event_push_actions AS ep + INNER JOIN events AS e USING (room_id, event_id) + WHERE + ep.room_id NOT IN ( + SELECT room_id FROM receipts_linearized + WHERE receipt_type = 'm.read' AND user_id = ? + GROUP BY room_id + ) + AND ep.user_id = ? + AND ep.stream_ordering > ? + AND ep.stream_ordering <= ? + AND ep.notif = 1 + ORDER BY ep.stream_ordering DESC LIMIT ? + """ args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] txn.execute(sql, args) return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall()) @@ -769,12 +780,12 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas # [10, , 20], we should treat this as being equivalent to # [10, 10, 20]. # - sql = ( - "SELECT received_ts FROM events" - " WHERE stream_ordering <= ?" - " ORDER BY stream_ordering DESC" - " LIMIT 1" - ) + sql = """ + SELECT received_ts FROM events + WHERE stream_ordering <= ? + ORDER BY stream_ordering DESC + LIMIT 1 + """ while range_end - range_start > 0: middle = (range_end + range_start) // 2 @@ -802,14 +813,14 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas self, stream_ordering: int ) -> Optional[int]: def f(txn: LoggingTransaction) -> Optional[Tuple[int]]: - sql = ( - "SELECT e.received_ts" - " FROM event_push_actions AS ep" - " JOIN events e ON ep.room_id = e.room_id AND ep.event_id = e.event_id" - " WHERE ep.stream_ordering > ? AND notif = 1" - " ORDER BY ep.stream_ordering ASC" - " LIMIT 1" - ) + sql = """ + SELECT e.received_ts + FROM event_push_actions AS ep + JOIN events e ON ep.room_id = e.room_id AND ep.event_id = e.event_id + WHERE ep.stream_ordering > ? AND notif = 1 + ORDER BY ep.stream_ordering ASC + LIMIT 1 + """ txn.execute(sql, (stream_ordering,)) return cast(Optional[Tuple[int]], txn.fetchone()) @@ -858,10 +869,13 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas Any push actions which predate the user's most recent read receipt are now redundant, so we can remove them from `event_push_actions` and update `event_push_summary`. + + Returns true if all new receipts have been processed. """ limit = 100 + # The (inclusive) receipt stream ID that was previously processed.. min_receipts_stream_id = self.db_pool.simple_select_one_onecol_txn( txn, table="event_push_summary_last_receipt_stream_id", @@ -871,6 +885,14 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas max_receipts_stream_id = self._receipts_id_gen.get_current_token() + # The (inclusive) event stream ordering that was previously summarised. + old_rotate_stream_ordering = self.db_pool.simple_select_one_onecol_txn( + txn, + table="event_push_summary_stream_ordering", + keyvalues={}, + retcol="stream_ordering", + ) + sql = """ SELECT r.stream_id, r.room_id, r.user_id, e.stream_ordering FROM receipts_linearized AS r @@ -895,13 +917,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas ) rows = txn.fetchall() - old_rotate_stream_ordering = self.db_pool.simple_select_one_onecol_txn( - txn, - table="event_push_summary_stream_ordering", - keyvalues={}, - retcol="stream_ordering", - ) - # For each new read receipt we delete push actions from before it and # recalculate the summary. for _, room_id, user_id, stream_ordering in rows: @@ -920,10 +935,13 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas (room_id, user_id, stream_ordering), ) + # Fetch the notification counts between the stream ordering of the + # latest receipt and what was previously summarised. notif_count, unread_count = self._get_notif_unread_count_for_user_room( txn, room_id, user_id, stream_ordering, old_rotate_stream_ordering ) + # Replace the previous summary with the new counts. self.db_pool.simple_upsert_txn( txn, table="event_push_summary", @@ -956,10 +974,12 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas return len(rows) < limit def _rotate_notifs_txn(self, txn: LoggingTransaction) -> bool: - """Archives older notifications into event_push_summary. Returns whether - the archiving process has caught up or not. + """Archives older notifications (from event_push_actions) into event_push_summary. + + Returns whether the archiving process has caught up or not. """ + # The (inclusive) event stream ordering that was previously summarised. old_rotate_stream_ordering = self.db_pool.simple_select_one_onecol_txn( txn, table="event_push_summary_stream_ordering", @@ -974,7 +994,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas SELECT stream_ordering FROM event_push_actions WHERE stream_ordering > ? ORDER BY stream_ordering ASC LIMIT 1 OFFSET ? - """, + """, (old_rotate_stream_ordering, self._rotate_count), ) stream_row = txn.fetchone() @@ -993,19 +1013,31 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas logger.info("Rotating notifications up to: %s", rotate_to_stream_ordering) - self._rotate_notifs_before_txn(txn, rotate_to_stream_ordering) + self._rotate_notifs_before_txn( + txn, old_rotate_stream_ordering, rotate_to_stream_ordering + ) return caught_up def _rotate_notifs_before_txn( - self, txn: LoggingTransaction, rotate_to_stream_ordering: int + self, + txn: LoggingTransaction, + old_rotate_stream_ordering: int, + rotate_to_stream_ordering: int, ) -> None: - old_rotate_stream_ordering = self.db_pool.simple_select_one_onecol_txn( - txn, - table="event_push_summary_stream_ordering", - keyvalues={}, - retcol="stream_ordering", - ) + """Archives older notifications (from event_push_actions) into event_push_summary. + + Any event_push_actions between old_rotate_stream_ordering (exclusive) and + rotate_to_stream_ordering (inclusive) will be added to the event_push_summary + table. + + Args: + txn: The database transaction. + old_rotate_stream_ordering: The previous maximum event stream ordering. + rotate_to_stream_ordering: The new maximum event stream ordering to summarise. + + Returns whether the archiving process has caught up or not. + """ # Calculate the new counts that should be upserted into event_push_summary sql = """ @@ -1093,9 +1125,9 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas async def _remove_old_push_actions_that_have_rotated( self, ) -> None: - """Clear out old push actions that have been summarized.""" + """Clear out old push actions that have been summarised.""" - # We want to clear out anything that older than a day that *has* already + # We want to clear out anything that is older than a day that *has* already # been rotated. rotated_upto_stream_ordering = await self.db_pool.simple_select_one_onecol( table="event_push_summary_stream_ordering", @@ -1119,7 +1151,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas SELECT stream_ordering FROM event_push_actions WHERE stream_ordering <= ? AND highlight = 0 ORDER BY stream_ordering ASC LIMIT 1 OFFSET ? - """, + """, ( max_stream_ordering_to_delete, batch_size, @@ -1215,16 +1247,18 @@ class EventPushActionsStore(EventPushActionsWorkerStore): # NB. This assumes event_ids are globally unique since # it makes the query easier to index - sql = ( - "SELECT epa.event_id, epa.room_id," - " epa.stream_ordering, epa.topological_ordering," - " epa.actions, epa.highlight, epa.profile_tag, e.received_ts" - " FROM event_push_actions epa, events e" - " WHERE epa.event_id = e.event_id" - " AND epa.user_id = ? %s" - " AND epa.notif = 1" - " ORDER BY epa.stream_ordering DESC" - " LIMIT ?" % (before_clause,) + sql = """ + SELECT epa.event_id, epa.room_id, + epa.stream_ordering, epa.topological_ordering, + epa.actions, epa.highlight, epa.profile_tag, e.received_ts + FROM event_push_actions epa, events e + WHERE epa.event_id = e.event_id + AND epa.user_id = ? %s + AND epa.notif = 1 + ORDER BY epa.stream_ordering DESC + LIMIT ? + """ % ( + before_clause, ) txn.execute(sql, args) return cast( -- cgit 1.5.1 From b6a6bb4027c1a812361ac127b8c5ea1226be295d Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 4 Aug 2022 20:38:08 +0100 Subject: Add comments about how event push actions are stored. (#13445) --- changelog.d/13445.misc | 1 + .../storage/databases/main/event_push_actions.py | 61 ++++++++++++++++++++++ 2 files changed, 62 insertions(+) create mode 100644 changelog.d/13445.misc (limited to 'synapse/storage/databases') diff --git a/changelog.d/13445.misc b/changelog.d/13445.misc new file mode 100644 index 0000000000..17462c56f3 --- /dev/null +++ b/changelog.d/13445.misc @@ -0,0 +1 @@ +Add some comments about how event push actions are stored. diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index 5ddddb1cf3..5db70f9a60 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -12,6 +12,67 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +"""Responsible for storing and fetching push actions / notifications. + +There are two main uses for push actions: + 1. Sending out push to a user's device; and + 2. Tracking per-room per-user notification counts (used in sync requests). + +For the former we simply use the `event_push_actions` table, which contains all +the calculated actions for a given user (which were calculated by the +`BulkPushRuleEvaluator`). + +For the latter we could simply count the number of rows in `event_push_actions` +table for a given room/user, but in practice this is *very* heavyweight when +there were a large number of notifications (due to e.g. the user never reading a +room). Plus, keeping all push actions indefinitely uses a lot of disk space. + +To fix these issues, we add a new table `event_push_summary` that tracks +per-user per-room counts of all notifications that happened before a stream +ordering S. Thus, to get the notification count for a user / room we can simply +query a single row in `event_push_summary` and count the number of rows in +`event_push_actions` with a stream ordering larger than S (and as long as S is +"recent", the number of rows needing to be scanned will be small). + +The `event_push_summary` table is updated via a background job that periodically +chooses a new stream ordering S' (usually the latest stream ordering), counts +all notifications in `event_push_actions` between the existing S and S', and +adds them to the existing counts in `event_push_summary`. + +This allows us to delete old rows from `event_push_actions` once those rows have +been counted and added to `event_push_summary` (we call this process +"rotation"). + + +We need to handle when a user sends a read receipt to the room. Again this is +done as a background process. For each receipt we clear the row in +`event_push_summary` and count the number of notifications in +`event_push_actions` that happened after the receipt but before S, and insert +that count into `event_push_summary` (If the receipt happened *after* S then we +simply clear the `event_push_summary`.) + +Note that its possible that if the read receipt is for an old event the relevant +`event_push_actions` rows will have been rotated and we get the wrong count +(it'll be too low). We accept this as a rare edge case that is unlikely to +impact the user much (since the vast majority of read receipts will be for the +latest event). + +The last complication is to handle the race where we request the notifications +counts after a user sends a read receipt into the room, but *before* the +background update handles the receipt (without any special handling the counts +would be outdated). We fix this by including in `event_push_summary` the read +receipt we used when updating `event_push_summary`, and every time we query the +table we check if that matches the most recent read receipt in the room. If yes, +continue as above, if not we simply query the `event_push_actions` table +directly. + +Since read receipts are almost always for recent events, scanning the +`event_push_actions` table in this case is unlikely to be a problem. Even if it +is a problem, it is temporary until the background job handles the new read +receipt. +""" + import logging from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast -- cgit 1.5.1 From ab18441573dc14cea1fe4082b2a89b9d392a4b9f Mon Sep 17 00:00:00 2001 From: Šimon Brandner Date: Fri, 5 Aug 2022 17:09:33 +0200 Subject: Support stable identifiers for MSC2285: private read receipts. (#13273) This adds support for the stable identifiers of MSC2285 while continuing to support the unstable identifiers behind the configuration flag. These will be removed in a future version. --- changelog.d/13273.feature | 1 + synapse/api/constants.py | 3 +- synapse/config/experimental.py | 2 +- synapse/handlers/initial_sync.py | 11 +-- synapse/handlers/receipts.py | 36 ++++++--- synapse/replication/tcp/client.py | 5 +- synapse/rest/client/notifications.py | 7 +- synapse/rest/client/read_marker.py | 8 +- synapse/rest/client/receipts.py | 10 ++- synapse/rest/client/versions.py | 1 + .../storage/databases/main/event_push_actions.py | 85 ++++++++++++++++++---- tests/handlers/test_receipts.py | 58 +++++++++++---- tests/rest/client/test_sync.py | 58 ++++++++++----- tests/storage/test_receipts.py | 55 +++++++++----- 14 files changed, 246 insertions(+), 94 deletions(-) create mode 100644 changelog.d/13273.feature (limited to 'synapse/storage/databases') diff --git a/changelog.d/13273.feature b/changelog.d/13273.feature new file mode 100644 index 0000000000..53110d74e9 --- /dev/null +++ b/changelog.d/13273.feature @@ -0,0 +1 @@ +Add support for stable prefixes for [MSC2285 (private read receipts)](https://github.com/matrix-org/matrix-spec-proposals/pull/2285). diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 789859e69e..1d46fb0e43 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -257,7 +257,8 @@ class GuestAccess: class ReceiptTypes: READ: Final = "m.read" - READ_PRIVATE: Final = "org.matrix.msc2285.read.private" + READ_PRIVATE: Final = "m.read.private" + UNSTABLE_READ_PRIVATE: Final = "org.matrix.msc2285.read.private" FULLY_READ: Final = "m.fully_read" diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index c2ecd977cd..7d17c958bb 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -32,7 +32,7 @@ class ExperimentalConfig(Config): # MSC2716 (importing historical messages) self.msc2716_enabled: bool = experimental.get("msc2716_enabled", False) - # MSC2285 (private read receipts) + # MSC2285 (unstable private read receipts) self.msc2285_enabled: bool = experimental.get("msc2285_enabled", False) # MSC3244 (room version capabilities) diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 85b472f250..6484e47e5f 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -143,8 +143,8 @@ class InitialSyncHandler: joined_rooms, to_key=int(now_token.receipt_key), ) - if self.hs.config.experimental.msc2285_enabled: - receipt = ReceiptEventSource.filter_out_private_receipts(receipt, user_id) + + receipt = ReceiptEventSource.filter_out_private_receipts(receipt, user_id) tags_by_room = await self.store.get_tags_for_user(user_id) @@ -456,11 +456,8 @@ class InitialSyncHandler: ) if not receipts: return [] - if self.hs.config.experimental.msc2285_enabled: - receipts = ReceiptEventSource.filter_out_private_receipts( - receipts, user_id - ) - return receipts + + return ReceiptEventSource.filter_out_private_receipts(receipts, user_id) presence, receipts, (messages, token) = await make_deferred_yieldable( gather_results( diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 43d2882b0a..d4a866b346 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -163,7 +163,10 @@ class ReceiptsHandler: if not is_new: return - if self.federation_sender and receipt_type != ReceiptTypes.READ_PRIVATE: + if self.federation_sender and receipt_type not in ( + ReceiptTypes.READ_PRIVATE, + ReceiptTypes.UNSTABLE_READ_PRIVATE, + ): await self.federation_sender.send_read_receipt(receipt) @@ -203,24 +206,38 @@ class ReceiptEventSource(EventSource[int, JsonDict]): for event_id, orig_event_content in room.get("content", {}).items(): event_content = orig_event_content # If there are private read receipts, additional logic is necessary. - if ReceiptTypes.READ_PRIVATE in event_content: + if ( + ReceiptTypes.READ_PRIVATE in event_content + or ReceiptTypes.UNSTABLE_READ_PRIVATE in event_content + ): # Make a copy without private read receipts to avoid leaking # other user's private read receipts.. event_content = { receipt_type: receipt_value for receipt_type, receipt_value in event_content.items() - if receipt_type != ReceiptTypes.READ_PRIVATE + if receipt_type + not in ( + ReceiptTypes.READ_PRIVATE, + ReceiptTypes.UNSTABLE_READ_PRIVATE, + ) } # Copy the current user's private read receipt from the # original content, if it exists. - user_private_read_receipt = orig_event_content[ - ReceiptTypes.READ_PRIVATE - ].get(user_id, None) + user_private_read_receipt = orig_event_content.get( + ReceiptTypes.READ_PRIVATE, {} + ).get(user_id, None) if user_private_read_receipt: event_content[ReceiptTypes.READ_PRIVATE] = { user_id: user_private_read_receipt } + user_unstable_private_read_receipt = orig_event_content.get( + ReceiptTypes.UNSTABLE_READ_PRIVATE, {} + ).get(user_id, None) + if user_unstable_private_read_receipt: + event_content[ReceiptTypes.UNSTABLE_READ_PRIVATE] = { + user_id: user_unstable_private_read_receipt + } # Include the event if there is at least one non-private read # receipt or the current user has a private read receipt. @@ -256,10 +273,9 @@ class ReceiptEventSource(EventSource[int, JsonDict]): room_ids, from_key=from_key, to_key=to_key ) - if self.config.experimental.msc2285_enabled: - events = ReceiptEventSource.filter_out_private_receipts( - events, user.to_string() - ) + events = ReceiptEventSource.filter_out_private_receipts( + events, user.to_string() + ) return events, to_key diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index e4f2201c92..1ed7230e32 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -416,7 +416,10 @@ class FederationSenderHandler: if not self._is_mine_id(receipt.user_id): continue # Private read receipts never get sent over federation. - if receipt.receipt_type == ReceiptTypes.READ_PRIVATE: + if receipt.receipt_type in ( + ReceiptTypes.READ_PRIVATE, + ReceiptTypes.UNSTABLE_READ_PRIVATE, + ): continue receipt_info = ReadReceipt( receipt.room_id, diff --git a/synapse/rest/client/notifications.py b/synapse/rest/client/notifications.py index 24bc7c9095..a73322a6a4 100644 --- a/synapse/rest/client/notifications.py +++ b/synapse/rest/client/notifications.py @@ -58,7 +58,12 @@ class NotificationsServlet(RestServlet): ) receipts_by_room = await self.store.get_receipts_for_user_with_orderings( - user_id, [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE] + user_id, + [ + ReceiptTypes.READ, + ReceiptTypes.READ_PRIVATE, + ReceiptTypes.UNSTABLE_READ_PRIVATE, + ], ) notif_event_ids = [pa.event_id for pa in push_actions] diff --git a/synapse/rest/client/read_marker.py b/synapse/rest/client/read_marker.py index 8896f2df50..aaad8b233f 100644 --- a/synapse/rest/client/read_marker.py +++ b/synapse/rest/client/read_marker.py @@ -40,9 +40,13 @@ class ReadMarkerRestServlet(RestServlet): self.read_marker_handler = hs.get_read_marker_handler() self.presence_handler = hs.get_presence_handler() - self._known_receipt_types = {ReceiptTypes.READ, ReceiptTypes.FULLY_READ} + self._known_receipt_types = { + ReceiptTypes.READ, + ReceiptTypes.FULLY_READ, + ReceiptTypes.READ_PRIVATE, + } if hs.config.experimental.msc2285_enabled: - self._known_receipt_types.add(ReceiptTypes.READ_PRIVATE) + self._known_receipt_types.add(ReceiptTypes.UNSTABLE_READ_PRIVATE) async def on_POST( self, request: SynapseRequest, room_id: str diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py index 409bfd43c1..c6108fc5eb 100644 --- a/synapse/rest/client/receipts.py +++ b/synapse/rest/client/receipts.py @@ -44,11 +44,13 @@ class ReceiptRestServlet(RestServlet): self.read_marker_handler = hs.get_read_marker_handler() self.presence_handler = hs.get_presence_handler() - self._known_receipt_types = {ReceiptTypes.READ} + self._known_receipt_types = { + ReceiptTypes.READ, + ReceiptTypes.READ_PRIVATE, + ReceiptTypes.FULLY_READ, + } if hs.config.experimental.msc2285_enabled: - self._known_receipt_types.update( - (ReceiptTypes.READ_PRIVATE, ReceiptTypes.FULLY_READ) - ) + self._known_receipt_types.add(ReceiptTypes.UNSTABLE_READ_PRIVATE) async def on_POST( self, request: SynapseRequest, room_id: str, receipt_type: str, event_id: str diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 0366986755..c9a830cbac 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -94,6 +94,7 @@ class VersionsRestServlet(RestServlet): # Supports the busy presence state described in MSC3026. "org.matrix.msc3026.busy_presence": self.config.experimental.msc3026_enabled, # Supports receiving private read receipts as per MSC2285 + "org.matrix.msc2285.stable": True, # TODO: Remove when MSC2285 becomes a part of the spec "org.matrix.msc2285": self.config.experimental.msc2285_enabled, # Supports filtering of /publicRooms by room type as per MSC3827 "org.matrix.msc3827.stable": True, diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index 5db70f9a60..161aad0f89 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -80,7 +80,7 @@ import attr from synapse.api.constants import ReceiptTypes from synapse.metrics.background_process_metrics import wrap_as_background_process -from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import ( DatabasePool, LoggingDatabaseConnection, @@ -259,7 +259,11 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas txn, user_id, room_id, - receipt_types=(ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE), + receipt_types=( + ReceiptTypes.READ, + ReceiptTypes.READ_PRIVATE, + ReceiptTypes.UNSTABLE_READ_PRIVATE, + ), ) stream_ordering = None @@ -448,6 +452,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas The list will be ordered by ascending stream_ordering. The list will have between 0~limit entries. """ + # find rooms that have a read receipt in them and return the next # push actions def get_after_receipt( @@ -455,7 +460,18 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas ) -> List[Tuple[str, str, int, str, bool]]: # find rooms that have a read receipt in them and return the next # push actions - sql = """ + + receipt_types_clause, args = make_in_list_sql_clause( + self.database_engine, + "receipt_type", + ( + ReceiptTypes.READ, + ReceiptTypes.READ_PRIVATE, + ReceiptTypes.UNSTABLE_READ_PRIVATE, + ), + ) + + sql = f""" SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions, ep.highlight FROM ( @@ -463,10 +479,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas MAX(stream_ordering) as stream_ordering FROM events INNER JOIN receipts_linearized USING (room_id, event_id) - WHERE receipt_type = 'm.read' AND user_id = ? + WHERE {receipt_types_clause} AND user_id = ? GROUP BY room_id ) AS rl, - event_push_actions AS ep + event_push_actions AS ep WHERE ep.room_id = rl.room_id AND ep.stream_ordering > rl.stream_ordering @@ -476,7 +492,9 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas AND ep.notif = 1 ORDER BY ep.stream_ordering ASC LIMIT ? """ - args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] + args.extend( + (user_id, user_id, min_stream_ordering, max_stream_ordering, limit) + ) txn.execute(sql, args) return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall()) @@ -490,7 +508,17 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas def get_no_receipt( txn: LoggingTransaction, ) -> List[Tuple[str, str, int, str, bool]]: - sql = """ + receipt_types_clause, args = make_in_list_sql_clause( + self.database_engine, + "receipt_type", + ( + ReceiptTypes.READ, + ReceiptTypes.READ_PRIVATE, + ReceiptTypes.UNSTABLE_READ_PRIVATE, + ), + ) + + sql = f""" SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions, ep.highlight FROM event_push_actions AS ep @@ -498,7 +526,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas WHERE ep.room_id NOT IN ( SELECT room_id FROM receipts_linearized - WHERE receipt_type = 'm.read' AND user_id = ? + WHERE {receipt_types_clause} AND user_id = ? GROUP BY room_id ) AND ep.user_id = ? @@ -507,7 +535,9 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas AND ep.notif = 1 ORDER BY ep.stream_ordering ASC LIMIT ? """ - args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] + args.extend( + (user_id, user_id, min_stream_ordering, max_stream_ordering, limit) + ) txn.execute(sql, args) return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall()) @@ -557,12 +587,23 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas The list will be ordered by descending received_ts. The list will have between 0~limit entries. """ + # find rooms that have a read receipt in them and return the most recent # push actions def get_after_receipt( txn: LoggingTransaction, ) -> List[Tuple[str, str, int, str, bool, int]]: - sql = """ + receipt_types_clause, args = make_in_list_sql_clause( + self.database_engine, + "receipt_type", + ( + ReceiptTypes.READ, + ReceiptTypes.READ_PRIVATE, + ReceiptTypes.UNSTABLE_READ_PRIVATE, + ), + ) + + sql = f""" SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions, ep.highlight, e.received_ts FROM ( @@ -570,7 +611,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas MAX(stream_ordering) as stream_ordering FROM events INNER JOIN receipts_linearized USING (room_id, event_id) - WHERE receipt_type = 'm.read' AND user_id = ? + WHERE {receipt_types_clause} AND user_id = ? GROUP BY room_id ) AS rl, event_push_actions AS ep @@ -584,7 +625,9 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas AND ep.notif = 1 ORDER BY ep.stream_ordering DESC LIMIT ? """ - args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] + args.extend( + (user_id, user_id, min_stream_ordering, max_stream_ordering, limit) + ) txn.execute(sql, args) return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall()) @@ -598,7 +641,17 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas def get_no_receipt( txn: LoggingTransaction, ) -> List[Tuple[str, str, int, str, bool, int]]: - sql = """ + receipt_types_clause, args = make_in_list_sql_clause( + self.database_engine, + "receipt_type", + ( + ReceiptTypes.READ, + ReceiptTypes.READ_PRIVATE, + ReceiptTypes.UNSTABLE_READ_PRIVATE, + ), + ) + + sql = f""" SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions, ep.highlight, e.received_ts FROM event_push_actions AS ep @@ -606,7 +659,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas WHERE ep.room_id NOT IN ( SELECT room_id FROM receipts_linearized - WHERE receipt_type = 'm.read' AND user_id = ? + WHERE {receipt_types_clause} AND user_id = ? GROUP BY room_id ) AND ep.user_id = ? @@ -615,7 +668,9 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas AND ep.notif = 1 ORDER BY ep.stream_ordering DESC LIMIT ? """ - args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] + args.extend( + (user_id, user_id, min_stream_ordering, max_stream_ordering, limit) + ) txn.execute(sql, args) return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall()) diff --git a/tests/handlers/test_receipts.py b/tests/handlers/test_receipts.py index a95868b5c0..5f70a2db79 100644 --- a/tests/handlers/test_receipts.py +++ b/tests/handlers/test_receipts.py @@ -15,6 +15,8 @@ from copy import deepcopy from typing import List +from parameterized import parameterized + from synapse.api.constants import EduTypes, ReceiptTypes from synapse.types import JsonDict @@ -25,13 +27,16 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.event_source = hs.get_event_sources().sources.receipt - def test_filters_out_private_receipt(self): + @parameterized.expand( + [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] + ) + def test_filters_out_private_receipt(self, receipt_type: str) -> None: self._test_filters_private( [ { "content": { "$1435641916114394fHBLK:matrix.org": { - ReceiptTypes.READ_PRIVATE: { + receipt_type: { "@rikj:jki.re": { "ts": 1436451550453, } @@ -45,13 +50,18 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): [], ) - def test_filters_out_private_receipt_and_ignores_rest(self): + @parameterized.expand( + [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] + ) + def test_filters_out_private_receipt_and_ignores_rest( + self, receipt_type: str + ) -> None: self._test_filters_private( [ { "content": { "$1dgdgrd5641916114394fHBLK:matrix.org": { - ReceiptTypes.READ_PRIVATE: { + receipt_type: { "@rikj:jki.re": { "ts": 1436451550453, }, @@ -84,13 +94,18 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): ], ) - def test_filters_out_event_with_only_private_receipts_and_ignores_the_rest(self): + @parameterized.expand( + [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] + ) + def test_filters_out_event_with_only_private_receipts_and_ignores_the_rest( + self, receipt_type: str + ) -> None: self._test_filters_private( [ { "content": { "$14356419edgd14394fHBLK:matrix.org": { - ReceiptTypes.READ_PRIVATE: { + receipt_type: { "@rikj:jki.re": { "ts": 1436451550453, }, @@ -125,7 +140,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): ], ) - def test_handles_empty_event(self): + def test_handles_empty_event(self) -> None: self._test_filters_private( [ { @@ -160,13 +175,18 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): ], ) - def test_filters_out_receipt_event_with_only_private_receipt_and_ignores_rest(self): + @parameterized.expand( + [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] + ) + def test_filters_out_receipt_event_with_only_private_receipt_and_ignores_rest( + self, receipt_type: str + ) -> None: self._test_filters_private( [ { "content": { "$14356419edgd14394fHBLK:matrix.org": { - ReceiptTypes.READ_PRIVATE: { + receipt_type: { "@rikj:jki.re": { "ts": 1436451550453, }, @@ -207,7 +227,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): ], ) - def test_handles_string_data(self): + def test_handles_string_data(self) -> None: """ Tests that an invalid shape for read-receipts is handled. Context: https://github.com/matrix-org/synapse/issues/10603 @@ -242,13 +262,16 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): ], ) - def test_leaves_our_private_and_their_public(self): + @parameterized.expand( + [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] + ) + def test_leaves_our_private_and_their_public(self, receipt_type: str) -> None: self._test_filters_private( [ { "content": { "$1dgdgrd5641916114394fHBLK:matrix.org": { - ReceiptTypes.READ_PRIVATE: { + receipt_type: { "@me:server.org": { "ts": 1436451550453, }, @@ -273,7 +296,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): { "content": { "$1dgdgrd5641916114394fHBLK:matrix.org": { - ReceiptTypes.READ_PRIVATE: { + receipt_type: { "@me:server.org": { "ts": 1436451550453, }, @@ -296,13 +319,16 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): ], ) - def test_we_do_not_mutate(self): + @parameterized.expand( + [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] + ) + def test_we_do_not_mutate(self, receipt_type: str) -> None: """Ensure the input values are not modified.""" events = [ { "content": { "$1435641916114394fHBLK:matrix.org": { - ReceiptTypes.READ_PRIVATE: { + receipt_type: { "@rikj:jki.re": { "ts": 1436451550453, } @@ -320,7 +346,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): def _test_filters_private( self, events: List[JsonDict], expected_output: List[JsonDict] - ): + ) -> None: """Tests that the _filter_out_private returns the expected output""" filtered_events = self.event_source.filter_out_private_receipts( events, "@me:server.org" diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py index ae16184828..de0dec8539 100644 --- a/tests/rest/client/test_sync.py +++ b/tests/rest/client/test_sync.py @@ -38,7 +38,6 @@ from tests.federation.transport.test_knocking import ( KnockingStrippedStateEventHelperMixin, ) from tests.server import TimedOutException -from tests.unittest import override_config class FilterTestCase(unittest.HomeserverTestCase): @@ -390,6 +389,12 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): sync.register_servlets, ] + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + config = self.default_config() + config["experimental_features"] = {"msc2285_enabled": True} + + return self.setup_test_homeserver(config=config) + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.url = "/sync?since=%s" self.next_batch = "s0" @@ -408,15 +413,17 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): # Join the second user self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2) - @override_config({"experimental_features": {"msc2285_enabled": True}}) - def test_private_read_receipts(self) -> None: + @parameterized.expand( + [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] + ) + def test_private_read_receipts(self, receipt_type: str) -> None: # Send a message as the first user res = self.helper.send(self.room_id, body="hello", tok=self.tok) # Send a private read receipt to tell the server the first user's message was read channel = self.make_request( "POST", - f"/rooms/{self.room_id}/receipt/org.matrix.msc2285.read.private/{res['event_id']}", + f"/rooms/{self.room_id}/receipt/{receipt_type}/{res['event_id']}", {}, access_token=self.tok2, ) @@ -425,8 +432,10 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): # Test that the first user can't see the other user's private read receipt self.assertIsNone(self._get_read_receipt()) - @override_config({"experimental_features": {"msc2285_enabled": True}}) - def test_public_receipt_can_override_private(self) -> None: + @parameterized.expand( + [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] + ) + def test_public_receipt_can_override_private(self, receipt_type: str) -> None: """ Sending a public read receipt to the same event which has a private read receipt should cause that receipt to become public. @@ -437,7 +446,7 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): # Send a private read receipt channel = self.make_request( "POST", - f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}", + f"/rooms/{self.room_id}/receipt/{receipt_type}/{res['event_id']}", {}, access_token=self.tok2, ) @@ -456,8 +465,10 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): # Test that we did override the private read receipt self.assertNotEqual(self._get_read_receipt(), None) - @override_config({"experimental_features": {"msc2285_enabled": True}}) - def test_private_receipt_cannot_override_public(self) -> None: + @parameterized.expand( + [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] + ) + def test_private_receipt_cannot_override_public(self, receipt_type: str) -> None: """ Sending a private read receipt to the same event which has a public read receipt should cause no change. @@ -478,7 +489,7 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): # Send a private read receipt channel = self.make_request( "POST", - f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}", + f"/rooms/{self.room_id}/receipt/{receipt_type}/{res['event_id']}", {}, access_token=self.tok2, ) @@ -590,7 +601,10 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): tok=self.tok, ) - def test_unread_counts(self) -> None: + @parameterized.expand( + [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] + ) + def test_unread_counts(self, receipt_type: str) -> None: """Tests that /sync returns the right value for the unread count (MSC2654).""" # Check that our own messages don't increase the unread count. @@ -624,7 +638,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): # Send a read receipt to tell the server we've read the latest event. channel = self.make_request( "POST", - f"/rooms/{self.room_id}/receipt/org.matrix.msc2285.read.private/{res['event_id']}", + f"/rooms/{self.room_id}/receipt/{receipt_type}/{res['event_id']}", {}, access_token=self.tok, ) @@ -700,7 +714,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): self._check_unread_count(5) res2 = self.helper.send(self.room_id, "hello", tok=self.tok2) - # Make sure both m.read and org.matrix.msc2285.read.private advance + # Make sure both m.read and m.read.private advance channel = self.make_request( "POST", f"/rooms/{self.room_id}/receipt/m.read/{res1['event_id']}", @@ -712,16 +726,22 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", - f"/rooms/{self.room_id}/receipt/org.matrix.msc2285.read.private/{res2['event_id']}", + f"/rooms/{self.room_id}/receipt/{receipt_type}/{res2['event_id']}", {}, access_token=self.tok, ) self.assertEqual(channel.code, 200, channel.json_body) self._check_unread_count(0) - # We test for both receipt types that influence notification counts - @parameterized.expand([ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE]) - def test_read_receipts_only_go_down(self, receipt_type: ReceiptTypes) -> None: + # We test for all three receipt types that influence notification counts + @parameterized.expand( + [ + ReceiptTypes.READ, + ReceiptTypes.READ_PRIVATE, + ReceiptTypes.UNSTABLE_READ_PRIVATE, + ] + ) + def test_read_receipts_only_go_down(self, receipt_type: str) -> None: # Join the new user self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2) @@ -739,11 +759,11 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200, channel.json_body) self._check_unread_count(0) - # Make sure neither m.read nor org.matrix.msc2285.read.private make the + # Make sure neither m.read nor m.read.private make the # read receipt go up to an older event channel = self.make_request( "POST", - f"/rooms/{self.room_id}/receipt/org.matrix.msc2285.read.private/{res1['event_id']}", + f"/rooms/{self.room_id}/receipt/{receipt_type}/{res1['event_id']}", {}, access_token=self.tok, ) diff --git a/tests/storage/test_receipts.py b/tests/storage/test_receipts.py index b1a8f8bba7..191c957fb5 100644 --- a/tests/storage/test_receipts.py +++ b/tests/storage/test_receipts.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from parameterized import parameterized + from synapse.api.constants import ReceiptTypes from synapse.types import UserID, create_requester @@ -23,7 +25,7 @@ OUR_USER_ID = "@our:test" class ReceiptTestCase(HomeserverTestCase): - def prepare(self, reactor, clock, homeserver): + def prepare(self, reactor, clock, homeserver) -> None: super().prepare(reactor, clock, homeserver) self.store = homeserver.get_datastores().main @@ -83,10 +85,15 @@ class ReceiptTestCase(HomeserverTestCase): ) ) - def test_return_empty_with_no_data(self): + def test_return_empty_with_no_data(self) -> None: res = self.get_success( self.store.get_receipts_for_user( - OUR_USER_ID, [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE] + OUR_USER_ID, + [ + ReceiptTypes.READ, + ReceiptTypes.READ_PRIVATE, + ReceiptTypes.UNSTABLE_READ_PRIVATE, + ], ) ) self.assertEqual(res, {}) @@ -94,7 +101,11 @@ class ReceiptTestCase(HomeserverTestCase): res = self.get_success( self.store.get_receipts_for_user_with_orderings( OUR_USER_ID, - [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE], + [ + ReceiptTypes.READ, + ReceiptTypes.READ_PRIVATE, + ReceiptTypes.UNSTABLE_READ_PRIVATE, + ], ) ) self.assertEqual(res, {}) @@ -103,12 +114,19 @@ class ReceiptTestCase(HomeserverTestCase): self.store.get_last_receipt_event_id_for_user( OUR_USER_ID, self.room_id1, - [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE], + [ + ReceiptTypes.READ, + ReceiptTypes.READ_PRIVATE, + ReceiptTypes.UNSTABLE_READ_PRIVATE, + ], ) ) self.assertEqual(res, None) - def test_get_receipts_for_user(self): + @parameterized.expand( + [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] + ) + def test_get_receipts_for_user(self, receipt_type: str) -> None: # Send some events into the first room event1_1_id = self.create_and_send_event( self.room_id1, UserID.from_string(OTHER_USER_ID) @@ -126,14 +144,14 @@ class ReceiptTestCase(HomeserverTestCase): # Send private read receipt for the second event self.get_success( self.store.insert_receipt( - self.room_id1, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event1_2_id], {} + self.room_id1, receipt_type, OUR_USER_ID, [event1_2_id], {} ) ) # Test we get the latest event when we want both private and public receipts res = self.get_success( self.store.get_receipts_for_user( - OUR_USER_ID, [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE] + OUR_USER_ID, [ReceiptTypes.READ, receipt_type] ) ) self.assertEqual(res, {self.room_id1: event1_2_id}) @@ -146,7 +164,7 @@ class ReceiptTestCase(HomeserverTestCase): # Test we get the latest event when we want only the public receipt res = self.get_success( - self.store.get_receipts_for_user(OUR_USER_ID, [ReceiptTypes.READ_PRIVATE]) + self.store.get_receipts_for_user(OUR_USER_ID, [receipt_type]) ) self.assertEqual(res, {self.room_id1: event1_2_id}) @@ -169,17 +187,20 @@ class ReceiptTestCase(HomeserverTestCase): # Test new room is reflected in what the method returns self.get_success( self.store.insert_receipt( - self.room_id2, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event2_1_id], {} + self.room_id2, receipt_type, OUR_USER_ID, [event2_1_id], {} ) ) res = self.get_success( self.store.get_receipts_for_user( - OUR_USER_ID, [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE] + OUR_USER_ID, [ReceiptTypes.READ, receipt_type] ) ) self.assertEqual(res, {self.room_id1: event1_2_id, self.room_id2: event2_1_id}) - def test_get_last_receipt_event_id_for_user(self): + @parameterized.expand( + [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] + ) + def test_get_last_receipt_event_id_for_user(self, receipt_type: str) -> None: # Send some events into the first room event1_1_id = self.create_and_send_event( self.room_id1, UserID.from_string(OTHER_USER_ID) @@ -197,7 +218,7 @@ class ReceiptTestCase(HomeserverTestCase): # Send private read receipt for the second event self.get_success( self.store.insert_receipt( - self.room_id1, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event1_2_id], {} + self.room_id1, receipt_type, OUR_USER_ID, [event1_2_id], {} ) ) @@ -206,7 +227,7 @@ class ReceiptTestCase(HomeserverTestCase): self.store.get_last_receipt_event_id_for_user( OUR_USER_ID, self.room_id1, - [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE], + [ReceiptTypes.READ, receipt_type], ) ) self.assertEqual(res, event1_2_id) @@ -222,7 +243,7 @@ class ReceiptTestCase(HomeserverTestCase): # Test we get the latest event when we want only the private receipt res = self.get_success( self.store.get_last_receipt_event_id_for_user( - OUR_USER_ID, self.room_id1, [ReceiptTypes.READ_PRIVATE] + OUR_USER_ID, self.room_id1, [receipt_type] ) ) self.assertEqual(res, event1_2_id) @@ -248,14 +269,14 @@ class ReceiptTestCase(HomeserverTestCase): # Test new room is reflected in what the method returns self.get_success( self.store.insert_receipt( - self.room_id2, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event2_1_id], {} + self.room_id2, receipt_type, OUR_USER_ID, [event2_1_id], {} ) ) res = self.get_success( self.store.get_last_receipt_event_id_for_user( OUR_USER_ID, self.room_id2, - [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE], + [ReceiptTypes.READ, receipt_type], ) ) self.assertEqual(res, event2_1_id) -- cgit 1.5.1