diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 151223219d..d1e679719b 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -20,6 +20,7 @@ from synapse.util.stringutils import to_ascii
from synapse.storage.engines import PostgresEngine
from twisted.internet import defer
+from collections import namedtuple
import logging
@@ -29,6 +30,16 @@ logger = logging.getLogger(__name__)
MAX_STATE_DELTA_HOPS = 100
+class _GetStateGroupDelta(namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids"))):
+ """Return type of get_state_group_delta that implements __len__, which lets
+ us use the itrable flag when caching
+ """
+ __slots__ = []
+
+ def __len__(self):
+ return len(self.delta_ids) if self.delta_ids else 0
+
+
class StateStore(SQLBaseStore):
""" Keeps track of the state at a given event.
@@ -98,6 +109,7 @@ class StateStore(SQLBaseStore):
_get_current_state_ids_txn,
)
+ @cached(max_entries=10000, iterable=True)
def get_state_group_delta(self, state_group):
"""Given a state group try to return a previous group and a delta between
the old and the new.
@@ -117,7 +129,7 @@ class StateStore(SQLBaseStore):
)
if not prev_group:
- return None, None
+ return _GetStateGroupDelta(None, None)
delta_ids = self._simple_select_list_txn(
txn,
@@ -128,10 +140,10 @@ class StateStore(SQLBaseStore):
retcols=("type", "state_key", "event_id",)
)
- return prev_group, {
+ return _GetStateGroupDelta(prev_group, {
(row["type"], row["state_key"]): row["event_id"]
for row in delta_ids
- }
+ })
return self.runInteraction(
"get_state_group_delta",
_get_state_group_delta_txn,
|