summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/handlers/test_appservice.py1
-rw-r--r--tests/http/server/_base.py2
-rw-r--r--tests/http/test_matrixfederationclient.py2
-rw-r--r--tests/storage/test_id_generators.py136
-rw-r--r--tests/unittest.py3
5 files changed, 129 insertions, 15 deletions
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index 867dbd6001..c888d1ff01 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -156,6 +156,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
         result = self.successResultOf(
             defer.ensureDeferred(self.handler.query_room_alias_exists(room_alias))
         )
+        assert result is not None
 
         self.mock_as_api.query_alias.assert_called_once_with(
             interested_service, room_alias_str
diff --git a/tests/http/server/_base.py b/tests/http/server/_base.py
index 36472e57a8..d524c183f8 100644
--- a/tests/http/server/_base.py
+++ b/tests/http/server/_base.py
@@ -335,7 +335,7 @@ class Deferred__next__Patch:
         self._request_number = request_number
         self._seen_awaits = seen_awaits
 
-        self._original_Deferred___next__ = Deferred.__next__
+        self._original_Deferred___next__ = Deferred.__next__  # type: ignore[misc,unused-ignore]
 
         # The number of `await`s on `Deferred`s we have seen so far.
         self.awaits_seen = 0
diff --git a/tests/http/test_matrixfederationclient.py b/tests/http/test_matrixfederationclient.py
index ab94f3f67a..bf1d287699 100644
--- a/tests/http/test_matrixfederationclient.py
+++ b/tests/http/test_matrixfederationclient.py
@@ -70,7 +70,7 @@ class FederationClientTests(HomeserverTestCase):
         """
 
         @defer.inlineCallbacks
-        def do_request() -> Generator["Deferred[object]", object, object]:
+        def do_request() -> Generator["Deferred[Any]", object, object]:
             with LoggingContext("one") as context:
                 fetch_d = defer.ensureDeferred(
                     self.cl.get_json("testserv:8008", "foo/bar")
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index 9174fb0964..fd53b0644c 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -259,8 +259,9 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
 
         id_gen = self._create_id_generator()
 
-        # The table is empty so we expect an empty map for positions
-        self.assertEqual(id_gen.get_positions(), {})
+        # The table is empty so we expect the map for positions to have a dummy
+        # minimum value.
+        self.assertEqual(id_gen.get_positions(), {"master": 1})
 
     def test_single_instance(self) -> None:
         """Test that reads and writes from a single process are handled
@@ -349,15 +350,12 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         first_id_gen = self._create_id_generator("first", writers=["first", "second"])
         second_id_gen = self._create_id_generator("second", writers=["first", "second"])
 
-        # The first ID gen will notice that it can advance its token to 7 as it
-        # has no in progress writes...
         self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
-        self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3)
+        self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
         self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
 
-        # ... but the second ID gen doesn't know that.
         self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7})
-        self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 3)
+        self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 7)
         self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
 
         # Try allocating a new ID gen and check that we only see position
@@ -398,6 +396,56 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         second_id_gen.advance("first", 8)
         self.assertEqual(second_id_gen.get_positions(), {"first": 8, "second": 9})
 
+    def test_multi_instance_empty_row(self) -> None:
+        """Test that reads and writes from multiple processes are handled
+        correctly, when one of the writers starts without any rows.
+        """
+        # Insert some rows for two out of three of the ID gens.
+        self._insert_rows("first", 3)
+        self._insert_rows("second", 4)
+
+        first_id_gen = self._create_id_generator(
+            "first", writers=["first", "second", "third"]
+        )
+        second_id_gen = self._create_id_generator(
+            "second", writers=["first", "second", "third"]
+        )
+        third_id_gen = self._create_id_generator(
+            "third", writers=["first", "second", "third"]
+        )
+
+        self.assertEqual(
+            first_id_gen.get_positions(), {"first": 3, "second": 7, "third": 7}
+        )
+        self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
+        self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
+        self.assertEqual(first_id_gen.get_current_token_for_writer("third"), 7)
+
+        self.assertEqual(
+            second_id_gen.get_positions(), {"first": 3, "second": 7, "third": 7}
+        )
+        self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 7)
+        self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
+        self.assertEqual(second_id_gen.get_current_token_for_writer("third"), 7)
+
+        # Try allocating a new ID gen and check that we only see position
+        # advanced after we leave the context manager.
+
+        async def _get_next_async() -> None:
+            async with third_id_gen.get_next() as stream_id:
+                self.assertEqual(stream_id, 8)
+
+                self.assertEqual(
+                    third_id_gen.get_positions(), {"first": 3, "second": 7, "third": 7}
+                )
+                self.assertEqual(third_id_gen.get_persisted_upto_position(), 7)
+
+        self.get_success(_get_next_async())
+
+        self.assertEqual(
+            third_id_gen.get_positions(), {"first": 3, "second": 7, "third": 8}
+        )
+
     def test_get_next_txn(self) -> None:
         """Test that the `get_next_txn` function works correctly."""
 
@@ -600,6 +648,70 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         with self.assertRaises(IncorrectDatabaseSetup):
             self._create_id_generator("first")
 
+    def test_minimal_local_token(self) -> None:
+        self._insert_rows("first", 3)
+        self._insert_rows("second", 4)
+
+        first_id_gen = self._create_id_generator("first", writers=["first", "second"])
+        second_id_gen = self._create_id_generator("second", writers=["first", "second"])
+
+        self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
+        self.assertEqual(first_id_gen.get_minimal_local_current_token(), 3)
+
+        self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7})
+        self.assertEqual(second_id_gen.get_minimal_local_current_token(), 7)
+
+    def test_current_token_gap(self) -> None:
+        """Test that getting the current token for a writer returns the maximal
+        token when there are no writes.
+        """
+        self._insert_rows("first", 3)
+        self._insert_rows("second", 4)
+
+        first_id_gen = self._create_id_generator(
+            "first", writers=["first", "second", "third"]
+        )
+        second_id_gen = self._create_id_generator(
+            "second", writers=["first", "second", "third"]
+        )
+
+        self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 7)
+        self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
+        self.assertEqual(second_id_gen.get_current_token(), 7)
+
+        # Check that the first ID gen advancing causes the second ID gen to
+        # advance (as the second ID gen has nothing in flight).
+
+        async def _get_next_async() -> None:
+            async with first_id_gen.get_next_mult(2):
+                pass
+
+        self.get_success(_get_next_async())
+        second_id_gen.advance("first", 9)
+
+        self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 9)
+        self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 9)
+        self.assertEqual(second_id_gen.get_current_token(), 7)
+
+        # Check that the first ID gen advancing doesn't advance the second ID
+        # gen when the second ID gen has stuff in flight.
+        self.get_success(_get_next_async())
+
+        ctxmgr = second_id_gen.get_next()
+        self.get_success(ctxmgr.__aenter__())
+
+        second_id_gen.advance("first", 11)
+
+        self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 11)
+        self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 9)
+        self.assertEqual(second_id_gen.get_current_token(), 7)
+
+        self.get_success(ctxmgr.__aexit__(None, None, None))
+
+        self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 11)
+        self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 12)
+        self.assertEqual(second_id_gen.get_current_token(), 7)
+
 
 class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
     """Tests MultiWriterIdGenerator that produce *negative* stream IDs."""
@@ -712,8 +824,8 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
 
         self.get_success(_get_next_async())
 
-        self.assertEqual(id_gen_1.get_positions(), {"first": -1})
-        self.assertEqual(id_gen_2.get_positions(), {"first": -1})
+        self.assertEqual(id_gen_1.get_positions(), {"first": -1, "second": -1})
+        self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -1})
         self.assertEqual(id_gen_1.get_persisted_upto_position(), -1)
         self.assertEqual(id_gen_2.get_persisted_upto_position(), -1)
 
@@ -822,11 +934,11 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
         second_id_gen = self._create_id_generator("second", writers=["first", "second"])
 
         self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 6})
-        self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3)
-        self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 6)
+        self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
+        self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
         self.assertEqual(first_id_gen.get_persisted_upto_position(), 7)
 
         self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7})
-        self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 3)
+        self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 7)
         self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
         self.assertEqual(second_id_gen.get_persisted_upto_position(), 7)
diff --git a/tests/unittest.py b/tests/unittest.py
index 99ad02eb06..79c47fc3cc 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -30,6 +30,7 @@ from typing import (
     Generic,
     Iterable,
     List,
+    Mapping,
     NoReturn,
     Optional,
     Tuple,
@@ -251,7 +252,7 @@ class TestCase(unittest.TestCase):
             except AssertionError as e:
                 raise (type(e))(f"Assert error for '.{key}':") from e
 
-    def assert_dict(self, required: dict, actual: dict) -> None:
+    def assert_dict(self, required: Mapping, actual: Mapping) -> None:
         """Does a partial assert of a dict.
 
         Args: