summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/handlers/test_appservice.py1
-rw-r--r--tests/http/server/_base.py2
-rw-r--r--tests/http/test_matrixfederationclient.py2
-rw-r--r--tests/http/test_proxyagent.py14
-rw-r--r--tests/replication/tcp/streams/test_events.py91
-rw-r--r--tests/storage/test_id_generators.py136
-rw-r--r--tests/unittest.py3
7 files changed, 205 insertions, 44 deletions
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index 867dbd6001..c888d1ff01 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -156,6 +156,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
         result = self.successResultOf(
             defer.ensureDeferred(self.handler.query_room_alias_exists(room_alias))
         )
+        assert result is not None
 
         self.mock_as_api.query_alias.assert_called_once_with(
             interested_service, room_alias_str
diff --git a/tests/http/server/_base.py b/tests/http/server/_base.py
index 36472e57a8..d524c183f8 100644
--- a/tests/http/server/_base.py
+++ b/tests/http/server/_base.py
@@ -335,7 +335,7 @@ class Deferred__next__Patch:
         self._request_number = request_number
         self._seen_awaits = seen_awaits
 
-        self._original_Deferred___next__ = Deferred.__next__
+        self._original_Deferred___next__ = Deferred.__next__  # type: ignore[misc,unused-ignore]
 
         # The number of `await`s on `Deferred`s we have seen so far.
         self.awaits_seen = 0
diff --git a/tests/http/test_matrixfederationclient.py b/tests/http/test_matrixfederationclient.py
index ab94f3f67a..bf1d287699 100644
--- a/tests/http/test_matrixfederationclient.py
+++ b/tests/http/test_matrixfederationclient.py
@@ -70,7 +70,7 @@ class FederationClientTests(HomeserverTestCase):
         """
 
         @defer.inlineCallbacks
-        def do_request() -> Generator["Deferred[object]", object, object]:
+        def do_request() -> Generator["Deferred[Any]", object, object]:
             with LoggingContext("one") as context:
                 fetch_d = defer.ensureDeferred(
                     self.cl.get_json("testserv:8008", "foo/bar")
diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py
index 8164b0b78e..b48c2c293a 100644
--- a/tests/http/test_proxyagent.py
+++ b/tests/http/test_proxyagent.py
@@ -217,6 +217,20 @@ class ProxyParserTests(TestCase):
         )
 
 
+class TestBasicProxyCredentials(TestCase):
+    def test_long_user_pass_string_encoded_without_newlines(self) -> None:
+        """Reproduces https://github.com/matrix-org/synapse/pull/16504."""
+        creds = BasicProxyCredentials(
+            b"looooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooonguser:pass@proxy.local:9988"
+        )
+        auth_value = creds.as_proxy_authorization_value()
+        self.assertNotIn(b"\n", auth_value)
+        self.assertEqual(
+            creds.as_proxy_authorization_value(),
+            b"Basic: bG9vb29vb29vb29vb29vb29vb29vb29vb29vb29vb29vb29vb29vb29vb29vb29vb29vb29vb29vb29vb29vb29vb29vb29vbmd1c2VyOnBhc3M=",
+        )
+
+
 class MatrixFederationAgentTests(TestCase):
     def setUp(self) -> None:
         self.reactor = ThreadedMemoryReactorClock()
diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py
index 128fc3e046..b8ab4ee54b 100644
--- a/tests/replication/tcp/streams/test_events.py
+++ b/tests/replication/tcp/streams/test_events.py
@@ -14,6 +14,8 @@
 
 from typing import Any, List, Optional
 
+from parameterized import parameterized
+
 from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.constants import EventTypes, Membership
@@ -21,6 +23,8 @@ from synapse.events import EventBase
 from synapse.replication.tcp.commands import RdataCommand
 from synapse.replication.tcp.streams._base import _STREAM_UPDATE_TARGET_ROW_COUNT
 from synapse.replication.tcp.streams.events import (
+    _MAX_STATE_UPDATES_PER_ROOM,
+    EventsStreamAllStateRow,
     EventsStreamCurrentStateRow,
     EventsStreamEventRow,
     EventsStreamRow,
@@ -106,11 +110,21 @@ class EventsStreamTestCase(BaseStreamTestCase):
 
         self.assertEqual([], received_rows)
 
-    def test_update_function_huge_state_change(self) -> None:
+    @parameterized.expand(
+        [(_STREAM_UPDATE_TARGET_ROW_COUNT, False), (_MAX_STATE_UPDATES_PER_ROOM, True)]
+    )
+    def test_update_function_huge_state_change(
+        self, num_state_changes: int, collapse_state_changes: bool
+    ) -> None:
         """Test replication with many state events
 
         Ensures that all events are correctly replicated when there are lots of
         state change rows to be replicated.
+
+        Args:
+            num_state_changes: The number of state changes to create.
+            collapse_state_changes: Whether the state changes are expected to be
+                collapsed or not.
         """
 
         # we want to generate lots of state changes at a single stream ID.
@@ -145,7 +159,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
 
         events = [
             self._inject_state_event(sender=OTHER_USER)
-            for _ in range(_STREAM_UPDATE_TARGET_ROW_COUNT)
+            for _ in range(num_state_changes)
         ]
 
         self.replicate()
@@ -202,8 +216,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
             row for row in self.test_handler.received_rdata_rows if row[0] == "events"
         ]
 
-        # first check the first two rows, which should be state1
-
+        # first check the first two rows, which should be the state1 event.
         stream_name, token, row = received_rows.pop(0)
         self.assertEqual("events", stream_name)
         self.assertIsInstance(row, EventsStreamRow)
@@ -217,7 +230,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
         self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
         self.assertEqual(row.data.event_id, state1.event_id)
 
-        # now the last two rows, which should be state2
+        # now the last two rows, which should be the state2 event.
         stream_name, token, row = received_rows.pop(-2)
         self.assertEqual("events", stream_name)
         self.assertIsInstance(row, EventsStreamRow)
@@ -231,34 +244,54 @@ class EventsStreamTestCase(BaseStreamTestCase):
         self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
         self.assertEqual(row.data.event_id, state2.event_id)
 
-        # that should leave us with the rows for the PL event
-        self.assertEqual(len(received_rows), len(events) + 2)
+        # Based on the number of
+        if collapse_state_changes:
+            # that should leave us with the rows for the PL event, the state changes
+            # get collapsed into a single row.
+            self.assertEqual(len(received_rows), 2)
 
-        stream_name, token, row = received_rows.pop(0)
-        self.assertEqual("events", stream_name)
-        self.assertIsInstance(row, EventsStreamRow)
-        self.assertEqual(row.type, "ev")
-        self.assertIsInstance(row.data, EventsStreamEventRow)
-        self.assertEqual(row.data.event_id, pl_event.event_id)
+            stream_name, token, row = received_rows.pop(0)
+            self.assertEqual("events", stream_name)
+            self.assertIsInstance(row, EventsStreamRow)
+            self.assertEqual(row.type, "ev")
+            self.assertIsInstance(row.data, EventsStreamEventRow)
+            self.assertEqual(row.data.event_id, pl_event.event_id)
 
-        # the state rows are unsorted
-        state_rows: List[EventsStreamCurrentStateRow] = []
-        for stream_name, _, row in received_rows:
+            stream_name, token, row = received_rows.pop(0)
+            self.assertIsInstance(row, EventsStreamRow)
+            self.assertEqual(row.type, "state-all")
+            self.assertIsInstance(row.data, EventsStreamAllStateRow)
+            self.assertEqual(row.data.room_id, state2.room_id)
+
+        else:
+            # that should leave us with the rows for the PL event
+            self.assertEqual(len(received_rows), len(events) + 2)
+
+            stream_name, token, row = received_rows.pop(0)
             self.assertEqual("events", stream_name)
             self.assertIsInstance(row, EventsStreamRow)
-            self.assertEqual(row.type, "state")
-            self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
-            state_rows.append(row.data)
-
-        state_rows.sort(key=lambda r: r.state_key)
-
-        sr = state_rows.pop(0)
-        self.assertEqual(sr.type, EventTypes.PowerLevels)
-        self.assertEqual(sr.event_id, pl_event.event_id)
-        for sr in state_rows:
-            self.assertEqual(sr.type, "test_state_event")
-            # "None" indicates the state has been deleted
-            self.assertIsNone(sr.event_id)
+            self.assertEqual(row.type, "ev")
+            self.assertIsInstance(row.data, EventsStreamEventRow)
+            self.assertEqual(row.data.event_id, pl_event.event_id)
+
+            # the state rows are unsorted
+            state_rows: List[EventsStreamCurrentStateRow] = []
+            for stream_name, _, row in received_rows:
+                self.assertEqual("events", stream_name)
+                self.assertIsInstance(row, EventsStreamRow)
+                self.assertEqual(row.type, "state")
+                self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
+                state_rows.append(row.data)
+
+            state_rows.sort(key=lambda r: r.state_key)
+
+            sr = state_rows.pop(0)
+            self.assertEqual(sr.type, EventTypes.PowerLevels)
+            self.assertEqual(sr.event_id, pl_event.event_id)
+            for sr in state_rows:
+                self.assertEqual(sr.type, "test_state_event")
+                # "None" indicates the state has been deleted
+                self.assertIsNone(sr.event_id)
 
     def test_update_function_state_row_limit(self) -> None:
         """Test replication with many state events over several stream ids."""
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index 9174fb0964..fd53b0644c 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -259,8 +259,9 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
 
         id_gen = self._create_id_generator()
 
-        # The table is empty so we expect an empty map for positions
-        self.assertEqual(id_gen.get_positions(), {})
+        # The table is empty so we expect the map for positions to have a dummy
+        # minimum value.
+        self.assertEqual(id_gen.get_positions(), {"master": 1})
 
     def test_single_instance(self) -> None:
         """Test that reads and writes from a single process are handled
@@ -349,15 +350,12 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         first_id_gen = self._create_id_generator("first", writers=["first", "second"])
         second_id_gen = self._create_id_generator("second", writers=["first", "second"])
 
-        # The first ID gen will notice that it can advance its token to 7 as it
-        # has no in progress writes...
         self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
-        self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3)
+        self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
         self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
 
-        # ... but the second ID gen doesn't know that.
         self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7})
-        self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 3)
+        self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 7)
         self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
 
         # Try allocating a new ID gen and check that we only see position
@@ -398,6 +396,56 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         second_id_gen.advance("first", 8)
         self.assertEqual(second_id_gen.get_positions(), {"first": 8, "second": 9})
 
+    def test_multi_instance_empty_row(self) -> None:
+        """Test that reads and writes from multiple processes are handled
+        correctly, when one of the writers starts without any rows.
+        """
+        # Insert some rows for two out of three of the ID gens.
+        self._insert_rows("first", 3)
+        self._insert_rows("second", 4)
+
+        first_id_gen = self._create_id_generator(
+            "first", writers=["first", "second", "third"]
+        )
+        second_id_gen = self._create_id_generator(
+            "second", writers=["first", "second", "third"]
+        )
+        third_id_gen = self._create_id_generator(
+            "third", writers=["first", "second", "third"]
+        )
+
+        self.assertEqual(
+            first_id_gen.get_positions(), {"first": 3, "second": 7, "third": 7}
+        )
+        self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
+        self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
+        self.assertEqual(first_id_gen.get_current_token_for_writer("third"), 7)
+
+        self.assertEqual(
+            second_id_gen.get_positions(), {"first": 3, "second": 7, "third": 7}
+        )
+        self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 7)
+        self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
+        self.assertEqual(second_id_gen.get_current_token_for_writer("third"), 7)
+
+        # Try allocating a new ID gen and check that we only see position
+        # advanced after we leave the context manager.
+
+        async def _get_next_async() -> None:
+            async with third_id_gen.get_next() as stream_id:
+                self.assertEqual(stream_id, 8)
+
+                self.assertEqual(
+                    third_id_gen.get_positions(), {"first": 3, "second": 7, "third": 7}
+                )
+                self.assertEqual(third_id_gen.get_persisted_upto_position(), 7)
+
+        self.get_success(_get_next_async())
+
+        self.assertEqual(
+            third_id_gen.get_positions(), {"first": 3, "second": 7, "third": 8}
+        )
+
     def test_get_next_txn(self) -> None:
         """Test that the `get_next_txn` function works correctly."""
 
@@ -600,6 +648,70 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         with self.assertRaises(IncorrectDatabaseSetup):
             self._create_id_generator("first")
 
+    def test_minimal_local_token(self) -> None:
+        self._insert_rows("first", 3)
+        self._insert_rows("second", 4)
+
+        first_id_gen = self._create_id_generator("first", writers=["first", "second"])
+        second_id_gen = self._create_id_generator("second", writers=["first", "second"])
+
+        self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
+        self.assertEqual(first_id_gen.get_minimal_local_current_token(), 3)
+
+        self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7})
+        self.assertEqual(second_id_gen.get_minimal_local_current_token(), 7)
+
+    def test_current_token_gap(self) -> None:
+        """Test that getting the current token for a writer returns the maximal
+        token when there are no writes.
+        """
+        self._insert_rows("first", 3)
+        self._insert_rows("second", 4)
+
+        first_id_gen = self._create_id_generator(
+            "first", writers=["first", "second", "third"]
+        )
+        second_id_gen = self._create_id_generator(
+            "second", writers=["first", "second", "third"]
+        )
+
+        self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 7)
+        self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
+        self.assertEqual(second_id_gen.get_current_token(), 7)
+
+        # Check that the first ID gen advancing causes the second ID gen to
+        # advance (as the second ID gen has nothing in flight).
+
+        async def _get_next_async() -> None:
+            async with first_id_gen.get_next_mult(2):
+                pass
+
+        self.get_success(_get_next_async())
+        second_id_gen.advance("first", 9)
+
+        self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 9)
+        self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 9)
+        self.assertEqual(second_id_gen.get_current_token(), 7)
+
+        # Check that the first ID gen advancing doesn't advance the second ID
+        # gen when the second ID gen has stuff in flight.
+        self.get_success(_get_next_async())
+
+        ctxmgr = second_id_gen.get_next()
+        self.get_success(ctxmgr.__aenter__())
+
+        second_id_gen.advance("first", 11)
+
+        self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 11)
+        self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 9)
+        self.assertEqual(second_id_gen.get_current_token(), 7)
+
+        self.get_success(ctxmgr.__aexit__(None, None, None))
+
+        self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 11)
+        self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 12)
+        self.assertEqual(second_id_gen.get_current_token(), 7)
+
 
 class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
     """Tests MultiWriterIdGenerator that produce *negative* stream IDs."""
@@ -712,8 +824,8 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
 
         self.get_success(_get_next_async())
 
-        self.assertEqual(id_gen_1.get_positions(), {"first": -1})
-        self.assertEqual(id_gen_2.get_positions(), {"first": -1})
+        self.assertEqual(id_gen_1.get_positions(), {"first": -1, "second": -1})
+        self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -1})
         self.assertEqual(id_gen_1.get_persisted_upto_position(), -1)
         self.assertEqual(id_gen_2.get_persisted_upto_position(), -1)
 
@@ -822,11 +934,11 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
         second_id_gen = self._create_id_generator("second", writers=["first", "second"])
 
         self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 6})
-        self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3)
-        self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 6)
+        self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
+        self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
         self.assertEqual(first_id_gen.get_persisted_upto_position(), 7)
 
         self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7})
-        self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 3)
+        self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 7)
         self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
         self.assertEqual(second_id_gen.get_persisted_upto_position(), 7)
diff --git a/tests/unittest.py b/tests/unittest.py
index 99ad02eb06..79c47fc3cc 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -30,6 +30,7 @@ from typing import (
     Generic,
     Iterable,
     List,
+    Mapping,
     NoReturn,
     Optional,
     Tuple,
@@ -251,7 +252,7 @@ class TestCase(unittest.TestCase):
             except AssertionError as e:
                 raise (type(e))(f"Assert error for '.{key}':") from e
 
-    def assert_dict(self, required: dict, actual: dict) -> None:
+    def assert_dict(self, required: Mapping, actual: Mapping) -> None:
         """Does a partial assert of a dict.
 
         Args: