diff options
Diffstat (limited to 'synapse/storage/state.py')
-rw-r--r-- | synapse/storage/state.py | 349 |
1 files changed, 270 insertions, 79 deletions
diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 7ce51b9bdc..478b382867 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore, cachedInlineCallbacks +from ._base import SQLBaseStore, cached, cachedInlineCallbacks, cachedList from twisted.internet import defer @@ -44,52 +44,26 @@ class StateStore(SQLBaseStore): """ @defer.inlineCallbacks - def get_state_groups(self, event_ids): + def get_state_groups(self, room_id, event_ids): """ Get the state groups for the given list of event_ids The return value is a dict mapping group names to lists of events. """ + if not event_ids: + defer.returnValue({}) - def f(txn): - groups = set() - for event_id in event_ids: - group = self._simple_select_one_onecol_txn( - txn, - table="event_to_state_groups", - keyvalues={"event_id": event_id}, - retcol="state_group", - allow_none=True, - ) - if group: - groups.add(group) - - res = {} - for group in groups: - state_ids = self._simple_select_onecol_txn( - txn, - table="state_groups_state", - keyvalues={"state_group": group}, - retcol="event_id", - ) - - res[group] = state_ids - - return res - - states = yield self.runInteraction( - "get_state_groups", - f, + event_to_groups = yield self._get_state_group_for_events( + room_id, event_ids, ) - state_list = yield defer.gatherResults( - [ - self._fetch_events_for_group(group, vals) - for group, vals in states.items() - ], - consumeErrors=True, - ) + groups = set(event_to_groups.values()) - defer.returnValue(dict(state_list)) + group_to_state = yield self._get_state_for_groups(groups) + + defer.returnValue({ + group: state_map.values() + for group, state_map in group_to_state.items() + }) def _fetch_events_for_group(self, key, events): return self._get_events( @@ -204,64 +178,281 @@ class StateStore(SQLBaseStore): events = yield self._get_events(event_ids, get_prev_content=False) defer.returnValue(events) + def _get_state_groups_from_group(self, group, types): + def f(txn): + if types is not None: + where_clause = "AND (%s)" % ( + " OR ".join(["(type = ? AND state_key = ?)"] * len(types)), + ) + else: + where_clause = "" + + sql = ( + "SELECT event_id FROM state_groups_state WHERE" + " state_group = ? %s" + ) % (where_clause,) + + args = [group] + if types is not None: + args.extend([i for typ in types for i in typ]) + + txn.execute(sql, args) + + return [r[0] for r in txn.fetchall()] + + return self.runInteraction( + "_get_state_groups_from_group", + f, + ) + + def _get_state_groups_from_groups(self, groups_and_types): + def f(txn): + results = {} + for group, types in groups_and_types: + if types is not None: + where_clause = "AND (%s)" % ( + " OR ".join(["(type = ? AND state_key = ?)"] * len(types)), + ) + else: + where_clause = "" + + sql = ( + "SELECT event_id FROM state_groups_state WHERE" + " state_group = ? %s" + ) % (where_clause,) + + args = [group] + if types is not None: + args.extend([i for typ in types for i in typ]) + + txn.execute(sql, args) + + results[group] = [ + r[0] + for r in txn.fetchall() + ] + + return results + + return self.runInteraction( + "_get_state_groups_from_groups", + f, + ) + + @cached(num_args=3, lru=True, max_entries=10000) + def _get_state_for_event_id(self, room_id, event_id, types): + def f(txn): + type_and_state_sql = " OR ".join([ + "(type = ? AND state_key = ?)" + if typ[1] is not None + else "type = ?" + for typ in types + ]) + + sql = ( + "SELECT e.event_id, sg.state_group, sg.event_id" + " FROM state_groups_state as sg" + " INNER JOIN event_to_state_groups as e" + " ON e.state_group = sg.state_group" + " WHERE e.event_id = ? AND (%s)" + ) % (type_and_state_sql,) + + args = [event_id] + for typ, state_key in types: + args.extend( + [typ, state_key] if state_key is not None else [typ] + ) + txn.execute(sql, args) + + return event_id, [ + r[0] + for r in txn.fetchall() + ] + + return self.runInteraction( + "_get_state_for_event_id", + f, + ) + @defer.inlineCallbacks - def get_state_for_events(self, room_id, event_ids): + def get_state_for_events(self, room_id, event_ids, types): + """Given a list of event_ids and type tuples, return a list of state + dicts for each event. The state dicts will only have the type/state_keys + that are in the `types` list. + + Args: + room_id (str) + event_ids (list) + types (list): List of (type, state_key) tuples which are used to + filter the state fetched. `state_key` may be None, which matches + any `state_key` + + Returns: + deferred: A list of dicts corresponding to the event_ids given. + The dicts are mappings from (type, state_key) -> state_events + """ + event_to_groups = yield self._get_state_group_for_events( + room_id, event_ids, + ) + + groups = set(event_to_groups.values()) + + group_to_state = yield self._get_state_for_groups( + groups, types + ) + + event_to_state = { + event_id: group_to_state[group] + for event_id, group in event_to_groups.items() + } + + defer.returnValue([ + event_to_state[event] + for event in event_ids + ]) + + @cached(num_args=2, lru=True, max_entries=100000) + def _get_state_group_for_event(self, room_id, event_id): + return self._simple_select_one_onecol( + table="event_to_state_groups", + keyvalues={ + "event_id": event_id, + }, + retcol="state_group", + allow_none=True, + desc="_get_state_group_for_event", + ) + + @cachedList(cache=_get_state_group_for_event.cache, list_name="event_ids", num_args=2) + def _get_state_group_for_events(self, room_id, event_ids): def f(txn): - groups = set() - event_to_group = {} + results = {} for event_id in event_ids: - # TODO: Remove this loop. - group = self._simple_select_one_onecol_txn( + results[event_id] = self._simple_select_one_onecol_txn( txn, table="event_to_state_groups", - keyvalues={"event_id": event_id}, + keyvalues={ + "event_id": event_id, + }, retcol="state_group", allow_none=True, ) - if group: - event_to_group[event_id] = group - groups.add(group) - - group_to_state_ids = {} - for group in groups: - state_ids = self._simple_select_onecol_txn( - txn, - table="state_groups_state", - keyvalues={"state_group": group}, - retcol="event_id", - ) - - group_to_state_ids[group] = state_ids - return event_to_group, group_to_state_ids + return results - res = yield self.runInteraction( - "annotate_events_with_state_groups", + return self.runInteraction( + "_get_state_group_for_events", f, ) - event_to_group, group_to_state_ids = res + def _get_state_for_group_from_cache(self, group, types=None): + is_all, state_dict = self._state_group_cache.get(group) + + type_to_key = {} + missing_types = set() + if types is not None: + for typ, state_key in types: + if state_key is None: + type_to_key[typ] = None + missing_types.add((typ, state_key)) + else: + if type_to_key.get(typ, object()) is not None: + type_to_key.setdefault(typ, set()).add(state_key) + + if (typ, state_key) not in state_dict: + missing_types.add((typ, state_key)) + + if is_all and types is None: + return state_dict, missing_types + + if is_all or (types is not None and not missing_types): + sentinel = object() + + def include(typ, state_key): + valid_state_keys = type_to_key.get(typ, sentinel) + if valid_state_keys is sentinel: + return False + if valid_state_keys is None: + return True + if state_key in valid_state_keys: + return True + return False + + return { + k: v + for k, v in state_dict.items() + if v and include(k[0], k[1]) + }, missing_types + + return {}, missing_types - state_list = yield defer.gatherResults( - [ - self._fetch_events_for_group(group, vals) - for group, vals in group_to_state_ids.items() - ], - consumeErrors=True, + @defer.inlineCallbacks + def _get_state_for_groups(self, groups, types=None): + results = {} + missing_groups_and_types = [] + for group in groups: + state_dict, missing_types = self._get_state_for_group_from_cache( + group, types + ) + + if types is not None and not missing_types: + results[group] = { + key: value + for key, value in state_dict.items() + if value + } + else: + missing_groups_and_types.append(( + group, + missing_types if types else None + )) + + if not missing_groups_and_types: + defer.returnValue(results) + + # Okay, so we have some missing_types, lets fetch them. + cache_seq_num = self._state_group_cache.sequence + + group_state_dict = yield self._get_state_groups_from_groups( + missing_groups_and_types ) - state_dict = { - group: { - (ev.type, ev.state_key): ev - for ev in state - } - for group, state in state_list + state_events = yield self._get_events( + [e_id for l in group_state_dict.values() for e_id in l], + get_prev_content=False + ) + + state_events = { + e.event_id: e + for e in state_events } - defer.returnValue([ - state_dict.get(event_to_group.get(event, None), None) - for event in event_ids - ]) + for group, state_ids in group_state_dict.items(): + state_dict = { + key: None + for key in missing_types + } + evs = [state_events[e_id] for e_id in state_ids] + state_dict.update({ + (e.type, e.state_key): e + for e in evs + }) + + # Update the cache + self._state_group_cache.update( + cache_seq_num, + key=group, + value=state_dict, + full=(types is None), + ) + + results[group] = { + key: value + for key, value in state_dict.items() + if value + } + + defer.returnValue(results) def _make_group_id(clock): |