diff options
Diffstat (limited to 'synapse/state.py')
-rw-r--r-- | synapse/state.py | 194 |
1 files changed, 124 insertions, 70 deletions
diff --git a/synapse/state.py b/synapse/state.py index daec983dc9..cd792afed1 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -23,6 +23,7 @@ from synapse.api.constants import EventTypes from synapse.api.errors import AuthError from synapse.api.auth import AuthEventTypes from synapse.events.snapshot import EventContext +from synapse.util.async import Linearizer from collections import namedtuple @@ -43,11 +44,35 @@ SIZE_OF_CACHE = int(1000 * CACHE_SIZE_FACTOR) EVICTION_TIMEOUT_SECONDS = 60 * 60 +_NEXT_STATE_ID = 1 + + +def _gen_state_id(): + global _NEXT_STATE_ID + s = "X%d" % (_NEXT_STATE_ID,) + _NEXT_STATE_ID += 1 + return s + + class _StateCacheEntry(object): - def __init__(self, state, state_group, ts): + __slots__ = ["state", "state_group", "state_id"] + + def __init__(self, state, state_group): self.state = state self.state_group = state_group + # The `state_id` is a unique ID we generate that can be used as ID for + # this collection of state. Usually this would be the same as the + # state group, but on worker instances we can't generate a new state + # group each time we resolve state, so we generate a separate one that + # isn't persisted and is used solely for caches. + # `state_id` is either a state_group (and so an int) or a string. This + # ensures we don't accidentally persist a state_id as a stateg_group + if state_group: + self.state_id = state_group + else: + self.state_id = _gen_state_id() + class StateHandler(object): """ Responsible for doing state conflict resolution. @@ -60,6 +85,7 @@ class StateHandler(object): # dict of set of event_ids -> _StateCacheEntry. self._state_cache = None + self.resolve_linearizer = Linearizer() def start_caching(self): logger.debug("start_caching") @@ -93,7 +119,8 @@ class StateHandler(object): if not latest_event_ids: latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) - _, state = yield self.resolve_state_groups(room_id, latest_event_ids) + ret = yield self.resolve_state_groups(room_id, latest_event_ids) + state = ret.state if event_type: event_id = state.get((event_type, state_key)) @@ -116,7 +143,8 @@ class StateHandler(object): if not latest_event_ids: latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) - _, state = yield self.resolve_state_groups(room_id, latest_event_ids) + ret = yield self.resolve_state_groups(room_id, latest_event_ids) + state = ret.state if event_type: defer.returnValue(state.get((event_type, state_key))) @@ -127,9 +155,9 @@ class StateHandler(object): @defer.inlineCallbacks def get_current_user_in_room(self, room_id): latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) - group, state_ids = yield self.resolve_state_groups(room_id, latest_event_ids) - joined_users = yield self.store.get_joined_users_from_context( - room_id, group, state_ids + entry = yield self.resolve_state_groups(room_id, latest_event_ids) + joined_users = yield self.store.get_joined_users_from_state( + room_id, entry.state_id, entry.state ) defer.returnValue(joined_users) @@ -154,52 +182,73 @@ class StateHandler(object): # state. Certainly store.get_current_state won't return any, and # persisting the event won't store the state group. if old_state: - context.current_state_ids = { + context.prev_state_ids = { (s.type, s.state_key): s.event_id for s in old_state } + if event.is_state(): + context.current_state_events = dict(context.prev_state_ids) + key = (event.type, event.state_key) + context.current_state_events[key] = event.event_id + else: + context.current_state_events = context.prev_state_ids else: context.current_state_ids = {} + context.prev_state_ids = {} context.prev_state_events = [] - context.state_group = None + context.state_group = self.store.get_next_state_group() defer.returnValue(context) if old_state: - context.current_state_ids = { + context.prev_state_ids = { (s.type, s.state_key): s.event_id for s in old_state } - context.state_group = None + context.state_group = self.store.get_next_state_group() if event.is_state(): key = (event.type, event.state_key) - if key in context.current_state_ids: - replaces = context.current_state_ids[key] + if key in context.prev_state_ids: + replaces = context.prev_state_ids[key] if replaces != event.event_id: # Paranoia check event.unsigned["replaces_state"] = replaces + context.current_state_ids = dict(context.prev_state_ids) + context.current_state_ids[key] = event.event_id + else: + context.current_state_ids = context.prev_state_ids context.prev_state_events = [] defer.returnValue(context) if event.is_state(): - ret = yield self.resolve_state_groups( + entry = yield self.resolve_state_groups( event.room_id, [e for e, _ in event.prev_events], event_type=event.type, state_key=event.state_key, ) else: - ret = yield self.resolve_state_groups( + entry = yield self.resolve_state_groups( event.room_id, [e for e, _ in event.prev_events], ) - group, curr_state = ret + curr_state = entry.state - context.current_state_ids = curr_state - context.state_group = group if not event.is_state() else None + context.prev_state_ids = curr_state + if event.is_state(): + context.state_group = self.store.get_next_state_group() + else: + if entry.state_group is None: + entry.state_group = self.store.get_next_state_group() + entry.state_id = entry.state_group + context.state_group = entry.state_group if event.is_state(): key = (event.type, event.state_key) - if key in context.current_state_ids: - replaces = context.current_state_ids[key] + if key in context.prev_state_ids: + replaces = context.prev_state_ids[key] event.unsigned["replaces_state"] = replaces + context.current_state_ids = dict(context.prev_state_ids) + context.current_state_ids[key] = event.event_id + else: + context.current_state_ids = context.prev_state_ids context.prev_state_events = [] defer.returnValue(context) @@ -231,70 +280,75 @@ class StateHandler(object): if len(group_names) == 1: name, state_list = state_groups_ids.items().pop() - defer.returnValue((name, state_list,)) - - if self._state_cache is not None: - cache = self._state_cache.get(group_names, None) - if cache: - cache.ts = self.clock.time_msec() + defer.returnValue(_StateCacheEntry( + state=state_list, + state_group=name, + )) - defer.returnValue( - (cache.state_group, cache.state,) - ) - - logger.info( - "Resolving state for %s with %d groups", room_id, len(state_groups_ids) - ) + with (yield self.resolve_linearizer.queue(group_names)): + if self._state_cache is not None: + cache = self._state_cache.get(group_names, None) + if cache: + defer.returnValue(cache) - state = {} - for st in state_groups_ids.values(): - for key, e_id in st.items(): - state.setdefault(key, set()).add(e_id) + logger.info( + "Resolving state for %s with %d groups", room_id, len(state_groups_ids) + ) - conflicted_state = { - k: list(v) - for k, v in state.items() - if len(v) > 1 - } + state = {} + for st in state_groups_ids.values(): + for key, e_id in st.items(): + state.setdefault(key, set()).add(e_id) - if conflicted_state: - logger.info("Resolving conflicted state for %r", room_id) - state_map = yield self.store.get_events( - [e_id for st in state_groups_ids.values() for e_id in st.values()], - get_prev_content=False - ) - state_sets = [ - [state_map[e_id] for key, e_id in st.items() if e_id in state_map] - for st in state_groups_ids.values() - ] - new_state, _ = self._resolve_events( - state_sets, event_type, state_key - ) - new_state = { - key: e.event_id for key, e in new_state.items() - } - else: - new_state = { - key: e_ids.pop() for key, e_ids in state.items() + conflicted_state = { + k: list(v) + for k, v in state.items() + if len(v) > 1 } - state_group = None - new_state_event_ids = frozenset(new_state.values()) - for sg, events in state_groups_ids.items(): - if new_state_event_ids == frozenset(e_id for e_id in events): - state_group = sg - break + if conflicted_state: + logger.info("Resolving conflicted state for %r", room_id) + state_map = yield self.store.get_events( + [e_id for st in state_groups_ids.values() for e_id in st.values()], + get_prev_content=False + ) + state_sets = [ + [state_map[e_id] for key, e_id in st.items() if e_id in state_map] + for st in state_groups_ids.values() + ] + new_state, _ = self._resolve_events( + state_sets, event_type, state_key + ) + new_state = { + key: e.event_id for key, e in new_state.items() + } + else: + new_state = { + key: e_ids.pop() for key, e_ids in state.items() + } + + state_group = None + new_state_event_ids = frozenset(new_state.values()) + for sg, events in state_groups_ids.items(): + if new_state_event_ids == frozenset(e_id for e_id in events): + state_group = sg + break + if state_group is None: + # Worker instances don't have access to this method, but we want + # to set the state_group on the main instance to increase cache + # hits. + if hasattr(self.store, "get_next_state_group"): + state_group = self.store.get_next_state_group() - if self._state_cache is not None: cache = _StateCacheEntry( state=new_state, state_group=state_group, - ts=self.clock.time_msec() ) - self._state_cache[group_names] = cache + if self._state_cache is not None: + self._state_cache[group_names] = cache - defer.returnValue((state_group, new_state,)) + defer.returnValue(cache) def resolve_events(self, state_sets, event): logger.info( |