diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index 2d8d1f860f..d6a2b8d274 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -16,15 +16,157 @@ from typing import List, Optional
from twisted.test.proto_helpers import MemoryReactor
from synapse.server import HomeServer
-from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
from synapse.storage.engines import IncorrectDatabaseSetup
-from synapse.storage.util.id_generators import MultiWriterIdGenerator
+from synapse.storage.types import Cursor
+from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.util import Clock
from tests.unittest import HomeserverTestCase
from tests.utils import USE_POSTGRES_FOR_TESTS
+class StreamIdGeneratorTestCase(HomeserverTestCase):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.store = hs.get_datastores().main
+ self.db_pool: DatabasePool = self.store.db_pool
+
+ self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
+
+ def _setup_db(self, txn: LoggingTransaction) -> None:
+ txn.execute(
+ """
+ CREATE TABLE foobar (
+ stream_id BIGINT NOT NULL,
+ data TEXT
+ );
+ """
+ )
+ txn.execute("INSERT INTO foobar VALUES (123, 'hello world');")
+
+ def _create_id_generator(self) -> StreamIdGenerator:
+ def _create(conn: LoggingDatabaseConnection) -> StreamIdGenerator:
+ return StreamIdGenerator(
+ db_conn=conn,
+ table="foobar",
+ column="stream_id",
+ )
+
+ return self.get_success_or_raise(self.db_pool.runWithConnection(_create))
+
+ def test_initial_value(self) -> None:
+ """Check that we read the current token from the DB."""
+ id_gen = self._create_id_generator()
+ self.assertEqual(id_gen.get_current_token(), 123)
+
+ def test_single_gen_next(self) -> None:
+ """Check that we correctly increment the current token from the DB."""
+ id_gen = self._create_id_generator()
+
+ async def test_gen_next() -> None:
+ async with id_gen.get_next() as next_id:
+ # We haven't persisted `next_id` yet; current token is still 123
+ self.assertEqual(id_gen.get_current_token(), 123)
+ # But we did learn what the next value is
+ self.assertEqual(next_id, 124)
+
+ # Once the context manager closes we assume that the `next_id` has been
+ # written to the DB.
+ self.assertEqual(id_gen.get_current_token(), 124)
+
+ self.get_success(test_gen_next())
+
+ def test_multiple_gen_nexts(self) -> None:
+ """Check that we handle overlapping calls to gen_next sensibly."""
+ id_gen = self._create_id_generator()
+
+ async def test_gen_next() -> None:
+ ctx1 = id_gen.get_next()
+ ctx2 = id_gen.get_next()
+ ctx3 = id_gen.get_next()
+
+ # Request three new stream IDs.
+ self.assertEqual(await ctx1.__aenter__(), 124)
+ self.assertEqual(await ctx2.__aenter__(), 125)
+ self.assertEqual(await ctx3.__aenter__(), 126)
+
+ # None are persisted: current token unchanged.
+ self.assertEqual(id_gen.get_current_token(), 123)
+
+ # Persist each in turn.
+ await ctx1.__aexit__(None, None, None)
+ self.assertEqual(id_gen.get_current_token(), 124)
+ await ctx2.__aexit__(None, None, None)
+ self.assertEqual(id_gen.get_current_token(), 125)
+ await ctx3.__aexit__(None, None, None)
+ self.assertEqual(id_gen.get_current_token(), 126)
+
+ self.get_success(test_gen_next())
+
+ def test_multiple_gen_nexts_closed_in_different_order(self) -> None:
+ """Check that we handle overlapping calls to gen_next, even when their IDs
+ created and persisted in different orders."""
+ id_gen = self._create_id_generator()
+
+ async def test_gen_next() -> None:
+ ctx1 = id_gen.get_next()
+ ctx2 = id_gen.get_next()
+ ctx3 = id_gen.get_next()
+
+ # Request three new stream IDs.
+ self.assertEqual(await ctx1.__aenter__(), 124)
+ self.assertEqual(await ctx2.__aenter__(), 125)
+ self.assertEqual(await ctx3.__aenter__(), 126)
+
+ # None are persisted: current token unchanged.
+ self.assertEqual(id_gen.get_current_token(), 123)
+
+ # Persist them in a different order, starting with 126 from ctx3.
+ await ctx3.__aexit__(None, None, None)
+ # We haven't persisted 124 from ctx1 yet---current token is still 123.
+ self.assertEqual(id_gen.get_current_token(), 123)
+
+ # Now persist 124 from ctx1.
+ await ctx1.__aexit__(None, None, None)
+ # Current token is then 124, waiting for 125 to be persisted.
+ self.assertEqual(id_gen.get_current_token(), 124)
+
+ # Finally persist 125 from ctx2.
+ await ctx2.__aexit__(None, None, None)
+ # Current token is then 126 (skipping over 125).
+ self.assertEqual(id_gen.get_current_token(), 126)
+
+ self.get_success(test_gen_next())
+
+ def test_gen_next_while_still_waiting_for_persistence(self) -> None:
+ """Check that we handle overlapping calls to gen_next."""
+ id_gen = self._create_id_generator()
+
+ async def test_gen_next() -> None:
+ ctx1 = id_gen.get_next()
+ ctx2 = id_gen.get_next()
+ ctx3 = id_gen.get_next()
+
+ # Request two new stream IDs.
+ self.assertEqual(await ctx1.__aenter__(), 124)
+ self.assertEqual(await ctx2.__aenter__(), 125)
+
+ # Persist ctx2 first.
+ await ctx2.__aexit__(None, None, None)
+ # Still waiting on ctx1's ID to be persisted.
+ self.assertEqual(id_gen.get_current_token(), 123)
+
+ # Now request a third stream ID. It should be 126 (the smallest ID that
+ # we've not yet handed out.)
+ self.assertEqual(await ctx3.__aenter__(), 126)
+
+ self.get_success(test_gen_next())
+
+
class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
if not USE_POSTGRES_FOR_TESTS:
skip = "Requires Postgres"
@@ -48,9 +190,9 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
)
def _create_id_generator(
- self, instance_name="master", writers: Optional[List[str]] = None
+ self, instance_name: str = "master", writers: Optional[List[str]] = None
) -> MultiWriterIdGenerator:
- def _create(conn):
+ def _create(conn: LoggingDatabaseConnection) -> MultiWriterIdGenerator:
return MultiWriterIdGenerator(
conn,
self.db_pool,
@@ -446,7 +588,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self._insert_row_with_id("master", 3)
# Now we add a row *without* updating the stream ID
- def _insert(txn):
+ def _insert(txn: Cursor) -> None:
txn.execute("INSERT INTO foobar VALUES (26, 'master')")
self.get_success(self.db_pool.runInteraction("_insert", _insert))
@@ -481,9 +623,9 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
)
def _create_id_generator(
- self, instance_name="master", writers: Optional[List[str]] = None
+ self, instance_name: str = "master", writers: Optional[List[str]] = None
) -> MultiWriterIdGenerator:
- def _create(conn):
+ def _create(conn: LoggingDatabaseConnection) -> MultiWriterIdGenerator:
return MultiWriterIdGenerator(
conn,
self.db_pool,
@@ -617,9 +759,9 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
)
def _create_id_generator(
- self, instance_name="master", writers: Optional[List[str]] = None
+ self, instance_name: str = "master", writers: Optional[List[str]] = None
) -> MultiWriterIdGenerator:
- def _create(conn):
+ def _create(conn: LoggingDatabaseConnection) -> MultiWriterIdGenerator:
return MultiWriterIdGenerator(
conn,
self.db_pool,
@@ -641,7 +783,7 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
instance_name: str,
number: int,
update_stream_table: bool = True,
- ):
+ ) -> None:
"""Insert N rows as the given instance, inserting with stream IDs pulled
from the postgres sequence.
"""
|