diff --git a/synapse/state.py b/synapse/state.py
index daec983dc9..b31bbcdbd2 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -43,11 +43,35 @@ SIZE_OF_CACHE = int(1000 * CACHE_SIZE_FACTOR)
EVICTION_TIMEOUT_SECONDS = 60 * 60
+_NEXT_STATE_ID = 1
+
+
+def _gen_state_id():
+ global _NEXT_STATE_ID
+ s = "X%d" % (_NEXT_STATE_ID,)
+ _NEXT_STATE_ID += 1
+ return s
+
+
class _StateCacheEntry(object):
- def __init__(self, state, state_group, ts):
+ __slots__ = ["state", "state_group", "state_id"]
+
+ def __init__(self, state, state_group):
self.state = state
self.state_group = state_group
+ # The `state_id` is a unique ID we generate that can be used as ID for
+ # this collection of state. Usually this would be the same as the
+ # state group, but on worker instances we can't generate a new state
+ # group each time we resolve state, so we generate a separate one that
+ # isn't persisted and is used solely for caches.
+ # `state_id` is either a state_group (and so an int) or a string. This
+ # ensures we don't accidentally persist a state_id as a stateg_group
+ if state_group:
+ self.state_id = state_group
+ else:
+ self.state_id = _gen_state_id()
+
class StateHandler(object):
""" Responsible for doing state conflict resolution.
@@ -93,7 +117,8 @@ class StateHandler(object):
if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
- _, state = yield self.resolve_state_groups(room_id, latest_event_ids)
+ ret = yield self.resolve_state_groups(room_id, latest_event_ids)
+ state = ret.state
if event_type:
event_id = state.get((event_type, state_key))
@@ -116,7 +141,8 @@ class StateHandler(object):
if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
- _, state = yield self.resolve_state_groups(room_id, latest_event_ids)
+ ret = yield self.resolve_state_groups(room_id, latest_event_ids)
+ state = ret.state
if event_type:
defer.returnValue(state.get((event_type, state_key)))
@@ -127,9 +153,9 @@ class StateHandler(object):
@defer.inlineCallbacks
def get_current_user_in_room(self, room_id):
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
- group, state_ids = yield self.resolve_state_groups(room_id, latest_event_ids)
- joined_users = yield self.store.get_joined_users_from_context(
- room_id, group, state_ids
+ entry = yield self.resolve_state_groups(room_id, latest_event_ids)
+ joined_users = yield self.store.get_joined_users_from_state(
+ room_id, entry.state_id, entry.state
)
defer.returnValue(joined_users)
@@ -154,52 +180,73 @@ class StateHandler(object):
# state. Certainly store.get_current_state won't return any, and
# persisting the event won't store the state group.
if old_state:
- context.current_state_ids = {
+ context.prev_state_ids = {
(s.type, s.state_key): s.event_id for s in old_state
}
+ if event.is_state():
+ context.current_state_events = dict(context.prev_state_ids)
+ key = (event.type, event.state_key)
+ context.current_state_events[key] = event.event_id
+ else:
+ context.current_state_events = context.prev_state_ids
else:
context.current_state_ids = {}
+ context.prev_state_ids = {}
context.prev_state_events = []
- context.state_group = None
+ context.state_group = self.store.get_next_state_group()
defer.returnValue(context)
if old_state:
- context.current_state_ids = {
+ context.prev_state_ids = {
(s.type, s.state_key): s.event_id for s in old_state
}
- context.state_group = None
+ context.state_group = self.store.get_next_state_group()
if event.is_state():
key = (event.type, event.state_key)
- if key in context.current_state_ids:
- replaces = context.current_state_ids[key]
+ if key in context.prev_state_ids:
+ replaces = context.prev_state_ids[key]
if replaces != event.event_id: # Paranoia check
event.unsigned["replaces_state"] = replaces
+ context.current_state_ids = dict(context.prev_state_ids)
+ context.current_state_ids[key] = event.event_id
+ else:
+ context.current_state_ids = context.prev_state_ids
context.prev_state_events = []
defer.returnValue(context)
if event.is_state():
- ret = yield self.resolve_state_groups(
+ entry = yield self.resolve_state_groups(
event.room_id, [e for e, _ in event.prev_events],
event_type=event.type,
state_key=event.state_key,
)
else:
- ret = yield self.resolve_state_groups(
+ entry = yield self.resolve_state_groups(
event.room_id, [e for e, _ in event.prev_events],
)
- group, curr_state = ret
+ curr_state = entry.state
- context.current_state_ids = curr_state
- context.state_group = group if not event.is_state() else None
+ context.prev_state_ids = curr_state
+ if event.is_state():
+ context.state_group = self.store.get_next_state_group()
+ else:
+ if entry.state_group is None:
+ entry.state_group = self.store.get_next_state_group()
+ entry.state_id = entry.state_group
+ context.state_group = entry.state_group
if event.is_state():
key = (event.type, event.state_key)
- if key in context.current_state_ids:
- replaces = context.current_state_ids[key]
+ if key in context.prev_state_ids:
+ replaces = context.prev_state_ids[key]
event.unsigned["replaces_state"] = replaces
+ context.current_state_ids = dict(context.prev_state_ids)
+ context.current_state_ids[key] = event.event_id
+ else:
+ context.current_state_ids = context.prev_state_ids
context.prev_state_events = []
defer.returnValue(context)
@@ -231,16 +278,15 @@ class StateHandler(object):
if len(group_names) == 1:
name, state_list = state_groups_ids.items().pop()
- defer.returnValue((name, state_list,))
+ defer.returnValue(_StateCacheEntry(
+ state=state_list,
+ state_group=name,
+ ))
if self._state_cache is not None:
cache = self._state_cache.get(group_names, None)
if cache:
- cache.ts = self.clock.time_msec()
-
- defer.returnValue(
- (cache.state_group, cache.state,)
- )
+ defer.returnValue(cache)
logger.info(
"Resolving state for %s with %d groups", room_id, len(state_groups_ids)
@@ -284,17 +330,22 @@ class StateHandler(object):
if new_state_event_ids == frozenset(e_id for e_id in events):
state_group = sg
break
+ if state_group is None:
+ # Worker instances don't have access to this method, but we want
+ # to set the state_group on the main instance to increase cache
+ # hits.
+ if hasattr(self.store, "get_next_state_group"):
+ state_group = self.store.get_next_state_group()
+
+ cache = _StateCacheEntry(
+ state=new_state,
+ state_group=state_group,
+ )
if self._state_cache is not None:
- cache = _StateCacheEntry(
- state=new_state,
- state_group=state_group,
- ts=self.clock.time_msec()
- )
-
self._state_cache[group_names] = cache
- defer.returnValue((state_group, new_state,))
+ defer.returnValue(cache)
def resolve_events(self, state_sets, event):
logger.info(
|