summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/api/auth.py7
-rw-r--r--synapse/config/experimental.py5
-rw-r--r--synapse/events/snapshot.py174
-rw-r--r--synapse/events/third_party_rules.py6
-rw-r--r--synapse/handlers/directory.py3
-rw-r--r--synapse/handlers/e2e_keys.py11
-rw-r--r--synapse/handlers/federation.py59
-rw-r--r--synapse/handlers/federation_event.py6
-rw-r--r--synapse/handlers/message.py48
-rw-r--r--synapse/handlers/room.py2
-rw-r--r--synapse/handlers/sync.py2
-rw-r--r--synapse/http/server.py40
-rw-r--r--synapse/logging/opentracing.py10
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py30
-rw-r--r--synapse/rest/admin/media.py18
-rw-r--r--synapse/rest/client/room_keys.py48
-rw-r--r--synapse/rest/client/tags.py4
-rw-r--r--synapse/rest/media/v1/media_storage.py7
-rw-r--r--synapse/server.py12
-rw-r--r--synapse/state/__init__.py176
-rw-r--r--synapse/storage/_base.py2
-rw-r--r--synapse/storage/database.py1
-rw-r--r--synapse/storage/databases/main/devices.py32
-rw-r--r--synapse/storage/databases/main/events.py2
-rw-r--r--synapse/storage/prepare_database.py13
25 files changed, 472 insertions, 246 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 3d7f986ac7..66e869bc2d 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -32,7 +32,6 @@ from synapse.appservice import ApplicationService
 from synapse.http import get_request_user_agent
 from synapse.http.site import SynapseRequest
 from synapse.logging.opentracing import (
-    SynapseTags,
     active_span,
     force_tracing,
     start_active_span,
@@ -162,12 +161,6 @@ class Auth:
                 parent_span.set_tag(
                     "authenticated_entity", requester.authenticated_entity
                 )
-                # We tag the Synapse instance name so that it's an easy jumping
-                # off point into the logs. Can also be used to filter for an
-                # instance that is under load.
-                parent_span.set_tag(
-                    SynapseTags.INSTANCE_NAME, self.hs.get_instance_name()
-                )
                 parent_span.set_tag("user_id", requester.user.to_string())
                 if requester.device_id is not None:
                     parent_span.set_tag("device_id", requester.device_id)
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index 53c0682dfd..5e3a889081 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -169,6 +169,11 @@ class ExperimentalConfig(Config):
         # MSC3925: do not replace events with their edits
         self.msc3925_inhibit_edit = experimental.get("msc3925_inhibit_edit", False)
 
+        # MSC3873: Disambiguate event_match keys.
+        self.msc3783_escape_event_match_key = experimental.get(
+            "msc3783_escape_event_match_key", False
+        )
+
         # MSC3952: Intentional mentions
         self.msc3952_intentional_mentions = experimental.get(
             "msc3952_intentional_mentions", False
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index 6eaef8b57a..e0d82ad81c 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -11,6 +11,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from abc import ABC, abstractmethod
 from typing import TYPE_CHECKING, List, Optional, Tuple
 
 import attr
@@ -26,8 +27,51 @@ if TYPE_CHECKING:
     from synapse.types.state import StateFilter
 
 
+class UnpersistedEventContextBase(ABC):
+    """
+    This is a base class for EventContext and UnpersistedEventContext, objects which
+    hold information relevant to storing an associated event. Note that an
+    UnpersistedEventContexts must be converted into an EventContext before it is
+    suitable to send to the db with its associated event.
+
+    Attributes:
+        _storage: storage controllers for interfacing with the database
+        app_service: If the associated event is being sent by a (local) application service, that
+            app service.
+    """
+
+    def __init__(self, storage_controller: "StorageControllers"):
+        self._storage: "StorageControllers" = storage_controller
+        self.app_service: Optional[ApplicationService] = None
+
+    @abstractmethod
+    async def persist(
+        self,
+        event: EventBase,
+    ) -> "EventContext":
+        """
+        A method to convert an UnpersistedEventContext to an EventContext, suitable for
+        sending to the database with the associated event.
+        """
+        pass
+
+    @abstractmethod
+    async def get_prev_state_ids(
+        self, state_filter: Optional["StateFilter"] = None
+    ) -> StateMap[str]:
+        """
+        Gets the room state at the event (ie not including the event if the event is a
+        state event).
+
+        Args:
+            state_filter: specifies the type of state event to fetch from DB, example:
+            EventTypes.JoinRules
+        """
+        pass
+
+
 @attr.s(slots=True, auto_attribs=True)
-class EventContext:
+class EventContext(UnpersistedEventContextBase):
     """
     Holds information relevant to persisting an event
 
@@ -77,9 +121,6 @@ class EventContext:
         delta_ids: If ``prev_group`` is not None, the state delta between ``prev_group``
             and ``state_group``.
 
-        app_service: If this event is being sent by a (local) application service, that
-            app service.
-
         partial_state: if True, we may be storing this event with a temporary,
             incomplete state.
     """
@@ -122,6 +163,9 @@ class EventContext:
         """Return an EventContext instance suitable for persisting an outlier event"""
         return EventContext(storage=storage)
 
+    async def persist(self, event: EventBase) -> "EventContext":
+        return self
+
     async def serialize(self, event: EventBase, store: "DataStore") -> JsonDict:
         """Converts self to a type that can be serialized as JSON, and then
         deserialized by `deserialize`
@@ -254,6 +298,128 @@ class EventContext:
         )
 
 
+@attr.s(slots=True, auto_attribs=True)
+class UnpersistedEventContext(UnpersistedEventContextBase):
+    """
+    The event context holds information about the state groups for an event. It is important
+    to remember that an event technically has two state groups: the state group before the
+    event, and the state group after the event. If the event is not a state event, the state
+    group will not change (ie the state group before the event will be the same as the state
+    group after the event), but if it is a state event the state group before the event
+    will differ from the state group after the event.
+    This is a version of an EventContext before the new state group (if any) has been
+    computed and stored. It contains information about the state before the event (which
+    also may be the information after the event, if the event is not a state event). The
+    UnpersistedEventContext must be converted into an EventContext by calling the method
+    'persist' on it before it is suitable to be sent to the DB for processing.
+
+        state_group_after_event:
+             The state group after the event. This will always be None until it is persisted.
+             If the event is not a state event, this will be the same as
+             state_group_before_event.
+
+        state_group_before_event:
+            The ID of the state group representing the state of the room before this event.
+
+        state_delta_due_to_event:
+            If the event is a state event, then this is the delta of the state between
+             `state_group` and `state_group_before_event`
+
+        prev_group_for_state_group_before_event:
+            If it is known, ``state_group_before_event``'s previous state group.
+
+        delta_ids_to_state_group_before_event:
+             If ``prev_group_for_state_group_before_event`` is not None, the state delta
+             between ``prev_group_for_state_group_before_event`` and ``state_group_before_event``.
+
+        partial_state:
+            Whether the event has partial state.
+
+        state_map_before_event:
+            A map of the state before the event, i.e. the state at `state_group_before_event`
+    """
+
+    _storage: "StorageControllers"
+    state_group_before_event: Optional[int]
+    state_group_after_event: Optional[int]
+    state_delta_due_to_event: Optional[dict]
+    prev_group_for_state_group_before_event: Optional[int]
+    delta_ids_to_state_group_before_event: Optional[StateMap[str]]
+    partial_state: bool
+    state_map_before_event: Optional[StateMap[str]] = None
+
+    async def get_prev_state_ids(
+        self, state_filter: Optional["StateFilter"] = None
+    ) -> StateMap[str]:
+        """
+        Gets the room state map, excluding this event.
+
+        Args:
+            state_filter: specifies the type of state event to fetch from DB
+
+        Returns:
+            Maps a (type, state_key) to the event ID of the state event matching
+            this tuple.
+        """
+        if self.state_map_before_event:
+            return self.state_map_before_event
+
+        assert self.state_group_before_event is not None
+        return await self._storage.state.get_state_ids_for_group(
+            self.state_group_before_event, state_filter
+        )
+
+    async def persist(self, event: EventBase) -> EventContext:
+        """
+        Creates a full `EventContext` for the event, persisting any referenced state that
+        has not yet been persisted.
+
+        Args:
+             event: event that the EventContext is associated with.
+
+        Returns: An EventContext suitable for sending to the database with the event
+        for persisting
+        """
+        assert self.partial_state is not None
+
+        # If we have a full set of state for before the event but don't have a state
+        # group for that state, we need to get one
+        if self.state_group_before_event is None:
+            assert self.state_map_before_event
+            state_group_before_event = await self._storage.state.store_state_group(
+                event.event_id,
+                event.room_id,
+                prev_group=self.prev_group_for_state_group_before_event,
+                delta_ids=self.delta_ids_to_state_group_before_event,
+                current_state_ids=self.state_map_before_event,
+            )
+            self.state_group_before_event = state_group_before_event
+
+        # if the event isn't a state event the state group doesn't change
+        if not self.state_delta_due_to_event:
+            state_group_after_event = self.state_group_before_event
+
+        # otherwise if it is a state event we need to get a state group for it
+        else:
+            state_group_after_event = await self._storage.state.store_state_group(
+                event.event_id,
+                event.room_id,
+                prev_group=self.state_group_before_event,
+                delta_ids=self.state_delta_due_to_event,
+                current_state_ids=None,
+            )
+
+        return EventContext.with_state(
+            storage=self._storage,
+            state_group=state_group_after_event,
+            state_group_before_event=self.state_group_before_event,
+            state_delta_due_to_event=self.state_delta_due_to_event,
+            partial_state=self.partial_state,
+            prev_group=self.state_group_before_event,
+            delta_ids=self.state_delta_due_to_event,
+        )
+
+
 def _encode_state_dict(
     state_dict: Optional[StateMap[str]],
 ) -> Optional[List[Tuple[str, str, str]]]:
diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py
index 72ab696898..97c61cc258 100644
--- a/synapse/events/third_party_rules.py
+++ b/synapse/events/third_party_rules.py
@@ -18,7 +18,7 @@ from twisted.internet.defer import CancelledError
 
 from synapse.api.errors import ModuleFailedException, SynapseError
 from synapse.events import EventBase
-from synapse.events.snapshot import EventContext
+from synapse.events.snapshot import UnpersistedEventContextBase
 from synapse.storage.roommember import ProfileInfo
 from synapse.types import Requester, StateMap
 from synapse.util.async_helpers import delay_cancellation, maybe_awaitable
@@ -231,7 +231,9 @@ class ThirdPartyEventRules:
             self._on_threepid_bind_callbacks.append(on_threepid_bind)
 
     async def check_event_allowed(
-        self, event: EventBase, context: EventContext
+        self,
+        event: EventBase,
+        context: UnpersistedEventContextBase,
     ) -> Tuple[bool, Optional[dict]]:
         """Check if a provided event should be allowed in the given context.
 
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 2ea52257cb..d31b0fbb17 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -485,7 +485,8 @@ class DirectoryHandler:
                 )
             )
             if canonical_alias:
-                room_aliases.append(canonical_alias)
+                # Ensure we do not mutate room_aliases.
+                room_aliases = room_aliases + [canonical_alias]
 
             if not self.config.roomdirectory.is_publishing_room_allowed(
                 user_id, room_id, room_aliases
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index d2188ca08f..43cbece21b 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -159,19 +159,22 @@ class E2eKeysHandler:
             # A map of destination -> user ID -> device IDs.
             remote_queries_not_in_cache: Dict[str, Dict[str, Iterable[str]]] = {}
             if remote_queries:
-                query_list: List[Tuple[str, Optional[str]]] = []
+                user_ids = set()
+                user_and_device_ids: List[Tuple[str, str]] = []
                 for user_id, device_ids in remote_queries.items():
                     if device_ids:
-                        query_list.extend(
+                        user_and_device_ids.extend(
                             (user_id, device_id) for device_id in device_ids
                         )
                     else:
-                        query_list.append((user_id, None))
+                        user_ids.add(user_id)
 
                 (
                     user_ids_not_in_cache,
                     remote_results,
-                ) = await self.store.get_user_devices_from_cache(query_list)
+                ) = await self.store.get_user_devices_from_cache(
+                    user_ids, user_and_device_ids
+                )
 
                 # Check that the homeserver still shares a room with all cached users.
                 # Note that this check may be slightly racy when a remote user leaves a
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 7f64130e0a..43ed4a3dd1 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -56,7 +56,7 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
 from synapse.crypto.event_signing import compute_event_signature
 from synapse.event_auth import validate_event_for_room_version
 from synapse.events import EventBase
-from synapse.events.snapshot import EventContext
+from synapse.events.snapshot import EventContext, UnpersistedEventContextBase
 from synapse.events.validator import EventValidator
 from synapse.federation.federation_client import InvalidResponseError
 from synapse.http.servlet import assert_params_in_dict
@@ -990,7 +990,10 @@ class FederationHandler:
         )
 
         try:
-            event, context = await self.event_creation_handler.create_new_client_event(
+            (
+                event,
+                unpersisted_context,
+            ) = await self.event_creation_handler.create_new_client_event(
                 builder=builder
             )
         except SynapseError as e:
@@ -998,7 +1001,9 @@ class FederationHandler:
             raise
 
         # Ensure the user can even join the room.
-        await self._federation_event_handler.check_join_restrictions(context, event)
+        await self._federation_event_handler.check_join_restrictions(
+            unpersisted_context, event
+        )
 
         # The remote hasn't signed it yet, obviously. We'll do the full checks
         # when we get the event back in `on_send_join_request`
@@ -1178,7 +1183,7 @@ class FederationHandler:
             },
         )
 
-        event, context = await self.event_creation_handler.create_new_client_event(
+        event, _ = await self.event_creation_handler.create_new_client_event(
             builder=builder
         )
 
@@ -1228,12 +1233,13 @@ class FederationHandler:
             },
         )
 
-        event, context = await self.event_creation_handler.create_new_client_event(
-            builder=builder
-        )
+        (
+            event,
+            unpersisted_context,
+        ) = await self.event_creation_handler.create_new_client_event(builder=builder)
 
         event_allowed, _ = await self.third_party_event_rules.check_event_allowed(
-            event, context
+            event, unpersisted_context
         )
         if not event_allowed:
             logger.warning("Creation of knock %s forbidden by third-party rules", event)
@@ -1406,15 +1412,20 @@ class FederationHandler:
                 try:
                     (
                         event,
-                        context,
+                        unpersisted_context,
                     ) = await self.event_creation_handler.create_new_client_event(
                         builder=builder
                     )
 
-                    event, context = await self.add_display_name_to_third_party_invite(
-                        room_version_obj, event_dict, event, context
+                    (
+                        event,
+                        unpersisted_context,
+                    ) = await self.add_display_name_to_third_party_invite(
+                        room_version_obj, event_dict, event, unpersisted_context
                     )
 
+                    context = await unpersisted_context.persist(event)
+
                     EventValidator().validate_new(event, self.config)
 
                     # We need to tell the transaction queue to send this out, even
@@ -1483,14 +1494,19 @@ class FederationHandler:
             try:
                 (
                     event,
-                    context,
+                    unpersisted_context,
                 ) = await self.event_creation_handler.create_new_client_event(
                     builder=builder
                 )
-                event, context = await self.add_display_name_to_third_party_invite(
-                    room_version_obj, event_dict, event, context
+                (
+                    event,
+                    unpersisted_context,
+                ) = await self.add_display_name_to_third_party_invite(
+                    room_version_obj, event_dict, event, unpersisted_context
                 )
 
+                context = await unpersisted_context.persist(event)
+
                 try:
                     validate_event_for_room_version(event)
                     await self._event_auth_handler.check_auth_rules_from_context(event)
@@ -1522,8 +1538,8 @@ class FederationHandler:
         room_version_obj: RoomVersion,
         event_dict: JsonDict,
         event: EventBase,
-        context: EventContext,
-    ) -> Tuple[EventBase, EventContext]:
+        context: UnpersistedEventContextBase,
+    ) -> Tuple[EventBase, UnpersistedEventContextBase]:
         key = (
             EventTypes.ThirdPartyInvite,
             event.content["third_party_invite"]["signed"]["token"],
@@ -1557,11 +1573,14 @@ class FederationHandler:
             room_version_obj, event_dict
         )
         EventValidator().validate_builder(builder)
-        event, context = await self.event_creation_handler.create_new_client_event(
-            builder=builder
-        )
+
+        (
+            event,
+            unpersisted_context,
+        ) = await self.event_creation_handler.create_new_client_event(builder=builder)
+
         EventValidator().validate_new(event, self.config)
-        return event, context
+        return event, unpersisted_context
 
     async def _check_signature(self, event: EventBase, context: EventContext) -> None:
         """
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index e037acbca2..3561f2f1de 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -58,7 +58,7 @@ from synapse.event_auth import (
     validate_event_for_room_version,
 )
 from synapse.events import EventBase
-from synapse.events.snapshot import EventContext
+from synapse.events.snapshot import EventContext, UnpersistedEventContextBase
 from synapse.federation.federation_client import InvalidResponseError, PulledPduInfo
 from synapse.logging.context import nested_logging_context
 from synapse.logging.opentracing import (
@@ -426,7 +426,9 @@ class FederationEventHandler:
         return event, context
 
     async def check_join_restrictions(
-        self, context: EventContext, event: EventBase
+        self,
+        context: UnpersistedEventContextBase,
+        event: EventBase,
     ) -> None:
         """Check that restrictions in restricted join rules are matched
 
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index e688e00575..3e30f52e4d 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -48,7 +48,7 @@ from synapse.api.urls import ConsentURIBuilder
 from synapse.event_auth import validate_event_for_room_version
 from synapse.events import EventBase, relation_from_event
 from synapse.events.builder import EventBuilder
-from synapse.events.snapshot import EventContext
+from synapse.events.snapshot import EventContext, UnpersistedEventContextBase
 from synapse.events.utils import maybe_upsert_event_field
 from synapse.events.validator import EventValidator
 from synapse.handlers.directory import DirectoryHandler
@@ -499,9 +499,9 @@ class EventCreationHandler:
 
         self.request_ratelimiter = hs.get_request_ratelimiter()
 
-        # We arbitrarily limit concurrent event creation for a room to 5.
-        # This is to stop us from diverging history *too* much.
-        self.limiter = Linearizer(max_count=5, name="room_event_creation_limit")
+        # We limit concurrent event creation for a room to 1. This prevents state resolution
+        # from occurring when sending bursts of events to a local room
+        self.limiter = Linearizer(max_count=1, name="room_event_creation_limit")
 
         self._bulk_push_rule_evaluator = hs.get_bulk_push_rule_evaluator()
 
@@ -708,7 +708,7 @@ class EventCreationHandler:
 
         builder.internal_metadata.historical = historical
 
-        event, context = await self.create_new_client_event(
+        event, unpersisted_context = await self.create_new_client_event(
             builder=builder,
             requester=requester,
             allow_no_prev_events=allow_no_prev_events,
@@ -721,6 +721,8 @@ class EventCreationHandler:
             current_state_group=current_state_group,
         )
 
+        context = await unpersisted_context.persist(event)
+
         # In an ideal world we wouldn't need the second part of this condition. However,
         # this behaviour isn't spec'd yet, meaning we should be able to deactivate this
         # behaviour. Another reason is that this code is also evaluated each time a new
@@ -1083,13 +1085,14 @@ class EventCreationHandler:
         state_map: Optional[StateMap[str]] = None,
         for_batch: bool = False,
         current_state_group: Optional[int] = None,
-    ) -> Tuple[EventBase, EventContext]:
+    ) -> Tuple[EventBase, UnpersistedEventContextBase]:
         """Create a new event for a local client. If bool for_batch is true, will
         create an event using the prev_event_ids, and will create an event context for
         the event using the parameters state_map and current_state_group, thus these parameters
         must be provided in this case if for_batch is True. The subsequently created event
         and context are suitable for being batched up and bulk persisted to the database
-        with other similarly created events.
+        with other similarly created events. Note that this returns an UnpersistedEventContext,
+        which must be converted to an EventContext before it can be sent to the DB.
 
         Args:
             builder:
@@ -1131,7 +1134,7 @@ class EventCreationHandler:
                 batch persisting
 
         Returns:
-            Tuple of created event, context
+            Tuple of created event, UnpersistedEventContext
         """
         # Strip down the state_event_ids to only what we need to auth the event.
         # For example, we don't need extra m.room.member that don't match event.sender
@@ -1192,9 +1195,16 @@ class EventCreationHandler:
             event = await builder.build(
                 prev_event_ids=prev_event_ids, auth_event_ids=auth_ids, depth=depth
             )
-            context = await self.state.compute_event_context_for_batched(
-                event, state_map, current_state_group
+
+            context: UnpersistedEventContextBase = (
+                await self.state.calculate_context_info(
+                    event,
+                    state_ids_before_event=state_map,
+                    partial_state=False,
+                    state_group_before_event=current_state_group,
+                )
             )
+
         else:
             event = await builder.build(
                 prev_event_ids=prev_event_ids,
@@ -1244,16 +1254,17 @@ class EventCreationHandler:
 
                     state_map_for_event[(data.event_type, data.state_key)] = state_id
 
-                context = await self.state.compute_event_context(
+                # TODO(faster_joins): check how MSC2716 works and whether we can have
+                #   partial state here
+                #   https://github.com/matrix-org/synapse/issues/13003
+                context = await self.state.calculate_context_info(
                     event,
                     state_ids_before_event=state_map_for_event,
-                    # TODO(faster_joins): check how MSC2716 works and whether we can have
-                    #   partial state here
-                    #   https://github.com/matrix-org/synapse/issues/13003
                     partial_state=False,
                 )
+
             else:
-                context = await self.state.compute_event_context(event)
+                context = await self.state.calculate_context_info(event)
 
         if requester:
             context.app_service = requester.app_service
@@ -2082,9 +2093,9 @@ class EventCreationHandler:
 
     async def _rebuild_event_after_third_party_rules(
         self, third_party_result: dict, original_event: EventBase
-    ) -> Tuple[EventBase, EventContext]:
+    ) -> Tuple[EventBase, UnpersistedEventContextBase]:
         # the third_party_event_rules want to replace the event.
-        # we do some basic checks, and then return the replacement event and context.
+        # we do some basic checks, and then return the replacement event.
 
         # Construct a new EventBuilder and validate it, which helps with the
         # rest of these checks.
@@ -2138,5 +2149,6 @@ class EventCreationHandler:
 
         # we rebuild the event context, to be on the safe side. If nothing else,
         # delta_ids might need an update.
-        context = await self.state.compute_event_context(event)
+        context = await self.state.calculate_context_info(event)
+
         return event, context
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 7ba7c4ff07..0e759b8a5d 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -1076,7 +1076,7 @@ class RoomCreationHandler:
         state_map: MutableStateMap[str] = {}
         # current_state_group of last event created. Used for computing event context of
         # events to be batched
-        current_state_group = None
+        current_state_group: Optional[int] = None
 
         def create_event_dict(etype: str, content: JsonDict, **kwargs: Any) -> JsonDict:
             e = {"type": etype, "content": content}
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 3566537894..202b35eee6 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -1753,6 +1753,7 @@ class SyncHandler:
             )
 
             if push_rules_changed:
+                global_account_data = dict(global_account_data)
                 global_account_data["m.push_rules"] = await self.push_rules_for_user(
                     sync_config.user
                 )
@@ -1763,6 +1764,7 @@ class SyncHandler:
                 account_data_by_room,
             ) = await self.store.get_account_data_for_user(sync_config.user.to_string())
 
+            global_account_data = dict(global_account_data)
             global_account_data["m.push_rules"] = await self.push_rules_for_user(
                 sync_config.user
             )
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 2563858f3c..9314454af1 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -30,7 +30,6 @@ from typing import (
     Iterable,
     Iterator,
     List,
-    NoReturn,
     Optional,
     Pattern,
     Tuple,
@@ -340,7 +339,8 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
 
             return callback_return
 
-        return _unrecognised_request_handler(request)
+        # A request with an unknown method (for a known endpoint) was received.
+        raise UnrecognizedRequestError(code=405)
 
     @abc.abstractmethod
     def _send_response(
@@ -396,7 +396,6 @@ class DirectServeJsonResource(_AsyncResource):
 
 @attr.s(slots=True, frozen=True, auto_attribs=True)
 class _PathEntry:
-    pattern: Pattern
     callback: ServletCallback
     servlet_classname: str
 
@@ -425,13 +424,14 @@ class JsonResource(DirectServeJsonResource):
     ):
         super().__init__(canonical_json, extract_context)
         self.clock = hs.get_clock()
-        self.path_regexs: Dict[bytes, List[_PathEntry]] = {}
+        # Map of path regex -> method -> callback.
+        self._routes: Dict[Pattern[str], Dict[bytes, _PathEntry]] = {}
         self.hs = hs
 
     def register_paths(
         self,
         method: str,
-        path_patterns: Iterable[Pattern],
+        path_patterns: Iterable[Pattern[str]],
         callback: ServletCallback,
         servlet_classname: str,
     ) -> None:
@@ -455,8 +455,8 @@ class JsonResource(DirectServeJsonResource):
 
         for path_pattern in path_patterns:
             logger.debug("Registering for %s %s", method, path_pattern.pattern)
-            self.path_regexs.setdefault(method_bytes, []).append(
-                _PathEntry(path_pattern, callback, servlet_classname)
+            self._routes.setdefault(path_pattern, {})[method_bytes] = _PathEntry(
+                callback, servlet_classname
             )
 
     def _get_handler_for_request(
@@ -478,14 +478,17 @@ class JsonResource(DirectServeJsonResource):
 
         # Loop through all the registered callbacks to check if the method
         # and path regex match
-        for path_entry in self.path_regexs.get(request_method, []):
-            m = path_entry.pattern.match(request_path)
+        for path_pattern, methods in self._routes.items():
+            m = path_pattern.match(request_path)
             if m:
-                # We found a match!
+                # We found a matching path!
+                path_entry = methods.get(request_method)
+                if not path_entry:
+                    raise UnrecognizedRequestError(code=405)
                 return path_entry.callback, path_entry.servlet_classname, m.groupdict()
 
-        # Huh. No one wanted to handle that? Fiiiiiine. Send 400.
-        return _unrecognised_request_handler, "unrecognised_request_handler", {}
+        # Huh. No one wanted to handle that? Fiiiiiine.
+        raise UnrecognizedRequestError(code=404)
 
     async def _async_render(self, request: SynapseRequest) -> Tuple[int, Any]:
         callback, servlet_classname, group_dict = self._get_handler_for_request(request)
@@ -567,19 +570,6 @@ class StaticResource(File):
         return super().render_GET(request)
 
 
-def _unrecognised_request_handler(request: Request) -> NoReturn:
-    """Request handler for unrecognised requests
-
-    This is a request handler suitable for return from
-    _get_handler_for_request. It actually just raises an
-    UnrecognizedRequestError.
-
-    Args:
-        request: Unused, but passed in to match the signature of ServletCallback.
-    """
-    raise UnrecognizedRequestError(code=404)
-
-
 class UnrecognizedRequestResource(resource.Resource):
     """
     Similar to twisted.web.resource.NoResource, but returns a JSON 404 with an
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index 8ef9a0dda8..6c7cf1b294 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -466,8 +466,16 @@ def init_tracer(hs: "HomeServer") -> None:
         STRIP_INSTANCE_NUMBER_SUFFIX_REGEX, "", hs.get_instance_name()
     )
 
+    jaeger_config = hs.config.tracing.jaeger_config
+    tags = jaeger_config.setdefault("tags", {})
+
+    # tag the Synapse instance name so that it's an easy jumping
+    # off point into the logs. Can also be used to filter for an
+    # instance that is under load.
+    tags[SynapseTags.INSTANCE_NAME] = hs.get_instance_name()
+
     config = JaegerConfig(
-        config=hs.config.tracing.jaeger_config,
+        config=jaeger_config,
         service_name=f"{hs.config.server.server_name} {instance_name_by_type}",
         scope_manager=LogContextScopeManager(),
         metrics_factory=PrometheusMetricsFactory(),
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index d9c0a98f44..39d2f88f03 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -271,7 +271,10 @@ class BulkPushRuleEvaluator:
                     related_event_id, allow_none=True
                 )
                 if related_event is not None:
-                    related_events[relation_type] = _flatten_dict(related_event)
+                    related_events[relation_type] = _flatten_dict(
+                        related_event,
+                        msc3783_escape_event_match_key=self.hs.config.experimental.msc3783_escape_event_match_key,
+                    )
 
             reply_event_id = (
                 event.content.get("m.relates_to", {})
@@ -286,7 +289,10 @@ class BulkPushRuleEvaluator:
                 )
 
                 if related_event is not None:
-                    related_events["m.in_reply_to"] = _flatten_dict(related_event)
+                    related_events["m.in_reply_to"] = _flatten_dict(
+                        related_event,
+                        msc3783_escape_event_match_key=self.hs.config.experimental.msc3783_escape_event_match_key,
+                    )
 
                     # indicate that this is from a fallback relation.
                     if relation_type == "m.thread" and event.content.get(
@@ -405,7 +411,10 @@ class BulkPushRuleEvaluator:
             room_mention = mentions.get("room") is True
 
         evaluator = PushRuleEvaluator(
-            _flatten_dict(event),
+            _flatten_dict(
+                event,
+                msc3783_escape_event_match_key=self.hs.config.experimental.msc3783_escape_event_match_key,
+            ),
             has_mentions,
             user_mentions,
             room_mention,
@@ -493,6 +502,8 @@ def _flatten_dict(
     d: Union[EventBase, Mapping[str, Any]],
     prefix: Optional[List[str]] = None,
     result: Optional[Dict[str, str]] = None,
+    *,
+    msc3783_escape_event_match_key: bool = False,
 ) -> Dict[str, str]:
     """
     Given a JSON dictionary (or event) which might contain sub dictionaries,
@@ -521,11 +532,22 @@ def _flatten_dict(
     if result is None:
         result = {}
     for key, value in d.items():
+        if msc3783_escape_event_match_key:
+            # Escape periods in the key with a backslash (and backslashes with an
+            # extra backslash). This is since a period is used as a separator between
+            # nested fields.
+            key = key.replace("\\", "\\\\").replace(".", "\\.")
+
         if isinstance(value, str):
             result[".".join(prefix + [key])] = value.lower()
         elif isinstance(value, Mapping):
             # do not set `room_version` due to recursion considerations below
-            _flatten_dict(value, prefix=(prefix + [key]), result=result)
+            _flatten_dict(
+                value,
+                prefix=(prefix + [key]),
+                result=result,
+                msc3783_escape_event_match_key=msc3783_escape_event_match_key,
+            )
 
     # `room_version` should only ever be set when looking at the top level of an event
     if (
diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py
index 0d072c42a7..c134ccfb3d 100644
--- a/synapse/rest/admin/media.py
+++ b/synapse/rest/admin/media.py
@@ -15,7 +15,7 @@
 
 import logging
 from http import HTTPStatus
-from typing import TYPE_CHECKING, Tuple
+from typing import TYPE_CHECKING, Optional, Tuple
 
 from synapse.api.constants import Direction
 from synapse.api.errors import Codes, NotFoundError, SynapseError
@@ -285,7 +285,12 @@ class DeleteMediaByDateSize(RestServlet):
     timestamp and size.
     """
 
-    PATTERNS = admin_patterns("/media/(?P<server_name>[^/]*)/delete$")
+    PATTERNS = [
+        *admin_patterns("/media/delete$"),
+        # This URL kept around for legacy reasons, it is undesirable since it
+        # overlaps with the DeleteMediaByID servlet.
+        *admin_patterns("/media/(?P<server_name>[^/]*)/delete$"),
+    ]
 
     def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastores().main
@@ -294,7 +299,7 @@ class DeleteMediaByDateSize(RestServlet):
         self.media_repository = hs.get_media_repository()
 
     async def on_POST(
-        self, request: SynapseRequest, server_name: str
+        self, request: SynapseRequest, server_name: Optional[str] = None
     ) -> Tuple[int, JsonDict]:
         await assert_requester_is_admin(self.auth, request)
 
@@ -322,7 +327,8 @@ class DeleteMediaByDateSize(RestServlet):
                 errcode=Codes.INVALID_PARAM,
             )
 
-        if self.server_name != server_name:
+        # This check is useless, we keep it for the legacy endpoint only.
+        if server_name is not None and self.server_name != server_name:
             raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only delete local media")
 
         logging.info(
@@ -489,6 +495,8 @@ def register_servlets_for_media_repo(hs: "HomeServer", http_server: HttpServer)
     ProtectMediaByID(hs).register(http_server)
     UnprotectMediaByID(hs).register(http_server)
     ListMediaInRoom(hs).register(http_server)
-    DeleteMediaByID(hs).register(http_server)
+    # XXX DeleteMediaByDateSize must be registered before DeleteMediaByID as
+    #     their URL routes overlap.
     DeleteMediaByDateSize(hs).register(http_server)
+    DeleteMediaByID(hs).register(http_server)
     UserMediaRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/room_keys.py b/synapse/rest/client/room_keys.py
index f7081f638e..4e7ffdb555 100644
--- a/synapse/rest/client/room_keys.py
+++ b/synapse/rest/client/room_keys.py
@@ -259,6 +259,32 @@ class RoomKeysNewVersionServlet(RestServlet):
         self.auth = hs.get_auth()
         self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler()
 
+    async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+        """
+        Retrieve the version information about the most current backup version (if any)
+
+        It takes out an exclusive lock on this user's room_key backups, to ensure
+        clients only upload to the current backup.
+
+        Returns 404 if the given version does not exist.
+
+        GET /room_keys/version HTTP/1.1
+        {
+            "version": "12345",
+            "algorithm": "m.megolm_backup.v1",
+            "auth_data": "dGhpcyBzaG91bGQgYWN0dWFsbHkgYmUgZW5jcnlwdGVkIGpzb24K"
+        }
+        """
+        requester = await self.auth.get_user_by_req(request, allow_guest=False)
+        user_id = requester.user.to_string()
+
+        try:
+            info = await self.e2e_room_keys_handler.get_version_info(user_id)
+        except SynapseError as e:
+            if e.code == 404:
+                raise SynapseError(404, "No backup found", Codes.NOT_FOUND)
+        return 200, info
+
     async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         """
         Create a new backup version for this user's room_keys with the given
@@ -301,7 +327,7 @@ class RoomKeysNewVersionServlet(RestServlet):
 
 
 class RoomKeysVersionServlet(RestServlet):
-    PATTERNS = client_patterns("/room_keys/version(/(?P<version>[^/]+))?$")
+    PATTERNS = client_patterns("/room_keys/version/(?P<version>[^/]+)$")
 
     def __init__(self, hs: "HomeServer"):
         super().__init__()
@@ -309,12 +335,11 @@ class RoomKeysVersionServlet(RestServlet):
         self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler()
 
     async def on_GET(
-        self, request: SynapseRequest, version: Optional[str]
+        self, request: SynapseRequest, version: str
     ) -> Tuple[int, JsonDict]:
         """
         Retrieve the version information about a given version of the user's
-        room_keys backup.  If the version part is missing, returns info about the
-        most current backup version (if any)
+        room_keys backup.
 
         It takes out an exclusive lock on this user's room_key backups, to ensure
         clients only upload to the current backup.
@@ -339,20 +364,16 @@ class RoomKeysVersionServlet(RestServlet):
         return 200, info
 
     async def on_DELETE(
-        self, request: SynapseRequest, version: Optional[str]
+        self, request: SynapseRequest, version: str
     ) -> Tuple[int, JsonDict]:
         """
         Delete the information about a given version of the user's
-        room_keys backup.  If the version part is missing, deletes the most
-        current backup version (if any). Doesn't delete the actual room data.
+        room_keys backup. Doesn't delete the actual room data.
 
         DELETE /room_keys/version/12345 HTTP/1.1
         HTTP/1.1 200 OK
         {}
         """
-        if version is None:
-            raise SynapseError(400, "No version specified to delete", Codes.NOT_FOUND)
-
         requester = await self.auth.get_user_by_req(request, allow_guest=False)
         user_id = requester.user.to_string()
 
@@ -360,7 +381,7 @@ class RoomKeysVersionServlet(RestServlet):
         return 200, {}
 
     async def on_PUT(
-        self, request: SynapseRequest, version: Optional[str]
+        self, request: SynapseRequest, version: str
     ) -> Tuple[int, JsonDict]:
         """
         Update the information about a given version of the user's room_keys backup.
@@ -386,11 +407,6 @@ class RoomKeysVersionServlet(RestServlet):
         user_id = requester.user.to_string()
         info = parse_json_object_from_request(request)
 
-        if version is None:
-            raise SynapseError(
-                400, "No version specified to update", Codes.MISSING_PARAM
-            )
-
         await self.e2e_room_keys_handler.update_version(user_id, version, info)
         return 200, {}
 
diff --git a/synapse/rest/client/tags.py b/synapse/rest/client/tags.py
index ca638755c7..dde08417a4 100644
--- a/synapse/rest/client/tags.py
+++ b/synapse/rest/client/tags.py
@@ -34,7 +34,9 @@ class TagListServlet(RestServlet):
     GET /user/{user_id}/rooms/{room_id}/tags HTTP/1.1
     """
 
-    PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags")
+    PATTERNS = client_patterns(
+        "/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags$"
+    )
 
     def __init__(self, hs: "HomeServer"):
         super().__init__()
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index a5c3de192f..db25848744 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -46,10 +46,9 @@ from ._base import FileInfo, Responder
 from .filepath import MediaFilePaths
 
 if TYPE_CHECKING:
+    from synapse.rest.media.v1.storage_provider import StorageProvider
     from synapse.server import HomeServer
 
-    from .storage_provider import StorageProviderWrapper
-
 logger = logging.getLogger(__name__)
 
 
@@ -68,7 +67,7 @@ class MediaStorage:
         hs: "HomeServer",
         local_media_directory: str,
         filepaths: MediaFilePaths,
-        storage_providers: Sequence["StorageProviderWrapper"],
+        storage_providers: Sequence["StorageProvider"],
     ):
         self.hs = hs
         self.reactor = hs.get_reactor()
@@ -360,7 +359,7 @@ class ReadableFileWrapper:
     clock: Clock
     path: str
 
-    async def write_chunks_to(self, callback: Callable[[bytes], None]) -> None:
+    async def write_chunks_to(self, callback: Callable[[bytes], object]) -> None:
         """Reads the file in chunks and calls the callback with each chunk."""
 
         with open(self.path, "rb") as file:
diff --git a/synapse/server.py b/synapse/server.py
index 9d6d268f49..efc6b5f895 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -21,7 +21,7 @@
 import abc
 import functools
 import logging
-from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, TypeVar, cast
+from typing import TYPE_CHECKING, Callable, Dict, List, Optional, TypeVar, cast
 
 from twisted.internet.interfaces import IOpenSSLContextFactory
 from twisted.internet.tcp import Port
@@ -144,10 +144,10 @@ if TYPE_CHECKING:
     from synapse.handlers.saml import SamlHandler
 
 
-T = TypeVar("T", bound=Callable[..., Any])
+T = TypeVar("T")
 
 
-def cache_in_self(builder: T) -> T:
+def cache_in_self(builder: Callable[["HomeServer"], T]) -> Callable[["HomeServer"], T]:
     """Wraps a function called e.g. `get_foo`, checking if `self.foo` exists and
     returning if so. If not, calls the given function and sets `self.foo` to it.
 
@@ -166,7 +166,7 @@ def cache_in_self(builder: T) -> T:
     building = [False]
 
     @functools.wraps(builder)
-    def _get(self):
+    def _get(self: "HomeServer") -> T:
         try:
             return getattr(self, depname)
         except AttributeError:
@@ -185,9 +185,7 @@ def cache_in_self(builder: T) -> T:
 
         return dep
 
-    # We cast here as we need to tell mypy that `_get` has the same signature as
-    # `builder`.
-    return cast(T, _get)
+    return _get
 
 
 class HomeServer(metaclass=abc.ABCMeta):
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index fdfb46ab82..e877e6f1a1 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -39,7 +39,11 @@ from prometheus_client import Counter, Histogram
 from synapse.api.constants import EventTypes
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions
 from synapse.events import EventBase
-from synapse.events.snapshot import EventContext
+from synapse.events.snapshot import (
+    EventContext,
+    UnpersistedEventContext,
+    UnpersistedEventContextBase,
+)
 from synapse.logging.context import ContextResourceUsage
 from synapse.replication.http.state import ReplicationUpdateCurrentStateRestServlet
 from synapse.state import v1, v2
@@ -262,31 +266,31 @@ class StateHandler:
         state = await entry.get_state(self._state_storage_controller, StateFilter.all())
         return await self.store.get_joined_hosts(room_id, state, entry)
 
-    async def compute_event_context(
+    async def calculate_context_info(
         self,
         event: EventBase,
         state_ids_before_event: Optional[StateMap[str]] = None,
         partial_state: Optional[bool] = None,
-    ) -> EventContext:
-        """Build an EventContext structure for a non-outlier event.
-
-        (for an outlier, call EventContext.for_outlier directly)
-
-        This works out what the current state should be for the event, and
-        generates a new state group if necessary.
-
-        Args:
-            event:
-            state_ids_before_event: The event ids of the state before the event if
-                it can't be calculated from existing events. This is normally
-                only specified when receiving an event from federation where we
-                don't have the prev events, e.g. when backfilling.
-            partial_state:
-                `True` if `state_ids_before_event` is partial and omits non-critical
-                membership events.
-                `False` if `state_ids_before_event` is the full state.
-                `None` when `state_ids_before_event` is not provided. In this case, the
-                flag will be calculated based on `event`'s prev events.
+        state_group_before_event: Optional[int] = None,
+    ) -> UnpersistedEventContextBase:
+        """
+        Calulates the contents of an unpersisted event context, other than the current
+        state group (which is either provided or calculated when the event context is persisted)
+
+        state_ids_before_event:
+            The event ids of the full state before the event if
+            it can't be calculated from existing events. This is normally
+            only specified when receiving an event from federation where we
+            don't have the prev events, e.g. when backfilling or when the event
+            is being created for batch persisting.
+        partial_state:
+            `True` if `state_ids_before_event` is partial and omits non-critical
+            membership events.
+            `False` if `state_ids_before_event` is the full state.
+            `None` when `state_ids_before_event` is not provided. In this case, the
+            flag will be calculated based on `event`'s prev events.
+        state_group_before_event:
+            the current state group at the time of event, if known
         Returns:
             The event context.
 
@@ -294,7 +298,6 @@ class StateHandler:
             RuntimeError if `state_ids_before_event` is not provided and one or more
                 prev events are missing or outliers.
         """
-
         assert not event.internal_metadata.is_outlier()
 
         #
@@ -306,17 +309,6 @@ class StateHandler:
             state_group_before_event_prev_group = None
             deltas_to_state_group_before_event = None
 
-            # .. though we need to get a state group for it.
-            state_group_before_event = (
-                await self._state_storage_controller.store_state_group(
-                    event.event_id,
-                    event.room_id,
-                    prev_group=None,
-                    delta_ids=None,
-                    current_state_ids=state_ids_before_event,
-                )
-            )
-
             # the partial_state flag must be provided
             assert partial_state is not None
         else:
@@ -345,6 +337,7 @@ class StateHandler:
             logger.debug("calling resolve_state_groups from compute_event_context")
             # we've already taken into account partial state, so no need to wait for
             # complete state here.
+
             entry = await self.resolve_state_groups_for_events(
                 event.room_id,
                 event.prev_event_ids(),
@@ -383,18 +376,19 @@ class StateHandler:
         #
 
         if not event.is_state():
-            return EventContext.with_state(
+            return UnpersistedEventContext(
                 storage=self._storage_controllers,
                 state_group_before_event=state_group_before_event,
-                state_group=state_group_before_event,
+                state_group_after_event=state_group_before_event,
                 state_delta_due_to_event={},
-                prev_group=state_group_before_event_prev_group,
-                delta_ids=deltas_to_state_group_before_event,
+                prev_group_for_state_group_before_event=state_group_before_event_prev_group,
+                delta_ids_to_state_group_before_event=deltas_to_state_group_before_event,
                 partial_state=partial_state,
+                state_map_before_event=state_ids_before_event,
             )
 
         #
-        # otherwise, we'll need to create a new state group for after the event
+        # otherwise, we'll need to set up creating a new state group for after the event
         #
 
         key = (event.type, event.state_key)
@@ -412,88 +406,60 @@ class StateHandler:
 
         delta_ids = {key: event.event_id}
 
-        state_group_after_event = (
-            await self._state_storage_controller.store_state_group(
-                event.event_id,
-                event.room_id,
-                prev_group=state_group_before_event,
-                delta_ids=delta_ids,
-                current_state_ids=None,
-            )
-        )
-
-        return EventContext.with_state(
+        return UnpersistedEventContext(
             storage=self._storage_controllers,
-            state_group=state_group_after_event,
             state_group_before_event=state_group_before_event,
+            state_group_after_event=None,
             state_delta_due_to_event=delta_ids,
-            prev_group=state_group_before_event,
-            delta_ids=delta_ids,
+            prev_group_for_state_group_before_event=state_group_before_event_prev_group,
+            delta_ids_to_state_group_before_event=deltas_to_state_group_before_event,
             partial_state=partial_state,
+            state_map_before_event=state_ids_before_event,
         )
 
-    async def compute_event_context_for_batched(
+    async def compute_event_context(
         self,
         event: EventBase,
-        state_ids_before_event: StateMap[str],
-        current_state_group: int,
+        state_ids_before_event: Optional[StateMap[str]] = None,
+        partial_state: Optional[bool] = None,
     ) -> EventContext:
-        """
-        Generate an event context for an event that has not yet been persisted to the
-        database. Intended for use with events that are created to be persisted in a batch.
-        Args:
-            event: the event the context is being computed for
-            state_ids_before_event: a state map consisting of the state ids of the events
-            created prior to this event.
-            current_state_group: the current state group before the event.
-        """
-        state_group_before_event_prev_group = None
-        deltas_to_state_group_before_event = None
-
-        state_group_before_event = current_state_group
-
-        # if the event is not state, we are set
-        if not event.is_state():
-            return EventContext.with_state(
-                storage=self._storage_controllers,
-                state_group_before_event=state_group_before_event,
-                state_group=state_group_before_event,
-                state_delta_due_to_event={},
-                prev_group=state_group_before_event_prev_group,
-                delta_ids=deltas_to_state_group_before_event,
-                partial_state=False,
-            )
+        """Build an EventContext structure for a non-outlier event.
 
-        # otherwise, we'll need to create a new state group for after the event
-        key = (event.type, event.state_key)
+        (for an outlier, call EventContext.for_outlier directly)
 
-        if state_ids_before_event is not None:
-            replaces = state_ids_before_event.get(key)
+        This works out what the current state should be for the event, and
+        generates a new state group if necessary.
 
-        if replaces and replaces != event.event_id:
-            event.unsigned["replaces_state"] = replaces
+        Args:
+            event:
+            state_ids_before_event: The event ids of the state before the event if
+                it can't be calculated from existing events. This is normally
+                only specified when receiving an event from federation where we
+                don't have the prev events, e.g. when backfilling.
+            partial_state:
+                `True` if `state_ids_before_event` is partial and omits non-critical
+                membership events.
+                `False` if `state_ids_before_event` is the full state.
+                `None` when `state_ids_before_event` is not provided. In this case, the
+                flag will be calculated based on `event`'s prev events.
+            entry:
+                A state cache entry for the resolved state across the prev events. We may
+                have already calculated this, so if it's available pass it in
+        Returns:
+            The event context.
 
-        delta_ids = {key: event.event_id}
+        Raises:
+            RuntimeError if `state_ids_before_event` is not provided and one or more
+                prev events are missing or outliers.
+        """
 
-        state_group_after_event = (
-            await self._state_storage_controller.store_state_group(
-                event.event_id,
-                event.room_id,
-                prev_group=state_group_before_event,
-                delta_ids=delta_ids,
-                current_state_ids=None,
-            )
+        unpersisted_context = await self.calculate_context_info(
+            event=event,
+            state_ids_before_event=state_ids_before_event,
+            partial_state=partial_state,
         )
 
-        return EventContext.with_state(
-            storage=self._storage_controllers,
-            state_group=state_group_after_event,
-            state_group_before_event=state_group_before_event,
-            state_delta_due_to_event=delta_ids,
-            prev_group=state_group_before_event,
-            delta_ids=delta_ids,
-            partial_state=False,
-        )
+        return await unpersisted_context.persist(event)
 
     @measure_func()
     async def resolve_state_groups_for_events(
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 41d9111019..481fec72fe 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -37,6 +37,8 @@ class SQLBaseStore(metaclass=ABCMeta):
     per data store (and not one per physical database).
     """
 
+    db_pool: DatabasePool
+
     def __init__(
         self,
         database: DatabasePool,
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index e20c5c5302..feaa6cdd07 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -499,6 +499,7 @@ class DatabasePool:
     """
 
     _TXN_ID = 0
+    engine: BaseDatabaseEngine
 
     def __init__(
         self,
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index e8b6cc6b80..85c1778a81 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -100,6 +100,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
                 ("device_lists_outbound_pokes", "stream_id"),
                 ("device_lists_changes_in_room", "stream_id"),
                 ("device_lists_remote_pending", "stream_id"),
+                ("device_lists_changes_converted_stream_position", "stream_id"),
             ],
             is_writer=hs.config.worker.worker_app is None,
         )
@@ -745,42 +746,45 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
     @trace
     @cancellable
     async def get_user_devices_from_cache(
-        self, query_list: List[Tuple[str, Optional[str]]]
+        self, user_ids: Set[str], user_and_device_ids: List[Tuple[str, str]]
     ) -> Tuple[Set[str], Dict[str, Dict[str, JsonDict]]]:
         """Get the devices (and keys if any) for remote users from the cache.
 
         Args:
-            query_list: List of (user_id, device_ids), if device_ids is
-                falsey then return all device ids for that user.
+            user_ids: users which should have all device IDs returned
+            user_and_device_ids: List of (user_id, device_ids)
 
         Returns:
             A tuple of (user_ids_not_in_cache, results_map), where
             user_ids_not_in_cache is a set of user_ids and results_map is a
             mapping of user_id -> device_id -> device_info.
         """
-        user_ids = {user_id for user_id, _ in query_list}
-        user_map = await self.get_device_list_last_stream_id_for_remotes(list(user_ids))
+        unique_user_ids = user_ids | {user_id for user_id, _ in user_and_device_ids}
+        user_map = await self.get_device_list_last_stream_id_for_remotes(
+            list(unique_user_ids)
+        )
 
         # We go and check if any of the users need to have their device lists
         # resynced. If they do then we remove them from the cached list.
         users_needing_resync = await self.get_user_ids_requiring_device_list_resync(
-            user_ids
+            unique_user_ids
         )
         user_ids_in_cache = {
             user_id for user_id, stream_id in user_map.items() if stream_id
         } - users_needing_resync
-        user_ids_not_in_cache = user_ids - user_ids_in_cache
+        user_ids_not_in_cache = unique_user_ids - user_ids_in_cache
 
+        # First fetch all the users which all devices are to be returned.
         results: Dict[str, Dict[str, JsonDict]] = {}
-        for user_id, device_id in query_list:
-            if user_id not in user_ids_in_cache:
-                continue
-
-            if device_id:
+        for user_id in user_ids:
+            if user_id in user_ids_in_cache:
+                results[user_id] = await self.get_cached_devices_for_user(user_id)
+        # Then fetch all device-specific requests, but skip users we've already
+        # fetched all devices for.
+        for user_id, device_id in user_and_device_ids:
+            if user_id in user_ids_in_cache and user_id not in user_ids:
                 device = await self._get_cached_user_device(user_id, device_id)
                 results.setdefault(user_id, {})[device_id] = device
-            else:
-                results[user_id] = await self.get_cached_devices_for_user(user_id)
 
         set_tag("in_cache", str(results))
         set_tag("not_in_cache", str(user_ids_not_in_cache))
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 1536937b67..cb66376fb4 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -306,7 +306,7 @@ class PersistEventsStore:
 
         # The set of event_ids to return. This includes all soft-failed events
         # and their prev events.
-        existing_prevs = set()
+        existing_prevs: Set[str] = set()
 
         def _get_prevs_before_rejected_txn(
             txn: LoggingTransaction, batch: Collection[str]
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 3acdb39da7..6c335a9315 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -23,7 +23,7 @@ from typing_extensions import Counter as CounterType
 
 from synapse.config.homeserver import HomeServerConfig
 from synapse.storage.database import LoggingDatabaseConnection
-from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
+from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
 from synapse.storage.schema import SCHEMA_COMPAT_VERSION, SCHEMA_VERSION
 from synapse.storage.types import Cursor
 
@@ -108,9 +108,14 @@ def prepare_database(
         # so we start one before running anything. This ensures that any upgrades
         # are either applied completely, or not at all.
         #
-        # (psycopg2 automatically starts a transaction as soon as we run any statements
-        # at all, so this is redundant but harmless there.)
-        cur.execute("BEGIN TRANSACTION")
+        # psycopg2 does not automatically start transactions when in autocommit mode.
+        # While it is technically harmless to nest transactions in postgres, doing so
+        # results in a warning in Postgres' logs per query. And we'd rather like to
+        # avoid doing that.
+        if isinstance(database_engine, Sqlite3Engine) or (
+            isinstance(database_engine, PostgresEngine) and db_conn.autocommit
+        ):
+            cur.execute("BEGIN TRANSACTION")
 
         logger.info("%r: Checking existing schema version", databases)
         version_info = _get_or_create_schema_state(cur, database_engine)