diff --git a/changelog.d/11376.bugfix b/changelog.d/11376.bugfix
new file mode 100644
index 0000000000..639e48b59b
--- /dev/null
+++ b/changelog.d/11376.bugfix
@@ -0,0 +1 @@
+Fix a long-standing bug where all requests that read events from the database could get stuck as a result of losing the database connection, for real this time. Also fix a race condition introduced in the previous insufficient fix in 1.47.0.
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index c6bf316d5b..c6bcfe1c32 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -72,7 +72,7 @@ from synapse.util.metrics import Measure
logger = logging.getLogger(__name__)
-# These values are used in the `enqueus_event` and `_do_fetch` methods to
+# These values are used in the `enqueue_event` and `_fetch_loop` methods to
# control how we batch/bulk fetch events from the database.
# The values are plucked out of thing air to make initial sync run faster
# on jki.re
@@ -602,7 +602,7 @@ class EventsWorkerStore(SQLBaseStore):
# already due to `_get_events_from_db`).
fetching_deferred: ObservableDeferred[
Dict[str, _EventCacheEntry]
- ] = ObservableDeferred(defer.Deferred())
+ ] = ObservableDeferred(defer.Deferred(), consumeErrors=True)
for event_id in missing_events_ids:
self._current_event_fetches[event_id] = fetching_deferred
@@ -736,35 +736,118 @@ class EventsWorkerStore(SQLBaseStore):
for e in state_to_include.values()
]
- def _do_fetch(self, conn: Connection) -> None:
+ def _maybe_start_fetch_thread(self) -> None:
+ """Starts an event fetch thread if we are not yet at the maximum number."""
+ with self._event_fetch_lock:
+ if (
+ self._event_fetch_list
+ and self._event_fetch_ongoing < EVENT_QUEUE_THREADS
+ ):
+ self._event_fetch_ongoing += 1
+ event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
+ # `_event_fetch_ongoing` is decremented in `_fetch_thread`.
+ should_start = True
+ else:
+ should_start = False
+
+ if should_start:
+ run_as_background_process("fetch_events", self._fetch_thread)
+
+ async def _fetch_thread(self) -> None:
+ """Services requests for events from `_event_fetch_list`."""
+ exc = None
+ try:
+ await self.db_pool.runWithConnection(self._fetch_loop)
+ except BaseException as e:
+ exc = e
+ raise
+ finally:
+ should_restart = False
+ event_fetches_to_fail = []
+ with self._event_fetch_lock:
+ self._event_fetch_ongoing -= 1
+ event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
+
+ # There may still be work remaining in `_event_fetch_list` if we
+ # failed, or it was added in between us deciding to exit and
+ # decrementing `_event_fetch_ongoing`.
+ if self._event_fetch_list:
+ if exc is None:
+ # We decided to exit, but then some more work was added
+ # before `_event_fetch_ongoing` was decremented.
+ # If a new event fetch thread was not started, we should
+ # restart ourselves since the remaining event fetch threads
+ # may take a while to get around to the new work.
+ #
+ # Unfortunately it is not possible to tell whether a new
+ # event fetch thread was started, so we restart
+ # unconditionally. If we are unlucky, we will end up with
+ # an idle fetch thread, but it will time out after
+ # `EVENT_QUEUE_ITERATIONS * EVENT_QUEUE_TIMEOUT_S` seconds
+ # in any case.
+ #
+ # Note that multiple fetch threads may run down this path at
+ # the same time.
+ should_restart = True
+ elif isinstance(exc, Exception):
+ if self._event_fetch_ongoing == 0:
+ # We were the last remaining fetcher and failed.
+ # Fail any outstanding fetches since no one else will
+ # handle them.
+ event_fetches_to_fail = self._event_fetch_list
+ self._event_fetch_list = []
+ else:
+ # We weren't the last remaining fetcher, so another
+ # fetcher will pick up the work. This will either happen
+ # after their existing work, however long that takes,
+ # or after at most `EVENT_QUEUE_TIMEOUT_S` seconds if
+ # they are idle.
+ pass
+ else:
+ # The exception is a `SystemExit`, `KeyboardInterrupt` or
+ # `GeneratorExit`. Don't try to do anything clever here.
+ pass
+
+ if should_restart:
+ # We exited cleanly but noticed more work.
+ self._maybe_start_fetch_thread()
+
+ if event_fetches_to_fail:
+ # We were the last remaining fetcher and failed.
+ # Fail any outstanding fetches since no one else will handle them.
+ assert exc is not None
+ with PreserveLoggingContext():
+ for _, deferred in event_fetches_to_fail:
+ deferred.errback(exc)
+
+ def _fetch_loop(self, conn: Connection) -> None:
"""Takes a database connection and waits for requests for events from
the _event_fetch_list queue.
"""
- try:
- i = 0
- while True:
- with self._event_fetch_lock:
- event_list = self._event_fetch_list
- self._event_fetch_list = []
-
- if not event_list:
- single_threaded = self.database_engine.single_threaded
- if (
- not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING
- or single_threaded
- or i > EVENT_QUEUE_ITERATIONS
- ):
- break
- else:
- self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
- i += 1
- continue
- i = 0
+ i = 0
+ while True:
+ with self._event_fetch_lock:
+ event_list = self._event_fetch_list
+ self._event_fetch_list = []
+
+ if not event_list:
+ # There are no requests waiting. If we haven't yet reached the
+ # maximum iteration limit, wait for some more requests to turn up.
+ # Otherwise, bail out.
+ single_threaded = self.database_engine.single_threaded
+ if (
+ not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING
+ or single_threaded
+ or i > EVENT_QUEUE_ITERATIONS
+ ):
+ return
+
+ self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
+ i += 1
+ continue
+ i = 0
- self._fetch_event_list(conn, event_list)
- finally:
- self._event_fetch_ongoing -= 1
- event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
+ self._fetch_event_list(conn, event_list)
def _fetch_event_list(
self, conn: Connection, event_list: List[Tuple[List[str], defer.Deferred]]
@@ -806,9 +889,7 @@ class EventsWorkerStore(SQLBaseStore):
# We only want to resolve deferreds from the main thread
def fire(evs, exc):
for _, d in evs:
- if not d.called:
- with PreserveLoggingContext():
- d.errback(exc)
+ d.errback(exc)
with PreserveLoggingContext():
self.hs.get_reactor().callFromThread(fire, event_list, e)
@@ -983,20 +1064,9 @@ class EventsWorkerStore(SQLBaseStore):
events_d = defer.Deferred()
with self._event_fetch_lock:
self._event_fetch_list.append((events, events_d))
-
self._event_fetch_lock.notify()
- if self._event_fetch_ongoing < EVENT_QUEUE_THREADS:
- self._event_fetch_ongoing += 1
- event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
- should_start = True
- else:
- should_start = False
-
- if should_start:
- run_as_background_process(
- "fetch_events", self.db_pool.runWithConnection, self._do_fetch
- )
+ self._maybe_start_fetch_thread()
logger.debug("Loading %d events: %s", len(events), events)
with PreserveLoggingContext():
diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py
index a649e8c618..5ae491ff5a 100644
--- a/tests/storage/databases/main/test_events_worker.py
+++ b/tests/storage/databases/main/test_events_worker.py
@@ -12,11 +12,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
+from contextlib import contextmanager
+from typing import Generator
+from twisted.enterprise.adbapi import ConnectionPool
+from twisted.internet.defer import ensureDeferred
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.api.room_versions import EventFormatVersions, RoomVersions
from synapse.logging.context import LoggingContext
from synapse.rest import admin
from synapse.rest.client import login, room
-from synapse.storage.databases.main.events_worker import EventsWorkerStore
+from synapse.server import HomeServer
+from synapse.storage.databases.main.events_worker import (
+ EVENT_QUEUE_THREADS,
+ EventsWorkerStore,
+)
+from synapse.storage.types import Connection
+from synapse.util import Clock
from synapse.util.async_helpers import yieldable_gather_results
from tests import unittest
@@ -144,3 +157,127 @@ class EventCacheTestCase(unittest.HomeserverTestCase):
# We should have fetched the event from the DB
self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1)
+
+
+class DatabaseOutageTestCase(unittest.HomeserverTestCase):
+ """Test event fetching during a database outage."""
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
+ self.store: EventsWorkerStore = hs.get_datastore()
+
+ self.room_id = f"!room:{hs.hostname}"
+ self.event_ids = [f"event{i}" for i in range(20)]
+
+ self._populate_events()
+
+ def _populate_events(self) -> None:
+ """Ensure that there are test events in the database.
+
+ When testing with the in-memory SQLite database, all the events are lost during
+ the simulated outage.
+
+ To ensure consistency between `room_id`s and `event_id`s before and after the
+ outage, rows are built and inserted manually.
+
+ Upserts are used to handle the non-SQLite case where events are not lost.
+ """
+ self.get_success(
+ self.store.db_pool.simple_upsert(
+ "rooms",
+ {"room_id": self.room_id},
+ {"room_version": RoomVersions.V4.identifier},
+ )
+ )
+
+ self.event_ids = [f"event{i}" for i in range(20)]
+ for idx, event_id in enumerate(self.event_ids):
+ self.get_success(
+ self.store.db_pool.simple_upsert(
+ "events",
+ {"event_id": event_id},
+ {
+ "event_id": event_id,
+ "room_id": self.room_id,
+ "topological_ordering": idx,
+ "stream_ordering": idx,
+ "type": "test",
+ "processed": True,
+ "outlier": False,
+ },
+ )
+ )
+ self.get_success(
+ self.store.db_pool.simple_upsert(
+ "event_json",
+ {"event_id": event_id},
+ {
+ "room_id": self.room_id,
+ "json": json.dumps({"type": "test", "room_id": self.room_id}),
+ "internal_metadata": "{}",
+ "format_version": EventFormatVersions.V3,
+ },
+ )
+ )
+
+ @contextmanager
+ def _outage(self) -> Generator[None, None, None]:
+ """Simulate a database outage.
+
+ Returns:
+ A context manager. While the context is active, any attempts to connect to
+ the database will fail.
+ """
+ connection_pool = self.store.db_pool._db_pool
+
+ # Close all connections and shut down the database `ThreadPool`.
+ connection_pool.close()
+
+ # Restart the database `ThreadPool`.
+ connection_pool.start()
+
+ original_connection_factory = connection_pool.connectionFactory
+
+ def connection_factory(_pool: ConnectionPool) -> Connection:
+ raise Exception("Could not connect to the database.")
+
+ connection_pool.connectionFactory = connection_factory # type: ignore[assignment]
+ try:
+ yield
+ finally:
+ connection_pool.connectionFactory = original_connection_factory
+
+ # If the in-memory SQLite database is being used, all the events are gone.
+ # Restore the test data.
+ self._populate_events()
+
+ def test_failure(self) -> None:
+ """Test that event fetches do not get stuck during a database outage."""
+ with self._outage():
+ failure = self.get_failure(
+ self.store.get_event(self.event_ids[0]), Exception
+ )
+ self.assertEqual(str(failure.value), "Could not connect to the database.")
+
+ def test_recovery(self) -> None:
+ """Test that event fetchers recover after a database outage."""
+ with self._outage():
+ # Kick off a bunch of event fetches but do not pump the reactor
+ event_deferreds = []
+ for event_id in self.event_ids:
+ event_deferreds.append(ensureDeferred(self.store.get_event(event_id)))
+
+ # We should have maxed out on event fetcher threads
+ self.assertEqual(self.store._event_fetch_ongoing, EVENT_QUEUE_THREADS)
+
+ # All the event fetchers will fail
+ self.pump()
+ self.assertEqual(self.store._event_fetch_ongoing, 0)
+
+ for event_deferred in event_deferreds:
+ failure = self.get_failure(event_deferred, Exception)
+ self.assertEqual(
+ str(failure.value), "Could not connect to the database."
+ )
+
+ # This next event fetch should succeed
+ self.get_success(self.store.get_event(self.event_ids[0]))
|