diff options
Diffstat (limited to '')
-rw-r--r-- | changelog.d/8257.misc | 1 | ||||
-rw-r--r-- | synapse/storage/util/id_generators.py | 43 | ||||
-rw-r--r-- | tests/storage/test_id_generators.py | 50 |
3 files changed, 88 insertions, 6 deletions
diff --git a/changelog.d/8257.misc b/changelog.d/8257.misc new file mode 100644 index 0000000000..47ac583eb4 --- /dev/null +++ b/changelog.d/8257.misc @@ -0,0 +1 @@ +Fix non-user visible bug in implementation of `MultiWriterIdGenerator.get_current_token_for_writer`. diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index b7eb4f8ac9..2a66b3ad4e 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -224,6 +224,10 @@ class MultiWriterIdGenerator: # should be less than the minimum of this set (if not empty). self._unfinished_ids = set() # type: Set[int] + # Set of local IDs that we've processed that are larger than the current + # position, due to there being smaller unpersisted IDs. + self._finished_ids = set() # type: Set[int] + # We track the max position where we know everything before has been # persisted. This is done by a) looking at the min across all instances # and b) noting that if we have seen a run of persisted positions @@ -348,17 +352,44 @@ class MultiWriterIdGenerator: def _mark_id_as_finished(self, next_id: int): """The ID has finished being processed so we should advance the - current poistion if possible. + current position if possible. """ with self._lock: self._unfinished_ids.discard(next_id) + self._finished_ids.add(next_id) + + new_cur = None + + if self._unfinished_ids: + # If there are unfinished IDs then the new position will be the + # largest finished ID less than the minimum unfinished ID. + + finished = set() + + min_unfinshed = min(self._unfinished_ids) + for s in self._finished_ids: + if s < min_unfinshed: + if new_cur is None or new_cur < s: + new_cur = s + else: + finished.add(s) + + # We clear these out since they're now all less than the new + # position. + self._finished_ids = finished + else: + # There are no unfinished IDs so the new position is simply the + # largest finished one. + new_cur = max(self._finished_ids) + + # We clear these out since they're now all less than the new + # position. + self._finished_ids.clear() - # Figure out if its safe to advance the position by checking there - # aren't any lower allocated IDs that are yet to finish. - if all(c > next_id for c in self._unfinished_ids): + if new_cur: curr = self._current_positions.get(self._instance_name, 0) - self._current_positions[self._instance_name] = max(curr, next_id) + self._current_positions[self._instance_name] = max(curr, new_cur) self._add_persisted_position(next_id) @@ -428,7 +459,7 @@ class MultiWriterIdGenerator: # We move the current min position up if the minimum current positions # of all instances is higher (since by definition all positions less # that that have been persisted). - min_curr = min(self._current_positions.values()) + min_curr = min(self._current_positions.values(), default=0) self._persisted_upto_position = max(min_curr, self._persisted_upto_position) # We now iterate through the seen positions, discarding those that are diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py index f0a8e32f1e..20636fc400 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py @@ -122,6 +122,56 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self.assertEqual(id_gen.get_positions(), {"master": 8}) self.assertEqual(id_gen.get_current_token_for_writer("master"), 8) + def test_out_of_order_finish(self): + """Test that IDs persisted out of order are correctly handled + """ + + # Prefill table with 7 rows written by 'master' + self._insert_rows("master", 7) + + id_gen = self._create_id_generator() + + self.assertEqual(id_gen.get_positions(), {"master": 7}) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) + + ctx1 = self.get_success(id_gen.get_next()) + ctx2 = self.get_success(id_gen.get_next()) + ctx3 = self.get_success(id_gen.get_next()) + ctx4 = self.get_success(id_gen.get_next()) + + s1 = ctx1.__enter__() + s2 = ctx2.__enter__() + s3 = ctx3.__enter__() + s4 = ctx4.__enter__() + + self.assertEqual(s1, 8) + self.assertEqual(s2, 9) + self.assertEqual(s3, 10) + self.assertEqual(s4, 11) + + self.assertEqual(id_gen.get_positions(), {"master": 7}) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) + + ctx2.__exit__(None, None, None) + + self.assertEqual(id_gen.get_positions(), {"master": 7}) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) + + ctx1.__exit__(None, None, None) + + self.assertEqual(id_gen.get_positions(), {"master": 9}) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 9) + + ctx4.__exit__(None, None, None) + + self.assertEqual(id_gen.get_positions(), {"master": 9}) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 9) + + ctx3.__exit__(None, None, None) + + self.assertEqual(id_gen.get_positions(), {"master": 11}) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 11) + def test_multi_instance(self): """Test that reads and writes from multiple processes are handled correctly. |