diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index b8016f679a..dadf3d1e3a 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -25,6 +25,7 @@ from typing import (
)
import attr
+from sortedcontainers import SortedDict
from twisted.internet import defer
@@ -72,6 +73,24 @@ class _GetStateGroupDelta:
return len(self.delta_ids) if self.delta_ids else 0
+def state_filter_rough_priority_comparator(
+ state_filter: StateFilter,
+) -> Tuple[int, int]:
+ """
+ Returns a comparable value that roughly indicates the relative size of this
+ state filter compared to others.
+ 'Larger' state filters should sort first when using ascending order, so
+ this is essentially the opposite of 'size'.
+ It should be treated as a rough guide only and should not be interpreted to
+ have any particular meaning. The representation may also change
+
+ The current implementation returns a tuple of the form:
+ * -1 for include_others, 0 otherwise
+ * -(number of entries in state_filter.types)
+ """
+ return -int(state_filter.include_others), -len(state_filter.types)
+
+
class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
"""A data store for fetching/storing state groups."""
@@ -127,7 +146,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
# Current ongoing get_state_for_groups in-flight requests
# {group ID -> {StateFilter -> ObservableDeferred}}
self._state_group_inflight_requests: Dict[
- int, Dict[StateFilter, AbstractObservableDeferred[StateMap[str]]]
+ int, SortedDict[StateFilter, AbstractObservableDeferred[StateMap[str]]]
] = {}
def get_max_state_group_txn(txn: Cursor) -> int:
@@ -279,7 +298,10 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
# The list of ongoing requests which will help narrow the current request.
reusable_requests = []
- for (request_state_filter, request_deferred) in inflight_requests.items():
+
+ # Iterate over existing requests in roughly biggest-first order.
+ for request_state_filter in inflight_requests:
+ request_deferred = inflight_requests[request_state_filter]
new_state_filter_left_over = state_filter_left_over.approx_difference(
request_state_filter
)
@@ -358,7 +380,9 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
observable_deferred = ObservableDeferred(request_deferred, consumeErrors=True)
# Insert the ObservableDeferred into the cache
- group_request_dict = self._state_group_inflight_requests.setdefault(group, {})
+ group_request_dict = self._state_group_inflight_requests.setdefault(
+ group, SortedDict(state_filter_rough_priority_comparator)
+ )
group_request_dict[db_state_filter] = observable_deferred
return await make_deferred_yieldable(observable_deferred.observe())
|