diff options
Diffstat (limited to 'synapse/storage/state.py')
-rw-r--r-- | synapse/storage/state.py | 82 |
1 files changed, 64 insertions, 18 deletions
diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 02cefdff26..5b743db67a 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -64,12 +64,12 @@ class StateStore(SQLBaseStore): for group, state_map in group_to_state.items() }) - def _store_state_groups_txn(self, txn, event, context): - return self._store_mult_state_groups_txn(txn, [(event, context)]) - def _store_mult_state_groups_txn(self, txn, events_and_contexts): state_groups = {} for event, context in events_and_contexts: + if event.internal_metadata.is_outlier(): + continue + if context.current_state is None: continue @@ -82,7 +82,8 @@ class StateStore(SQLBaseStore): if event.is_state(): state_events[(event.type, event.state_key)] = event - state_group = self._state_groups_id_gen.get_next() + state_group = context.new_state_group_id + self._simple_insert_txn( txn, table="state_groups", @@ -114,11 +115,10 @@ class StateStore(SQLBaseStore): table="event_to_state_groups", values=[ { - "state_group": state_groups[event.event_id], - "event_id": event.event_id, + "state_group": state_group_id, + "event_id": event_id, } - for event, context in events_and_contexts - if context.current_state is not None + for event_id, state_group_id in state_groups.items() ], ) @@ -174,6 +174,12 @@ class StateStore(SQLBaseStore): return [r[0] for r in results] return self.runInteraction("get_current_state_for_key", f) + @cached(num_args=2, lru=True, max_entries=1000) + def _get_state_group_from_group(self, group, types): + raise NotImplementedError() + + @cachedList(cached_method_name="_get_state_group_from_group", + list_name="groups", num_args=2, inlineCallbacks=True) def _get_state_groups_from_groups(self, groups, types): """Returns dictionary state_group -> (dict of (type, state_key) -> event id) """ @@ -201,18 +207,23 @@ class StateStore(SQLBaseStore): txn.execute(sql, args) rows = self.cursor_to_dict(txn) - results = {} + results = {group: {} for group in groups} for row in rows: key = (row["type"], row["state_key"]) - results.setdefault(row["state_group"], {})[key] = row["event_id"] + results[row["state_group"]][key] = row["event_id"] return results + results = {} + chunks = [groups[i:i + 100] for i in xrange(0, len(groups), 100)] for chunk in chunks: - return self.runInteraction( + res = yield self.runInteraction( "_get_state_groups_from_groups", f, chunk ) + results.update(res) + + defer.returnValue(results) @defer.inlineCallbacks def get_state_for_events(self, event_ids, types): @@ -249,11 +260,14 @@ class StateStore(SQLBaseStore): """ Get the state dict corresponding to a particular event - :param str event_id: event whose state should be returned - :param list[(str, str)]|None types: List of (type, state_key) tuples - which are used to filter the state fetched. May be None, which - matches any key - :return: a deferred dict from (type, state_key) -> state_event + Args: + event_id(str): event whose state should be returned + types(list[(str, str)]|None): List of (type, state_key) tuples + which are used to filter the state fetched. May be None, which + matches any key + + Returns: + A deferred dict from (type, state_key) -> state_event """ state_map = yield self.get_state_for_events([event_id], types) defer.returnValue(state_map[event_id]) @@ -270,8 +284,8 @@ class StateStore(SQLBaseStore): desc="_get_state_group_for_event", ) - @cachedList(cache=_get_state_group_for_event.cache, list_name="event_ids", - num_args=1, inlineCallbacks=True) + @cachedList(cached_method_name="_get_state_group_for_event", + list_name="event_ids", num_args=1, inlineCallbacks=True) def _get_state_group_for_events(self, event_ids): """Returns mapping event_id -> state_group """ @@ -356,6 +370,8 @@ class StateStore(SQLBaseStore): a `state_key` of None matches all state_keys. If `types` is None then all events are returned. """ + if types: + types = frozenset(types) results = {} missing_groups = [] if types is not None: @@ -429,3 +445,33 @@ class StateStore(SQLBaseStore): } defer.returnValue(results) + + def get_all_new_state_groups(self, last_id, current_id, limit): + def get_all_new_state_groups_txn(txn): + sql = ( + "SELECT id, room_id, event_id FROM state_groups" + " WHERE ? < id AND id <= ? ORDER BY id LIMIT ?" + ) + txn.execute(sql, (last_id, current_id, limit)) + groups = txn.fetchall() + + if not groups: + return ([], []) + + lower_bound = groups[0][0] + upper_bound = groups[-1][0] + sql = ( + "SELECT state_group, type, state_key, event_id" + " FROM state_groups_state" + " WHERE ? <= state_group AND state_group <= ?" + ) + + txn.execute(sql, (lower_bound, upper_bound)) + state_group_state = txn.fetchall() + return (groups, state_group_state) + return self.runInteraction( + "get_all_new_state_groups", get_all_new_state_groups_txn + ) + + def get_state_stream_token(self): + return self._state_groups_id_gen.get_current_token() |