summary refs log tree commit diff
path: root/synapse/handlers/room_member.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/room_member.py')
-rw-r--r--synapse/handlers/room_member.py124
1 files changed, 84 insertions, 40 deletions
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 8b17632fdc..dd4b90ee24 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -93,20 +93,26 @@ class RoomMemberHandler(BaseHandler):
             ratelimit=ratelimit,
         )
 
-        prev_member_event = context.current_state.get(
+        prev_member_event_id = context.current_state_ids.get(
             (EventTypes.Member, target.to_string()),
             None
         )
 
         if event.membership == Membership.JOIN:
-            if not prev_member_event or prev_member_event.membership != Membership.JOIN:
-                # Only fire user_joined_room if the user has acutally joined the
-                # room. Don't bother if the user is just changing their profile
-                # info.
+            # Only fire user_joined_room if the user has acutally joined the
+            # room. Don't bother if the user is just changing their profile
+            # info.
+            newly_joined = True
+            if prev_member_event_id:
+                prev_member_event = yield self.store.get_event(prev_member_event_id)
+                newly_joined = prev_member_event.membership != Membership.JOIN
+            if newly_joined:
                 yield user_joined_room(self.distributor, target, room_id)
         elif event.membership == Membership.LEAVE:
-            if prev_member_event and prev_member_event.membership == Membership.JOIN:
-                user_left_room(self.distributor, target, room_id)
+            if prev_member_event_id:
+                prev_member_event = yield self.store.get_event(prev_member_event_id)
+                if prev_member_event.membership == Membership.JOIN:
+                    user_left_room(self.distributor, target, room_id)
 
     @defer.inlineCallbacks
     def remote_join(self, remote_room_hosts, room_id, user, content):
@@ -195,29 +201,32 @@ class RoomMemberHandler(BaseHandler):
             remote_room_hosts = []
 
         latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
-        current_state = yield self.state_handler.get_current_state(
+        current_state_ids = yield self.state_handler.get_current_state_ids(
             room_id, latest_event_ids=latest_event_ids,
         )
 
-        old_state = current_state.get((EventTypes.Member, target.to_string()))
-        old_membership = old_state.content.get("membership") if old_state else None
-        if action == "unban" and old_membership != "ban":
-            raise SynapseError(
-                403,
-                "Cannot unban user who was not banned (membership=%s)" % old_membership,
-                errcode=Codes.BAD_STATE
-            )
-        if old_membership == "ban" and action != "unban":
-            raise SynapseError(
-                403,
-                "Cannot %s user who was banned" % (action,),
-                errcode=Codes.BAD_STATE
-            )
+        old_state_id = current_state_ids.get((EventTypes.Member, target.to_string()))
+        if old_state_id:
+            old_state = yield self.store.get_event(old_state_id, allow_none=True)
+            old_membership = old_state.content.get("membership") if old_state else None
+            if action == "unban" and old_membership != "ban":
+                raise SynapseError(
+                    403,
+                    "Cannot unban user who was not banned"
+                    " (membership=%s)" % old_membership,
+                    errcode=Codes.BAD_STATE
+                )
+            if old_membership == "ban" and action != "unban":
+                raise SynapseError(
+                    403,
+                    "Cannot %s user who was banned" % (action,),
+                    errcode=Codes.BAD_STATE
+                )
 
-        is_host_in_room = self.is_host_in_room(current_state)
+        is_host_in_room = yield self._is_host_in_room(current_state_ids)
 
         if effective_membership_state == Membership.JOIN:
-            if requester.is_guest and not self._can_guest_join(current_state):
+            if requester.is_guest and not self._can_guest_join(current_state_ids):
                 # This should be an auth check, but guests are a local concept,
                 # so don't really fit into the general auth process.
                 raise AuthError(403, "Guest access not allowed")
@@ -326,15 +335,17 @@ class RoomMemberHandler(BaseHandler):
             requester = synapse.types.create_requester(target_user)
 
         message_handler = self.hs.get_handlers().message_handler
-        prev_event = message_handler.deduplicate_state_event(event, context)
+        prev_event = yield message_handler.deduplicate_state_event(event, context)
         if prev_event is not None:
             return
 
         if event.membership == Membership.JOIN:
-            if requester.is_guest and not self._can_guest_join(context.current_state):
-                # This should be an auth check, but guests are a local concept,
-                # so don't really fit into the general auth process.
-                raise AuthError(403, "Guest access not allowed")
+            if requester.is_guest:
+                guest_can_join = yield self._can_guest_join(context.current_state_ids)
+                if not guest_can_join:
+                    # This should be an auth check, but guests are a local concept,
+                    # so don't really fit into the general auth process.
+                    raise AuthError(403, "Guest access not allowed")
 
         yield message_handler.handle_new_client_event(
             requester,
@@ -344,27 +355,39 @@ class RoomMemberHandler(BaseHandler):
             ratelimit=ratelimit,
         )
 
-        prev_member_event = context.current_state.get(
-            (EventTypes.Member, target_user.to_string()),
+        prev_member_event_id = context.current_state_ids.get(
+            (EventTypes.Member, event.state_key),
             None
         )
 
         if event.membership == Membership.JOIN:
-            if not prev_member_event or prev_member_event.membership != Membership.JOIN:
-                # Only fire user_joined_room if the user has acutally joined the
-                # room. Don't bother if the user is just changing their profile
-                # info.
+            # Only fire user_joined_room if the user has acutally joined the
+            # room. Don't bother if the user is just changing their profile
+            # info.
+            newly_joined = True
+            if prev_member_event_id:
+                prev_member_event = yield self.store.get_event(prev_member_event_id)
+                newly_joined = prev_member_event.membership != Membership.JOIN
+            if newly_joined:
                 yield user_joined_room(self.distributor, target_user, room_id)
         elif event.membership == Membership.LEAVE:
-            if prev_member_event and prev_member_event.membership == Membership.JOIN:
-                user_left_room(self.distributor, target_user, room_id)
+            if prev_member_event_id:
+                prev_member_event = yield self.store.get_event(prev_member_event_id)
+                if prev_member_event.membership == Membership.JOIN:
+                    user_left_room(self.distributor, target_user, room_id)
 
-    def _can_guest_join(self, current_state):
+    @defer.inlineCallbacks
+    def _can_guest_join(self, current_state_ids):
         """
         Returns whether a guest can join a room based on its current state.
         """
-        guest_access = current_state.get((EventTypes.GuestAccess, ""), None)
-        return (
+        guest_access_id = current_state_ids.get((EventTypes.GuestAccess, ""), None)
+        if not guest_access_id:
+            defer.returnValue(False)
+
+        guest_access = yield self.store.get_event(guest_access_id)
+
+        defer.returnValue(
             guest_access
             and guest_access.content
             and "guest_access" in guest_access.content
@@ -683,3 +706,24 @@ class RoomMemberHandler(BaseHandler):
 
         if membership:
             yield self.store.forget(user_id, room_id)
+
+    @defer.inlineCallbacks
+    def _is_host_in_room(self, current_state_ids):
+        # Have we just created the room, and is this about to be the very
+        # first member event?
+        create_event_id = current_state_ids.get(("m.room.create", ""))
+        if len(current_state_ids) == 1 and create_event_id:
+            defer.returnValue(self.hs.is_mine_id(create_event_id))
+
+        for (etype, state_key), event_id in current_state_ids.items():
+            if etype != EventTypes.Member or not self.hs.is_mine_id(state_key):
+                continue
+
+            event = yield self.store.get_event(event_id, allow_none=True)
+            if not event:
+                continue
+
+            if event.membership == Membership.JOIN:
+                defer.returnValue(True)
+
+        defer.returnValue(False)