diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index 500c9ccfbc..e0eda545b9 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -237,7 +237,9 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
)
current_state = self.get_success(
self.store.get_events_as_list(
- (self.get_success(self.store.get_current_state_ids(room_id))).values()
+ (
+ self.get_success(self.store.get_partial_current_state_ids(room_id))
+ ).values()
)
)
@@ -512,7 +514,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
self.get_success(d)
# sanity-check: the room should show that the new user is a member
- r = self.get_success(self.store.get_current_state_ids(room_id))
+ r = self.get_success(self.store.get_partial_current_state_ids(room_id))
self.assertEqual(r[(EventTypes.Member, other_user)], join_event.event_id)
return join_event
diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py
index 1d5b2492c0..1a36c25c41 100644
--- a/tests/handlers/test_federation_event.py
+++ b/tests/handlers/test_federation_event.py
@@ -91,7 +91,9 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
event_injection.inject_member_event(self.hs, room_id, OTHER_USER, "join")
)
- initial_state_map = self.get_success(main_store.get_current_state_ids(room_id))
+ initial_state_map = self.get_success(
+ main_store.get_partial_current_state_ids(room_id)
+ )
auth_event_ids = [
initial_state_map[("m.room.create", "")],
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 057256cecd..14a0ee4922 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -146,7 +146,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
)
)
- self.datastore.get_current_state_deltas = Mock(return_value=(0, None))
+ self.datastore.get_partial_current_state_deltas = Mock(return_value=(0, None))
self.datastore.get_to_device_stream_token = lambda: 0
self.datastore.get_new_device_msgs_for_remote = (
diff --git a/tests/rest/client/test_upgrade_room.py b/tests/rest/client/test_upgrade_room.py
index a21cbe9fa8..98c1039d33 100644
--- a/tests/rest/client/test_upgrade_room.py
+++ b/tests/rest/client/test_upgrade_room.py
@@ -249,7 +249,9 @@ class UpgradeRoomTest(unittest.HomeserverTestCase):
new_space_id = channel.json_body["replacement_room"]
- state_ids = self.get_success(self.store.get_current_state_ids(new_space_id))
+ state_ids = self.get_success(
+ self.store.get_partial_current_state_ids(new_space_id)
+ )
# Ensure the new room is still a space.
create_event = self.get_success(
@@ -284,7 +286,9 @@ class UpgradeRoomTest(unittest.HomeserverTestCase):
new_room_id = channel.json_body["replacement_room"]
- state_ids = self.get_success(self.store.get_current_state_ids(new_room_id))
+ state_ids = self.get_success(
+ self.store.get_partial_current_state_ids(new_room_id)
+ )
# Ensure the new room is the same type as the old room.
create_event = self.get_success(
diff --git a/tests/storage/util/test_partial_state_events_tracker.py b/tests/storage/util/test_partial_state_events_tracker.py
index 303e190b6c..cae14151c0 100644
--- a/tests/storage/util/test_partial_state_events_tracker.py
+++ b/tests/storage/util/test_partial_state_events_tracker.py
@@ -17,8 +17,12 @@ from unittest import mock
from twisted.internet.defer import CancelledError, ensureDeferred
-from synapse.storage.util.partial_state_events_tracker import PartialStateEventsTracker
+from synapse.storage.util.partial_state_events_tracker import (
+ PartialCurrentStateTracker,
+ PartialStateEventsTracker,
+)
+from tests.test_utils import make_awaitable
from tests.unittest import TestCase
@@ -115,3 +119,56 @@ class PartialStateEventsTrackerTestCase(TestCase):
self.tracker.notify_un_partial_stated("event1")
self.successResultOf(d2)
+
+
+class PartialCurrentStateTrackerTestCase(TestCase):
+ def setUp(self) -> None:
+ self.mock_store = mock.Mock(spec_set=["is_partial_state_room"])
+
+ self.tracker = PartialCurrentStateTracker(self.mock_store)
+
+ def test_does_not_block_for_full_state_rooms(self):
+ self.mock_store.is_partial_state_room.return_value = make_awaitable(False)
+
+ self.successResultOf(ensureDeferred(self.tracker.await_full_state("room_id")))
+
+ def test_blocks_for_partial_room_state(self):
+ self.mock_store.is_partial_state_room.return_value = make_awaitable(True)
+
+ d = ensureDeferred(self.tracker.await_full_state("room_id"))
+
+ # there should be no result yet
+ self.assertNoResult(d)
+
+ # notifying that the room has been de-partial-stated should unblock
+ self.tracker.notify_un_partial_stated("room_id")
+ self.successResultOf(d)
+
+ def test_un_partial_state_race(self):
+ # We should correctly handle race between awaiting the state and us
+ # un-partialling the state
+ async def is_partial_state_room(events):
+ self.tracker.notify_un_partial_stated("room_id")
+ return True
+
+ self.mock_store.is_partial_state_room.side_effect = is_partial_state_room
+
+ self.successResultOf(ensureDeferred(self.tracker.await_full_state("room_id")))
+
+ def test_cancellation(self):
+ self.mock_store.is_partial_state_room.return_value = make_awaitable(True)
+
+ d1 = ensureDeferred(self.tracker.await_full_state("room_id"))
+ self.assertNoResult(d1)
+
+ d2 = ensureDeferred(self.tracker.await_full_state("room_id"))
+ self.assertNoResult(d2)
+
+ d1.cancel()
+ self.assertFailure(d1, CancelledError)
+
+ # d2 should still be waiting!
+ self.assertNoResult(d2)
+
+ self.tracker.notify_un_partial_stated("room_id")
+ self.successResultOf(d2)
|