diff options
-rw-r--r-- | synapse/handlers/sync.py | 79 | ||||
-rw-r--r-- | synapse/storage/state.py | 46 |
2 files changed, 106 insertions, 19 deletions
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 0f713ce038..c754cfdeeb 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -399,7 +399,7 @@ class SyncHandler(object): )) @defer.inlineCallbacks - def get_state_after_event(self, event): + def get_state_after_event(self, event, types=None): """ Get the room state after the given event @@ -409,14 +409,14 @@ class SyncHandler(object): Returns: A Deferred map from ((type, state_key)->Event) """ - state_ids = yield self.store.get_state_ids_for_event(event.event_id) + state_ids = yield self.store.get_state_ids_for_event(event.event_id, types) if event.is_state(): state_ids = state_ids.copy() state_ids[(event.type, event.state_key)] = event.event_id defer.returnValue(state_ids) @defer.inlineCallbacks - def get_state_at(self, room_id, stream_position): + def get_state_at(self, room_id, stream_position, types=None): """ Get the room state at a particular stream position Args: @@ -432,7 +432,7 @@ class SyncHandler(object): if last_events: last_event = last_events[-1] - state = yield self.get_state_after_event(last_event) + state = yield self.get_state_after_event(last_event, types) else: # no events in this room - so presumably no state @@ -441,7 +441,7 @@ class SyncHandler(object): @defer.inlineCallbacks def compute_state_delta(self, room_id, batch, sync_config, since_token, now_token, - full_state): + full_state, filter_members): """ Works out the differnce in state between the start of the timeline and the previous sync. @@ -454,6 +454,8 @@ class SyncHandler(object): be None. now_token(str): Token of the end of the current batch. full_state(bool): Whether to force returning the full state. + filter_members(bool): Whether to only return state for members + referenced in this timeline segment Returns: A deferred new event dictionary @@ -464,22 +466,54 @@ class SyncHandler(object): # TODO(mjark) Check for new redactions in the state events. with Measure(self.clock, "compute_state_delta"): + + types = None + member_state_ids = {} + + if filter_members: + # We only request state for the members needed to display the + # timeline: + types = [ + (EventTypes.Member, state_key) + for state_key in set( + event.sender # FIXME: we also care about targets etc. + for event in batch.events + ) + ] + types.append((None, None)) # don't just filter to room members + + # TODO: we should opportunistically deduplicate these members too + # within the same sync series (based on an in-memory cache) + if full_state: if batch: current_state_ids = yield self.store.get_state_ids_for_event( - batch.events[-1].event_id + batch.events[-1].event_id, types=types ) state_ids = yield self.store.get_state_ids_for_event( - batch.events[0].event_id + batch.events[0].event_id, types=types ) + + if filter_members: + member_state_ids = { + t: state_ids[t] + for t in state_ids if t[0] == EventTypes.member + } + else: current_state_ids = yield self.get_state_at( - room_id, stream_position=now_token + room_id, stream_position=now_token, types=types ) state_ids = current_state_ids + if filter_members: + member_state_ids = { + t: state_ids[t] + for t in state_ids if t[0] == EventTypes.member + } + timeline_state = { (event.type, event.state_key): event.event_id for event in batch.events if event.is_state() @@ -488,22 +522,29 @@ class SyncHandler(object): state_ids = _calculate_state( timeline_contains=timeline_state, timeline_start=state_ids, + timeline_start_members=member_state_ids, previous={}, current=current_state_ids, ) elif batch.limited: state_at_previous_sync = yield self.get_state_at( - room_id, stream_position=since_token + room_id, stream_position=since_token, types=types ) current_state_ids = yield self.store.get_state_ids_for_event( - batch.events[-1].event_id + batch.events[-1].event_id, types=types ) state_at_timeline_start = yield self.store.get_state_ids_for_event( - batch.events[0].event_id + batch.events[0].event_id, types=types ) + if filter_members: + member_state_ids = { + t: state_at_timeline_start[t] + for t in state_ids if t[0] == EventTypes.member + } + timeline_state = { (event.type, event.state_key): event.event_id for event in batch.events if event.is_state() @@ -512,6 +553,7 @@ class SyncHandler(object): state_ids = _calculate_state( timeline_contains=timeline_state, timeline_start=state_at_timeline_start, + timeline_start_members=member_state_ids, previous=state_at_previous_sync, current=current_state_ids, ) @@ -1325,7 +1367,7 @@ class SyncHandler(object): state = yield self.compute_state_delta( room_id, batch, sync_config, since_token, now_token, - full_state=full_state + full_state=full_state, filter_members=True ) if room_builder.rtype == "joined": @@ -1421,12 +1463,16 @@ def _action_has_highlight(actions): return False -def _calculate_state(timeline_contains, timeline_start, previous, current): +def _calculate_state(timeline_contains, timeline_start, timeline_start_members, + previous, current): """Works out what state to include in a sync response. Args: timeline_contains (dict): state in the timeline timeline_start (dict): state at the start of the timeline + timeline_start_members (dict): state at the start of the timeline + for room members who participate in this chunk of timeline. + Should always be a subset of timeline_start. previous (dict): state at the end of the previous sync (or empty dict if this is an initial sync) current (dict): state at the end of the timeline @@ -1445,11 +1491,12 @@ def _calculate_state(timeline_contains, timeline_start, previous, current): } c_ids = set(e for e in current.values()) - tc_ids = set(e for e in timeline_contains.values()) - p_ids = set(e for e in previous.values()) ts_ids = set(e for e in timeline_start.values()) + tsm_ids = set(e for e in timeline_start_members.values()) + p_ids = set(e for e in previous.values()) + tc_ids = set(e for e in timeline_contains.values()) - state_ids = ((c_ids | ts_ids) - p_ids) - tc_ids + state_ids = (((c_ids | ts_ids) - p_ids) - tc_ids) | tsm_ids return { event_id_to_key[e]: e for e in state_ids diff --git a/synapse/storage/state.py b/synapse/storage/state.py index ffa4246031..4ab16e18b2 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -198,8 +198,15 @@ class StateGroupWorkerStore(SQLBaseStore): def _get_state_groups_from_groups_txn(self, txn, groups, types=None): results = {group: {} for group in groups} + + include_other_types = False + if types is not None: - types = list(set(types)) # deduplicate types list + type_set = set(types) + if (None, None) in type_set: + include_other_types = True + type_set.remove((None, None)) + types = list(type_set) # deduplicate types list if isinstance(self.database_engine, PostgresEngine): # Temporarily disable sequential scans in this transaction. This is @@ -246,6 +253,16 @@ class StateGroupWorkerStore(SQLBaseStore): ) for etype, state_key in types ] + + if include_other_types: + # XXX: check whether this slows postgres down like a list of + # ORs does too? + clause_to_args.append( + ( + "AND type <> ? " * len(types), + [t for (t, _) in types] + ) + ) else: # If types is None we fetch all the state, and so just use an # empty where clause with no extra args. @@ -274,6 +291,13 @@ class StateGroupWorkerStore(SQLBaseStore): else: where_clauses.append("(type = ? AND state_key = ?)") where_args.extend([typ[0], typ[1]]) + + if include_other_types: + where_clauses.append( + "(" + " AND ".join(["type <> ?"] * len(types)) + ")" + ) + where_args.extend(t for (t, _) in types) + where_clause = "AND (%s)" % (" OR ".join(where_clauses)) else: where_clause = "" @@ -469,17 +493,27 @@ class StateGroupWorkerStore(SQLBaseStore): group: The state group to lookup types (list): List of 2-tuples of the form (`type`, `state_key`), where a `state_key` of `None` matches all state_keys for the - `type`. + `type`. Presence of type of `None` indicates that types not + in the list should not be filtered out. """ is_all, known_absent, state_dict_ids = self._state_group_cache.get(group) type_to_key = {} missing_types = set() + include_other_types = False + for typ, state_key in types: key = (typ, state_key) + + if typ is None: + include_other_types = True + next + if state_key is None: type_to_key[typ] = None + # XXX: why do we mark the type as missing from our cache just + # because we weren't filtering on a specific value of state_key? missing_types.add(key) else: if type_to_key.get(typ, object()) is not None: @@ -493,7 +527,7 @@ class StateGroupWorkerStore(SQLBaseStore): def include(typ, state_key): valid_state_keys = type_to_key.get(typ, sentinel) if valid_state_keys is sentinel: - return False + return include_other_types if valid_state_keys is None: return True if state_key in valid_state_keys: @@ -527,6 +561,12 @@ class StateGroupWorkerStore(SQLBaseStore): with matching types. `types` is a list of `(type, state_key)`, where a `state_key` of None matches all state_keys. If `types` is None then all events are returned. + + XXX: is it really true that `state_key` of None in `types` matches all + state_keys? it looks like _get-some_state_from_cache does the right thing, + but _get_state_groups_from_groups_txn treats ths None is turned into + 'AND state_key = NULL' or similar (at least until i just fixed it) --Matthew + I've filed this as https://github.com/matrix-org/synapse/issues/2969 """ if types: types = frozenset(types) |