diff --git a/synapse/state.py b/synapse/state.py
index 2249b7fffb..2a01887a67 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -107,6 +107,20 @@ class StateHandler(object):
defer.returnValue(state)
@defer.inlineCallbacks
+ def get_current_state_ids(self, room_id, event_type=None, state_key="",
+ latest_event_ids=None):
+ if not latest_event_ids:
+ latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
+
+ _, state = yield self.resolve_state_groups(room_id, latest_event_ids)
+
+ if event_type:
+ defer.returnValue(state.get((event_type, state_key)))
+ return
+
+ defer.returnValue(state)
+
+ @defer.inlineCallbacks
def compute_event_context(self, event, old_state=None):
""" Fills out the context with the `current state` of the graph. The
`current state` here is defined to be the state of the event graph
@@ -127,27 +141,27 @@ class StateHandler(object):
# state. Certainly store.get_current_state won't return any, and
# persisting the event won't store the state group.
if old_state:
- context.current_state = {
- (s.type, s.state_key): s for s in old_state
+ context.current_state_ids = {
+ (s.type, s.state_key): s.event_id for s in old_state
}
else:
- context.current_state = {}
+ context.current_state_ids = {}
context.prev_state_events = []
context.state_group = None
defer.returnValue(context)
if old_state:
- context.current_state = {
- (s.type, s.state_key): s for s in old_state
+ context.current_state_ids = {
+ (s.type, s.state_key): s.event_id for s in old_state
}
context.state_group = None
if event.is_state():
key = (event.type, event.state_key)
- if key in context.current_state:
- replaces = context.current_state[key]
- if replaces.event_id != event.event_id: # Paranoia check
- event.unsigned["replaces_state"] = replaces.event_id
+ if key in context.current_state_ids:
+ replaces = context.current_state_ids[key]
+ if replaces != event.event_id: # Paranoia check
+ event.unsigned["replaces_state"] = replaces
context.prev_state_events = []
defer.returnValue(context)
@@ -165,22 +179,14 @@ class StateHandler(object):
group, curr_state = ret
- state_map = yield self.store.get_events(
- curr_state.values(),
- get_prev_content=False
- )
- curr_state = {
- key: state_map[e_id] for key, e_id in curr_state.items() if e_id in state_map
- }
-
- context.current_state = curr_state
+ context.current_state_ids = curr_state
context.state_group = group if not event.is_state() else None
if event.is_state():
key = (event.type, event.state_key)
- if key in context.current_state:
- replaces = context.current_state[key]
- event.unsigned["replaces_state"] = replaces.event_id
+ if key in context.current_state_ids:
+ replaces = context.current_state_ids[key]
+ event.unsigned["replaces_state"] = replaces
context.prev_state_events = []
defer.returnValue(context)
|