summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2020-08-24 15:49:11 +0100
committerErik Johnston <erik@matrix.org>2020-08-24 15:49:22 +0100
commitc4ca1de6a013321c8dfc4a817b61ffd3e2559c94 (patch)
tree0b044317dd8fece5755cce5caac6010677ae1adc
parentClose the database connection we create during startup (#8131) (diff)
downloadsynapse-c4ca1de6a013321c8dfc4a817b61ffd3e2559c94.tar.xz
add get_persisted_upto_position
-rw-r--r--synapse/storage/util/id_generators.py55
-rw-r--r--tests/storage/test_id_generators.py36
2 files changed, 90 insertions, 1 deletions
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 0bf772d4d1..b63a71ed97 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -14,9 +14,10 @@
 # limitations under the License.
 
 import contextlib
+import heapq
 import threading
 from collections import deque
-from typing import Dict, Set
+from typing import Dict, List, Set
 
 from typing_extensions import Deque
 
@@ -210,6 +211,15 @@ class MultiWriterIdGenerator:
         # should be less than the minimum of this set (if not empty).
         self._unfinished_ids = set()  # type: Set[int]
 
+        # We track the max position where we know everything before has been
+        # persisted. This is done by a) looking at the min across all instances
+        # and b) noting that if we have seen a run of persisted positions
+        # without gaps (e.g. 5, 6, 7) then we can skip forward (e.g. to 7).
+        self._persisted_upto_position = (
+            min(self._current_positions.values()) if self._current_positions else 0
+        )
+        self._known_persisted_positions = []  # type: List[int]
+
         self._sequence_gen = PostgresSequenceGenerator(sequence_name)
 
     def _load_current_ids(
@@ -326,3 +336,46 @@ class MultiWriterIdGenerator:
             self._current_positions[instance_name] = max(
                 new_id, self._current_positions.get(instance_name, 0)
             )
+
+            self._add_persisted_position(new_id)
+
+    def get_persisted_upto_position(self) -> int:
+        """Get the max position where all previous positions have been
+        persisted.
+        """
+
+        with self._lock:
+            return self._persisted_upto_position
+
+    def _add_persisted_position(self, new_id: int):
+        """Record that we have persisted a position.
+
+        This is used to keep the `_current_positions` up to date.
+        """
+
+        # We require that the lock is locked by caller
+        assert self._lock.locked()
+
+        heapq.heappush(self._known_persisted_positions, new_id)
+
+        # We move the current min position up if the minimum current positions
+        # of all instances is higher (since by definition all positions less
+        # that that have been persisted).
+        min_curr = min(self._current_positions.values())
+        self._persisted_upto_position = max(min_curr, self._persisted_upto_position)
+
+        # We now iterate through the seen positions, discarding those that are
+        # less than the current min positions, and incrementing the min position
+        # if its exactly one greater.
+        while self._known_persisted_positions:
+            if self._known_persisted_positions[0] <= self._persisted_upto_position:
+                heapq.heappop(self._known_persisted_positions)
+            elif (
+                self._known_persisted_positions[0] == self._persisted_upto_position + 1
+            ):
+                heapq.heappop(self._known_persisted_positions)
+                self._persisted_upto_position += 1
+            else:
+                # There was a gap in seen positions, so there is nothing more to
+                # do.
+                break
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index 7a05194653..f42849ad16 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -182,3 +182,39 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
 
         self.assertEqual(id_gen.get_positions(), {"master": 8})
         self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
+
+    def test_get_persisted_upto_position(self):
+        """Test that `get_persisted_upto_position` correctly tracks updates to
+        positions.
+        """
+
+        self._insert_rows("first", 3)
+        self._insert_rows("second", 5)
+
+        id_gen = self._create_id_generator("first")
+
+        # Min is 3 and there is a gap between 5, so we expect it to be 3.
+        self.assertEqual(id_gen.get_persisted_upto_position(), 3)
+
+        # We advance "first" straight to 6. Min is now 5 but there is no gap so
+        # we expect it to be 6
+        id_gen.advance("first", 6)
+        self.assertEqual(id_gen.get_persisted_upto_position(), 6)
+
+        # No gap, so we expect 7.
+        id_gen.advance("second", 7)
+        self.assertEqual(id_gen.get_persisted_upto_position(), 7)
+
+        # We haven't seen 8 yet, so we expect 7 still.
+        id_gen.advance("second", 9)
+        self.assertEqual(id_gen.get_persisted_upto_position(), 7)
+
+        # Now that we've seen 7, 8 and 9 we can got straight to 9.
+        id_gen.advance("first", 8)
+        self.assertEqual(id_gen.get_persisted_upto_position(), 9)
+
+        # Jmp forward with gaps. The minimum is 11, even though we haven't seen
+        # 10 we know that everything before 11 must be persisted.
+        id_gen.advance("first", 11)
+        id_gen.advance("second", 15)
+        self.assertEqual(id_gen.get_persisted_upto_position(), 11)