diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py
index 128fc3e046..b8ab4ee54b 100644
--- a/tests/replication/tcp/streams/test_events.py
+++ b/tests/replication/tcp/streams/test_events.py
@@ -14,6 +14,8 @@
from typing import Any, List, Optional
+from parameterized import parameterized
+
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes, Membership
@@ -21,6 +23,8 @@ from synapse.events import EventBase
from synapse.replication.tcp.commands import RdataCommand
from synapse.replication.tcp.streams._base import _STREAM_UPDATE_TARGET_ROW_COUNT
from synapse.replication.tcp.streams.events import (
+ _MAX_STATE_UPDATES_PER_ROOM,
+ EventsStreamAllStateRow,
EventsStreamCurrentStateRow,
EventsStreamEventRow,
EventsStreamRow,
@@ -106,11 +110,21 @@ class EventsStreamTestCase(BaseStreamTestCase):
self.assertEqual([], received_rows)
- def test_update_function_huge_state_change(self) -> None:
+ @parameterized.expand(
+ [(_STREAM_UPDATE_TARGET_ROW_COUNT, False), (_MAX_STATE_UPDATES_PER_ROOM, True)]
+ )
+ def test_update_function_huge_state_change(
+ self, num_state_changes: int, collapse_state_changes: bool
+ ) -> None:
"""Test replication with many state events
Ensures that all events are correctly replicated when there are lots of
state change rows to be replicated.
+
+ Args:
+ num_state_changes: The number of state changes to create.
+ collapse_state_changes: Whether the state changes are expected to be
+ collapsed or not.
"""
# we want to generate lots of state changes at a single stream ID.
@@ -145,7 +159,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
events = [
self._inject_state_event(sender=OTHER_USER)
- for _ in range(_STREAM_UPDATE_TARGET_ROW_COUNT)
+ for _ in range(num_state_changes)
]
self.replicate()
@@ -202,8 +216,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
row for row in self.test_handler.received_rdata_rows if row[0] == "events"
]
- # first check the first two rows, which should be state1
-
+ # first check the first two rows, which should be the state1 event.
stream_name, token, row = received_rows.pop(0)
self.assertEqual("events", stream_name)
self.assertIsInstance(row, EventsStreamRow)
@@ -217,7 +230,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
self.assertEqual(row.data.event_id, state1.event_id)
- # now the last two rows, which should be state2
+ # now the last two rows, which should be the state2 event.
stream_name, token, row = received_rows.pop(-2)
self.assertEqual("events", stream_name)
self.assertIsInstance(row, EventsStreamRow)
@@ -231,34 +244,54 @@ class EventsStreamTestCase(BaseStreamTestCase):
self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
self.assertEqual(row.data.event_id, state2.event_id)
- # that should leave us with the rows for the PL event
- self.assertEqual(len(received_rows), len(events) + 2)
+ # Based on the number of
+ if collapse_state_changes:
+ # that should leave us with the rows for the PL event, the state changes
+ # get collapsed into a single row.
+ self.assertEqual(len(received_rows), 2)
- stream_name, token, row = received_rows.pop(0)
- self.assertEqual("events", stream_name)
- self.assertIsInstance(row, EventsStreamRow)
- self.assertEqual(row.type, "ev")
- self.assertIsInstance(row.data, EventsStreamEventRow)
- self.assertEqual(row.data.event_id, pl_event.event_id)
+ stream_name, token, row = received_rows.pop(0)
+ self.assertEqual("events", stream_name)
+ self.assertIsInstance(row, EventsStreamRow)
+ self.assertEqual(row.type, "ev")
+ self.assertIsInstance(row.data, EventsStreamEventRow)
+ self.assertEqual(row.data.event_id, pl_event.event_id)
- # the state rows are unsorted
- state_rows: List[EventsStreamCurrentStateRow] = []
- for stream_name, _, row in received_rows:
+ stream_name, token, row = received_rows.pop(0)
+ self.assertIsInstance(row, EventsStreamRow)
+ self.assertEqual(row.type, "state-all")
+ self.assertIsInstance(row.data, EventsStreamAllStateRow)
+ self.assertEqual(row.data.room_id, state2.room_id)
+
+ else:
+ # that should leave us with the rows for the PL event
+ self.assertEqual(len(received_rows), len(events) + 2)
+
+ stream_name, token, row = received_rows.pop(0)
self.assertEqual("events", stream_name)
self.assertIsInstance(row, EventsStreamRow)
- self.assertEqual(row.type, "state")
- self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
- state_rows.append(row.data)
-
- state_rows.sort(key=lambda r: r.state_key)
-
- sr = state_rows.pop(0)
- self.assertEqual(sr.type, EventTypes.PowerLevels)
- self.assertEqual(sr.event_id, pl_event.event_id)
- for sr in state_rows:
- self.assertEqual(sr.type, "test_state_event")
- # "None" indicates the state has been deleted
- self.assertIsNone(sr.event_id)
+ self.assertEqual(row.type, "ev")
+ self.assertIsInstance(row.data, EventsStreamEventRow)
+ self.assertEqual(row.data.event_id, pl_event.event_id)
+
+ # the state rows are unsorted
+ state_rows: List[EventsStreamCurrentStateRow] = []
+ for stream_name, _, row in received_rows:
+ self.assertEqual("events", stream_name)
+ self.assertIsInstance(row, EventsStreamRow)
+ self.assertEqual(row.type, "state")
+ self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
+ state_rows.append(row.data)
+
+ state_rows.sort(key=lambda r: r.state_key)
+
+ sr = state_rows.pop(0)
+ self.assertEqual(sr.type, EventTypes.PowerLevels)
+ self.assertEqual(sr.event_id, pl_event.event_id)
+ for sr in state_rows:
+ self.assertEqual(sr.type, "test_state_event")
+ # "None" indicates the state has been deleted
+ self.assertIsNone(sr.event_id)
def test_update_function_state_row_limit(self) -> None:
"""Test replication with many state events over several stream ids."""
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index 6b9692c486..0c054a598f 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -24,7 +24,10 @@ import synapse.rest.admin
from synapse.http.site import XForwardedForRequest
from synapse.rest.client import login
from synapse.server import HomeServer
-from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY
+from synapse.storage.databases.main.client_ips import (
+ LAST_SEEN_GRANULARITY,
+ DeviceLastConnectionInfo,
+)
from synapse.types import UserID
from synapse.util import Clock
@@ -65,15 +68,15 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
)
r = result[(user_id, device_id)]
- self.assertLessEqual(
- {
- "user_id": user_id,
- "device_id": device_id,
- "ip": "ip",
- "user_agent": "user_agent",
- "last_seen": 12345678000,
- }.items(),
- r.items(),
+ self.assertEqual(
+ DeviceLastConnectionInfo(
+ user_id=user_id,
+ device_id=device_id,
+ ip="ip",
+ user_agent="user_agent",
+ last_seen=12345678000,
+ ),
+ r,
)
def test_insert_new_client_ip_none_device_id(self) -> None:
@@ -201,13 +204,13 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.assertEqual(
result,
{
- (user_id, device_id): {
- "user_id": user_id,
- "device_id": device_id,
- "ip": "ip",
- "user_agent": "user_agent",
- "last_seen": 12345678000,
- },
+ (user_id, device_id): DeviceLastConnectionInfo(
+ user_id=user_id,
+ device_id=device_id,
+ ip="ip",
+ user_agent="user_agent",
+ last_seen=12345678000,
+ ),
},
)
@@ -292,20 +295,20 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.assertEqual(
result,
{
- (user_id, device_id_1): {
- "user_id": user_id,
- "device_id": device_id_1,
- "ip": "ip_1",
- "user_agent": "user_agent_1",
- "last_seen": 12345678000,
- },
- (user_id, device_id_2): {
- "user_id": user_id,
- "device_id": device_id_2,
- "ip": "ip_2",
- "user_agent": "user_agent_3",
- "last_seen": 12345688000 + LAST_SEEN_GRANULARITY,
- },
+ (user_id, device_id_1): DeviceLastConnectionInfo(
+ user_id=user_id,
+ device_id=device_id_1,
+ ip="ip_1",
+ user_agent="user_agent_1",
+ last_seen=12345678000,
+ ),
+ (user_id, device_id_2): DeviceLastConnectionInfo(
+ user_id=user_id,
+ device_id=device_id_2,
+ ip="ip_2",
+ user_agent="user_agent_3",
+ last_seen=12345688000 + LAST_SEEN_GRANULARITY,
+ ),
},
)
@@ -526,15 +529,15 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
)
r = result[(user_id, device_id)]
- self.assertLessEqual(
- {
- "user_id": user_id,
- "device_id": device_id,
- "ip": None,
- "user_agent": None,
- "last_seen": None,
- }.items(),
- r.items(),
+ self.assertEqual(
+ DeviceLastConnectionInfo(
+ user_id=user_id,
+ device_id=device_id,
+ ip=None,
+ user_agent=None,
+ last_seen=None,
+ ),
+ r,
)
# Register the background update to run again.
@@ -561,15 +564,15 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
)
r = result[(user_id, device_id)]
- self.assertLessEqual(
- {
- "user_id": user_id,
- "device_id": device_id,
- "ip": "ip",
- "user_agent": "user_agent",
- "last_seen": 0,
- }.items(),
- r.items(),
+ self.assertEqual(
+ DeviceLastConnectionInfo(
+ user_id=user_id,
+ device_id=device_id,
+ ip="ip",
+ user_agent="user_agent",
+ last_seen=0,
+ ),
+ r,
)
def test_old_user_ips_pruned(self) -> None:
@@ -640,15 +643,15 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
)
r = result2[(user_id, device_id)]
- self.assertLessEqual(
- {
- "user_id": user_id,
- "device_id": device_id,
- "ip": "ip",
- "user_agent": "user_agent",
- "last_seen": 0,
- }.items(),
- r.items(),
+ self.assertEqual(
+ DeviceLastConnectionInfo(
+ user_id=user_id,
+ device_id=device_id,
+ ip="ip",
+ user_agent="user_agent",
+ last_seen=0,
+ ),
+ r,
)
def test_invalid_user_agents_are_ignored(self) -> None:
@@ -777,13 +780,13 @@ class ClientIpAuthTestCase(unittest.HomeserverTestCase):
self.store.get_last_client_ip_by_device(self.user_id, device_id)
)
r = result[(self.user_id, device_id)]
- self.assertLessEqual(
- {
- "user_id": self.user_id,
- "device_id": device_id,
- "ip": expected_ip,
- "user_agent": "Mozzila pizza",
- "last_seen": 123456100,
- }.items(),
- r.items(),
+ self.assertEqual(
+ DeviceLastConnectionInfo(
+ user_id=self.user_id,
+ device_id=device_id,
+ ip=expected_ip,
+ user_agent="Mozzila pizza",
+ last_seen=123456100,
+ ),
+ r,
)
diff --git a/tests/util/test_retryutils.py b/tests/util/test_retryutils.py
index 4bcd17a6fc..ad88b24566 100644
--- a/tests/util/test_retryutils.py
+++ b/tests/util/test_retryutils.py
@@ -11,6 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from unittest import mock
+
+from synapse.notifier import Notifier
+from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
from tests.unittest import HomeserverTestCase
@@ -109,6 +113,77 @@ class RetryLimiterTestCase(HomeserverTestCase):
new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
self.assertIsNone(new_timings)
+ def test_notifier_replication(self) -> None:
+ """Ensure the notifier/replication client is called only when expected."""
+ store = self.hs.get_datastores().main
+
+ notifier = mock.Mock(spec=Notifier)
+ replication_client = mock.Mock(spec=ReplicationCommandHandler)
+
+ limiter = self.get_success(
+ get_retry_limiter(
+ "test_dest",
+ self.clock,
+ store,
+ notifier=notifier,
+ replication_client=replication_client,
+ )
+ )
+
+ # The server is already up, nothing should occur.
+ self.pump(1)
+ with limiter:
+ pass
+ self.pump()
+
+ new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
+ self.assertIsNone(new_timings)
+ notifier.notify_remote_server_up.assert_not_called()
+ replication_client.send_remote_server_up.assert_not_called()
+
+ # Attempt again, but return an error. This will cause new retry timings, but
+ # should not trigger server up notifications.
+ self.pump(1)
+ try:
+ with limiter:
+ raise AssertionError("argh")
+ except AssertionError:
+ pass
+ self.pump()
+
+ new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
+ # The exact retry timings are tested separately.
+ self.assertIsNotNone(new_timings)
+ notifier.notify_remote_server_up.assert_not_called()
+ replication_client.send_remote_server_up.assert_not_called()
+
+ # A second failing request should be treated as the above.
+ self.pump(1)
+ try:
+ with limiter:
+ raise AssertionError("argh")
+ except AssertionError:
+ pass
+ self.pump()
+
+ new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
+ # The exact retry timings are tested separately.
+ self.assertIsNotNone(new_timings)
+ notifier.notify_remote_server_up.assert_not_called()
+ replication_client.send_remote_server_up.assert_not_called()
+
+ # A final successful attempt should generate a server up notification.
+ self.pump(1)
+ with limiter:
+ pass
+ self.pump()
+
+ new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
+ # The exact retry timings are tested separately.
+ self.assertIsNone(new_timings)
+ notifier.notify_remote_server_up.assert_called_once_with("test_dest")
+ replication_client.send_remote_server_up.assert_called_once_with("test_dest")
+
def test_max_retry_interval(self) -> None:
"""Test that `destination_max_retry_interval` setting works as expected"""
store = self.hs.get_datastores().main
|