diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 5f86ed03fa..da99a4b449 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -112,7 +112,7 @@ class FederationHandler(BaseHandler):
is_new_state = yield self.state_handler.annotate_state_groups(
event,
- state=state
+ old_state=state
)
logger.debug("Event: %s", event)
@@ -240,7 +240,7 @@ class FederationHandler(BaseHandler):
is_new_state = yield self.state_handler.annotate_state_groups(
event,
- state=state
+ old_state=state
)
logger.debug("do_invite_join event: %s", event)
@@ -279,7 +279,10 @@ class FederationHandler(BaseHandler):
del self.room_queues[room_id]
for p in room_queue:
- yield self.on_receive_pdu(p, backfilled=False)
+ try:
+ yield self.on_receive_pdu(p, backfilled=False)
+ except:
+ pass
defer.returnValue(True)
@@ -355,15 +358,30 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks
def get_state_for_pdu(self, pdu_id, pdu_origin):
+ event_id = encode_event_id(pdu_id, pdu_origin)
+
state_groups = yield self.store.get_state_groups(
- [encode_event_id(pdu_id, pdu_origin)]
+ [event_id]
)
if state_groups:
+ results = {
+ (e.type, e.state_key): e for e in state_groups[0].state
+ }
+
+ event = yield self.store.get_event(event_id)
+ if hasattr(event, "state_key"):
+ # Get previous state
+ if hasattr(event, "prev_state") and event.prev_state:
+ prev_event = yield self.store.get_event(event.prev_state)
+ results[(event.type, event.state_key)] = prev_event
+ else:
+ del results[(event.type, event.state_key)]
+
defer.returnValue(
[
self.pdu_codec.pdu_from_event(s)
- for s in state_groups[0].state
+ for s in results.values()
]
)
else:
diff --git a/synapse/state.py b/synapse/state.py
index 993c4f18d3..a59688e3b4 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -128,11 +128,15 @@ class StateHandler(object):
@defer.inlineCallbacks
@log_function
- def annotate_state_groups(self, event, state=None):
- if state:
+ def annotate_state_groups(self, event, old_state=None):
+ if old_state:
event.state_group = None
- event.old_state_events = None
- event.state_events = {(s.type, s.state_key): s for s in state}
+ event.old_state_events = old_state
+ event.state_events = {(s.type, s.state_key): s for s in old_state}
+
+ if hasattr(event, "state_key"):
+ event.state_events[(event.type, event.state_key)] = event
+
defer.returnValue(False)
return
@@ -163,7 +167,7 @@ class StateHandler(object):
event_ids = [
e_id
- for e_id, _ in events
+ for e_id, _, _ in events
]
res = yield self.resolve_state_groups(event_ids)
|