summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/11036.misc1
-rw-r--r--changelog.d/11045.bugfix1
-rw-r--r--synapse/storage/util/id_generators.py82
-rw-r--r--tests/config/test_cache.py24
4 files changed, 80 insertions, 28 deletions
diff --git a/changelog.d/11036.misc b/changelog.d/11036.misc
new file mode 100644
index 0000000000..aae5ee62b2
--- /dev/null
+++ b/changelog.d/11036.misc
@@ -0,0 +1 @@
+Ensure that cache config tests do not share state.
diff --git a/changelog.d/11045.bugfix b/changelog.d/11045.bugfix
new file mode 100644
index 0000000000..d712dc946a
--- /dev/null
+++ b/changelog.d/11045.bugfix
@@ -0,0 +1 @@
+Fix a long-standing bug when using multiple event persister workers where events were not correctly sent down `/sync` due to a race.
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 852bd79fee..670811611f 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -36,7 +36,7 @@ from typing import (
 )
 
 import attr
-from sortedcontainers import SortedSet
+from sortedcontainers import SortedList, SortedSet
 
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.database import (
@@ -265,6 +265,15 @@ class MultiWriterIdGenerator:
         # should be less than the minimum of this set (if not empty).
         self._unfinished_ids: SortedSet[int] = SortedSet()
 
+        # We also need to track when we've requested some new stream IDs but
+        # they haven't yet been added to the `_unfinished_ids` set. Every time
+        # we request a new stream ID we add the current max stream ID to the
+        # list, and remove it once we've added the newly allocated IDs to the
+        # `_unfinished_ids` set. This means that we *may* be allocated stream
+        # IDs above those in the list, and so we can't advance the local current
+        # position beyond the minimum stream ID in this list.
+        self._in_flight_fetches: SortedList[int] = SortedList()
+
         # Set of local IDs that we've processed that are larger than the current
         # position, due to there being smaller unpersisted IDs.
         self._finished_ids: Set[int] = set()
@@ -290,6 +299,9 @@ class MultiWriterIdGenerator:
         )
         self._known_persisted_positions: List[int] = []
 
+        # The maximum stream ID that we have seen been allocated across any writer.
+        self._max_seen_allocated_stream_id = 1
+
         self._sequence_gen = PostgresSequenceGenerator(sequence_name)
 
         # We check that the table and sequence haven't diverged.
@@ -305,6 +317,10 @@ class MultiWriterIdGenerator:
         # This goes and fills out the above state from the database.
         self._load_current_ids(db_conn, tables)
 
+        self._max_seen_allocated_stream_id = max(
+            self._current_positions.values(), default=1
+        )
+
     def _load_current_ids(
         self,
         db_conn: LoggingDatabaseConnection,
@@ -411,10 +427,32 @@ class MultiWriterIdGenerator:
         cur.close()
 
     def _load_next_id_txn(self, txn: Cursor) -> int:
-        return self._sequence_gen.get_next_id_txn(txn)
+        stream_ids = self._load_next_mult_id_txn(txn, 1)
+        return stream_ids[0]
 
     def _load_next_mult_id_txn(self, txn: Cursor, n: int) -> List[int]:
-        return self._sequence_gen.get_next_mult_txn(txn, n)
+        # We need to track that we've requested some more stream IDs, and what
+        # the current max allocated stream ID is. This is to prevent a race
+        # where we've been allocated stream IDs but they have not yet been added
+        # to the `_unfinished_ids` set, allowing the current position to advance
+        # past them.
+        with self._lock:
+            current_max = self._max_seen_allocated_stream_id
+            self._in_flight_fetches.add(current_max)
+
+        try:
+            stream_ids = self._sequence_gen.get_next_mult_txn(txn, n)
+
+            with self._lock:
+                self._unfinished_ids.update(stream_ids)
+                self._max_seen_allocated_stream_id = max(
+                    self._max_seen_allocated_stream_id, self._unfinished_ids[-1]
+                )
+        finally:
+            with self._lock:
+                self._in_flight_fetches.remove(current_max)
+
+        return stream_ids
 
     def get_next(self) -> AsyncContextManager[int]:
         """
@@ -463,9 +501,6 @@ class MultiWriterIdGenerator:
 
         next_id = self._load_next_id_txn(txn)
 
-        with self._lock:
-            self._unfinished_ids.add(next_id)
-
         txn.call_after(self._mark_id_as_finished, next_id)
         txn.call_on_exception(self._mark_id_as_finished, next_id)
 
@@ -497,15 +532,27 @@ class MultiWriterIdGenerator:
 
             new_cur: Optional[int] = None
 
-            if self._unfinished_ids:
+            if self._unfinished_ids or self._in_flight_fetches:
                 # If there are unfinished IDs then the new position will be the
-                # largest finished ID less than the minimum unfinished ID.
+                # largest finished ID strictly less than the minimum unfinished
+                # ID.
+
+                # The minimum unfinished ID needs to take account of both
+                # `_unfinished_ids` and `_in_flight_fetches`.
+                if self._unfinished_ids and self._in_flight_fetches:
+                    # `_in_flight_fetches` stores the maximum safe stream ID, so
+                    # we add one to make it equivalent to the minimum unsafe ID.
+                    min_unfinished = min(
+                        self._unfinished_ids[0], self._in_flight_fetches[0] + 1
+                    )
+                elif self._in_flight_fetches:
+                    min_unfinished = self._in_flight_fetches[0] + 1
+                else:
+                    min_unfinished = self._unfinished_ids[0]
 
                 finished = set()
-
-                min_unfinshed = self._unfinished_ids[0]
                 for s in self._finished_ids:
-                    if s < min_unfinshed:
+                    if s < min_unfinished:
                         if new_cur is None or new_cur < s:
                             new_cur = s
                     else:
@@ -575,6 +622,10 @@ class MultiWriterIdGenerator:
                 new_id, self._current_positions.get(instance_name, 0)
             )
 
+            self._max_seen_allocated_stream_id = max(
+                self._max_seen_allocated_stream_id, new_id
+            )
+
             self._add_persisted_position(new_id)
 
     def get_persisted_upto_position(self) -> int:
@@ -605,7 +656,11 @@ class MultiWriterIdGenerator:
         # to report a recent position when asked, rather than a potentially old
         # one (if this instance hasn't written anything for a while).
         our_current_position = self._current_positions.get(self._instance_name)
-        if our_current_position and not self._unfinished_ids:
+        if (
+            our_current_position
+            and not self._unfinished_ids
+            and not self._in_flight_fetches
+        ):
             self._current_positions[self._instance_name] = max(
                 our_current_position, new_id
             )
@@ -697,9 +752,6 @@ class _MultiWriterCtxManager:
             db_autocommit=True,
         )
 
-        with self.id_gen._lock:
-            self.id_gen._unfinished_ids.update(self.stream_ids)
-
         if self.multiple_ids is None:
             return self.stream_ids[0] * self.id_gen._return_factor
         else:
diff --git a/tests/config/test_cache.py b/tests/config/test_cache.py
index 79d417568d..4bb82e810e 100644
--- a/tests/config/test_cache.py
+++ b/tests/config/test_cache.py
@@ -12,25 +12,23 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from unittest.mock import patch
-
 from synapse.config.cache import CacheConfig, add_resizable_cache
 from synapse.util.caches.lrucache import LruCache
 
 from tests.unittest import TestCase
 
 
-# Patch the global _CACHES so that each test runs against its own state.
-@patch("synapse.config.cache._CACHES", new_callable=dict)
 class CacheConfigTests(TestCase):
     def setUp(self):
-        # Reset caches before each test
+        # Reset caches before each test since there's global state involved.
         self.config = CacheConfig()
+        self.config.reset()
 
     def tearDown(self):
+        # Also reset the caches after each test to leave state pristine.
         self.config.reset()
 
-    def test_individual_caches_from_environ(self, _caches):
+    def test_individual_caches_from_environ(self):
         """
         Individual cache factors will be loaded from the environment.
         """
@@ -43,7 +41,7 @@ class CacheConfigTests(TestCase):
 
         self.assertEqual(dict(self.config.cache_factors), {"something_or_other": 2.0})
 
-    def test_config_overrides_environ(self, _caches):
+    def test_config_overrides_environ(self):
         """
         Individual cache factors defined in the environment will take precedence
         over those in the config.
@@ -60,7 +58,7 @@ class CacheConfigTests(TestCase):
             {"foo": 1.0, "bar": 3.0, "something_or_other": 2.0},
         )
 
-    def test_individual_instantiated_before_config_load(self, _caches):
+    def test_individual_instantiated_before_config_load(self):
         """
         If a cache is instantiated before the config is read, it will be given
         the default cache size in the interim, and then resized once the config
@@ -76,7 +74,7 @@ class CacheConfigTests(TestCase):
 
         self.assertEqual(cache.max_size, 300)
 
-    def test_individual_instantiated_after_config_load(self, _caches):
+    def test_individual_instantiated_after_config_load(self):
         """
         If a cache is instantiated after the config is read, it will be
         immediately resized to the correct size given the per_cache_factor if
@@ -89,7 +87,7 @@ class CacheConfigTests(TestCase):
         add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor)
         self.assertEqual(cache.max_size, 200)
 
-    def test_global_instantiated_before_config_load(self, _caches):
+    def test_global_instantiated_before_config_load(self):
         """
         If a cache is instantiated before the config is read, it will be given
         the default cache size in the interim, and then resized to the new
@@ -104,7 +102,7 @@ class CacheConfigTests(TestCase):
 
         self.assertEqual(cache.max_size, 400)
 
-    def test_global_instantiated_after_config_load(self, _caches):
+    def test_global_instantiated_after_config_load(self):
         """
         If a cache is instantiated after the config is read, it will be
         immediately resized to the correct size given the global factor if there
@@ -117,7 +115,7 @@ class CacheConfigTests(TestCase):
         add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor)
         self.assertEqual(cache.max_size, 150)
 
-    def test_cache_with_asterisk_in_name(self, _caches):
+    def test_cache_with_asterisk_in_name(self):
         """Some caches have asterisks in their name, test that they are set correctly."""
 
         config = {
@@ -143,7 +141,7 @@ class CacheConfigTests(TestCase):
         add_resizable_cache("*cache_c*", cache_resize_callback=cache_c.set_cache_factor)
         self.assertEqual(cache_c.max_size, 200)
 
-    def test_apply_cache_factor_from_config(self, _caches):
+    def test_apply_cache_factor_from_config(self):
         """Caches can disable applying cache factor updates, mainly used by
         event cache size.
         """