diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index fff0b5fa12..187dedae7d 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -53,6 +53,7 @@ from synapse.events.snapshot import EventContext, UnpersistedEventContextBase
from synapse.events.utils import SerializeEventConfig, maybe_upsert_event_field
from synapse.events.validator import EventValidator
from synapse.handlers.directory import DirectoryHandler
+from synapse.handlers.worker_lock import DELETE_ROOM_LOCK_NAME
from synapse.logging import opentracing
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
@@ -485,6 +486,7 @@ class EventCreationHandler:
self._events_shard_config = self.config.worker.events_shard_config
self._instance_name = hs.get_instance_name()
self._notifier = hs.get_notifier()
+ self._worker_lock_handler = hs.get_worker_locks_handler()
self.room_prejoin_state_types = self.hs.config.api.room_prejoin_state
@@ -1010,6 +1012,37 @@ class EventCreationHandler:
event.internal_metadata.stream_ordering,
)
+ async with self._worker_lock_handler.acquire_read_write_lock(
+ DELETE_ROOM_LOCK_NAME, room_id, write=False
+ ):
+ return await self._create_and_send_nonmember_event_locked(
+ requester=requester,
+ event_dict=event_dict,
+ allow_no_prev_events=allow_no_prev_events,
+ prev_event_ids=prev_event_ids,
+ state_event_ids=state_event_ids,
+ ratelimit=ratelimit,
+ txn_id=txn_id,
+ ignore_shadow_ban=ignore_shadow_ban,
+ outlier=outlier,
+ depth=depth,
+ )
+
+ async def _create_and_send_nonmember_event_locked(
+ self,
+ requester: Requester,
+ event_dict: dict,
+ allow_no_prev_events: bool = False,
+ prev_event_ids: Optional[List[str]] = None,
+ state_event_ids: Optional[List[str]] = None,
+ ratelimit: bool = True,
+ txn_id: Optional[str] = None,
+ ignore_shadow_ban: bool = False,
+ outlier: bool = False,
+ depth: Optional[int] = None,
+ ) -> Tuple[EventBase, int]:
+ room_id = event_dict["room_id"]
+
# If we don't have any prev event IDs specified then we need to
# check that the host is in the room (as otherwise populating the
# prev events will fail), at which point we may as well check the
@@ -1923,7 +1956,10 @@ class EventCreationHandler:
)
for room_id in room_ids:
- dummy_event_sent = await self._send_dummy_event_for_room(room_id)
+ async with self._worker_lock_handler.acquire_read_write_lock(
+ DELETE_ROOM_LOCK_NAME, room_id, write=False
+ ):
+ dummy_event_sent = await self._send_dummy_event_for_room(room_id)
if not dummy_event_sent:
# Did not find a valid user in the room, so remove from future attempts
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 19b8728db9..da34658470 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -46,6 +46,11 @@ logger = logging.getLogger(__name__)
BACKFILL_BECAUSE_TOO_MANY_GAPS_THRESHOLD = 3
+PURGE_HISTORY_LOCK_NAME = "purge_history_lock"
+
+DELETE_ROOM_LOCK_NAME = "delete_room_lock"
+
+
@attr.s(slots=True, auto_attribs=True)
class PurgeStatus:
"""Object tracking the status of a purge request
@@ -142,6 +147,7 @@ class PaginationHandler:
self._server_name = hs.hostname
self._room_shutdown_handler = hs.get_room_shutdown_handler()
self._relations_handler = hs.get_relations_handler()
+ self._worker_locks = hs.get_worker_locks_handler()
self.pagination_lock = ReadWriteLock()
# IDs of rooms in which there currently an active purge *or delete* operation.
@@ -356,7 +362,9 @@ class PaginationHandler:
"""
self._purges_in_progress_by_room.add(room_id)
try:
- async with self.pagination_lock.write(room_id):
+ async with self._worker_locks.acquire_read_write_lock(
+ PURGE_HISTORY_LOCK_NAME, room_id, write=True
+ ):
await self._storage_controllers.purge_events.purge_history(
room_id, token, delete_local_events
)
@@ -412,7 +420,10 @@ class PaginationHandler:
room_id: room to be purged
force: set true to skip checking for joined users.
"""
- async with self.pagination_lock.write(room_id):
+ async with self._worker_locks.acquire_multi_read_write_lock(
+ [(PURGE_HISTORY_LOCK_NAME, room_id), (DELETE_ROOM_LOCK_NAME, room_id)],
+ write=True,
+ ):
# first check that we have no users in this room
if not force:
joined = await self.store.is_host_joined(room_id, self._server_name)
@@ -471,7 +482,9 @@ class PaginationHandler:
room_token = from_token.room_key
- async with self.pagination_lock.read(room_id):
+ async with self._worker_locks.acquire_read_write_lock(
+ PURGE_HISTORY_LOCK_NAME, room_id, write=False
+ ):
(membership, member_event_id) = (None, None)
if not use_admin_priviledge:
(
@@ -747,7 +760,9 @@ class PaginationHandler:
self._purges_in_progress_by_room.add(room_id)
try:
- async with self.pagination_lock.write(room_id):
+ async with self._worker_locks.acquire_read_write_lock(
+ PURGE_HISTORY_LOCK_NAME, room_id, write=True
+ ):
self._delete_by_id[delete_id].status = DeleteStatus.STATUS_SHUTTING_DOWN
self._delete_by_id[
delete_id
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 496e701f13..6cca2ec344 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -39,6 +39,7 @@ from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN
from synapse.handlers.state_deltas import MatchChange, StateDeltasHandler
+from synapse.handlers.worker_lock import DELETE_ROOM_LOCK_NAME
from synapse.logging import opentracing
from synapse.metrics import event_processing_positions
from synapse.metrics.background_process_metrics import run_as_background_process
@@ -94,6 +95,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
self.event_creation_handler = hs.get_event_creation_handler()
self.account_data_handler = hs.get_account_data_handler()
self.event_auth_handler = hs.get_event_auth_handler()
+ self._worker_lock_handler = hs.get_worker_locks_handler()
self.member_linearizer: Linearizer = Linearizer(name="member")
self.member_as_limiter = Linearizer(max_count=10, name="member_as_limiter")
@@ -638,26 +640,29 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# by application services), and then by room ID.
async with self.member_as_limiter.queue(as_id):
async with self.member_linearizer.queue(key):
- with opentracing.start_active_span("update_membership_locked"):
- result = await self.update_membership_locked(
- requester,
- target,
- room_id,
- action,
- txn_id=txn_id,
- remote_room_hosts=remote_room_hosts,
- third_party_signed=third_party_signed,
- ratelimit=ratelimit,
- content=content,
- new_room=new_room,
- require_consent=require_consent,
- outlier=outlier,
- allow_no_prev_events=allow_no_prev_events,
- prev_event_ids=prev_event_ids,
- state_event_ids=state_event_ids,
- depth=depth,
- origin_server_ts=origin_server_ts,
- )
+ async with self._worker_lock_handler.acquire_read_write_lock(
+ DELETE_ROOM_LOCK_NAME, room_id, write=False
+ ):
+ with opentracing.start_active_span("update_membership_locked"):
+ result = await self.update_membership_locked(
+ requester,
+ target,
+ room_id,
+ action,
+ txn_id=txn_id,
+ remote_room_hosts=remote_room_hosts,
+ third_party_signed=third_party_signed,
+ ratelimit=ratelimit,
+ content=content,
+ new_room=new_room,
+ require_consent=require_consent,
+ outlier=outlier,
+ allow_no_prev_events=allow_no_prev_events,
+ prev_event_ids=prev_event_ids,
+ state_event_ids=state_event_ids,
+ depth=depth,
+ origin_server_ts=origin_server_ts,
+ )
return result
diff --git a/synapse/handlers/worker_lock.py b/synapse/handlers/worker_lock.py
new file mode 100644
index 0000000000..72df773a86
--- /dev/null
+++ b/synapse/handlers/worker_lock.py
@@ -0,0 +1,333 @@
+# Copyright 2023 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import random
+from types import TracebackType
+from typing import (
+ TYPE_CHECKING,
+ AsyncContextManager,
+ Collection,
+ Dict,
+ Optional,
+ Tuple,
+ Type,
+ Union,
+)
+from weakref import WeakSet
+
+import attr
+
+from twisted.internet import defer
+from twisted.internet.interfaces import IReactorTime
+
+from synapse.logging.context import PreserveLoggingContext
+from synapse.logging.opentracing import start_active_span
+from synapse.metrics.background_process_metrics import wrap_as_background_process
+from synapse.storage.databases.main.lock import Lock, LockStore
+from synapse.util.async_helpers import timeout_deferred
+
+if TYPE_CHECKING:
+ from synapse.logging.opentracing import opentracing
+ from synapse.server import HomeServer
+
+
+DELETE_ROOM_LOCK_NAME = "delete_room_lock"
+
+
+class WorkerLocksHandler:
+ """A class for waiting on taking out locks, rather than using the storage
+ functions directly (which don't support awaiting).
+ """
+
+ def __init__(self, hs: "HomeServer") -> None:
+ self._reactor = hs.get_reactor()
+ self._store = hs.get_datastores().main
+ self._clock = hs.get_clock()
+ self._notifier = hs.get_notifier()
+ self._instance_name = hs.get_instance_name()
+
+ # Map from lock name/key to set of `WaitingLock` that are active for
+ # that lock.
+ self._locks: Dict[
+ Tuple[str, str], WeakSet[Union[WaitingLock, WaitingMultiLock]]
+ ] = {}
+
+ self._clock.looping_call(self._cleanup_locks, 30_000)
+
+ self._notifier.add_lock_released_callback(self._on_lock_released)
+
+ def acquire_lock(self, lock_name: str, lock_key: str) -> "WaitingLock":
+ """Acquire a standard lock, returns a context manager that will block
+ until the lock is acquired.
+
+ Note: Care must be taken to avoid deadlocks. In particular, this
+ function does *not* timeout.
+
+ Usage:
+ async with handler.acquire_lock(name, key):
+ # Do work while holding the lock...
+ """
+
+ lock = WaitingLock(
+ reactor=self._reactor,
+ store=self._store,
+ handler=self,
+ lock_name=lock_name,
+ lock_key=lock_key,
+ write=None,
+ )
+
+ self._locks.setdefault((lock_name, lock_key), WeakSet()).add(lock)
+
+ return lock
+
+ def acquire_read_write_lock(
+ self,
+ lock_name: str,
+ lock_key: str,
+ *,
+ write: bool,
+ ) -> "WaitingLock":
+ """Acquire a read/write lock, returns a context manager that will block
+ until the lock is acquired.
+
+ Note: Care must be taken to avoid deadlocks. In particular, this
+ function does *not* timeout.
+
+ Usage:
+ async with handler.acquire_read_write_lock(name, key, write=True):
+ # Do work while holding the lock...
+ """
+
+ lock = WaitingLock(
+ reactor=self._reactor,
+ store=self._store,
+ handler=self,
+ lock_name=lock_name,
+ lock_key=lock_key,
+ write=write,
+ )
+
+ self._locks.setdefault((lock_name, lock_key), WeakSet()).add(lock)
+
+ return lock
+
+ def acquire_multi_read_write_lock(
+ self,
+ lock_names: Collection[Tuple[str, str]],
+ *,
+ write: bool,
+ ) -> "WaitingMultiLock":
+ """Acquires multi read/write locks at once, returns a context manager
+ that will block until all the locks are acquired.
+
+ This will try and acquire all locks at once, and will never hold on to a
+ subset of the locks. (This avoids accidentally creating deadlocks).
+
+ Note: Care must be taken to avoid deadlocks. In particular, this
+ function does *not* timeout.
+ """
+
+ lock = WaitingMultiLock(
+ lock_names=lock_names,
+ write=write,
+ reactor=self._reactor,
+ store=self._store,
+ handler=self,
+ )
+
+ for lock_name, lock_key in lock_names:
+ self._locks.setdefault((lock_name, lock_key), WeakSet()).add(lock)
+
+ return lock
+
+ def notify_lock_released(self, lock_name: str, lock_key: str) -> None:
+ """Notify that a lock has been released.
+
+ Pokes both the notifier and replication.
+ """
+
+ self._notifier.notify_lock_released(self._instance_name, lock_name, lock_key)
+
+ def _on_lock_released(
+ self, instance_name: str, lock_name: str, lock_key: str
+ ) -> None:
+ """Called when a lock has been released.
+
+ Wakes up any locks that might be waiting on this.
+ """
+ locks = self._locks.get((lock_name, lock_key))
+ if not locks:
+ return
+
+ def _wake_deferred(deferred: defer.Deferred) -> None:
+ if not deferred.called:
+ deferred.callback(None)
+
+ for lock in locks:
+ self._clock.call_later(0, _wake_deferred, lock.deferred)
+
+ @wrap_as_background_process("_cleanup_locks")
+ async def _cleanup_locks(self) -> None:
+ """Periodically cleans out stale entries in the locks map"""
+ self._locks = {key: value for key, value in self._locks.items() if value}
+
+
+@attr.s(auto_attribs=True, eq=False)
+class WaitingLock:
+ reactor: IReactorTime
+ store: LockStore
+ handler: WorkerLocksHandler
+ lock_name: str
+ lock_key: str
+ write: Optional[bool]
+ deferred: "defer.Deferred[None]" = attr.Factory(defer.Deferred)
+ _inner_lock: Optional[Lock] = None
+ _retry_interval: float = 0.1
+ _lock_span: "opentracing.Scope" = attr.Factory(
+ lambda: start_active_span("WaitingLock.lock")
+ )
+
+ async def __aenter__(self) -> None:
+ self._lock_span.__enter__()
+
+ with start_active_span("WaitingLock.waiting_for_lock"):
+ while self._inner_lock is None:
+ self.deferred = defer.Deferred()
+
+ if self.write is not None:
+ lock = await self.store.try_acquire_read_write_lock(
+ self.lock_name, self.lock_key, write=self.write
+ )
+ else:
+ lock = await self.store.try_acquire_lock(
+ self.lock_name, self.lock_key
+ )
+
+ if lock:
+ self._inner_lock = lock
+ break
+
+ try:
+ # Wait until the we get notified the lock might have been
+ # released (by the deferred being resolved). We also
+ # periodically wake up in case the lock was released but we
+ # weren't notified.
+ with PreserveLoggingContext():
+ await timeout_deferred(
+ deferred=self.deferred,
+ timeout=self._get_next_retry_interval(),
+ reactor=self.reactor,
+ )
+ except Exception:
+ pass
+
+ return await self._inner_lock.__aenter__()
+
+ async def __aexit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc: Optional[BaseException],
+ tb: Optional[TracebackType],
+ ) -> Optional[bool]:
+ assert self._inner_lock
+
+ self.handler.notify_lock_released(self.lock_name, self.lock_key)
+
+ try:
+ r = await self._inner_lock.__aexit__(exc_type, exc, tb)
+ finally:
+ self._lock_span.__exit__(exc_type, exc, tb)
+
+ return r
+
+ def _get_next_retry_interval(self) -> float:
+ next = self._retry_interval
+ self._retry_interval = max(5, next * 2)
+ return next * random.uniform(0.9, 1.1)
+
+
+@attr.s(auto_attribs=True, eq=False)
+class WaitingMultiLock:
+ lock_names: Collection[Tuple[str, str]]
+
+ write: bool
+
+ reactor: IReactorTime
+ store: LockStore
+ handler: WorkerLocksHandler
+
+ deferred: "defer.Deferred[None]" = attr.Factory(defer.Deferred)
+
+ _inner_lock_cm: Optional[AsyncContextManager] = None
+ _retry_interval: float = 0.1
+ _lock_span: "opentracing.Scope" = attr.Factory(
+ lambda: start_active_span("WaitingLock.lock")
+ )
+
+ async def __aenter__(self) -> None:
+ self._lock_span.__enter__()
+
+ with start_active_span("WaitingLock.waiting_for_lock"):
+ while self._inner_lock_cm is None:
+ self.deferred = defer.Deferred()
+
+ lock_cm = await self.store.try_acquire_multi_read_write_lock(
+ self.lock_names, write=self.write
+ )
+
+ if lock_cm:
+ self._inner_lock_cm = lock_cm
+ break
+
+ try:
+ # Wait until the we get notified the lock might have been
+ # released (by the deferred being resolved). We also
+ # periodically wake up in case the lock was released but we
+ # weren't notified.
+ with PreserveLoggingContext():
+ await timeout_deferred(
+ deferred=self.deferred,
+ timeout=self._get_next_retry_interval(),
+ reactor=self.reactor,
+ )
+ except Exception:
+ pass
+
+ assert self._inner_lock_cm
+ await self._inner_lock_cm.__aenter__()
+ return
+
+ async def __aexit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc: Optional[BaseException],
+ tb: Optional[TracebackType],
+ ) -> Optional[bool]:
+ assert self._inner_lock_cm
+
+ for lock_name, lock_key in self.lock_names:
+ self.handler.notify_lock_released(lock_name, lock_key)
+
+ try:
+ r = await self._inner_lock_cm.__aexit__(exc_type, exc, tb)
+ finally:
+ self._lock_span.__exit__(exc_type, exc, tb)
+
+ return r
+
+ def _get_next_retry_interval(self) -> float:
+ next = self._retry_interval
+ self._retry_interval = max(5, next * 2)
+ return next * random.uniform(0.9, 1.1)
|