diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py
index 35c0680365..35cd1089d6 100644
--- a/synapse/storage/controllers/persist_events.py
+++ b/synapse/storage/controllers/persist_events.py
@@ -45,6 +45,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
+from synapse.handlers.worker_lock import DELETE_ROOM_LOCK_NAME
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
from synapse.logging.opentracing import (
SynapseTags,
@@ -338,6 +339,7 @@ class EventsPersistenceStorageController:
)
self._state_resolution_handler = hs.get_state_resolution_handler()
self._state_controller = state_controller
+ self.hs = hs
async def _process_event_persist_queue_task(
self,
@@ -350,15 +352,22 @@ class EventsPersistenceStorageController:
A dictionary of event ID to event ID we didn't persist as we already
had another event persisted with the same TXN ID.
"""
- if isinstance(task, _PersistEventsTask):
- return await self._persist_event_batch(room_id, task)
- elif isinstance(task, _UpdateCurrentStateTask):
- await self._update_current_state(room_id, task)
- return {}
- else:
- raise AssertionError(
- f"Found an unexpected task type in event persistence queue: {task}"
- )
+
+ # Ensure that the room can't be deleted while we're persisting events to
+ # it. We might already have taken out the lock, but since this is just a
+ # "read" lock its inherently reentrant.
+ async with self.hs.get_worker_locks_handler().acquire_read_write_lock(
+ DELETE_ROOM_LOCK_NAME, room_id, write=False
+ ):
+ if isinstance(task, _PersistEventsTask):
+ return await self._persist_event_batch(room_id, task)
+ elif isinstance(task, _UpdateCurrentStateTask):
+ await self._update_current_state(room_id, task)
+ return {}
+ else:
+ raise AssertionError(
+ f"Found an unexpected task type in event persistence queue: {task}"
+ )
@trace
async def persist_events(
diff --git a/synapse/storage/databases/main/lock.py b/synapse/storage/databases/main/lock.py
index c89b4f7919..1680bf6168 100644
--- a/synapse/storage/databases/main/lock.py
+++ b/synapse/storage/databases/main/lock.py
@@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from contextlib import AsyncExitStack
from types import TracebackType
-from typing import TYPE_CHECKING, Optional, Set, Tuple, Type
+from typing import TYPE_CHECKING, Collection, Optional, Set, Tuple, Type
from weakref import WeakValueDictionary
from twisted.internet.interfaces import IReactorCore
@@ -208,76 +209,85 @@ class LockStore(SQLBaseStore):
used (otherwise the lock will leak).
"""
+ try:
+ lock = await self.db_pool.runInteraction(
+ "try_acquire_read_write_lock",
+ self._try_acquire_read_write_lock_txn,
+ lock_name,
+ lock_key,
+ write,
+ )
+ except self.database_engine.module.IntegrityError:
+ return None
+
+ return lock
+
+ def _try_acquire_read_write_lock_txn(
+ self,
+ txn: LoggingTransaction,
+ lock_name: str,
+ lock_key: str,
+ write: bool,
+ ) -> "Lock":
+ # We attempt to acquire the lock by inserting into
+ # `worker_read_write_locks` and seeing if that fails any
+ # constraints. If it doesn't then we have acquired the lock,
+ # otherwise we haven't.
+ #
+ # Before that though we clear the table of any stale locks.
+
now = self._clock.time_msec()
token = random_string(6)
- def _try_acquire_read_write_lock_txn(txn: LoggingTransaction) -> None:
- # We attempt to acquire the lock by inserting into
- # `worker_read_write_locks` and seeing if that fails any
- # constraints. If it doesn't then we have acquired the lock,
- # otherwise we haven't.
- #
- # Before that though we clear the table of any stale locks.
-
- delete_sql = """
- DELETE FROM worker_read_write_locks
- WHERE last_renewed_ts < ? AND lock_name = ? AND lock_key = ?;
- """
-
- insert_sql = """
- INSERT INTO worker_read_write_locks (lock_name, lock_key, write_lock, instance_name, token, last_renewed_ts)
- VALUES (?, ?, ?, ?, ?, ?)
- """
-
- if isinstance(self.database_engine, PostgresEngine):
- # For Postgres we can send these queries at the same time.
- txn.execute(
- delete_sql + ";" + insert_sql,
- (
- # DELETE args
- now - _LOCK_TIMEOUT_MS,
- lock_name,
- lock_key,
- # UPSERT args
- lock_name,
- lock_key,
- write,
- self._instance_name,
- token,
- now,
- ),
- )
- else:
- # For SQLite these need to be two queries.
- txn.execute(
- delete_sql,
- (
- now - _LOCK_TIMEOUT_MS,
- lock_name,
- lock_key,
- ),
- )
- txn.execute(
- insert_sql,
- (
- lock_name,
- lock_key,
- write,
- self._instance_name,
- token,
- now,
- ),
- )
+ delete_sql = """
+ DELETE FROM worker_read_write_locks
+ WHERE last_renewed_ts < ? AND lock_name = ? AND lock_key = ?;
+ """
- return
+ insert_sql = """
+ INSERT INTO worker_read_write_locks (lock_name, lock_key, write_lock, instance_name, token, last_renewed_ts)
+ VALUES (?, ?, ?, ?, ?, ?)
+ """
- try:
- await self.db_pool.runInteraction(
- "try_acquire_read_write_lock",
- _try_acquire_read_write_lock_txn,
+ if isinstance(self.database_engine, PostgresEngine):
+ # For Postgres we can send these queries at the same time.
+ txn.execute(
+ delete_sql + ";" + insert_sql,
+ (
+ # DELETE args
+ now - _LOCK_TIMEOUT_MS,
+ lock_name,
+ lock_key,
+ # UPSERT args
+ lock_name,
+ lock_key,
+ write,
+ self._instance_name,
+ token,
+ now,
+ ),
+ )
+ else:
+ # For SQLite these need to be two queries.
+ txn.execute(
+ delete_sql,
+ (
+ now - _LOCK_TIMEOUT_MS,
+ lock_name,
+ lock_key,
+ ),
+ )
+ txn.execute(
+ insert_sql,
+ (
+ lock_name,
+ lock_key,
+ write,
+ self._instance_name,
+ token,
+ now,
+ ),
)
- except self.database_engine.module.IntegrityError:
- return None
lock = Lock(
self._reactor,
@@ -289,10 +299,58 @@ class LockStore(SQLBaseStore):
token=token,
)
- self._live_read_write_lock_tokens[(lock_name, lock_key, token)] = lock
+ def set_lock() -> None:
+ self._live_read_write_lock_tokens[(lock_name, lock_key, token)] = lock
+
+ txn.call_after(set_lock)
return lock
+ async def try_acquire_multi_read_write_lock(
+ self,
+ lock_names: Collection[Tuple[str, str]],
+ write: bool,
+ ) -> Optional[AsyncExitStack]:
+ """Try to acquire multiple locks for the given names/keys. Will return
+ an async context manager if the locks are successfully acquired, which
+ *must* be used (otherwise the lock will leak).
+
+ If only a subset of the locks can be acquired then it will immediately
+ drop them and return `None`.
+ """
+ try:
+ locks = await self.db_pool.runInteraction(
+ "try_acquire_multi_read_write_lock",
+ self._try_acquire_multi_read_write_lock_txn,
+ lock_names,
+ write,
+ )
+ except self.database_engine.module.IntegrityError:
+ return None
+
+ stack = AsyncExitStack()
+
+ for lock in locks:
+ await stack.enter_async_context(lock)
+
+ return stack
+
+ def _try_acquire_multi_read_write_lock_txn(
+ self,
+ txn: LoggingTransaction,
+ lock_names: Collection[Tuple[str, str]],
+ write: bool,
+ ) -> Collection["Lock"]:
+ locks = []
+
+ for lock_name, lock_key in lock_names:
+ lock = self._try_acquire_read_write_lock_txn(
+ txn, lock_name, lock_key, write
+ )
+ locks.append(lock)
+
+ return locks
+
class Lock:
"""An async context manager that manages an acquired lock, ensuring it is
|