summary refs log tree commit diff
path: root/synapse/storage/controllers/state.py
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2022-06-01 16:02:53 +0100
committerGitHub <noreply@github.com>2022-06-01 16:02:53 +0100
commit888a29f4127723a8d048ce47cff37ee8a7a6f1b9 (patch)
tree891ae4c95632801bb097aaed8aea5d8b541c5cf7 /synapse/storage/controllers/state.py
parentFix complement tests using the wrong path (#12933) (diff)
downloadsynapse-888a29f4127723a8d048ce47cff37ee8a7a6f1b9.tar.xz
Wait for lazy join to complete when getting current state (#12872)
Diffstat (limited to 'synapse/storage/controllers/state.py')
-rw-r--r--synapse/storage/controllers/state.py112
1 files changed, 109 insertions, 3 deletions
diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py
index 0f09953086..9952b00493 100644
--- a/synapse/storage/controllers/state.py
+++ b/synapse/storage/controllers/state.py
@@ -14,7 +14,9 @@
 import logging
 from typing import (
     TYPE_CHECKING,
+    Any,
     Awaitable,
+    Callable,
     Collection,
     Dict,
     Iterable,
@@ -24,9 +26,13 @@ from typing import (
     Tuple,
 )
 
+from synapse.api.constants import EventTypes
 from synapse.events import EventBase
 from synapse.storage.state import StateFilter
-from synapse.storage.util.partial_state_events_tracker import PartialStateEventsTracker
+from synapse.storage.util.partial_state_events_tracker import (
+    PartialCurrentStateTracker,
+    PartialStateEventsTracker,
+)
 from synapse.types import MutableStateMap, StateMap
 
 if TYPE_CHECKING:
@@ -36,17 +42,27 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
-class StateGroupStorageController:
-    """High level interface to fetching state for event."""
+class StateStorageController:
+    """High level interface to fetching state for an event, or the current state
+    in a room.
+    """
 
     def __init__(self, hs: "HomeServer", stores: "Databases"):
         self._is_mine_id = hs.is_mine_id
         self.stores = stores
         self._partial_state_events_tracker = PartialStateEventsTracker(stores.main)
+        self._partial_state_room_tracker = PartialCurrentStateTracker(stores.main)
 
     def notify_event_un_partial_stated(self, event_id: str) -> None:
         self._partial_state_events_tracker.notify_un_partial_stated(event_id)
 
+    def notify_room_un_partial_stated(self, room_id: str) -> None:
+        """Notify that the room no longer has any partial state.
+
+        Must be called after `DataStore.clear_partial_state_room`
+        """
+        self._partial_state_room_tracker.notify_un_partial_stated(room_id)
+
     async def get_state_group_delta(
         self, state_group: int
     ) -> Tuple[Optional[int], Optional[StateMap[str]]]:
@@ -349,3 +365,93 @@ class StateGroupStorageController:
         return await self.stores.state.store_state_group(
             event_id, room_id, prev_group, delta_ids, current_state_ids
         )
+
+    async def get_current_state_ids(
+        self,
+        room_id: str,
+        state_filter: Optional[StateFilter] = None,
+        on_invalidate: Optional[Callable[[], None]] = None,
+    ) -> StateMap[str]:
+        """Get the current state event ids for a room based on the
+        current_state_events table.
+
+        If a state filter is given (that is not `StateFilter.all()`) the query
+        result is *not* cached.
+
+        Args:
+            room_id: The room to get the state IDs of. state_filter: The state
+            filter used to fetch state from the
+                database.
+            on_invalidate: Callback for when the `get_current_state_ids` cache
+                for the room gets invalidated.
+
+        Returns:
+            The current state of the room.
+        """
+        if not state_filter or state_filter.must_await_full_state(self._is_mine_id):
+            await self._partial_state_room_tracker.await_full_state(room_id)
+
+        if state_filter and not state_filter.is_full():
+            return await self.stores.main.get_partial_filtered_current_state_ids(
+                room_id, state_filter
+            )
+        else:
+            return await self.stores.main.get_partial_current_state_ids(
+                room_id, on_invalidate=on_invalidate
+            )
+
+    async def get_canonical_alias_for_room(self, room_id: str) -> Optional[str]:
+        """Get canonical alias for room, if any
+
+        Args:
+            room_id: The room ID
+
+        Returns:
+            The canonical alias, if any
+        """
+
+        state = await self.get_current_state_ids(
+            room_id, StateFilter.from_types([(EventTypes.CanonicalAlias, "")])
+        )
+
+        event_id = state.get((EventTypes.CanonicalAlias, ""))
+        if not event_id:
+            return None
+
+        event = await self.stores.main.get_event(event_id, allow_none=True)
+        if not event:
+            return None
+
+        return event.content.get("canonical_alias")
+
+    async def get_current_state_deltas(
+        self, prev_stream_id: int, max_stream_id: int
+    ) -> Tuple[int, List[Dict[str, Any]]]:
+        """Fetch a list of room state changes since the given stream id
+
+        Each entry in the result contains the following fields:
+            - stream_id (int)
+            - room_id (str)
+            - type (str): event type
+            - state_key (str):
+            - event_id (str|None): new event_id for this state key. None if the
+                state has been deleted.
+            - prev_event_id (str|None): previous event_id for this state key. None
+                if it's new state.
+
+        Args:
+            prev_stream_id: point to get changes since (exclusive)
+            max_stream_id: the point that we know has been correctly persisted
+               - ie, an upper limit to return changes from.
+
+        Returns:
+            A tuple consisting of:
+               - the stream id which these results go up to
+               - list of current_state_delta_stream rows. If it is empty, we are
+                 up to date.
+        """
+        # FIXME(faster_joins): what do we do here?
+
+        return await self.stores.main.get_partial_current_state_deltas(
+            prev_stream_id, max_stream_id
+        )