diff options
Diffstat (limited to 'synapse/storage/state.py')
-rw-r--r-- | synapse/storage/state.py | 99 |
1 files changed, 85 insertions, 14 deletions
diff --git a/synapse/storage/state.py b/synapse/storage/state.py index f2b17f29ea..47bec65497 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -92,24 +92,31 @@ class StateStore(SQLBaseStore): defer.returnValue(dict(state_list)) @cached(num_args=1) - def _fetch_events_for_group(self, state_group, events): + def _fetch_events_for_group(self, key, events): return self._get_events( events, get_prev_content=False ).addCallback( - lambda evs: (state_group, evs) + lambda evs: (key, evs) ) def _store_state_groups_txn(self, txn, event, context): - if context.current_state is None: - return + return self._store_mult_state_groups_txn(txn, [(event, context)]) - state_events = dict(context.current_state) + def _store_mult_state_groups_txn(self, txn, events_and_contexts): + state_groups = {} + for event, context in events_and_contexts: + if context.current_state is None: + continue - if event.is_state(): - state_events[(event.type, event.state_key)] = event + if context.state_group is not None: + state_groups[event.event_id] = context.state_group + continue + + state_events = dict(context.current_state) + + if event.is_state(): + state_events[(event.type, event.state_key)] = event - state_group = context.state_group - if not state_group: state_group = self._state_groups_id_gen.get_next_txn(txn) self._simple_insert_txn( txn, @@ -135,14 +142,19 @@ class StateStore(SQLBaseStore): for state in state_events.values() ], ) + state_groups[event.event_id] = state_group - self._simple_insert_txn( + self._simple_insert_many_txn( txn, table="event_to_state_groups", - values={ - "state_group": state_group, - "event_id": event.event_id, - }, + values=[ + { + "state_group": state_groups[event.event_id], + "event_id": event.event_id, + } + for event, context in events_and_contexts + if context.current_state is not None + ], ) @defer.inlineCallbacks @@ -194,6 +206,65 @@ class StateStore(SQLBaseStore): events = yield self._get_events(event_ids, get_prev_content=False) defer.returnValue(events) + @defer.inlineCallbacks + def get_state_for_events(self, room_id, event_ids): + def f(txn): + groups = set() + event_to_group = {} + for event_id in event_ids: + # TODO: Remove this loop. + group = self._simple_select_one_onecol_txn( + txn, + table="event_to_state_groups", + keyvalues={"event_id": event_id}, + retcol="state_group", + allow_none=True, + ) + if group: + event_to_group[event_id] = group + groups.add(group) + + group_to_state_ids = {} + for group in groups: + state_ids = self._simple_select_onecol_txn( + txn, + table="state_groups_state", + keyvalues={"state_group": group}, + retcol="event_id", + ) + + group_to_state_ids[group] = state_ids + + return event_to_group, group_to_state_ids + + res = yield self.runInteraction( + "annotate_events_with_state_groups", + f, + ) + + event_to_group, group_to_state_ids = res + + state_list = yield defer.gatherResults( + [ + self._fetch_events_for_group(group, vals) + for group, vals in group_to_state_ids.items() + ], + consumeErrors=True, + ) + + state_dict = { + group: { + (ev.type, ev.state_key): ev + for ev in state + } + for group, state in state_list + } + + defer.returnValue([ + state_dict.get(event_to_group.get(event, None), None) + for event in event_ids + ]) + def _make_group_id(clock): return str(int(clock.time_msec())) + random_string(5) |