diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py
index 0f09953086..9952b00493 100644
--- a/synapse/storage/controllers/state.py
+++ b/synapse/storage/controllers/state.py
@@ -14,7 +14,9 @@
import logging
from typing import (
TYPE_CHECKING,
+ Any,
Awaitable,
+ Callable,
Collection,
Dict,
Iterable,
@@ -24,9 +26,13 @@ from typing import (
Tuple,
)
+from synapse.api.constants import EventTypes
from synapse.events import EventBase
from synapse.storage.state import StateFilter
-from synapse.storage.util.partial_state_events_tracker import PartialStateEventsTracker
+from synapse.storage.util.partial_state_events_tracker import (
+ PartialCurrentStateTracker,
+ PartialStateEventsTracker,
+)
from synapse.types import MutableStateMap, StateMap
if TYPE_CHECKING:
@@ -36,17 +42,27 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-class StateGroupStorageController:
- """High level interface to fetching state for event."""
+class StateStorageController:
+ """High level interface to fetching state for an event, or the current state
+ in a room.
+ """
def __init__(self, hs: "HomeServer", stores: "Databases"):
self._is_mine_id = hs.is_mine_id
self.stores = stores
self._partial_state_events_tracker = PartialStateEventsTracker(stores.main)
+ self._partial_state_room_tracker = PartialCurrentStateTracker(stores.main)
def notify_event_un_partial_stated(self, event_id: str) -> None:
self._partial_state_events_tracker.notify_un_partial_stated(event_id)
+ def notify_room_un_partial_stated(self, room_id: str) -> None:
+ """Notify that the room no longer has any partial state.
+
+ Must be called after `DataStore.clear_partial_state_room`
+ """
+ self._partial_state_room_tracker.notify_un_partial_stated(room_id)
+
async def get_state_group_delta(
self, state_group: int
) -> Tuple[Optional[int], Optional[StateMap[str]]]:
@@ -349,3 +365,93 @@ class StateGroupStorageController:
return await self.stores.state.store_state_group(
event_id, room_id, prev_group, delta_ids, current_state_ids
)
+
+ async def get_current_state_ids(
+ self,
+ room_id: str,
+ state_filter: Optional[StateFilter] = None,
+ on_invalidate: Optional[Callable[[], None]] = None,
+ ) -> StateMap[str]:
+ """Get the current state event ids for a room based on the
+ current_state_events table.
+
+ If a state filter is given (that is not `StateFilter.all()`) the query
+ result is *not* cached.
+
+ Args:
+ room_id: The room to get the state IDs of. state_filter: The state
+ filter used to fetch state from the
+ database.
+ on_invalidate: Callback for when the `get_current_state_ids` cache
+ for the room gets invalidated.
+
+ Returns:
+ The current state of the room.
+ """
+ if not state_filter or state_filter.must_await_full_state(self._is_mine_id):
+ await self._partial_state_room_tracker.await_full_state(room_id)
+
+ if state_filter and not state_filter.is_full():
+ return await self.stores.main.get_partial_filtered_current_state_ids(
+ room_id, state_filter
+ )
+ else:
+ return await self.stores.main.get_partial_current_state_ids(
+ room_id, on_invalidate=on_invalidate
+ )
+
+ async def get_canonical_alias_for_room(self, room_id: str) -> Optional[str]:
+ """Get canonical alias for room, if any
+
+ Args:
+ room_id: The room ID
+
+ Returns:
+ The canonical alias, if any
+ """
+
+ state = await self.get_current_state_ids(
+ room_id, StateFilter.from_types([(EventTypes.CanonicalAlias, "")])
+ )
+
+ event_id = state.get((EventTypes.CanonicalAlias, ""))
+ if not event_id:
+ return None
+
+ event = await self.stores.main.get_event(event_id, allow_none=True)
+ if not event:
+ return None
+
+ return event.content.get("canonical_alias")
+
+ async def get_current_state_deltas(
+ self, prev_stream_id: int, max_stream_id: int
+ ) -> Tuple[int, List[Dict[str, Any]]]:
+ """Fetch a list of room state changes since the given stream id
+
+ Each entry in the result contains the following fields:
+ - stream_id (int)
+ - room_id (str)
+ - type (str): event type
+ - state_key (str):
+ - event_id (str|None): new event_id for this state key. None if the
+ state has been deleted.
+ - prev_event_id (str|None): previous event_id for this state key. None
+ if it's new state.
+
+ Args:
+ prev_stream_id: point to get changes since (exclusive)
+ max_stream_id: the point that we know has been correctly persisted
+ - ie, an upper limit to return changes from.
+
+ Returns:
+ A tuple consisting of:
+ - the stream id which these results go up to
+ - list of current_state_delta_stream rows. If it is empty, we are
+ up to date.
+ """
+ # FIXME(faster_joins): what do we do here?
+
+ return await self.stores.main.get_partial_current_state_deltas(
+ prev_stream_id, max_stream_id
+ )
|