diff --git a/synapse/state.py b/synapse/state.py
index daec983dc9..cd792afed1 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -23,6 +23,7 @@ from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError
from synapse.api.auth import AuthEventTypes
from synapse.events.snapshot import EventContext
+from synapse.util.async import Linearizer
from collections import namedtuple
@@ -43,11 +44,35 @@ SIZE_OF_CACHE = int(1000 * CACHE_SIZE_FACTOR)
EVICTION_TIMEOUT_SECONDS = 60 * 60
+_NEXT_STATE_ID = 1
+
+
+def _gen_state_id():
+ global _NEXT_STATE_ID
+ s = "X%d" % (_NEXT_STATE_ID,)
+ _NEXT_STATE_ID += 1
+ return s
+
+
class _StateCacheEntry(object):
- def __init__(self, state, state_group, ts):
+ __slots__ = ["state", "state_group", "state_id"]
+
+ def __init__(self, state, state_group):
self.state = state
self.state_group = state_group
+ # The `state_id` is a unique ID we generate that can be used as ID for
+ # this collection of state. Usually this would be the same as the
+ # state group, but on worker instances we can't generate a new state
+ # group each time we resolve state, so we generate a separate one that
+ # isn't persisted and is used solely for caches.
+ # `state_id` is either a state_group (and so an int) or a string. This
+ # ensures we don't accidentally persist a state_id as a stateg_group
+ if state_group:
+ self.state_id = state_group
+ else:
+ self.state_id = _gen_state_id()
+
class StateHandler(object):
""" Responsible for doing state conflict resolution.
@@ -60,6 +85,7 @@ class StateHandler(object):
# dict of set of event_ids -> _StateCacheEntry.
self._state_cache = None
+ self.resolve_linearizer = Linearizer()
def start_caching(self):
logger.debug("start_caching")
@@ -93,7 +119,8 @@ class StateHandler(object):
if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
- _, state = yield self.resolve_state_groups(room_id, latest_event_ids)
+ ret = yield self.resolve_state_groups(room_id, latest_event_ids)
+ state = ret.state
if event_type:
event_id = state.get((event_type, state_key))
@@ -116,7 +143,8 @@ class StateHandler(object):
if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
- _, state = yield self.resolve_state_groups(room_id, latest_event_ids)
+ ret = yield self.resolve_state_groups(room_id, latest_event_ids)
+ state = ret.state
if event_type:
defer.returnValue(state.get((event_type, state_key)))
@@ -127,9 +155,9 @@ class StateHandler(object):
@defer.inlineCallbacks
def get_current_user_in_room(self, room_id):
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
- group, state_ids = yield self.resolve_state_groups(room_id, latest_event_ids)
- joined_users = yield self.store.get_joined_users_from_context(
- room_id, group, state_ids
+ entry = yield self.resolve_state_groups(room_id, latest_event_ids)
+ joined_users = yield self.store.get_joined_users_from_state(
+ room_id, entry.state_id, entry.state
)
defer.returnValue(joined_users)
@@ -154,52 +182,73 @@ class StateHandler(object):
# state. Certainly store.get_current_state won't return any, and
# persisting the event won't store the state group.
if old_state:
- context.current_state_ids = {
+ context.prev_state_ids = {
(s.type, s.state_key): s.event_id for s in old_state
}
+ if event.is_state():
+ context.current_state_events = dict(context.prev_state_ids)
+ key = (event.type, event.state_key)
+ context.current_state_events[key] = event.event_id
+ else:
+ context.current_state_events = context.prev_state_ids
else:
context.current_state_ids = {}
+ context.prev_state_ids = {}
context.prev_state_events = []
- context.state_group = None
+ context.state_group = self.store.get_next_state_group()
defer.returnValue(context)
if old_state:
- context.current_state_ids = {
+ context.prev_state_ids = {
(s.type, s.state_key): s.event_id for s in old_state
}
- context.state_group = None
+ context.state_group = self.store.get_next_state_group()
if event.is_state():
key = (event.type, event.state_key)
- if key in context.current_state_ids:
- replaces = context.current_state_ids[key]
+ if key in context.prev_state_ids:
+ replaces = context.prev_state_ids[key]
if replaces != event.event_id: # Paranoia check
event.unsigned["replaces_state"] = replaces
+ context.current_state_ids = dict(context.prev_state_ids)
+ context.current_state_ids[key] = event.event_id
+ else:
+ context.current_state_ids = context.prev_state_ids
context.prev_state_events = []
defer.returnValue(context)
if event.is_state():
- ret = yield self.resolve_state_groups(
+ entry = yield self.resolve_state_groups(
event.room_id, [e for e, _ in event.prev_events],
event_type=event.type,
state_key=event.state_key,
)
else:
- ret = yield self.resolve_state_groups(
+ entry = yield self.resolve_state_groups(
event.room_id, [e for e, _ in event.prev_events],
)
- group, curr_state = ret
+ curr_state = entry.state
- context.current_state_ids = curr_state
- context.state_group = group if not event.is_state() else None
+ context.prev_state_ids = curr_state
+ if event.is_state():
+ context.state_group = self.store.get_next_state_group()
+ else:
+ if entry.state_group is None:
+ entry.state_group = self.store.get_next_state_group()
+ entry.state_id = entry.state_group
+ context.state_group = entry.state_group
if event.is_state():
key = (event.type, event.state_key)
- if key in context.current_state_ids:
- replaces = context.current_state_ids[key]
+ if key in context.prev_state_ids:
+ replaces = context.prev_state_ids[key]
event.unsigned["replaces_state"] = replaces
+ context.current_state_ids = dict(context.prev_state_ids)
+ context.current_state_ids[key] = event.event_id
+ else:
+ context.current_state_ids = context.prev_state_ids
context.prev_state_events = []
defer.returnValue(context)
@@ -231,70 +280,75 @@ class StateHandler(object):
if len(group_names) == 1:
name, state_list = state_groups_ids.items().pop()
- defer.returnValue((name, state_list,))
-
- if self._state_cache is not None:
- cache = self._state_cache.get(group_names, None)
- if cache:
- cache.ts = self.clock.time_msec()
+ defer.returnValue(_StateCacheEntry(
+ state=state_list,
+ state_group=name,
+ ))
- defer.returnValue(
- (cache.state_group, cache.state,)
- )
-
- logger.info(
- "Resolving state for %s with %d groups", room_id, len(state_groups_ids)
- )
+ with (yield self.resolve_linearizer.queue(group_names)):
+ if self._state_cache is not None:
+ cache = self._state_cache.get(group_names, None)
+ if cache:
+ defer.returnValue(cache)
- state = {}
- for st in state_groups_ids.values():
- for key, e_id in st.items():
- state.setdefault(key, set()).add(e_id)
+ logger.info(
+ "Resolving state for %s with %d groups", room_id, len(state_groups_ids)
+ )
- conflicted_state = {
- k: list(v)
- for k, v in state.items()
- if len(v) > 1
- }
+ state = {}
+ for st in state_groups_ids.values():
+ for key, e_id in st.items():
+ state.setdefault(key, set()).add(e_id)
- if conflicted_state:
- logger.info("Resolving conflicted state for %r", room_id)
- state_map = yield self.store.get_events(
- [e_id for st in state_groups_ids.values() for e_id in st.values()],
- get_prev_content=False
- )
- state_sets = [
- [state_map[e_id] for key, e_id in st.items() if e_id in state_map]
- for st in state_groups_ids.values()
- ]
- new_state, _ = self._resolve_events(
- state_sets, event_type, state_key
- )
- new_state = {
- key: e.event_id for key, e in new_state.items()
- }
- else:
- new_state = {
- key: e_ids.pop() for key, e_ids in state.items()
+ conflicted_state = {
+ k: list(v)
+ for k, v in state.items()
+ if len(v) > 1
}
- state_group = None
- new_state_event_ids = frozenset(new_state.values())
- for sg, events in state_groups_ids.items():
- if new_state_event_ids == frozenset(e_id for e_id in events):
- state_group = sg
- break
+ if conflicted_state:
+ logger.info("Resolving conflicted state for %r", room_id)
+ state_map = yield self.store.get_events(
+ [e_id for st in state_groups_ids.values() for e_id in st.values()],
+ get_prev_content=False
+ )
+ state_sets = [
+ [state_map[e_id] for key, e_id in st.items() if e_id in state_map]
+ for st in state_groups_ids.values()
+ ]
+ new_state, _ = self._resolve_events(
+ state_sets, event_type, state_key
+ )
+ new_state = {
+ key: e.event_id for key, e in new_state.items()
+ }
+ else:
+ new_state = {
+ key: e_ids.pop() for key, e_ids in state.items()
+ }
+
+ state_group = None
+ new_state_event_ids = frozenset(new_state.values())
+ for sg, events in state_groups_ids.items():
+ if new_state_event_ids == frozenset(e_id for e_id in events):
+ state_group = sg
+ break
+ if state_group is None:
+ # Worker instances don't have access to this method, but we want
+ # to set the state_group on the main instance to increase cache
+ # hits.
+ if hasattr(self.store, "get_next_state_group"):
+ state_group = self.store.get_next_state_group()
- if self._state_cache is not None:
cache = _StateCacheEntry(
state=new_state,
state_group=state_group,
- ts=self.clock.time_msec()
)
- self._state_cache[group_names] = cache
+ if self._state_cache is not None:
+ self._state_cache[group_names] = cache
- defer.returnValue((state_group, new_state,))
+ defer.returnValue(cache)
def resolve_events(self, state_sets, event):
logger.info(
|