summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--tests/replication/tcp/streams/test_events.py91
-rw-r--r--tests/storage/test_client_ips.py137
-rw-r--r--tests/util/test_retryutils.py75
3 files changed, 207 insertions, 96 deletions
diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py

index 128fc3e046..b8ab4ee54b 100644 --- a/tests/replication/tcp/streams/test_events.py +++ b/tests/replication/tcp/streams/test_events.py
@@ -14,6 +14,8 @@ from typing import Any, List, Optional +from parameterized import parameterized + from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import EventTypes, Membership @@ -21,6 +23,8 @@ from synapse.events import EventBase from synapse.replication.tcp.commands import RdataCommand from synapse.replication.tcp.streams._base import _STREAM_UPDATE_TARGET_ROW_COUNT from synapse.replication.tcp.streams.events import ( + _MAX_STATE_UPDATES_PER_ROOM, + EventsStreamAllStateRow, EventsStreamCurrentStateRow, EventsStreamEventRow, EventsStreamRow, @@ -106,11 +110,21 @@ class EventsStreamTestCase(BaseStreamTestCase): self.assertEqual([], received_rows) - def test_update_function_huge_state_change(self) -> None: + @parameterized.expand( + [(_STREAM_UPDATE_TARGET_ROW_COUNT, False), (_MAX_STATE_UPDATES_PER_ROOM, True)] + ) + def test_update_function_huge_state_change( + self, num_state_changes: int, collapse_state_changes: bool + ) -> None: """Test replication with many state events Ensures that all events are correctly replicated when there are lots of state change rows to be replicated. + + Args: + num_state_changes: The number of state changes to create. + collapse_state_changes: Whether the state changes are expected to be + collapsed or not. """ # we want to generate lots of state changes at a single stream ID. @@ -145,7 +159,7 @@ class EventsStreamTestCase(BaseStreamTestCase): events = [ self._inject_state_event(sender=OTHER_USER) - for _ in range(_STREAM_UPDATE_TARGET_ROW_COUNT) + for _ in range(num_state_changes) ] self.replicate() @@ -202,8 +216,7 @@ class EventsStreamTestCase(BaseStreamTestCase): row for row in self.test_handler.received_rdata_rows if row[0] == "events" ] - # first check the first two rows, which should be state1 - + # first check the first two rows, which should be the state1 event. stream_name, token, row = received_rows.pop(0) self.assertEqual("events", stream_name) self.assertIsInstance(row, EventsStreamRow) @@ -217,7 +230,7 @@ class EventsStreamTestCase(BaseStreamTestCase): self.assertIsInstance(row.data, EventsStreamCurrentStateRow) self.assertEqual(row.data.event_id, state1.event_id) - # now the last two rows, which should be state2 + # now the last two rows, which should be the state2 event. stream_name, token, row = received_rows.pop(-2) self.assertEqual("events", stream_name) self.assertIsInstance(row, EventsStreamRow) @@ -231,34 +244,54 @@ class EventsStreamTestCase(BaseStreamTestCase): self.assertIsInstance(row.data, EventsStreamCurrentStateRow) self.assertEqual(row.data.event_id, state2.event_id) - # that should leave us with the rows for the PL event - self.assertEqual(len(received_rows), len(events) + 2) + # Based on the number of + if collapse_state_changes: + # that should leave us with the rows for the PL event, the state changes + # get collapsed into a single row. + self.assertEqual(len(received_rows), 2) - stream_name, token, row = received_rows.pop(0) - self.assertEqual("events", stream_name) - self.assertIsInstance(row, EventsStreamRow) - self.assertEqual(row.type, "ev") - self.assertIsInstance(row.data, EventsStreamEventRow) - self.assertEqual(row.data.event_id, pl_event.event_id) + stream_name, token, row = received_rows.pop(0) + self.assertEqual("events", stream_name) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "ev") + self.assertIsInstance(row.data, EventsStreamEventRow) + self.assertEqual(row.data.event_id, pl_event.event_id) - # the state rows are unsorted - state_rows: List[EventsStreamCurrentStateRow] = [] - for stream_name, _, row in received_rows: + stream_name, token, row = received_rows.pop(0) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "state-all") + self.assertIsInstance(row.data, EventsStreamAllStateRow) + self.assertEqual(row.data.room_id, state2.room_id) + + else: + # that should leave us with the rows for the PL event + self.assertEqual(len(received_rows), len(events) + 2) + + stream_name, token, row = received_rows.pop(0) self.assertEqual("events", stream_name) self.assertIsInstance(row, EventsStreamRow) - self.assertEqual(row.type, "state") - self.assertIsInstance(row.data, EventsStreamCurrentStateRow) - state_rows.append(row.data) - - state_rows.sort(key=lambda r: r.state_key) - - sr = state_rows.pop(0) - self.assertEqual(sr.type, EventTypes.PowerLevels) - self.assertEqual(sr.event_id, pl_event.event_id) - for sr in state_rows: - self.assertEqual(sr.type, "test_state_event") - # "None" indicates the state has been deleted - self.assertIsNone(sr.event_id) + self.assertEqual(row.type, "ev") + self.assertIsInstance(row.data, EventsStreamEventRow) + self.assertEqual(row.data.event_id, pl_event.event_id) + + # the state rows are unsorted + state_rows: List[EventsStreamCurrentStateRow] = [] + for stream_name, _, row in received_rows: + self.assertEqual("events", stream_name) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "state") + self.assertIsInstance(row.data, EventsStreamCurrentStateRow) + state_rows.append(row.data) + + state_rows.sort(key=lambda r: r.state_key) + + sr = state_rows.pop(0) + self.assertEqual(sr.type, EventTypes.PowerLevels) + self.assertEqual(sr.event_id, pl_event.event_id) + for sr in state_rows: + self.assertEqual(sr.type, "test_state_event") + # "None" indicates the state has been deleted + self.assertIsNone(sr.event_id) def test_update_function_state_row_limit(self) -> None: """Test replication with many state events over several stream ids.""" diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index 6b9692c486..0c054a598f 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py
@@ -24,7 +24,10 @@ import synapse.rest.admin from synapse.http.site import XForwardedForRequest from synapse.rest.client import login from synapse.server import HomeServer -from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY +from synapse.storage.databases.main.client_ips import ( + LAST_SEEN_GRANULARITY, + DeviceLastConnectionInfo, +) from synapse.types import UserID from synapse.util import Clock @@ -65,15 +68,15 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): ) r = result[(user_id, device_id)] - self.assertLessEqual( - { - "user_id": user_id, - "device_id": device_id, - "ip": "ip", - "user_agent": "user_agent", - "last_seen": 12345678000, - }.items(), - r.items(), + self.assertEqual( + DeviceLastConnectionInfo( + user_id=user_id, + device_id=device_id, + ip="ip", + user_agent="user_agent", + last_seen=12345678000, + ), + r, ) def test_insert_new_client_ip_none_device_id(self) -> None: @@ -201,13 +204,13 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.assertEqual( result, { - (user_id, device_id): { - "user_id": user_id, - "device_id": device_id, - "ip": "ip", - "user_agent": "user_agent", - "last_seen": 12345678000, - }, + (user_id, device_id): DeviceLastConnectionInfo( + user_id=user_id, + device_id=device_id, + ip="ip", + user_agent="user_agent", + last_seen=12345678000, + ), }, ) @@ -292,20 +295,20 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.assertEqual( result, { - (user_id, device_id_1): { - "user_id": user_id, - "device_id": device_id_1, - "ip": "ip_1", - "user_agent": "user_agent_1", - "last_seen": 12345678000, - }, - (user_id, device_id_2): { - "user_id": user_id, - "device_id": device_id_2, - "ip": "ip_2", - "user_agent": "user_agent_3", - "last_seen": 12345688000 + LAST_SEEN_GRANULARITY, - }, + (user_id, device_id_1): DeviceLastConnectionInfo( + user_id=user_id, + device_id=device_id_1, + ip="ip_1", + user_agent="user_agent_1", + last_seen=12345678000, + ), + (user_id, device_id_2): DeviceLastConnectionInfo( + user_id=user_id, + device_id=device_id_2, + ip="ip_2", + user_agent="user_agent_3", + last_seen=12345688000 + LAST_SEEN_GRANULARITY, + ), }, ) @@ -526,15 +529,15 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): ) r = result[(user_id, device_id)] - self.assertLessEqual( - { - "user_id": user_id, - "device_id": device_id, - "ip": None, - "user_agent": None, - "last_seen": None, - }.items(), - r.items(), + self.assertEqual( + DeviceLastConnectionInfo( + user_id=user_id, + device_id=device_id, + ip=None, + user_agent=None, + last_seen=None, + ), + r, ) # Register the background update to run again. @@ -561,15 +564,15 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): ) r = result[(user_id, device_id)] - self.assertLessEqual( - { - "user_id": user_id, - "device_id": device_id, - "ip": "ip", - "user_agent": "user_agent", - "last_seen": 0, - }.items(), - r.items(), + self.assertEqual( + DeviceLastConnectionInfo( + user_id=user_id, + device_id=device_id, + ip="ip", + user_agent="user_agent", + last_seen=0, + ), + r, ) def test_old_user_ips_pruned(self) -> None: @@ -640,15 +643,15 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): ) r = result2[(user_id, device_id)] - self.assertLessEqual( - { - "user_id": user_id, - "device_id": device_id, - "ip": "ip", - "user_agent": "user_agent", - "last_seen": 0, - }.items(), - r.items(), + self.assertEqual( + DeviceLastConnectionInfo( + user_id=user_id, + device_id=device_id, + ip="ip", + user_agent="user_agent", + last_seen=0, + ), + r, ) def test_invalid_user_agents_are_ignored(self) -> None: @@ -777,13 +780,13 @@ class ClientIpAuthTestCase(unittest.HomeserverTestCase): self.store.get_last_client_ip_by_device(self.user_id, device_id) ) r = result[(self.user_id, device_id)] - self.assertLessEqual( - { - "user_id": self.user_id, - "device_id": device_id, - "ip": expected_ip, - "user_agent": "Mozzila pizza", - "last_seen": 123456100, - }.items(), - r.items(), + self.assertEqual( + DeviceLastConnectionInfo( + user_id=self.user_id, + device_id=device_id, + ip=expected_ip, + user_agent="Mozzila pizza", + last_seen=123456100, + ), + r, ) diff --git a/tests/util/test_retryutils.py b/tests/util/test_retryutils.py
index 4bcd17a6fc..ad88b24566 100644 --- a/tests/util/test_retryutils.py +++ b/tests/util/test_retryutils.py
@@ -11,6 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from unittest import mock + +from synapse.notifier import Notifier +from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter from tests.unittest import HomeserverTestCase @@ -109,6 +113,77 @@ class RetryLimiterTestCase(HomeserverTestCase): new_timings = self.get_success(store.get_destination_retry_timings("test_dest")) self.assertIsNone(new_timings) + def test_notifier_replication(self) -> None: + """Ensure the notifier/replication client is called only when expected.""" + store = self.hs.get_datastores().main + + notifier = mock.Mock(spec=Notifier) + replication_client = mock.Mock(spec=ReplicationCommandHandler) + + limiter = self.get_success( + get_retry_limiter( + "test_dest", + self.clock, + store, + notifier=notifier, + replication_client=replication_client, + ) + ) + + # The server is already up, nothing should occur. + self.pump(1) + with limiter: + pass + self.pump() + + new_timings = self.get_success(store.get_destination_retry_timings("test_dest")) + self.assertIsNone(new_timings) + notifier.notify_remote_server_up.assert_not_called() + replication_client.send_remote_server_up.assert_not_called() + + # Attempt again, but return an error. This will cause new retry timings, but + # should not trigger server up notifications. + self.pump(1) + try: + with limiter: + raise AssertionError("argh") + except AssertionError: + pass + self.pump() + + new_timings = self.get_success(store.get_destination_retry_timings("test_dest")) + # The exact retry timings are tested separately. + self.assertIsNotNone(new_timings) + notifier.notify_remote_server_up.assert_not_called() + replication_client.send_remote_server_up.assert_not_called() + + # A second failing request should be treated as the above. + self.pump(1) + try: + with limiter: + raise AssertionError("argh") + except AssertionError: + pass + self.pump() + + new_timings = self.get_success(store.get_destination_retry_timings("test_dest")) + # The exact retry timings are tested separately. + self.assertIsNotNone(new_timings) + notifier.notify_remote_server_up.assert_not_called() + replication_client.send_remote_server_up.assert_not_called() + + # A final successful attempt should generate a server up notification. + self.pump(1) + with limiter: + pass + self.pump() + + new_timings = self.get_success(store.get_destination_retry_timings("test_dest")) + # The exact retry timings are tested separately. + self.assertIsNone(new_timings) + notifier.notify_remote_server_up.assert_called_once_with("test_dest") + replication_client.send_remote_server_up.assert_called_once_with("test_dest") + def test_max_retry_interval(self) -> None: """Test that `destination_max_retry_interval` setting works as expected""" store = self.hs.get_datastores().main