summary refs log tree commit diff
path: root/synapse/handlers/federation.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/federation.py')
-rw-r--r--synapse/handlers/federation.py223
1 files changed, 88 insertions, 135 deletions
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 3882ba79ed..8d99101619 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -39,7 +39,7 @@ from twisted.internet import defer
 
 import itertools
 import logging
-
+from synapse.util.thirdpartyinvites import ThirdPartyInvites
 
 logger = logging.getLogger(__name__)
 
@@ -125,72 +125,60 @@ class FederationHandler(BaseHandler):
         )
         if not is_in_room and not event.internal_metadata.is_outlier():
             logger.debug("Got event for room we're not in.")
+            current_state = state
 
-            try:
-                event_stream_id, max_stream_id = yield self._persist_auth_tree(
-                    auth_chain, state, event
-                )
-            except AuthError as e:
-                raise FederationError(
-                    "ERROR",
-                    e.code,
-                    e.msg,
-                    affected=event.event_id,
-                )
+        event_ids = set()
+        if state:
+            event_ids |= {e.event_id for e in state}
+        if auth_chain:
+            event_ids |= {e.event_id for e in auth_chain}
 
-        else:
-            event_ids = set()
-            if state:
-                event_ids |= {e.event_id for e in state}
-            if auth_chain:
-                event_ids |= {e.event_id for e in auth_chain}
-
-            seen_ids = set(
-                (yield self.store.have_events(event_ids)).keys()
-            )
+        seen_ids = set(
+            (yield self.store.have_events(event_ids)).keys()
+        )
 
-            if state and auth_chain is not None:
-                # If we have any state or auth_chain given to us by the replication
-                # layer, then we should handle them (if we haven't before.)
+        if state and auth_chain is not None:
+            # If we have any state or auth_chain given to us by the replication
+            # layer, then we should handle them (if we haven't before.)
 
-                event_infos = []
+            event_infos = []
 
-                for e in itertools.chain(auth_chain, state):
-                    if e.event_id in seen_ids:
-                        continue
-                    e.internal_metadata.outlier = True
-                    auth_ids = [e_id for e_id, _ in e.auth_events]
-                    auth = {
-                        (e.type, e.state_key): e for e in auth_chain
-                        if e.event_id in auth_ids or e.type == EventTypes.Create
-                    }
-                    event_infos.append({
-                        "event": e,
-                        "auth_events": auth,
-                    })
-                    seen_ids.add(e.event_id)
+            for e in itertools.chain(auth_chain, state):
+                if e.event_id in seen_ids:
+                    continue
+                e.internal_metadata.outlier = True
+                auth_ids = [e_id for e_id, _ in e.auth_events]
+                auth = {
+                    (e.type, e.state_key): e for e in auth_chain
+                    if e.event_id in auth_ids
+                }
+                event_infos.append({
+                    "event": e,
+                    "auth_events": auth,
+                })
+                seen_ids.add(e.event_id)
 
-                yield self._handle_new_events(
-                    origin,
-                    event_infos,
-                    outliers=True
-                )
+            yield self._handle_new_events(
+                origin,
+                event_infos,
+                outliers=True
+            )
 
-            try:
-                _, event_stream_id, max_stream_id = yield self._handle_new_event(
-                    origin,
-                    event,
-                    state=state,
-                    backfilled=backfilled,
-                    current_state=current_state,
-                )
-            except AuthError as e:
-                raise FederationError(
-                    "ERROR",
-                    e.code,
-                    e.msg,
-                    affected=event.event_id,
-                )
+        try:
+            _, event_stream_id, max_stream_id = yield self._handle_new_event(
+                origin,
+                event,
+                state=state,
+                backfilled=backfilled,
+                current_state=current_state,
+            )
+        except AuthError as e:
+            raise FederationError(
+                "ERROR",
+                e.code,
+                e.msg,
+                affected=event.event_id,
+            )
 
         # if we're receiving valid events from an origin,
         # it's probably a good idea to mark it as not in retry-state
@@ -584,7 +572,8 @@ class FederationHandler(BaseHandler):
         origin, pdu = yield self.replication_layer.make_join(
             target_hosts,
             room_id,
-            joinee
+            joinee,
+            content
         )
 
         logger.debug("Got response to make_join: %s", pdu)
@@ -661,8 +650,35 @@ class FederationHandler(BaseHandler):
                 # FIXME
                 pass
 
-            event_stream_id, max_stream_id = yield self._persist_auth_tree(
-                auth_chain, state, event
+            ev_infos = []
+            for e in itertools.chain(state, auth_chain):
+                if e.event_id == event.event_id:
+                    continue
+
+                e.internal_metadata.outlier = True
+                auth_ids = [e_id for e_id, _ in e.auth_events]
+                ev_infos.append({
+                    "event": e,
+                    "auth_events": {
+                        (e.type, e.state_key): e for e in auth_chain
+                        if e.event_id in auth_ids
+                    }
+                })
+
+            yield self._handle_new_events(origin, ev_infos, outliers=True)
+
+            auth_ids = [e_id for e_id, _ in event.auth_events]
+            auth_events = {
+                (e.type, e.state_key): e for e in auth_chain
+                if e.event_id in auth_ids
+            }
+
+            _, event_stream_id, max_stream_id = yield self._handle_new_event(
+                origin,
+                new_event,
+                state=state,
+                current_state=state,
+                auth_events=auth_events,
             )
 
             with PreserveLoggingContext():
@@ -697,14 +713,18 @@ class FederationHandler(BaseHandler):
 
     @defer.inlineCallbacks
     @log_function
-    def on_make_join_request(self, room_id, user_id):
+    def on_make_join_request(self, room_id, user_id, query):
         """ We've received a /make_join/ request, so we create a partial
         join event for the room and return that. We don *not* persist or
         process it until the other server has signed it and sent it back.
         """
+        event_content = {"membership": Membership.JOIN}
+        if ThirdPartyInvites.has_join_keys(query):
+            ThirdPartyInvites.copy_join_keys(query, event_content)
+
         builder = self.event_builder_factory.new({
             "type": EventTypes.Member,
-            "content": {"membership": Membership.JOIN},
+            "content": event_content,
             "room_id": room_id,
             "sender": user_id,
             "state_key": user_id,
@@ -716,6 +736,9 @@ class FederationHandler(BaseHandler):
 
         self.auth.check(event, auth_events=context.current_state)
 
+        if ThirdPartyInvites.has_join_keys(event.content):
+            ThirdPartyInvites.check_key_valid(self.hs.get_simple_http_client(), event)
+
         defer.returnValue(event)
 
     @defer.inlineCallbacks
@@ -1012,76 +1035,6 @@ class FederationHandler(BaseHandler):
         )
 
     @defer.inlineCallbacks
-    def _persist_auth_tree(self, auth_events, state, event):
-        """Checks the auth chain is valid (and passes auth checks) for the
-        state and event. Then persists the auth chain and state atomically.
-        Persists the event seperately.
-
-        Returns:
-            2-tuple of (event_stream_id, max_stream_id) from the persist_event
-            call for `event`
-        """
-        events_to_context = {}
-        for e in itertools.chain(auth_events, state):
-            ctx = yield self.state_handler.compute_event_context(
-                e, outlier=True,
-            )
-            events_to_context[e.event_id] = ctx
-            e.internal_metadata.outlier = True
-
-        event_map = {
-            e.event_id: e
-            for e in auth_events
-        }
-
-        create_event = None
-        for e in auth_events:
-            if (e.type, e.state_key) == (EventTypes.Create, ""):
-                create_event = e
-                break
-
-        for e in itertools.chain(auth_events, state, [event]):
-            auth_for_e = {
-                (event_map[e_id].type, event_map[e_id].state_key): event_map[e_id]
-                for e_id, _ in e.auth_events
-            }
-            if create_event:
-                auth_for_e[(EventTypes.Create, "")] = create_event
-
-            try:
-                self.auth.check(e, auth_events=auth_for_e)
-            except AuthError as err:
-                logger.warn(
-                    "Rejecting %s because %s",
-                    e.event_id, err.msg
-                )
-
-                if e == event:
-                    raise
-                events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR
-
-        yield self.store.persist_events(
-            [
-                (e, events_to_context[e.event_id])
-                for e in itertools.chain(auth_events, state)
-            ],
-            is_new_state=False,
-        )
-
-        new_event_context = yield self.state_handler.compute_event_context(
-            event, old_state=state, outlier=False,
-        )
-
-        event_stream_id, max_stream_id = yield self.store.persist_event(
-            event, new_event_context,
-            backfilled=False,
-            is_new_state=True,
-            current_state=state,
-        )
-
-        defer.returnValue((event_stream_id, max_stream_id))
-
-    @defer.inlineCallbacks
     def _prep_event(self, origin, event, state=None, backfilled=False,
                     current_state=None, auth_events=None):
         outlier = event.internal_metadata.is_outlier()