summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/federation.py256
-rw-r--r--synapse/handlers/room.py44
2 files changed, 147 insertions, 153 deletions
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 4ff20599d6..3aff80bf59 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -125,60 +125,72 @@ class FederationHandler(BaseHandler):
         )
         if not is_in_room and not event.internal_metadata.is_outlier():
             logger.debug("Got event for room we're not in.")
-            current_state = state
 
-        event_ids = set()
-        if state:
-            event_ids |= {e.event_id for e in state}
-        if auth_chain:
-            event_ids |= {e.event_id for e in auth_chain}
+            try:
+                event_stream_id, max_stream_id = yield self._persist_auth_tree(
+                    auth_chain, state, event
+                )
+            except AuthError as e:
+                raise FederationError(
+                    "ERROR",
+                    e.code,
+                    e.msg,
+                    affected=event.event_id,
+                )
 
-        seen_ids = set(
-            (yield self.store.have_events(event_ids)).keys()
-        )
+        else:
+            event_ids = set()
+            if state:
+                event_ids |= {e.event_id for e in state}
+            if auth_chain:
+                event_ids |= {e.event_id for e in auth_chain}
+
+            seen_ids = set(
+                (yield self.store.have_events(event_ids)).keys()
+            )
 
-        if state and auth_chain is not None:
-            # If we have any state or auth_chain given to us by the replication
-            # layer, then we should handle them (if we haven't before.)
+            if state and auth_chain is not None:
+                # If we have any state or auth_chain given to us by the replication
+                # layer, then we should handle them (if we haven't before.)
 
-            event_infos = []
+                event_infos = []
 
-            for e in itertools.chain(auth_chain, state):
-                if e.event_id in seen_ids:
-                    continue
-                e.internal_metadata.outlier = True
-                auth_ids = [e_id for e_id, _ in e.auth_events]
-                auth = {
-                    (e.type, e.state_key): e for e in auth_chain
-                    if e.event_id in auth_ids
-                }
-                event_infos.append({
-                    "event": e,
-                    "auth_events": auth,
-                })
-                seen_ids.add(e.event_id)
+                for e in itertools.chain(auth_chain, state):
+                    if e.event_id in seen_ids:
+                        continue
+                    e.internal_metadata.outlier = True
+                    auth_ids = [e_id for e_id, _ in e.auth_events]
+                    auth = {
+                        (e.type, e.state_key): e for e in auth_chain
+                        if e.event_id in auth_ids
+                    }
+                    event_infos.append({
+                        "event": e,
+                        "auth_events": auth,
+                    })
+                    seen_ids.add(e.event_id)
 
-            yield self._handle_new_events(
-                origin,
-                event_infos,
-                outliers=True
-            )
+                yield self._handle_new_events(
+                    origin,
+                    event_infos,
+                    outliers=True
+                )
 
-        try:
-            _, event_stream_id, max_stream_id = yield self._handle_new_event(
-                origin,
-                event,
-                state=state,
-                backfilled=backfilled,
-                current_state=current_state,
-            )
-        except AuthError as e:
-            raise FederationError(
-                "ERROR",
-                e.code,
-                e.msg,
-                affected=event.event_id,
-            )
+            try:
+                _, event_stream_id, max_stream_id = yield self._handle_new_event(
+                    origin,
+                    event,
+                    state=state,
+                    backfilled=backfilled,
+                    current_state=current_state,
+                )
+            except AuthError as e:
+                raise FederationError(
+                    "ERROR",
+                    e.code,
+                    e.msg,
+                    affected=event.event_id,
+                )
 
         # if we're receiving valid events from an origin,
         # it's probably a good idea to mark it as not in retry-state
@@ -649,35 +661,8 @@ class FederationHandler(BaseHandler):
                 # FIXME
                 pass
 
-            ev_infos = []
-            for e in itertools.chain(state, auth_chain):
-                if e.event_id == event.event_id:
-                    continue
-
-                e.internal_metadata.outlier = True
-                auth_ids = [e_id for e_id, _ in e.auth_events]
-                ev_infos.append({
-                    "event": e,
-                    "auth_events": {
-                        (e.type, e.state_key): e for e in auth_chain
-                        if e.event_id in auth_ids
-                    }
-                })
-
-            yield self._handle_new_events(origin, ev_infos, outliers=True)
-
-            auth_ids = [e_id for e_id, _ in event.auth_events]
-            auth_events = {
-                (e.type, e.state_key): e for e in auth_chain
-                if e.event_id in auth_ids
-            }
-
-            _, event_stream_id, max_stream_id = yield self._handle_new_event(
-                origin,
-                new_event,
-                state=state,
-                current_state=state,
-                auth_events=auth_events,
+            event_stream_id, max_stream_id = yield self._persist_auth_tree(
+                auth_chain, state, event
             )
 
             with PreserveLoggingContext():
@@ -1027,6 +1012,76 @@ class FederationHandler(BaseHandler):
         )
 
     @defer.inlineCallbacks
+    def _persist_auth_tree(self, auth_events, state, event):
+        """Checks the auth chain is valid (and passes auth checks) for the
+        state and event. Then persists the auth chain and state atomically.
+        Persists the event seperately.
+
+        Returns:
+            2-tuple of (event_stream_id, max_stream_id) from the persist_event
+            call for `event`
+        """
+        events_to_context = {}
+        for e in itertools.chain(auth_events, state):
+            ctx = yield self.state_handler.compute_event_context(
+                e, outlier=True,
+            )
+            events_to_context[e.event_id] = ctx
+            e.internal_metadata.outlier = True
+
+        event_map = {
+            e.event_id: e
+            for e in auth_events
+        }
+
+        create_event = None
+        for e in auth_events:
+            if (e.type, e.state_key) == (EventTypes.Create, ""):
+                create_event = e
+                break
+
+        for e in itertools.chain(auth_events, state, [event]):
+            auth_for_e = {
+                (event_map[e_id].type, event_map[e_id].state_key): event_map[e_id]
+                for e_id, _ in e.auth_events
+            }
+            if create_event:
+                auth_for_e[(EventTypes.Create, "")] = create_event
+
+            try:
+                self.auth.check(e, auth_events=auth_for_e)
+            except AuthError as err:
+                logger.warn(
+                    "Rejecting %s because %s",
+                    e.event_id, err.msg
+                )
+
+                if e == event:
+                    raise
+                events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR
+
+        yield self.store.persist_events(
+            [
+                (e, events_to_context[e.event_id])
+                for e in itertools.chain(auth_events, state)
+            ],
+            is_new_state=False,
+        )
+
+        new_event_context = yield self.state_handler.compute_event_context(
+            event, old_state=state, outlier=False,
+        )
+
+        event_stream_id, max_stream_id = yield self.store.persist_event(
+            event, new_event_context,
+            backfilled=False,
+            is_new_state=True,
+            current_state=state,
+        )
+
+        defer.returnValue((event_stream_id, max_stream_id))
+
+    @defer.inlineCallbacks
     def _prep_event(self, origin, event, state=None, backfilled=False,
                     current_state=None, auth_events=None):
         outlier = event.internal_metadata.is_outlier()
@@ -1456,52 +1511,3 @@ class FederationHandler(BaseHandler):
             },
             "missing": [e.event_id for e in missing_locals],
         })
-
-    @defer.inlineCallbacks
-    def _handle_auth_events(self, origin, auth_events):
-        auth_ids_to_deferred = {}
-
-        def process_auth_ev(ev):
-            auth_ids = [e_id for e_id, _ in ev.auth_events]
-
-            prev_ds = [
-                auth_ids_to_deferred[i]
-                for i in auth_ids
-                if i in auth_ids_to_deferred
-            ]
-
-            d = defer.Deferred()
-
-            auth_ids_to_deferred[ev.event_id] = d
-
-            @defer.inlineCallbacks
-            def f(*_):
-                ev.internal_metadata.outlier = True
-
-                try:
-                    auth = {
-                        (e.type, e.state_key): e for e in auth_events
-                        if e.event_id in auth_ids
-                    }
-
-                    yield self._handle_new_event(
-                        origin, ev, auth_events=auth
-                    )
-                except:
-                    logger.exception(
-                        "Failed to handle auth event %s",
-                        ev.event_id,
-                    )
-
-                d.callback(None)
-
-            if prev_ds:
-                dx = defer.DeferredList(prev_ds)
-                dx.addBoth(f)
-            else:
-                f()
-
-        for e in auth_events:
-            process_auth_ev(e)
-
-        yield defer.DeferredList(auth_ids_to_deferred.values())
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index bb5eef6bbd..ac636255c2 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -149,12 +149,16 @@ class RoomCreationHandler(BaseHandler):
         for val in raw_initial_state:
             initial_state[(val["type"], val.get("state_key", ""))] = val["content"]
 
+        creation_content = config.get("creation_content", {})
+
         user = UserID.from_string(user_id)
         creation_events = self._create_events_for_new_room(
             user, room_id,
             preset_config=preset_config,
             invite_list=invite_list,
             initial_state=initial_state,
+            creation_content=creation_content,
+            room_alias=room_alias,
         )
 
         msg_handler = self.hs.get_handlers().message_handler
@@ -202,7 +206,8 @@ class RoomCreationHandler(BaseHandler):
         defer.returnValue(result)
 
     def _create_events_for_new_room(self, creator, room_id, preset_config,
-                                    invite_list, initial_state):
+                                    invite_list, initial_state, creation_content,
+                                    room_alias):
         config = RoomCreationHandler.PRESETS_DICT[preset_config]
 
         creator_id = creator.to_string()
@@ -224,9 +229,10 @@ class RoomCreationHandler(BaseHandler):
 
             return e
 
+        creation_content.update({"creator": creator.to_string()})
         creation_event = create(
             etype=EventTypes.Create,
-            content={"creator": creator.to_string()},
+            content=creation_content,
         )
 
         join_event = create(
@@ -271,6 +277,14 @@ class RoomCreationHandler(BaseHandler):
 
             returned_events.append(power_levels_event)
 
+        if room_alias and (EventTypes.CanonicalAlias, '') not in initial_state:
+            room_alias_event = create(
+                etype=EventTypes.CanonicalAlias,
+                content={"alias": room_alias.to_string()},
+            )
+
+            returned_events.append(room_alias_event)
+
         if (EventTypes.JoinRules, '') not in initial_state:
             join_rules_event = create(
                 etype=EventTypes.JoinRules,
@@ -493,32 +507,6 @@ class RoomMemberHandler(BaseHandler):
         )
 
     @defer.inlineCallbacks
-    def _should_invite_join(self, room_id, prev_state, do_auth):
-        logger.debug("_should_invite_join: room_id: %s", room_id)
-
-        # XXX: We don't do an auth check if we are doing an invite
-        # join dance for now, since we're kinda implicitly checking
-        # that we are allowed to join when we decide whether or not we
-        # need to do the invite/join dance.
-
-        # Only do an invite join dance if a) we were invited,
-        # b) the person inviting was from a differnt HS and c) we are
-        # not currently in the room
-        room_host = None
-        if prev_state and prev_state.membership == Membership.INVITE:
-            room = yield self.store.get_room(room_id)
-            inviter = UserID.from_string(
-                prev_state.sender
-            )
-
-            is_remote_invite_join = not self.hs.is_mine(inviter) and not room
-            room_host = inviter.domain
-        else:
-            is_remote_invite_join = False
-
-        defer.returnValue((is_remote_invite_join, room_host))
-
-    @defer.inlineCallbacks
     def get_joined_rooms_for_user(self, user):
         """Returns a list of roomids that the user has any of the given
         membership states in."""