diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
index 83c377824b..ff67a73749 100644
--- a/tests/storage/test_event_chain.py
+++ b/tests/storage/test_event_chain.py
@@ -20,7 +20,10 @@ from twisted.trial import unittest
from synapse.api.constants import EventTypes
from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
from synapse.storage.databases.main.events import _LinkMap
+from synapse.types import create_requester
from tests.unittest import HomeserverTestCase
@@ -470,3 +473,114 @@ class LinkMapTestCase(unittest.TestCase):
self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1), (3, 3)])
self.assertCountEqual(link_map.get_additions(), [(1, 3, 2, 3), (2, 5, 1, 3)])
+
+
+class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
+
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def test_background_update(self):
+ """Test that the background update to calculate auth chains for historic
+ rooms works correctly.
+ """
+
+ # Create a room
+ user_id = self.register_user("foo", "pass")
+ token = self.login("foo", "pass")
+ room_id = self.helper.create_room_as(user_id, tok=token)
+ requester = create_requester(user_id)
+
+ store = self.hs.get_datastore()
+
+ # Mark the room as not having a chain cover index
+ self.get_success(
+ store.db_pool.simple_update(
+ table="rooms",
+ keyvalues={"room_id": room_id},
+ updatevalues={"has_auth_chain_index": False},
+ desc="test",
+ )
+ )
+
+ # Create a fork in the DAG with different events.
+ event_handler = self.hs.get_event_creation_handler()
+ latest_event_ids = self.get_success(store.get_prev_events_for_room(room_id))
+ event, context = self.get_success(
+ event_handler.create_event(
+ requester,
+ {
+ "type": "some_state_type",
+ "state_key": "",
+ "content": {},
+ "room_id": room_id,
+ "sender": user_id,
+ },
+ prev_event_ids=latest_event_ids,
+ )
+ )
+ self.get_success(
+ event_handler.handle_new_client_event(requester, event, context)
+ )
+ state1 = list(self.get_success(context.get_current_state_ids()).values())
+
+ event, context = self.get_success(
+ event_handler.create_event(
+ requester,
+ {
+ "type": "some_state_type",
+ "state_key": "",
+ "content": {},
+ "room_id": room_id,
+ "sender": user_id,
+ },
+ prev_event_ids=latest_event_ids,
+ )
+ )
+ self.get_success(
+ event_handler.handle_new_client_event(requester, event, context)
+ )
+ state2 = list(self.get_success(context.get_current_state_ids()).values())
+
+ # Delete the chain cover info.
+
+ def _delete_tables(txn):
+ txn.execute("DELETE FROM event_auth_chains")
+ txn.execute("DELETE FROM event_auth_chain_links")
+
+ self.get_success(store.db_pool.runInteraction("test", _delete_tables))
+
+ # Insert and run the background update.
+ self.get_success(
+ store.db_pool.simple_insert(
+ "background_updates",
+ {"update_name": "chain_cover", "progress_json": "{}"},
+ )
+ )
+
+ # Ugh, have to reset this flag
+ store.db_pool.updates._all_done = False
+
+ while not self.get_success(
+ store.db_pool.updates.has_completed_background_updates()
+ ):
+ self.get_success(
+ store.db_pool.updates.do_next_background_update(100), by=0.1
+ )
+
+ # Test that the `has_auth_chain_index` has been set
+ self.assertTrue(self.get_success(store.has_auth_chain_index(room_id)))
+
+ # Test that calculating the auth chain difference using the newly
+ # calculated chain cover works.
+ self.get_success(
+ store.db_pool.runInteraction(
+ "test",
+ store._get_auth_chain_difference_using_cover_index_txn,
+ room_id,
+ [state1, state2],
+ )
+ )
|