summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/message.py38
-rw-r--r--synapse/handlers/pagination.py23
-rw-r--r--synapse/handlers/room_member.py45
-rw-r--r--synapse/handlers/worker_lock.py333
4 files changed, 414 insertions, 25 deletions
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index fff0b5fa12..187dedae7d 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -53,6 +53,7 @@ from synapse.events.snapshot import EventContext, UnpersistedEventContextBase
 from synapse.events.utils import SerializeEventConfig, maybe_upsert_event_field
 from synapse.events.validator import EventValidator
 from synapse.handlers.directory import DirectoryHandler
+from synapse.handlers.worker_lock import DELETE_ROOM_LOCK_NAME
 from synapse.logging import opentracing
 from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.metrics.background_process_metrics import run_as_background_process
@@ -485,6 +486,7 @@ class EventCreationHandler:
         self._events_shard_config = self.config.worker.events_shard_config
         self._instance_name = hs.get_instance_name()
         self._notifier = hs.get_notifier()
+        self._worker_lock_handler = hs.get_worker_locks_handler()
 
         self.room_prejoin_state_types = self.hs.config.api.room_prejoin_state
 
@@ -1010,6 +1012,37 @@ class EventCreationHandler:
                         event.internal_metadata.stream_ordering,
                     )
 
+        async with self._worker_lock_handler.acquire_read_write_lock(
+            DELETE_ROOM_LOCK_NAME, room_id, write=False
+        ):
+            return await self._create_and_send_nonmember_event_locked(
+                requester=requester,
+                event_dict=event_dict,
+                allow_no_prev_events=allow_no_prev_events,
+                prev_event_ids=prev_event_ids,
+                state_event_ids=state_event_ids,
+                ratelimit=ratelimit,
+                txn_id=txn_id,
+                ignore_shadow_ban=ignore_shadow_ban,
+                outlier=outlier,
+                depth=depth,
+            )
+
+    async def _create_and_send_nonmember_event_locked(
+        self,
+        requester: Requester,
+        event_dict: dict,
+        allow_no_prev_events: bool = False,
+        prev_event_ids: Optional[List[str]] = None,
+        state_event_ids: Optional[List[str]] = None,
+        ratelimit: bool = True,
+        txn_id: Optional[str] = None,
+        ignore_shadow_ban: bool = False,
+        outlier: bool = False,
+        depth: Optional[int] = None,
+    ) -> Tuple[EventBase, int]:
+        room_id = event_dict["room_id"]
+
         # If we don't have any prev event IDs specified then we need to
         # check that the host is in the room (as otherwise populating the
         # prev events will fail), at which point we may as well check the
@@ -1923,7 +1956,10 @@ class EventCreationHandler:
         )
 
         for room_id in room_ids:
-            dummy_event_sent = await self._send_dummy_event_for_room(room_id)
+            async with self._worker_lock_handler.acquire_read_write_lock(
+                DELETE_ROOM_LOCK_NAME, room_id, write=False
+            ):
+                dummy_event_sent = await self._send_dummy_event_for_room(room_id)
 
             if not dummy_event_sent:
                 # Did not find a valid user in the room, so remove from future attempts
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 19b8728db9..da34658470 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -46,6 +46,11 @@ logger = logging.getLogger(__name__)
 BACKFILL_BECAUSE_TOO_MANY_GAPS_THRESHOLD = 3
 
 
+PURGE_HISTORY_LOCK_NAME = "purge_history_lock"
+
+DELETE_ROOM_LOCK_NAME = "delete_room_lock"
+
+
 @attr.s(slots=True, auto_attribs=True)
 class PurgeStatus:
     """Object tracking the status of a purge request
@@ -142,6 +147,7 @@ class PaginationHandler:
         self._server_name = hs.hostname
         self._room_shutdown_handler = hs.get_room_shutdown_handler()
         self._relations_handler = hs.get_relations_handler()
+        self._worker_locks = hs.get_worker_locks_handler()
 
         self.pagination_lock = ReadWriteLock()
         # IDs of rooms in which there currently an active purge *or delete* operation.
@@ -356,7 +362,9 @@ class PaginationHandler:
         """
         self._purges_in_progress_by_room.add(room_id)
         try:
-            async with self.pagination_lock.write(room_id):
+            async with self._worker_locks.acquire_read_write_lock(
+                PURGE_HISTORY_LOCK_NAME, room_id, write=True
+            ):
                 await self._storage_controllers.purge_events.purge_history(
                     room_id, token, delete_local_events
                 )
@@ -412,7 +420,10 @@ class PaginationHandler:
             room_id: room to be purged
             force: set true to skip checking for joined users.
         """
-        async with self.pagination_lock.write(room_id):
+        async with self._worker_locks.acquire_multi_read_write_lock(
+            [(PURGE_HISTORY_LOCK_NAME, room_id), (DELETE_ROOM_LOCK_NAME, room_id)],
+            write=True,
+        ):
             # first check that we have no users in this room
             if not force:
                 joined = await self.store.is_host_joined(room_id, self._server_name)
@@ -471,7 +482,9 @@ class PaginationHandler:
 
         room_token = from_token.room_key
 
-        async with self.pagination_lock.read(room_id):
+        async with self._worker_locks.acquire_read_write_lock(
+            PURGE_HISTORY_LOCK_NAME, room_id, write=False
+        ):
             (membership, member_event_id) = (None, None)
             if not use_admin_priviledge:
                 (
@@ -747,7 +760,9 @@ class PaginationHandler:
 
         self._purges_in_progress_by_room.add(room_id)
         try:
-            async with self.pagination_lock.write(room_id):
+            async with self._worker_locks.acquire_read_write_lock(
+                PURGE_HISTORY_LOCK_NAME, room_id, write=True
+            ):
                 self._delete_by_id[delete_id].status = DeleteStatus.STATUS_SHUTTING_DOWN
                 self._delete_by_id[
                     delete_id
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 496e701f13..6cca2ec344 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -39,6 +39,7 @@ from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
 from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN
 from synapse.handlers.state_deltas import MatchChange, StateDeltasHandler
+from synapse.handlers.worker_lock import DELETE_ROOM_LOCK_NAME
 from synapse.logging import opentracing
 from synapse.metrics import event_processing_positions
 from synapse.metrics.background_process_metrics import run_as_background_process
@@ -94,6 +95,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         self.event_creation_handler = hs.get_event_creation_handler()
         self.account_data_handler = hs.get_account_data_handler()
         self.event_auth_handler = hs.get_event_auth_handler()
+        self._worker_lock_handler = hs.get_worker_locks_handler()
 
         self.member_linearizer: Linearizer = Linearizer(name="member")
         self.member_as_limiter = Linearizer(max_count=10, name="member_as_limiter")
@@ -638,26 +640,29 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         # by application services), and then by room ID.
         async with self.member_as_limiter.queue(as_id):
             async with self.member_linearizer.queue(key):
-                with opentracing.start_active_span("update_membership_locked"):
-                    result = await self.update_membership_locked(
-                        requester,
-                        target,
-                        room_id,
-                        action,
-                        txn_id=txn_id,
-                        remote_room_hosts=remote_room_hosts,
-                        third_party_signed=third_party_signed,
-                        ratelimit=ratelimit,
-                        content=content,
-                        new_room=new_room,
-                        require_consent=require_consent,
-                        outlier=outlier,
-                        allow_no_prev_events=allow_no_prev_events,
-                        prev_event_ids=prev_event_ids,
-                        state_event_ids=state_event_ids,
-                        depth=depth,
-                        origin_server_ts=origin_server_ts,
-                    )
+                async with self._worker_lock_handler.acquire_read_write_lock(
+                    DELETE_ROOM_LOCK_NAME, room_id, write=False
+                ):
+                    with opentracing.start_active_span("update_membership_locked"):
+                        result = await self.update_membership_locked(
+                            requester,
+                            target,
+                            room_id,
+                            action,
+                            txn_id=txn_id,
+                            remote_room_hosts=remote_room_hosts,
+                            third_party_signed=third_party_signed,
+                            ratelimit=ratelimit,
+                            content=content,
+                            new_room=new_room,
+                            require_consent=require_consent,
+                            outlier=outlier,
+                            allow_no_prev_events=allow_no_prev_events,
+                            prev_event_ids=prev_event_ids,
+                            state_event_ids=state_event_ids,
+                            depth=depth,
+                            origin_server_ts=origin_server_ts,
+                        )
 
         return result
 
diff --git a/synapse/handlers/worker_lock.py b/synapse/handlers/worker_lock.py
new file mode 100644
index 0000000000..72df773a86
--- /dev/null
+++ b/synapse/handlers/worker_lock.py
@@ -0,0 +1,333 @@
+# Copyright 2023 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import random
+from types import TracebackType
+from typing import (
+    TYPE_CHECKING,
+    AsyncContextManager,
+    Collection,
+    Dict,
+    Optional,
+    Tuple,
+    Type,
+    Union,
+)
+from weakref import WeakSet
+
+import attr
+
+from twisted.internet import defer
+from twisted.internet.interfaces import IReactorTime
+
+from synapse.logging.context import PreserveLoggingContext
+from synapse.logging.opentracing import start_active_span
+from synapse.metrics.background_process_metrics import wrap_as_background_process
+from synapse.storage.databases.main.lock import Lock, LockStore
+from synapse.util.async_helpers import timeout_deferred
+
+if TYPE_CHECKING:
+    from synapse.logging.opentracing import opentracing
+    from synapse.server import HomeServer
+
+
+DELETE_ROOM_LOCK_NAME = "delete_room_lock"
+
+
+class WorkerLocksHandler:
+    """A class for waiting on taking out locks, rather than using the storage
+    functions directly (which don't support awaiting).
+    """
+
+    def __init__(self, hs: "HomeServer") -> None:
+        self._reactor = hs.get_reactor()
+        self._store = hs.get_datastores().main
+        self._clock = hs.get_clock()
+        self._notifier = hs.get_notifier()
+        self._instance_name = hs.get_instance_name()
+
+        # Map from lock name/key to set of `WaitingLock` that are active for
+        # that lock.
+        self._locks: Dict[
+            Tuple[str, str], WeakSet[Union[WaitingLock, WaitingMultiLock]]
+        ] = {}
+
+        self._clock.looping_call(self._cleanup_locks, 30_000)
+
+        self._notifier.add_lock_released_callback(self._on_lock_released)
+
+    def acquire_lock(self, lock_name: str, lock_key: str) -> "WaitingLock":
+        """Acquire a standard lock, returns a context manager that will block
+        until the lock is acquired.
+
+        Note: Care must be taken to avoid deadlocks. In particular, this
+        function does *not* timeout.
+
+        Usage:
+            async with handler.acquire_lock(name, key):
+                # Do work while holding the lock...
+        """
+
+        lock = WaitingLock(
+            reactor=self._reactor,
+            store=self._store,
+            handler=self,
+            lock_name=lock_name,
+            lock_key=lock_key,
+            write=None,
+        )
+
+        self._locks.setdefault((lock_name, lock_key), WeakSet()).add(lock)
+
+        return lock
+
+    def acquire_read_write_lock(
+        self,
+        lock_name: str,
+        lock_key: str,
+        *,
+        write: bool,
+    ) -> "WaitingLock":
+        """Acquire a read/write lock, returns a context manager that will block
+        until the lock is acquired.
+
+        Note: Care must be taken to avoid deadlocks. In particular, this
+        function does *not* timeout.
+
+        Usage:
+            async with handler.acquire_read_write_lock(name, key, write=True):
+                # Do work while holding the lock...
+        """
+
+        lock = WaitingLock(
+            reactor=self._reactor,
+            store=self._store,
+            handler=self,
+            lock_name=lock_name,
+            lock_key=lock_key,
+            write=write,
+        )
+
+        self._locks.setdefault((lock_name, lock_key), WeakSet()).add(lock)
+
+        return lock
+
+    def acquire_multi_read_write_lock(
+        self,
+        lock_names: Collection[Tuple[str, str]],
+        *,
+        write: bool,
+    ) -> "WaitingMultiLock":
+        """Acquires multi read/write locks at once, returns a context manager
+        that will block until all the locks are acquired.
+
+        This will try and acquire all locks at once, and will never hold on to a
+        subset of the locks. (This avoids accidentally creating deadlocks).
+
+        Note: Care must be taken to avoid deadlocks. In particular, this
+        function does *not* timeout.
+        """
+
+        lock = WaitingMultiLock(
+            lock_names=lock_names,
+            write=write,
+            reactor=self._reactor,
+            store=self._store,
+            handler=self,
+        )
+
+        for lock_name, lock_key in lock_names:
+            self._locks.setdefault((lock_name, lock_key), WeakSet()).add(lock)
+
+        return lock
+
+    def notify_lock_released(self, lock_name: str, lock_key: str) -> None:
+        """Notify that a lock has been released.
+
+        Pokes both the notifier and replication.
+        """
+
+        self._notifier.notify_lock_released(self._instance_name, lock_name, lock_key)
+
+    def _on_lock_released(
+        self, instance_name: str, lock_name: str, lock_key: str
+    ) -> None:
+        """Called when a lock has been released.
+
+        Wakes up any locks that might be waiting on this.
+        """
+        locks = self._locks.get((lock_name, lock_key))
+        if not locks:
+            return
+
+        def _wake_deferred(deferred: defer.Deferred) -> None:
+            if not deferred.called:
+                deferred.callback(None)
+
+        for lock in locks:
+            self._clock.call_later(0, _wake_deferred, lock.deferred)
+
+    @wrap_as_background_process("_cleanup_locks")
+    async def _cleanup_locks(self) -> None:
+        """Periodically cleans out stale entries in the locks map"""
+        self._locks = {key: value for key, value in self._locks.items() if value}
+
+
+@attr.s(auto_attribs=True, eq=False)
+class WaitingLock:
+    reactor: IReactorTime
+    store: LockStore
+    handler: WorkerLocksHandler
+    lock_name: str
+    lock_key: str
+    write: Optional[bool]
+    deferred: "defer.Deferred[None]" = attr.Factory(defer.Deferred)
+    _inner_lock: Optional[Lock] = None
+    _retry_interval: float = 0.1
+    _lock_span: "opentracing.Scope" = attr.Factory(
+        lambda: start_active_span("WaitingLock.lock")
+    )
+
+    async def __aenter__(self) -> None:
+        self._lock_span.__enter__()
+
+        with start_active_span("WaitingLock.waiting_for_lock"):
+            while self._inner_lock is None:
+                self.deferred = defer.Deferred()
+
+                if self.write is not None:
+                    lock = await self.store.try_acquire_read_write_lock(
+                        self.lock_name, self.lock_key, write=self.write
+                    )
+                else:
+                    lock = await self.store.try_acquire_lock(
+                        self.lock_name, self.lock_key
+                    )
+
+                if lock:
+                    self._inner_lock = lock
+                    break
+
+                try:
+                    # Wait until the we get notified the lock might have been
+                    # released (by the deferred being resolved). We also
+                    # periodically wake up in case the lock was released but we
+                    # weren't notified.
+                    with PreserveLoggingContext():
+                        await timeout_deferred(
+                            deferred=self.deferred,
+                            timeout=self._get_next_retry_interval(),
+                            reactor=self.reactor,
+                        )
+                except Exception:
+                    pass
+
+        return await self._inner_lock.__aenter__()
+
+    async def __aexit__(
+        self,
+        exc_type: Optional[Type[BaseException]],
+        exc: Optional[BaseException],
+        tb: Optional[TracebackType],
+    ) -> Optional[bool]:
+        assert self._inner_lock
+
+        self.handler.notify_lock_released(self.lock_name, self.lock_key)
+
+        try:
+            r = await self._inner_lock.__aexit__(exc_type, exc, tb)
+        finally:
+            self._lock_span.__exit__(exc_type, exc, tb)
+
+        return r
+
+    def _get_next_retry_interval(self) -> float:
+        next = self._retry_interval
+        self._retry_interval = max(5, next * 2)
+        return next * random.uniform(0.9, 1.1)
+
+
+@attr.s(auto_attribs=True, eq=False)
+class WaitingMultiLock:
+    lock_names: Collection[Tuple[str, str]]
+
+    write: bool
+
+    reactor: IReactorTime
+    store: LockStore
+    handler: WorkerLocksHandler
+
+    deferred: "defer.Deferred[None]" = attr.Factory(defer.Deferred)
+
+    _inner_lock_cm: Optional[AsyncContextManager] = None
+    _retry_interval: float = 0.1
+    _lock_span: "opentracing.Scope" = attr.Factory(
+        lambda: start_active_span("WaitingLock.lock")
+    )
+
+    async def __aenter__(self) -> None:
+        self._lock_span.__enter__()
+
+        with start_active_span("WaitingLock.waiting_for_lock"):
+            while self._inner_lock_cm is None:
+                self.deferred = defer.Deferred()
+
+                lock_cm = await self.store.try_acquire_multi_read_write_lock(
+                    self.lock_names, write=self.write
+                )
+
+                if lock_cm:
+                    self._inner_lock_cm = lock_cm
+                    break
+
+                try:
+                    # Wait until the we get notified the lock might have been
+                    # released (by the deferred being resolved). We also
+                    # periodically wake up in case the lock was released but we
+                    # weren't notified.
+                    with PreserveLoggingContext():
+                        await timeout_deferred(
+                            deferred=self.deferred,
+                            timeout=self._get_next_retry_interval(),
+                            reactor=self.reactor,
+                        )
+                except Exception:
+                    pass
+
+        assert self._inner_lock_cm
+        await self._inner_lock_cm.__aenter__()
+        return
+
+    async def __aexit__(
+        self,
+        exc_type: Optional[Type[BaseException]],
+        exc: Optional[BaseException],
+        tb: Optional[TracebackType],
+    ) -> Optional[bool]:
+        assert self._inner_lock_cm
+
+        for lock_name, lock_key in self.lock_names:
+            self.handler.notify_lock_released(lock_name, lock_key)
+
+        try:
+            r = await self._inner_lock_cm.__aexit__(exc_type, exc, tb)
+        finally:
+            self._lock_span.__exit__(exc_type, exc, tb)
+
+        return r
+
+    def _get_next_retry_interval(self) -> float:
+        next = self._retry_interval
+        self._retry_interval = max(5, next * 2)
+        return next * random.uniform(0.9, 1.1)