diff options
Diffstat (limited to 'synapse/storage/state.py')
-rw-r--r-- | synapse/storage/state.py | 357 |
1 files changed, 296 insertions, 61 deletions
diff --git a/synapse/storage/state.py b/synapse/storage/state.py index f2b17f29ea..9630efcfcc 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -13,7 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore, cached +from ._base import SQLBaseStore +from synapse.util.caches.descriptors import ( + cached, cachedInlineCallbacks, cachedList +) from twisted.internet import defer @@ -44,72 +47,44 @@ class StateStore(SQLBaseStore): """ @defer.inlineCallbacks - def get_state_groups(self, event_ids): + def get_state_groups(self, room_id, event_ids): """ Get the state groups for the given list of event_ids The return value is a dict mapping group names to lists of events. """ + if not event_ids: + defer.returnValue({}) - def f(txn): - groups = set() - for event_id in event_ids: - group = self._simple_select_one_onecol_txn( - txn, - table="event_to_state_groups", - keyvalues={"event_id": event_id}, - retcol="state_group", - allow_none=True, - ) - if group: - groups.add(group) - - res = {} - for group in groups: - state_ids = self._simple_select_onecol_txn( - txn, - table="state_groups_state", - keyvalues={"state_group": group}, - retcol="event_id", - ) - - res[group] = state_ids - - return res - - states = yield self.runInteraction( - "get_state_groups", - f, + event_to_groups = yield self._get_state_group_for_events( + room_id, event_ids, ) - state_list = yield defer.gatherResults( - [ - self._fetch_events_for_group(group, vals) - for group, vals in states.items() - ], - consumeErrors=True, - ) + groups = set(event_to_groups.values()) + group_to_state = yield self._get_state_for_groups(groups) - defer.returnValue(dict(state_list)) - - @cached(num_args=1) - def _fetch_events_for_group(self, state_group, events): - return self._get_events( - events, get_prev_content=False - ).addCallback( - lambda evs: (state_group, evs) - ) + defer.returnValue({ + group: state_map.values() + for group, state_map in group_to_state.items() + }) def _store_state_groups_txn(self, txn, event, context): - if context.current_state is None: - return + return self._store_mult_state_groups_txn(txn, [(event, context)]) + + def _store_mult_state_groups_txn(self, txn, events_and_contexts): + state_groups = {} + for event, context in events_and_contexts: + if context.current_state is None: + continue + + if context.state_group is not None: + state_groups[event.event_id] = context.state_group + continue - state_events = dict(context.current_state) + state_events = dict(context.current_state) - if event.is_state(): - state_events[(event.type, event.state_key)] = event + if event.is_state(): + state_events[(event.type, event.state_key)] = event - state_group = context.state_group - if not state_group: state_group = self._state_groups_id_gen.get_next_txn(txn) self._simple_insert_txn( txn, @@ -135,14 +110,19 @@ class StateStore(SQLBaseStore): for state in state_events.values() ], ) + state_groups[event.event_id] = state_group - self._simple_insert_txn( + self._simple_insert_many_txn( txn, table="event_to_state_groups", - values={ - "state_group": state_group, - "event_id": event.event_id, - }, + values=[ + { + "state_group": state_groups[event.event_id], + "event_id": event.event_id, + } + for event, context in events_and_contexts + if context.current_state is not None + ], ) @defer.inlineCallbacks @@ -177,8 +157,7 @@ class StateStore(SQLBaseStore): events = yield self._get_events(event_ids, get_prev_content=False) defer.returnValue(events) - @cached(num_args=3) - @defer.inlineCallbacks + @cachedInlineCallbacks(num_args=3) def get_current_state_for_key(self, room_id, event_type, state_key): def f(txn): sql = ( @@ -194,6 +173,262 @@ class StateStore(SQLBaseStore): events = yield self._get_events(event_ids, get_prev_content=False) defer.returnValue(events) + def _get_state_groups_from_groups(self, groups_and_types): + """Returns dictionary state_group -> state event ids + + Args: + groups_and_types (list): list of 2-tuple (`group`, `types`) + """ + def f(txn): + results = {} + for group, types in groups_and_types: + if types is not None: + where_clause = "AND (%s)" % ( + " OR ".join(["(type = ? AND state_key = ?)"] * len(types)), + ) + else: + where_clause = "" + + sql = ( + "SELECT event_id FROM state_groups_state WHERE" + " state_group = ? %s" + ) % (where_clause,) + + args = [group] + if types is not None: + args.extend([i for typ in types for i in typ]) + + txn.execute(sql, args) + + results[group] = [r[0] for r in txn.fetchall()] + + return results + + return self.runInteraction( + "_get_state_groups_from_groups", + f, + ) + + @defer.inlineCallbacks + def get_state_for_events(self, room_id, event_ids, types): + """Given a list of event_ids and type tuples, return a list of state + dicts for each event. The state dicts will only have the type/state_keys + that are in the `types` list. + + Args: + room_id (str) + event_ids (list) + types (list): List of (type, state_key) tuples which are used to + filter the state fetched. `state_key` may be None, which matches + any `state_key` + + Returns: + deferred: A list of dicts corresponding to the event_ids given. + The dicts are mappings from (type, state_key) -> state_events + """ + event_to_groups = yield self._get_state_group_for_events( + room_id, event_ids, + ) + + groups = set(event_to_groups.values()) + group_to_state = yield self._get_state_for_groups(groups, types) + + event_to_state = { + event_id: group_to_state[group] + for event_id, group in event_to_groups.items() + } + + defer.returnValue({event: event_to_state[event] for event in event_ids}) + + @cached(num_args=2, lru=True, max_entries=10000) + def _get_state_group_for_event(self, room_id, event_id): + return self._simple_select_one_onecol( + table="event_to_state_groups", + keyvalues={ + "event_id": event_id, + }, + retcol="state_group", + allow_none=True, + desc="_get_state_group_for_event", + ) + + @cachedList(cache=_get_state_group_for_event.cache, list_name="event_ids", + num_args=2) + def _get_state_group_for_events(self, room_id, event_ids): + """Returns mapping event_id -> state_group + """ + def f(txn): + results = {} + for event_id in event_ids: + results[event_id] = self._simple_select_one_onecol_txn( + txn, + table="event_to_state_groups", + keyvalues={ + "event_id": event_id, + }, + retcol="state_group", + allow_none=True, + ) + + return results + + return self.runInteraction("_get_state_group_for_events", f) + + def _get_some_state_from_cache(self, group, types): + """Checks if group is in cache. See `_get_state_for_groups` + + Returns 3-tuple (`state_dict`, `missing_types`, `got_all`). + `missing_types` is the list of types that aren't in the cache for that + group. `got_all` is a bool indicating if we successfully retrieved all + requests state from the cache, if False we need to query the DB for the + missing state. + + Args: + group: The state group to lookup + types (list): List of 2-tuples of the form (`type`, `state_key`), + where a `state_key` of `None` matches all state_keys for the + `type`. + """ + is_all, state_dict = self._state_group_cache.get(group) + + type_to_key = {} + missing_types = set() + for typ, state_key in types: + if state_key is None: + type_to_key[typ] = None + missing_types.add((typ, state_key)) + else: + if type_to_key.get(typ, object()) is not None: + type_to_key.setdefault(typ, set()).add(state_key) + + if (typ, state_key) not in state_dict: + missing_types.add((typ, state_key)) + + sentinel = object() + + def include(typ, state_key): + valid_state_keys = type_to_key.get(typ, sentinel) + if valid_state_keys is sentinel: + return False + if valid_state_keys is None: + return True + if state_key in valid_state_keys: + return True + return False + + got_all = not (missing_types or types is None) + + return { + k: v for k, v in state_dict.items() + if include(k[0], k[1]) + }, missing_types, got_all + + def _get_all_state_from_cache(self, group): + """Checks if group is in cache. See `_get_state_for_groups` + + Returns 2-tuple (`state_dict`, `got_all`). `got_all` is a bool + indicating if we successfully retrieved all requests state from the + cache, if False we need to query the DB for the missing state. + + Args: + group: The state group to lookup + """ + is_all, state_dict = self._state_group_cache.get(group) + return state_dict, is_all + + @defer.inlineCallbacks + def _get_state_for_groups(self, groups, types=None): + """Given list of groups returns dict of group -> list of state events + with matching types. `types` is a list of `(type, state_key)`, where + a `state_key` of None matches all state_keys. If `types` is None then + all events are returned. + """ + results = {} + missing_groups_and_types = [] + if types is not None: + for group in set(groups): + state_dict, missing_types, got_all = self._get_some_state_from_cache( + group, types + ) + results[group] = state_dict + + if not got_all: + missing_groups_and_types.append((group, missing_types)) + else: + for group in set(groups): + state_dict, got_all = self._get_all_state_from_cache( + group + ) + results[group] = state_dict + + if not got_all: + missing_groups_and_types.append((group, None)) + + if not missing_groups_and_types: + defer.returnValue({ + group: { + type_tuple: event + for type_tuple, event in state.items() + if event + } + for group, state in results.items() + }) + + # Okay, so we have some missing_types, lets fetch them. + cache_seq_num = self._state_group_cache.sequence + + group_state_dict = yield self._get_state_groups_from_groups( + missing_groups_and_types + ) + + state_events = yield self._get_events( + [e_id for l in group_state_dict.values() for e_id in l], + get_prev_content=False + ) + + state_events = {e.event_id: e for e in state_events} + + # Now we want to update the cache with all the things we fetched + # from the database. + for group, state_ids in group_state_dict.items(): + if types: + # We delibrately put key -> None mappings into the cache to + # cache absence of the key, on the assumption that if we've + # explicitly asked for some types then we will probably ask + # for them again. + state_dict = {key: None for key in types} + state_dict.update(results[group]) + results[group] = state_dict + else: + state_dict = results[group] + + for event_id in state_ids: + try: + state_event = state_events[event_id] + state_dict[(state_event.type, state_event.state_key)] = state_event + except KeyError: + # Hmm. So we do don't have that state event? Interesting. + logger.warn( + "Can't find state event %r for state group %r", + event_id, group, + ) + + self._state_group_cache.update( + cache_seq_num, + key=group, + value=state_dict, + full=(types is None), + ) + + # Remove all the entries with None values. The None values were just + # used for bookkeeping in the cache. + for group, state_dict in results.items(): + results[group] = { + key: event for key, event in state_dict.items() if event + } + + defer.returnValue(results) + def _make_group_id(clock): return str(int(clock.time_msec())) + random_string(5) |