summary refs log tree commit diff
diff options
context:
space:
mode:
authorShay <hillerys@element.io>2023-02-09 13:05:02 -0800
committerGitHub <noreply@github.com>2023-02-09 13:05:02 -0800
commit03bccd542bcffe3ea12cd35108740a7d62dd38ab (patch)
treebadcff6446d1230bccd7e623c4c08f8ccbef780e
parentDo not always start a db txn on Postgres (#14840) (diff)
downloadsynapse-03bccd542bcffe3ea12cd35108740a7d62dd38ab.tar.xz
Add a class UnpersistedEventContext to allow for the batching up of storing state groups (#14675)
* add class UnpersistedEventContext

* modify create new client event to create unpersistedeventcontexts

* persist event contexts after creation

* fix tests to persist unpersisted event contexts

* cleanup

* misc lints + cleanup

* changelog + fix comments

* lints

* fix batch insertion?

* reduce redundant calculation

* add unpersisted event classes

* rework compute_event_context, split into function that returns unpersisted event context and then persists it

* use calculate_context_info to create unpersisted event contexts

* update typing

* $%#^&*

* black

* fix comments and consolidate classes, use attr.s for class

* requested changes

* lint

* requested changes

* requested changes

* refactor to be stupidly explicit

* clearer renaming and flow

* make partial state non-optional

* update docstrings

---------

Co-authored-by: Erik Johnston <erik@matrix.org>
Diffstat (limited to '')
-rw-r--r--changelog.d/14675.misc1
-rw-r--r--synapse/events/snapshot.py174
-rw-r--r--synapse/events/third_party_rules.py6
-rw-r--r--synapse/handlers/federation.py59
-rw-r--r--synapse/handlers/federation_event.py6
-rw-r--r--synapse/handlers/message.py42
-rw-r--r--synapse/state/__init__.py176
-rw-r--r--tests/handlers/test_user_directory.py4
-rw-r--r--tests/rest/admin/test_user.py4
-rw-r--r--tests/storage/test_redaction.py24
-rw-r--r--tests/storage/test_state.py4
-rw-r--r--tests/test_utils/event_injection.py7
-rw-r--r--tests/test_visibility.py9
-rw-r--r--tests/utils.py5
14 files changed, 359 insertions, 162 deletions
diff --git a/changelog.d/14675.misc b/changelog.d/14675.misc
new file mode 100644
index 0000000000..bc1ac1c82a
--- /dev/null
+++ b/changelog.d/14675.misc
@@ -0,0 +1 @@
+Add a class UnpersistedEventContext to allow for the batching up of storing state groups.
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index 6eaef8b57a..e0d82ad81c 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -11,6 +11,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from abc import ABC, abstractmethod
 from typing import TYPE_CHECKING, List, Optional, Tuple
 
 import attr
@@ -26,8 +27,51 @@ if TYPE_CHECKING:
     from synapse.types.state import StateFilter
 
 
+class UnpersistedEventContextBase(ABC):
+    """
+    This is a base class for EventContext and UnpersistedEventContext, objects which
+    hold information relevant to storing an associated event. Note that an
+    UnpersistedEventContexts must be converted into an EventContext before it is
+    suitable to send to the db with its associated event.
+
+    Attributes:
+        _storage: storage controllers for interfacing with the database
+        app_service: If the associated event is being sent by a (local) application service, that
+            app service.
+    """
+
+    def __init__(self, storage_controller: "StorageControllers"):
+        self._storage: "StorageControllers" = storage_controller
+        self.app_service: Optional[ApplicationService] = None
+
+    @abstractmethod
+    async def persist(
+        self,
+        event: EventBase,
+    ) -> "EventContext":
+        """
+        A method to convert an UnpersistedEventContext to an EventContext, suitable for
+        sending to the database with the associated event.
+        """
+        pass
+
+    @abstractmethod
+    async def get_prev_state_ids(
+        self, state_filter: Optional["StateFilter"] = None
+    ) -> StateMap[str]:
+        """
+        Gets the room state at the event (ie not including the event if the event is a
+        state event).
+
+        Args:
+            state_filter: specifies the type of state event to fetch from DB, example:
+            EventTypes.JoinRules
+        """
+        pass
+
+
 @attr.s(slots=True, auto_attribs=True)
-class EventContext:
+class EventContext(UnpersistedEventContextBase):
     """
     Holds information relevant to persisting an event
 
@@ -77,9 +121,6 @@ class EventContext:
         delta_ids: If ``prev_group`` is not None, the state delta between ``prev_group``
             and ``state_group``.
 
-        app_service: If this event is being sent by a (local) application service, that
-            app service.
-
         partial_state: if True, we may be storing this event with a temporary,
             incomplete state.
     """
@@ -122,6 +163,9 @@ class EventContext:
         """Return an EventContext instance suitable for persisting an outlier event"""
         return EventContext(storage=storage)
 
+    async def persist(self, event: EventBase) -> "EventContext":
+        return self
+
     async def serialize(self, event: EventBase, store: "DataStore") -> JsonDict:
         """Converts self to a type that can be serialized as JSON, and then
         deserialized by `deserialize`
@@ -254,6 +298,128 @@ class EventContext:
         )
 
 
+@attr.s(slots=True, auto_attribs=True)
+class UnpersistedEventContext(UnpersistedEventContextBase):
+    """
+    The event context holds information about the state groups for an event. It is important
+    to remember that an event technically has two state groups: the state group before the
+    event, and the state group after the event. If the event is not a state event, the state
+    group will not change (ie the state group before the event will be the same as the state
+    group after the event), but if it is a state event the state group before the event
+    will differ from the state group after the event.
+    This is a version of an EventContext before the new state group (if any) has been
+    computed and stored. It contains information about the state before the event (which
+    also may be the information after the event, if the event is not a state event). The
+    UnpersistedEventContext must be converted into an EventContext by calling the method
+    'persist' on it before it is suitable to be sent to the DB for processing.
+
+        state_group_after_event:
+             The state group after the event. This will always be None until it is persisted.
+             If the event is not a state event, this will be the same as
+             state_group_before_event.
+
+        state_group_before_event:
+            The ID of the state group representing the state of the room before this event.
+
+        state_delta_due_to_event:
+            If the event is a state event, then this is the delta of the state between
+             `state_group` and `state_group_before_event`
+
+        prev_group_for_state_group_before_event:
+            If it is known, ``state_group_before_event``'s previous state group.
+
+        delta_ids_to_state_group_before_event:
+             If ``prev_group_for_state_group_before_event`` is not None, the state delta
+             between ``prev_group_for_state_group_before_event`` and ``state_group_before_event``.
+
+        partial_state:
+            Whether the event has partial state.
+
+        state_map_before_event:
+            A map of the state before the event, i.e. the state at `state_group_before_event`
+    """
+
+    _storage: "StorageControllers"
+    state_group_before_event: Optional[int]
+    state_group_after_event: Optional[int]
+    state_delta_due_to_event: Optional[dict]
+    prev_group_for_state_group_before_event: Optional[int]
+    delta_ids_to_state_group_before_event: Optional[StateMap[str]]
+    partial_state: bool
+    state_map_before_event: Optional[StateMap[str]] = None
+
+    async def get_prev_state_ids(
+        self, state_filter: Optional["StateFilter"] = None
+    ) -> StateMap[str]:
+        """
+        Gets the room state map, excluding this event.
+
+        Args:
+            state_filter: specifies the type of state event to fetch from DB
+
+        Returns:
+            Maps a (type, state_key) to the event ID of the state event matching
+            this tuple.
+        """
+        if self.state_map_before_event:
+            return self.state_map_before_event
+
+        assert self.state_group_before_event is not None
+        return await self._storage.state.get_state_ids_for_group(
+            self.state_group_before_event, state_filter
+        )
+
+    async def persist(self, event: EventBase) -> EventContext:
+        """
+        Creates a full `EventContext` for the event, persisting any referenced state that
+        has not yet been persisted.
+
+        Args:
+             event: event that the EventContext is associated with.
+
+        Returns: An EventContext suitable for sending to the database with the event
+        for persisting
+        """
+        assert self.partial_state is not None
+
+        # If we have a full set of state for before the event but don't have a state
+        # group for that state, we need to get one
+        if self.state_group_before_event is None:
+            assert self.state_map_before_event
+            state_group_before_event = await self._storage.state.store_state_group(
+                event.event_id,
+                event.room_id,
+                prev_group=self.prev_group_for_state_group_before_event,
+                delta_ids=self.delta_ids_to_state_group_before_event,
+                current_state_ids=self.state_map_before_event,
+            )
+            self.state_group_before_event = state_group_before_event
+
+        # if the event isn't a state event the state group doesn't change
+        if not self.state_delta_due_to_event:
+            state_group_after_event = self.state_group_before_event
+
+        # otherwise if it is a state event we need to get a state group for it
+        else:
+            state_group_after_event = await self._storage.state.store_state_group(
+                event.event_id,
+                event.room_id,
+                prev_group=self.state_group_before_event,
+                delta_ids=self.state_delta_due_to_event,
+                current_state_ids=None,
+            )
+
+        return EventContext.with_state(
+            storage=self._storage,
+            state_group=state_group_after_event,
+            state_group_before_event=self.state_group_before_event,
+            state_delta_due_to_event=self.state_delta_due_to_event,
+            partial_state=self.partial_state,
+            prev_group=self.state_group_before_event,
+            delta_ids=self.state_delta_due_to_event,
+        )
+
+
 def _encode_state_dict(
     state_dict: Optional[StateMap[str]],
 ) -> Optional[List[Tuple[str, str, str]]]:
diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py
index 72ab696898..97c61cc258 100644
--- a/synapse/events/third_party_rules.py
+++ b/synapse/events/third_party_rules.py
@@ -18,7 +18,7 @@ from twisted.internet.defer import CancelledError
 
 from synapse.api.errors import ModuleFailedException, SynapseError
 from synapse.events import EventBase
-from synapse.events.snapshot import EventContext
+from synapse.events.snapshot import UnpersistedEventContextBase
 from synapse.storage.roommember import ProfileInfo
 from synapse.types import Requester, StateMap
 from synapse.util.async_helpers import delay_cancellation, maybe_awaitable
@@ -231,7 +231,9 @@ class ThirdPartyEventRules:
             self._on_threepid_bind_callbacks.append(on_threepid_bind)
 
     async def check_event_allowed(
-        self, event: EventBase, context: EventContext
+        self,
+        event: EventBase,
+        context: UnpersistedEventContextBase,
     ) -> Tuple[bool, Optional[dict]]:
         """Check if a provided event should be allowed in the given context.
 
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 7f64130e0a..43ed4a3dd1 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -56,7 +56,7 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
 from synapse.crypto.event_signing import compute_event_signature
 from synapse.event_auth import validate_event_for_room_version
 from synapse.events import EventBase
-from synapse.events.snapshot import EventContext
+from synapse.events.snapshot import EventContext, UnpersistedEventContextBase
 from synapse.events.validator import EventValidator
 from synapse.federation.federation_client import InvalidResponseError
 from synapse.http.servlet import assert_params_in_dict
@@ -990,7 +990,10 @@ class FederationHandler:
         )
 
         try:
-            event, context = await self.event_creation_handler.create_new_client_event(
+            (
+                event,
+                unpersisted_context,
+            ) = await self.event_creation_handler.create_new_client_event(
                 builder=builder
             )
         except SynapseError as e:
@@ -998,7 +1001,9 @@ class FederationHandler:
             raise
 
         # Ensure the user can even join the room.
-        await self._federation_event_handler.check_join_restrictions(context, event)
+        await self._federation_event_handler.check_join_restrictions(
+            unpersisted_context, event
+        )
 
         # The remote hasn't signed it yet, obviously. We'll do the full checks
         # when we get the event back in `on_send_join_request`
@@ -1178,7 +1183,7 @@ class FederationHandler:
             },
         )
 
-        event, context = await self.event_creation_handler.create_new_client_event(
+        event, _ = await self.event_creation_handler.create_new_client_event(
             builder=builder
         )
 
@@ -1228,12 +1233,13 @@ class FederationHandler:
             },
         )
 
-        event, context = await self.event_creation_handler.create_new_client_event(
-            builder=builder
-        )
+        (
+            event,
+            unpersisted_context,
+        ) = await self.event_creation_handler.create_new_client_event(builder=builder)
 
         event_allowed, _ = await self.third_party_event_rules.check_event_allowed(
-            event, context
+            event, unpersisted_context
         )
         if not event_allowed:
             logger.warning("Creation of knock %s forbidden by third-party rules", event)
@@ -1406,15 +1412,20 @@ class FederationHandler:
                 try:
                     (
                         event,
-                        context,
+                        unpersisted_context,
                     ) = await self.event_creation_handler.create_new_client_event(
                         builder=builder
                     )
 
-                    event, context = await self.add_display_name_to_third_party_invite(
-                        room_version_obj, event_dict, event, context
+                    (
+                        event,
+                        unpersisted_context,
+                    ) = await self.add_display_name_to_third_party_invite(
+                        room_version_obj, event_dict, event, unpersisted_context
                     )
 
+                    context = await unpersisted_context.persist(event)
+
                     EventValidator().validate_new(event, self.config)
 
                     # We need to tell the transaction queue to send this out, even
@@ -1483,14 +1494,19 @@ class FederationHandler:
             try:
                 (
                     event,
-                    context,
+                    unpersisted_context,
                 ) = await self.event_creation_handler.create_new_client_event(
                     builder=builder
                 )
-                event, context = await self.add_display_name_to_third_party_invite(
-                    room_version_obj, event_dict, event, context
+                (
+                    event,
+                    unpersisted_context,
+                ) = await self.add_display_name_to_third_party_invite(
+                    room_version_obj, event_dict, event, unpersisted_context
                 )
 
+                context = await unpersisted_context.persist(event)
+
                 try:
                     validate_event_for_room_version(event)
                     await self._event_auth_handler.check_auth_rules_from_context(event)
@@ -1522,8 +1538,8 @@ class FederationHandler:
         room_version_obj: RoomVersion,
         event_dict: JsonDict,
         event: EventBase,
-        context: EventContext,
-    ) -> Tuple[EventBase, EventContext]:
+        context: UnpersistedEventContextBase,
+    ) -> Tuple[EventBase, UnpersistedEventContextBase]:
         key = (
             EventTypes.ThirdPartyInvite,
             event.content["third_party_invite"]["signed"]["token"],
@@ -1557,11 +1573,14 @@ class FederationHandler:
             room_version_obj, event_dict
         )
         EventValidator().validate_builder(builder)
-        event, context = await self.event_creation_handler.create_new_client_event(
-            builder=builder
-        )
+
+        (
+            event,
+            unpersisted_context,
+        ) = await self.event_creation_handler.create_new_client_event(builder=builder)
+
         EventValidator().validate_new(event, self.config)
-        return event, context
+        return event, unpersisted_context
 
     async def _check_signature(self, event: EventBase, context: EventContext) -> None:
         """
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index e037acbca2..3561f2f1de 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -58,7 +58,7 @@ from synapse.event_auth import (
     validate_event_for_room_version,
 )
 from synapse.events import EventBase
-from synapse.events.snapshot import EventContext
+from synapse.events.snapshot import EventContext, UnpersistedEventContextBase
 from synapse.federation.federation_client import InvalidResponseError, PulledPduInfo
 from synapse.logging.context import nested_logging_context
 from synapse.logging.opentracing import (
@@ -426,7 +426,9 @@ class FederationEventHandler:
         return event, context
 
     async def check_join_restrictions(
-        self, context: EventContext, event: EventBase
+        self,
+        context: UnpersistedEventContextBase,
+        event: EventBase,
     ) -> None:
         """Check that restrictions in restricted join rules are matched
 
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 5f6da2943f..3e30f52e4d 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -48,7 +48,7 @@ from synapse.api.urls import ConsentURIBuilder
 from synapse.event_auth import validate_event_for_room_version
 from synapse.events import EventBase, relation_from_event
 from synapse.events.builder import EventBuilder
-from synapse.events.snapshot import EventContext
+from synapse.events.snapshot import EventContext, UnpersistedEventContextBase
 from synapse.events.utils import maybe_upsert_event_field
 from synapse.events.validator import EventValidator
 from synapse.handlers.directory import DirectoryHandler
@@ -708,7 +708,7 @@ class EventCreationHandler:
 
         builder.internal_metadata.historical = historical
 
-        event, context = await self.create_new_client_event(
+        event, unpersisted_context = await self.create_new_client_event(
             builder=builder,
             requester=requester,
             allow_no_prev_events=allow_no_prev_events,
@@ -721,6 +721,8 @@ class EventCreationHandler:
             current_state_group=current_state_group,
         )
 
+        context = await unpersisted_context.persist(event)
+
         # In an ideal world we wouldn't need the second part of this condition. However,
         # this behaviour isn't spec'd yet, meaning we should be able to deactivate this
         # behaviour. Another reason is that this code is also evaluated each time a new
@@ -1083,13 +1085,14 @@ class EventCreationHandler:
         state_map: Optional[StateMap[str]] = None,
         for_batch: bool = False,
         current_state_group: Optional[int] = None,
-    ) -> Tuple[EventBase, EventContext]:
+    ) -> Tuple[EventBase, UnpersistedEventContextBase]:
         """Create a new event for a local client. If bool for_batch is true, will
         create an event using the prev_event_ids, and will create an event context for
         the event using the parameters state_map and current_state_group, thus these parameters
         must be provided in this case if for_batch is True. The subsequently created event
         and context are suitable for being batched up and bulk persisted to the database
-        with other similarly created events.
+        with other similarly created events. Note that this returns an UnpersistedEventContext,
+        which must be converted to an EventContext before it can be sent to the DB.
 
         Args:
             builder:
@@ -1131,7 +1134,7 @@ class EventCreationHandler:
                 batch persisting
 
         Returns:
-            Tuple of created event, context
+            Tuple of created event, UnpersistedEventContext
         """
         # Strip down the state_event_ids to only what we need to auth the event.
         # For example, we don't need extra m.room.member that don't match event.sender
@@ -1192,9 +1195,16 @@ class EventCreationHandler:
             event = await builder.build(
                 prev_event_ids=prev_event_ids, auth_event_ids=auth_ids, depth=depth
             )
-            context = await self.state.compute_event_context_for_batched(
-                event, state_map, current_state_group
+
+            context: UnpersistedEventContextBase = (
+                await self.state.calculate_context_info(
+                    event,
+                    state_ids_before_event=state_map,
+                    partial_state=False,
+                    state_group_before_event=current_state_group,
+                )
             )
+
         else:
             event = await builder.build(
                 prev_event_ids=prev_event_ids,
@@ -1244,16 +1254,17 @@ class EventCreationHandler:
 
                     state_map_for_event[(data.event_type, data.state_key)] = state_id
 
-                context = await self.state.compute_event_context(
+                # TODO(faster_joins): check how MSC2716 works and whether we can have
+                #   partial state here
+                #   https://github.com/matrix-org/synapse/issues/13003
+                context = await self.state.calculate_context_info(
                     event,
                     state_ids_before_event=state_map_for_event,
-                    # TODO(faster_joins): check how MSC2716 works and whether we can have
-                    #   partial state here
-                    #   https://github.com/matrix-org/synapse/issues/13003
                     partial_state=False,
                 )
+
             else:
-                context = await self.state.compute_event_context(event)
+                context = await self.state.calculate_context_info(event)
 
         if requester:
             context.app_service = requester.app_service
@@ -2082,9 +2093,9 @@ class EventCreationHandler:
 
     async def _rebuild_event_after_third_party_rules(
         self, third_party_result: dict, original_event: EventBase
-    ) -> Tuple[EventBase, EventContext]:
+    ) -> Tuple[EventBase, UnpersistedEventContextBase]:
         # the third_party_event_rules want to replace the event.
-        # we do some basic checks, and then return the replacement event and context.
+        # we do some basic checks, and then return the replacement event.
 
         # Construct a new EventBuilder and validate it, which helps with the
         # rest of these checks.
@@ -2138,5 +2149,6 @@ class EventCreationHandler:
 
         # we rebuild the event context, to be on the safe side. If nothing else,
         # delta_ids might need an update.
-        context = await self.state.compute_event_context(event)
+        context = await self.state.calculate_context_info(event)
+
         return event, context
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index fdfb46ab82..e877e6f1a1 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -39,7 +39,11 @@ from prometheus_client import Counter, Histogram
 from synapse.api.constants import EventTypes
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions
 from synapse.events import EventBase
-from synapse.events.snapshot import EventContext
+from synapse.events.snapshot import (
+    EventContext,
+    UnpersistedEventContext,
+    UnpersistedEventContextBase,
+)
 from synapse.logging.context import ContextResourceUsage
 from synapse.replication.http.state import ReplicationUpdateCurrentStateRestServlet
 from synapse.state import v1, v2
@@ -262,31 +266,31 @@ class StateHandler:
         state = await entry.get_state(self._state_storage_controller, StateFilter.all())
         return await self.store.get_joined_hosts(room_id, state, entry)
 
-    async def compute_event_context(
+    async def calculate_context_info(
         self,
         event: EventBase,
         state_ids_before_event: Optional[StateMap[str]] = None,
         partial_state: Optional[bool] = None,
-    ) -> EventContext:
-        """Build an EventContext structure for a non-outlier event.
-
-        (for an outlier, call EventContext.for_outlier directly)
-
-        This works out what the current state should be for the event, and
-        generates a new state group if necessary.
-
-        Args:
-            event:
-            state_ids_before_event: The event ids of the state before the event if
-                it can't be calculated from existing events. This is normally
-                only specified when receiving an event from federation where we
-                don't have the prev events, e.g. when backfilling.
-            partial_state:
-                `True` if `state_ids_before_event` is partial and omits non-critical
-                membership events.
-                `False` if `state_ids_before_event` is the full state.
-                `None` when `state_ids_before_event` is not provided. In this case, the
-                flag will be calculated based on `event`'s prev events.
+        state_group_before_event: Optional[int] = None,
+    ) -> UnpersistedEventContextBase:
+        """
+        Calulates the contents of an unpersisted event context, other than the current
+        state group (which is either provided or calculated when the event context is persisted)
+
+        state_ids_before_event:
+            The event ids of the full state before the event if
+            it can't be calculated from existing events. This is normally
+            only specified when receiving an event from federation where we
+            don't have the prev events, e.g. when backfilling or when the event
+            is being created for batch persisting.
+        partial_state:
+            `True` if `state_ids_before_event` is partial and omits non-critical
+            membership events.
+            `False` if `state_ids_before_event` is the full state.
+            `None` when `state_ids_before_event` is not provided. In this case, the
+            flag will be calculated based on `event`'s prev events.
+        state_group_before_event:
+            the current state group at the time of event, if known
         Returns:
             The event context.
 
@@ -294,7 +298,6 @@ class StateHandler:
             RuntimeError if `state_ids_before_event` is not provided and one or more
                 prev events are missing or outliers.
         """
-
         assert not event.internal_metadata.is_outlier()
 
         #
@@ -306,17 +309,6 @@ class StateHandler:
             state_group_before_event_prev_group = None
             deltas_to_state_group_before_event = None
 
-            # .. though we need to get a state group for it.
-            state_group_before_event = (
-                await self._state_storage_controller.store_state_group(
-                    event.event_id,
-                    event.room_id,
-                    prev_group=None,
-                    delta_ids=None,
-                    current_state_ids=state_ids_before_event,
-                )
-            )
-
             # the partial_state flag must be provided
             assert partial_state is not None
         else:
@@ -345,6 +337,7 @@ class StateHandler:
             logger.debug("calling resolve_state_groups from compute_event_context")
             # we've already taken into account partial state, so no need to wait for
             # complete state here.
+
             entry = await self.resolve_state_groups_for_events(
                 event.room_id,
                 event.prev_event_ids(),
@@ -383,18 +376,19 @@ class StateHandler:
         #
 
         if not event.is_state():
-            return EventContext.with_state(
+            return UnpersistedEventContext(
                 storage=self._storage_controllers,
                 state_group_before_event=state_group_before_event,
-                state_group=state_group_before_event,
+                state_group_after_event=state_group_before_event,
                 state_delta_due_to_event={},
-                prev_group=state_group_before_event_prev_group,
-                delta_ids=deltas_to_state_group_before_event,
+                prev_group_for_state_group_before_event=state_group_before_event_prev_group,
+                delta_ids_to_state_group_before_event=deltas_to_state_group_before_event,
                 partial_state=partial_state,
+                state_map_before_event=state_ids_before_event,
             )
 
         #
-        # otherwise, we'll need to create a new state group for after the event
+        # otherwise, we'll need to set up creating a new state group for after the event
         #
 
         key = (event.type, event.state_key)
@@ -412,88 +406,60 @@ class StateHandler:
 
         delta_ids = {key: event.event_id}
 
-        state_group_after_event = (
-            await self._state_storage_controller.store_state_group(
-                event.event_id,
-                event.room_id,
-                prev_group=state_group_before_event,
-                delta_ids=delta_ids,
-                current_state_ids=None,
-            )
-        )
-
-        return EventContext.with_state(
+        return UnpersistedEventContext(
             storage=self._storage_controllers,
-            state_group=state_group_after_event,
             state_group_before_event=state_group_before_event,
+            state_group_after_event=None,
             state_delta_due_to_event=delta_ids,
-            prev_group=state_group_before_event,
-            delta_ids=delta_ids,
+            prev_group_for_state_group_before_event=state_group_before_event_prev_group,
+            delta_ids_to_state_group_before_event=deltas_to_state_group_before_event,
             partial_state=partial_state,
+            state_map_before_event=state_ids_before_event,
         )
 
-    async def compute_event_context_for_batched(
+    async def compute_event_context(
         self,
         event: EventBase,
-        state_ids_before_event: StateMap[str],
-        current_state_group: int,
+        state_ids_before_event: Optional[StateMap[str]] = None,
+        partial_state: Optional[bool] = None,
     ) -> EventContext:
-        """
-        Generate an event context for an event that has not yet been persisted to the
-        database. Intended for use with events that are created to be persisted in a batch.
-        Args:
-            event: the event the context is being computed for
-            state_ids_before_event: a state map consisting of the state ids of the events
-            created prior to this event.
-            current_state_group: the current state group before the event.
-        """
-        state_group_before_event_prev_group = None
-        deltas_to_state_group_before_event = None
-
-        state_group_before_event = current_state_group
-
-        # if the event is not state, we are set
-        if not event.is_state():
-            return EventContext.with_state(
-                storage=self._storage_controllers,
-                state_group_before_event=state_group_before_event,
-                state_group=state_group_before_event,
-                state_delta_due_to_event={},
-                prev_group=state_group_before_event_prev_group,
-                delta_ids=deltas_to_state_group_before_event,
-                partial_state=False,
-            )
+        """Build an EventContext structure for a non-outlier event.
 
-        # otherwise, we'll need to create a new state group for after the event
-        key = (event.type, event.state_key)
+        (for an outlier, call EventContext.for_outlier directly)
 
-        if state_ids_before_event is not None:
-            replaces = state_ids_before_event.get(key)
+        This works out what the current state should be for the event, and
+        generates a new state group if necessary.
 
-        if replaces and replaces != event.event_id:
-            event.unsigned["replaces_state"] = replaces
+        Args:
+            event:
+            state_ids_before_event: The event ids of the state before the event if
+                it can't be calculated from existing events. This is normally
+                only specified when receiving an event from federation where we
+                don't have the prev events, e.g. when backfilling.
+            partial_state:
+                `True` if `state_ids_before_event` is partial and omits non-critical
+                membership events.
+                `False` if `state_ids_before_event` is the full state.
+                `None` when `state_ids_before_event` is not provided. In this case, the
+                flag will be calculated based on `event`'s prev events.
+            entry:
+                A state cache entry for the resolved state across the prev events. We may
+                have already calculated this, so if it's available pass it in
+        Returns:
+            The event context.
 
-        delta_ids = {key: event.event_id}
+        Raises:
+            RuntimeError if `state_ids_before_event` is not provided and one or more
+                prev events are missing or outliers.
+        """
 
-        state_group_after_event = (
-            await self._state_storage_controller.store_state_group(
-                event.event_id,
-                event.room_id,
-                prev_group=state_group_before_event,
-                delta_ids=delta_ids,
-                current_state_ids=None,
-            )
+        unpersisted_context = await self.calculate_context_info(
+            event=event,
+            state_ids_before_event=state_ids_before_event,
+            partial_state=partial_state,
         )
 
-        return EventContext.with_state(
-            storage=self._storage_controllers,
-            state_group=state_group_after_event,
-            state_group_before_event=state_group_before_event,
-            state_delta_due_to_event=delta_ids,
-            prev_group=state_group_before_event,
-            delta_ids=delta_ids,
-            partial_state=False,
-        )
+        return await unpersisted_context.persist(event)
 
     @measure_func()
     async def resolve_state_groups_for_events(
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 75fc5a17a4..e9be5fb504 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -949,10 +949,12 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
             },
         )
 
-        event, context = self.get_success(
+        event, unpersisted_context = self.get_success(
             self.event_creation_handler.create_new_client_event(builder)
         )
 
+        context = self.get_success(unpersisted_context.persist(event))
+
         self.get_success(
             self.hs.get_storage_controllers().persistence.persist_event(event, context)
         )
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 5c1ced355f..b50406e129 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -2934,10 +2934,12 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
             },
         )
 
-        event, context = self.get_success(
+        event, unpersisted_context = self.get_success(
             event_creation_handler.create_new_client_event(builder)
         )
 
+        context = self.get_success(unpersisted_context.persist(event))
+
         self.get_success(storage_controllers.persistence.persist_event(event, context))
 
         # Now get rooms
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index df4740f9d9..0100f7da14 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -74,10 +74,12 @@ class RedactionTestCase(unittest.HomeserverTestCase):
             },
         )
 
-        event, context = self.get_success(
+        event, unpersisted_context = self.get_success(
             self.event_creation_handler.create_new_client_event(builder)
         )
 
+        context = self.get_success(unpersisted_context.persist(event))
+
         self.get_success(self._persistence.persist_event(event, context))
 
         return event
@@ -96,10 +98,12 @@ class RedactionTestCase(unittest.HomeserverTestCase):
             },
         )
 
-        event, context = self.get_success(
+        event, unpersisted_context = self.get_success(
             self.event_creation_handler.create_new_client_event(builder)
         )
 
+        context = self.get_success(unpersisted_context.persist(event))
+
         self.get_success(self._persistence.persist_event(event, context))
 
         return event
@@ -119,10 +123,12 @@ class RedactionTestCase(unittest.HomeserverTestCase):
             },
         )
 
-        event, context = self.get_success(
+        event, unpersisted_context = self.get_success(
             self.event_creation_handler.create_new_client_event(builder)
         )
 
+        context = self.get_success(unpersisted_context.persist(event))
+
         self.get_success(self._persistence.persist_event(event, context))
 
         return event
@@ -259,7 +265,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
             def internal_metadata(self) -> _EventInternalMetadata:
                 return self._base_builder.internal_metadata
 
-        event_1, context_1 = self.get_success(
+        event_1, unpersisted_context_1 = self.get_success(
             self.event_creation_handler.create_new_client_event(
                 cast(
                     EventBuilder,
@@ -280,9 +286,11 @@ class RedactionTestCase(unittest.HomeserverTestCase):
             )
         )
 
+        context_1 = self.get_success(unpersisted_context_1.persist(event_1))
+
         self.get_success(self._persistence.persist_event(event_1, context_1))
 
-        event_2, context_2 = self.get_success(
+        event_2, unpersisted_context_2 = self.get_success(
             self.event_creation_handler.create_new_client_event(
                 cast(
                     EventBuilder,
@@ -302,6 +310,8 @@ class RedactionTestCase(unittest.HomeserverTestCase):
                 )
             )
         )
+
+        context_2 = self.get_success(unpersisted_context_2.persist(event_2))
         self.get_success(self._persistence.persist_event(event_2, context_2))
 
         # fetch one of the redactions
@@ -421,10 +431,12 @@ class RedactionTestCase(unittest.HomeserverTestCase):
             },
         )
 
-        redaction_event, context = self.get_success(
+        redaction_event, unpersisted_context = self.get_success(
             self.event_creation_handler.create_new_client_event(builder)
         )
 
+        context = self.get_success(unpersisted_context.persist(redaction_event))
+
         self.get_success(self._persistence.persist_event(redaction_event, context))
 
         # Now lets jump to the future where we have censored the redaction event
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index bad7f0bc60..f730b888f7 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -67,10 +67,12 @@ class StateStoreTestCase(HomeserverTestCase):
             },
         )
 
-        event, context = self.get_success(
+        event, unpersisted_context = self.get_success(
             self.event_creation_handler.create_new_client_event(builder)
         )
 
+        context = self.get_success(unpersisted_context.persist(event))
+
         assert self.storage.persistence is not None
         self.get_success(self.storage.persistence.persist_event(event, context))
 
diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py
index 1a50c2acf1..a6330ed840 100644
--- a/tests/test_utils/event_injection.py
+++ b/tests/test_utils/event_injection.py
@@ -92,8 +92,13 @@ async def create_event(
     builder = hs.get_event_builder_factory().for_room_version(
         KNOWN_ROOM_VERSIONS[room_version], kwargs
     )
-    event, context = await hs.get_event_creation_handler().create_new_client_event(
+    (
+        event,
+        unpersisted_context,
+    ) = await hs.get_event_creation_handler().create_new_client_event(
         builder, prev_event_ids=prev_event_ids
     )
 
+    context = await unpersisted_context.persist(event)
+
     return event, context
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index 875e37988f..36d6b37aa4 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -175,9 +175,10 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
             },
         )
 
-        event, context = self.get_success(
+        event, unpersisted_context = self.get_success(
             self.event_creation_handler.create_new_client_event(builder)
         )
+        context = self.get_success(unpersisted_context.persist(event))
         self.get_success(
             self._storage_controllers.persistence.persist_event(event, context)
         )
@@ -202,9 +203,10 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
             },
         )
 
-        event, context = self.get_success(
+        event, unpersisted_context = self.get_success(
             self.event_creation_handler.create_new_client_event(builder)
         )
+        context = self.get_success(unpersisted_context.persist(event))
 
         self.get_success(
             self._storage_controllers.persistence.persist_event(event, context)
@@ -226,9 +228,10 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
             },
         )
 
-        event, context = self.get_success(
+        event, unpersisted_context = self.get_success(
             self.event_creation_handler.create_new_client_event(builder)
         )
+        context = self.get_success(unpersisted_context.persist(event))
 
         self.get_success(
             self._storage_controllers.persistence.persist_event(event, context)
diff --git a/tests/utils.py b/tests/utils.py
index d76bf9716a..15fabbc2d0 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -335,6 +335,9 @@ async def create_room(hs: HomeServer, room_id: str, creator_id: str) -> None:
         },
     )
 
-    event, context = await event_creation_handler.create_new_client_event(builder)
+    event, unpersisted_context = await event_creation_handler.create_new_client_event(
+        builder
+    )
+    context = await unpersisted_context.persist(event)
 
     await persistence_store.persist_event(event, context)