summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/12033.misc1
-rw-r--r--tests/storage/databases/test_state_store.py192
2 files changed, 172 insertions, 21 deletions
diff --git a/changelog.d/12033.misc b/changelog.d/12033.misc
new file mode 100644
index 0000000000..3af049b969
--- /dev/null
+++ b/changelog.d/12033.misc
@@ -0,0 +1 @@
+Deduplicate in-flight requests in `_get_state_for_groups`.
diff --git a/tests/storage/databases/test_state_store.py b/tests/storage/databases/test_state_store.py
index cf126ee62d..3a4a4a3a29 100644
--- a/tests/storage/databases/test_state_store.py
+++ b/tests/storage/databases/test_state_store.py
@@ -18,8 +18,9 @@ from unittest.mock import patch
 from twisted.internet.defer import Deferred, ensureDeferred
 from twisted.test.proto_helpers import MemoryReactor
 
+from synapse.api.constants import EventTypes
 from synapse.storage.state import StateFilter
-from synapse.types import MutableStateMap, StateMap
+from synapse.types import StateMap
 from synapse.util import Clock
 
 from tests.unittest import HomeserverTestCase
@@ -27,6 +28,21 @@ from tests.unittest import HomeserverTestCase
 if typing.TYPE_CHECKING:
     from synapse.server import HomeServer
 
+# StateFilter for ALL non-m.room.member state events
+ALL_NON_MEMBERS_STATE_FILTER = StateFilter.freeze(
+    types={EventTypes.Member: set()},
+    include_others=True,
+)
+
+FAKE_STATE = {
+    (EventTypes.Member, "@alice:test"): "join",
+    (EventTypes.Member, "@bob:test"): "leave",
+    (EventTypes.Member, "@charlie:test"): "invite",
+    ("test.type", "a"): "AAA",
+    ("test.type", "b"): "BBB",
+    ("other.event.type", "state.key"): "123",
+}
+
 
 class StateGroupInflightCachingTestCase(HomeserverTestCase):
     def prepare(
@@ -65,24 +81,8 @@ class StateGroupInflightCachingTestCase(HomeserverTestCase):
         Assemble a fake database response and complete the database request.
         """
 
-        result: Dict[int, StateMap[str]] = {}
-
-        for group in groups:
-            group_result: MutableStateMap[str] = {}
-            result[group] = group_result
-
-            for state_type, state_keys in state_filter.types.items():
-                if state_keys is None:
-                    group_result[(state_type, "a")] = "xyz"
-                    group_result[(state_type, "b")] = "xyz"
-                else:
-                    for state_key in state_keys:
-                        group_result[(state_type, state_key)] = "abc"
-
-            if state_filter.include_others:
-                group_result[("other.event.type", "state.key")] = "123"
-
-        d.callback(result)
+        # Return a filtered copy of the fake state
+        d.callback({group: state_filter.filter_state(FAKE_STATE) for group in groups})
 
     def test_duplicate_requests_deduplicated(self) -> None:
         """
@@ -125,9 +125,159 @@ class StateGroupInflightCachingTestCase(HomeserverTestCase):
         # Now we can complete the request
         self._complete_request_fake(groups, sf, d)
 
+        self.assertEqual(self.get_success(req1), FAKE_STATE)
+        self.assertEqual(self.get_success(req2), FAKE_STATE)
+
+    def test_smaller_request_deduplicated(self) -> None:
+        """
+        Tests that duplicate requests for state are deduplicated.
+
+        This test:
+        - requests some state (state group 42, 'all' state filter)
+        - requests a subset of that state, before the first request finishes
+        - checks to see that only one database query was made
+        - completes the database query
+        - checks that both requests see the correct retrieved state
+        """
+        req1 = ensureDeferred(
+            self.state_datastore._get_state_for_group_using_inflight_cache(
+                42, StateFilter.from_types((("test.type", None),))
+            )
+        )
+        self.pump(by=0.1)
+
+        # This should have gone to the database
+        self.assertEqual(len(self.get_state_group_calls), 1)
+        self.assertFalse(req1.called)
+
+        req2 = ensureDeferred(
+            self.state_datastore._get_state_for_group_using_inflight_cache(
+                42, StateFilter.from_types((("test.type", "b"),))
+            )
+        )
+        self.pump(by=0.1)
+
+        # No more calls should have gone to the database, because the second
+        # request was already in the in-flight cache!
+        self.assertEqual(len(self.get_state_group_calls), 1)
+        self.assertFalse(req1.called)
+        self.assertFalse(req2.called)
+
+        groups, sf, d = self.get_state_group_calls[0]
+        self.assertEqual(groups, (42,))
+        # The state filter is expanded internally for increased cache hit rate,
+        # so we the database sees a wider state filter than requested.
+        self.assertEqual(sf, ALL_NON_MEMBERS_STATE_FILTER)
+
+        # Now we can complete the request
+        self._complete_request_fake(groups, sf, d)
+
+        self.assertEqual(
+            self.get_success(req1),
+            {("test.type", "a"): "AAA", ("test.type", "b"): "BBB"},
+        )
+        self.assertEqual(self.get_success(req2), {("test.type", "b"): "BBB"})
+
+    def test_partially_overlapping_request_deduplicated(self) -> None:
+        """
+        Tests that partially-overlapping requests are partially deduplicated.
+
+        This test:
+        - requests a single type of wildcard state
+          (This is internally expanded to be all non-member state)
+        - requests the entire state in parallel
+        - checks to see that two database queries were made, but that the second
+          one is only for member state.
+        - completes the database queries
+        - checks that both requests have the correct result.
+        """
+
+        req1 = ensureDeferred(
+            self.state_datastore._get_state_for_group_using_inflight_cache(
+                42, StateFilter.from_types((("test.type", None),))
+            )
+        )
+        self.pump(by=0.1)
+
+        # This should have gone to the database
+        self.assertEqual(len(self.get_state_group_calls), 1)
+        self.assertFalse(req1.called)
+
+        req2 = ensureDeferred(
+            self.state_datastore._get_state_for_group_using_inflight_cache(
+                42, StateFilter.all()
+            )
+        )
+        self.pump(by=0.1)
+
+        # Because it only partially overlaps, this also went to the database
+        self.assertEqual(len(self.get_state_group_calls), 2)
+        self.assertFalse(req1.called)
+        self.assertFalse(req2.called)
+
+        # First request:
+        groups, sf, d = self.get_state_group_calls[0]
+        self.assertEqual(groups, (42,))
+        # The state filter is expanded internally for increased cache hit rate,
+        # so we the database sees a wider state filter than requested.
+        self.assertEqual(sf, ALL_NON_MEMBERS_STATE_FILTER)
+        self._complete_request_fake(groups, sf, d)
+
+        # Second request:
+        groups, sf, d = self.get_state_group_calls[1]
+        self.assertEqual(groups, (42,))
+        # The state filter is narrowed to only request membership state, because
+        # the remainder of the state is already being queried in the first request!
         self.assertEqual(
-            self.get_success(req1), {("other.event.type", "state.key"): "123"}
+            sf, StateFilter.freeze({EventTypes.Member: None}, include_others=False)
         )
+        self._complete_request_fake(groups, sf, d)
+
+        # Check the results are correct
         self.assertEqual(
-            self.get_success(req2), {("other.event.type", "state.key"): "123"}
+            self.get_success(req1),
+            {("test.type", "a"): "AAA", ("test.type", "b"): "BBB"},
         )
+        self.assertEqual(self.get_success(req2), FAKE_STATE)
+
+    def test_in_flight_requests_stop_being_in_flight(self) -> None:
+        """
+        Tests that in-flight request deduplication doesn't somehow 'hold on'
+        to completed requests: once they're done, they're taken out of the
+        in-flight cache.
+        """
+        req1 = ensureDeferred(
+            self.state_datastore._get_state_for_group_using_inflight_cache(
+                42, StateFilter.all()
+            )
+        )
+        self.pump(by=0.1)
+
+        # This should have gone to the database
+        self.assertEqual(len(self.get_state_group_calls), 1)
+        self.assertFalse(req1.called)
+
+        # Complete the request right away.
+        self._complete_request_fake(*self.get_state_group_calls[0])
+        self.assertTrue(req1.called)
+
+        # Send off another request
+        req2 = ensureDeferred(
+            self.state_datastore._get_state_for_group_using_inflight_cache(
+                42, StateFilter.all()
+            )
+        )
+        self.pump(by=0.1)
+
+        # It should have gone to the database again, because the previous request
+        # isn't in-flight and therefore isn't available for deduplication.
+        self.assertEqual(len(self.get_state_group_calls), 2)
+        self.assertFalse(req2.called)
+
+        # Complete the request right away.
+        self._complete_request_fake(*self.get_state_group_calls[1])
+        self.assertTrue(req2.called)
+        groups, sf, d = self.get_state_group_calls[0]
+
+        self.assertEqual(self.get_success(req1), FAKE_STATE)
+        self.assertEqual(self.get_success(req2), FAKE_STATE)