summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/11608.misc1
-rw-r--r--synapse/storage/databases/state/store.py16
-rw-r--r--tests/storage/databases/test_state_store.py69
3 files changed, 86 insertions, 0 deletions
diff --git a/changelog.d/11608.misc b/changelog.d/11608.misc
new file mode 100644
index 0000000000..3af049b969
--- /dev/null
+++ b/changelog.d/11608.misc
@@ -0,0 +1 @@
+Deduplicate in-flight requests in `_get_state_for_groups`.
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index 3af69a2076..b8016f679a 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -56,6 +56,7 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 MAX_STATE_DELTA_HOPS = 100
+MAX_INFLIGHT_REQUESTS_PER_GROUP = 5
 
 
 @attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -258,6 +259,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
         Attempts to gather in-flight requests and re-use them to retrieve state
         for the given state group, filtered with the given state filter.
 
+        If there are more than MAX_INFLIGHT_REQUESTS_PER_GROUP in-flight requests,
+        and there *still* isn't enough information to complete the request by solely
+        reusing others, a full state filter will be requested to ensure that subsequent
+        requests can reuse this request.
+
         Used as part of _get_state_for_group_using_inflight_cache.
 
         Returns:
@@ -288,6 +294,16 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
                 # to cover our StateFilter and give us the state we need.
                 break
 
+        if (
+            state_filter_left_over != StateFilter.none()
+            and len(inflight_requests) >= MAX_INFLIGHT_REQUESTS_PER_GROUP
+        ):
+            # There are too many requests for this group.
+            # To prevent even more from building up, we request the whole
+            # state filter to guarantee that we can be reused by any subsequent
+            # requests for this state group.
+            return (), StateFilter.all()
+
         return reusable_requests, state_filter_left_over
 
     async def _get_state_for_group_fire_request(
diff --git a/tests/storage/databases/test_state_store.py b/tests/storage/databases/test_state_store.py
index 3a4a4a3a29..076b660809 100644
--- a/tests/storage/databases/test_state_store.py
+++ b/tests/storage/databases/test_state_store.py
@@ -19,6 +19,7 @@ from twisted.internet.defer import Deferred, ensureDeferred
 from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.constants import EventTypes
+from synapse.storage.databases.state.store import MAX_INFLIGHT_REQUESTS_PER_GROUP
 from synapse.storage.state import StateFilter
 from synapse.types import StateMap
 from synapse.util import Clock
@@ -281,3 +282,71 @@ class StateGroupInflightCachingTestCase(HomeserverTestCase):
 
         self.assertEqual(self.get_success(req1), FAKE_STATE)
         self.assertEqual(self.get_success(req2), FAKE_STATE)
+
+    def test_inflight_requests_capped(self) -> None:
+        """
+        Tests that the number of in-flight requests is capped to 5.
+
+        - requests several pieces of state separately
+          (5 to hit the limit, 1 to 'shunt out', another that comes after the
+          group has been 'shunted out')
+        - checks to see that the torrent of requests is shunted out by
+          rewriting one of the filters as the 'all' state filter
+        - requests after that one do not cause any additional queries
+        """
+        # 5 at the time of writing.
+        CAP_COUNT = MAX_INFLIGHT_REQUESTS_PER_GROUP
+
+        reqs = []
+
+        # Request 7 different keys (1 to 7) of the `some.state` type.
+        for req_id in range(CAP_COUNT + 2):
+            reqs.append(
+                ensureDeferred(
+                    self.state_datastore._get_state_for_group_using_inflight_cache(
+                        42,
+                        StateFilter.freeze(
+                            {"some.state": {str(req_id + 1)}}, include_others=False
+                        ),
+                    )
+                )
+            )
+        self.pump(by=0.1)
+
+        # There should only be 6 calls to the database, not 7.
+        self.assertEqual(len(self.get_state_group_calls), CAP_COUNT + 1)
+
+        # Assert that the first 5 are exact requests for the individual pieces
+        # wanted
+        for req_id in range(CAP_COUNT):
+            groups, sf, d = self.get_state_group_calls[req_id]
+            self.assertEqual(
+                sf,
+                StateFilter.freeze(
+                    {"some.state": {str(req_id + 1)}}, include_others=False
+                ),
+            )
+
+        # The 6th request should be the 'all' state filter
+        groups, sf, d = self.get_state_group_calls[CAP_COUNT]
+        self.assertEqual(sf, StateFilter.all())
+
+        # Complete the queries and check which requests complete as a result
+        for req_id in range(CAP_COUNT):
+            # This request should not have been completed yet
+            self.assertFalse(reqs[req_id].called)
+
+            groups, sf, d = self.get_state_group_calls[req_id]
+            self._complete_request_fake(groups, sf, d)
+
+            # This should have only completed this one request
+            self.assertTrue(reqs[req_id].called)
+
+        # Now complete the final query; the last 2 requests should complete
+        # as a result
+        self.assertFalse(reqs[CAP_COUNT].called)
+        self.assertFalse(reqs[CAP_COUNT + 1].called)
+        groups, sf, d = self.get_state_group_calls[CAP_COUNT]
+        self._complete_request_fake(groups, sf, d)
+        self.assertTrue(reqs[CAP_COUNT].called)
+        self.assertTrue(reqs[CAP_COUNT + 1].called)