summary refs log tree commit diff
path: root/synapse/storage/databases
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r--synapse/storage/databases/main/lock.py19
1 files changed, 18 insertions, 1 deletions
diff --git a/synapse/storage/databases/main/lock.py b/synapse/storage/databases/main/lock.py
index bedacaf0d7..2d7633fbd5 100644
--- a/synapse/storage/databases/main/lock.py
+++ b/synapse/storage/databases/main/lock.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 import logging
 from types import TracebackType
-from typing import TYPE_CHECKING, Optional, Tuple, Type
+from typing import TYPE_CHECKING, Optional, Set, Tuple, Type
 from weakref import WeakValueDictionary
 
 from twisted.internet.interfaces import IReactorCore
@@ -84,6 +84,8 @@ class LockStore(SQLBaseStore):
             self._on_shutdown,
         )
 
+        self._acquiring_locks: Set[Tuple[str, str]] = set()
+
     @wrap_as_background_process("LockStore._on_shutdown")
     async def _on_shutdown(self) -> None:
         """Called when the server is shutting down"""
@@ -103,6 +105,21 @@ class LockStore(SQLBaseStore):
         context manager if the lock is successfully acquired, which *must* be
         used (otherwise the lock will leak).
         """
+        if (lock_name, lock_key) in self._acquiring_locks:
+            return None
+        try:
+            self._acquiring_locks.add((lock_name, lock_key))
+            return await self._try_acquire_lock(lock_name, lock_key)
+        finally:
+            self._acquiring_locks.discard((lock_name, lock_key))
+
+    async def _try_acquire_lock(
+        self, lock_name: str, lock_key: str
+    ) -> Optional["Lock"]:
+        """Try to acquire a lock for the given name/key. Will return an async
+        context manager if the lock is successfully acquired, which *must* be
+        used (otherwise the lock will leak).
+        """
 
         # Check if this process has taken out a lock and if it's still valid.
         lock = self._live_tokens.get((lock_name, lock_key))