summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/controllers/persist_events.py27
-rw-r--r--synapse/storage/databases/main/lock.py190
2 files changed, 142 insertions, 75 deletions
diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py
index 35c0680365..35cd1089d6 100644
--- a/synapse/storage/controllers/persist_events.py
+++ b/synapse/storage/controllers/persist_events.py
@@ -45,6 +45,7 @@ from twisted.internet import defer
 from synapse.api.constants import EventTypes, Membership
 from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
+from synapse.handlers.worker_lock import DELETE_ROOM_LOCK_NAME
 from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
 from synapse.logging.opentracing import (
     SynapseTags,
@@ -338,6 +339,7 @@ class EventsPersistenceStorageController:
         )
         self._state_resolution_handler = hs.get_state_resolution_handler()
         self._state_controller = state_controller
+        self.hs = hs
 
     async def _process_event_persist_queue_task(
         self,
@@ -350,15 +352,22 @@ class EventsPersistenceStorageController:
             A dictionary of event ID to event ID we didn't persist as we already
             had another event persisted with the same TXN ID.
         """
-        if isinstance(task, _PersistEventsTask):
-            return await self._persist_event_batch(room_id, task)
-        elif isinstance(task, _UpdateCurrentStateTask):
-            await self._update_current_state(room_id, task)
-            return {}
-        else:
-            raise AssertionError(
-                f"Found an unexpected task type in event persistence queue: {task}"
-            )
+
+        # Ensure that the room can't be deleted while we're persisting events to
+        # it. We might already have taken out the lock, but since this is just a
+        # "read" lock its inherently reentrant.
+        async with self.hs.get_worker_locks_handler().acquire_read_write_lock(
+            DELETE_ROOM_LOCK_NAME, room_id, write=False
+        ):
+            if isinstance(task, _PersistEventsTask):
+                return await self._persist_event_batch(room_id, task)
+            elif isinstance(task, _UpdateCurrentStateTask):
+                await self._update_current_state(room_id, task)
+                return {}
+            else:
+                raise AssertionError(
+                    f"Found an unexpected task type in event persistence queue: {task}"
+                )
 
     @trace
     async def persist_events(
diff --git a/synapse/storage/databases/main/lock.py b/synapse/storage/databases/main/lock.py
index c89b4f7919..1680bf6168 100644
--- a/synapse/storage/databases/main/lock.py
+++ b/synapse/storage/databases/main/lock.py
@@ -12,8 +12,9 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+from contextlib import AsyncExitStack
 from types import TracebackType
-from typing import TYPE_CHECKING, Optional, Set, Tuple, Type
+from typing import TYPE_CHECKING, Collection, Optional, Set, Tuple, Type
 from weakref import WeakValueDictionary
 
 from twisted.internet.interfaces import IReactorCore
@@ -208,76 +209,85 @@ class LockStore(SQLBaseStore):
         used (otherwise the lock will leak).
         """
 
+        try:
+            lock = await self.db_pool.runInteraction(
+                "try_acquire_read_write_lock",
+                self._try_acquire_read_write_lock_txn,
+                lock_name,
+                lock_key,
+                write,
+            )
+        except self.database_engine.module.IntegrityError:
+            return None
+
+        return lock
+
+    def _try_acquire_read_write_lock_txn(
+        self,
+        txn: LoggingTransaction,
+        lock_name: str,
+        lock_key: str,
+        write: bool,
+    ) -> "Lock":
+        # We attempt to acquire the lock by inserting into
+        # `worker_read_write_locks` and seeing if that fails any
+        # constraints. If it doesn't then we have acquired the lock,
+        # otherwise we haven't.
+        #
+        # Before that though we clear the table of any stale locks.
+
         now = self._clock.time_msec()
         token = random_string(6)
 
-        def _try_acquire_read_write_lock_txn(txn: LoggingTransaction) -> None:
-            # We attempt to acquire the lock by inserting into
-            # `worker_read_write_locks` and seeing if that fails any
-            # constraints. If it doesn't then we have acquired the lock,
-            # otherwise we haven't.
-            #
-            # Before that though we clear the table of any stale locks.
-
-            delete_sql = """
-                DELETE FROM worker_read_write_locks
-                    WHERE last_renewed_ts < ? AND lock_name = ? AND lock_key = ?;
-            """
-
-            insert_sql = """
-                INSERT INTO worker_read_write_locks (lock_name, lock_key, write_lock, instance_name, token, last_renewed_ts)
-                VALUES (?, ?, ?, ?, ?, ?)
-            """
-
-            if isinstance(self.database_engine, PostgresEngine):
-                # For Postgres we can send these queries at the same time.
-                txn.execute(
-                    delete_sql + ";" + insert_sql,
-                    (
-                        # DELETE args
-                        now - _LOCK_TIMEOUT_MS,
-                        lock_name,
-                        lock_key,
-                        # UPSERT args
-                        lock_name,
-                        lock_key,
-                        write,
-                        self._instance_name,
-                        token,
-                        now,
-                    ),
-                )
-            else:
-                # For SQLite these need to be two queries.
-                txn.execute(
-                    delete_sql,
-                    (
-                        now - _LOCK_TIMEOUT_MS,
-                        lock_name,
-                        lock_key,
-                    ),
-                )
-                txn.execute(
-                    insert_sql,
-                    (
-                        lock_name,
-                        lock_key,
-                        write,
-                        self._instance_name,
-                        token,
-                        now,
-                    ),
-                )
+        delete_sql = """
+            DELETE FROM worker_read_write_locks
+                WHERE last_renewed_ts < ? AND lock_name = ? AND lock_key = ?;
+        """
 
-            return
+        insert_sql = """
+            INSERT INTO worker_read_write_locks (lock_name, lock_key, write_lock, instance_name, token, last_renewed_ts)
+            VALUES (?, ?, ?, ?, ?, ?)
+        """
 
-        try:
-            await self.db_pool.runInteraction(
-                "try_acquire_read_write_lock",
-                _try_acquire_read_write_lock_txn,
+        if isinstance(self.database_engine, PostgresEngine):
+            # For Postgres we can send these queries at the same time.
+            txn.execute(
+                delete_sql + ";" + insert_sql,
+                (
+                    # DELETE args
+                    now - _LOCK_TIMEOUT_MS,
+                    lock_name,
+                    lock_key,
+                    # UPSERT args
+                    lock_name,
+                    lock_key,
+                    write,
+                    self._instance_name,
+                    token,
+                    now,
+                ),
+            )
+        else:
+            # For SQLite these need to be two queries.
+            txn.execute(
+                delete_sql,
+                (
+                    now - _LOCK_TIMEOUT_MS,
+                    lock_name,
+                    lock_key,
+                ),
+            )
+            txn.execute(
+                insert_sql,
+                (
+                    lock_name,
+                    lock_key,
+                    write,
+                    self._instance_name,
+                    token,
+                    now,
+                ),
             )
-        except self.database_engine.module.IntegrityError:
-            return None
 
         lock = Lock(
             self._reactor,
@@ -289,10 +299,58 @@ class LockStore(SQLBaseStore):
             token=token,
         )
 
-        self._live_read_write_lock_tokens[(lock_name, lock_key, token)] = lock
+        def set_lock() -> None:
+            self._live_read_write_lock_tokens[(lock_name, lock_key, token)] = lock
+
+        txn.call_after(set_lock)
 
         return lock
 
+    async def try_acquire_multi_read_write_lock(
+        self,
+        lock_names: Collection[Tuple[str, str]],
+        write: bool,
+    ) -> Optional[AsyncExitStack]:
+        """Try to acquire multiple locks for the given names/keys. Will return
+        an async context manager if the locks are successfully acquired, which
+        *must* be used (otherwise the lock will leak).
+
+        If only a subset of the locks can be acquired then it will immediately
+        drop them and return `None`.
+        """
+        try:
+            locks = await self.db_pool.runInteraction(
+                "try_acquire_multi_read_write_lock",
+                self._try_acquire_multi_read_write_lock_txn,
+                lock_names,
+                write,
+            )
+        except self.database_engine.module.IntegrityError:
+            return None
+
+        stack = AsyncExitStack()
+
+        for lock in locks:
+            await stack.enter_async_context(lock)
+
+        return stack
+
+    def _try_acquire_multi_read_write_lock_txn(
+        self,
+        txn: LoggingTransaction,
+        lock_names: Collection[Tuple[str, str]],
+        write: bool,
+    ) -> Collection["Lock"]:
+        locks = []
+
+        for lock_name, lock_key in lock_names:
+            lock = self._try_acquire_read_write_lock_txn(
+                txn, lock_name, lock_key, write
+            )
+            locks.append(lock)
+
+        return locks
+
 
 class Lock:
     """An async context manager that manages an acquired lock, ensuring it is