summary refs log tree commit diff
path: root/tests/storage
diff options
context:
space:
mode:
Diffstat (limited to 'tests/storage')
-rw-r--r--tests/storage/test_event_chain.py6
-rw-r--r--tests/storage/test_state.py126
2 files changed, 130 insertions, 2 deletions
diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
index 73d11e7786..e39b63edac 100644
--- a/tests/storage/test_event_chain.py
+++ b/tests/storage/test_event_chain.py
@@ -522,7 +522,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
         latest_event_ids = self.get_success(
             self.store.get_prev_events_for_room(room_id)
         )
-        event, context = self.get_success(
+        event, unpersisted_context = self.get_success(
             event_handler.create_event(
                 self.requester,
                 {
@@ -535,6 +535,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
                 prev_event_ids=latest_event_ids,
             )
         )
+        context = self.get_success(unpersisted_context.persist(event))
         self.get_success(
             event_handler.handle_new_client_event(
                 self.requester, events_and_context=[(event, context)]
@@ -544,7 +545,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
         assert state_ids1 is not None
         state1 = set(state_ids1.values())
 
-        event, context = self.get_success(
+        event, unpersisted_context = self.get_success(
             event_handler.create_event(
                 self.requester,
                 {
@@ -557,6 +558,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
                 prev_event_ids=latest_event_ids,
             )
         )
+        context = self.get_success(unpersisted_context.persist(event))
         self.get_success(
             event_handler.handle_new_client_event(
                 self.requester, events_and_context=[(event, context)]
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index e82c03f597..62aed6af0a 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -496,3 +496,129 @@ class StateStoreTestCase(HomeserverTestCase):
 
         self.assertEqual(is_all, True)
         self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
+
+    def test_batched_state_group_storing(self) -> None:
+        creation_event = self.inject_state_event(
+            self.room, self.u_alice, EventTypes.Create, "", {}
+        )
+        state_to_event = self.get_success(
+            self.storage.state.get_state_groups(
+                self.room.to_string(), [creation_event.event_id]
+            )
+        )
+        current_state_group = list(state_to_event.keys())[0]
+
+        # create some unpersisted events and event contexts to store against room
+        events_and_context = []
+        builder = self.event_builder_factory.for_room_version(
+            RoomVersions.V1,
+            {
+                "type": EventTypes.Name,
+                "sender": self.u_alice.to_string(),
+                "state_key": "",
+                "room_id": self.room.to_string(),
+                "content": {"name": "first rename of room"},
+            },
+        )
+
+        event1, unpersisted_context1 = self.get_success(
+            self.event_creation_handler.create_new_client_event(builder)
+        )
+        events_and_context.append((event1, unpersisted_context1))
+
+        builder2 = self.event_builder_factory.for_room_version(
+            RoomVersions.V1,
+            {
+                "type": EventTypes.JoinRules,
+                "sender": self.u_alice.to_string(),
+                "state_key": "",
+                "room_id": self.room.to_string(),
+                "content": {"join_rule": "private"},
+            },
+        )
+
+        event2, unpersisted_context2 = self.get_success(
+            self.event_creation_handler.create_new_client_event(builder2)
+        )
+        events_and_context.append((event2, unpersisted_context2))
+
+        builder3 = self.event_builder_factory.for_room_version(
+            RoomVersions.V1,
+            {
+                "type": EventTypes.Message,
+                "sender": self.u_alice.to_string(),
+                "room_id": self.room.to_string(),
+                "content": {"body": "hello from event 3", "msgtype": "m.text"},
+            },
+        )
+
+        event3, unpersisted_context3 = self.get_success(
+            self.event_creation_handler.create_new_client_event(builder3)
+        )
+        events_and_context.append((event3, unpersisted_context3))
+
+        builder4 = self.event_builder_factory.for_room_version(
+            RoomVersions.V1,
+            {
+                "type": EventTypes.JoinRules,
+                "sender": self.u_alice.to_string(),
+                "state_key": "",
+                "room_id": self.room.to_string(),
+                "content": {"join_rule": "public"},
+            },
+        )
+
+        event4, unpersisted_context4 = self.get_success(
+            self.event_creation_handler.create_new_client_event(builder4)
+        )
+        events_and_context.append((event4, unpersisted_context4))
+
+        processed_events_and_context = self.get_success(
+            self.hs.get_datastores().state.store_state_deltas_for_batched(
+                events_and_context, self.room.to_string(), current_state_group
+            )
+        )
+
+        # check that only state events are in state_groups, and all state events are in state_groups
+        res = self.get_success(
+            self.store.db_pool.simple_select_list(
+                table="state_groups",
+                keyvalues=None,
+                retcols=("event_id",),
+            )
+        )
+
+        events = []
+        for result in res:
+            self.assertNotIn(event3.event_id, result)
+            events.append(result.get("event_id"))
+
+        for event, _ in processed_events_and_context:
+            if event.is_state():
+                self.assertIn(event.event_id, events)
+
+        # check that each unique state has state group in state_groups_state and that the
+        # type/state key is correct, and check that each state event's state group
+        # has an entry and prev event in state_group_edges
+        for event, context in processed_events_and_context:
+            if event.is_state():
+                state = self.get_success(
+                    self.store.db_pool.simple_select_list(
+                        table="state_groups_state",
+                        keyvalues={"state_group": context.state_group_after_event},
+                        retcols=("type", "state_key"),
+                    )
+                )
+                self.assertEqual(event.type, state[0].get("type"))
+                self.assertEqual(event.state_key, state[0].get("state_key"))
+
+                groups = self.get_success(
+                    self.store.db_pool.simple_select_list(
+                        table="state_group_edges",
+                        keyvalues={"state_group": str(context.state_group_after_event)},
+                        retcols=("*",),
+                    )
+                )
+                self.assertEqual(
+                    context.state_group_before_event, groups[0].get("prev_state_group")
+                )