diff --git a/synapse/state.py b/synapse/state.py
index cd792afed1..b4eca0e5d5 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -26,6 +26,7 @@ from synapse.events.snapshot import EventContext
from synapse.util.async import Linearizer
from collections import namedtuple
+from frozendict import frozendict
import logging
import hashlib
@@ -55,12 +56,15 @@ def _gen_state_id():
class _StateCacheEntry(object):
- __slots__ = ["state", "state_group", "state_id"]
+ __slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids"]
- def __init__(self, state, state_group):
- self.state = state
+ def __init__(self, state, state_group, prev_group=None, delta_ids=None):
+ self.state = frozendict(state)
self.state_group = state_group
+ self.prev_group = prev_group
+ self.delta_ids = frozendict(delta_ids) if delta_ids is not None else None
+
# The `state_id` is a unique ID we generate that can be used as ID for
# this collection of state. Usually this would be the same as the
# state group, but on worker instances we can't generate a new state
@@ -153,8 +157,9 @@ class StateHandler(object):
defer.returnValue(state)
@defer.inlineCallbacks
- def get_current_user_in_room(self, room_id):
- latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
+ def get_current_user_in_room(self, room_id, latest_event_ids=None):
+ if not latest_event_ids:
+ latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
entry = yield self.resolve_state_groups(room_id, latest_event_ids)
joined_users = yield self.store.get_joined_users_from_state(
room_id, entry.state_id, entry.state
@@ -234,21 +239,29 @@ class StateHandler(object):
context.prev_state_ids = curr_state
if event.is_state():
context.state_group = self.store.get_next_state_group()
- else:
- if entry.state_group is None:
- entry.state_group = self.store.get_next_state_group()
- entry.state_id = entry.state_group
- context.state_group = entry.state_group
- if event.is_state():
key = (event.type, event.state_key)
if key in context.prev_state_ids:
replaces = context.prev_state_ids[key]
event.unsigned["replaces_state"] = replaces
+
context.current_state_ids = dict(context.prev_state_ids)
context.current_state_ids[key] = event.event_id
+
+ context.prev_group = entry.prev_group
+ context.delta_ids = entry.delta_ids
+ if context.delta_ids is not None:
+ context.delta_ids = dict(context.delta_ids)
+ context.delta_ids[key] = event.event_id
else:
+ if entry.state_group is None:
+ entry.state_group = self.store.get_next_state_group()
+ entry.state_id = entry.state_group
+
+ context.state_group = entry.state_group
context.current_state_ids = context.prev_state_ids
+ context.prev_group = entry.prev_group
+ context.delta_ids = entry.delta_ids
context.prev_state_events = []
defer.returnValue(context)
@@ -283,6 +296,8 @@ class StateHandler(object):
defer.returnValue(_StateCacheEntry(
state=state_list,
state_group=name,
+ prev_group=name,
+ delta_ids={},
))
with (yield self.resolve_linearizer.queue(group_names)):
@@ -340,9 +355,24 @@ class StateHandler(object):
if hasattr(self.store, "get_next_state_group"):
state_group = self.store.get_next_state_group()
+ prev_group = None
+ delta_ids = None
+ for old_group, old_ids in state_groups_ids.items():
+ if not set(new_state.iterkeys()) - set(old_ids.iterkeys()):
+ n_delta_ids = {
+ k: v
+ for k, v in new_state.items()
+ if old_ids.get(k) != v
+ }
+ if not delta_ids or len(n_delta_ids) < len(delta_ids):
+ prev_group = old_group
+ delta_ids = n_delta_ids
+
cache = _StateCacheEntry(
state=new_state,
state_group=state_group,
+ prev_group=prev_group,
+ delta_ids=delta_ids,
)
if self._state_cache is not None:
|