diff --git a/changelog.d/10269.misc b/changelog.d/10269.misc
new file mode 100644
index 0000000000..23e590490c
--- /dev/null
+++ b/changelog.d/10269.misc
@@ -0,0 +1 @@
+Add a distributed lock implementation.
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index af8a1833f3..5b041fcaad 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -108,6 +108,7 @@ from synapse.server import HomeServer
from synapse.storage.databases.main.censor_events import CensorEventsStore
from synapse.storage.databases.main.client_ips import ClientIpWorkerStore
from synapse.storage.databases.main.e2e_room_keys import EndToEndRoomKeyStore
+from synapse.storage.databases.main.lock import LockStore
from synapse.storage.databases.main.media_repository import MediaRepositoryStore
from synapse.storage.databases.main.metrics import ServerMetricsStore
from synapse.storage.databases.main.monthly_active_users import (
@@ -249,6 +250,7 @@ class GenericWorkerSlavedStore(
ServerMetricsStore,
SearchStore,
TransactionWorkerStore,
+ LockStore,
BaseSlavedStore,
):
pass
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 9cce62ae6c..a3fddea042 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -46,6 +46,7 @@ from .events_forward_extremities import EventForwardExtremitiesStore
from .filtering import FilteringStore
from .group_server import GroupServerStore
from .keys import KeyStore
+from .lock import LockStore
from .media_repository import MediaRepositoryStore
from .metrics import ServerMetricsStore
from .monthly_active_users import MonthlyActiveUsersStore
@@ -119,6 +120,7 @@ class DataStore(
CacheInvalidationWorkerStore,
ServerMetricsStore,
EventForwardExtremitiesStore,
+ LockStore,
):
def __init__(self, database: DatabasePool, db_conn, hs):
self.hs = hs
diff --git a/synapse/storage/databases/main/lock.py b/synapse/storage/databases/main/lock.py
new file mode 100644
index 0000000000..e76188328c
--- /dev/null
+++ b/synapse/storage/databases/main/lock.py
@@ -0,0 +1,334 @@
+# Copyright 2021 Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+from types import TracebackType
+from typing import TYPE_CHECKING, Dict, Optional, Tuple, Type
+
+from twisted.internet.interfaces import IReactorCore
+
+from synapse.metrics.background_process_metrics import wrap_as_background_process
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.types import Connection
+from synapse.util import Clock
+from synapse.util.stringutils import random_string
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+
+logger = logging.getLogger(__name__)
+
+
+# How often to renew an acquired lock by updating the `last_renewed_ts` time in
+# the lock table.
+_RENEWAL_INTERVAL_MS = 30 * 1000
+
+# How long before an acquired lock times out.
+_LOCK_TIMEOUT_MS = 2 * 60 * 1000
+
+
+class LockStore(SQLBaseStore):
+ """Provides a best effort distributed lock between worker instances.
+
+ Locks are identified by a name and key. A lock is acquired by inserting into
+ the `worker_locks` table if a) there is no existing row for the name/key or
+ b) the existing row has a `last_renewed_ts` older than `_LOCK_TIMEOUT_MS`.
+
+ When a lock is taken out the instance inserts a random `token`, the instance
+ that holds that token holds the lock until it drops (or times out).
+
+ The instance that holds the lock should regularly update the
+ `last_renewed_ts` column with the current time.
+ """
+
+ def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+ super().__init__(database, db_conn, hs)
+
+ self._reactor = hs.get_reactor()
+ self._instance_name = hs.get_instance_id()
+
+ # A map from `(lock_name, lock_key)` to the token of any locks that we
+ # think we currently hold.
+ self._live_tokens: Dict[Tuple[str, str], str] = {}
+
+ # When we shut down we want to remove the locks. Technically this can
+ # lead to a race, as we may drop the lock while we are still processing.
+ # However, a) it should be a small window, b) the lock is best effort
+ # anyway and c) we want to really avoid leaking locks when we restart.
+ hs.get_reactor().addSystemEventTrigger(
+ "before",
+ "shutdown",
+ self._on_shutdown,
+ )
+
+ @wrap_as_background_process("LockStore._on_shutdown")
+ async def _on_shutdown(self) -> None:
+ """Called when the server is shutting down"""
+ logger.info("Dropping held locks due to shutdown")
+
+ for (lock_name, lock_key), token in self._live_tokens.items():
+ await self._drop_lock(lock_name, lock_key, token)
+
+ logger.info("Dropped locks due to shutdown")
+
+ async def try_acquire_lock(self, lock_name: str, lock_key: str) -> Optional["Lock"]:
+ """Try to acquire a lock for the given name/key. Will return an async
+ context manager if the lock is successfully acquired, which *must* be
+ used (otherwise the lock will leak).
+ """
+
+ now = self._clock.time_msec()
+ token = random_string(6)
+
+ if self.db_pool.engine.can_native_upsert:
+
+ def _try_acquire_lock_txn(txn: LoggingTransaction) -> bool:
+ # We take out the lock if either a) there is no row for the lock
+ # already or b) the existing row has timed out.
+ sql = """
+ INSERT INTO worker_locks (lock_name, lock_key, instance_name, token, last_renewed_ts)
+ VALUES (?, ?, ?, ?, ?)
+ ON CONFLICT (lock_name, lock_key)
+ DO UPDATE
+ SET
+ token = EXCLUDED.token,
+ instance_name = EXCLUDED.instance_name,
+ last_renewed_ts = EXCLUDED.last_renewed_ts
+ WHERE
+ worker_locks.last_renewed_ts < ?
+ """
+ txn.execute(
+ sql,
+ (
+ lock_name,
+ lock_key,
+ self._instance_name,
+ token,
+ now,
+ now - _LOCK_TIMEOUT_MS,
+ ),
+ )
+
+ # We only acquired the lock if we inserted or updated the table.
+ return bool(txn.rowcount)
+
+ did_lock = await self.db_pool.runInteraction(
+ "try_acquire_lock",
+ _try_acquire_lock_txn,
+ # We can autocommit here as we're executing a single query, this
+ # will avoid serialization errors.
+ db_autocommit=True,
+ )
+ if not did_lock:
+ return None
+
+ else:
+ # If we're on an old SQLite we emulate the above logic by first
+ # clearing out any existing stale locks and then upserting.
+
+ def _try_acquire_lock_emulated_txn(txn: LoggingTransaction) -> bool:
+ sql = """
+ DELETE FROM worker_locks
+ WHERE
+ lock_name = ?
+ AND lock_key = ?
+ AND last_renewed_ts < ?
+ """
+ txn.execute(
+ sql,
+ (lock_name, lock_key, now - _LOCK_TIMEOUT_MS),
+ )
+
+ inserted = self.db_pool.simple_upsert_txn_emulated(
+ txn,
+ table="worker_locks",
+ keyvalues={
+ "lock_name": lock_name,
+ "lock_key": lock_key,
+ },
+ values={},
+ insertion_values={
+ "token": token,
+ "last_renewed_ts": self._clock.time_msec(),
+ "instance_name": self._instance_name,
+ },
+ )
+
+ return inserted
+
+ did_lock = await self.db_pool.runInteraction(
+ "try_acquire_lock_emulated", _try_acquire_lock_emulated_txn
+ )
+
+ if not did_lock:
+ return None
+
+ self._live_tokens[(lock_name, lock_key)] = token
+
+ return Lock(
+ self._reactor,
+ self._clock,
+ self,
+ lock_name=lock_name,
+ lock_key=lock_key,
+ token=token,
+ )
+
+ async def _is_lock_still_valid(
+ self, lock_name: str, lock_key: str, token: str
+ ) -> bool:
+ """Checks whether this instance still holds the lock."""
+ last_renewed_ts = await self.db_pool.simple_select_one_onecol(
+ table="worker_locks",
+ keyvalues={
+ "lock_name": lock_name,
+ "lock_key": lock_key,
+ "token": token,
+ },
+ retcol="last_renewed_ts",
+ allow_none=True,
+ desc="is_lock_still_valid",
+ )
+ return (
+ last_renewed_ts is not None
+ and self._clock.time_msec() - _LOCK_TIMEOUT_MS < last_renewed_ts
+ )
+
+ async def _renew_lock(self, lock_name: str, lock_key: str, token: str) -> None:
+ """Attempt to renew the lock if we still hold it."""
+ await self.db_pool.simple_update(
+ table="worker_locks",
+ keyvalues={
+ "lock_name": lock_name,
+ "lock_key": lock_key,
+ "token": token,
+ },
+ updatevalues={"last_renewed_ts": self._clock.time_msec()},
+ desc="renew_lock",
+ )
+
+ async def _drop_lock(self, lock_name: str, lock_key: str, token: str) -> None:
+ """Attempt to drop the lock, if we still hold it"""
+ await self.db_pool.simple_delete(
+ table="worker_locks",
+ keyvalues={
+ "lock_name": lock_name,
+ "lock_key": lock_key,
+ "token": token,
+ },
+ desc="drop_lock",
+ )
+
+ self._live_tokens.pop((lock_name, lock_key), None)
+
+
+class Lock:
+ """An async context manager that manages an acquired lock, ensuring it is
+ regularly renewed and dropping it when the context manager exits.
+
+ The lock object has an `is_still_valid` method which can be used to
+ double-check the lock is still valid, if e.g. processing work in a loop.
+
+ For example:
+
+ lock = await self.store.try_acquire_lock(...)
+ if not lock:
+ return
+
+ async with lock:
+ for item in work:
+ await process(item)
+
+ if not await lock.is_still_valid():
+ break
+ """
+
+ def __init__(
+ self,
+ reactor: IReactorCore,
+ clock: Clock,
+ store: LockStore,
+ lock_name: str,
+ lock_key: str,
+ token: str,
+ ) -> None:
+ self._reactor = reactor
+ self._clock = clock
+ self._store = store
+ self._lock_name = lock_name
+ self._lock_key = lock_key
+
+ self._token = token
+
+ self._looping_call = clock.looping_call(
+ self._renew, _RENEWAL_INTERVAL_MS, store, lock_name, lock_key, token
+ )
+
+ self._dropped = False
+
+ @staticmethod
+ @wrap_as_background_process("Lock._renew")
+ async def _renew(
+ store: LockStore,
+ lock_name: str,
+ lock_key: str,
+ token: str,
+ ) -> None:
+ """Renew the lock.
+
+ Note: this is a static method, rather than using self.*, so that we
+ don't end up with a reference to `self` in the reactor, which would stop
+ this from being cleaned up if we dropped the context manager.
+ """
+ await store._renew_lock(lock_name, lock_key, token)
+
+ async def is_still_valid(self) -> bool:
+ """Check if the lock is still held by us"""
+ return await self._store._is_lock_still_valid(
+ self._lock_name, self._lock_key, self._token
+ )
+
+ async def __aenter__(self) -> None:
+ if self._dropped:
+ raise Exception("Cannot reuse a Lock object")
+
+ async def __aexit__(
+ self,
+ _exctype: Optional[Type[BaseException]],
+ _excinst: Optional[BaseException],
+ _exctb: Optional[TracebackType],
+ ) -> bool:
+ if self._looping_call.running:
+ self._looping_call.stop()
+
+ await self._store._drop_lock(self._lock_name, self._lock_key, self._token)
+ self._dropped = True
+
+ return False
+
+ def __del__(self) -> None:
+ if not self._dropped:
+ # We should not be dropped without the lock being released (unless
+ # we're shutting down), but if we are then let's at least stop
+ # renewing the lock.
+ if self._looping_call.running:
+ self._looping_call.stop()
+
+ if self._reactor.running:
+ logger.error(
+ "Lock for (%s, %s) dropped without being released",
+ self._lock_name,
+ self._lock_key,
+ )
diff --git a/synapse/storage/schema/main/delta/59/15locks.sql b/synapse/storage/schema/main/delta/59/15locks.sql
new file mode 100644
index 0000000000..8b2999ff3e
--- /dev/null
+++ b/synapse/storage/schema/main/delta/59/15locks.sql
@@ -0,0 +1,37 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+-- A noddy implementation of a distributed lock across workers. While a worker
+-- has taken a lock out they should regularly update the `last_renewed_ts`
+-- column, a lock will be considered dropped if `last_renewed_ts` is from ages
+-- ago.
+CREATE TABLE worker_locks (
+ lock_name TEXT NOT NULL,
+ lock_key TEXT NOT NULL,
+ -- We write the instance name to ease manual debugging, we don't ever read
+ -- from it.
+ -- Note: instance names aren't guarenteed to be unique.
+ instance_name TEXT NOT NULL,
+ -- A random string generated each time an instance takes out a lock. Used by
+ -- the instance to tell whether the lock is still held by it (e.g. in the
+ -- case where the process stalls for a long time the lock may time out and
+ -- be taken out by another instance, at which point the original instance
+ -- can tell it no longer holds the lock as the tokens no longer match).
+ token TEXT NOT NULL,
+ last_renewed_ts BIGINT NOT NULL
+);
+
+CREATE UNIQUE INDEX worker_locks_key ON worker_locks (lock_name, lock_key);
diff --git a/tests/storage/databases/main/test_lock.py b/tests/storage/databases/main/test_lock.py
new file mode 100644
index 0000000000..9ca70e7367
--- /dev/null
+++ b/tests/storage/databases/main/test_lock.py
@@ -0,0 +1,100 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.server import HomeServer
+from synapse.storage.databases.main.lock import _LOCK_TIMEOUT_MS
+
+from tests import unittest
+
+
+class LockTestCase(unittest.HomeserverTestCase):
+ def prepare(self, reactor, clock, hs: HomeServer):
+ self.store = hs.get_datastore()
+
+ def test_simple_lock(self):
+ """Test that we can take out a lock and that while we hold it nobody
+ else can take it out.
+ """
+ # First to acquire this lock, so it should complete
+ lock = self.get_success(self.store.try_acquire_lock("name", "key"))
+ self.assertIsNotNone(lock)
+
+ # Enter the context manager
+ self.get_success(lock.__aenter__())
+
+ # Attempting to acquire the lock again fails.
+ lock2 = self.get_success(self.store.try_acquire_lock("name", "key"))
+ self.assertIsNone(lock2)
+
+ # Calling `is_still_valid` reports true.
+ self.assertTrue(self.get_success(lock.is_still_valid()))
+
+ # Drop the lock
+ self.get_success(lock.__aexit__(None, None, None))
+
+ # We can now acquire the lock again.
+ lock3 = self.get_success(self.store.try_acquire_lock("name", "key"))
+ self.assertIsNotNone(lock3)
+ self.get_success(lock3.__aenter__())
+ self.get_success(lock3.__aexit__(None, None, None))
+
+ def test_maintain_lock(self):
+ """Test that we don't time out locks while they're still active"""
+
+ lock = self.get_success(self.store.try_acquire_lock("name", "key"))
+ self.assertIsNotNone(lock)
+
+ self.get_success(lock.__aenter__())
+
+ # Wait for ages with the lock, we should not be able to get the lock.
+ self.reactor.advance(5 * _LOCK_TIMEOUT_MS / 1000)
+
+ lock2 = self.get_success(self.store.try_acquire_lock("name", "key"))
+ self.assertIsNone(lock2)
+
+ self.get_success(lock.__aexit__(None, None, None))
+
+ def test_timeout_lock(self):
+ """Test that we time out locks if they're not updated for ages"""
+
+ lock = self.get_success(self.store.try_acquire_lock("name", "key"))
+ self.assertIsNotNone(lock)
+
+ self.get_success(lock.__aenter__())
+
+ # We simulate the process getting stuck by cancelling the looping call
+ # that keeps the lock active.
+ lock._looping_call.stop()
+
+ # Wait for the lock to timeout.
+ self.reactor.advance(2 * _LOCK_TIMEOUT_MS / 1000)
+
+ lock2 = self.get_success(self.store.try_acquire_lock("name", "key"))
+ self.assertIsNotNone(lock2)
+
+ self.assertFalse(self.get_success(lock.is_still_valid()))
+
+ def test_drop(self):
+ """Test that dropping the context manager means we stop renewing the lock"""
+
+ lock = self.get_success(self.store.try_acquire_lock("name", "key"))
+ self.assertIsNotNone(lock)
+
+ del lock
+
+ # Wait for the lock to timeout.
+ self.reactor.advance(2 * _LOCK_TIMEOUT_MS / 1000)
+
+ lock2 = self.get_success(self.store.try_acquire_lock("name", "key"))
+ self.assertIsNotNone(lock2)
|