summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2022-06-06 11:24:12 +0300
committerGitHub <noreply@github.com>2022-06-06 09:24:12 +0100
commite3163e2e11cf8bffa4cb3e58ac0b86a83eca314c (patch)
tree639271e4b4d4157b95047f801f6b22bf39a24f08
parentRemove groups code from synapse_port_db. (#12899) (diff)
downloadsynapse-e3163e2e11cf8bffa4cb3e58ac0b86a83eca314c.tar.xz
Reduce the amount of state we pull from the DB (#12811)
-rw-r--r--changelog.d/12811.misc1
-rw-r--r--synapse/api/auth.py45
-rw-r--r--synapse/federation/federation_base.py1
-rw-r--r--synapse/federation/federation_server.py12
-rw-r--r--synapse/handlers/directory.py2
-rw-r--r--synapse/handlers/federation.py2
-rw-r--r--synapse/handlers/federation_event.py18
-rw-r--r--synapse/handlers/initial_sync.py6
-rw-r--r--synapse/handlers/message.py4
-rw-r--r--synapse/handlers/room.py5
-rw-r--r--synapse/handlers/room_member.py16
-rw-r--r--synapse/handlers/search.py2
-rw-r--r--synapse/notifier.py2
-rw-r--r--synapse/rest/admin/rooms.py34
-rw-r--r--synapse/rest/client/room.py7
-rw-r--r--synapse/server_notices/resource_limits_server_notices.py7
-rw-r--r--synapse/state/__init__.py73
-rw-r--r--synapse/storage/controllers/state.py27
-rw-r--r--tests/federation/test_federation_server.py6
-rw-r--r--tests/handlers/test_directory.py3
-rw-r--r--tests/storage/test_events.py17
-rw-r--r--tests/storage/test_purge.py5
-rw-r--r--tests/storage/test_room.py12
23 files changed, 161 insertions, 146 deletions
diff --git a/changelog.d/12811.misc b/changelog.d/12811.misc
new file mode 100644
index 0000000000..d57e1aca6b
--- /dev/null
+++ b/changelog.d/12811.misc
@@ -0,0 +1 @@
+Reduce the amount of state we pull from the DB.
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 931750668e..5a410f805a 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -29,12 +29,11 @@ from synapse.api.errors import (
     MissingClientTokenError,
 )
 from synapse.appservice import ApplicationService
-from synapse.events import EventBase
 from synapse.http import get_request_user_agent
 from synapse.http.site import SynapseRequest
 from synapse.logging.opentracing import active_span, force_tracing, start_active_span
 from synapse.storage.databases.main.registration import TokenLookupResult
-from synapse.types import Requester, StateMap, UserID, create_requester
+from synapse.types import Requester, UserID, create_requester
 from synapse.util.caches.lrucache import LruCache
 from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
 
@@ -61,8 +60,8 @@ class Auth:
         self.hs = hs
         self.clock = hs.get_clock()
         self.store = hs.get_datastores().main
-        self.state = hs.get_state_handler()
         self._account_validity_handler = hs.get_account_validity_handler()
+        self._storage_controllers = hs.get_storage_controllers()
 
         self.token_cache: LruCache[str, Tuple[str, bool]] = LruCache(
             10000, "token_cache"
@@ -79,9 +78,8 @@ class Auth:
         self,
         room_id: str,
         user_id: str,
-        current_state: Optional[StateMap[EventBase]] = None,
         allow_departed_users: bool = False,
-    ) -> EventBase:
+    ) -> Tuple[str, Optional[str]]:
         """Check if the user is in the room, or was at some point.
         Args:
             room_id: The room to check.
@@ -99,29 +97,28 @@ class Auth:
         Raises:
             AuthError if the user is/was not in the room.
         Returns:
-            Membership event for the user if the user was in the
-            room. This will be the join event if they are currently joined to
-            the room. This will be the leave event if they have left the room.
+            The current membership of the user in the room and the
+            membership event ID of the user.
         """
-        if current_state:
-            member = current_state.get((EventTypes.Member, user_id), None)
-        else:
-            member = await self.state.get_current_state(
-                room_id=room_id, event_type=EventTypes.Member, state_key=user_id
-            )
 
-        if member:
-            membership = member.membership
+        (
+            membership,
+            member_event_id,
+        ) = await self.store.get_local_current_membership_for_user_in_room(
+            user_id=user_id,
+            room_id=room_id,
+        )
 
+        if membership:
             if membership == Membership.JOIN:
-                return member
+                return membership, member_event_id
 
             # XXX this looks totally bogus. Why do we not allow users who have been banned,
             # or those who were members previously and have been re-invited?
             if allow_departed_users and membership == Membership.LEAVE:
                 forgot = await self.store.did_forget(user_id, room_id)
                 if not forgot:
-                    return member
+                    return membership, member_event_id
 
         raise AuthError(403, "User %s not in room %s" % (user_id, room_id))
 
@@ -602,8 +599,11 @@ class Auth:
         # We currently require the user is a "moderator" in the room. We do this
         # by checking if they would (theoretically) be able to change the
         # m.room.canonical_alias events
-        power_level_event = await self.state.get_current_state(
-            room_id, EventTypes.PowerLevels, ""
+
+        power_level_event = (
+            await self._storage_controllers.state.get_current_state_event(
+                room_id, EventTypes.PowerLevels, ""
+            )
         )
 
         auth_events = {}
@@ -693,12 +693,11 @@ class Auth:
             #  * The user is a non-guest user, and was ever in the room
             #  * The user is a guest user, and has joined the room
             # else it will throw.
-            member_event = await self.check_user_in_room(
+            return await self.check_user_in_room(
                 room_id, user_id, allow_departed_users=allow_departed_users
             )
-            return member_event.membership, member_event.event_id
         except AuthError:
-            visibility = await self.state.get_current_state(
+            visibility = await self._storage_controllers.state.get_current_state_event(
                 room_id, EventTypes.RoomHistoryVisibility, ""
             )
             if (
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index a6232e048b..2522bf78fc 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -53,6 +53,7 @@ class FederationBase:
         self.spam_checker = hs.get_spam_checker()
         self.store = hs.get_datastores().main
         self._clock = hs.get_clock()
+        self._storage_controllers = hs.get_storage_controllers()
 
     async def _check_sigs_and_hash(
         self, room_version: RoomVersion, pdu: EventBase
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index f4af121c4d..3e1518f1f6 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -1223,14 +1223,10 @@ class FederationServer(FederationBase):
         Raises:
             AuthError if the server does not match the ACL
         """
-        state_ids = await self._state_storage_controller.get_current_state_ids(room_id)
-        acl_event_id = state_ids.get((EventTypes.ServerACL, ""))
-
-        if not acl_event_id:
-            return
-
-        acl_event = await self.store.get_event(acl_event_id)
-        if server_matches_acl_event(server_name, acl_event):
+        acl_event = await self._storage_controllers.state.get_current_state_event(
+            room_id, EventTypes.ServerACL, ""
+        )
+        if not acl_event or server_matches_acl_event(server_name, acl_event):
             return
 
         raise AuthError(code=403, msg="Server is banned from room")
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 44e84698c4..1459a046de 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -320,7 +320,7 @@ class DirectoryHandler:
         Raises:
             ShadowBanError if the requester has been shadow-banned.
         """
-        alias_event = await self.state.get_current_state(
+        alias_event = await self._storage_controllers.state.get_current_state_event(
             room_id, EventTypes.CanonicalAlias, ""
         )
 
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index b212ee2172..6a143440d3 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -371,7 +371,7 @@ class FederationHandler:
         # First we try hosts that are already in the room
         # TODO: HEURISTIC ALERT.
 
-        curr_state = await self.state_handler.get_current_state(room_id)
+        curr_state = await self._storage_controllers.state.get_current_state(room_id)
 
         curr_domains = get_domains_from_state(curr_state)
 
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index 549b066dd9..87a0608359 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -1584,9 +1584,11 @@ class FederationEventHandler:
         if guest_access == GuestAccess.CAN_JOIN:
             return
 
-        current_state_map = await self._state_handler.get_current_state(event.room_id)
-        current_state = list(current_state_map.values())
-        await self._get_room_member_handler().kick_guest_users(current_state)
+        current_state = await self._storage_controllers.state.get_current_state(
+            event.room_id
+        )
+        current_state_list = list(current_state.values())
+        await self._get_room_member_handler().kick_guest_users(current_state_list)
 
     async def _check_for_soft_fail(
         self,
@@ -1614,6 +1616,9 @@ class FederationEventHandler:
         room_version = await self._store.get_room_version_id(event.room_id)
         room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
 
+        # The event types we want to pull from the "current" state.
+        auth_types = auth_types_for_event(room_version_obj, event)
+
         # Calculate the "current state".
         if state_ids is not None:
             # If we're explicitly given the state then we won't have all the
@@ -1643,8 +1648,10 @@ class FederationEventHandler:
                 )
             )
         else:
-            current_state_ids = await self._state_handler.get_current_state_ids(
-                event.room_id, latest_event_ids=extrem_ids
+            current_state_ids = (
+                await self._state_storage_controller.get_current_state_ids(
+                    event.room_id, StateFilter.from_types(auth_types)
+                )
             )
 
         logger.debug(
@@ -1654,7 +1661,6 @@ class FederationEventHandler:
         )
 
         # Now check if event pass auth against said current state
-        auth_types = auth_types_for_event(room_version_obj, event)
         current_state_ids_list = [
             e for k, e in current_state_ids.items() if k in auth_types
         ]
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index d2b489e816..85b472f250 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -190,7 +190,7 @@ class InitialSyncHandler:
                 if event.membership == Membership.JOIN:
                     room_end_token = now_token.room_key
                     deferred_room_state = run_in_background(
-                        self.state_handler.get_current_state, event.room_id
+                        self._state_storage_controller.get_current_state, event.room_id
                     )
                 elif event.membership == Membership.LEAVE:
                     room_end_token = RoomStreamToken(
@@ -407,7 +407,9 @@ class InitialSyncHandler:
         membership: str,
         is_peeking: bool,
     ) -> JsonDict:
-        current_state = await self.state.get_current_state(room_id=room_id)
+        current_state = await self._storage_controllers.state.get_current_state(
+            room_id=room_id
+        )
 
         # TODO: These concurrently
         time_now = self.clock.time_msec()
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 081625f0bd..f455158a2c 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -125,7 +125,9 @@ class MessageHandler:
         )
 
         if membership == Membership.JOIN:
-            data = await self.state.get_current_state(room_id, event_type, state_key)
+            data = await self._storage_controllers.state.get_current_state_event(
+                room_id, event_type, state_key
+            )
         elif membership == Membership.LEAVE:
             key = (event_type, state_key)
             # If the membership is not JOIN, then the event ID should exist.
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index e2b0e519d4..520663f172 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -1333,6 +1333,7 @@ class TimestampLookupHandler:
         self.store = hs.get_datastores().main
         self.state_handler = hs.get_state_handler()
         self.federation_client = hs.get_federation_client()
+        self._storage_controllers = hs.get_storage_controllers()
 
     async def get_event_for_timestamp(
         self,
@@ -1406,7 +1407,9 @@ class TimestampLookupHandler:
             )
 
             # Find other homeservers from the given state in the room
-            curr_state = await self.state_handler.get_current_state(room_id)
+            curr_state = await self._storage_controllers.state.get_current_state(
+                room_id
+            )
             curr_domains = get_domains_from_state(curr_state)
             likely_domains = [
                 domain for domain, depth in curr_domains if domain != self.server_name
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 70c674ff8e..d1199a0644 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -1401,7 +1401,19 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         txn_id: Optional[str],
         id_access_token: Optional[str] = None,
     ) -> int:
-        room_state = await self.state_handler.get_current_state(room_id)
+        room_state = await self._storage_controllers.state.get_current_state(
+            room_id,
+            StateFilter.from_types(
+                [
+                    (EventTypes.Member, user.to_string()),
+                    (EventTypes.CanonicalAlias, ""),
+                    (EventTypes.Name, ""),
+                    (EventTypes.Create, ""),
+                    (EventTypes.JoinRules, ""),
+                    (EventTypes.RoomAvatar, ""),
+                ]
+            ),
+        )
 
         inviter_display_name = ""
         inviter_avatar_url = ""
@@ -1797,7 +1809,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
     async def forget(self, user: UserID, room_id: str) -> None:
         user_id = user.to_string()
 
-        member = await self.state_handler.get_current_state(
+        member = await self._storage_controllers.state.get_current_state_event(
             room_id=room_id, event_type=EventTypes.Member, state_key=user_id
         )
         membership = member.membership if member else None
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 659f99f7e2..bcab98c6d5 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -348,7 +348,7 @@ class SearchHandler:
         state_results = {}
         if include_state:
             for room_id in {e.room_id for e in search_result.allowed_events}:
-                state = await self.state_handler.get_current_state(room_id)
+                state = await self._storage_controllers.state.get_current_state(room_id)
                 state_results[room_id] = list(state.values())
 
         aggregations = await self._relations_handler.get_bundled_aggregations(
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 1100434b3f..54b0ec4b97 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -681,7 +681,7 @@ class Notifier:
         return joined_room_ids, True
 
     async def _is_world_readable(self, room_id: str) -> bool:
-        state = await self.state_handler.get_current_state(
+        state = await self._storage_controllers.state.get_current_state_event(
             room_id, EventTypes.RoomHistoryVisibility, ""
         )
         if state and "history_visibility" in state.content:
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index 1cacd1a4f0..9d953d58de 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -34,6 +34,7 @@ from synapse.rest.admin._base import (
     assert_user_is_admin,
 )
 from synapse.storage.databases.main.room import RoomSortOrder
+from synapse.storage.state import StateFilter
 from synapse.types import JsonDict, RoomID, UserID, create_requester
 from synapse.util import json_decoder
 
@@ -448,7 +449,8 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet):
         super().__init__(hs)
         self.auth = hs.get_auth()
         self.admin_handler = hs.get_admin_handler()
-        self.state_handler = hs.get_state_handler()
+        self.store = hs.get_datastores().main
+        self._storage_controllers = hs.get_storage_controllers()
         self.is_mine = hs.is_mine
 
     async def on_POST(
@@ -490,8 +492,11 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet):
         )
 
         # send invite if room has "JoinRules.INVITE"
-        room_state = await self.state_handler.get_current_state(room_id)
-        join_rules_event = room_state.get((EventTypes.JoinRules, ""))
+        join_rules_event = (
+            await self._storage_controllers.state.get_current_state_event(
+                room_id, EventTypes.JoinRules, ""
+            )
+        )
         if join_rules_event:
             if not (join_rules_event.content.get("join_rule") == JoinRules.PUBLIC):
                 # update_membership with an action of "invite" can raise a
@@ -536,6 +541,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
         super().__init__(hs)
         self.auth = hs.get_auth()
         self.store = hs.get_datastores().main
+        self._state_storage_controller = hs.get_storage_controllers().state
         self.event_creation_handler = hs.get_event_creation_handler()
         self.state_handler = hs.get_state_handler()
         self.is_mine_id = hs.is_mine_id
@@ -553,12 +559,22 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
         user_to_add = content.get("user_id", requester.user.to_string())
 
         # Figure out which local users currently have power in the room, if any.
-        room_state = await self.state_handler.get_current_state(room_id)
-        if not room_state:
+        filtered_room_state = await self._state_storage_controller.get_current_state(
+            room_id,
+            StateFilter.from_types(
+                [
+                    (EventTypes.Create, ""),
+                    (EventTypes.PowerLevels, ""),
+                    (EventTypes.JoinRules, ""),
+                    (EventTypes.Member, user_to_add),
+                ]
+            ),
+        )
+        if not filtered_room_state:
             raise SynapseError(HTTPStatus.BAD_REQUEST, "Server not in room")
 
-        create_event = room_state[(EventTypes.Create, "")]
-        power_levels = room_state.get((EventTypes.PowerLevels, ""))
+        create_event = filtered_room_state[(EventTypes.Create, "")]
+        power_levels = filtered_room_state.get((EventTypes.PowerLevels, ""))
 
         if power_levels is not None:
             # We pick the local user with the highest power.
@@ -634,7 +650,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
 
         # Now we check if the user we're granting admin rights to is already in
         # the room. If not and it's not a public room we invite them.
-        member_event = room_state.get((EventTypes.Member, user_to_add))
+        member_event = filtered_room_state.get((EventTypes.Member, user_to_add))
         is_joined = False
         if member_event:
             is_joined = member_event.content["membership"] in (
@@ -645,7 +661,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
         if is_joined:
             return HTTPStatus.OK, {}
 
-        join_rules = room_state.get((EventTypes.JoinRules, ""))
+        join_rules = filtered_room_state.get((EventTypes.JoinRules, ""))
         is_public = False
         if join_rules:
             is_public = join_rules.content.get("join_rule") == JoinRules.PUBLIC
diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index 7a5ce8ad0e..a26e976492 100644
--- a/synapse/rest/client/room.py
+++ b/synapse/rest/client/room.py
@@ -650,6 +650,7 @@ class RoomEventServlet(RestServlet):
         self.clock = hs.get_clock()
         self._store = hs.get_datastores().main
         self._state = hs.get_state_handler()
+        self._storage_controllers = hs.get_storage_controllers()
         self.event_handler = hs.get_event_handler()
         self._event_serializer = hs.get_event_client_serializer()
         self._relations_handler = hs.get_relations_handler()
@@ -673,8 +674,10 @@ class RoomEventServlet(RestServlet):
         if include_unredacted_content and not await self.auth.is_server_admin(
             requester.user
         ):
-            power_level_event = await self._state.get_current_state(
-                room_id, EventTypes.PowerLevels, ""
+            power_level_event = (
+                await self._storage_controllers.state.get_current_state_event(
+                    room_id, EventTypes.PowerLevels, ""
+                )
             )
 
             auth_events = {}
diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py
index b5f3a0c74e..6863020778 100644
--- a/synapse/server_notices/resource_limits_server_notices.py
+++ b/synapse/server_notices/resource_limits_server_notices.py
@@ -36,6 +36,7 @@ class ResourceLimitsServerNotices:
     def __init__(self, hs: "HomeServer"):
         self._server_notices_manager = hs.get_server_notices_manager()
         self._store = hs.get_datastores().main
+        self._storage_controllers = hs.get_storage_controllers()
         self._auth = hs.get_auth()
         self._config = hs.config
         self._resouce_limited = False
@@ -178,8 +179,10 @@ class ResourceLimitsServerNotices:
         currently_blocked = False
         pinned_state_event = None
         try:
-            pinned_state_event = await self._state.get_current_state(
-                room_id, event_type=EventTypes.Pinned
+            pinned_state_event = (
+                await self._storage_controllers.state.get_current_state_event(
+                    room_id, event_type=EventTypes.Pinned, state_key=""
+                )
             )
         except AuthError:
             # The user has yet to join the server notices room
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index bf09f5128a..ab68e2b6a4 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -32,13 +32,11 @@ from typing import (
     Set,
     Tuple,
     Union,
-    overload,
 )
 
 import attr
 from frozendict import frozendict
 from prometheus_client import Counter, Histogram
-from typing_extensions import Literal
 
 from synapse.api.constants import EventTypes
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions
@@ -132,85 +130,20 @@ class StateHandler:
         self._state_resolution_handler = hs.get_state_resolution_handler()
         self._storage_controllers = hs.get_storage_controllers()
 
-    @overload
-    async def get_current_state(
-        self,
-        room_id: str,
-        event_type: Literal[None] = None,
-        state_key: str = "",
-        latest_event_ids: Optional[List[str]] = None,
-    ) -> StateMap[EventBase]:
-        ...
-
-    @overload
-    async def get_current_state(
-        self,
-        room_id: str,
-        event_type: str,
-        state_key: str = "",
-        latest_event_ids: Optional[List[str]] = None,
-    ) -> Optional[EventBase]:
-        ...
-
-    async def get_current_state(
+    async def get_current_state_ids(
         self,
         room_id: str,
-        event_type: Optional[str] = None,
-        state_key: str = "",
-        latest_event_ids: Optional[List[str]] = None,
-    ) -> Union[Optional[EventBase], StateMap[EventBase]]:
-        """Retrieves the current state for the room. This is done by
-        calling `get_latest_events_in_room` to get the leading edges of the
-        event graph and then resolving any of the state conflicts.
-
-        This is equivalent to getting the state of an event that were to send
-        next before receiving any new events.
-
-        Returns:
-            If `event_type` is specified, then the method returns only the one
-            event (or None) with that `event_type` and `state_key`.
-
-            Otherwise, a map from (type, state_key) to event.
-        """
-        if not latest_event_ids:
-            latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
-        assert latest_event_ids is not None
-
-        logger.debug("calling resolve_state_groups from get_current_state")
-        ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
-        state = ret.state
-
-        if event_type:
-            event_id = state.get((event_type, state_key))
-            event = None
-            if event_id:
-                event = await self.store.get_event(event_id, allow_none=True)
-            return event
-
-        state_map = await self.store.get_events(
-            list(state.values()), get_prev_content=False
-        )
-        return {
-            key: state_map[e_id] for key, e_id in state.items() if e_id in state_map
-        }
-
-    async def get_current_state_ids(
-        self, room_id: str, latest_event_ids: Optional[Collection[str]] = None
+        latest_event_ids: Collection[str],
     ) -> StateMap[str]:
         """Get the current state, or the state at a set of events, for a room
 
         Args:
             room_id:
-            latest_event_ids: if given, the forward extremities to resolve. If
-                None, we look them up from the database (via a cache).
+            latest_event_ids: The forward extremities to resolve.
 
         Returns:
             the state dict, mapping from (event_type, state_key) -> event_id
         """
-        if not latest_event_ids:
-            latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
-        assert latest_event_ids is not None
-
         logger.debug("calling resolve_state_groups from get_current_state_ids")
         ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
         return ret.state
diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py
index 9952b00493..63a78ebc87 100644
--- a/synapse/storage/controllers/state.py
+++ b/synapse/storage/controllers/state.py
@@ -455,3 +455,30 @@ class StateStorageController:
         return await self.stores.main.get_partial_current_state_deltas(
             prev_stream_id, max_stream_id
         )
+
+    async def get_current_state(
+        self, room_id: str, state_filter: Optional[StateFilter] = None
+    ) -> StateMap[EventBase]:
+        """Same as `get_current_state_ids` but also fetches the events"""
+        state_map_ids = await self.get_current_state_ids(room_id, state_filter)
+
+        event_map = await self.stores.main.get_events(list(state_map_ids.values()))
+
+        state_map = {}
+        for key, event_id in state_map_ids.items():
+            event = event_map.get(event_id)
+            if event:
+                state_map[key] = event
+
+        return state_map
+
+    async def get_current_state_event(
+        self, room_id: str, event_type: str, state_key: str
+    ) -> Optional[EventBase]:
+        """Get the current state event for the given type/state_key."""
+
+        key = (event_type, state_key)
+        state_map = await self.get_current_state(
+            room_id, StateFilter.from_types((key,))
+        )
+        return state_map.get(key)
diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py
index b19365b81a..413b3c9426 100644
--- a/tests/federation/test_federation_server.py
+++ b/tests/federation/test_federation_server.py
@@ -134,6 +134,8 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
         super().prepare(reactor, clock, hs)
 
+        self._storage_controllers = hs.get_storage_controllers()
+
         # create the room
         creator_user_id = self.register_user("kermit", "test")
         tok = self.login("kermit", "test")
@@ -207,7 +209,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
 
         # the room should show that the new user is a member
         r = self.get_success(
-            self.hs.get_state_handler().get_current_state(self._room_id)
+            self._storage_controllers.state.get_current_state(self._room_id)
         )
         self.assertEqual(r[("m.room.member", joining_user)].membership, "join")
 
@@ -258,7 +260,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
 
         # the room should show that the new user is a member
         r = self.get_success(
-            self.hs.get_state_handler().get_current_state(self._room_id)
+            self._storage_controllers.state.get_current_state(self._room_id)
         )
         self.assertEqual(r[("m.room.member", joining_user)].membership, "join")
 
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index 11ad44223d..53d49ca896 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -298,6 +298,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
         self.store = hs.get_datastores().main
         self.handler = hs.get_directory_handler()
         self.state_handler = hs.get_state_handler()
+        self._storage_controllers = hs.get_storage_controllers()
 
         # Create user
         self.admin_user = self.register_user("admin", "pass", admin=True)
@@ -335,7 +336,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
     def _get_canonical_alias(self):
         """Get the canonical alias state of the room."""
         return self.get_success(
-            self.state_handler.get_current_state(
+            self._storage_controllers.state.get_current_state_event(
                 self.room_id, EventTypes.CanonicalAlias, ""
             )
         )
diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py
index a76718e8f9..2ff88e64a5 100644
--- a/tests/storage/test_events.py
+++ b/tests/storage/test_events.py
@@ -32,6 +32,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
     def prepare(self, reactor, clock, homeserver):
         self.state = self.hs.get_state_handler()
         self._persistence = self.hs.get_storage_controllers().persistence
+        self._state_storage_controller = self.hs.get_storage_controllers().state
         self.store = self.hs.get_datastores().main
 
         self.register_user("user", "pass")
@@ -104,7 +105,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
         )
 
         state_before_gap = self.get_success(
-            self.state.get_current_state_ids(self.room_id)
+            self._state_storage_controller.get_current_state_ids(self.room_id)
         )
 
         self.persist_event(remote_event_2, state=state_before_gap)
@@ -137,7 +138,9 @@ class ExtremPruneTestCase(HomeserverTestCase):
         # setting. The state resolution across the old and new event will then
         # include it, and so the resolved state won't match the new state.
         state_before_gap = dict(
-            self.get_success(self.state.get_current_state_ids(self.room_id))
+            self.get_success(
+                self._state_storage_controller.get_current_state_ids(self.room_id)
+            )
         )
         state_before_gap.pop(("m.room.history_visibility", ""))
 
@@ -181,7 +184,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
         )
 
         state_before_gap = self.get_success(
-            self.state.get_current_state_ids(self.room_id)
+            self._state_storage_controller.get_current_state_ids(self.room_id)
         )
 
         self.persist_event(remote_event_2, state=state_before_gap)
@@ -213,7 +216,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
         )
 
         state_before_gap = self.get_success(
-            self.state.get_current_state_ids(self.room_id)
+            self._state_storage_controller.get_current_state_ids(self.room_id)
         )
 
         self.persist_event(remote_event_2, state=state_before_gap)
@@ -255,7 +258,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
         )
 
         state_before_gap = self.get_success(
-            self.state.get_current_state_ids(self.room_id)
+            self._state_storage_controller.get_current_state_ids(self.room_id)
         )
 
         self.persist_event(remote_event_2, state=state_before_gap)
@@ -299,7 +302,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
         )
 
         state_before_gap = self.get_success(
-            self.state.get_current_state_ids(self.room_id)
+            self._state_storage_controller.get_current_state_ids(self.room_id)
         )
 
         self.persist_event(remote_event_2, state=state_before_gap)
@@ -335,7 +338,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
         )
 
         state_before_gap = self.get_success(
-            self.state.get_current_state_ids(self.room_id)
+            self._state_storage_controller.get_current_state_ids(self.room_id)
         )
 
         self.persist_event(remote_event_2, state=state_before_gap)
diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
index 92cd0dfc05..8dfaa0559b 100644
--- a/tests/storage/test_purge.py
+++ b/tests/storage/test_purge.py
@@ -102,9 +102,10 @@ class PurgeTests(HomeserverTestCase):
         first = self.helper.send(self.room_id, body="test1")
 
         # Get the current room state.
-        state_handler = self.hs.get_state_handler()
         create_event = self.get_success(
-            state_handler.get_current_state(self.room_id, "m.room.create", "")
+            self._storage_controllers.state.get_current_state_event(
+                self.room_id, "m.room.create", ""
+            )
         )
         self.assertIsNotNone(create_event)
 
diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index d497a19f63..3c79dabc9f 100644
--- a/tests/storage/test_room.py
+++ b/tests/storage/test_room.py
@@ -72,7 +72,7 @@ class RoomEventsStoreTestCase(HomeserverTestCase):
         # Room events need the full datastore, for persist_event() and
         # get_room_state()
         self.store = hs.get_datastores().main
-        self._storage = hs.get_storage_controllers()
+        self._storage_controllers = hs.get_storage_controllers()
         self.event_factory = hs.get_event_factory()
 
         self.room = RoomID.from_string("!abcde:test")
@@ -88,7 +88,7 @@ class RoomEventsStoreTestCase(HomeserverTestCase):
 
     def inject_room_event(self, **kwargs):
         self.get_success(
-            self._storage.persistence.persist_event(
+            self._storage_controllers.persistence.persist_event(
                 self.event_factory.create_event(room_id=self.room.to_string(), **kwargs)
             )
         )
@@ -101,7 +101,9 @@ class RoomEventsStoreTestCase(HomeserverTestCase):
         )
 
         state = self.get_success(
-            self.store.get_current_state(room_id=self.room.to_string())
+            self._storage_controllers.state.get_current_state(
+                room_id=self.room.to_string()
+            )
         )
 
         self.assertEqual(1, len(state))
@@ -118,7 +120,9 @@ class RoomEventsStoreTestCase(HomeserverTestCase):
         )
 
         state = self.get_success(
-            self.store.get_current_state(room_id=self.room.to_string())
+            self._storage_controllers.state.get_current_state(
+                room_id=self.room.to_string()
+            )
         )
 
         self.assertEqual(1, len(state))