summary refs log tree commit diff
path: root/synapse/handlers/message.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/message.py')
-rw-r--r--synapse/handlers/message.py87
1 files changed, 53 insertions, 34 deletions
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 654ecd2b37..14051aee99 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -167,7 +167,7 @@ class MessageHandler(BaseHandler):
 
     @defer.inlineCallbacks
     def create_and_send_event(self, event_dict, ratelimit=True,
-                              token_id=None, txn_id=None):
+                              token_id=None, txn_id=None, is_guest=False):
         """ Given a dict from a client, create and handle a new event.
 
         Creates an FrozenEvent object, filling out auth_events, prev_events,
@@ -213,7 +213,7 @@ class MessageHandler(BaseHandler):
 
         if event.type == EventTypes.Member:
             member_handler = self.hs.get_handlers().room_member_handler
-            yield member_handler.change_membership(event, context)
+            yield member_handler.change_membership(event, context, is_guest=is_guest)
         else:
             yield self.handle_new_client_event(
                 event=event,
@@ -258,20 +258,30 @@ class MessageHandler(BaseHandler):
 
     @defer.inlineCallbacks
     def _check_in_room_or_world_readable(self, room_id, user_id, is_guest):
-        if is_guest:
+        try:
+            # check_user_was_in_room will return the most recent membership
+            # event for the user if:
+            #  * The user is a non-guest user, and was ever in the room
+            #  * The user is a guest user, and has joined the room
+            # else it will throw.
+            member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
+            defer.returnValue((member_event.membership, member_event.event_id))
+            return
+        except AuthError, auth_error:
             visibility = yield self.state_handler.get_current_state(
                 room_id, EventTypes.RoomHistoryVisibility, ""
             )
-            if visibility.content["history_visibility"] == "world_readable":
+            if (
+                visibility and
+                visibility.content["history_visibility"] == "world_readable"
+            ):
                 defer.returnValue((Membership.JOIN, None))
                 return
-            else:
-                raise AuthError(
-                    403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN
-                )
-        else:
-            member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
-            defer.returnValue((member_event.membership, member_event.event_id))
+            if not is_guest:
+                raise auth_error
+            raise AuthError(
+                403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN
+            )
 
     @defer.inlineCallbacks
     def get_state_events(self, user_id, room_id, is_guest=False):
@@ -456,7 +466,7 @@ class MessageHandler(BaseHandler):
         defer.returnValue(ret)
 
     @defer.inlineCallbacks
-    def room_initial_sync(self, user_id, room_id, pagin_config=None):
+    def room_initial_sync(self, user_id, room_id, pagin_config=None, is_guest=False):
         """Capture the a snapshot of a room. If user is currently a member of
         the room this will be what is currently in the room. If the user left
         the room this will be what was in the room when they left.
@@ -473,15 +483,19 @@ class MessageHandler(BaseHandler):
             A JSON serialisable dict with the snapshot of the room.
         """
 
-        member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
+        membership, member_event_id = yield self._check_in_room_or_world_readable(
+            room_id,
+            user_id,
+            is_guest
+        )
 
-        if member_event.membership == Membership.JOIN:
+        if membership == Membership.JOIN:
             result = yield self._room_initial_sync_joined(
-                user_id, room_id, pagin_config, member_event
+                user_id, room_id, pagin_config, membership, is_guest
             )
-        elif member_event.membership == Membership.LEAVE:
+        elif membership == Membership.LEAVE:
             result = yield self._room_initial_sync_parted(
-                user_id, room_id, pagin_config, member_event
+                user_id, room_id, pagin_config, membership, member_event_id, is_guest
             )
 
         private_user_data = []
@@ -497,19 +511,19 @@ class MessageHandler(BaseHandler):
 
     @defer.inlineCallbacks
     def _room_initial_sync_parted(self, user_id, room_id, pagin_config,
-                                  member_event):
+                                  membership, member_event_id, is_guest):
         room_state = yield self.store.get_state_for_events(
-            [member_event.event_id], None
+            [member_event_id], None
         )
 
-        room_state = room_state[member_event.event_id]
+        room_state = room_state[member_event_id]
 
         limit = pagin_config.limit if pagin_config else None
         if limit is None:
             limit = 10
 
         stream_token = yield self.store.get_stream_token_for_event(
-            member_event.event_id
+            member_event_id
         )
 
         messages, token = yield self.store.get_recent_events_for_room(
@@ -519,7 +533,7 @@ class MessageHandler(BaseHandler):
         )
 
         messages = yield self._filter_events_for_client(
-            user_id, messages
+            user_id, messages, is_guest=is_guest
         )
 
         start_token = StreamToken(token[0], 0, 0, 0, 0)
@@ -528,7 +542,7 @@ class MessageHandler(BaseHandler):
         time_now = self.clock.time_msec()
 
         defer.returnValue({
-            "membership": member_event.membership,
+            "membership": membership,
             "room_id": room_id,
             "messages": {
                 "chunk": [serialize_event(m, time_now) for m in messages],
@@ -542,7 +556,7 @@ class MessageHandler(BaseHandler):
 
     @defer.inlineCallbacks
     def _room_initial_sync_joined(self, user_id, room_id, pagin_config,
-                                  member_event):
+                                  membership, is_guest):
         current_state = yield self.state.get_current_state(
             room_id=room_id,
         )
@@ -574,12 +588,14 @@ class MessageHandler(BaseHandler):
 
         @defer.inlineCallbacks
         def get_presence():
-            states = yield presence_handler.get_states(
-                target_users=[UserID.from_string(m.user_id) for m in room_members],
-                auth_user=auth_user,
-                as_event=True,
-                check_auth=False,
-            )
+            states = {}
+            if not is_guest:
+                states = yield presence_handler.get_states(
+                    target_users=[UserID.from_string(m.user_id) for m in room_members],
+                    auth_user=auth_user,
+                    as_event=True,
+                    check_auth=False,
+                )
 
             defer.returnValue(states.values())
 
@@ -599,7 +615,7 @@ class MessageHandler(BaseHandler):
         ).addErrback(unwrapFirstError)
 
         messages = yield self._filter_events_for_client(
-            user_id, messages
+            user_id, messages, is_guest=is_guest, require_all_visible_for_guests=False
         )
 
         start_token = now_token.copy_and_replace("room_key", token[0])
@@ -607,8 +623,7 @@ class MessageHandler(BaseHandler):
 
         time_now = self.clock.time_msec()
 
-        defer.returnValue({
-            "membership": member_event.membership,
+        ret = {
             "room_id": room_id,
             "messages": {
                 "chunk": [serialize_event(m, time_now) for m in messages],
@@ -618,4 +633,8 @@ class MessageHandler(BaseHandler):
             "state": state,
             "presence": presence,
             "receipts": receipts,
-        })
+        }
+        if not is_guest:
+            ret["membership"] = membership
+
+        defer.returnValue(ret)