summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/storage/test_id_generators.py162
1 files changed, 152 insertions, 10 deletions
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index 2d8d1f860f..d6a2b8d274 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -16,15 +16,157 @@ from typing import List, Optional
 from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.server import HomeServer
-from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.storage.engines import IncorrectDatabaseSetup
-from synapse.storage.util.id_generators import MultiWriterIdGenerator
+from synapse.storage.types import Cursor
+from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
 from synapse.util import Clock
 
 from tests.unittest import HomeserverTestCase
 from tests.utils import USE_POSTGRES_FOR_TESTS
 
 
+class StreamIdGeneratorTestCase(HomeserverTestCase):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        self.store = hs.get_datastores().main
+        self.db_pool: DatabasePool = self.store.db_pool
+
+        self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
+
+    def _setup_db(self, txn: LoggingTransaction) -> None:
+        txn.execute(
+            """
+            CREATE TABLE foobar (
+                stream_id BIGINT NOT NULL,
+                data TEXT
+            );
+            """
+        )
+        txn.execute("INSERT INTO foobar VALUES (123, 'hello world');")
+
+    def _create_id_generator(self) -> StreamIdGenerator:
+        def _create(conn: LoggingDatabaseConnection) -> StreamIdGenerator:
+            return StreamIdGenerator(
+                db_conn=conn,
+                table="foobar",
+                column="stream_id",
+            )
+
+        return self.get_success_or_raise(self.db_pool.runWithConnection(_create))
+
+    def test_initial_value(self) -> None:
+        """Check that we read the current token from the DB."""
+        id_gen = self._create_id_generator()
+        self.assertEqual(id_gen.get_current_token(), 123)
+
+    def test_single_gen_next(self) -> None:
+        """Check that we correctly increment the current token from the DB."""
+        id_gen = self._create_id_generator()
+
+        async def test_gen_next() -> None:
+            async with id_gen.get_next() as next_id:
+                # We haven't persisted `next_id` yet; current token is still 123
+                self.assertEqual(id_gen.get_current_token(), 123)
+                # But we did learn what the next value is
+                self.assertEqual(next_id, 124)
+
+            # Once the context manager closes we assume that the `next_id` has been
+            # written to the DB.
+            self.assertEqual(id_gen.get_current_token(), 124)
+
+        self.get_success(test_gen_next())
+
+    def test_multiple_gen_nexts(self) -> None:
+        """Check that we handle overlapping calls to gen_next sensibly."""
+        id_gen = self._create_id_generator()
+
+        async def test_gen_next() -> None:
+            ctx1 = id_gen.get_next()
+            ctx2 = id_gen.get_next()
+            ctx3 = id_gen.get_next()
+
+            # Request three new stream IDs.
+            self.assertEqual(await ctx1.__aenter__(), 124)
+            self.assertEqual(await ctx2.__aenter__(), 125)
+            self.assertEqual(await ctx3.__aenter__(), 126)
+
+            # None are persisted: current token unchanged.
+            self.assertEqual(id_gen.get_current_token(), 123)
+
+            # Persist each in turn.
+            await ctx1.__aexit__(None, None, None)
+            self.assertEqual(id_gen.get_current_token(), 124)
+            await ctx2.__aexit__(None, None, None)
+            self.assertEqual(id_gen.get_current_token(), 125)
+            await ctx3.__aexit__(None, None, None)
+            self.assertEqual(id_gen.get_current_token(), 126)
+
+        self.get_success(test_gen_next())
+
+    def test_multiple_gen_nexts_closed_in_different_order(self) -> None:
+        """Check that we handle overlapping calls to gen_next, even when their IDs
+        created and persisted in different orders."""
+        id_gen = self._create_id_generator()
+
+        async def test_gen_next() -> None:
+            ctx1 = id_gen.get_next()
+            ctx2 = id_gen.get_next()
+            ctx3 = id_gen.get_next()
+
+            # Request three new stream IDs.
+            self.assertEqual(await ctx1.__aenter__(), 124)
+            self.assertEqual(await ctx2.__aenter__(), 125)
+            self.assertEqual(await ctx3.__aenter__(), 126)
+
+            # None are persisted: current token unchanged.
+            self.assertEqual(id_gen.get_current_token(), 123)
+
+            # Persist them in a different order, starting with 126 from ctx3.
+            await ctx3.__aexit__(None, None, None)
+            # We haven't persisted 124 from ctx1 yet---current token is still 123.
+            self.assertEqual(id_gen.get_current_token(), 123)
+
+            # Now persist 124 from ctx1.
+            await ctx1.__aexit__(None, None, None)
+            # Current token is then 124, waiting for 125 to be persisted.
+            self.assertEqual(id_gen.get_current_token(), 124)
+
+            # Finally persist 125 from ctx2.
+            await ctx2.__aexit__(None, None, None)
+            # Current token is then 126 (skipping over 125).
+            self.assertEqual(id_gen.get_current_token(), 126)
+
+        self.get_success(test_gen_next())
+
+    def test_gen_next_while_still_waiting_for_persistence(self) -> None:
+        """Check that we handle overlapping calls to gen_next."""
+        id_gen = self._create_id_generator()
+
+        async def test_gen_next() -> None:
+            ctx1 = id_gen.get_next()
+            ctx2 = id_gen.get_next()
+            ctx3 = id_gen.get_next()
+
+            # Request two new stream IDs.
+            self.assertEqual(await ctx1.__aenter__(), 124)
+            self.assertEqual(await ctx2.__aenter__(), 125)
+
+            # Persist ctx2 first.
+            await ctx2.__aexit__(None, None, None)
+            # Still waiting on ctx1's ID to be persisted.
+            self.assertEqual(id_gen.get_current_token(), 123)
+
+            # Now request a third stream ID. It should be 126 (the smallest ID that
+            # we've not yet handed out.)
+            self.assertEqual(await ctx3.__aenter__(), 126)
+
+        self.get_success(test_gen_next())
+
+
 class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
     if not USE_POSTGRES_FOR_TESTS:
         skip = "Requires Postgres"
@@ -48,9 +190,9 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         )
 
     def _create_id_generator(
-        self, instance_name="master", writers: Optional[List[str]] = None
+        self, instance_name: str = "master", writers: Optional[List[str]] = None
     ) -> MultiWriterIdGenerator:
-        def _create(conn):
+        def _create(conn: LoggingDatabaseConnection) -> MultiWriterIdGenerator:
             return MultiWriterIdGenerator(
                 conn,
                 self.db_pool,
@@ -446,7 +588,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         self._insert_row_with_id("master", 3)
 
         # Now we add a row *without* updating the stream ID
-        def _insert(txn):
+        def _insert(txn: Cursor) -> None:
             txn.execute("INSERT INTO foobar VALUES (26, 'master')")
 
         self.get_success(self.db_pool.runInteraction("_insert", _insert))
@@ -481,9 +623,9 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
         )
 
     def _create_id_generator(
-        self, instance_name="master", writers: Optional[List[str]] = None
+        self, instance_name: str = "master", writers: Optional[List[str]] = None
     ) -> MultiWriterIdGenerator:
-        def _create(conn):
+        def _create(conn: LoggingDatabaseConnection) -> MultiWriterIdGenerator:
             return MultiWriterIdGenerator(
                 conn,
                 self.db_pool,
@@ -617,9 +759,9 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
         )
 
     def _create_id_generator(
-        self, instance_name="master", writers: Optional[List[str]] = None
+        self, instance_name: str = "master", writers: Optional[List[str]] = None
     ) -> MultiWriterIdGenerator:
-        def _create(conn):
+        def _create(conn: LoggingDatabaseConnection) -> MultiWriterIdGenerator:
             return MultiWriterIdGenerator(
                 conn,
                 self.db_pool,
@@ -641,7 +783,7 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
         instance_name: str,
         number: int,
         update_stream_table: bool = True,
-    ):
+    ) -> None:
         """Insert N rows as the given instance, inserting with stream IDs pulled
         from the postgres sequence.
         """