summary refs log tree commit diff
path: root/synapse/state/__init__.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/state/__init__.py73
1 files changed, 67 insertions, 6 deletions
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py

index 72b291889b..9e48e09270 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py
@@ -59,11 +59,13 @@ from synapse.types.state import StateFilter from synapse.util.async_helpers import Linearizer from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.metrics import Measure, measure_func +from synapse.util.stringutils import shortstr if TYPE_CHECKING: from synapse.server import HomeServer from synapse.storage.controllers import StateStorageController from synapse.storage.databases.main import DataStore + from synapse.storage.databases.state.deletion import StateDeletionDataStore logger = logging.getLogger(__name__) metrics_logger = logging.getLogger("synapse.state.metrics") @@ -194,6 +196,8 @@ class StateHandler: self._storage_controllers = hs.get_storage_controllers() self._events_shard_config = hs.config.worker.events_shard_config self._instance_name = hs.get_instance_name() + self._state_store = hs.get_datastores().state + self._state_deletion_store = hs.get_datastores().state_deletion self._update_current_state_client = ( ReplicationUpdateCurrentStateRestServlet.make_client(hs) @@ -355,6 +359,28 @@ class StateHandler: await_full_state=False, ) + # Ensure we still have the state groups we're relying on, and bump + # their usage time to avoid them being deleted from under us. + if entry.state_group: + missing_state_group = await self._state_deletion_store.check_state_groups_and_bump_deletion( + {entry.state_group} + ) + if missing_state_group: + raise Exception(f"Missing state group: {entry.state_group}") + elif entry.prev_group: + # We only rely on the prev group when persisting the event if we + # don't have an `entry.state_group`. + missing_state_group = await self._state_deletion_store.check_state_groups_and_bump_deletion( + {entry.prev_group} + ) + + if missing_state_group: + # If we're missing the prev group then we can just clear the + # entries, and rely on `entry._state` (which must exist if + # `entry.state_group` is None) + entry.prev_group = None + entry.delta_ids = None + state_group_before_event_prev_group = entry.prev_group deltas_to_state_group_before_event = entry.delta_ids state_ids_before_event = None @@ -475,7 +501,10 @@ class StateHandler: @trace @measure_func() async def resolve_state_groups_for_events( - self, room_id: str, event_ids: StrCollection, await_full_state: bool = True + self, + room_id: str, + event_ids: StrCollection, + await_full_state: bool = True, ) -> _StateCacheEntry: """Given a list of event_ids this method fetches the state at each event, resolves conflicts between them and returns them. @@ -511,6 +540,7 @@ class StateHandler: ) = await self._state_storage_controller.get_state_group_delta( state_group_id ) + return _StateCacheEntry( state=None, state_group=state_group_id, @@ -531,7 +561,9 @@ class StateHandler: room_version, state_to_resolve, None, - state_res_store=StateResolutionStore(self.store), + state_res_store=StateResolutionStore( + self.store, self._state_deletion_store + ), ) return result @@ -663,7 +695,25 @@ class StateResolutionHandler: async with self.resolve_linearizer.queue(group_names): cache = self._state_cache.get(group_names, None) if cache: - return cache + # Check that the returned cache entry doesn't point to deleted + # state groups. + state_groups_to_check = set() + if cache.state_group is not None: + state_groups_to_check.add(cache.state_group) + + if cache.prev_group is not None: + state_groups_to_check.add(cache.prev_group) + + missing_state_groups = await state_res_store.state_deletion_store.check_state_groups_and_bump_deletion( + state_groups_to_check + ) + + if not missing_state_groups: + return cache + else: + # There are missing state groups, so let's remove the stale + # entry and continue as if it was a cache miss. + self._state_cache.pop(group_names, None) logger.info( "Resolving state for %s with groups %s", @@ -671,6 +721,16 @@ class StateResolutionHandler: list(group_names), ) + # We double check that none of the state groups have been deleted. + # They shouldn't be as all these state groups should be referenced. + missing_state_groups = await state_res_store.state_deletion_store.check_state_groups_and_bump_deletion( + group_names + ) + if missing_state_groups: + raise Exception( + f"State groups have been deleted: {shortstr(missing_state_groups)}" + ) + state_groups_histogram.observe(len(state_groups_ids)) new_state = await self.resolve_events_with_store( @@ -884,7 +944,8 @@ class StateResolutionStore: in well defined way. """ - store: "DataStore" + main_store: "DataStore" + state_deletion_store: "StateDeletionDataStore" def get_events( self, event_ids: StrCollection, allow_rejected: bool = False @@ -899,7 +960,7 @@ class StateResolutionStore: An awaitable which resolves to a dict from event_id to event. """ - return self.store.get_events( + return self.main_store.get_events( event_ids, redact_behaviour=EventRedactBehaviour.as_is, get_prev_content=False, @@ -920,4 +981,4 @@ class StateResolutionStore: An awaitable that resolves to a set of event IDs. """ - return self.store.get_auth_chain_difference(room_id, state_sets) + return self.main_store.get_auth_chain_difference(room_id, state_sets)