summary refs log tree commit diff
path: root/synapse/handlers/message.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/message.py')
-rw-r--r--synapse/handlers/message.py27
1 files changed, 25 insertions, 2 deletions
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 9d2c897341..bf0fef1510 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -27,6 +27,7 @@ from synapse import event_auth
 from synapse.api.constants import (
     EventContentFields,
     EventTypes,
+    GuestAccess,
     Membership,
     RelationTypes,
     UserTypes,
@@ -426,7 +427,7 @@ class EventCreationHandler:
 
         self.send_event = ReplicationSendEventRestServlet.make_client(hs)
 
-        # This is only used to get at ratelimit function, and maybe_kick_guest_users
+        # This is only used to get at ratelimit function
         self.base_handler = BaseHandler(hs)
 
         # We arbitrarily limit concurrent event creation for a room to 5.
@@ -1306,7 +1307,7 @@ class EventCreationHandler:
                 requester, is_admin_redaction=is_admin_redaction
             )
 
-        await self.base_handler.maybe_kick_guest_users(event, context)
+        await self._maybe_kick_guest_users(event, context)
 
         if event.type == EventTypes.CanonicalAlias:
             # Validate a newly added alias or newly added alt_aliases.
@@ -1493,6 +1494,28 @@ class EventCreationHandler:
 
         return event
 
+    async def _maybe_kick_guest_users(
+        self, event: EventBase, context: EventContext
+    ) -> None:
+        if event.type != EventTypes.GuestAccess:
+            return
+
+        guest_access = event.content.get(EventContentFields.GUEST_ACCESS)
+        if guest_access == GuestAccess.CAN_JOIN:
+            return
+
+        current_state_ids = await context.get_current_state_ids()
+
+        # since this is a client-generated event, it cannot be an outlier and we must
+        # therefore have the state ids.
+        assert current_state_ids is not None
+        current_state_dict = await self.store.get_events(
+            list(current_state_ids.values())
+        )
+        current_state = list(current_state_dict.values())
+        logger.info("maybe_kick_guest_users %r", current_state)
+        await self.hs.get_room_member_handler().kick_guest_users(current_state)
+
     async def _bump_active_time(self, user: UserID) -> None:
         try:
             presence = self.hs.get_presence_handler()