diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 22b9663831..34bc397e8a 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -305,6 +305,78 @@ class FederationServer(FederationBase):
(200, send_content)
)
+ @defer.inlineCallbacks
+ def get_missing_events(self, origin, room_id, earliest_events,
+ latest_events, limit, min_depth):
+ limit = max(limit, 50)
+ min_depth = max(min_depth, 0)
+
+ missing_events = yield self.store.get_missing_events(
+ room_id=room_id,
+ earliest_events=earliest_events,
+ latest_events=latest_events,
+ limit=limit,
+ min_depth=min_depth,
+ )
+
+ known_ids = {e.event_id for e in missing_events} | {earliest_events}
+
+ back_edges = {
+ e for e in missing_events
+ if {i for i, h in e.prev_events.items()} <= known_ids
+ }
+
+ decoded_auth_events = set()
+ state = {}
+ auth_events = set()
+ auth_and_state = {}
+ for event in back_edges:
+ state_pdus = yield self.handler.get_state_for_pdu(
+ origin, room_id, event.event_id,
+ do_auth=False,
+ )
+
+ state[event.event_id] = [s.event_id for s in state_pdus]
+
+ auth_and_state.update({
+ s.event_id: s for s in state_pdus
+ })
+
+ state_ids = {pdu.event_id for pdu in state_pdus}
+ prev_ids = {i for i, h in event.prev_events.items()}
+ partial_auth_chain = yield self.store.get_auth_chain(
+ state_ids | prev_ids, have_ids=decoded_auth_events.keys()
+ )
+
+ for p in partial_auth_chain:
+ p.signatures.update(
+ compute_event_signature(
+ p,
+ self.hs.hostname,
+ self.hs.config.signing_key[0]
+ )
+ )
+
+ auth_events.update(
+ a.event_id for a in partial_auth_chain
+ )
+
+ auth_and_state.update({
+ a.event_id: a for a in partial_auth_chain
+ })
+
+ time_now = self._clock.time_msec()
+
+ defer.returnValue({
+ "events": [ev.get_pdu_json(time_now) for ev in missing_events],
+ "state_for_events": state,
+ "auth_events": auth_events,
+ "event_map": {
+ k: ev.get_pdu_json(time_now)
+ for k, ev in auth_and_state.items()
+ },
+ })
+
@log_function
def _get_persisted_pdu(self, origin, event_id, do_auth=True):
""" Get a PDU from the database with given origin and id.
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 0eb2ff95ca..26bdc6d1a7 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -581,12 +581,13 @@ class FederationHandler(BaseHandler):
defer.returnValue(event)
@defer.inlineCallbacks
- def get_state_for_pdu(self, origin, room_id, event_id):
+ def get_state_for_pdu(self, origin, room_id, event_id, do_auth=True):
yield run_on_reactor()
- in_room = yield self.auth.check_host_in_room(room_id, origin)
- if not in_room:
- raise AuthError(403, "Host not in room.")
+ if do_auth:
+ in_room = yield self.auth.check_host_in_room(room_id, origin)
+ if not in_room:
+ raise AuthError(403, "Host not in room.")
state_groups = yield self.store.get_state_groups(
[event_id]
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index 3fbc090224..22bf7ad832 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -32,15 +32,15 @@ class EventFederationStore(SQLBaseStore):
and backfilling from another server respectively.
"""
- def get_auth_chain(self, event_ids):
+ def get_auth_chain(self, event_ids, have_ids=set()):
return self.runInteraction(
"get_auth_chain",
self._get_auth_chain_txn,
- event_ids
+ event_ids, have_ids
)
- def _get_auth_chain_txn(self, txn, event_ids):
- results = self._get_auth_chain_ids_txn(txn, event_ids)
+ def _get_auth_chain_txn(self, txn, event_ids, have_ids):
+ results = self._get_auth_chain_ids_txn(txn, event_ids, have_ids)
return self._get_events_txn(txn, results)
@@ -51,8 +51,9 @@ class EventFederationStore(SQLBaseStore):
event_ids
)
- def _get_auth_chain_ids_txn(self, txn, event_ids):
+ def _get_auth_chain_ids_txn(self, txn, event_ids, have_ids):
results = set()
+ have_ids = set(have_ids)
base_sql = (
"SELECT auth_id FROM event_auth WHERE event_id = ?"
@@ -64,6 +65,10 @@ class EventFederationStore(SQLBaseStore):
for f in front:
txn.execute(base_sql, (f,))
new_front.update([r[0] for r in txn.fetchall()])
+
+ new_front -= results
+ new_front -= have_ids
+
front = new_front
results.update(front)
@@ -378,3 +383,51 @@ class EventFederationStore(SQLBaseStore):
event_results += new_front
return self._get_events_txn(txn, event_results)
+
+ def get_missing_events(self, room_id, earliest_events, latest_events,
+ limit, min_depth):
+ return self.runInteraction(
+ "get_missing_events",
+ self._get_missing_events,
+ room_id, earliest_events, latest_events, limit, min_depth
+ )
+
+ def _get_missing_events(self, txn, room_id, earliest_events, latest_events,
+ limit, min_depth):
+
+ earliest_events = set(earliest_events)
+ front = set(latest_events) - earliest_events
+
+ event_results = set()
+
+ query = (
+ "SELECT prev_event_id FROM event_edges "
+ "WHERE room_id = ? AND event_id = ? AND is_state = 0 "
+ "LIMIT ?"
+ )
+
+ while front and len(event_results) < limit:
+ new_front = set()
+ for event_id in front:
+ txn.execute(
+ query,
+ (room_id, event_id, limit - len(event_results))
+ )
+
+ for e_id, in txn.fetchall():
+ new_front.add(e_id)
+
+ new_front -= earliest_events
+ new_front -= event_results
+
+ front = new_front
+ event_results |= new_front
+
+ events = self._get_events_txn(txn, event_results)
+
+ events = sorted(
+ [ev for ev in events if ev.depth >= min_depth],
+ key=lambda e: e.depth,
+ )
+
+ return events[:limit]
|