summary refs log tree commit diff
path: root/synapse/storage/state.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/state.py')
-rw-r--r--synapse/storage/state.py82
1 files changed, 71 insertions, 11 deletions
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index d1d5859214..ab630953ac 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -1,4 +1,5 @@
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2022 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -15,6 +16,7 @@ import logging
 from typing import (
     TYPE_CHECKING,
     Awaitable,
+    Callable,
     Collection,
     Dict,
     Iterable,
@@ -62,7 +64,7 @@ class StateFilter:
     types: "frozendict[str, Optional[FrozenSet[str]]]"
     include_others: bool = False
 
-    def __attrs_post_init__(self):
+    def __attrs_post_init__(self) -> None:
         # If `include_others` is set we canonicalise the filter by removing
         # wildcards from the types dictionary
         if self.include_others:
@@ -138,7 +140,9 @@ class StateFilter:
         )
 
     @staticmethod
-    def freeze(types: Mapping[str, Optional[Collection[str]]], include_others: bool):
+    def freeze(
+        types: Mapping[str, Optional[Collection[str]]], include_others: bool
+    ) -> "StateFilter":
         """
         Returns a (frozen) StateFilter with the same contents as the parameters
         specified here, which can be made of mutable types.
@@ -530,6 +534,44 @@ class StateFilter:
             new_all, new_excludes, new_wildcards, new_concrete_keys
         )
 
+    def must_await_full_state(self, is_mine_id: Callable[[str], bool]) -> bool:
+        """Check if we need to wait for full state to complete to calculate this state
+
+        If we have a state filter which is completely satisfied even with partial
+        state, then we don't need to await_full_state before we can return it.
+
+        Args:
+            is_mine_id: a callable which confirms if a given state_key matches a mxid
+               of a local user
+        """
+
+        # TODO(faster_joins): it's not entirely clear that this is safe. In particular,
+        #  there may be circumstances in which we return a piece of state that, once we
+        #  resync the state, we discover is invalid. For example: if it turns out that
+        #  the sender of a piece of state wasn't actually in the room, then clearly that
+        #  state shouldn't have been returned.
+        #  We should at least add some tests around this to see what happens.
+
+        # if we haven't requested membership events, then it depends on the value of
+        # 'include_others'
+        if EventTypes.Member not in self.types:
+            return self.include_others
+
+        # if we're looking for *all* membership events, then we have to wait
+        member_state_keys = self.types[EventTypes.Member]
+        if member_state_keys is None:
+            return True
+
+        # otherwise, consider whose membership we are looking for. If it's entirely
+        # local users, then we don't need to wait.
+        for state_key in member_state_keys:
+            if not is_mine_id(state_key):
+                # remote user
+                return True
+
+        # local users only
+        return False
+
 
 _ALL_STATE_FILTER = StateFilter(types=frozendict(), include_others=True)
 _ALL_NON_MEMBER_STATE_FILTER = StateFilter(
@@ -542,6 +584,7 @@ class StateGroupStorage:
     """High level interface to fetching state for event."""
 
     def __init__(self, hs: "HomeServer", stores: "Databases"):
+        self._is_mine_id = hs.is_mine_id
         self.stores = stores
         self._partial_state_events_tracker = PartialStateEventsTracker(stores.main)
 
@@ -584,23 +627,26 @@ class StateGroupStorage:
         if not event_ids:
             return {}
 
-        event_to_groups = await self._get_state_group_for_events(event_ids)
+        event_to_groups = await self.get_state_group_for_events(event_ids)
 
         groups = set(event_to_groups.values())
         group_to_state = await self.stores.state._get_state_for_groups(groups)
 
         return group_to_state
 
-    async def get_state_ids_for_group(self, state_group: int) -> StateMap[str]:
+    async def get_state_ids_for_group(
+        self, state_group: int, state_filter: Optional[StateFilter] = None
+    ) -> StateMap[str]:
         """Get the event IDs of all the state in the given state group
 
         Args:
             state_group: A state group for which we want to get the state IDs.
+            state_filter: specifies the type of state event to fetch from DB, example: EventTypes.JoinRules
 
         Returns:
             Resolves to a map of (type, state_key) -> event_id
         """
-        group_to_state = await self._get_state_for_groups((state_group,))
+        group_to_state = await self.get_state_for_groups((state_group,), state_filter)
 
         return group_to_state[state_group]
 
@@ -673,7 +719,13 @@ class StateGroupStorage:
             RuntimeError if we don't have a state group for one or more of the events
                (ie they are outliers or unknown)
         """
-        event_to_groups = await self._get_state_group_for_events(event_ids)
+        await_full_state = True
+        if state_filter and not state_filter.must_await_full_state(self._is_mine_id):
+            await_full_state = False
+
+        event_to_groups = await self.get_state_group_for_events(
+            event_ids, await_full_state=await_full_state
+        )
 
         groups = set(event_to_groups.values())
         group_to_state = await self.stores.state._get_state_for_groups(
@@ -697,7 +749,9 @@ class StateGroupStorage:
         return {event: event_to_state[event] for event in event_ids}
 
     async def get_state_ids_for_events(
-        self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None
+        self,
+        event_ids: Collection[str],
+        state_filter: Optional[StateFilter] = None,
     ) -> Dict[str, StateMap[str]]:
         """
         Get the state dicts corresponding to a list of events, containing the event_ids
@@ -714,7 +768,13 @@ class StateGroupStorage:
             RuntimeError if we don't have a state group for one or more of the events
                 (ie they are outliers or unknown)
         """
-        event_to_groups = await self._get_state_group_for_events(event_ids)
+        await_full_state = True
+        if state_filter and not state_filter.must_await_full_state(self._is_mine_id):
+            await_full_state = False
+
+        event_to_groups = await self.get_state_group_for_events(
+            event_ids, await_full_state=await_full_state
+        )
 
         groups = set(event_to_groups.values())
         group_to_state = await self.stores.state._get_state_for_groups(
@@ -772,7 +832,7 @@ class StateGroupStorage:
         )
         return state_map[event_id]
 
-    def _get_state_for_groups(
+    def get_state_for_groups(
         self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
     ) -> Awaitable[Dict[int, MutableStateMap[str]]]:
         """Gets the state at each of a list of state groups, optionally
@@ -790,7 +850,7 @@ class StateGroupStorage:
             groups, state_filter or StateFilter.all()
         )
 
-    async def _get_state_group_for_events(
+    async def get_state_group_for_events(
         self,
         event_ids: Collection[str],
         await_full_state: bool = True,
@@ -800,7 +860,7 @@ class StateGroupStorage:
         Args:
             event_ids: events to get state groups for
             await_full_state: if true, will block if we do not yet have complete
-               state at this event.
+               state at these events.
         """
         if await_full_state:
             await self._partial_state_events_tracker.await_full_state(event_ids)