summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/api/auth.py7
-rw-r--r--synapse/api/errors.py22
-rw-r--r--synapse/api/filtering.py30
-rw-r--r--synapse/app/phone_stats_home.py4
-rw-r--r--synapse/config/experimental.py10
-rw-r--r--synapse/config/room_directory.py6
-rw-r--r--synapse/event_auth.py23
-rw-r--r--synapse/events/__init__.py6
-rw-r--r--synapse/events/builder.py6
-rw-r--r--synapse/events/snapshot.py174
-rw-r--r--synapse/events/third_party_rules.py6
-rw-r--r--synapse/federation/federation_server.py5
-rw-r--r--synapse/handlers/account_data.py10
-rw-r--r--synapse/handlers/auth.py5
-rw-r--r--synapse/handlers/deactivate_account.py7
-rw-r--r--synapse/handlers/directory.py7
-rw-r--r--synapse/handlers/e2e_keys.py11
-rw-r--r--synapse/handlers/event_auth.py16
-rw-r--r--synapse/handlers/federation.py61
-rw-r--r--synapse/handlers/federation_event.py65
-rw-r--r--synapse/handlers/identity.py47
-rw-r--r--synapse/handlers/initial_sync.py5
-rw-r--r--synapse/handlers/message.py50
-rw-r--r--synapse/handlers/receipts.py4
-rw-r--r--synapse/handlers/room.py6
-rw-r--r--synapse/handlers/room_member.py120
-rw-r--r--synapse/handlers/room_member_worker.py5
-rw-r--r--synapse/handlers/room_summary.py4
-rw-r--r--synapse/handlers/sync.py105
-rw-r--r--synapse/http/server.py40
-rw-r--r--synapse/logging/opentracing.py10
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py52
-rw-r--r--synapse/rest/admin/media.py18
-rw-r--r--synapse/rest/admin/users.py3
-rw-r--r--synapse/rest/client/account.py10
-rw-r--r--synapse/rest/client/room_keys.py48
-rw-r--r--synapse/rest/client/tags.py4
-rw-r--r--synapse/rest/media/v1/media_storage.py7
-rw-r--r--synapse/server.py12
-rw-r--r--synapse/state/__init__.py178
-rw-r--r--synapse/storage/_base.py2
-rw-r--r--synapse/storage/controllers/state.py6
-rw-r--r--synapse/storage/database.py1
-rw-r--r--synapse/storage/databases/main/account_data.py129
-rw-r--r--synapse/storage/databases/main/appservice.py2
-rw-r--r--synapse/storage/databases/main/devices.py49
-rw-r--r--synapse/storage/databases/main/directory.py4
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py27
-rw-r--r--synapse/storage/databases/main/event_federation.py11
-rw-r--r--synapse/storage/databases/main/event_push_actions.py7
-rw-r--r--synapse/storage/databases/main/events.py30
-rw-r--r--synapse/storage/databases/main/events_bg_updates.py6
-rw-r--r--synapse/storage/databases/main/monthly_active_users.py4
-rw-r--r--synapse/storage/databases/main/receipts.py10
-rw-r--r--synapse/storage/databases/main/registration.py4
-rw-r--r--synapse/storage/databases/main/relations.py7
-rw-r--r--synapse/storage/databases/main/roommember.py19
-rw-r--r--synapse/storage/databases/main/signatures.py6
-rw-r--r--synapse/storage/databases/main/tags.py8
-rw-r--r--synapse/storage/databases/main/user_directory.py4
-rw-r--r--synapse/storage/prepare_database.py13
-rw-r--r--synapse/types/__init__.py2
62 files changed, 1004 insertions, 556 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 3d7f986ac7..66e869bc2d 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -32,7 +32,6 @@ from synapse.appservice import ApplicationService
 from synapse.http import get_request_user_agent
 from synapse.http.site import SynapseRequest
 from synapse.logging.opentracing import (
-    SynapseTags,
     active_span,
     force_tracing,
     start_active_span,
@@ -162,12 +161,6 @@ class Auth:
                 parent_span.set_tag(
                     "authenticated_entity", requester.authenticated_entity
                 )
-                # We tag the Synapse instance name so that it's an easy jumping
-                # off point into the logs. Can also be used to filter for an
-                # instance that is under load.
-                parent_span.set_tag(
-                    SynapseTags.INSTANCE_NAME, self.hs.get_instance_name()
-                )
                 parent_span.set_tag("user_id", requester.user.to_string())
                 if requester.device_id is not None:
                     parent_span.set_tag("device_id", requester.device_id)
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index c2c177fd71..9235ce6536 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -751,3 +751,25 @@ class ModuleFailedException(Exception):
     Raised when a module API callback fails, for example because it raised an
     exception.
     """
+
+
+class PartialStateConflictError(SynapseError):
+    """An internal error raised when attempting to persist an event with partial state
+    after the room containing the event has been un-partial stated.
+
+    This error should be handled by recomputing the event context and trying again.
+
+    This error has an HTTP status code so that it can be transported over replication.
+    It should not be exposed to clients.
+    """
+
+    @staticmethod
+    def message() -> str:
+        return "Cannot persist partial state event in un-partial stated room"
+
+    def __init__(self) -> None:
+        super().__init__(
+            HTTPStatus.CONFLICT,
+            msg=PartialStateConflictError.message(),
+            errcode=Codes.UNKNOWN,
+        )
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index 83c42fc25a..b9f432cc23 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -219,9 +219,13 @@ class FilterCollection:
         self._room_timeline_filter = Filter(hs, room_filter_json.get("timeline", {}))
         self._room_state_filter = Filter(hs, room_filter_json.get("state", {}))
         self._room_ephemeral_filter = Filter(hs, room_filter_json.get("ephemeral", {}))
-        self._room_account_data = Filter(hs, room_filter_json.get("account_data", {}))
+        self._room_account_data_filter = Filter(
+            hs, room_filter_json.get("account_data", {})
+        )
         self._presence_filter = Filter(hs, filter_json.get("presence", {}))
-        self._account_data = Filter(hs, filter_json.get("account_data", {}))
+        self._global_account_data_filter = Filter(
+            hs, filter_json.get("account_data", {})
+        )
 
         self.include_leave = filter_json.get("room", {}).get("include_leave", False)
         self.event_fields = filter_json.get("event_fields", [])
@@ -256,8 +260,10 @@ class FilterCollection:
     ) -> List[UserPresenceState]:
         return await self._presence_filter.filter(presence_states)
 
-    async def filter_account_data(self, events: Iterable[JsonDict]) -> List[JsonDict]:
-        return await self._account_data.filter(events)
+    async def filter_global_account_data(
+        self, events: Iterable[JsonDict]
+    ) -> List[JsonDict]:
+        return await self._global_account_data_filter.filter(events)
 
     async def filter_room_state(self, events: Iterable[EventBase]) -> List[EventBase]:
         return await self._room_state_filter.filter(
@@ -279,7 +285,7 @@ class FilterCollection:
     async def filter_room_account_data(
         self, events: Iterable[JsonDict]
     ) -> List[JsonDict]:
-        return await self._room_account_data.filter(
+        return await self._room_account_data_filter.filter(
             await self._room_filter.filter(events)
         )
 
@@ -292,6 +298,13 @@ class FilterCollection:
             or self._presence_filter.filters_all_senders()
         )
 
+    def blocks_all_global_account_data(self) -> bool:
+        """True if all global acount data will be filtered out."""
+        return (
+            self._global_account_data_filter.filters_all_types()
+            or self._global_account_data_filter.filters_all_senders()
+        )
+
     def blocks_all_room_ephemeral(self) -> bool:
         return (
             self._room_ephemeral_filter.filters_all_types()
@@ -299,6 +312,13 @@ class FilterCollection:
             or self._room_ephemeral_filter.filters_all_rooms()
         )
 
+    def blocks_all_room_account_data(self) -> bool:
+        return (
+            self._room_account_data_filter.filters_all_types()
+            or self._room_account_data_filter.filters_all_senders()
+            or self._room_account_data_filter.filters_all_rooms()
+        )
+
     def blocks_all_room_timeline(self) -> bool:
         return (
             self._room_timeline_filter.filters_all_types()
diff --git a/synapse/app/phone_stats_home.py b/synapse/app/phone_stats_home.py
index 53db1e85b3..897dd3edac 100644
--- a/synapse/app/phone_stats_home.py
+++ b/synapse/app/phone_stats_home.py
@@ -15,7 +15,7 @@ import logging
 import math
 import resource
 import sys
-from typing import TYPE_CHECKING, List, Sized, Tuple
+from typing import TYPE_CHECKING, List, Mapping, Sized, Tuple
 
 from prometheus_client import Gauge
 
@@ -194,7 +194,7 @@ def start_phone_stats_home(hs: "HomeServer") -> None:
     @wrap_as_background_process("generate_monthly_active_users")
     async def generate_monthly_active_users() -> None:
         current_mau_count = 0
-        current_mau_count_by_service = {}
+        current_mau_count_by_service: Mapping[str, int] = {}
         reserved_users: Sized = ()
         store = hs.get_datastores().main
         if hs.config.server.limit_usage_by_mau or hs.config.server.mau_stats_only:
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index 53c0682dfd..6ac2f0c10d 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -169,6 +169,16 @@ class ExperimentalConfig(Config):
         # MSC3925: do not replace events with their edits
         self.msc3925_inhibit_edit = experimental.get("msc3925_inhibit_edit", False)
 
+        # MSC3758: exact_event_match push rule condition
+        self.msc3758_exact_event_match = experimental.get(
+            "msc3758_exact_event_match", False
+        )
+
+        # MSC3873: Disambiguate event_match keys.
+        self.msc3783_escape_event_match_key = experimental.get(
+            "msc3783_escape_event_match_key", False
+        )
+
         # MSC3952: Intentional mentions
         self.msc3952_intentional_mentions = experimental.get(
             "msc3952_intentional_mentions", False
diff --git a/synapse/config/room_directory.py b/synapse/config/room_directory.py
index 3ed236217f..8666c22f01 100644
--- a/synapse/config/room_directory.py
+++ b/synapse/config/room_directory.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Any, List
+from typing import Any, Collection
 
 from matrix_common.regex import glob_to_regex
 
@@ -70,7 +70,7 @@ class RoomDirectoryConfig(Config):
         return False
 
     def is_publishing_room_allowed(
-        self, user_id: str, room_id: str, aliases: List[str]
+        self, user_id: str, room_id: str, aliases: Collection[str]
     ) -> bool:
         """Checks if the given user is allowed to publish the room
 
@@ -122,7 +122,7 @@ class _RoomDirectoryRule:
         except Exception as e:
             raise ConfigError("Failed to parse glob into regex") from e
 
-    def matches(self, user_id: str, room_id: str, aliases: List[str]) -> bool:
+    def matches(self, user_id: str, room_id: str, aliases: Collection[str]) -> bool:
         """Tests if this rule matches the given user_id, room_id and aliases.
 
         Args:
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index e0be9f88cc..4d6d1b8ebd 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -16,18 +16,7 @@
 import collections.abc
 import logging
 import typing
-from typing import (
-    Any,
-    Collection,
-    Dict,
-    Iterable,
-    List,
-    Mapping,
-    Optional,
-    Set,
-    Tuple,
-    Union,
-)
+from typing import Any, Dict, Iterable, List, Mapping, Optional, Set, Tuple, Union
 
 from canonicaljson import encode_canonical_json
 from signedjson.key import decode_verify_key_bytes
@@ -56,7 +45,13 @@ from synapse.api.room_versions import (
     RoomVersions,
 )
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
-from synapse.types import MutableStateMap, StateMap, UserID, get_domain_from_id
+from synapse.types import (
+    MutableStateMap,
+    StateMap,
+    StrCollection,
+    UserID,
+    get_domain_from_id,
+)
 
 if typing.TYPE_CHECKING:
     # conditional imports to avoid import cycle
@@ -69,7 +64,7 @@ logger = logging.getLogger(__name__)
 class _EventSourceStore(Protocol):
     async def get_events(
         self,
-        event_ids: Collection[str],
+        event_ids: StrCollection,
         redact_behaviour: EventRedactBehaviour,
         get_prev_content: bool = False,
         allow_rejected: bool = False,
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index 8aca9a3ab9..91118a8d84 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -39,7 +39,7 @@ from unpaddedbase64 import encode_base64
 
 from synapse.api.constants import RelationTypes
 from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions
-from synapse.types import JsonDict, RoomStreamToken
+from synapse.types import JsonDict, RoomStreamToken, StrCollection
 from synapse.util.caches import intern_dict
 from synapse.util.frozenutils import freeze
 from synapse.util.stringutils import strtobool
@@ -413,7 +413,7 @@ class EventBase(metaclass=abc.ABCMeta):
         """
         return [e for e, _ in self._dict["prev_events"]]
 
-    def auth_event_ids(self) -> Sequence[str]:
+    def auth_event_ids(self) -> StrCollection:
         """Returns the list of auth event IDs. The order matches the order
         specified in the event, though there is no meaning to it.
 
@@ -558,7 +558,7 @@ class FrozenEventV2(EventBase):
         """
         return self._dict["prev_events"]
 
-    def auth_event_ids(self) -> Sequence[str]:
+    def auth_event_ids(self) -> StrCollection:
         """Returns the list of auth event IDs. The order matches the order
         specified in the event, though there is no meaning to it.
 
diff --git a/synapse/events/builder.py b/synapse/events/builder.py
index 94dd1298e1..c82745275f 100644
--- a/synapse/events/builder.py
+++ b/synapse/events/builder.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Any, Collection, Dict, List, Optional, Tuple, Union
 
 import attr
 from signedjson.types import SigningKey
@@ -103,7 +103,7 @@ class EventBuilder:
 
     async def build(
         self,
-        prev_event_ids: List[str],
+        prev_event_ids: Collection[str],
         auth_event_ids: Optional[List[str]],
         depth: Optional[int] = None,
     ) -> EventBase:
@@ -136,7 +136,7 @@ class EventBuilder:
 
         format_version = self.room_version.event_format
         # The types of auth/prev events changes between event versions.
-        prev_events: Union[List[str], List[Tuple[str, Dict[str, str]]]]
+        prev_events: Union[Collection[str], List[Tuple[str, Dict[str, str]]]]
         auth_events: Union[List[str], List[Tuple[str, Dict[str, str]]]]
         if format_version == EventFormatVersions.ROOM_V1_V2:
             auth_events = await self._store.add_event_hashes(auth_event_ids)
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index 6eaef8b57a..e0d82ad81c 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -11,6 +11,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from abc import ABC, abstractmethod
 from typing import TYPE_CHECKING, List, Optional, Tuple
 
 import attr
@@ -26,8 +27,51 @@ if TYPE_CHECKING:
     from synapse.types.state import StateFilter
 
 
+class UnpersistedEventContextBase(ABC):
+    """
+    This is a base class for EventContext and UnpersistedEventContext, objects which
+    hold information relevant to storing an associated event. Note that an
+    UnpersistedEventContexts must be converted into an EventContext before it is
+    suitable to send to the db with its associated event.
+
+    Attributes:
+        _storage: storage controllers for interfacing with the database
+        app_service: If the associated event is being sent by a (local) application service, that
+            app service.
+    """
+
+    def __init__(self, storage_controller: "StorageControllers"):
+        self._storage: "StorageControllers" = storage_controller
+        self.app_service: Optional[ApplicationService] = None
+
+    @abstractmethod
+    async def persist(
+        self,
+        event: EventBase,
+    ) -> "EventContext":
+        """
+        A method to convert an UnpersistedEventContext to an EventContext, suitable for
+        sending to the database with the associated event.
+        """
+        pass
+
+    @abstractmethod
+    async def get_prev_state_ids(
+        self, state_filter: Optional["StateFilter"] = None
+    ) -> StateMap[str]:
+        """
+        Gets the room state at the event (ie not including the event if the event is a
+        state event).
+
+        Args:
+            state_filter: specifies the type of state event to fetch from DB, example:
+            EventTypes.JoinRules
+        """
+        pass
+
+
 @attr.s(slots=True, auto_attribs=True)
-class EventContext:
+class EventContext(UnpersistedEventContextBase):
     """
     Holds information relevant to persisting an event
 
@@ -77,9 +121,6 @@ class EventContext:
         delta_ids: If ``prev_group`` is not None, the state delta between ``prev_group``
             and ``state_group``.
 
-        app_service: If this event is being sent by a (local) application service, that
-            app service.
-
         partial_state: if True, we may be storing this event with a temporary,
             incomplete state.
     """
@@ -122,6 +163,9 @@ class EventContext:
         """Return an EventContext instance suitable for persisting an outlier event"""
         return EventContext(storage=storage)
 
+    async def persist(self, event: EventBase) -> "EventContext":
+        return self
+
     async def serialize(self, event: EventBase, store: "DataStore") -> JsonDict:
         """Converts self to a type that can be serialized as JSON, and then
         deserialized by `deserialize`
@@ -254,6 +298,128 @@ class EventContext:
         )
 
 
+@attr.s(slots=True, auto_attribs=True)
+class UnpersistedEventContext(UnpersistedEventContextBase):
+    """
+    The event context holds information about the state groups for an event. It is important
+    to remember that an event technically has two state groups: the state group before the
+    event, and the state group after the event. If the event is not a state event, the state
+    group will not change (ie the state group before the event will be the same as the state
+    group after the event), but if it is a state event the state group before the event
+    will differ from the state group after the event.
+    This is a version of an EventContext before the new state group (if any) has been
+    computed and stored. It contains information about the state before the event (which
+    also may be the information after the event, if the event is not a state event). The
+    UnpersistedEventContext must be converted into an EventContext by calling the method
+    'persist' on it before it is suitable to be sent to the DB for processing.
+
+        state_group_after_event:
+             The state group after the event. This will always be None until it is persisted.
+             If the event is not a state event, this will be the same as
+             state_group_before_event.
+
+        state_group_before_event:
+            The ID of the state group representing the state of the room before this event.
+
+        state_delta_due_to_event:
+            If the event is a state event, then this is the delta of the state between
+             `state_group` and `state_group_before_event`
+
+        prev_group_for_state_group_before_event:
+            If it is known, ``state_group_before_event``'s previous state group.
+
+        delta_ids_to_state_group_before_event:
+             If ``prev_group_for_state_group_before_event`` is not None, the state delta
+             between ``prev_group_for_state_group_before_event`` and ``state_group_before_event``.
+
+        partial_state:
+            Whether the event has partial state.
+
+        state_map_before_event:
+            A map of the state before the event, i.e. the state at `state_group_before_event`
+    """
+
+    _storage: "StorageControllers"
+    state_group_before_event: Optional[int]
+    state_group_after_event: Optional[int]
+    state_delta_due_to_event: Optional[dict]
+    prev_group_for_state_group_before_event: Optional[int]
+    delta_ids_to_state_group_before_event: Optional[StateMap[str]]
+    partial_state: bool
+    state_map_before_event: Optional[StateMap[str]] = None
+
+    async def get_prev_state_ids(
+        self, state_filter: Optional["StateFilter"] = None
+    ) -> StateMap[str]:
+        """
+        Gets the room state map, excluding this event.
+
+        Args:
+            state_filter: specifies the type of state event to fetch from DB
+
+        Returns:
+            Maps a (type, state_key) to the event ID of the state event matching
+            this tuple.
+        """
+        if self.state_map_before_event:
+            return self.state_map_before_event
+
+        assert self.state_group_before_event is not None
+        return await self._storage.state.get_state_ids_for_group(
+            self.state_group_before_event, state_filter
+        )
+
+    async def persist(self, event: EventBase) -> EventContext:
+        """
+        Creates a full `EventContext` for the event, persisting any referenced state that
+        has not yet been persisted.
+
+        Args:
+             event: event that the EventContext is associated with.
+
+        Returns: An EventContext suitable for sending to the database with the event
+        for persisting
+        """
+        assert self.partial_state is not None
+
+        # If we have a full set of state for before the event but don't have a state
+        # group for that state, we need to get one
+        if self.state_group_before_event is None:
+            assert self.state_map_before_event
+            state_group_before_event = await self._storage.state.store_state_group(
+                event.event_id,
+                event.room_id,
+                prev_group=self.prev_group_for_state_group_before_event,
+                delta_ids=self.delta_ids_to_state_group_before_event,
+                current_state_ids=self.state_map_before_event,
+            )
+            self.state_group_before_event = state_group_before_event
+
+        # if the event isn't a state event the state group doesn't change
+        if not self.state_delta_due_to_event:
+            state_group_after_event = self.state_group_before_event
+
+        # otherwise if it is a state event we need to get a state group for it
+        else:
+            state_group_after_event = await self._storage.state.store_state_group(
+                event.event_id,
+                event.room_id,
+                prev_group=self.state_group_before_event,
+                delta_ids=self.state_delta_due_to_event,
+                current_state_ids=None,
+            )
+
+        return EventContext.with_state(
+            storage=self._storage,
+            state_group=state_group_after_event,
+            state_group_before_event=self.state_group_before_event,
+            state_delta_due_to_event=self.state_delta_due_to_event,
+            partial_state=self.partial_state,
+            prev_group=self.state_group_before_event,
+            delta_ids=self.state_delta_due_to_event,
+        )
+
+
 def _encode_state_dict(
     state_dict: Optional[StateMap[str]],
 ) -> Optional[List[Tuple[str, str, str]]]:
diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py
index 72ab696898..97c61cc258 100644
--- a/synapse/events/third_party_rules.py
+++ b/synapse/events/third_party_rules.py
@@ -18,7 +18,7 @@ from twisted.internet.defer import CancelledError
 
 from synapse.api.errors import ModuleFailedException, SynapseError
 from synapse.events import EventBase
-from synapse.events.snapshot import EventContext
+from synapse.events.snapshot import UnpersistedEventContextBase
 from synapse.storage.roommember import ProfileInfo
 from synapse.types import Requester, StateMap
 from synapse.util.async_helpers import delay_cancellation, maybe_awaitable
@@ -231,7 +231,9 @@ class ThirdPartyEventRules:
             self._on_threepid_bind_callbacks.append(on_threepid_bind)
 
     async def check_event_allowed(
-        self, event: EventBase, context: EventContext
+        self,
+        event: EventBase,
+        context: UnpersistedEventContextBase,
     ) -> Tuple[bool, Optional[dict]]:
         """Check if a provided event should be allowed in the given context.
 
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 8d36172484..6d99845de5 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -23,6 +23,7 @@ from typing import (
     Collection,
     Dict,
     List,
+    Mapping,
     Optional,
     Tuple,
     Union,
@@ -47,6 +48,7 @@ from synapse.api.errors import (
     FederationError,
     IncompatibleRoomVersionError,
     NotFoundError,
+    PartialStateConflictError,
     SynapseError,
     UnsupportedRoomVersionError,
 )
@@ -80,7 +82,6 @@ from synapse.replication.http.federation import (
     ReplicationFederationSendEduRestServlet,
     ReplicationGetQueryRestServlet,
 )
-from synapse.storage.databases.main.events import PartialStateConflictError
 from synapse.storage.databases.main.lock import Lock
 from synapse.storage.databases.main.roommember import extract_heroes_from_room_summary
 from synapse.storage.roommember import MemberSummary
@@ -1512,7 +1513,7 @@ class FederationHandlerRegistry:
 def _get_event_ids_for_partial_state_join(
     join_event: EventBase,
     prev_state_ids: StateMap[str],
-    summary: Dict[str, MemberSummary],
+    summary: Mapping[str, MemberSummary],
 ) -> Collection[str]:
     """Calculate state to be returned in a partial_state send_join
 
diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py
index 67e789eef7..797de46dbc 100644
--- a/synapse/handlers/account_data.py
+++ b/synapse/handlers/account_data.py
@@ -343,10 +343,12 @@ class AccountDataEventSource(EventSource[int, JsonDict]):
                 }
             )
 
-        (
-            account_data,
-            room_account_data,
-        ) = await self.store.get_updated_account_data_for_user(user_id, last_stream_id)
+        account_data = await self.store.get_updated_global_account_data_for_user(
+            user_id, last_stream_id
+        )
+        room_account_data = await self.store.get_updated_room_account_data_for_user(
+            user_id, last_stream_id
+        )
 
         for account_data_type, content in account_data.items():
             results.append({"type": account_data_type, "content": content})
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 30f2d46c3c..57a6854b1e 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -1593,9 +1593,8 @@ class AuthHandler:
         if medium == "email":
             address = canonicalise_email(address)
 
-        identity_handler = self.hs.get_identity_handler()
-        result = await identity_handler.try_unbind_threepid(
-            user_id, {"medium": medium, "address": address, "id_server": id_server}
+        result = await self.hs.get_identity_handler().try_unbind_threepid(
+            user_id, medium, address, id_server
         )
 
         await self.store.user_delete_threepid(user_id, medium, address)
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index d74d135c0c..d24f649382 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -106,12 +106,7 @@ class DeactivateAccountHandler:
         for threepid in threepids:
             try:
                 result = await self._identity_handler.try_unbind_threepid(
-                    user_id,
-                    {
-                        "medium": threepid["medium"],
-                        "address": threepid["address"],
-                        "id_server": id_server,
-                    },
+                    user_id, threepid["medium"], threepid["address"], id_server
                 )
                 identity_server_supports_unbinding &= result
             except Exception:
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 2ea52257cb..a5798e9483 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -14,7 +14,7 @@
 
 import logging
 import string
-from typing import TYPE_CHECKING, Iterable, List, Optional
+from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence
 
 from typing_extensions import Literal
 
@@ -485,7 +485,8 @@ class DirectoryHandler:
                 )
             )
             if canonical_alias:
-                room_aliases.append(canonical_alias)
+                # Ensure we do not mutate room_aliases.
+                room_aliases = list(room_aliases) + [canonical_alias]
 
             if not self.config.roomdirectory.is_publishing_room_allowed(
                 user_id, room_id, room_aliases
@@ -528,7 +529,7 @@ class DirectoryHandler:
 
     async def get_aliases_for_room(
         self, requester: Requester, room_id: str
-    ) -> List[str]:
+    ) -> Sequence[str]:
         """
         Get a list of the aliases that currently point to this room on this server
         """
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index d2188ca08f..43cbece21b 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -159,19 +159,22 @@ class E2eKeysHandler:
             # A map of destination -> user ID -> device IDs.
             remote_queries_not_in_cache: Dict[str, Dict[str, Iterable[str]]] = {}
             if remote_queries:
-                query_list: List[Tuple[str, Optional[str]]] = []
+                user_ids = set()
+                user_and_device_ids: List[Tuple[str, str]] = []
                 for user_id, device_ids in remote_queries.items():
                     if device_ids:
-                        query_list.extend(
+                        user_and_device_ids.extend(
                             (user_id, device_id) for device_id in device_ids
                         )
                     else:
-                        query_list.append((user_id, None))
+                        user_ids.add(user_id)
 
                 (
                     user_ids_not_in_cache,
                     remote_results,
-                ) = await self.store.get_user_devices_from_cache(query_list)
+                ) = await self.store.get_user_devices_from_cache(
+                    user_ids, user_and_device_ids
+                )
 
                 # Check that the homeserver still shares a room with all cached users.
                 # Note that this check may be slightly racy when a remote user leaves a
diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py
index a23a8ce2a1..46dd63c3f0 100644
--- a/synapse/handlers/event_auth.py
+++ b/synapse/handlers/event_auth.py
@@ -202,7 +202,7 @@ class EventAuthHandler:
         state_ids: StateMap[str],
         room_version: RoomVersion,
         user_id: str,
-        prev_member_event: Optional[EventBase],
+        prev_membership: Optional[str],
     ) -> None:
         """
         Check whether a user can join a room without an invite due to restricted join rules.
@@ -214,15 +214,14 @@ class EventAuthHandler:
             state_ids: The state of the room as it currently is.
             room_version: The room version of the room being joined.
             user_id: The user joining the room.
-            prev_member_event: The current membership event for this user.
+            prev_membership: The current membership state for this user. `None` if the
+                user has never joined the room (equivalent to "leave").
 
         Raises:
             AuthError if the user cannot join the room.
         """
         # If the member is invited or currently joined, then nothing to do.
-        if prev_member_event and (
-            prev_member_event.membership in (Membership.JOIN, Membership.INVITE)
-        ):
+        if prev_membership in (Membership.JOIN, Membership.INVITE):
             return
 
         # This is not a room with a restricted join rule, so we don't need to do the
@@ -255,13 +254,14 @@ class EventAuthHandler:
             )
 
     async def has_restricted_join_rules(
-        self, state_ids: StateMap[str], room_version: RoomVersion
+        self, partial_state_ids: StateMap[str], room_version: RoomVersion
     ) -> bool:
         """
         Return if the room has the proper join rules set for access via rooms.
 
         Args:
-            state_ids: The state of the room as it currently is.
+            state_ids: The state of the room as it currently is. May be full or partial
+                state.
             room_version: The room version of the room to query.
 
         Returns:
@@ -272,7 +272,7 @@ class EventAuthHandler:
             return False
 
         # If there's no join rule, then it defaults to invite (so this doesn't apply).
-        join_rules_event_id = state_ids.get((EventTypes.JoinRules, ""), None)
+        join_rules_event_id = partial_state_ids.get((EventTypes.JoinRules, ""), None)
         if not join_rules_event_id:
             return False
 
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 7f64130e0a..08727e4857 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -49,6 +49,7 @@ from synapse.api.errors import (
     FederationPullAttemptBackoffError,
     HttpResponseException,
     NotFoundError,
+    PartialStateConflictError,
     RequestSendFailed,
     SynapseError,
 )
@@ -56,7 +57,7 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
 from synapse.crypto.event_signing import compute_event_signature
 from synapse.event_auth import validate_event_for_room_version
 from synapse.events import EventBase
-from synapse.events.snapshot import EventContext
+from synapse.events.snapshot import EventContext, UnpersistedEventContextBase
 from synapse.events.validator import EventValidator
 from synapse.federation.federation_client import InvalidResponseError
 from synapse.http.servlet import assert_params_in_dict
@@ -68,7 +69,6 @@ from synapse.replication.http.federation import (
     ReplicationCleanRoomRestServlet,
     ReplicationStoreRoomOnOutlierMembershipRestServlet,
 )
-from synapse.storage.databases.main.events import PartialStateConflictError
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
 from synapse.types import JsonDict, StrCollection, get_domain_from_id
 from synapse.types.state import StateFilter
@@ -990,7 +990,10 @@ class FederationHandler:
         )
 
         try:
-            event, context = await self.event_creation_handler.create_new_client_event(
+            (
+                event,
+                unpersisted_context,
+            ) = await self.event_creation_handler.create_new_client_event(
                 builder=builder
             )
         except SynapseError as e:
@@ -998,7 +1001,9 @@ class FederationHandler:
             raise
 
         # Ensure the user can even join the room.
-        await self._federation_event_handler.check_join_restrictions(context, event)
+        await self._federation_event_handler.check_join_restrictions(
+            unpersisted_context, event
+        )
 
         # The remote hasn't signed it yet, obviously. We'll do the full checks
         # when we get the event back in `on_send_join_request`
@@ -1178,7 +1183,7 @@ class FederationHandler:
             },
         )
 
-        event, context = await self.event_creation_handler.create_new_client_event(
+        event, _ = await self.event_creation_handler.create_new_client_event(
             builder=builder
         )
 
@@ -1228,12 +1233,13 @@ class FederationHandler:
             },
         )
 
-        event, context = await self.event_creation_handler.create_new_client_event(
-            builder=builder
-        )
+        (
+            event,
+            unpersisted_context,
+        ) = await self.event_creation_handler.create_new_client_event(builder=builder)
 
         event_allowed, _ = await self.third_party_event_rules.check_event_allowed(
-            event, context
+            event, unpersisted_context
         )
         if not event_allowed:
             logger.warning("Creation of knock %s forbidden by third-party rules", event)
@@ -1406,15 +1412,20 @@ class FederationHandler:
                 try:
                     (
                         event,
-                        context,
+                        unpersisted_context,
                     ) = await self.event_creation_handler.create_new_client_event(
                         builder=builder
                     )
 
-                    event, context = await self.add_display_name_to_third_party_invite(
-                        room_version_obj, event_dict, event, context
+                    (
+                        event,
+                        unpersisted_context,
+                    ) = await self.add_display_name_to_third_party_invite(
+                        room_version_obj, event_dict, event, unpersisted_context
                     )
 
+                    context = await unpersisted_context.persist(event)
+
                     EventValidator().validate_new(event, self.config)
 
                     # We need to tell the transaction queue to send this out, even
@@ -1483,14 +1494,19 @@ class FederationHandler:
             try:
                 (
                     event,
-                    context,
+                    unpersisted_context,
                 ) = await self.event_creation_handler.create_new_client_event(
                     builder=builder
                 )
-                event, context = await self.add_display_name_to_third_party_invite(
-                    room_version_obj, event_dict, event, context
+                (
+                    event,
+                    unpersisted_context,
+                ) = await self.add_display_name_to_third_party_invite(
+                    room_version_obj, event_dict, event, unpersisted_context
                 )
 
+                context = await unpersisted_context.persist(event)
+
                 try:
                     validate_event_for_room_version(event)
                     await self._event_auth_handler.check_auth_rules_from_context(event)
@@ -1522,8 +1538,8 @@ class FederationHandler:
         room_version_obj: RoomVersion,
         event_dict: JsonDict,
         event: EventBase,
-        context: EventContext,
-    ) -> Tuple[EventBase, EventContext]:
+        context: UnpersistedEventContextBase,
+    ) -> Tuple[EventBase, UnpersistedEventContextBase]:
         key = (
             EventTypes.ThirdPartyInvite,
             event.content["third_party_invite"]["signed"]["token"],
@@ -1557,11 +1573,14 @@ class FederationHandler:
             room_version_obj, event_dict
         )
         EventValidator().validate_builder(builder)
-        event, context = await self.event_creation_handler.create_new_client_event(
-            builder=builder
-        )
+
+        (
+            event,
+            unpersisted_context,
+        ) = await self.event_creation_handler.create_new_client_event(builder=builder)
+
         EventValidator().validate_new(event, self.config)
-        return event, context
+        return event, unpersisted_context
 
     async def _check_signature(self, event: EventBase, context: EventContext) -> None:
         """
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index e037acbca2..b7136f8d1c 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -47,6 +47,7 @@ from synapse.api.errors import (
     FederationError,
     FederationPullAttemptBackoffError,
     HttpResponseException,
+    PartialStateConflictError,
     RequestSendFailed,
     SynapseError,
 )
@@ -58,7 +59,7 @@ from synapse.event_auth import (
     validate_event_for_room_version,
 )
 from synapse.events import EventBase
-from synapse.events.snapshot import EventContext
+from synapse.events.snapshot import EventContext, UnpersistedEventContextBase
 from synapse.federation.federation_client import InvalidResponseError, PulledPduInfo
 from synapse.logging.context import nested_logging_context
 from synapse.logging.opentracing import (
@@ -74,7 +75,6 @@ from synapse.replication.http.federation import (
     ReplicationFederationSendEventsRestServlet,
 )
 from synapse.state import StateResolutionStore
-from synapse.storage.databases.main.events import PartialStateConflictError
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
 from synapse.types import (
     PersistedEventPosition,
@@ -426,7 +426,9 @@ class FederationEventHandler:
         return event, context
 
     async def check_join_restrictions(
-        self, context: EventContext, event: EventBase
+        self,
+        context: UnpersistedEventContextBase,
+        event: EventBase,
     ) -> None:
         """Check that restrictions in restricted join rules are matched
 
@@ -439,16 +441,17 @@ class FederationEventHandler:
         # Check if the user is already in the room or invited to the room.
         user_id = event.state_key
         prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
-        prev_member_event = None
+        prev_membership = None
         if prev_member_event_id:
             prev_member_event = await self._store.get_event(prev_member_event_id)
+            prev_membership = prev_member_event.membership
 
         # Check if the member should be allowed access via membership in a space.
         await self._event_auth_handler.check_restricted_join_rules(
             prev_state_ids,
             event.room_version,
             user_id,
-            prev_member_event,
+            prev_membership,
         )
 
     @trace
@@ -524,11 +527,57 @@ class FederationEventHandler:
             "Peristing join-via-remote %s (partial_state: %s)", event, partial_state
         )
         with nested_logging_context(suffix=event.event_id):
+            if partial_state:
+                # When handling a second partial state join into a partial state room,
+                # the returned state will exclude the membership from the first join. To
+                # preserve prior memberships, we try to compute the partial state before
+                # the event ourselves if we know about any of the prev events.
+                #
+                # When we don't know about any of the prev events, it's fine to just use
+                # the returned state, since the new join will create a new forward
+                # extremity, and leave the forward extremity containing our prior
+                # memberships alone.
+                prev_event_ids = set(event.prev_event_ids())
+                seen_event_ids = await self._store.have_events_in_timeline(
+                    prev_event_ids
+                )
+                missing_event_ids = prev_event_ids - seen_event_ids
+
+                state_maps_to_resolve: List[StateMap[str]] = []
+
+                # Fetch the state after the prev events that we know about.
+                state_maps_to_resolve.extend(
+                    (
+                        await self._state_storage_controller.get_state_groups_ids(
+                            room_id, seen_event_ids, await_full_state=False
+                        )
+                    ).values()
+                )
+
+                # When there are prev events we do not have the state for, we state
+                # resolve with the state returned by the remote homeserver.
+                if missing_event_ids or len(state_maps_to_resolve) == 0:
+                    state_maps_to_resolve.append(
+                        {(e.type, e.state_key): e.event_id for e in state}
+                    )
+
+                state_ids_before_event = (
+                    await self._state_resolution_handler.resolve_events_with_store(
+                        event.room_id,
+                        room_version.identifier,
+                        state_maps_to_resolve,
+                        event_map=None,
+                        state_res_store=StateResolutionStore(self._store),
+                    )
+                )
+            else:
+                state_ids_before_event = {
+                    (e.type, e.state_key): e.event_id for e in state
+                }
+
             context = await self._state_handler.compute_event_context(
                 event,
-                state_ids_before_event={
-                    (e.type, e.state_key): e.event_id for e in state
-                },
+                state_ids_before_event=state_ids_before_event,
                 partial_state=partial_state,
             )
 
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 848e46eb9b..bf0f7acf80 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -219,28 +219,31 @@ class IdentityHandler:
             data = json_decoder.decode(e.msg)  # XXX WAT?
             return data
 
-    async def try_unbind_threepid(self, mxid: str, threepid: dict) -> bool:
-        """Attempt to remove a 3PID from an identity server, or if one is not provided, all
-        identity servers we're aware the binding is present on
+    async def try_unbind_threepid(
+        self, mxid: str, medium: str, address: str, id_server: Optional[str]
+    ) -> bool:
+        """Attempt to remove a 3PID from one or more identity servers.
 
         Args:
             mxid: Matrix user ID of binding to be removed
-            threepid: Dict with medium & address of binding to be
-                removed, and an optional id_server.
+            medium: The medium of the third-party ID.
+            address: The address of the third-party ID.
+            id_server: An identity server to attempt to unbind from. If None,
+                attempt to remove the association from all identity servers
+                known to potentially have it.
 
         Raises:
-            SynapseError: If we failed to contact the identity server
+            SynapseError: If we failed to contact one or more identity servers.
 
         Returns:
-            True on success, otherwise False if the identity
-            server doesn't support unbinding (or no identity server found to
-            contact).
+            True on success, otherwise False if the identity server doesn't
+            support unbinding (or no identity server to contact was found).
         """
-        if threepid.get("id_server"):
-            id_servers = [threepid["id_server"]]
+        if id_server:
+            id_servers = [id_server]
         else:
             id_servers = await self.store.get_id_servers_user_bound(
-                user_id=mxid, medium=threepid["medium"], address=threepid["address"]
+                mxid, medium, address
             )
 
         # We don't know where to unbind, so we don't have a choice but to return
@@ -249,20 +252,21 @@ class IdentityHandler:
 
         changed = True
         for id_server in id_servers:
-            changed &= await self.try_unbind_threepid_with_id_server(
-                mxid, threepid, id_server
+            changed &= await self._try_unbind_threepid_with_id_server(
+                mxid, medium, address, id_server
             )
 
         return changed
 
-    async def try_unbind_threepid_with_id_server(
-        self, mxid: str, threepid: dict, id_server: str
+    async def _try_unbind_threepid_with_id_server(
+        self, mxid: str, medium: str, address: str, id_server: str
     ) -> bool:
         """Removes a binding from an identity server
 
         Args:
             mxid: Matrix user ID of binding to be removed
-            threepid: Dict with medium & address of binding to be removed
+            medium: The medium of the third-party ID
+            address: The address of the third-party ID
             id_server: Identity server to unbind from
 
         Raises:
@@ -286,7 +290,7 @@ class IdentityHandler:
 
         content = {
             "mxid": mxid,
-            "threepid": {"medium": threepid["medium"], "address": threepid["address"]},
+            "threepid": {"medium": medium, "address": address},
         }
 
         # we abuse the federation http client to sign the request, but we have to send it
@@ -319,12 +323,7 @@ class IdentityHandler:
         except RequestTimedOutError:
             raise SynapseError(500, "Timed out contacting identity server")
 
-        await self.store.remove_user_bound_threepid(
-            user_id=mxid,
-            medium=threepid["medium"],
-            address=threepid["address"],
-            id_server=id_server,
-        )
+        await self.store.remove_user_bound_threepid(mxid, medium, address, id_server)
 
         return changed
 
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 191529bd8e..1a29abde98 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -154,9 +154,8 @@ class InitialSyncHandler:
 
         tags_by_room = await self.store.get_tags_for_user(user_id)
 
-        account_data, account_data_by_room = await self.store.get_account_data_for_user(
-            user_id
-        )
+        account_data = await self.store.get_global_account_data_for_user(user_id)
+        account_data_by_room = await self.store.get_room_account_data_for_user(user_id)
 
         public_room_ids = await self.store.get_public_room_ids()
 
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index e688e00575..8f5b658d9d 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -38,6 +38,7 @@ from synapse.api.errors import (
     Codes,
     ConsentNotGivenError,
     NotFoundError,
+    PartialStateConflictError,
     ShadowBanError,
     SynapseError,
     UnstableSpecAuthError,
@@ -48,7 +49,7 @@ from synapse.api.urls import ConsentURIBuilder
 from synapse.event_auth import validate_event_for_room_version
 from synapse.events import EventBase, relation_from_event
 from synapse.events.builder import EventBuilder
-from synapse.events.snapshot import EventContext
+from synapse.events.snapshot import EventContext, UnpersistedEventContextBase
 from synapse.events.utils import maybe_upsert_event_field
 from synapse.events.validator import EventValidator
 from synapse.handlers.directory import DirectoryHandler
@@ -57,7 +58,6 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.replication.http.send_event import ReplicationSendEventRestServlet
 from synapse.replication.http.send_events import ReplicationSendEventsRestServlet
-from synapse.storage.databases.main.events import PartialStateConflictError
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
 from synapse.types import (
     MutableStateMap,
@@ -499,9 +499,9 @@ class EventCreationHandler:
 
         self.request_ratelimiter = hs.get_request_ratelimiter()
 
-        # We arbitrarily limit concurrent event creation for a room to 5.
-        # This is to stop us from diverging history *too* much.
-        self.limiter = Linearizer(max_count=5, name="room_event_creation_limit")
+        # We limit concurrent event creation for a room to 1. This prevents state resolution
+        # from occurring when sending bursts of events to a local room
+        self.limiter = Linearizer(max_count=1, name="room_event_creation_limit")
 
         self._bulk_push_rule_evaluator = hs.get_bulk_push_rule_evaluator()
 
@@ -708,7 +708,7 @@ class EventCreationHandler:
 
         builder.internal_metadata.historical = historical
 
-        event, context = await self.create_new_client_event(
+        event, unpersisted_context = await self.create_new_client_event(
             builder=builder,
             requester=requester,
             allow_no_prev_events=allow_no_prev_events,
@@ -721,6 +721,8 @@ class EventCreationHandler:
             current_state_group=current_state_group,
         )
 
+        context = await unpersisted_context.persist(event)
+
         # In an ideal world we wouldn't need the second part of this condition. However,
         # this behaviour isn't spec'd yet, meaning we should be able to deactivate this
         # behaviour. Another reason is that this code is also evaluated each time a new
@@ -1083,13 +1085,14 @@ class EventCreationHandler:
         state_map: Optional[StateMap[str]] = None,
         for_batch: bool = False,
         current_state_group: Optional[int] = None,
-    ) -> Tuple[EventBase, EventContext]:
+    ) -> Tuple[EventBase, UnpersistedEventContextBase]:
         """Create a new event for a local client. If bool for_batch is true, will
         create an event using the prev_event_ids, and will create an event context for
         the event using the parameters state_map and current_state_group, thus these parameters
         must be provided in this case if for_batch is True. The subsequently created event
         and context are suitable for being batched up and bulk persisted to the database
-        with other similarly created events.
+        with other similarly created events. Note that this returns an UnpersistedEventContext,
+        which must be converted to an EventContext before it can be sent to the DB.
 
         Args:
             builder:
@@ -1131,7 +1134,7 @@ class EventCreationHandler:
                 batch persisting
 
         Returns:
-            Tuple of created event, context
+            Tuple of created event, UnpersistedEventContext
         """
         # Strip down the state_event_ids to only what we need to auth the event.
         # For example, we don't need extra m.room.member that don't match event.sender
@@ -1192,9 +1195,16 @@ class EventCreationHandler:
             event = await builder.build(
                 prev_event_ids=prev_event_ids, auth_event_ids=auth_ids, depth=depth
             )
-            context = await self.state.compute_event_context_for_batched(
-                event, state_map, current_state_group
+
+            context: UnpersistedEventContextBase = (
+                await self.state.calculate_context_info(
+                    event,
+                    state_ids_before_event=state_map,
+                    partial_state=False,
+                    state_group_before_event=current_state_group,
+                )
             )
+
         else:
             event = await builder.build(
                 prev_event_ids=prev_event_ids,
@@ -1244,16 +1254,17 @@ class EventCreationHandler:
 
                     state_map_for_event[(data.event_type, data.state_key)] = state_id
 
-                context = await self.state.compute_event_context(
+                # TODO(faster_joins): check how MSC2716 works and whether we can have
+                #   partial state here
+                #   https://github.com/matrix-org/synapse/issues/13003
+                context = await self.state.calculate_context_info(
                     event,
                     state_ids_before_event=state_map_for_event,
-                    # TODO(faster_joins): check how MSC2716 works and whether we can have
-                    #   partial state here
-                    #   https://github.com/matrix-org/synapse/issues/13003
                     partial_state=False,
                 )
+
             else:
-                context = await self.state.compute_event_context(event)
+                context = await self.state.calculate_context_info(event)
 
         if requester:
             context.app_service = requester.app_service
@@ -2082,9 +2093,9 @@ class EventCreationHandler:
 
     async def _rebuild_event_after_third_party_rules(
         self, third_party_result: dict, original_event: EventBase
-    ) -> Tuple[EventBase, EventContext]:
+    ) -> Tuple[EventBase, UnpersistedEventContextBase]:
         # the third_party_event_rules want to replace the event.
-        # we do some basic checks, and then return the replacement event and context.
+        # we do some basic checks, and then return the replacement event.
 
         # Construct a new EventBuilder and validate it, which helps with the
         # rest of these checks.
@@ -2138,5 +2149,6 @@ class EventCreationHandler:
 
         # we rebuild the event context, to be on the safe side. If nothing else,
         # delta_ids might need an update.
-        context = await self.state.compute_event_context(event)
+        context = await self.state.calculate_context_info(event)
+
         return event, context
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 04c61ae3dd..2bacdebfb5 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence, Tuple
 
 from synapse.api.constants import EduTypes, ReceiptTypes
 from synapse.appservice import ApplicationService
@@ -189,7 +189,7 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
 
     @staticmethod
     def filter_out_private_receipts(
-        rooms: List[JsonDict], user_id: str
+        rooms: Sequence[JsonDict], user_id: str
     ) -> List[JsonDict]:
         """
         Filters a list of serialized receipts (as returned by /sync and /initialSync)
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 7ba7c4ff07..837dabb3b7 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -43,6 +43,7 @@ from synapse.api.errors import (
     Codes,
     LimitExceededError,
     NotFoundError,
+    PartialStateConflictError,
     StoreError,
     SynapseError,
 )
@@ -54,7 +55,6 @@ from synapse.events.utils import copy_and_fixup_power_levels_contents
 from synapse.handlers.relations import BundledAggregations
 from synapse.module_api import NOT_SPAM
 from synapse.rest.admin._base import assert_user_is_admin
-from synapse.storage.databases.main.events import PartialStateConflictError
 from synapse.streams import EventSource
 from synapse.types import (
     JsonDict,
@@ -1076,7 +1076,7 @@ class RoomCreationHandler:
         state_map: MutableStateMap[str] = {}
         # current_state_group of last event created. Used for computing event context of
         # events to be batched
-        current_state_group = None
+        current_state_group: Optional[int] = None
 
         def create_event_dict(etype: str, content: JsonDict, **kwargs: Any) -> JsonDict:
             e = {"type": etype, "content": content}
@@ -1928,6 +1928,6 @@ class RoomShutdownHandler:
         return {
             "kicked_users": kicked_users,
             "failed_to_kick_users": failed_to_kick_users,
-            "local_aliases": aliases_for_room,
+            "local_aliases": list(aliases_for_room),
             "new_room_id": new_room_id,
         }
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index d236cc09b5..a965c7ec76 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -26,7 +26,13 @@ from synapse.api.constants import (
     GuestAccess,
     Membership,
 )
-from synapse.api.errors import AuthError, Codes, ShadowBanError, SynapseError
+from synapse.api.errors import (
+    AuthError,
+    Codes,
+    PartialStateConflictError,
+    ShadowBanError,
+    SynapseError,
+)
 from synapse.api.ratelimiting import Ratelimiter
 from synapse.event_auth import get_named_level, get_power_level_event
 from synapse.events import EventBase
@@ -34,7 +40,6 @@ from synapse.events.snapshot import EventContext
 from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN
 from synapse.logging import opentracing
 from synapse.module_api import NOT_SPAM
-from synapse.storage.databases.main.events import PartialStateConflictError
 from synapse.types import (
     JsonDict,
     Requester,
@@ -56,6 +61,13 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
+class NoKnownServersError(SynapseError):
+    """No server already resident to the room was provided to the join/knock operation."""
+
+    def __init__(self, msg: str = "No known servers"):
+        super().__init__(404, msg)
+
+
 class RoomMemberHandler(metaclass=abc.ABCMeta):
     # TODO(paul): This handler currently contains a messy conflation of
     #   low-level API that works on UserID objects and so on, and REST-level
@@ -185,6 +197,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             room_id: Room that we are trying to join
             user: User who is trying to join
             content: A dict that should be used as the content of the join event.
+
+        Raises:
+            NoKnownServersError: if remote_room_hosts does not contain a server joined to
+                the room.
         """
         raise NotImplementedError()
 
@@ -484,7 +500,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             user_id: The user's ID.
         """
         # Retrieve user account data for predecessor room
-        user_account_data, _ = await self.store.get_account_data_for_user(user_id)
+        user_account_data = await self.store.get_global_account_data_for_user(user_id)
 
         # Copy direct message state if applicable
         direct_rooms = user_account_data.get(AccountDataTypes.DIRECT, {})
@@ -823,14 +839,19 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
 
         latest_event_ids = await self.store.get_prev_events_for_room(room_id)
 
-        state_before_join = await self.state_handler.compute_state_after_events(
-            room_id, latest_event_ids
+        is_partial_state_room = await self.store.is_partial_state_room(room_id)
+        partial_state_before_join = await self.state_handler.compute_state_after_events(
+            room_id, latest_event_ids, await_full_state=False
         )
+        # `is_partial_state_room` also indicates whether `partial_state_before_join` is
+        # partial.
 
         # TODO: Refactor into dictionary of explicitly allowed transitions
         # between old and new state, with specific error messages for some
         # transitions and generic otherwise
-        old_state_id = state_before_join.get((EventTypes.Member, target.to_string()))
+        old_state_id = partial_state_before_join.get(
+            (EventTypes.Member, target.to_string())
+        )
         if old_state_id:
             old_state = await self.store.get_event(old_state_id, allow_none=True)
             old_membership = old_state.content.get("membership") if old_state else None
@@ -881,11 +902,11 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             if action == "kick":
                 raise AuthError(403, "The target user is not in the room")
 
-        is_host_in_room = await self._is_host_in_room(state_before_join)
+        is_host_in_room = await self._is_host_in_room(partial_state_before_join)
 
         if effective_membership_state == Membership.JOIN:
             if requester.is_guest:
-                guest_can_join = await self._can_guest_join(state_before_join)
+                guest_can_join = await self._can_guest_join(partial_state_before_join)
                 if not guest_can_join:
                     # This should be an auth check, but guests are a local concept,
                     # so don't really fit into the general auth process.
@@ -927,8 +948,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                 room_id,
                 remote_room_hosts,
                 content,
+                is_partial_state_room,
                 is_host_in_room,
-                state_before_join,
+                partial_state_before_join,
             )
             if remote_join:
                 if ratelimit:
@@ -1073,8 +1095,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         room_id: str,
         remote_room_hosts: List[str],
         content: JsonDict,
+        is_partial_state_room: bool,
         is_host_in_room: bool,
-        state_before_join: StateMap[str],
+        partial_state_before_join: StateMap[str],
     ) -> Tuple[bool, List[str]]:
         """
         Check whether the server should do a remote join (as opposed to a local
@@ -1093,9 +1116,12 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             remote_room_hosts: A list of remote room hosts.
             content: The content to use as the event body of the join. This may
                 be modified.
-            is_host_in_room: True if the host is in the room.
-            state_before_join: The state before the join event (i.e. the resolution of
-                the states after its parent events).
+            is_partial_state_room: `True` if the server currently doesn't hold the full
+                state of the room.
+            is_host_in_room: `True` if the host is in the room.
+            partial_state_before_join: The state before the join event (i.e. the
+                resolution of the states after its parent events). May be full or
+                partial state, depending on `is_partial_state_room`.
 
         Returns:
             A tuple of:
@@ -1109,6 +1135,23 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         if not is_host_in_room:
             return True, remote_room_hosts
 
+        prev_member_event_id = partial_state_before_join.get(
+            (EventTypes.Member, user_id), None
+        )
+        previous_membership = None
+        if prev_member_event_id:
+            prev_member_event = await self.store.get_event(prev_member_event_id)
+            previous_membership = prev_member_event.membership
+
+        # If we are not fully joined yet, and the target is not already in the room,
+        # let's do a remote join so another server with the full state can validate
+        # that the user has not been banned for example.
+        # We could just accept the join and wait for state res to resolve that later on
+        # but we would then leak room history to this person until then, which is pretty
+        # bad.
+        if is_partial_state_room and previous_membership != Membership.JOIN:
+            return True, remote_room_hosts
+
         # If the host is in the room, but not one of the authorised hosts
         # for restricted join rules, a remote join must be used.
         room_version = await self.store.get_room_version(room_id)
@@ -1116,21 +1159,19 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         # If restricted join rules are not being used, a local join can always
         # be used.
         if not await self.event_auth_handler.has_restricted_join_rules(
-            state_before_join, room_version
+            partial_state_before_join, room_version
         ):
             return False, []
 
         # If the user is invited to the room or already joined, the join
         # event can always be issued locally.
-        prev_member_event_id = state_before_join.get((EventTypes.Member, user_id), None)
-        prev_member_event = None
-        if prev_member_event_id:
-            prev_member_event = await self.store.get_event(prev_member_event_id)
-            if prev_member_event.membership in (
-                Membership.JOIN,
-                Membership.INVITE,
-            ):
-                return False, []
+        if previous_membership in (Membership.JOIN, Membership.INVITE):
+            return False, []
+
+        # All the partial state cases are covered above. We have been given the full
+        # state of the room.
+        assert not is_partial_state_room
+        state_before_join = partial_state_before_join
 
         # If the local host has a user who can issue invites, then a local
         # join can be done.
@@ -1154,7 +1195,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
 
         # Ensure the member should be allowed access via membership in a room.
         await self.event_auth_handler.check_restricted_join_rules(
-            state_before_join, room_version, user_id, prev_member_event
+            state_before_join, room_version, user_id, previous_membership
         )
 
         # If this is going to be a local join, additional information must
@@ -1304,11 +1345,17 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                 if prev_member_event.membership == Membership.JOIN:
                     await self._user_left_room(target_user, room_id)
 
-    async def _can_guest_join(self, current_state_ids: StateMap[str]) -> bool:
+    async def _can_guest_join(self, partial_current_state_ids: StateMap[str]) -> bool:
         """
         Returns whether a guest can join a room based on its current state.
+
+        Args:
+            partial_current_state_ids: The current state of the room. May be full or
+                partial state.
         """
-        guest_access_id = current_state_ids.get((EventTypes.GuestAccess, ""), None)
+        guest_access_id = partial_current_state_ids.get(
+            (EventTypes.GuestAccess, ""), None
+        )
         if not guest_access_id:
             return False
 
@@ -1634,19 +1681,25 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         )
         return event, stream_id
 
-    async def _is_host_in_room(self, current_state_ids: StateMap[str]) -> bool:
+    async def _is_host_in_room(self, partial_current_state_ids: StateMap[str]) -> bool:
+        """Returns whether the homeserver is in the room based on its current state.
+
+        Args:
+            partial_current_state_ids: The current state of the room. May be full or
+                partial state.
+        """
         # Have we just created the room, and is this about to be the very
         # first member event?
-        create_event_id = current_state_ids.get(("m.room.create", ""))
-        if len(current_state_ids) == 1 and create_event_id:
+        create_event_id = partial_current_state_ids.get(("m.room.create", ""))
+        if len(partial_current_state_ids) == 1 and create_event_id:
             # We can only get here if we're in the process of creating the room
             return True
 
-        for etype, state_key in current_state_ids:
+        for etype, state_key in partial_current_state_ids:
             if etype != EventTypes.Member or not self.hs.is_mine_id(state_key):
                 continue
 
-            event_id = current_state_ids[(etype, state_key)]
+            event_id = partial_current_state_ids[(etype, state_key)]
             event = await self.store.get_event(event_id, allow_none=True)
             if not event:
                 continue
@@ -1715,8 +1768,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
         ]
 
         if len(remote_room_hosts) == 0:
-            raise SynapseError(
-                404,
+            raise NoKnownServersError(
                 "Can't join remote room because no servers "
                 "that are in the room have been provided.",
             )
@@ -1947,7 +1999,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
         ]
 
         if len(remote_room_hosts) == 0:
-            raise SynapseError(404, "No known servers")
+            raise NoKnownServersError()
 
         return await self.federation_handler.do_knock(
             remote_room_hosts, room_id, user.to_string(), content=content
diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py
index 221552a2a6..ba261702d4 100644
--- a/synapse/handlers/room_member_worker.py
+++ b/synapse/handlers/room_member_worker.py
@@ -15,8 +15,7 @@
 import logging
 from typing import TYPE_CHECKING, List, Optional, Tuple
 
-from synapse.api.errors import SynapseError
-from synapse.handlers.room_member import RoomMemberHandler
+from synapse.handlers.room_member import NoKnownServersError, RoomMemberHandler
 from synapse.replication.http.membership import (
     ReplicationRemoteJoinRestServlet as ReplRemoteJoin,
     ReplicationRemoteKnockRestServlet as ReplRemoteKnock,
@@ -52,7 +51,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
     ) -> Tuple[str, int]:
         """Implements RoomMemberHandler._remote_join"""
         if len(remote_room_hosts) == 0:
-            raise SynapseError(404, "No known servers")
+            raise NoKnownServersError()
 
         ret = await self._remote_join_client(
             requester=requester,
diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py
index 4472019fbc..807245160d 100644
--- a/synapse/handlers/room_summary.py
+++ b/synapse/handlers/room_summary.py
@@ -521,8 +521,8 @@ class RoomSummaryHandler:
 
         It should return true if:
 
-        * The requester is joined or can join the room (per MSC3173).
-        * The origin server has any user that is joined or can join the room.
+        * The requesting user is joined or can join the room (per MSC3173); or
+        * The origin server has any user that is joined or can join the room; or
         * The history visibility is set to world readable.
 
         Args:
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 3566537894..4e4595312c 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -269,6 +269,8 @@ class SyncHandler:
         self._state_storage_controller = self._storage_controllers.state
         self._device_handler = hs.get_device_handler()
 
+        self.should_calculate_push_rules = hs.config.push.enable_push
+
         # TODO: flush cache entries on subsequent sync request.
         #    Once we get the next /sync request (ie, one with the same access token
         #    that sets 'since' to 'next_batch'), we know that device won't need a
@@ -1288,6 +1290,12 @@ class SyncHandler:
     async def unread_notifs_for_room_id(
         self, room_id: str, sync_config: SyncConfig
     ) -> RoomNotifCounts:
+        if not self.should_calculate_push_rules:
+            # If push rules have been universally disabled then we know we won't
+            # have any unread counts in the DB, so we may as well skip asking
+            # the DB.
+            return RoomNotifCounts.empty()
+
         with Measure(self.clock, "unread_notifs_for_room_id"):
 
             return await self.store.get_unread_event_push_actions_by_room_for_user(
@@ -1391,6 +1399,11 @@ class SyncHandler:
                 for room_id, is_partial_state in results.items()
                 if is_partial_state
             )
+            membership_change_events = [
+                event
+                for event in membership_change_events
+                if not results.get(event.room_id, False)
+            ]
 
         # Incremental eager syncs should additionally include rooms that
         # - we are joined to
@@ -1444,9 +1457,9 @@ class SyncHandler:
 
         logger.debug("Fetching account data")
 
-        account_data_by_room = await self._generate_sync_entry_for_account_data(
-            sync_result_builder
-        )
+        # Global account data is included if it is not filtered out.
+        if not sync_config.filter_collection.blocks_all_global_account_data():
+            await self._generate_sync_entry_for_account_data(sync_result_builder)
 
         # Presence data is included if the server has it enabled and not filtered out.
         include_presence_data = bool(
@@ -1472,9 +1485,7 @@ class SyncHandler:
             (
                 newly_joined_rooms,
                 newly_left_rooms,
-            ) = await self._generate_sync_entry_for_rooms(
-                sync_result_builder, account_data_by_room
-            )
+            ) = await self._generate_sync_entry_for_rooms(sync_result_builder)
 
             # Work out which users have joined or left rooms we're in. We use this
             # to build the presence and device_list parts of the sync response in
@@ -1521,7 +1532,7 @@ class SyncHandler:
             one_time_keys_count = await self.store.count_e2e_one_time_keys(
                 user_id, device_id
             )
-            unused_fallback_key_types = (
+            unused_fallback_key_types = list(
                 await self.store.get_e2e_unused_fallback_key_types(user_id, device_id)
             )
 
@@ -1717,35 +1728,29 @@ class SyncHandler:
 
     async def _generate_sync_entry_for_account_data(
         self, sync_result_builder: "SyncResultBuilder"
-    ) -> Dict[str, Dict[str, JsonDict]]:
-        """Generates the account data portion of the sync response.
+    ) -> None:
+        """Generates the global account data portion of the sync response.
 
         Account data (called "Client Config" in the spec) can be set either globally
         or for a specific room. Account data consists of a list of events which
         accumulate state, much like a room.
 
-        This function retrieves global and per-room account data. The former is written
-        to the given `sync_result_builder`. The latter is returned directly, to be
-        later written to the `sync_result_builder` on a room-by-room basis.
+        This function retrieves global account data and writes it to the given
+        `sync_result_builder`. See `_generate_sync_entry_for_rooms` for handling
+         of per-room account data.
 
         Args:
             sync_result_builder
-
-        Returns:
-            A dictionary whose keys (room ids) map to the per room account data for that
-            room.
         """
         sync_config = sync_result_builder.sync_config
         user_id = sync_result_builder.sync_config.user.to_string()
         since_token = sync_result_builder.since_token
 
         if since_token and not sync_result_builder.full_state:
-            # TODO Do not fetch room account data if it will be unused.
-            (
-                global_account_data,
-                account_data_by_room,
-            ) = await self.store.get_updated_account_data_for_user(
-                user_id, since_token.account_data_key
+            global_account_data = (
+                await self.store.get_updated_global_account_data_for_user(
+                    user_id, since_token.account_data_key
+                )
             )
 
             push_rules_changed = await self.store.have_push_rules_changed_for_user(
@@ -1753,31 +1758,31 @@ class SyncHandler:
             )
 
             if push_rules_changed:
+                global_account_data = dict(global_account_data)
                 global_account_data["m.push_rules"] = await self.push_rules_for_user(
                     sync_config.user
                 )
         else:
-            # TODO Do not fetch room account data if it will be unused.
-            (
-                global_account_data,
-                account_data_by_room,
-            ) = await self.store.get_account_data_for_user(sync_config.user.to_string())
+            all_global_account_data = await self.store.get_global_account_data_for_user(
+                user_id
+            )
 
+            global_account_data = dict(all_global_account_data)
             global_account_data["m.push_rules"] = await self.push_rules_for_user(
                 sync_config.user
             )
 
-        account_data_for_user = await sync_config.filter_collection.filter_account_data(
-            [
-                {"type": account_data_type, "content": content}
-                for account_data_type, content in global_account_data.items()
-            ]
+        account_data_for_user = (
+            await sync_config.filter_collection.filter_global_account_data(
+                [
+                    {"type": account_data_type, "content": content}
+                    for account_data_type, content in global_account_data.items()
+                ]
+            )
         )
 
         sync_result_builder.account_data = account_data_for_user
 
-        return account_data_by_room
-
     async def _generate_sync_entry_for_presence(
         self,
         sync_result_builder: "SyncResultBuilder",
@@ -1837,9 +1842,7 @@ class SyncHandler:
         sync_result_builder.presence = presence
 
     async def _generate_sync_entry_for_rooms(
-        self,
-        sync_result_builder: "SyncResultBuilder",
-        account_data_by_room: Dict[str, Dict[str, JsonDict]],
+        self, sync_result_builder: "SyncResultBuilder"
     ) -> Tuple[AbstractSet[str], AbstractSet[str]]:
         """Generates the rooms portion of the sync response. Populates the
         `sync_result_builder` with the result.
@@ -1850,7 +1853,6 @@ class SyncHandler:
 
         Args:
             sync_result_builder
-            account_data_by_room: Dictionary of per room account data
 
         Returns:
             Returns a 2-tuple describing rooms the user has joined or left.
@@ -1863,9 +1865,30 @@ class SyncHandler:
         since_token = sync_result_builder.since_token
         user_id = sync_result_builder.sync_config.user.to_string()
 
+        blocks_all_rooms = (
+            sync_result_builder.sync_config.filter_collection.blocks_all_rooms()
+        )
+
+        # 0. Start by fetching room account data (if required).
+        if (
+            blocks_all_rooms
+            or sync_result_builder.sync_config.filter_collection.blocks_all_room_account_data()
+        ):
+            account_data_by_room: Mapping[str, Mapping[str, JsonDict]] = {}
+        elif since_token and not sync_result_builder.full_state:
+            account_data_by_room = (
+                await self.store.get_updated_room_account_data_for_user(
+                    user_id, since_token.account_data_key
+                )
+            )
+        else:
+            account_data_by_room = await self.store.get_room_account_data_for_user(
+                user_id
+            )
+
         # 1. Start by fetching all ephemeral events in rooms we've joined (if required).
         block_all_room_ephemeral = (
-            sync_result_builder.sync_config.filter_collection.blocks_all_rooms()
+            blocks_all_rooms
             or sync_result_builder.sync_config.filter_collection.blocks_all_room_ephemeral()
         )
         if block_all_room_ephemeral:
@@ -2291,8 +2314,8 @@ class SyncHandler:
         sync_result_builder: "SyncResultBuilder",
         room_builder: "RoomSyncResultBuilder",
         ephemeral: List[JsonDict],
-        tags: Optional[Dict[str, Dict[str, Any]]],
-        account_data: Dict[str, JsonDict],
+        tags: Optional[Mapping[str, Mapping[str, Any]]],
+        account_data: Mapping[str, JsonDict],
         always_include: bool = False,
     ) -> None:
         """Populates the `joined` and `archived` section of `sync_result_builder`
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 2563858f3c..9314454af1 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -30,7 +30,6 @@ from typing import (
     Iterable,
     Iterator,
     List,
-    NoReturn,
     Optional,
     Pattern,
     Tuple,
@@ -340,7 +339,8 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
 
             return callback_return
 
-        return _unrecognised_request_handler(request)
+        # A request with an unknown method (for a known endpoint) was received.
+        raise UnrecognizedRequestError(code=405)
 
     @abc.abstractmethod
     def _send_response(
@@ -396,7 +396,6 @@ class DirectServeJsonResource(_AsyncResource):
 
 @attr.s(slots=True, frozen=True, auto_attribs=True)
 class _PathEntry:
-    pattern: Pattern
     callback: ServletCallback
     servlet_classname: str
 
@@ -425,13 +424,14 @@ class JsonResource(DirectServeJsonResource):
     ):
         super().__init__(canonical_json, extract_context)
         self.clock = hs.get_clock()
-        self.path_regexs: Dict[bytes, List[_PathEntry]] = {}
+        # Map of path regex -> method -> callback.
+        self._routes: Dict[Pattern[str], Dict[bytes, _PathEntry]] = {}
         self.hs = hs
 
     def register_paths(
         self,
         method: str,
-        path_patterns: Iterable[Pattern],
+        path_patterns: Iterable[Pattern[str]],
         callback: ServletCallback,
         servlet_classname: str,
     ) -> None:
@@ -455,8 +455,8 @@ class JsonResource(DirectServeJsonResource):
 
         for path_pattern in path_patterns:
             logger.debug("Registering for %s %s", method, path_pattern.pattern)
-            self.path_regexs.setdefault(method_bytes, []).append(
-                _PathEntry(path_pattern, callback, servlet_classname)
+            self._routes.setdefault(path_pattern, {})[method_bytes] = _PathEntry(
+                callback, servlet_classname
             )
 
     def _get_handler_for_request(
@@ -478,14 +478,17 @@ class JsonResource(DirectServeJsonResource):
 
         # Loop through all the registered callbacks to check if the method
         # and path regex match
-        for path_entry in self.path_regexs.get(request_method, []):
-            m = path_entry.pattern.match(request_path)
+        for path_pattern, methods in self._routes.items():
+            m = path_pattern.match(request_path)
             if m:
-                # We found a match!
+                # We found a matching path!
+                path_entry = methods.get(request_method)
+                if not path_entry:
+                    raise UnrecognizedRequestError(code=405)
                 return path_entry.callback, path_entry.servlet_classname, m.groupdict()
 
-        # Huh. No one wanted to handle that? Fiiiiiine. Send 400.
-        return _unrecognised_request_handler, "unrecognised_request_handler", {}
+        # Huh. No one wanted to handle that? Fiiiiiine.
+        raise UnrecognizedRequestError(code=404)
 
     async def _async_render(self, request: SynapseRequest) -> Tuple[int, Any]:
         callback, servlet_classname, group_dict = self._get_handler_for_request(request)
@@ -567,19 +570,6 @@ class StaticResource(File):
         return super().render_GET(request)
 
 
-def _unrecognised_request_handler(request: Request) -> NoReturn:
-    """Request handler for unrecognised requests
-
-    This is a request handler suitable for return from
-    _get_handler_for_request. It actually just raises an
-    UnrecognizedRequestError.
-
-    Args:
-        request: Unused, but passed in to match the signature of ServletCallback.
-    """
-    raise UnrecognizedRequestError(code=404)
-
-
 class UnrecognizedRequestResource(resource.Resource):
     """
     Similar to twisted.web.resource.NoResource, but returns a JSON 404 with an
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index 8ef9a0dda8..6c7cf1b294 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -466,8 +466,16 @@ def init_tracer(hs: "HomeServer") -> None:
         STRIP_INSTANCE_NUMBER_SUFFIX_REGEX, "", hs.get_instance_name()
     )
 
+    jaeger_config = hs.config.tracing.jaeger_config
+    tags = jaeger_config.setdefault("tags", {})
+
+    # tag the Synapse instance name so that it's an easy jumping
+    # off point into the logs. Can also be used to filter for an
+    # instance that is under load.
+    tags[SynapseTags.INSTANCE_NAME] = hs.get_instance_name()
+
     config = JaegerConfig(
-        config=hs.config.tracing.jaeger_config,
+        config=jaeger_config,
         service_name=f"{hs.config.server.server_name} {instance_name_by_type}",
         scope_manager=LogContextScopeManager(),
         metrics_factory=PrometheusMetricsFactory(),
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index d9c0a98f44..f6a5bffb0f 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -22,6 +22,7 @@ from typing import (
     List,
     Mapping,
     Optional,
+    Sequence,
     Set,
     Tuple,
     Union,
@@ -43,6 +44,7 @@ from synapse.events.snapshot import EventContext
 from synapse.state import POWER_KEY
 from synapse.storage.databases.main.roommember import EventIdMembership
 from synapse.synapse_rust.push import FilteredPushRules, PushRuleEvaluator
+from synapse.types import SimpleJsonValue
 from synapse.types.state import StateFilter
 from synapse.util.caches import register_cache
 from synapse.util.metrics import measure_func
@@ -148,7 +150,7 @@ class BulkPushRuleEvaluator:
         # little, we can skip fetching a huge number of push rules in large rooms.
         # This helps make joins and leaves faster.
         if event.type == EventTypes.Member:
-            local_users = []
+            local_users: Sequence[str] = []
             # We never notify a user about their own actions. This is enforced in
             # `_action_for_event_by_user` in the loop over `rules_by_user`, but we
             # do the same check here to avoid unnecessary DB queries.
@@ -183,7 +185,6 @@ class BulkPushRuleEvaluator:
         if event.type == EventTypes.Member and event.membership == Membership.INVITE:
             invited = event.state_key
             if invited and self.hs.is_mine_id(invited) and invited not in local_users:
-                local_users = list(local_users)
                 local_users.append(invited)
 
         if not local_users:
@@ -256,13 +257,15 @@ class BulkPushRuleEvaluator:
 
         return pl_event.content if pl_event else {}, sender_level
 
-    async def _related_events(self, event: EventBase) -> Dict[str, Dict[str, str]]:
+    async def _related_events(
+        self, event: EventBase
+    ) -> Dict[str, Dict[str, SimpleJsonValue]]:
         """Fetches the related events for 'event'. Sets the im.vector.is_falling_back key if the event is from a fallback relation
 
         Returns:
             Mapping of relation type to flattened events.
         """
-        related_events: Dict[str, Dict[str, str]] = {}
+        related_events: Dict[str, Dict[str, SimpleJsonValue]] = {}
         if self._related_event_match_enabled:
             related_event_id = event.content.get("m.relates_to", {}).get("event_id")
             relation_type = event.content.get("m.relates_to", {}).get("rel_type")
@@ -271,7 +274,10 @@ class BulkPushRuleEvaluator:
                     related_event_id, allow_none=True
                 )
                 if related_event is not None:
-                    related_events[relation_type] = _flatten_dict(related_event)
+                    related_events[relation_type] = _flatten_dict(
+                        related_event,
+                        msc3783_escape_event_match_key=self.hs.config.experimental.msc3783_escape_event_match_key,
+                    )
 
             reply_event_id = (
                 event.content.get("m.relates_to", {})
@@ -286,7 +292,10 @@ class BulkPushRuleEvaluator:
                 )
 
                 if related_event is not None:
-                    related_events["m.in_reply_to"] = _flatten_dict(related_event)
+                    related_events["m.in_reply_to"] = _flatten_dict(
+                        related_event,
+                        msc3783_escape_event_match_key=self.hs.config.experimental.msc3783_escape_event_match_key,
+                    )
 
                     # indicate that this is from a fallback relation.
                     if relation_type == "m.thread" and event.content.get(
@@ -405,7 +414,10 @@ class BulkPushRuleEvaluator:
             room_mention = mentions.get("room") is True
 
         evaluator = PushRuleEvaluator(
-            _flatten_dict(event),
+            _flatten_dict(
+                event,
+                msc3783_escape_event_match_key=self.hs.config.experimental.msc3783_escape_event_match_key,
+            ),
             has_mentions,
             user_mentions,
             room_mention,
@@ -416,6 +428,7 @@ class BulkPushRuleEvaluator:
             self._related_event_match_enabled,
             event.room_version.msc3931_push_features,
             self.hs.config.experimental.msc1767_enabled,  # MSC3931 flag
+            self.hs.config.experimental.msc3758_exact_event_match,
         )
 
         users = rules_by_user.keys()
@@ -492,13 +505,15 @@ StateGroup = Union[object, int]
 def _flatten_dict(
     d: Union[EventBase, Mapping[str, Any]],
     prefix: Optional[List[str]] = None,
-    result: Optional[Dict[str, str]] = None,
-) -> Dict[str, str]:
+    result: Optional[Dict[str, SimpleJsonValue]] = None,
+    *,
+    msc3783_escape_event_match_key: bool = False,
+) -> Dict[str, SimpleJsonValue]:
     """
     Given a JSON dictionary (or event) which might contain sub dictionaries,
     flatten it into a single layer dictionary by combining the keys & sub-keys.
 
-    Any (non-dictionary), non-string value is dropped.
+    String, integer, boolean, and null values are kept. All others are dropped.
 
     Transforms:
 
@@ -521,11 +536,22 @@ def _flatten_dict(
     if result is None:
         result = {}
     for key, value in d.items():
-        if isinstance(value, str):
-            result[".".join(prefix + [key])] = value.lower()
+        if msc3783_escape_event_match_key:
+            # Escape periods in the key with a backslash (and backslashes with an
+            # extra backslash). This is since a period is used as a separator between
+            # nested fields.
+            key = key.replace("\\", "\\\\").replace(".", "\\.")
+
+        if isinstance(value, (bool, str)) or type(value) is int or value is None:
+            result[".".join(prefix + [key])] = value
         elif isinstance(value, Mapping):
             # do not set `room_version` due to recursion considerations below
-            _flatten_dict(value, prefix=(prefix + [key]), result=result)
+            _flatten_dict(
+                value,
+                prefix=(prefix + [key]),
+                result=result,
+                msc3783_escape_event_match_key=msc3783_escape_event_match_key,
+            )
 
     # `room_version` should only ever be set when looking at the top level of an event
     if (
diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py
index 0d072c42a7..c134ccfb3d 100644
--- a/synapse/rest/admin/media.py
+++ b/synapse/rest/admin/media.py
@@ -15,7 +15,7 @@
 
 import logging
 from http import HTTPStatus
-from typing import TYPE_CHECKING, Tuple
+from typing import TYPE_CHECKING, Optional, Tuple
 
 from synapse.api.constants import Direction
 from synapse.api.errors import Codes, NotFoundError, SynapseError
@@ -285,7 +285,12 @@ class DeleteMediaByDateSize(RestServlet):
     timestamp and size.
     """
 
-    PATTERNS = admin_patterns("/media/(?P<server_name>[^/]*)/delete$")
+    PATTERNS = [
+        *admin_patterns("/media/delete$"),
+        # This URL kept around for legacy reasons, it is undesirable since it
+        # overlaps with the DeleteMediaByID servlet.
+        *admin_patterns("/media/(?P<server_name>[^/]*)/delete$"),
+    ]
 
     def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastores().main
@@ -294,7 +299,7 @@ class DeleteMediaByDateSize(RestServlet):
         self.media_repository = hs.get_media_repository()
 
     async def on_POST(
-        self, request: SynapseRequest, server_name: str
+        self, request: SynapseRequest, server_name: Optional[str] = None
     ) -> Tuple[int, JsonDict]:
         await assert_requester_is_admin(self.auth, request)
 
@@ -322,7 +327,8 @@ class DeleteMediaByDateSize(RestServlet):
                 errcode=Codes.INVALID_PARAM,
             )
 
-        if self.server_name != server_name:
+        # This check is useless, we keep it for the legacy endpoint only.
+        if server_name is not None and self.server_name != server_name:
             raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only delete local media")
 
         logging.info(
@@ -489,6 +495,8 @@ def register_servlets_for_media_repo(hs: "HomeServer", http_server: HttpServer)
     ProtectMediaByID(hs).register(http_server)
     UnprotectMediaByID(hs).register(http_server)
     ListMediaInRoom(hs).register(http_server)
-    DeleteMediaByID(hs).register(http_server)
+    # XXX DeleteMediaByDateSize must be registered before DeleteMediaByID as
+    #     their URL routes overlap.
     DeleteMediaByDateSize(hs).register(http_server)
+    DeleteMediaByID(hs).register(http_server)
     UserMediaRestServlet(hs).register(http_server)
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index b9dca8ef3a..0c0bf540b9 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -1192,7 +1192,8 @@ class AccountDataRestServlet(RestServlet):
         if not await self._store.get_user_by_id(user_id):
             raise NotFoundError("User not found")
 
-        global_data, by_room_data = await self._store.get_account_data_for_user(user_id)
+        global_data = await self._store.get_global_account_data_for_user(user_id)
+        by_room_data = await self._store.get_room_account_data_for_user(user_id)
         return HTTPStatus.OK, {
             "account_data": {
                 "global": global_data,
diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py
index 4373c73662..662f5bf762 100644
--- a/synapse/rest/client/account.py
+++ b/synapse/rest/client/account.py
@@ -415,6 +415,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
             request, MsisdnRequestTokenBody
         )
         msisdn = phone_number_to_msisdn(body.country, body.phone_number)
+        logger.info("Request #%s to verify ownership of %s", body.send_attempt, msisdn)
 
         if not await check_3pid_allowed(self.hs, "msisdn", msisdn):
             raise SynapseError(
@@ -444,6 +445,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
                 await self.hs.get_clock().sleep(random.randint(1, 10) / 10)
                 return 200, {"sid": random_string(16)}
 
+            logger.info("MSISDN %s is already in use by %s", msisdn, existing_user_id)
             raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE)
 
         if not self.hs.config.registration.account_threepid_delegate_msisdn:
@@ -468,6 +470,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
         threepid_send_requests.labels(type="msisdn", reason="add_threepid").observe(
             body.send_attempt
         )
+        logger.info("MSISDN %s: got response from identity server: %s", msisdn, ret)
 
         return 200, ret
 
@@ -734,12 +737,7 @@ class ThreepidUnbindRestServlet(RestServlet):
         # Attempt to unbind the threepid from an identity server. If id_server is None, try to
         # unbind from all identity servers this threepid has been added to in the past
         result = await self.identity_handler.try_unbind_threepid(
-            requester.user.to_string(),
-            {
-                "address": body.address,
-                "medium": body.medium,
-                "id_server": body.id_server,
-            },
+            requester.user.to_string(), body.medium, body.address, body.id_server
         )
         return 200, {"id_server_unbind_result": "success" if result else "no-support"}
 
diff --git a/synapse/rest/client/room_keys.py b/synapse/rest/client/room_keys.py
index f7081f638e..4e7ffdb555 100644
--- a/synapse/rest/client/room_keys.py
+++ b/synapse/rest/client/room_keys.py
@@ -259,6 +259,32 @@ class RoomKeysNewVersionServlet(RestServlet):
         self.auth = hs.get_auth()
         self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler()
 
+    async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+        """
+        Retrieve the version information about the most current backup version (if any)
+
+        It takes out an exclusive lock on this user's room_key backups, to ensure
+        clients only upload to the current backup.
+
+        Returns 404 if the given version does not exist.
+
+        GET /room_keys/version HTTP/1.1
+        {
+            "version": "12345",
+            "algorithm": "m.megolm_backup.v1",
+            "auth_data": "dGhpcyBzaG91bGQgYWN0dWFsbHkgYmUgZW5jcnlwdGVkIGpzb24K"
+        }
+        """
+        requester = await self.auth.get_user_by_req(request, allow_guest=False)
+        user_id = requester.user.to_string()
+
+        try:
+            info = await self.e2e_room_keys_handler.get_version_info(user_id)
+        except SynapseError as e:
+            if e.code == 404:
+                raise SynapseError(404, "No backup found", Codes.NOT_FOUND)
+        return 200, info
+
     async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         """
         Create a new backup version for this user's room_keys with the given
@@ -301,7 +327,7 @@ class RoomKeysNewVersionServlet(RestServlet):
 
 
 class RoomKeysVersionServlet(RestServlet):
-    PATTERNS = client_patterns("/room_keys/version(/(?P<version>[^/]+))?$")
+    PATTERNS = client_patterns("/room_keys/version/(?P<version>[^/]+)$")
 
     def __init__(self, hs: "HomeServer"):
         super().__init__()
@@ -309,12 +335,11 @@ class RoomKeysVersionServlet(RestServlet):
         self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler()
 
     async def on_GET(
-        self, request: SynapseRequest, version: Optional[str]
+        self, request: SynapseRequest, version: str
     ) -> Tuple[int, JsonDict]:
         """
         Retrieve the version information about a given version of the user's
-        room_keys backup.  If the version part is missing, returns info about the
-        most current backup version (if any)
+        room_keys backup.
 
         It takes out an exclusive lock on this user's room_key backups, to ensure
         clients only upload to the current backup.
@@ -339,20 +364,16 @@ class RoomKeysVersionServlet(RestServlet):
         return 200, info
 
     async def on_DELETE(
-        self, request: SynapseRequest, version: Optional[str]
+        self, request: SynapseRequest, version: str
     ) -> Tuple[int, JsonDict]:
         """
         Delete the information about a given version of the user's
-        room_keys backup.  If the version part is missing, deletes the most
-        current backup version (if any). Doesn't delete the actual room data.
+        room_keys backup. Doesn't delete the actual room data.
 
         DELETE /room_keys/version/12345 HTTP/1.1
         HTTP/1.1 200 OK
         {}
         """
-        if version is None:
-            raise SynapseError(400, "No version specified to delete", Codes.NOT_FOUND)
-
         requester = await self.auth.get_user_by_req(request, allow_guest=False)
         user_id = requester.user.to_string()
 
@@ -360,7 +381,7 @@ class RoomKeysVersionServlet(RestServlet):
         return 200, {}
 
     async def on_PUT(
-        self, request: SynapseRequest, version: Optional[str]
+        self, request: SynapseRequest, version: str
     ) -> Tuple[int, JsonDict]:
         """
         Update the information about a given version of the user's room_keys backup.
@@ -386,11 +407,6 @@ class RoomKeysVersionServlet(RestServlet):
         user_id = requester.user.to_string()
         info = parse_json_object_from_request(request)
 
-        if version is None:
-            raise SynapseError(
-                400, "No version specified to update", Codes.MISSING_PARAM
-            )
-
         await self.e2e_room_keys_handler.update_version(user_id, version, info)
         return 200, {}
 
diff --git a/synapse/rest/client/tags.py b/synapse/rest/client/tags.py
index ca638755c7..dde08417a4 100644
--- a/synapse/rest/client/tags.py
+++ b/synapse/rest/client/tags.py
@@ -34,7 +34,9 @@ class TagListServlet(RestServlet):
     GET /user/{user_id}/rooms/{room_id}/tags HTTP/1.1
     """
 
-    PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags")
+    PATTERNS = client_patterns(
+        "/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags$"
+    )
 
     def __init__(self, hs: "HomeServer"):
         super().__init__()
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index a5c3de192f..db25848744 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -46,10 +46,9 @@ from ._base import FileInfo, Responder
 from .filepath import MediaFilePaths
 
 if TYPE_CHECKING:
+    from synapse.rest.media.v1.storage_provider import StorageProvider
     from synapse.server import HomeServer
 
-    from .storage_provider import StorageProviderWrapper
-
 logger = logging.getLogger(__name__)
 
 
@@ -68,7 +67,7 @@ class MediaStorage:
         hs: "HomeServer",
         local_media_directory: str,
         filepaths: MediaFilePaths,
-        storage_providers: Sequence["StorageProviderWrapper"],
+        storage_providers: Sequence["StorageProvider"],
     ):
         self.hs = hs
         self.reactor = hs.get_reactor()
@@ -360,7 +359,7 @@ class ReadableFileWrapper:
     clock: Clock
     path: str
 
-    async def write_chunks_to(self, callback: Callable[[bytes], None]) -> None:
+    async def write_chunks_to(self, callback: Callable[[bytes], object]) -> None:
         """Reads the file in chunks and calls the callback with each chunk."""
 
         with open(self.path, "rb") as file:
diff --git a/synapse/server.py b/synapse/server.py
index 9d6d268f49..efc6b5f895 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -21,7 +21,7 @@
 import abc
 import functools
 import logging
-from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, TypeVar, cast
+from typing import TYPE_CHECKING, Callable, Dict, List, Optional, TypeVar, cast
 
 from twisted.internet.interfaces import IOpenSSLContextFactory
 from twisted.internet.tcp import Port
@@ -144,10 +144,10 @@ if TYPE_CHECKING:
     from synapse.handlers.saml import SamlHandler
 
 
-T = TypeVar("T", bound=Callable[..., Any])
+T = TypeVar("T")
 
 
-def cache_in_self(builder: T) -> T:
+def cache_in_self(builder: Callable[["HomeServer"], T]) -> Callable[["HomeServer"], T]:
     """Wraps a function called e.g. `get_foo`, checking if `self.foo` exists and
     returning if so. If not, calls the given function and sets `self.foo` to it.
 
@@ -166,7 +166,7 @@ def cache_in_self(builder: T) -> T:
     building = [False]
 
     @functools.wraps(builder)
-    def _get(self):
+    def _get(self: "HomeServer") -> T:
         try:
             return getattr(self, depname)
         except AttributeError:
@@ -185,9 +185,7 @@ def cache_in_self(builder: T) -> T:
 
         return dep
 
-    # We cast here as we need to tell mypy that `_get` has the same signature as
-    # `builder`.
-    return cast(T, _get)
+    return _get
 
 
 class HomeServer(metaclass=abc.ABCMeta):
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index fdfb46ab82..4dc25df67e 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -39,7 +39,11 @@ from prometheus_client import Counter, Histogram
 from synapse.api.constants import EventTypes
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions
 from synapse.events import EventBase
-from synapse.events.snapshot import EventContext
+from synapse.events.snapshot import (
+    EventContext,
+    UnpersistedEventContext,
+    UnpersistedEventContextBase,
+)
 from synapse.logging.context import ContextResourceUsage
 from synapse.replication.http.state import ReplicationUpdateCurrentStateRestServlet
 from synapse.state import v1, v2
@@ -222,7 +226,7 @@ class StateHandler:
         return await ret.get_state(self._state_storage_controller, state_filter)
 
     async def get_current_user_ids_in_room(
-        self, room_id: str, latest_event_ids: List[str]
+        self, room_id: str, latest_event_ids: Collection[str]
     ) -> Set[str]:
         """
         Get the users IDs who are currently in a room.
@@ -262,31 +266,31 @@ class StateHandler:
         state = await entry.get_state(self._state_storage_controller, StateFilter.all())
         return await self.store.get_joined_hosts(room_id, state, entry)
 
-    async def compute_event_context(
+    async def calculate_context_info(
         self,
         event: EventBase,
         state_ids_before_event: Optional[StateMap[str]] = None,
         partial_state: Optional[bool] = None,
-    ) -> EventContext:
-        """Build an EventContext structure for a non-outlier event.
-
-        (for an outlier, call EventContext.for_outlier directly)
-
-        This works out what the current state should be for the event, and
-        generates a new state group if necessary.
-
-        Args:
-            event:
-            state_ids_before_event: The event ids of the state before the event if
-                it can't be calculated from existing events. This is normally
-                only specified when receiving an event from federation where we
-                don't have the prev events, e.g. when backfilling.
-            partial_state:
-                `True` if `state_ids_before_event` is partial and omits non-critical
-                membership events.
-                `False` if `state_ids_before_event` is the full state.
-                `None` when `state_ids_before_event` is not provided. In this case, the
-                flag will be calculated based on `event`'s prev events.
+        state_group_before_event: Optional[int] = None,
+    ) -> UnpersistedEventContextBase:
+        """
+        Calulates the contents of an unpersisted event context, other than the current
+        state group (which is either provided or calculated when the event context is persisted)
+
+        state_ids_before_event:
+            The event ids of the full state before the event if
+            it can't be calculated from existing events. This is normally
+            only specified when receiving an event from federation where we
+            don't have the prev events, e.g. when backfilling or when the event
+            is being created for batch persisting.
+        partial_state:
+            `True` if `state_ids_before_event` is partial and omits non-critical
+            membership events.
+            `False` if `state_ids_before_event` is the full state.
+            `None` when `state_ids_before_event` is not provided. In this case, the
+            flag will be calculated based on `event`'s prev events.
+        state_group_before_event:
+            the current state group at the time of event, if known
         Returns:
             The event context.
 
@@ -294,7 +298,6 @@ class StateHandler:
             RuntimeError if `state_ids_before_event` is not provided and one or more
                 prev events are missing or outliers.
         """
-
         assert not event.internal_metadata.is_outlier()
 
         #
@@ -306,17 +309,6 @@ class StateHandler:
             state_group_before_event_prev_group = None
             deltas_to_state_group_before_event = None
 
-            # .. though we need to get a state group for it.
-            state_group_before_event = (
-                await self._state_storage_controller.store_state_group(
-                    event.event_id,
-                    event.room_id,
-                    prev_group=None,
-                    delta_ids=None,
-                    current_state_ids=state_ids_before_event,
-                )
-            )
-
             # the partial_state flag must be provided
             assert partial_state is not None
         else:
@@ -345,6 +337,7 @@ class StateHandler:
             logger.debug("calling resolve_state_groups from compute_event_context")
             # we've already taken into account partial state, so no need to wait for
             # complete state here.
+
             entry = await self.resolve_state_groups_for_events(
                 event.room_id,
                 event.prev_event_ids(),
@@ -383,18 +376,19 @@ class StateHandler:
         #
 
         if not event.is_state():
-            return EventContext.with_state(
+            return UnpersistedEventContext(
                 storage=self._storage_controllers,
                 state_group_before_event=state_group_before_event,
-                state_group=state_group_before_event,
+                state_group_after_event=state_group_before_event,
                 state_delta_due_to_event={},
-                prev_group=state_group_before_event_prev_group,
-                delta_ids=deltas_to_state_group_before_event,
+                prev_group_for_state_group_before_event=state_group_before_event_prev_group,
+                delta_ids_to_state_group_before_event=deltas_to_state_group_before_event,
                 partial_state=partial_state,
+                state_map_before_event=state_ids_before_event,
             )
 
         #
-        # otherwise, we'll need to create a new state group for after the event
+        # otherwise, we'll need to set up creating a new state group for after the event
         #
 
         key = (event.type, event.state_key)
@@ -412,88 +406,60 @@ class StateHandler:
 
         delta_ids = {key: event.event_id}
 
-        state_group_after_event = (
-            await self._state_storage_controller.store_state_group(
-                event.event_id,
-                event.room_id,
-                prev_group=state_group_before_event,
-                delta_ids=delta_ids,
-                current_state_ids=None,
-            )
-        )
-
-        return EventContext.with_state(
+        return UnpersistedEventContext(
             storage=self._storage_controllers,
-            state_group=state_group_after_event,
             state_group_before_event=state_group_before_event,
+            state_group_after_event=None,
             state_delta_due_to_event=delta_ids,
-            prev_group=state_group_before_event,
-            delta_ids=delta_ids,
+            prev_group_for_state_group_before_event=state_group_before_event_prev_group,
+            delta_ids_to_state_group_before_event=deltas_to_state_group_before_event,
             partial_state=partial_state,
+            state_map_before_event=state_ids_before_event,
         )
 
-    async def compute_event_context_for_batched(
+    async def compute_event_context(
         self,
         event: EventBase,
-        state_ids_before_event: StateMap[str],
-        current_state_group: int,
+        state_ids_before_event: Optional[StateMap[str]] = None,
+        partial_state: Optional[bool] = None,
     ) -> EventContext:
-        """
-        Generate an event context for an event that has not yet been persisted to the
-        database. Intended for use with events that are created to be persisted in a batch.
-        Args:
-            event: the event the context is being computed for
-            state_ids_before_event: a state map consisting of the state ids of the events
-            created prior to this event.
-            current_state_group: the current state group before the event.
-        """
-        state_group_before_event_prev_group = None
-        deltas_to_state_group_before_event = None
-
-        state_group_before_event = current_state_group
-
-        # if the event is not state, we are set
-        if not event.is_state():
-            return EventContext.with_state(
-                storage=self._storage_controllers,
-                state_group_before_event=state_group_before_event,
-                state_group=state_group_before_event,
-                state_delta_due_to_event={},
-                prev_group=state_group_before_event_prev_group,
-                delta_ids=deltas_to_state_group_before_event,
-                partial_state=False,
-            )
+        """Build an EventContext structure for a non-outlier event.
 
-        # otherwise, we'll need to create a new state group for after the event
-        key = (event.type, event.state_key)
+        (for an outlier, call EventContext.for_outlier directly)
 
-        if state_ids_before_event is not None:
-            replaces = state_ids_before_event.get(key)
+        This works out what the current state should be for the event, and
+        generates a new state group if necessary.
 
-        if replaces and replaces != event.event_id:
-            event.unsigned["replaces_state"] = replaces
+        Args:
+            event:
+            state_ids_before_event: The event ids of the state before the event if
+                it can't be calculated from existing events. This is normally
+                only specified when receiving an event from federation where we
+                don't have the prev events, e.g. when backfilling.
+            partial_state:
+                `True` if `state_ids_before_event` is partial and omits non-critical
+                membership events.
+                `False` if `state_ids_before_event` is the full state.
+                `None` when `state_ids_before_event` is not provided. In this case, the
+                flag will be calculated based on `event`'s prev events.
+            entry:
+                A state cache entry for the resolved state across the prev events. We may
+                have already calculated this, so if it's available pass it in
+        Returns:
+            The event context.
 
-        delta_ids = {key: event.event_id}
+        Raises:
+            RuntimeError if `state_ids_before_event` is not provided and one or more
+                prev events are missing or outliers.
+        """
 
-        state_group_after_event = (
-            await self._state_storage_controller.store_state_group(
-                event.event_id,
-                event.room_id,
-                prev_group=state_group_before_event,
-                delta_ids=delta_ids,
-                current_state_ids=None,
-            )
+        unpersisted_context = await self.calculate_context_info(
+            event=event,
+            state_ids_before_event=state_ids_before_event,
+            partial_state=partial_state,
         )
 
-        return EventContext.with_state(
-            storage=self._storage_controllers,
-            state_group=state_group_after_event,
-            state_group_before_event=state_group_before_event,
-            state_delta_due_to_event=delta_ids,
-            prev_group=state_group_before_event,
-            delta_ids=delta_ids,
-            partial_state=False,
-        )
+        return await unpersisted_context.persist(event)
 
     @measure_func()
     async def resolve_state_groups_for_events(
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 41d9111019..481fec72fe 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -37,6 +37,8 @@ class SQLBaseStore(metaclass=ABCMeta):
     per data store (and not one per physical database).
     """
 
+    db_pool: DatabasePool
+
     def __init__(
         self,
         database: DatabasePool,
diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py
index 52efd4a171..9d7a8a792f 100644
--- a/synapse/storage/controllers/state.py
+++ b/synapse/storage/controllers/state.py
@@ -14,6 +14,7 @@
 import logging
 from typing import (
     TYPE_CHECKING,
+    AbstractSet,
     Any,
     Awaitable,
     Callable,
@@ -23,7 +24,6 @@ from typing import (
     List,
     Mapping,
     Optional,
-    Set,
     Tuple,
 )
 
@@ -527,7 +527,7 @@ class StateStorageController:
         )
         return state_map.get(key)
 
-    async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
+    async def get_current_hosts_in_room(self, room_id: str) -> AbstractSet[str]:
         """Get current hosts in room based on current state.
 
         Blocks until we have full state for the given room. This only happens for rooms
@@ -584,7 +584,7 @@ class StateStorageController:
 
     async def get_users_in_room_with_profiles(
         self, room_id: str
-    ) -> Dict[str, ProfileInfo]:
+    ) -> Mapping[str, ProfileInfo]:
         """
         Get the current users in the room with their profiles.
         If the room is currently partial-stated, this will block until the room has
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index e20c5c5302..feaa6cdd07 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -499,6 +499,7 @@ class DatabasePool:
     """
 
     _TXN_ID = 0
+    engine: BaseDatabaseEngine
 
     def __init__(
         self,
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 8a359d7eb8..95567826f2 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -21,6 +21,7 @@ from typing import (
     FrozenSet,
     Iterable,
     List,
+    Mapping,
     Optional,
     Tuple,
     cast,
@@ -122,25 +123,25 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
         return self._account_data_id_gen.get_current_token()
 
     @cached()
-    async def get_account_data_for_user(
+    async def get_global_account_data_for_user(
         self, user_id: str
-    ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
+    ) -> Mapping[str, JsonDict]:
         """
-        Get all the client account_data for a user.
+        Get all the global client account_data for a user.
 
         If experimental MSC3391 support is enabled, any entries with an empty
         content body are excluded; as this means they have been deleted.
 
         Args:
             user_id: The user to get the account_data for.
+
         Returns:
-            A 2-tuple of a dict of global account_data and a dict mapping from
-            room_id string to per room account_data dicts.
+            The global account_data.
         """
 
-        def get_account_data_for_user_txn(
+        def get_global_account_data_for_user(
             txn: LoggingTransaction,
-        ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
+        ) -> Dict[str, JsonDict]:
             # The 'content != '{}' condition below prevents us from using
             # `simple_select_list_txn` here, as it doesn't support conditions
             # other than 'equals'.
@@ -158,10 +159,34 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
             txn.execute(sql, (user_id,))
             rows = self.db_pool.cursor_to_dict(txn)
 
-            global_account_data = {
+            return {
                 row["account_data_type"]: db_to_json(row["content"]) for row in rows
             }
 
+        return await self.db_pool.runInteraction(
+            "get_global_account_data_for_user", get_global_account_data_for_user
+        )
+
+    @cached()
+    async def get_room_account_data_for_user(
+        self, user_id: str
+    ) -> Mapping[str, Mapping[str, JsonDict]]:
+        """
+        Get all of the per-room client account_data for a user.
+
+        If experimental MSC3391 support is enabled, any entries with an empty
+        content body are excluded; as this means they have been deleted.
+
+        Args:
+            user_id: The user to get the account_data for.
+
+        Returns:
+            A dict mapping from room_id string to per-room account_data dicts.
+        """
+
+        def get_room_account_data_for_user_txn(
+            txn: LoggingTransaction,
+        ) -> Dict[str, Dict[str, JsonDict]]:
             # The 'content != '{}' condition below prevents us from using
             # `simple_select_list_txn` here, as it doesn't support conditions
             # other than 'equals'.
@@ -185,10 +210,10 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
 
                 room_data[row["account_data_type"]] = db_to_json(row["content"])
 
-            return global_account_data, by_room
+            return by_room
 
         return await self.db_pool.runInteraction(
-            "get_account_data_for_user", get_account_data_for_user_txn
+            "get_room_account_data_for_user_txn", get_room_account_data_for_user_txn
         )
 
     @cached(num_args=2, max_entries=5000, tree=True)
@@ -215,7 +240,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
     @cached(num_args=2, tree=True)
     async def get_account_data_for_room(
         self, user_id: str, room_id: str
-    ) -> Dict[str, JsonDict]:
+    ) -> Mapping[str, JsonDict]:
         """Get all the client account_data for a user for a room.
 
         Args:
@@ -342,36 +367,61 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
             "get_updated_room_account_data", get_updated_room_account_data_txn
         )
 
-    async def get_updated_account_data_for_user(
+    async def get_updated_global_account_data_for_user(
         self, user_id: str, stream_id: int
-    ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
-        """Get all the client account_data for a that's changed for a user
+    ) -> Dict[str, JsonDict]:
+        """Get all the global account_data that's changed for a user.
 
         Args:
             user_id: The user to get the account_data for.
             stream_id: The point in the stream since which to get updates
+
         Returns:
-            A deferred pair of a dict of global account_data and a dict
-            mapping from room_id string to per room account_data dicts.
+            A dict of global account_data.
         """
 
-        def get_updated_account_data_for_user_txn(
+        def get_updated_global_account_data_for_user(
             txn: LoggingTransaction,
-        ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
-            sql = (
-                "SELECT account_data_type, content FROM account_data"
-                " WHERE user_id = ? AND stream_id > ?"
-            )
-
+        ) -> Dict[str, JsonDict]:
+            sql = """
+                SELECT account_data_type, content FROM account_data
+                WHERE user_id = ? AND stream_id > ?
+            """
             txn.execute(sql, (user_id, stream_id))
 
-            global_account_data = {row[0]: db_to_json(row[1]) for row in txn}
+            return {row[0]: db_to_json(row[1]) for row in txn}
 
-            sql = (
-                "SELECT room_id, account_data_type, content FROM room_account_data"
-                " WHERE user_id = ? AND stream_id > ?"
-            )
+        changed = self._account_data_stream_cache.has_entity_changed(
+            user_id, int(stream_id)
+        )
+        if not changed:
+            return {}
+
+        return await self.db_pool.runInteraction(
+            "get_updated_global_account_data_for_user",
+            get_updated_global_account_data_for_user,
+        )
+
+    async def get_updated_room_account_data_for_user(
+        self, user_id: str, stream_id: int
+    ) -> Dict[str, Dict[str, JsonDict]]:
+        """Get all the room account_data that's changed for a user.
 
+        Args:
+            user_id: The user to get the account_data for.
+            stream_id: The point in the stream since which to get updates
+
+        Returns:
+            A dict mapping from room_id string to per room account_data dicts.
+        """
+
+        def get_updated_room_account_data_for_user_txn(
+            txn: LoggingTransaction,
+        ) -> Dict[str, Dict[str, JsonDict]]:
+            sql = """
+                SELECT room_id, account_data_type, content FROM room_account_data
+                WHERE user_id = ? AND stream_id > ?
+            """
             txn.execute(sql, (user_id, stream_id))
 
             account_data_by_room: Dict[str, Dict[str, JsonDict]] = {}
@@ -379,16 +429,17 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
                 room_account_data = account_data_by_room.setdefault(row[0], {})
                 room_account_data[row[1]] = db_to_json(row[2])
 
-            return global_account_data, account_data_by_room
+            return account_data_by_room
 
         changed = self._account_data_stream_cache.has_entity_changed(
             user_id, int(stream_id)
         )
         if not changed:
-            return {}, {}
+            return {}
 
         return await self.db_pool.runInteraction(
-            "get_updated_account_data_for_user", get_updated_account_data_for_user_txn
+            "get_updated_room_account_data_for_user",
+            get_updated_room_account_data_for_user_txn,
         )
 
     @cached(max_entries=5000, iterable=True)
@@ -444,7 +495,8 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
                     self.get_global_account_data_by_type_for_user.invalidate(
                         (row.user_id, row.data_type)
                     )
-                self.get_account_data_for_user.invalidate((row.user_id,))
+                self.get_global_account_data_for_user.invalidate((row.user_id,))
+                self.get_room_account_data_for_user.invalidate((row.user_id,))
                 self.get_account_data_for_room.invalidate((row.user_id, row.room_id))
                 self.get_account_data_for_room_and_type.invalidate(
                     (row.user_id, row.room_id, row.data_type)
@@ -492,7 +544,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
             )
 
             self._account_data_stream_cache.entity_has_changed(user_id, next_id)
-            self.get_account_data_for_user.invalidate((user_id,))
+            self.get_room_account_data_for_user.invalidate((user_id,))
             self.get_account_data_for_room.invalidate((user_id, room_id))
             self.get_account_data_for_room_and_type.prefill(
                 (user_id, room_id, account_data_type), content
@@ -558,7 +610,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
                 return None
 
             self._account_data_stream_cache.entity_has_changed(user_id, next_id)
-            self.get_account_data_for_user.invalidate((user_id,))
+            self.get_room_account_data_for_user.invalidate((user_id,))
             self.get_account_data_for_room.invalidate((user_id, room_id))
             self.get_account_data_for_room_and_type.prefill(
                 (user_id, room_id, account_data_type), {}
@@ -593,7 +645,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
             )
 
             self._account_data_stream_cache.entity_has_changed(user_id, next_id)
-            self.get_account_data_for_user.invalidate((user_id,))
+            self.get_global_account_data_for_user.invalidate((user_id,))
             self.get_global_account_data_by_type_for_user.invalidate(
                 (user_id, account_data_type)
             )
@@ -761,7 +813,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
                 return None
 
             self._account_data_stream_cache.entity_has_changed(user_id, next_id)
-            self.get_account_data_for_user.invalidate((user_id,))
+            self.get_global_account_data_for_user.invalidate((user_id,))
             self.get_global_account_data_by_type_for_user.prefill(
                 (user_id, account_data_type), {}
             )
@@ -822,7 +874,10 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
             txn, self.get_account_data_for_room_and_type, (user_id,)
         )
         self._invalidate_cache_and_stream(
-            txn, self.get_account_data_for_user, (user_id,)
+            txn, self.get_global_account_data_for_user, (user_id,)
+        )
+        self._invalidate_cache_and_stream(
+            txn, self.get_room_account_data_for_user, (user_id,)
         )
         self._invalidate_cache_and_stream(
             txn, self.get_global_account_data_by_type_for_user, (user_id,)
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 5fb152c4ff..484db175d0 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -166,7 +166,7 @@ class ApplicationServiceWorkerStore(RoomMemberWorkerStore):
         room_id: str,
         app_service: "ApplicationService",
         cache_context: _CacheContext,
-    ) -> List[str]:
+    ) -> Sequence[str]:
         """
         Get all users in a room that the appservice controls.
 
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index e8b6cc6b80..1ca66d57d4 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -21,6 +21,7 @@ from typing import (
     Dict,
     Iterable,
     List,
+    Mapping,
     Optional,
     Set,
     Tuple,
@@ -100,6 +101,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
                 ("device_lists_outbound_pokes", "stream_id"),
                 ("device_lists_changes_in_room", "stream_id"),
                 ("device_lists_remote_pending", "stream_id"),
+                ("device_lists_changes_converted_stream_position", "stream_id"),
             ],
             is_writer=hs.config.worker.worker_app is None,
         )
@@ -201,7 +203,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
     def get_device_stream_token(self) -> int:
         return self._device_list_id_gen.get_current_token()
 
-    async def count_devices_by_users(self, user_ids: Optional[List[str]] = None) -> int:
+    async def count_devices_by_users(
+        self, user_ids: Optional[Collection[str]] = None
+    ) -> int:
         """Retrieve number of all devices of given users.
         Only returns number of devices that are not marked as hidden.
 
@@ -212,7 +216,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
         """
 
         def count_devices_by_users_txn(
-            txn: LoggingTransaction, user_ids: List[str]
+            txn: LoggingTransaction, user_ids: Collection[str]
         ) -> int:
             sql = """
                 SELECT count(*)
@@ -745,42 +749,47 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
     @trace
     @cancellable
     async def get_user_devices_from_cache(
-        self, query_list: List[Tuple[str, Optional[str]]]
-    ) -> Tuple[Set[str], Dict[str, Dict[str, JsonDict]]]:
+        self, user_ids: Set[str], user_and_device_ids: List[Tuple[str, str]]
+    ) -> Tuple[Set[str], Dict[str, Mapping[str, JsonDict]]]:
         """Get the devices (and keys if any) for remote users from the cache.
 
         Args:
-            query_list: List of (user_id, device_ids), if device_ids is
-                falsey then return all device ids for that user.
+            user_ids: users which should have all device IDs returned
+            user_and_device_ids: List of (user_id, device_ids)
 
         Returns:
             A tuple of (user_ids_not_in_cache, results_map), where
             user_ids_not_in_cache is a set of user_ids and results_map is a
             mapping of user_id -> device_id -> device_info.
         """
-        user_ids = {user_id for user_id, _ in query_list}
-        user_map = await self.get_device_list_last_stream_id_for_remotes(list(user_ids))
+        unique_user_ids = user_ids | {user_id for user_id, _ in user_and_device_ids}
+        user_map = await self.get_device_list_last_stream_id_for_remotes(
+            list(unique_user_ids)
+        )
 
         # We go and check if any of the users need to have their device lists
         # resynced. If they do then we remove them from the cached list.
         users_needing_resync = await self.get_user_ids_requiring_device_list_resync(
-            user_ids
+            unique_user_ids
         )
         user_ids_in_cache = {
             user_id for user_id, stream_id in user_map.items() if stream_id
         } - users_needing_resync
-        user_ids_not_in_cache = user_ids - user_ids_in_cache
-
-        results: Dict[str, Dict[str, JsonDict]] = {}
-        for user_id, device_id in query_list:
-            if user_id not in user_ids_in_cache:
-                continue
+        user_ids_not_in_cache = unique_user_ids - user_ids_in_cache
 
-            if device_id:
-                device = await self._get_cached_user_device(user_id, device_id)
-                results.setdefault(user_id, {})[device_id] = device
-            else:
+        # First fetch all the users which all devices are to be returned.
+        results: Dict[str, Mapping[str, JsonDict]] = {}
+        for user_id in user_ids:
+            if user_id in user_ids_in_cache:
                 results[user_id] = await self.get_cached_devices_for_user(user_id)
+        # Then fetch all device-specific requests, but skip users we've already
+        # fetched all devices for.
+        device_specific_results: Dict[str, Dict[str, JsonDict]] = {}
+        for user_id, device_id in user_and_device_ids:
+            if user_id in user_ids_in_cache and user_id not in user_ids:
+                device = await self._get_cached_user_device(user_id, device_id)
+                device_specific_results.setdefault(user_id, {})[device_id] = device
+        results.update(device_specific_results)
 
         set_tag("in_cache", str(results))
         set_tag("not_in_cache", str(user_ids_not_in_cache))
@@ -798,7 +807,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
         return db_to_json(content)
 
     @cached()
-    async def get_cached_devices_for_user(self, user_id: str) -> Dict[str, JsonDict]:
+    async def get_cached_devices_for_user(self, user_id: str) -> Mapping[str, JsonDict]:
         devices = await self.db_pool.simple_select_list(
             table="device_lists_remote_cache",
             keyvalues={"user_id": user_id},
diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py
index 5903fdaf00..44aa181174 100644
--- a/synapse/storage/databases/main/directory.py
+++ b/synapse/storage/databases/main/directory.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Iterable, List, Optional, Tuple
+from typing import Iterable, List, Optional, Sequence, Tuple
 
 import attr
 
@@ -74,7 +74,7 @@ class DirectoryWorkerStore(CacheInvalidationWorkerStore):
         )
 
     @cached(max_entries=5000)
-    async def get_aliases_for_room(self, room_id: str) -> List[str]:
+    async def get_aliases_for_room(self, room_id: str) -> Sequence[str]:
         return await self.db_pool.simple_select_onecol(
             "room_aliases",
             {"room_id": room_id},
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index c4ac6c33ba..2c2d145666 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -20,7 +20,9 @@ from typing import (
     Dict,
     Iterable,
     List,
+    Mapping,
     Optional,
+    Sequence,
     Tuple,
     Union,
     cast,
@@ -260,7 +262,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
 
         for batch in batch_iter(signature_query, 50):
             cross_sigs_result = await self.db_pool.runInteraction(
-                "get_e2e_cross_signing_signatures",
+                "get_e2e_cross_signing_signatures_for_devices",
                 self._get_e2e_cross_signing_signatures_for_devices_txn,
                 batch,
             )
@@ -691,7 +693,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
     @cached(max_entries=10000)
     async def get_e2e_unused_fallback_key_types(
         self, user_id: str, device_id: str
-    ) -> List[str]:
+    ) -> Sequence[str]:
         """Returns the fallback key types that have an unused key.
 
         Args:
@@ -731,7 +733,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
         return user_keys.get(key_type)
 
     @cached(num_args=1)
-    def _get_bare_e2e_cross_signing_keys(self, user_id: str) -> Dict[str, JsonDict]:
+    def _get_bare_e2e_cross_signing_keys(self, user_id: str) -> Mapping[str, JsonDict]:
         """Dummy function.  Only used to make a cache for
         _get_bare_e2e_cross_signing_keys_bulk.
         """
@@ -744,7 +746,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
     )
     async def _get_bare_e2e_cross_signing_keys_bulk(
         self, user_ids: Iterable[str]
-    ) -> Dict[str, Optional[Dict[str, JsonDict]]]:
+    ) -> Dict[str, Optional[Mapping[str, JsonDict]]]:
         """Returns the cross-signing keys for a set of users.  The output of this
         function should be passed to _get_e2e_cross_signing_signatures_txn if
         the signatures for the calling user need to be fetched.
@@ -765,7 +767,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
         )
 
         # The `Optional` comes from the `@cachedList` decorator.
-        return cast(Dict[str, Optional[Dict[str, JsonDict]]], result)
+        return cast(Dict[str, Optional[Mapping[str, JsonDict]]], result)
 
     def _get_bare_e2e_cross_signing_keys_bulk_txn(
         self,
@@ -924,7 +926,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
     @cancellable
     async def get_e2e_cross_signing_keys_bulk(
         self, user_ids: List[str], from_user_id: Optional[str] = None
-    ) -> Dict[str, Optional[Dict[str, JsonDict]]]:
+    ) -> Dict[str, Optional[Mapping[str, JsonDict]]]:
         """Returns the cross-signing keys for a set of users.
 
         Args:
@@ -940,11 +942,14 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
         result = await self._get_bare_e2e_cross_signing_keys_bulk(user_ids)
 
         if from_user_id:
-            result = await self.db_pool.runInteraction(
-                "get_e2e_cross_signing_signatures",
-                self._get_e2e_cross_signing_signatures_txn,
-                result,
-                from_user_id,
+            result = cast(
+                Dict[str, Optional[Mapping[str, JsonDict]]],
+                await self.db_pool.runInteraction(
+                    "get_e2e_cross_signing_signatures",
+                    self._get_e2e_cross_signing_signatures_txn,
+                    result,
+                    from_user_id,
+                ),
             )
 
         return result
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index bbee02ab18..ca780cca36 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -22,6 +22,7 @@ from typing import (
     Iterable,
     List,
     Optional,
+    Sequence,
     Set,
     Tuple,
     cast,
@@ -1004,7 +1005,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
             room_id,
         )
 
-    async def get_max_depth_of(self, event_ids: List[str]) -> Tuple[Optional[str], int]:
+    async def get_max_depth_of(
+        self, event_ids: Collection[str]
+    ) -> Tuple[Optional[str], int]:
         """Returns the event ID and depth for the event that has the max depth from a set of event IDs
 
         Args:
@@ -1141,7 +1144,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         )
 
     @cached(max_entries=5000, iterable=True)
-    async def get_latest_event_ids_in_room(self, room_id: str) -> List[str]:
+    async def get_latest_event_ids_in_room(self, room_id: str) -> Sequence[str]:
         return await self.db_pool.simple_select_onecol(
             table="event_forward_extremities",
             keyvalues={"room_id": room_id},
@@ -1171,7 +1174,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
     @cancellable
     async def get_forward_extremities_for_room_at_stream_ordering(
         self, room_id: str, stream_ordering: int
-    ) -> List[str]:
+    ) -> Sequence[str]:
         """For a given room_id and stream_ordering, return the forward
         extremeties of the room at that point in "time".
 
@@ -1204,7 +1207,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
     @cached(max_entries=5000, num_args=2)
     async def _get_forward_extremeties_for_room(
         self, room_id: str, stream_ordering: int
-    ) -> List[str]:
+    ) -> Sequence[str]:
         """For a given room_id and stream_ordering, return the forward
         extremeties of the room at that point in "time".
 
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 3a0c370fde..eeccf5db24 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -203,11 +203,18 @@ class RoomNotifCounts:
     # Map of thread ID to the notification counts.
     threads: Dict[str, NotifCounts]
 
+    @staticmethod
+    def empty() -> "RoomNotifCounts":
+        return _EMPTY_ROOM_NOTIF_COUNTS
+
     def __len__(self) -> int:
         # To properly account for the amount of space in any caches.
         return len(self.threads) + 1
 
 
+_EMPTY_ROOM_NOTIF_COUNTS = RoomNotifCounts(NotifCounts(), {})
+
+
 def _serialize_action(
     actions: Collection[Union[Mapping, str]], is_highlight: bool
 ) -> str:
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 1536937b67..7996cbb557 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -16,7 +16,6 @@
 import itertools
 import logging
 from collections import OrderedDict
-from http import HTTPStatus
 from typing import (
     TYPE_CHECKING,
     Any,
@@ -26,7 +25,6 @@ from typing import (
     Iterable,
     List,
     Optional,
-    Sequence,
     Set,
     Tuple,
 )
@@ -36,7 +34,7 @@ from prometheus_client import Counter
 
 import synapse.metrics
 from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
-from synapse.api.errors import Codes, SynapseError
+from synapse.api.errors import PartialStateConflictError
 from synapse.api.room_versions import RoomVersions
 from synapse.events import EventBase, relation_from_event
 from synapse.events.snapshot import EventContext
@@ -52,7 +50,7 @@ from synapse.storage.databases.main.search import SearchEntry
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.util.id_generators import AbstractStreamIdGenerator
 from synapse.storage.util.sequence import SequenceGenerator
-from synapse.types import JsonDict, StateMap, get_domain_from_id
+from synapse.types import JsonDict, StateMap, StrCollection, get_domain_from_id
 from synapse.util import json_encoder
 from synapse.util.iterutils import batch_iter, sorted_topologically
 from synapse.util.stringutils import non_null_str_or_none
@@ -72,24 +70,6 @@ event_counter = Counter(
 )
 
 
-class PartialStateConflictError(SynapseError):
-    """An internal error raised when attempting to persist an event with partial state
-    after the room containing the event has been un-partial stated.
-
-    This error should be handled by recomputing the event context and trying again.
-
-    This error has an HTTP status code so that it can be transported over replication.
-    It should not be exposed to clients.
-    """
-
-    def __init__(self) -> None:
-        super().__init__(
-            HTTPStatus.CONFLICT,
-            msg="Cannot persist partial state event in un-partial stated room",
-            errcode=Codes.UNKNOWN,
-        )
-
-
 @attr.s(slots=True, auto_attribs=True)
 class DeltaState:
     """Deltas to use to update the `current_state_events` table.
@@ -306,7 +286,7 @@ class PersistEventsStore:
 
         # The set of event_ids to return. This includes all soft-failed events
         # and their prev events.
-        existing_prevs = set()
+        existing_prevs: Set[str] = set()
 
         def _get_prevs_before_rejected_txn(
             txn: LoggingTransaction, batch: Collection[str]
@@ -571,7 +551,7 @@ class PersistEventsStore:
         event_chain_id_gen: SequenceGenerator,
         event_to_room_id: Dict[str, str],
         event_to_types: Dict[str, Tuple[str, str]],
-        event_to_auth_chain: Dict[str, Sequence[str]],
+        event_to_auth_chain: Dict[str, StrCollection],
     ) -> None:
         """Calculate the chain cover index for the given events.
 
@@ -865,7 +845,7 @@ class PersistEventsStore:
         event_chain_id_gen: SequenceGenerator,
         event_to_room_id: Dict[str, str],
         event_to_types: Dict[str, Tuple[str, str]],
-        event_to_auth_chain: Dict[str, Sequence[str]],
+        event_to_auth_chain: Dict[str, StrCollection],
         events_to_calc_chain_id_for: Set[str],
         chain_map: Dict[str, Tuple[int, int]],
     ) -> Dict[str, Tuple[int, int]]:
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index b9d3c36d60..584536111d 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, Tuple, cast
+from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, cast
 
 import attr
 
@@ -29,7 +29,7 @@ from synapse.storage.database import (
 )
 from synapse.storage.databases.main.events import PersistEventsStore
 from synapse.storage.types import Cursor
-from synapse.types import JsonDict
+from synapse.types import JsonDict, StrCollection
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -1061,7 +1061,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
             self.event_chain_id_gen,  # type: ignore[attr-defined]
             event_to_room_id,
             event_to_types,
-            cast(Dict[str, Sequence[str]], event_to_auth_chain),
+            cast(Dict[str, StrCollection], event_to_auth_chain),
         )
 
         return _CalculateChainCover(
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index db9a24db5e..4b1061e6d7 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast
+from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, cast
 
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage.database import (
@@ -95,7 +95,7 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore):
         return await self.db_pool.runInteraction("count_users", _count_users)
 
     @cached(num_args=0)
-    async def get_monthly_active_count_by_service(self) -> Dict[str, int]:
+    async def get_monthly_active_count_by_service(self) -> Mapping[str, int]:
         """Generates current count of monthly active users broken down by service.
         A service is typically an appservice but also includes native matrix users.
         Since the `monthly_active_users` table is populated from the `user_ips` table
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 29972d5204..dddf49c2d5 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -21,7 +21,9 @@ from typing import (
     Dict,
     Iterable,
     List,
+    Mapping,
     Optional,
+    Sequence,
     Tuple,
     cast,
 )
@@ -288,7 +290,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
     async def get_linearized_receipts_for_room(
         self, room_id: str, to_key: int, from_key: Optional[int] = None
-    ) -> List[dict]:
+    ) -> Sequence[JsonDict]:
         """Get receipts for a single room for sending to clients.
 
         Args:
@@ -311,7 +313,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
     @cached(tree=True)
     async def _get_linearized_receipts_for_room(
         self, room_id: str, to_key: int, from_key: Optional[int] = None
-    ) -> List[JsonDict]:
+    ) -> Sequence[JsonDict]:
         """See get_linearized_receipts_for_room"""
 
         def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
@@ -354,7 +356,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
     )
     async def _get_linearized_receipts_for_rooms(
         self, room_ids: Collection[str], to_key: int, from_key: Optional[int] = None
-    ) -> Dict[str, List[JsonDict]]:
+    ) -> Dict[str, Sequence[JsonDict]]:
         if not room_ids:
             return {}
 
@@ -416,7 +418,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
     )
     async def get_linearized_receipts_for_all_rooms(
         self, to_key: int, from_key: Optional[int] = None
-    ) -> Dict[str, JsonDict]:
+    ) -> Mapping[str, JsonDict]:
         """Get receipts for all rooms between two stream_ids, up
         to a limit of the latest 100 read receipts.
 
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 31f0f2bd3d..9a55e17624 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -16,7 +16,7 @@
 import logging
 import random
 import re
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
+from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union, cast
 
 import attr
 
@@ -192,7 +192,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             )
 
     @cached()
-    async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
+    async def get_user_by_id(self, user_id: str) -> Optional[Mapping[str, Any]]:
         """Deprecated: use get_userinfo_by_id instead"""
 
         def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]:
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 0018d6f7ab..fa3266c081 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -22,6 +22,7 @@ from typing import (
     List,
     Mapping,
     Optional,
+    Sequence,
     Set,
     Tuple,
     Union,
@@ -171,7 +172,7 @@ class RelationsWorkerStore(SQLBaseStore):
         direction: Direction = Direction.BACKWARDS,
         from_token: Optional[StreamToken] = None,
         to_token: Optional[StreamToken] = None,
-    ) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]:
+    ) -> Tuple[Sequence[_RelatedEvent], Optional[StreamToken]]:
         """Get a list of relations for an event, ordered by topological ordering.
 
         Args:
@@ -397,7 +398,9 @@ class RelationsWorkerStore(SQLBaseStore):
         return result is not None
 
     @cached()
-    async def get_aggregation_groups_for_event(self, event_id: str) -> List[JsonDict]:
+    async def get_aggregation_groups_for_event(
+        self, event_id: str
+    ) -> Sequence[JsonDict]:
         raise NotImplementedError()
 
     @cachedList(
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index ea6a5e2f34..694a5b802c 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -24,6 +24,7 @@ from typing import (
     List,
     Mapping,
     Optional,
+    Sequence,
     Set,
     Tuple,
     Union,
@@ -153,7 +154,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         return self._known_servers_count
 
     @cached(max_entries=100000, iterable=True)
-    async def get_users_in_room(self, room_id: str) -> List[str]:
+    async def get_users_in_room(self, room_id: str) -> Sequence[str]:
         """Returns a list of users in the room.
 
         Will return inaccurate results for rooms with partial state, since the state for
@@ -190,9 +191,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         )
 
     @cached()
-    def get_user_in_room_with_profile(
-        self, room_id: str, user_id: str
-    ) -> Dict[str, ProfileInfo]:
+    def get_user_in_room_with_profile(self, room_id: str, user_id: str) -> ProfileInfo:
         raise NotImplementedError()
 
     @cachedList(
@@ -246,7 +245,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
     @cached(max_entries=100000, iterable=True)
     async def get_users_in_room_with_profiles(
         self, room_id: str
-    ) -> Dict[str, ProfileInfo]:
+    ) -> Mapping[str, ProfileInfo]:
         """Get a mapping from user ID to profile information for all users in a given room.
 
         The profile information comes directly from this room's `m.room.member`
@@ -285,7 +284,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         )
 
     @cached(max_entries=100000)
-    async def get_room_summary(self, room_id: str) -> Dict[str, MemberSummary]:
+    async def get_room_summary(self, room_id: str) -> Mapping[str, MemberSummary]:
         """Get the details of a room roughly suitable for use by the room
         summary extension to /sync. Useful when lazy loading room members.
         Args:
@@ -357,7 +356,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
     @cached()
     async def get_invited_rooms_for_local_user(
         self, user_id: str
-    ) -> List[RoomsForUser]:
+    ) -> Sequence[RoomsForUser]:
         """Get all the rooms the *local* user is invited to.
 
         Args:
@@ -475,7 +474,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         return results
 
     @cached(iterable=True)
-    async def get_local_users_in_room(self, room_id: str) -> List[str]:
+    async def get_local_users_in_room(self, room_id: str) -> Sequence[str]:
         """
         Retrieves a list of the current roommembers who are local to the server.
         """
@@ -791,7 +790,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         """Returns the set of users who share a room with `user_id`"""
         room_ids = await self.get_rooms_for_user(user_id)
 
-        user_who_share_room = set()
+        user_who_share_room: Set[str] = set()
         for room_id in room_ids:
             user_ids = await self.get_users_in_room(room_id)
             user_who_share_room.update(user_ids)
@@ -953,7 +952,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         return True
 
     @cached(iterable=True, max_entries=10000)
-    async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
+    async def get_current_hosts_in_room(self, room_id: str) -> AbstractSet[str]:
         """Get current hosts in room based on current state."""
 
         # First we check if we already have `get_users_in_room` in the cache, as
diff --git a/synapse/storage/databases/main/signatures.py b/synapse/storage/databases/main/signatures.py
index 05da15074a..5dcb1fc0b5 100644
--- a/synapse/storage/databases/main/signatures.py
+++ b/synapse/storage/databases/main/signatures.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Collection, Dict, List, Tuple
+from typing import Collection, Dict, List, Mapping, Tuple
 
 from unpaddedbase64 import encode_base64
 
@@ -26,7 +26,7 @@ from synapse.util.caches.descriptors import cached, cachedList
 
 class SignatureWorkerStore(EventsWorkerStore):
     @cached()
-    def get_event_reference_hash(self, event_id: str) -> Dict[str, Dict[str, bytes]]:
+    def get_event_reference_hash(self, event_id: str) -> Mapping[str, bytes]:
         # This is a dummy function to allow get_event_reference_hashes
         # to use its cache
         raise NotImplementedError()
@@ -36,7 +36,7 @@ class SignatureWorkerStore(EventsWorkerStore):
     )
     async def get_event_reference_hashes(
         self, event_ids: Collection[str]
-    ) -> Dict[str, Dict[str, bytes]]:
+    ) -> Mapping[str, Mapping[str, bytes]]:
         """Get all hashes for given events.
 
         Args:
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index d5500cdd47..c149a9eacb 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -15,7 +15,7 @@
 # limitations under the License.
 
 import logging
-from typing import Any, Dict, Iterable, List, Tuple, cast
+from typing import Any, Dict, Iterable, List, Mapping, Tuple, cast
 
 from synapse.api.constants import AccountDataTypes
 from synapse.replication.tcp.streams import AccountDataStream
@@ -32,7 +32,9 @@ logger = logging.getLogger(__name__)
 
 class TagsWorkerStore(AccountDataWorkerStore):
     @cached()
-    async def get_tags_for_user(self, user_id: str) -> Dict[str, Dict[str, JsonDict]]:
+    async def get_tags_for_user(
+        self, user_id: str
+    ) -> Mapping[str, Mapping[str, JsonDict]]:
         """Get all the tags for a user.
 
 
@@ -107,7 +109,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
 
     async def get_updated_tags(
         self, user_id: str, stream_id: int
-    ) -> Dict[str, Dict[str, JsonDict]]:
+    ) -> Mapping[str, Mapping[str, JsonDict]]:
         """Get all the tags for the rooms where the tags have changed since the
         given version
 
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index 14ef5b040d..f6a6fd4079 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -16,9 +16,9 @@ import logging
 import re
 from typing import (
     TYPE_CHECKING,
-    Dict,
     Iterable,
     List,
+    Mapping,
     Optional,
     Sequence,
     Set,
@@ -586,7 +586,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
         )
 
     @cached()
-    async def get_user_in_directory(self, user_id: str) -> Optional[Dict[str, str]]:
+    async def get_user_in_directory(self, user_id: str) -> Optional[Mapping[str, str]]:
         return await self.db_pool.simple_select_one(
             table="user_directory",
             keyvalues={"user_id": user_id},
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 3acdb39da7..6c335a9315 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -23,7 +23,7 @@ from typing_extensions import Counter as CounterType
 
 from synapse.config.homeserver import HomeServerConfig
 from synapse.storage.database import LoggingDatabaseConnection
-from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
+from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
 from synapse.storage.schema import SCHEMA_COMPAT_VERSION, SCHEMA_VERSION
 from synapse.storage.types import Cursor
 
@@ -108,9 +108,14 @@ def prepare_database(
         # so we start one before running anything. This ensures that any upgrades
         # are either applied completely, or not at all.
         #
-        # (psycopg2 automatically starts a transaction as soon as we run any statements
-        # at all, so this is redundant but harmless there.)
-        cur.execute("BEGIN TRANSACTION")
+        # psycopg2 does not automatically start transactions when in autocommit mode.
+        # While it is technically harmless to nest transactions in postgres, doing so
+        # results in a warning in Postgres' logs per query. And we'd rather like to
+        # avoid doing that.
+        if isinstance(database_engine, Sqlite3Engine) or (
+            isinstance(database_engine, PostgresEngine) and db_conn.autocommit
+        ):
+            cur.execute("BEGIN TRANSACTION")
 
         logger.info("%r: Checking existing schema version", databases)
         version_info = _get_or_create_schema_state(cur, database_engine)
diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py
index f82d1cfc29..52e366c8ae 100644
--- a/synapse/types/__init__.py
+++ b/synapse/types/__init__.py
@@ -69,6 +69,8 @@ StateMap = Mapping[StateKey, T]
 MutableStateMap = MutableMapping[StateKey, T]
 
 # JSON types. These could be made stronger, but will do for now.
+# A "simple" (canonical) JSON value.
+SimpleJsonValue = Optional[Union[str, int, bool]]
 # A JSON-serialisable dict.
 JsonDict = Dict[str, Any]
 # A JSON-serialisable mapping; roughly speaking an immutable JSONDict.