summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/_base.py30
-rw-r--r--synapse/handlers/federation.py210
-rw-r--r--synapse/handlers/message.py4
-rw-r--r--synapse/handlers/room.py23
4 files changed, 184 insertions, 83 deletions
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index 60ac6617ae..c488ee0f6d 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -123,24 +123,39 @@ class BaseHandler(object):
                         )
                     )
 
-        (event_stream_id, max_stream_id) = yield self.store.persist_event(
-            event, context=context
-        )
-
         federation_handler = self.hs.get_handlers().federation_handler
 
         if event.type == EventTypes.Member:
             if event.content["membership"] == Membership.INVITE:
+                event.unsigned["invite_room_state"] = [
+                    {
+                        "type": e.type,
+                        "state_key": e.state_key,
+                        "content": e.content,
+                        "sender": e.sender,
+                    }
+                    for k, e in context.current_state.items()
+                    if e.type in (
+                        EventTypes.JoinRules,
+                        EventTypes.CanonicalAlias,
+                        EventTypes.RoomAvatar,
+                        EventTypes.Name,
+                    )
+                ]
+
                 invitee = UserID.from_string(event.state_key)
                 if not self.hs.is_mine(invitee):
                     # TODO: Can we add signature from remote server in a nicer
                     # way? If we have been invited by a remote server, we need
                     # to get them to sign the event.
+
                     returned_invite = yield federation_handler.send_invite(
                         invitee.domain,
                         event,
                     )
 
+                    event.unsigned.pop("room_state", None)
+
                     # TODO: Make sure the signatures actually are correct.
                     event.signatures.update(
                         returned_invite.signatures
@@ -161,6 +176,10 @@ class BaseHandler(object):
                         "You don't have permission to redact events"
                     )
 
+        (event_stream_id, max_stream_id) = yield self.store.persist_event(
+            event, context=context
+        )
+
         destinations = set(extra_destinations)
         for k, s in context.current_state.items():
             try:
@@ -189,6 +208,9 @@ class BaseHandler(object):
 
         notify_d.addErrback(log_failure)
 
+        # If invite, remove room_state from unsigned before sending.
+        event.unsigned.pop("invite_room_state", None)
+
         federation_handler.handle_new_event(
             event, destinations=destinations,
         )
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index f4dce712f9..3882ba79ed 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -125,60 +125,72 @@ class FederationHandler(BaseHandler):
         )
         if not is_in_room and not event.internal_metadata.is_outlier():
             logger.debug("Got event for room we're not in.")
-            current_state = state
 
-        event_ids = set()
-        if state:
-            event_ids |= {e.event_id for e in state}
-        if auth_chain:
-            event_ids |= {e.event_id for e in auth_chain}
+            try:
+                event_stream_id, max_stream_id = yield self._persist_auth_tree(
+                    auth_chain, state, event
+                )
+            except AuthError as e:
+                raise FederationError(
+                    "ERROR",
+                    e.code,
+                    e.msg,
+                    affected=event.event_id,
+                )
 
-        seen_ids = set(
-            (yield self.store.have_events(event_ids)).keys()
-        )
+        else:
+            event_ids = set()
+            if state:
+                event_ids |= {e.event_id for e in state}
+            if auth_chain:
+                event_ids |= {e.event_id for e in auth_chain}
+
+            seen_ids = set(
+                (yield self.store.have_events(event_ids)).keys()
+            )
 
-        if state and auth_chain is not None:
-            # If we have any state or auth_chain given to us by the replication
-            # layer, then we should handle them (if we haven't before.)
+            if state and auth_chain is not None:
+                # If we have any state or auth_chain given to us by the replication
+                # layer, then we should handle them (if we haven't before.)
 
-            event_infos = []
+                event_infos = []
 
-            for e in itertools.chain(auth_chain, state):
-                if e.event_id in seen_ids:
-                    continue
-                e.internal_metadata.outlier = True
-                auth_ids = [e_id for e_id, _ in e.auth_events]
-                auth = {
-                    (e.type, e.state_key): e for e in auth_chain
-                    if e.event_id in auth_ids
-                }
-                event_infos.append({
-                    "event": e,
-                    "auth_events": auth,
-                })
-                seen_ids.add(e.event_id)
+                for e in itertools.chain(auth_chain, state):
+                    if e.event_id in seen_ids:
+                        continue
+                    e.internal_metadata.outlier = True
+                    auth_ids = [e_id for e_id, _ in e.auth_events]
+                    auth = {
+                        (e.type, e.state_key): e for e in auth_chain
+                        if e.event_id in auth_ids or e.type == EventTypes.Create
+                    }
+                    event_infos.append({
+                        "event": e,
+                        "auth_events": auth,
+                    })
+                    seen_ids.add(e.event_id)
 
-            yield self._handle_new_events(
-                origin,
-                event_infos,
-                outliers=True
-            )
+                yield self._handle_new_events(
+                    origin,
+                    event_infos,
+                    outliers=True
+                )
 
-        try:
-            _, event_stream_id, max_stream_id = yield self._handle_new_event(
-                origin,
-                event,
-                state=state,
-                backfilled=backfilled,
-                current_state=current_state,
-            )
-        except AuthError as e:
-            raise FederationError(
-                "ERROR",
-                e.code,
-                e.msg,
-                affected=event.event_id,
-            )
+            try:
+                _, event_stream_id, max_stream_id = yield self._handle_new_event(
+                    origin,
+                    event,
+                    state=state,
+                    backfilled=backfilled,
+                    current_state=current_state,
+                )
+            except AuthError as e:
+                raise FederationError(
+                    "ERROR",
+                    e.code,
+                    e.msg,
+                    affected=event.event_id,
+                )
 
         # if we're receiving valid events from an origin,
         # it's probably a good idea to mark it as not in retry-state
@@ -649,35 +661,8 @@ class FederationHandler(BaseHandler):
                 # FIXME
                 pass
 
-            ev_infos = []
-            for e in itertools.chain(state, auth_chain):
-                if e.event_id == event.event_id:
-                    continue
-
-                e.internal_metadata.outlier = True
-                auth_ids = [e_id for e_id, _ in e.auth_events]
-                ev_infos.append({
-                    "event": e,
-                    "auth_events": {
-                        (e.type, e.state_key): e for e in auth_chain
-                        if e.event_id in auth_ids
-                    }
-                })
-
-            yield self._handle_new_events(origin, ev_infos, outliers=True)
-
-            auth_ids = [e_id for e_id, _ in event.auth_events]
-            auth_events = {
-                (e.type, e.state_key): e for e in auth_chain
-                if e.event_id in auth_ids
-            }
-
-            _, event_stream_id, max_stream_id = yield self._handle_new_event(
-                origin,
-                new_event,
-                state=state,
-                current_state=state,
-                auth_events=auth_events,
+            event_stream_id, max_stream_id = yield self._persist_auth_tree(
+                auth_chain, state, event
             )
 
             with PreserveLoggingContext():
@@ -1027,6 +1012,76 @@ class FederationHandler(BaseHandler):
         )
 
     @defer.inlineCallbacks
+    def _persist_auth_tree(self, auth_events, state, event):
+        """Checks the auth chain is valid (and passes auth checks) for the
+        state and event. Then persists the auth chain and state atomically.
+        Persists the event seperately.
+
+        Returns:
+            2-tuple of (event_stream_id, max_stream_id) from the persist_event
+            call for `event`
+        """
+        events_to_context = {}
+        for e in itertools.chain(auth_events, state):
+            ctx = yield self.state_handler.compute_event_context(
+                e, outlier=True,
+            )
+            events_to_context[e.event_id] = ctx
+            e.internal_metadata.outlier = True
+
+        event_map = {
+            e.event_id: e
+            for e in auth_events
+        }
+
+        create_event = None
+        for e in auth_events:
+            if (e.type, e.state_key) == (EventTypes.Create, ""):
+                create_event = e
+                break
+
+        for e in itertools.chain(auth_events, state, [event]):
+            auth_for_e = {
+                (event_map[e_id].type, event_map[e_id].state_key): event_map[e_id]
+                for e_id, _ in e.auth_events
+            }
+            if create_event:
+                auth_for_e[(EventTypes.Create, "")] = create_event
+
+            try:
+                self.auth.check(e, auth_events=auth_for_e)
+            except AuthError as err:
+                logger.warn(
+                    "Rejecting %s because %s",
+                    e.event_id, err.msg
+                )
+
+                if e == event:
+                    raise
+                events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR
+
+        yield self.store.persist_events(
+            [
+                (e, events_to_context[e.event_id])
+                for e in itertools.chain(auth_events, state)
+            ],
+            is_new_state=False,
+        )
+
+        new_event_context = yield self.state_handler.compute_event_context(
+            event, old_state=state, outlier=False,
+        )
+
+        event_stream_id, max_stream_id = yield self.store.persist_event(
+            event, new_event_context,
+            backfilled=False,
+            is_new_state=True,
+            current_state=state,
+        )
+
+        defer.returnValue((event_stream_id, max_stream_id))
+
+    @defer.inlineCallbacks
     def _prep_event(self, origin, event, state=None, backfilled=False,
                     current_state=None, auth_events=None):
         outlier = event.internal_metadata.is_outlier()
@@ -1166,7 +1221,7 @@ class FederationHandler(BaseHandler):
                         auth_ids = [e_id for e_id, _ in e.auth_events]
                         auth = {
                             (e.type, e.state_key): e for e in remote_auth_chain
-                            if e.event_id in auth_ids
+                            if e.event_id in auth_ids or e.type == EventTypes.Create
                         }
                         e.internal_metadata.outlier = True
 
@@ -1284,6 +1339,7 @@ class FederationHandler(BaseHandler):
                                 (e.type, e.state_key): e
                                 for e in result["auth_chain"]
                                 if e.event_id in auth_ids
+                                or event.type == EventTypes.Create
                             }
                             ev.internal_metadata.outlier = True
 
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index bda8eb5f3f..30949ff7a6 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -383,8 +383,12 @@ class MessageHandler(BaseHandler):
             }
 
             if event.membership == Membership.INVITE:
+                time_now = self.clock.time_msec()
                 d["inviter"] = event.sender
 
+                invite_event = yield self.store.get_event(event.event_id)
+                d["invite"] = serialize_event(invite_event, time_now, as_client_event)
+
             rooms_ret.append(d)
 
             if event.membership not in (Membership.JOIN, Membership.LEAVE):
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 773f0a2e92..3364a5de14 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -41,6 +41,11 @@ class RoomCreationHandler(BaseHandler):
             "history_visibility": "shared",
             "original_invitees_have_ops": False,
         },
+        RoomCreationPreset.TRUSTED_PRIVATE_CHAT: {
+            "join_rules": JoinRules.INVITE,
+            "history_visibility": "shared",
+            "original_invitees_have_ops": True,
+        },
         RoomCreationPreset.PUBLIC_CHAT: {
             "join_rules": JoinRules.PUBLIC,
             "history_visibility": "shared",
@@ -149,12 +154,16 @@ class RoomCreationHandler(BaseHandler):
         for val in raw_initial_state:
             initial_state[(val["type"], val.get("state_key", ""))] = val["content"]
 
+        creation_content = config.get("creation_content", {})
+
         user = UserID.from_string(user_id)
         creation_events = self._create_events_for_new_room(
             user, room_id,
             preset_config=preset_config,
             invite_list=invite_list,
             initial_state=initial_state,
+            creation_content=creation_content,
+            room_alias=room_alias,
         )
 
         msg_handler = self.hs.get_handlers().message_handler
@@ -202,7 +211,8 @@ class RoomCreationHandler(BaseHandler):
         defer.returnValue(result)
 
     def _create_events_for_new_room(self, creator, room_id, preset_config,
-                                    invite_list, initial_state):
+                                    invite_list, initial_state, creation_content,
+                                    room_alias):
         config = RoomCreationHandler.PRESETS_DICT[preset_config]
 
         creator_id = creator.to_string()
@@ -224,9 +234,10 @@ class RoomCreationHandler(BaseHandler):
 
             return e
 
+        creation_content.update({"creator": creator.to_string()})
         creation_event = create(
             etype=EventTypes.Create,
-            content={"creator": creator.to_string()},
+            content=creation_content,
         )
 
         join_event = create(
@@ -271,6 +282,14 @@ class RoomCreationHandler(BaseHandler):
 
             returned_events.append(power_levels_event)
 
+        if room_alias and (EventTypes.CanonicalAlias, '') not in initial_state:
+            room_alias_event = create(
+                etype=EventTypes.CanonicalAlias,
+                content={"alias": room_alias.to_string()},
+            )
+
+            returned_events.append(room_alias_event)
+
         if (EventTypes.JoinRules, '') not in initial_state:
             join_rules_event = create(
                 etype=EventTypes.JoinRules,