summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/federation/replication.py5
-rw-r--r--synapse/federation/transport.py26
-rw-r--r--synapse/handlers/federation.py5
-rw-r--r--synapse/storage/event_federation.py26
4 files changed, 55 insertions, 7 deletions
diff --git a/synapse/federation/replication.py b/synapse/federation/replication.py
index e358de942e..719bfcc42c 100644
--- a/synapse/federation/replication.py
+++ b/synapse/federation/replication.py
@@ -427,6 +427,11 @@ class ReplicationLayer(object):
         }))
 
     @defer.inlineCallbacks
+    def on_event_auth(self, origin, context, event_id):
+        auth_pdus = yield self.handler.on_event_auth(event_id)
+        defer.returnValue((200, [a.get_dict() for a in auth_pdus]))
+
+    @defer.inlineCallbacks
     def make_join(self, destination, context, user_id):
         pdu_dict = yield self.transport_layer.make_join(
             destination=destination,
diff --git a/synapse/federation/transport.py b/synapse/federation/transport.py
index b9f7d54c71..babe8447eb 100644
--- a/synapse/federation/transport.py
+++ b/synapse/federation/transport.py
@@ -257,6 +257,21 @@ class TransportLayer(object):
         defer.returnValue(json.loads(content))
 
     @defer.inlineCallbacks
+    @log_function
+    def get_event_auth(self, destination, context, event_id):
+        path = PREFIX + "/event_auth/%s/%s" % (
+            context,
+            event_id,
+        )
+
+        response = yield self.client.get_json(
+            destination=destination,
+            path=path,
+        )
+
+        defer.returnValue(response)
+
+    @defer.inlineCallbacks
     def _authenticate_request(self, request):
         json_request = {
             "method": request.method,
@@ -427,6 +442,17 @@ class TransportLayer(object):
         )
 
         self.server.register_path(
+            "GET",
+            re.compile("^" + PREFIX + "/event_auth/([^/]*)/([^/]*)$"),
+            self._with_authentication(
+                lambda origin, content, query, context, event_id:
+                handler.on_event_auth(
+                    origin, context, event_id,
+                )
+            )
+        )
+
+        self.server.register_path(
             "PUT",
             re.compile("^" + PREFIX + "/send_join/([^/]*)/([^/]*)$"),
             self._with_authentication(
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index e6afd95a58..ce65bbcd62 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -224,6 +224,11 @@ class FederationHandler(BaseHandler):
 
         defer.returnValue(self.pdu_codec.event_from_pdu(pdu))
 
+    @defer.inlineCallbacks
+    def on_event_auth(self, event_id):
+        auth = yield self.store.get_auth_chain(event_id)
+        defer.returnValue([self.pdu_codec.pdu_from_event(e) for e in auth])
+
     @log_function
     @defer.inlineCallbacks
     def do_invite_join(self, target_host, room_id, joinee, content, snapshot):
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index d66a49e9f2..06e32d592d 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -32,6 +32,24 @@ class EventFederationStore(SQLBaseStore):
         )
 
     def _get_auth_chain_txn(self, txn, event_id):
+        results = self._get_auth_chain_ids_txn(txn, event_id)
+
+        sql = "SELECT * FROM events WHERE event_id = ?"
+        rows = []
+        for ev_id in results:
+            c = txn.execute(sql, (ev_id,))
+            rows.extend(self.cursor_to_dict(c))
+
+        return self._parse_events_txn(txn, rows)
+
+    def get_auth_chain_ids(self, event_id):
+        return self.runInteraction(
+            "get_auth_chain_ids",
+            self._get_auth_chain_ids_txn,
+            event_id
+        )
+
+    def _get_auth_chain_ids_txn(self, txn, event_id):
         results = set()
 
         base_sql = (
@@ -48,13 +66,7 @@ class EventFederationStore(SQLBaseStore):
             front = [r[0] for r in txn.fetchall()]
             results.update(front)
 
-        sql = "SELECT * FROM events WHERE event_id = ?"
-        rows = []
-        for ev_id in results:
-            c = txn.execute(sql, (ev_id,))
-            rows.extend(self.cursor_to_dict(c))
-
-        return self._parse_events_txn(txn, rows)
+        return list(results)
 
     def get_oldest_events_in_room(self, room_id):
         return self.runInteraction(