From b5facbac0f2d5f6f0e83d7cac43f8de02ce6742f Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 21 May 2024 16:48:20 +0100 Subject: Improve perf of sync device lists (#17216) Re-introduces #17191, and includes #17197 and #17214 The basic idea is to stop calling `get_rooms_for_user` everywhere, and instead use the table `device_lists_changes_in_room`. Commits reviewable one-by-one. --- synapse/storage/databases/main/devices.py | 89 +++++++++++++++++++++++-------- 1 file changed, 68 insertions(+), 21 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 8dbcb3f5a0..f4410b5c02 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -70,10 +70,7 @@ from synapse.types import ( from synapse.util import json_decoder, json_encoder from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.lrucache import LruCache -from synapse.util.caches.stream_change_cache import ( - AllEntitiesChangedResult, - StreamChangeCache, -) +from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.cancellation import cancellable from synapse.util.iterutils import batch_iter from synapse.util.stringutils import shortstr @@ -132,6 +129,20 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): prefilled_cache=device_list_prefill, ) + device_list_room_prefill, min_device_list_room_id = self.db_pool.get_cache_dict( + db_conn, + "device_lists_changes_in_room", + entity_column="room_id", + stream_column="stream_id", + max_value=device_list_max, + limit=10000, + ) + self._device_list_room_stream_cache = StreamChangeCache( + "DeviceListRoomStreamChangeCache", + min_device_list_room_id, + prefilled_cache=device_list_room_prefill, + ) + ( user_signature_stream_prefill, user_signature_stream_list_id, @@ -209,6 +220,13 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): row.entity, token ) + def device_lists_in_rooms_have_changed( + self, room_ids: StrCollection, token: int + ) -> None: + "Record that device lists have changed in rooms" + for room_id in room_ids: + self._device_list_room_stream_cache.entity_has_changed(room_id, token) + def get_device_stream_token(self) -> int: return self._device_list_id_gen.get_current_token() @@ -832,16 +850,6 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): ) return {device[0]: db_to_json(device[1]) for device in devices} - def get_cached_device_list_changes( - self, - from_key: int, - ) -> AllEntitiesChangedResult: - """Get set of users whose devices have changed since `from_key`, or None - if that information is not in our cache. - """ - - return self._device_list_stream_cache.get_all_entities_changed(from_key) - @cancellable async def get_all_devices_changed( self, @@ -1457,7 +1465,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): @cancellable async def get_device_list_changes_in_rooms( - self, room_ids: Collection[str], from_id: int + self, room_ids: Collection[str], from_id: int, to_id: int ) -> Optional[Set[str]]: """Return the set of users whose devices have changed in the given rooms since the given stream ID. @@ -1473,9 +1481,15 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): if min_stream_id > from_id: return None + changed_room_ids = self._device_list_room_stream_cache.get_entities_changed( + room_ids, from_id + ) + if not changed_room_ids: + return set() + sql = """ SELECT DISTINCT user_id FROM device_lists_changes_in_room - WHERE {clause} AND stream_id >= ? + WHERE {clause} AND stream_id > ? AND stream_id <= ? """ def _get_device_list_changes_in_rooms_txn( @@ -1487,11 +1501,12 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): return {user_id for user_id, in txn} changes = set() - for chunk in batch_iter(room_ids, 1000): + for chunk in batch_iter(changed_room_ids, 1000): clause, args = make_in_list_sql_clause( self.database_engine, "room_id", chunk ) args.append(from_id) + args.append(to_id) changes |= await self.db_pool.runInteraction( "get_device_list_changes_in_rooms", @@ -1502,6 +1517,34 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): return changes + async def get_all_device_list_changes(self, from_id: int, to_id: int) -> Set[str]: + """Return the set of rooms where devices have changed since the given + stream ID. + + Will raise an exception if the given stream ID is too old. + """ + + min_stream_id = await self._get_min_device_lists_changes_in_room() + + if min_stream_id > from_id: + raise Exception("stream ID is too old") + + sql = """ + SELECT DISTINCT room_id FROM device_lists_changes_in_room + WHERE stream_id > ? AND stream_id <= ? + """ + + def _get_all_device_list_changes_txn( + txn: LoggingTransaction, + ) -> Set[str]: + txn.execute(sql, (from_id, to_id)) + return {room_id for room_id, in txn} + + return await self.db_pool.runInteraction( + "get_all_device_list_changes", + _get_all_device_list_changes_txn, + ) + async def get_device_list_changes_in_room( self, room_id: str, min_stream_id: int ) -> Collection[Tuple[str, str]]: @@ -1962,8 +2005,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): async def add_device_change_to_streams( self, user_id: str, - device_ids: Collection[str], - room_ids: Collection[str], + device_ids: StrCollection, + room_ids: StrCollection, ) -> Optional[int]: """Persist that a user's devices have been updated, and which hosts (if any) should be poked. @@ -2122,8 +2165,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): self, txn: LoggingTransaction, user_id: str, - device_ids: Iterable[str], - room_ids: Collection[str], + device_ids: StrCollection, + room_ids: StrCollection, stream_ids: List[int], context: Dict[str, str], ) -> None: @@ -2161,6 +2204,10 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ], ) + txn.call_after( + self.device_lists_in_rooms_have_changed, room_ids, max(stream_ids) + ) + async def get_uncoverted_outbound_room_pokes( self, start_stream_id: int, start_room_id: str, limit: int = 10 ) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]: -- cgit 1.5.1 From b71d2774388c90a68d71dd8d805556c8f62c92a1 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 22 May 2024 13:55:18 +0100 Subject: Reduce work of calculating outbound device pokes (#17211) --- changelog.d/17211.misc | 1 + synapse/handlers/device.py | 7 +++++++ synapse/storage/databases/main/devices.py | 24 ++++++++++++++++++++++++ 3 files changed, 32 insertions(+) create mode 100644 changelog.d/17211.misc (limited to 'synapse/storage') diff --git a/changelog.d/17211.misc b/changelog.d/17211.misc new file mode 100644 index 0000000000..144db03a40 --- /dev/null +++ b/changelog.d/17211.misc @@ -0,0 +1 @@ +Reduce work of calculating outbound device lists updates. diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 55842e7c7b..0432d97109 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -906,6 +906,13 @@ class DeviceHandler(DeviceWorkerHandler): context=opentracing_context, ) + await self.store.mark_redundant_device_lists_pokes( + user_id=user_id, + device_id=device_id, + room_id=room_id, + converted_upto_stream_id=stream_id, + ) + # Notify replication that we've updated the device list stream. self.notifier.notify_replication() diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index f4410b5c02..48384e238c 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -2161,6 +2161,30 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): }, ) + async def mark_redundant_device_lists_pokes( + self, + user_id: str, + device_id: str, + room_id: str, + converted_upto_stream_id: int, + ) -> None: + """If we've calculated the outbound pokes for a given room/device list + update, mark any subsequent changes as already converted""" + + sql = """ + UPDATE device_lists_changes_in_room + SET converted_to_destinations = true + WHERE stream_id > ? AND user_id = ? AND device_id = ? + AND room_id = ? AND NOT converted_to_destinations + """ + + def mark_redundant_device_lists_pokes_txn(txn: LoggingTransaction) -> None: + txn.execute(sql, (converted_upto_stream_id, user_id, device_id, room_id)) + + return await self.db_pool.runInteraction( + "mark_redundant_device_lists_pokes", mark_redundant_device_lists_pokes_txn + ) + def _add_device_outbound_room_poke_txn( self, txn: LoggingTransaction, -- cgit 1.5.1 From 967b6948b0d738bc685d433d44e82631fd2ad232 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 29 May 2024 12:04:13 +0100 Subject: Change allow_unsafe_locale to also apply on new databases (#17238) We relax this as there are use cases where this is safe, though it is still highly recommended that people avoid using it. --- changelog.d/17238.misc | 1 + docs/postgres.md | 11 +++++------ synapse/storage/engines/postgres.py | 8 +++++++- 3 files changed, 13 insertions(+), 7 deletions(-) create mode 100644 changelog.d/17238.misc (limited to 'synapse/storage') diff --git a/changelog.d/17238.misc b/changelog.d/17238.misc new file mode 100644 index 0000000000..261467e55c --- /dev/null +++ b/changelog.d/17238.misc @@ -0,0 +1 @@ +Change the `allow_unsafe_locale` config option to also apply when setting up new databases. diff --git a/docs/postgres.md b/docs/postgres.md index ae34f7689b..4b2ba38275 100644 --- a/docs/postgres.md +++ b/docs/postgres.md @@ -242,12 +242,11 @@ host all all ::1/128 ident ### Fixing incorrect `COLLATE` or `CTYPE` -Synapse will refuse to set up a new database if it has the wrong values of -`COLLATE` and `CTYPE` set. Synapse will also refuse to start an existing database with incorrect values -of `COLLATE` and `CTYPE` unless the config flag `allow_unsafe_locale`, found in the -`database` section of the config, is set to true. Using different locales can cause issues if the locale library is updated from -underneath the database, or if a different version of the locale is used on any -replicas. +Synapse will refuse to start when using a database with incorrect values of +`COLLATE` and `CTYPE` unless the config flag `allow_unsafe_locale`, found in the +`database` section of the config, is set to true. Using different locales can +cause issues if the locale library is updated from underneath the database, or +if a different version of the locale is used on any replicas. If you have a database with an unsafe locale, the safest way to fix the issue is to dump the database and recreate it with the correct locale parameter (as shown above). It is also possible to change the diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index b9168ee074..90641d5a18 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -142,6 +142,10 @@ class PostgresEngine( apply stricter checks on new databases versus existing database. """ + allow_unsafe_locale = self.config.get("allow_unsafe_locale", False) + if allow_unsafe_locale: + return + collation, ctype = self.get_db_locale(txn) errors = [] @@ -155,7 +159,9 @@ class PostgresEngine( if errors: raise IncorrectDatabaseSetup( "Database is incorrectly configured:\n\n%s\n\n" - "See docs/postgres.md for more information." % ("\n".join(errors)) + "See docs/postgres.md for more information. You can override this check by" + "setting 'allow_unsafe_locale' to true in the database config.", + "\n".join(errors), ) def convert_param_style(self, sql: str) -> str: -- cgit 1.5.1 From 726006cdf2dfea3bcac9f6e0e912646b1751bdb7 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 29 May 2024 12:57:10 +0100 Subject: Don't invalidate all `get_relations_for_event` on history purge (#17083) This is a tree cache already, so may as well move the room ID to the front and use that --- changelog.d/17083.misc | 1 + synapse/handlers/relations.py | 2 +- synapse/storage/databases/main/cache.py | 18 +++++++++++++--- synapse/storage/databases/main/events.py | 7 ++++++- .../storage/databases/main/events_bg_updates.py | 24 +++++++++++++++------- synapse/storage/databases/main/relations.py | 2 +- 6 files changed, 41 insertions(+), 13 deletions(-) create mode 100644 changelog.d/17083.misc (limited to 'synapse/storage') diff --git a/changelog.d/17083.misc b/changelog.d/17083.misc new file mode 100644 index 0000000000..7c7cebea4e --- /dev/null +++ b/changelog.d/17083.misc @@ -0,0 +1 @@ +Improve DB usage when fetching related events. diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index c5cee8860b..de092f8623 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -393,9 +393,9 @@ class RelationsHandler: # Attempt to find another event to use as the latest event. potential_events, _ = await self._main_store.get_relations_for_event( + room_id, event_id, event, - room_id, RelationTypes.THREAD, direction=Direction.FORWARDS, ) diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index bfd492d95d..c6787faea0 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -318,7 +318,13 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self._invalidate_local_get_event_cache(redacts) # type: ignore[attr-defined] # Caches which might leak edits must be invalidated for the event being # redacted. - self._attempt_to_invalidate_cache("get_relations_for_event", (redacts,)) + self._attempt_to_invalidate_cache( + "get_relations_for_event", + ( + room_id, + redacts, + ), + ) self._attempt_to_invalidate_cache("get_applicable_edit", (redacts,)) self._attempt_to_invalidate_cache("get_thread_id", (redacts,)) self._attempt_to_invalidate_cache("get_thread_id_for_receipts", (redacts,)) @@ -345,7 +351,13 @@ class CacheInvalidationWorkerStore(SQLBaseStore): ) if relates_to: - self._attempt_to_invalidate_cache("get_relations_for_event", (relates_to,)) + self._attempt_to_invalidate_cache( + "get_relations_for_event", + ( + room_id, + relates_to, + ), + ) self._attempt_to_invalidate_cache("get_references_for_event", (relates_to,)) self._attempt_to_invalidate_cache("get_applicable_edit", (relates_to,)) self._attempt_to_invalidate_cache("get_thread_summary", (relates_to,)) @@ -380,9 +392,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self._attempt_to_invalidate_cache( "get_unread_event_push_actions_by_room_for_user", (room_id,) ) + self._attempt_to_invalidate_cache("get_relations_for_event", (room_id,)) self._attempt_to_invalidate_cache("_get_membership_from_event_id", None) - self._attempt_to_invalidate_cache("get_relations_for_event", None) self._attempt_to_invalidate_cache("get_applicable_edit", None) self._attempt_to_invalidate_cache("get_thread_id", None) self._attempt_to_invalidate_cache("get_thread_id_for_receipts", None) diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 990698aa5c..fd7167904d 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1923,7 +1923,12 @@ class PersistEventsStore: # Any relation information for the related event must be cleared. self.store._invalidate_cache_and_stream( - txn, self.store.get_relations_for_event, (redacted_relates_to,) + txn, + self.store.get_relations_for_event, + ( + room_id, + redacted_relates_to, + ), ) if rel_type == RelationTypes.REFERENCE: self.store._invalidate_cache_and_stream( diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index 6c979f9f2c..64d303e330 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -1181,7 +1181,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): results = list(txn) # (event_id, parent_id, rel_type) for each relation - relations_to_insert: List[Tuple[str, str, str]] = [] + relations_to_insert: List[Tuple[str, str, str, str]] = [] for event_id, event_json_raw in results: try: event_json = db_to_json(event_json_raw) @@ -1214,7 +1214,8 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): if not isinstance(parent_id, str): continue - relations_to_insert.append((event_id, parent_id, rel_type)) + room_id = event_json["room_id"] + relations_to_insert.append((room_id, event_id, parent_id, rel_type)) # Insert the missing data, note that we upsert here in case the event # has already been processed. @@ -1223,18 +1224,27 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): txn=txn, table="event_relations", key_names=("event_id",), - key_values=[(r[0],) for r in relations_to_insert], + key_values=[(r[1],) for r in relations_to_insert], value_names=("relates_to_id", "relation_type"), - value_values=[r[1:] for r in relations_to_insert], + value_values=[r[2:] for r in relations_to_insert], ) # Iterate the parent IDs and invalidate caches. - cache_tuples = {(r[1],) for r in relations_to_insert} self._invalidate_cache_and_stream_bulk( # type: ignore[attr-defined] - txn, self.get_relations_for_event, cache_tuples # type: ignore[attr-defined] + txn, + self.get_relations_for_event, # type: ignore[attr-defined] + { + ( + r[0], # room_id + r[2], # parent_id + ) + for r in relations_to_insert + }, ) self._invalidate_cache_and_stream_bulk( # type: ignore[attr-defined] - txn, self.get_thread_summary, cache_tuples # type: ignore[attr-defined] + txn, + self.get_thread_summary, # type: ignore[attr-defined] + {(r[1],) for r in relations_to_insert}, ) if results: diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 77f3641525..29a001ff92 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -169,9 +169,9 @@ class RelationsWorkerStore(SQLBaseStore): @cached(uncached_args=("event",), tree=True) async def get_relations_for_event( self, + room_id: str, event_id: str, event: EventBase, - room_id: str, relation_type: Optional[str] = None, event_type: Optional[str] = None, limit: int = 5, -- cgit 1.5.1 From 466f344547fc6bea2c43257dd65286380fbb512d Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 29 May 2024 13:19:10 +0100 Subject: Move towards using `MultiWriterIdGenerator` everywhere (#17226) There is a problem with `StreamIdGenerator` where it can go backwards over restarts when a stream ID is requested but then not inserted into the DB. This is problematic if we want to land #17215, and is generally a potential cause for all sorts of nastiness. Instead of trying to fix `StreamIdGenerator`, we may as well move to `MultiWriterIdGenerator` that does not suffer from this problem (the latest positions are stored in `stream_positions` table). This involves adding SQLite support to the class. This only changes id generators that were already using `MultiWriterIdGenerator` under postgres, a separate PR will move the rest of the uses of `StreamIdGenerator` over. --- changelog.d/17226.misc | 1 + synapse/storage/database.py | 21 +- synapse/storage/databases/main/account_data.py | 47 +--- synapse/storage/databases/main/deviceinbox.py | 46 ++-- synapse/storage/databases/main/events_worker.py | 101 +++---- synapse/storage/databases/main/presence.py | 27 +- synapse/storage/databases/main/receipts.py | 43 +-- synapse/storage/databases/main/room.py | 34 +-- synapse/storage/util/id_generators.py | 49 +++- tests/storage/test_id_generators.py | 351 +++++++++++++----------- 10 files changed, 341 insertions(+), 379 deletions(-) create mode 100644 changelog.d/17226.misc (limited to 'synapse/storage') diff --git a/changelog.d/17226.misc b/changelog.d/17226.misc new file mode 100644 index 0000000000..7c023a5759 --- /dev/null +++ b/changelog.d/17226.misc @@ -0,0 +1 @@ +Move towards using `MultiWriterIdGenerator` everywhere. diff --git a/synapse/storage/database.py b/synapse/storage/database.py index d9c85e411e..569f618193 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -2461,7 +2461,11 @@ class DatabasePool: def make_in_list_sql_clause( - database_engine: BaseDatabaseEngine, column: str, iterable: Collection[Any] + database_engine: BaseDatabaseEngine, + column: str, + iterable: Collection[Any], + *, + negative: bool = False, ) -> Tuple[str, list]: """Returns an SQL clause that checks the given column is in the iterable. @@ -2474,6 +2478,7 @@ def make_in_list_sql_clause( database_engine column: Name of the column iterable: The values to check the column against. + negative: Whether we should check for inequality, i.e. `NOT IN` Returns: A tuple of SQL query and the args @@ -2482,9 +2487,19 @@ def make_in_list_sql_clause( if database_engine.supports_using_any_list: # This should hopefully be faster, but also makes postgres query # stats easier to understand. - return "%s = ANY(?)" % (column,), [list(iterable)] + if not negative: + clause = f"{column} = ANY(?)" + else: + clause = f"{column} != ALL(?)" + + return clause, [list(iterable)] else: - return "%s IN (%s)" % (column, ",".join("?" for _ in iterable)), list(iterable) + params = ",".join("?" for _ in iterable) + if not negative: + clause = f"{column} IN ({params})" + else: + clause = f"{column} NOT IN ({params})" + return clause, list(iterable) # These overloads ensure that `columns` and `iterable` values have the same length. diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index 563450a97e..9611a84932 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -43,11 +43,9 @@ from synapse.storage.database import ( ) from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.databases.main.push_rule import PushRulesWorkerStore -from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import ( AbstractStreamIdGenerator, MultiWriterIdGenerator, - StreamIdGenerator, ) from synapse.types import JsonDict, JsonMapping from synapse.util import json_encoder @@ -75,37 +73,20 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) self._account_data_id_gen: AbstractStreamIdGenerator - if isinstance(database.engine, PostgresEngine): - self._account_data_id_gen = MultiWriterIdGenerator( - db_conn=db_conn, - db=database, - notifier=hs.get_replication_notifier(), - stream_name="account_data", - instance_name=self._instance_name, - tables=[ - ("room_account_data", "instance_name", "stream_id"), - ("room_tags_revisions", "instance_name", "stream_id"), - ("account_data", "instance_name", "stream_id"), - ], - sequence_name="account_data_sequence", - writers=hs.config.worker.writers.account_data, - ) - else: - # Multiple writers are not supported for SQLite. - # - # We shouldn't be running in worker mode with SQLite, but its useful - # to support it for unit tests. - self._account_data_id_gen = StreamIdGenerator( - db_conn, - hs.get_replication_notifier(), - "room_account_data", - "stream_id", - extra_tables=[ - ("account_data", "stream_id"), - ("room_tags_revisions", "stream_id"), - ], - is_writer=self._instance_name in hs.config.worker.writers.account_data, - ) + self._account_data_id_gen = MultiWriterIdGenerator( + db_conn=db_conn, + db=database, + notifier=hs.get_replication_notifier(), + stream_name="account_data", + instance_name=self._instance_name, + tables=[ + ("room_account_data", "instance_name", "stream_id"), + ("room_tags_revisions", "instance_name", "stream_id"), + ("account_data", "instance_name", "stream_id"), + ], + sequence_name="account_data_sequence", + writers=hs.config.worker.writers.account_data, + ) account_max = self.get_max_account_data_stream_id() self._account_data_stream_cache = StreamChangeCache( diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index e17821ff6e..25023b5e7a 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -50,11 +50,9 @@ from synapse.storage.database import ( LoggingTransaction, make_in_list_sql_clause, ) -from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import ( AbstractStreamIdGenerator, MultiWriterIdGenerator, - StreamIdGenerator, ) from synapse.types import JsonDict from synapse.util import json_encoder @@ -89,35 +87,23 @@ class DeviceInboxWorkerStore(SQLBaseStore): expiry_ms=30 * 60 * 1000, ) - if isinstance(database.engine, PostgresEngine): - self._can_write_to_device = ( - self._instance_name in hs.config.worker.writers.to_device - ) + self._can_write_to_device = ( + self._instance_name in hs.config.worker.writers.to_device + ) - self._to_device_msg_id_gen: AbstractStreamIdGenerator = ( - MultiWriterIdGenerator( - db_conn=db_conn, - db=database, - notifier=hs.get_replication_notifier(), - stream_name="to_device", - instance_name=self._instance_name, - tables=[ - ("device_inbox", "instance_name", "stream_id"), - ("device_federation_outbox", "instance_name", "stream_id"), - ], - sequence_name="device_inbox_sequence", - writers=hs.config.worker.writers.to_device, - ) - ) - else: - self._can_write_to_device = True - self._to_device_msg_id_gen = StreamIdGenerator( - db_conn, - hs.get_replication_notifier(), - "device_inbox", - "stream_id", - extra_tables=[("device_federation_outbox", "stream_id")], - ) + self._to_device_msg_id_gen: AbstractStreamIdGenerator = MultiWriterIdGenerator( + db_conn=db_conn, + db=database, + notifier=hs.get_replication_notifier(), + stream_name="to_device", + instance_name=self._instance_name, + tables=[ + ("device_inbox", "instance_name", "stream_id"), + ("device_federation_outbox", "instance_name", "stream_id"), + ], + sequence_name="device_inbox_sequence", + writers=hs.config.worker.writers.to_device, + ) max_device_inbox_id = self._to_device_msg_id_gen.get_current_token() device_inbox_prefill, min_device_inbox_id = self.db_pool.get_cache_dict( diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index e39d4b9624..426df2a9d2 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -75,12 +75,10 @@ from synapse.storage.database import ( LoggingDatabaseConnection, LoggingTransaction, ) -from synapse.storage.engines import PostgresEngine from synapse.storage.types import Cursor from synapse.storage.util.id_generators import ( AbstractStreamIdGenerator, MultiWriterIdGenerator, - StreamIdGenerator, ) from synapse.storage.util.sequence import build_sequence_generator from synapse.types import JsonDict, get_domain_from_id @@ -195,51 +193,28 @@ class EventsWorkerStore(SQLBaseStore): self._stream_id_gen: AbstractStreamIdGenerator self._backfill_id_gen: AbstractStreamIdGenerator - if isinstance(database.engine, PostgresEngine): - # If we're using Postgres than we can use `MultiWriterIdGenerator` - # regardless of whether this process writes to the streams or not. - self._stream_id_gen = MultiWriterIdGenerator( - db_conn=db_conn, - db=database, - notifier=hs.get_replication_notifier(), - stream_name="events", - instance_name=hs.get_instance_name(), - tables=[("events", "instance_name", "stream_ordering")], - sequence_name="events_stream_seq", - writers=hs.config.worker.writers.events, - ) - self._backfill_id_gen = MultiWriterIdGenerator( - db_conn=db_conn, - db=database, - notifier=hs.get_replication_notifier(), - stream_name="backfill", - instance_name=hs.get_instance_name(), - tables=[("events", "instance_name", "stream_ordering")], - sequence_name="events_backfill_stream_seq", - positive=False, - writers=hs.config.worker.writers.events, - ) - else: - # Multiple writers are not supported for SQLite. - # - # We shouldn't be running in worker mode with SQLite, but its useful - # to support it for unit tests. - self._stream_id_gen = StreamIdGenerator( - db_conn, - hs.get_replication_notifier(), - "events", - "stream_ordering", - is_writer=hs.get_instance_name() in hs.config.worker.writers.events, - ) - self._backfill_id_gen = StreamIdGenerator( - db_conn, - hs.get_replication_notifier(), - "events", - "stream_ordering", - step=-1, - extra_tables=[("ex_outlier_stream", "event_stream_ordering")], - is_writer=hs.get_instance_name() in hs.config.worker.writers.events, - ) + + self._stream_id_gen = MultiWriterIdGenerator( + db_conn=db_conn, + db=database, + notifier=hs.get_replication_notifier(), + stream_name="events", + instance_name=hs.get_instance_name(), + tables=[("events", "instance_name", "stream_ordering")], + sequence_name="events_stream_seq", + writers=hs.config.worker.writers.events, + ) + self._backfill_id_gen = MultiWriterIdGenerator( + db_conn=db_conn, + db=database, + notifier=hs.get_replication_notifier(), + stream_name="backfill", + instance_name=hs.get_instance_name(), + tables=[("events", "instance_name", "stream_ordering")], + sequence_name="events_backfill_stream_seq", + positive=False, + writers=hs.config.worker.writers.events, + ) events_max = self._stream_id_gen.get_current_token() curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict( @@ -309,27 +284,17 @@ class EventsWorkerStore(SQLBaseStore): self._un_partial_stated_events_stream_id_gen: AbstractStreamIdGenerator - if isinstance(database.engine, PostgresEngine): - self._un_partial_stated_events_stream_id_gen = MultiWriterIdGenerator( - db_conn=db_conn, - db=database, - notifier=hs.get_replication_notifier(), - stream_name="un_partial_stated_event_stream", - instance_name=hs.get_instance_name(), - tables=[ - ("un_partial_stated_event_stream", "instance_name", "stream_id") - ], - sequence_name="un_partial_stated_event_stream_sequence", - # TODO(faster_joins, multiple writers) Support multiple writers. - writers=["master"], - ) - else: - self._un_partial_stated_events_stream_id_gen = StreamIdGenerator( - db_conn, - hs.get_replication_notifier(), - "un_partial_stated_event_stream", - "stream_id", - ) + self._un_partial_stated_events_stream_id_gen = MultiWriterIdGenerator( + db_conn=db_conn, + db=database, + notifier=hs.get_replication_notifier(), + stream_name="un_partial_stated_event_stream", + instance_name=hs.get_instance_name(), + tables=[("un_partial_stated_event_stream", "instance_name", "stream_id")], + sequence_name="un_partial_stated_event_stream_sequence", + # TODO(faster_joins, multiple writers) Support multiple writers. + writers=["master"], + ) def get_un_partial_stated_events_token(self, instance_name: str) -> int: return ( diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py index 567c2d30bd..923e764491 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py @@ -40,13 +40,11 @@ from synapse.storage.database import ( LoggingTransaction, ) from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore -from synapse.storage.engines import PostgresEngine from synapse.storage.engines._base import IsolationLevel from synapse.storage.types import Connection from synapse.storage.util.id_generators import ( AbstractStreamIdGenerator, MultiWriterIdGenerator, - StreamIdGenerator, ) from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -91,21 +89,16 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore) self._instance_name in hs.config.worker.writers.presence ) - if isinstance(database.engine, PostgresEngine): - self._presence_id_gen = MultiWriterIdGenerator( - db_conn=db_conn, - db=database, - notifier=hs.get_replication_notifier(), - stream_name="presence_stream", - instance_name=self._instance_name, - tables=[("presence_stream", "instance_name", "stream_id")], - sequence_name="presence_stream_sequence", - writers=hs.config.worker.writers.presence, - ) - else: - self._presence_id_gen = StreamIdGenerator( - db_conn, hs.get_replication_notifier(), "presence_stream", "stream_id" - ) + self._presence_id_gen = MultiWriterIdGenerator( + db_conn=db_conn, + db=database, + notifier=hs.get_replication_notifier(), + stream_name="presence_stream", + instance_name=self._instance_name, + tables=[("presence_stream", "instance_name", "stream_id")], + sequence_name="presence_stream_sequence", + writers=hs.config.worker.writers.presence, + ) self.hs = hs self._presence_on_startup = self._get_active_presence(db_conn) diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 13387a3839..8432560a89 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -44,12 +44,10 @@ from synapse.storage.database import ( LoggingDatabaseConnection, LoggingTransaction, ) -from synapse.storage.engines import PostgresEngine from synapse.storage.engines._base import IsolationLevel from synapse.storage.util.id_generators import ( AbstractStreamIdGenerator, MultiWriterIdGenerator, - StreamIdGenerator, ) from synapse.types import ( JsonDict, @@ -80,35 +78,20 @@ class ReceiptsWorkerStore(SQLBaseStore): # class below that is used on the main process. self._receipts_id_gen: AbstractStreamIdGenerator - if isinstance(database.engine, PostgresEngine): - self._can_write_to_receipts = ( - self._instance_name in hs.config.worker.writers.receipts - ) + self._can_write_to_receipts = ( + self._instance_name in hs.config.worker.writers.receipts + ) - self._receipts_id_gen = MultiWriterIdGenerator( - db_conn=db_conn, - db=database, - notifier=hs.get_replication_notifier(), - stream_name="receipts", - instance_name=self._instance_name, - tables=[("receipts_linearized", "instance_name", "stream_id")], - sequence_name="receipts_sequence", - writers=hs.config.worker.writers.receipts, - ) - else: - self._can_write_to_receipts = True - - # Multiple writers are not supported for SQLite. - # - # We shouldn't be running in worker mode with SQLite, but its useful - # to support it for unit tests. - self._receipts_id_gen = StreamIdGenerator( - db_conn, - hs.get_replication_notifier(), - "receipts_linearized", - "stream_id", - is_writer=hs.get_instance_name() in hs.config.worker.writers.receipts, - ) + self._receipts_id_gen = MultiWriterIdGenerator( + db_conn=db_conn, + db=database, + notifier=hs.get_replication_notifier(), + stream_name="receipts", + instance_name=self._instance_name, + tables=[("receipts_linearized", "instance_name", "stream_id")], + sequence_name="receipts_sequence", + writers=hs.config.worker.writers.receipts, + ) super().__init__(database, db_conn, hs) diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 8205109548..616c941687 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -58,13 +58,11 @@ from synapse.storage.database import ( LoggingTransaction, ) from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore -from synapse.storage.engines import PostgresEngine from synapse.storage.types import Cursor from synapse.storage.util.id_generators import ( AbstractStreamIdGenerator, IdGenerator, MultiWriterIdGenerator, - StreamIdGenerator, ) from synapse.types import JsonDict, RetentionPolicy, StrCollection, ThirdPartyInstanceID from synapse.util import json_encoder @@ -155,27 +153,17 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): self._un_partial_stated_rooms_stream_id_gen: AbstractStreamIdGenerator - if isinstance(database.engine, PostgresEngine): - self._un_partial_stated_rooms_stream_id_gen = MultiWriterIdGenerator( - db_conn=db_conn, - db=database, - notifier=hs.get_replication_notifier(), - stream_name="un_partial_stated_room_stream", - instance_name=self._instance_name, - tables=[ - ("un_partial_stated_room_stream", "instance_name", "stream_id") - ], - sequence_name="un_partial_stated_room_stream_sequence", - # TODO(faster_joins, multiple writers) Support multiple writers. - writers=["master"], - ) - else: - self._un_partial_stated_rooms_stream_id_gen = StreamIdGenerator( - db_conn, - hs.get_replication_notifier(), - "un_partial_stated_room_stream", - "stream_id", - ) + self._un_partial_stated_rooms_stream_id_gen = MultiWriterIdGenerator( + db_conn=db_conn, + db=database, + notifier=hs.get_replication_notifier(), + stream_name="un_partial_stated_room_stream", + instance_name=self._instance_name, + tables=[("un_partial_stated_room_stream", "instance_name", "stream_id")], + sequence_name="un_partial_stated_room_stream_sequence", + # TODO(faster_joins, multiple writers) Support multiple writers. + writers=["master"], + ) def process_replication_position( self, stream_name: str, instance_name: str, token: int diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index fadc75cc80..0cf5851ad7 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -53,9 +53,11 @@ from synapse.storage.database import ( DatabasePool, LoggingDatabaseConnection, LoggingTransaction, + make_in_list_sql_clause, ) +from synapse.storage.engines import PostgresEngine from synapse.storage.types import Cursor -from synapse.storage.util.sequence import PostgresSequenceGenerator +from synapse.storage.util.sequence import build_sequence_generator if TYPE_CHECKING: from synapse.notifier import ReplicationNotifier @@ -432,7 +434,22 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator): # no active writes in progress. self._max_position_of_local_instance = self._max_seen_allocated_stream_id - self._sequence_gen = PostgresSequenceGenerator(sequence_name) + # This goes and fills out the above state from the database. + self._load_current_ids(db_conn, tables) + + self._sequence_gen = build_sequence_generator( + db_conn=db_conn, + database_engine=db.engine, + get_first_callback=lambda _: self._persisted_upto_position, + sequence_name=sequence_name, + # We only need to set the below if we want it to call + # `check_consistency`, but we do that ourselves below so we can + # leave them blank. + table=None, + id_column=None, + stream_name=None, + positive=positive, + ) # We check that the table and sequence haven't diverged. for table, _, id_column in tables: @@ -444,9 +461,6 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator): positive=positive, ) - # This goes and fills out the above state from the database. - self._load_current_ids(db_conn, tables) - self._max_seen_allocated_stream_id = max( self._current_positions.values(), default=1 ) @@ -480,13 +494,17 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator): # important if we add back a writer after a long time; we want to # consider that a "new" writer, rather than using the old stale # entry here. - sql = """ + clause, args = make_in_list_sql_clause( + self._db.engine, "instance_name", self._writers, negative=True + ) + + sql = f""" DELETE FROM stream_positions WHERE stream_name = ? - AND instance_name != ALL(?) + AND {clause} """ - cur.execute(sql, (self._stream_name, self._writers)) + cur.execute(sql, [self._stream_name] + args) sql = """ SELECT instance_name, stream_id FROM stream_positions @@ -508,12 +526,16 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator): # We add a GREATEST here to ensure that the result is always # positive. (This can be a problem for e.g. backfill streams where # the server has never backfilled). + greatest_func = ( + "GREATEST" if isinstance(self._db.engine, PostgresEngine) else "MAX" + ) max_stream_id = 1 for table, _, id_column in tables: sql = """ - SELECT GREATEST(COALESCE(%(agg)s(%(id)s), 1), 1) + SELECT %(greatest_func)s(COALESCE(%(agg)s(%(id)s), 1), 1) FROM %(table)s """ % { + "greatest_func": greatest_func, "id": id_column, "table": table, "agg": "MAX" if self._positive else "-MIN", @@ -913,6 +935,11 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator): # We upsert the value, ensuring on conflict that we always increase the # value (or decrease if stream goes backwards). + if isinstance(self._db.engine, PostgresEngine): + agg = "GREATEST" if self._positive else "LEAST" + else: + agg = "MAX" if self._positive else "MIN" + sql = """ INSERT INTO stream_positions (stream_name, instance_name, stream_id) VALUES (?, ?, ?) @@ -920,10 +947,10 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator): DO UPDATE SET stream_id = %(agg)s(stream_positions.stream_id, EXCLUDED.stream_id) """ % { - "agg": "GREATEST" if self._positive else "LEAST", + "agg": agg, } - pos = (self.get_current_token_for_writer(self._instance_name),) + pos = self.get_current_token_for_writer(self._instance_name) txn.execute(sql, (self._stream_name, self._instance_name, pos)) diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py index 409d856ab9..fad9511cea 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py @@ -31,6 +31,11 @@ from synapse.storage.database import ( from synapse.storage.engines import IncorrectDatabaseSetup from synapse.storage.types import Cursor from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator +from synapse.storage.util.sequence import ( + LocalSequenceGenerator, + PostgresSequenceGenerator, + SequenceGenerator, +) from synapse.util import Clock from tests.unittest import HomeserverTestCase @@ -175,18 +180,22 @@ class StreamIdGeneratorTestCase(HomeserverTestCase): self.get_success(test_gen_next()) -class MultiWriterIdGeneratorTestCase(HomeserverTestCase): - if not USE_POSTGRES_FOR_TESTS: - skip = "Requires Postgres" - +class MultiWriterIdGeneratorBase(HomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main self.db_pool: DatabasePool = self.store.db_pool self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) + if USE_POSTGRES_FOR_TESTS: + self.seq_gen: SequenceGenerator = PostgresSequenceGenerator("foobar_seq") + else: + self.seq_gen = LocalSequenceGenerator(lambda _: 0) + def _setup_db(self, txn: LoggingTransaction) -> None: - txn.execute("CREATE SEQUENCE foobar_seq") + if USE_POSTGRES_FOR_TESTS: + txn.execute("CREATE SEQUENCE foobar_seq") + txn.execute( """ CREATE TABLE foobar ( @@ -221,44 +230,27 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): def _insert(txn: LoggingTransaction) -> None: for _ in range(number): + next_val = self.seq_gen.get_next_id_txn(txn) txn.execute( - "INSERT INTO foobar VALUES (nextval('foobar_seq'), ?)", - (instance_name,), + "INSERT INTO foobar (stream_id, instance_name) VALUES (?, ?)", + ( + next_val, + instance_name, + ), ) + txn.execute( """ - INSERT INTO stream_positions VALUES ('test_stream', ?, lastval()) - ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = lastval() + INSERT INTO stream_positions VALUES ('test_stream', ?, ?) + ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ? """, - (instance_name,), + (instance_name, next_val, next_val), ) self.get_success(self.db_pool.runInteraction("_insert_rows", _insert)) - def _insert_row_with_id(self, instance_name: str, stream_id: int) -> None: - """Insert one row as the given instance with given stream_id, updating - the postgres sequence position to match. - """ - - def _insert(txn: LoggingTransaction) -> None: - txn.execute( - "INSERT INTO foobar VALUES (?, ?)", - ( - stream_id, - instance_name, - ), - ) - txn.execute("SELECT setval('foobar_seq', ?)", (stream_id,)) - txn.execute( - """ - INSERT INTO stream_positions VALUES ('test_stream', ?, ?) - ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ? - """, - (instance_name, stream_id, stream_id), - ) - - self.get_success(self.db_pool.runInteraction("_insert_row_with_id", _insert)) +class MultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase): def test_empty(self) -> None: """Test an ID generator against an empty database gives sensible current positions. @@ -347,137 +339,106 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self.assertEqual(id_gen.get_positions(), {"master": 11}) self.assertEqual(id_gen.get_current_token_for_writer("master"), 11) - def test_multi_instance(self) -> None: - """Test that reads and writes from multiple processes are handled - correctly. - """ - self._insert_rows("first", 3) - self._insert_rows("second", 4) + def test_get_next_txn(self) -> None: + """Test that the `get_next_txn` function works correctly.""" - first_id_gen = self._create_id_generator("first", writers=["first", "second"]) - second_id_gen = self._create_id_generator("second", writers=["first", "second"]) + # Prefill table with 7 rows written by 'master' + self._insert_rows("master", 7) - self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7}) - self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7) - self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7) + id_gen = self._create_id_generator() - self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7}) - self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 7) - self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7) + self.assertEqual(id_gen.get_positions(), {"master": 7}) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) # Try allocating a new ID gen and check that we only see position # advanced after we leave the context manager. - async def _get_next_async() -> None: - async with first_id_gen.get_next() as stream_id: - self.assertEqual(stream_id, 8) - - self.assertEqual( - first_id_gen.get_positions(), {"first": 3, "second": 7} - ) - self.assertEqual(first_id_gen.get_persisted_upto_position(), 7) - - self.get_success(_get_next_async()) - - self.assertEqual(first_id_gen.get_positions(), {"first": 8, "second": 7}) - - # However the ID gen on the second instance won't have seen the update - self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7}) - - # ... but calling `get_next` on the second instance should give a unique - # stream ID + def _get_next_txn(txn: LoggingTransaction) -> None: + stream_id = id_gen.get_next_txn(txn) + self.assertEqual(stream_id, 8) - async def _get_next_async2() -> None: - async with second_id_gen.get_next() as stream_id: - self.assertEqual(stream_id, 9) + self.assertEqual(id_gen.get_positions(), {"master": 7}) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) - self.assertEqual( - second_id_gen.get_positions(), {"first": 3, "second": 7} - ) + self.get_success(self.db_pool.runInteraction("test", _get_next_txn)) - self.get_success(_get_next_async2()) + self.assertEqual(id_gen.get_positions(), {"master": 8}) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 8) - self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 9}) + def test_restart_during_out_of_order_persistence(self) -> None: + """Test that restarting a process while another process is writing out + of order updates are handled correctly. + """ - # If the second ID gen gets told about the first, it correctly updates - second_id_gen.advance("first", 8) - self.assertEqual(second_id_gen.get_positions(), {"first": 8, "second": 9}) + # Prefill table with 7 rows written by 'master' + self._insert_rows("master", 7) - def test_multi_instance_empty_row(self) -> None: - """Test that reads and writes from multiple processes are handled - correctly, when one of the writers starts without any rows. - """ - # Insert some rows for two out of three of the ID gens. - self._insert_rows("first", 3) - self._insert_rows("second", 4) + id_gen = self._create_id_generator() - first_id_gen = self._create_id_generator( - "first", writers=["first", "second", "third"] - ) - second_id_gen = self._create_id_generator( - "second", writers=["first", "second", "third"] - ) - third_id_gen = self._create_id_generator( - "third", writers=["first", "second", "third"] - ) + self.assertEqual(id_gen.get_positions(), {"master": 7}) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) - self.assertEqual( - first_id_gen.get_positions(), {"first": 3, "second": 7, "third": 7} - ) - self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7) - self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7) - self.assertEqual(first_id_gen.get_current_token_for_writer("third"), 7) + # Persist two rows at once + ctx1 = id_gen.get_next() + ctx2 = id_gen.get_next() - self.assertEqual( - second_id_gen.get_positions(), {"first": 3, "second": 7, "third": 7} - ) - self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 7) - self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7) - self.assertEqual(second_id_gen.get_current_token_for_writer("third"), 7) + s1 = self.get_success(ctx1.__aenter__()) + s2 = self.get_success(ctx2.__aenter__()) - # Try allocating a new ID gen and check that we only see position - # advanced after we leave the context manager. + self.assertEqual(s1, 8) + self.assertEqual(s2, 9) - async def _get_next_async() -> None: - async with third_id_gen.get_next() as stream_id: - self.assertEqual(stream_id, 8) + self.assertEqual(id_gen.get_positions(), {"master": 7}) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) - self.assertEqual( - third_id_gen.get_positions(), {"first": 3, "second": 7, "third": 7} - ) - self.assertEqual(third_id_gen.get_persisted_upto_position(), 7) + # We finish persisting the second row before restart + self.get_success(ctx2.__aexit__(None, None, None)) - self.get_success(_get_next_async()) + # We simulate a restart of another worker by just creating a new ID gen. + id_gen_worker = self._create_id_generator("worker") - self.assertEqual( - third_id_gen.get_positions(), {"first": 3, "second": 7, "third": 8} - ) + # Restarted worker should not see the second persisted row + self.assertEqual(id_gen_worker.get_positions(), {"master": 7}) + self.assertEqual(id_gen_worker.get_current_token_for_writer("master"), 7) - def test_get_next_txn(self) -> None: - """Test that the `get_next_txn` function works correctly.""" + # Now if we persist the first row then both instances should jump ahead + # correctly. + self.get_success(ctx1.__aexit__(None, None, None)) - # Prefill table with 7 rows written by 'master' - self._insert_rows("master", 7) + self.assertEqual(id_gen.get_positions(), {"master": 9}) + id_gen_worker.advance("master", 9) + self.assertEqual(id_gen_worker.get_positions(), {"master": 9}) - id_gen = self._create_id_generator() - self.assertEqual(id_gen.get_positions(), {"master": 7}) - self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) +class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase): + if not USE_POSTGRES_FOR_TESTS: + skip = "Requires Postgres" - # Try allocating a new ID gen and check that we only see position - # advanced after we leave the context manager. + def _insert_row_with_id(self, instance_name: str, stream_id: int) -> None: + """Insert one row as the given instance with given stream_id, updating + the postgres sequence position to match. + """ - def _get_next_txn(txn: LoggingTransaction) -> None: - stream_id = id_gen.get_next_txn(txn) - self.assertEqual(stream_id, 8) + def _insert(txn: LoggingTransaction) -> None: + txn.execute( + "INSERT INTO foobar (stream_id, instance_name) VALUES (?, ?)", + ( + stream_id, + instance_name, + ), + ) - self.assertEqual(id_gen.get_positions(), {"master": 7}) - self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) + txn.execute("SELECT setval('foobar_seq', ?)", (stream_id,)) - self.get_success(self.db_pool.runInteraction("test", _get_next_txn)) + txn.execute( + """ + INSERT INTO stream_positions VALUES ('test_stream', ?, ?) + ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ? + """, + (instance_name, stream_id, stream_id), + ) - self.assertEqual(id_gen.get_positions(), {"master": 8}) - self.assertEqual(id_gen.get_current_token_for_writer("master"), 8) + self.get_success(self.db_pool.runInteraction("_insert_row_with_id", _insert)) def test_get_persisted_upto_position(self) -> None: """Test that `get_persisted_upto_position` correctly tracks updates to @@ -548,49 +509,111 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): # `persisted_upto_position` in this case, then it will be correct in the # other cases that are tested above (since they'll hit the same code). - def test_restart_during_out_of_order_persistence(self) -> None: - """Test that restarting a process while another process is writing out - of order updates are handled correctly. + def test_multi_instance(self) -> None: + """Test that reads and writes from multiple processes are handled + correctly. """ + self._insert_rows("first", 3) + self._insert_rows("second", 4) - # Prefill table with 7 rows written by 'master' - self._insert_rows("master", 7) + first_id_gen = self._create_id_generator("first", writers=["first", "second"]) + second_id_gen = self._create_id_generator("second", writers=["first", "second"]) - id_gen = self._create_id_generator() + self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7}) + self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7) + self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7) - self.assertEqual(id_gen.get_positions(), {"master": 7}) - self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) + self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7}) + self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 7) + self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7) - # Persist two rows at once - ctx1 = id_gen.get_next() - ctx2 = id_gen.get_next() + # Try allocating a new ID gen and check that we only see position + # advanced after we leave the context manager. - s1 = self.get_success(ctx1.__aenter__()) - s2 = self.get_success(ctx2.__aenter__()) + async def _get_next_async() -> None: + async with first_id_gen.get_next() as stream_id: + self.assertEqual(stream_id, 8) - self.assertEqual(s1, 8) - self.assertEqual(s2, 9) + self.assertEqual( + first_id_gen.get_positions(), {"first": 3, "second": 7} + ) + self.assertEqual(first_id_gen.get_persisted_upto_position(), 7) - self.assertEqual(id_gen.get_positions(), {"master": 7}) - self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) + self.get_success(_get_next_async()) - # We finish persisting the second row before restart - self.get_success(ctx2.__aexit__(None, None, None)) + self.assertEqual(first_id_gen.get_positions(), {"first": 8, "second": 7}) - # We simulate a restart of another worker by just creating a new ID gen. - id_gen_worker = self._create_id_generator("worker") + # However the ID gen on the second instance won't have seen the update + self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7}) - # Restarted worker should not see the second persisted row - self.assertEqual(id_gen_worker.get_positions(), {"master": 7}) - self.assertEqual(id_gen_worker.get_current_token_for_writer("master"), 7) + # ... but calling `get_next` on the second instance should give a unique + # stream ID - # Now if we persist the first row then both instances should jump ahead - # correctly. - self.get_success(ctx1.__aexit__(None, None, None)) + async def _get_next_async2() -> None: + async with second_id_gen.get_next() as stream_id: + self.assertEqual(stream_id, 9) - self.assertEqual(id_gen.get_positions(), {"master": 9}) - id_gen_worker.advance("master", 9) - self.assertEqual(id_gen_worker.get_positions(), {"master": 9}) + self.assertEqual( + second_id_gen.get_positions(), {"first": 3, "second": 7} + ) + + self.get_success(_get_next_async2()) + + self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 9}) + + # If the second ID gen gets told about the first, it correctly updates + second_id_gen.advance("first", 8) + self.assertEqual(second_id_gen.get_positions(), {"first": 8, "second": 9}) + + def test_multi_instance_empty_row(self) -> None: + """Test that reads and writes from multiple processes are handled + correctly, when one of the writers starts without any rows. + """ + # Insert some rows for two out of three of the ID gens. + self._insert_rows("first", 3) + self._insert_rows("second", 4) + + first_id_gen = self._create_id_generator( + "first", writers=["first", "second", "third"] + ) + second_id_gen = self._create_id_generator( + "second", writers=["first", "second", "third"] + ) + third_id_gen = self._create_id_generator( + "third", writers=["first", "second", "third"] + ) + + self.assertEqual( + first_id_gen.get_positions(), {"first": 3, "second": 7, "third": 7} + ) + self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7) + self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7) + self.assertEqual(first_id_gen.get_current_token_for_writer("third"), 7) + + self.assertEqual( + second_id_gen.get_positions(), {"first": 3, "second": 7, "third": 7} + ) + self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 7) + self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7) + self.assertEqual(second_id_gen.get_current_token_for_writer("third"), 7) + + # Try allocating a new ID gen and check that we only see position + # advanced after we leave the context manager. + + async def _get_next_async() -> None: + async with third_id_gen.get_next() as stream_id: + self.assertEqual(stream_id, 8) + + self.assertEqual( + third_id_gen.get_positions(), {"first": 3, "second": 7, "third": 7} + ) + self.assertEqual(third_id_gen.get_persisted_upto_position(), 7) + + self.get_success(_get_next_async()) + + self.assertEqual( + third_id_gen.get_positions(), {"first": 3, "second": 7, "third": 8} + ) def test_writer_config_change(self) -> None: """Test that changing the writer config correctly works.""" -- cgit 1.5.1