summary refs log tree commit diff
path: root/synapse/handlers/room_member.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/room_member.py')
-rw-r--r--synapse/handlers/room_member.py173
1 files changed, 154 insertions, 19 deletions
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index ed805d6ec8..fbef600acd 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -16,7 +16,7 @@ import abc
 import logging
 import random
 from http import HTTPStatus
-from typing import TYPE_CHECKING, Iterable, List, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple
 
 from synapse import types
 from synapse.api.constants import (
@@ -38,7 +38,10 @@ from synapse.event_auth import get_named_level, get_power_level_event
 from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
 from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN
+from synapse.handlers.state_deltas import MatchChange, StateDeltasHandler
 from synapse.logging import opentracing
+from synapse.metrics import event_processing_positions
+from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.module_api import NOT_SPAM
 from synapse.types import (
     JsonDict,
@@ -280,9 +283,25 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         """
         raise NotImplementedError()
 
-    @abc.abstractmethod
     async def forget(self, user: UserID, room_id: str) -> None:
-        raise NotImplementedError()
+        user_id = user.to_string()
+
+        member = await self._storage_controllers.state.get_current_state_event(
+            room_id=room_id, event_type=EventTypes.Member, state_key=user_id
+        )
+        membership = member.membership if member else None
+
+        if membership is not None and membership not in [
+            Membership.LEAVE,
+            Membership.BAN,
+        ]:
+            raise SynapseError(400, "User %s in room %s" % (user_id, room_id))
+
+        # In normal case this call is only required if `membership` is not `None`.
+        # But: After the last member had left the room, the background update
+        # `_background_remove_left_rooms` is deleting rows related to this room from
+        # the table `current_state_events` and `get_current_state_events` is `None`.
+        await self.store.forget(user_id, room_id)
 
     async def ratelimit_multiple_invites(
         self,
@@ -2046,25 +2065,141 @@ class RoomMemberMasterHandler(RoomMemberHandler):
         """Implements RoomMemberHandler._user_left_room"""
         user_left_room(self.distributor, target, room_id)
 
-    async def forget(self, user: UserID, room_id: str) -> None:
-        user_id = user.to_string()
 
-        member = await self._storage_controllers.state.get_current_state_event(
-            room_id=room_id, event_type=EventTypes.Member, state_key=user_id
-        )
-        membership = member.membership if member else None
+class RoomForgetterHandler(StateDeltasHandler):
+    """Forgets rooms when they are left, when enabled in the homeserver config.
 
-        if membership is not None and membership not in [
-            Membership.LEAVE,
-            Membership.BAN,
-        ]:
-            raise SynapseError(400, "User %s in room %s" % (user_id, room_id))
+    For the purposes of this feature, kicks, bans and "leaves" via state resolution
+    weirdness are all considered to be leaves.
 
-        # In normal case this call is only required if `membership` is not `None`.
-        # But: After the last member had left the room, the background update
-        # `_background_remove_left_rooms` is deleting rows related to this room from
-        # the table `current_state_events` and `get_current_state_events` is `None`.
-        await self.store.forget(user_id, room_id)
+    Derived from `StatsHandler` and `UserDirectoryHandler`.
+    """
+
+    def __init__(self, hs: "HomeServer"):
+        super().__init__(hs)
+
+        self._hs = hs
+        self._store = hs.get_datastores().main
+        self._storage_controllers = hs.get_storage_controllers()
+        self._clock = hs.get_clock()
+        self._notifier = hs.get_notifier()
+        self._room_member_handler = hs.get_room_member_handler()
+
+        # The current position in the current_state_delta stream
+        self.pos: Optional[int] = None
+
+        # Guard to ensure we only process deltas one at a time
+        self._is_processing = False
+
+        if hs.config.worker.run_background_tasks:
+            self._notifier.add_replication_callback(self.notify_new_event)
+
+            # We kick this off to pick up outstanding work from before the last restart.
+            self._clock.call_later(0, self.notify_new_event)
+
+    def notify_new_event(self) -> None:
+        """Called when there may be more deltas to process"""
+        if self._is_processing:
+            return
+
+        self._is_processing = True
+
+        async def process() -> None:
+            try:
+                await self._unsafe_process()
+            finally:
+                self._is_processing = False
+
+        run_as_background_process("room_forgetter.notify_new_event", process)
+
+    async def _unsafe_process(self) -> None:
+        # If self.pos is None then means we haven't fetched it from DB
+        if self.pos is None:
+            self.pos = await self._store.get_room_forgetter_stream_pos()
+            room_max_stream_ordering = self._store.get_room_max_stream_ordering()
+            if self.pos > room_max_stream_ordering:
+                # apparently, we've processed more events than exist in the database!
+                # this can happen if events are removed with history purge or similar.
+                logger.warning(
+                    "Event stream ordering appears to have gone backwards (%i -> %i): "
+                    "rewinding room forgetter processor",
+                    self.pos,
+                    room_max_stream_ordering,
+                )
+                self.pos = room_max_stream_ordering
+
+        if not self._hs.config.room.forget_on_leave:
+            # Update the processing position, so that if the server admin turns the
+            # feature on at a later date, we don't decide to forget every room that
+            # has ever been left in the past.
+            self.pos = self._store.get_room_max_stream_ordering()
+            await self._store.update_room_forgetter_stream_pos(self.pos)
+            return
+
+        # Loop round handling deltas until we're up to date
+
+        while True:
+            # Be sure to read the max stream_ordering *before* checking if there are any outstanding
+            # deltas, since there is otherwise a chance that we could miss updates which arrive
+            # after we check the deltas.
+            room_max_stream_ordering = self._store.get_room_max_stream_ordering()
+            if self.pos == room_max_stream_ordering:
+                break
+
+            logger.debug(
+                "Processing room forgetting %s->%s", self.pos, room_max_stream_ordering
+            )
+            (
+                max_pos,
+                deltas,
+            ) = await self._storage_controllers.state.get_current_state_deltas(
+                self.pos, room_max_stream_ordering
+            )
+
+            logger.debug("Handling %d state deltas", len(deltas))
+            await self._handle_deltas(deltas)
+
+            self.pos = max_pos
+
+            # Expose current event processing position to prometheus
+            event_processing_positions.labels("room_forgetter").set(max_pos)
+
+            await self._store.update_room_forgetter_stream_pos(max_pos)
+
+    async def _handle_deltas(self, deltas: List[Dict[str, Any]]) -> None:
+        """Called with the state deltas to process"""
+        for delta in deltas:
+            typ = delta["type"]
+            state_key = delta["state_key"]
+            room_id = delta["room_id"]
+            event_id = delta["event_id"]
+            prev_event_id = delta["prev_event_id"]
+
+            if typ != EventTypes.Member:
+                continue
+
+            if not self._hs.is_mine_id(state_key):
+                continue
+
+            change = await self._get_key_change(
+                prev_event_id,
+                event_id,
+                key_name="membership",
+                public_value=Membership.JOIN,
+            )
+            is_leave = change is MatchChange.now_false
+
+            if is_leave:
+                try:
+                    await self._room_member_handler.forget(
+                        UserID.from_string(state_key), room_id
+                    )
+                except SynapseError as e:
+                    if e.code == 400:
+                        # The user is back in the room.
+                        pass
+                    else:
+                        raise
 
 
 def get_users_which_can_issue_invite(auth_events: StateMap[EventBase]) -> List[str]: