diff options
Diffstat (limited to 'synapse/storage/state.py')
-rw-r--r-- | synapse/storage/state.py | 114 |
1 files changed, 72 insertions, 42 deletions
diff --git a/synapse/storage/state.py b/synapse/storage/state.py index ffa4246031..89a05c4618 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -13,17 +13,21 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple import logging +from collections import namedtuple + +from six import iteritems, itervalues +from six.moves import range from twisted.internet import defer from synapse.storage.background_updates import BackgroundUpdateStore from synapse.storage.engines import PostgresEngine -from synapse.util.caches import intern_string, CACHE_SIZE_FACTOR +from synapse.util.caches import get_cache_factor_for, intern_string from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.dictionary_cache import DictionaryCache from synapse.util.stringutils import to_ascii + from ._base import SQLBaseStore logger = logging.getLogger(__name__) @@ -54,7 +58,7 @@ class StateGroupWorkerStore(SQLBaseStore): super(StateGroupWorkerStore, self).__init__(db_conn, hs) self._state_group_cache = DictionaryCache( - "*stateGroupCache*", 100000 * CACHE_SIZE_FACTOR + "*stateGroupCache*", 500000 * get_cache_factor_for("stateGroupCache") ) @cached(max_entries=100000, iterable=True) @@ -134,7 +138,7 @@ class StateGroupWorkerStore(SQLBaseStore): event_ids, ) - groups = set(event_to_groups.itervalues()) + groups = set(itervalues(event_to_groups)) group_to_state = yield self._get_state_for_groups(groups) defer.returnValue(group_to_state) @@ -166,18 +170,18 @@ class StateGroupWorkerStore(SQLBaseStore): state_event_map = yield self.get_events( [ - ev_id for group_ids in group_to_ids.itervalues() - for ev_id in group_ids.itervalues() + ev_id for group_ids in itervalues(group_to_ids) + for ev_id in itervalues(group_ids) ], get_prev_content=False ) defer.returnValue({ group: [ - state_event_map[v] for v in event_id_map.itervalues() + state_event_map[v] for v in itervalues(event_id_map) if v in state_event_map ] - for group, event_id_map in group_to_ids.iteritems() + for group, event_id_map in iteritems(group_to_ids) }) @defer.inlineCallbacks @@ -186,7 +190,7 @@ class StateGroupWorkerStore(SQLBaseStore): """ results = {} - chunks = [groups[i:i + 100] for i in xrange(0, len(groups), 100)] + chunks = [groups[i:i + 100] for i in range(0, len(groups), 100)] for chunk in chunks: res = yield self.runInteraction( "_get_state_groups_from_groups", @@ -269,7 +273,7 @@ class StateGroupWorkerStore(SQLBaseStore): for typ in types: if typ[1] is None: where_clauses.append("(type = ?)") - where_args.extend(typ[0]) + where_args.append(typ[0]) wildcard_types = True else: where_clauses.append("(type = ? AND state_key = ?)") @@ -347,21 +351,21 @@ class StateGroupWorkerStore(SQLBaseStore): event_ids, ) - groups = set(event_to_groups.itervalues()) + groups = set(itervalues(event_to_groups)) group_to_state = yield self._get_state_for_groups(groups, types) state_event_map = yield self.get_events( - [ev_id for sd in group_to_state.itervalues() for ev_id in sd.itervalues()], + [ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)], get_prev_content=False ) event_to_state = { event_id: { k: state_event_map[v] - for k, v in group_to_state[group].iteritems() + for k, v in iteritems(group_to_state[group]) if v in state_event_map } - for event_id, group in event_to_groups.iteritems() + for event_id, group in iteritems(event_to_groups) } defer.returnValue({event: event_to_state[event] for event in event_ids}) @@ -384,12 +388,12 @@ class StateGroupWorkerStore(SQLBaseStore): event_ids, ) - groups = set(event_to_groups.itervalues()) + groups = set(itervalues(event_to_groups)) group_to_state = yield self._get_state_for_groups(groups, types) event_to_state = { event_id: group_to_state[group] - for event_id, group in event_to_groups.iteritems() + for event_id, group in iteritems(event_to_groups) } defer.returnValue({event: event_to_state[event] for event in event_ids}) @@ -503,7 +507,7 @@ class StateGroupWorkerStore(SQLBaseStore): got_all = is_all or not missing_types return { - k: v for k, v in state_dict_ids.iteritems() + k: v for k, v in iteritems(state_dict_ids) if include(k[0], k[1]) }, missing_types, got_all @@ -523,10 +527,23 @@ class StateGroupWorkerStore(SQLBaseStore): @defer.inlineCallbacks def _get_state_for_groups(self, groups, types=None): - """Given list of groups returns dict of group -> list of state events - with matching types. `types` is a list of `(type, state_key)`, where - a `state_key` of None matches all state_keys. If `types` is None then - all events are returned. + """Gets the state at each of a list of state groups, optionally + filtering by type/state_key + + Args: + groups (iterable[int]): list of state groups for which we want + to get the state. + types (None|iterable[(str, None|str)]): + indicates the state type/keys required. If None, the whole + state is fetched and returned. + + Otherwise, each entry should be a `(type, state_key)` tuple to + include in the response. A `state_key` of None is a wildcard + meaning that we require all state with that type. + + Returns: + Deferred[dict[int, dict[(type, state_key), EventBase]]] + a dictionary mapping from state group to state dictionary. """ if types: types = frozenset(types) @@ -535,7 +552,7 @@ class StateGroupWorkerStore(SQLBaseStore): if types is not None: for group in set(groups): state_dict_ids, _, got_all = self._get_some_state_from_cache( - group, types + group, types, ) results[group] = state_dict_ids @@ -556,26 +573,40 @@ class StateGroupWorkerStore(SQLBaseStore): # Okay, so we have some missing_types, lets fetch them. cache_seq_num = self._state_group_cache.sequence + # the DictionaryCache knows if it has *all* the state, but + # does not know if it has all of the keys of a particular type, + # which makes wildcard lookups expensive unless we have a complete + # cache. Hence, if we are doing a wildcard lookup, populate the + # cache fully so that we can do an efficient lookup next time. + + if types and any(k is None for (t, k) in types): + types_to_fetch = None + else: + types_to_fetch = types + group_to_state_dict = yield self._get_state_groups_from_groups( - missing_groups, types + missing_groups, types_to_fetch, ) - # Now we want to update the cache with all the things we fetched - # from the database. - for group, group_state_dict in group_to_state_dict.iteritems(): + for group, group_state_dict in iteritems(group_to_state_dict): state_dict = results[group] - state_dict.update( - ((intern_string(k[0]), intern_string(k[1])), to_ascii(v)) - for k, v in group_state_dict.iteritems() - ) - + # update the result, filtering by `types`. + if types: + for k, v in iteritems(group_state_dict): + (typ, _) = k + if k in types or (typ, None) in types: + state_dict[k] = v + else: + state_dict.update(group_state_dict) + + # update the cache with all the things we fetched from the + # database. self._state_group_cache.update( cache_seq_num, key=group, - value=state_dict, - full=(types is None), - known_absent=types, + value=group_state_dict, + fetched_keys=types_to_fetch, ) defer.returnValue(results) @@ -654,7 +685,7 @@ class StateGroupWorkerStore(SQLBaseStore): "state_key": key[1], "event_id": state_id, } - for key, state_id in delta_ids.iteritems() + for key, state_id in iteritems(delta_ids) ], ) else: @@ -669,7 +700,7 @@ class StateGroupWorkerStore(SQLBaseStore): "state_key": key[1], "event_id": state_id, } - for key, state_id in current_state_ids.iteritems() + for key, state_id in iteritems(current_state_ids) ], ) @@ -682,7 +713,6 @@ class StateGroupWorkerStore(SQLBaseStore): self._state_group_cache.sequence, key=state_group, value=dict(current_state_ids), - full=True, ) return state_group @@ -794,11 +824,11 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore): "state_group": state_group_id, "event_id": event_id, } - for event_id, state_group_id in state_groups.iteritems() + for event_id, state_group_id in iteritems(state_groups) ], ) - for event_id, state_group_id in state_groups.iteritems(): + for event_id, state_group_id in iteritems(state_groups): txn.call_after( self._get_state_group_for_event.prefill, (event_id,), state_group_id @@ -826,7 +856,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore): def reindex_txn(txn): new_last_state_group = last_state_group - for count in xrange(batch_size): + for count in range(batch_size): txn.execute( "SELECT id, room_id FROM state_groups" " WHERE ? < id AND id <= ?" @@ -884,7 +914,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore): # of keys delta_state = { - key: value for key, value in curr_state.iteritems() + key: value for key, value in iteritems(curr_state) if prev_state.get(key, None) != value } @@ -924,7 +954,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore): "state_key": key[1], "event_id": state_id, } - for key, state_id in delta_state.iteritems() + for key, state_id in iteritems(delta_state) ], ) |