diff options
Diffstat (limited to 'synapse/storage/state.py')
-rw-r--r-- | synapse/storage/state.py | 116 |
1 files changed, 87 insertions, 29 deletions
diff --git a/synapse/storage/state.py b/synapse/storage/state.py index a16afa8df5..d1e679719b 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -20,6 +20,7 @@ from synapse.util.stringutils import to_ascii from synapse.storage.engines import PostgresEngine from twisted.internet import defer +from collections import namedtuple import logging @@ -29,6 +30,16 @@ logger = logging.getLogger(__name__) MAX_STATE_DELTA_HOPS = 100 +class _GetStateGroupDelta(namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids"))): + """Return type of get_state_group_delta that implements __len__, which lets + us use the itrable flag when caching + """ + __slots__ = [] + + def __len__(self): + return len(self.delta_ids) if self.delta_ids else 0 + + class StateStore(SQLBaseStore): """ Keeps track of the state at a given event. @@ -98,6 +109,46 @@ class StateStore(SQLBaseStore): _get_current_state_ids_txn, ) + @cached(max_entries=10000, iterable=True) + def get_state_group_delta(self, state_group): + """Given a state group try to return a previous group and a delta between + the old and the new. + + Returns: + (prev_group, delta_ids), where both may be None. + """ + def _get_state_group_delta_txn(txn): + prev_group = self._simple_select_one_onecol_txn( + txn, + table="state_group_edges", + keyvalues={ + "state_group": state_group, + }, + retcol="prev_state_group", + allow_none=True, + ) + + if not prev_group: + return _GetStateGroupDelta(None, None) + + delta_ids = self._simple_select_list_txn( + txn, + table="state_groups_state", + keyvalues={ + "state_group": state_group, + }, + retcols=("type", "state_key", "event_id",) + ) + + return _GetStateGroupDelta(prev_group, { + (row["type"], row["state_key"]): row["event_id"] + for row in delta_ids + }) + return self.runInteraction( + "get_state_group_delta", + _get_state_group_delta_txn, + ) + @defer.inlineCallbacks def get_state_groups_ids(self, room_id, event_ids): if not event_ids: @@ -184,6 +235,19 @@ class StateStore(SQLBaseStore): # We persist as a delta if we can, while also ensuring the chain # of deltas isn't tooo long, as otherwise read performance degrades. if context.prev_group: + is_in_db = self._simple_select_one_onecol_txn( + txn, + table="state_groups", + keyvalues={"id": context.prev_group}, + retcol="id", + allow_none=True, + ) + if not is_in_db: + raise Exception( + "Trying to persist state with unpersisted prev_group: %r" + % (context.prev_group,) + ) + potential_hops = self._count_state_group_hops_txn( txn, context.prev_group ) @@ -227,6 +291,18 @@ class StateStore(SQLBaseStore): ], ) + # Prefill the state group cache with this group. + # It's fine to use the sequence like this as the state group map + # is immutable. (If the map wasn't immutable then this prefill could + # race with another update) + txn.call_after( + self._state_group_cache.update, + self._state_group_cache.sequence, + key=context.state_group, + value=dict(context.current_state_ids), + full=True, + ) + self._simple_insert_many_txn( txn, table="event_to_state_groups", @@ -551,20 +627,22 @@ class StateStore(SQLBaseStore): where a `state_key` of `None` matches all state_keys for the `type`. """ - is_all, state_dict_ids = self._state_group_cache.get(group) + is_all, known_absent, state_dict_ids = self._state_group_cache.get(group) type_to_key = {} missing_types = set() + for typ, state_key in types: + key = (typ, state_key) if state_key is None: type_to_key[typ] = None - missing_types.add((typ, state_key)) + missing_types.add(key) else: if type_to_key.get(typ, object()) is not None: type_to_key.setdefault(typ, set()).add(state_key) - if (typ, state_key) not in state_dict_ids: - missing_types.add((typ, state_key)) + if key not in state_dict_ids and key not in known_absent: + missing_types.add(key) sentinel = object() @@ -578,7 +656,7 @@ class StateStore(SQLBaseStore): return True return False - got_all = not (missing_types or types is None) + got_all = is_all or not missing_types return { k: v for k, v in state_dict_ids.iteritems() @@ -595,7 +673,7 @@ class StateStore(SQLBaseStore): Args: group: The state group to lookup """ - is_all, state_dict_ids = self._state_group_cache.get(group) + is_all, _, state_dict_ids = self._state_group_cache.get(group) return state_dict_ids, is_all @@ -612,7 +690,7 @@ class StateStore(SQLBaseStore): missing_groups = [] if types is not None: for group in set(groups): - state_dict_ids, missing_types, got_all = self._get_some_state_from_cache( + state_dict_ids, _, got_all = self._get_some_state_from_cache( group, types ) results[group] = state_dict_ids @@ -641,19 +719,7 @@ class StateStore(SQLBaseStore): # Now we want to update the cache with all the things we fetched # from the database. for group, group_state_dict in group_to_state_dict.iteritems(): - if types: - # We delibrately put key -> None mappings into the cache to - # cache absence of the key, on the assumption that if we've - # explicitly asked for some types then we will probably ask - # for them again. - state_dict = { - (intern_string(etype), intern_string(state_key)): None - for (etype, state_key) in types - } - state_dict.update(results[group]) - results[group] = state_dict - else: - state_dict = results[group] + state_dict = results[group] state_dict.update( ((intern_string(k[0]), intern_string(k[1])), to_ascii(v)) @@ -665,17 +731,9 @@ class StateStore(SQLBaseStore): key=group, value=state_dict, full=(types is None), + known_absent=types, ) - # Remove all the entries with None values. The None values were just - # used for bookkeeping in the cache. - for group, state_dict in results.iteritems(): - results[group] = { - key: event_id - for key, event_id in state_dict.iteritems() - if event_id - } - defer.returnValue(results) def get_next_state_group(self): |