From 46bd7f4ed9020bbed459c03a11c26d7f7c3093b0 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 15 Aug 2022 09:33:17 -0400 Subject: Clarifications for event push action processing. (#13485) * Clarifies comments. * Fixes an erroneous comment (about return type) added in #13455 (ec24813220f9d54108924dc04aecd24555277b99). * Clarifies the name of a variable. * Simplifies logic of pulling out the latest join for the requesting user. --- synapse/storage/databases/main/receipts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/storage/databases/main/receipts.py') diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 0090c9f225..124c70ad37 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -161,7 +161,7 @@ class ReceiptsWorkerStore(SQLBaseStore): receipt_type: The receipt types to fetch. Returns: - The latest receipt, if one exists. + The event ID and stream ordering of the latest receipt, if one exists. """ clause, args = make_in_list_sql_clause( -- cgit 1.5.1 From 3d9f82efcb9c337197c9f50a88ec3fb541ee08ff Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 9 Sep 2022 07:08:41 -0400 Subject: Use an upsert for `receipts_graph`. (#13752) Instead of a delete, then insert. This was previously done for `receipts_linearized` in 2dc430d36ef793b38d6d79ec8db4ea60588df2ee (#7607). --- changelog.d/13752.misc | 1 + synapse/storage/databases/main/receipts.py | 12 ++++-------- 2 files changed, 5 insertions(+), 8 deletions(-) create mode 100644 changelog.d/13752.misc (limited to 'synapse/storage/databases/main/receipts.py') diff --git a/changelog.d/13752.misc b/changelog.d/13752.misc new file mode 100644 index 0000000000..7624861b9f --- /dev/null +++ b/changelog.d/13752.misc @@ -0,0 +1 @@ +User an additional database query when persisting receipts. diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 124c70ad37..3838409519 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -812,7 +812,7 @@ class ReceiptsWorkerStore(SQLBaseStore): # FIXME: This shouldn't invalidate the whole cache txn.call_after(self._get_linearized_receipts_for_room.invalidate, (room_id,)) - self.db_pool.simple_delete_txn( + self.db_pool.simple_upsert_txn( txn, table="receipts_graph", keyvalues={ @@ -820,17 +820,13 @@ class ReceiptsWorkerStore(SQLBaseStore): "receipt_type": receipt_type, "user_id": user_id, }, - ) - self.db_pool.simple_insert_txn( - txn, - table="receipts_graph", values={ - "room_id": room_id, - "receipt_type": receipt_type, - "user_id": user_id, "event_ids": json_encoder.encode(event_ids), "data": json_encoder.encode(data), }, + # receipts_graph has a unique constraint on + # (user_id, room_id, receipt_type), so no need to lock + lock=False, ) -- cgit 1.5.1 From cdbb6412327b542e0dead792717fe58253291131 Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Tue, 13 Sep 2022 08:16:37 +0100 Subject: Add receipts event stream ordering (#13703) --- changelog.d/13703.misc | 1 + synapse/_scripts/synapse_port_db.py | 2 + synapse/storage/databases/main/receipts.py | 74 +++++++++++++++++++++- .../delta/72/05receipts_event_stream_ordering.sql | 19 ++++++ 4 files changed, 95 insertions(+), 1 deletion(-) create mode 100644 changelog.d/13703.misc create mode 100644 synapse/storage/schema/main/delta/72/05receipts_event_stream_ordering.sql (limited to 'synapse/storage/databases/main/receipts.py') diff --git a/changelog.d/13703.misc b/changelog.d/13703.misc new file mode 100644 index 0000000000..685a29b17d --- /dev/null +++ b/changelog.d/13703.misc @@ -0,0 +1 @@ +Add & populate `event_stream_ordering` column on receipts table for future optimisation of push action processing. Contributed by Nick @ Beeper (@fizzadar). diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py index 543bba27c2..30983c47fb 100755 --- a/synapse/_scripts/synapse_port_db.py +++ b/synapse/_scripts/synapse_port_db.py @@ -67,6 +67,7 @@ from synapse.storage.databases.main.media_repository import ( ) from synapse.storage.databases.main.presence import PresenceBackgroundUpdateStore from synapse.storage.databases.main.pusher import PusherWorkerStore +from synapse.storage.databases.main.receipts import ReceiptsBackgroundUpdateStore from synapse.storage.databases.main.registration import ( RegistrationBackgroundUpdateStore, find_max_generated_user_id_localpart, @@ -203,6 +204,7 @@ class Store( PushRuleStore, PusherWorkerStore, PresenceBackgroundUpdateStore, + ReceiptsBackgroundUpdateStore, ): def execute(self, f: Callable[..., R], *args: Any, **kwargs: Any) -> Awaitable[R]: return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs) diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 3838409519..719a12b0ae 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -675,6 +675,7 @@ class ReceiptsWorkerStore(SQLBaseStore): values={ "stream_id": stream_id, "event_id": event_id, + "event_stream_ordering": stream_ordering, "data": json_encoder.encode(data), }, # receipts_linearized has a unique constraint on @@ -830,5 +831,76 @@ class ReceiptsWorkerStore(SQLBaseStore): ) -class ReceiptsStore(ReceiptsWorkerStore): +class ReceiptsBackgroundUpdateStore(SQLBaseStore): + POPULATE_RECEIPT_EVENT_STREAM_ORDERING = "populate_event_stream_ordering" + + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): + super().__init__(database, db_conn, hs) + + self.db_pool.updates.register_background_update_handler( + self.POPULATE_RECEIPT_EVENT_STREAM_ORDERING, + self._populate_receipt_event_stream_ordering, + ) + + async def _populate_receipt_event_stream_ordering( + self, progress: JsonDict, batch_size: int + ) -> int: + def _populate_receipt_event_stream_ordering_txn( + txn: LoggingTransaction, + ) -> bool: + + if "max_stream_id" in progress: + max_stream_id = progress["max_stream_id"] + else: + txn.execute("SELECT max(stream_id) FROM receipts_linearized") + res = txn.fetchone() + if res is None or res[0] is None: + return True + else: + max_stream_id = res[0] + + start = progress.get("stream_id", 0) + stop = start + batch_size + + sql = """ + UPDATE receipts_linearized + SET event_stream_ordering = ( + SELECT stream_ordering + FROM events + WHERE event_id = receipts_linearized.event_id + ) + WHERE stream_id >= ? AND stream_id < ? + """ + txn.execute(sql, (start, stop)) + + self.db_pool.updates._background_update_progress_txn( + txn, + self.POPULATE_RECEIPT_EVENT_STREAM_ORDERING, + { + "stream_id": stop, + "max_stream_id": max_stream_id, + }, + ) + + return stop > max_stream_id + + finished = await self.db_pool.runInteraction( + "_remove_devices_from_device_inbox_txn", + _populate_receipt_event_stream_ordering_txn, + ) + + if finished: + await self.db_pool.updates._end_background_update( + self.POPULATE_RECEIPT_EVENT_STREAM_ORDERING + ) + + return batch_size + + +class ReceiptsStore(ReceiptsWorkerStore, ReceiptsBackgroundUpdateStore): pass diff --git a/synapse/storage/schema/main/delta/72/05receipts_event_stream_ordering.sql b/synapse/storage/schema/main/delta/72/05receipts_event_stream_ordering.sql new file mode 100644 index 0000000000..2a822f4509 --- /dev/null +++ b/synapse/storage/schema/main/delta/72/05receipts_event_stream_ordering.sql @@ -0,0 +1,19 @@ +/* Copyright 2022 Beeper + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE receipts_linearized ADD COLUMN event_stream_ordering BIGINT; + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('populate_event_stream_ordering', '{}'); -- cgit 1.5.1 From 666ae877292d4747b9441105e3df8558f7a335c0 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 14 Sep 2022 13:11:16 -0400 Subject: Update event push action and receipt tables to support threads. (#13753) Adds a `thread_id` column to the `event_push_actions`, `event_push_actions_staging`, and `event_push_summary` tables. This will notifications to be segmented by the thread in a future pull request. The `thread_id` column stores the root event ID or the special value `"main"`. The `thread_id` column for `event_push_actions` and `event_push_summary` is backfilled with `"main"` for all existing rows. New entries into `event_push_actions` and `event_push_actions_staging` will get the proper thread ID. `receipts_linearized` and `receipts_graph` also gain a `thread_id` column, which is similar, except `NULL` is a special value meaning the receipt is "unthreaded". See MSC3771 and MSC3773 for where this data will be useful. --- changelog.d/13753.misc | 1 + synapse/push/bulk_push_rule_evaluator.py | 29 ++--- .../storage/databases/main/event_push_actions.py | 121 ++++++++++++++++++++- synapse/storage/databases/main/events.py | 4 +- synapse/storage/databases/main/receipts.py | 20 ++++ synapse/storage/schema/__init__.py | 6 +- .../main/delta/72/06thread_notifications.sql | 30 +++++ .../main/delta/72/07thread_receipts.sql.postgres | 30 +++++ .../main/delta/72/07thread_receipts.sql.sqlite | 70 ++++++++++++ .../schema/main/delta/72/08thread_receipts.sql | 20 ++++ tests/replication/slave/storage/test_events.py | 1 + 11 files changed, 312 insertions(+), 20 deletions(-) create mode 100644 changelog.d/13753.misc create mode 100644 synapse/storage/schema/main/delta/72/06thread_notifications.sql create mode 100644 synapse/storage/schema/main/delta/72/07thread_receipts.sql.postgres create mode 100644 synapse/storage/schema/main/delta/72/07thread_receipts.sql.sqlite create mode 100644 synapse/storage/schema/main/delta/72/08thread_receipts.sql (limited to 'synapse/storage/databases/main/receipts.py') diff --git a/changelog.d/13753.misc b/changelog.d/13753.misc new file mode 100644 index 0000000000..63de2eb9f9 --- /dev/null +++ b/changelog.d/13753.misc @@ -0,0 +1 @@ +Prepatory work for storing thread IDs for notifications and receipts. diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index d1caf8a0f7..3846fbc5f0 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -198,7 +198,7 @@ class BulkPushRuleEvaluator: return pl_event.content if pl_event else {}, sender_level async def _get_mutual_relations( - self, event: EventBase, rules: Iterable[Tuple[PushRule, bool]] + self, parent_id: str, rules: Iterable[Tuple[PushRule, bool]] ) -> Dict[str, Set[Tuple[str, str]]]: """ Fetch event metadata for events which related to the same event as the given event. @@ -206,7 +206,7 @@ class BulkPushRuleEvaluator: If the given event has no relation information, returns an empty dictionary. Args: - event_id: The event ID which is targeted by relations. + parent_id: The event ID which is targeted by relations. rules: The push rules which will be processed for this event. Returns: @@ -220,12 +220,6 @@ class BulkPushRuleEvaluator: if not self._relations_match_enabled: return {} - # If the event does not have a relation, then cannot have any mutual - # relations. - relation = relation_from_event(event) - if not relation: - return {} - # Pre-filter to figure out which relation types are interesting. rel_types = set() for rule, enabled in rules: @@ -246,9 +240,7 @@ class BulkPushRuleEvaluator: return {} # If any valid rules were found, fetch the mutual relations. - return await self.store.get_mutual_event_relations( - relation.parent_id, rel_types - ) + return await self.store.get_mutual_event_relations(parent_id, rel_types) @measure_func("action_for_event_by_user") async def action_for_event_by_user( @@ -281,9 +273,17 @@ class BulkPushRuleEvaluator: sender_power_level, ) = await self._get_power_levels_and_sender_level(event, context) - relations = await self._get_mutual_relations( - event, itertools.chain(*rules_by_user.values()) - ) + relation = relation_from_event(event) + # If the event does not have a relation, then cannot have any mutual + # relations or thread ID. + relations = {} + thread_id = "main" + if relation: + relations = await self._get_mutual_relations( + relation.parent_id, itertools.chain(*rules_by_user.values()) + ) + if relation.rel_type == RelationTypes.THREAD: + thread_id = relation.parent_id evaluator = PushRuleEvaluatorForEvent( event, @@ -352,6 +352,7 @@ class BulkPushRuleEvaluator: event.event_id, actions_by_user, count_as_unread, + thread_id, ) diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index 3a3fb8c507..6b8668d2dc 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -98,6 +98,7 @@ from synapse.storage.database import ( ) from synapse.storage.databases.main.receipts import ReceiptsWorkerStore from synapse.storage.databases.main.stream import StreamWorkerStore +from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.caches.descriptors import cached @@ -232,6 +233,104 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas replaces_index="event_push_summary_user_rm", ) + self.db_pool.updates.register_background_index_update( + "event_push_summary_unique_index2", + index_name="event_push_summary_unique_index2", + table="event_push_summary", + columns=["user_id", "room_id", "thread_id"], + unique=True, + ) + + self.db_pool.updates.register_background_update_handler( + "event_push_backfill_thread_id", + self._background_backfill_thread_id, + ) + + async def _background_backfill_thread_id( + self, progress: JsonDict, batch_size: int + ) -> int: + """ + Fill in the thread_id field for event_push_actions and event_push_summary. + + This is preparatory so that it can be made non-nullable in the future. + + Because all current (null) data is done in an unthreaded manner this + simply assumes it is on the "main" timeline. Since event_push_actions + are periodically cleared it is not possible to correctly re-calculate + the thread_id. + """ + event_push_actions_done = progress.get("event_push_actions_done", False) + + def add_thread_id_txn( + txn: LoggingTransaction, table_name: str, start_stream_ordering: int + ) -> int: + sql = f""" + SELECT stream_ordering + FROM {table_name} + WHERE + thread_id IS NULL + AND stream_ordering > ? + ORDER BY stream_ordering + LIMIT ? + """ + txn.execute(sql, (start_stream_ordering, batch_size)) + + # No more rows to process. + rows = txn.fetchall() + if not rows: + progress[f"{table_name}_done"] = True + self.db_pool.updates._background_update_progress_txn( + txn, "event_push_backfill_thread_id", progress + ) + return 0 + + # Update the thread ID for any of those rows. + max_stream_ordering = rows[-1][0] + + sql = f""" + UPDATE {table_name} + SET thread_id = 'main' + WHERE stream_ordering <= ? AND thread_id IS NULL + """ + txn.execute(sql, (max_stream_ordering,)) + + # Update progress. + processed_rows = txn.rowcount + progress[f"max_{table_name}_stream_ordering"] = max_stream_ordering + self.db_pool.updates._background_update_progress_txn( + txn, "event_push_backfill_thread_id", progress + ) + + return processed_rows + + # First update the event_push_actions table, then the event_push_summary table. + # + # Note that the event_push_actions_staging table is ignored since it is + # assumed that items in that table will only exist for a short period of + # time. + if not event_push_actions_done: + result = await self.db_pool.runInteraction( + "event_push_backfill_thread_id", + add_thread_id_txn, + "event_push_actions", + progress.get("max_event_push_actions_stream_ordering", 0), + ) + else: + result = await self.db_pool.runInteraction( + "event_push_backfill_thread_id", + add_thread_id_txn, + "event_push_summary", + progress.get("max_event_push_summary_stream_ordering", 0), + ) + + # Only done after the event_push_summary table is done. + if not result: + await self.db_pool.updates._end_background_update( + "event_push_backfill_thread_id" + ) + + return result + @cached(tree=True, max_entries=5000) async def get_unread_event_push_actions_by_room_for_user( self, @@ -670,6 +769,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas event_id: str, user_id_actions: Dict[str, Collection[Union[Mapping, str]]], count_as_unread: bool, + thread_id: str, ) -> None: """Add the push actions for the event to the push action staging area. @@ -678,6 +778,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas user_id_actions: A mapping of user_id to list of push actions, where an action can either be a string or dict. count_as_unread: Whether this event should increment unread counts. + thread_id: The thread this event is parent of, if applicable. """ if not user_id_actions: return @@ -686,7 +787,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas # can be used to insert into the `event_push_actions_staging` table. def _gen_entry( user_id: str, actions: Collection[Union[Mapping, str]] - ) -> Tuple[str, str, str, int, int, int]: + ) -> Tuple[str, str, str, int, int, int, str]: is_highlight = 1 if _action_has_highlight(actions) else 0 notif = 1 if "notify" in actions else 0 return ( @@ -696,11 +797,20 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas notif, # notif column is_highlight, # highlight column int(count_as_unread), # unread column + thread_id, # thread_id column ) await self.db_pool.simple_insert_many( "event_push_actions_staging", - keys=("event_id", "user_id", "actions", "notif", "highlight", "unread"), + keys=( + "event_id", + "user_id", + "actions", + "notif", + "highlight", + "unread", + "thread_id", + ), values=[ _gen_entry(user_id, actions) for user_id, actions in user_id_actions.items() @@ -981,6 +1091,8 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas ) # Replace the previous summary with the new counts. + # + # TODO(threads): Upsert per-thread instead of setting them all to main. self.db_pool.simple_upsert_txn( txn, table="event_push_summary", @@ -990,6 +1102,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas "unread_count": unread_count, "stream_ordering": old_rotate_stream_ordering, "last_receipt_stream_ordering": stream_ordering, + "thread_id": "main", }, ) @@ -1138,17 +1251,19 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas logger.info("Rotating notifications, handling %d rows", len(summaries)) + # TODO(threads): Update on a per-thread basis. self.db_pool.simple_upsert_many_txn( txn, table="event_push_summary", key_names=("user_id", "room_id"), key_values=[(user_id, room_id) for user_id, room_id in summaries], - value_names=("notif_count", "unread_count", "stream_ordering"), + value_names=("notif_count", "unread_count", "stream_ordering", "thread_id"), value_values=[ ( summary.notif_count, summary.unread_count, summary.stream_ordering, + "main", ) for summary in summaries.values() ], diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index a4010ee28d..c0b4080e4b 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -2192,9 +2192,9 @@ class PersistEventsStore: sql = """ INSERT INTO event_push_actions ( room_id, event_id, user_id, actions, stream_ordering, - topological_ordering, notif, highlight, unread + topological_ordering, notif, highlight, unread, thread_id ) - SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight, unread + SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight, unread, thread_id FROM event_push_actions_staging WHERE event_id = ? """ diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 719a12b0ae..ddb8e80b69 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -113,6 +113,24 @@ class ReceiptsWorkerStore(SQLBaseStore): prefilled_cache=receipts_stream_prefill, ) + self.db_pool.updates.register_background_index_update( + "receipts_linearized_unique_index", + index_name="receipts_linearized_unique_index", + table="receipts_linearized", + columns=["room_id", "receipt_type", "user_id"], + where_clause="thread_id IS NULL", + unique=True, + ) + + self.db_pool.updates.register_background_index_update( + "receipts_graph_unique_index", + index_name="receipts_graph_unique_index", + table="receipts_graph", + columns=["room_id", "receipt_type", "user_id"], + where_clause="thread_id IS NULL", + unique=True, + ) + def get_max_receipt_stream_id(self) -> int: """Get the current max stream ID for receipts stream""" return self._receipts_id_gen.get_current_token() @@ -677,6 +695,7 @@ class ReceiptsWorkerStore(SQLBaseStore): "event_id": event_id, "event_stream_ordering": stream_ordering, "data": json_encoder.encode(data), + "thread_id": None, }, # receipts_linearized has a unique constraint on # (user_id, room_id, receipt_type), so no need to lock @@ -824,6 +843,7 @@ class ReceiptsWorkerStore(SQLBaseStore): values={ "event_ids": json_encoder.encode(event_ids), "data": json_encoder.encode(data), + "thread_id": None, }, # receipts_graph has a unique constraint on # (user_id, room_id, receipt_type), so no need to lock diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py index 32cda5e3ba..38c9532bfd 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -SCHEMA_VERSION = 72 # remember to update the list below when updating +SCHEMA_VERSION = 73 # remember to update the list below when updating """Represents the expectations made by the codebase about the database schema This should be incremented whenever the codebase changes its requirements on the @@ -77,6 +77,10 @@ Changes in SCHEMA_VERSION = 72: - Tables related to groups are dropped. - Unused column application_services_state.last_txn is dropped - Cache invalidation stream id sequence now begins at 2 to match code expectation. + +Changes in SCHEMA_VERSION = 73; + - thread_id column is added to event_push_actions, event_push_actions_staging + event_push_summary, receipts_linearized, and receipts_graph. """ diff --git a/synapse/storage/schema/main/delta/72/06thread_notifications.sql b/synapse/storage/schema/main/delta/72/06thread_notifications.sql new file mode 100644 index 0000000000..2f4f5dac7a --- /dev/null +++ b/synapse/storage/schema/main/delta/72/06thread_notifications.sql @@ -0,0 +1,30 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Add a nullable column for thread ID to the event push actions tables; this +-- will be filled in with a default value for any previously existing rows. +-- +-- After migration this can be made non-nullable. + +ALTER TABLE event_push_actions_staging ADD COLUMN thread_id TEXT; +ALTER TABLE event_push_actions ADD COLUMN thread_id TEXT; +ALTER TABLE event_push_summary ADD COLUMN thread_id TEXT; + +-- Update the unique index for `event_push_summary`. +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (7006, 'event_push_summary_unique_index2', '{}'); + +INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES + (7006, 'event_push_backfill_thread_id', '{}', 'event_push_summary_unique_index2'); diff --git a/synapse/storage/schema/main/delta/72/07thread_receipts.sql.postgres b/synapse/storage/schema/main/delta/72/07thread_receipts.sql.postgres new file mode 100644 index 0000000000..55fff9e278 --- /dev/null +++ b/synapse/storage/schema/main/delta/72/07thread_receipts.sql.postgres @@ -0,0 +1,30 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Add a nullable column for thread ID to the receipts table; this allows a +-- receipt per user, per room, as well as an unthreaded receipt (corresponding +-- to a null thread ID). + +ALTER TABLE receipts_linearized ADD COLUMN thread_id TEXT; +ALTER TABLE receipts_graph ADD COLUMN thread_id TEXT; + +-- Rebuild the unique constraint with the thread_id. +ALTER TABLE receipts_linearized + ADD CONSTRAINT receipts_linearized_uniqueness_thread + UNIQUE (room_id, receipt_type, user_id, thread_id); + +ALTER TABLE receipts_graph + ADD CONSTRAINT receipts_graph_uniqueness_thread + UNIQUE (room_id, receipt_type, user_id, thread_id); diff --git a/synapse/storage/schema/main/delta/72/07thread_receipts.sql.sqlite b/synapse/storage/schema/main/delta/72/07thread_receipts.sql.sqlite new file mode 100644 index 0000000000..232f67deb4 --- /dev/null +++ b/synapse/storage/schema/main/delta/72/07thread_receipts.sql.sqlite @@ -0,0 +1,70 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Allow multiple receipts per user per room via a nullable thread_id column. +-- +-- SQLite doesn't support modifying constraints to an existing table, so it must +-- be recreated. + +-- Create the new tables. +CREATE TABLE receipts_linearized_new ( + stream_id BIGINT NOT NULL, + room_id TEXT NOT NULL, + receipt_type TEXT NOT NULL, + user_id TEXT NOT NULL, + event_id TEXT NOT NULL, + thread_id TEXT, + event_stream_ordering BIGINT, + data TEXT NOT NULL, + CONSTRAINT receipts_linearized_uniqueness UNIQUE (room_id, receipt_type, user_id), + CONSTRAINT receipts_linearized_uniqueness_thread UNIQUE (room_id, receipt_type, user_id, thread_id) +); + +CREATE TABLE receipts_graph_new ( + room_id TEXT NOT NULL, + receipt_type TEXT NOT NULL, + user_id TEXT NOT NULL, + event_ids TEXT NOT NULL, + thread_id TEXT, + data TEXT NOT NULL, + CONSTRAINT receipts_graph_uniqueness UNIQUE (room_id, receipt_type, user_id), + CONSTRAINT receipts_graph_uniqueness_thread UNIQUE (room_id, receipt_type, user_id, thread_id) +); + +-- Drop the old indexes. +DROP INDEX IF EXISTS receipts_linearized_id; +DROP INDEX IF EXISTS receipts_linearized_room_stream; +DROP INDEX IF EXISTS receipts_linearized_user; + +-- Copy the data. +INSERT INTO receipts_linearized_new (stream_id, room_id, receipt_type, user_id, event_id, event_stream_ordering, data) + SELECT stream_id, room_id, receipt_type, user_id, event_id, event_stream_ordering, data + FROM receipts_linearized; +INSERT INTO receipts_graph_new (room_id, receipt_type, user_id, event_ids, data) + SELECT room_id, receipt_type, user_id, event_ids, data + FROM receipts_graph; + +-- Drop the old tables. +DROP TABLE receipts_linearized; +DROP TABLE receipts_graph; + +-- Rename the tables. +ALTER TABLE receipts_linearized_new RENAME TO receipts_linearized; +ALTER TABLE receipts_graph_new RENAME TO receipts_graph; + +-- Create the indices. +CREATE INDEX receipts_linearized_id ON receipts_linearized( stream_id ); +CREATE INDEX receipts_linearized_room_stream ON receipts_linearized( room_id, stream_id ); +CREATE INDEX receipts_linearized_user ON receipts_linearized( user_id ); diff --git a/synapse/storage/schema/main/delta/72/08thread_receipts.sql b/synapse/storage/schema/main/delta/72/08thread_receipts.sql new file mode 100644 index 0000000000..e35b021f31 --- /dev/null +++ b/synapse/storage/schema/main/delta/72/08thread_receipts.sql @@ -0,0 +1,20 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (7007, 'receipts_linearized_unique_index', '{}'); + +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (7007, 'receipts_graph_unique_index', '{}'); diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 531a0db2d0..49a21e2e85 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -404,6 +404,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): event.event_id, {user_id: actions for user_id, actions in push_actions}, False, + "main", ) ) return event, context -- cgit 1.5.1 From efd108b45d1706526416bc9a6f89463b5ff4506a Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 23 Sep 2022 10:33:28 -0400 Subject: Accept & store thread IDs for receipts (implement MSC3771). (#13782) Updates the `/receipts` endpoint and receipt EDU handler to parse a `thread_id` from the body and insert it in the database. --- changelog.d/13782.feature | 1 + synapse/config/experimental.py | 2 + synapse/handlers/receipts.py | 23 ++++++- synapse/replication/tcp/client.py | 3 +- synapse/replication/tcp/streams/_base.py | 1 + synapse/rest/client/read_marker.py | 2 + synapse/rest/client/receipts.py | 14 ++++- synapse/rest/client/versions.py | 2 + synapse/storage/database.py | 2 + synapse/storage/databases/main/receipts.py | 87 +++++++++++++++++++------- synapse/types.py | 1 + tests/federation/test_federation_sender.py | 21 ++++++- tests/handlers/test_appservice.py | 1 + tests/replication/slave/storage/test_events.py | 2 +- tests/replication/tcp/streams/test_receipts.py | 15 ++++- tests/storage/test_event_push_actions.py | 1 + tests/storage/test_receipts.py | 36 ++++++++--- 17 files changed, 173 insertions(+), 41 deletions(-) create mode 100644 changelog.d/13782.feature (limited to 'synapse/storage/databases/main/receipts.py') diff --git a/changelog.d/13782.feature b/changelog.d/13782.feature new file mode 100644 index 0000000000..d0cb902dff --- /dev/null +++ b/changelog.d/13782.feature @@ -0,0 +1 @@ +Experimental support for thread-specific receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771)). diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 595eb007a5..933779c23a 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -83,6 +83,8 @@ class ExperimentalConfig(Config): # MSC3786 (Add a default push rule to ignore m.room.server_acl events) self.msc3786_enabled: bool = experimental.get("msc3786_enabled", False) + # MSC3771: Thread read receipts + self.msc3771_enabled: bool = experimental.get("msc3771_enabled", False) # MSC3772: A push rule for mutual relations. self.msc3772_enabled: bool = experimental.get("msc3772_enabled", False) diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index afaf3261df..4768a34c07 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -63,6 +63,8 @@ class ReceiptsHandler: self.clock = self.hs.get_clock() self.state = hs.get_state_handler() + self._msc3771_enabled = hs.config.experimental.msc3771_enabled + async def _received_remote_receipt(self, origin: str, content: JsonDict) -> None: """Called when we receive an EDU of type m.receipt from a remote HS.""" receipts = [] @@ -91,13 +93,23 @@ class ReceiptsHandler: ) continue + # Check if these receipts apply to a thread. + thread_id = None + data = user_values.get("data", {}) + if self._msc3771_enabled and isinstance(data, dict): + thread_id = data.get("thread_id") + # If the thread ID is invalid, consider it missing. + if not isinstance(thread_id, str): + thread_id = None + receipts.append( ReadReceipt( room_id=room_id, receipt_type=receipt_type, user_id=user_id, event_ids=user_values["event_ids"], - data=user_values.get("data", {}), + thread_id=thread_id, + data=data, ) ) @@ -114,6 +126,7 @@ class ReceiptsHandler: receipt.receipt_type, receipt.user_id, receipt.event_ids, + receipt.thread_id, receipt.data, ) @@ -146,7 +159,12 @@ class ReceiptsHandler: return True async def received_client_receipt( - self, room_id: str, receipt_type: str, user_id: str, event_id: str + self, + room_id: str, + receipt_type: str, + user_id: str, + event_id: str, + thread_id: Optional[str], ) -> None: """Called when a client tells us a local user has read up to the given event_id in the room. @@ -156,6 +174,7 @@ class ReceiptsHandler: receipt_type=receipt_type, user_id=user_id, event_ids=[event_id], + thread_id=thread_id, data={"ts": int(self.clock.time_msec())}, ) diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index cf9cd6833b..b2522f98ca 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -427,7 +427,8 @@ class FederationSenderHandler: receipt.receipt_type, receipt.user_id, [receipt.event_id], - receipt.data, + thread_id=receipt.thread_id, + data=receipt.data, ) await self.federation_sender.send_read_receipt(receipt_info) diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 398bebeaa6..e01155ad59 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -361,6 +361,7 @@ class ReceiptsStream(Stream): receipt_type: str user_id: str event_id: str + thread_id: Optional[str] data: dict NAME = "receipts" diff --git a/synapse/rest/client/read_marker.py b/synapse/rest/client/read_marker.py index 5e53096539..852838515c 100644 --- a/synapse/rest/client/read_marker.py +++ b/synapse/rest/client/read_marker.py @@ -83,6 +83,8 @@ class ReadMarkerRestServlet(RestServlet): receipt_type, user_id=requester.user.to_string(), event_id=event_id, + # Setting the thread ID is not possible with the /read_markers endpoint. + thread_id=None, ) return 200, {} diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py index 5b7fad7402..f3ff156abe 100644 --- a/synapse/rest/client/receipts.py +++ b/synapse/rest/client/receipts.py @@ -49,6 +49,7 @@ class ReceiptRestServlet(RestServlet): ReceiptTypes.READ_PRIVATE, ReceiptTypes.FULLY_READ, } + self._msc3771_enabled = hs.config.experimental.msc3771_enabled async def on_POST( self, request: SynapseRequest, room_id: str, receipt_type: str, event_id: str @@ -61,7 +62,17 @@ class ReceiptRestServlet(RestServlet): f"Receipt type must be {', '.join(self._known_receipt_types)}", ) - parse_json_object_from_request(request, allow_empty_body=False) + body = parse_json_object_from_request(request) + + # Pull the thread ID, if one exists. + thread_id = None + if self._msc3771_enabled: + if "thread_id" in body: + thread_id = body.get("thread_id") + if not thread_id or not isinstance(thread_id, str): + raise SynapseError( + 400, "thread_id field must be a non-empty string" + ) await self.presence_handler.bump_presence_active_time(requester.user) @@ -77,6 +88,7 @@ class ReceiptRestServlet(RestServlet): receipt_type, user_id=requester.user.to_string(), event_id=event_id, + thread_id=thread_id, ) return 200, {} diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index b3917a5abc..c95b0d6f19 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -103,6 +103,8 @@ class VersionsRestServlet(RestServlet): "org.matrix.msc3030": self.config.experimental.msc3030_enabled, # Adds support for thread relations, per MSC3440. "org.matrix.msc3440.stable": True, # TODO: remove when "v1.3" is added above + # Support for thread read receipts. + "org.matrix.msc3771": self.config.experimental.msc3771_enabled, # Allows moderators to fetch redacted event content as described in MSC2815 "fi.mau.msc2815": self.config.experimental.msc2815_enabled, # Adds support for login token requests as per MSC3882 diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 921cd4dc5e..9d116f6925 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -95,6 +95,8 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = { "local_media_repository_thumbnails": "local_media_repository_thumbnails_method_idx", "remote_media_cache_thumbnails": "remote_media_repository_thumbnails_method_idx", "event_push_summary": "event_push_summary_unique_index", + "receipts_linearized": "receipts_linearized_unique_index", + "receipts_graph": "receipts_graph_unique_index", } diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index ddb8e80b69..52fe0db924 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -540,7 +540,9 @@ class ReceiptsWorkerStore(SQLBaseStore): async def get_all_updated_receipts( self, instance_name: str, last_id: int, current_id: int, limit: int - ) -> Tuple[List[Tuple[int, list]], int, bool]: + ) -> Tuple[ + List[Tuple[int, Tuple[str, str, str, str, Optional[str], JsonDict]]], int, bool + ]: """Get updates for receipts replication stream. Args: @@ -567,9 +569,13 @@ class ReceiptsWorkerStore(SQLBaseStore): def get_all_updated_receipts_txn( txn: LoggingTransaction, - ) -> Tuple[List[Tuple[int, list]], int, bool]: + ) -> Tuple[ + List[Tuple[int, Tuple[str, str, str, str, Optional[str], JsonDict]]], + int, + bool, + ]: sql = """ - SELECT stream_id, room_id, receipt_type, user_id, event_id, data + SELECT stream_id, room_id, receipt_type, user_id, event_id, thread_id, data FROM receipts_linearized WHERE ? < stream_id AND stream_id <= ? ORDER BY stream_id ASC @@ -578,8 +584,8 @@ class ReceiptsWorkerStore(SQLBaseStore): txn.execute(sql, (last_id, current_id, limit)) updates = cast( - List[Tuple[int, list]], - [(r[0], r[1:5] + (db_to_json(r[5]),)) for r in txn], + List[Tuple[int, Tuple[str, str, str, str, Optional[str], JsonDict]]], + [(r[0], r[1:6] + (db_to_json(r[6]),)) for r in txn], ) limited = False @@ -631,6 +637,7 @@ class ReceiptsWorkerStore(SQLBaseStore): receipt_type: str, user_id: str, event_id: str, + thread_id: Optional[str], data: JsonDict, stream_id: int, ) -> Optional[int]: @@ -657,12 +664,27 @@ class ReceiptsWorkerStore(SQLBaseStore): # We don't want to clobber receipts for more recent events, so we # have to compare orderings of existing receipts if stream_ordering is not None: - sql = ( - "SELECT stream_ordering, event_id FROM events" - " INNER JOIN receipts_linearized AS r USING (event_id, room_id)" - " WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ?" + if thread_id is None: + thread_clause = "r.thread_id IS NULL" + thread_args: Tuple[str, ...] = () + else: + thread_clause = "r.thread_id = ?" + thread_args = (thread_id,) + + sql = f""" + SELECT stream_ordering, event_id FROM events + INNER JOIN receipts_linearized AS r USING (event_id, room_id) + WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ? AND {thread_clause} + """ + txn.execute( + sql, + ( + room_id, + receipt_type, + user_id, + ) + + thread_args, ) - txn.execute(sql, (room_id, receipt_type, user_id)) for so, eid in txn: if int(so) >= stream_ordering: @@ -682,21 +704,28 @@ class ReceiptsWorkerStore(SQLBaseStore): self._receipts_stream_cache.entity_has_changed, room_id, stream_id ) + keyvalues = { + "room_id": room_id, + "receipt_type": receipt_type, + "user_id": user_id, + } + where_clause = "" + if thread_id is None: + where_clause = "thread_id IS NULL" + else: + keyvalues["thread_id"] = thread_id + self.db_pool.simple_upsert_txn( txn, table="receipts_linearized", - keyvalues={ - "room_id": room_id, - "receipt_type": receipt_type, - "user_id": user_id, - }, + keyvalues=keyvalues, values={ "stream_id": stream_id, "event_id": event_id, "event_stream_ordering": stream_ordering, "data": json_encoder.encode(data), - "thread_id": None, }, + where_clause=where_clause, # receipts_linearized has a unique constraint on # (user_id, room_id, receipt_type), so no need to lock lock=False, @@ -748,6 +777,7 @@ class ReceiptsWorkerStore(SQLBaseStore): receipt_type: str, user_id: str, event_ids: List[str], + thread_id: Optional[str], data: dict, ) -> Optional[Tuple[int, int]]: """Insert a receipt, either from local client or remote server. @@ -780,6 +810,7 @@ class ReceiptsWorkerStore(SQLBaseStore): receipt_type, user_id, linearized_event_id, + thread_id, data, stream_id=stream_id, # Read committed is actually beneficial here because we check for a receipt with @@ -794,7 +825,8 @@ class ReceiptsWorkerStore(SQLBaseStore): now = self._clock.time_msec() logger.debug( - "RR for event %s in %s (%i ms old)", + "Receipt %s for event %s in %s (%i ms old)", + receipt_type, linearized_event_id, room_id, now - event_ts, @@ -807,6 +839,7 @@ class ReceiptsWorkerStore(SQLBaseStore): receipt_type, user_id, event_ids, + thread_id, data, ) @@ -821,6 +854,7 @@ class ReceiptsWorkerStore(SQLBaseStore): receipt_type: str, user_id: str, event_ids: List[str], + thread_id: Optional[str], data: JsonDict, ) -> None: assert self._can_write_to_receipts @@ -832,19 +866,26 @@ class ReceiptsWorkerStore(SQLBaseStore): # FIXME: This shouldn't invalidate the whole cache txn.call_after(self._get_linearized_receipts_for_room.invalidate, (room_id,)) + keyvalues = { + "room_id": room_id, + "receipt_type": receipt_type, + "user_id": user_id, + } + where_clause = "" + if thread_id is None: + where_clause = "thread_id IS NULL" + else: + keyvalues["thread_id"] = thread_id + self.db_pool.simple_upsert_txn( txn, table="receipts_graph", - keyvalues={ - "room_id": room_id, - "receipt_type": receipt_type, - "user_id": user_id, - }, + keyvalues=keyvalues, values={ "event_ids": json_encoder.encode(event_ids), "data": json_encoder.encode(data), - "thread_id": None, }, + where_clause=where_clause, # receipts_graph has a unique constraint on # (user_id, room_id, receipt_type), so no need to lock lock=False, diff --git a/synapse/types.py b/synapse/types.py index ec44601f54..773f0438d5 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -835,6 +835,7 @@ class ReadReceipt: receipt_type: str user_id: str event_ids: List[str] + thread_id: Optional[str] data: JsonDict diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index a5aa500ef8..f1e357764f 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -49,7 +49,12 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): sender = self.hs.get_federation_sender() receipt = ReadReceipt( - "room_id", "m.read", "user_id", ["event_id"], {"ts": 1234} + "room_id", + "m.read", + "user_id", + ["event_id"], + thread_id=None, + data={"ts": 1234}, ) self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt))) @@ -89,7 +94,12 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): sender = self.hs.get_federation_sender() receipt = ReadReceipt( - "room_id", "m.read", "user_id", ["event_id"], {"ts": 1234} + "room_id", + "m.read", + "user_id", + ["event_id"], + thread_id=None, + data={"ts": 1234}, ) self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt))) @@ -121,7 +131,12 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): # send the second RR receipt = ReadReceipt( - "room_id", "m.read", "user_id", ["other_id"], {"ts": 1234} + "room_id", + "m.read", + "user_id", + ["other_id"], + thread_id=None, + data={"ts": 1234}, ) self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt))) self.pump() diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index b17af2725b..af24c4984d 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -447,6 +447,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): receipt_type="m.read", user_id=self.local_user, event_ids=[f"$eventid_{i}"], + thread_id=None, data={}, ) ) diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 49a21e2e85..efd92793c0 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -171,7 +171,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): if send_receipt: self.get_success( self.master_store.insert_receipt( - ROOM_ID, ReceiptTypes.READ, USER_ID_2, [event1.event_id], {} + ROOM_ID, ReceiptTypes.READ, USER_ID_2, [event1.event_id], None, {} ) ) diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py index eb00117845..ede6d0c118 100644 --- a/tests/replication/tcp/streams/test_receipts.py +++ b/tests/replication/tcp/streams/test_receipts.py @@ -33,7 +33,12 @@ class ReceiptsStreamTestCase(BaseStreamTestCase): # tell the master to send a new receipt self.get_success( self.hs.get_datastores().main.insert_receipt( - "!room:blue", "m.read", USER_ID, ["$event:blue"], {"a": 1} + "!room:blue", + "m.read", + USER_ID, + ["$event:blue"], + thread_id=None, + data={"a": 1}, ) ) self.replicate() @@ -48,6 +53,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase): self.assertEqual("m.read", row.receipt_type) self.assertEqual(USER_ID, row.user_id) self.assertEqual("$event:blue", row.event_id) + self.assertIsNone(row.thread_id) self.assertEqual({"a": 1}, row.data) # Now let's disconnect and insert some data. @@ -57,7 +63,12 @@ class ReceiptsStreamTestCase(BaseStreamTestCase): self.get_success( self.hs.get_datastores().main.insert_receipt( - "!room2:blue", "m.read", USER_ID, ["$event2:foo"], {"a": 2} + "!room2:blue", + "m.read", + USER_ID, + ["$event2:foo"], + thread_id=None, + data={"a": 2}, ) ) self.replicate() diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index fc43d7edd1..08c74b93e3 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -106,6 +106,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): "m.read", user_id=user_id, event_ids=[event_id], + thread_id=None, data={}, ) ) diff --git a/tests/storage/test_receipts.py b/tests/storage/test_receipts.py index c89bfff241..9459ee1705 100644 --- a/tests/storage/test_receipts.py +++ b/tests/storage/test_receipts.py @@ -131,13 +131,18 @@ class ReceiptTestCase(HomeserverTestCase): # Send public read receipt for the first event self.get_success( self.store.insert_receipt( - self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_1_id], {} + self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_1_id], None, {} ) ) # Send private read receipt for the second event self.get_success( self.store.insert_receipt( - self.room_id1, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event1_2_id], {} + self.room_id1, + ReceiptTypes.READ_PRIVATE, + OUR_USER_ID, + [event1_2_id], + None, + {}, ) ) @@ -164,7 +169,7 @@ class ReceiptTestCase(HomeserverTestCase): # Test receipt updating self.get_success( self.store.insert_receipt( - self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], {} + self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], None, {} ) ) res = self.get_success( @@ -180,7 +185,12 @@ class ReceiptTestCase(HomeserverTestCase): # Test new room is reflected in what the method returns self.get_success( self.store.insert_receipt( - self.room_id2, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event2_1_id], {} + self.room_id2, + ReceiptTypes.READ_PRIVATE, + OUR_USER_ID, + [event2_1_id], + None, + {}, ) ) res = self.get_success( @@ -202,13 +212,18 @@ class ReceiptTestCase(HomeserverTestCase): # Send public read receipt for the first event self.get_success( self.store.insert_receipt( - self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_1_id], {} + self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_1_id], None, {} ) ) # Send private read receipt for the second event self.get_success( self.store.insert_receipt( - self.room_id1, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event1_2_id], {} + self.room_id1, + ReceiptTypes.READ_PRIVATE, + OUR_USER_ID, + [event1_2_id], + None, + {}, ) ) @@ -241,7 +256,7 @@ class ReceiptTestCase(HomeserverTestCase): # Test receipt updating self.get_success( self.store.insert_receipt( - self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], {} + self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], None, {} ) ) res = self.get_success( @@ -259,7 +274,12 @@ class ReceiptTestCase(HomeserverTestCase): # Test new room is reflected in what the method returns self.get_success( self.store.insert_receipt( - self.room_id2, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event2_1_id], {} + self.room_id2, + ReceiptTypes.READ_PRIVATE, + OUR_USER_ID, + [event2_1_id], + None, + {}, ) ) res = self.get_success( -- cgit 1.5.1 From 568016929f3d22f632cb9145429fa45754a8d59f Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 29 Sep 2022 07:07:31 -0400 Subject: Clarify that a method returns only unthreaded receipts. (#13937) By renaming it and updating the docstring. Additionally, refactors a method which is used only by tests. --- changelog.d/13937.feature | 1 + .../storage/databases/main/event_push_actions.py | 12 +--- synapse/storage/databases/main/receipts.py | 36 ++--------- tests/storage/test_receipts.py | 74 +++++++++++----------- 4 files changed, 47 insertions(+), 76 deletions(-) create mode 100644 changelog.d/13937.feature (limited to 'synapse/storage/databases/main/receipts.py') diff --git a/changelog.d/13937.feature b/changelog.d/13937.feature new file mode 100644 index 0000000000..d0cb902dff --- /dev/null +++ b/changelog.d/13937.feature @@ -0,0 +1 @@ +Experimental support for thread-specific receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771)). diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index f4cdc2e399..7e0ffef7d3 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -366,14 +366,11 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas user_id: str, ) -> NotifCounts: # Get the stream ordering of the user's latest receipt in the room. - result = self.get_last_receipt_for_user_txn( + result = self.get_last_unthreaded_receipt_for_user_txn( txn, user_id, room_id, - receipt_types=( - ReceiptTypes.READ, - ReceiptTypes.READ_PRIVATE, - ), + receipt_types=(ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE), ) if result: @@ -574,10 +571,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas receipt_types_clause, args = make_in_list_sql_clause( self.database_engine, "receipt_type", - ( - ReceiptTypes.READ, - ReceiptTypes.READ_PRIVATE, - ), + (ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE), ) sql = f""" diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 52fe0db924..246f78ac1f 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -135,34 +135,7 @@ class ReceiptsWorkerStore(SQLBaseStore): """Get the current max stream ID for receipts stream""" return self._receipts_id_gen.get_current_token() - async def get_last_receipt_event_id_for_user( - self, user_id: str, room_id: str, receipt_types: Collection[str] - ) -> Optional[str]: - """ - Fetch the event ID for the latest receipt in a room with one of the given receipt types. - - Args: - user_id: The user to fetch receipts for. - room_id: The room ID to fetch the receipt for. - receipt_type: The receipt types to fetch. - - Returns: - The latest receipt, if one exists. - """ - result = await self.db_pool.runInteraction( - "get_last_receipt_event_id_for_user", - self.get_last_receipt_for_user_txn, - user_id, - room_id, - receipt_types, - ) - if not result: - return None - - event_id, _ = result - return event_id - - def get_last_receipt_for_user_txn( + def get_last_unthreaded_receipt_for_user_txn( self, txn: LoggingTransaction, user_id: str, @@ -170,13 +143,13 @@ class ReceiptsWorkerStore(SQLBaseStore): receipt_types: Collection[str], ) -> Optional[Tuple[str, int]]: """ - Fetch the event ID and stream_ordering for the latest receipt in a room - with one of the given receipt types. + Fetch the event ID and stream_ordering for the latest unthreaded receipt + in a room with one of the given receipt types. Args: user_id: The user to fetch receipts for. room_id: The room ID to fetch the receipt for. - receipt_type: The receipt types to fetch. + receipt_types: The receipt types to fetch. Returns: The event ID and stream ordering of the latest receipt, if one exists. @@ -193,6 +166,7 @@ class ReceiptsWorkerStore(SQLBaseStore): WHERE {clause} AND user_id = ? AND room_id = ? + AND thread_id IS NULL ORDER BY stream_ordering DESC LIMIT 1 """ diff --git a/tests/storage/test_receipts.py b/tests/storage/test_receipts.py index 9459ee1705..81253d0361 100644 --- a/tests/storage/test_receipts.py +++ b/tests/storage/test_receipts.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Collection, Optional from synapse.api.constants import ReceiptTypes from synapse.types import UserID, create_requester @@ -84,6 +85,33 @@ class ReceiptTestCase(HomeserverTestCase): ) ) + def get_last_unthreaded_receipt( + self, receipt_types: Collection[str], room_id: Optional[str] = None + ) -> Optional[str]: + """ + Fetch the event ID for the latest unthreaded receipt in the test room for the test user. + + Args: + receipt_types: The receipt types to fetch. + + Returns: + The latest receipt, if one exists. + """ + result = self.get_success( + self.store.db_pool.runInteraction( + "get_last_receipt_event_id_for_user", + self.store.get_last_unthreaded_receipt_for_user_txn, + OUR_USER_ID, + room_id or self.room_id1, + receipt_types, + ) + ) + if not result: + return None + + event_id, _ = result + return event_id + def test_return_empty_with_no_data(self) -> None: res = self.get_success( self.store.get_receipts_for_user( @@ -107,16 +135,10 @@ class ReceiptTestCase(HomeserverTestCase): ) self.assertEqual(res, {}) - res = self.get_success( - self.store.get_last_receipt_event_id_for_user( - OUR_USER_ID, - self.room_id1, - [ - ReceiptTypes.READ, - ReceiptTypes.READ_PRIVATE, - ], - ) + res = self.get_last_unthreaded_receipt( + [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE] ) + self.assertEqual(res, None) def test_get_receipts_for_user(self) -> None: @@ -228,29 +250,17 @@ class ReceiptTestCase(HomeserverTestCase): ) # Test we get the latest event when we want both private and public receipts - res = self.get_success( - self.store.get_last_receipt_event_id_for_user( - OUR_USER_ID, - self.room_id1, - [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE], - ) + res = self.get_last_unthreaded_receipt( + [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE] ) self.assertEqual(res, event1_2_id) # Test we get the older event when we want only public receipt - res = self.get_success( - self.store.get_last_receipt_event_id_for_user( - OUR_USER_ID, self.room_id1, [ReceiptTypes.READ] - ) - ) + res = self.get_last_unthreaded_receipt([ReceiptTypes.READ]) self.assertEqual(res, event1_1_id) # Test we get the latest event when we want only the private receipt - res = self.get_success( - self.store.get_last_receipt_event_id_for_user( - OUR_USER_ID, self.room_id1, [ReceiptTypes.READ_PRIVATE] - ) - ) + res = self.get_last_unthreaded_receipt([ReceiptTypes.READ_PRIVATE]) self.assertEqual(res, event1_2_id) # Test receipt updating @@ -259,11 +269,7 @@ class ReceiptTestCase(HomeserverTestCase): self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], None, {} ) ) - res = self.get_success( - self.store.get_last_receipt_event_id_for_user( - OUR_USER_ID, self.room_id1, [ReceiptTypes.READ] - ) - ) + res = self.get_last_unthreaded_receipt([ReceiptTypes.READ]) self.assertEqual(res, event1_2_id) # Send some events into the second room @@ -282,11 +288,7 @@ class ReceiptTestCase(HomeserverTestCase): {}, ) ) - res = self.get_success( - self.store.get_last_receipt_event_id_for_user( - OUR_USER_ID, - self.room_id2, - [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE], - ) + res = self.get_last_unthreaded_receipt( + [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE], room_id=self.room_id2 ) self.assertEqual(res, event2_1_id) -- cgit 1.5.1 From e6e876b9b158f47811b6dfedd8783f658ce960a4 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 12 Oct 2022 12:18:34 -0400 Subject: Return the thread ID properly down sync. (#14159) A receipt's thread ID, if one exists, should be added to the body of a receipt. --- changelog.d/14159.feature | 1 + synapse/storage/databases/main/receipts.py | 2 ++ 2 files changed, 3 insertions(+) create mode 100644 changelog.d/14159.feature (limited to 'synapse/storage/databases/main/receipts.py') diff --git a/changelog.d/14159.feature b/changelog.d/14159.feature new file mode 100644 index 0000000000..5d0ae16e13 --- /dev/null +++ b/changelog.d/14159.feature @@ -0,0 +1 @@ +Support for thread-specific notifications & receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771) and [MSC3773](https://github.com/matrix-org/matrix-spec-proposals/pull/3773)). diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 246f78ac1f..b04026c21b 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -416,6 +416,8 @@ class ReceiptsWorkerStore(SQLBaseStore): # {"$foo:bar": { "read": { "@user:host": }, .. }, .. } event_entry = room_event["content"].setdefault(row["event_id"], {}) receipt_type = event_entry.setdefault(row["receipt_type"], {}) + if row["thread_id"]: + receipt_type[row["user_id"]]["thread_id"] = row["thread_id"] receipt_type[row["user_id"]] = db_to_json(row["data"]) -- cgit 1.5.1 From 7d59a515bb97dc4f8253aa9a5a560221a0ef4702 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 13 Oct 2022 12:15:41 -0400 Subject: Properly return the thread ID down sync. (#14159) Fix a broken conflict in e6e876b9b158f47811b6dfedd8783f658ce960a4, by not stomping over a field right after creating it. --- synapse/storage/databases/main/receipts.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'synapse/storage/databases/main/receipts.py') diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index b04026c21b..dc6989527e 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -416,10 +416,10 @@ class ReceiptsWorkerStore(SQLBaseStore): # {"$foo:bar": { "read": { "@user:host": }, .. }, .. } event_entry = room_event["content"].setdefault(row["event_id"], {}) receipt_type = event_entry.setdefault(row["receipt_type"], {}) - if row["thread_id"]: - receipt_type[row["user_id"]]["thread_id"] = row["thread_id"] receipt_type[row["user_id"]] = db_to_json(row["data"]) + if row["thread_id"]: + receipt_type[row["user_id"]]["thread_id"] = row["thread_id"] results = { room_id: [results[room_id]] if room_id in results else [] -- cgit 1.5.1 From 36097e88c4da51fce6556a58c49bd675f4cf20ab Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Mon, 14 Nov 2022 17:31:36 +0000 Subject: Remove slaved id tracker (#14376) This matches the multi instance writer ID generator class which can both handle advancing the current token over replication and by calling the database. --- changelog.d/14376.misc | 1 + synapse/replication/slave/__init__.py | 13 ------ synapse/replication/slave/storage/__init__.py | 13 ------ .../slave/storage/_slaved_id_tracker.py | 50 ---------------------- synapse/storage/databases/main/account_data.py | 30 +++++-------- synapse/storage/databases/main/devices.py | 36 ++++++---------- synapse/storage/databases/main/events_worker.py | 35 ++++++--------- synapse/storage/databases/main/push_rule.py | 17 ++++---- synapse/storage/databases/main/pusher.py | 24 ++++------- synapse/storage/databases/main/receipts.py | 18 ++++---- synapse/storage/util/id_generators.py | 13 ++++-- 11 files changed, 74 insertions(+), 176 deletions(-) create mode 100644 changelog.d/14376.misc delete mode 100644 synapse/replication/slave/__init__.py delete mode 100644 synapse/replication/slave/storage/__init__.py delete mode 100644 synapse/replication/slave/storage/_slaved_id_tracker.py (limited to 'synapse/storage/databases/main/receipts.py') diff --git a/changelog.d/14376.misc b/changelog.d/14376.misc new file mode 100644 index 0000000000..2ca326fea6 --- /dev/null +++ b/changelog.d/14376.misc @@ -0,0 +1 @@ +Remove old stream ID tracking code. Contributed by Nick @Beeper (@fizzadar). diff --git a/synapse/replication/slave/__init__.py b/synapse/replication/slave/__init__.py deleted file mode 100644 index f43a360a80..0000000000 --- a/synapse/replication/slave/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/synapse/replication/slave/storage/__init__.py b/synapse/replication/slave/storage/__init__.py deleted file mode 100644 index f43a360a80..0000000000 --- a/synapse/replication/slave/storage/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py deleted file mode 100644 index 8f3f953ed4..0000000000 --- a/synapse/replication/slave/storage/_slaved_id_tracker.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import List, Optional, Tuple - -from synapse.storage.database import LoggingDatabaseConnection -from synapse.storage.util.id_generators import AbstractStreamIdTracker, _load_current_id - - -class SlavedIdTracker(AbstractStreamIdTracker): - """Tracks the "current" stream ID of a stream with a single writer. - - See `AbstractStreamIdTracker` for more details. - - Note that this class does not work correctly when there are multiple - writers. - """ - - def __init__( - self, - db_conn: LoggingDatabaseConnection, - table: str, - column: str, - extra_tables: Optional[List[Tuple[str, str]]] = None, - step: int = 1, - ): - self.step = step - self._current = _load_current_id(db_conn, table, column, step) - if extra_tables: - for table, column in extra_tables: - self.advance(None, _load_current_id(db_conn, table, column)) - - def advance(self, instance_name: Optional[str], new_id: int) -> None: - self._current = (max if self.step > 0 else min)(self._current, new_id) - - def get_current_token(self) -> int: - return self._current - - def get_current_token_for_writer(self, instance_name: str) -> int: - return self.get_current_token() diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index c38b8a9e5a..282687ebce 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -27,7 +27,6 @@ from typing import ( ) from synapse.api.constants import AccountDataTypes -from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream from synapse.storage._base import db_to_json from synapse.storage.database import ( @@ -68,12 +67,11 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) # to write account data. A value of `True` implies that `_account_data_id_gen` # is an `AbstractStreamIdGenerator` and not just a tracker. self._account_data_id_gen: AbstractStreamIdTracker + self._can_write_to_account_data = ( + self._instance_name in hs.config.worker.writers.account_data + ) if isinstance(database.engine, PostgresEngine): - self._can_write_to_account_data = ( - self._instance_name in hs.config.worker.writers.account_data - ) - self._account_data_id_gen = MultiWriterIdGenerator( db_conn=db_conn, db=database, @@ -95,21 +93,13 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets # updated over replication. (Multiple writers are not supported for # SQLite). - if self._instance_name in hs.config.worker.writers.account_data: - self._can_write_to_account_data = True - self._account_data_id_gen = StreamIdGenerator( - db_conn, - "room_account_data", - "stream_id", - extra_tables=[("room_tags_revisions", "stream_id")], - ) - else: - self._account_data_id_gen = SlavedIdTracker( - db_conn, - "room_account_data", - "stream_id", - extra_tables=[("room_tags_revisions", "stream_id")], - ) + self._account_data_id_gen = StreamIdGenerator( + db_conn, + "room_account_data", + "stream_id", + extra_tables=[("room_tags_revisions", "stream_id")], + is_writer=self._instance_name in hs.config.worker.writers.account_data, + ) account_max = self.get_max_account_data_stream_id() self._account_data_stream_cache = StreamChangeCache( diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index aa58c2adc3..3e5c16b15b 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -38,7 +38,6 @@ from synapse.logging.opentracing import ( whitelisted_homeserver, ) from synapse.metrics.background_process_metrics import wrap_as_background_process -from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import ( @@ -86,28 +85,19 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): ): super().__init__(database, db_conn, hs) - if hs.config.worker.worker_app is None: - self._device_list_id_gen: AbstractStreamIdTracker = StreamIdGenerator( - db_conn, - "device_lists_stream", - "stream_id", - extra_tables=[ - ("user_signature_stream", "stream_id"), - ("device_lists_outbound_pokes", "stream_id"), - ("device_lists_changes_in_room", "stream_id"), - ], - ) - else: - self._device_list_id_gen = SlavedIdTracker( - db_conn, - "device_lists_stream", - "stream_id", - extra_tables=[ - ("user_signature_stream", "stream_id"), - ("device_lists_outbound_pokes", "stream_id"), - ("device_lists_changes_in_room", "stream_id"), - ], - ) + # In the worker store this is an ID tracker which we overwrite in the non-worker + # class below that is used on the main process. + self._device_list_id_gen: AbstractStreamIdTracker = StreamIdGenerator( + db_conn, + "device_lists_stream", + "stream_id", + extra_tables=[ + ("user_signature_stream", "stream_id"), + ("device_lists_outbound_pokes", "stream_id"), + ("device_lists_changes_in_room", "stream_id"), + ], + is_writer=hs.config.worker.worker_app is None, + ) # Type-ignore: _device_list_id_gen is mixed in from either DataStore (as a # StreamIdGenerator) or SlavedDataStore (as a SlavedIdTracker). diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index a79091952a..7a003ab88f 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -59,7 +59,6 @@ from synapse.metrics.background_process_metrics import ( run_as_background_process, wrap_as_background_process, ) -from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import BackfillStream from synapse.replication.tcp.streams.events import EventsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause @@ -213,26 +212,20 @@ class EventsWorkerStore(SQLBaseStore): # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets # updated over replication. (Multiple writers are not supported for # SQLite). - if hs.get_instance_name() in hs.config.worker.writers.events: - self._stream_id_gen = StreamIdGenerator( - db_conn, - "events", - "stream_ordering", - ) - self._backfill_id_gen = StreamIdGenerator( - db_conn, - "events", - "stream_ordering", - step=-1, - extra_tables=[("ex_outlier_stream", "event_stream_ordering")], - ) - else: - self._stream_id_gen = SlavedIdTracker( - db_conn, "events", "stream_ordering" - ) - self._backfill_id_gen = SlavedIdTracker( - db_conn, "events", "stream_ordering", step=-1 - ) + self._stream_id_gen = StreamIdGenerator( + db_conn, + "events", + "stream_ordering", + is_writer=hs.get_instance_name() in hs.config.worker.writers.events, + ) + self._backfill_id_gen = StreamIdGenerator( + db_conn, + "events", + "stream_ordering", + step=-1, + extra_tables=[("ex_outlier_stream", "event_stream_ordering")], + is_writer=hs.get_instance_name() in hs.config.worker.writers.events, + ) events_max = self._stream_id_gen.get_current_token() curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict( diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 8ae10f6127..12ad44dbb3 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -30,7 +30,6 @@ from typing import ( from synapse.api.errors import StoreError from synapse.config.homeserver import ExperimentalConfig -from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import PushRulesStream from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( @@ -111,14 +110,14 @@ class PushRulesWorkerStore( ): super().__init__(database, db_conn, hs) - if hs.config.worker.worker_app is None: - self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator( - db_conn, "push_rules_stream", "stream_id" - ) - else: - self._push_rules_stream_id_gen = SlavedIdTracker( - db_conn, "push_rules_stream", "stream_id" - ) + # In the worker store this is an ID tracker which we overwrite in the non-worker + # class below that is used on the main process. + self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator( + db_conn, + "push_rules_stream", + "stream_id", + is_writer=hs.config.worker.worker_app is None, + ) push_rules_prefill, push_rules_id = self.db_pool.get_cache_dict( db_conn, diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index 4a01562d45..fee37b9ce4 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -27,7 +27,6 @@ from typing import ( ) from synapse.push import PusherConfig, ThrottleParams -from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import PushersStream from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import ( @@ -59,20 +58,15 @@ class PusherWorkerStore(SQLBaseStore): ): super().__init__(database, db_conn, hs) - if hs.config.worker.worker_app is None: - self._pushers_id_gen: AbstractStreamIdTracker = StreamIdGenerator( - db_conn, - "pushers", - "id", - extra_tables=[("deleted_pushers", "stream_id")], - ) - else: - self._pushers_id_gen = SlavedIdTracker( - db_conn, - "pushers", - "id", - extra_tables=[("deleted_pushers", "stream_id")], - ) + # In the worker store this is an ID tracker which we overwrite in the non-worker + # class below that is used on the main process. + self._pushers_id_gen: AbstractStreamIdTracker = StreamIdGenerator( + db_conn, + "pushers", + "id", + extra_tables=[("deleted_pushers", "stream_id")], + is_writer=hs.config.worker.worker_app is None, + ) self.db_pool.updates.register_background_update_handler( "remove_deactivated_pushers", diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index dc6989527e..64519587f8 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -27,7 +27,6 @@ from typing import ( ) from synapse.api.constants import EduTypes -from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import ReceiptsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import ( @@ -61,6 +60,9 @@ class ReceiptsWorkerStore(SQLBaseStore): hs: "HomeServer", ): self._instance_name = hs.get_instance_name() + + # In the worker store this is an ID tracker which we overwrite in the non-worker + # class below that is used on the main process. self._receipts_id_gen: AbstractStreamIdTracker if isinstance(database.engine, PostgresEngine): @@ -87,14 +89,12 @@ class ReceiptsWorkerStore(SQLBaseStore): # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets # updated over replication. (Multiple writers are not supported for # SQLite). - if hs.get_instance_name() in hs.config.worker.writers.receipts: - self._receipts_id_gen = StreamIdGenerator( - db_conn, "receipts_linearized", "stream_id" - ) - else: - self._receipts_id_gen = SlavedIdTracker( - db_conn, "receipts_linearized", "stream_id" - ) + self._receipts_id_gen = StreamIdGenerator( + db_conn, + "receipts_linearized", + "stream_id", + is_writer=hs.get_instance_name() in hs.config.worker.writers.receipts, + ) super().__init__(database, db_conn, hs) diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 2dfe4c0b66..1af0af1266 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -186,11 +186,13 @@ class StreamIdGenerator(AbstractStreamIdGenerator): column: str, extra_tables: Iterable[Tuple[str, str]] = (), step: int = 1, + is_writer: bool = True, ) -> None: assert step != 0 self._lock = threading.Lock() self._step: int = step self._current: int = _load_current_id(db_conn, table, column, step) + self._is_writer = is_writer for table, column in extra_tables: self._current = (max if step > 0 else min)( self._current, _load_current_id(db_conn, table, column, step) @@ -204,9 +206,11 @@ class StreamIdGenerator(AbstractStreamIdGenerator): self._unfinished_ids: OrderedDict[int, int] = OrderedDict() def advance(self, instance_name: str, new_id: int) -> None: - # `StreamIdGenerator` should only be used when there is a single writer, - # so replication should never happen. - raise Exception("Replication is not supported by StreamIdGenerator") + # Advance should never be called on a writer instance, only over replication + if self._is_writer: + raise Exception("Replication is not supported by writer StreamIdGenerator") + + self._current = (max if self._step > 0 else min)(self._current, new_id) def get_next(self) -> AsyncContextManager[int]: with self._lock: @@ -249,6 +253,9 @@ class StreamIdGenerator(AbstractStreamIdGenerator): return _AsyncCtxManagerWrapper(manager()) def get_current_token(self) -> int: + if self._is_writer: + return self._current + with self._lock: if self._unfinished_ids: return next(iter(self._unfinished_ids)) - self._step -- cgit 1.5.1 From d63814fd736fed5d3d45ff3af5e6d3bfae50c439 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 16 Nov 2022 13:50:07 +0000 Subject: Revert "Remove slaved id tracker (#14376)" (#14463) This reverts commit 36097e88c4da51fce6556a58c49bd675f4cf20ab. --- changelog.d/14376.misc | 1 - synapse/replication/slave/__init__.py | 13 ++++++ synapse/replication/slave/storage/__init__.py | 13 ++++++ .../slave/storage/_slaved_id_tracker.py | 50 ++++++++++++++++++++++ synapse/storage/databases/main/account_data.py | 30 ++++++++----- synapse/storage/databases/main/devices.py | 36 ++++++++++------ synapse/storage/databases/main/events_worker.py | 35 +++++++++------ synapse/storage/databases/main/push_rule.py | 17 ++++---- synapse/storage/databases/main/pusher.py | 24 +++++++---- synapse/storage/databases/main/receipts.py | 18 ++++---- synapse/storage/util/id_generators.py | 13 ++---- 11 files changed, 176 insertions(+), 74 deletions(-) delete mode 100644 changelog.d/14376.misc create mode 100644 synapse/replication/slave/__init__.py create mode 100644 synapse/replication/slave/storage/__init__.py create mode 100644 synapse/replication/slave/storage/_slaved_id_tracker.py (limited to 'synapse/storage/databases/main/receipts.py') diff --git a/changelog.d/14376.misc b/changelog.d/14376.misc deleted file mode 100644 index 2ca326fea6..0000000000 --- a/changelog.d/14376.misc +++ /dev/null @@ -1 +0,0 @@ -Remove old stream ID tracking code. Contributed by Nick @Beeper (@fizzadar). diff --git a/synapse/replication/slave/__init__.py b/synapse/replication/slave/__init__.py new file mode 100644 index 0000000000..f43a360a80 --- /dev/null +++ b/synapse/replication/slave/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/synapse/replication/slave/storage/__init__.py b/synapse/replication/slave/storage/__init__.py new file mode 100644 index 0000000000..f43a360a80 --- /dev/null +++ b/synapse/replication/slave/storage/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py new file mode 100644 index 0000000000..8f3f953ed4 --- /dev/null +++ b/synapse/replication/slave/storage/_slaved_id_tracker.py @@ -0,0 +1,50 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Optional, Tuple + +from synapse.storage.database import LoggingDatabaseConnection +from synapse.storage.util.id_generators import AbstractStreamIdTracker, _load_current_id + + +class SlavedIdTracker(AbstractStreamIdTracker): + """Tracks the "current" stream ID of a stream with a single writer. + + See `AbstractStreamIdTracker` for more details. + + Note that this class does not work correctly when there are multiple + writers. + """ + + def __init__( + self, + db_conn: LoggingDatabaseConnection, + table: str, + column: str, + extra_tables: Optional[List[Tuple[str, str]]] = None, + step: int = 1, + ): + self.step = step + self._current = _load_current_id(db_conn, table, column, step) + if extra_tables: + for table, column in extra_tables: + self.advance(None, _load_current_id(db_conn, table, column)) + + def advance(self, instance_name: Optional[str], new_id: int) -> None: + self._current = (max if self.step > 0 else min)(self._current, new_id) + + def get_current_token(self) -> int: + return self._current + + def get_current_token_for_writer(self, instance_name: str) -> int: + return self.get_current_token() diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index 282687ebce..c38b8a9e5a 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -27,6 +27,7 @@ from typing import ( ) from synapse.api.constants import AccountDataTypes +from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream from synapse.storage._base import db_to_json from synapse.storage.database import ( @@ -67,11 +68,12 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) # to write account data. A value of `True` implies that `_account_data_id_gen` # is an `AbstractStreamIdGenerator` and not just a tracker. self._account_data_id_gen: AbstractStreamIdTracker - self._can_write_to_account_data = ( - self._instance_name in hs.config.worker.writers.account_data - ) if isinstance(database.engine, PostgresEngine): + self._can_write_to_account_data = ( + self._instance_name in hs.config.worker.writers.account_data + ) + self._account_data_id_gen = MultiWriterIdGenerator( db_conn=db_conn, db=database, @@ -93,13 +95,21 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets # updated over replication. (Multiple writers are not supported for # SQLite). - self._account_data_id_gen = StreamIdGenerator( - db_conn, - "room_account_data", - "stream_id", - extra_tables=[("room_tags_revisions", "stream_id")], - is_writer=self._instance_name in hs.config.worker.writers.account_data, - ) + if self._instance_name in hs.config.worker.writers.account_data: + self._can_write_to_account_data = True + self._account_data_id_gen = StreamIdGenerator( + db_conn, + "room_account_data", + "stream_id", + extra_tables=[("room_tags_revisions", "stream_id")], + ) + else: + self._account_data_id_gen = SlavedIdTracker( + db_conn, + "room_account_data", + "stream_id", + extra_tables=[("room_tags_revisions", "stream_id")], + ) account_max = self.get_max_account_data_stream_id() self._account_data_stream_cache = StreamChangeCache( diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 3e5c16b15b..aa58c2adc3 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -38,6 +38,7 @@ from synapse.logging.opentracing import ( whitelisted_homeserver, ) from synapse.metrics.background_process_metrics import wrap_as_background_process +from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import ( @@ -85,19 +86,28 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): ): super().__init__(database, db_conn, hs) - # In the worker store this is an ID tracker which we overwrite in the non-worker - # class below that is used on the main process. - self._device_list_id_gen: AbstractStreamIdTracker = StreamIdGenerator( - db_conn, - "device_lists_stream", - "stream_id", - extra_tables=[ - ("user_signature_stream", "stream_id"), - ("device_lists_outbound_pokes", "stream_id"), - ("device_lists_changes_in_room", "stream_id"), - ], - is_writer=hs.config.worker.worker_app is None, - ) + if hs.config.worker.worker_app is None: + self._device_list_id_gen: AbstractStreamIdTracker = StreamIdGenerator( + db_conn, + "device_lists_stream", + "stream_id", + extra_tables=[ + ("user_signature_stream", "stream_id"), + ("device_lists_outbound_pokes", "stream_id"), + ("device_lists_changes_in_room", "stream_id"), + ], + ) + else: + self._device_list_id_gen = SlavedIdTracker( + db_conn, + "device_lists_stream", + "stream_id", + extra_tables=[ + ("user_signature_stream", "stream_id"), + ("device_lists_outbound_pokes", "stream_id"), + ("device_lists_changes_in_room", "stream_id"), + ], + ) # Type-ignore: _device_list_id_gen is mixed in from either DataStore (as a # StreamIdGenerator) or SlavedDataStore (as a SlavedIdTracker). diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 296e50d661..467d20253d 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -59,6 +59,7 @@ from synapse.metrics.background_process_metrics import ( run_as_background_process, wrap_as_background_process, ) +from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import BackfillStream from synapse.replication.tcp.streams.events import EventsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause @@ -212,20 +213,26 @@ class EventsWorkerStore(SQLBaseStore): # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets # updated over replication. (Multiple writers are not supported for # SQLite). - self._stream_id_gen = StreamIdGenerator( - db_conn, - "events", - "stream_ordering", - is_writer=hs.get_instance_name() in hs.config.worker.writers.events, - ) - self._backfill_id_gen = StreamIdGenerator( - db_conn, - "events", - "stream_ordering", - step=-1, - extra_tables=[("ex_outlier_stream", "event_stream_ordering")], - is_writer=hs.get_instance_name() in hs.config.worker.writers.events, - ) + if hs.get_instance_name() in hs.config.worker.writers.events: + self._stream_id_gen = StreamIdGenerator( + db_conn, + "events", + "stream_ordering", + ) + self._backfill_id_gen = StreamIdGenerator( + db_conn, + "events", + "stream_ordering", + step=-1, + extra_tables=[("ex_outlier_stream", "event_stream_ordering")], + ) + else: + self._stream_id_gen = SlavedIdTracker( + db_conn, "events", "stream_ordering" + ) + self._backfill_id_gen = SlavedIdTracker( + db_conn, "events", "stream_ordering", step=-1 + ) events_max = self._stream_id_gen.get_current_token() curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict( diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 12ad44dbb3..8ae10f6127 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -30,6 +30,7 @@ from typing import ( from synapse.api.errors import StoreError from synapse.config.homeserver import ExperimentalConfig +from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import PushRulesStream from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( @@ -110,14 +111,14 @@ class PushRulesWorkerStore( ): super().__init__(database, db_conn, hs) - # In the worker store this is an ID tracker which we overwrite in the non-worker - # class below that is used on the main process. - self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator( - db_conn, - "push_rules_stream", - "stream_id", - is_writer=hs.config.worker.worker_app is None, - ) + if hs.config.worker.worker_app is None: + self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator( + db_conn, "push_rules_stream", "stream_id" + ) + else: + self._push_rules_stream_id_gen = SlavedIdTracker( + db_conn, "push_rules_stream", "stream_id" + ) push_rules_prefill, push_rules_id = self.db_pool.get_cache_dict( db_conn, diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index fee37b9ce4..4a01562d45 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -27,6 +27,7 @@ from typing import ( ) from synapse.push import PusherConfig, ThrottleParams +from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import PushersStream from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import ( @@ -58,15 +59,20 @@ class PusherWorkerStore(SQLBaseStore): ): super().__init__(database, db_conn, hs) - # In the worker store this is an ID tracker which we overwrite in the non-worker - # class below that is used on the main process. - self._pushers_id_gen: AbstractStreamIdTracker = StreamIdGenerator( - db_conn, - "pushers", - "id", - extra_tables=[("deleted_pushers", "stream_id")], - is_writer=hs.config.worker.worker_app is None, - ) + if hs.config.worker.worker_app is None: + self._pushers_id_gen: AbstractStreamIdTracker = StreamIdGenerator( + db_conn, + "pushers", + "id", + extra_tables=[("deleted_pushers", "stream_id")], + ) + else: + self._pushers_id_gen = SlavedIdTracker( + db_conn, + "pushers", + "id", + extra_tables=[("deleted_pushers", "stream_id")], + ) self.db_pool.updates.register_background_update_handler( "remove_deactivated_pushers", diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 64519587f8..dc6989527e 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -27,6 +27,7 @@ from typing import ( ) from synapse.api.constants import EduTypes +from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import ReceiptsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import ( @@ -60,9 +61,6 @@ class ReceiptsWorkerStore(SQLBaseStore): hs: "HomeServer", ): self._instance_name = hs.get_instance_name() - - # In the worker store this is an ID tracker which we overwrite in the non-worker - # class below that is used on the main process. self._receipts_id_gen: AbstractStreamIdTracker if isinstance(database.engine, PostgresEngine): @@ -89,12 +87,14 @@ class ReceiptsWorkerStore(SQLBaseStore): # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets # updated over replication. (Multiple writers are not supported for # SQLite). - self._receipts_id_gen = StreamIdGenerator( - db_conn, - "receipts_linearized", - "stream_id", - is_writer=hs.get_instance_name() in hs.config.worker.writers.receipts, - ) + if hs.get_instance_name() in hs.config.worker.writers.receipts: + self._receipts_id_gen = StreamIdGenerator( + db_conn, "receipts_linearized", "stream_id" + ) + else: + self._receipts_id_gen = SlavedIdTracker( + db_conn, "receipts_linearized", "stream_id" + ) super().__init__(database, db_conn, hs) diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 1af0af1266..2dfe4c0b66 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -186,13 +186,11 @@ class StreamIdGenerator(AbstractStreamIdGenerator): column: str, extra_tables: Iterable[Tuple[str, str]] = (), step: int = 1, - is_writer: bool = True, ) -> None: assert step != 0 self._lock = threading.Lock() self._step: int = step self._current: int = _load_current_id(db_conn, table, column, step) - self._is_writer = is_writer for table, column in extra_tables: self._current = (max if step > 0 else min)( self._current, _load_current_id(db_conn, table, column, step) @@ -206,11 +204,9 @@ class StreamIdGenerator(AbstractStreamIdGenerator): self._unfinished_ids: OrderedDict[int, int] = OrderedDict() def advance(self, instance_name: str, new_id: int) -> None: - # Advance should never be called on a writer instance, only over replication - if self._is_writer: - raise Exception("Replication is not supported by writer StreamIdGenerator") - - self._current = (max if self._step > 0 else min)(self._current, new_id) + # `StreamIdGenerator` should only be used when there is a single writer, + # so replication should never happen. + raise Exception("Replication is not supported by StreamIdGenerator") def get_next(self) -> AsyncContextManager[int]: with self._lock: @@ -253,9 +249,6 @@ class StreamIdGenerator(AbstractStreamIdGenerator): return _AsyncCtxManagerWrapper(manager()) def get_current_token(self) -> int: - if self._is_writer: - return self._current - with self._lock: if self._unfinished_ids: return next(iter(self._unfinished_ids)) - self._step -- cgit 1.5.1 From 882277008c7b43ab26e3445ab94a38aa25ad0965 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Wed, 16 Nov 2022 15:01:22 +0000 Subject: Fix background updates failing to add unique indexes on receipts (#14453) As part of the database migration to support threaded receipts, there is a possible window in between `73/08thread_receipts_non_null.sql.postgres` removing the original unique constraints on `receipts_linearized` and `receipts_graph` and the `reeipts_linearized_unique_index` and `receipts_graph_unique_index` background updates from `72/08thread_receipts.sql` completing where the unique constraints on `receipts_linearized` and `receipts_graph` are missing. Any emulated upserts on these tables must therefore be performed with a lock held, otherwise duplicate rows can end up in the tables when there are concurrent emulated upserts. Fix the missing lock. Note that emulated upserts no longer happen by default on sqlite, since the minimum supported version of sqlite supports native upserts by default now. Finally, clean up any duplicate receipts that may have crept in before trying to create the `receipts_graph_unique_index` and `receipts_linearized_unique_index` unique indexes. Signed-off-by: Sean Quah --- changelog.d/14453.bugfix | 1 + synapse/storage/databases/main/receipts.py | 171 ++++++++++++++++++--- tests/storage/databases/main/test_receipts.py | 209 ++++++++++++++++++++++++++ 3 files changed, 357 insertions(+), 24 deletions(-) create mode 100644 changelog.d/14453.bugfix create mode 100644 tests/storage/databases/main/test_receipts.py (limited to 'synapse/storage/databases/main/receipts.py') diff --git a/changelog.d/14453.bugfix b/changelog.d/14453.bugfix new file mode 100644 index 0000000000..4969e5450c --- /dev/null +++ b/changelog.d/14453.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.70.0 where the background updates to add non-thread unique indexes on receipts could fail when upgrading from 1.67.0 or earlier. diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index dc6989527e..fbf27497ec 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -113,24 +113,6 @@ class ReceiptsWorkerStore(SQLBaseStore): prefilled_cache=receipts_stream_prefill, ) - self.db_pool.updates.register_background_index_update( - "receipts_linearized_unique_index", - index_name="receipts_linearized_unique_index", - table="receipts_linearized", - columns=["room_id", "receipt_type", "user_id"], - where_clause="thread_id IS NULL", - unique=True, - ) - - self.db_pool.updates.register_background_index_update( - "receipts_graph_unique_index", - index_name="receipts_graph_unique_index", - table="receipts_graph", - columns=["room_id", "receipt_type", "user_id"], - where_clause="thread_id IS NULL", - unique=True, - ) - def get_max_receipt_stream_id(self) -> int: """Get the current max stream ID for receipts stream""" return self._receipts_id_gen.get_current_token() @@ -702,9 +684,6 @@ class ReceiptsWorkerStore(SQLBaseStore): "data": json_encoder.encode(data), }, where_clause=where_clause, - # receipts_linearized has a unique constraint on - # (user_id, room_id, receipt_type), so no need to lock - lock=False, ) return rx_ts @@ -862,14 +841,13 @@ class ReceiptsWorkerStore(SQLBaseStore): "data": json_encoder.encode(data), }, where_clause=where_clause, - # receipts_graph has a unique constraint on - # (user_id, room_id, receipt_type), so no need to lock - lock=False, ) class ReceiptsBackgroundUpdateStore(SQLBaseStore): POPULATE_RECEIPT_EVENT_STREAM_ORDERING = "populate_event_stream_ordering" + RECEIPTS_LINEARIZED_UNIQUE_INDEX_UPDATE_NAME = "receipts_linearized_unique_index" + RECEIPTS_GRAPH_UNIQUE_INDEX_UPDATE_NAME = "receipts_graph_unique_index" def __init__( self, @@ -883,6 +861,14 @@ class ReceiptsBackgroundUpdateStore(SQLBaseStore): self.POPULATE_RECEIPT_EVENT_STREAM_ORDERING, self._populate_receipt_event_stream_ordering, ) + self.db_pool.updates.register_background_update_handler( + self.RECEIPTS_LINEARIZED_UNIQUE_INDEX_UPDATE_NAME, + self._background_receipts_linearized_unique_index, + ) + self.db_pool.updates.register_background_update_handler( + self.RECEIPTS_GRAPH_UNIQUE_INDEX_UPDATE_NAME, + self._background_receipts_graph_unique_index, + ) async def _populate_receipt_event_stream_ordering( self, progress: JsonDict, batch_size: int @@ -938,6 +924,143 @@ class ReceiptsBackgroundUpdateStore(SQLBaseStore): return batch_size + async def _create_receipts_index(self, index_name: str, table: str) -> None: + """Adds a unique index on `(room_id, receipt_type, user_id)` to the given + receipts table, for non-thread receipts.""" + + def _create_index(conn: LoggingDatabaseConnection) -> None: + conn.rollback() + + # we have to set autocommit, because postgres refuses to + # CREATE INDEX CONCURRENTLY without it. + if isinstance(self.database_engine, PostgresEngine): + conn.set_session(autocommit=True) + + try: + c = conn.cursor() + + # Now that the duplicates are gone, we can create the index. + concurrently = ( + "CONCURRENTLY" + if isinstance(self.database_engine, PostgresEngine) + else "" + ) + sql = f""" + CREATE UNIQUE INDEX {concurrently} {index_name} + ON {table}(room_id, receipt_type, user_id) + WHERE thread_id IS NULL + """ + c.execute(sql) + finally: + if isinstance(self.database_engine, PostgresEngine): + conn.set_session(autocommit=False) + + await self.db_pool.runWithConnection(_create_index) + + async def _background_receipts_linearized_unique_index( + self, progress: dict, batch_size: int + ) -> int: + """Removes duplicate receipts and adds a unique index on + `(room_id, receipt_type, user_id)` to `receipts_linearized`, for non-thread + receipts.""" + + def _remote_duplicate_receipts_txn(txn: LoggingTransaction) -> None: + # Identify any duplicate receipts arising from + # https://github.com/matrix-org/synapse/issues/14406. + # We expect the following query to use the per-thread receipt index and take + # less than a minute. + sql = """ + SELECT MAX(stream_id), room_id, receipt_type, user_id + FROM receipts_linearized + WHERE thread_id IS NULL + GROUP BY room_id, receipt_type, user_id + HAVING COUNT(*) > 1 + """ + txn.execute(sql) + duplicate_keys = cast(List[Tuple[int, str, str, str]], list(txn)) + + # Then remove duplicate receipts, keeping the one with the highest + # `stream_id`. There should only be a single receipt with any given + # `stream_id`. + for max_stream_id, room_id, receipt_type, user_id in duplicate_keys: + sql = """ + DELETE FROM receipts_linearized + WHERE + room_id = ? AND + receipt_type = ? AND + user_id = ? AND + thread_id IS NULL AND + stream_id < ? + """ + txn.execute(sql, (room_id, receipt_type, user_id, max_stream_id)) + + await self.db_pool.runInteraction( + self.RECEIPTS_LINEARIZED_UNIQUE_INDEX_UPDATE_NAME, + _remote_duplicate_receipts_txn, + ) + + await self._create_receipts_index( + "receipts_linearized_unique_index", + "receipts_linearized", + ) + + await self.db_pool.updates._end_background_update( + self.RECEIPTS_LINEARIZED_UNIQUE_INDEX_UPDATE_NAME + ) + + return 1 + + async def _background_receipts_graph_unique_index( + self, progress: dict, batch_size: int + ) -> int: + """Removes duplicate receipts and adds a unique index on + `(room_id, receipt_type, user_id)` to `receipts_graph`, for non-thread + receipts.""" + + def _remote_duplicate_receipts_txn(txn: LoggingTransaction) -> None: + # Identify any duplicate receipts arising from + # https://github.com/matrix-org/synapse/issues/14406. + # We expect the following query to use the per-thread receipt index and take + # less than a minute. + sql = """ + SELECT room_id, receipt_type, user_id FROM receipts_graph + WHERE thread_id IS NULL + GROUP BY room_id, receipt_type, user_id + HAVING COUNT(*) > 1 + """ + txn.execute(sql) + duplicate_keys = cast(List[Tuple[str, str, str]], list(txn)) + + # Then remove all duplicate receipts. + # We could be clever and try to keep the latest receipt out of every set of + # duplicates, but it's far simpler to remove them all. + for room_id, receipt_type, user_id in duplicate_keys: + sql = """ + DELETE FROM receipts_graph + WHERE + room_id = ? AND + receipt_type = ? AND + user_id = ? AND + thread_id IS NULL + """ + txn.execute(sql, (room_id, receipt_type, user_id)) + + await self.db_pool.runInteraction( + self.RECEIPTS_GRAPH_UNIQUE_INDEX_UPDATE_NAME, + _remote_duplicate_receipts_txn, + ) + + await self._create_receipts_index( + "receipts_graph_unique_index", + "receipts_graph", + ) + + await self.db_pool.updates._end_background_update( + self.RECEIPTS_GRAPH_UNIQUE_INDEX_UPDATE_NAME + ) + + return 1 + class ReceiptsStore(ReceiptsWorkerStore, ReceiptsBackgroundUpdateStore): pass diff --git a/tests/storage/databases/main/test_receipts.py b/tests/storage/databases/main/test_receipts.py new file mode 100644 index 0000000000..c4f12d81d7 --- /dev/null +++ b/tests/storage/databases/main/test_receipts.py @@ -0,0 +1,209 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional, Sequence, Tuple + +from twisted.test.proto_helpers import MemoryReactor + +from synapse.rest import admin +from synapse.rest.client import login, room +from synapse.server import HomeServer +from synapse.storage.database import LoggingTransaction +from synapse.util import Clock + +from tests.unittest import HomeserverTestCase + + +class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase): + + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer): + self.store = hs.get_datastores().main + self.user_id = self.register_user("foo", "pass") + self.token = self.login("foo", "pass") + self.room_id = self.helper.create_room_as(self.user_id, tok=self.token) + self.other_room_id = self.helper.create_room_as(self.user_id, tok=self.token) + + def _test_background_receipts_unique_index( + self, + update_name: str, + index_name: str, + table: str, + receipts: Dict[Tuple[str, str, str], Sequence[Dict[str, Any]]], + expected_unique_receipts: Dict[Tuple[str, str, str], Optional[Dict[str, Any]]], + ): + """Test that the background update to uniqueify non-thread receipts in + the given receipts table works properly. + + Args: + update_name: The name of the background update to test. + index_name: The name of the index that the background update creates. + table: The table of receipts that the background update fixes. + receipts: The test data containing duplicate receipts. + A list of receipt rows to insert, grouped by + `(room_id, receipt_type, user_id)`. + expected_unique_receipts: A dictionary of `(room_id, receipt_type, user_id)` + keys and expected receipt key-values after duplicate receipts have been + removed. + """ + # First, undo the background update. + def drop_receipts_unique_index(txn: LoggingTransaction) -> None: + txn.execute(f"DROP INDEX IF EXISTS {index_name}") + + self.get_success( + self.store.db_pool.runInteraction( + "drop_receipts_unique_index", + drop_receipts_unique_index, + ) + ) + + # Populate the receipts table, including duplicates. + for (room_id, receipt_type, user_id), rows in receipts.items(): + for row in rows: + self.get_success( + self.store.db_pool.simple_insert( + table, + { + "room_id": room_id, + "receipt_type": receipt_type, + "user_id": user_id, + "thread_id": None, + "data": "{}", + **row, + }, + ) + ) + + # Insert and run the background update. + self.get_success( + self.store.db_pool.simple_insert( + "background_updates", + { + "update_name": update_name, + "progress_json": "{}", + }, + ) + ) + + self.store.db_pool.updates._all_done = False + + self.wait_for_background_updates() + + # Check that the remaining receipts match expectations. + for ( + room_id, + receipt_type, + user_id, + ), expected_row in expected_unique_receipts.items(): + # Include the receipt key in the returned columns, for more informative + # assertion messages. + columns = ["room_id", "receipt_type", "user_id"] + if expected_row is not None: + columns += expected_row.keys() + + rows = self.get_success( + self.store.db_pool.simple_select_list( + table=table, + keyvalues={ + "room_id": room_id, + "receipt_type": receipt_type, + "user_id": user_id, + # `simple_select_onecol` does not support NULL filters, + # so skip the filter on `thread_id`. + }, + retcols=columns, + desc="get_receipt", + ) + ) + + if expected_row is not None: + self.assertEqual( + len(rows), + 1, + f"Background update did not leave behind latest receipt in {table}", + ) + self.assertEqual( + rows[0], + { + "room_id": room_id, + "receipt_type": receipt_type, + "user_id": user_id, + **expected_row, + }, + ) + else: + self.assertEqual( + len(rows), + 0, + f"Background update did not remove all duplicate receipts from {table}", + ) + + def test_background_receipts_linearized_unique_index(self): + """Test that the background update to uniqueify non-thread receipts in + `receipts_linearized` works properly. + """ + self._test_background_receipts_unique_index( + "receipts_linearized_unique_index", + "receipts_linearized_unique_index", + "receipts_linearized", + receipts={ + (self.room_id, "m.read", self.user_id): [ + {"stream_id": 5, "event_id": "$some_event"}, + {"stream_id": 6, "event_id": "$some_event"}, + ], + (self.other_room_id, "m.read", self.user_id): [ + {"stream_id": 7, "event_id": "$some_event"} + ], + }, + expected_unique_receipts={ + (self.room_id, "m.read", self.user_id): {"stream_id": 6}, + (self.other_room_id, "m.read", self.user_id): {"stream_id": 7}, + }, + ) + + def test_background_receipts_graph_unique_index(self): + """Test that the background update to uniqueify non-thread receipts in + `receipts_graph` works properly. + """ + self._test_background_receipts_unique_index( + "receipts_graph_unique_index", + "receipts_graph_unique_index", + "receipts_graph", + receipts={ + (self.room_id, "m.read", self.user_id): [ + { + "event_ids": '["$some_event"]', + }, + { + "event_ids": '["$some_event"]', + }, + ], + (self.other_room_id, "m.read", self.user_id): [ + { + "event_ids": '["$some_event"]', + } + ], + }, + expected_unique_receipts={ + (self.room_id, "m.read", self.user_id): None, + (self.other_room_id, "m.read", self.user_id): { + "event_ids": '["$some_event"]' + }, + }, + ) -- cgit 1.5.1 From 115f0eb2334b13665e5c112bd87f95ea393c9047 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Wed, 16 Nov 2022 22:16:46 +0000 Subject: Reintroduce #14376, with bugfix for monoliths (#14468) * Add tests for StreamIdGenerator * Drive-by: annotate all defs * Revert "Revert "Remove slaved id tracker (#14376)" (#14463)" This reverts commit d63814fd736fed5d3d45ff3af5e6d3bfae50c439, which in turn reverted 36097e88c4da51fce6556a58c49bd675f4cf20ab. This restores the latter. * Fix StreamIdGenerator not handling unpersisted IDs Spotted by @erikjohnston. Closes #14456. * Changelog Co-authored-by: Nick Mills-Barrett Co-authored-by: Erik Johnston --- changelog.d/14376.misc | 1 + changelog.d/14468.misc | 1 + mypy.ini | 3 + synapse/replication/slave/__init__.py | 13 -- synapse/replication/slave/storage/__init__.py | 13 -- .../slave/storage/_slaved_id_tracker.py | 50 ------- synapse/storage/databases/main/account_data.py | 30 ++-- synapse/storage/databases/main/devices.py | 36 ++--- synapse/storage/databases/main/events_worker.py | 35 ++--- synapse/storage/databases/main/push_rule.py | 17 +-- synapse/storage/databases/main/pusher.py | 24 ++- synapse/storage/databases/main/receipts.py | 18 +-- synapse/storage/util/id_generators.py | 13 +- tests/storage/test_id_generators.py | 162 +++++++++++++++++++-- 14 files changed, 230 insertions(+), 186 deletions(-) create mode 100644 changelog.d/14376.misc create mode 100644 changelog.d/14468.misc delete mode 100644 synapse/replication/slave/__init__.py delete mode 100644 synapse/replication/slave/storage/__init__.py delete mode 100644 synapse/replication/slave/storage/_slaved_id_tracker.py (limited to 'synapse/storage/databases/main/receipts.py') diff --git a/changelog.d/14376.misc b/changelog.d/14376.misc new file mode 100644 index 0000000000..2ca326fea6 --- /dev/null +++ b/changelog.d/14376.misc @@ -0,0 +1 @@ +Remove old stream ID tracking code. Contributed by Nick @Beeper (@fizzadar). diff --git a/changelog.d/14468.misc b/changelog.d/14468.misc new file mode 100644 index 0000000000..2ca326fea6 --- /dev/null +++ b/changelog.d/14468.misc @@ -0,0 +1 @@ +Remove old stream ID tracking code. Contributed by Nick @Beeper (@fizzadar). diff --git a/mypy.ini b/mypy.ini index 8f1141a239..53512b2584 100644 --- a/mypy.ini +++ b/mypy.ini @@ -117,6 +117,9 @@ disallow_untyped_defs = True [mypy-tests.state.test_profile] disallow_untyped_defs = True +[mypy-tests.storage.test_id_generators] +disallow_untyped_defs = True + [mypy-tests.storage.test_profile] disallow_untyped_defs = True diff --git a/synapse/replication/slave/__init__.py b/synapse/replication/slave/__init__.py deleted file mode 100644 index f43a360a80..0000000000 --- a/synapse/replication/slave/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/synapse/replication/slave/storage/__init__.py b/synapse/replication/slave/storage/__init__.py deleted file mode 100644 index f43a360a80..0000000000 --- a/synapse/replication/slave/storage/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py deleted file mode 100644 index 8f3f953ed4..0000000000 --- a/synapse/replication/slave/storage/_slaved_id_tracker.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import List, Optional, Tuple - -from synapse.storage.database import LoggingDatabaseConnection -from synapse.storage.util.id_generators import AbstractStreamIdTracker, _load_current_id - - -class SlavedIdTracker(AbstractStreamIdTracker): - """Tracks the "current" stream ID of a stream with a single writer. - - See `AbstractStreamIdTracker` for more details. - - Note that this class does not work correctly when there are multiple - writers. - """ - - def __init__( - self, - db_conn: LoggingDatabaseConnection, - table: str, - column: str, - extra_tables: Optional[List[Tuple[str, str]]] = None, - step: int = 1, - ): - self.step = step - self._current = _load_current_id(db_conn, table, column, step) - if extra_tables: - for table, column in extra_tables: - self.advance(None, _load_current_id(db_conn, table, column)) - - def advance(self, instance_name: Optional[str], new_id: int) -> None: - self._current = (max if self.step > 0 else min)(self._current, new_id) - - def get_current_token(self) -> int: - return self._current - - def get_current_token_for_writer(self, instance_name: str) -> int: - return self.get_current_token() diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index c38b8a9e5a..282687ebce 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -27,7 +27,6 @@ from typing import ( ) from synapse.api.constants import AccountDataTypes -from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream from synapse.storage._base import db_to_json from synapse.storage.database import ( @@ -68,12 +67,11 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) # to write account data. A value of `True` implies that `_account_data_id_gen` # is an `AbstractStreamIdGenerator` and not just a tracker. self._account_data_id_gen: AbstractStreamIdTracker + self._can_write_to_account_data = ( + self._instance_name in hs.config.worker.writers.account_data + ) if isinstance(database.engine, PostgresEngine): - self._can_write_to_account_data = ( - self._instance_name in hs.config.worker.writers.account_data - ) - self._account_data_id_gen = MultiWriterIdGenerator( db_conn=db_conn, db=database, @@ -95,21 +93,13 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets # updated over replication. (Multiple writers are not supported for # SQLite). - if self._instance_name in hs.config.worker.writers.account_data: - self._can_write_to_account_data = True - self._account_data_id_gen = StreamIdGenerator( - db_conn, - "room_account_data", - "stream_id", - extra_tables=[("room_tags_revisions", "stream_id")], - ) - else: - self._account_data_id_gen = SlavedIdTracker( - db_conn, - "room_account_data", - "stream_id", - extra_tables=[("room_tags_revisions", "stream_id")], - ) + self._account_data_id_gen = StreamIdGenerator( + db_conn, + "room_account_data", + "stream_id", + extra_tables=[("room_tags_revisions", "stream_id")], + is_writer=self._instance_name in hs.config.worker.writers.account_data, + ) account_max = self.get_max_account_data_stream_id() self._account_data_stream_cache = StreamChangeCache( diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index e114c733d1..57230df5ae 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -38,7 +38,6 @@ from synapse.logging.opentracing import ( whitelisted_homeserver, ) from synapse.metrics.background_process_metrics import wrap_as_background_process -from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import ( @@ -86,28 +85,19 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): ): super().__init__(database, db_conn, hs) - if hs.config.worker.worker_app is None: - self._device_list_id_gen: AbstractStreamIdTracker = StreamIdGenerator( - db_conn, - "device_lists_stream", - "stream_id", - extra_tables=[ - ("user_signature_stream", "stream_id"), - ("device_lists_outbound_pokes", "stream_id"), - ("device_lists_changes_in_room", "stream_id"), - ], - ) - else: - self._device_list_id_gen = SlavedIdTracker( - db_conn, - "device_lists_stream", - "stream_id", - extra_tables=[ - ("user_signature_stream", "stream_id"), - ("device_lists_outbound_pokes", "stream_id"), - ("device_lists_changes_in_room", "stream_id"), - ], - ) + # In the worker store this is an ID tracker which we overwrite in the non-worker + # class below that is used on the main process. + self._device_list_id_gen: AbstractStreamIdTracker = StreamIdGenerator( + db_conn, + "device_lists_stream", + "stream_id", + extra_tables=[ + ("user_signature_stream", "stream_id"), + ("device_lists_outbound_pokes", "stream_id"), + ("device_lists_changes_in_room", "stream_id"), + ], + is_writer=hs.config.worker.worker_app is None, + ) # Type-ignore: _device_list_id_gen is mixed in from either DataStore (as a # StreamIdGenerator) or SlavedDataStore (as a SlavedIdTracker). diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 8a104f7e93..01e935edef 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -59,7 +59,6 @@ from synapse.metrics.background_process_metrics import ( run_as_background_process, wrap_as_background_process, ) -from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import BackfillStream from synapse.replication.tcp.streams.events import EventsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause @@ -213,26 +212,20 @@ class EventsWorkerStore(SQLBaseStore): # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets # updated over replication. (Multiple writers are not supported for # SQLite). - if hs.get_instance_name() in hs.config.worker.writers.events: - self._stream_id_gen = StreamIdGenerator( - db_conn, - "events", - "stream_ordering", - ) - self._backfill_id_gen = StreamIdGenerator( - db_conn, - "events", - "stream_ordering", - step=-1, - extra_tables=[("ex_outlier_stream", "event_stream_ordering")], - ) - else: - self._stream_id_gen = SlavedIdTracker( - db_conn, "events", "stream_ordering" - ) - self._backfill_id_gen = SlavedIdTracker( - db_conn, "events", "stream_ordering", step=-1 - ) + self._stream_id_gen = StreamIdGenerator( + db_conn, + "events", + "stream_ordering", + is_writer=hs.get_instance_name() in hs.config.worker.writers.events, + ) + self._backfill_id_gen = StreamIdGenerator( + db_conn, + "events", + "stream_ordering", + step=-1, + extra_tables=[("ex_outlier_stream", "event_stream_ordering")], + is_writer=hs.get_instance_name() in hs.config.worker.writers.events, + ) events_max = self._stream_id_gen.get_current_token() curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict( diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 8ae10f6127..12ad44dbb3 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -30,7 +30,6 @@ from typing import ( from synapse.api.errors import StoreError from synapse.config.homeserver import ExperimentalConfig -from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import PushRulesStream from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( @@ -111,14 +110,14 @@ class PushRulesWorkerStore( ): super().__init__(database, db_conn, hs) - if hs.config.worker.worker_app is None: - self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator( - db_conn, "push_rules_stream", "stream_id" - ) - else: - self._push_rules_stream_id_gen = SlavedIdTracker( - db_conn, "push_rules_stream", "stream_id" - ) + # In the worker store this is an ID tracker which we overwrite in the non-worker + # class below that is used on the main process. + self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator( + db_conn, + "push_rules_stream", + "stream_id", + is_writer=hs.config.worker.worker_app is None, + ) push_rules_prefill, push_rules_id = self.db_pool.get_cache_dict( db_conn, diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index 4a01562d45..fee37b9ce4 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -27,7 +27,6 @@ from typing import ( ) from synapse.push import PusherConfig, ThrottleParams -from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import PushersStream from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import ( @@ -59,20 +58,15 @@ class PusherWorkerStore(SQLBaseStore): ): super().__init__(database, db_conn, hs) - if hs.config.worker.worker_app is None: - self._pushers_id_gen: AbstractStreamIdTracker = StreamIdGenerator( - db_conn, - "pushers", - "id", - extra_tables=[("deleted_pushers", "stream_id")], - ) - else: - self._pushers_id_gen = SlavedIdTracker( - db_conn, - "pushers", - "id", - extra_tables=[("deleted_pushers", "stream_id")], - ) + # In the worker store this is an ID tracker which we overwrite in the non-worker + # class below that is used on the main process. + self._pushers_id_gen: AbstractStreamIdTracker = StreamIdGenerator( + db_conn, + "pushers", + "id", + extra_tables=[("deleted_pushers", "stream_id")], + is_writer=hs.config.worker.worker_app is None, + ) self.db_pool.updates.register_background_update_handler( "remove_deactivated_pushers", diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index fbf27497ec..a580e4bdda 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -27,7 +27,6 @@ from typing import ( ) from synapse.api.constants import EduTypes -from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import ReceiptsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import ( @@ -61,6 +60,9 @@ class ReceiptsWorkerStore(SQLBaseStore): hs: "HomeServer", ): self._instance_name = hs.get_instance_name() + + # In the worker store this is an ID tracker which we overwrite in the non-worker + # class below that is used on the main process. self._receipts_id_gen: AbstractStreamIdTracker if isinstance(database.engine, PostgresEngine): @@ -87,14 +89,12 @@ class ReceiptsWorkerStore(SQLBaseStore): # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets # updated over replication. (Multiple writers are not supported for # SQLite). - if hs.get_instance_name() in hs.config.worker.writers.receipts: - self._receipts_id_gen = StreamIdGenerator( - db_conn, "receipts_linearized", "stream_id" - ) - else: - self._receipts_id_gen = SlavedIdTracker( - db_conn, "receipts_linearized", "stream_id" - ) + self._receipts_id_gen = StreamIdGenerator( + db_conn, + "receipts_linearized", + "stream_id", + is_writer=hs.get_instance_name() in hs.config.worker.writers.receipts, + ) super().__init__(database, db_conn, hs) diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 2dfe4c0b66..0d7108f01b 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -186,11 +186,13 @@ class StreamIdGenerator(AbstractStreamIdGenerator): column: str, extra_tables: Iterable[Tuple[str, str]] = (), step: int = 1, + is_writer: bool = True, ) -> None: assert step != 0 self._lock = threading.Lock() self._step: int = step self._current: int = _load_current_id(db_conn, table, column, step) + self._is_writer = is_writer for table, column in extra_tables: self._current = (max if step > 0 else min)( self._current, _load_current_id(db_conn, table, column, step) @@ -204,9 +206,11 @@ class StreamIdGenerator(AbstractStreamIdGenerator): self._unfinished_ids: OrderedDict[int, int] = OrderedDict() def advance(self, instance_name: str, new_id: int) -> None: - # `StreamIdGenerator` should only be used when there is a single writer, - # so replication should never happen. - raise Exception("Replication is not supported by StreamIdGenerator") + # Advance should never be called on a writer instance, only over replication + if self._is_writer: + raise Exception("Replication is not supported by writer StreamIdGenerator") + + self._current = (max if self._step > 0 else min)(self._current, new_id) def get_next(self) -> AsyncContextManager[int]: with self._lock: @@ -249,6 +253,9 @@ class StreamIdGenerator(AbstractStreamIdGenerator): return _AsyncCtxManagerWrapper(manager()) def get_current_token(self) -> int: + if not self._is_writer: + return self._current + with self._lock: if self._unfinished_ids: return next(iter(self._unfinished_ids)) - self._step diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py index 2d8d1f860f..d6a2b8d274 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py @@ -16,15 +16,157 @@ from typing import List, Optional from twisted.test.proto_helpers import MemoryReactor from synapse.server import HomeServer -from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.engines import IncorrectDatabaseSetup -from synapse.storage.util.id_generators import MultiWriterIdGenerator +from synapse.storage.types import Cursor +from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.util import Clock from tests.unittest import HomeserverTestCase from tests.utils import USE_POSTGRES_FOR_TESTS +class StreamIdGeneratorTestCase(HomeserverTestCase): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main + self.db_pool: DatabasePool = self.store.db_pool + + self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) + + def _setup_db(self, txn: LoggingTransaction) -> None: + txn.execute( + """ + CREATE TABLE foobar ( + stream_id BIGINT NOT NULL, + data TEXT + ); + """ + ) + txn.execute("INSERT INTO foobar VALUES (123, 'hello world');") + + def _create_id_generator(self) -> StreamIdGenerator: + def _create(conn: LoggingDatabaseConnection) -> StreamIdGenerator: + return StreamIdGenerator( + db_conn=conn, + table="foobar", + column="stream_id", + ) + + return self.get_success_or_raise(self.db_pool.runWithConnection(_create)) + + def test_initial_value(self) -> None: + """Check that we read the current token from the DB.""" + id_gen = self._create_id_generator() + self.assertEqual(id_gen.get_current_token(), 123) + + def test_single_gen_next(self) -> None: + """Check that we correctly increment the current token from the DB.""" + id_gen = self._create_id_generator() + + async def test_gen_next() -> None: + async with id_gen.get_next() as next_id: + # We haven't persisted `next_id` yet; current token is still 123 + self.assertEqual(id_gen.get_current_token(), 123) + # But we did learn what the next value is + self.assertEqual(next_id, 124) + + # Once the context manager closes we assume that the `next_id` has been + # written to the DB. + self.assertEqual(id_gen.get_current_token(), 124) + + self.get_success(test_gen_next()) + + def test_multiple_gen_nexts(self) -> None: + """Check that we handle overlapping calls to gen_next sensibly.""" + id_gen = self._create_id_generator() + + async def test_gen_next() -> None: + ctx1 = id_gen.get_next() + ctx2 = id_gen.get_next() + ctx3 = id_gen.get_next() + + # Request three new stream IDs. + self.assertEqual(await ctx1.__aenter__(), 124) + self.assertEqual(await ctx2.__aenter__(), 125) + self.assertEqual(await ctx3.__aenter__(), 126) + + # None are persisted: current token unchanged. + self.assertEqual(id_gen.get_current_token(), 123) + + # Persist each in turn. + await ctx1.__aexit__(None, None, None) + self.assertEqual(id_gen.get_current_token(), 124) + await ctx2.__aexit__(None, None, None) + self.assertEqual(id_gen.get_current_token(), 125) + await ctx3.__aexit__(None, None, None) + self.assertEqual(id_gen.get_current_token(), 126) + + self.get_success(test_gen_next()) + + def test_multiple_gen_nexts_closed_in_different_order(self) -> None: + """Check that we handle overlapping calls to gen_next, even when their IDs + created and persisted in different orders.""" + id_gen = self._create_id_generator() + + async def test_gen_next() -> None: + ctx1 = id_gen.get_next() + ctx2 = id_gen.get_next() + ctx3 = id_gen.get_next() + + # Request three new stream IDs. + self.assertEqual(await ctx1.__aenter__(), 124) + self.assertEqual(await ctx2.__aenter__(), 125) + self.assertEqual(await ctx3.__aenter__(), 126) + + # None are persisted: current token unchanged. + self.assertEqual(id_gen.get_current_token(), 123) + + # Persist them in a different order, starting with 126 from ctx3. + await ctx3.__aexit__(None, None, None) + # We haven't persisted 124 from ctx1 yet---current token is still 123. + self.assertEqual(id_gen.get_current_token(), 123) + + # Now persist 124 from ctx1. + await ctx1.__aexit__(None, None, None) + # Current token is then 124, waiting for 125 to be persisted. + self.assertEqual(id_gen.get_current_token(), 124) + + # Finally persist 125 from ctx2. + await ctx2.__aexit__(None, None, None) + # Current token is then 126 (skipping over 125). + self.assertEqual(id_gen.get_current_token(), 126) + + self.get_success(test_gen_next()) + + def test_gen_next_while_still_waiting_for_persistence(self) -> None: + """Check that we handle overlapping calls to gen_next.""" + id_gen = self._create_id_generator() + + async def test_gen_next() -> None: + ctx1 = id_gen.get_next() + ctx2 = id_gen.get_next() + ctx3 = id_gen.get_next() + + # Request two new stream IDs. + self.assertEqual(await ctx1.__aenter__(), 124) + self.assertEqual(await ctx2.__aenter__(), 125) + + # Persist ctx2 first. + await ctx2.__aexit__(None, None, None) + # Still waiting on ctx1's ID to be persisted. + self.assertEqual(id_gen.get_current_token(), 123) + + # Now request a third stream ID. It should be 126 (the smallest ID that + # we've not yet handed out.) + self.assertEqual(await ctx3.__aenter__(), 126) + + self.get_success(test_gen_next()) + + class MultiWriterIdGeneratorTestCase(HomeserverTestCase): if not USE_POSTGRES_FOR_TESTS: skip = "Requires Postgres" @@ -48,9 +190,9 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): ) def _create_id_generator( - self, instance_name="master", writers: Optional[List[str]] = None + self, instance_name: str = "master", writers: Optional[List[str]] = None ) -> MultiWriterIdGenerator: - def _create(conn): + def _create(conn: LoggingDatabaseConnection) -> MultiWriterIdGenerator: return MultiWriterIdGenerator( conn, self.db_pool, @@ -446,7 +588,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self._insert_row_with_id("master", 3) # Now we add a row *without* updating the stream ID - def _insert(txn): + def _insert(txn: Cursor) -> None: txn.execute("INSERT INTO foobar VALUES (26, 'master')") self.get_success(self.db_pool.runInteraction("_insert", _insert)) @@ -481,9 +623,9 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): ) def _create_id_generator( - self, instance_name="master", writers: Optional[List[str]] = None + self, instance_name: str = "master", writers: Optional[List[str]] = None ) -> MultiWriterIdGenerator: - def _create(conn): + def _create(conn: LoggingDatabaseConnection) -> MultiWriterIdGenerator: return MultiWriterIdGenerator( conn, self.db_pool, @@ -617,9 +759,9 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase): ) def _create_id_generator( - self, instance_name="master", writers: Optional[List[str]] = None + self, instance_name: str = "master", writers: Optional[List[str]] = None ) -> MultiWriterIdGenerator: - def _create(conn): + def _create(conn: LoggingDatabaseConnection) -> MultiWriterIdGenerator: return MultiWriterIdGenerator( conn, self.db_pool, @@ -641,7 +783,7 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase): instance_name: str, number: int, update_stream_table: bool = True, - ): + ) -> None: """Insert N rows as the given instance, inserting with stream IDs pulled from the postgres sequence. """ -- cgit 1.5.1