diff --git a/changelog.d/12775.misc b/changelog.d/12775.misc
new file mode 100644
index 0000000000..eac326cde3
--- /dev/null
+++ b/changelog.d/12775.misc
@@ -0,0 +1 @@
+Refactor `resolve_state_groups_for_events` to not pull out full state when no state resolution happens.
\ No newline at end of file
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 0219091c4e..4b4ed42cff 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -288,7 +288,6 @@ class StateHandler:
#
# first of all, figure out the state before the event
#
-
if old_state:
# if we're given the state before the event, then we use that
state_ids_before_event: StateMap[str] = {
@@ -419,33 +418,37 @@ class StateHandler:
"""
logger.debug("resolve_state_groups event_ids %s", event_ids)
- # map from state group id to the state in that state group (where
- # 'state' is a map from state key to event id)
- # dict[int, dict[(str, str), str]]
- state_groups_ids = await self.state_store.get_state_groups_ids(
- room_id, event_ids
- )
-
- if len(state_groups_ids) == 0:
- return _StateCacheEntry(state={}, state_group=None)
- elif len(state_groups_ids) == 1:
- name, state_list = list(state_groups_ids.items()).pop()
+ state_groups = await self.state_store.get_state_group_for_events(event_ids)
- prev_group, delta_ids = await self.state_store.get_state_group_delta(name)
+ state_group_ids = state_groups.values()
+ # check if each event has same state group id, if so there's no state to resolve
+ state_group_ids_set = set(state_group_ids)
+ if len(state_group_ids_set) == 1:
+ (state_group_id,) = state_group_ids_set
+ state = await self.state_store.get_state_for_groups(state_group_ids_set)
+ prev_group, delta_ids = await self.state_store.get_state_group_delta(
+ state_group_id
+ )
return _StateCacheEntry(
- state=state_list,
- state_group=name,
+ state=state[state_group_id],
+ state_group=state_group_id,
prev_group=prev_group,
delta_ids=delta_ids,
)
+ elif len(state_group_ids_set) == 0:
+ return _StateCacheEntry(state={}, state_group=None)
room_version = await self.store.get_room_version_id(room_id)
+ state_to_resolve = await self.state_store.get_state_for_groups(
+ state_group_ids_set
+ )
+
result = await self._state_resolution_handler.resolve_state_groups(
room_id,
room_version,
- state_groups_ids,
+ state_to_resolve,
None,
state_res_store=StateResolutionStore(self.store),
)
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index 7614d76ac6..609a2b88bf 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -189,7 +189,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
group: int,
state_filter: StateFilter,
) -> Tuple[MutableStateMap[str], bool]:
- """Checks if group is in cache. See `_get_state_for_groups`
+ """Checks if group is in cache. See `get_state_for_groups`
Args:
cache: the state group cache to use
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index d4a1bd4f9d..a6c60de504 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -586,7 +586,7 @@ class StateGroupStorage:
if not event_ids:
return {}
- event_to_groups = await self._get_state_group_for_events(event_ids)
+ event_to_groups = await self.get_state_group_for_events(event_ids)
groups = set(event_to_groups.values())
group_to_state = await self.stores.state._get_state_for_groups(groups)
@@ -602,7 +602,7 @@ class StateGroupStorage:
Returns:
Resolves to a map of (type, state_key) -> event_id
"""
- group_to_state = await self._get_state_for_groups((state_group,))
+ group_to_state = await self.get_state_for_groups((state_group,))
return group_to_state[state_group]
@@ -675,7 +675,7 @@ class StateGroupStorage:
RuntimeError if we don't have a state group for one or more of the events
(ie they are outliers or unknown)
"""
- event_to_groups = await self._get_state_group_for_events(event_ids)
+ event_to_groups = await self.get_state_group_for_events(event_ids)
groups = set(event_to_groups.values())
group_to_state = await self.stores.state._get_state_for_groups(
@@ -716,7 +716,7 @@ class StateGroupStorage:
RuntimeError if we don't have a state group for one or more of the events
(ie they are outliers or unknown)
"""
- event_to_groups = await self._get_state_group_for_events(event_ids)
+ event_to_groups = await self.get_state_group_for_events(event_ids)
groups = set(event_to_groups.values())
group_to_state = await self.stores.state._get_state_for_groups(
@@ -774,7 +774,7 @@ class StateGroupStorage:
)
return state_map[event_id]
- def _get_state_for_groups(
+ def get_state_for_groups(
self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
) -> Awaitable[Dict[int, MutableStateMap[str]]]:
"""Gets the state at each of a list of state groups, optionally
@@ -792,7 +792,7 @@ class StateGroupStorage:
groups, state_filter or StateFilter.all()
)
- async def _get_state_group_for_events(
+ async def get_state_group_for_events(
self,
event_ids: Collection[str],
await_full_state: bool = True,
diff --git a/tests/test_state.py b/tests/test_state.py
index 651ec1c7d4..74a8ce6096 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -129,6 +129,19 @@ class _DummyStore:
async def get_room_version_id(self, room_id):
return RoomVersions.V1.identifier
+ async def get_state_group_for_events(self, event_ids):
+ res = {}
+ for event in event_ids:
+ res[event] = self._event_to_state_group[event]
+ return res
+
+ async def get_state_for_groups(self, groups):
+ res = {}
+ for group in groups:
+ state = self._group_to_state[group]
+ res[group] = state
+ return res
+
class DictObj(dict):
def __init__(self, **kwargs):
|