diff options
Diffstat (limited to 'synapse/storage/state.py')
-rw-r--r-- | synapse/storage/state.py | 128 |
1 files changed, 61 insertions, 67 deletions
diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 8ed8a21b0a..02cefdff26 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -14,9 +14,8 @@ # limitations under the License. from ._base import SQLBaseStore -from synapse.util.caches.descriptors import ( - cached, cachedInlineCallbacks, cachedList -) +from synapse.util.caches.descriptors import cached, cachedList +from synapse.util.caches import intern_string from twisted.internet import defer @@ -155,8 +154,14 @@ class StateStore(SQLBaseStore): events = yield self._get_events(event_ids, get_prev_content=False) defer.returnValue(events) - @cachedInlineCallbacks(num_args=3) + @defer.inlineCallbacks def get_current_state_for_key(self, room_id, event_type, state_key): + event_ids = yield self._get_current_state_for_key(room_id, event_type, state_key) + events = yield self._get_events(event_ids, get_prev_content=False) + defer.returnValue(events) + + @cached(num_args=3) + def _get_current_state_for_key(self, room_id, event_type, state_key): def f(txn): sql = ( "SELECT event_id FROM current_state_events" @@ -167,12 +172,10 @@ class StateStore(SQLBaseStore): txn.execute(sql, args) results = txn.fetchall() return [r[0] for r in results] - event_ids = yield self.runInteraction("get_current_state_for_key", f) - events = yield self._get_events(event_ids, get_prev_content=False) - defer.returnValue(events) + return self.runInteraction("get_current_state_for_key", f) def _get_state_groups_from_groups(self, groups, types): - """Returns dictionary state_group -> state event ids + """Returns dictionary state_group -> (dict of (type, state_key) -> event id) """ def f(txn, groups): if types is not None: @@ -183,7 +186,8 @@ class StateStore(SQLBaseStore): where_clause = "" sql = ( - "SELECT state_group, event_id FROM state_groups_state WHERE" + "SELECT state_group, event_id, type, state_key" + " FROM state_groups_state WHERE" " state_group IN (%s) %s" % ( ",".join("?" for _ in groups), where_clause, @@ -199,7 +203,8 @@ class StateStore(SQLBaseStore): results = {} for row in rows: - results.setdefault(row["state_group"], []).append(row["event_id"]) + key = (row["type"], row["state_key"]) + results.setdefault(row["state_group"], {})[key] = row["event_id"] return results chunks = [groups[i:i + 100] for i in xrange(0, len(groups), 100)] @@ -296,7 +301,7 @@ class StateStore(SQLBaseStore): where a `state_key` of `None` matches all state_keys for the `type`. """ - is_all, state_dict = self._state_group_cache.get(group) + is_all, state_dict_ids = self._state_group_cache.get(group) type_to_key = {} missing_types = set() @@ -308,7 +313,7 @@ class StateStore(SQLBaseStore): if type_to_key.get(typ, object()) is not None: type_to_key.setdefault(typ, set()).add(state_key) - if (typ, state_key) not in state_dict: + if (typ, state_key) not in state_dict_ids: missing_types.add((typ, state_key)) sentinel = object() @@ -326,7 +331,7 @@ class StateStore(SQLBaseStore): got_all = not (missing_types or types is None) return { - k: v for k, v in state_dict.items() + k: v for k, v in state_dict_ids.items() if include(k[0], k[1]) }, missing_types, got_all @@ -340,8 +345,9 @@ class StateStore(SQLBaseStore): Args: group: The state group to lookup """ - is_all, state_dict = self._state_group_cache.get(group) - return state_dict, is_all + is_all, state_dict_ids = self._state_group_cache.get(group) + + return state_dict_ids, is_all @defer.inlineCallbacks def _get_state_for_groups(self, groups, types=None): @@ -354,84 +360,72 @@ class StateStore(SQLBaseStore): missing_groups = [] if types is not None: for group in set(groups): - state_dict, missing_types, got_all = self._get_some_state_from_cache( + state_dict_ids, missing_types, got_all = self._get_some_state_from_cache( group, types ) - results[group] = state_dict + results[group] = state_dict_ids if not got_all: missing_groups.append(group) else: for group in set(groups): - state_dict, got_all = self._get_all_state_from_cache( + state_dict_ids, got_all = self._get_all_state_from_cache( group ) - results[group] = state_dict + + results[group] = state_dict_ids if not got_all: missing_groups.append(group) - if not missing_groups: - defer.returnValue({ - group: { - type_tuple: event - for type_tuple, event in state.items() - if event - } - for group, state in results.items() - }) + if missing_groups: + # Okay, so we have some missing_types, lets fetch them. + cache_seq_num = self._state_group_cache.sequence - # Okay, so we have some missing_types, lets fetch them. - cache_seq_num = self._state_group_cache.sequence + group_to_state_dict = yield self._get_state_groups_from_groups( + missing_groups, types + ) - group_state_dict = yield self._get_state_groups_from_groups( - missing_groups, types - ) + # Now we want to update the cache with all the things we fetched + # from the database. + for group, group_state_dict in group_to_state_dict.items(): + if types: + # We delibrately put key -> None mappings into the cache to + # cache absence of the key, on the assumption that if we've + # explicitly asked for some types then we will probably ask + # for them again. + state_dict = { + (intern_string(etype), intern_string(state_key)): None + for (etype, state_key) in types + } + state_dict.update(results[group]) + results[group] = state_dict + else: + state_dict = results[group] + + state_dict.update(group_state_dict) + + self._state_group_cache.update( + cache_seq_num, + key=group, + value=state_dict, + full=(types is None), + ) state_events = yield self._get_events( - [e_id for l in group_state_dict.values() for e_id in l], + [ev_id for sd in results.values() for ev_id in sd.values()], get_prev_content=False ) state_events = {e.event_id: e for e in state_events} - # Now we want to update the cache with all the things we fetched - # from the database. - for group, state_ids in group_state_dict.items(): - if types: - # We delibrately put key -> None mappings into the cache to - # cache absence of the key, on the assumption that if we've - # explicitly asked for some types then we will probably ask - # for them again. - state_dict = {key: None for key in types} - state_dict.update(results[group]) - results[group] = state_dict - else: - state_dict = results[group] - - for event_id in state_ids: - try: - state_event = state_events[event_id] - state_dict[(state_event.type, state_event.state_key)] = state_event - except KeyError: - # Hmm. So we do don't have that state event? Interesting. - logger.warn( - "Can't find state event %r for state group %r", - event_id, group, - ) - - self._state_group_cache.update( - cache_seq_num, - key=group, - value=state_dict, - full=(types is None), - ) - # Remove all the entries with None values. The None values were just # used for bookkeeping in the cache. for group, state_dict in results.items(): results[group] = { - key: event for key, event in state_dict.items() if event + key: state_events[event_id] + for key, event_id in state_dict.items() + if event_id and event_id in state_events } defer.returnValue(results) |