diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 72b291889b..9e48e09270 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -59,11 +59,13 @@ from synapse.types.state import StateFilter
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.metrics import Measure, measure_func
+from synapse.util.stringutils import shortstr
if TYPE_CHECKING:
from synapse.server import HomeServer
from synapse.storage.controllers import StateStorageController
from synapse.storage.databases.main import DataStore
+ from synapse.storage.databases.state.deletion import StateDeletionDataStore
logger = logging.getLogger(__name__)
metrics_logger = logging.getLogger("synapse.state.metrics")
@@ -194,6 +196,8 @@ class StateHandler:
self._storage_controllers = hs.get_storage_controllers()
self._events_shard_config = hs.config.worker.events_shard_config
self._instance_name = hs.get_instance_name()
+ self._state_store = hs.get_datastores().state
+ self._state_deletion_store = hs.get_datastores().state_deletion
self._update_current_state_client = (
ReplicationUpdateCurrentStateRestServlet.make_client(hs)
@@ -355,6 +359,28 @@ class StateHandler:
await_full_state=False,
)
+ # Ensure we still have the state groups we're relying on, and bump
+ # their usage time to avoid them being deleted from under us.
+ if entry.state_group:
+ missing_state_group = await self._state_deletion_store.check_state_groups_and_bump_deletion(
+ {entry.state_group}
+ )
+ if missing_state_group:
+ raise Exception(f"Missing state group: {entry.state_group}")
+ elif entry.prev_group:
+ # We only rely on the prev group when persisting the event if we
+ # don't have an `entry.state_group`.
+ missing_state_group = await self._state_deletion_store.check_state_groups_and_bump_deletion(
+ {entry.prev_group}
+ )
+
+ if missing_state_group:
+ # If we're missing the prev group then we can just clear the
+ # entries, and rely on `entry._state` (which must exist if
+ # `entry.state_group` is None)
+ entry.prev_group = None
+ entry.delta_ids = None
+
state_group_before_event_prev_group = entry.prev_group
deltas_to_state_group_before_event = entry.delta_ids
state_ids_before_event = None
@@ -475,7 +501,10 @@ class StateHandler:
@trace
@measure_func()
async def resolve_state_groups_for_events(
- self, room_id: str, event_ids: StrCollection, await_full_state: bool = True
+ self,
+ room_id: str,
+ event_ids: StrCollection,
+ await_full_state: bool = True,
) -> _StateCacheEntry:
"""Given a list of event_ids this method fetches the state at each
event, resolves conflicts between them and returns them.
@@ -511,6 +540,7 @@ class StateHandler:
) = await self._state_storage_controller.get_state_group_delta(
state_group_id
)
+
return _StateCacheEntry(
state=None,
state_group=state_group_id,
@@ -531,7 +561,9 @@ class StateHandler:
room_version,
state_to_resolve,
None,
- state_res_store=StateResolutionStore(self.store),
+ state_res_store=StateResolutionStore(
+ self.store, self._state_deletion_store
+ ),
)
return result
@@ -663,7 +695,25 @@ class StateResolutionHandler:
async with self.resolve_linearizer.queue(group_names):
cache = self._state_cache.get(group_names, None)
if cache:
- return cache
+ # Check that the returned cache entry doesn't point to deleted
+ # state groups.
+ state_groups_to_check = set()
+ if cache.state_group is not None:
+ state_groups_to_check.add(cache.state_group)
+
+ if cache.prev_group is not None:
+ state_groups_to_check.add(cache.prev_group)
+
+ missing_state_groups = await state_res_store.state_deletion_store.check_state_groups_and_bump_deletion(
+ state_groups_to_check
+ )
+
+ if not missing_state_groups:
+ return cache
+ else:
+ # There are missing state groups, so let's remove the stale
+ # entry and continue as if it was a cache miss.
+ self._state_cache.pop(group_names, None)
logger.info(
"Resolving state for %s with groups %s",
@@ -671,6 +721,16 @@ class StateResolutionHandler:
list(group_names),
)
+ # We double check that none of the state groups have been deleted.
+ # They shouldn't be as all these state groups should be referenced.
+ missing_state_groups = await state_res_store.state_deletion_store.check_state_groups_and_bump_deletion(
+ group_names
+ )
+ if missing_state_groups:
+ raise Exception(
+ f"State groups have been deleted: {shortstr(missing_state_groups)}"
+ )
+
state_groups_histogram.observe(len(state_groups_ids))
new_state = await self.resolve_events_with_store(
@@ -884,7 +944,8 @@ class StateResolutionStore:
in well defined way.
"""
- store: "DataStore"
+ main_store: "DataStore"
+ state_deletion_store: "StateDeletionDataStore"
def get_events(
self, event_ids: StrCollection, allow_rejected: bool = False
@@ -899,7 +960,7 @@ class StateResolutionStore:
An awaitable which resolves to a dict from event_id to event.
"""
- return self.store.get_events(
+ return self.main_store.get_events(
event_ids,
redact_behaviour=EventRedactBehaviour.as_is,
get_prev_content=False,
@@ -920,4 +981,4 @@ class StateResolutionStore:
An awaitable that resolves to a set of event IDs.
"""
- return self.store.get_auth_chain_difference(room_id, state_sets)
+ return self.main_store.get_auth_chain_difference(room_id, state_sets)
|