diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index 867dbd6001..c888d1ff01 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -156,6 +156,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
result = self.successResultOf(
defer.ensureDeferred(self.handler.query_room_alias_exists(room_alias))
)
+ assert result is not None
self.mock_as_api.query_alias.assert_called_once_with(
interested_service, room_alias_str
diff --git a/tests/http/server/_base.py b/tests/http/server/_base.py
index 36472e57a8..d524c183f8 100644
--- a/tests/http/server/_base.py
+++ b/tests/http/server/_base.py
@@ -335,7 +335,7 @@ class Deferred__next__Patch:
self._request_number = request_number
self._seen_awaits = seen_awaits
- self._original_Deferred___next__ = Deferred.__next__
+ self._original_Deferred___next__ = Deferred.__next__ # type: ignore[misc,unused-ignore]
# The number of `await`s on `Deferred`s we have seen so far.
self.awaits_seen = 0
diff --git a/tests/http/test_matrixfederationclient.py b/tests/http/test_matrixfederationclient.py
index ab94f3f67a..bf1d287699 100644
--- a/tests/http/test_matrixfederationclient.py
+++ b/tests/http/test_matrixfederationclient.py
@@ -70,7 +70,7 @@ class FederationClientTests(HomeserverTestCase):
"""
@defer.inlineCallbacks
- def do_request() -> Generator["Deferred[object]", object, object]:
+ def do_request() -> Generator["Deferred[Any]", object, object]:
with LoggingContext("one") as context:
fetch_d = defer.ensureDeferred(
self.cl.get_json("testserv:8008", "foo/bar")
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index 9174fb0964..fd53b0644c 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -259,8 +259,9 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
id_gen = self._create_id_generator()
- # The table is empty so we expect an empty map for positions
- self.assertEqual(id_gen.get_positions(), {})
+ # The table is empty so we expect the map for positions to have a dummy
+ # minimum value.
+ self.assertEqual(id_gen.get_positions(), {"master": 1})
def test_single_instance(self) -> None:
"""Test that reads and writes from a single process are handled
@@ -349,15 +350,12 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
first_id_gen = self._create_id_generator("first", writers=["first", "second"])
second_id_gen = self._create_id_generator("second", writers=["first", "second"])
- # The first ID gen will notice that it can advance its token to 7 as it
- # has no in progress writes...
self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
- self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3)
+ self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
- # ... but the second ID gen doesn't know that.
self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7})
- self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 3)
+ self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 7)
self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
# Try allocating a new ID gen and check that we only see position
@@ -398,6 +396,56 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
second_id_gen.advance("first", 8)
self.assertEqual(second_id_gen.get_positions(), {"first": 8, "second": 9})
+ def test_multi_instance_empty_row(self) -> None:
+ """Test that reads and writes from multiple processes are handled
+ correctly, when one of the writers starts without any rows.
+ """
+ # Insert some rows for two out of three of the ID gens.
+ self._insert_rows("first", 3)
+ self._insert_rows("second", 4)
+
+ first_id_gen = self._create_id_generator(
+ "first", writers=["first", "second", "third"]
+ )
+ second_id_gen = self._create_id_generator(
+ "second", writers=["first", "second", "third"]
+ )
+ third_id_gen = self._create_id_generator(
+ "third", writers=["first", "second", "third"]
+ )
+
+ self.assertEqual(
+ first_id_gen.get_positions(), {"first": 3, "second": 7, "third": 7}
+ )
+ self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
+ self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
+ self.assertEqual(first_id_gen.get_current_token_for_writer("third"), 7)
+
+ self.assertEqual(
+ second_id_gen.get_positions(), {"first": 3, "second": 7, "third": 7}
+ )
+ self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 7)
+ self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
+ self.assertEqual(second_id_gen.get_current_token_for_writer("third"), 7)
+
+ # Try allocating a new ID gen and check that we only see position
+ # advanced after we leave the context manager.
+
+ async def _get_next_async() -> None:
+ async with third_id_gen.get_next() as stream_id:
+ self.assertEqual(stream_id, 8)
+
+ self.assertEqual(
+ third_id_gen.get_positions(), {"first": 3, "second": 7, "third": 7}
+ )
+ self.assertEqual(third_id_gen.get_persisted_upto_position(), 7)
+
+ self.get_success(_get_next_async())
+
+ self.assertEqual(
+ third_id_gen.get_positions(), {"first": 3, "second": 7, "third": 8}
+ )
+
def test_get_next_txn(self) -> None:
"""Test that the `get_next_txn` function works correctly."""
@@ -600,6 +648,70 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
with self.assertRaises(IncorrectDatabaseSetup):
self._create_id_generator("first")
+ def test_minimal_local_token(self) -> None:
+ self._insert_rows("first", 3)
+ self._insert_rows("second", 4)
+
+ first_id_gen = self._create_id_generator("first", writers=["first", "second"])
+ second_id_gen = self._create_id_generator("second", writers=["first", "second"])
+
+ self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
+ self.assertEqual(first_id_gen.get_minimal_local_current_token(), 3)
+
+ self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7})
+ self.assertEqual(second_id_gen.get_minimal_local_current_token(), 7)
+
+ def test_current_token_gap(self) -> None:
+ """Test that getting the current token for a writer returns the maximal
+ token when there are no writes.
+ """
+ self._insert_rows("first", 3)
+ self._insert_rows("second", 4)
+
+ first_id_gen = self._create_id_generator(
+ "first", writers=["first", "second", "third"]
+ )
+ second_id_gen = self._create_id_generator(
+ "second", writers=["first", "second", "third"]
+ )
+
+ self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 7)
+ self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
+ self.assertEqual(second_id_gen.get_current_token(), 7)
+
+ # Check that the first ID gen advancing causes the second ID gen to
+ # advance (as the second ID gen has nothing in flight).
+
+ async def _get_next_async() -> None:
+ async with first_id_gen.get_next_mult(2):
+ pass
+
+ self.get_success(_get_next_async())
+ second_id_gen.advance("first", 9)
+
+ self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 9)
+ self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 9)
+ self.assertEqual(second_id_gen.get_current_token(), 7)
+
+ # Check that the first ID gen advancing doesn't advance the second ID
+ # gen when the second ID gen has stuff in flight.
+ self.get_success(_get_next_async())
+
+ ctxmgr = second_id_gen.get_next()
+ self.get_success(ctxmgr.__aenter__())
+
+ second_id_gen.advance("first", 11)
+
+ self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 11)
+ self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 9)
+ self.assertEqual(second_id_gen.get_current_token(), 7)
+
+ self.get_success(ctxmgr.__aexit__(None, None, None))
+
+ self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 11)
+ self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 12)
+ self.assertEqual(second_id_gen.get_current_token(), 7)
+
class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
"""Tests MultiWriterIdGenerator that produce *negative* stream IDs."""
@@ -712,8 +824,8 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.get_success(_get_next_async())
- self.assertEqual(id_gen_1.get_positions(), {"first": -1})
- self.assertEqual(id_gen_2.get_positions(), {"first": -1})
+ self.assertEqual(id_gen_1.get_positions(), {"first": -1, "second": -1})
+ self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -1})
self.assertEqual(id_gen_1.get_persisted_upto_position(), -1)
self.assertEqual(id_gen_2.get_persisted_upto_position(), -1)
@@ -822,11 +934,11 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
second_id_gen = self._create_id_generator("second", writers=["first", "second"])
self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 6})
- self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3)
- self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 6)
+ self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
+ self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
self.assertEqual(first_id_gen.get_persisted_upto_position(), 7)
self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7})
- self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 3)
+ self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 7)
self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
self.assertEqual(second_id_gen.get_persisted_upto_position(), 7)
diff --git a/tests/unittest.py b/tests/unittest.py
index 99ad02eb06..79c47fc3cc 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -30,6 +30,7 @@ from typing import (
Generic,
Iterable,
List,
+ Mapping,
NoReturn,
Optional,
Tuple,
@@ -251,7 +252,7 @@ class TestCase(unittest.TestCase):
except AssertionError as e:
raise (type(e))(f"Assert error for '.{key}':") from e
- def assert_dict(self, required: dict, actual: dict) -> None:
+ def assert_dict(self, required: Mapping, actual: Mapping) -> None:
"""Does a partial assert of a dict.
Args:
|