summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2023-07-31 10:58:03 +0100
committerGitHub <noreply@github.com>2023-07-31 10:58:03 +0100
commitae55cc1e6bc6527d0e359a823c474f5c9ed4382e (patch)
treee874ad28d0ef94933201fe88511c9c8b93968a32
parentBump types-commonmark from 0.9.2.3 to 0.9.2.4 (#16037) (diff)
downloadsynapse-ae55cc1e6bc6527d0e359a823c474f5c9ed4382e.tar.xz
Add ability to wait for locks and add locks to purge history / room deletion (#15791)
c.f. #13476
-rw-r--r--changelog.d/15791.bugfix1
-rw-r--r--synapse/federation/federation_server.py17
-rw-r--r--synapse/handlers/message.py38
-rw-r--r--synapse/handlers/pagination.py23
-rw-r--r--synapse/handlers/room_member.py45
-rw-r--r--synapse/handlers/worker_lock.py333
-rw-r--r--synapse/notifier.py16
-rw-r--r--synapse/replication/tcp/commands.py33
-rw-r--r--synapse/replication/tcp/handler.py22
-rw-r--r--synapse/rest/client/room_upgrade_rest_servlet.py11
-rw-r--r--synapse/server.py5
-rw-r--r--synapse/storage/controllers/persist_events.py27
-rw-r--r--synapse/storage/databases/main/lock.py190
-rw-r--r--tests/handlers/test_worker_lock.py74
-rw-r--r--tests/rest/client/test_rooms.py4
-rw-r--r--tests/storage/databases/main/test_lock.py52
16 files changed, 783 insertions, 108 deletions
diff --git a/changelog.d/15791.bugfix b/changelog.d/15791.bugfix
new file mode 100644
index 0000000000..182634b62f
--- /dev/null
+++ b/changelog.d/15791.bugfix
@@ -0,0 +1 @@
+Fix bug where purging history and paginating simultaneously could lead to database corruption when using workers.
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index fa61dd8c10..a90d99c4d6 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -63,6 +63,7 @@ from synapse.federation.federation_base import (
 )
 from synapse.federation.persistence import TransactionActions
 from synapse.federation.units import Edu, Transaction
+from synapse.handlers.worker_lock import DELETE_ROOM_LOCK_NAME
 from synapse.http.servlet import assert_params_in_dict
 from synapse.logging.context import (
     make_deferred_yieldable,
@@ -137,6 +138,7 @@ class FederationServer(FederationBase):
         self._event_auth_handler = hs.get_event_auth_handler()
         self._room_member_handler = hs.get_room_member_handler()
         self._e2e_keys_handler = hs.get_e2e_keys_handler()
+        self._worker_lock_handler = hs.get_worker_locks_handler()
 
         self._state_storage_controller = hs.get_storage_controllers().state
 
@@ -1236,9 +1238,18 @@ class FederationServer(FederationBase):
                 logger.info("handling received PDU in room %s: %s", room_id, event)
                 try:
                     with nested_logging_context(event.event_id):
-                        await self._federation_event_handler.on_receive_pdu(
-                            origin, event
-                        )
+                        # We're taking out a lock within a lock, which could
+                        # lead to deadlocks if we're not careful. However, it is
+                        # safe on this occasion as we only ever take a write
+                        # lock when deleting a room, which we would never do
+                        # while holding the `_INBOUND_EVENT_HANDLING_LOCK_NAME`
+                        # lock.
+                        async with self._worker_lock_handler.acquire_read_write_lock(
+                            DELETE_ROOM_LOCK_NAME, room_id, write=False
+                        ):
+                            await self._federation_event_handler.on_receive_pdu(
+                                origin, event
+                            )
                 except FederationError as e:
                     # XXX: Ideally we'd inform the remote we failed to process
                     # the event, but we can't return an error in the transaction
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index fff0b5fa12..187dedae7d 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -53,6 +53,7 @@ from synapse.events.snapshot import EventContext, UnpersistedEventContextBase
 from synapse.events.utils import SerializeEventConfig, maybe_upsert_event_field
 from synapse.events.validator import EventValidator
 from synapse.handlers.directory import DirectoryHandler
+from synapse.handlers.worker_lock import DELETE_ROOM_LOCK_NAME
 from synapse.logging import opentracing
 from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.metrics.background_process_metrics import run_as_background_process
@@ -485,6 +486,7 @@ class EventCreationHandler:
         self._events_shard_config = self.config.worker.events_shard_config
         self._instance_name = hs.get_instance_name()
         self._notifier = hs.get_notifier()
+        self._worker_lock_handler = hs.get_worker_locks_handler()
 
         self.room_prejoin_state_types = self.hs.config.api.room_prejoin_state
 
@@ -1010,6 +1012,37 @@ class EventCreationHandler:
                         event.internal_metadata.stream_ordering,
                     )
 
+        async with self._worker_lock_handler.acquire_read_write_lock(
+            DELETE_ROOM_LOCK_NAME, room_id, write=False
+        ):
+            return await self._create_and_send_nonmember_event_locked(
+                requester=requester,
+                event_dict=event_dict,
+                allow_no_prev_events=allow_no_prev_events,
+                prev_event_ids=prev_event_ids,
+                state_event_ids=state_event_ids,
+                ratelimit=ratelimit,
+                txn_id=txn_id,
+                ignore_shadow_ban=ignore_shadow_ban,
+                outlier=outlier,
+                depth=depth,
+            )
+
+    async def _create_and_send_nonmember_event_locked(
+        self,
+        requester: Requester,
+        event_dict: dict,
+        allow_no_prev_events: bool = False,
+        prev_event_ids: Optional[List[str]] = None,
+        state_event_ids: Optional[List[str]] = None,
+        ratelimit: bool = True,
+        txn_id: Optional[str] = None,
+        ignore_shadow_ban: bool = False,
+        outlier: bool = False,
+        depth: Optional[int] = None,
+    ) -> Tuple[EventBase, int]:
+        room_id = event_dict["room_id"]
+
         # If we don't have any prev event IDs specified then we need to
         # check that the host is in the room (as otherwise populating the
         # prev events will fail), at which point we may as well check the
@@ -1923,7 +1956,10 @@ class EventCreationHandler:
         )
 
         for room_id in room_ids:
-            dummy_event_sent = await self._send_dummy_event_for_room(room_id)
+            async with self._worker_lock_handler.acquire_read_write_lock(
+                DELETE_ROOM_LOCK_NAME, room_id, write=False
+            ):
+                dummy_event_sent = await self._send_dummy_event_for_room(room_id)
 
             if not dummy_event_sent:
                 # Did not find a valid user in the room, so remove from future attempts
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 19b8728db9..da34658470 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -46,6 +46,11 @@ logger = logging.getLogger(__name__)
 BACKFILL_BECAUSE_TOO_MANY_GAPS_THRESHOLD = 3
 
 
+PURGE_HISTORY_LOCK_NAME = "purge_history_lock"
+
+DELETE_ROOM_LOCK_NAME = "delete_room_lock"
+
+
 @attr.s(slots=True, auto_attribs=True)
 class PurgeStatus:
     """Object tracking the status of a purge request
@@ -142,6 +147,7 @@ class PaginationHandler:
         self._server_name = hs.hostname
         self._room_shutdown_handler = hs.get_room_shutdown_handler()
         self._relations_handler = hs.get_relations_handler()
+        self._worker_locks = hs.get_worker_locks_handler()
 
         self.pagination_lock = ReadWriteLock()
         # IDs of rooms in which there currently an active purge *or delete* operation.
@@ -356,7 +362,9 @@ class PaginationHandler:
         """
         self._purges_in_progress_by_room.add(room_id)
         try:
-            async with self.pagination_lock.write(room_id):
+            async with self._worker_locks.acquire_read_write_lock(
+                PURGE_HISTORY_LOCK_NAME, room_id, write=True
+            ):
                 await self._storage_controllers.purge_events.purge_history(
                     room_id, token, delete_local_events
                 )
@@ -412,7 +420,10 @@ class PaginationHandler:
             room_id: room to be purged
             force: set true to skip checking for joined users.
         """
-        async with self.pagination_lock.write(room_id):
+        async with self._worker_locks.acquire_multi_read_write_lock(
+            [(PURGE_HISTORY_LOCK_NAME, room_id), (DELETE_ROOM_LOCK_NAME, room_id)],
+            write=True,
+        ):
             # first check that we have no users in this room
             if not force:
                 joined = await self.store.is_host_joined(room_id, self._server_name)
@@ -471,7 +482,9 @@ class PaginationHandler:
 
         room_token = from_token.room_key
 
-        async with self.pagination_lock.read(room_id):
+        async with self._worker_locks.acquire_read_write_lock(
+            PURGE_HISTORY_LOCK_NAME, room_id, write=False
+        ):
             (membership, member_event_id) = (None, None)
             if not use_admin_priviledge:
                 (
@@ -747,7 +760,9 @@ class PaginationHandler:
 
         self._purges_in_progress_by_room.add(room_id)
         try:
-            async with self.pagination_lock.write(room_id):
+            async with self._worker_locks.acquire_read_write_lock(
+                PURGE_HISTORY_LOCK_NAME, room_id, write=True
+            ):
                 self._delete_by_id[delete_id].status = DeleteStatus.STATUS_SHUTTING_DOWN
                 self._delete_by_id[
                     delete_id
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 496e701f13..6cca2ec344 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -39,6 +39,7 @@ from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
 from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN
 from synapse.handlers.state_deltas import MatchChange, StateDeltasHandler
+from synapse.handlers.worker_lock import DELETE_ROOM_LOCK_NAME
 from synapse.logging import opentracing
 from synapse.metrics import event_processing_positions
 from synapse.metrics.background_process_metrics import run_as_background_process
@@ -94,6 +95,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         self.event_creation_handler = hs.get_event_creation_handler()
         self.account_data_handler = hs.get_account_data_handler()
         self.event_auth_handler = hs.get_event_auth_handler()
+        self._worker_lock_handler = hs.get_worker_locks_handler()
 
         self.member_linearizer: Linearizer = Linearizer(name="member")
         self.member_as_limiter = Linearizer(max_count=10, name="member_as_limiter")
@@ -638,26 +640,29 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         # by application services), and then by room ID.
         async with self.member_as_limiter.queue(as_id):
             async with self.member_linearizer.queue(key):
-                with opentracing.start_active_span("update_membership_locked"):
-                    result = await self.update_membership_locked(
-                        requester,
-                        target,
-                        room_id,
-                        action,
-                        txn_id=txn_id,
-                        remote_room_hosts=remote_room_hosts,
-                        third_party_signed=third_party_signed,
-                        ratelimit=ratelimit,
-                        content=content,
-                        new_room=new_room,
-                        require_consent=require_consent,
-                        outlier=outlier,
-                        allow_no_prev_events=allow_no_prev_events,
-                        prev_event_ids=prev_event_ids,
-                        state_event_ids=state_event_ids,
-                        depth=depth,
-                        origin_server_ts=origin_server_ts,
-                    )
+                async with self._worker_lock_handler.acquire_read_write_lock(
+                    DELETE_ROOM_LOCK_NAME, room_id, write=False
+                ):
+                    with opentracing.start_active_span("update_membership_locked"):
+                        result = await self.update_membership_locked(
+                            requester,
+                            target,
+                            room_id,
+                            action,
+                            txn_id=txn_id,
+                            remote_room_hosts=remote_room_hosts,
+                            third_party_signed=third_party_signed,
+                            ratelimit=ratelimit,
+                            content=content,
+                            new_room=new_room,
+                            require_consent=require_consent,
+                            outlier=outlier,
+                            allow_no_prev_events=allow_no_prev_events,
+                            prev_event_ids=prev_event_ids,
+                            state_event_ids=state_event_ids,
+                            depth=depth,
+                            origin_server_ts=origin_server_ts,
+                        )
 
         return result
 
diff --git a/synapse/handlers/worker_lock.py b/synapse/handlers/worker_lock.py
new file mode 100644
index 0000000000..72df773a86
--- /dev/null
+++ b/synapse/handlers/worker_lock.py
@@ -0,0 +1,333 @@
+# Copyright 2023 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import random
+from types import TracebackType
+from typing import (
+    TYPE_CHECKING,
+    AsyncContextManager,
+    Collection,
+    Dict,
+    Optional,
+    Tuple,
+    Type,
+    Union,
+)
+from weakref import WeakSet
+
+import attr
+
+from twisted.internet import defer
+from twisted.internet.interfaces import IReactorTime
+
+from synapse.logging.context import PreserveLoggingContext
+from synapse.logging.opentracing import start_active_span
+from synapse.metrics.background_process_metrics import wrap_as_background_process
+from synapse.storage.databases.main.lock import Lock, LockStore
+from synapse.util.async_helpers import timeout_deferred
+
+if TYPE_CHECKING:
+    from synapse.logging.opentracing import opentracing
+    from synapse.server import HomeServer
+
+
+DELETE_ROOM_LOCK_NAME = "delete_room_lock"
+
+
+class WorkerLocksHandler:
+    """A class for waiting on taking out locks, rather than using the storage
+    functions directly (which don't support awaiting).
+    """
+
+    def __init__(self, hs: "HomeServer") -> None:
+        self._reactor = hs.get_reactor()
+        self._store = hs.get_datastores().main
+        self._clock = hs.get_clock()
+        self._notifier = hs.get_notifier()
+        self._instance_name = hs.get_instance_name()
+
+        # Map from lock name/key to set of `WaitingLock` that are active for
+        # that lock.
+        self._locks: Dict[
+            Tuple[str, str], WeakSet[Union[WaitingLock, WaitingMultiLock]]
+        ] = {}
+
+        self._clock.looping_call(self._cleanup_locks, 30_000)
+
+        self._notifier.add_lock_released_callback(self._on_lock_released)
+
+    def acquire_lock(self, lock_name: str, lock_key: str) -> "WaitingLock":
+        """Acquire a standard lock, returns a context manager that will block
+        until the lock is acquired.
+
+        Note: Care must be taken to avoid deadlocks. In particular, this
+        function does *not* timeout.
+
+        Usage:
+            async with handler.acquire_lock(name, key):
+                # Do work while holding the lock...
+        """
+
+        lock = WaitingLock(
+            reactor=self._reactor,
+            store=self._store,
+            handler=self,
+            lock_name=lock_name,
+            lock_key=lock_key,
+            write=None,
+        )
+
+        self._locks.setdefault((lock_name, lock_key), WeakSet()).add(lock)
+
+        return lock
+
+    def acquire_read_write_lock(
+        self,
+        lock_name: str,
+        lock_key: str,
+        *,
+        write: bool,
+    ) -> "WaitingLock":
+        """Acquire a read/write lock, returns a context manager that will block
+        until the lock is acquired.
+
+        Note: Care must be taken to avoid deadlocks. In particular, this
+        function does *not* timeout.
+
+        Usage:
+            async with handler.acquire_read_write_lock(name, key, write=True):
+                # Do work while holding the lock...
+        """
+
+        lock = WaitingLock(
+            reactor=self._reactor,
+            store=self._store,
+            handler=self,
+            lock_name=lock_name,
+            lock_key=lock_key,
+            write=write,
+        )
+
+        self._locks.setdefault((lock_name, lock_key), WeakSet()).add(lock)
+
+        return lock
+
+    def acquire_multi_read_write_lock(
+        self,
+        lock_names: Collection[Tuple[str, str]],
+        *,
+        write: bool,
+    ) -> "WaitingMultiLock":
+        """Acquires multi read/write locks at once, returns a context manager
+        that will block until all the locks are acquired.
+
+        This will try and acquire all locks at once, and will never hold on to a
+        subset of the locks. (This avoids accidentally creating deadlocks).
+
+        Note: Care must be taken to avoid deadlocks. In particular, this
+        function does *not* timeout.
+        """
+
+        lock = WaitingMultiLock(
+            lock_names=lock_names,
+            write=write,
+            reactor=self._reactor,
+            store=self._store,
+            handler=self,
+        )
+
+        for lock_name, lock_key in lock_names:
+            self._locks.setdefault((lock_name, lock_key), WeakSet()).add(lock)
+
+        return lock
+
+    def notify_lock_released(self, lock_name: str, lock_key: str) -> None:
+        """Notify that a lock has been released.
+
+        Pokes both the notifier and replication.
+        """
+
+        self._notifier.notify_lock_released(self._instance_name, lock_name, lock_key)
+
+    def _on_lock_released(
+        self, instance_name: str, lock_name: str, lock_key: str
+    ) -> None:
+        """Called when a lock has been released.
+
+        Wakes up any locks that might be waiting on this.
+        """
+        locks = self._locks.get((lock_name, lock_key))
+        if not locks:
+            return
+
+        def _wake_deferred(deferred: defer.Deferred) -> None:
+            if not deferred.called:
+                deferred.callback(None)
+
+        for lock in locks:
+            self._clock.call_later(0, _wake_deferred, lock.deferred)
+
+    @wrap_as_background_process("_cleanup_locks")
+    async def _cleanup_locks(self) -> None:
+        """Periodically cleans out stale entries in the locks map"""
+        self._locks = {key: value for key, value in self._locks.items() if value}
+
+
+@attr.s(auto_attribs=True, eq=False)
+class WaitingLock:
+    reactor: IReactorTime
+    store: LockStore
+    handler: WorkerLocksHandler
+    lock_name: str
+    lock_key: str
+    write: Optional[bool]
+    deferred: "defer.Deferred[None]" = attr.Factory(defer.Deferred)
+    _inner_lock: Optional[Lock] = None
+    _retry_interval: float = 0.1
+    _lock_span: "opentracing.Scope" = attr.Factory(
+        lambda: start_active_span("WaitingLock.lock")
+    )
+
+    async def __aenter__(self) -> None:
+        self._lock_span.__enter__()
+
+        with start_active_span("WaitingLock.waiting_for_lock"):
+            while self._inner_lock is None:
+                self.deferred = defer.Deferred()
+
+                if self.write is not None:
+                    lock = await self.store.try_acquire_read_write_lock(
+                        self.lock_name, self.lock_key, write=self.write
+                    )
+                else:
+                    lock = await self.store.try_acquire_lock(
+                        self.lock_name, self.lock_key
+                    )
+
+                if lock:
+                    self._inner_lock = lock
+                    break
+
+                try:
+                    # Wait until the we get notified the lock might have been
+                    # released (by the deferred being resolved). We also
+                    # periodically wake up in case the lock was released but we
+                    # weren't notified.
+                    with PreserveLoggingContext():
+                        await timeout_deferred(
+                            deferred=self.deferred,
+                            timeout=self._get_next_retry_interval(),
+                            reactor=self.reactor,
+                        )
+                except Exception:
+                    pass
+
+        return await self._inner_lock.__aenter__()
+
+    async def __aexit__(
+        self,
+        exc_type: Optional[Type[BaseException]],
+        exc: Optional[BaseException],
+        tb: Optional[TracebackType],
+    ) -> Optional[bool]:
+        assert self._inner_lock
+
+        self.handler.notify_lock_released(self.lock_name, self.lock_key)
+
+        try:
+            r = await self._inner_lock.__aexit__(exc_type, exc, tb)
+        finally:
+            self._lock_span.__exit__(exc_type, exc, tb)
+
+        return r
+
+    def _get_next_retry_interval(self) -> float:
+        next = self._retry_interval
+        self._retry_interval = max(5, next * 2)
+        return next * random.uniform(0.9, 1.1)
+
+
+@attr.s(auto_attribs=True, eq=False)
+class WaitingMultiLock:
+    lock_names: Collection[Tuple[str, str]]
+
+    write: bool
+
+    reactor: IReactorTime
+    store: LockStore
+    handler: WorkerLocksHandler
+
+    deferred: "defer.Deferred[None]" = attr.Factory(defer.Deferred)
+
+    _inner_lock_cm: Optional[AsyncContextManager] = None
+    _retry_interval: float = 0.1
+    _lock_span: "opentracing.Scope" = attr.Factory(
+        lambda: start_active_span("WaitingLock.lock")
+    )
+
+    async def __aenter__(self) -> None:
+        self._lock_span.__enter__()
+
+        with start_active_span("WaitingLock.waiting_for_lock"):
+            while self._inner_lock_cm is None:
+                self.deferred = defer.Deferred()
+
+                lock_cm = await self.store.try_acquire_multi_read_write_lock(
+                    self.lock_names, write=self.write
+                )
+
+                if lock_cm:
+                    self._inner_lock_cm = lock_cm
+                    break
+
+                try:
+                    # Wait until the we get notified the lock might have been
+                    # released (by the deferred being resolved). We also
+                    # periodically wake up in case the lock was released but we
+                    # weren't notified.
+                    with PreserveLoggingContext():
+                        await timeout_deferred(
+                            deferred=self.deferred,
+                            timeout=self._get_next_retry_interval(),
+                            reactor=self.reactor,
+                        )
+                except Exception:
+                    pass
+
+        assert self._inner_lock_cm
+        await self._inner_lock_cm.__aenter__()
+        return
+
+    async def __aexit__(
+        self,
+        exc_type: Optional[Type[BaseException]],
+        exc: Optional[BaseException],
+        tb: Optional[TracebackType],
+    ) -> Optional[bool]:
+        assert self._inner_lock_cm
+
+        for lock_name, lock_key in self.lock_names:
+            self.handler.notify_lock_released(lock_name, lock_key)
+
+        try:
+            r = await self._inner_lock_cm.__aexit__(exc_type, exc, tb)
+        finally:
+            self._lock_span.__exit__(exc_type, exc, tb)
+
+        return r
+
+    def _get_next_retry_interval(self) -> float:
+        next = self._retry_interval
+        self._retry_interval = max(5, next * 2)
+        return next * random.uniform(0.9, 1.1)
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 897272ad5b..68115bca70 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -234,6 +234,9 @@ class Notifier:
 
         self._third_party_rules = hs.get_module_api_callbacks().third_party_event_rules
 
+        # List of callbacks to be notified when a lock is released
+        self._lock_released_callback: List[Callable[[str, str, str], None]] = []
+
         self.clock = hs.get_clock()
         self.appservice_handler = hs.get_application_service_handler()
         self._pusher_pool = hs.get_pusherpool()
@@ -785,6 +788,19 @@ class Notifier:
         # that any in flight requests can be immediately retried.
         self._federation_client.wake_destination(server)
 
+    def add_lock_released_callback(
+        self, callback: Callable[[str, str, str], None]
+    ) -> None:
+        """Add a function to be called whenever we are notified about a released lock."""
+        self._lock_released_callback.append(callback)
+
+    def notify_lock_released(
+        self, instance_name: str, lock_name: str, lock_key: str
+    ) -> None:
+        """Notify the callbacks that a lock has been released."""
+        for cb in self._lock_released_callback:
+            cb(instance_name, lock_name, lock_key)
+
 
 @attr.s(auto_attribs=True)
 class ReplicationNotifier:
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index 32f52e54d8..10f5c98ff8 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -422,6 +422,36 @@ class RemoteServerUpCommand(_SimpleCommand):
     NAME = "REMOTE_SERVER_UP"
 
 
+class LockReleasedCommand(Command):
+    """Sent to inform other instances that a given lock has been dropped.
+
+    Format::
+
+        LOCK_RELEASED ["<instance_name>", "<lock_name>", "<lock_key>"]
+    """
+
+    NAME = "LOCK_RELEASED"
+
+    def __init__(
+        self,
+        instance_name: str,
+        lock_name: str,
+        lock_key: str,
+    ):
+        self.instance_name = instance_name
+        self.lock_name = lock_name
+        self.lock_key = lock_key
+
+    @classmethod
+    def from_line(cls: Type["LockReleasedCommand"], line: str) -> "LockReleasedCommand":
+        instance_name, lock_name, lock_key = json_decoder.decode(line)
+
+        return cls(instance_name, lock_name, lock_key)
+
+    def to_line(self) -> str:
+        return json_encoder.encode([self.instance_name, self.lock_name, self.lock_key])
+
+
 _COMMANDS: Tuple[Type[Command], ...] = (
     ServerCommand,
     RdataCommand,
@@ -435,6 +465,7 @@ _COMMANDS: Tuple[Type[Command], ...] = (
     UserIpCommand,
     RemoteServerUpCommand,
     ClearUserSyncsCommand,
+    LockReleasedCommand,
 )
 
 # Map of command name to command type.
@@ -448,6 +479,7 @@ VALID_SERVER_COMMANDS = (
     ErrorCommand.NAME,
     PingCommand.NAME,
     RemoteServerUpCommand.NAME,
+    LockReleasedCommand.NAME,
 )
 
 # The commands the client is allowed to send
@@ -461,6 +493,7 @@ VALID_CLIENT_COMMANDS = (
     UserIpCommand.NAME,
     ErrorCommand.NAME,
     RemoteServerUpCommand.NAME,
+    LockReleasedCommand.NAME,
 )
 
 
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 5d108fe11b..a2cabba7b1 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -39,6 +39,7 @@ from synapse.replication.tcp.commands import (
     ClearUserSyncsCommand,
     Command,
     FederationAckCommand,
+    LockReleasedCommand,
     PositionCommand,
     RdataCommand,
     RemoteServerUpCommand,
@@ -248,6 +249,9 @@ class ReplicationCommandHandler:
         if self._is_master or self._should_insert_client_ips:
             self.subscribe_to_channel("USER_IP")
 
+        if hs.config.redis.redis_enabled:
+            self._notifier.add_lock_released_callback(self.on_lock_released)
+
     def subscribe_to_channel(self, channel_name: str) -> None:
         """
         Indicates that we wish to subscribe to a Redis channel by name.
@@ -648,6 +652,17 @@ class ReplicationCommandHandler:
 
         self._notifier.notify_remote_server_up(cmd.data)
 
+    def on_LOCK_RELEASED(
+        self, conn: IReplicationConnection, cmd: LockReleasedCommand
+    ) -> None:
+        """Called when we get a new LOCK_RELEASED command."""
+        if cmd.instance_name == self._instance_name:
+            return
+
+        self._notifier.notify_lock_released(
+            cmd.instance_name, cmd.lock_name, cmd.lock_key
+        )
+
     def new_connection(self, connection: IReplicationConnection) -> None:
         """Called when we have a new connection."""
         self._connections.append(connection)
@@ -754,6 +769,13 @@ class ReplicationCommandHandler:
         """
         self.send_command(RdataCommand(stream_name, self._instance_name, token, data))
 
+    def on_lock_released(
+        self, instance_name: str, lock_name: str, lock_key: str
+    ) -> None:
+        """Called when we released a lock and should notify other instances."""
+        if instance_name == self._instance_name:
+            self.send_command(LockReleasedCommand(instance_name, lock_name, lock_key))
+
 
 UpdateToken = TypeVar("UpdateToken")
 UpdateRow = TypeVar("UpdateRow")
diff --git a/synapse/rest/client/room_upgrade_rest_servlet.py b/synapse/rest/client/room_upgrade_rest_servlet.py
index 6a7792e18b..4a5d9e13e7 100644
--- a/synapse/rest/client/room_upgrade_rest_servlet.py
+++ b/synapse/rest/client/room_upgrade_rest_servlet.py
@@ -17,6 +17,7 @@ from typing import TYPE_CHECKING, Tuple
 
 from synapse.api.errors import Codes, ShadowBanError, SynapseError
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
+from synapse.handlers.worker_lock import DELETE_ROOM_LOCK_NAME
 from synapse.http.server import HttpServer
 from synapse.http.servlet import (
     RestServlet,
@@ -60,6 +61,7 @@ class RoomUpgradeRestServlet(RestServlet):
         self._hs = hs
         self._room_creation_handler = hs.get_room_creation_handler()
         self._auth = hs.get_auth()
+        self._worker_lock_handler = hs.get_worker_locks_handler()
 
     async def on_POST(
         self, request: SynapseRequest, room_id: str
@@ -78,9 +80,12 @@ class RoomUpgradeRestServlet(RestServlet):
             )
 
         try:
-            new_room_id = await self._room_creation_handler.upgrade_room(
-                requester, room_id, new_version
-            )
+            async with self._worker_lock_handler.acquire_read_write_lock(
+                DELETE_ROOM_LOCK_NAME, room_id, write=False
+            ):
+                new_room_id = await self._room_creation_handler.upgrade_room(
+                    requester, room_id, new_version
+                )
         except ShadowBanError:
             # Generate a random room ID.
             new_room_id = stringutils.random_string(18)
diff --git a/synapse/server.py b/synapse/server.py
index b72b76a38b..8430f99ef2 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -107,6 +107,7 @@ from synapse.handlers.stats import StatsHandler
 from synapse.handlers.sync import SyncHandler
 from synapse.handlers.typing import FollowerTypingHandler, TypingWriterHandler
 from synapse.handlers.user_directory import UserDirectoryHandler
+from synapse.handlers.worker_lock import WorkerLocksHandler
 from synapse.http.client import (
     InsecureInterceptableContextFactory,
     ReplicationClient,
@@ -912,3 +913,7 @@ class HomeServer(metaclass=abc.ABCMeta):
     def get_common_usage_metrics_manager(self) -> CommonUsageMetricsManager:
         """Usage metrics shared between phone home stats and the prometheus exporter."""
         return CommonUsageMetricsManager(self)
+
+    @cache_in_self
+    def get_worker_locks_handler(self) -> WorkerLocksHandler:
+        return WorkerLocksHandler(self)
diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py
index 35c0680365..35cd1089d6 100644
--- a/synapse/storage/controllers/persist_events.py
+++ b/synapse/storage/controllers/persist_events.py
@@ -45,6 +45,7 @@ from twisted.internet import defer
 from synapse.api.constants import EventTypes, Membership
 from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
+from synapse.handlers.worker_lock import DELETE_ROOM_LOCK_NAME
 from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
 from synapse.logging.opentracing import (
     SynapseTags,
@@ -338,6 +339,7 @@ class EventsPersistenceStorageController:
         )
         self._state_resolution_handler = hs.get_state_resolution_handler()
         self._state_controller = state_controller
+        self.hs = hs
 
     async def _process_event_persist_queue_task(
         self,
@@ -350,15 +352,22 @@ class EventsPersistenceStorageController:
             A dictionary of event ID to event ID we didn't persist as we already
             had another event persisted with the same TXN ID.
         """
-        if isinstance(task, _PersistEventsTask):
-            return await self._persist_event_batch(room_id, task)
-        elif isinstance(task, _UpdateCurrentStateTask):
-            await self._update_current_state(room_id, task)
-            return {}
-        else:
-            raise AssertionError(
-                f"Found an unexpected task type in event persistence queue: {task}"
-            )
+
+        # Ensure that the room can't be deleted while we're persisting events to
+        # it. We might already have taken out the lock, but since this is just a
+        # "read" lock its inherently reentrant.
+        async with self.hs.get_worker_locks_handler().acquire_read_write_lock(
+            DELETE_ROOM_LOCK_NAME, room_id, write=False
+        ):
+            if isinstance(task, _PersistEventsTask):
+                return await self._persist_event_batch(room_id, task)
+            elif isinstance(task, _UpdateCurrentStateTask):
+                await self._update_current_state(room_id, task)
+                return {}
+            else:
+                raise AssertionError(
+                    f"Found an unexpected task type in event persistence queue: {task}"
+                )
 
     @trace
     async def persist_events(
diff --git a/synapse/storage/databases/main/lock.py b/synapse/storage/databases/main/lock.py
index c89b4f7919..1680bf6168 100644
--- a/synapse/storage/databases/main/lock.py
+++ b/synapse/storage/databases/main/lock.py
@@ -12,8 +12,9 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+from contextlib import AsyncExitStack
 from types import TracebackType
-from typing import TYPE_CHECKING, Optional, Set, Tuple, Type
+from typing import TYPE_CHECKING, Collection, Optional, Set, Tuple, Type
 from weakref import WeakValueDictionary
 
 from twisted.internet.interfaces import IReactorCore
@@ -208,76 +209,85 @@ class LockStore(SQLBaseStore):
         used (otherwise the lock will leak).
         """
 
+        try:
+            lock = await self.db_pool.runInteraction(
+                "try_acquire_read_write_lock",
+                self._try_acquire_read_write_lock_txn,
+                lock_name,
+                lock_key,
+                write,
+            )
+        except self.database_engine.module.IntegrityError:
+            return None
+
+        return lock
+
+    def _try_acquire_read_write_lock_txn(
+        self,
+        txn: LoggingTransaction,
+        lock_name: str,
+        lock_key: str,
+        write: bool,
+    ) -> "Lock":
+        # We attempt to acquire the lock by inserting into
+        # `worker_read_write_locks` and seeing if that fails any
+        # constraints. If it doesn't then we have acquired the lock,
+        # otherwise we haven't.
+        #
+        # Before that though we clear the table of any stale locks.
+
         now = self._clock.time_msec()
         token = random_string(6)
 
-        def _try_acquire_read_write_lock_txn(txn: LoggingTransaction) -> None:
-            # We attempt to acquire the lock by inserting into
-            # `worker_read_write_locks` and seeing if that fails any
-            # constraints. If it doesn't then we have acquired the lock,
-            # otherwise we haven't.
-            #
-            # Before that though we clear the table of any stale locks.
-
-            delete_sql = """
-                DELETE FROM worker_read_write_locks
-                    WHERE last_renewed_ts < ? AND lock_name = ? AND lock_key = ?;
-            """
-
-            insert_sql = """
-                INSERT INTO worker_read_write_locks (lock_name, lock_key, write_lock, instance_name, token, last_renewed_ts)
-                VALUES (?, ?, ?, ?, ?, ?)
-            """
-
-            if isinstance(self.database_engine, PostgresEngine):
-                # For Postgres we can send these queries at the same time.
-                txn.execute(
-                    delete_sql + ";" + insert_sql,
-                    (
-                        # DELETE args
-                        now - _LOCK_TIMEOUT_MS,
-                        lock_name,
-                        lock_key,
-                        # UPSERT args
-                        lock_name,
-                        lock_key,
-                        write,
-                        self._instance_name,
-                        token,
-                        now,
-                    ),
-                )
-            else:
-                # For SQLite these need to be two queries.
-                txn.execute(
-                    delete_sql,
-                    (
-                        now - _LOCK_TIMEOUT_MS,
-                        lock_name,
-                        lock_key,
-                    ),
-                )
-                txn.execute(
-                    insert_sql,
-                    (
-                        lock_name,
-                        lock_key,
-                        write,
-                        self._instance_name,
-                        token,
-                        now,
-                    ),
-                )
+        delete_sql = """
+            DELETE FROM worker_read_write_locks
+                WHERE last_renewed_ts < ? AND lock_name = ? AND lock_key = ?;
+        """
 
-            return
+        insert_sql = """
+            INSERT INTO worker_read_write_locks (lock_name, lock_key, write_lock, instance_name, token, last_renewed_ts)
+            VALUES (?, ?, ?, ?, ?, ?)
+        """
 
-        try:
-            await self.db_pool.runInteraction(
-                "try_acquire_read_write_lock",
-                _try_acquire_read_write_lock_txn,
+        if isinstance(self.database_engine, PostgresEngine):
+            # For Postgres we can send these queries at the same time.
+            txn.execute(
+                delete_sql + ";" + insert_sql,
+                (
+                    # DELETE args
+                    now - _LOCK_TIMEOUT_MS,
+                    lock_name,
+                    lock_key,
+                    # UPSERT args
+                    lock_name,
+                    lock_key,
+                    write,
+                    self._instance_name,
+                    token,
+                    now,
+                ),
+            )
+        else:
+            # For SQLite these need to be two queries.
+            txn.execute(
+                delete_sql,
+                (
+                    now - _LOCK_TIMEOUT_MS,
+                    lock_name,
+                    lock_key,
+                ),
+            )
+            txn.execute(
+                insert_sql,
+                (
+                    lock_name,
+                    lock_key,
+                    write,
+                    self._instance_name,
+                    token,
+                    now,
+                ),
             )
-        except self.database_engine.module.IntegrityError:
-            return None
 
         lock = Lock(
             self._reactor,
@@ -289,10 +299,58 @@ class LockStore(SQLBaseStore):
             token=token,
         )
 
-        self._live_read_write_lock_tokens[(lock_name, lock_key, token)] = lock
+        def set_lock() -> None:
+            self._live_read_write_lock_tokens[(lock_name, lock_key, token)] = lock
+
+        txn.call_after(set_lock)
 
         return lock
 
+    async def try_acquire_multi_read_write_lock(
+        self,
+        lock_names: Collection[Tuple[str, str]],
+        write: bool,
+    ) -> Optional[AsyncExitStack]:
+        """Try to acquire multiple locks for the given names/keys. Will return
+        an async context manager if the locks are successfully acquired, which
+        *must* be used (otherwise the lock will leak).
+
+        If only a subset of the locks can be acquired then it will immediately
+        drop them and return `None`.
+        """
+        try:
+            locks = await self.db_pool.runInteraction(
+                "try_acquire_multi_read_write_lock",
+                self._try_acquire_multi_read_write_lock_txn,
+                lock_names,
+                write,
+            )
+        except self.database_engine.module.IntegrityError:
+            return None
+
+        stack = AsyncExitStack()
+
+        for lock in locks:
+            await stack.enter_async_context(lock)
+
+        return stack
+
+    def _try_acquire_multi_read_write_lock_txn(
+        self,
+        txn: LoggingTransaction,
+        lock_names: Collection[Tuple[str, str]],
+        write: bool,
+    ) -> Collection["Lock"]:
+        locks = []
+
+        for lock_name, lock_key in lock_names:
+            lock = self._try_acquire_read_write_lock_txn(
+                txn, lock_name, lock_key, write
+            )
+            locks.append(lock)
+
+        return locks
+
 
 class Lock:
     """An async context manager that manages an acquired lock, ensuring it is
diff --git a/tests/handlers/test_worker_lock.py b/tests/handlers/test_worker_lock.py
new file mode 100644
index 0000000000..73e548726c
--- /dev/null
+++ b/tests/handlers/test_worker_lock.py
@@ -0,0 +1,74 @@
+# Copyright 2023 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.server import HomeServer
+from synapse.util import Clock
+
+from tests import unittest
+from tests.replication._base import BaseMultiWorkerStreamTestCase
+
+
+class WorkerLockTestCase(unittest.HomeserverTestCase):
+    def prepare(
+        self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+    ) -> None:
+        self.worker_lock_handler = self.hs.get_worker_locks_handler()
+
+    def test_wait_for_lock_locally(self) -> None:
+        """Test waiting for a lock on a single worker"""
+
+        lock1 = self.worker_lock_handler.acquire_lock("name", "key")
+        self.get_success(lock1.__aenter__())
+
+        lock2 = self.worker_lock_handler.acquire_lock("name", "key")
+        d2 = defer.ensureDeferred(lock2.__aenter__())
+        self.assertNoResult(d2)
+
+        self.get_success(lock1.__aexit__(None, None, None))
+
+        self.get_success(d2)
+        self.get_success(lock2.__aexit__(None, None, None))
+
+
+class WorkerLockWorkersTestCase(BaseMultiWorkerStreamTestCase):
+    def prepare(
+        self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+    ) -> None:
+        self.main_worker_lock_handler = self.hs.get_worker_locks_handler()
+
+    def test_wait_for_lock_worker(self) -> None:
+        """Test waiting for a lock on another worker"""
+
+        worker = self.make_worker_hs(
+            "synapse.app.generic_worker",
+            extra_config={
+                "redis": {"enabled": True},
+            },
+        )
+        worker_lock_handler = worker.get_worker_locks_handler()
+
+        lock1 = self.main_worker_lock_handler.acquire_lock("name", "key")
+        self.get_success(lock1.__aenter__())
+
+        lock2 = worker_lock_handler.acquire_lock("name", "key")
+        d2 = defer.ensureDeferred(lock2.__aenter__())
+        self.assertNoResult(d2)
+
+        self.get_success(lock1.__aexit__(None, None, None))
+
+        self.get_success(d2)
+        self.get_success(lock2.__aexit__(None, None, None))
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index d013e75d55..4f6347be15 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -711,7 +711,7 @@ class RoomsCreateTestCase(RoomBase):
         self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
         self.assertTrue("room_id" in channel.json_body)
         assert channel.resource_usage is not None
-        self.assertEqual(30, channel.resource_usage.db_txn_count)
+        self.assertEqual(32, channel.resource_usage.db_txn_count)
 
     def test_post_room_initial_state(self) -> None:
         # POST with initial_state config key, expect new room id
@@ -724,7 +724,7 @@ class RoomsCreateTestCase(RoomBase):
         self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
         self.assertTrue("room_id" in channel.json_body)
         assert channel.resource_usage is not None
-        self.assertEqual(32, channel.resource_usage.db_txn_count)
+        self.assertEqual(34, channel.resource_usage.db_txn_count)
 
     def test_post_room_visibility_key(self) -> None:
         # POST with visibility config key, expect new room id
diff --git a/tests/storage/databases/main/test_lock.py b/tests/storage/databases/main/test_lock.py
index ad454f6dd8..383da83dfb 100644
--- a/tests/storage/databases/main/test_lock.py
+++ b/tests/storage/databases/main/test_lock.py
@@ -448,3 +448,55 @@ class ReadWriteLockTestCase(unittest.HomeserverTestCase):
         self.get_success(self.store._on_shutdown())
 
         self.assertEqual(self.store._live_read_write_lock_tokens, {})
+
+    def test_acquire_multiple_locks(self) -> None:
+        """Tests that acquiring multiple locks at once works."""
+
+        # Take out multiple locks and ensure that we can't get those locks out
+        # again.
+        lock = self.get_success(
+            self.store.try_acquire_multi_read_write_lock(
+                [("name1", "key1"), ("name2", "key2")], write=True
+            )
+        )
+        self.assertIsNotNone(lock)
+
+        assert lock is not None
+        self.get_success(lock.__aenter__())
+
+        lock2 = self.get_success(
+            self.store.try_acquire_read_write_lock("name1", "key1", write=True)
+        )
+        self.assertIsNone(lock2)
+
+        lock3 = self.get_success(
+            self.store.try_acquire_read_write_lock("name2", "key2", write=False)
+        )
+        self.assertIsNone(lock3)
+
+        # Overlapping locks attempts will fail, and won't lock any locks.
+        lock4 = self.get_success(
+            self.store.try_acquire_multi_read_write_lock(
+                [("name1", "key1"), ("name3", "key3")], write=True
+            )
+        )
+        self.assertIsNone(lock4)
+
+        lock5 = self.get_success(
+            self.store.try_acquire_read_write_lock("name3", "key3", write=True)
+        )
+        self.assertIsNotNone(lock5)
+        assert lock5 is not None
+        self.get_success(lock5.__aenter__())
+        self.get_success(lock5.__aexit__(None, None, None))
+
+        # Once we release the lock we can take out the locks again.
+        self.get_success(lock.__aexit__(None, None, None))
+
+        lock6 = self.get_success(
+            self.store.try_acquire_read_write_lock("name1", "key1", write=True)
+        )
+        self.assertIsNotNone(lock6)
+        assert lock6 is not None
+        self.get_success(lock6.__aenter__())
+        self.get_success(lock6.__aexit__(None, None, None))