summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2022-06-01 16:02:53 +0100
committerGitHub <noreply@github.com>2022-06-01 16:02:53 +0100
commit888a29f4127723a8d048ce47cff37ee8a7a6f1b9 (patch)
tree891ae4c95632801bb097aaed8aea5d8b541c5cf7 /synapse
parentFix complement tests using the wrong path (#12933) (diff)
downloadsynapse-888a29f4127723a8d048ce47cff37ee8a7a6f1b9.tar.xz
Wait for lazy join to complete when getting current state (#12872)
Diffstat (limited to 'synapse')
-rw-r--r--synapse/events/third_party_rules.py3
-rw-r--r--synapse/federation/federation_server.py4
-rw-r--r--synapse/handlers/device.py2
-rw-r--r--synapse/handlers/directory.py7
-rw-r--r--synapse/handlers/federation.py7
-rw-r--r--synapse/handlers/message.py2
-rw-r--r--synapse/handlers/presence.py6
-rw-r--r--synapse/handlers/register.py3
-rw-r--r--synapse/handlers/room.py13
-rw-r--r--synapse/handlers/room_list.py3
-rw-r--r--synapse/handlers/room_member.py5
-rw-r--r--synapse/handlers/room_summary.py11
-rw-r--r--synapse/handlers/stats.py6
-rw-r--r--synapse/handlers/sync.py13
-rw-r--r--synapse/handlers/user_directory.py6
-rw-r--r--synapse/module_api/__init__.py19
-rw-r--r--synapse/push/mailer.py4
-rw-r--r--synapse/rest/admin/rooms.py3
-rw-r--r--synapse/storage/_base.py2
-rw-r--r--synapse/storage/controllers/__init__.py4
-rw-r--r--synapse/storage/controllers/persist_events.py4
-rw-r--r--synapse/storage/controllers/state.py112
-rw-r--r--synapse/storage/databases/main/room.py18
-rw-r--r--synapse/storage/databases/main/state.py38
-rw-r--r--synapse/storage/databases/main/state_deltas.py4
-rw-r--r--synapse/storage/databases/main/user_directory.py4
-rw-r--r--synapse/storage/util/partial_state_events_tracker.py60
27 files changed, 288 insertions, 75 deletions
diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py
index 9f4ff9799c..35f3f3690f 100644
--- a/synapse/events/third_party_rules.py
+++ b/synapse/events/third_party_rules.py
@@ -152,6 +152,7 @@ class ThirdPartyEventRules:
         self.third_party_rules = None
 
         self.store = hs.get_datastores().main
+        self._storage_controllers = hs.get_storage_controllers()
 
         self._check_event_allowed_callbacks: List[CHECK_EVENT_ALLOWED_CALLBACK] = []
         self._on_create_room_callbacks: List[ON_CREATE_ROOM_CALLBACK] = []
@@ -463,7 +464,7 @@ class ThirdPartyEventRules:
         Returns:
             A dict mapping (event type, state key) to state event.
         """
-        state_ids = await self.store.get_filtered_current_state_ids(room_id)
+        state_ids = await self._storage_controllers.state.get_current_state_ids(room_id)
         room_state_events = await self.store.get_events(state_ids.values())
 
         state_events = {}
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 12591dc8db..f4af121c4d 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -118,6 +118,8 @@ class FederationServer(FederationBase):
         self.state = hs.get_state_handler()
         self._event_auth_handler = hs.get_event_auth_handler()
 
+        self._state_storage_controller = hs.get_storage_controllers().state
+
         self.device_handler = hs.get_device_handler()
 
         # Ensure the following handlers are loaded since they register callbacks
@@ -1221,7 +1223,7 @@ class FederationServer(FederationBase):
         Raises:
             AuthError if the server does not match the ACL
         """
-        state_ids = await self.store.get_current_state_ids(room_id)
+        state_ids = await self._state_storage_controller.get_current_state_ids(room_id)
         acl_event_id = state_ids.get((EventTypes.ServerACL, ""))
 
         if not acl_event_id:
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 72faf2ee38..a0cbeedc30 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -166,7 +166,7 @@ class DeviceWorkerHandler:
         possibly_changed = set(changed)
         possibly_left = set()
         for room_id in rooms_changed:
-            current_state_ids = await self.store.get_current_state_ids(room_id)
+            current_state_ids = await self._state_storage.get_current_state_ids(room_id)
 
             # The user may have left the room
             # TODO: Check if they actually did or if we were just invited.
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 4aa33df884..44e84698c4 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -45,6 +45,7 @@ class DirectoryHandler:
         self.appservice_handler = hs.get_application_service_handler()
         self.event_creation_handler = hs.get_event_creation_handler()
         self.store = hs.get_datastores().main
+        self._storage_controllers = hs.get_storage_controllers()
         self.config = hs.config
         self.enable_room_list_search = hs.config.roomdirectory.enable_room_list_search
         self.require_membership = hs.config.server.require_membership_for_aliases
@@ -463,7 +464,11 @@ class DirectoryHandler:
         making_public = visibility == "public"
         if making_public:
             room_aliases = await self.store.get_aliases_for_room(room_id)
-            canonical_alias = await self.store.get_canonical_alias_for_room(room_id)
+            canonical_alias = (
+                await self._storage_controllers.state.get_canonical_alias_for_room(
+                    room_id
+                )
+            )
             if canonical_alias:
                 room_aliases.append(canonical_alias)
 
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 659f279441..b212ee2172 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -750,7 +750,9 @@ class FederationHandler:
         # Note that this requires the /send_join request to come back to the
         # same server.
         if room_version.msc3083_join_rules:
-            state_ids = await self.store.get_current_state_ids(room_id)
+            state_ids = await self._state_storage_controller.get_current_state_ids(
+                room_id
+            )
             if await self._event_auth_handler.has_restricted_join_rules(
                 state_ids, room_version
             ):
@@ -1552,6 +1554,9 @@ class FederationHandler:
                 success = await self.store.clear_partial_state_room(room_id)
                 if success:
                     logger.info("State resync complete for %s", room_id)
+                    self._storage_controllers.state.notify_room_un_partial_stated(
+                        room_id
+                    )
 
                     # TODO(faster_joins) update room stats and user directory?
                     return
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index ac911a2ddc..081625f0bd 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -217,7 +217,7 @@ class MessageHandler:
             )
 
             if membership == Membership.JOIN:
-                state_ids = await self.store.get_filtered_current_state_ids(
+                state_ids = await self._state_storage_controller.get_current_state_ids(
                     room_id, state_filter=state_filter
                 )
                 room_state = await self.store.get_events(state_ids.values())
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index bf112b9e1e..895ea63ed3 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -134,6 +134,7 @@ class BasePresenceHandler(abc.ABC):
     def __init__(self, hs: "HomeServer"):
         self.clock = hs.get_clock()
         self.store = hs.get_datastores().main
+        self._storage_controllers = hs.get_storage_controllers()
         self.presence_router = hs.get_presence_router()
         self.state = hs.get_state_handler()
         self.is_mine_id = hs.is_mine_id
@@ -1348,7 +1349,10 @@ class PresenceHandler(BasePresenceHandler):
                     self._event_pos,
                     room_max_stream_ordering,
                 )
-                max_pos, deltas = await self.store.get_current_state_deltas(
+                (
+                    max_pos,
+                    deltas,
+                ) = await self._storage_controllers.state.get_current_state_deltas(
                     self._event_pos, room_max_stream_ordering
                 )
 
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 05bb1e0225..338204287f 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -87,6 +87,7 @@ class LoginDict(TypedDict):
 class RegistrationHandler:
     def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastores().main
+        self._storage_controllers = hs.get_storage_controllers()
         self.clock = hs.get_clock()
         self.hs = hs
         self.auth = hs.get_auth()
@@ -528,7 +529,7 @@ class RegistrationHandler:
 
                 if requires_invite:
                     # If the server is in the room, check if the room is public.
-                    state = await self.store.get_filtered_current_state_ids(
+                    state = await self._storage_controllers.state.get_current_state_ids(
                         room_id, StateFilter.from_types([(EventTypes.JoinRules, "")])
                     )
 
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index e1341dd9bb..e2b0e519d4 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -107,6 +107,7 @@ class EventContext:
 class RoomCreationHandler:
     def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastores().main
+        self._storage_controllers = hs.get_storage_controllers()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.hs = hs
@@ -480,8 +481,10 @@ class RoomCreationHandler:
             if room_type == RoomTypes.SPACE:
                 types_to_copy.append((EventTypes.SpaceChild, None))
 
-        old_room_state_ids = await self.store.get_filtered_current_state_ids(
-            old_room_id, StateFilter.from_types(types_to_copy)
+        old_room_state_ids = (
+            await self._storage_controllers.state.get_current_state_ids(
+                old_room_id, StateFilter.from_types(types_to_copy)
+            )
         )
         # map from event_id to BaseEvent
         old_room_state_events = await self.store.get_events(old_room_state_ids.values())
@@ -558,8 +561,10 @@ class RoomCreationHandler:
         )
 
         # Transfer membership events
-        old_room_member_state_ids = await self.store.get_filtered_current_state_ids(
-            old_room_id, StateFilter.from_types([(EventTypes.Member, None)])
+        old_room_member_state_ids = (
+            await self._storage_controllers.state.get_current_state_ids(
+                old_room_id, StateFilter.from_types([(EventTypes.Member, None)])
+            )
         )
 
         # map from event_id to BaseEvent
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index f3577b5d5a..183d4ae3c4 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -50,6 +50,7 @@ EMPTY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None)
 class RoomListHandler:
     def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastores().main
+        self._storage_controllers = hs.get_storage_controllers()
         self.hs = hs
         self.enable_room_list_search = hs.config.roomdirectory.enable_room_list_search
         self.response_cache: ResponseCache[
@@ -274,7 +275,7 @@ class RoomListHandler:
             if aliases:
                 result["aliases"] = aliases
 
-        current_state_ids = await self.store.get_current_state_ids(
+        current_state_ids = await self._storage_controllers.state.get_current_state_ids(
             room_id, on_invalidate=cache_context.invalidate
         )
 
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 00662dc961..70c674ff8e 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -68,6 +68,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
     def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.store = hs.get_datastores().main
+        self._storage_controllers = hs.get_storage_controllers()
         self.auth = hs.get_auth()
         self.state_handler = hs.get_state_handler()
         self.config = hs.config
@@ -994,7 +995,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         # If the host is in the room, but not one of the authorised hosts
         # for restricted join rules, a remote join must be used.
         room_version = await self.store.get_room_version(room_id)
-        current_state_ids = await self.store.get_current_state_ids(room_id)
+        current_state_ids = await self._storage_controllers.state.get_current_state_ids(
+            room_id
+        )
 
         # If restricted join rules are not being used, a local join can always
         # be used.
diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py
index 75aee6a111..13098f56ed 100644
--- a/synapse/handlers/room_summary.py
+++ b/synapse/handlers/room_summary.py
@@ -90,6 +90,7 @@ class RoomSummaryHandler:
     def __init__(self, hs: "HomeServer"):
         self._event_auth_handler = hs.get_event_auth_handler()
         self._store = hs.get_datastores().main
+        self._storage_controllers = hs.get_storage_controllers()
         self._event_serializer = hs.get_event_client_serializer()
         self._server_name = hs.hostname
         self._federation_client = hs.get_federation_client()
@@ -537,7 +538,7 @@ class RoomSummaryHandler:
         Returns:
              True if the room is accessible to the requesting user or server.
         """
-        state_ids = await self._store.get_current_state_ids(room_id)
+        state_ids = await self._storage_controllers.state.get_current_state_ids(room_id)
 
         # If there's no state for the room, it isn't known.
         if not state_ids:
@@ -702,7 +703,9 @@ class RoomSummaryHandler:
         # there should always be an entry
         assert stats is not None, "unable to retrieve stats for %s" % (room_id,)
 
-        current_state_ids = await self._store.get_current_state_ids(room_id)
+        current_state_ids = await self._storage_controllers.state.get_current_state_ids(
+            room_id
+        )
         create_event = await self._store.get_event(
             current_state_ids[(EventTypes.Create, "")]
         )
@@ -760,7 +763,9 @@ class RoomSummaryHandler:
         """
 
         # look for child rooms/spaces.
-        current_state_ids = await self._store.get_current_state_ids(room_id)
+        current_state_ids = await self._storage_controllers.state.get_current_state_ids(
+            room_id
+        )
 
         events = await self._store.get_events_as_list(
             [
diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py
index 436cd971ce..f45e06eb0e 100644
--- a/synapse/handlers/stats.py
+++ b/synapse/handlers/stats.py
@@ -40,6 +40,7 @@ class StatsHandler:
     def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.store = hs.get_datastores().main
+        self._storage_controllers = hs.get_storage_controllers()
         self.state = hs.get_state_handler()
         self.server_name = hs.hostname
         self.clock = hs.get_clock()
@@ -105,7 +106,10 @@ class StatsHandler:
             logger.debug(
                 "Processing room stats %s->%s", self.pos, room_max_stream_ordering
             )
-            max_pos, deltas = await self.store.get_current_state_deltas(
+            (
+                max_pos,
+                deltas,
+            ) = await self._storage_controllers.state.get_current_state_deltas(
                 self.pos, room_max_stream_ordering
             )
 
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index a1d41358d9..b4ead79f97 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -506,8 +506,10 @@ class SyncHandler:
                 # ensure that we always include current state in the timeline
                 current_state_ids: FrozenSet[str] = frozenset()
                 if any(e.is_state() for e in recents):
-                    current_state_ids_map = await self.store.get_current_state_ids(
-                        room_id
+                    current_state_ids_map = (
+                        await self._state_storage_controller.get_current_state_ids(
+                            room_id
+                        )
                     )
                     current_state_ids = frozenset(current_state_ids_map.values())
 
@@ -574,8 +576,11 @@ class SyncHandler:
                 # ensure that we always include current state in the timeline
                 current_state_ids = frozenset()
                 if any(e.is_state() for e in loaded_recents):
-                    current_state_ids_map = await self.store.get_current_state_ids(
-                        room_id
+                    # FIXME(faster_joins): We use the partial state here as
+                    # we don't want to block `/sync` on finishing a lazy join.
+                    # Is this the correct way of doing it?
+                    current_state_ids_map = (
+                        await self.store.get_partial_current_state_ids(room_id)
                     )
                     current_state_ids = frozenset(current_state_ids_map.values())
 
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index 74f7fdfe6c..8c3c52e1ca 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -56,6 +56,7 @@ class UserDirectoryHandler(StateDeltasHandler):
         super().__init__(hs)
 
         self.store = hs.get_datastores().main
+        self._storage_controllers = hs.get_storage_controllers()
         self.server_name = hs.hostname
         self.clock = hs.get_clock()
         self.notifier = hs.get_notifier()
@@ -174,7 +175,10 @@ class UserDirectoryHandler(StateDeltasHandler):
                 logger.debug(
                     "Processing user stats %s->%s", self.pos, room_max_stream_ordering
                 )
-                max_pos, deltas = await self.store.get_current_state_deltas(
+                (
+                    max_pos,
+                    deltas,
+                ) = await self._storage_controllers.state.get_current_state_deltas(
                     self.pos, room_max_stream_ordering
                 )
 
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index b7451fc870..a8ad575fcd 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -194,6 +194,7 @@ class ModuleApi:
         self._store: Union[
             DataStore, "GenericWorkerSlavedStore"
         ] = hs.get_datastores().main
+        self._storage_controllers = hs.get_storage_controllers()
         self._auth = hs.get_auth()
         self._auth_handler = auth_handler
         self._server_name = hs.hostname
@@ -911,7 +912,7 @@ class ModuleApi:
                 The filtered state events in the room.
         """
         state_ids = yield defer.ensureDeferred(
-            self._store.get_filtered_current_state_ids(
+            self._storage_controllers.state.get_current_state_ids(
                 room_id=room_id, state_filter=StateFilter.from_types(types)
             )
         )
@@ -1289,20 +1290,16 @@ class ModuleApi:
                                                                 # regardless of their state key
                     ]
         """
+        state_filter = None
         if event_filter:
             # If a filter was provided, turn it into a StateFilter and retrieve a filtered
             # view of the state.
             state_filter = StateFilter.from_types(event_filter)
-            state_ids = await self._store.get_filtered_current_state_ids(
-                room_id,
-                state_filter,
-            )
-        else:
-            # If no filter was provided, get the whole state. We could also reuse the call
-            # to get_filtered_current_state_ids above, with `state_filter = StateFilter.all()`,
-            # but get_filtered_current_state_ids isn't cached and `get_current_state_ids`
-            # is, so using the latter when we can is better for perf.
-            state_ids = await self._store.get_current_state_ids(room_id)
+
+        state_ids = await self._storage_controllers.state.get_current_state_ids(
+            room_id,
+            state_filter,
+        )
 
         state_events = await self._store.get_events(state_ids.values())
 
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index 63aefd07f5..015c19b2d9 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -255,7 +255,9 @@ class Mailer:
             user_display_name = user_id
 
         async def _fetch_room_state(room_id: str) -> None:
-            room_state = await self.store.get_current_state_ids(room_id)
+            room_state = await self._state_storage_controller.get_current_state_ids(
+                room_id
+            )
             state_by_room[room_id] = room_state
 
         # Run at most 3 of these at once: sync does 10 at a time but email
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index 356d6f74d7..1cacd1a4f0 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -418,6 +418,7 @@ class RoomStateRestServlet(RestServlet):
     def __init__(self, hs: "HomeServer"):
         self.auth = hs.get_auth()
         self.store = hs.get_datastores().main
+        self._storage_controllers = hs.get_storage_controllers()
         self.clock = hs.get_clock()
         self._event_serializer = hs.get_event_client_serializer()
 
@@ -430,7 +431,7 @@ class RoomStateRestServlet(RestServlet):
         if not ret:
             raise NotFoundError("Room not found")
 
-        event_ids = await self.store.get_current_state_ids(room_id)
+        event_ids = await self._storage_controllers.state.get_current_state_ids(room_id)
         events = await self.store.get_events(event_ids.values())
         now = self.clock.time_msec()
         room_state = self._event_serializer.serialize_events(events.values(), now)
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 8df80664a2..57bd74700e 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -77,7 +77,7 @@ class SQLBaseStore(metaclass=ABCMeta):
 
         # Purge other caches based on room state.
         self._attempt_to_invalidate_cache("get_room_summary", (room_id,))
-        self._attempt_to_invalidate_cache("get_current_state_ids", (room_id,))
+        self._attempt_to_invalidate_cache("get_partial_current_state_ids", (room_id,))
 
     def _attempt_to_invalidate_cache(
         self, cache_name: str, key: Optional[Collection[Any]]
diff --git a/synapse/storage/controllers/__init__.py b/synapse/storage/controllers/__init__.py
index 992261d07b..55649719f6 100644
--- a/synapse/storage/controllers/__init__.py
+++ b/synapse/storage/controllers/__init__.py
@@ -18,7 +18,7 @@ from synapse.storage.controllers.persist_events import (
     EventsPersistenceStorageController,
 )
 from synapse.storage.controllers.purge_events import PurgeEventsStorageController
-from synapse.storage.controllers.state import StateGroupStorageController
+from synapse.storage.controllers.state import StateStorageController
 from synapse.storage.databases import Databases
 from synapse.storage.databases.main import DataStore
 
@@ -39,7 +39,7 @@ class StorageControllers:
         self.main = stores.main
 
         self.purge_events = PurgeEventsStorageController(hs, stores)
-        self.state = StateGroupStorageController(hs, stores)
+        self.state = StateStorageController(hs, stores)
 
         self.persistence = None
         if stores.persist_events:
diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py
index ef8c135b12..4caaa81808 100644
--- a/synapse/storage/controllers/persist_events.py
+++ b/synapse/storage/controllers/persist_events.py
@@ -994,7 +994,7 @@ class EventsPersistenceStorageController:
 
         Assumes that we are only persisting events for one room at a time.
         """
-        existing_state = await self.main_store.get_current_state_ids(room_id)
+        existing_state = await self.main_store.get_partial_current_state_ids(room_id)
 
         to_delete = [key for key in existing_state if key not in current_state]
 
@@ -1083,7 +1083,7 @@ class EventsPersistenceStorageController:
         # The server will leave the room, so we go and find out which remote
         # users will still be joined when we leave.
         if current_state is None:
-            current_state = await self.main_store.get_current_state_ids(room_id)
+            current_state = await self.main_store.get_partial_current_state_ids(room_id)
             current_state = dict(current_state)
             for key in delta.to_delete:
                 current_state.pop(key, None)
diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py
index 0f09953086..9952b00493 100644
--- a/synapse/storage/controllers/state.py
+++ b/synapse/storage/controllers/state.py
@@ -14,7 +14,9 @@
 import logging
 from typing import (
     TYPE_CHECKING,
+    Any,
     Awaitable,
+    Callable,
     Collection,
     Dict,
     Iterable,
@@ -24,9 +26,13 @@ from typing import (
     Tuple,
 )
 
+from synapse.api.constants import EventTypes
 from synapse.events import EventBase
 from synapse.storage.state import StateFilter
-from synapse.storage.util.partial_state_events_tracker import PartialStateEventsTracker
+from synapse.storage.util.partial_state_events_tracker import (
+    PartialCurrentStateTracker,
+    PartialStateEventsTracker,
+)
 from synapse.types import MutableStateMap, StateMap
 
 if TYPE_CHECKING:
@@ -36,17 +42,27 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
-class StateGroupStorageController:
-    """High level interface to fetching state for event."""
+class StateStorageController:
+    """High level interface to fetching state for an event, or the current state
+    in a room.
+    """
 
     def __init__(self, hs: "HomeServer", stores: "Databases"):
         self._is_mine_id = hs.is_mine_id
         self.stores = stores
         self._partial_state_events_tracker = PartialStateEventsTracker(stores.main)
+        self._partial_state_room_tracker = PartialCurrentStateTracker(stores.main)
 
     def notify_event_un_partial_stated(self, event_id: str) -> None:
         self._partial_state_events_tracker.notify_un_partial_stated(event_id)
 
+    def notify_room_un_partial_stated(self, room_id: str) -> None:
+        """Notify that the room no longer has any partial state.
+
+        Must be called after `DataStore.clear_partial_state_room`
+        """
+        self._partial_state_room_tracker.notify_un_partial_stated(room_id)
+
     async def get_state_group_delta(
         self, state_group: int
     ) -> Tuple[Optional[int], Optional[StateMap[str]]]:
@@ -349,3 +365,93 @@ class StateGroupStorageController:
         return await self.stores.state.store_state_group(
             event_id, room_id, prev_group, delta_ids, current_state_ids
         )
+
+    async def get_current_state_ids(
+        self,
+        room_id: str,
+        state_filter: Optional[StateFilter] = None,
+        on_invalidate: Optional[Callable[[], None]] = None,
+    ) -> StateMap[str]:
+        """Get the current state event ids for a room based on the
+        current_state_events table.
+
+        If a state filter is given (that is not `StateFilter.all()`) the query
+        result is *not* cached.
+
+        Args:
+            room_id: The room to get the state IDs of. state_filter: The state
+            filter used to fetch state from the
+                database.
+            on_invalidate: Callback for when the `get_current_state_ids` cache
+                for the room gets invalidated.
+
+        Returns:
+            The current state of the room.
+        """
+        if not state_filter or state_filter.must_await_full_state(self._is_mine_id):
+            await self._partial_state_room_tracker.await_full_state(room_id)
+
+        if state_filter and not state_filter.is_full():
+            return await self.stores.main.get_partial_filtered_current_state_ids(
+                room_id, state_filter
+            )
+        else:
+            return await self.stores.main.get_partial_current_state_ids(
+                room_id, on_invalidate=on_invalidate
+            )
+
+    async def get_canonical_alias_for_room(self, room_id: str) -> Optional[str]:
+        """Get canonical alias for room, if any
+
+        Args:
+            room_id: The room ID
+
+        Returns:
+            The canonical alias, if any
+        """
+
+        state = await self.get_current_state_ids(
+            room_id, StateFilter.from_types([(EventTypes.CanonicalAlias, "")])
+        )
+
+        event_id = state.get((EventTypes.CanonicalAlias, ""))
+        if not event_id:
+            return None
+
+        event = await self.stores.main.get_event(event_id, allow_none=True)
+        if not event:
+            return None
+
+        return event.content.get("canonical_alias")
+
+    async def get_current_state_deltas(
+        self, prev_stream_id: int, max_stream_id: int
+    ) -> Tuple[int, List[Dict[str, Any]]]:
+        """Fetch a list of room state changes since the given stream id
+
+        Each entry in the result contains the following fields:
+            - stream_id (int)
+            - room_id (str)
+            - type (str): event type
+            - state_key (str):
+            - event_id (str|None): new event_id for this state key. None if the
+                state has been deleted.
+            - prev_event_id (str|None): previous event_id for this state key. None
+                if it's new state.
+
+        Args:
+            prev_stream_id: point to get changes since (exclusive)
+            max_stream_id: the point that we know has been correctly persisted
+               - ie, an upper limit to return changes from.
+
+        Returns:
+            A tuple consisting of:
+               - the stream id which these results go up to
+               - list of current_state_delta_stream rows. If it is empty, we are
+                 up to date.
+        """
+        # FIXME(faster_joins): what do we do here?
+
+        return await self.stores.main.get_partial_current_state_deltas(
+            prev_stream_id, max_stream_id
+        )
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index cfd8ce1624..68d4fc2e64 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -1139,6 +1139,24 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
             keyvalues={"room_id": room_id},
         )
 
+    async def is_partial_state_room(self, room_id: str) -> bool:
+        """Checks if this room has partial state.
+
+        Returns true if this is a "partial-state" room, which means that the state
+        at events in the room, and `current_state_events`, may not yet be
+        complete.
+        """
+
+        entry = await self.db_pool.simple_select_one_onecol(
+            table="partial_state_rooms",
+            keyvalues={"room_id": room_id},
+            retcol="room_id",
+            allow_none=True,
+            desc="is_partial_state_room",
+        )
+
+        return entry is not None
+
 
 class _BackgroundUpdates:
     REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 3f2be3854b..bdd00273cd 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -242,7 +242,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         Raises:
             NotFoundError if the room is unknown
         """
-        state_ids = await self.get_current_state_ids(room_id)
+        state_ids = await self.get_partial_current_state_ids(room_id)
 
         if not state_ids:
             raise NotFoundError(f"Current state for room {room_id} is empty")
@@ -258,10 +258,12 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         return create_event
 
     @cached(max_entries=100000, iterable=True)
-    async def get_current_state_ids(self, room_id: str) -> StateMap[str]:
+    async def get_partial_current_state_ids(self, room_id: str) -> StateMap[str]:
         """Get the current state event ids for a room based on the
         current_state_events table.
 
+        This may be the partial state if we're lazy joining the room.
+
         Args:
             room_id: The room to get the state IDs of.
 
@@ -280,17 +282,19 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
             return {(intern_string(r[0]), intern_string(r[1])): r[2] for r in txn}
 
         return await self.db_pool.runInteraction(
-            "get_current_state_ids", _get_current_state_ids_txn
+            "get_partial_current_state_ids", _get_current_state_ids_txn
         )
 
     # FIXME: how should this be cached?
-    async def get_filtered_current_state_ids(
+    async def get_partial_filtered_current_state_ids(
         self, room_id: str, state_filter: Optional[StateFilter] = None
     ) -> StateMap[str]:
         """Get the current state event of a given type for a room based on the
         current_state_events table.  This may not be as up-to-date as the result
         of doing a fresh state resolution as per state_handler.get_current_state
 
+        This may be the partial state if we're lazy joining the room.
+
         Args:
             room_id
             state_filter: The state filter used to fetch state
@@ -306,7 +310,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         if not where_clause:
             # We delegate to the cached version
-            return await self.get_current_state_ids(room_id)
+            return await self.get_partial_current_state_ids(room_id)
 
         def _get_filtered_current_state_ids_txn(
             txn: LoggingTransaction,
@@ -334,30 +338,6 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
             "get_filtered_current_state_ids", _get_filtered_current_state_ids_txn
         )
 
-    async def get_canonical_alias_for_room(self, room_id: str) -> Optional[str]:
-        """Get canonical alias for room, if any
-
-        Args:
-            room_id: The room ID
-
-        Returns:
-            The canonical alias, if any
-        """
-
-        state = await self.get_filtered_current_state_ids(
-            room_id, StateFilter.from_types([(EventTypes.CanonicalAlias, "")])
-        )
-
-        event_id = state.get((EventTypes.CanonicalAlias, ""))
-        if not event_id:
-            return None
-
-        event = await self.get_event(event_id, allow_none=True)
-        if not event:
-            return None
-
-        return event.content.get("canonical_alias")
-
     @cached(max_entries=50000)
     async def _get_state_group_for_event(self, event_id: str) -> Optional[int]:
         return await self.db_pool.simple_select_one_onecol(
diff --git a/synapse/storage/databases/main/state_deltas.py b/synapse/storage/databases/main/state_deltas.py
index 188afec332..445213e12a 100644
--- a/synapse/storage/databases/main/state_deltas.py
+++ b/synapse/storage/databases/main/state_deltas.py
@@ -27,7 +27,7 @@ class StateDeltasStore(SQLBaseStore):
     # attribute. TODO: can we get static analysis to enforce this?
     _curr_state_delta_stream_cache: StreamChangeCache
 
-    async def get_current_state_deltas(
+    async def get_partial_current_state_deltas(
         self, prev_stream_id: int, max_stream_id: int
     ) -> Tuple[int, List[Dict[str, Any]]]:
         """Fetch a list of room state changes since the given stream id
@@ -42,6 +42,8 @@ class StateDeltasStore(SQLBaseStore):
             - prev_event_id (str|None): previous event_id for this state key. None
                 if it's new state.
 
+        This may be the partial state if we're lazy joining the room.
+
         Args:
             prev_stream_id: point to get changes since (exclusive)
             max_stream_id: the point that we know has been correctly persisted
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index 2282242e9d..ddb25b5cea 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -441,7 +441,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
             (EventTypes.RoomHistoryVisibility, ""),
         )
 
-        current_state_ids = await self.get_filtered_current_state_ids(  # type: ignore[attr-defined]
+        # Getting the partial state is fine, as we're not looking at membership
+        # events.
+        current_state_ids = await self.get_partial_filtered_current_state_ids(  # type: ignore[attr-defined]
             room_id, StateFilter.from_types(types_to_filter)
         )
 
diff --git a/synapse/storage/util/partial_state_events_tracker.py b/synapse/storage/util/partial_state_events_tracker.py
index a61a951ef0..211437cfaa 100644
--- a/synapse/storage/util/partial_state_events_tracker.py
+++ b/synapse/storage/util/partial_state_events_tracker.py
@@ -21,6 +21,7 @@ from twisted.internet.defer import Deferred
 
 from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
+from synapse.storage.databases.main.room import RoomWorkerStore
 from synapse.util import unwrapFirstError
 
 logger = logging.getLogger(__name__)
@@ -118,3 +119,62 @@ class PartialStateEventsTracker:
                     observer_set.discard(observer)
                     if not observer_set:
                         del self._observers[event_id]
+
+
+class PartialCurrentStateTracker:
+    """Keeps track of which rooms have partial state, after partial-state joins"""
+
+    def __init__(self, store: RoomWorkerStore):
+        self._store = store
+
+        # a map from room id to a set of Deferreds which are waiting for that room to be
+        # un-partial-stated.
+        self._observers: Dict[str, Set[Deferred[None]]] = defaultdict(set)
+
+    def notify_un_partial_stated(self, room_id: str) -> None:
+        """Notify that we now have full current state for a given room
+
+        Unblocks any callers to await_full_state() for that room.
+
+        Args:
+            room_id: the room that now has full current state.
+        """
+        observers = self._observers.pop(room_id, None)
+        if not observers:
+            return
+        logger.info(
+            "Notifying %i things waiting for un-partial-stating of room %s",
+            len(observers),
+            room_id,
+        )
+        with PreserveLoggingContext():
+            for o in observers:
+                o.callback(None)
+
+    async def await_full_state(self, room_id: str) -> None:
+        # We add the deferred immediately so that the DB call to check for
+        # partial state doesn't race when we unpartial the room.
+        d: Deferred[None] = Deferred()
+        self._observers.setdefault(room_id, set()).add(d)
+
+        try:
+            # Check if the room has partial current state or not.
+            has_partial_state = await self._store.is_partial_state_room(room_id)
+            if not has_partial_state:
+                return
+
+            logger.info(
+                "Awaiting un-partial-stating of room %s",
+                room_id,
+            )
+
+            await make_deferred_yieldable(d)
+
+            logger.info("Room has un-partial-stated")
+        finally:
+            # Remove the added observer, and remove the room entry if its empty.
+            ds = self._observers.get(room_id)
+            if ds is not None:
+                ds.discard(d)
+                if not ds:
+                    self._observers.pop(room_id, None)