From 1e453053cb12ff084fdcdc2f75c08ced274dff21 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 31 May 2022 13:17:50 +0100 Subject: Rename storage classes (#12913) --- changelog.d/12913.misc | 1 + synapse/events/snapshot.py | 10 +- synapse/federation/federation_server.py | 1 - synapse/handlers/admin.py | 12 +- synapse/handlers/device.py | 4 +- synapse/handlers/events.py | 4 +- synapse/handlers/federation.py | 30 +- synapse/handlers/federation_event.py | 27 +- synapse/handlers/initial_sync.py | 17 +- synapse/handlers/message.py | 30 +- synapse/handlers/pagination.py | 17 +- synapse/handlers/relations.py | 7 +- synapse/handlers/room.py | 11 +- synapse/handlers/room_batch.py | 4 +- synapse/handlers/search.py | 14 +- synapse/handlers/sync.py | 26 +- synapse/notifier.py | 4 +- synapse/push/httppusher.py | 6 +- synapse/push/mailer.py | 14 +- synapse/push/push_tools.py | 4 +- synapse/replication/http/federation.py | 4 +- synapse/replication/http/send_event.py | 6 +- synapse/server.py | 7 +- synapse/state/__init__.py | 51 +- synapse/storage/__init__.py | 35 +- synapse/storage/controllers/__init__.py | 46 + synapse/storage/controllers/persist_events.py | 1124 ++++++++++++++++++++++ synapse/storage/controllers/purge_events.py | 112 +++ synapse/storage/controllers/state.py | 351 +++++++ synapse/storage/persist_events.py | 1124 ---------------------- synapse/storage/purge_events.py | 112 --- synapse/storage/state.py | 320 ------ synapse/visibility.py | 10 +- tests/events/test_snapshot.py | 4 +- tests/handlers/test_federation.py | 6 +- tests/handlers/test_federation_event.py | 9 +- tests/handlers/test_message.py | 14 +- tests/handlers/test_user_directory.py | 2 +- tests/replication/slave/storage/_base.py | 2 +- tests/replication/slave/storage/test_events.py | 10 +- tests/replication/slave/storage/test_receipts.py | 12 +- tests/rest/admin/test_user.py | 4 +- tests/rest/client/test_retention.py | 4 +- tests/rest/client/test_room_batch.py | 6 +- tests/storage/test_event_chain.py | 3 +- tests/storage/test_events.py | 12 +- tests/storage/test_purge.py | 14 +- tests/storage/test_redaction.py | 14 +- tests/storage/test_room.py | 4 +- tests/storage/test_room_search.py | 4 +- tests/storage/test_state.py | 2 +- tests/test_state.py | 6 +- tests/test_utils/event_injection.py | 2 +- tests/test_visibility.py | 46 +- tests/utils.py | 2 +- 55 files changed, 1942 insertions(+), 1785 deletions(-) create mode 100644 changelog.d/12913.misc create mode 100644 synapse/storage/controllers/__init__.py create mode 100644 synapse/storage/controllers/persist_events.py create mode 100644 synapse/storage/controllers/purge_events.py create mode 100644 synapse/storage/controllers/state.py delete mode 100644 synapse/storage/persist_events.py delete mode 100644 synapse/storage/purge_events.py diff --git a/changelog.d/12913.misc b/changelog.d/12913.misc new file mode 100644 index 0000000000..a2bc940557 --- /dev/null +++ b/changelog.d/12913.misc @@ -0,0 +1 @@ +Rename storage classes. diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index 7a91544119..b700cbbfa1 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -22,7 +22,7 @@ from synapse.events import EventBase from synapse.types import JsonDict, StateMap if TYPE_CHECKING: - from synapse.storage import Storage + from synapse.storage.controllers import StorageControllers from synapse.storage.databases.main import DataStore from synapse.storage.state import StateFilter @@ -84,7 +84,7 @@ class EventContext: incomplete state. """ - _storage: "Storage" + _storage: "StorageControllers" rejected: Union[Literal[False], str] = False _state_group: Optional[int] = None state_group_before_event: Optional[int] = None @@ -97,7 +97,7 @@ class EventContext: @staticmethod def with_state( - storage: "Storage", + storage: "StorageControllers", state_group: Optional[int], state_group_before_event: Optional[int], state_delta_due_to_event: Optional[StateMap[str]], @@ -117,7 +117,7 @@ class EventContext: @staticmethod def for_outlier( - storage: "Storage", + storage: "StorageControllers", ) -> "EventContext": """Return an EventContext instance suitable for persisting an outlier event""" return EventContext(storage=storage) @@ -147,7 +147,7 @@ class EventContext: } @staticmethod - def deserialize(storage: "Storage", input: JsonDict) -> "EventContext": + def deserialize(storage: "StorageControllers", input: JsonDict) -> "EventContext": """Converts a dict that was produced by `serialize` back into a EventContext. diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 5b227b85fd..3ecede22d9 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -109,7 +109,6 @@ class FederationServer(FederationBase): super().__init__(hs) self.handler = hs.get_federation_handler() - self.storage = hs.get_storage() self._spam_checker = hs.get_spam_checker() self._federation_event_handler = hs.get_federation_event_handler() self.state = hs.get_state_handler() diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index 50e34743b7..d4fe7df533 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -30,8 +30,8 @@ logger = logging.getLogger(__name__) class AdminHandler: def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main - self.storage = hs.get_storage() - self.state_storage = self.storage.state + self._storage_controllers = hs.get_storage_controllers() + self._state_storage_controller = self._storage_controllers.state async def get_whois(self, user: UserID) -> JsonDict: connections = [] @@ -197,7 +197,9 @@ class AdminHandler: from_key = events[-1].internal_metadata.after - events = await filter_events_for_client(self.storage, user_id, events) + events = await filter_events_for_client( + self._storage_controllers, user_id, events + ) writer.write_events(room_id, events) @@ -233,7 +235,9 @@ class AdminHandler: for event_id in extremities: if not event_to_unseen_prevs[event_id]: continue - state = await self.state_storage.get_state_for_event(event_id) + state = await self._state_storage_controller.get_state_for_event( + event_id + ) writer.write_state(room_id, event_id, state) return writer.finished() diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 2a56473dc6..72faf2ee38 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -71,7 +71,7 @@ class DeviceWorkerHandler: self.store = hs.get_datastores().main self.notifier = hs.get_notifier() self.state = hs.get_state_handler() - self.state_storage = hs.get_storage().state + self._state_storage = hs.get_storage_controllers().state self._auth_handler = hs.get_auth_handler() self.server_name = hs.hostname @@ -204,7 +204,7 @@ class DeviceWorkerHandler: continue # mapping from event_id -> state_dict - prev_state_ids = await self.state_storage.get_state_ids_for_events( + prev_state_ids = await self._state_storage.get_state_ids_for_events( event_ids ) diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index cb7e0ca7a8..ac13340d3a 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -139,7 +139,7 @@ class EventStreamHandler: class EventHandler: def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main - self.storage = hs.get_storage() + self._storage_controllers = hs.get_storage_controllers() async def get_event( self, @@ -177,7 +177,7 @@ class EventHandler: is_peeking = user.to_string() not in users filtered = await filter_events_for_client( - self.storage, user.to_string(), [event], is_peeking=is_peeking + self._storage_controllers, user.to_string(), [event], is_peeking=is_peeking ) if not filtered: diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index c8233270d7..80ee7e7b4e 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -125,8 +125,8 @@ class FederationHandler: self.hs = hs self.store = hs.get_datastores().main - self.storage = hs.get_storage() - self.state_storage = self.storage.state + self._storage_controllers = hs.get_storage_controllers() + self._state_storage_controller = self._storage_controllers.state self.federation_client = hs.get_federation_client() self.state_handler = hs.get_state_handler() self.server_name = hs.hostname @@ -324,7 +324,7 @@ class FederationHandler: # We set `check_history_visibility_only` as we might otherwise get false # positives from users having been erased. filtered_extremities = await filter_events_for_server( - self.storage, + self._storage_controllers, self.server_name, events_to_check, redact=False, @@ -660,7 +660,7 @@ class FederationHandler: # in the invitee's sync stream. It is stripped out for all other local users. event.unsigned["knock_room_state"] = stripped_room_state["knock_state_events"] - context = EventContext.for_outlier(self.storage) + context = EventContext.for_outlier(self._storage_controllers) stream_id = await self._federation_event_handler.persist_events_and_notify( event.room_id, [(event, context)] ) @@ -849,7 +849,7 @@ class FederationHandler: ) ) - context = EventContext.for_outlier(self.storage) + context = EventContext.for_outlier(self._storage_controllers) await self._federation_event_handler.persist_events_and_notify( event.room_id, [(event, context)] ) @@ -878,7 +878,7 @@ class FederationHandler: await self.federation_client.send_leave(host_list, event) - context = EventContext.for_outlier(self.storage) + context = EventContext.for_outlier(self._storage_controllers) stream_id = await self._federation_event_handler.persist_events_and_notify( event.room_id, [(event, context)] ) @@ -1027,7 +1027,7 @@ class FederationHandler: if event.internal_metadata.outlier: raise NotFoundError("State not known at event %s" % (event_id,)) - state_groups = await self.state_storage.get_state_groups_ids( + state_groups = await self._state_storage_controller.get_state_groups_ids( room_id, [event_id] ) @@ -1078,7 +1078,9 @@ class FederationHandler: ], ) - events = await filter_events_for_server(self.storage, origin, events) + events = await filter_events_for_server( + self._storage_controllers, origin, events + ) return events @@ -1109,7 +1111,9 @@ class FederationHandler: if not in_room: raise AuthError(403, "Host not in room.") - events = await filter_events_for_server(self.storage, origin, [event]) + events = await filter_events_for_server( + self._storage_controllers, origin, [event] + ) event = events[0] return event else: @@ -1138,7 +1142,7 @@ class FederationHandler: ) missing_events = await filter_events_for_server( - self.storage, origin, missing_events + self._storage_controllers, origin, missing_events ) return missing_events @@ -1480,9 +1484,11 @@ class FederationHandler: # clear the lazy-loading flag. logger.info("Updating current state for %s", room_id) assert ( - self.storage.persistence is not None + self._storage_controllers.persistence is not None ), "TODO(faster_joins): support for workers" - await self.storage.persistence.update_current_state(room_id) + await self._storage_controllers.persistence.update_current_state( + room_id + ) logger.info("Clearing partial-state flag for %s", room_id) success = await self.store.clear_partial_state_room(room_id) diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index a1361af272..b908674529 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -98,8 +98,8 @@ class FederationEventHandler: def __init__(self, hs: "HomeServer"): self._store = hs.get_datastores().main - self._storage = hs.get_storage() - self._state_storage = self._storage.state + self._storage_controllers = hs.get_storage_controllers() + self._state_storage_controller = self._storage_controllers.state self._state_handler = hs.get_state_handler() self._event_creation_handler = hs.get_event_creation_handler() @@ -535,7 +535,9 @@ class FederationEventHandler: ) return await self._store.update_state_for_partial_state_event(event, context) - self._state_storage.notify_event_un_partial_stated(event.event_id) + self._state_storage_controller.notify_event_un_partial_stated( + event.event_id + ) async def backfill( self, dest: str, room_id: str, limit: int, extremities: Collection[str] @@ -835,7 +837,9 @@ class FederationEventHandler: try: # Get the state of the events we know about - ours = await self._state_storage.get_state_groups_ids(room_id, seen) + ours = await self._state_storage_controller.get_state_groups_ids( + room_id, seen + ) # state_maps is a list of mappings from (type, state_key) to event_id state_maps: List[StateMap[str]] = list(ours.values()) @@ -1436,7 +1440,7 @@ class FederationEventHandler: # we're not bothering about room state, so flag the event as an outlier. event.internal_metadata.outlier = True - context = EventContext.for_outlier(self._storage) + context = EventContext.for_outlier(self._storage_controllers) try: validate_event_for_room_version(room_version_obj, event) check_auth_rules_for_event(room_version_obj, event, auth) @@ -1613,7 +1617,7 @@ class FederationEventHandler: # given state at the event. This should correctly handle cases # like bans, especially with state res v2. - state_sets_d = await self._state_storage.get_state_groups_ids( + state_sets_d = await self._state_storage_controller.get_state_groups_ids( event.room_id, extrem_ids ) state_sets: List[StateMap[str]] = list(state_sets_d.values()) @@ -1885,7 +1889,7 @@ class FederationEventHandler: # create a new state group as a delta from the existing one. prev_group = context.state_group - state_group = await self._state_storage.store_state_group( + state_group = await self._state_storage_controller.store_state_group( event.event_id, event.room_id, prev_group=prev_group, @@ -1894,7 +1898,7 @@ class FederationEventHandler: ) return EventContext.with_state( - storage=self._storage, + storage=self._storage_controllers, state_group=state_group, state_group_before_event=context.state_group_before_event, state_delta_due_to_event=state_updates, @@ -1984,11 +1988,14 @@ class FederationEventHandler: ) return result["max_stream_id"] else: - assert self._storage.persistence + assert self._storage_controllers.persistence # Note that this returns the events that were persisted, which may not be # the same as were passed in if some were deduplicated due to transaction IDs. - events, max_stream_token = await self._storage.persistence.persist_events( + ( + events, + max_stream_token, + ) = await self._storage_controllers.persistence.persist_events( event_and_contexts, backfilled=backfilled ) diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index fbdbeeedfd..d2b489e816 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -67,8 +67,8 @@ class InitialSyncHandler: ] ] = ResponseCache(hs.get_clock(), "initial_sync_cache") self._event_serializer = hs.get_event_client_serializer() - self.storage = hs.get_storage() - self.state_storage = self.storage.state + self._storage_controllers = hs.get_storage_controllers() + self._state_storage_controller = self._storage_controllers.state async def snapshot_all_rooms( self, @@ -198,7 +198,8 @@ class InitialSyncHandler: event.stream_ordering, ) deferred_room_state = run_in_background( - self.state_storage.get_state_for_events, [event.event_id] + self._state_storage_controller.get_state_for_events, + [event.event_id], ).addCallback( lambda states: cast(StateMap[EventBase], states[event.event_id]) ) @@ -218,7 +219,7 @@ class InitialSyncHandler: ).addErrback(unwrapFirstError) messages = await filter_events_for_client( - self.storage, user_id, messages + self._storage_controllers, user_id, messages ) start_token = now_token.copy_and_replace(StreamKeyType.ROOM, token) @@ -355,7 +356,9 @@ class InitialSyncHandler: member_event_id: str, is_peeking: bool, ) -> JsonDict: - room_state = await self.state_storage.get_state_for_event(member_event_id) + room_state = await self._state_storage_controller.get_state_for_event( + member_event_id + ) limit = pagin_config.limit if pagin_config else None if limit is None: @@ -369,7 +372,7 @@ class InitialSyncHandler: ) messages = await filter_events_for_client( - self.storage, user_id, messages, is_peeking=is_peeking + self._storage_controllers, user_id, messages, is_peeking=is_peeking ) start_token = StreamToken.START.copy_and_replace(StreamKeyType.ROOM, token) @@ -474,7 +477,7 @@ class InitialSyncHandler: ) messages = await filter_events_for_client( - self.storage, user_id, messages, is_peeking=is_peeking + self._storage_controllers, user_id, messages, is_peeking=is_peeking ) start_token = now_token.copy_and_replace(StreamKeyType.ROOM, token) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 38b71a2c96..f377769071 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -84,8 +84,8 @@ class MessageHandler: self.clock = hs.get_clock() self.state = hs.get_state_handler() self.store = hs.get_datastores().main - self.storage = hs.get_storage() - self.state_storage = self.storage.state + self._storage_controllers = hs.get_storage_controllers() + self._state_storage_controller = self._storage_controllers.state self._event_serializer = hs.get_event_client_serializer() self._ephemeral_events_enabled = hs.config.server.enable_ephemeral_messages @@ -132,7 +132,7 @@ class MessageHandler: assert ( membership_event_id is not None ), "check_user_in_room_or_world_readable returned invalid data" - room_state = await self.state_storage.get_state_for_events( + room_state = await self._state_storage_controller.get_state_for_events( [membership_event_id], StateFilter.from_types([key]) ) data = room_state[membership_event_id].get(key) @@ -193,7 +193,7 @@ class MessageHandler: # check whether the user is in the room at that time to determine # whether they should be treated as peeking. - state_map = await self.state_storage.get_state_for_event( + state_map = await self._state_storage_controller.get_state_for_event( last_event.event_id, StateFilter.from_types([(EventTypes.Member, user_id)]), ) @@ -206,7 +206,7 @@ class MessageHandler: is_peeking = not joined visible_events = await filter_events_for_client( - self.storage, + self._storage_controllers, user_id, [last_event], filter_send_to_client=False, @@ -214,8 +214,10 @@ class MessageHandler: ) if visible_events: - room_state_events = await self.state_storage.get_state_for_events( - [last_event.event_id], state_filter=state_filter + room_state_events = ( + await self._state_storage_controller.get_state_for_events( + [last_event.event_id], state_filter=state_filter + ) ) room_state: Mapping[Any, EventBase] = room_state_events[ last_event.event_id @@ -244,8 +246,10 @@ class MessageHandler: assert ( membership_event_id is not None ), "check_user_in_room_or_world_readable returned invalid data" - room_state_events = await self.state_storage.get_state_for_events( - [membership_event_id], state_filter=state_filter + room_state_events = ( + await self._state_storage_controller.get_state_for_events( + [membership_event_id], state_filter=state_filter + ) ) room_state = room_state_events[membership_event_id] @@ -402,7 +406,7 @@ class EventCreationHandler: self.auth = hs.get_auth() self._event_auth_handler = hs.get_event_auth_handler() self.store = hs.get_datastores().main - self.storage = hs.get_storage() + self._storage_controllers = hs.get_storage_controllers() self.state = hs.get_state_handler() self.clock = hs.get_clock() self.validator = EventValidator() @@ -1032,7 +1036,7 @@ class EventCreationHandler: # after it is created if builder.internal_metadata.outlier: event.internal_metadata.outlier = True - context = EventContext.for_outlier(self.storage) + context = EventContext.for_outlier(self._storage_controllers) elif ( event.type == EventTypes.MSC2716_INSERTION and state_event_ids @@ -1445,7 +1449,7 @@ class EventCreationHandler: """ extra_users = extra_users or [] - assert self.storage.persistence is not None + assert self._storage_controllers.persistence is not None assert self._events_shard_config.should_handle( self._instance_name, event.room_id ) @@ -1679,7 +1683,7 @@ class EventCreationHandler: event, event_pos, max_stream_token, - ) = await self.storage.persistence.persist_event( + ) = await self._storage_controllers.persistence.persist_event( event, context=context, backfilled=backfilled ) diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 35afe6b855..6262a35822 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -129,8 +129,8 @@ class PaginationHandler: self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastores().main - self.storage = hs.get_storage() - self.state_storage = self.storage.state + self._storage_controllers = hs.get_storage_controllers() + self._state_storage_controller = self._storage_controllers.state self.clock = hs.get_clock() self._server_name = hs.hostname self._room_shutdown_handler = hs.get_room_shutdown_handler() @@ -352,7 +352,7 @@ class PaginationHandler: self._purges_in_progress_by_room.add(room_id) try: async with self.pagination_lock.write(room_id): - await self.storage.purge_events.purge_history( + await self._storage_controllers.purge_events.purge_history( room_id, token, delete_local_events ) logger.info("[purge] complete") @@ -414,7 +414,7 @@ class PaginationHandler: if joined: raise SynapseError(400, "Users are still joined to this room") - await self.storage.purge_events.purge_room(room_id) + await self._storage_controllers.purge_events.purge_room(room_id) async def get_messages( self, @@ -529,7 +529,10 @@ class PaginationHandler: events = await event_filter.filter(events) events = await filter_events_for_client( - self.storage, user_id, events, is_peeking=(member_event_id is None) + self._storage_controllers, + user_id, + events, + is_peeking=(member_event_id is None), ) # if after the filter applied there are no more events @@ -550,7 +553,7 @@ class PaginationHandler: (EventTypes.Member, event.sender) for event in events ) - state_ids = await self.state_storage.get_state_ids_for_event( + state_ids = await self._state_storage_controller.get_state_ids_for_event( events[0].event_id, state_filter=state_filter ) @@ -664,7 +667,7 @@ class PaginationHandler: 400, "Users are still joined to this room" ) - await self.storage.purge_events.purge_room(room_id) + await self._storage_controllers.purge_events.purge_room(room_id) logger.info("complete") self._delete_by_id[delete_id].status = DeleteStatus.STATUS_COMPLETE diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index ab7e54857d..9a1cc11bb3 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -69,7 +69,7 @@ class BundledAggregations: class RelationsHandler: def __init__(self, hs: "HomeServer"): self._main_store = hs.get_datastores().main - self._storage = hs.get_storage() + self._storage_controllers = hs.get_storage_controllers() self._auth = hs.get_auth() self._clock = hs.get_clock() self._event_handler = hs.get_event_handler() @@ -143,7 +143,10 @@ class RelationsHandler: ) events = await filter_events_for_client( - self._storage, user_id, events, is_peeking=(member_event_id is None) + self._storage_controllers, + user_id, + events, + is_peeking=(member_event_id is None), ) now = self._clock.time_msec() diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index e2775b34f1..5c91d33f58 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1192,8 +1192,8 @@ class RoomContextHandler: self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastores().main - self.storage = hs.get_storage() - self.state_storage = self.storage.state + self._storage_controllers = hs.get_storage_controllers() + self._state_storage_controller = self._storage_controllers.state self._relations_handler = hs.get_relations_handler() async def get_event_context( @@ -1236,7 +1236,10 @@ class RoomContextHandler: if use_admin_priviledge: return events return await filter_events_for_client( - self.storage, user.to_string(), events, is_peeking=is_peeking + self._storage_controllers, + user.to_string(), + events, + is_peeking=is_peeking, ) event = await self.store.get_event( @@ -1293,7 +1296,7 @@ class RoomContextHandler: # first? Shouldn't we be consistent with /sync? # https://github.com/matrix-org/matrix-doc/issues/687 - state = await self.state_storage.get_state_for_events( + state = await self._state_storage_controller.get_state_for_events( [last_event_id], state_filter=state_filter ) diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py index 7ce32f2e9c..1414e575d6 100644 --- a/synapse/handlers/room_batch.py +++ b/synapse/handlers/room_batch.py @@ -17,7 +17,7 @@ class RoomBatchHandler: def __init__(self, hs: "HomeServer"): self.hs = hs self.store = hs.get_datastores().main - self.state_storage = hs.get_storage().state + self._state_storage_controller = hs.get_storage_controllers().state self.event_creation_handler = hs.get_event_creation_handler() self.room_member_handler = hs.get_room_member_handler() self.auth = hs.get_auth() @@ -141,7 +141,7 @@ class RoomBatchHandler: ) = await self.store.get_max_depth_of(event_ids) # mapping from (type, state_key) -> state_event_id assert most_recent_event_id is not None - prev_state_map = await self.state_storage.get_state_ids_for_event( + prev_state_map = await self._state_storage_controller.get_state_ids_for_event( most_recent_event_id ) # List of state event ID's diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index e02c915248..659f99f7e2 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -55,8 +55,8 @@ class SearchHandler: self.hs = hs self._event_serializer = hs.get_event_client_serializer() self._relations_handler = hs.get_relations_handler() - self.storage = hs.get_storage() - self.state_storage = self.storage.state + self._storage_controllers = hs.get_storage_controllers() + self._state_storage_controller = self._storage_controllers.state self.auth = hs.get_auth() async def get_old_rooms_from_upgraded_room(self, room_id: str) -> Iterable[str]: @@ -460,7 +460,7 @@ class SearchHandler: filtered_events = await search_filter.filter([r["event"] for r in results]) events = await filter_events_for_client( - self.storage, user.to_string(), filtered_events + self._storage_controllers, user.to_string(), filtered_events ) events.sort(key=lambda e: -rank_map[e.event_id]) @@ -559,7 +559,7 @@ class SearchHandler: filtered_events = await search_filter.filter([r["event"] for r in results]) events = await filter_events_for_client( - self.storage, user.to_string(), filtered_events + self._storage_controllers, user.to_string(), filtered_events ) room_events.extend(events) @@ -644,11 +644,11 @@ class SearchHandler: ) events_before = await filter_events_for_client( - self.storage, user.to_string(), res.events_before + self._storage_controllers, user.to_string(), res.events_before ) events_after = await filter_events_for_client( - self.storage, user.to_string(), res.events_after + self._storage_controllers, user.to_string(), res.events_after ) context: JsonDict = { @@ -677,7 +677,7 @@ class SearchHandler: [(EventTypes.Member, sender) for sender in senders] ) - state = await self.state_storage.get_state_for_event( + state = await self._state_storage_controller.get_state_for_event( last_event_id, state_filter ) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index c5c538e0c3..b5859dcb28 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -238,8 +238,8 @@ class SyncHandler: self.clock = hs.get_clock() self.state = hs.get_state_handler() self.auth = hs.get_auth() - self.storage = hs.get_storage() - self.state_storage = self.storage.state + self._storage_controllers = hs.get_storage_controllers() + self._state_storage_controller = self._storage_controllers.state # TODO: flush cache entries on subsequent sync request. # Once we get the next /sync request (ie, one with the same access token @@ -512,7 +512,7 @@ class SyncHandler: current_state_ids = frozenset(current_state_ids_map.values()) recents = await filter_events_for_client( - self.storage, + self._storage_controllers, sync_config.user.to_string(), recents, always_include_ids=current_state_ids, @@ -580,7 +580,7 @@ class SyncHandler: current_state_ids = frozenset(current_state_ids_map.values()) loaded_recents = await filter_events_for_client( - self.storage, + self._storage_controllers, sync_config.user.to_string(), loaded_recents, always_include_ids=current_state_ids, @@ -630,7 +630,7 @@ class SyncHandler: event: event of interest state_filter: The state filter used to fetch state from the database. """ - state_ids = await self.state_storage.get_state_ids_for_event( + state_ids = await self._state_storage_controller.get_state_ids_for_event( event.event_id, state_filter=state_filter or StateFilter.all() ) if event.is_state(): @@ -710,7 +710,7 @@ class SyncHandler: return None last_event = last_events[-1] - state_ids = await self.state_storage.get_state_ids_for_event( + state_ids = await self._state_storage_controller.get_state_ids_for_event( last_event.event_id, state_filter=StateFilter.from_types( [(EventTypes.Name, ""), (EventTypes.CanonicalAlias, "")] @@ -889,13 +889,15 @@ class SyncHandler: if full_state: if batch: current_state_ids = ( - await self.state_storage.get_state_ids_for_event( + await self._state_storage_controller.get_state_ids_for_event( batch.events[-1].event_id, state_filter=state_filter ) ) - state_ids = await self.state_storage.get_state_ids_for_event( - batch.events[0].event_id, state_filter=state_filter + state_ids = ( + await self._state_storage_controller.get_state_ids_for_event( + batch.events[0].event_id, state_filter=state_filter + ) ) else: @@ -915,7 +917,7 @@ class SyncHandler: elif batch.limited: if batch: state_at_timeline_start = ( - await self.state_storage.get_state_ids_for_event( + await self._state_storage_controller.get_state_ids_for_event( batch.events[0].event_id, state_filter=state_filter ) ) @@ -950,7 +952,7 @@ class SyncHandler: if batch: current_state_ids = ( - await self.state_storage.get_state_ids_for_event( + await self._state_storage_controller.get_state_ids_for_event( batch.events[-1].event_id, state_filter=state_filter ) ) @@ -982,7 +984,7 @@ class SyncHandler: # So we fish out all the member events corresponding to the # timeline here, and then dedupe any redundant ones below. - state_ids = await self.state_storage.get_state_ids_for_event( + state_ids = await self._state_storage_controller.get_state_ids_for_event( batch.events[0].event_id, # we only want members! state_filter=StateFilter.from_types( diff --git a/synapse/notifier.py b/synapse/notifier.py index c2b66eec62..1100434b3f 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -221,7 +221,7 @@ class Notifier: self.room_to_user_streams: Dict[str, Set[_NotifierUserStream]] = {} self.hs = hs - self.storage = hs.get_storage() + self._storage_controllers = hs.get_storage_controllers() self.event_sources = hs.get_event_sources() self.store = hs.get_datastores().main self.pending_new_room_events: List[_PendingRoomEventEntry] = [] @@ -623,7 +623,7 @@ class Notifier: if name == "room": new_events = await filter_events_for_client( - self.storage, + self._storage_controllers, user.to_string(), new_events, is_peeking=is_peeking, diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index d5603596c0..e96fb45e9f 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -65,7 +65,7 @@ class HttpPusher(Pusher): def __init__(self, hs: "HomeServer", pusher_config: PusherConfig): super().__init__(hs, pusher_config) - self.storage = self.hs.get_storage() + self._storage_controllers = self.hs.get_storage_controllers() self.app_display_name = pusher_config.app_display_name self.device_display_name = pusher_config.device_display_name self.pushkey_ts = pusher_config.ts @@ -343,7 +343,9 @@ class HttpPusher(Pusher): } return d - ctx = await push_tools.get_context_for_event(self.storage, event, self.user_id) + ctx = await push_tools.get_context_for_event( + self._storage_controllers, event, self.user_id + ) d = { "notification": { diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index 84124af965..63aefd07f5 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -114,10 +114,10 @@ class Mailer: self.send_email_handler = hs.get_send_email_handler() self.store = self.hs.get_datastores().main - self.state_storage = self.hs.get_storage().state + self._state_storage_controller = self.hs.get_storage_controllers().state self.macaroon_gen = self.hs.get_macaroon_generator() self.state_handler = self.hs.get_state_handler() - self.storage = hs.get_storage() + self._storage_controllers = hs.get_storage_controllers() self.app_name = app_name self.email_subjects: EmailSubjectConfig = hs.config.email.email_subjects @@ -456,7 +456,7 @@ class Mailer: } the_events = await filter_events_for_client( - self.storage, user_id, results.events_before + self._storage_controllers, user_id, results.events_before ) the_events.append(notif_event) @@ -494,7 +494,7 @@ class Mailer: ) else: # Attempt to check the historical state for the room. - historical_state = await self.state_storage.get_state_for_event( + historical_state = await self._state_storage_controller.get_state_for_event( event.event_id, StateFilter.from_types((type_state_key,)) ) sender_state_event = historical_state.get(type_state_key) @@ -767,8 +767,10 @@ class Mailer: member_event_ids.append(sender_state_event_id) else: # Attempt to check the historical state for the room. - historical_state = await self.state_storage.get_state_for_event( - event_id, StateFilter.from_types((type_state_key,)) + historical_state = ( + await self._state_storage_controller.get_state_for_event( + event_id, StateFilter.from_types((type_state_key,)) + ) ) sender_state_event = historical_state.get(type_state_key) if sender_state_event: diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index a1bf5b20dd..8397229ccb 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -16,7 +16,7 @@ from typing import Dict from synapse.api.constants import ReceiptTypes from synapse.events import EventBase from synapse.push.presentable_names import calculate_room_name, name_from_member_event -from synapse.storage import Storage +from synapse.storage.controllers import StorageControllers from synapse.storage.databases.main import DataStore @@ -52,7 +52,7 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) - async def get_context_for_event( - storage: Storage, ev: EventBase, user_id: str + storage: StorageControllers, ev: EventBase, user_id: str ) -> Dict[str, str]: ctx = {} diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index 3e7300b4a1..eed29cd597 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -69,7 +69,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): super().__init__(hs) self.store = hs.get_datastores().main - self.storage = hs.get_storage() + self._storage_controllers = hs.get_storage_controllers() self.clock = hs.get_clock() self.federation_event_handler = hs.get_federation_event_handler() @@ -133,7 +133,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): event.internal_metadata.outlier = event_payload["outlier"] context = EventContext.deserialize( - self.storage, event_payload["context"] + self._storage_controllers, event_payload["context"] ) event_and_contexts.append((event, context)) diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py index ce78176836..c2b2588ea5 100644 --- a/synapse/replication/http/send_event.py +++ b/synapse/replication/http/send_event.py @@ -70,7 +70,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): self.event_creation_handler = hs.get_event_creation_handler() self.store = hs.get_datastores().main - self.storage = hs.get_storage() + self._storage_controllers = hs.get_storage_controllers() self.clock = hs.get_clock() @staticmethod @@ -127,7 +127,9 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): event.internal_metadata.outlier = content["outlier"] requester = Requester.deserialize(self.store, content["requester"]) - context = EventContext.deserialize(self.storage, content["context"]) + context = EventContext.deserialize( + self._storage_controllers, content["context"] + ) ratelimit = content["ratelimit"] extra_users = [UserID.from_string(u) for u in content["extra_users"]] diff --git a/synapse/server.py b/synapse/server.py index 3fd23aaf52..a66ec228db 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -123,7 +123,8 @@ from synapse.server_notices.worker_server_notices_sender import ( WorkerServerNoticesSender, ) from synapse.state import StateHandler, StateResolutionHandler -from synapse.storage import Databases, Storage +from synapse.storage import Databases +from synapse.storage.controllers import StorageControllers from synapse.streams.events import EventSources from synapse.types import DomainSpecificString, ISynapseReactor from synapse.util import Clock @@ -729,8 +730,8 @@ class HomeServer(metaclass=abc.ABCMeta): return PasswordPolicyHandler(self) @cache_in_self - def get_storage(self) -> Storage: - return Storage(self, self.get_datastores()) + def get_storage_controllers(self) -> StorageControllers: + return StorageControllers(self, self.get_datastores()) @cache_in_self def get_replication_streamer(self) -> ReplicationStreamer: diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 9c9d946f38..bf09f5128a 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -127,10 +127,10 @@ class StateHandler: def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() self.store = hs.get_datastores().main - self.state_storage = hs.get_storage().state + self._state_storage_controller = hs.get_storage_controllers().state self.hs = hs self._state_resolution_handler = hs.get_state_resolution_handler() - self._storage = hs.get_storage() + self._storage_controllers = hs.get_storage_controllers() @overload async def get_current_state( @@ -337,12 +337,14 @@ class StateHandler: # if not state_group_before_event: - state_group_before_event = await self.state_storage.store_state_group( - event.event_id, - event.room_id, - prev_group=state_group_before_event_prev_group, - delta_ids=deltas_to_state_group_before_event, - current_state_ids=state_ids_before_event, + state_group_before_event = ( + await self._state_storage_controller.store_state_group( + event.event_id, + event.room_id, + prev_group=state_group_before_event_prev_group, + delta_ids=deltas_to_state_group_before_event, + current_state_ids=state_ids_before_event, + ) ) # Assign the new state group to the cached state entry. @@ -359,7 +361,7 @@ class StateHandler: if not event.is_state(): return EventContext.with_state( - storage=self._storage, + storage=self._storage_controllers, state_group_before_event=state_group_before_event, state_group=state_group_before_event, state_delta_due_to_event={}, @@ -382,16 +384,18 @@ class StateHandler: state_ids_after_event[key] = event.event_id delta_ids = {key: event.event_id} - state_group_after_event = await self.state_storage.store_state_group( - event.event_id, - event.room_id, - prev_group=state_group_before_event, - delta_ids=delta_ids, - current_state_ids=state_ids_after_event, + state_group_after_event = ( + await self._state_storage_controller.store_state_group( + event.event_id, + event.room_id, + prev_group=state_group_before_event, + delta_ids=delta_ids, + current_state_ids=state_ids_after_event, + ) ) return EventContext.with_state( - storage=self._storage, + storage=self._storage_controllers, state_group=state_group_after_event, state_group_before_event=state_group_before_event, state_delta_due_to_event=delta_ids, @@ -416,7 +420,9 @@ class StateHandler: """ logger.debug("resolve_state_groups event_ids %s", event_ids) - state_groups = await self.state_storage.get_state_group_for_events(event_ids) + state_groups = await self._state_storage_controller.get_state_group_for_events( + event_ids + ) state_group_ids = state_groups.values() @@ -424,8 +430,13 @@ class StateHandler: state_group_ids_set = set(state_group_ids) if len(state_group_ids_set) == 1: (state_group_id,) = state_group_ids_set - state = await self.state_storage.get_state_for_groups(state_group_ids_set) - prev_group, delta_ids = await self.state_storage.get_state_group_delta( + state = await self._state_storage_controller.get_state_for_groups( + state_group_ids_set + ) + ( + prev_group, + delta_ids, + ) = await self._state_storage_controller.get_state_group_delta( state_group_id ) return _StateCacheEntry( @@ -439,7 +450,7 @@ class StateHandler: room_version = await self.store.get_room_version_id(room_id) - state_to_resolve = await self.state_storage.get_state_for_groups( + state_to_resolve = await self._state_storage_controller.get_state_for_groups( state_group_ids_set ) diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 105e4e1fec..bac21ecf9c 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -18,41 +18,20 @@ The storage layer is split up into multiple parts to allow Synapse to run against different configurations of databases (e.g. single or multiple databases). The `DatabasePool` class represents connections to a single physical database. The `databases` are classes that talk directly to a `DatabasePool` -instance and have associated schemas, background updates, etc. On top of those -there are classes that provide high level interfaces that combine calls to -multiple `databases`. +instance and have associated schemas, background updates, etc. + +On top of the databases are the StorageControllers, located in the +`synapse.storage.controllers` module. These classes provide high level +interfaces that combine calls to multiple `databases`. They are bundled into the +`StorageControllers` singleton for ease of use, and exposed via +`HomeServer.get_storage_controllers()`. There are also schemas that get applied to every database, regardless of the data stores associated with them (e.g. the schema version tables), which are stored in `synapse.storage.schema`. """ -from typing import TYPE_CHECKING from synapse.storage.databases import Databases from synapse.storage.databases.main import DataStore -from synapse.storage.persist_events import EventsPersistenceStorage -from synapse.storage.purge_events import PurgeEventsStorage -from synapse.storage.state import StateGroupStorage - -if TYPE_CHECKING: - from synapse.server import HomeServer - __all__ = ["Databases", "DataStore"] - - -class Storage: - """The high level interfaces for talking to various storage layers.""" - - def __init__(self, hs: "HomeServer", stores: Databases): - # We include the main data store here mainly so that we don't have to - # rewrite all the existing code to split it into high vs low level - # interfaces. - self.main = stores.main - - self.purge_events = PurgeEventsStorage(hs, stores) - self.state = StateGroupStorage(hs, stores) - - self.persistence = None - if stores.persist_events: - self.persistence = EventsPersistenceStorage(hs, stores) diff --git a/synapse/storage/controllers/__init__.py b/synapse/storage/controllers/__init__.py new file mode 100644 index 0000000000..992261d07b --- /dev/null +++ b/synapse/storage/controllers/__init__.py @@ -0,0 +1,46 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from synapse.storage.controllers.persist_events import ( + EventsPersistenceStorageController, +) +from synapse.storage.controllers.purge_events import PurgeEventsStorageController +from synapse.storage.controllers.state import StateGroupStorageController +from synapse.storage.databases import Databases +from synapse.storage.databases.main import DataStore + +if TYPE_CHECKING: + from synapse.server import HomeServer + + +__all__ = ["Databases", "DataStore"] + + +class StorageControllers: + """The high level interfaces for talking to various storage controller layers.""" + + def __init__(self, hs: "HomeServer", stores: Databases): + # We include the main data store here mainly so that we don't have to + # rewrite all the existing code to split it into high vs low level + # interfaces. + self.main = stores.main + + self.purge_events = PurgeEventsStorageController(hs, stores) + self.state = StateGroupStorageController(hs, stores) + + self.persistence = None + if stores.persist_events: + self.persistence = EventsPersistenceStorageController(hs, stores) diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py new file mode 100644 index 0000000000..ef8c135b12 --- /dev/null +++ b/synapse/storage/controllers/persist_events.py @@ -0,0 +1,1124 @@ +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import logging +from collections import deque +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Collection, + Deque, + Dict, + Generator, + Generic, + Iterable, + List, + Optional, + Set, + Tuple, + TypeVar, +) + +import attr +from prometheus_client import Counter, Histogram + +from twisted.internet import defer + +from synapse.api.constants import EventTypes, Membership +from synapse.events import EventBase +from synapse.events.snapshot import EventContext +from synapse.logging import opentracing +from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage.databases import Databases +from synapse.storage.databases.main.events import DeltaState +from synapse.storage.databases.main.events_worker import EventRedactBehaviour +from synapse.types import ( + PersistedEventPosition, + RoomStreamToken, + StateMap, + get_domain_from_id, +) +from synapse.util.async_helpers import ObservableDeferred, yieldable_gather_results +from synapse.util.metrics import Measure + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + +# The number of times we are recalculating the current state +state_delta_counter = Counter("synapse_storage_events_state_delta", "") + +# The number of times we are recalculating state when there is only a +# single forward extremity +state_delta_single_event_counter = Counter( + "synapse_storage_events_state_delta_single_event", "" +) + +# The number of times we are reculating state when we could have resonably +# calculated the delta when we calculated the state for an event we were +# persisting. +state_delta_reuse_delta_counter = Counter( + "synapse_storage_events_state_delta_reuse_delta", "" +) + +# The number of forward extremities for each new event. +forward_extremities_counter = Histogram( + "synapse_storage_events_forward_extremities_persisted", + "Number of forward extremities for each new event", + buckets=(1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"), +) + +# The number of stale forward extremities for each new event. Stale extremities +# are those that were in the previous set of extremities as well as the new. +stale_forward_extremities_counter = Histogram( + "synapse_storage_events_stale_forward_extremities_persisted", + "Number of unchanged forward extremities for each new event", + buckets=(0, 1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"), +) + +state_resolutions_during_persistence = Counter( + "synapse_storage_events_state_resolutions_during_persistence", + "Number of times we had to do state res to calculate new current state", +) + +potential_times_prune_extremities = Counter( + "synapse_storage_events_potential_times_prune_extremities", + "Number of times we might be able to prune extremities", +) + +times_pruned_extremities = Counter( + "synapse_storage_events_times_pruned_extremities", + "Number of times we were actually be able to prune extremities", +) + + +@attr.s(auto_attribs=True, slots=True) +class _EventPersistQueueItem: + events_and_contexts: List[Tuple[EventBase, EventContext]] + backfilled: bool + deferred: ObservableDeferred + + parent_opentracing_span_contexts: List = attr.ib(factory=list) + """A list of opentracing spans waiting for this batch""" + + opentracing_span_context: Any = None + """The opentracing span under which the persistence actually happened""" + + +_PersistResult = TypeVar("_PersistResult") + + +class _EventPeristenceQueue(Generic[_PersistResult]): + """Queues up events so that they can be persisted in bulk with only one + concurrent transaction per room. + """ + + def __init__( + self, + per_item_callback: Callable[ + [List[Tuple[EventBase, EventContext]], bool], + Awaitable[_PersistResult], + ], + ): + """Create a new event persistence queue + + The per_item_callback will be called for each item added via add_to_queue, + and its result will be returned via the Deferreds returned from add_to_queue. + """ + self._event_persist_queues: Dict[str, Deque[_EventPersistQueueItem]] = {} + self._currently_persisting_rooms: Set[str] = set() + self._per_item_callback = per_item_callback + + async def add_to_queue( + self, + room_id: str, + events_and_contexts: Iterable[Tuple[EventBase, EventContext]], + backfilled: bool, + ) -> _PersistResult: + """Add events to the queue, with the given persist_event options. + + If we are not already processing events in this room, starts off a background + process to to so, calling the per_item_callback for each item. + + Args: + room_id (str): + events_and_contexts (list[(EventBase, EventContext)]): + backfilled (bool): + + Returns: + the result returned by the `_per_item_callback` passed to + `__init__`. + """ + queue = self._event_persist_queues.setdefault(room_id, deque()) + + # if the last item in the queue has the same `backfilled` setting, + # we can just add these new events to that item. + if queue and queue[-1].backfilled == backfilled: + end_item = queue[-1] + else: + # need to make a new queue item + deferred: ObservableDeferred[_PersistResult] = ObservableDeferred( + defer.Deferred(), consumeErrors=True + ) + + end_item = _EventPersistQueueItem( + events_and_contexts=[], + backfilled=backfilled, + deferred=deferred, + ) + queue.append(end_item) + + # add our events to the queue item + end_item.events_and_contexts.extend(events_and_contexts) + + # also add our active opentracing span to the item so that we get a link back + span = opentracing.active_span() + if span: + end_item.parent_opentracing_span_contexts.append(span.context) + + # start a processor for the queue, if there isn't one already + self._handle_queue(room_id) + + # wait for the queue item to complete + res = await make_deferred_yieldable(end_item.deferred.observe()) + + # add another opentracing span which links to the persist trace. + with opentracing.start_active_span_follows_from( + "persist_event_batch_complete", (end_item.opentracing_span_context,) + ): + pass + + return res + + def _handle_queue(self, room_id: str) -> None: + """Attempts to handle the queue for a room if not already being handled. + + The queue's callback will be invoked with for each item in the queue, + of type _EventPersistQueueItem. The per_item_callback will continuously + be called with new items, unless the queue becomes empty. The return + value of the function will be given to the deferreds waiting on the item, + exceptions will be passed to the deferreds as well. + + This function should therefore be called whenever anything is added + to the queue. + + If another callback is currently handling the queue then it will not be + invoked. + """ + if room_id in self._currently_persisting_rooms: + return + + self._currently_persisting_rooms.add(room_id) + + async def handle_queue_loop() -> None: + try: + queue = self._get_drainining_queue(room_id) + for item in queue: + try: + with opentracing.start_active_span_follows_from( + "persist_event_batch", + item.parent_opentracing_span_contexts, + inherit_force_tracing=True, + ) as scope: + if scope: + item.opentracing_span_context = scope.span.context + + ret = await self._per_item_callback( + item.events_and_contexts, item.backfilled + ) + except Exception: + with PreserveLoggingContext(): + item.deferred.errback() + else: + with PreserveLoggingContext(): + item.deferred.callback(ret) + finally: + remaining_queue = self._event_persist_queues.pop(room_id, None) + if remaining_queue: + self._event_persist_queues[room_id] = remaining_queue + self._currently_persisting_rooms.discard(room_id) + + # set handle_queue_loop off in the background + run_as_background_process("persist_events", handle_queue_loop) + + def _get_drainining_queue( + self, room_id: str + ) -> Generator[_EventPersistQueueItem, None, None]: + queue = self._event_persist_queues.setdefault(room_id, deque()) + + try: + while True: + yield queue.popleft() + except IndexError: + # Queue has been drained. + pass + + +class EventsPersistenceStorageController: + """High level interface for handling persisting newly received events. + + Takes care of batching up events by room, and calculating the necessary + current state and forward extremity changes. + """ + + def __init__(self, hs: "HomeServer", stores: Databases): + # We ultimately want to split out the state store from the main store, + # so we use separate variables here even though they point to the same + # store for now. + self.main_store = stores.main + self.state_store = stores.state + + assert stores.persist_events + self.persist_events_store = stores.persist_events + + self._clock = hs.get_clock() + self._instance_name = hs.get_instance_name() + self.is_mine_id = hs.is_mine_id + self._event_persist_queue = _EventPeristenceQueue(self._persist_event_batch) + self._state_resolution_handler = hs.get_state_resolution_handler() + + @opentracing.trace + async def persist_events( + self, + events_and_contexts: Iterable[Tuple[EventBase, EventContext]], + backfilled: bool = False, + ) -> Tuple[List[EventBase], RoomStreamToken]: + """ + Write events to the database + Args: + events_and_contexts: list of tuples of (event, context) + backfilled: Whether the results are retrieved from federation + via backfill or not. Used to determine if they're "new" events + which might update the current state etc. + + Returns: + List of events persisted, the current position room stream position. + The list of events persisted may not be the same as those passed in + if they were deduplicated due to an event already existing that + matched the transaction ID; the existing event is returned in such + a case. + """ + partitioned: Dict[str, List[Tuple[EventBase, EventContext]]] = {} + for event, ctx in events_and_contexts: + partitioned.setdefault(event.room_id, []).append((event, ctx)) + + async def enqueue( + item: Tuple[str, List[Tuple[EventBase, EventContext]]] + ) -> Dict[str, str]: + room_id, evs_ctxs = item + return await self._event_persist_queue.add_to_queue( + room_id, evs_ctxs, backfilled=backfilled + ) + + ret_vals = await yieldable_gather_results(enqueue, partitioned.items()) + + # Each call to add_to_queue returns a map from event ID to existing event ID if + # the event was deduplicated. (The dict may also include other entries if + # the event was persisted in a batch with other events). + # + # Since we use `yieldable_gather_results` we need to merge the returned list + # of dicts into one. + replaced_events: Dict[str, str] = {} + for d in ret_vals: + replaced_events.update(d) + + events = [] + for event, _ in events_and_contexts: + existing_event_id = replaced_events.get(event.event_id) + if existing_event_id: + events.append(await self.main_store.get_event(existing_event_id)) + else: + events.append(event) + + return ( + events, + self.main_store.get_room_max_token(), + ) + + @opentracing.trace + async def persist_event( + self, event: EventBase, context: EventContext, backfilled: bool = False + ) -> Tuple[EventBase, PersistedEventPosition, RoomStreamToken]: + """ + Returns: + The event, stream ordering of `event`, and the stream ordering of the + latest persisted event. The returned event may not match the given + event if it was deduplicated due to an existing event matching the + transaction ID. + """ + # add_to_queue returns a map from event ID to existing event ID if the + # event was deduplicated. (The dict may also include other entries if + # the event was persisted in a batch with other events.) + replaced_events = await self._event_persist_queue.add_to_queue( + event.room_id, [(event, context)], backfilled=backfilled + ) + replaced_event = replaced_events.get(event.event_id) + if replaced_event: + event = await self.main_store.get_event(replaced_event) + + event_stream_id = event.internal_metadata.stream_ordering + # stream ordering should have been assigned by now + assert event_stream_id + + pos = PersistedEventPosition(self._instance_name, event_stream_id) + return event, pos, self.main_store.get_room_max_token() + + async def update_current_state(self, room_id: str) -> None: + """Recalculate the current state for a room, and persist it""" + state = await self._calculate_current_state(room_id) + delta = await self._calculate_state_delta(room_id, state) + + # TODO(faster_joins): get a real stream ordering, to make this work correctly + # across workers. + # + # TODO(faster_joins): this can race against event persistence, in which case we + # will end up with incorrect state. Perhaps we should make this a job we + # farm out to the event persister, somehow. + stream_id = self.main_store.get_room_max_stream_ordering() + await self.persist_events_store.update_current_state(room_id, delta, stream_id) + + async def _calculate_current_state(self, room_id: str) -> StateMap[str]: + """Calculate the current state of a room, based on the forward extremities + + Args: + room_id: room for which to calculate current state + + Returns: + map from (type, state_key) to event id for the current state in the room + """ + latest_event_ids = await self.main_store.get_latest_event_ids_in_room(room_id) + state_groups = set( + ( + await self.main_store._get_state_group_for_events(latest_event_ids) + ).values() + ) + + state_maps_by_state_group = await self.state_store._get_state_for_groups( + state_groups + ) + + if len(state_groups) == 1: + # If there is only one state group, then we know what the current + # state is. + return state_maps_by_state_group[state_groups.pop()] + + # Ok, we need to defer to the state handler to resolve our state sets. + logger.debug("calling resolve_state_groups from preserve_events") + + # Avoid a circular import. + from synapse.state import StateResolutionStore + + room_version = await self.main_store.get_room_version_id(room_id) + res = await self._state_resolution_handler.resolve_state_groups( + room_id, + room_version, + state_maps_by_state_group, + event_map=None, + state_res_store=StateResolutionStore(self.main_store), + ) + + return res.state + + async def _persist_event_batch( + self, + events_and_contexts: List[Tuple[EventBase, EventContext]], + backfilled: bool = False, + ) -> Dict[str, str]: + """Callback for the _event_persist_queue + + Calculates the change to current state and forward extremities, and + persists the given events and with those updates. + + Returns: + A dictionary of event ID to event ID we didn't persist as we already + had another event persisted with the same TXN ID. + """ + replaced_events: Dict[str, str] = {} + if not events_and_contexts: + return replaced_events + + # Check if any of the events have a transaction ID that has already been + # persisted, and if so we don't persist it again. + # + # We should have checked this a long time before we get here, but it's + # possible that different send event requests race in such a way that + # they both pass the earlier checks. Checking here isn't racey as we can + # have only one `_persist_events` per room being called at a time. + replaced_events = await self.main_store.get_already_persisted_events( + (event for event, _ in events_and_contexts) + ) + + if replaced_events: + events_and_contexts = [ + (e, ctx) + for e, ctx in events_and_contexts + if e.event_id not in replaced_events + ] + + if not events_and_contexts: + return replaced_events + + chunks = [ + events_and_contexts[x : x + 100] + for x in range(0, len(events_and_contexts), 100) + ] + + for chunk in chunks: + # We can't easily parallelize these since different chunks + # might contain the same event. :( + + # NB: Assumes that we are only persisting events for one room + # at a time. + + # map room_id->set[event_ids] giving the new forward + # extremities in each room + new_forward_extremities: Dict[str, Set[str]] = {} + + # map room_id->(to_delete, to_insert) where to_delete is a list + # of type/state keys to remove from current state, and to_insert + # is a map (type,key)->event_id giving the state delta in each + # room + state_delta_for_room: Dict[str, DeltaState] = {} + + # Set of remote users which were in rooms the server has left. We + # should check if we still share any rooms and if not we mark their + # device lists as stale. + potentially_left_users: Set[str] = set() + + if not backfilled: + with Measure(self._clock, "_calculate_state_and_extrem"): + # Work out the new "current state" for each room. + # We do this by working out what the new extremities are and then + # calculating the state from that. + events_by_room: Dict[str, List[Tuple[EventBase, EventContext]]] = {} + for event, context in chunk: + events_by_room.setdefault(event.room_id, []).append( + (event, context) + ) + + for room_id, ev_ctx_rm in events_by_room.items(): + latest_event_ids = set( + await self.main_store.get_latest_event_ids_in_room(room_id) + ) + new_latest_event_ids = await self._calculate_new_extremities( + room_id, ev_ctx_rm, latest_event_ids + ) + + if new_latest_event_ids == latest_event_ids: + # No change in extremities, so no change in state + continue + + # there should always be at least one forward extremity. + # (except during the initial persistence of the send_join + # results, in which case there will be no existing + # extremities, so we'll `continue` above and skip this bit.) + assert new_latest_event_ids, "No forward extremities left!" + + new_forward_extremities[room_id] = new_latest_event_ids + + len_1 = ( + len(latest_event_ids) == 1 + and len(new_latest_event_ids) == 1 + ) + if len_1: + all_single_prev_not_state = all( + len(event.prev_event_ids()) == 1 + and not event.is_state() + for event, ctx in ev_ctx_rm + ) + # Don't bother calculating state if they're just + # a long chain of single ancestor non-state events. + if all_single_prev_not_state: + continue + + state_delta_counter.inc() + if len(new_latest_event_ids) == 1: + state_delta_single_event_counter.inc() + + # This is a fairly handwavey check to see if we could + # have guessed what the delta would have been when + # processing one of these events. + # What we're interested in is if the latest extremities + # were the same when we created the event as they are + # now. When this server creates a new event (as opposed + # to receiving it over federation) it will use the + # forward extremities as the prev_events, so we can + # guess this by looking at the prev_events and checking + # if they match the current forward extremities. + for ev, _ in ev_ctx_rm: + prev_event_ids = set(ev.prev_event_ids()) + if latest_event_ids == prev_event_ids: + state_delta_reuse_delta_counter.inc() + break + + logger.debug("Calculating state delta for room %s", room_id) + with Measure( + self._clock, "persist_events.get_new_state_after_events" + ): + res = await self._get_new_state_after_events( + room_id, + ev_ctx_rm, + latest_event_ids, + new_latest_event_ids, + ) + current_state, delta_ids, new_latest_event_ids = res + + # there should always be at least one forward extremity. + # (except during the initial persistence of the send_join + # results, in which case there will be no existing + # extremities, so we'll `continue` above and skip this bit.) + assert new_latest_event_ids, "No forward extremities left!" + + new_forward_extremities[room_id] = new_latest_event_ids + + # If either are not None then there has been a change, + # and we need to work out the delta (or use that + # given) + delta = None + if delta_ids is not None: + # If there is a delta we know that we've + # only added or replaced state, never + # removed keys entirely. + delta = DeltaState([], delta_ids) + elif current_state is not None: + with Measure( + self._clock, "persist_events.calculate_state_delta" + ): + delta = await self._calculate_state_delta( + room_id, current_state + ) + + if delta: + # If we have a change of state then lets check + # whether we're actually still a member of the room, + # or if our last user left. If we're no longer in + # the room then we delete the current state and + # extremities. + is_still_joined = await self._is_server_still_joined( + room_id, + ev_ctx_rm, + delta, + current_state, + potentially_left_users, + ) + if not is_still_joined: + logger.info("Server no longer in room %s", room_id) + latest_event_ids = set() + current_state = {} + delta.no_longer_in_room = True + + state_delta_for_room[room_id] = delta + + await self.persist_events_store._persist_events_and_state_updates( + chunk, + state_delta_for_room=state_delta_for_room, + new_forward_extremities=new_forward_extremities, + use_negative_stream_ordering=backfilled, + inhibit_local_membership_updates=backfilled, + ) + + await self._handle_potentially_left_users(potentially_left_users) + + return replaced_events + + async def _calculate_new_extremities( + self, + room_id: str, + event_contexts: List[Tuple[EventBase, EventContext]], + latest_event_ids: Collection[str], + ) -> Set[str]: + """Calculates the new forward extremities for a room given events to + persist. + + Assumes that we are only persisting events for one room at a time. + """ + + # we're only interested in new events which aren't outliers and which aren't + # being rejected. + new_events = [ + event + for event, ctx in event_contexts + if not event.internal_metadata.is_outlier() + and not ctx.rejected + and not event.internal_metadata.is_soft_failed() + ] + + latest_event_ids = set(latest_event_ids) + + # start with the existing forward extremities + result = set(latest_event_ids) + + # add all the new events to the list + result.update(event.event_id for event in new_events) + + # Now remove all events which are prev_events of any of the new events + result.difference_update( + e_id for event in new_events for e_id in event.prev_event_ids() + ) + + # Remove any events which are prev_events of any existing events. + existing_prevs: Collection[ + str + ] = await self.persist_events_store._get_events_which_are_prevs(result) + result.difference_update(existing_prevs) + + # Finally handle the case where the new events have soft-failed prev + # events. If they do we need to remove them and their prev events, + # otherwise we end up with dangling extremities. + existing_prevs = await self.persist_events_store._get_prevs_before_rejected( + e_id for event in new_events for e_id in event.prev_event_ids() + ) + result.difference_update(existing_prevs) + + # We only update metrics for events that change forward extremities + # (e.g. we ignore backfill/outliers/etc) + if result != latest_event_ids: + forward_extremities_counter.observe(len(result)) + stale = latest_event_ids & result + stale_forward_extremities_counter.observe(len(stale)) + + return result + + async def _get_new_state_after_events( + self, + room_id: str, + events_context: List[Tuple[EventBase, EventContext]], + old_latest_event_ids: Set[str], + new_latest_event_ids: Set[str], + ) -> Tuple[Optional[StateMap[str]], Optional[StateMap[str]], Set[str]]: + """Calculate the current state dict after adding some new events to + a room + + Args: + room_id: + room to which the events are being added. Used for logging etc + + events_context: + events and contexts which are being added to the room + + old_latest_event_ids: + the old forward extremities for the room. + + new_latest_event_ids : + the new forward extremities for the room. + + Returns: + Returns a tuple of two state maps and a set of new forward + extremities. + + The first state map is the full new current state and the second + is the delta to the existing current state. If both are None then + there has been no change. Either or neither can be None if there + has been a change. + + The function may prune some old entries from the set of new + forward extremities if it's safe to do so. + + If there has been a change then we only return the delta if its + already been calculated. Conversely if we do know the delta then + the new current state is only returned if we've already calculated + it. + """ + # Map from (prev state group, new state group) -> delta state dict + state_group_deltas = {} + + for ev, ctx in events_context: + if ctx.state_group is None: + # This should only happen for outlier events. + if not ev.internal_metadata.is_outlier(): + raise Exception( + "Context for new event %s has no state " + "group" % (ev.event_id,) + ) + continue + + if ctx.prev_group: + state_group_deltas[(ctx.prev_group, ctx.state_group)] = ctx.delta_ids + + # We need to map the event_ids to their state groups. First, let's + # check if the event is one we're persisting, in which case we can + # pull the state group from its context. + # Otherwise we need to pull the state group from the database. + + # Set of events we need to fetch groups for. (We know none of the old + # extremities are going to be in events_context). + missing_event_ids = set(old_latest_event_ids) + + event_id_to_state_group = {} + for event_id in new_latest_event_ids: + # First search in the list of new events we're adding. + for ev, ctx in events_context: + if event_id == ev.event_id and ctx.state_group is not None: + event_id_to_state_group[event_id] = ctx.state_group + break + else: + # If we couldn't find it, then we'll need to pull + # the state from the database + missing_event_ids.add(event_id) + + if missing_event_ids: + # Now pull out the state groups for any missing events from DB + event_to_groups = await self.main_store._get_state_group_for_events( + missing_event_ids + ) + event_id_to_state_group.update(event_to_groups) + + # State groups of old_latest_event_ids + old_state_groups = { + event_id_to_state_group[evid] for evid in old_latest_event_ids + } + + # State groups of new_latest_event_ids + new_state_groups = { + event_id_to_state_group[evid] for evid in new_latest_event_ids + } + + # If they old and new groups are the same then we don't need to do + # anything. + if old_state_groups == new_state_groups: + return None, None, new_latest_event_ids + + if len(new_state_groups) == 1 and len(old_state_groups) == 1: + # If we're going from one state group to another, lets check if + # we have a delta for that transition. If we do then we can just + # return that. + + new_state_group = next(iter(new_state_groups)) + old_state_group = next(iter(old_state_groups)) + + delta_ids = state_group_deltas.get((old_state_group, new_state_group), None) + if delta_ids is not None: + # We have a delta from the existing to new current state, + # so lets just return that. + return None, delta_ids, new_latest_event_ids + + # Now that we have calculated new_state_groups we need to get + # their state IDs so we can resolve to a single state set. + state_groups_map = await self.state_store._get_state_for_groups( + new_state_groups + ) + + if len(new_state_groups) == 1: + # If there is only one state group, then we know what the current + # state is. + return state_groups_map[new_state_groups.pop()], None, new_latest_event_ids + + # Ok, we need to defer to the state handler to resolve our state sets. + + state_groups = {sg: state_groups_map[sg] for sg in new_state_groups} + + events_map = {ev.event_id: ev for ev, _ in events_context} + + # We need to get the room version, which is in the create event. + # Normally that'd be in the database, but its also possible that we're + # currently trying to persist it. + room_version = None + for ev, _ in events_context: + if ev.type == EventTypes.Create and ev.state_key == "": + room_version = ev.content.get("room_version", "1") + break + + if not room_version: + room_version = await self.main_store.get_room_version_id(room_id) + + logger.debug("calling resolve_state_groups from preserve_events") + + # Avoid a circular import. + from synapse.state import StateResolutionStore + + res = await self._state_resolution_handler.resolve_state_groups( + room_id, + room_version, + state_groups, + events_map, + state_res_store=StateResolutionStore(self.main_store), + ) + + state_resolutions_during_persistence.inc() + + # If the returned state matches the state group of one of the new + # forward extremities then we check if we are able to prune some state + # extremities. + if res.state_group and res.state_group in new_state_groups: + new_latest_event_ids = await self._prune_extremities( + room_id, + new_latest_event_ids, + res.state_group, + event_id_to_state_group, + events_context, + ) + + return res.state, None, new_latest_event_ids + + async def _prune_extremities( + self, + room_id: str, + new_latest_event_ids: Set[str], + resolved_state_group: int, + event_id_to_state_group: Dict[str, int], + events_context: List[Tuple[EventBase, EventContext]], + ) -> Set[str]: + """See if we can prune any of the extremities after calculating the + resolved state. + """ + potential_times_prune_extremities.inc() + + # We keep all the extremities that have the same state group, and + # see if we can drop the others. + new_new_extrems = { + e + for e in new_latest_event_ids + if event_id_to_state_group[e] == resolved_state_group + } + + dropped_extrems = set(new_latest_event_ids) - new_new_extrems + + logger.debug("Might drop extremities: %s", dropped_extrems) + + # We only drop events from the extremities list if: + # 1. we're not currently persisting them; + # 2. they're not our own events (or are dummy events); and + # 3. they're either: + # 1. over N hours old and more than N events ago (we use depth to + # calculate); or + # 2. we are persisting an event from the same domain and more than + # M events ago. + # + # The idea is that we don't want to drop events that are "legitimate" + # extremities (that we would want to include as prev events), only + # "stuck" extremities that are e.g. due to a gap in the graph. + # + # Note that we either drop all of them or none of them. If we only drop + # some of the events we don't know if state res would come to the same + # conclusion. + + for ev, _ in events_context: + if ev.event_id in dropped_extrems: + logger.debug( + "Not dropping extremities: %s is being persisted", ev.event_id + ) + return new_latest_event_ids + + dropped_events = await self.main_store.get_events( + dropped_extrems, + allow_rejected=True, + redact_behaviour=EventRedactBehaviour.as_is, + ) + + new_senders = {get_domain_from_id(e.sender) for e, _ in events_context} + + one_day_ago = self._clock.time_msec() - 24 * 60 * 60 * 1000 + current_depth = max(e.depth for e, _ in events_context) + for event in dropped_events.values(): + # If the event is a local dummy event then we should check it + # doesn't reference any local events, as we want to reference those + # if we send any new events. + # + # Note we do this recursively to handle the case where a dummy event + # references a dummy event that only references remote events. + # + # Ideally we'd figure out a way of still being able to drop old + # dummy events that reference local events, but this is good enough + # as a first cut. + events_to_check: Collection[EventBase] = [event] + while events_to_check: + new_events: Set[str] = set() + for event_to_check in events_to_check: + if self.is_mine_id(event_to_check.sender): + if event_to_check.type != EventTypes.Dummy: + logger.debug("Not dropping own event") + return new_latest_event_ids + new_events.update(event_to_check.prev_event_ids()) + + prev_events = await self.main_store.get_events( + new_events, + allow_rejected=True, + redact_behaviour=EventRedactBehaviour.as_is, + ) + events_to_check = prev_events.values() + + if ( + event.origin_server_ts < one_day_ago + and event.depth < current_depth - 100 + ): + continue + + # We can be less conservative about dropping extremities from the + # same domain, though we do want to wait a little bit (otherwise + # we'll immediately remove all extremities from a given server). + if ( + get_domain_from_id(event.sender) in new_senders + and event.depth < current_depth - 20 + ): + continue + + logger.debug( + "Not dropping as too new and not in new_senders: %s", + new_senders, + ) + + return new_latest_event_ids + + times_pruned_extremities.inc() + + logger.info( + "Pruning forward extremities in room %s: from %s -> %s", + room_id, + new_latest_event_ids, + new_new_extrems, + ) + return new_new_extrems + + async def _calculate_state_delta( + self, room_id: str, current_state: StateMap[str] + ) -> DeltaState: + """Calculate the new state deltas for a room. + + Assumes that we are only persisting events for one room at a time. + """ + existing_state = await self.main_store.get_current_state_ids(room_id) + + to_delete = [key for key in existing_state if key not in current_state] + + to_insert = { + key: ev_id + for key, ev_id in current_state.items() + if ev_id != existing_state.get(key) + } + + return DeltaState(to_delete=to_delete, to_insert=to_insert) + + async def _is_server_still_joined( + self, + room_id: str, + ev_ctx_rm: List[Tuple[EventBase, EventContext]], + delta: DeltaState, + current_state: Optional[StateMap[str]], + potentially_left_users: Set[str], + ) -> bool: + """Check if the server will still be joined after the given events have + been persised. + + Args: + room_id + ev_ctx_rm + delta: The delta of current state between what is in the database + and what the new current state will be. + current_state: The new current state if it already been calculated, + otherwise None. + potentially_left_users: If the server has left the room, then joined + remote users will be added to this set to indicate that the + server may no longer be sharing a room with them. + """ + + if not any( + self.is_mine_id(state_key) + for typ, state_key in itertools.chain(delta.to_delete, delta.to_insert) + if typ == EventTypes.Member + ): + # There have been no changes to membership of our users, so nothing + # has changed and we assume we're still in the room. + return True + + # Check if any of the given events are a local join that appear in the + # current state + events_to_check = [] # Event IDs that aren't an event we're persisting + for (typ, state_key), event_id in delta.to_insert.items(): + if typ != EventTypes.Member or not self.is_mine_id(state_key): + continue + + for event, _ in ev_ctx_rm: + if event_id == event.event_id: + if event.membership == Membership.JOIN: + return True + + # The event is not in `ev_ctx_rm`, so we need to pull it out of + # the DB. + events_to_check.append(event_id) + + # Check if any of the changes that we don't have events for are joins. + if events_to_check: + members = await self.main_store.get_membership_from_event_ids( + events_to_check + ) + is_still_joined = any( + member and member.membership == Membership.JOIN + for member in members.values() + ) + if is_still_joined: + return True + + # None of the new state events are local joins, so we check the database + # to see if there are any other local users in the room. We ignore users + # whose state has changed as we've already their new state above. + users_to_ignore = [ + state_key + for typ, state_key in itertools.chain(delta.to_insert, delta.to_delete) + if typ == EventTypes.Member and self.is_mine_id(state_key) + ] + + if await self.main_store.is_local_host_in_room_ignoring_users( + room_id, users_to_ignore + ): + return True + + # The server will leave the room, so we go and find out which remote + # users will still be joined when we leave. + if current_state is None: + current_state = await self.main_store.get_current_state_ids(room_id) + current_state = dict(current_state) + for key in delta.to_delete: + current_state.pop(key, None) + + current_state.update(delta.to_insert) + + remote_event_ids = [ + event_id + for ( + typ, + state_key, + ), event_id in current_state.items() + if typ == EventTypes.Member and not self.is_mine_id(state_key) + ] + members = await self.main_store.get_membership_from_event_ids(remote_event_ids) + potentially_left_users.update( + member.user_id + for member in members.values() + if member and member.membership == Membership.JOIN + ) + + return False + + async def _handle_potentially_left_users(self, user_ids: Set[str]) -> None: + """Given a set of remote users check if the server still shares a room with + them. If not then mark those users' device cache as stale. + """ + + if not user_ids: + return + + joined_users = await self.main_store.get_users_server_still_shares_room_with( + user_ids + ) + left_users = user_ids - joined_users + + for user_id in left_users: + await self.main_store.mark_remote_user_device_list_as_unsubscribed(user_id) diff --git a/synapse/storage/controllers/purge_events.py b/synapse/storage/controllers/purge_events.py new file mode 100644 index 0000000000..9ca50d6a09 --- /dev/null +++ b/synapse/storage/controllers/purge_events.py @@ -0,0 +1,112 @@ +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import logging +from typing import TYPE_CHECKING, Set + +from synapse.storage.databases import Databases + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class PurgeEventsStorageController: + """High level interface for purging rooms and event history.""" + + def __init__(self, hs: "HomeServer", stores: Databases): + self.stores = stores + + async def purge_room(self, room_id: str) -> None: + """Deletes all record of a room""" + + state_groups_to_delete = await self.stores.main.purge_room(room_id) + await self.stores.state.purge_room_state(room_id, state_groups_to_delete) + + async def purge_history( + self, room_id: str, token: str, delete_local_events: bool + ) -> None: + """Deletes room history before a certain point + + Args: + room_id: The room ID + + token: A topological token to delete events before + + delete_local_events: + if True, we will delete local events as well as remote ones + (instead of just marking them as outliers and deleting their + state groups). + """ + state_groups = await self.stores.main.purge_history( + room_id, token, delete_local_events + ) + + logger.info("[purge] finding state groups that can be deleted") + + sg_to_delete = await self._find_unreferenced_groups(state_groups) + + await self.stores.state.purge_unreferenced_state_groups(room_id, sg_to_delete) + + async def _find_unreferenced_groups(self, state_groups: Set[int]) -> Set[int]: + """Used when purging history to figure out which state groups can be + deleted. + + Args: + state_groups: Set of state groups referenced by events + that are going to be deleted. + + Returns: + The set of state groups that can be deleted. + """ + # Set of events that we have found to be referenced by events + referenced_groups = set() + + # Set of state groups we've already seen + state_groups_seen = set(state_groups) + + # Set of state groups to handle next. + next_to_search = set(state_groups) + while next_to_search: + # We bound size of groups we're looking up at once, to stop the + # SQL query getting too big + if len(next_to_search) < 100: + current_search = next_to_search + next_to_search = set() + else: + current_search = set(itertools.islice(next_to_search, 100)) + next_to_search -= current_search + + referenced = await self.stores.main.get_referenced_state_groups( + current_search + ) + referenced_groups |= referenced + + # We don't continue iterating up the state group graphs for state + # groups that are referenced. + current_search -= referenced + + edges = await self.stores.state.get_previous_state_groups(current_search) + + prevs = set(edges.values()) + # We don't bother re-handling groups we've already seen + prevs -= state_groups_seen + next_to_search |= prevs + state_groups_seen |= prevs + + to_delete = state_groups_seen - referenced_groups + + return to_delete diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py new file mode 100644 index 0000000000..0f09953086 --- /dev/null +++ b/synapse/storage/controllers/state.py @@ -0,0 +1,351 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import ( + TYPE_CHECKING, + Awaitable, + Collection, + Dict, + Iterable, + List, + Mapping, + Optional, + Tuple, +) + +from synapse.events import EventBase +from synapse.storage.state import StateFilter +from synapse.storage.util.partial_state_events_tracker import PartialStateEventsTracker +from synapse.types import MutableStateMap, StateMap + +if TYPE_CHECKING: + from synapse.server import HomeServer + from synapse.storage.databases import Databases + +logger = logging.getLogger(__name__) + + +class StateGroupStorageController: + """High level interface to fetching state for event.""" + + def __init__(self, hs: "HomeServer", stores: "Databases"): + self._is_mine_id = hs.is_mine_id + self.stores = stores + self._partial_state_events_tracker = PartialStateEventsTracker(stores.main) + + def notify_event_un_partial_stated(self, event_id: str) -> None: + self._partial_state_events_tracker.notify_un_partial_stated(event_id) + + async def get_state_group_delta( + self, state_group: int + ) -> Tuple[Optional[int], Optional[StateMap[str]]]: + """Given a state group try to return a previous group and a delta between + the old and the new. + + Args: + state_group: The state group used to retrieve state deltas. + + Returns: + A tuple of the previous group and a state map of the event IDs which + make up the delta between the old and new state groups. + """ + + state_group_delta = await self.stores.state.get_state_group_delta(state_group) + return state_group_delta.prev_group, state_group_delta.delta_ids + + async def get_state_groups_ids( + self, _room_id: str, event_ids: Collection[str] + ) -> Dict[int, MutableStateMap[str]]: + """Get the event IDs of all the state for the state groups for the given events + + Args: + _room_id: id of the room for these events + event_ids: ids of the events + + Returns: + dict of state_group_id -> (dict of (type, state_key) -> event id) + + Raises: + RuntimeError if we don't have a state group for one or more of the events + (ie they are outliers or unknown) + """ + if not event_ids: + return {} + + event_to_groups = await self.get_state_group_for_events(event_ids) + + groups = set(event_to_groups.values()) + group_to_state = await self.stores.state._get_state_for_groups(groups) + + return group_to_state + + async def get_state_ids_for_group( + self, state_group: int, state_filter: Optional[StateFilter] = None + ) -> StateMap[str]: + """Get the event IDs of all the state in the given state group + + Args: + state_group: A state group for which we want to get the state IDs. + state_filter: specifies the type of state event to fetch from DB, example: EventTypes.JoinRules + + Returns: + Resolves to a map of (type, state_key) -> event_id + """ + group_to_state = await self.get_state_for_groups((state_group,), state_filter) + + return group_to_state[state_group] + + async def get_state_groups( + self, room_id: str, event_ids: Collection[str] + ) -> Dict[int, List[EventBase]]: + """Get the state groups for the given list of event_ids + + Args: + room_id: ID of the room for these events. + event_ids: The event IDs to retrieve state for. + + Returns: + dict of state_group_id -> list of state events. + """ + if not event_ids: + return {} + + group_to_ids = await self.get_state_groups_ids(room_id, event_ids) + + state_event_map = await self.stores.main.get_events( + [ + ev_id + for group_ids in group_to_ids.values() + for ev_id in group_ids.values() + ], + get_prev_content=False, + ) + + return { + group: [ + state_event_map[v] + for v in event_id_map.values() + if v in state_event_map + ] + for group, event_id_map in group_to_ids.items() + } + + def _get_state_groups_from_groups( + self, groups: List[int], state_filter: StateFilter + ) -> Awaitable[Dict[int, StateMap[str]]]: + """Returns the state groups for a given set of groups, filtering on + types of state events. + + Args: + groups: list of state group IDs to query + state_filter: The state filter used to fetch state + from the database. + + Returns: + Dict of state group to state map. + """ + + return self.stores.state._get_state_groups_from_groups(groups, state_filter) + + async def get_state_for_events( + self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None + ) -> Dict[str, StateMap[EventBase]]: + """Given a list of event_ids and type tuples, return a list of state + dicts for each event. + + Args: + event_ids: The events to fetch the state of. + state_filter: The state filter used to fetch state. + + Returns: + A dict of (event_id) -> (type, state_key) -> [state_events] + + Raises: + RuntimeError if we don't have a state group for one or more of the events + (ie they are outliers or unknown) + """ + await_full_state = True + if state_filter and not state_filter.must_await_full_state(self._is_mine_id): + await_full_state = False + + event_to_groups = await self.get_state_group_for_events( + event_ids, await_full_state=await_full_state + ) + + groups = set(event_to_groups.values()) + group_to_state = await self.stores.state._get_state_for_groups( + groups, state_filter or StateFilter.all() + ) + + state_event_map = await self.stores.main.get_events( + [ev_id for sd in group_to_state.values() for ev_id in sd.values()], + get_prev_content=False, + ) + + event_to_state = { + event_id: { + k: state_event_map[v] + for k, v in group_to_state[group].items() + if v in state_event_map + } + for event_id, group in event_to_groups.items() + } + + return {event: event_to_state[event] for event in event_ids} + + async def get_state_ids_for_events( + self, + event_ids: Collection[str], + state_filter: Optional[StateFilter] = None, + ) -> Dict[str, StateMap[str]]: + """ + Get the state dicts corresponding to a list of events, containing the event_ids + of the state events (as opposed to the events themselves) + + Args: + event_ids: events whose state should be returned + state_filter: The state filter used to fetch state from the database. + + Returns: + A dict from event_id -> (type, state_key) -> event_id + + Raises: + RuntimeError if we don't have a state group for one or more of the events + (ie they are outliers or unknown) + """ + await_full_state = True + if state_filter and not state_filter.must_await_full_state(self._is_mine_id): + await_full_state = False + + event_to_groups = await self.get_state_group_for_events( + event_ids, await_full_state=await_full_state + ) + + groups = set(event_to_groups.values()) + group_to_state = await self.stores.state._get_state_for_groups( + groups, state_filter or StateFilter.all() + ) + + event_to_state = { + event_id: group_to_state[group] + for event_id, group in event_to_groups.items() + } + + return {event: event_to_state[event] for event in event_ids} + + async def get_state_for_event( + self, event_id: str, state_filter: Optional[StateFilter] = None + ) -> StateMap[EventBase]: + """ + Get the state dict corresponding to a particular event + + Args: + event_id: event whose state should be returned + state_filter: The state filter used to fetch state from the database. + + Returns: + A dict from (type, state_key) -> state_event + + Raises: + RuntimeError if we don't have a state group for the event (ie it is an + outlier or is unknown) + """ + state_map = await self.get_state_for_events( + [event_id], state_filter or StateFilter.all() + ) + return state_map[event_id] + + async def get_state_ids_for_event( + self, event_id: str, state_filter: Optional[StateFilter] = None + ) -> StateMap[str]: + """ + Get the state dict corresponding to a particular event + + Args: + event_id: event whose state should be returned + state_filter: The state filter used to fetch state from the database. + + Returns: + A dict from (type, state_key) -> state_event_id + + Raises: + RuntimeError if we don't have a state group for the event (ie it is an + outlier or is unknown) + """ + state_map = await self.get_state_ids_for_events( + [event_id], state_filter or StateFilter.all() + ) + return state_map[event_id] + + def get_state_for_groups( + self, groups: Iterable[int], state_filter: Optional[StateFilter] = None + ) -> Awaitable[Dict[int, MutableStateMap[str]]]: + """Gets the state at each of a list of state groups, optionally + filtering by type/state_key + + Args: + groups: list of state groups for which we want to get the state. + state_filter: The state filter used to fetch state. + from the database. + + Returns: + Dict of state group to state map. + """ + return self.stores.state._get_state_for_groups( + groups, state_filter or StateFilter.all() + ) + + async def get_state_group_for_events( + self, + event_ids: Collection[str], + await_full_state: bool = True, + ) -> Mapping[str, int]: + """Returns mapping event_id -> state_group + + Args: + event_ids: events to get state groups for + await_full_state: if true, will block if we do not yet have complete + state at these events. + """ + if await_full_state: + await self._partial_state_events_tracker.await_full_state(event_ids) + + return await self.stores.main._get_state_group_for_events(event_ids) + + async def store_state_group( + self, + event_id: str, + room_id: str, + prev_group: Optional[int], + delta_ids: Optional[StateMap[str]], + current_state_ids: StateMap[str], + ) -> int: + """Store a new set of state, returning a newly assigned state group. + + Args: + event_id: The event ID for which the state was calculated. + room_id: ID of the room for which the state was calculated. + prev_group: A previous state group for the room, optional. + delta_ids: The delta between state at `prev_group` and + `current_state_ids`, if `prev_group` was given. Same format as + `current_state_ids`. + current_state_ids: The state to store. Map of (type, state_key) + to event_id. + + Returns: + The state group ID + """ + return await self.stores.state.store_state_group( + event_id, room_id, prev_group, delta_ids, current_state_ids + ) diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py deleted file mode 100644 index a21dea91c8..0000000000 --- a/synapse/storage/persist_events.py +++ /dev/null @@ -1,1124 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# Copyright 2018-2019 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import itertools -import logging -from collections import deque -from typing import ( - TYPE_CHECKING, - Any, - Awaitable, - Callable, - Collection, - Deque, - Dict, - Generator, - Generic, - Iterable, - List, - Optional, - Set, - Tuple, - TypeVar, -) - -import attr -from prometheus_client import Counter, Histogram - -from twisted.internet import defer - -from synapse.api.constants import EventTypes, Membership -from synapse.events import EventBase -from synapse.events.snapshot import EventContext -from synapse.logging import opentracing -from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable -from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage.databases import Databases -from synapse.storage.databases.main.events import DeltaState -from synapse.storage.databases.main.events_worker import EventRedactBehaviour -from synapse.types import ( - PersistedEventPosition, - RoomStreamToken, - StateMap, - get_domain_from_id, -) -from synapse.util.async_helpers import ObservableDeferred, yieldable_gather_results -from synapse.util.metrics import Measure - -if TYPE_CHECKING: - from synapse.server import HomeServer - -logger = logging.getLogger(__name__) - -# The number of times we are recalculating the current state -state_delta_counter = Counter("synapse_storage_events_state_delta", "") - -# The number of times we are recalculating state when there is only a -# single forward extremity -state_delta_single_event_counter = Counter( - "synapse_storage_events_state_delta_single_event", "" -) - -# The number of times we are reculating state when we could have resonably -# calculated the delta when we calculated the state for an event we were -# persisting. -state_delta_reuse_delta_counter = Counter( - "synapse_storage_events_state_delta_reuse_delta", "" -) - -# The number of forward extremities for each new event. -forward_extremities_counter = Histogram( - "synapse_storage_events_forward_extremities_persisted", - "Number of forward extremities for each new event", - buckets=(1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"), -) - -# The number of stale forward extremities for each new event. Stale extremities -# are those that were in the previous set of extremities as well as the new. -stale_forward_extremities_counter = Histogram( - "synapse_storage_events_stale_forward_extremities_persisted", - "Number of unchanged forward extremities for each new event", - buckets=(0, 1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"), -) - -state_resolutions_during_persistence = Counter( - "synapse_storage_events_state_resolutions_during_persistence", - "Number of times we had to do state res to calculate new current state", -) - -potential_times_prune_extremities = Counter( - "synapse_storage_events_potential_times_prune_extremities", - "Number of times we might be able to prune extremities", -) - -times_pruned_extremities = Counter( - "synapse_storage_events_times_pruned_extremities", - "Number of times we were actually be able to prune extremities", -) - - -@attr.s(auto_attribs=True, slots=True) -class _EventPersistQueueItem: - events_and_contexts: List[Tuple[EventBase, EventContext]] - backfilled: bool - deferred: ObservableDeferred - - parent_opentracing_span_contexts: List = attr.ib(factory=list) - """A list of opentracing spans waiting for this batch""" - - opentracing_span_context: Any = None - """The opentracing span under which the persistence actually happened""" - - -_PersistResult = TypeVar("_PersistResult") - - -class _EventPeristenceQueue(Generic[_PersistResult]): - """Queues up events so that they can be persisted in bulk with only one - concurrent transaction per room. - """ - - def __init__( - self, - per_item_callback: Callable[ - [List[Tuple[EventBase, EventContext]], bool], - Awaitable[_PersistResult], - ], - ): - """Create a new event persistence queue - - The per_item_callback will be called for each item added via add_to_queue, - and its result will be returned via the Deferreds returned from add_to_queue. - """ - self._event_persist_queues: Dict[str, Deque[_EventPersistQueueItem]] = {} - self._currently_persisting_rooms: Set[str] = set() - self._per_item_callback = per_item_callback - - async def add_to_queue( - self, - room_id: str, - events_and_contexts: Iterable[Tuple[EventBase, EventContext]], - backfilled: bool, - ) -> _PersistResult: - """Add events to the queue, with the given persist_event options. - - If we are not already processing events in this room, starts off a background - process to to so, calling the per_item_callback for each item. - - Args: - room_id (str): - events_and_contexts (list[(EventBase, EventContext)]): - backfilled (bool): - - Returns: - the result returned by the `_per_item_callback` passed to - `__init__`. - """ - queue = self._event_persist_queues.setdefault(room_id, deque()) - - # if the last item in the queue has the same `backfilled` setting, - # we can just add these new events to that item. - if queue and queue[-1].backfilled == backfilled: - end_item = queue[-1] - else: - # need to make a new queue item - deferred: ObservableDeferred[_PersistResult] = ObservableDeferred( - defer.Deferred(), consumeErrors=True - ) - - end_item = _EventPersistQueueItem( - events_and_contexts=[], - backfilled=backfilled, - deferred=deferred, - ) - queue.append(end_item) - - # add our events to the queue item - end_item.events_and_contexts.extend(events_and_contexts) - - # also add our active opentracing span to the item so that we get a link back - span = opentracing.active_span() - if span: - end_item.parent_opentracing_span_contexts.append(span.context) - - # start a processor for the queue, if there isn't one already - self._handle_queue(room_id) - - # wait for the queue item to complete - res = await make_deferred_yieldable(end_item.deferred.observe()) - - # add another opentracing span which links to the persist trace. - with opentracing.start_active_span_follows_from( - "persist_event_batch_complete", (end_item.opentracing_span_context,) - ): - pass - - return res - - def _handle_queue(self, room_id: str) -> None: - """Attempts to handle the queue for a room if not already being handled. - - The queue's callback will be invoked with for each item in the queue, - of type _EventPersistQueueItem. The per_item_callback will continuously - be called with new items, unless the queue becomes empty. The return - value of the function will be given to the deferreds waiting on the item, - exceptions will be passed to the deferreds as well. - - This function should therefore be called whenever anything is added - to the queue. - - If another callback is currently handling the queue then it will not be - invoked. - """ - if room_id in self._currently_persisting_rooms: - return - - self._currently_persisting_rooms.add(room_id) - - async def handle_queue_loop() -> None: - try: - queue = self._get_drainining_queue(room_id) - for item in queue: - try: - with opentracing.start_active_span_follows_from( - "persist_event_batch", - item.parent_opentracing_span_contexts, - inherit_force_tracing=True, - ) as scope: - if scope: - item.opentracing_span_context = scope.span.context - - ret = await self._per_item_callback( - item.events_and_contexts, item.backfilled - ) - except Exception: - with PreserveLoggingContext(): - item.deferred.errback() - else: - with PreserveLoggingContext(): - item.deferred.callback(ret) - finally: - remaining_queue = self._event_persist_queues.pop(room_id, None) - if remaining_queue: - self._event_persist_queues[room_id] = remaining_queue - self._currently_persisting_rooms.discard(room_id) - - # set handle_queue_loop off in the background - run_as_background_process("persist_events", handle_queue_loop) - - def _get_drainining_queue( - self, room_id: str - ) -> Generator[_EventPersistQueueItem, None, None]: - queue = self._event_persist_queues.setdefault(room_id, deque()) - - try: - while True: - yield queue.popleft() - except IndexError: - # Queue has been drained. - pass - - -class EventsPersistenceStorage: - """High level interface for handling persisting newly received events. - - Takes care of batching up events by room, and calculating the necessary - current state and forward extremity changes. - """ - - def __init__(self, hs: "HomeServer", stores: Databases): - # We ultimately want to split out the state store from the main store, - # so we use separate variables here even though they point to the same - # store for now. - self.main_store = stores.main - self.state_store = stores.state - - assert stores.persist_events - self.persist_events_store = stores.persist_events - - self._clock = hs.get_clock() - self._instance_name = hs.get_instance_name() - self.is_mine_id = hs.is_mine_id - self._event_persist_queue = _EventPeristenceQueue(self._persist_event_batch) - self._state_resolution_handler = hs.get_state_resolution_handler() - - @opentracing.trace - async def persist_events( - self, - events_and_contexts: Iterable[Tuple[EventBase, EventContext]], - backfilled: bool = False, - ) -> Tuple[List[EventBase], RoomStreamToken]: - """ - Write events to the database - Args: - events_and_contexts: list of tuples of (event, context) - backfilled: Whether the results are retrieved from federation - via backfill or not. Used to determine if they're "new" events - which might update the current state etc. - - Returns: - List of events persisted, the current position room stream position. - The list of events persisted may not be the same as those passed in - if they were deduplicated due to an event already existing that - matched the transaction ID; the existing event is returned in such - a case. - """ - partitioned: Dict[str, List[Tuple[EventBase, EventContext]]] = {} - for event, ctx in events_and_contexts: - partitioned.setdefault(event.room_id, []).append((event, ctx)) - - async def enqueue( - item: Tuple[str, List[Tuple[EventBase, EventContext]]] - ) -> Dict[str, str]: - room_id, evs_ctxs = item - return await self._event_persist_queue.add_to_queue( - room_id, evs_ctxs, backfilled=backfilled - ) - - ret_vals = await yieldable_gather_results(enqueue, partitioned.items()) - - # Each call to add_to_queue returns a map from event ID to existing event ID if - # the event was deduplicated. (The dict may also include other entries if - # the event was persisted in a batch with other events). - # - # Since we use `yieldable_gather_results` we need to merge the returned list - # of dicts into one. - replaced_events: Dict[str, str] = {} - for d in ret_vals: - replaced_events.update(d) - - events = [] - for event, _ in events_and_contexts: - existing_event_id = replaced_events.get(event.event_id) - if existing_event_id: - events.append(await self.main_store.get_event(existing_event_id)) - else: - events.append(event) - - return ( - events, - self.main_store.get_room_max_token(), - ) - - @opentracing.trace - async def persist_event( - self, event: EventBase, context: EventContext, backfilled: bool = False - ) -> Tuple[EventBase, PersistedEventPosition, RoomStreamToken]: - """ - Returns: - The event, stream ordering of `event`, and the stream ordering of the - latest persisted event. The returned event may not match the given - event if it was deduplicated due to an existing event matching the - transaction ID. - """ - # add_to_queue returns a map from event ID to existing event ID if the - # event was deduplicated. (The dict may also include other entries if - # the event was persisted in a batch with other events.) - replaced_events = await self._event_persist_queue.add_to_queue( - event.room_id, [(event, context)], backfilled=backfilled - ) - replaced_event = replaced_events.get(event.event_id) - if replaced_event: - event = await self.main_store.get_event(replaced_event) - - event_stream_id = event.internal_metadata.stream_ordering - # stream ordering should have been assigned by now - assert event_stream_id - - pos = PersistedEventPosition(self._instance_name, event_stream_id) - return event, pos, self.main_store.get_room_max_token() - - async def update_current_state(self, room_id: str) -> None: - """Recalculate the current state for a room, and persist it""" - state = await self._calculate_current_state(room_id) - delta = await self._calculate_state_delta(room_id, state) - - # TODO(faster_joins): get a real stream ordering, to make this work correctly - # across workers. - # - # TODO(faster_joins): this can race against event persistence, in which case we - # will end up with incorrect state. Perhaps we should make this a job we - # farm out to the event persister, somehow. - stream_id = self.main_store.get_room_max_stream_ordering() - await self.persist_events_store.update_current_state(room_id, delta, stream_id) - - async def _calculate_current_state(self, room_id: str) -> StateMap[str]: - """Calculate the current state of a room, based on the forward extremities - - Args: - room_id: room for which to calculate current state - - Returns: - map from (type, state_key) to event id for the current state in the room - """ - latest_event_ids = await self.main_store.get_latest_event_ids_in_room(room_id) - state_groups = set( - ( - await self.main_store._get_state_group_for_events(latest_event_ids) - ).values() - ) - - state_maps_by_state_group = await self.state_store._get_state_for_groups( - state_groups - ) - - if len(state_groups) == 1: - # If there is only one state group, then we know what the current - # state is. - return state_maps_by_state_group[state_groups.pop()] - - # Ok, we need to defer to the state handler to resolve our state sets. - logger.debug("calling resolve_state_groups from preserve_events") - - # Avoid a circular import. - from synapse.state import StateResolutionStore - - room_version = await self.main_store.get_room_version_id(room_id) - res = await self._state_resolution_handler.resolve_state_groups( - room_id, - room_version, - state_maps_by_state_group, - event_map=None, - state_res_store=StateResolutionStore(self.main_store), - ) - - return res.state - - async def _persist_event_batch( - self, - events_and_contexts: List[Tuple[EventBase, EventContext]], - backfilled: bool = False, - ) -> Dict[str, str]: - """Callback for the _event_persist_queue - - Calculates the change to current state and forward extremities, and - persists the given events and with those updates. - - Returns: - A dictionary of event ID to event ID we didn't persist as we already - had another event persisted with the same TXN ID. - """ - replaced_events: Dict[str, str] = {} - if not events_and_contexts: - return replaced_events - - # Check if any of the events have a transaction ID that has already been - # persisted, and if so we don't persist it again. - # - # We should have checked this a long time before we get here, but it's - # possible that different send event requests race in such a way that - # they both pass the earlier checks. Checking here isn't racey as we can - # have only one `_persist_events` per room being called at a time. - replaced_events = await self.main_store.get_already_persisted_events( - (event for event, _ in events_and_contexts) - ) - - if replaced_events: - events_and_contexts = [ - (e, ctx) - for e, ctx in events_and_contexts - if e.event_id not in replaced_events - ] - - if not events_and_contexts: - return replaced_events - - chunks = [ - events_and_contexts[x : x + 100] - for x in range(0, len(events_and_contexts), 100) - ] - - for chunk in chunks: - # We can't easily parallelize these since different chunks - # might contain the same event. :( - - # NB: Assumes that we are only persisting events for one room - # at a time. - - # map room_id->set[event_ids] giving the new forward - # extremities in each room - new_forward_extremities: Dict[str, Set[str]] = {} - - # map room_id->(to_delete, to_insert) where to_delete is a list - # of type/state keys to remove from current state, and to_insert - # is a map (type,key)->event_id giving the state delta in each - # room - state_delta_for_room: Dict[str, DeltaState] = {} - - # Set of remote users which were in rooms the server has left. We - # should check if we still share any rooms and if not we mark their - # device lists as stale. - potentially_left_users: Set[str] = set() - - if not backfilled: - with Measure(self._clock, "_calculate_state_and_extrem"): - # Work out the new "current state" for each room. - # We do this by working out what the new extremities are and then - # calculating the state from that. - events_by_room: Dict[str, List[Tuple[EventBase, EventContext]]] = {} - for event, context in chunk: - events_by_room.setdefault(event.room_id, []).append( - (event, context) - ) - - for room_id, ev_ctx_rm in events_by_room.items(): - latest_event_ids = set( - await self.main_store.get_latest_event_ids_in_room(room_id) - ) - new_latest_event_ids = await self._calculate_new_extremities( - room_id, ev_ctx_rm, latest_event_ids - ) - - if new_latest_event_ids == latest_event_ids: - # No change in extremities, so no change in state - continue - - # there should always be at least one forward extremity. - # (except during the initial persistence of the send_join - # results, in which case there will be no existing - # extremities, so we'll `continue` above and skip this bit.) - assert new_latest_event_ids, "No forward extremities left!" - - new_forward_extremities[room_id] = new_latest_event_ids - - len_1 = ( - len(latest_event_ids) == 1 - and len(new_latest_event_ids) == 1 - ) - if len_1: - all_single_prev_not_state = all( - len(event.prev_event_ids()) == 1 - and not event.is_state() - for event, ctx in ev_ctx_rm - ) - # Don't bother calculating state if they're just - # a long chain of single ancestor non-state events. - if all_single_prev_not_state: - continue - - state_delta_counter.inc() - if len(new_latest_event_ids) == 1: - state_delta_single_event_counter.inc() - - # This is a fairly handwavey check to see if we could - # have guessed what the delta would have been when - # processing one of these events. - # What we're interested in is if the latest extremities - # were the same when we created the event as they are - # now. When this server creates a new event (as opposed - # to receiving it over federation) it will use the - # forward extremities as the prev_events, so we can - # guess this by looking at the prev_events and checking - # if they match the current forward extremities. - for ev, _ in ev_ctx_rm: - prev_event_ids = set(ev.prev_event_ids()) - if latest_event_ids == prev_event_ids: - state_delta_reuse_delta_counter.inc() - break - - logger.debug("Calculating state delta for room %s", room_id) - with Measure( - self._clock, "persist_events.get_new_state_after_events" - ): - res = await self._get_new_state_after_events( - room_id, - ev_ctx_rm, - latest_event_ids, - new_latest_event_ids, - ) - current_state, delta_ids, new_latest_event_ids = res - - # there should always be at least one forward extremity. - # (except during the initial persistence of the send_join - # results, in which case there will be no existing - # extremities, so we'll `continue` above and skip this bit.) - assert new_latest_event_ids, "No forward extremities left!" - - new_forward_extremities[room_id] = new_latest_event_ids - - # If either are not None then there has been a change, - # and we need to work out the delta (or use that - # given) - delta = None - if delta_ids is not None: - # If there is a delta we know that we've - # only added or replaced state, never - # removed keys entirely. - delta = DeltaState([], delta_ids) - elif current_state is not None: - with Measure( - self._clock, "persist_events.calculate_state_delta" - ): - delta = await self._calculate_state_delta( - room_id, current_state - ) - - if delta: - # If we have a change of state then lets check - # whether we're actually still a member of the room, - # or if our last user left. If we're no longer in - # the room then we delete the current state and - # extremities. - is_still_joined = await self._is_server_still_joined( - room_id, - ev_ctx_rm, - delta, - current_state, - potentially_left_users, - ) - if not is_still_joined: - logger.info("Server no longer in room %s", room_id) - latest_event_ids = set() - current_state = {} - delta.no_longer_in_room = True - - state_delta_for_room[room_id] = delta - - await self.persist_events_store._persist_events_and_state_updates( - chunk, - state_delta_for_room=state_delta_for_room, - new_forward_extremities=new_forward_extremities, - use_negative_stream_ordering=backfilled, - inhibit_local_membership_updates=backfilled, - ) - - await self._handle_potentially_left_users(potentially_left_users) - - return replaced_events - - async def _calculate_new_extremities( - self, - room_id: str, - event_contexts: List[Tuple[EventBase, EventContext]], - latest_event_ids: Collection[str], - ) -> Set[str]: - """Calculates the new forward extremities for a room given events to - persist. - - Assumes that we are only persisting events for one room at a time. - """ - - # we're only interested in new events which aren't outliers and which aren't - # being rejected. - new_events = [ - event - for event, ctx in event_contexts - if not event.internal_metadata.is_outlier() - and not ctx.rejected - and not event.internal_metadata.is_soft_failed() - ] - - latest_event_ids = set(latest_event_ids) - - # start with the existing forward extremities - result = set(latest_event_ids) - - # add all the new events to the list - result.update(event.event_id for event in new_events) - - # Now remove all events which are prev_events of any of the new events - result.difference_update( - e_id for event in new_events for e_id in event.prev_event_ids() - ) - - # Remove any events which are prev_events of any existing events. - existing_prevs: Collection[ - str - ] = await self.persist_events_store._get_events_which_are_prevs(result) - result.difference_update(existing_prevs) - - # Finally handle the case where the new events have soft-failed prev - # events. If they do we need to remove them and their prev events, - # otherwise we end up with dangling extremities. - existing_prevs = await self.persist_events_store._get_prevs_before_rejected( - e_id for event in new_events for e_id in event.prev_event_ids() - ) - result.difference_update(existing_prevs) - - # We only update metrics for events that change forward extremities - # (e.g. we ignore backfill/outliers/etc) - if result != latest_event_ids: - forward_extremities_counter.observe(len(result)) - stale = latest_event_ids & result - stale_forward_extremities_counter.observe(len(stale)) - - return result - - async def _get_new_state_after_events( - self, - room_id: str, - events_context: List[Tuple[EventBase, EventContext]], - old_latest_event_ids: Set[str], - new_latest_event_ids: Set[str], - ) -> Tuple[Optional[StateMap[str]], Optional[StateMap[str]], Set[str]]: - """Calculate the current state dict after adding some new events to - a room - - Args: - room_id: - room to which the events are being added. Used for logging etc - - events_context: - events and contexts which are being added to the room - - old_latest_event_ids: - the old forward extremities for the room. - - new_latest_event_ids : - the new forward extremities for the room. - - Returns: - Returns a tuple of two state maps and a set of new forward - extremities. - - The first state map is the full new current state and the second - is the delta to the existing current state. If both are None then - there has been no change. Either or neither can be None if there - has been a change. - - The function may prune some old entries from the set of new - forward extremities if it's safe to do so. - - If there has been a change then we only return the delta if its - already been calculated. Conversely if we do know the delta then - the new current state is only returned if we've already calculated - it. - """ - # Map from (prev state group, new state group) -> delta state dict - state_group_deltas = {} - - for ev, ctx in events_context: - if ctx.state_group is None: - # This should only happen for outlier events. - if not ev.internal_metadata.is_outlier(): - raise Exception( - "Context for new event %s has no state " - "group" % (ev.event_id,) - ) - continue - - if ctx.prev_group: - state_group_deltas[(ctx.prev_group, ctx.state_group)] = ctx.delta_ids - - # We need to map the event_ids to their state groups. First, let's - # check if the event is one we're persisting, in which case we can - # pull the state group from its context. - # Otherwise we need to pull the state group from the database. - - # Set of events we need to fetch groups for. (We know none of the old - # extremities are going to be in events_context). - missing_event_ids = set(old_latest_event_ids) - - event_id_to_state_group = {} - for event_id in new_latest_event_ids: - # First search in the list of new events we're adding. - for ev, ctx in events_context: - if event_id == ev.event_id and ctx.state_group is not None: - event_id_to_state_group[event_id] = ctx.state_group - break - else: - # If we couldn't find it, then we'll need to pull - # the state from the database - missing_event_ids.add(event_id) - - if missing_event_ids: - # Now pull out the state groups for any missing events from DB - event_to_groups = await self.main_store._get_state_group_for_events( - missing_event_ids - ) - event_id_to_state_group.update(event_to_groups) - - # State groups of old_latest_event_ids - old_state_groups = { - event_id_to_state_group[evid] for evid in old_latest_event_ids - } - - # State groups of new_latest_event_ids - new_state_groups = { - event_id_to_state_group[evid] for evid in new_latest_event_ids - } - - # If they old and new groups are the same then we don't need to do - # anything. - if old_state_groups == new_state_groups: - return None, None, new_latest_event_ids - - if len(new_state_groups) == 1 and len(old_state_groups) == 1: - # If we're going from one state group to another, lets check if - # we have a delta for that transition. If we do then we can just - # return that. - - new_state_group = next(iter(new_state_groups)) - old_state_group = next(iter(old_state_groups)) - - delta_ids = state_group_deltas.get((old_state_group, new_state_group), None) - if delta_ids is not None: - # We have a delta from the existing to new current state, - # so lets just return that. - return None, delta_ids, new_latest_event_ids - - # Now that we have calculated new_state_groups we need to get - # their state IDs so we can resolve to a single state set. - state_groups_map = await self.state_store._get_state_for_groups( - new_state_groups - ) - - if len(new_state_groups) == 1: - # If there is only one state group, then we know what the current - # state is. - return state_groups_map[new_state_groups.pop()], None, new_latest_event_ids - - # Ok, we need to defer to the state handler to resolve our state sets. - - state_groups = {sg: state_groups_map[sg] for sg in new_state_groups} - - events_map = {ev.event_id: ev for ev, _ in events_context} - - # We need to get the room version, which is in the create event. - # Normally that'd be in the database, but its also possible that we're - # currently trying to persist it. - room_version = None - for ev, _ in events_context: - if ev.type == EventTypes.Create and ev.state_key == "": - room_version = ev.content.get("room_version", "1") - break - - if not room_version: - room_version = await self.main_store.get_room_version_id(room_id) - - logger.debug("calling resolve_state_groups from preserve_events") - - # Avoid a circular import. - from synapse.state import StateResolutionStore - - res = await self._state_resolution_handler.resolve_state_groups( - room_id, - room_version, - state_groups, - events_map, - state_res_store=StateResolutionStore(self.main_store), - ) - - state_resolutions_during_persistence.inc() - - # If the returned state matches the state group of one of the new - # forward extremities then we check if we are able to prune some state - # extremities. - if res.state_group and res.state_group in new_state_groups: - new_latest_event_ids = await self._prune_extremities( - room_id, - new_latest_event_ids, - res.state_group, - event_id_to_state_group, - events_context, - ) - - return res.state, None, new_latest_event_ids - - async def _prune_extremities( - self, - room_id: str, - new_latest_event_ids: Set[str], - resolved_state_group: int, - event_id_to_state_group: Dict[str, int], - events_context: List[Tuple[EventBase, EventContext]], - ) -> Set[str]: - """See if we can prune any of the extremities after calculating the - resolved state. - """ - potential_times_prune_extremities.inc() - - # We keep all the extremities that have the same state group, and - # see if we can drop the others. - new_new_extrems = { - e - for e in new_latest_event_ids - if event_id_to_state_group[e] == resolved_state_group - } - - dropped_extrems = set(new_latest_event_ids) - new_new_extrems - - logger.debug("Might drop extremities: %s", dropped_extrems) - - # We only drop events from the extremities list if: - # 1. we're not currently persisting them; - # 2. they're not our own events (or are dummy events); and - # 3. they're either: - # 1. over N hours old and more than N events ago (we use depth to - # calculate); or - # 2. we are persisting an event from the same domain and more than - # M events ago. - # - # The idea is that we don't want to drop events that are "legitimate" - # extremities (that we would want to include as prev events), only - # "stuck" extremities that are e.g. due to a gap in the graph. - # - # Note that we either drop all of them or none of them. If we only drop - # some of the events we don't know if state res would come to the same - # conclusion. - - for ev, _ in events_context: - if ev.event_id in dropped_extrems: - logger.debug( - "Not dropping extremities: %s is being persisted", ev.event_id - ) - return new_latest_event_ids - - dropped_events = await self.main_store.get_events( - dropped_extrems, - allow_rejected=True, - redact_behaviour=EventRedactBehaviour.as_is, - ) - - new_senders = {get_domain_from_id(e.sender) for e, _ in events_context} - - one_day_ago = self._clock.time_msec() - 24 * 60 * 60 * 1000 - current_depth = max(e.depth for e, _ in events_context) - for event in dropped_events.values(): - # If the event is a local dummy event then we should check it - # doesn't reference any local events, as we want to reference those - # if we send any new events. - # - # Note we do this recursively to handle the case where a dummy event - # references a dummy event that only references remote events. - # - # Ideally we'd figure out a way of still being able to drop old - # dummy events that reference local events, but this is good enough - # as a first cut. - events_to_check: Collection[EventBase] = [event] - while events_to_check: - new_events: Set[str] = set() - for event_to_check in events_to_check: - if self.is_mine_id(event_to_check.sender): - if event_to_check.type != EventTypes.Dummy: - logger.debug("Not dropping own event") - return new_latest_event_ids - new_events.update(event_to_check.prev_event_ids()) - - prev_events = await self.main_store.get_events( - new_events, - allow_rejected=True, - redact_behaviour=EventRedactBehaviour.as_is, - ) - events_to_check = prev_events.values() - - if ( - event.origin_server_ts < one_day_ago - and event.depth < current_depth - 100 - ): - continue - - # We can be less conservative about dropping extremities from the - # same domain, though we do want to wait a little bit (otherwise - # we'll immediately remove all extremities from a given server). - if ( - get_domain_from_id(event.sender) in new_senders - and event.depth < current_depth - 20 - ): - continue - - logger.debug( - "Not dropping as too new and not in new_senders: %s", - new_senders, - ) - - return new_latest_event_ids - - times_pruned_extremities.inc() - - logger.info( - "Pruning forward extremities in room %s: from %s -> %s", - room_id, - new_latest_event_ids, - new_new_extrems, - ) - return new_new_extrems - - async def _calculate_state_delta( - self, room_id: str, current_state: StateMap[str] - ) -> DeltaState: - """Calculate the new state deltas for a room. - - Assumes that we are only persisting events for one room at a time. - """ - existing_state = await self.main_store.get_current_state_ids(room_id) - - to_delete = [key for key in existing_state if key not in current_state] - - to_insert = { - key: ev_id - for key, ev_id in current_state.items() - if ev_id != existing_state.get(key) - } - - return DeltaState(to_delete=to_delete, to_insert=to_insert) - - async def _is_server_still_joined( - self, - room_id: str, - ev_ctx_rm: List[Tuple[EventBase, EventContext]], - delta: DeltaState, - current_state: Optional[StateMap[str]], - potentially_left_users: Set[str], - ) -> bool: - """Check if the server will still be joined after the given events have - been persised. - - Args: - room_id - ev_ctx_rm - delta: The delta of current state between what is in the database - and what the new current state will be. - current_state: The new current state if it already been calculated, - otherwise None. - potentially_left_users: If the server has left the room, then joined - remote users will be added to this set to indicate that the - server may no longer be sharing a room with them. - """ - - if not any( - self.is_mine_id(state_key) - for typ, state_key in itertools.chain(delta.to_delete, delta.to_insert) - if typ == EventTypes.Member - ): - # There have been no changes to membership of our users, so nothing - # has changed and we assume we're still in the room. - return True - - # Check if any of the given events are a local join that appear in the - # current state - events_to_check = [] # Event IDs that aren't an event we're persisting - for (typ, state_key), event_id in delta.to_insert.items(): - if typ != EventTypes.Member or not self.is_mine_id(state_key): - continue - - for event, _ in ev_ctx_rm: - if event_id == event.event_id: - if event.membership == Membership.JOIN: - return True - - # The event is not in `ev_ctx_rm`, so we need to pull it out of - # the DB. - events_to_check.append(event_id) - - # Check if any of the changes that we don't have events for are joins. - if events_to_check: - members = await self.main_store.get_membership_from_event_ids( - events_to_check - ) - is_still_joined = any( - member and member.membership == Membership.JOIN - for member in members.values() - ) - if is_still_joined: - return True - - # None of the new state events are local joins, so we check the database - # to see if there are any other local users in the room. We ignore users - # whose state has changed as we've already their new state above. - users_to_ignore = [ - state_key - for typ, state_key in itertools.chain(delta.to_insert, delta.to_delete) - if typ == EventTypes.Member and self.is_mine_id(state_key) - ] - - if await self.main_store.is_local_host_in_room_ignoring_users( - room_id, users_to_ignore - ): - return True - - # The server will leave the room, so we go and find out which remote - # users will still be joined when we leave. - if current_state is None: - current_state = await self.main_store.get_current_state_ids(room_id) - current_state = dict(current_state) - for key in delta.to_delete: - current_state.pop(key, None) - - current_state.update(delta.to_insert) - - remote_event_ids = [ - event_id - for ( - typ, - state_key, - ), event_id in current_state.items() - if typ == EventTypes.Member and not self.is_mine_id(state_key) - ] - members = await self.main_store.get_membership_from_event_ids(remote_event_ids) - potentially_left_users.update( - member.user_id - for member in members.values() - if member and member.membership == Membership.JOIN - ) - - return False - - async def _handle_potentially_left_users(self, user_ids: Set[str]) -> None: - """Given a set of remote users check if the server still shares a room with - them. If not then mark those users' device cache as stale. - """ - - if not user_ids: - return - - joined_users = await self.main_store.get_users_server_still_shares_room_with( - user_ids - ) - left_users = user_ids - joined_users - - for user_id in left_users: - await self.main_store.mark_remote_user_device_list_as_unsubscribed(user_id) diff --git a/synapse/storage/purge_events.py b/synapse/storage/purge_events.py deleted file mode 100644 index 30669beb7c..0000000000 --- a/synapse/storage/purge_events.py +++ /dev/null @@ -1,112 +0,0 @@ -# Copyright 2019 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import itertools -import logging -from typing import TYPE_CHECKING, Set - -from synapse.storage.databases import Databases - -if TYPE_CHECKING: - from synapse.server import HomeServer - -logger = logging.getLogger(__name__) - - -class PurgeEventsStorage: - """High level interface for purging rooms and event history.""" - - def __init__(self, hs: "HomeServer", stores: Databases): - self.stores = stores - - async def purge_room(self, room_id: str) -> None: - """Deletes all record of a room""" - - state_groups_to_delete = await self.stores.main.purge_room(room_id) - await self.stores.state.purge_room_state(room_id, state_groups_to_delete) - - async def purge_history( - self, room_id: str, token: str, delete_local_events: bool - ) -> None: - """Deletes room history before a certain point - - Args: - room_id: The room ID - - token: A topological token to delete events before - - delete_local_events: - if True, we will delete local events as well as remote ones - (instead of just marking them as outliers and deleting their - state groups). - """ - state_groups = await self.stores.main.purge_history( - room_id, token, delete_local_events - ) - - logger.info("[purge] finding state groups that can be deleted") - - sg_to_delete = await self._find_unreferenced_groups(state_groups) - - await self.stores.state.purge_unreferenced_state_groups(room_id, sg_to_delete) - - async def _find_unreferenced_groups(self, state_groups: Set[int]) -> Set[int]: - """Used when purging history to figure out which state groups can be - deleted. - - Args: - state_groups: Set of state groups referenced by events - that are going to be deleted. - - Returns: - The set of state groups that can be deleted. - """ - # Set of events that we have found to be referenced by events - referenced_groups = set() - - # Set of state groups we've already seen - state_groups_seen = set(state_groups) - - # Set of state groups to handle next. - next_to_search = set(state_groups) - while next_to_search: - # We bound size of groups we're looking up at once, to stop the - # SQL query getting too big - if len(next_to_search) < 100: - current_search = next_to_search - next_to_search = set() - else: - current_search = set(itertools.islice(next_to_search, 100)) - next_to_search -= current_search - - referenced = await self.stores.main.get_referenced_state_groups( - current_search - ) - referenced_groups |= referenced - - # We don't continue iterating up the state group graphs for state - # groups that are referenced. - current_search -= referenced - - edges = await self.stores.state.get_previous_state_groups(current_search) - - prevs = set(edges.values()) - # We don't bother re-handling groups we've already seen - prevs -= state_groups_seen - next_to_search |= prevs - state_groups_seen |= prevs - - to_delete = state_groups_seen - referenced_groups - - return to_delete diff --git a/synapse/storage/state.py b/synapse/storage/state.py index ab630953ac..96aaffb53c 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -15,7 +15,6 @@ import logging from typing import ( TYPE_CHECKING, - Awaitable, Callable, Collection, Dict, @@ -32,15 +31,11 @@ import attr from frozendict import frozendict from synapse.api.constants import EventTypes -from synapse.events import EventBase -from synapse.storage.util.partial_state_events_tracker import PartialStateEventsTracker from synapse.types import MutableStateMap, StateKey, StateMap if TYPE_CHECKING: from typing import FrozenSet # noqa: used within quoted type hint; flake8 sad - from synapse.server import HomeServer - from synapse.storage.databases import Databases logger = logging.getLogger(__name__) @@ -578,318 +573,3 @@ _ALL_NON_MEMBER_STATE_FILTER = StateFilter( types=frozendict({EventTypes.Member: frozenset()}), include_others=True ) _NONE_STATE_FILTER = StateFilter(types=frozendict(), include_others=False) - - -class StateGroupStorage: - """High level interface to fetching state for event.""" - - def __init__(self, hs: "HomeServer", stores: "Databases"): - self._is_mine_id = hs.is_mine_id - self.stores = stores - self._partial_state_events_tracker = PartialStateEventsTracker(stores.main) - - def notify_event_un_partial_stated(self, event_id: str) -> None: - self._partial_state_events_tracker.notify_un_partial_stated(event_id) - - async def get_state_group_delta( - self, state_group: int - ) -> Tuple[Optional[int], Optional[StateMap[str]]]: - """Given a state group try to return a previous group and a delta between - the old and the new. - - Args: - state_group: The state group used to retrieve state deltas. - - Returns: - A tuple of the previous group and a state map of the event IDs which - make up the delta between the old and new state groups. - """ - - state_group_delta = await self.stores.state.get_state_group_delta(state_group) - return state_group_delta.prev_group, state_group_delta.delta_ids - - async def get_state_groups_ids( - self, _room_id: str, event_ids: Collection[str] - ) -> Dict[int, MutableStateMap[str]]: - """Get the event IDs of all the state for the state groups for the given events - - Args: - _room_id: id of the room for these events - event_ids: ids of the events - - Returns: - dict of state_group_id -> (dict of (type, state_key) -> event id) - - Raises: - RuntimeError if we don't have a state group for one or more of the events - (ie they are outliers or unknown) - """ - if not event_ids: - return {} - - event_to_groups = await self.get_state_group_for_events(event_ids) - - groups = set(event_to_groups.values()) - group_to_state = await self.stores.state._get_state_for_groups(groups) - - return group_to_state - - async def get_state_ids_for_group( - self, state_group: int, state_filter: Optional[StateFilter] = None - ) -> StateMap[str]: - """Get the event IDs of all the state in the given state group - - Args: - state_group: A state group for which we want to get the state IDs. - state_filter: specifies the type of state event to fetch from DB, example: EventTypes.JoinRules - - Returns: - Resolves to a map of (type, state_key) -> event_id - """ - group_to_state = await self.get_state_for_groups((state_group,), state_filter) - - return group_to_state[state_group] - - async def get_state_groups( - self, room_id: str, event_ids: Collection[str] - ) -> Dict[int, List[EventBase]]: - """Get the state groups for the given list of event_ids - - Args: - room_id: ID of the room for these events. - event_ids: The event IDs to retrieve state for. - - Returns: - dict of state_group_id -> list of state events. - """ - if not event_ids: - return {} - - group_to_ids = await self.get_state_groups_ids(room_id, event_ids) - - state_event_map = await self.stores.main.get_events( - [ - ev_id - for group_ids in group_to_ids.values() - for ev_id in group_ids.values() - ], - get_prev_content=False, - ) - - return { - group: [ - state_event_map[v] - for v in event_id_map.values() - if v in state_event_map - ] - for group, event_id_map in group_to_ids.items() - } - - def _get_state_groups_from_groups( - self, groups: List[int], state_filter: StateFilter - ) -> Awaitable[Dict[int, StateMap[str]]]: - """Returns the state groups for a given set of groups, filtering on - types of state events. - - Args: - groups: list of state group IDs to query - state_filter: The state filter used to fetch state - from the database. - - Returns: - Dict of state group to state map. - """ - - return self.stores.state._get_state_groups_from_groups(groups, state_filter) - - async def get_state_for_events( - self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None - ) -> Dict[str, StateMap[EventBase]]: - """Given a list of event_ids and type tuples, return a list of state - dicts for each event. - - Args: - event_ids: The events to fetch the state of. - state_filter: The state filter used to fetch state. - - Returns: - A dict of (event_id) -> (type, state_key) -> [state_events] - - Raises: - RuntimeError if we don't have a state group for one or more of the events - (ie they are outliers or unknown) - """ - await_full_state = True - if state_filter and not state_filter.must_await_full_state(self._is_mine_id): - await_full_state = False - - event_to_groups = await self.get_state_group_for_events( - event_ids, await_full_state=await_full_state - ) - - groups = set(event_to_groups.values()) - group_to_state = await self.stores.state._get_state_for_groups( - groups, state_filter or StateFilter.all() - ) - - state_event_map = await self.stores.main.get_events( - [ev_id for sd in group_to_state.values() for ev_id in sd.values()], - get_prev_content=False, - ) - - event_to_state = { - event_id: { - k: state_event_map[v] - for k, v in group_to_state[group].items() - if v in state_event_map - } - for event_id, group in event_to_groups.items() - } - - return {event: event_to_state[event] for event in event_ids} - - async def get_state_ids_for_events( - self, - event_ids: Collection[str], - state_filter: Optional[StateFilter] = None, - ) -> Dict[str, StateMap[str]]: - """ - Get the state dicts corresponding to a list of events, containing the event_ids - of the state events (as opposed to the events themselves) - - Args: - event_ids: events whose state should be returned - state_filter: The state filter used to fetch state from the database. - - Returns: - A dict from event_id -> (type, state_key) -> event_id - - Raises: - RuntimeError if we don't have a state group for one or more of the events - (ie they are outliers or unknown) - """ - await_full_state = True - if state_filter and not state_filter.must_await_full_state(self._is_mine_id): - await_full_state = False - - event_to_groups = await self.get_state_group_for_events( - event_ids, await_full_state=await_full_state - ) - - groups = set(event_to_groups.values()) - group_to_state = await self.stores.state._get_state_for_groups( - groups, state_filter or StateFilter.all() - ) - - event_to_state = { - event_id: group_to_state[group] - for event_id, group in event_to_groups.items() - } - - return {event: event_to_state[event] for event in event_ids} - - async def get_state_for_event( - self, event_id: str, state_filter: Optional[StateFilter] = None - ) -> StateMap[EventBase]: - """ - Get the state dict corresponding to a particular event - - Args: - event_id: event whose state should be returned - state_filter: The state filter used to fetch state from the database. - - Returns: - A dict from (type, state_key) -> state_event - - Raises: - RuntimeError if we don't have a state group for the event (ie it is an - outlier or is unknown) - """ - state_map = await self.get_state_for_events( - [event_id], state_filter or StateFilter.all() - ) - return state_map[event_id] - - async def get_state_ids_for_event( - self, event_id: str, state_filter: Optional[StateFilter] = None - ) -> StateMap[str]: - """ - Get the state dict corresponding to a particular event - - Args: - event_id: event whose state should be returned - state_filter: The state filter used to fetch state from the database. - - Returns: - A dict from (type, state_key) -> state_event_id - - Raises: - RuntimeError if we don't have a state group for the event (ie it is an - outlier or is unknown) - """ - state_map = await self.get_state_ids_for_events( - [event_id], state_filter or StateFilter.all() - ) - return state_map[event_id] - - def get_state_for_groups( - self, groups: Iterable[int], state_filter: Optional[StateFilter] = None - ) -> Awaitable[Dict[int, MutableStateMap[str]]]: - """Gets the state at each of a list of state groups, optionally - filtering by type/state_key - - Args: - groups: list of state groups for which we want to get the state. - state_filter: The state filter used to fetch state. - from the database. - - Returns: - Dict of state group to state map. - """ - return self.stores.state._get_state_for_groups( - groups, state_filter or StateFilter.all() - ) - - async def get_state_group_for_events( - self, - event_ids: Collection[str], - await_full_state: bool = True, - ) -> Mapping[str, int]: - """Returns mapping event_id -> state_group - - Args: - event_ids: events to get state groups for - await_full_state: if true, will block if we do not yet have complete - state at these events. - """ - if await_full_state: - await self._partial_state_events_tracker.await_full_state(event_ids) - - return await self.stores.main._get_state_group_for_events(event_ids) - - async def store_state_group( - self, - event_id: str, - room_id: str, - prev_group: Optional[int], - delta_ids: Optional[StateMap[str]], - current_state_ids: StateMap[str], - ) -> int: - """Store a new set of state, returning a newly assigned state group. - - Args: - event_id: The event ID for which the state was calculated. - room_id: ID of the room for which the state was calculated. - prev_group: A previous state group for the room, optional. - delta_ids: The delta between state at `prev_group` and - `current_state_ids`, if `prev_group` was given. Same format as - `current_state_ids`. - current_state_ids: The state to store. Map of (type, state_key) - to event_id. - - Returns: - The state group ID - """ - return await self.stores.state.store_state_group( - event_id, room_id, prev_group, delta_ids, current_state_ids - ) diff --git a/synapse/visibility.py b/synapse/visibility.py index da4af02796..97548c14e3 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -20,7 +20,7 @@ from typing_extensions import Final from synapse.api.constants import EventTypes, HistoryVisibility, Membership from synapse.events import EventBase from synapse.events.utils import prune_event -from synapse.storage import Storage +from synapse.storage.controllers import StorageControllers from synapse.storage.state import StateFilter from synapse.types import RetentionPolicy, StateMap, get_domain_from_id @@ -47,7 +47,7 @@ _HISTORY_VIS_KEY: Final[Tuple[str, str]] = (EventTypes.RoomHistoryVisibility, "" async def filter_events_for_client( - storage: Storage, + storage: StorageControllers, user_id: str, events: List[EventBase], is_peeking: bool = False, @@ -268,7 +268,7 @@ async def filter_events_for_client( async def filter_events_for_server( - storage: Storage, + storage: StorageControllers, server_name: str, events: List[EventBase], redact: bool = True, @@ -360,7 +360,7 @@ async def filter_events_for_server( async def _event_to_history_vis( - storage: Storage, events: Collection[EventBase] + storage: StorageControllers, events: Collection[EventBase] ) -> Dict[str, str]: """Get the history visibility at each of the given events @@ -407,7 +407,7 @@ async def _event_to_history_vis( async def _event_to_memberships( - storage: Storage, events: Collection[EventBase], server_name: str + storage: StorageControllers, events: Collection[EventBase], server_name: str ) -> Dict[str, StateMap[EventBase]]: """Get the remote membership list at each of the given events diff --git a/tests/events/test_snapshot.py b/tests/events/test_snapshot.py index defbc68c18..8ddce83b83 100644 --- a/tests/events/test_snapshot.py +++ b/tests/events/test_snapshot.py @@ -29,7 +29,7 @@ class TestEventContext(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.store = hs.get_datastores().main - self.storage = hs.get_storage() + self._storage_controllers = hs.get_storage_controllers() self.user_id = self.register_user("u1", "pass") self.user_tok = self.login("u1", "pass") @@ -87,7 +87,7 @@ class TestEventContext(unittest.HomeserverTestCase): def _check_serialize_deserialize(self, event, context): serialized = self.get_success(context.serialize(event, self.store)) - d_context = EventContext.deserialize(self.storage, serialized) + d_context = EventContext.deserialize(self._storage_controllers, serialized) self.assertEqual(context.state_group, d_context.state_group) self.assertEqual(context.rejected, d_context.rejected) diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index ec00900621..500c9ccfbc 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -50,7 +50,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase): hs = self.setup_test_homeserver(federation_http_client=None) self.handler = hs.get_federation_handler() self.store = hs.get_datastores().main - self.state_storage = hs.get_storage().state + self.state_storage_controller = hs.get_storage_controllers().state self._event_auth_handler = hs.get_event_auth_handler() return hs @@ -338,7 +338,9 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase): # mapping from (type, state_key) -> state_event_id assert most_recent_prev_event_id is not None prev_state_map = self.get_success( - self.state_storage.get_state_ids_for_event(most_recent_prev_event_id) + self.state_storage_controller.get_state_ids_for_event( + most_recent_prev_event_id + ) ) # List of state event ID's prev_state_ids = list(prev_state_map.values()) diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py index e64b28f28b..1d5b2492c0 100644 --- a/tests/handlers/test_federation_event.py +++ b/tests/handlers/test_federation_event.py @@ -70,7 +70,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): ) -> None: OTHER_USER = f"@user:{self.OTHER_SERVER_NAME}" main_store = self.hs.get_datastores().main - state_storage = self.hs.get_storage().state + state_storage_controller = self.hs.get_storage_controllers().state # create the room user_id = self.register_user("kermit", "test") @@ -146,10 +146,11 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): ) if prev_exists_as_outlier: prev_event.internal_metadata.outlier = True - persistence = self.hs.get_storage().persistence + persistence = self.hs.get_storage_controllers().persistence self.get_success( persistence.persist_event( - prev_event, EventContext.for_outlier(self.hs.get_storage()) + prev_event, + EventContext.for_outlier(self.hs.get_storage_controllers()), ) ) else: @@ -216,7 +217,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): # check that the state at that event is as expected state = self.get_success( - state_storage.get_state_ids_for_event(pulled_event.event_id) + state_storage_controller.get_state_ids_for_event(pulled_event.event_id) ) expected_state = { (e.type, e.state_key): e.event_id for e in state_at_prev_event diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py index f4f7ab4845..44da96c792 100644 --- a/tests/handlers/test_message.py +++ b/tests/handlers/test_message.py @@ -37,7 +37,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.handler = self.hs.get_event_creation_handler() - self.persist_event_storage = self.hs.get_storage().persistence + self._persist_event_storage_controller = ( + self.hs.get_storage_controllers().persistence + ) self.user_id = self.register_user("tester", "foobar") self.access_token = self.login("tester", "foobar") @@ -65,7 +67,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase): ) ) self.get_success( - self.persist_event_storage.persist_event(memberEvent, memberEventContext) + self._persist_event_storage_controller.persist_event( + memberEvent, memberEventContext + ) ) return memberEvent, memberEventContext @@ -129,7 +133,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase): self.assertNotEqual(event1.event_id, event3.event_id) ret_event3, event_pos3, _ = self.get_success( - self.persist_event_storage.persist_event(event3, context) + self._persist_event_storage_controller.persist_event(event3, context) ) # Assert that the returned values match those from the initial event @@ -143,7 +147,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase): self.assertNotEqual(event1.event_id, event3.event_id) events, _ = self.get_success( - self.persist_event_storage.persist_events([(event3, context)]) + self._persist_event_storage_controller.persist_events([(event3, context)]) ) ret_event4 = events[0] @@ -166,7 +170,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase): self.assertNotEqual(event1.event_id, event2.event_id) events, _ = self.get_success( - self.persist_event_storage.persist_events( + self._persist_event_storage_controller.persist_events( [(event1, context1), (event2, context2)] ) ) diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index 4d658d29ca..a68c2ffd45 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -954,7 +954,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): ) self.get_success( - self.hs.get_storage().persistence.persist_event(event, context) + self.hs.get_storage_controllers().persistence.persist_event(event, context) ) def test_local_user_leaving_room_remains_in_user_directory(self) -> None: diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 85be79d19d..c5705256e6 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -32,7 +32,7 @@ class BaseSlavedStoreTestCase(BaseStreamTestCase): self.master_store = hs.get_datastores().main self.slaved_store = self.worker_hs.get_datastores().main - self.storage = hs.get_storage() + self._storage_controllers = hs.get_storage_controllers() def replicate(self): """Tell the master side of replication that something has happened, and then diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 297a9e77f8..6d3d4afe52 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -262,7 +262,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): ) msg, msgctx = self.build_event() self.get_success( - self.storage.persistence.persist_events([(j2, j2ctx), (msg, msgctx)]) + self._storage_controllers.persistence.persist_events( + [(j2, j2ctx), (msg, msgctx)] + ) ) self.replicate() @@ -323,12 +325,14 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): if backfill: self.get_success( - self.storage.persistence.persist_events( + self._storage_controllers.persistence.persist_events( [(event, context)], backfilled=True ) ) else: - self.get_success(self.storage.persistence.persist_event(event, context)) + self.get_success( + self._storage_controllers.persistence.persist_event(event, context) + ) return event diff --git a/tests/replication/slave/storage/test_receipts.py b/tests/replication/slave/storage/test_receipts.py index 5bbbd5fbcb..19f57115a1 100644 --- a/tests/replication/slave/storage/test_receipts.py +++ b/tests/replication/slave/storage/test_receipts.py @@ -31,7 +31,9 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase): def prepare(self, reactor, clock, homeserver): super().prepare(reactor, clock, homeserver) self.room_creator = homeserver.get_room_creation_handler() - self.persist_event_storage = self.hs.get_storage().persistence + self.persist_event_storage_controller = ( + self.hs.get_storage_controllers().persistence + ) # Create a test user self.ourUser = UserID.from_string(OUR_USER_ID) @@ -61,7 +63,9 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase): ) ) self.get_success( - self.persist_event_storage.persist_event(memberEvent, memberEventContext) + self.persist_event_storage_controller.persist_event( + memberEvent, memberEventContext + ) ) # Join the second user to the second room @@ -76,7 +80,9 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase): ) ) self.get_success( - self.persist_event_storage.persist_event(memberEvent, memberEventContext) + self.persist_event_storage_controller.persist_event( + memberEvent, memberEventContext + ) ) def test_return_empty_with_no_data(self): diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 0cdf1dec40..0d44102237 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -2579,7 +2579,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): other_user_tok = self.login("user", "pass") event_builder_factory = self.hs.get_event_builder_factory() event_creation_handler = self.hs.get_event_creation_handler() - storage = self.hs.get_storage() + storage_controllers = self.hs.get_storage_controllers() # Create two rooms, one with a local user only and one with both a local # and remote user. @@ -2604,7 +2604,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): event_creation_handler.create_new_client_event(builder) ) - self.get_success(storage.persistence.persist_event(event, context)) + self.get_success(storage_controllers.persistence.persist_event(event, context)) # Now get rooms url = "/_synapse/admin/v1/users/@joiner:remote_hs/joined_rooms" diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py index 2cd7a9e6c5..ac9c113354 100644 --- a/tests/rest/client/test_retention.py +++ b/tests/rest/client/test_retention.py @@ -130,7 +130,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): We do this by setting a very long time between purge jobs. """ store = self.hs.get_datastores().main - storage = self.hs.get_storage() + storage_controllers = self.hs.get_storage_controllers() room_id = self.helper.create_room_as(self.user_id, tok=self.token) # Send a first event, which should be filtered out at the end of the test. @@ -155,7 +155,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): ) self.assertEqual(2, len(events), "events retrieved from database") filtered_events = self.get_success( - filter_events_for_client(storage, self.user_id, events) + filter_events_for_client(storage_controllers, self.user_id, events) ) # We should only get one event back. diff --git a/tests/rest/client/test_room_batch.py b/tests/rest/client/test_room_batch.py index 41a1bf6d89..1b7ee08ab2 100644 --- a/tests/rest/client/test_room_batch.py +++ b/tests/rest/client/test_room_batch.py @@ -88,7 +88,7 @@ class RoomBatchTestCase(unittest.HomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.clock = clock - self.storage = hs.get_storage() + self._storage_controllers = hs.get_storage_controllers() self.virtual_user_id, _ = self.register_appservice_user( "as_user_potato", self.appservice.token @@ -168,7 +168,9 @@ class RoomBatchTestCase(unittest.HomeserverTestCase): # Fetch the state_groups state_group_map = self.get_success( - self.storage.state.get_state_groups_ids(room_id, historical_event_ids) + self._storage_controllers.state.get_state_groups_ids( + room_id, historical_event_ids + ) ) # We expect all of the historical events to be using the same state_group diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py index c7661e7186..a0ce077a99 100644 --- a/tests/storage/test_event_chain.py +++ b/tests/storage/test_event_chain.py @@ -393,7 +393,8 @@ class EventChainStoreTestCase(HomeserverTestCase): # We need to persist the events to the events and state_events # tables. persist_events_store._store_event_txn( - txn, [(e, EventContext(self.hs.get_storage())) for e in events] + txn, + [(e, EventContext(self.hs.get_storage_controllers())) for e in events], ) # Actually call the function that calculates the auth chain stuff. diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py index aaa3189b16..a76718e8f9 100644 --- a/tests/storage/test_events.py +++ b/tests/storage/test_events.py @@ -31,7 +31,7 @@ class ExtremPruneTestCase(HomeserverTestCase): def prepare(self, reactor, clock, homeserver): self.state = self.hs.get_state_handler() - self.persistence = self.hs.get_storage().persistence + self._persistence = self.hs.get_storage_controllers().persistence self.store = self.hs.get_datastores().main self.register_user("user", "pass") @@ -71,7 +71,7 @@ class ExtremPruneTestCase(HomeserverTestCase): context = self.get_success( self.state.compute_event_context(event, state_ids_before_event=state) ) - self.get_success(self.persistence.persist_event(event, context)) + self.get_success(self._persistence.persist_event(event, context)) def assert_extremities(self, expected_extremities): """Assert the current extremities for the room""" @@ -148,7 +148,7 @@ class ExtremPruneTestCase(HomeserverTestCase): ) ) - self.get_success(self.persistence.persist_event(remote_event_2, context)) + self.get_success(self._persistence.persist_event(remote_event_2, context)) # Check that we haven't dropped the old extremity. self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id]) @@ -353,7 +353,7 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase): def prepare(self, reactor, clock, homeserver): self.state = self.hs.get_state_handler() - self.persistence = self.hs.get_storage().persistence + self._persistence = self.hs.get_storage_controllers().persistence self.store = self.hs.get_datastores().main def test_remote_user_rooms_cache_invalidated(self): @@ -390,7 +390,7 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase): ) context = self.get_success(self.state.compute_event_context(remote_event_1)) - self.get_success(self.persistence.persist_event(remote_event_1, context)) + self.get_success(self._persistence.persist_event(remote_event_1, context)) # Call `get_rooms_for_user` to add the remote user to the cache rooms = self.get_success(self.store.get_rooms_for_user(remote_user)) @@ -437,7 +437,7 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase): ) context = self.get_success(self.state.compute_event_context(remote_event_1)) - self.get_success(self.persistence.persist_event(remote_event_1, context)) + self.get_success(self._persistence.persist_event(remote_event_1, context)) # Call `get_users_in_room` to add the remote user to the cache users = self.get_success(self.store.get_users_in_room(room_id)) diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py index 08cc60237e..92cd0dfc05 100644 --- a/tests/storage/test_purge.py +++ b/tests/storage/test_purge.py @@ -31,7 +31,7 @@ class PurgeTests(HomeserverTestCase): self.room_id = self.helper.create_room_as(self.user_id) self.store = hs.get_datastores().main - self.storage = self.hs.get_storage() + self._storage_controllers = self.hs.get_storage_controllers() def test_purge_history(self): """ @@ -51,7 +51,9 @@ class PurgeTests(HomeserverTestCase): # Purge everything before this topological token self.get_success( - self.storage.purge_events.purge_history(self.room_id, token_str, True) + self._storage_controllers.purge_events.purge_history( + self.room_id, token_str, True + ) ) # 1-3 should fail and last will succeed, meaning that 1-3 are deleted @@ -79,7 +81,9 @@ class PurgeTests(HomeserverTestCase): # Purge everything before this topological token f = self.get_failure( - self.storage.purge_events.purge_history(self.room_id, event, True), + self._storage_controllers.purge_events.purge_history( + self.room_id, event, True + ), SynapseError, ) self.assertIn("greater than forward", f.value.args[0]) @@ -105,7 +109,9 @@ class PurgeTests(HomeserverTestCase): self.assertIsNotNone(create_event) # Purge everything before this topological token - self.get_success(self.storage.purge_events.purge_room(self.room_id)) + self.get_success( + self._storage_controllers.purge_events.purge_room(self.room_id) + ) # The events aren't found. self.store._invalidate_get_event_cache(create_event.event_id) diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index d8d17ef379..6c4e63b77c 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -31,7 +31,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.store = hs.get_datastores().main - self.storage = hs.get_storage() + self._storage = hs.get_storage_controllers() self.event_builder_factory = hs.get_event_builder_factory() self.event_creation_handler = hs.get_event_creation_handler() @@ -71,7 +71,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): self.event_creation_handler.create_new_client_event(builder) ) - self.get_success(self.storage.persistence.persist_event(event, context)) + self.get_success(self._storage.persistence.persist_event(event, context)) return event @@ -93,7 +93,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): self.event_creation_handler.create_new_client_event(builder) ) - self.get_success(self.storage.persistence.persist_event(event, context)) + self.get_success(self._storage.persistence.persist_event(event, context)) return event @@ -114,7 +114,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): self.event_creation_handler.create_new_client_event(builder) ) - self.get_success(self.storage.persistence.persist_event(event, context)) + self.get_success(self._storage.persistence.persist_event(event, context)) return event @@ -268,7 +268,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): ) ) - self.get_success(self.storage.persistence.persist_event(event_1, context_1)) + self.get_success(self._storage.persistence.persist_event(event_1, context_1)) event_2, context_2 = self.get_success( self.event_creation_handler.create_new_client_event( @@ -287,7 +287,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): ) ) ) - self.get_success(self.storage.persistence.persist_event(event_2, context_2)) + self.get_success(self._storage.persistence.persist_event(event_2, context_2)) # fetch one of the redactions fetched = self.get_success(self.store.get_event(redaction_event_id1)) @@ -411,7 +411,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): ) self.get_success( - self.storage.persistence.persist_event(redaction_event, context) + self._storage.persistence.persist_event(redaction_event, context) ) # Now lets jump to the future where we have censored the redaction event diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index 5b011e18cd..d497a19f63 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -72,7 +72,7 @@ class RoomEventsStoreTestCase(HomeserverTestCase): # Room events need the full datastore, for persist_event() and # get_room_state() self.store = hs.get_datastores().main - self.storage = hs.get_storage() + self._storage = hs.get_storage_controllers() self.event_factory = hs.get_event_factory() self.room = RoomID.from_string("!abcde:test") @@ -88,7 +88,7 @@ class RoomEventsStoreTestCase(HomeserverTestCase): def inject_room_event(self, **kwargs): self.get_success( - self.storage.persistence.persist_event( + self._storage.persistence.persist_event( self.event_factory.create_event(room_id=self.room.to_string(), **kwargs) ) ) diff --git a/tests/storage/test_room_search.py b/tests/storage/test_room_search.py index 8dfc1e1db9..e747c6b50e 100644 --- a/tests/storage/test_room_search.py +++ b/tests/storage/test_room_search.py @@ -99,7 +99,9 @@ class EventSearchInsertionTest(HomeserverTestCase): prev_event_ids = self.get_success(store.get_prev_events_for_room(room_id)) prev_event = self.get_success(store.get_event(prev_event_ids[0])) prev_state_map = self.get_success( - self.hs.get_storage().state.get_state_ids_for_event(prev_event_ids[0]) + self.hs.get_storage_controllers().state.get_state_ids_for_event( + prev_event_ids[0] + ) ) event_dict = { diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index f88f1c55fc..8043bdbde2 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -29,7 +29,7 @@ logger = logging.getLogger(__name__) class StateStoreTestCase(HomeserverTestCase): def prepare(self, reactor, clock, hs): self.store = hs.get_datastores().main - self.storage = hs.get_storage() + self.storage = hs.get_storage_controllers() self.state_datastore = self.storage.state.stores.state self.event_builder_factory = hs.get_event_builder_factory() self.event_creation_handler = hs.get_event_creation_handler() diff --git a/tests/test_state.py b/tests/test_state.py index 84694d368d..95f81bebae 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -179,12 +179,12 @@ class Graph: class StateTestCase(unittest.TestCase): def setUp(self): self.dummy_store = _DummyStore() - storage = Mock(main=self.dummy_store, state=self.dummy_store) + storage_controllers = Mock(main=self.dummy_store, state=self.dummy_store) hs = Mock( spec_set=[ "config", "get_datastores", - "get_storage", + "get_storage_controllers", "get_auth", "get_state_handler", "get_clock", @@ -199,7 +199,7 @@ class StateTestCase(unittest.TestCase): hs.get_clock.return_value = MockClock() hs.get_auth.return_value = Auth(hs) hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs) - hs.get_storage.return_value = storage + hs.get_storage_controllers.return_value = storage_controllers self.state = StateHandler(hs) self.event_id = 0 diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py index c654e36ee4..8027c7a856 100644 --- a/tests/test_utils/event_injection.py +++ b/tests/test_utils/event_injection.py @@ -70,7 +70,7 @@ async def inject_event( """ event, context = await create_event(hs, room_version, prev_event_ids, **kwargs) - persistence = hs.get_storage().persistence + persistence = hs.get_storage_controllers().persistence assert persistence is not None await persistence.persist_event(event, context) diff --git a/tests/test_visibility.py b/tests/test_visibility.py index 7a9b01ef9d..f338af6c36 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -34,7 +34,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): super(FilterEventsForServerTestCase, self).setUp() self.event_creation_handler = self.hs.get_event_creation_handler() self.event_builder_factory = self.hs.get_event_builder_factory() - self.storage = self.hs.get_storage() + self._storage_controllers = self.hs.get_storage_controllers() self.get_success(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM")) @@ -60,7 +60,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): events_to_filter.append(evt) filtered = self.get_success( - filter_events_for_server(self.storage, "test_server", events_to_filter) + filter_events_for_server( + self._storage_controllers, "test_server", events_to_filter + ) ) # the result should be 5 redacted events, and 5 unredacted events. @@ -80,7 +82,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): outlier = self._inject_outlier() self.assertEqual( self.get_success( - filter_events_for_server(self.storage, "remote_hs", [outlier]) + filter_events_for_server( + self._storage_controllers, "remote_hs", [outlier] + ) ), [outlier], ) @@ -89,7 +93,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): evt = self._inject_message("@unerased:local_hs") filtered = self.get_success( - filter_events_for_server(self.storage, "remote_hs", [outlier, evt]) + filter_events_for_server( + self._storage_controllers, "remote_hs", [outlier, evt] + ) ) self.assertEqual(len(filtered), 2, f"expected 2 results, got: {filtered}") self.assertEqual(filtered[0], outlier) @@ -99,7 +105,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): # ... but other servers should only be able to see the outlier (the other should # be redacted) filtered = self.get_success( - filter_events_for_server(self.storage, "other_server", [outlier, evt]) + filter_events_for_server( + self._storage_controllers, "other_server", [outlier, evt] + ) ) self.assertEqual(filtered[0], outlier) self.assertEqual(filtered[1].event_id, evt.event_id) @@ -132,7 +140,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): # ... and the filtering happens. filtered = self.get_success( - filter_events_for_server(self.storage, "test_server", events_to_filter) + filter_events_for_server( + self._storage_controllers, "test_server", events_to_filter + ) ) for i in range(0, len(events_to_filter)): @@ -168,7 +178,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): event, context = self.get_success( self.event_creation_handler.create_new_client_event(builder) ) - self.get_success(self.storage.persistence.persist_event(event, context)) + self.get_success( + self._storage_controllers.persistence.persist_event(event, context) + ) return event def _inject_room_member( @@ -194,7 +206,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): self.event_creation_handler.create_new_client_event(builder) ) - self.get_success(self.storage.persistence.persist_event(event, context)) + self.get_success( + self._storage_controllers.persistence.persist_event(event, context) + ) return event def _inject_message( @@ -216,7 +230,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): self.event_creation_handler.create_new_client_event(builder) ) - self.get_success(self.storage.persistence.persist_event(event, context)) + self.get_success( + self._storage_controllers.persistence.persist_event(event, context) + ) return event def _inject_outlier(self) -> EventBase: @@ -234,8 +250,8 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): event = self.get_success(builder.build(prev_event_ids=[], auth_event_ids=[])) event.internal_metadata.outlier = True self.get_success( - self.storage.persistence.persist_event( - event, EventContext.for_outlier(self.storage) + self._storage_controllers.persistence.persist_event( + event, EventContext.for_outlier(self._storage_controllers) ) ) return event @@ -293,7 +309,9 @@ class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase): self.assertEqual( self.get_success( filter_events_for_client( - self.hs.get_storage(), "@user:test", [invite_event, reject_event] + self.hs.get_storage_controllers(), + "@user:test", + [invite_event, reject_event], ) ), [invite_event, reject_event], @@ -303,7 +321,9 @@ class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase): self.assertEqual( self.get_success( filter_events_for_client( - self.hs.get_storage(), "@other:test", [invite_event, reject_event] + self.hs.get_storage_controllers(), + "@other:test", + [invite_event, reject_event], ) ), [], diff --git a/tests/utils.py b/tests/utils.py index d4ba3a9b99..3059c453d5 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -264,7 +264,7 @@ class MockClock: async def create_room(hs, room_id: str, creator_id: str): """Creates and persist a creation event for the given room""" - persistence_store = hs.get_storage().persistence + persistence_store = hs.get_storage_controllers().persistence store = hs.get_datastores().main event_builder_factory = hs.get_event_builder_factory() event_creation_handler = hs.get_event_creation_handler() -- cgit 1.4.1