summary refs log tree commit diff
path: root/tests/storage/test_event_chain.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/storage/test_event_chain.py')
-rw-r--r--tests/storage/test_event_chain.py217
1 files changed, 186 insertions, 31 deletions
diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
index ff67a73749..0c46ad595b 100644
--- a/tests/storage/test_event_chain.py
+++ b/tests/storage/test_event_chain.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Dict, List, Tuple
+from typing import Dict, List, Set, Tuple
 
 from twisted.trial import unittest
 
@@ -483,22 +483,20 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
         login.register_servlets,
     ]
 
-    def test_background_update(self):
-        """Test that the background update to calculate auth chains for historic
-        rooms works correctly.
-        """
-
-        # Create a room
-        user_id = self.register_user("foo", "pass")
-        token = self.login("foo", "pass")
-        room_id = self.helper.create_room_as(user_id, tok=token)
-        requester = create_requester(user_id)
+    def prepare(self, reactor, clock, hs):
+        self.store = hs.get_datastore()
+        self.user_id = self.register_user("foo", "pass")
+        self.token = self.login("foo", "pass")
+        self.requester = create_requester(self.user_id)
 
-        store = self.hs.get_datastore()
+    def _generate_room(self) -> Tuple[str, List[Set[str]]]:
+        """Insert a room without a chain cover index.
+        """
+        room_id = self.helper.create_room_as(self.user_id, tok=self.token)
 
         # Mark the room as not having a chain cover index
         self.get_success(
-            store.db_pool.simple_update(
+            self.store.db_pool.simple_update(
                 table="rooms",
                 keyvalues={"room_id": room_id},
                 updatevalues={"has_auth_chain_index": False},
@@ -508,42 +506,44 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
 
         # Create a fork in the DAG with different events.
         event_handler = self.hs.get_event_creation_handler()
-        latest_event_ids = self.get_success(store.get_prev_events_for_room(room_id))
+        latest_event_ids = self.get_success(
+            self.store.get_prev_events_for_room(room_id)
+        )
         event, context = self.get_success(
             event_handler.create_event(
-                requester,
+                self.requester,
                 {
                     "type": "some_state_type",
                     "state_key": "",
                     "content": {},
                     "room_id": room_id,
-                    "sender": user_id,
+                    "sender": self.user_id,
                 },
                 prev_event_ids=latest_event_ids,
             )
         )
         self.get_success(
-            event_handler.handle_new_client_event(requester, event, context)
+            event_handler.handle_new_client_event(self.requester, event, context)
         )
-        state1 = list(self.get_success(context.get_current_state_ids()).values())
+        state1 = set(self.get_success(context.get_current_state_ids()).values())
 
         event, context = self.get_success(
             event_handler.create_event(
-                requester,
+                self.requester,
                 {
                     "type": "some_state_type",
                     "state_key": "",
                     "content": {},
                     "room_id": room_id,
-                    "sender": user_id,
+                    "sender": self.user_id,
                 },
                 prev_event_ids=latest_event_ids,
             )
         )
         self.get_success(
-            event_handler.handle_new_client_event(requester, event, context)
+            event_handler.handle_new_client_event(self.requester, event, context)
         )
-        state2 = list(self.get_success(context.get_current_state_ids()).values())
+        state2 = set(self.get_success(context.get_current_state_ids()).values())
 
         # Delete the chain cover info.
 
@@ -551,36 +551,191 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
             txn.execute("DELETE FROM event_auth_chains")
             txn.execute("DELETE FROM event_auth_chain_links")
 
-        self.get_success(store.db_pool.runInteraction("test", _delete_tables))
+        self.get_success(self.store.db_pool.runInteraction("test", _delete_tables))
+
+        return room_id, [state1, state2]
+
+    def test_background_update_single_room(self):
+        """Test that the background update to calculate auth chains for historic
+        rooms works correctly.
+        """
+
+        # Create a room
+        room_id, states = self._generate_room()
 
         # Insert and run the background update.
         self.get_success(
-            store.db_pool.simple_insert(
+            self.store.db_pool.simple_insert(
                 "background_updates",
                 {"update_name": "chain_cover", "progress_json": "{}"},
             )
         )
 
         # Ugh, have to reset this flag
-        store.db_pool.updates._all_done = False
+        self.store.db_pool.updates._all_done = False
 
         while not self.get_success(
-            store.db_pool.updates.has_completed_background_updates()
+            self.store.db_pool.updates.has_completed_background_updates()
         ):
             self.get_success(
-                store.db_pool.updates.do_next_background_update(100), by=0.1
+                self.store.db_pool.updates.do_next_background_update(100), by=0.1
             )
 
         # Test that the `has_auth_chain_index` has been set
-        self.assertTrue(self.get_success(store.has_auth_chain_index(room_id)))
+        self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id)))
 
         # Test that calculating the auth chain difference using the newly
         # calculated chain cover works.
         self.get_success(
-            store.db_pool.runInteraction(
+            self.store.db_pool.runInteraction(
                 "test",
-                store._get_auth_chain_difference_using_cover_index_txn,
+                self.store._get_auth_chain_difference_using_cover_index_txn,
                 room_id,
-                [state1, state2],
+                states,
+            )
+        )
+
+    def test_background_update_multiple_rooms(self):
+        """Test that the background update to calculate auth chains for historic
+        rooms works correctly.
+        """
+        # Create a room
+        room_id1, states1 = self._generate_room()
+        room_id2, states2 = self._generate_room()
+        room_id3, states2 = self._generate_room()
+
+        # Insert and run the background update.
+        self.get_success(
+            self.store.db_pool.simple_insert(
+                "background_updates",
+                {"update_name": "chain_cover", "progress_json": "{}"},
+            )
+        )
+
+        # Ugh, have to reset this flag
+        self.store.db_pool.updates._all_done = False
+
+        while not self.get_success(
+            self.store.db_pool.updates.has_completed_background_updates()
+        ):
+            self.get_success(
+                self.store.db_pool.updates.do_next_background_update(100), by=0.1
+            )
+
+        # Test that the `has_auth_chain_index` has been set
+        self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id1)))
+        self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id2)))
+        self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id3)))
+
+        # Test that calculating the auth chain difference using the newly
+        # calculated chain cover works.
+        self.get_success(
+            self.store.db_pool.runInteraction(
+                "test",
+                self.store._get_auth_chain_difference_using_cover_index_txn,
+                room_id1,
+                states1,
             )
         )
+
+    def test_background_update_single_large_room(self):
+        """Test that the background update to calculate auth chains for historic
+        rooms works correctly.
+        """
+
+        # Create a room
+        room_id, states = self._generate_room()
+
+        # Add a bunch of state so that it takes multiple iterations of the
+        # background update to process the room.
+        for i in range(0, 150):
+            self.helper.send_state(
+                room_id, event_type="m.test", body={"index": i}, tok=self.token
+            )
+
+        # Insert and run the background update.
+        self.get_success(
+            self.store.db_pool.simple_insert(
+                "background_updates",
+                {"update_name": "chain_cover", "progress_json": "{}"},
+            )
+        )
+
+        # Ugh, have to reset this flag
+        self.store.db_pool.updates._all_done = False
+
+        iterations = 0
+        while not self.get_success(
+            self.store.db_pool.updates.has_completed_background_updates()
+        ):
+            iterations += 1
+            self.get_success(
+                self.store.db_pool.updates.do_next_background_update(100), by=0.1
+            )
+
+        # Ensure that we did actually take multiple iterations to process the
+        # room.
+        self.assertGreater(iterations, 1)
+
+        # Test that the `has_auth_chain_index` has been set
+        self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id)))
+
+        # Test that calculating the auth chain difference using the newly
+        # calculated chain cover works.
+        self.get_success(
+            self.store.db_pool.runInteraction(
+                "test",
+                self.store._get_auth_chain_difference_using_cover_index_txn,
+                room_id,
+                states,
+            )
+        )
+
+    def test_background_update_multiple_large_room(self):
+        """Test that the background update to calculate auth chains for historic
+        rooms works correctly.
+        """
+
+        # Create the rooms
+        room_id1, _ = self._generate_room()
+        room_id2, _ = self._generate_room()
+
+        # Add a bunch of state so that it takes multiple iterations of the
+        # background update to process the room.
+        for i in range(0, 150):
+            self.helper.send_state(
+                room_id1, event_type="m.test", body={"index": i}, tok=self.token
+            )
+
+        for i in range(0, 150):
+            self.helper.send_state(
+                room_id2, event_type="m.test", body={"index": i}, tok=self.token
+            )
+
+        # Insert and run the background update.
+        self.get_success(
+            self.store.db_pool.simple_insert(
+                "background_updates",
+                {"update_name": "chain_cover", "progress_json": "{}"},
+            )
+        )
+
+        # Ugh, have to reset this flag
+        self.store.db_pool.updates._all_done = False
+
+        iterations = 0
+        while not self.get_success(
+            self.store.db_pool.updates.has_completed_background_updates()
+        ):
+            iterations += 1
+            self.get_success(
+                self.store.db_pool.updates.do_next_background_update(100), by=0.1
+            )
+
+        # Ensure that we did actually take multiple iterations to process the
+        # room.
+        self.assertGreater(iterations, 1)
+
+        # Test that the `has_auth_chain_index` has been set
+        self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id1)))
+        self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id2)))