diff options
-rw-r--r-- | changelog.d/11608.misc | 1 | ||||
-rw-r--r-- | synapse/storage/databases/state/store.py | 16 | ||||
-rw-r--r-- | tests/storage/databases/test_state_store.py | 69 |
3 files changed, 86 insertions, 0 deletions
diff --git a/changelog.d/11608.misc b/changelog.d/11608.misc new file mode 100644 index 0000000000..3af049b969 --- /dev/null +++ b/changelog.d/11608.misc @@ -0,0 +1 @@ +Deduplicate in-flight requests in `_get_state_for_groups`. diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index 3af69a2076..b8016f679a 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -56,6 +56,7 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) MAX_STATE_DELTA_HOPS = 100 +MAX_INFLIGHT_REQUESTS_PER_GROUP = 5 @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -258,6 +259,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): Attempts to gather in-flight requests and re-use them to retrieve state for the given state group, filtered with the given state filter. + If there are more than MAX_INFLIGHT_REQUESTS_PER_GROUP in-flight requests, + and there *still* isn't enough information to complete the request by solely + reusing others, a full state filter will be requested to ensure that subsequent + requests can reuse this request. + Used as part of _get_state_for_group_using_inflight_cache. Returns: @@ -288,6 +294,16 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): # to cover our StateFilter and give us the state we need. break + if ( + state_filter_left_over != StateFilter.none() + and len(inflight_requests) >= MAX_INFLIGHT_REQUESTS_PER_GROUP + ): + # There are too many requests for this group. + # To prevent even more from building up, we request the whole + # state filter to guarantee that we can be reused by any subsequent + # requests for this state group. + return (), StateFilter.all() + return reusable_requests, state_filter_left_over async def _get_state_for_group_fire_request( diff --git a/tests/storage/databases/test_state_store.py b/tests/storage/databases/test_state_store.py index 3a4a4a3a29..076b660809 100644 --- a/tests/storage/databases/test_state_store.py +++ b/tests/storage/databases/test_state_store.py @@ -19,6 +19,7 @@ from twisted.internet.defer import Deferred, ensureDeferred from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import EventTypes +from synapse.storage.databases.state.store import MAX_INFLIGHT_REQUESTS_PER_GROUP from synapse.storage.state import StateFilter from synapse.types import StateMap from synapse.util import Clock @@ -281,3 +282,71 @@ class StateGroupInflightCachingTestCase(HomeserverTestCase): self.assertEqual(self.get_success(req1), FAKE_STATE) self.assertEqual(self.get_success(req2), FAKE_STATE) + + def test_inflight_requests_capped(self) -> None: + """ + Tests that the number of in-flight requests is capped to 5. + + - requests several pieces of state separately + (5 to hit the limit, 1 to 'shunt out', another that comes after the + group has been 'shunted out') + - checks to see that the torrent of requests is shunted out by + rewriting one of the filters as the 'all' state filter + - requests after that one do not cause any additional queries + """ + # 5 at the time of writing. + CAP_COUNT = MAX_INFLIGHT_REQUESTS_PER_GROUP + + reqs = [] + + # Request 7 different keys (1 to 7) of the `some.state` type. + for req_id in range(CAP_COUNT + 2): + reqs.append( + ensureDeferred( + self.state_datastore._get_state_for_group_using_inflight_cache( + 42, + StateFilter.freeze( + {"some.state": {str(req_id + 1)}}, include_others=False + ), + ) + ) + ) + self.pump(by=0.1) + + # There should only be 6 calls to the database, not 7. + self.assertEqual(len(self.get_state_group_calls), CAP_COUNT + 1) + + # Assert that the first 5 are exact requests for the individual pieces + # wanted + for req_id in range(CAP_COUNT): + groups, sf, d = self.get_state_group_calls[req_id] + self.assertEqual( + sf, + StateFilter.freeze( + {"some.state": {str(req_id + 1)}}, include_others=False + ), + ) + + # The 6th request should be the 'all' state filter + groups, sf, d = self.get_state_group_calls[CAP_COUNT] + self.assertEqual(sf, StateFilter.all()) + + # Complete the queries and check which requests complete as a result + for req_id in range(CAP_COUNT): + # This request should not have been completed yet + self.assertFalse(reqs[req_id].called) + + groups, sf, d = self.get_state_group_calls[req_id] + self._complete_request_fake(groups, sf, d) + + # This should have only completed this one request + self.assertTrue(reqs[req_id].called) + + # Now complete the final query; the last 2 requests should complete + # as a result + self.assertFalse(reqs[CAP_COUNT].called) + self.assertFalse(reqs[CAP_COUNT + 1].called) + groups, sf, d = self.get_state_group_calls[CAP_COUNT] + self._complete_request_fake(groups, sf, d) + self.assertTrue(reqs[CAP_COUNT].called) + self.assertTrue(reqs[CAP_COUNT + 1].called) |