diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index c8dfd02e7b..5cd009a1c8 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -355,11 +355,11 @@ class SyncHandler(object):
Returns:
A Deferred map from ((type, state_key)->Event)
"""
- state = yield self.store.get_state_for_event(event.event_id)
+ state_ids = yield self.store.get_state_ids_for_event(event.event_id)
if event.is_state():
- state = state.copy()
- state[(event.type, event.state_key)] = event
- defer.returnValue(state)
+ state_ids = state_ids.copy()
+ state_ids[(event.type, event.state_key)] = event.event_id
+ defer.returnValue(state_ids)
@defer.inlineCallbacks
def get_state_at(self, room_id, stream_position):
@@ -412,57 +412,61 @@ class SyncHandler(object):
with Measure(self.clock, "compute_state_delta"):
if full_state:
if batch:
- current_state = yield self.store.get_state_for_event(
+ current_state_ids = yield self.store.get_state_ids_for_event(
batch.events[-1].event_id
)
- state = yield self.store.get_state_for_event(
+ state_ids = yield self.store.get_state_ids_for_event(
batch.events[0].event_id
)
else:
- current_state = yield self.get_state_at(
+ current_state_ids = yield self.get_state_at(
room_id, stream_position=now_token
)
- state = current_state
+ state_ids = current_state_ids
timeline_state = {
- (event.type, event.state_key): event
+ (event.type, event.state_key): event.event_id
for event in batch.events if event.is_state()
}
- state = _calculate_state(
+ state_ids = _calculate_state(
timeline_contains=timeline_state,
- timeline_start=state,
+ timeline_start=state_ids,
previous={},
- current=current_state,
+ current=current_state_ids,
)
elif batch.limited:
state_at_previous_sync = yield self.get_state_at(
room_id, stream_position=since_token
)
- current_state = yield self.store.get_state_for_event(
+ current_state_ids = yield self.store.get_state_ids_for_event(
batch.events[-1].event_id
)
- state_at_timeline_start = yield self.store.get_state_for_event(
+ state_at_timeline_start = yield self.store.get_state_ids_for_event(
batch.events[0].event_id
)
timeline_state = {
- (event.type, event.state_key): event
+ (event.type, event.state_key): event.event_id
for event in batch.events if event.is_state()
}
- state = _calculate_state(
+ state_ids = _calculate_state(
timeline_contains=timeline_state,
timeline_start=state_at_timeline_start,
previous=state_at_previous_sync,
- current=current_state,
+ current=current_state_ids,
)
else:
- state = {}
+ state_ids = {}
+
+ state = {}
+ if state_ids:
+ state = yield self.store.get_events(state_ids.values())
defer.returnValue({
(e.type, e.state_key): e
@@ -766,8 +770,13 @@ class SyncHandler(object):
# the last sync (even if we have since left). This is to make sure
# we do send down the room, and with full state, where necessary
if room_id in joined_room_ids or has_join:
- old_state = yield self.get_state_at(room_id, since_token)
- old_mem_ev = old_state.get((EventTypes.Member, user_id), None)
+ old_state_ids = yield self.get_state_at(room_id, since_token)
+ old_mem_ev_id = old_state_ids.get((EventTypes.Member, user_id), None)
+ old_mem_ev = None
+ if old_mem_ev_id:
+ old_mem_ev = yield self.store.get_event(
+ old_mem_ev_id, allow_none=True
+ )
if not old_mem_ev or old_mem_ev.membership != Membership.JOIN:
newly_joined_rooms.append(room_id)
@@ -1059,27 +1068,25 @@ def _calculate_state(timeline_contains, timeline_start, previous, current):
Returns:
dict
"""
- event_id_to_state = {
- e.event_id: e
- for e in itertools.chain(
- timeline_contains.values(),
- previous.values(),
- timeline_start.values(),
- current.values(),
+ event_id_to_key = {
+ e: key
+ for key, e in itertools.chain(
+ timeline_contains.items(),
+ previous.items(),
+ timeline_start.items(),
+ current.items(),
)
}
- c_ids = set(e.event_id for e in current.values())
- tc_ids = set(e.event_id for e in timeline_contains.values())
- p_ids = set(e.event_id for e in previous.values())
- ts_ids = set(e.event_id for e in timeline_start.values())
+ c_ids = set(e for e in current.values())
+ tc_ids = set(e for e in timeline_contains.values())
+ p_ids = set(e for e in previous.values())
+ ts_ids = set(e for e in timeline_start.values())
state_ids = ((c_ids | ts_ids) - p_ids) - tc_ids
- evs = (event_id_to_state[e] for e in state_ids)
return {
- (e.type, e.state_key): e
- for e in evs
+ event_id_to_key[e]: e for e in state_ids
}
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 22f7fb1aa1..b1d461fef5 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -284,6 +284,22 @@ class StateStore(SQLBaseStore):
defer.returnValue({event: event_to_state[event] for event in event_ids})
@defer.inlineCallbacks
+ def get_state_ids_for_events(self, event_ids, types):
+ event_to_groups = yield self._get_state_group_for_events(
+ event_ids,
+ )
+
+ groups = set(event_to_groups.values())
+ group_to_state = yield self._get_state_for_groups(groups, types)
+
+ event_to_state = {
+ event_id: group_to_state[group]
+ for event_id, group in event_to_groups.items()
+ }
+
+ defer.returnValue({event: event_to_state[event] for event in event_ids})
+
+ @defer.inlineCallbacks
def get_state_for_event(self, event_id, types=None):
"""
Get the state dict corresponding to a particular event
@@ -300,6 +316,23 @@ class StateStore(SQLBaseStore):
state_map = yield self.get_state_for_events([event_id], types)
defer.returnValue(state_map[event_id])
+ @defer.inlineCallbacks
+ def get_state_ids_for_event(self, event_id, types=None):
+ """
+ Get the state dict corresponding to a particular event
+
+ Args:
+ event_id(str): event whose state should be returned
+ types(list[(str, str)]|None): List of (type, state_key) tuples
+ which are used to filter the state fetched. May be None, which
+ matches any key
+
+ Returns:
+ A deferred dict from (type, state_key) -> state_event
+ """
+ state_map = yield self.get_state_ids_for_events([event_id], types)
+ defer.returnValue(state_map[event_id])
+
@cached(num_args=2, max_entries=10000)
def _get_state_group_for_event(self, room_id, event_id):
return self._simple_select_one_onecol(
|