summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/federation/federation_server.py72
-rw-r--r--synapse/handlers/federation.py9
-rw-r--r--synapse/storage/event_federation.py63
3 files changed, 135 insertions, 9 deletions
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 22b9663831..34bc397e8a 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -305,6 +305,78 @@ class FederationServer(FederationBase):
             (200, send_content)
         )
 
+    @defer.inlineCallbacks
+    def get_missing_events(self, origin, room_id, earliest_events,
+                           latest_events, limit, min_depth):
+        limit = max(limit, 50)
+        min_depth = max(min_depth, 0)
+
+        missing_events = yield self.store.get_missing_events(
+            room_id=room_id,
+            earliest_events=earliest_events,
+            latest_events=latest_events,
+            limit=limit,
+            min_depth=min_depth,
+        )
+
+        known_ids = {e.event_id for e in missing_events} | {earliest_events}
+
+        back_edges = {
+            e for e in missing_events
+            if {i for i, h in e.prev_events.items()} <= known_ids
+        }
+
+        decoded_auth_events = set()
+        state = {}
+        auth_events = set()
+        auth_and_state = {}
+        for event in back_edges:
+            state_pdus = yield self.handler.get_state_for_pdu(
+                origin, room_id, event.event_id,
+                do_auth=False,
+            )
+
+            state[event.event_id] = [s.event_id for s in state_pdus]
+
+            auth_and_state.update({
+                s.event_id: s for s in state_pdus
+            })
+
+            state_ids = {pdu.event_id for pdu in state_pdus}
+            prev_ids = {i for i, h in event.prev_events.items()}
+            partial_auth_chain = yield self.store.get_auth_chain(
+                state_ids | prev_ids, have_ids=decoded_auth_events.keys()
+            )
+
+            for p in partial_auth_chain:
+                p.signatures.update(
+                    compute_event_signature(
+                        p,
+                        self.hs.hostname,
+                        self.hs.config.signing_key[0]
+                    )
+                )
+
+            auth_events.update(
+                a.event_id for a in partial_auth_chain
+            )
+
+            auth_and_state.update({
+                a.event_id: a for a in partial_auth_chain
+            })
+
+        time_now = self._clock.time_msec()
+
+        defer.returnValue({
+            "events": [ev.get_pdu_json(time_now) for ev in missing_events],
+            "state_for_events": state,
+            "auth_events": auth_events,
+            "event_map": {
+                k: ev.get_pdu_json(time_now)
+                for k, ev in auth_and_state.items()
+            },
+        })
+
     @log_function
     def _get_persisted_pdu(self, origin, event_id, do_auth=True):
         """ Get a PDU from the database with given origin and id.
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 0eb2ff95ca..26bdc6d1a7 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -581,12 +581,13 @@ class FederationHandler(BaseHandler):
         defer.returnValue(event)
 
     @defer.inlineCallbacks
-    def get_state_for_pdu(self, origin, room_id, event_id):
+    def get_state_for_pdu(self, origin, room_id, event_id, do_auth=True):
         yield run_on_reactor()
 
-        in_room = yield self.auth.check_host_in_room(room_id, origin)
-        if not in_room:
-            raise AuthError(403, "Host not in room.")
+        if do_auth:
+            in_room = yield self.auth.check_host_in_room(room_id, origin)
+            if not in_room:
+                raise AuthError(403, "Host not in room.")
 
         state_groups = yield self.store.get_state_groups(
             [event_id]
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index 3fbc090224..22bf7ad832 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -32,15 +32,15 @@ class EventFederationStore(SQLBaseStore):
     and backfilling from another server respectively.
     """
 
-    def get_auth_chain(self, event_ids):
+    def get_auth_chain(self, event_ids, have_ids=set()):
         return self.runInteraction(
             "get_auth_chain",
             self._get_auth_chain_txn,
-            event_ids
+            event_ids, have_ids
         )
 
-    def _get_auth_chain_txn(self, txn, event_ids):
-        results = self._get_auth_chain_ids_txn(txn, event_ids)
+    def _get_auth_chain_txn(self, txn, event_ids, have_ids):
+        results = self._get_auth_chain_ids_txn(txn, event_ids, have_ids)
 
         return self._get_events_txn(txn, results)
 
@@ -51,8 +51,9 @@ class EventFederationStore(SQLBaseStore):
             event_ids
         )
 
-    def _get_auth_chain_ids_txn(self, txn, event_ids):
+    def _get_auth_chain_ids_txn(self, txn, event_ids, have_ids):
         results = set()
+        have_ids = set(have_ids)
 
         base_sql = (
             "SELECT auth_id FROM event_auth WHERE event_id = ?"
@@ -64,6 +65,10 @@ class EventFederationStore(SQLBaseStore):
             for f in front:
                 txn.execute(base_sql, (f,))
                 new_front.update([r[0] for r in txn.fetchall()])
+
+            new_front -= results
+            new_front -= have_ids
+
             front = new_front
             results.update(front)
 
@@ -378,3 +383,51 @@ class EventFederationStore(SQLBaseStore):
             event_results += new_front
 
         return self._get_events_txn(txn, event_results)
+
+    def get_missing_events(self, room_id, earliest_events, latest_events,
+                           limit, min_depth):
+        return self.runInteraction(
+            "get_missing_events",
+            self._get_missing_events,
+            room_id, earliest_events, latest_events, limit, min_depth
+        )
+
+    def _get_missing_events(self, txn, room_id, earliest_events, latest_events,
+                            limit, min_depth):
+
+        earliest_events = set(earliest_events)
+        front = set(latest_events) - earliest_events
+
+        event_results = set()
+
+        query = (
+            "SELECT prev_event_id FROM event_edges "
+            "WHERE room_id = ? AND event_id = ? AND is_state = 0 "
+            "LIMIT ?"
+        )
+
+        while front and len(event_results) < limit:
+            new_front = set()
+            for event_id in front:
+                txn.execute(
+                    query,
+                    (room_id, event_id, limit - len(event_results))
+                )
+
+                for e_id, in txn.fetchall():
+                    new_front.add(e_id)
+
+            new_front -= earliest_events
+            new_front -= event_results
+
+            front = new_front
+            event_results |= new_front
+
+        events = self._get_events_txn(txn, event_results)
+
+        events = sorted(
+            [ev for ev in events if ev.depth >= min_depth],
+            key=lambda e: e.depth,
+        )
+
+        return events[:limit]