diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index 5fb3d5083d..359999f680 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -17,8 +17,6 @@ import urllib.parse
from typing import (
TYPE_CHECKING,
Any,
- Awaitable,
- Callable,
Dict,
Iterable,
List,
@@ -30,7 +28,7 @@ from typing import (
)
from prometheus_client import Counter
-from typing_extensions import Concatenate, ParamSpec, TypeGuard
+from typing_extensions import ParamSpec, TypeGuard
from synapse.api.constants import EventTypes, Membership, ThirdPartyEntityKind
from synapse.api.errors import CodeMessageException, HttpResponseException
@@ -80,9 +78,7 @@ sent_todevice_counter = Counter(
HOUR_IN_MS = 60 * 60 * 1000
-
APP_SERVICE_PREFIX = "/_matrix/app/v1"
-APP_SERVICE_UNSTABLE_PREFIX = "/_matrix/app/unstable"
P = ParamSpec("P")
R = TypeVar("R")
@@ -128,47 +124,6 @@ class ApplicationServiceApi(SimpleHttpClient):
hs.get_clock(), "as_protocol_meta", timeout_ms=HOUR_IN_MS
)
- async def _send_with_fallbacks(
- self,
- service: "ApplicationService",
- prefixes: List[str],
- path: str,
- func: Callable[Concatenate[str, P], Awaitable[R]],
- *args: P.args,
- **kwargs: P.kwargs,
- ) -> R:
- """
- Attempt to call an application service with multiple paths, falling back
- until one succeeds.
-
- Args:
- service: The appliacation service, this provides the base URL.
- prefixes: A last of paths to try in order for the requests.
- path: A suffix to append to each prefix.
- func: The function to call, the first argument will be the full
- endpoint to fetch. Other arguments are provided by args/kwargs.
-
- Returns:
- The return value of func.
- """
- for i, prefix in enumerate(prefixes, start=1):
- uri = f"{service.url}{prefix}{path}"
- try:
- return await func(uri, *args, **kwargs)
- except HttpResponseException as e:
- # If an error is received that is due to an unrecognised path,
- # fallback to next path (if one exists). Otherwise, consider it
- # a legitimate error and raise.
- if i < len(prefixes) and is_unknown_endpoint(e):
- continue
- raise
- except Exception:
- # Unexpected exceptions get sent to the caller.
- raise
-
- # The function should always exit via the return or raise above this.
- raise RuntimeError("Unexpected fallback behaviour. This should never be seen.")
-
async def query_user(self, service: "ApplicationService", user_id: str) -> bool:
if service.url is None:
return False
@@ -177,11 +132,8 @@ class ApplicationServiceApi(SimpleHttpClient):
assert service.hs_token is not None
try:
- response = await self._send_with_fallbacks(
- service,
- [APP_SERVICE_PREFIX, ""],
- f"/users/{urllib.parse.quote(user_id)}",
- self.get_json,
+ response = await self.get_json(
+ f"{service.url}{APP_SERVICE_PREFIX}/users/{urllib.parse.quote(user_id)}",
{"access_token": service.hs_token},
headers={"Authorization": [f"Bearer {service.hs_token}"]},
)
@@ -203,11 +155,8 @@ class ApplicationServiceApi(SimpleHttpClient):
assert service.hs_token is not None
try:
- response = await self._send_with_fallbacks(
- service,
- [APP_SERVICE_PREFIX, ""],
- f"/rooms/{urllib.parse.quote(alias)}",
- self.get_json,
+ response = await self.get_json(
+ f"{service.url}{APP_SERVICE_PREFIX}/rooms/{urllib.parse.quote(alias)}",
{"access_token": service.hs_token},
headers={"Authorization": [f"Bearer {service.hs_token}"]},
)
@@ -245,11 +194,8 @@ class ApplicationServiceApi(SimpleHttpClient):
**fields,
b"access_token": service.hs_token,
}
- response = await self._send_with_fallbacks(
- service,
- [APP_SERVICE_PREFIX, APP_SERVICE_UNSTABLE_PREFIX],
- f"/thirdparty/{kind}/{urllib.parse.quote(protocol)}",
- self.get_json,
+ response = await self.get_json(
+ f"{service.url}{APP_SERVICE_PREFIX}/thirdparty/{kind}/{urllib.parse.quote(protocol)}",
args=args,
headers={"Authorization": [f"Bearer {service.hs_token}"]},
)
@@ -285,11 +231,8 @@ class ApplicationServiceApi(SimpleHttpClient):
# This is required by the configuration.
assert service.hs_token is not None
try:
- info = await self._send_with_fallbacks(
- service,
- [APP_SERVICE_PREFIX, APP_SERVICE_UNSTABLE_PREFIX],
- f"/thirdparty/protocol/{urllib.parse.quote(protocol)}",
- self.get_json,
+ info = await self.get_json(
+ f"{service.url}{APP_SERVICE_PREFIX}/thirdparty/protocol/{urllib.parse.quote(protocol)}",
{"access_token": service.hs_token},
headers={"Authorization": [f"Bearer {service.hs_token}"]},
)
@@ -401,11 +344,8 @@ class ApplicationServiceApi(SimpleHttpClient):
}
try:
- await self._send_with_fallbacks(
- service,
- [APP_SERVICE_PREFIX, ""],
- f"/transactions/{urllib.parse.quote(str(txn_id))}",
- self.put_json,
+ await self.put_json(
+ f"{service.url}{APP_SERVICE_PREFIX}/transactions/{urllib.parse.quote(str(txn_id))}",
json_body=body,
args={"access_token": service.hs_token},
headers={"Authorization": [f"Bearer {service.hs_token}"]},
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index fa61dd8c10..a90d99c4d6 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -63,6 +63,7 @@ from synapse.federation.federation_base import (
)
from synapse.federation.persistence import TransactionActions
from synapse.federation.units import Edu, Transaction
+from synapse.handlers.worker_lock import DELETE_ROOM_LOCK_NAME
from synapse.http.servlet import assert_params_in_dict
from synapse.logging.context import (
make_deferred_yieldable,
@@ -137,6 +138,7 @@ class FederationServer(FederationBase):
self._event_auth_handler = hs.get_event_auth_handler()
self._room_member_handler = hs.get_room_member_handler()
self._e2e_keys_handler = hs.get_e2e_keys_handler()
+ self._worker_lock_handler = hs.get_worker_locks_handler()
self._state_storage_controller = hs.get_storage_controllers().state
@@ -1236,9 +1238,18 @@ class FederationServer(FederationBase):
logger.info("handling received PDU in room %s: %s", room_id, event)
try:
with nested_logging_context(event.event_id):
- await self._federation_event_handler.on_receive_pdu(
- origin, event
- )
+ # We're taking out a lock within a lock, which could
+ # lead to deadlocks if we're not careful. However, it is
+ # safe on this occasion as we only ever take a write
+ # lock when deleting a room, which we would never do
+ # while holding the `_INBOUND_EVENT_HANDLING_LOCK_NAME`
+ # lock.
+ async with self._worker_lock_handler.acquire_read_write_lock(
+ DELETE_ROOM_LOCK_NAME, room_id, write=False
+ ):
+ await self._federation_event_handler.on_receive_pdu(
+ origin, event
+ )
except FederationError as e:
# XXX: Ideally we'd inform the remote we failed to process
# the event, but we can't return an error in the transaction
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index fff0b5fa12..c656e07d37 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -53,6 +53,7 @@ from synapse.events.snapshot import EventContext, UnpersistedEventContextBase
from synapse.events.utils import SerializeEventConfig, maybe_upsert_event_field
from synapse.events.validator import EventValidator
from synapse.handlers.directory import DirectoryHandler
+from synapse.handlers.worker_lock import DELETE_ROOM_LOCK_NAME
from synapse.logging import opentracing
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
@@ -485,6 +486,7 @@ class EventCreationHandler:
self._events_shard_config = self.config.worker.events_shard_config
self._instance_name = hs.get_instance_name()
self._notifier = hs.get_notifier()
+ self._worker_lock_handler = hs.get_worker_locks_handler()
self.room_prejoin_state_types = self.hs.config.api.room_prejoin_state
@@ -876,14 +878,13 @@ class EventCreationHandler:
return prev_event
return None
- async def get_event_from_transaction(
+ async def get_event_id_from_transaction(
self,
requester: Requester,
txn_id: str,
room_id: str,
- ) -> Optional[EventBase]:
- """For the given transaction ID and room ID, check if there is a matching event.
- If so, fetch it and return it.
+ ) -> Optional[str]:
+ """For the given transaction ID and room ID, check if there is a matching event ID.
Args:
requester: The requester making the request in the context of which we want
@@ -892,8 +893,9 @@ class EventCreationHandler:
room_id: The room ID.
Returns:
- An event if one could be found, None otherwise.
+ An event ID if one could be found, None otherwise.
"""
+ existing_event_id = None
if self._msc3970_enabled and requester.device_id:
# When MSC3970 is enabled, we lookup for events sent by the same device first,
@@ -907,7 +909,7 @@ class EventCreationHandler:
)
)
if existing_event_id:
- return await self.store.get_event(existing_event_id)
+ return existing_event_id
# Pre-MSC3970, we looked up for events that were sent by the same session by
# using the access token ID.
@@ -920,9 +922,32 @@ class EventCreationHandler:
txn_id,
)
)
- if existing_event_id:
- return await self.store.get_event(existing_event_id)
+ return existing_event_id
+
+ async def get_event_from_transaction(
+ self,
+ requester: Requester,
+ txn_id: str,
+ room_id: str,
+ ) -> Optional[EventBase]:
+ """For the given transaction ID and room ID, check if there is a matching event.
+ If so, fetch it and return it.
+
+ Args:
+ requester: The requester making the request in the context of which we want
+ to fetch the event.
+ txn_id: The transaction ID.
+ room_id: The room ID.
+
+ Returns:
+ An event if one could be found, None otherwise.
+ """
+ existing_event_id = await self.get_event_id_from_transaction(
+ requester, txn_id, room_id
+ )
+ if existing_event_id:
+ return await self.store.get_event(existing_event_id)
return None
async def create_and_send_nonmember_event(
@@ -1010,6 +1035,37 @@ class EventCreationHandler:
event.internal_metadata.stream_ordering,
)
+ async with self._worker_lock_handler.acquire_read_write_lock(
+ DELETE_ROOM_LOCK_NAME, room_id, write=False
+ ):
+ return await self._create_and_send_nonmember_event_locked(
+ requester=requester,
+ event_dict=event_dict,
+ allow_no_prev_events=allow_no_prev_events,
+ prev_event_ids=prev_event_ids,
+ state_event_ids=state_event_ids,
+ ratelimit=ratelimit,
+ txn_id=txn_id,
+ ignore_shadow_ban=ignore_shadow_ban,
+ outlier=outlier,
+ depth=depth,
+ )
+
+ async def _create_and_send_nonmember_event_locked(
+ self,
+ requester: Requester,
+ event_dict: dict,
+ allow_no_prev_events: bool = False,
+ prev_event_ids: Optional[List[str]] = None,
+ state_event_ids: Optional[List[str]] = None,
+ ratelimit: bool = True,
+ txn_id: Optional[str] = None,
+ ignore_shadow_ban: bool = False,
+ outlier: bool = False,
+ depth: Optional[int] = None,
+ ) -> Tuple[EventBase, int]:
+ room_id = event_dict["room_id"]
+
# If we don't have any prev event IDs specified then we need to
# check that the host is in the room (as otherwise populating the
# prev events will fail), at which point we may as well check the
@@ -1923,7 +1979,10 @@ class EventCreationHandler:
)
for room_id in room_ids:
- dummy_event_sent = await self._send_dummy_event_for_room(room_id)
+ async with self._worker_lock_handler.acquire_read_write_lock(
+ DELETE_ROOM_LOCK_NAME, room_id, write=False
+ ):
+ dummy_event_sent = await self._send_dummy_event_for_room(room_id)
if not dummy_event_sent:
# Did not find a valid user in the room, so remove from future attempts
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 19b8728db9..da34658470 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -46,6 +46,11 @@ logger = logging.getLogger(__name__)
BACKFILL_BECAUSE_TOO_MANY_GAPS_THRESHOLD = 3
+PURGE_HISTORY_LOCK_NAME = "purge_history_lock"
+
+DELETE_ROOM_LOCK_NAME = "delete_room_lock"
+
+
@attr.s(slots=True, auto_attribs=True)
class PurgeStatus:
"""Object tracking the status of a purge request
@@ -142,6 +147,7 @@ class PaginationHandler:
self._server_name = hs.hostname
self._room_shutdown_handler = hs.get_room_shutdown_handler()
self._relations_handler = hs.get_relations_handler()
+ self._worker_locks = hs.get_worker_locks_handler()
self.pagination_lock = ReadWriteLock()
# IDs of rooms in which there currently an active purge *or delete* operation.
@@ -356,7 +362,9 @@ class PaginationHandler:
"""
self._purges_in_progress_by_room.add(room_id)
try:
- async with self.pagination_lock.write(room_id):
+ async with self._worker_locks.acquire_read_write_lock(
+ PURGE_HISTORY_LOCK_NAME, room_id, write=True
+ ):
await self._storage_controllers.purge_events.purge_history(
room_id, token, delete_local_events
)
@@ -412,7 +420,10 @@ class PaginationHandler:
room_id: room to be purged
force: set true to skip checking for joined users.
"""
- async with self.pagination_lock.write(room_id):
+ async with self._worker_locks.acquire_multi_read_write_lock(
+ [(PURGE_HISTORY_LOCK_NAME, room_id), (DELETE_ROOM_LOCK_NAME, room_id)],
+ write=True,
+ ):
# first check that we have no users in this room
if not force:
joined = await self.store.is_host_joined(room_id, self._server_name)
@@ -471,7 +482,9 @@ class PaginationHandler:
room_token = from_token.room_key
- async with self.pagination_lock.read(room_id):
+ async with self._worker_locks.acquire_read_write_lock(
+ PURGE_HISTORY_LOCK_NAME, room_id, write=False
+ ):
(membership, member_event_id) = (None, None)
if not use_admin_priviledge:
(
@@ -747,7 +760,9 @@ class PaginationHandler:
self._purges_in_progress_by_room.add(room_id)
try:
- async with self.pagination_lock.write(room_id):
+ async with self._worker_locks.acquire_read_write_lock(
+ PURGE_HISTORY_LOCK_NAME, room_id, write=True
+ ):
self._delete_by_id[delete_id].status = DeleteStatus.STATUS_SHUTTING_DOWN
self._delete_by_id[
delete_id
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index a7f8c5e636..c7fe101cd9 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -68,7 +68,7 @@ class ProfileHandler:
if self.hs.is_mine(target_user):
profileinfo = await self.store.get_profileinfo(target_user)
- if profileinfo.display_name is None:
+ if profileinfo.display_name is None and profileinfo.avatar_url is None:
raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND)
return {
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 496e701f13..e3cdf2bc61 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -39,6 +39,7 @@ from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN
from synapse.handlers.state_deltas import MatchChange, StateDeltasHandler
+from synapse.handlers.worker_lock import DELETE_ROOM_LOCK_NAME
from synapse.logging import opentracing
from synapse.metrics import event_processing_positions
from synapse.metrics.background_process_metrics import run_as_background_process
@@ -94,6 +95,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
self.event_creation_handler = hs.get_event_creation_handler()
self.account_data_handler = hs.get_account_data_handler()
self.event_auth_handler = hs.get_event_auth_handler()
+ self._worker_lock_handler = hs.get_worker_locks_handler()
self.member_linearizer: Linearizer = Linearizer(name="member")
self.member_as_limiter = Linearizer(max_count=10, name="member_as_limiter")
@@ -174,8 +176,6 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
self.request_ratelimiter = hs.get_request_ratelimiter()
hs.get_notifier().add_new_join_in_room_callback(self._on_user_joined_room)
- self._msc3970_enabled = hs.config.experimental.msc3970_enabled
-
def _on_user_joined_room(self, event_id: str, room_id: str) -> None:
"""Notify the rate limiter that a room join has occurred.
@@ -416,29 +416,11 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# do this check just before we persist an event as well, but may as well
# do it up front for efficiency.)
if txn_id:
- existing_event_id = None
- if self._msc3970_enabled and requester.device_id:
- # When MSC3970 is enabled, we lookup for events sent by the same device
- # first, and fallback to the old behaviour if none were found.
- existing_event_id = (
- await self.store.get_event_id_from_transaction_id_and_device_id(
- room_id,
- requester.user.to_string(),
- requester.device_id,
- txn_id,
- )
+ existing_event_id = (
+ await self.event_creation_handler.get_event_id_from_transaction(
+ requester, txn_id, room_id
)
-
- if requester.access_token_id and not existing_event_id:
- existing_event_id = (
- await self.store.get_event_id_from_transaction_id_and_token_id(
- room_id,
- requester.user.to_string(),
- requester.access_token_id,
- txn_id,
- )
- )
-
+ )
if existing_event_id:
event_pos = await self.store.get_position_for_event(existing_event_id)
return existing_event_id, event_pos.stream
@@ -638,26 +620,29 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# by application services), and then by room ID.
async with self.member_as_limiter.queue(as_id):
async with self.member_linearizer.queue(key):
- with opentracing.start_active_span("update_membership_locked"):
- result = await self.update_membership_locked(
- requester,
- target,
- room_id,
- action,
- txn_id=txn_id,
- remote_room_hosts=remote_room_hosts,
- third_party_signed=third_party_signed,
- ratelimit=ratelimit,
- content=content,
- new_room=new_room,
- require_consent=require_consent,
- outlier=outlier,
- allow_no_prev_events=allow_no_prev_events,
- prev_event_ids=prev_event_ids,
- state_event_ids=state_event_ids,
- depth=depth,
- origin_server_ts=origin_server_ts,
- )
+ async with self._worker_lock_handler.acquire_read_write_lock(
+ DELETE_ROOM_LOCK_NAME, room_id, write=False
+ ):
+ with opentracing.start_active_span("update_membership_locked"):
+ result = await self.update_membership_locked(
+ requester,
+ target,
+ room_id,
+ action,
+ txn_id=txn_id,
+ remote_room_hosts=remote_room_hosts,
+ third_party_signed=third_party_signed,
+ ratelimit=ratelimit,
+ content=content,
+ new_room=new_room,
+ require_consent=require_consent,
+ outlier=outlier,
+ allow_no_prev_events=allow_no_prev_events,
+ prev_event_ids=prev_event_ids,
+ state_event_ids=state_event_ids,
+ depth=depth,
+ origin_server_ts=origin_server_ts,
+ )
return result
diff --git a/synapse/handlers/worker_lock.py b/synapse/handlers/worker_lock.py
new file mode 100644
index 0000000000..72df773a86
--- /dev/null
+++ b/synapse/handlers/worker_lock.py
@@ -0,0 +1,333 @@
+# Copyright 2023 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import random
+from types import TracebackType
+from typing import (
+ TYPE_CHECKING,
+ AsyncContextManager,
+ Collection,
+ Dict,
+ Optional,
+ Tuple,
+ Type,
+ Union,
+)
+from weakref import WeakSet
+
+import attr
+
+from twisted.internet import defer
+from twisted.internet.interfaces import IReactorTime
+
+from synapse.logging.context import PreserveLoggingContext
+from synapse.logging.opentracing import start_active_span
+from synapse.metrics.background_process_metrics import wrap_as_background_process
+from synapse.storage.databases.main.lock import Lock, LockStore
+from synapse.util.async_helpers import timeout_deferred
+
+if TYPE_CHECKING:
+ from synapse.logging.opentracing import opentracing
+ from synapse.server import HomeServer
+
+
+DELETE_ROOM_LOCK_NAME = "delete_room_lock"
+
+
+class WorkerLocksHandler:
+ """A class for waiting on taking out locks, rather than using the storage
+ functions directly (which don't support awaiting).
+ """
+
+ def __init__(self, hs: "HomeServer") -> None:
+ self._reactor = hs.get_reactor()
+ self._store = hs.get_datastores().main
+ self._clock = hs.get_clock()
+ self._notifier = hs.get_notifier()
+ self._instance_name = hs.get_instance_name()
+
+ # Map from lock name/key to set of `WaitingLock` that are active for
+ # that lock.
+ self._locks: Dict[
+ Tuple[str, str], WeakSet[Union[WaitingLock, WaitingMultiLock]]
+ ] = {}
+
+ self._clock.looping_call(self._cleanup_locks, 30_000)
+
+ self._notifier.add_lock_released_callback(self._on_lock_released)
+
+ def acquire_lock(self, lock_name: str, lock_key: str) -> "WaitingLock":
+ """Acquire a standard lock, returns a context manager that will block
+ until the lock is acquired.
+
+ Note: Care must be taken to avoid deadlocks. In particular, this
+ function does *not* timeout.
+
+ Usage:
+ async with handler.acquire_lock(name, key):
+ # Do work while holding the lock...
+ """
+
+ lock = WaitingLock(
+ reactor=self._reactor,
+ store=self._store,
+ handler=self,
+ lock_name=lock_name,
+ lock_key=lock_key,
+ write=None,
+ )
+
+ self._locks.setdefault((lock_name, lock_key), WeakSet()).add(lock)
+
+ return lock
+
+ def acquire_read_write_lock(
+ self,
+ lock_name: str,
+ lock_key: str,
+ *,
+ write: bool,
+ ) -> "WaitingLock":
+ """Acquire a read/write lock, returns a context manager that will block
+ until the lock is acquired.
+
+ Note: Care must be taken to avoid deadlocks. In particular, this
+ function does *not* timeout.
+
+ Usage:
+ async with handler.acquire_read_write_lock(name, key, write=True):
+ # Do work while holding the lock...
+ """
+
+ lock = WaitingLock(
+ reactor=self._reactor,
+ store=self._store,
+ handler=self,
+ lock_name=lock_name,
+ lock_key=lock_key,
+ write=write,
+ )
+
+ self._locks.setdefault((lock_name, lock_key), WeakSet()).add(lock)
+
+ return lock
+
+ def acquire_multi_read_write_lock(
+ self,
+ lock_names: Collection[Tuple[str, str]],
+ *,
+ write: bool,
+ ) -> "WaitingMultiLock":
+ """Acquires multi read/write locks at once, returns a context manager
+ that will block until all the locks are acquired.
+
+ This will try and acquire all locks at once, and will never hold on to a
+ subset of the locks. (This avoids accidentally creating deadlocks).
+
+ Note: Care must be taken to avoid deadlocks. In particular, this
+ function does *not* timeout.
+ """
+
+ lock = WaitingMultiLock(
+ lock_names=lock_names,
+ write=write,
+ reactor=self._reactor,
+ store=self._store,
+ handler=self,
+ )
+
+ for lock_name, lock_key in lock_names:
+ self._locks.setdefault((lock_name, lock_key), WeakSet()).add(lock)
+
+ return lock
+
+ def notify_lock_released(self, lock_name: str, lock_key: str) -> None:
+ """Notify that a lock has been released.
+
+ Pokes both the notifier and replication.
+ """
+
+ self._notifier.notify_lock_released(self._instance_name, lock_name, lock_key)
+
+ def _on_lock_released(
+ self, instance_name: str, lock_name: str, lock_key: str
+ ) -> None:
+ """Called when a lock has been released.
+
+ Wakes up any locks that might be waiting on this.
+ """
+ locks = self._locks.get((lock_name, lock_key))
+ if not locks:
+ return
+
+ def _wake_deferred(deferred: defer.Deferred) -> None:
+ if not deferred.called:
+ deferred.callback(None)
+
+ for lock in locks:
+ self._clock.call_later(0, _wake_deferred, lock.deferred)
+
+ @wrap_as_background_process("_cleanup_locks")
+ async def _cleanup_locks(self) -> None:
+ """Periodically cleans out stale entries in the locks map"""
+ self._locks = {key: value for key, value in self._locks.items() if value}
+
+
+@attr.s(auto_attribs=True, eq=False)
+class WaitingLock:
+ reactor: IReactorTime
+ store: LockStore
+ handler: WorkerLocksHandler
+ lock_name: str
+ lock_key: str
+ write: Optional[bool]
+ deferred: "defer.Deferred[None]" = attr.Factory(defer.Deferred)
+ _inner_lock: Optional[Lock] = None
+ _retry_interval: float = 0.1
+ _lock_span: "opentracing.Scope" = attr.Factory(
+ lambda: start_active_span("WaitingLock.lock")
+ )
+
+ async def __aenter__(self) -> None:
+ self._lock_span.__enter__()
+
+ with start_active_span("WaitingLock.waiting_for_lock"):
+ while self._inner_lock is None:
+ self.deferred = defer.Deferred()
+
+ if self.write is not None:
+ lock = await self.store.try_acquire_read_write_lock(
+ self.lock_name, self.lock_key, write=self.write
+ )
+ else:
+ lock = await self.store.try_acquire_lock(
+ self.lock_name, self.lock_key
+ )
+
+ if lock:
+ self._inner_lock = lock
+ break
+
+ try:
+ # Wait until the we get notified the lock might have been
+ # released (by the deferred being resolved). We also
+ # periodically wake up in case the lock was released but we
+ # weren't notified.
+ with PreserveLoggingContext():
+ await timeout_deferred(
+ deferred=self.deferred,
+ timeout=self._get_next_retry_interval(),
+ reactor=self.reactor,
+ )
+ except Exception:
+ pass
+
+ return await self._inner_lock.__aenter__()
+
+ async def __aexit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc: Optional[BaseException],
+ tb: Optional[TracebackType],
+ ) -> Optional[bool]:
+ assert self._inner_lock
+
+ self.handler.notify_lock_released(self.lock_name, self.lock_key)
+
+ try:
+ r = await self._inner_lock.__aexit__(exc_type, exc, tb)
+ finally:
+ self._lock_span.__exit__(exc_type, exc, tb)
+
+ return r
+
+ def _get_next_retry_interval(self) -> float:
+ next = self._retry_interval
+ self._retry_interval = max(5, next * 2)
+ return next * random.uniform(0.9, 1.1)
+
+
+@attr.s(auto_attribs=True, eq=False)
+class WaitingMultiLock:
+ lock_names: Collection[Tuple[str, str]]
+
+ write: bool
+
+ reactor: IReactorTime
+ store: LockStore
+ handler: WorkerLocksHandler
+
+ deferred: "defer.Deferred[None]" = attr.Factory(defer.Deferred)
+
+ _inner_lock_cm: Optional[AsyncContextManager] = None
+ _retry_interval: float = 0.1
+ _lock_span: "opentracing.Scope" = attr.Factory(
+ lambda: start_active_span("WaitingLock.lock")
+ )
+
+ async def __aenter__(self) -> None:
+ self._lock_span.__enter__()
+
+ with start_active_span("WaitingLock.waiting_for_lock"):
+ while self._inner_lock_cm is None:
+ self.deferred = defer.Deferred()
+
+ lock_cm = await self.store.try_acquire_multi_read_write_lock(
+ self.lock_names, write=self.write
+ )
+
+ if lock_cm:
+ self._inner_lock_cm = lock_cm
+ break
+
+ try:
+ # Wait until the we get notified the lock might have been
+ # released (by the deferred being resolved). We also
+ # periodically wake up in case the lock was released but we
+ # weren't notified.
+ with PreserveLoggingContext():
+ await timeout_deferred(
+ deferred=self.deferred,
+ timeout=self._get_next_retry_interval(),
+ reactor=self.reactor,
+ )
+ except Exception:
+ pass
+
+ assert self._inner_lock_cm
+ await self._inner_lock_cm.__aenter__()
+ return
+
+ async def __aexit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc: Optional[BaseException],
+ tb: Optional[TracebackType],
+ ) -> Optional[bool]:
+ assert self._inner_lock_cm
+
+ for lock_name, lock_key in self.lock_names:
+ self.handler.notify_lock_released(lock_name, lock_key)
+
+ try:
+ r = await self._inner_lock_cm.__aexit__(exc_type, exc, tb)
+ finally:
+ self._lock_span.__exit__(exc_type, exc, tb)
+
+ return r
+
+ def _get_next_retry_interval(self) -> float:
+ next = self._retry_interval
+ self._retry_interval = max(5, next * 2)
+ return next * random.uniform(0.9, 1.1)
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 897272ad5b..68115bca70 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -234,6 +234,9 @@ class Notifier:
self._third_party_rules = hs.get_module_api_callbacks().third_party_event_rules
+ # List of callbacks to be notified when a lock is released
+ self._lock_released_callback: List[Callable[[str, str, str], None]] = []
+
self.clock = hs.get_clock()
self.appservice_handler = hs.get_application_service_handler()
self._pusher_pool = hs.get_pusherpool()
@@ -785,6 +788,19 @@ class Notifier:
# that any in flight requests can be immediately retried.
self._federation_client.wake_destination(server)
+ def add_lock_released_callback(
+ self, callback: Callable[[str, str, str], None]
+ ) -> None:
+ """Add a function to be called whenever we are notified about a released lock."""
+ self._lock_released_callback.append(callback)
+
+ def notify_lock_released(
+ self, instance_name: str, lock_name: str, lock_key: str
+ ) -> None:
+ """Notify the callbacks that a lock has been released."""
+ for cb in self._lock_released_callback:
+ cb(instance_name, lock_name, lock_key)
+
@attr.s(auto_attribs=True)
class ReplicationNotifier:
diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py
index f874f072f9..73f3de3642 100644
--- a/synapse/replication/http/devices.py
+++ b/synapse/replication/http/devices.py
@@ -107,8 +107,7 @@ class ReplicationUploadKeysForUserRestServlet(ReplicationEndpoint):
Calls to e2e_keys_handler.upload_keys_for_user(user_id, device_id, keys) on
the main process to accomplish this.
- Defined in https://spec.matrix.org/v1.4/client-server-api/#post_matrixclientv3keysupload
- Request format(borrowed and expanded from KeyUploadServlet):
+ Request format for this endpoint (borrowed and expanded from KeyUploadServlet):
POST /_synapse/replication/upload_keys_for_user
@@ -117,6 +116,7 @@ class ReplicationUploadKeysForUserRestServlet(ReplicationEndpoint):
"device_id": "<device_id>",
"keys": {
....this part can be found in KeyUploadServlet in rest/client/keys.py....
+ or as defined in https://spec.matrix.org/v1.4/client-server-api/#post_matrixclientv3keysupload
}
}
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index 32f52e54d8..10f5c98ff8 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -422,6 +422,36 @@ class RemoteServerUpCommand(_SimpleCommand):
NAME = "REMOTE_SERVER_UP"
+class LockReleasedCommand(Command):
+ """Sent to inform other instances that a given lock has been dropped.
+
+ Format::
+
+ LOCK_RELEASED ["<instance_name>", "<lock_name>", "<lock_key>"]
+ """
+
+ NAME = "LOCK_RELEASED"
+
+ def __init__(
+ self,
+ instance_name: str,
+ lock_name: str,
+ lock_key: str,
+ ):
+ self.instance_name = instance_name
+ self.lock_name = lock_name
+ self.lock_key = lock_key
+
+ @classmethod
+ def from_line(cls: Type["LockReleasedCommand"], line: str) -> "LockReleasedCommand":
+ instance_name, lock_name, lock_key = json_decoder.decode(line)
+
+ return cls(instance_name, lock_name, lock_key)
+
+ def to_line(self) -> str:
+ return json_encoder.encode([self.instance_name, self.lock_name, self.lock_key])
+
+
_COMMANDS: Tuple[Type[Command], ...] = (
ServerCommand,
RdataCommand,
@@ -435,6 +465,7 @@ _COMMANDS: Tuple[Type[Command], ...] = (
UserIpCommand,
RemoteServerUpCommand,
ClearUserSyncsCommand,
+ LockReleasedCommand,
)
# Map of command name to command type.
@@ -448,6 +479,7 @@ VALID_SERVER_COMMANDS = (
ErrorCommand.NAME,
PingCommand.NAME,
RemoteServerUpCommand.NAME,
+ LockReleasedCommand.NAME,
)
# The commands the client is allowed to send
@@ -461,6 +493,7 @@ VALID_CLIENT_COMMANDS = (
UserIpCommand.NAME,
ErrorCommand.NAME,
RemoteServerUpCommand.NAME,
+ LockReleasedCommand.NAME,
)
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 5d108fe11b..a2cabba7b1 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -39,6 +39,7 @@ from synapse.replication.tcp.commands import (
ClearUserSyncsCommand,
Command,
FederationAckCommand,
+ LockReleasedCommand,
PositionCommand,
RdataCommand,
RemoteServerUpCommand,
@@ -248,6 +249,9 @@ class ReplicationCommandHandler:
if self._is_master or self._should_insert_client_ips:
self.subscribe_to_channel("USER_IP")
+ if hs.config.redis.redis_enabled:
+ self._notifier.add_lock_released_callback(self.on_lock_released)
+
def subscribe_to_channel(self, channel_name: str) -> None:
"""
Indicates that we wish to subscribe to a Redis channel by name.
@@ -648,6 +652,17 @@ class ReplicationCommandHandler:
self._notifier.notify_remote_server_up(cmd.data)
+ def on_LOCK_RELEASED(
+ self, conn: IReplicationConnection, cmd: LockReleasedCommand
+ ) -> None:
+ """Called when we get a new LOCK_RELEASED command."""
+ if cmd.instance_name == self._instance_name:
+ return
+
+ self._notifier.notify_lock_released(
+ cmd.instance_name, cmd.lock_name, cmd.lock_key
+ )
+
def new_connection(self, connection: IReplicationConnection) -> None:
"""Called when we have a new connection."""
self._connections.append(connection)
@@ -754,6 +769,13 @@ class ReplicationCommandHandler:
"""
self.send_command(RdataCommand(stream_name, self._instance_name, token, data))
+ def on_lock_released(
+ self, instance_name: str, lock_name: str, lock_key: str
+ ) -> None:
+ """Called when we released a lock and should notify other instances."""
+ if instance_name == self._instance_name:
+ self.send_command(LockReleasedCommand(instance_name, lock_name, lock_key))
+
UpdateToken = TypeVar("UpdateToken")
UpdateRow = TypeVar("UpdateRow")
diff --git a/synapse/rest/client/room_upgrade_rest_servlet.py b/synapse/rest/client/room_upgrade_rest_servlet.py
index 6a7792e18b..4a5d9e13e7 100644
--- a/synapse/rest/client/room_upgrade_rest_servlet.py
+++ b/synapse/rest/client/room_upgrade_rest_servlet.py
@@ -17,6 +17,7 @@ from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import Codes, ShadowBanError, SynapseError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
+from synapse.handlers.worker_lock import DELETE_ROOM_LOCK_NAME
from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
@@ -60,6 +61,7 @@ class RoomUpgradeRestServlet(RestServlet):
self._hs = hs
self._room_creation_handler = hs.get_room_creation_handler()
self._auth = hs.get_auth()
+ self._worker_lock_handler = hs.get_worker_locks_handler()
async def on_POST(
self, request: SynapseRequest, room_id: str
@@ -78,9 +80,12 @@ class RoomUpgradeRestServlet(RestServlet):
)
try:
- new_room_id = await self._room_creation_handler.upgrade_room(
- requester, room_id, new_version
- )
+ async with self._worker_lock_handler.acquire_read_write_lock(
+ DELETE_ROOM_LOCK_NAME, room_id, write=False
+ ):
+ new_room_id = await self._room_creation_handler.upgrade_room(
+ requester, room_id, new_version
+ )
except ShadowBanError:
# Generate a random room ID.
new_room_id = stringutils.random_string(18)
diff --git a/synapse/server.py b/synapse/server.py
index b72b76a38b..8430f99ef2 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -107,6 +107,7 @@ from synapse.handlers.stats import StatsHandler
from synapse.handlers.sync import SyncHandler
from synapse.handlers.typing import FollowerTypingHandler, TypingWriterHandler
from synapse.handlers.user_directory import UserDirectoryHandler
+from synapse.handlers.worker_lock import WorkerLocksHandler
from synapse.http.client import (
InsecureInterceptableContextFactory,
ReplicationClient,
@@ -912,3 +913,7 @@ class HomeServer(metaclass=abc.ABCMeta):
def get_common_usage_metrics_manager(self) -> CommonUsageMetricsManager:
"""Usage metrics shared between phone home stats and the prometheus exporter."""
return CommonUsageMetricsManager(self)
+
+ @cache_in_self
+ def get_worker_locks_handler(self) -> WorkerLocksHandler:
+ return WorkerLocksHandler(self)
diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py
index 35c0680365..35cd1089d6 100644
--- a/synapse/storage/controllers/persist_events.py
+++ b/synapse/storage/controllers/persist_events.py
@@ -45,6 +45,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
+from synapse.handlers.worker_lock import DELETE_ROOM_LOCK_NAME
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
from synapse.logging.opentracing import (
SynapseTags,
@@ -338,6 +339,7 @@ class EventsPersistenceStorageController:
)
self._state_resolution_handler = hs.get_state_resolution_handler()
self._state_controller = state_controller
+ self.hs = hs
async def _process_event_persist_queue_task(
self,
@@ -350,15 +352,22 @@ class EventsPersistenceStorageController:
A dictionary of event ID to event ID we didn't persist as we already
had another event persisted with the same TXN ID.
"""
- if isinstance(task, _PersistEventsTask):
- return await self._persist_event_batch(room_id, task)
- elif isinstance(task, _UpdateCurrentStateTask):
- await self._update_current_state(room_id, task)
- return {}
- else:
- raise AssertionError(
- f"Found an unexpected task type in event persistence queue: {task}"
- )
+
+ # Ensure that the room can't be deleted while we're persisting events to
+ # it. We might already have taken out the lock, but since this is just a
+ # "read" lock its inherently reentrant.
+ async with self.hs.get_worker_locks_handler().acquire_read_write_lock(
+ DELETE_ROOM_LOCK_NAME, room_id, write=False
+ ):
+ if isinstance(task, _PersistEventsTask):
+ return await self._persist_event_batch(room_id, task)
+ elif isinstance(task, _UpdateCurrentStateTask):
+ await self._update_current_state(room_id, task)
+ return {}
+ else:
+ raise AssertionError(
+ f"Found an unexpected task type in event persistence queue: {task}"
+ )
@trace
async def persist_events(
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index b2cda52ce5..534dc32413 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -843,7 +843,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
* because the schema change is in a background update, it's not
* necessarily safe to assume that it will have been completed.
*/
- AND edge.is_state is ? /* False */
+ AND edge.is_state is FALSE
/**
* We only want backwards extremities that are older than or at
* the same position of the given `current_depth` (where older
@@ -886,7 +886,6 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
sql,
(
room_id,
- False,
current_depth,
self._clock.time_msec(),
BACKFILL_EVENT_EXPONENTIAL_BACKOFF_MAXIMUM_DOUBLING_STEPS,
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 2b83a69426..bd3f14fb71 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1455,8 +1455,8 @@ class PersistEventsStore:
},
)
- sql = "UPDATE events SET outlier = ? WHERE event_id = ?"
- txn.execute(sql, (False, event.event_id))
+ sql = "UPDATE events SET outlier = FALSE WHERE event_id = ?"
+ txn.execute(sql, (event.event_id,))
# Update the event_backward_extremities table now that this
# event isn't an outlier any more.
@@ -1549,13 +1549,13 @@ class PersistEventsStore:
for event, _ in events_and_contexts
if not event.internal_metadata.is_redacted()
]
- sql = "UPDATE redactions SET have_censored = ? WHERE "
+ sql = "UPDATE redactions SET have_censored = FALSE WHERE "
clause, args = make_in_list_sql_clause(
self.database_engine,
"redacts",
unredacted_events,
)
- txn.execute(sql + clause, [False] + args)
+ txn.execute(sql + clause, args)
self.db_pool.simple_insert_many_txn(
txn,
@@ -2318,14 +2318,14 @@ class PersistEventsStore:
" SELECT 1 FROM events"
" LEFT JOIN event_edges edge"
" ON edge.event_id = events.event_id"
- " WHERE events.event_id = ? AND events.room_id = ? AND (events.outlier = ? OR edge.event_id IS NULL)"
+ " WHERE events.event_id = ? AND events.room_id = ? AND (events.outlier = FALSE OR edge.event_id IS NULL)"
" )"
)
txn.execute_batch(
query,
[
- (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False)
+ (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id)
for ev in events
for e_id in ev.prev_event_ids()
if not ev.internal_metadata.is_outlier()
diff --git a/synapse/storage/databases/main/lock.py b/synapse/storage/databases/main/lock.py
index c89b4f7919..1680bf6168 100644
--- a/synapse/storage/databases/main/lock.py
+++ b/synapse/storage/databases/main/lock.py
@@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from contextlib import AsyncExitStack
from types import TracebackType
-from typing import TYPE_CHECKING, Optional, Set, Tuple, Type
+from typing import TYPE_CHECKING, Collection, Optional, Set, Tuple, Type
from weakref import WeakValueDictionary
from twisted.internet.interfaces import IReactorCore
@@ -208,76 +209,85 @@ class LockStore(SQLBaseStore):
used (otherwise the lock will leak).
"""
+ try:
+ lock = await self.db_pool.runInteraction(
+ "try_acquire_read_write_lock",
+ self._try_acquire_read_write_lock_txn,
+ lock_name,
+ lock_key,
+ write,
+ )
+ except self.database_engine.module.IntegrityError:
+ return None
+
+ return lock
+
+ def _try_acquire_read_write_lock_txn(
+ self,
+ txn: LoggingTransaction,
+ lock_name: str,
+ lock_key: str,
+ write: bool,
+ ) -> "Lock":
+ # We attempt to acquire the lock by inserting into
+ # `worker_read_write_locks` and seeing if that fails any
+ # constraints. If it doesn't then we have acquired the lock,
+ # otherwise we haven't.
+ #
+ # Before that though we clear the table of any stale locks.
+
now = self._clock.time_msec()
token = random_string(6)
- def _try_acquire_read_write_lock_txn(txn: LoggingTransaction) -> None:
- # We attempt to acquire the lock by inserting into
- # `worker_read_write_locks` and seeing if that fails any
- # constraints. If it doesn't then we have acquired the lock,
- # otherwise we haven't.
- #
- # Before that though we clear the table of any stale locks.
-
- delete_sql = """
- DELETE FROM worker_read_write_locks
- WHERE last_renewed_ts < ? AND lock_name = ? AND lock_key = ?;
- """
-
- insert_sql = """
- INSERT INTO worker_read_write_locks (lock_name, lock_key, write_lock, instance_name, token, last_renewed_ts)
- VALUES (?, ?, ?, ?, ?, ?)
- """
-
- if isinstance(self.database_engine, PostgresEngine):
- # For Postgres we can send these queries at the same time.
- txn.execute(
- delete_sql + ";" + insert_sql,
- (
- # DELETE args
- now - _LOCK_TIMEOUT_MS,
- lock_name,
- lock_key,
- # UPSERT args
- lock_name,
- lock_key,
- write,
- self._instance_name,
- token,
- now,
- ),
- )
- else:
- # For SQLite these need to be two queries.
- txn.execute(
- delete_sql,
- (
- now - _LOCK_TIMEOUT_MS,
- lock_name,
- lock_key,
- ),
- )
- txn.execute(
- insert_sql,
- (
- lock_name,
- lock_key,
- write,
- self._instance_name,
- token,
- now,
- ),
- )
+ delete_sql = """
+ DELETE FROM worker_read_write_locks
+ WHERE last_renewed_ts < ? AND lock_name = ? AND lock_key = ?;
+ """
- return
+ insert_sql = """
+ INSERT INTO worker_read_write_locks (lock_name, lock_key, write_lock, instance_name, token, last_renewed_ts)
+ VALUES (?, ?, ?, ?, ?, ?)
+ """
- try:
- await self.db_pool.runInteraction(
- "try_acquire_read_write_lock",
- _try_acquire_read_write_lock_txn,
+ if isinstance(self.database_engine, PostgresEngine):
+ # For Postgres we can send these queries at the same time.
+ txn.execute(
+ delete_sql + ";" + insert_sql,
+ (
+ # DELETE args
+ now - _LOCK_TIMEOUT_MS,
+ lock_name,
+ lock_key,
+ # UPSERT args
+ lock_name,
+ lock_key,
+ write,
+ self._instance_name,
+ token,
+ now,
+ ),
+ )
+ else:
+ # For SQLite these need to be two queries.
+ txn.execute(
+ delete_sql,
+ (
+ now - _LOCK_TIMEOUT_MS,
+ lock_name,
+ lock_key,
+ ),
+ )
+ txn.execute(
+ insert_sql,
+ (
+ lock_name,
+ lock_key,
+ write,
+ self._instance_name,
+ token,
+ now,
+ ),
)
- except self.database_engine.module.IntegrityError:
- return None
lock = Lock(
self._reactor,
@@ -289,10 +299,58 @@ class LockStore(SQLBaseStore):
token=token,
)
- self._live_read_write_lock_tokens[(lock_name, lock_key, token)] = lock
+ def set_lock() -> None:
+ self._live_read_write_lock_tokens[(lock_name, lock_key, token)] = lock
+
+ txn.call_after(set_lock)
return lock
+ async def try_acquire_multi_read_write_lock(
+ self,
+ lock_names: Collection[Tuple[str, str]],
+ write: bool,
+ ) -> Optional[AsyncExitStack]:
+ """Try to acquire multiple locks for the given names/keys. Will return
+ an async context manager if the locks are successfully acquired, which
+ *must* be used (otherwise the lock will leak).
+
+ If only a subset of the locks can be acquired then it will immediately
+ drop them and return `None`.
+ """
+ try:
+ locks = await self.db_pool.runInteraction(
+ "try_acquire_multi_read_write_lock",
+ self._try_acquire_multi_read_write_lock_txn,
+ lock_names,
+ write,
+ )
+ except self.database_engine.module.IntegrityError:
+ return None
+
+ stack = AsyncExitStack()
+
+ for lock in locks:
+ await stack.enter_async_context(lock)
+
+ return stack
+
+ def _try_acquire_multi_read_write_lock_txn(
+ self,
+ txn: LoggingTransaction,
+ lock_names: Collection[Tuple[str, str]],
+ write: bool,
+ ) -> Collection["Lock"]:
+ locks = []
+
+ for lock_name, lock_key in lock_names:
+ lock = self._try_acquire_read_write_lock_txn(
+ txn, lock_name, lock_key, write
+ )
+ locks.append(lock)
+
+ return locks
+
class Lock:
"""An async context manager that manages an acquired lock, ensuring it is
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index 9773c1fcd2..b52f48cf04 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -249,12 +249,11 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
# Mark all state and own events as outliers
logger.info("[purge] marking remaining events as outliers")
txn.execute(
- "UPDATE events SET outlier = ?"
+ "UPDATE events SET outlier = TRUE"
" WHERE event_id IN ("
- " SELECT event_id FROM events_to_purge "
- " WHERE NOT should_delete"
- ")",
- (True,),
+ " SELECT event_id FROM events_to_purge "
+ " WHERE NOT should_delete"
+ ")"
)
# synapse tries to take out an exclusive lock on room_depth whenever it
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index e098ceea3c..c13c0bc7d7 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -560,19 +560,19 @@ class PushRuleStore(PushRulesWorkerStore):
if isinstance(self.database_engine, PostgresEngine):
sql = """
INSERT INTO push_rules_enable (id, user_name, rule_id, enabled)
- VALUES (?, ?, ?, ?)
+ VALUES (?, ?, ?, 1)
ON CONFLICT DO NOTHING
"""
elif isinstance(self.database_engine, Sqlite3Engine):
sql = """
INSERT OR IGNORE INTO push_rules_enable (id, user_name, rule_id, enabled)
- VALUES (?, ?, ?, ?)
+ VALUES (?, ?, ?, 1)
"""
else:
raise RuntimeError("Unknown database engine")
new_enable_id = self._push_rules_enable_id_gen.get_next()
- txn.execute(sql, (new_enable_id, user_id, rule_id, 1))
+ txn.execute(sql, (new_enable_id, user_id, rule_id))
async def delete_push_rule(self, user_id: str, rule_id: str) -> None:
"""
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 676d03bb7e..c582cf0573 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -454,9 +454,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
) -> List[Tuple[str, int]]:
sql = (
"SELECT user_id, expiration_ts_ms FROM account_validity"
- " WHERE email_sent = ? AND (expiration_ts_ms - ?) <= ?"
+ " WHERE email_sent = FALSE AND (expiration_ts_ms - ?) <= ?"
)
- values = [False, now_ms, renew_at]
+ values = [now_ms, renew_at]
txn.execute(sql, values)
return cast(List[Tuple[str, int]], txn.fetchall())
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 830658f328..719e11aea6 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -936,11 +936,11 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
JOIN event_json USING (room_id, event_id)
WHERE room_id = ?
%(where_clause)s
- AND contains_url = ? AND outlier = ?
+ AND contains_url = TRUE AND outlier = FALSE
ORDER BY stream_ordering DESC
LIMIT ?
"""
- txn.execute(sql % {"where_clause": ""}, (room_id, True, False, 100))
+ txn.execute(sql % {"where_clause": ""}, (room_id, 100))
local_media_mxcs = []
remote_media_mxcs = []
@@ -976,7 +976,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
txn.execute(
sql % {"where_clause": "AND stream_ordering < ?"},
- (room_id, next_token, True, False, 100),
+ (room_id, next_token, 100),
)
return local_media_mxcs, remote_media_mxcs
@@ -1086,9 +1086,9 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
# set quarantine
if quarantined_by is not None:
- sql += "AND safe_from_quarantine = ?"
+ sql += "AND safe_from_quarantine = FALSE"
txn.executemany(
- sql, [(quarantined_by, media_id, False) for media_id in local_mxcs]
+ sql, [(quarantined_by, media_id) for media_id in local_mxcs]
)
# remove from quarantine
else:
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 92cbe262a6..5a3611c415 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -1401,7 +1401,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
`to_token`), or `limit` is zero.
"""
- args = [False, room_id]
+ args: List[Any] = [room_id]
order, from_bound, to_bound = generate_pagination_bounds(
direction, from_token, to_token
@@ -1475,7 +1475,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
event.topological_ordering, event.stream_ordering
FROM events AS event
%(join_clause)s
- WHERE event.outlier = ? AND event.room_id = ? AND %(bounds)s
+ WHERE event.outlier = FALSE AND event.room_id = ? AND %(bounds)s
ORDER BY event.topological_ordering %(order)s,
event.stream_ordering %(order)s LIMIT ?
""" % {
|