diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 85acf2ad1e..5673e4aa96 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -20,6 +20,7 @@ from synapse.util.stringutils import to_ascii
from synapse.storage.engines import PostgresEngine
from twisted.internet import defer
+from collections import namedtuple
import logging
@@ -29,6 +30,16 @@ logger = logging.getLogger(__name__)
MAX_STATE_DELTA_HOPS = 100
+class _GetStateGroupDelta(namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids"))):
+ """Return type of get_state_group_delta that implements __len__, which lets
+ us use the itrable flag when caching
+ """
+ __slots__ = []
+
+ def __len__(self):
+ return len(self.delta_ids) if self.delta_ids else 0
+
+
class StateStore(SQLBaseStore):
""" Keeps track of the state at a given event.
@@ -98,6 +109,46 @@ class StateStore(SQLBaseStore):
_get_current_state_ids_txn,
)
+ @cached(max_entries=10000, iterable=True)
+ def get_state_group_delta(self, state_group):
+ """Given a state group try to return a previous group and a delta between
+ the old and the new.
+
+ Returns:
+ (prev_group, delta_ids), where both may be None.
+ """
+ def _get_state_group_delta_txn(txn):
+ prev_group = self._simple_select_one_onecol_txn(
+ txn,
+ table="state_group_edges",
+ keyvalues={
+ "state_group": state_group,
+ },
+ retcol="prev_state_group",
+ allow_none=True,
+ )
+
+ if not prev_group:
+ return _GetStateGroupDelta(None, None)
+
+ delta_ids = self._simple_select_list_txn(
+ txn,
+ table="state_groups_state",
+ keyvalues={
+ "state_group": state_group,
+ },
+ retcols=("type", "state_key", "event_id",)
+ )
+
+ return _GetStateGroupDelta(prev_group, {
+ (row["type"], row["state_key"]): row["event_id"]
+ for row in delta_ids
+ })
+ return self.runInteraction(
+ "get_state_group_delta",
+ _get_state_group_delta_txn,
+ )
+
@defer.inlineCallbacks
def get_state_groups_ids(self, room_id, event_ids):
if not event_ids:
@@ -184,6 +235,19 @@ class StateStore(SQLBaseStore):
# We persist as a delta if we can, while also ensuring the chain
# of deltas isn't tooo long, as otherwise read performance degrades.
if context.prev_group:
+ is_in_db = self._simple_select_one_onecol_txn(
+ txn,
+ table="state_groups",
+ keyvalues={"id": context.prev_group},
+ retcol="id",
+ allow_none=True,
+ )
+ if not is_in_db:
+ raise Exception(
+ "Trying to persist state with unpersisted prev_group: %r"
+ % (context.prev_group,)
+ )
+
potential_hops = self._count_state_group_hops_txn(
txn, context.prev_group
)
@@ -251,6 +315,12 @@ class StateStore(SQLBaseStore):
],
)
+ for event_id, state_group_id in state_groups.iteritems():
+ txn.call_after(
+ self._get_state_group_for_event.prefill,
+ (event_id,), state_group_id
+ )
+
def _count_state_group_hops_txn(self, txn, state_group):
"""Given a state group, count how many hops there are in the tree.
@@ -520,8 +590,8 @@ class StateStore(SQLBaseStore):
state_map = yield self.get_state_ids_for_events([event_id], types)
defer.returnValue(state_map[event_id])
- @cached(num_args=2, max_entries=50000)
- def _get_state_group_for_event(self, room_id, event_id):
+ @cached(max_entries=50000)
+ def _get_state_group_for_event(self, event_id):
return self._simple_select_one_onecol(
table="event_to_state_groups",
keyvalues={
@@ -563,20 +633,22 @@ class StateStore(SQLBaseStore):
where a `state_key` of `None` matches all state_keys for the
`type`.
"""
- is_all, state_dict_ids = self._state_group_cache.get(group)
+ is_all, known_absent, state_dict_ids = self._state_group_cache.get(group)
type_to_key = {}
missing_types = set()
+
for typ, state_key in types:
+ key = (typ, state_key)
if state_key is None:
type_to_key[typ] = None
- missing_types.add((typ, state_key))
+ missing_types.add(key)
else:
if type_to_key.get(typ, object()) is not None:
type_to_key.setdefault(typ, set()).add(state_key)
- if (typ, state_key) not in state_dict_ids:
- missing_types.add((typ, state_key))
+ if key not in state_dict_ids and key not in known_absent:
+ missing_types.add(key)
sentinel = object()
@@ -590,7 +662,7 @@ class StateStore(SQLBaseStore):
return True
return False
- got_all = not (missing_types or types is None)
+ got_all = is_all or not missing_types
return {
k: v for k, v in state_dict_ids.iteritems()
@@ -607,7 +679,7 @@ class StateStore(SQLBaseStore):
Args:
group: The state group to lookup
"""
- is_all, state_dict_ids = self._state_group_cache.get(group)
+ is_all, _, state_dict_ids = self._state_group_cache.get(group)
return state_dict_ids, is_all
@@ -624,7 +696,7 @@ class StateStore(SQLBaseStore):
missing_groups = []
if types is not None:
for group in set(groups):
- state_dict_ids, missing_types, got_all = self._get_some_state_from_cache(
+ state_dict_ids, _, got_all = self._get_some_state_from_cache(
group, types
)
results[group] = state_dict_ids
@@ -653,19 +725,7 @@ class StateStore(SQLBaseStore):
# Now we want to update the cache with all the things we fetched
# from the database.
for group, group_state_dict in group_to_state_dict.iteritems():
- if types:
- # We delibrately put key -> None mappings into the cache to
- # cache absence of the key, on the assumption that if we've
- # explicitly asked for some types then we will probably ask
- # for them again.
- state_dict = {
- (intern_string(etype), intern_string(state_key)): None
- for (etype, state_key) in types
- }
- state_dict.update(results[group])
- results[group] = state_dict
- else:
- state_dict = results[group]
+ state_dict = results[group]
state_dict.update(
((intern_string(k[0]), intern_string(k[1])), to_ascii(v))
@@ -677,17 +737,9 @@ class StateStore(SQLBaseStore):
key=group,
value=state_dict,
full=(types is None),
+ known_absent=types,
)
- # Remove all the entries with None values. The None values were just
- # used for bookkeeping in the cache.
- for group, state_dict in results.iteritems():
- results[group] = {
- key: event_id
- for key, event_id in state_dict.iteritems()
- if event_id
- }
-
defer.returnValue(results)
def get_next_state_group(self):
|