summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/admin.py4
-rw-r--r--synapse/handlers/device.py10
-rw-r--r--synapse/handlers/federation.py6
-rw-r--r--synapse/handlers/federation_event.py186
-rw-r--r--synapse/handlers/groups_local.py503
-rw-r--r--synapse/handlers/initial_sync.py6
-rw-r--r--synapse/handlers/message.py50
-rw-r--r--synapse/handlers/pagination.py4
-rw-r--r--synapse/handlers/room.py4
-rw-r--r--synapse/handlers/room_batch.py4
-rw-r--r--synapse/handlers/room_member.py11
-rw-r--r--synapse/handlers/room_summary.py6
-rw-r--r--synapse/handlers/search.py4
-rw-r--r--synapse/handlers/sync.py89
14 files changed, 171 insertions, 716 deletions
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index 96376963f2..50e34743b7 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -31,7 +31,7 @@ class AdminHandler:
     def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastores().main
         self.storage = hs.get_storage()
-        self.state_store = self.storage.state
+        self.state_storage = self.storage.state
 
     async def get_whois(self, user: UserID) -> JsonDict:
         connections = []
@@ -233,7 +233,7 @@ class AdminHandler:
             for event_id in extremities:
                 if not event_to_unseen_prevs[event_id]:
                     continue
-                state = await self.state_store.get_state_for_event(event_id)
+                state = await self.state_storage.get_state_for_event(event_id)
                 writer.write_state(room_id, event_id, state)
 
         return writer.finished()
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 1d6d1f8a92..b21e469865 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -70,7 +70,7 @@ class DeviceWorkerHandler:
         self.store = hs.get_datastores().main
         self.notifier = hs.get_notifier()
         self.state = hs.get_state_handler()
-        self.state_store = hs.get_storage().state
+        self.state_storage = hs.get_storage().state
         self._auth_handler = hs.get_auth_handler()
         self.server_name = hs.hostname
 
@@ -203,7 +203,9 @@ class DeviceWorkerHandler:
                 continue
 
             # mapping from event_id -> state_dict
-            prev_state_ids = await self.state_store.get_state_ids_for_events(event_ids)
+            prev_state_ids = await self.state_storage.get_state_ids_for_events(
+                event_ids
+            )
 
             # Check if we've joined the room? If so we just blindly add all the users to
             # the "possibly changed" users.
@@ -763,6 +765,10 @@ class DeviceListUpdater:
         device_id = edu_content.pop("device_id")
         stream_id = str(edu_content.pop("stream_id"))  # They may come as ints
         prev_ids = edu_content.pop("prev_id", [])
+        if not isinstance(prev_ids, list):
+            raise SynapseError(
+                400, "Device list update had an invalid 'prev_ids' field"
+            )
         prev_ids = [str(p) for p in prev_ids]  # They may come as ints
 
         if get_domain_from_id(user_id) != origin:
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 0386d0a07b..c8233270d7 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -126,7 +126,7 @@ class FederationHandler:
 
         self.store = hs.get_datastores().main
         self.storage = hs.get_storage()
-        self.state_store = self.storage.state
+        self.state_storage = self.storage.state
         self.federation_client = hs.get_federation_client()
         self.state_handler = hs.get_state_handler()
         self.server_name = hs.hostname
@@ -1027,7 +1027,9 @@ class FederationHandler:
         if event.internal_metadata.outlier:
             raise NotFoundError("State not known at event %s" % (event_id,))
 
-        state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id])
+        state_groups = await self.state_storage.get_state_groups_ids(
+            room_id, [event_id]
+        )
 
         # get_state_groups_ids should return exactly one result
         assert len(state_groups) == 1
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index ca82df8a6d..a1361af272 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -99,7 +99,7 @@ class FederationEventHandler:
     def __init__(self, hs: "HomeServer"):
         self._store = hs.get_datastores().main
         self._storage = hs.get_storage()
-        self._state_store = self._storage.state
+        self._state_storage = self._storage.state
 
         self._state_handler = hs.get_state_handler()
         self._event_creation_handler = hs.get_event_creation_handler()
@@ -274,7 +274,7 @@ class FederationEventHandler:
                     affected=pdu.event_id,
                 )
 
-        await self._process_received_pdu(origin, pdu, state=None)
+        await self._process_received_pdu(origin, pdu, state_ids=None)
 
     async def on_send_membership_event(
         self, origin: str, event: EventBase
@@ -463,7 +463,9 @@ class FederationEventHandler:
         with nested_logging_context(suffix=event.event_id):
             context = await self._state_handler.compute_event_context(
                 event,
-                old_state=state,
+                state_ids_before_event={
+                    (e.type, e.state_key): e.event_id for e in state
+                },
                 partial_state=partial_state,
             )
 
@@ -512,12 +514,12 @@ class FederationEventHandler:
             #
             # This is the same operation as we do when we receive a regular event
             # over federation.
-            state = await self._resolve_state_at_missing_prevs(destination, event)
+            state_ids = await self._resolve_state_at_missing_prevs(destination, event)
 
             # build a new state group for it if need be
             context = await self._state_handler.compute_event_context(
                 event,
-                old_state=state,
+                state_ids_before_event=state_ids,
             )
             if context.partial_state:
                 # this can happen if some or all of the event's prev_events still have
@@ -533,7 +535,7 @@ class FederationEventHandler:
                 )
                 return
             await self._store.update_state_for_partial_state_event(event, context)
-            self._state_store.notify_event_un_partial_stated(event.event_id)
+            self._state_storage.notify_event_un_partial_stated(event.event_id)
 
     async def backfill(
         self, dest: str, room_id: str, limit: int, extremities: Collection[str]
@@ -767,11 +769,12 @@ class FederationEventHandler:
             return
 
         try:
-            state = await self._resolve_state_at_missing_prevs(origin, event)
+            state_ids = await self._resolve_state_at_missing_prevs(origin, event)
             # TODO(faster_joins): make sure that _resolve_state_at_missing_prevs does
             #   not return partial state
+
             await self._process_received_pdu(
-                origin, event, state=state, backfilled=backfilled
+                origin, event, state_ids=state_ids, backfilled=backfilled
             )
         except FederationError as e:
             if e.code == 403:
@@ -781,7 +784,7 @@ class FederationEventHandler:
 
     async def _resolve_state_at_missing_prevs(
         self, dest: str, event: EventBase
-    ) -> Optional[Iterable[EventBase]]:
+    ) -> Optional[StateMap[str]]:
         """Calculate the state at an event with missing prev_events.
 
         This is used when we have pulled a batch of events from a remote server, and
@@ -808,8 +811,8 @@ class FederationEventHandler:
             event: an event to check for missing prevs.
 
         Returns:
-            if we already had all the prev events, `None`. Otherwise, returns a list of
-            the events in the state at `event`.
+            if we already had all the prev events, `None`. Otherwise, returns
+            the event ids of the state at `event`.
         """
         room_id = event.room_id
         event_id = event.event_id
@@ -829,10 +832,10 @@ class FederationEventHandler:
         )
         # Calculate the state after each of the previous events, and
         # resolve them to find the correct state at the current event.
-        event_map = {event_id: event}
+
         try:
             # Get the state of the events we know about
-            ours = await self._state_store.get_state_groups_ids(room_id, seen)
+            ours = await self._state_storage.get_state_groups_ids(room_id, seen)
 
             # state_maps is a list of mappings from (type, state_key) to event_id
             state_maps: List[StateMap[str]] = list(ours.values())
@@ -849,40 +852,23 @@ class FederationEventHandler:
                     # note that if any of the missing prevs share missing state or
                     # auth events, the requests to fetch those events are deduped
                     # by the get_pdu_cache in federation_client.
-                    remote_state = await self._get_state_after_missing_prev_event(
-                        dest, room_id, p
+                    remote_state_map = (
+                        await self._get_state_ids_after_missing_prev_event(
+                            dest, room_id, p
+                        )
                     )
 
-                    remote_state_map = {
-                        (x.type, x.state_key): x.event_id for x in remote_state
-                    }
                     state_maps.append(remote_state_map)
 
-                    for x in remote_state:
-                        event_map[x.event_id] = x
-
             room_version = await self._store.get_room_version_id(room_id)
             state_map = await self._state_resolution_handler.resolve_events_with_store(
                 room_id,
                 room_version,
                 state_maps,
-                event_map,
+                event_map={event_id: event},
                 state_res_store=StateResolutionStore(self._store),
             )
 
-            # We need to give _process_received_pdu the actual state events
-            # rather than event ids, so generate that now.
-
-            # First though we need to fetch all the events that are in
-            # state_map, so we can build up the state below.
-            evs = await self._store.get_events(
-                list(state_map.values()),
-                get_prev_content=False,
-                redact_behaviour=EventRedactBehaviour.as_is,
-            )
-            event_map.update(evs)
-
-            state = [event_map[e] for e in state_map.values()]
         except Exception:
             logger.warning(
                 "Error attempting to resolve state at missing prev_events",
@@ -894,14 +880,14 @@ class FederationEventHandler:
                 "We can't get valid state history.",
                 affected=event_id,
             )
-        return state
+        return state_map
 
-    async def _get_state_after_missing_prev_event(
+    async def _get_state_ids_after_missing_prev_event(
         self,
         destination: str,
         room_id: str,
         event_id: str,
-    ) -> List[EventBase]:
+    ) -> StateMap[str]:
         """Requests all of the room state at a given event from a remote homeserver.
 
         Args:
@@ -910,7 +896,7 @@ class FederationEventHandler:
             event_id: The id of the event we want the state at.
 
         Returns:
-            A list of events in the state, including the event itself
+            The event ids of the state *after* the given event.
         """
         (
             state_event_ids,
@@ -925,19 +911,17 @@ class FederationEventHandler:
             len(auth_event_ids),
         )
 
-        # start by just trying to fetch the events from the store
+        # Start by checking events we already have in the DB
         desired_events = set(state_event_ids)
         desired_events.add(event_id)
         logger.debug("Fetching %i events from cache/store", len(desired_events))
-        fetched_events = await self._store.get_events(
-            desired_events, allow_rejected=True
-        )
+        have_events = await self._store.have_seen_events(room_id, desired_events)
 
-        missing_desired_events = desired_events - fetched_events.keys()
+        missing_desired_events = desired_events - have_events
         logger.debug(
             "We are missing %i events (got %i)",
             len(missing_desired_events),
-            len(fetched_events),
+            len(have_events),
         )
 
         # We probably won't need most of the auth events, so let's just check which
@@ -948,7 +932,7 @@ class FederationEventHandler:
         #   already have a bunch of the state events. It would be nice if the
         #   federation api gave us a way of finding out which we actually need.
 
-        missing_auth_events = set(auth_event_ids) - fetched_events.keys()
+        missing_auth_events = set(auth_event_ids) - have_events
         missing_auth_events.difference_update(
             await self._store.have_seen_events(room_id, missing_auth_events)
         )
@@ -974,47 +958,51 @@ class FederationEventHandler:
                 destination=destination, room_id=room_id, event_ids=missing_events
             )
 
-        # we need to make sure we re-load from the database to get the rejected
-        # state correct.
-        fetched_events.update(
-            await self._store.get_events(missing_desired_events, allow_rejected=True)
-        )
+        # We now need to fill out the state map, which involves fetching the
+        # type and state key for each event ID in the state.
+        state_map = {}
 
-        # check for events which were in the wrong room.
-        #
-        # this can happen if a remote server claims that the state or
-        # auth_events at an event in room A are actually events in room B
-
-        bad_events = [
-            (event_id, event.room_id)
-            for event_id, event in fetched_events.items()
-            if event.room_id != room_id
-        ]
+        event_metadata = await self._store.get_metadata_for_events(state_event_ids)
+        for state_event_id, metadata in event_metadata.items():
+            if metadata.room_id != room_id:
+                # This is a bogus situation, but since we may only discover it a long time
+                # after it happened, we try our best to carry on, by just omitting the
+                # bad events from the returned state set.
+                #
+                # This can happen if a remote server claims that the state or
+                # auth_events at an event in room A are actually events in room B
+                logger.warning(
+                    "Remote server %s claims event %s in room %s is an auth/state "
+                    "event in room %s",
+                    destination,
+                    state_event_id,
+                    metadata.room_id,
+                    room_id,
+                )
+                continue
 
-        for bad_event_id, bad_room_id in bad_events:
-            # This is a bogus situation, but since we may only discover it a long time
-            # after it happened, we try our best to carry on, by just omitting the
-            # bad events from the returned state set.
-            logger.warning(
-                "Remote server %s claims event %s in room %s is an auth/state "
-                "event in room %s",
-                destination,
-                bad_event_id,
-                bad_room_id,
-                room_id,
-            )
+            if metadata.state_key is None:
+                logger.warning(
+                    "Remote server gave us non-state event in state: %s", state_event_id
+                )
+                continue
 
-            del fetched_events[bad_event_id]
+            state_map[(metadata.event_type, metadata.state_key)] = state_event_id
 
         # if we couldn't get the prev event in question, that's a problem.
-        remote_event = fetched_events.get(event_id)
+        remote_event = await self._store.get_event(
+            event_id,
+            allow_none=True,
+            allow_rejected=True,
+            redact_behaviour=EventRedactBehaviour.as_is,
+        )
         if not remote_event:
             raise Exception("Unable to get missing prev_event %s" % (event_id,))
 
         # missing state at that event is a warning, not a blocker
         # XXX: this doesn't sound right? it means that we'll end up with incomplete
         #   state.
-        failed_to_fetch = desired_events - fetched_events.keys()
+        failed_to_fetch = desired_events - event_metadata.keys()
         if failed_to_fetch:
             logger.warning(
                 "Failed to fetch missing state events for %s %s",
@@ -1022,14 +1010,12 @@ class FederationEventHandler:
                 failed_to_fetch,
             )
 
-        remote_state = [
-            fetched_events[e_id] for e_id in state_event_ids if e_id in fetched_events
-        ]
-
         if remote_event.is_state() and remote_event.rejected_reason is None:
-            remote_state.append(remote_event)
+            state_map[
+                (remote_event.type, remote_event.state_key)
+            ] = remote_event.event_id
 
-        return remote_state
+        return state_map
 
     async def _get_state_and_persist(
         self, destination: str, room_id: str, event_id: str
@@ -1056,7 +1042,7 @@ class FederationEventHandler:
         self,
         origin: str,
         event: EventBase,
-        state: Optional[Iterable[EventBase]],
+        state_ids: Optional[StateMap[str]],
         backfilled: bool = False,
     ) -> None:
         """Called when we have a new non-outlier event.
@@ -1078,7 +1064,7 @@ class FederationEventHandler:
 
             event: event to be persisted
 
-            state: Normally None, but if we are handling a gap in the graph
+            state_ids: Normally None, but if we are handling a gap in the graph
                 (ie, we are missing one or more prev_events), the resolved state at the
                 event
 
@@ -1090,7 +1076,8 @@ class FederationEventHandler:
 
         try:
             context = await self._state_handler.compute_event_context(
-                event, old_state=state
+                event,
+                state_ids_before_event=state_ids,
             )
             context = await self._check_event_auth(
                 origin,
@@ -1107,7 +1094,7 @@ class FederationEventHandler:
             # For new (non-backfilled and non-outlier) events we check if the event
             # passes auth based on the current state. If it doesn't then we
             # "soft-fail" the event.
-            await self._check_for_soft_fail(event, state, origin=origin)
+            await self._check_for_soft_fail(event, state_ids, origin=origin)
 
         await self._run_push_actions_and_persist_event(event, context, backfilled)
 
@@ -1589,7 +1576,7 @@ class FederationEventHandler:
     async def _check_for_soft_fail(
         self,
         event: EventBase,
-        state: Optional[Iterable[EventBase]],
+        state_ids: Optional[StateMap[str]],
         origin: str,
     ) -> None:
         """Checks if we should soft fail the event; if so, marks the event as
@@ -1597,7 +1584,7 @@ class FederationEventHandler:
 
         Args:
             event
-            state: The state at the event if we don't have all the event's prev events
+            state_ids: The state at the event if we don't have all the event's prev events
             origin: The host the event originates from.
         """
         extrem_ids_list = await self._store.get_latest_event_ids_in_room(event.room_id)
@@ -1613,7 +1600,7 @@ class FederationEventHandler:
         room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
 
         # Calculate the "current state".
-        if state is not None:
+        if state_ids is not None:
             # If we're explicitly given the state then we won't have all the
             # prev events, and so we have a gap in the graph. In this case
             # we want to be a little careful as we might have been down for
@@ -1626,17 +1613,20 @@ class FederationEventHandler:
             # given state at the event. This should correctly handle cases
             # like bans, especially with state res v2.
 
-            state_sets_d = await self._state_store.get_state_groups(
+            state_sets_d = await self._state_storage.get_state_groups_ids(
                 event.room_id, extrem_ids
             )
-            state_sets: List[Iterable[EventBase]] = list(state_sets_d.values())
-            state_sets.append(state)
-            current_states = await self._state_handler.resolve_events(
-                room_version, state_sets, event
+            state_sets: List[StateMap[str]] = list(state_sets_d.values())
+            state_sets.append(state_ids)
+            current_state_ids = (
+                await self._state_resolution_handler.resolve_events_with_store(
+                    event.room_id,
+                    room_version,
+                    state_sets,
+                    event_map=None,
+                    state_res_store=StateResolutionStore(self._store),
+                )
             )
-            current_state_ids: StateMap[str] = {
-                k: e.event_id for k, e in current_states.items()
-            }
         else:
             current_state_ids = await self._state_handler.get_current_state_ids(
                 event.room_id, latest_event_ids=extrem_ids
@@ -1895,7 +1885,7 @@ class FederationEventHandler:
 
         # create a new state group as a delta from the existing one.
         prev_group = context.state_group
-        state_group = await self._state_store.store_state_group(
+        state_group = await self._state_storage.store_state_group(
             event.event_id,
             event.room_id,
             prev_group=prev_group,
diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py
deleted file mode 100644
index e7a399787b..0000000000
--- a/synapse/handlers/groups_local.py
+++ /dev/null
@@ -1,503 +0,0 @@
-# Copyright 2017 Vector Creations Ltd
-# Copyright 2018 New Vector Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import logging
-from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Iterable, List, Set
-
-from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError
-from synapse.types import GroupID, JsonDict, get_domain_from_id
-
-if TYPE_CHECKING:
-    from synapse.server import HomeServer
-
-logger = logging.getLogger(__name__)
-
-
-def _create_rerouter(func_name: str) -> Callable[..., Awaitable[JsonDict]]:
-    """Returns an async function that looks at the group id and calls the function
-    on federation or the local group server if the group is local
-    """
-
-    async def f(
-        self: "GroupsLocalWorkerHandler", group_id: str, *args: Any, **kwargs: Any
-    ) -> JsonDict:
-        if not GroupID.is_valid(group_id):
-            raise SynapseError(400, "%s is not a legal group ID" % (group_id,))
-
-        if self.is_mine_id(group_id):
-            return await getattr(self.groups_server_handler, func_name)(
-                group_id, *args, **kwargs
-            )
-        else:
-            destination = get_domain_from_id(group_id)
-
-            try:
-                return await getattr(self.transport_client, func_name)(
-                    destination, group_id, *args, **kwargs
-                )
-            except HttpResponseException as e:
-                # Capture errors returned by the remote homeserver and
-                # re-throw specific errors as SynapseErrors. This is so
-                # when the remote end responds with things like 403 Not
-                # In Group, we can communicate that to the client instead
-                # of a 500.
-                raise e.to_synapse_error()
-            except RequestSendFailed:
-                raise SynapseError(502, "Failed to contact group server")
-
-    return f
-
-
-class GroupsLocalWorkerHandler:
-    def __init__(self, hs: "HomeServer"):
-        self.hs = hs
-        self.store = hs.get_datastores().main
-        self.room_list_handler = hs.get_room_list_handler()
-        self.groups_server_handler = hs.get_groups_server_handler()
-        self.transport_client = hs.get_federation_transport_client()
-        self.auth = hs.get_auth()
-        self.clock = hs.get_clock()
-        self.keyring = hs.get_keyring()
-        self.is_mine_id = hs.is_mine_id
-        self.signing_key = hs.signing_key
-        self.server_name = hs.hostname
-        self.notifier = hs.get_notifier()
-        self.attestations = hs.get_groups_attestation_signing()
-
-        self.profile_handler = hs.get_profile_handler()
-
-    # The following functions merely route the query to the local groups server
-    # or federation depending on if the group is local or remote
-
-    get_group_profile = _create_rerouter("get_group_profile")
-    get_rooms_in_group = _create_rerouter("get_rooms_in_group")
-    get_invited_users_in_group = _create_rerouter("get_invited_users_in_group")
-    get_group_category = _create_rerouter("get_group_category")
-    get_group_categories = _create_rerouter("get_group_categories")
-    get_group_role = _create_rerouter("get_group_role")
-    get_group_roles = _create_rerouter("get_group_roles")
-
-    async def get_group_summary(
-        self, group_id: str, requester_user_id: str
-    ) -> JsonDict:
-        """Get the group summary for a group.
-
-        If the group is remote we check that the users have valid attestations.
-        """
-        if self.is_mine_id(group_id):
-            res = await self.groups_server_handler.get_group_summary(
-                group_id, requester_user_id
-            )
-        else:
-            try:
-                res = await self.transport_client.get_group_summary(
-                    get_domain_from_id(group_id), group_id, requester_user_id
-                )
-            except HttpResponseException as e:
-                raise e.to_synapse_error()
-            except RequestSendFailed:
-                raise SynapseError(502, "Failed to contact group server")
-
-            group_server_name = get_domain_from_id(group_id)
-
-            # Loop through the users and validate the attestations.
-            chunk = res["users_section"]["users"]
-            valid_users = []
-            for entry in chunk:
-                g_user_id = entry["user_id"]
-                attestation = entry.pop("attestation", {})
-                try:
-                    if get_domain_from_id(g_user_id) != group_server_name:
-                        await self.attestations.verify_attestation(
-                            attestation,
-                            group_id=group_id,
-                            user_id=g_user_id,
-                            server_name=get_domain_from_id(g_user_id),
-                        )
-                    valid_users.append(entry)
-                except Exception as e:
-                    logger.info("Failed to verify user is in group: %s", e)
-
-            res["users_section"]["users"] = valid_users
-
-            res["users_section"]["users"].sort(key=lambda e: e.get("order", 0))
-            res["rooms_section"]["rooms"].sort(key=lambda e: e.get("order", 0))
-
-        # Add `is_publicised` flag to indicate whether the user has publicised their
-        # membership of the group on their profile
-        result = await self.store.get_publicised_groups_for_user(requester_user_id)
-        is_publicised = group_id in result
-
-        res.setdefault("user", {})["is_publicised"] = is_publicised
-
-        return res
-
-    async def get_users_in_group(
-        self, group_id: str, requester_user_id: str
-    ) -> JsonDict:
-        """Get users in a group"""
-        if self.is_mine_id(group_id):
-            return await self.groups_server_handler.get_users_in_group(
-                group_id, requester_user_id
-            )
-
-        group_server_name = get_domain_from_id(group_id)
-
-        try:
-            res = await self.transport_client.get_users_in_group(
-                get_domain_from_id(group_id), group_id, requester_user_id
-            )
-        except HttpResponseException as e:
-            raise e.to_synapse_error()
-        except RequestSendFailed:
-            raise SynapseError(502, "Failed to contact group server")
-
-        chunk = res["chunk"]
-        valid_entries = []
-        for entry in chunk:
-            g_user_id = entry["user_id"]
-            attestation = entry.pop("attestation", {})
-            try:
-                if get_domain_from_id(g_user_id) != group_server_name:
-                    await self.attestations.verify_attestation(
-                        attestation,
-                        group_id=group_id,
-                        user_id=g_user_id,
-                        server_name=get_domain_from_id(g_user_id),
-                    )
-                valid_entries.append(entry)
-            except Exception as e:
-                logger.info("Failed to verify user is in group: %s", e)
-
-        res["chunk"] = valid_entries
-
-        return res
-
-    async def get_joined_groups(self, user_id: str) -> JsonDict:
-        group_ids = await self.store.get_joined_groups(user_id)
-        return {"groups": group_ids}
-
-    async def get_publicised_groups_for_user(self, user_id: str) -> JsonDict:
-        if self.hs.is_mine_id(user_id):
-            result = await self.store.get_publicised_groups_for_user(user_id)
-
-            # Check AS associated groups for this user - this depends on the
-            # RegExps in the AS registration file (under `users`)
-            for app_service in self.store.get_app_services():
-                result.extend(app_service.get_groups_for_user(user_id))
-
-            return {"groups": result}
-        else:
-            try:
-                bulk_result = await self.transport_client.bulk_get_publicised_groups(
-                    get_domain_from_id(user_id), [user_id]
-                )
-            except HttpResponseException as e:
-                raise e.to_synapse_error()
-            except RequestSendFailed:
-                raise SynapseError(502, "Failed to contact group server")
-
-            result = bulk_result.get("users", {}).get(user_id)
-            # TODO: Verify attestations
-            return {"groups": result}
-
-    async def bulk_get_publicised_groups(
-        self, user_ids: Iterable[str], proxy: bool = True
-    ) -> JsonDict:
-        destinations: Dict[str, Set[str]] = {}
-        local_users = set()
-
-        for user_id in user_ids:
-            if self.hs.is_mine_id(user_id):
-                local_users.add(user_id)
-            else:
-                destinations.setdefault(get_domain_from_id(user_id), set()).add(user_id)
-
-        if not proxy and destinations:
-            raise SynapseError(400, "Some user_ids are not local")
-
-        results = {}
-        failed_results: List[str] = []
-        for destination, dest_user_ids in destinations.items():
-            try:
-                r = await self.transport_client.bulk_get_publicised_groups(
-                    destination, list(dest_user_ids)
-                )
-                results.update(r["users"])
-            except Exception:
-                failed_results.extend(dest_user_ids)
-
-        for uid in local_users:
-            results[uid] = await self.store.get_publicised_groups_for_user(uid)
-
-            # Check AS associated groups for this user - this depends on the
-            # RegExps in the AS registration file (under `users`)
-            for app_service in self.store.get_app_services():
-                results[uid].extend(app_service.get_groups_for_user(uid))
-
-        return {"users": results}
-
-
-class GroupsLocalHandler(GroupsLocalWorkerHandler):
-    def __init__(self, hs: "HomeServer"):
-        super().__init__(hs)
-
-        # Ensure attestations get renewed
-        hs.get_groups_attestation_renewer()
-
-    # The following functions merely route the query to the local groups server
-    # or federation depending on if the group is local or remote
-
-    update_group_profile = _create_rerouter("update_group_profile")
-
-    add_room_to_group = _create_rerouter("add_room_to_group")
-    update_room_in_group = _create_rerouter("update_room_in_group")
-    remove_room_from_group = _create_rerouter("remove_room_from_group")
-
-    update_group_summary_room = _create_rerouter("update_group_summary_room")
-    delete_group_summary_room = _create_rerouter("delete_group_summary_room")
-
-    update_group_category = _create_rerouter("update_group_category")
-    delete_group_category = _create_rerouter("delete_group_category")
-
-    update_group_summary_user = _create_rerouter("update_group_summary_user")
-    delete_group_summary_user = _create_rerouter("delete_group_summary_user")
-
-    update_group_role = _create_rerouter("update_group_role")
-    delete_group_role = _create_rerouter("delete_group_role")
-
-    set_group_join_policy = _create_rerouter("set_group_join_policy")
-
-    async def create_group(
-        self, group_id: str, user_id: str, content: JsonDict
-    ) -> JsonDict:
-        """Create a group"""
-
-        logger.info("Asking to create group with ID: %r", group_id)
-
-        if self.is_mine_id(group_id):
-            res = await self.groups_server_handler.create_group(
-                group_id, user_id, content
-            )
-            local_attestation = None
-            remote_attestation = None
-        else:
-            raise SynapseError(400, "Unable to create remote groups")
-
-        is_publicised = content.get("publicise", False)
-        token = await self.store.register_user_group_membership(
-            group_id,
-            user_id,
-            membership="join",
-            is_admin=True,
-            local_attestation=local_attestation,
-            remote_attestation=remote_attestation,
-            is_publicised=is_publicised,
-        )
-        self.notifier.on_new_event("groups_key", token, users=[user_id])
-
-        return res
-
-    async def join_group(
-        self, group_id: str, user_id: str, content: JsonDict
-    ) -> JsonDict:
-        """Request to join a group"""
-        if self.is_mine_id(group_id):
-            await self.groups_server_handler.join_group(group_id, user_id, content)
-            local_attestation = None
-            remote_attestation = None
-        else:
-            local_attestation = self.attestations.create_attestation(group_id, user_id)
-            content["attestation"] = local_attestation
-
-            try:
-                res = await self.transport_client.join_group(
-                    get_domain_from_id(group_id), group_id, user_id, content
-                )
-            except HttpResponseException as e:
-                raise e.to_synapse_error()
-            except RequestSendFailed:
-                raise SynapseError(502, "Failed to contact group server")
-
-            remote_attestation = res["attestation"]
-
-            await self.attestations.verify_attestation(
-                remote_attestation,
-                group_id=group_id,
-                user_id=user_id,
-                server_name=get_domain_from_id(group_id),
-            )
-
-        # TODO: Check that the group is public and we're being added publicly
-        is_publicised = content.get("publicise", False)
-
-        token = await self.store.register_user_group_membership(
-            group_id,
-            user_id,
-            membership="join",
-            is_admin=False,
-            local_attestation=local_attestation,
-            remote_attestation=remote_attestation,
-            is_publicised=is_publicised,
-        )
-        self.notifier.on_new_event("groups_key", token, users=[user_id])
-
-        return {}
-
-    async def accept_invite(
-        self, group_id: str, user_id: str, content: JsonDict
-    ) -> JsonDict:
-        """Accept an invite to a group"""
-        if self.is_mine_id(group_id):
-            await self.groups_server_handler.accept_invite(group_id, user_id, content)
-            local_attestation = None
-            remote_attestation = None
-        else:
-            local_attestation = self.attestations.create_attestation(group_id, user_id)
-            content["attestation"] = local_attestation
-
-            try:
-                res = await self.transport_client.accept_group_invite(
-                    get_domain_from_id(group_id), group_id, user_id, content
-                )
-            except HttpResponseException as e:
-                raise e.to_synapse_error()
-            except RequestSendFailed:
-                raise SynapseError(502, "Failed to contact group server")
-
-            remote_attestation = res["attestation"]
-
-            await self.attestations.verify_attestation(
-                remote_attestation,
-                group_id=group_id,
-                user_id=user_id,
-                server_name=get_domain_from_id(group_id),
-            )
-
-        # TODO: Check that the group is public and we're being added publicly
-        is_publicised = content.get("publicise", False)
-
-        token = await self.store.register_user_group_membership(
-            group_id,
-            user_id,
-            membership="join",
-            is_admin=False,
-            local_attestation=local_attestation,
-            remote_attestation=remote_attestation,
-            is_publicised=is_publicised,
-        )
-        self.notifier.on_new_event("groups_key", token, users=[user_id])
-
-        return {}
-
-    async def invite(
-        self, group_id: str, user_id: str, requester_user_id: str, config: JsonDict
-    ) -> JsonDict:
-        """Invite a user to a group"""
-        content = {"requester_user_id": requester_user_id, "config": config}
-        if self.is_mine_id(group_id):
-            res = await self.groups_server_handler.invite_to_group(
-                group_id, user_id, requester_user_id, content
-            )
-        else:
-            try:
-                res = await self.transport_client.invite_to_group(
-                    get_domain_from_id(group_id),
-                    group_id,
-                    user_id,
-                    requester_user_id,
-                    content,
-                )
-            except HttpResponseException as e:
-                raise e.to_synapse_error()
-            except RequestSendFailed:
-                raise SynapseError(502, "Failed to contact group server")
-
-        return res
-
-    async def on_invite(
-        self, group_id: str, user_id: str, content: JsonDict
-    ) -> JsonDict:
-        """One of our users were invited to a group"""
-        # TODO: Support auto join and rejection
-
-        if not self.is_mine_id(user_id):
-            raise SynapseError(400, "User not on this server")
-
-        local_profile = {}
-        if "profile" in content:
-            if "name" in content["profile"]:
-                local_profile["name"] = content["profile"]["name"]
-            if "avatar_url" in content["profile"]:
-                local_profile["avatar_url"] = content["profile"]["avatar_url"]
-
-        token = await self.store.register_user_group_membership(
-            group_id,
-            user_id,
-            membership="invite",
-            content={"profile": local_profile, "inviter": content["inviter"]},
-        )
-        self.notifier.on_new_event("groups_key", token, users=[user_id])
-        try:
-            user_profile = await self.profile_handler.get_profile(user_id)
-        except Exception as e:
-            logger.warning("No profile for user %s: %s", user_id, e)
-            user_profile = {}
-
-        return {"state": "invite", "user_profile": user_profile}
-
-    async def remove_user_from_group(
-        self, group_id: str, user_id: str, requester_user_id: str, content: JsonDict
-    ) -> JsonDict:
-        """Remove a user from a group"""
-        if user_id == requester_user_id:
-            token = await self.store.register_user_group_membership(
-                group_id, user_id, membership="leave"
-            )
-            self.notifier.on_new_event("groups_key", token, users=[user_id])
-
-            # TODO: Should probably remember that we tried to leave so that we can
-            # retry if the group server is currently down.
-
-        if self.is_mine_id(group_id):
-            res = await self.groups_server_handler.remove_user_from_group(
-                group_id, user_id, requester_user_id, content
-            )
-        else:
-            content["requester_user_id"] = requester_user_id
-            try:
-                res = await self.transport_client.remove_user_from_group(
-                    get_domain_from_id(group_id),
-                    group_id,
-                    requester_user_id,
-                    user_id,
-                    content,
-                )
-            except HttpResponseException as e:
-                raise e.to_synapse_error()
-            except RequestSendFailed:
-                raise SynapseError(502, "Failed to contact group server")
-
-        return res
-
-    async def user_removed_from_group(
-        self, group_id: str, user_id: str, content: JsonDict
-    ) -> None:
-        """One of our users was removed/kicked from a group"""
-        # TODO: Check if user in group
-        token = await self.store.register_user_group_membership(
-            group_id, user_id, membership="leave"
-        )
-        self.notifier.on_new_event("groups_key", token, users=[user_id])
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index d79248ad90..c06932a41a 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -68,7 +68,7 @@ class InitialSyncHandler:
         ] = ResponseCache(hs.get_clock(), "initial_sync_cache")
         self._event_serializer = hs.get_event_client_serializer()
         self.storage = hs.get_storage()
-        self.state_store = self.storage.state
+        self.state_storage = self.storage.state
 
     async def snapshot_all_rooms(
         self,
@@ -198,7 +198,7 @@ class InitialSyncHandler:
                         event.stream_ordering,
                     )
                     deferred_room_state = run_in_background(
-                        self.state_store.get_state_for_events, [event.event_id]
+                        self.state_storage.get_state_for_events, [event.event_id]
                     ).addCallback(
                         lambda states: cast(StateMap[EventBase], states[event.event_id])
                     )
@@ -355,7 +355,7 @@ class InitialSyncHandler:
         member_event_id: str,
         is_peeking: bool,
     ) -> JsonDict:
-        room_state = await self.state_store.get_state_for_event(member_event_id)
+        room_state = await self.state_storage.get_state_for_event(member_event_id)
 
         limit = pagin_config.limit if pagin_config else None
         if limit is None:
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index cb1bc4c06f..7ca126dbd1 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -55,7 +55,14 @@ from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.replication.http.send_event import ReplicationSendEventRestServlet
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
 from synapse.storage.state import StateFilter
-from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester
+from synapse.types import (
+    MutableStateMap,
+    Requester,
+    RoomAlias,
+    StreamToken,
+    UserID,
+    create_requester,
+)
 from synapse.util import json_decoder, json_encoder, log_failure, unwrapFirstError
 from synapse.util.async_helpers import Linearizer, gather_results
 from synapse.util.caches.expiringcache import ExpiringCache
@@ -78,7 +85,7 @@ class MessageHandler:
         self.state = hs.get_state_handler()
         self.store = hs.get_datastores().main
         self.storage = hs.get_storage()
-        self.state_store = self.storage.state
+        self.state_storage = self.storage.state
         self._event_serializer = hs.get_event_client_serializer()
         self._ephemeral_events_enabled = hs.config.server.enable_ephemeral_messages
 
@@ -125,7 +132,7 @@ class MessageHandler:
             assert (
                 membership_event_id is not None
             ), "check_user_in_room_or_world_readable returned invalid data"
-            room_state = await self.state_store.get_state_for_events(
+            room_state = await self.state_storage.get_state_for_events(
                 [membership_event_id], StateFilter.from_types([key])
             )
             data = room_state[membership_event_id].get(key)
@@ -186,7 +193,7 @@ class MessageHandler:
 
             # check whether the user is in the room at that time to determine
             # whether they should be treated as peeking.
-            state_map = await self.state_store.get_state_for_event(
+            state_map = await self.state_storage.get_state_for_event(
                 last_event.event_id,
                 StateFilter.from_types([(EventTypes.Member, user_id)]),
             )
@@ -207,7 +214,7 @@ class MessageHandler:
             )
 
             if visible_events:
-                room_state_events = await self.state_store.get_state_for_events(
+                room_state_events = await self.state_storage.get_state_for_events(
                     [last_event.event_id], state_filter=state_filter
                 )
                 room_state: Mapping[Any, EventBase] = room_state_events[
@@ -237,7 +244,7 @@ class MessageHandler:
                 assert (
                     membership_event_id is not None
                 ), "check_user_in_room_or_world_readable returned invalid data"
-                room_state_events = await self.state_store.get_state_for_events(
+                room_state_events = await self.state_storage.get_state_for_events(
                     [membership_event_id], state_filter=state_filter
                 )
                 room_state = room_state_events[membership_event_id]
@@ -1022,8 +1029,35 @@ class EventCreationHandler:
             #
             # TODO(faster_joins): figure out how this works, and make sure that the
             #   old state is complete.
-            old_state = await self.store.get_events_as_list(state_event_ids)
-            context = await self.state.compute_event_context(event, old_state=old_state)
+            metadata = await self.store.get_metadata_for_events(state_event_ids)
+
+            state_map_for_event: MutableStateMap[str] = {}
+            for state_id in state_event_ids:
+                data = metadata.get(state_id)
+                if data is None:
+                    # We're trying to persist a new historical batch of events
+                    # with the given state, e.g. via
+                    # `RoomBatchSendEventRestServlet`. The state can be inferred
+                    # by Synapse or set directly by the client.
+                    #
+                    # Either way, we should have persisted all the state before
+                    # getting here.
+                    raise Exception(
+                        f"State event {state_id} not found in DB,"
+                        " Synapse should have persisted it before using it."
+                    )
+
+                if data.state_key is None:
+                    raise Exception(
+                        f"Trying to set non-state event {state_id} as state"
+                    )
+
+                state_map_for_event[(data.event_type, data.state_key)] = state_id
+
+            context = await self.state.compute_event_context(
+                event,
+                state_ids_before_event=state_map_for_event,
+            )
         else:
             context = await self.state.compute_event_context(event)
 
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 19a4407050..6f4820c240 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -130,7 +130,7 @@ class PaginationHandler:
         self.auth = hs.get_auth()
         self.store = hs.get_datastores().main
         self.storage = hs.get_storage()
-        self.state_store = self.storage.state
+        self.state_storage = self.storage.state
         self.clock = hs.get_clock()
         self._server_name = hs.hostname
         self._room_shutdown_handler = hs.get_room_shutdown_handler()
@@ -539,7 +539,7 @@ class PaginationHandler:
                 (EventTypes.Member, event.sender) for event in events
             )
 
-            state_ids = await self.state_store.get_state_ids_for_event(
+            state_ids = await self.state_storage.get_state_ids_for_event(
                 events[0].event_id, state_filter=state_filter
             )
 
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 92e1de0500..e2775b34f1 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -1193,7 +1193,7 @@ class RoomContextHandler:
         self.auth = hs.get_auth()
         self.store = hs.get_datastores().main
         self.storage = hs.get_storage()
-        self.state_store = self.storage.state
+        self.state_storage = self.storage.state
         self._relations_handler = hs.get_relations_handler()
 
     async def get_event_context(
@@ -1293,7 +1293,7 @@ class RoomContextHandler:
         # first? Shouldn't we be consistent with /sync?
         # https://github.com/matrix-org/matrix-doc/issues/687
 
-        state = await self.state_store.get_state_for_events(
+        state = await self.state_storage.get_state_for_events(
             [last_event_id], state_filter=state_filter
         )
 
diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py
index fbfd748406..7ce32f2e9c 100644
--- a/synapse/handlers/room_batch.py
+++ b/synapse/handlers/room_batch.py
@@ -17,7 +17,7 @@ class RoomBatchHandler:
     def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.store = hs.get_datastores().main
-        self.state_store = hs.get_storage().state
+        self.state_storage = hs.get_storage().state
         self.event_creation_handler = hs.get_event_creation_handler()
         self.room_member_handler = hs.get_room_member_handler()
         self.auth = hs.get_auth()
@@ -141,7 +141,7 @@ class RoomBatchHandler:
         ) = await self.store.get_max_depth_of(event_ids)
         # mapping from (type, state_key) -> state_event_id
         assert most_recent_event_id is not None
-        prev_state_map = await self.state_store.get_state_ids_for_event(
+        prev_state_map = await self.state_storage.get_state_ids_for_event(
             most_recent_event_id
         )
         # List of state event ID's
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index ea876c168d..00662dc961 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -1081,17 +1081,6 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         # Transfer alias mappings in the room directory
         await self.store.update_aliases_for_room(old_room_id, room_id)
 
-        # Check if any groups we own contain the predecessor room
-        local_group_ids = await self.store.get_local_groups_for_room(old_room_id)
-        for group_id in local_group_ids:
-            # Add new the new room to those groups
-            await self.store.add_room_to_group(
-                group_id, room_id, old_room is not None and old_room["is_public"]
-            )
-
-            # Remove the old room from those groups
-            await self.store.remove_room_from_group(group_id, old_room_id)
-
     async def copy_user_state_on_room_upgrade(
         self, old_room_id: str, new_room_id: str, user_ids: Iterable[str]
     ) -> None:
diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py
index af83de3193..75aee6a111 100644
--- a/synapse/handlers/room_summary.py
+++ b/synapse/handlers/room_summary.py
@@ -662,7 +662,8 @@ class RoomSummaryHandler:
         # The API doesn't return the room version so assume that a
         # join rule of knock is valid.
         if (
-            room.get("join_rules") in (JoinRules.PUBLIC, JoinRules.KNOCK)
+            room.get("join_rule")
+            in (JoinRules.PUBLIC, JoinRules.KNOCK, JoinRules.KNOCK_RESTRICTED)
             or room.get("world_readable") is True
         ):
             return True
@@ -713,9 +714,6 @@ class RoomSummaryHandler:
             "canonical_alias": stats["canonical_alias"],
             "num_joined_members": stats["joined_members"],
             "avatar_url": stats["avatar"],
-            # plural join_rules is a documentation error but kept for historical
-            # purposes. Should match /publicRooms.
-            "join_rules": stats["join_rules"],
             "join_rule": stats["join_rules"],
             "world_readable": (
                 stats["history_visibility"] == HistoryVisibility.WORLD_READABLE
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index cd1c47dae8..e02c915248 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -56,7 +56,7 @@ class SearchHandler:
         self._event_serializer = hs.get_event_client_serializer()
         self._relations_handler = hs.get_relations_handler()
         self.storage = hs.get_storage()
-        self.state_store = self.storage.state
+        self.state_storage = self.storage.state
         self.auth = hs.get_auth()
 
     async def get_old_rooms_from_upgraded_room(self, room_id: str) -> Iterable[str]:
@@ -677,7 +677,7 @@ class SearchHandler:
                     [(EventTypes.Member, sender) for sender in senders]
                 )
 
-                state = await self.state_store.get_state_for_event(
+                state = await self.state_storage.get_state_for_event(
                     last_event_id, state_filter
                 )
 
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 59b5d497be..c5c538e0c3 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -166,16 +166,6 @@ class KnockedSyncResult:
         return True
 
 
-@attr.s(slots=True, frozen=True, auto_attribs=True)
-class GroupsSyncResult:
-    join: JsonDict
-    invite: JsonDict
-    leave: JsonDict
-
-    def __bool__(self) -> bool:
-        return bool(self.join or self.invite or self.leave)
-
-
 @attr.s(slots=True, auto_attribs=True)
 class _RoomChanges:
     """The set of room entries to include in the sync, plus the set of joined
@@ -206,7 +196,6 @@ class SyncResult:
             for this device
         device_unused_fallback_key_types: List of key types that have an unused fallback
             key
-        groups: Group updates, if any
     """
 
     next_batch: StreamToken
@@ -220,7 +209,6 @@ class SyncResult:
     device_lists: DeviceListUpdates
     device_one_time_keys_count: JsonDict
     device_unused_fallback_key_types: List[str]
-    groups: Optional[GroupsSyncResult]
 
     def __bool__(self) -> bool:
         """Make the result appear empty if there are no updates. This is used
@@ -236,7 +224,6 @@ class SyncResult:
             or self.account_data
             or self.to_device
             or self.device_lists
-            or self.groups
         )
 
 
@@ -252,7 +239,7 @@ class SyncHandler:
         self.state = hs.get_state_handler()
         self.auth = hs.get_auth()
         self.storage = hs.get_storage()
-        self.state_store = self.storage.state
+        self.state_storage = self.storage.state
 
         # TODO: flush cache entries on subsequent sync request.
         #    Once we get the next /sync request (ie, one with the same access token
@@ -643,7 +630,7 @@ class SyncHandler:
             event: event of interest
             state_filter: The state filter used to fetch state from the database.
         """
-        state_ids = await self.state_store.get_state_ids_for_event(
+        state_ids = await self.state_storage.get_state_ids_for_event(
             event.event_id, state_filter=state_filter or StateFilter.all()
         )
         if event.is_state():
@@ -723,7 +710,7 @@ class SyncHandler:
             return None
 
         last_event = last_events[-1]
-        state_ids = await self.state_store.get_state_ids_for_event(
+        state_ids = await self.state_storage.get_state_ids_for_event(
             last_event.event_id,
             state_filter=StateFilter.from_types(
                 [(EventTypes.Name, ""), (EventTypes.CanonicalAlias, "")]
@@ -901,11 +888,13 @@ class SyncHandler:
 
             if full_state:
                 if batch:
-                    current_state_ids = await self.state_store.get_state_ids_for_event(
-                        batch.events[-1].event_id, state_filter=state_filter
+                    current_state_ids = (
+                        await self.state_storage.get_state_ids_for_event(
+                            batch.events[-1].event_id, state_filter=state_filter
+                        )
                     )
 
-                    state_ids = await self.state_store.get_state_ids_for_event(
+                    state_ids = await self.state_storage.get_state_ids_for_event(
                         batch.events[0].event_id, state_filter=state_filter
                     )
 
@@ -926,7 +915,7 @@ class SyncHandler:
             elif batch.limited:
                 if batch:
                     state_at_timeline_start = (
-                        await self.state_store.get_state_ids_for_event(
+                        await self.state_storage.get_state_ids_for_event(
                             batch.events[0].event_id, state_filter=state_filter
                         )
                     )
@@ -960,8 +949,10 @@ class SyncHandler:
                 )
 
                 if batch:
-                    current_state_ids = await self.state_store.get_state_ids_for_event(
-                        batch.events[-1].event_id, state_filter=state_filter
+                    current_state_ids = (
+                        await self.state_storage.get_state_ids_for_event(
+                            batch.events[-1].event_id, state_filter=state_filter
+                        )
                     )
                 else:
                     # Its not clear how we get here, but empirically we do
@@ -991,7 +982,7 @@ class SyncHandler:
                         # So we fish out all the member events corresponding to the
                         # timeline here, and then dedupe any redundant ones below.
 
-                        state_ids = await self.state_store.get_state_ids_for_event(
+                        state_ids = await self.state_storage.get_state_ids_for_event(
                             batch.events[0].event_id,
                             # we only want members!
                             state_filter=StateFilter.from_types(
@@ -1157,10 +1148,6 @@ class SyncHandler:
                 await self.store.get_e2e_unused_fallback_key_types(user_id, device_id)
             )
 
-        if self.hs_config.experimental.groups_enabled:
-            logger.debug("Fetching group data")
-            await self._generate_sync_entry_for_groups(sync_result_builder)
-
         num_events = 0
 
         # debug for https://github.com/matrix-org/synapse/issues/9424
@@ -1184,57 +1171,11 @@ class SyncHandler:
             archived=sync_result_builder.archived,
             to_device=sync_result_builder.to_device,
             device_lists=device_lists,
-            groups=sync_result_builder.groups,
             device_one_time_keys_count=one_time_key_counts,
             device_unused_fallback_key_types=unused_fallback_key_types,
             next_batch=sync_result_builder.now_token,
         )
 
-    @measure_func("_generate_sync_entry_for_groups")
-    async def _generate_sync_entry_for_groups(
-        self, sync_result_builder: "SyncResultBuilder"
-    ) -> None:
-        user_id = sync_result_builder.sync_config.user.to_string()
-        since_token = sync_result_builder.since_token
-        now_token = sync_result_builder.now_token
-
-        if since_token and since_token.groups_key:
-            results = await self.store.get_groups_changes_for_user(
-                user_id, since_token.groups_key, now_token.groups_key
-            )
-        else:
-            results = await self.store.get_all_groups_for_user(
-                user_id, now_token.groups_key
-            )
-
-        invited = {}
-        joined = {}
-        left = {}
-        for result in results:
-            membership = result["membership"]
-            group_id = result["group_id"]
-            gtype = result["type"]
-            content = result["content"]
-
-            if membership == "join":
-                if gtype == "membership":
-                    # TODO: Add profile
-                    content.pop("membership", None)
-                    joined[group_id] = content["content"]
-                else:
-                    joined.setdefault(group_id, {})[gtype] = content
-            elif membership == "invite":
-                if gtype == "membership":
-                    content.pop("membership", None)
-                    invited[group_id] = content["content"]
-            else:
-                if gtype == "membership":
-                    left[group_id] = content["content"]
-
-        sync_result_builder.groups = GroupsSyncResult(
-            join=joined, invite=invited, leave=left
-        )
-
     @measure_func("_generate_sync_entry_for_device_list")
     async def _generate_sync_entry_for_device_list(
         self,
@@ -2333,7 +2274,6 @@ class SyncResultBuilder:
         invited
         knocked
         archived
-        groups
         to_device
     """
 
@@ -2349,7 +2289,6 @@ class SyncResultBuilder:
     invited: List[InvitedSyncResult] = attr.Factory(list)
     knocked: List[KnockedSyncResult] = attr.Factory(list)
     archived: List[ArchivedSyncResult] = attr.Factory(list)
-    groups: Optional[GroupsSyncResult] = None
     to_device: List[JsonDict] = attr.Factory(list)
 
     def calculate_user_changes(self) -> Tuple[Set[str], Set[str]]: