summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/16785.misc1
-rw-r--r--synapse/handlers/room_summary.py12
-rw-r--r--synapse/storage/databases/main/state.py48
-rw-r--r--synapse/types/state.py24
4 files changed, 82 insertions, 3 deletions
diff --git a/changelog.d/16785.misc b/changelog.d/16785.misc
new file mode 100644
index 0000000000..4de185c5dd
--- /dev/null
+++ b/changelog.d/16785.misc
@@ -0,0 +1 @@
+Reduce amount of state pulled out when querying federation hierachy.
diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py
index a534f5f280..78bcac1429 100644
--- a/synapse/handlers/room_summary.py
+++ b/synapse/handlers/room_summary.py
@@ -44,6 +44,7 @@ from synapse.api.ratelimiting import Ratelimiter
 from synapse.config.ratelimiting import RatelimitSettings
 from synapse.events import EventBase
 from synapse.types import JsonDict, Requester, StrCollection
+from synapse.types.state import StateFilter
 from synapse.util.caches.response_cache import ResponseCache
 
 if TYPE_CHECKING:
@@ -546,7 +547,16 @@ class RoomSummaryHandler:
         Returns:
              True if the room is accessible to the requesting user or server.
         """
-        state_ids = await self._storage_controllers.state.get_current_state_ids(room_id)
+        event_types = [
+            (EventTypes.JoinRules, ""),
+            (EventTypes.RoomHistoryVisibility, ""),
+        ]
+        if requester:
+            event_types.append((EventTypes.Member, requester))
+
+        state_ids = await self._storage_controllers.state.get_current_state_ids(
+            room_id, state_filter=StateFilter.from_types(event_types)
+        )
 
         # If there's no state for the room, it isn't known.
         if not state_ids:
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 4700e74ad2..06c44bb563 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -30,7 +30,10 @@ from typing import (
     Optional,
     Set,
     Tuple,
+    TypeVar,
+    Union,
     cast,
+    overload,
 )
 
 import attr
@@ -52,7 +55,7 @@ from synapse.storage.database import (
 )
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
-from synapse.types import JsonDict, JsonMapping, StateMap
+from synapse.types import JsonDict, JsonMapping, StateKey, StateMap
 from synapse.types.state import StateFilter
 from synapse.util.caches import intern_string
 from synapse.util.caches.descriptors import cached, cachedList
@@ -64,6 +67,8 @@ if TYPE_CHECKING:
 
 logger = logging.getLogger(__name__)
 
+_T = TypeVar("_T")
+
 
 MAX_STATE_DELTA_HOPS = 100
 
@@ -349,7 +354,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         def _get_filtered_current_state_ids_txn(
             txn: LoggingTransaction,
         ) -> StateMap[str]:
-            results = {}
+            results = StateMapWrapper(state_filter=state_filter or StateFilter.all())
+
             sql = """
                 SELECT type, state_key, event_id FROM current_state_events
                 WHERE room_id = ?
@@ -726,3 +732,41 @@ class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore):
         hs: "HomeServer",
     ):
         super().__init__(database, db_conn, hs)
+
+
+@attr.s(auto_attribs=True, slots=True)
+class StateMapWrapper(Dict[StateKey, str]):
+    """A wrapper around a StateMap[str] to ensure that we only query for items
+    that were not filtered out.
+
+    This is to help prevent bugs where we filter out state but other bits of the
+    code expect the state to be there.
+    """
+
+    state_filter: StateFilter
+
+    def __getitem__(self, key: StateKey) -> str:
+        if key not in self.state_filter:
+            raise Exception("State map was filtered and doesn't include: %s", key)
+        return super().__getitem__(key)
+
+    @overload
+    def get(self, key: Tuple[str, str]) -> Optional[str]:
+        ...
+
+    @overload
+    def get(self, key: Tuple[str, str], default: Union[str, _T]) -> Union[str, _T]:
+        ...
+
+    def get(
+        self, key: StateKey, default: Union[str, _T, None] = None
+    ) -> Union[str, _T, None]:
+        if key not in self.state_filter:
+            raise Exception("State map was filtered and doesn't include: %s", key)
+        return super().get(key, default)
+
+    def __contains__(self, key: Any) -> bool:
+        if key not in self.state_filter:
+            raise Exception("State map was filtered and doesn't include: %s", key)
+
+        return super().__contains__(key)
diff --git a/synapse/types/state.py b/synapse/types/state.py
index 5ca3c94bce..53662372af 100644
--- a/synapse/types/state.py
+++ b/synapse/types/state.py
@@ -20,6 +20,7 @@
 import logging
 from typing import (
     TYPE_CHECKING,
+    Any,
     Callable,
     Collection,
     Dict,
@@ -584,6 +585,29 @@ class StateFilter:
         # local users only
         return False
 
+    def __contains__(self, key: Any) -> bool:
+        if not isinstance(key, tuple) or len(key) != 2:
+            raise TypeError(
+                f"'in StateFilter' requires (str, str) as left operand, not {type(key).__name__}"
+            )
+
+        typ, state_key = key
+
+        if not isinstance(typ, str) or not isinstance(state_key, str):
+            raise TypeError(
+                f"'in StateFilter' requires (str, str) as left operand, not ({type(typ).__name__}, {type(state_key).__name__})"
+            )
+
+        if typ in self.types:
+            state_keys = self.types[typ]
+            if state_keys is None or state_key in state_keys:
+                return True
+
+        elif self.include_others:
+            return True
+
+        return False
+
 
 _ALL_STATE_FILTER = StateFilter(types=immutabledict(), include_others=True)
 _ALL_NON_MEMBER_STATE_FILTER = StateFilter(