diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 8df80664a2..57bd74700e 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -77,7 +77,7 @@ class SQLBaseStore(metaclass=ABCMeta):
# Purge other caches based on room state.
self._attempt_to_invalidate_cache("get_room_summary", (room_id,))
- self._attempt_to_invalidate_cache("get_current_state_ids", (room_id,))
+ self._attempt_to_invalidate_cache("get_partial_current_state_ids", (room_id,))
def _attempt_to_invalidate_cache(
self, cache_name: str, key: Optional[Collection[Any]]
diff --git a/synapse/storage/controllers/__init__.py b/synapse/storage/controllers/__init__.py
index 992261d07b..55649719f6 100644
--- a/synapse/storage/controllers/__init__.py
+++ b/synapse/storage/controllers/__init__.py
@@ -18,7 +18,7 @@ from synapse.storage.controllers.persist_events import (
EventsPersistenceStorageController,
)
from synapse.storage.controllers.purge_events import PurgeEventsStorageController
-from synapse.storage.controllers.state import StateGroupStorageController
+from synapse.storage.controllers.state import StateStorageController
from synapse.storage.databases import Databases
from synapse.storage.databases.main import DataStore
@@ -39,7 +39,7 @@ class StorageControllers:
self.main = stores.main
self.purge_events = PurgeEventsStorageController(hs, stores)
- self.state = StateGroupStorageController(hs, stores)
+ self.state = StateStorageController(hs, stores)
self.persistence = None
if stores.persist_events:
diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py
index ef8c135b12..4caaa81808 100644
--- a/synapse/storage/controllers/persist_events.py
+++ b/synapse/storage/controllers/persist_events.py
@@ -994,7 +994,7 @@ class EventsPersistenceStorageController:
Assumes that we are only persisting events for one room at a time.
"""
- existing_state = await self.main_store.get_current_state_ids(room_id)
+ existing_state = await self.main_store.get_partial_current_state_ids(room_id)
to_delete = [key for key in existing_state if key not in current_state]
@@ -1083,7 +1083,7 @@ class EventsPersistenceStorageController:
# The server will leave the room, so we go and find out which remote
# users will still be joined when we leave.
if current_state is None:
- current_state = await self.main_store.get_current_state_ids(room_id)
+ current_state = await self.main_store.get_partial_current_state_ids(room_id)
current_state = dict(current_state)
for key in delta.to_delete:
current_state.pop(key, None)
diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py
index 0f09953086..9952b00493 100644
--- a/synapse/storage/controllers/state.py
+++ b/synapse/storage/controllers/state.py
@@ -14,7 +14,9 @@
import logging
from typing import (
TYPE_CHECKING,
+ Any,
Awaitable,
+ Callable,
Collection,
Dict,
Iterable,
@@ -24,9 +26,13 @@ from typing import (
Tuple,
)
+from synapse.api.constants import EventTypes
from synapse.events import EventBase
from synapse.storage.state import StateFilter
-from synapse.storage.util.partial_state_events_tracker import PartialStateEventsTracker
+from synapse.storage.util.partial_state_events_tracker import (
+ PartialCurrentStateTracker,
+ PartialStateEventsTracker,
+)
from synapse.types import MutableStateMap, StateMap
if TYPE_CHECKING:
@@ -36,17 +42,27 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-class StateGroupStorageController:
- """High level interface to fetching state for event."""
+class StateStorageController:
+ """High level interface to fetching state for an event, or the current state
+ in a room.
+ """
def __init__(self, hs: "HomeServer", stores: "Databases"):
self._is_mine_id = hs.is_mine_id
self.stores = stores
self._partial_state_events_tracker = PartialStateEventsTracker(stores.main)
+ self._partial_state_room_tracker = PartialCurrentStateTracker(stores.main)
def notify_event_un_partial_stated(self, event_id: str) -> None:
self._partial_state_events_tracker.notify_un_partial_stated(event_id)
+ def notify_room_un_partial_stated(self, room_id: str) -> None:
+ """Notify that the room no longer has any partial state.
+
+ Must be called after `DataStore.clear_partial_state_room`
+ """
+ self._partial_state_room_tracker.notify_un_partial_stated(room_id)
+
async def get_state_group_delta(
self, state_group: int
) -> Tuple[Optional[int], Optional[StateMap[str]]]:
@@ -349,3 +365,93 @@ class StateGroupStorageController:
return await self.stores.state.store_state_group(
event_id, room_id, prev_group, delta_ids, current_state_ids
)
+
+ async def get_current_state_ids(
+ self,
+ room_id: str,
+ state_filter: Optional[StateFilter] = None,
+ on_invalidate: Optional[Callable[[], None]] = None,
+ ) -> StateMap[str]:
+ """Get the current state event ids for a room based on the
+ current_state_events table.
+
+ If a state filter is given (that is not `StateFilter.all()`) the query
+ result is *not* cached.
+
+ Args:
+ room_id: The room to get the state IDs of. state_filter: The state
+ filter used to fetch state from the
+ database.
+ on_invalidate: Callback for when the `get_current_state_ids` cache
+ for the room gets invalidated.
+
+ Returns:
+ The current state of the room.
+ """
+ if not state_filter or state_filter.must_await_full_state(self._is_mine_id):
+ await self._partial_state_room_tracker.await_full_state(room_id)
+
+ if state_filter and not state_filter.is_full():
+ return await self.stores.main.get_partial_filtered_current_state_ids(
+ room_id, state_filter
+ )
+ else:
+ return await self.stores.main.get_partial_current_state_ids(
+ room_id, on_invalidate=on_invalidate
+ )
+
+ async def get_canonical_alias_for_room(self, room_id: str) -> Optional[str]:
+ """Get canonical alias for room, if any
+
+ Args:
+ room_id: The room ID
+
+ Returns:
+ The canonical alias, if any
+ """
+
+ state = await self.get_current_state_ids(
+ room_id, StateFilter.from_types([(EventTypes.CanonicalAlias, "")])
+ )
+
+ event_id = state.get((EventTypes.CanonicalAlias, ""))
+ if not event_id:
+ return None
+
+ event = await self.stores.main.get_event(event_id, allow_none=True)
+ if not event:
+ return None
+
+ return event.content.get("canonical_alias")
+
+ async def get_current_state_deltas(
+ self, prev_stream_id: int, max_stream_id: int
+ ) -> Tuple[int, List[Dict[str, Any]]]:
+ """Fetch a list of room state changes since the given stream id
+
+ Each entry in the result contains the following fields:
+ - stream_id (int)
+ - room_id (str)
+ - type (str): event type
+ - state_key (str):
+ - event_id (str|None): new event_id for this state key. None if the
+ state has been deleted.
+ - prev_event_id (str|None): previous event_id for this state key. None
+ if it's new state.
+
+ Args:
+ prev_stream_id: point to get changes since (exclusive)
+ max_stream_id: the point that we know has been correctly persisted
+ - ie, an upper limit to return changes from.
+
+ Returns:
+ A tuple consisting of:
+ - the stream id which these results go up to
+ - list of current_state_delta_stream rows. If it is empty, we are
+ up to date.
+ """
+ # FIXME(faster_joins): what do we do here?
+
+ return await self.stores.main.get_partial_current_state_deltas(
+ prev_stream_id, max_stream_id
+ )
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index cfd8ce1624..68d4fc2e64 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -1139,6 +1139,24 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
keyvalues={"room_id": room_id},
)
+ async def is_partial_state_room(self, room_id: str) -> bool:
+ """Checks if this room has partial state.
+
+ Returns true if this is a "partial-state" room, which means that the state
+ at events in the room, and `current_state_events`, may not yet be
+ complete.
+ """
+
+ entry = await self.db_pool.simple_select_one_onecol(
+ table="partial_state_rooms",
+ keyvalues={"room_id": room_id},
+ retcol="room_id",
+ allow_none=True,
+ desc="is_partial_state_room",
+ )
+
+ return entry is not None
+
class _BackgroundUpdates:
REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 3f2be3854b..bdd00273cd 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -242,7 +242,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
Raises:
NotFoundError if the room is unknown
"""
- state_ids = await self.get_current_state_ids(room_id)
+ state_ids = await self.get_partial_current_state_ids(room_id)
if not state_ids:
raise NotFoundError(f"Current state for room {room_id} is empty")
@@ -258,10 +258,12 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return create_event
@cached(max_entries=100000, iterable=True)
- async def get_current_state_ids(self, room_id: str) -> StateMap[str]:
+ async def get_partial_current_state_ids(self, room_id: str) -> StateMap[str]:
"""Get the current state event ids for a room based on the
current_state_events table.
+ This may be the partial state if we're lazy joining the room.
+
Args:
room_id: The room to get the state IDs of.
@@ -280,17 +282,19 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return {(intern_string(r[0]), intern_string(r[1])): r[2] for r in txn}
return await self.db_pool.runInteraction(
- "get_current_state_ids", _get_current_state_ids_txn
+ "get_partial_current_state_ids", _get_current_state_ids_txn
)
# FIXME: how should this be cached?
- async def get_filtered_current_state_ids(
+ async def get_partial_filtered_current_state_ids(
self, room_id: str, state_filter: Optional[StateFilter] = None
) -> StateMap[str]:
"""Get the current state event of a given type for a room based on the
current_state_events table. This may not be as up-to-date as the result
of doing a fresh state resolution as per state_handler.get_current_state
+ This may be the partial state if we're lazy joining the room.
+
Args:
room_id
state_filter: The state filter used to fetch state
@@ -306,7 +310,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
if not where_clause:
# We delegate to the cached version
- return await self.get_current_state_ids(room_id)
+ return await self.get_partial_current_state_ids(room_id)
def _get_filtered_current_state_ids_txn(
txn: LoggingTransaction,
@@ -334,30 +338,6 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
"get_filtered_current_state_ids", _get_filtered_current_state_ids_txn
)
- async def get_canonical_alias_for_room(self, room_id: str) -> Optional[str]:
- """Get canonical alias for room, if any
-
- Args:
- room_id: The room ID
-
- Returns:
- The canonical alias, if any
- """
-
- state = await self.get_filtered_current_state_ids(
- room_id, StateFilter.from_types([(EventTypes.CanonicalAlias, "")])
- )
-
- event_id = state.get((EventTypes.CanonicalAlias, ""))
- if not event_id:
- return None
-
- event = await self.get_event(event_id, allow_none=True)
- if not event:
- return None
-
- return event.content.get("canonical_alias")
-
@cached(max_entries=50000)
async def _get_state_group_for_event(self, event_id: str) -> Optional[int]:
return await self.db_pool.simple_select_one_onecol(
diff --git a/synapse/storage/databases/main/state_deltas.py b/synapse/storage/databases/main/state_deltas.py
index 188afec332..445213e12a 100644
--- a/synapse/storage/databases/main/state_deltas.py
+++ b/synapse/storage/databases/main/state_deltas.py
@@ -27,7 +27,7 @@ class StateDeltasStore(SQLBaseStore):
# attribute. TODO: can we get static analysis to enforce this?
_curr_state_delta_stream_cache: StreamChangeCache
- async def get_current_state_deltas(
+ async def get_partial_current_state_deltas(
self, prev_stream_id: int, max_stream_id: int
) -> Tuple[int, List[Dict[str, Any]]]:
"""Fetch a list of room state changes since the given stream id
@@ -42,6 +42,8 @@ class StateDeltasStore(SQLBaseStore):
- prev_event_id (str|None): previous event_id for this state key. None
if it's new state.
+ This may be the partial state if we're lazy joining the room.
+
Args:
prev_stream_id: point to get changes since (exclusive)
max_stream_id: the point that we know has been correctly persisted
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index 2282242e9d..ddb25b5cea 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -441,7 +441,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
(EventTypes.RoomHistoryVisibility, ""),
)
- current_state_ids = await self.get_filtered_current_state_ids( # type: ignore[attr-defined]
+ # Getting the partial state is fine, as we're not looking at membership
+ # events.
+ current_state_ids = await self.get_partial_filtered_current_state_ids( # type: ignore[attr-defined]
room_id, StateFilter.from_types(types_to_filter)
)
diff --git a/synapse/storage/util/partial_state_events_tracker.py b/synapse/storage/util/partial_state_events_tracker.py
index a61a951ef0..211437cfaa 100644
--- a/synapse/storage/util/partial_state_events_tracker.py
+++ b/synapse/storage/util/partial_state_events_tracker.py
@@ -21,6 +21,7 @@ from twisted.internet.defer import Deferred
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
from synapse.storage.databases.main.events_worker import EventsWorkerStore
+from synapse.storage.databases.main.room import RoomWorkerStore
from synapse.util import unwrapFirstError
logger = logging.getLogger(__name__)
@@ -118,3 +119,62 @@ class PartialStateEventsTracker:
observer_set.discard(observer)
if not observer_set:
del self._observers[event_id]
+
+
+class PartialCurrentStateTracker:
+ """Keeps track of which rooms have partial state, after partial-state joins"""
+
+ def __init__(self, store: RoomWorkerStore):
+ self._store = store
+
+ # a map from room id to a set of Deferreds which are waiting for that room to be
+ # un-partial-stated.
+ self._observers: Dict[str, Set[Deferred[None]]] = defaultdict(set)
+
+ def notify_un_partial_stated(self, room_id: str) -> None:
+ """Notify that we now have full current state for a given room
+
+ Unblocks any callers to await_full_state() for that room.
+
+ Args:
+ room_id: the room that now has full current state.
+ """
+ observers = self._observers.pop(room_id, None)
+ if not observers:
+ return
+ logger.info(
+ "Notifying %i things waiting for un-partial-stating of room %s",
+ len(observers),
+ room_id,
+ )
+ with PreserveLoggingContext():
+ for o in observers:
+ o.callback(None)
+
+ async def await_full_state(self, room_id: str) -> None:
+ # We add the deferred immediately so that the DB call to check for
+ # partial state doesn't race when we unpartial the room.
+ d: Deferred[None] = Deferred()
+ self._observers.setdefault(room_id, set()).add(d)
+
+ try:
+ # Check if the room has partial current state or not.
+ has_partial_state = await self._store.is_partial_state_room(room_id)
+ if not has_partial_state:
+ return
+
+ logger.info(
+ "Awaiting un-partial-stating of room %s",
+ room_id,
+ )
+
+ await make_deferred_yieldable(d)
+
+ logger.info("Room has un-partial-stated")
+ finally:
+ # Remove the added observer, and remove the room entry if its empty.
+ ds = self._observers.get(room_id)
+ if ds is not None:
+ ds.discard(d)
+ if not ds:
+ self._observers.pop(room_id, None)
|