summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/8203.misc1
-rw-r--r--synapse/storage/util/id_generators.py39
-rw-r--r--tests/storage/test_id_generators.py105
3 files changed, 134 insertions, 11 deletions
diff --git a/changelog.d/8203.misc b/changelog.d/8203.misc
new file mode 100644
index 0000000000..9fe2224aaa
--- /dev/null
+++ b/changelog.d/8203.misc
@@ -0,0 +1 @@
+Make `MultiWriterIDGenerator` work for streams that use negative values.
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index b27a4843d0..9f3d23f0a5 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -185,6 +185,8 @@ class MultiWriterIdGenerator:
         id_column: Column that stores the stream ID.
         sequence_name: The name of the postgres sequence used to generate new
             IDs.
+        positive: Whether the IDs are positive (true) or negative (false).
+            When using negative IDs we go backwards from -1 to -2, -3, etc.
     """
 
     def __init__(
@@ -196,13 +198,19 @@ class MultiWriterIdGenerator:
         instance_column: str,
         id_column: str,
         sequence_name: str,
+        positive: bool = True,
     ):
         self._db = db
         self._instance_name = instance_name
+        self._positive = positive
+        self._return_factor = 1 if positive else -1
 
         # We lock as some functions may be called from DB threads.
         self._lock = threading.Lock()
 
+        # Note: If we are a negative stream then we still store all the IDs as
+        # positive to make life easier for us, and simply negate the IDs when we
+        # return them.
         self._current_positions = self._load_current_ids(
             db_conn, table, instance_column, id_column
         )
@@ -233,13 +241,16 @@ class MultiWriterIdGenerator:
     def _load_current_ids(
         self, db_conn, table: str, instance_column: str, id_column: str
     ) -> Dict[str, int]:
+        # If positive stream aggregate via MAX. For negative stream use MIN
+        # *and* negate the result to get a positive number.
         sql = """
-            SELECT %(instance)s, MAX(%(id)s) FROM %(table)s
+            SELECT %(instance)s, %(agg)s(%(id)s) FROM %(table)s
             GROUP BY %(instance)s
         """ % {
             "instance": instance_column,
             "id": id_column,
             "table": table,
+            "agg": "MAX" if self._positive else "-MIN",
         }
 
         cur = db_conn.cursor()
@@ -269,15 +280,16 @@ class MultiWriterIdGenerator:
         # Assert the fetched ID is actually greater than what we currently
         # believe the ID to be. If not, then the sequence and table have got
         # out of sync somehow.
-        assert self.get_current_token_for_writer(self._instance_name) < next_id
-
         with self._lock:
+            assert self._current_positions.get(self._instance_name, 0) < next_id
+
             self._unfinished_ids.add(next_id)
 
         @contextlib.contextmanager
         def manager():
             try:
-                yield next_id
+                # Multiply by the return factor so that the ID has correct sign.
+                yield self._return_factor * next_id
             finally:
                 self._mark_id_as_finished(next_id)
 
@@ -296,15 +308,15 @@ class MultiWriterIdGenerator:
         # Assert the fetched ID is actually greater than any ID we've already
         # seen. If not, then the sequence and table have got out of sync
         # somehow.
-        assert max(self.get_positions().values(), default=0) < min(next_ids)
-
         with self._lock:
+            assert max(self._current_positions.values(), default=0) < min(next_ids)
+
             self._unfinished_ids.update(next_ids)
 
         @contextlib.contextmanager
         def manager():
             try:
-                yield next_ids
+                yield [self._return_factor * i for i in next_ids]
             finally:
                 for i in next_ids:
                     self._mark_id_as_finished(i)
@@ -327,7 +339,7 @@ class MultiWriterIdGenerator:
         txn.call_after(self._mark_id_as_finished, next_id)
         txn.call_on_exception(self._mark_id_as_finished, next_id)
 
-        return next_id
+        return self._return_factor * next_id
 
     def _mark_id_as_finished(self, next_id: int):
         """The ID has finished being processed so we should advance the
@@ -359,20 +371,25 @@ class MultiWriterIdGenerator:
         """
 
         with self._lock:
-            return self._current_positions.get(instance_name, 0)
+            return self._return_factor * self._current_positions.get(instance_name, 0)
 
     def get_positions(self) -> Dict[str, int]:
         """Get a copy of the current positon map.
         """
 
         with self._lock:
-            return dict(self._current_positions)
+            return {
+                name: self._return_factor * i
+                for name, i in self._current_positions.items()
+            }
 
     def advance(self, instance_name: str, new_id: int):
         """Advance the postion of the named writer to the given ID, if greater
         than existing entry.
         """
 
+        new_id *= self._return_factor
+
         with self._lock:
             self._current_positions[instance_name] = max(
                 new_id, self._current_positions.get(instance_name, 0)
@@ -390,7 +407,7 @@ class MultiWriterIdGenerator:
         """
 
         with self._lock:
-            return self._persisted_upto_position
+            return self._return_factor * self._persisted_upto_position
 
     def _add_persisted_position(self, new_id: int):
         """Record that we have persisted a position.
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index 14ce21c786..f0a8e32f1e 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -264,3 +264,108 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         # We assume that so long as `get_next` does correctly advance the
         # `persisted_upto_position` in this case, then it will be correct in the
         # other cases that are tested above (since they'll hit the same code).
+
+
+class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
+    """Tests MultiWriterIdGenerator that produce *negative* stream IDs.
+    """
+
+    if not USE_POSTGRES_FOR_TESTS:
+        skip = "Requires Postgres"
+
+    def prepare(self, reactor, clock, hs):
+        self.store = hs.get_datastore()
+        self.db_pool = self.store.db_pool  # type: DatabasePool
+
+        self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
+
+    def _setup_db(self, txn):
+        txn.execute("CREATE SEQUENCE foobar_seq")
+        txn.execute(
+            """
+            CREATE TABLE foobar (
+                stream_id BIGINT NOT NULL,
+                instance_name TEXT NOT NULL,
+                data TEXT
+            );
+            """
+        )
+
+    def _create_id_generator(self, instance_name="master") -> MultiWriterIdGenerator:
+        def _create(conn):
+            return MultiWriterIdGenerator(
+                conn,
+                self.db_pool,
+                instance_name=instance_name,
+                table="foobar",
+                instance_column="instance_name",
+                id_column="stream_id",
+                sequence_name="foobar_seq",
+                positive=False,
+            )
+
+        return self.get_success(self.db_pool.runWithConnection(_create))
+
+    def _insert_row(self, instance_name: str, stream_id: int):
+        """Insert one row as the given instance with given stream_id.
+        """
+
+        def _insert(txn):
+            txn.execute(
+                "INSERT INTO foobar VALUES (?, ?)", (stream_id, instance_name,),
+            )
+
+        self.get_success(self.db_pool.runInteraction("_insert_row", _insert))
+
+    def test_single_instance(self):
+        """Test that reads and writes from a single process are handled
+        correctly.
+        """
+        id_gen = self._create_id_generator()
+
+        with self.get_success(id_gen.get_next()) as stream_id:
+            self._insert_row("master", stream_id)
+
+        self.assertEqual(id_gen.get_positions(), {"master": -1})
+        self.assertEqual(id_gen.get_current_token_for_writer("master"), -1)
+        self.assertEqual(id_gen.get_persisted_upto_position(), -1)
+
+        with self.get_success(id_gen.get_next_mult(3)) as stream_ids:
+            for stream_id in stream_ids:
+                self._insert_row("master", stream_id)
+
+        self.assertEqual(id_gen.get_positions(), {"master": -4})
+        self.assertEqual(id_gen.get_current_token_for_writer("master"), -4)
+        self.assertEqual(id_gen.get_persisted_upto_position(), -4)
+
+        # Test loading from DB by creating a second ID gen
+        second_id_gen = self._create_id_generator()
+
+        self.assertEqual(second_id_gen.get_positions(), {"master": -4})
+        self.assertEqual(second_id_gen.get_current_token_for_writer("master"), -4)
+        self.assertEqual(second_id_gen.get_persisted_upto_position(), -4)
+
+    def test_multiple_instance(self):
+        """Tests that having multiple instances that get advanced over
+        federation works corretly.
+        """
+        id_gen_1 = self._create_id_generator("first")
+        id_gen_2 = self._create_id_generator("second")
+
+        with self.get_success(id_gen_1.get_next()) as stream_id:
+            self._insert_row("first", stream_id)
+            id_gen_2.advance("first", stream_id)
+
+        self.assertEqual(id_gen_1.get_positions(), {"first": -1})
+        self.assertEqual(id_gen_2.get_positions(), {"first": -1})
+        self.assertEqual(id_gen_1.get_persisted_upto_position(), -1)
+        self.assertEqual(id_gen_2.get_persisted_upto_position(), -1)
+
+        with self.get_success(id_gen_2.get_next()) as stream_id:
+            self._insert_row("second", stream_id)
+            id_gen_1.advance("second", stream_id)
+
+        self.assertEqual(id_gen_1.get_positions(), {"first": -1, "second": -2})
+        self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2})
+        self.assertEqual(id_gen_1.get_persisted_upto_position(), -2)
+        self.assertEqual(id_gen_2.get_persisted_upto_position(), -2)