summary refs log tree commit diff
path: root/tests/storage/databases
diff options
context:
space:
mode:
Diffstat (limited to 'tests/storage/databases')
-rw-r--r--tests/storage/databases/main/test_events_worker.py139
1 files changed, 138 insertions, 1 deletions
diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py
index a649e8c618..5ae491ff5a 100644
--- a/tests/storage/databases/main/test_events_worker.py
+++ b/tests/storage/databases/main/test_events_worker.py
@@ -12,11 +12,24 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import json
+from contextlib import contextmanager
+from typing import Generator
 
+from twisted.enterprise.adbapi import ConnectionPool
+from twisted.internet.defer import ensureDeferred
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.api.room_versions import EventFormatVersions, RoomVersions
 from synapse.logging.context import LoggingContext
 from synapse.rest import admin
 from synapse.rest.client import login, room
-from synapse.storage.databases.main.events_worker import EventsWorkerStore
+from synapse.server import HomeServer
+from synapse.storage.databases.main.events_worker import (
+    EVENT_QUEUE_THREADS,
+    EventsWorkerStore,
+)
+from synapse.storage.types import Connection
+from synapse.util import Clock
 from synapse.util.async_helpers import yieldable_gather_results
 
 from tests import unittest
@@ -144,3 +157,127 @@ class EventCacheTestCase(unittest.HomeserverTestCase):
 
             # We should have fetched the event from the DB
             self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1)
+
+
+class DatabaseOutageTestCase(unittest.HomeserverTestCase):
+    """Test event fetching during a database outage."""
+
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
+        self.store: EventsWorkerStore = hs.get_datastore()
+
+        self.room_id = f"!room:{hs.hostname}"
+        self.event_ids = [f"event{i}" for i in range(20)]
+
+        self._populate_events()
+
+    def _populate_events(self) -> None:
+        """Ensure that there are test events in the database.
+
+        When testing with the in-memory SQLite database, all the events are lost during
+        the simulated outage.
+
+        To ensure consistency between `room_id`s and `event_id`s before and after the
+        outage, rows are built and inserted manually.
+
+        Upserts are used to handle the non-SQLite case where events are not lost.
+        """
+        self.get_success(
+            self.store.db_pool.simple_upsert(
+                "rooms",
+                {"room_id": self.room_id},
+                {"room_version": RoomVersions.V4.identifier},
+            )
+        )
+
+        self.event_ids = [f"event{i}" for i in range(20)]
+        for idx, event_id in enumerate(self.event_ids):
+            self.get_success(
+                self.store.db_pool.simple_upsert(
+                    "events",
+                    {"event_id": event_id},
+                    {
+                        "event_id": event_id,
+                        "room_id": self.room_id,
+                        "topological_ordering": idx,
+                        "stream_ordering": idx,
+                        "type": "test",
+                        "processed": True,
+                        "outlier": False,
+                    },
+                )
+            )
+            self.get_success(
+                self.store.db_pool.simple_upsert(
+                    "event_json",
+                    {"event_id": event_id},
+                    {
+                        "room_id": self.room_id,
+                        "json": json.dumps({"type": "test", "room_id": self.room_id}),
+                        "internal_metadata": "{}",
+                        "format_version": EventFormatVersions.V3,
+                    },
+                )
+            )
+
+    @contextmanager
+    def _outage(self) -> Generator[None, None, None]:
+        """Simulate a database outage.
+
+        Returns:
+            A context manager. While the context is active, any attempts to connect to
+            the database will fail.
+        """
+        connection_pool = self.store.db_pool._db_pool
+
+        # Close all connections and shut down the database `ThreadPool`.
+        connection_pool.close()
+
+        # Restart the database `ThreadPool`.
+        connection_pool.start()
+
+        original_connection_factory = connection_pool.connectionFactory
+
+        def connection_factory(_pool: ConnectionPool) -> Connection:
+            raise Exception("Could not connect to the database.")
+
+        connection_pool.connectionFactory = connection_factory  # type: ignore[assignment]
+        try:
+            yield
+        finally:
+            connection_pool.connectionFactory = original_connection_factory
+
+            # If the in-memory SQLite database is being used, all the events are gone.
+            # Restore the test data.
+            self._populate_events()
+
+    def test_failure(self) -> None:
+        """Test that event fetches do not get stuck during a database outage."""
+        with self._outage():
+            failure = self.get_failure(
+                self.store.get_event(self.event_ids[0]), Exception
+            )
+            self.assertEqual(str(failure.value), "Could not connect to the database.")
+
+    def test_recovery(self) -> None:
+        """Test that event fetchers recover after a database outage."""
+        with self._outage():
+            # Kick off a bunch of event fetches but do not pump the reactor
+            event_deferreds = []
+            for event_id in self.event_ids:
+                event_deferreds.append(ensureDeferred(self.store.get_event(event_id)))
+
+            # We should have maxed out on event fetcher threads
+            self.assertEqual(self.store._event_fetch_ongoing, EVENT_QUEUE_THREADS)
+
+            # All the event fetchers will fail
+            self.pump()
+            self.assertEqual(self.store._event_fetch_ongoing, 0)
+
+            for event_deferred in event_deferreds:
+                failure = self.get_failure(event_deferred, Exception)
+                self.assertEqual(
+                    str(failure.value), "Could not connect to the database."
+                )
+
+        # This next event fetch should succeed
+        self.get_success(self.store.get_event(self.event_ids[0]))