diff options
Diffstat (limited to 'synapse/state.py')
-rw-r--r-- | synapse/state.py | 52 |
1 files changed, 41 insertions, 11 deletions
diff --git a/synapse/state.py b/synapse/state.py index cd792afed1..b4eca0e5d5 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -26,6 +26,7 @@ from synapse.events.snapshot import EventContext from synapse.util.async import Linearizer from collections import namedtuple +from frozendict import frozendict import logging import hashlib @@ -55,12 +56,15 @@ def _gen_state_id(): class _StateCacheEntry(object): - __slots__ = ["state", "state_group", "state_id"] + __slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids"] - def __init__(self, state, state_group): - self.state = state + def __init__(self, state, state_group, prev_group=None, delta_ids=None): + self.state = frozendict(state) self.state_group = state_group + self.prev_group = prev_group + self.delta_ids = frozendict(delta_ids) if delta_ids is not None else None + # The `state_id` is a unique ID we generate that can be used as ID for # this collection of state. Usually this would be the same as the # state group, but on worker instances we can't generate a new state @@ -153,8 +157,9 @@ class StateHandler(object): defer.returnValue(state) @defer.inlineCallbacks - def get_current_user_in_room(self, room_id): - latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) + def get_current_user_in_room(self, room_id, latest_event_ids=None): + if not latest_event_ids: + latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) entry = yield self.resolve_state_groups(room_id, latest_event_ids) joined_users = yield self.store.get_joined_users_from_state( room_id, entry.state_id, entry.state @@ -234,21 +239,29 @@ class StateHandler(object): context.prev_state_ids = curr_state if event.is_state(): context.state_group = self.store.get_next_state_group() - else: - if entry.state_group is None: - entry.state_group = self.store.get_next_state_group() - entry.state_id = entry.state_group - context.state_group = entry.state_group - if event.is_state(): key = (event.type, event.state_key) if key in context.prev_state_ids: replaces = context.prev_state_ids[key] event.unsigned["replaces_state"] = replaces + context.current_state_ids = dict(context.prev_state_ids) context.current_state_ids[key] = event.event_id + + context.prev_group = entry.prev_group + context.delta_ids = entry.delta_ids + if context.delta_ids is not None: + context.delta_ids = dict(context.delta_ids) + context.delta_ids[key] = event.event_id else: + if entry.state_group is None: + entry.state_group = self.store.get_next_state_group() + entry.state_id = entry.state_group + + context.state_group = entry.state_group context.current_state_ids = context.prev_state_ids + context.prev_group = entry.prev_group + context.delta_ids = entry.delta_ids context.prev_state_events = [] defer.returnValue(context) @@ -283,6 +296,8 @@ class StateHandler(object): defer.returnValue(_StateCacheEntry( state=state_list, state_group=name, + prev_group=name, + delta_ids={}, )) with (yield self.resolve_linearizer.queue(group_names)): @@ -340,9 +355,24 @@ class StateHandler(object): if hasattr(self.store, "get_next_state_group"): state_group = self.store.get_next_state_group() + prev_group = None + delta_ids = None + for old_group, old_ids in state_groups_ids.items(): + if not set(new_state.iterkeys()) - set(old_ids.iterkeys()): + n_delta_ids = { + k: v + for k, v in new_state.items() + if old_ids.get(k) != v + } + if not delta_ids or len(n_delta_ids) < len(delta_ids): + prev_group = old_group + delta_ids = n_delta_ids + cache = _StateCacheEntry( state=new_state, state_group=state_group, + prev_group=prev_group, + delta_ids=delta_ids, ) if self._state_cache is not None: |