summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/handlers/test_federation.py6
-rw-r--r--tests/handlers/test_federation_event.py4
-rw-r--r--tests/handlers/test_typing.py2
-rw-r--r--tests/rest/client/test_upgrade_room.py8
-rw-r--r--tests/storage/util/test_partial_state_events_tracker.py59
5 files changed, 72 insertions, 7 deletions
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index 500c9ccfbc..e0eda545b9 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -237,7 +237,9 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
         )
         current_state = self.get_success(
             self.store.get_events_as_list(
-                (self.get_success(self.store.get_current_state_ids(room_id))).values()
+                (
+                    self.get_success(self.store.get_partial_current_state_ids(room_id))
+                ).values()
             )
         )
 
@@ -512,7 +514,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
         self.get_success(d)
 
         # sanity-check: the room should show that the new user is a member
-        r = self.get_success(self.store.get_current_state_ids(room_id))
+        r = self.get_success(self.store.get_partial_current_state_ids(room_id))
         self.assertEqual(r[(EventTypes.Member, other_user)], join_event.event_id)
 
         return join_event
diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py
index 1d5b2492c0..1a36c25c41 100644
--- a/tests/handlers/test_federation_event.py
+++ b/tests/handlers/test_federation_event.py
@@ -91,7 +91,9 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
             event_injection.inject_member_event(self.hs, room_id, OTHER_USER, "join")
         )
 
-        initial_state_map = self.get_success(main_store.get_current_state_ids(room_id))
+        initial_state_map = self.get_success(
+            main_store.get_partial_current_state_ids(room_id)
+        )
 
         auth_event_ids = [
             initial_state_map[("m.room.create", "")],
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 057256cecd..14a0ee4922 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -146,7 +146,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
             )
         )
 
-        self.datastore.get_current_state_deltas = Mock(return_value=(0, None))
+        self.datastore.get_partial_current_state_deltas = Mock(return_value=(0, None))
 
         self.datastore.get_to_device_stream_token = lambda: 0
         self.datastore.get_new_device_msgs_for_remote = (
diff --git a/tests/rest/client/test_upgrade_room.py b/tests/rest/client/test_upgrade_room.py
index a21cbe9fa8..98c1039d33 100644
--- a/tests/rest/client/test_upgrade_room.py
+++ b/tests/rest/client/test_upgrade_room.py
@@ -249,7 +249,9 @@ class UpgradeRoomTest(unittest.HomeserverTestCase):
 
         new_space_id = channel.json_body["replacement_room"]
 
-        state_ids = self.get_success(self.store.get_current_state_ids(new_space_id))
+        state_ids = self.get_success(
+            self.store.get_partial_current_state_ids(new_space_id)
+        )
 
         # Ensure the new room is still a space.
         create_event = self.get_success(
@@ -284,7 +286,9 @@ class UpgradeRoomTest(unittest.HomeserverTestCase):
 
         new_room_id = channel.json_body["replacement_room"]
 
-        state_ids = self.get_success(self.store.get_current_state_ids(new_room_id))
+        state_ids = self.get_success(
+            self.store.get_partial_current_state_ids(new_room_id)
+        )
 
         # Ensure the new room is the same type as the old room.
         create_event = self.get_success(
diff --git a/tests/storage/util/test_partial_state_events_tracker.py b/tests/storage/util/test_partial_state_events_tracker.py
index 303e190b6c..cae14151c0 100644
--- a/tests/storage/util/test_partial_state_events_tracker.py
+++ b/tests/storage/util/test_partial_state_events_tracker.py
@@ -17,8 +17,12 @@ from unittest import mock
 
 from twisted.internet.defer import CancelledError, ensureDeferred
 
-from synapse.storage.util.partial_state_events_tracker import PartialStateEventsTracker
+from synapse.storage.util.partial_state_events_tracker import (
+    PartialCurrentStateTracker,
+    PartialStateEventsTracker,
+)
 
+from tests.test_utils import make_awaitable
 from tests.unittest import TestCase
 
 
@@ -115,3 +119,56 @@ class PartialStateEventsTrackerTestCase(TestCase):
 
         self.tracker.notify_un_partial_stated("event1")
         self.successResultOf(d2)
+
+
+class PartialCurrentStateTrackerTestCase(TestCase):
+    def setUp(self) -> None:
+        self.mock_store = mock.Mock(spec_set=["is_partial_state_room"])
+
+        self.tracker = PartialCurrentStateTracker(self.mock_store)
+
+    def test_does_not_block_for_full_state_rooms(self):
+        self.mock_store.is_partial_state_room.return_value = make_awaitable(False)
+
+        self.successResultOf(ensureDeferred(self.tracker.await_full_state("room_id")))
+
+    def test_blocks_for_partial_room_state(self):
+        self.mock_store.is_partial_state_room.return_value = make_awaitable(True)
+
+        d = ensureDeferred(self.tracker.await_full_state("room_id"))
+
+        # there should be no result yet
+        self.assertNoResult(d)
+
+        # notifying that the room has been de-partial-stated should unblock
+        self.tracker.notify_un_partial_stated("room_id")
+        self.successResultOf(d)
+
+    def test_un_partial_state_race(self):
+        # We should correctly handle race between awaiting the state and us
+        # un-partialling the state
+        async def is_partial_state_room(events):
+            self.tracker.notify_un_partial_stated("room_id")
+            return True
+
+        self.mock_store.is_partial_state_room.side_effect = is_partial_state_room
+
+        self.successResultOf(ensureDeferred(self.tracker.await_full_state("room_id")))
+
+    def test_cancellation(self):
+        self.mock_store.is_partial_state_room.return_value = make_awaitable(True)
+
+        d1 = ensureDeferred(self.tracker.await_full_state("room_id"))
+        self.assertNoResult(d1)
+
+        d2 = ensureDeferred(self.tracker.await_full_state("room_id"))
+        self.assertNoResult(d2)
+
+        d1.cancel()
+        self.assertFailure(d1, CancelledError)
+
+        # d2 should still be waiting!
+        self.assertNoResult(d2)
+
+        self.tracker.notify_un_partial_stated("room_id")
+        self.successResultOf(d2)