diff options
Diffstat (limited to 'synapse/storage/state.py')
-rw-r--r-- | synapse/storage/state.py | 327 |
1 files changed, 241 insertions, 86 deletions
diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 7ce51b9bdc..185f88fd7c 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -13,7 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore, cachedInlineCallbacks +from ._base import SQLBaseStore +from synapse.util.caches.descriptors import ( + cached, cachedInlineCallbacks, cachedList +) from twisted.internet import defer @@ -44,59 +47,25 @@ class StateStore(SQLBaseStore): """ @defer.inlineCallbacks - def get_state_groups(self, event_ids): + def get_state_groups(self, room_id, event_ids): """ Get the state groups for the given list of event_ids The return value is a dict mapping group names to lists of events. """ + if not event_ids: + defer.returnValue({}) - def f(txn): - groups = set() - for event_id in event_ids: - group = self._simple_select_one_onecol_txn( - txn, - table="event_to_state_groups", - keyvalues={"event_id": event_id}, - retcol="state_group", - allow_none=True, - ) - if group: - groups.add(group) - - res = {} - for group in groups: - state_ids = self._simple_select_onecol_txn( - txn, - table="state_groups_state", - keyvalues={"state_group": group}, - retcol="event_id", - ) - - res[group] = state_ids - - return res - - states = yield self.runInteraction( - "get_state_groups", - f, + event_to_groups = yield self._get_state_group_for_events( + room_id, event_ids, ) - state_list = yield defer.gatherResults( - [ - self._fetch_events_for_group(group, vals) - for group, vals in states.items() - ], - consumeErrors=True, - ) + groups = set(event_to_groups.values()) + group_to_state = yield self._get_state_for_groups(groups) - defer.returnValue(dict(state_list)) - - def _fetch_events_for_group(self, key, events): - return self._get_events( - events, get_prev_content=False - ).addCallback( - lambda evs: (key, evs) - ) + defer.returnValue({ + group: state_map.values() + for group, state_map in group_to_state.items() + }) def _store_state_groups_txn(self, txn, event, context): return self._store_mult_state_groups_txn(txn, [(event, context)]) @@ -204,64 +173,250 @@ class StateStore(SQLBaseStore): events = yield self._get_events(event_ids, get_prev_content=False) defer.returnValue(events) + def _get_state_groups_from_groups(self, groups_and_types): + """Returns dictionary state_group -> state event ids + + Args: + groups_and_types (list): list of 2-tuple (`group`, `types`) + """ + def f(txn): + results = {} + for group, types in groups_and_types: + if types is not None: + where_clause = "AND (%s)" % ( + " OR ".join(["(type = ? AND state_key = ?)"] * len(types)), + ) + else: + where_clause = "" + + sql = ( + "SELECT event_id FROM state_groups_state WHERE" + " state_group = ? %s" + ) % (where_clause,) + + args = [group] + if types is not None: + args.extend([i for typ in types for i in typ]) + + txn.execute(sql, args) + + results[group] = [r[0] for r in txn.fetchall()] + + return results + + return self.runInteraction( + "_get_state_groups_from_groups", + f, + ) + @defer.inlineCallbacks - def get_state_for_events(self, room_id, event_ids): + def get_state_for_events(self, room_id, event_ids, types): + """Given a list of event_ids and type tuples, return a list of state + dicts for each event. The state dicts will only have the type/state_keys + that are in the `types` list. + + Args: + room_id (str) + event_ids (list) + types (list): List of (type, state_key) tuples which are used to + filter the state fetched. `state_key` may be None, which matches + any `state_key` + + Returns: + deferred: A list of dicts corresponding to the event_ids given. + The dicts are mappings from (type, state_key) -> state_events + """ + event_to_groups = yield self._get_state_group_for_events( + room_id, event_ids, + ) + + groups = set(event_to_groups.values()) + group_to_state = yield self._get_state_for_groups(groups, types) + + event_to_state = { + event_id: group_to_state[group] + for event_id, group in event_to_groups.items() + } + + defer.returnValue({event: event_to_state[event] for event in event_ids}) + + @cached(num_args=2, lru=True, max_entries=100000) + def _get_state_group_for_event(self, room_id, event_id): + return self._simple_select_one_onecol( + table="event_to_state_groups", + keyvalues={ + "event_id": event_id, + }, + retcol="state_group", + allow_none=True, + desc="_get_state_group_for_event", + ) + + @cachedList(cache=_get_state_group_for_event.cache, list_name="event_ids", + num_args=2) + def _get_state_group_for_events(self, room_id, event_ids): + """Returns mapping event_id -> state_group + """ def f(txn): - groups = set() - event_to_group = {} + results = {} for event_id in event_ids: - # TODO: Remove this loop. - group = self._simple_select_one_onecol_txn( + results[event_id] = self._simple_select_one_onecol_txn( txn, table="event_to_state_groups", - keyvalues={"event_id": event_id}, + keyvalues={ + "event_id": event_id, + }, retcol="state_group", allow_none=True, ) - if group: - event_to_group[event_id] = group - groups.add(group) - group_to_state_ids = {} - for group in groups: - state_ids = self._simple_select_onecol_txn( - txn, - table="state_groups_state", - keyvalues={"state_group": group}, - retcol="event_id", + return results + + return self.runInteraction("_get_state_group_for_events", f) + + def _get_some_state_from_cache(self, group, types): + """Checks if group is in cache. See `_get_state_for_groups` + + Returns 3-tuple (`state_dict`, `missing_types`, `got_all`). + `missing_types` is the list of types that aren't in the cache for that + group. `got_all` is a bool indicating if we successfully retrieved all + requests state from the cache, if False we need to query the DB for the + missing state. + + Args: + group: The state group to lookup + types (list): List of 2-tuples of the form (`type`, `state_key`), + where a `state_key` of `None` matches all state_keys for the + `type`. + """ + is_all, state_dict = self._state_group_cache.get(group) + + type_to_key = {} + missing_types = set() + for typ, state_key in types: + if state_key is None: + type_to_key[typ] = None + missing_types.add((typ, state_key)) + else: + if type_to_key.get(typ, object()) is not None: + type_to_key.setdefault(typ, set()).add(state_key) + + if (typ, state_key) not in state_dict: + missing_types.add((typ, state_key)) + + sentinel = object() + + def include(typ, state_key): + valid_state_keys = type_to_key.get(typ, sentinel) + if valid_state_keys is sentinel: + return False + if valid_state_keys is None: + return True + if state_key in valid_state_keys: + return True + return False + + got_all = not (missing_types or types is None) + + return { + k: v for k, v in state_dict.items() + if include(k[0], k[1]) + }, missing_types, got_all + + def _get_all_state_from_cache(self, group): + """Checks if group is in cache. See `_get_state_for_groups` + + Returns 2-tuple (`state_dict`, `got_all`). `got_all` is a bool + indicating if we successfully retrieved all requests state from the + cache, if False we need to query the DB for the missing state. + + Args: + group: The state group to lookup + """ + is_all, state_dict = self._state_group_cache.get(group) + return state_dict, is_all + + @defer.inlineCallbacks + def _get_state_for_groups(self, groups, types=None): + """Given list of groups returns dict of group -> list of state events + with matching types. `types` is a list of `(type, state_key)`, where + a `state_key` of None matches all state_keys. If `types` is None then + all events are returned. + """ + results = {} + missing_groups_and_types = [] + if types is not None: + for group in set(groups): + state_dict, missing_types, got_all = self._get_some_state_from_cache( + group, types + ) + results[group] = state_dict + + if not got_all: + missing_groups_and_types.append((group, missing_types)) + else: + for group in set(groups): + state_dict, got_all = self._get_all_state_from_cache( + group ) + results[group] = state_dict - group_to_state_ids[group] = state_ids + if not got_all: + missing_groups_and_types.append((group, None)) - return event_to_group, group_to_state_ids + if not missing_groups_and_types: + defer.returnValue({ + group: { + type_tuple: event + for type_tuple, event in state.items() + if event + } + for group, state in results.items() + }) - res = yield self.runInteraction( - "annotate_events_with_state_groups", - f, - ) + # Okay, so we have some missing_types, lets fetch them. + cache_seq_num = self._state_group_cache.sequence - event_to_group, group_to_state_ids = res + group_state_dict = yield self._get_state_groups_from_groups( + missing_groups_and_types + ) - state_list = yield defer.gatherResults( - [ - self._fetch_events_for_group(group, vals) - for group, vals in group_to_state_ids.items() - ], - consumeErrors=True, + state_events = yield self._get_events( + [e_id for l in group_state_dict.values() for e_id in l], + get_prev_content=False ) - state_dict = { - group: { - (ev.type, ev.state_key): ev - for ev in state - } - for group, state in state_list - } + state_events = {e.event_id: e for e in state_events} + + # Now we want to update the cache with all the things we fetched + # from the database. + for group, state_ids in group_state_dict.items(): + if types: + # We delibrately put key -> None mappings into the cache to + # cache absence of the key, on the assumption that if we've + # explicitly asked for some types then we will probably ask + # for them again. + state_dict = {key: None for key in types} + state_dict.update(results[group]) + else: + state_dict = results[group] + + for event_id in state_ids: + state_event = state_events[event_id] + state_dict[(state_event.type, state_event.state_key)] = state_event + + self._state_group_cache.update( + cache_seq_num, + key=group, + value=state_dict, + full=(types is None), + ) + + results[group].update({ + key: value for key, value in state_dict.items() if value + }) - defer.returnValue([ - state_dict.get(event_to_group.get(event, None), None) - for event in event_ids - ]) + defer.returnValue(results) def _make_group_id(clock): |