summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2015-09-11 11:07:22 +0100
committerErik Johnston <erik@matrix.org>2015-09-17 10:24:51 +0100
commit4678055173636f9940e77f1af35b888f99506030 (patch)
tree2d04c5fe182ddf36d90b0696463302b2880a07e0 /synapse/handlers
parentMerge branch 'master' of github.com:matrix-org/synapse into develop (diff)
downloadsynapse-4678055173636f9940e77f1af35b888f99506030.tar.xz
Refactor do_invite_join
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/federation.py84
1 files changed, 56 insertions, 28 deletions
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 4ff20599d6..30b9982e25 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -649,35 +649,10 @@ class FederationHandler(BaseHandler):
                 # FIXME
                 pass
 
-            ev_infos = []
-            for e in itertools.chain(state, auth_chain):
-                if e.event_id == event.event_id:
-                    continue
-
-                e.internal_metadata.outlier = True
-                auth_ids = [e_id for e_id, _ in e.auth_events]
-                ev_infos.append({
-                    "event": e,
-                    "auth_events": {
-                        (e.type, e.state_key): e for e in auth_chain
-                        if e.event_id in auth_ids
-                    }
-                })
-
-            yield self._handle_new_events(origin, ev_infos, outliers=True)
+            self._check_auth_tree(auth_chain, event)
 
-            auth_ids = [e_id for e_id, _ in event.auth_events]
-            auth_events = {
-                (e.type, e.state_key): e for e in auth_chain
-                if e.event_id in auth_ids
-            }
-
-            _, event_stream_id, max_stream_id = yield self._handle_new_event(
-                origin,
-                new_event,
-                state=state,
-                current_state=state,
-                auth_events=auth_events,
+            event_stream_id, max_stream_id = yield self._persist_auth_tree(
+                auth_chain, state, event
             )
 
             with PreserveLoggingContext():
@@ -1026,6 +1001,59 @@ class FederationHandler(BaseHandler):
             is_new_state=(not outliers and not backfilled),
         )
 
+    def _check_auth_tree(self, auth_events, event):
+        event_map = {
+            e.event_id: e
+            for e in auth_events
+        }
+
+        create_event = None
+        for e in auth_events:
+            if (e.type, e.state_key) == (EventTypes.Create, ""):
+                create_event = e
+                break
+
+        for e in auth_events + [event]:
+            a = {
+                (event_map[e_id].type, event_map[e_id].state_key): event_map[e_id]
+                for e_id, _ in e.auth_events
+            }
+            if create_event:
+                a[(EventTypes.Create, "")] = create_event
+
+            self.auth.check(e, auth_events=a)
+
+    @defer.inlineCallbacks
+    def _persist_auth_tree(self, auth_events, state, event):
+        events_to_context = {}
+        for e in auth_events:
+            ctx = yield self.state_handler.compute_event_context(
+                e, outlier=True,
+            )
+            events_to_context[e.event_id] = ctx
+            e.internal_metadata.outlier = True
+
+        yield self.store.persist_events(
+            [
+                (e, events_to_context[e.event_id])
+                for e in auth_events
+            ],
+            is_new_state=False,
+        )
+
+        new_event_context = yield self.state_handler.compute_event_context(
+            event, old_state=state, outlier=False,
+        )
+
+        event_stream_id, max_stream_id = yield self.store.persist_event(
+            event, new_event_context,
+            backfilled=False,
+            is_new_state=True,
+            current_state=state,
+        )
+
+        defer.returnValue((event_stream_id, max_stream_id))
+
     @defer.inlineCallbacks
     def _prep_event(self, origin, event, state=None, backfilled=False,
                     current_state=None, auth_events=None):