summary refs log tree commit diff
diff options
context:
space:
mode:
authorShay <hillerys@element.io>2023-02-24 13:15:29 -0800
committerGitHub <noreply@github.com>2023-02-24 13:15:29 -0800
commit1c95ddd09bbc46046a3412e7bb03a87aa3b6f65a (patch)
tree7941d82e136933cb271ce0b9cad24c749aa9b435
parentImprove handling of non-ASCII characters in user directory search (#15143) (diff)
downloadsynapse-1c95ddd09bbc46046a3412e7bb03a87aa3b6f65a.tar.xz
Batch up storing state groups when creating new room (#14918)
Diffstat (limited to '')
-rw-r--r--changelog.d/14918.misc1
-rw-r--r--synapse/events/snapshot.py49
-rw-r--r--synapse/handlers/message.py16
-rw-r--r--synapse/handlers/room.py37
-rw-r--r--synapse/handlers/room_batch.py4
-rw-r--r--synapse/handlers/room_member.py13
-rw-r--r--synapse/storage/databases/state/store.py119
-rw-r--r--tests/handlers/test_message.py25
-rw-r--r--tests/handlers/test_register.py3
-rw-r--r--tests/push/test_bulk_push_rule_evaluator.py13
-rw-r--r--tests/rest/client/test_rooms.py4
-rw-r--r--tests/storage/test_event_chain.py6
-rw-r--r--tests/storage/test_state.py126
-rw-r--r--tests/unittest.py4
14 files changed, 371 insertions, 49 deletions
diff --git a/changelog.d/14918.misc b/changelog.d/14918.misc
new file mode 100644
index 0000000000..828794354a
--- /dev/null
+++ b/changelog.d/14918.misc
@@ -0,0 +1 @@
+Batch up storing state groups when creating a new room.
\ No newline at end of file
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index e0d82ad81c..a91a5d1e3c 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -23,6 +23,7 @@ from synapse.types import JsonDict, StateMap
 
 if TYPE_CHECKING:
     from synapse.storage.controllers import StorageControllers
+    from synapse.storage.databases import StateGroupDataStore
     from synapse.storage.databases.main import DataStore
     from synapse.types.state import StateFilter
 
@@ -348,6 +349,54 @@ class UnpersistedEventContext(UnpersistedEventContextBase):
     partial_state: bool
     state_map_before_event: Optional[StateMap[str]] = None
 
+    @classmethod
+    async def batch_persist_unpersisted_contexts(
+        cls,
+        events_and_context: List[Tuple[EventBase, "UnpersistedEventContextBase"]],
+        room_id: str,
+        last_known_state_group: int,
+        datastore: "StateGroupDataStore",
+    ) -> List[Tuple[EventBase, EventContext]]:
+        """
+        Takes a list of events and their associated unpersisted contexts and persists
+        the unpersisted contexts, returning a list of events and persisted contexts.
+        Note that all the events must be in a linear chain (ie a <- b <- c).
+
+        Args:
+            events_and_context: A list of events and their unpersisted contexts
+            room_id: the room_id for the events
+            last_known_state_group: the last persisted state group
+            datastore: a state datastore
+        """
+        amended_events_and_context = await datastore.store_state_deltas_for_batched(
+            events_and_context, room_id, last_known_state_group
+        )
+
+        events_and_persisted_context = []
+        for event, unpersisted_context in amended_events_and_context:
+            if event.is_state():
+                context = EventContext(
+                    storage=unpersisted_context._storage,
+                    state_group=unpersisted_context.state_group_after_event,
+                    state_group_before_event=unpersisted_context.state_group_before_event,
+                    state_delta_due_to_event=unpersisted_context.state_delta_due_to_event,
+                    partial_state=unpersisted_context.partial_state,
+                    prev_group=unpersisted_context.state_group_before_event,
+                    delta_ids=unpersisted_context.state_delta_due_to_event,
+                )
+            else:
+                context = EventContext(
+                    storage=unpersisted_context._storage,
+                    state_group=unpersisted_context.state_group_after_event,
+                    state_group_before_event=unpersisted_context.state_group_before_event,
+                    state_delta_due_to_event=unpersisted_context.state_delta_due_to_event,
+                    partial_state=unpersisted_context.partial_state,
+                    prev_group=unpersisted_context.prev_group_for_state_group_before_event,
+                    delta_ids=unpersisted_context.delta_ids_to_state_group_before_event,
+                )
+            events_and_persisted_context.append((event, context))
+        return events_and_persisted_context
+
     async def get_prev_state_ids(
         self, state_filter: Optional["StateFilter"] = None
     ) -> StateMap[str]:
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index aa90d0000d..e433d6b01f 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -574,7 +574,7 @@ class EventCreationHandler:
         state_map: Optional[StateMap[str]] = None,
         for_batch: bool = False,
         current_state_group: Optional[int] = None,
-    ) -> Tuple[EventBase, EventContext]:
+    ) -> Tuple[EventBase, UnpersistedEventContextBase]:
         """
         Given a dict from a client, create a new event. If bool for_batch is true, will
         create an event using the prev_event_ids, and will create an event context for
@@ -721,8 +721,6 @@ class EventCreationHandler:
             current_state_group=current_state_group,
         )
 
-        context = await unpersisted_context.persist(event)
-
         # In an ideal world we wouldn't need the second part of this condition. However,
         # this behaviour isn't spec'd yet, meaning we should be able to deactivate this
         # behaviour. Another reason is that this code is also evaluated each time a new
@@ -739,7 +737,7 @@ class EventCreationHandler:
                 assert state_map is not None
                 prev_event_id = state_map.get((EventTypes.Member, event.sender))
             else:
-                prev_state_ids = await context.get_prev_state_ids(
+                prev_state_ids = await unpersisted_context.get_prev_state_ids(
                     StateFilter.from_types([(EventTypes.Member, None)])
                 )
                 prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender))
@@ -764,8 +762,7 @@ class EventCreationHandler:
                 )
 
         self.validator.validate_new(event, self.config)
-
-        return event, context
+        return event, unpersisted_context
 
     async def _is_exempt_from_privacy_policy(
         self, builder: EventBuilder, requester: Requester
@@ -1005,7 +1002,7 @@ class EventCreationHandler:
         max_retries = 5
         for i in range(max_retries):
             try:
-                event, context = await self.create_event(
+                event, unpersisted_context = await self.create_event(
                     requester,
                     event_dict,
                     txn_id=txn_id,
@@ -1016,6 +1013,7 @@ class EventCreationHandler:
                     historical=historical,
                     depth=depth,
                 )
+                context = await unpersisted_context.persist(event)
 
                 assert self.hs.is_mine_id(event.sender), "User must be our own: %s" % (
                     event.sender,
@@ -1190,7 +1188,6 @@ class EventCreationHandler:
         if for_batch:
             assert prev_event_ids is not None
             assert state_map is not None
-            assert current_state_group is not None
             auth_ids = self._event_auth_handler.compute_auth_events(builder, state_map)
             event = await builder.build(
                 prev_event_ids=prev_event_ids, auth_event_ids=auth_ids, depth=depth
@@ -2046,7 +2043,7 @@ class EventCreationHandler:
                 max_retries = 5
                 for i in range(max_retries):
                     try:
-                        event, context = await self.create_event(
+                        event, unpersisted_context = await self.create_event(
                             requester,
                             {
                                 "type": EventTypes.Dummy,
@@ -2055,6 +2052,7 @@ class EventCreationHandler:
                                 "sender": user_id,
                             },
                         )
+                        context = await unpersisted_context.persist(event)
 
                         event.internal_metadata.proactively_send = False
 
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index a26ec02284..b1784638f4 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -51,6 +51,7 @@ from synapse.api.filtering import Filter
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
 from synapse.event_auth import validate_event_for_room_version
 from synapse.events import EventBase
+from synapse.events.snapshot import UnpersistedEventContext
 from synapse.events.utils import copy_and_fixup_power_levels_contents
 from synapse.handlers.relations import BundledAggregations
 from synapse.module_api import NOT_SPAM
@@ -211,7 +212,7 @@ class RoomCreationHandler:
                 # the required power level to send the tombstone event.
                 (
                     tombstone_event,
-                    tombstone_context,
+                    tombstone_unpersisted_context,
                 ) = await self.event_creation_handler.create_event(
                     requester,
                     {
@@ -225,6 +226,9 @@ class RoomCreationHandler:
                         },
                     },
                 )
+                tombstone_context = await tombstone_unpersisted_context.persist(
+                    tombstone_event
+                )
                 validate_event_for_room_version(tombstone_event)
                 await self._event_auth_handler.check_auth_rules_from_context(
                     tombstone_event
@@ -1092,7 +1096,7 @@ class RoomCreationHandler:
             content: JsonDict,
             for_batch: bool,
             **kwargs: Any,
-        ) -> Tuple[EventBase, synapse.events.snapshot.EventContext]:
+        ) -> Tuple[EventBase, synapse.events.snapshot.UnpersistedEventContextBase]:
             """
             Creates an event and associated event context.
             Args:
@@ -1111,20 +1115,23 @@ class RoomCreationHandler:
 
             event_dict = create_event_dict(etype, content, **kwargs)
 
-            new_event, new_context = await self.event_creation_handler.create_event(
+            (
+                new_event,
+                new_unpersisted_context,
+            ) = await self.event_creation_handler.create_event(
                 creator,
                 event_dict,
                 prev_event_ids=prev_event,
                 depth=depth,
                 state_map=state_map,
                 for_batch=for_batch,
-                current_state_group=current_state_group,
             )
+
             depth += 1
             prev_event = [new_event.event_id]
             state_map[(new_event.type, new_event.state_key)] = new_event.event_id
 
-            return new_event, new_context
+            return new_event, new_unpersisted_context
 
         try:
             config = self._presets_dict[preset_config]
@@ -1134,10 +1141,10 @@ class RoomCreationHandler:
             )
 
         creation_content.update({"creator": creator_id})
-        creation_event, creation_context = await create_event(
+        creation_event, unpersisted_creation_context = await create_event(
             EventTypes.Create, creation_content, False
         )
-
+        creation_context = await unpersisted_creation_context.persist(creation_event)
         logger.debug("Sending %s in new room", EventTypes.Member)
         ev = await self.event_creation_handler.handle_new_client_event(
             requester=creator,
@@ -1181,7 +1188,6 @@ class RoomCreationHandler:
             power_event, power_context = await create_event(
                 EventTypes.PowerLevels, pl_content, True
             )
-            current_state_group = power_context._state_group
             events_to_send.append((power_event, power_context))
         else:
             power_level_content: JsonDict = {
@@ -1230,14 +1236,12 @@ class RoomCreationHandler:
                 power_level_content,
                 True,
             )
-            current_state_group = pl_context._state_group
             events_to_send.append((pl_event, pl_context))
 
         if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state:
             room_alias_event, room_alias_context = await create_event(
                 EventTypes.CanonicalAlias, {"alias": room_alias.to_string()}, True
             )
-            current_state_group = room_alias_context._state_group
             events_to_send.append((room_alias_event, room_alias_context))
 
         if (EventTypes.JoinRules, "") not in initial_state:
@@ -1246,7 +1250,6 @@ class RoomCreationHandler:
                 {"join_rule": config["join_rules"]},
                 True,
             )
-            current_state_group = join_rules_context._state_group
             events_to_send.append((join_rules_event, join_rules_context))
 
         if (EventTypes.RoomHistoryVisibility, "") not in initial_state:
@@ -1255,7 +1258,6 @@ class RoomCreationHandler:
                 {"history_visibility": config["history_visibility"]},
                 True,
             )
-            current_state_group = visibility_context._state_group
             events_to_send.append((visibility_event, visibility_context))
 
         if config["guest_can_join"]:
@@ -1265,14 +1267,12 @@ class RoomCreationHandler:
                     {EventContentFields.GUEST_ACCESS: GuestAccess.CAN_JOIN},
                     True,
                 )
-                current_state_group = guest_access_context._state_group
                 events_to_send.append((guest_access_event, guest_access_context))
 
         for (etype, state_key), content in initial_state.items():
             event, context = await create_event(
                 etype, content, True, state_key=state_key
             )
-            current_state_group = context._state_group
             events_to_send.append((event, context))
 
         if config["encrypted"]:
@@ -1284,9 +1284,16 @@ class RoomCreationHandler:
             )
             events_to_send.append((encryption_event, encryption_context))
 
+        datastore = self.hs.get_datastores().state
+        events_and_context = (
+            await UnpersistedEventContext.batch_persist_unpersisted_contexts(
+                events_to_send, room_id, current_state_group, datastore
+            )
+        )
+
         last_event = await self.event_creation_handler.handle_new_client_event(
             creator,
-            events_to_send,
+            events_and_context,
             ignore_shadow_ban=True,
             ratelimit=False,
         )
diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py
index 5d4ca0e2d2..bf9df60218 100644
--- a/synapse/handlers/room_batch.py
+++ b/synapse/handlers/room_batch.py
@@ -327,7 +327,7 @@ class RoomBatchHandler:
             # Mark all events as historical
             event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True
 
-            event, context = await self.event_creation_handler.create_event(
+            event, unpersisted_context = await self.event_creation_handler.create_event(
                 await self.create_requester_for_user_id_from_app_service(
                     ev["sender"], app_service_requester.app_service
                 ),
@@ -345,7 +345,7 @@ class RoomBatchHandler:
                 historical=True,
                 depth=inherited_depth,
             )
-
+            context = await unpersisted_context.persist(event)
             assert context._state_group
 
             # Normally this is done when persisting the event but we have to
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index a965c7ec76..de7476f300 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -414,7 +414,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         max_retries = 5
         for i in range(max_retries):
             try:
-                event, context = await self.event_creation_handler.create_event(
+                (
+                    event,
+                    unpersisted_context,
+                ) = await self.event_creation_handler.create_event(
                     requester,
                     {
                         "type": EventTypes.Member,
@@ -435,7 +438,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                     outlier=outlier,
                     historical=historical,
                 )
-
+                context = await unpersisted_context.persist(event)
                 prev_state_ids = await context.get_prev_state_ids(
                     StateFilter.from_types([(EventTypes.Member, None)])
                 )
@@ -1944,7 +1947,10 @@ class RoomMemberMasterHandler(RoomMemberHandler):
         max_retries = 5
         for i in range(max_retries):
             try:
-                event, context = await self.event_creation_handler.create_event(
+                (
+                    event,
+                    unpersisted_context,
+                ) = await self.event_creation_handler.create_event(
                     requester,
                     event_dict,
                     txn_id=txn_id,
@@ -1952,6 +1958,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
                     auth_event_ids=auth_event_ids,
                     outlier=True,
                 )
+                context = await unpersisted_context.persist(event)
                 event.internal_metadata.out_of_band_membership = True
 
                 result_event = (
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index 89b1faa6c8..bf4cdfdf29 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -18,6 +18,8 @@ from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Se
 import attr
 
 from synapse.api.constants import EventTypes
+from synapse.events import EventBase
+from synapse.events.snapshot import UnpersistedEventContext, UnpersistedEventContextBase
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import (
     DatabasePool,
@@ -401,6 +403,123 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
                 fetched_keys=non_member_types,
             )
 
+    async def store_state_deltas_for_batched(
+        self,
+        events_and_context: List[Tuple[EventBase, UnpersistedEventContextBase]],
+        room_id: str,
+        prev_group: int,
+    ) -> List[Tuple[EventBase, UnpersistedEventContext]]:
+        """Generate and store state deltas for a group of events and contexts created to be
+        batch persisted. Note that all the events must be in a linear chain (ie a <- b <- c).
+
+        Args:
+            events_and_context: the events to generate and store a state groups for
+            and their associated contexts
+            room_id: the id of the room the events were created for
+            prev_group: the state group of the last event persisted before the batched events
+            were created
+        """
+
+        def insert_deltas_group_txn(
+            txn: LoggingTransaction,
+            events_and_context: List[Tuple[EventBase, UnpersistedEventContext]],
+            prev_group: int,
+        ) -> List[Tuple[EventBase, UnpersistedEventContext]]:
+            """Generate and store state groups for the provided events and contexts.
+
+            Requires that we have the state as a delta from the last persisted state group.
+
+            Returns:
+                A list of state groups
+            """
+            is_in_db = self.db_pool.simple_select_one_onecol_txn(
+                txn,
+                table="state_groups",
+                keyvalues={"id": prev_group},
+                retcol="id",
+                allow_none=True,
+            )
+            if not is_in_db:
+                raise Exception(
+                    "Trying to persist state with unpersisted prev_group: %r"
+                    % (prev_group,)
+                )
+
+            num_state_groups = sum(
+                1 for event, _ in events_and_context if event.is_state()
+            )
+
+            state_groups = self._state_group_seq_gen.get_next_mult_txn(
+                txn, num_state_groups
+            )
+
+            sg_before = prev_group
+            state_group_iter = iter(state_groups)
+            for event, context in events_and_context:
+                if not event.is_state():
+                    context.state_group_after_event = sg_before
+                    context.state_group_before_event = sg_before
+                    continue
+
+                sg_after = next(state_group_iter)
+                context.state_group_after_event = sg_after
+                context.state_group_before_event = sg_before
+                context.state_delta_due_to_event = {
+                    (event.type, event.state_key): event.event_id
+                }
+                sg_before = sg_after
+
+            self.db_pool.simple_insert_many_txn(
+                txn,
+                table="state_groups",
+                keys=("id", "room_id", "event_id"),
+                values=[
+                    (context.state_group_after_event, room_id, event.event_id)
+                    for event, context in events_and_context
+                    if event.is_state()
+                ],
+            )
+
+            self.db_pool.simple_insert_many_txn(
+                txn,
+                table="state_group_edges",
+                keys=("state_group", "prev_state_group"),
+                values=[
+                    (
+                        context.state_group_after_event,
+                        context.state_group_before_event,
+                    )
+                    for event, context in events_and_context
+                    if event.is_state()
+                ],
+            )
+
+            self.db_pool.simple_insert_many_txn(
+                txn,
+                table="state_groups_state",
+                keys=("state_group", "room_id", "type", "state_key", "event_id"),
+                values=[
+                    (
+                        context.state_group_after_event,
+                        room_id,
+                        key[0],
+                        key[1],
+                        state_id,
+                    )
+                    for event, context in events_and_context
+                    if context.state_delta_due_to_event is not None
+                    for key, state_id in context.state_delta_due_to_event.items()
+                ],
+            )
+            return events_and_context
+
+        return await self.db_pool.runInteraction(
+            "store_state_deltas_for_batched.insert_deltas_group",
+            insert_deltas_group_txn,
+            events_and_context,
+            prev_group,
+        )
+
     async def store_state_group(
         self,
         event_id: str,
diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py
index 69d384442f..9691d66b48 100644
--- a/tests/handlers/test_message.py
+++ b/tests/handlers/test_message.py
@@ -18,7 +18,7 @@ from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.constants import EventTypes
 from synapse.events import EventBase
-from synapse.events.snapshot import EventContext
+from synapse.events.snapshot import EventContext, UnpersistedEventContextBase
 from synapse.rest import admin
 from synapse.rest.client import login, room
 from synapse.server import HomeServer
@@ -79,7 +79,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
 
         return memberEvent, memberEventContext
 
-    def _create_duplicate_event(self, txn_id: str) -> Tuple[EventBase, EventContext]:
+    def _create_duplicate_event(
+        self, txn_id: str
+    ) -> Tuple[EventBase, UnpersistedEventContextBase]:
         """Create a new event with the given transaction ID. All events produced
         by this method will be considered duplicates.
         """
@@ -107,7 +109,8 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
 
         txn_id = "something_suitably_random"
 
-        event1, context = self._create_duplicate_event(txn_id)
+        event1, unpersisted_context = self._create_duplicate_event(txn_id)
+        context = self.get_success(unpersisted_context.persist(event1))
 
         ret_event1 = self.get_success(
             self.handler.handle_new_client_event(
@@ -119,7 +122,8 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual(event1.event_id, ret_event1.event_id)
 
-        event2, context = self._create_duplicate_event(txn_id)
+        event2, unpersisted_context = self._create_duplicate_event(txn_id)
+        context = self.get_success(unpersisted_context.persist(event2))
 
         # We want to test that the deduplication at the persit event end works,
         # so we want to make sure we test with different events.
@@ -140,7 +144,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
 
         # Let's test that calling `persist_event` directly also does the right
         # thing.
-        event3, context = self._create_duplicate_event(txn_id)
+        event3, unpersisted_context = self._create_duplicate_event(txn_id)
+        context = self.get_success(unpersisted_context.persist(event3))
+
         self.assertNotEqual(event1.event_id, event3.event_id)
 
         ret_event3, event_pos3, _ = self.get_success(
@@ -154,7 +160,8 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
 
         # Let's test that calling `persist_events` directly also does the right
         # thing.
-        event4, context = self._create_duplicate_event(txn_id)
+        event4, unpersisted_context = self._create_duplicate_event(txn_id)
+        context = self.get_success(unpersisted_context.persist(event4))
         self.assertNotEqual(event1.event_id, event3.event_id)
 
         events, _ = self.get_success(
@@ -174,8 +181,10 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
         txn_id = "something_else_suitably_random"
 
         # Create two duplicate events to persist at the same time
-        event1, context1 = self._create_duplicate_event(txn_id)
-        event2, context2 = self._create_duplicate_event(txn_id)
+        event1, unpersisted_context1 = self._create_duplicate_event(txn_id)
+        context1 = self.get_success(unpersisted_context1.persist(event1))
+        event2, unpersisted_context2 = self._create_duplicate_event(txn_id)
+        context2 = self.get_success(unpersisted_context2.persist(event2))
 
         # Ensure their event IDs are different to start with
         self.assertNotEqual(event1.event_id, event2.event_id)
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index 1db99b3c00..aff1ec4758 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -507,7 +507,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
         # Lower the permissions of the inviter.
         event_creation_handler = self.hs.get_event_creation_handler()
         requester = create_requester(inviter)
-        event, context = self.get_success(
+        event, unpersisted_context = self.get_success(
             event_creation_handler.create_event(
                 requester,
                 {
@@ -519,6 +519,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
                 },
             )
         )
+        context = self.get_success(unpersisted_context.persist(event))
         self.get_success(
             event_creation_handler.handle_new_client_event(
                 requester, events_and_context=[(event, context)]
diff --git a/tests/push/test_bulk_push_rule_evaluator.py b/tests/push/test_bulk_push_rule_evaluator.py
index dce6899e78..1458076a90 100644
--- a/tests/push/test_bulk_push_rule_evaluator.py
+++ b/tests/push/test_bulk_push_rule_evaluator.py
@@ -130,7 +130,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
 
         # Create a new message event, and try to evaluate it under the dodgy
         # power level event.
-        event, context = self.get_success(
+        event, unpersisted_context = self.get_success(
             self.event_creation_handler.create_event(
                 self.requester,
                 {
@@ -145,6 +145,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
                 prev_event_ids=[pl_event_id],
             )
         )
+        context = self.get_success(unpersisted_context.persist(event))
 
         bulk_evaluator = BulkPushRuleEvaluator(self.hs)
         # should not raise
@@ -170,7 +171,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
         """Ensure that push rules are not calculated when disabled in the config"""
 
         # Create a new message event which should cause a notification.
-        event, context = self.get_success(
+        event, unpersisted_context = self.get_success(
             self.event_creation_handler.create_event(
                 self.requester,
                 {
@@ -184,6 +185,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
                 },
             )
         )
+        context = self.get_success(unpersisted_context.persist(event))
 
         bulk_evaluator = BulkPushRuleEvaluator(self.hs)
         # Mock the method which calculates push rules -- we do this instead of
@@ -200,7 +202,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
     ) -> bool:
         """Returns true iff the `mentions` trigger an event push action."""
         # Create a new message event which should cause a notification.
-        event, context = self.get_success(
+        event, unpersisted_context = self.get_success(
             self.event_creation_handler.create_event(
                 self.requester,
                 {
@@ -211,7 +213,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
                 },
             )
         )
-
+        context = self.get_success(unpersisted_context.persist(event))
         # Execute the push rule machinery.
         self.get_success(bulk_evaluator.action_for_events_by_user([(event, context)]))
 
@@ -390,7 +392,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
         bulk_evaluator = BulkPushRuleEvaluator(self.hs)
 
         # Create & persist an event to use as the parent of the relation.
-        event, context = self.get_success(
+        event, unpersisted_context = self.get_success(
             self.event_creation_handler.create_event(
                 self.requester,
                 {
@@ -404,6 +406,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
                 },
             )
         )
+        context = self.get_success(unpersisted_context.persist(event))
         self.get_success(
             self.event_creation_handler.handle_new_client_event(
                 self.requester, events_and_context=[(event, context)]
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index 4dd763096d..a4900703c4 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -713,7 +713,7 @@ class RoomsCreateTestCase(RoomBase):
         self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
         self.assertTrue("room_id" in channel.json_body)
         assert channel.resource_usage is not None
-        self.assertEqual(33, channel.resource_usage.db_txn_count)
+        self.assertEqual(30, channel.resource_usage.db_txn_count)
 
     def test_post_room_initial_state(self) -> None:
         # POST with initial_state config key, expect new room id
@@ -726,7 +726,7 @@ class RoomsCreateTestCase(RoomBase):
         self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
         self.assertTrue("room_id" in channel.json_body)
         assert channel.resource_usage is not None
-        self.assertEqual(36, channel.resource_usage.db_txn_count)
+        self.assertEqual(32, channel.resource_usage.db_txn_count)
 
     def test_post_room_visibility_key(self) -> None:
         # POST with visibility config key, expect new room id
diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
index 73d11e7786..e39b63edac 100644
--- a/tests/storage/test_event_chain.py
+++ b/tests/storage/test_event_chain.py
@@ -522,7 +522,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
         latest_event_ids = self.get_success(
             self.store.get_prev_events_for_room(room_id)
         )
-        event, context = self.get_success(
+        event, unpersisted_context = self.get_success(
             event_handler.create_event(
                 self.requester,
                 {
@@ -535,6 +535,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
                 prev_event_ids=latest_event_ids,
             )
         )
+        context = self.get_success(unpersisted_context.persist(event))
         self.get_success(
             event_handler.handle_new_client_event(
                 self.requester, events_and_context=[(event, context)]
@@ -544,7 +545,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
         assert state_ids1 is not None
         state1 = set(state_ids1.values())
 
-        event, context = self.get_success(
+        event, unpersisted_context = self.get_success(
             event_handler.create_event(
                 self.requester,
                 {
@@ -557,6 +558,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
                 prev_event_ids=latest_event_ids,
             )
         )
+        context = self.get_success(unpersisted_context.persist(event))
         self.get_success(
             event_handler.handle_new_client_event(
                 self.requester, events_and_context=[(event, context)]
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index e82c03f597..62aed6af0a 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -496,3 +496,129 @@ class StateStoreTestCase(HomeserverTestCase):
 
         self.assertEqual(is_all, True)
         self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
+
+    def test_batched_state_group_storing(self) -> None:
+        creation_event = self.inject_state_event(
+            self.room, self.u_alice, EventTypes.Create, "", {}
+        )
+        state_to_event = self.get_success(
+            self.storage.state.get_state_groups(
+                self.room.to_string(), [creation_event.event_id]
+            )
+        )
+        current_state_group = list(state_to_event.keys())[0]
+
+        # create some unpersisted events and event contexts to store against room
+        events_and_context = []
+        builder = self.event_builder_factory.for_room_version(
+            RoomVersions.V1,
+            {
+                "type": EventTypes.Name,
+                "sender": self.u_alice.to_string(),
+                "state_key": "",
+                "room_id": self.room.to_string(),
+                "content": {"name": "first rename of room"},
+            },
+        )
+
+        event1, unpersisted_context1 = self.get_success(
+            self.event_creation_handler.create_new_client_event(builder)
+        )
+        events_and_context.append((event1, unpersisted_context1))
+
+        builder2 = self.event_builder_factory.for_room_version(
+            RoomVersions.V1,
+            {
+                "type": EventTypes.JoinRules,
+                "sender": self.u_alice.to_string(),
+                "state_key": "",
+                "room_id": self.room.to_string(),
+                "content": {"join_rule": "private"},
+            },
+        )
+
+        event2, unpersisted_context2 = self.get_success(
+            self.event_creation_handler.create_new_client_event(builder2)
+        )
+        events_and_context.append((event2, unpersisted_context2))
+
+        builder3 = self.event_builder_factory.for_room_version(
+            RoomVersions.V1,
+            {
+                "type": EventTypes.Message,
+                "sender": self.u_alice.to_string(),
+                "room_id": self.room.to_string(),
+                "content": {"body": "hello from event 3", "msgtype": "m.text"},
+            },
+        )
+
+        event3, unpersisted_context3 = self.get_success(
+            self.event_creation_handler.create_new_client_event(builder3)
+        )
+        events_and_context.append((event3, unpersisted_context3))
+
+        builder4 = self.event_builder_factory.for_room_version(
+            RoomVersions.V1,
+            {
+                "type": EventTypes.JoinRules,
+                "sender": self.u_alice.to_string(),
+                "state_key": "",
+                "room_id": self.room.to_string(),
+                "content": {"join_rule": "public"},
+            },
+        )
+
+        event4, unpersisted_context4 = self.get_success(
+            self.event_creation_handler.create_new_client_event(builder4)
+        )
+        events_and_context.append((event4, unpersisted_context4))
+
+        processed_events_and_context = self.get_success(
+            self.hs.get_datastores().state.store_state_deltas_for_batched(
+                events_and_context, self.room.to_string(), current_state_group
+            )
+        )
+
+        # check that only state events are in state_groups, and all state events are in state_groups
+        res = self.get_success(
+            self.store.db_pool.simple_select_list(
+                table="state_groups",
+                keyvalues=None,
+                retcols=("event_id",),
+            )
+        )
+
+        events = []
+        for result in res:
+            self.assertNotIn(event3.event_id, result)
+            events.append(result.get("event_id"))
+
+        for event, _ in processed_events_and_context:
+            if event.is_state():
+                self.assertIn(event.event_id, events)
+
+        # check that each unique state has state group in state_groups_state and that the
+        # type/state key is correct, and check that each state event's state group
+        # has an entry and prev event in state_group_edges
+        for event, context in processed_events_and_context:
+            if event.is_state():
+                state = self.get_success(
+                    self.store.db_pool.simple_select_list(
+                        table="state_groups_state",
+                        keyvalues={"state_group": context.state_group_after_event},
+                        retcols=("type", "state_key"),
+                    )
+                )
+                self.assertEqual(event.type, state[0].get("type"))
+                self.assertEqual(event.state_key, state[0].get("state_key"))
+
+                groups = self.get_success(
+                    self.store.db_pool.simple_select_list(
+                        table="state_group_edges",
+                        keyvalues={"state_group": str(context.state_group_after_event)},
+                        retcols=("*",),
+                    )
+                )
+                self.assertEqual(
+                    context.state_group_before_event, groups[0].get("prev_state_group")
+                )
diff --git a/tests/unittest.py b/tests/unittest.py
index b21e7f1221..f9160faa1d 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -723,7 +723,7 @@ class HomeserverTestCase(TestCase):
         event_creator = self.hs.get_event_creation_handler()
         requester = create_requester(user)
 
-        event, context = self.get_success(
+        event, unpersisted_context = self.get_success(
             event_creator.create_event(
                 requester,
                 {
@@ -735,7 +735,7 @@ class HomeserverTestCase(TestCase):
                 prev_event_ids=prev_event_ids,
             )
         )
-
+        context = self.get_success(unpersisted_context.persist(event))
         if soft_failed:
             event.internal_metadata.soft_failed = True