summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/_base.py68
-rw-r--r--synapse/handlers/federation_event.py17
-rw-r--r--synapse/handlers/message.py27
-rw-r--r--synapse/handlers/room.py5
-rw-r--r--synapse/handlers/room_list.py12
-rw-r--r--synapse/handlers/room_member.py65
-rw-r--r--synapse/handlers/stats.py6
7 files changed, 115 insertions, 85 deletions
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index 6a05a65305..955cfa2207 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -15,10 +15,7 @@
 import logging
 from typing import TYPE_CHECKING, Optional
 
-import synapse.types
-from synapse.api.constants import EventTypes, Membership
 from synapse.api.ratelimiting import Ratelimiter
-from synapse.types import UserID
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -115,68 +112,3 @@ class BaseHandler:
                 burst_count=burst_count,
                 update=update,
             )
-
-    async def maybe_kick_guest_users(self, event, context=None):
-        # Technically this function invalidates current_state by changing it.
-        # Hopefully this isn't that important to the caller.
-        if event.type == EventTypes.GuestAccess:
-            guest_access = event.content.get("guest_access", "forbidden")
-            if guest_access != "can_join":
-                if context:
-                    current_state_ids = await context.get_current_state_ids()
-                    current_state_dict = await self.store.get_events(
-                        list(current_state_ids.values())
-                    )
-                    current_state = list(current_state_dict.values())
-                else:
-                    current_state_map = await self.state_handler.get_current_state(
-                        event.room_id
-                    )
-                    current_state = list(current_state_map.values())
-
-                logger.info("maybe_kick_guest_users %r", current_state)
-                await self.kick_guest_users(current_state)
-
-    async def kick_guest_users(self, current_state):
-        for member_event in current_state:
-            try:
-                if member_event.type != EventTypes.Member:
-                    continue
-
-                target_user = UserID.from_string(member_event.state_key)
-                if not self.hs.is_mine(target_user):
-                    continue
-
-                if member_event.content["membership"] not in {
-                    Membership.JOIN,
-                    Membership.INVITE,
-                }:
-                    continue
-
-                if (
-                    "kind" not in member_event.content
-                    or member_event.content["kind"] != "guest"
-                ):
-                    continue
-
-                # We make the user choose to leave, rather than have the
-                # event-sender kick them. This is partially because we don't
-                # need to worry about power levels, and partially because guest
-                # users are a concept which doesn't hugely work over federation,
-                # and having homeservers have their own users leave keeps more
-                # of that decision-making and control local to the guest-having
-                # homeserver.
-                requester = synapse.types.create_requester(
-                    target_user, is_guest=True, authenticated_entity=self.server_name
-                )
-                handler = self.hs.get_room_member_handler()
-                await handler.update_membership(
-                    requester,
-                    target_user,
-                    member_event.room_id,
-                    "leave",
-                    ratelimit=False,
-                    require_consent=False,
-                )
-            except Exception as e:
-                logger.exception("Error kicking guest user: %s" % (e,))
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index b622e3ae2d..3414747f49 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -36,6 +36,7 @@ from synapse import event_auth
 from synapse.api.constants import (
     EventContentFields,
     EventTypes,
+    GuestAccess,
     Membership,
     RejectedReason,
     RoomEncryptionAlgorithms,
@@ -1327,9 +1328,7 @@ class FederationEventHandler(BaseHandler):
 
         if not context.rejected:
             await self._check_for_soft_fail(event, state, backfilled, origin=origin)
-
-        if event.type == EventTypes.GuestAccess and not context.rejected:
-            await self.maybe_kick_guest_users(event)
+            await self._maybe_kick_guest_users(event)
 
         # If we are going to send this event over federation we precaclculate
         # the joined hosts.
@@ -1340,6 +1339,18 @@ class FederationEventHandler(BaseHandler):
 
         return context
 
+    async def _maybe_kick_guest_users(self, event: EventBase) -> None:
+        if event.type != EventTypes.GuestAccess:
+            return
+
+        guest_access = event.content.get(EventContentFields.GUEST_ACCESS)
+        if guest_access == GuestAccess.CAN_JOIN:
+            return
+
+        current_state_map = await self.state_handler.get_current_state(event.room_id)
+        current_state = list(current_state_map.values())
+        await self.hs.get_room_member_handler().kick_guest_users(current_state)
+
     async def _check_for_soft_fail(
         self,
         event: EventBase,
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 9d2c897341..bf0fef1510 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -27,6 +27,7 @@ from synapse import event_auth
 from synapse.api.constants import (
     EventContentFields,
     EventTypes,
+    GuestAccess,
     Membership,
     RelationTypes,
     UserTypes,
@@ -426,7 +427,7 @@ class EventCreationHandler:
 
         self.send_event = ReplicationSendEventRestServlet.make_client(hs)
 
-        # This is only used to get at ratelimit function, and maybe_kick_guest_users
+        # This is only used to get at ratelimit function
         self.base_handler = BaseHandler(hs)
 
         # We arbitrarily limit concurrent event creation for a room to 5.
@@ -1306,7 +1307,7 @@ class EventCreationHandler:
                 requester, is_admin_redaction=is_admin_redaction
             )
 
-        await self.base_handler.maybe_kick_guest_users(event, context)
+        await self._maybe_kick_guest_users(event, context)
 
         if event.type == EventTypes.CanonicalAlias:
             # Validate a newly added alias or newly added alt_aliases.
@@ -1493,6 +1494,28 @@ class EventCreationHandler:
 
         return event
 
+    async def _maybe_kick_guest_users(
+        self, event: EventBase, context: EventContext
+    ) -> None:
+        if event.type != EventTypes.GuestAccess:
+            return
+
+        guest_access = event.content.get(EventContentFields.GUEST_ACCESS)
+        if guest_access == GuestAccess.CAN_JOIN:
+            return
+
+        current_state_ids = await context.get_current_state_ids()
+
+        # since this is a client-generated event, it cannot be an outlier and we must
+        # therefore have the state ids.
+        assert current_state_ids is not None
+        current_state_dict = await self.store.get_events(
+            list(current_state_ids.values())
+        )
+        current_state = list(current_state_dict.values())
+        logger.info("maybe_kick_guest_users %r", current_state)
+        await self.hs.get_room_member_handler().kick_guest_users(current_state)
+
     async def _bump_active_time(self, user: UserID) -> None:
         try:
             presence = self.hs.get_presence_handler()
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index ed780bb41f..0235fd09b4 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -25,7 +25,9 @@ from collections import OrderedDict
 from typing import TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple
 
 from synapse.api.constants import (
+    EventContentFields,
     EventTypes,
+    GuestAccess,
     HistoryVisibility,
     JoinRules,
     Membership,
@@ -993,7 +995,8 @@ class RoomCreationHandler(BaseHandler):
         if config["guest_can_join"]:
             if (EventTypes.GuestAccess, "") not in initial_state:
                 last_sent_stream_id = await send(
-                    etype=EventTypes.GuestAccess, content={"guest_access": "can_join"}
+                    etype=EventTypes.GuestAccess,
+                    content={EventContentFields.GUEST_ACCESS: GuestAccess.CAN_JOIN},
                 )
 
         for (etype, state_key), content in initial_state.items():
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index 6d433fad41..92bb75c848 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -19,7 +19,13 @@ from typing import TYPE_CHECKING, Optional, Tuple
 import msgpack
 from unpaddedbase64 import decode_base64, encode_base64
 
-from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules
+from synapse.api.constants import (
+    EventContentFields,
+    EventTypes,
+    GuestAccess,
+    HistoryVisibility,
+    JoinRules,
+)
 from synapse.api.errors import (
     Codes,
     HttpResponseException,
@@ -336,8 +342,8 @@ class RoomListHandler(BaseHandler):
         guest_event = current_state.get((EventTypes.GuestAccess, ""))
         guest = None
         if guest_event:
-            guest = guest_event.content.get("guest_access", None)
-        result["guest_can_join"] = guest == "can_join"
+            guest = guest_event.content.get(EventContentFields.GUEST_ACCESS)
+        result["guest_can_join"] = guest == GuestAccess.CAN_JOIN
 
         avatar_event = current_state.get(("m.room.avatar", ""))
         if avatar_event:
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 401b84aad1..4390201641 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -23,6 +23,7 @@ from synapse.api.constants import (
     AccountDataTypes,
     EventContentFields,
     EventTypes,
+    GuestAccess,
     Membership,
 )
 from synapse.api.errors import (
@@ -44,6 +45,7 @@ from synapse.types import (
     RoomID,
     StateMap,
     UserID,
+    create_requester,
     get_domain_from_id,
 )
 from synapse.util.async_helpers import Linearizer
@@ -70,6 +72,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         self.auth = hs.get_auth()
         self.state_handler = hs.get_state_handler()
         self.config = hs.config
+        self._server_name = hs.hostname
 
         self.federation_handler = hs.get_federation_handler()
         self.directory_handler = hs.get_directory_handler()
@@ -115,9 +118,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             burst_count=hs.config.ratelimiting.rc_invites_per_user.burst_count,
         )
 
-        # This is only used to get at ratelimit function, and
-        # maybe_kick_guest_users. It's fine there are multiple of these as
-        # it doesn't store state.
+        # This is only used to get at the ratelimit function. It's fine there are
+        # multiple of these as it doesn't store state.
         self.base_handler = BaseHandler(hs)
 
     @abc.abstractmethod
@@ -1095,10 +1097,62 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         return bool(
             guest_access
             and guest_access.content
-            and "guest_access" in guest_access.content
-            and guest_access.content["guest_access"] == "can_join"
+            and guest_access.content.get(EventContentFields.GUEST_ACCESS)
+            == GuestAccess.CAN_JOIN
         )
 
+    async def kick_guest_users(self, current_state: Iterable[EventBase]) -> None:
+        """Kick any local guest users from the room.
+
+        This is called when the room state changes from guests allowed to not-allowed.
+
+        Params:
+            current_state: the current state of the room. We will iterate this to look
+               for guest users to kick.
+        """
+        for member_event in current_state:
+            try:
+                if member_event.type != EventTypes.Member:
+                    continue
+
+                if not self.hs.is_mine_id(member_event.state_key):
+                    continue
+
+                if member_event.content["membership"] not in {
+                    Membership.JOIN,
+                    Membership.INVITE,
+                }:
+                    continue
+
+                if (
+                    "kind" not in member_event.content
+                    or member_event.content["kind"] != "guest"
+                ):
+                    continue
+
+                # We make the user choose to leave, rather than have the
+                # event-sender kick them. This is partially because we don't
+                # need to worry about power levels, and partially because guest
+                # users are a concept which doesn't hugely work over federation,
+                # and having homeservers have their own users leave keeps more
+                # of that decision-making and control local to the guest-having
+                # homeserver.
+                target_user = UserID.from_string(member_event.state_key)
+                requester = create_requester(
+                    target_user, is_guest=True, authenticated_entity=self._server_name
+                )
+                handler = self.hs.get_room_member_handler()
+                await handler.update_membership(
+                    requester,
+                    target_user,
+                    member_event.room_id,
+                    "leave",
+                    ratelimit=False,
+                    require_consent=False,
+                )
+            except Exception as e:
+                logger.exception("Error kicking guest user: %s" % (e,))
+
     async def lookup_room_alias(
         self, room_alias: RoomAlias
     ) -> Tuple[RoomID, List[str]]:
@@ -1352,7 +1406,6 @@ class RoomMemberMasterHandler(RoomMemberHandler):
 
         self.distributor = hs.get_distributor()
         self.distributor.declare("user_left_room")
-        self._server_name = hs.hostname
 
     async def _is_remote_room_too_complex(
         self, room_id: str, remote_room_hosts: List[str]
diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py
index 3fd89af2a4..3a4c41c9ff 100644
--- a/synapse/handlers/stats.py
+++ b/synapse/handlers/stats.py
@@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Tuple
 
 from typing_extensions import Counter as CounterType
 
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EventContentFields, EventTypes, Membership
 from synapse.metrics import event_processing_positions
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.types import JsonDict
@@ -273,7 +273,9 @@ class StatsHandler:
             elif typ == EventTypes.CanonicalAlias:
                 room_state["canonical_alias"] = event_content.get("alias")
             elif typ == EventTypes.GuestAccess:
-                room_state["guest_access"] = event_content.get("guest_access")
+                room_state["guest_access"] = event_content.get(
+                    EventContentFields.GUEST_ACCESS
+                )
 
         for room_id, state in room_to_state_updates.items():
             logger.debug("Updating room_stats_state for %s: %s", room_id, state)