diff options
Diffstat (limited to 'synapse')
22 files changed, 726 insertions, 237 deletions
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index 5fb3d5083d..359999f680 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -17,8 +17,6 @@ import urllib.parse from typing import ( TYPE_CHECKING, Any, - Awaitable, - Callable, Dict, Iterable, List, @@ -30,7 +28,7 @@ from typing import ( ) from prometheus_client import Counter -from typing_extensions import Concatenate, ParamSpec, TypeGuard +from typing_extensions import ParamSpec, TypeGuard from synapse.api.constants import EventTypes, Membership, ThirdPartyEntityKind from synapse.api.errors import CodeMessageException, HttpResponseException @@ -80,9 +78,7 @@ sent_todevice_counter = Counter( HOUR_IN_MS = 60 * 60 * 1000 - APP_SERVICE_PREFIX = "/_matrix/app/v1" -APP_SERVICE_UNSTABLE_PREFIX = "/_matrix/app/unstable" P = ParamSpec("P") R = TypeVar("R") @@ -128,47 +124,6 @@ class ApplicationServiceApi(SimpleHttpClient): hs.get_clock(), "as_protocol_meta", timeout_ms=HOUR_IN_MS ) - async def _send_with_fallbacks( - self, - service: "ApplicationService", - prefixes: List[str], - path: str, - func: Callable[Concatenate[str, P], Awaitable[R]], - *args: P.args, - **kwargs: P.kwargs, - ) -> R: - """ - Attempt to call an application service with multiple paths, falling back - until one succeeds. - - Args: - service: The appliacation service, this provides the base URL. - prefixes: A last of paths to try in order for the requests. - path: A suffix to append to each prefix. - func: The function to call, the first argument will be the full - endpoint to fetch. Other arguments are provided by args/kwargs. - - Returns: - The return value of func. - """ - for i, prefix in enumerate(prefixes, start=1): - uri = f"{service.url}{prefix}{path}" - try: - return await func(uri, *args, **kwargs) - except HttpResponseException as e: - # If an error is received that is due to an unrecognised path, - # fallback to next path (if one exists). Otherwise, consider it - # a legitimate error and raise. - if i < len(prefixes) and is_unknown_endpoint(e): - continue - raise - except Exception: - # Unexpected exceptions get sent to the caller. - raise - - # The function should always exit via the return or raise above this. - raise RuntimeError("Unexpected fallback behaviour. This should never be seen.") - async def query_user(self, service: "ApplicationService", user_id: str) -> bool: if service.url is None: return False @@ -177,11 +132,8 @@ class ApplicationServiceApi(SimpleHttpClient): assert service.hs_token is not None try: - response = await self._send_with_fallbacks( - service, - [APP_SERVICE_PREFIX, ""], - f"/users/{urllib.parse.quote(user_id)}", - self.get_json, + response = await self.get_json( + f"{service.url}{APP_SERVICE_PREFIX}/users/{urllib.parse.quote(user_id)}", {"access_token": service.hs_token}, headers={"Authorization": [f"Bearer {service.hs_token}"]}, ) @@ -203,11 +155,8 @@ class ApplicationServiceApi(SimpleHttpClient): assert service.hs_token is not None try: - response = await self._send_with_fallbacks( - service, - [APP_SERVICE_PREFIX, ""], - f"/rooms/{urllib.parse.quote(alias)}", - self.get_json, + response = await self.get_json( + f"{service.url}{APP_SERVICE_PREFIX}/rooms/{urllib.parse.quote(alias)}", {"access_token": service.hs_token}, headers={"Authorization": [f"Bearer {service.hs_token}"]}, ) @@ -245,11 +194,8 @@ class ApplicationServiceApi(SimpleHttpClient): **fields, b"access_token": service.hs_token, } - response = await self._send_with_fallbacks( - service, - [APP_SERVICE_PREFIX, APP_SERVICE_UNSTABLE_PREFIX], - f"/thirdparty/{kind}/{urllib.parse.quote(protocol)}", - self.get_json, + response = await self.get_json( + f"{service.url}{APP_SERVICE_PREFIX}/thirdparty/{kind}/{urllib.parse.quote(protocol)}", args=args, headers={"Authorization": [f"Bearer {service.hs_token}"]}, ) @@ -285,11 +231,8 @@ class ApplicationServiceApi(SimpleHttpClient): # This is required by the configuration. assert service.hs_token is not None try: - info = await self._send_with_fallbacks( - service, - [APP_SERVICE_PREFIX, APP_SERVICE_UNSTABLE_PREFIX], - f"/thirdparty/protocol/{urllib.parse.quote(protocol)}", - self.get_json, + info = await self.get_json( + f"{service.url}{APP_SERVICE_PREFIX}/thirdparty/protocol/{urllib.parse.quote(protocol)}", {"access_token": service.hs_token}, headers={"Authorization": [f"Bearer {service.hs_token}"]}, ) @@ -401,11 +344,8 @@ class ApplicationServiceApi(SimpleHttpClient): } try: - await self._send_with_fallbacks( - service, - [APP_SERVICE_PREFIX, ""], - f"/transactions/{urllib.parse.quote(str(txn_id))}", - self.put_json, + await self.put_json( + f"{service.url}{APP_SERVICE_PREFIX}/transactions/{urllib.parse.quote(str(txn_id))}", json_body=body, args={"access_token": service.hs_token}, headers={"Authorization": [f"Bearer {service.hs_token}"]}, diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index fa61dd8c10..a90d99c4d6 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -63,6 +63,7 @@ from synapse.federation.federation_base import ( ) from synapse.federation.persistence import TransactionActions from synapse.federation.units import Edu, Transaction +from synapse.handlers.worker_lock import DELETE_ROOM_LOCK_NAME from synapse.http.servlet import assert_params_in_dict from synapse.logging.context import ( make_deferred_yieldable, @@ -137,6 +138,7 @@ class FederationServer(FederationBase): self._event_auth_handler = hs.get_event_auth_handler() self._room_member_handler = hs.get_room_member_handler() self._e2e_keys_handler = hs.get_e2e_keys_handler() + self._worker_lock_handler = hs.get_worker_locks_handler() self._state_storage_controller = hs.get_storage_controllers().state @@ -1236,9 +1238,18 @@ class FederationServer(FederationBase): logger.info("handling received PDU in room %s: %s", room_id, event) try: with nested_logging_context(event.event_id): - await self._federation_event_handler.on_receive_pdu( - origin, event - ) + # We're taking out a lock within a lock, which could + # lead to deadlocks if we're not careful. However, it is + # safe on this occasion as we only ever take a write + # lock when deleting a room, which we would never do + # while holding the `_INBOUND_EVENT_HANDLING_LOCK_NAME` + # lock. + async with self._worker_lock_handler.acquire_read_write_lock( + DELETE_ROOM_LOCK_NAME, room_id, write=False + ): + await self._federation_event_handler.on_receive_pdu( + origin, event + ) except FederationError as e: # XXX: Ideally we'd inform the remote we failed to process # the event, but we can't return an error in the transaction diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index fff0b5fa12..c656e07d37 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -53,6 +53,7 @@ from synapse.events.snapshot import EventContext, UnpersistedEventContextBase from synapse.events.utils import SerializeEventConfig, maybe_upsert_event_field from synapse.events.validator import EventValidator from synapse.handlers.directory import DirectoryHandler +from synapse.handlers.worker_lock import DELETE_ROOM_LOCK_NAME from synapse.logging import opentracing from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics.background_process_metrics import run_as_background_process @@ -485,6 +486,7 @@ class EventCreationHandler: self._events_shard_config = self.config.worker.events_shard_config self._instance_name = hs.get_instance_name() self._notifier = hs.get_notifier() + self._worker_lock_handler = hs.get_worker_locks_handler() self.room_prejoin_state_types = self.hs.config.api.room_prejoin_state @@ -876,14 +878,13 @@ class EventCreationHandler: return prev_event return None - async def get_event_from_transaction( + async def get_event_id_from_transaction( self, requester: Requester, txn_id: str, room_id: str, - ) -> Optional[EventBase]: - """For the given transaction ID and room ID, check if there is a matching event. - If so, fetch it and return it. + ) -> Optional[str]: + """For the given transaction ID and room ID, check if there is a matching event ID. Args: requester: The requester making the request in the context of which we want @@ -892,8 +893,9 @@ class EventCreationHandler: room_id: The room ID. Returns: - An event if one could be found, None otherwise. + An event ID if one could be found, None otherwise. """ + existing_event_id = None if self._msc3970_enabled and requester.device_id: # When MSC3970 is enabled, we lookup for events sent by the same device first, @@ -907,7 +909,7 @@ class EventCreationHandler: ) ) if existing_event_id: - return await self.store.get_event(existing_event_id) + return existing_event_id # Pre-MSC3970, we looked up for events that were sent by the same session by # using the access token ID. @@ -920,9 +922,32 @@ class EventCreationHandler: txn_id, ) ) - if existing_event_id: - return await self.store.get_event(existing_event_id) + return existing_event_id + + async def get_event_from_transaction( + self, + requester: Requester, + txn_id: str, + room_id: str, + ) -> Optional[EventBase]: + """For the given transaction ID and room ID, check if there is a matching event. + If so, fetch it and return it. + + Args: + requester: The requester making the request in the context of which we want + to fetch the event. + txn_id: The transaction ID. + room_id: The room ID. + + Returns: + An event if one could be found, None otherwise. + """ + existing_event_id = await self.get_event_id_from_transaction( + requester, txn_id, room_id + ) + if existing_event_id: + return await self.store.get_event(existing_event_id) return None async def create_and_send_nonmember_event( @@ -1010,6 +1035,37 @@ class EventCreationHandler: event.internal_metadata.stream_ordering, ) + async with self._worker_lock_handler.acquire_read_write_lock( + DELETE_ROOM_LOCK_NAME, room_id, write=False + ): + return await self._create_and_send_nonmember_event_locked( + requester=requester, + event_dict=event_dict, + allow_no_prev_events=allow_no_prev_events, + prev_event_ids=prev_event_ids, + state_event_ids=state_event_ids, + ratelimit=ratelimit, + txn_id=txn_id, + ignore_shadow_ban=ignore_shadow_ban, + outlier=outlier, + depth=depth, + ) + + async def _create_and_send_nonmember_event_locked( + self, + requester: Requester, + event_dict: dict, + allow_no_prev_events: bool = False, + prev_event_ids: Optional[List[str]] = None, + state_event_ids: Optional[List[str]] = None, + ratelimit: bool = True, + txn_id: Optional[str] = None, + ignore_shadow_ban: bool = False, + outlier: bool = False, + depth: Optional[int] = None, + ) -> Tuple[EventBase, int]: + room_id = event_dict["room_id"] + # If we don't have any prev event IDs specified then we need to # check that the host is in the room (as otherwise populating the # prev events will fail), at which point we may as well check the @@ -1923,7 +1979,10 @@ class EventCreationHandler: ) for room_id in room_ids: - dummy_event_sent = await self._send_dummy_event_for_room(room_id) + async with self._worker_lock_handler.acquire_read_write_lock( + DELETE_ROOM_LOCK_NAME, room_id, write=False + ): + dummy_event_sent = await self._send_dummy_event_for_room(room_id) if not dummy_event_sent: # Did not find a valid user in the room, so remove from future attempts diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 19b8728db9..da34658470 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -46,6 +46,11 @@ logger = logging.getLogger(__name__) BACKFILL_BECAUSE_TOO_MANY_GAPS_THRESHOLD = 3 +PURGE_HISTORY_LOCK_NAME = "purge_history_lock" + +DELETE_ROOM_LOCK_NAME = "delete_room_lock" + + @attr.s(slots=True, auto_attribs=True) class PurgeStatus: """Object tracking the status of a purge request @@ -142,6 +147,7 @@ class PaginationHandler: self._server_name = hs.hostname self._room_shutdown_handler = hs.get_room_shutdown_handler() self._relations_handler = hs.get_relations_handler() + self._worker_locks = hs.get_worker_locks_handler() self.pagination_lock = ReadWriteLock() # IDs of rooms in which there currently an active purge *or delete* operation. @@ -356,7 +362,9 @@ class PaginationHandler: """ self._purges_in_progress_by_room.add(room_id) try: - async with self.pagination_lock.write(room_id): + async with self._worker_locks.acquire_read_write_lock( + PURGE_HISTORY_LOCK_NAME, room_id, write=True + ): await self._storage_controllers.purge_events.purge_history( room_id, token, delete_local_events ) @@ -412,7 +420,10 @@ class PaginationHandler: room_id: room to be purged force: set true to skip checking for joined users. """ - async with self.pagination_lock.write(room_id): + async with self._worker_locks.acquire_multi_read_write_lock( + [(PURGE_HISTORY_LOCK_NAME, room_id), (DELETE_ROOM_LOCK_NAME, room_id)], + write=True, + ): # first check that we have no users in this room if not force: joined = await self.store.is_host_joined(room_id, self._server_name) @@ -471,7 +482,9 @@ class PaginationHandler: room_token = from_token.room_key - async with self.pagination_lock.read(room_id): + async with self._worker_locks.acquire_read_write_lock( + PURGE_HISTORY_LOCK_NAME, room_id, write=False + ): (membership, member_event_id) = (None, None) if not use_admin_priviledge: ( @@ -747,7 +760,9 @@ class PaginationHandler: self._purges_in_progress_by_room.add(room_id) try: - async with self.pagination_lock.write(room_id): + async with self._worker_locks.acquire_read_write_lock( + PURGE_HISTORY_LOCK_NAME, room_id, write=True + ): self._delete_by_id[delete_id].status = DeleteStatus.STATUS_SHUTTING_DOWN self._delete_by_id[ delete_id diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index a7f8c5e636..c7fe101cd9 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -68,7 +68,7 @@ class ProfileHandler: if self.hs.is_mine(target_user): profileinfo = await self.store.get_profileinfo(target_user) - if profileinfo.display_name is None: + if profileinfo.display_name is None and profileinfo.avatar_url is None: raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND) return { diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 496e701f13..e3cdf2bc61 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -39,6 +39,7 @@ from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN from synapse.handlers.state_deltas import MatchChange, StateDeltasHandler +from synapse.handlers.worker_lock import DELETE_ROOM_LOCK_NAME from synapse.logging import opentracing from synapse.metrics import event_processing_positions from synapse.metrics.background_process_metrics import run_as_background_process @@ -94,6 +95,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): self.event_creation_handler = hs.get_event_creation_handler() self.account_data_handler = hs.get_account_data_handler() self.event_auth_handler = hs.get_event_auth_handler() + self._worker_lock_handler = hs.get_worker_locks_handler() self.member_linearizer: Linearizer = Linearizer(name="member") self.member_as_limiter = Linearizer(max_count=10, name="member_as_limiter") @@ -174,8 +176,6 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): self.request_ratelimiter = hs.get_request_ratelimiter() hs.get_notifier().add_new_join_in_room_callback(self._on_user_joined_room) - self._msc3970_enabled = hs.config.experimental.msc3970_enabled - def _on_user_joined_room(self, event_id: str, room_id: str) -> None: """Notify the rate limiter that a room join has occurred. @@ -416,29 +416,11 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # do this check just before we persist an event as well, but may as well # do it up front for efficiency.) if txn_id: - existing_event_id = None - if self._msc3970_enabled and requester.device_id: - # When MSC3970 is enabled, we lookup for events sent by the same device - # first, and fallback to the old behaviour if none were found. - existing_event_id = ( - await self.store.get_event_id_from_transaction_id_and_device_id( - room_id, - requester.user.to_string(), - requester.device_id, - txn_id, - ) + existing_event_id = ( + await self.event_creation_handler.get_event_id_from_transaction( + requester, txn_id, room_id ) - - if requester.access_token_id and not existing_event_id: - existing_event_id = ( - await self.store.get_event_id_from_transaction_id_and_token_id( - room_id, - requester.user.to_string(), - requester.access_token_id, - txn_id, - ) - ) - + ) if existing_event_id: event_pos = await self.store.get_position_for_event(existing_event_id) return existing_event_id, event_pos.stream @@ -638,26 +620,29 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # by application services), and then by room ID. async with self.member_as_limiter.queue(as_id): async with self.member_linearizer.queue(key): - with opentracing.start_active_span("update_membership_locked"): - result = await self.update_membership_locked( - requester, - target, - room_id, - action, - txn_id=txn_id, - remote_room_hosts=remote_room_hosts, - third_party_signed=third_party_signed, - ratelimit=ratelimit, - content=content, - new_room=new_room, - require_consent=require_consent, - outlier=outlier, - allow_no_prev_events=allow_no_prev_events, - prev_event_ids=prev_event_ids, - state_event_ids=state_event_ids, - depth=depth, - origin_server_ts=origin_server_ts, - ) + async with self._worker_lock_handler.acquire_read_write_lock( + DELETE_ROOM_LOCK_NAME, room_id, write=False + ): + with opentracing.start_active_span("update_membership_locked"): + result = await self.update_membership_locked( + requester, + target, + room_id, + action, + txn_id=txn_id, + remote_room_hosts=remote_room_hosts, + third_party_signed=third_party_signed, + ratelimit=ratelimit, + content=content, + new_room=new_room, + require_consent=require_consent, + outlier=outlier, + allow_no_prev_events=allow_no_prev_events, + prev_event_ids=prev_event_ids, + state_event_ids=state_event_ids, + depth=depth, + origin_server_ts=origin_server_ts, + ) return result diff --git a/synapse/handlers/worker_lock.py b/synapse/handlers/worker_lock.py new file mode 100644 index 0000000000..72df773a86 --- /dev/null +++ b/synapse/handlers/worker_lock.py @@ -0,0 +1,333 @@ +# Copyright 2023 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +from types import TracebackType +from typing import ( + TYPE_CHECKING, + AsyncContextManager, + Collection, + Dict, + Optional, + Tuple, + Type, + Union, +) +from weakref import WeakSet + +import attr + +from twisted.internet import defer +from twisted.internet.interfaces import IReactorTime + +from synapse.logging.context import PreserveLoggingContext +from synapse.logging.opentracing import start_active_span +from synapse.metrics.background_process_metrics import wrap_as_background_process +from synapse.storage.databases.main.lock import Lock, LockStore +from synapse.util.async_helpers import timeout_deferred + +if TYPE_CHECKING: + from synapse.logging.opentracing import opentracing + from synapse.server import HomeServer + + +DELETE_ROOM_LOCK_NAME = "delete_room_lock" + + +class WorkerLocksHandler: + """A class for waiting on taking out locks, rather than using the storage + functions directly (which don't support awaiting). + """ + + def __init__(self, hs: "HomeServer") -> None: + self._reactor = hs.get_reactor() + self._store = hs.get_datastores().main + self._clock = hs.get_clock() + self._notifier = hs.get_notifier() + self._instance_name = hs.get_instance_name() + + # Map from lock name/key to set of `WaitingLock` that are active for + # that lock. + self._locks: Dict[ + Tuple[str, str], WeakSet[Union[WaitingLock, WaitingMultiLock]] + ] = {} + + self._clock.looping_call(self._cleanup_locks, 30_000) + + self._notifier.add_lock_released_callback(self._on_lock_released) + + def acquire_lock(self, lock_name: str, lock_key: str) -> "WaitingLock": + """Acquire a standard lock, returns a context manager that will block + until the lock is acquired. + + Note: Care must be taken to avoid deadlocks. In particular, this + function does *not* timeout. + + Usage: + async with handler.acquire_lock(name, key): + # Do work while holding the lock... + """ + + lock = WaitingLock( + reactor=self._reactor, + store=self._store, + handler=self, + lock_name=lock_name, + lock_key=lock_key, + write=None, + ) + + self._locks.setdefault((lock_name, lock_key), WeakSet()).add(lock) + + return lock + + def acquire_read_write_lock( + self, + lock_name: str, + lock_key: str, + *, + write: bool, + ) -> "WaitingLock": + """Acquire a read/write lock, returns a context manager that will block + until the lock is acquired. + + Note: Care must be taken to avoid deadlocks. In particular, this + function does *not* timeout. + + Usage: + async with handler.acquire_read_write_lock(name, key, write=True): + # Do work while holding the lock... + """ + + lock = WaitingLock( + reactor=self._reactor, + store=self._store, + handler=self, + lock_name=lock_name, + lock_key=lock_key, + write=write, + ) + + self._locks.setdefault((lock_name, lock_key), WeakSet()).add(lock) + + return lock + + def acquire_multi_read_write_lock( + self, + lock_names: Collection[Tuple[str, str]], + *, + write: bool, + ) -> "WaitingMultiLock": + """Acquires multi read/write locks at once, returns a context manager + that will block until all the locks are acquired. + + This will try and acquire all locks at once, and will never hold on to a + subset of the locks. (This avoids accidentally creating deadlocks). + + Note: Care must be taken to avoid deadlocks. In particular, this + function does *not* timeout. + """ + + lock = WaitingMultiLock( + lock_names=lock_names, + write=write, + reactor=self._reactor, + store=self._store, + handler=self, + ) + + for lock_name, lock_key in lock_names: + self._locks.setdefault((lock_name, lock_key), WeakSet()).add(lock) + + return lock + + def notify_lock_released(self, lock_name: str, lock_key: str) -> None: + """Notify that a lock has been released. + + Pokes both the notifier and replication. + """ + + self._notifier.notify_lock_released(self._instance_name, lock_name, lock_key) + + def _on_lock_released( + self, instance_name: str, lock_name: str, lock_key: str + ) -> None: + """Called when a lock has been released. + + Wakes up any locks that might be waiting on this. + """ + locks = self._locks.get((lock_name, lock_key)) + if not locks: + return + + def _wake_deferred(deferred: defer.Deferred) -> None: + if not deferred.called: + deferred.callback(None) + + for lock in locks: + self._clock.call_later(0, _wake_deferred, lock.deferred) + + @wrap_as_background_process("_cleanup_locks") + async def _cleanup_locks(self) -> None: + """Periodically cleans out stale entries in the locks map""" + self._locks = {key: value for key, value in self._locks.items() if value} + + +@attr.s(auto_attribs=True, eq=False) +class WaitingLock: + reactor: IReactorTime + store: LockStore + handler: WorkerLocksHandler + lock_name: str + lock_key: str + write: Optional[bool] + deferred: "defer.Deferred[None]" = attr.Factory(defer.Deferred) + _inner_lock: Optional[Lock] = None + _retry_interval: float = 0.1 + _lock_span: "opentracing.Scope" = attr.Factory( + lambda: start_active_span("WaitingLock.lock") + ) + + async def __aenter__(self) -> None: + self._lock_span.__enter__() + + with start_active_span("WaitingLock.waiting_for_lock"): + while self._inner_lock is None: + self.deferred = defer.Deferred() + + if self.write is not None: + lock = await self.store.try_acquire_read_write_lock( + self.lock_name, self.lock_key, write=self.write + ) + else: + lock = await self.store.try_acquire_lock( + self.lock_name, self.lock_key + ) + + if lock: + self._inner_lock = lock + break + + try: + # Wait until the we get notified the lock might have been + # released (by the deferred being resolved). We also + # periodically wake up in case the lock was released but we + # weren't notified. + with PreserveLoggingContext(): + await timeout_deferred( + deferred=self.deferred, + timeout=self._get_next_retry_interval(), + reactor=self.reactor, + ) + except Exception: + pass + + return await self._inner_lock.__aenter__() + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc: Optional[BaseException], + tb: Optional[TracebackType], + ) -> Optional[bool]: + assert self._inner_lock + + self.handler.notify_lock_released(self.lock_name, self.lock_key) + + try: + r = await self._inner_lock.__aexit__(exc_type, exc, tb) + finally: + self._lock_span.__exit__(exc_type, exc, tb) + + return r + + def _get_next_retry_interval(self) -> float: + next = self._retry_interval + self._retry_interval = max(5, next * 2) + return next * random.uniform(0.9, 1.1) + + +@attr.s(auto_attribs=True, eq=False) +class WaitingMultiLock: + lock_names: Collection[Tuple[str, str]] + + write: bool + + reactor: IReactorTime + store: LockStore + handler: WorkerLocksHandler + + deferred: "defer.Deferred[None]" = attr.Factory(defer.Deferred) + + _inner_lock_cm: Optional[AsyncContextManager] = None + _retry_interval: float = 0.1 + _lock_span: "opentracing.Scope" = attr.Factory( + lambda: start_active_span("WaitingLock.lock") + ) + + async def __aenter__(self) -> None: + self._lock_span.__enter__() + + with start_active_span("WaitingLock.waiting_for_lock"): + while self._inner_lock_cm is None: + self.deferred = defer.Deferred() + + lock_cm = await self.store.try_acquire_multi_read_write_lock( + self.lock_names, write=self.write + ) + + if lock_cm: + self._inner_lock_cm = lock_cm + break + + try: + # Wait until the we get notified the lock might have been + # released (by the deferred being resolved). We also + # periodically wake up in case the lock was released but we + # weren't notified. + with PreserveLoggingContext(): + await timeout_deferred( + deferred=self.deferred, + timeout=self._get_next_retry_interval(), + reactor=self.reactor, + ) + except Exception: + pass + + assert self._inner_lock_cm + await self._inner_lock_cm.__aenter__() + return + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc: Optional[BaseException], + tb: Optional[TracebackType], + ) -> Optional[bool]: + assert self._inner_lock_cm + + for lock_name, lock_key in self.lock_names: + self.handler.notify_lock_released(lock_name, lock_key) + + try: + r = await self._inner_lock_cm.__aexit__(exc_type, exc, tb) + finally: + self._lock_span.__exit__(exc_type, exc, tb) + + return r + + def _get_next_retry_interval(self) -> float: + next = self._retry_interval + self._retry_interval = max(5, next * 2) + return next * random.uniform(0.9, 1.1) diff --git a/synapse/notifier.py b/synapse/notifier.py index 897272ad5b..68115bca70 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -234,6 +234,9 @@ class Notifier: self._third_party_rules = hs.get_module_api_callbacks().third_party_event_rules + # List of callbacks to be notified when a lock is released + self._lock_released_callback: List[Callable[[str, str, str], None]] = [] + self.clock = hs.get_clock() self.appservice_handler = hs.get_application_service_handler() self._pusher_pool = hs.get_pusherpool() @@ -785,6 +788,19 @@ class Notifier: # that any in flight requests can be immediately retried. self._federation_client.wake_destination(server) + def add_lock_released_callback( + self, callback: Callable[[str, str, str], None] + ) -> None: + """Add a function to be called whenever we are notified about a released lock.""" + self._lock_released_callback.append(callback) + + def notify_lock_released( + self, instance_name: str, lock_name: str, lock_key: str + ) -> None: + """Notify the callbacks that a lock has been released.""" + for cb in self._lock_released_callback: + cb(instance_name, lock_name, lock_key) + @attr.s(auto_attribs=True) class ReplicationNotifier: diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py index f874f072f9..73f3de3642 100644 --- a/synapse/replication/http/devices.py +++ b/synapse/replication/http/devices.py @@ -107,8 +107,7 @@ class ReplicationUploadKeysForUserRestServlet(ReplicationEndpoint): Calls to e2e_keys_handler.upload_keys_for_user(user_id, device_id, keys) on the main process to accomplish this. - Defined in https://spec.matrix.org/v1.4/client-server-api/#post_matrixclientv3keysupload - Request format(borrowed and expanded from KeyUploadServlet): + Request format for this endpoint (borrowed and expanded from KeyUploadServlet): POST /_synapse/replication/upload_keys_for_user @@ -117,6 +116,7 @@ class ReplicationUploadKeysForUserRestServlet(ReplicationEndpoint): "device_id": "<device_id>", "keys": { ....this part can be found in KeyUploadServlet in rest/client/keys.py.... + or as defined in https://spec.matrix.org/v1.4/client-server-api/#post_matrixclientv3keysupload } } diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index 32f52e54d8..10f5c98ff8 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -422,6 +422,36 @@ class RemoteServerUpCommand(_SimpleCommand): NAME = "REMOTE_SERVER_UP" +class LockReleasedCommand(Command): + """Sent to inform other instances that a given lock has been dropped. + + Format:: + + LOCK_RELEASED ["<instance_name>", "<lock_name>", "<lock_key>"] + """ + + NAME = "LOCK_RELEASED" + + def __init__( + self, + instance_name: str, + lock_name: str, + lock_key: str, + ): + self.instance_name = instance_name + self.lock_name = lock_name + self.lock_key = lock_key + + @classmethod + def from_line(cls: Type["LockReleasedCommand"], line: str) -> "LockReleasedCommand": + instance_name, lock_name, lock_key = json_decoder.decode(line) + + return cls(instance_name, lock_name, lock_key) + + def to_line(self) -> str: + return json_encoder.encode([self.instance_name, self.lock_name, self.lock_key]) + + _COMMANDS: Tuple[Type[Command], ...] = ( ServerCommand, RdataCommand, @@ -435,6 +465,7 @@ _COMMANDS: Tuple[Type[Command], ...] = ( UserIpCommand, RemoteServerUpCommand, ClearUserSyncsCommand, + LockReleasedCommand, ) # Map of command name to command type. @@ -448,6 +479,7 @@ VALID_SERVER_COMMANDS = ( ErrorCommand.NAME, PingCommand.NAME, RemoteServerUpCommand.NAME, + LockReleasedCommand.NAME, ) # The commands the client is allowed to send @@ -461,6 +493,7 @@ VALID_CLIENT_COMMANDS = ( UserIpCommand.NAME, ErrorCommand.NAME, RemoteServerUpCommand.NAME, + LockReleasedCommand.NAME, ) diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 5d108fe11b..a2cabba7b1 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -39,6 +39,7 @@ from synapse.replication.tcp.commands import ( ClearUserSyncsCommand, Command, FederationAckCommand, + LockReleasedCommand, PositionCommand, RdataCommand, RemoteServerUpCommand, @@ -248,6 +249,9 @@ class ReplicationCommandHandler: if self._is_master or self._should_insert_client_ips: self.subscribe_to_channel("USER_IP") + if hs.config.redis.redis_enabled: + self._notifier.add_lock_released_callback(self.on_lock_released) + def subscribe_to_channel(self, channel_name: str) -> None: """ Indicates that we wish to subscribe to a Redis channel by name. @@ -648,6 +652,17 @@ class ReplicationCommandHandler: self._notifier.notify_remote_server_up(cmd.data) + def on_LOCK_RELEASED( + self, conn: IReplicationConnection, cmd: LockReleasedCommand + ) -> None: + """Called when we get a new LOCK_RELEASED command.""" + if cmd.instance_name == self._instance_name: + return + + self._notifier.notify_lock_released( + cmd.instance_name, cmd.lock_name, cmd.lock_key + ) + def new_connection(self, connection: IReplicationConnection) -> None: """Called when we have a new connection.""" self._connections.append(connection) @@ -754,6 +769,13 @@ class ReplicationCommandHandler: """ self.send_command(RdataCommand(stream_name, self._instance_name, token, data)) + def on_lock_released( + self, instance_name: str, lock_name: str, lock_key: str + ) -> None: + """Called when we released a lock and should notify other instances.""" + if instance_name == self._instance_name: + self.send_command(LockReleasedCommand(instance_name, lock_name, lock_key)) + UpdateToken = TypeVar("UpdateToken") UpdateRow = TypeVar("UpdateRow") diff --git a/synapse/rest/client/room_upgrade_rest_servlet.py b/synapse/rest/client/room_upgrade_rest_servlet.py index 6a7792e18b..4a5d9e13e7 100644 --- a/synapse/rest/client/room_upgrade_rest_servlet.py +++ b/synapse/rest/client/room_upgrade_rest_servlet.py @@ -17,6 +17,7 @@ from typing import TYPE_CHECKING, Tuple from synapse.api.errors import Codes, ShadowBanError, SynapseError from synapse.api.room_versions import KNOWN_ROOM_VERSIONS +from synapse.handlers.worker_lock import DELETE_ROOM_LOCK_NAME from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, @@ -60,6 +61,7 @@ class RoomUpgradeRestServlet(RestServlet): self._hs = hs self._room_creation_handler = hs.get_room_creation_handler() self._auth = hs.get_auth() + self._worker_lock_handler = hs.get_worker_locks_handler() async def on_POST( self, request: SynapseRequest, room_id: str @@ -78,9 +80,12 @@ class RoomUpgradeRestServlet(RestServlet): ) try: - new_room_id = await self._room_creation_handler.upgrade_room( - requester, room_id, new_version - ) + async with self._worker_lock_handler.acquire_read_write_lock( + DELETE_ROOM_LOCK_NAME, room_id, write=False + ): + new_room_id = await self._room_creation_handler.upgrade_room( + requester, room_id, new_version + ) except ShadowBanError: # Generate a random room ID. new_room_id = stringutils.random_string(18) diff --git a/synapse/server.py b/synapse/server.py index b72b76a38b..8430f99ef2 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -107,6 +107,7 @@ from synapse.handlers.stats import StatsHandler from synapse.handlers.sync import SyncHandler from synapse.handlers.typing import FollowerTypingHandler, TypingWriterHandler from synapse.handlers.user_directory import UserDirectoryHandler +from synapse.handlers.worker_lock import WorkerLocksHandler from synapse.http.client import ( InsecureInterceptableContextFactory, ReplicationClient, @@ -912,3 +913,7 @@ class HomeServer(metaclass=abc.ABCMeta): def get_common_usage_metrics_manager(self) -> CommonUsageMetricsManager: """Usage metrics shared between phone home stats and the prometheus exporter.""" return CommonUsageMetricsManager(self) + + @cache_in_self + def get_worker_locks_handler(self) -> WorkerLocksHandler: + return WorkerLocksHandler(self) diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py index 35c0680365..35cd1089d6 100644 --- a/synapse/storage/controllers/persist_events.py +++ b/synapse/storage/controllers/persist_events.py @@ -45,6 +45,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership from synapse.events import EventBase from synapse.events.snapshot import EventContext +from synapse.handlers.worker_lock import DELETE_ROOM_LOCK_NAME from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable from synapse.logging.opentracing import ( SynapseTags, @@ -338,6 +339,7 @@ class EventsPersistenceStorageController: ) self._state_resolution_handler = hs.get_state_resolution_handler() self._state_controller = state_controller + self.hs = hs async def _process_event_persist_queue_task( self, @@ -350,15 +352,22 @@ class EventsPersistenceStorageController: A dictionary of event ID to event ID we didn't persist as we already had another event persisted with the same TXN ID. """ - if isinstance(task, _PersistEventsTask): - return await self._persist_event_batch(room_id, task) - elif isinstance(task, _UpdateCurrentStateTask): - await self._update_current_state(room_id, task) - return {} - else: - raise AssertionError( - f"Found an unexpected task type in event persistence queue: {task}" - ) + + # Ensure that the room can't be deleted while we're persisting events to + # it. We might already have taken out the lock, but since this is just a + # "read" lock its inherently reentrant. + async with self.hs.get_worker_locks_handler().acquire_read_write_lock( + DELETE_ROOM_LOCK_NAME, room_id, write=False + ): + if isinstance(task, _PersistEventsTask): + return await self._persist_event_batch(room_id, task) + elif isinstance(task, _UpdateCurrentStateTask): + await self._update_current_state(room_id, task) + return {} + else: + raise AssertionError( + f"Found an unexpected task type in event persistence queue: {task}" + ) @trace async def persist_events( diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index b2cda52ce5..534dc32413 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -843,7 +843,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas * because the schema change is in a background update, it's not * necessarily safe to assume that it will have been completed. */ - AND edge.is_state is ? /* False */ + AND edge.is_state is FALSE /** * We only want backwards extremities that are older than or at * the same position of the given `current_depth` (where older @@ -886,7 +886,6 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas sql, ( room_id, - False, current_depth, self._clock.time_msec(), BACKFILL_EVENT_EXPONENTIAL_BACKOFF_MAXIMUM_DOUBLING_STEPS, diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 2b83a69426..bd3f14fb71 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1455,8 +1455,8 @@ class PersistEventsStore: }, ) - sql = "UPDATE events SET outlier = ? WHERE event_id = ?" - txn.execute(sql, (False, event.event_id)) + sql = "UPDATE events SET outlier = FALSE WHERE event_id = ?" + txn.execute(sql, (event.event_id,)) # Update the event_backward_extremities table now that this # event isn't an outlier any more. @@ -1549,13 +1549,13 @@ class PersistEventsStore: for event, _ in events_and_contexts if not event.internal_metadata.is_redacted() ] - sql = "UPDATE redactions SET have_censored = ? WHERE " + sql = "UPDATE redactions SET have_censored = FALSE WHERE " clause, args = make_in_list_sql_clause( self.database_engine, "redacts", unredacted_events, ) - txn.execute(sql + clause, [False] + args) + txn.execute(sql + clause, args) self.db_pool.simple_insert_many_txn( txn, @@ -2318,14 +2318,14 @@ class PersistEventsStore: " SELECT 1 FROM events" " LEFT JOIN event_edges edge" " ON edge.event_id = events.event_id" - " WHERE events.event_id = ? AND events.room_id = ? AND (events.outlier = ? OR edge.event_id IS NULL)" + " WHERE events.event_id = ? AND events.room_id = ? AND (events.outlier = FALSE OR edge.event_id IS NULL)" " )" ) txn.execute_batch( query, [ - (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False) + (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id) for ev in events for e_id in ev.prev_event_ids() if not ev.internal_metadata.is_outlier() diff --git a/synapse/storage/databases/main/lock.py b/synapse/storage/databases/main/lock.py index c89b4f7919..1680bf6168 100644 --- a/synapse/storage/databases/main/lock.py +++ b/synapse/storage/databases/main/lock.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from contextlib import AsyncExitStack from types import TracebackType -from typing import TYPE_CHECKING, Optional, Set, Tuple, Type +from typing import TYPE_CHECKING, Collection, Optional, Set, Tuple, Type from weakref import WeakValueDictionary from twisted.internet.interfaces import IReactorCore @@ -208,76 +209,85 @@ class LockStore(SQLBaseStore): used (otherwise the lock will leak). """ + try: + lock = await self.db_pool.runInteraction( + "try_acquire_read_write_lock", + self._try_acquire_read_write_lock_txn, + lock_name, + lock_key, + write, + ) + except self.database_engine.module.IntegrityError: + return None + + return lock + + def _try_acquire_read_write_lock_txn( + self, + txn: LoggingTransaction, + lock_name: str, + lock_key: str, + write: bool, + ) -> "Lock": + # We attempt to acquire the lock by inserting into + # `worker_read_write_locks` and seeing if that fails any + # constraints. If it doesn't then we have acquired the lock, + # otherwise we haven't. + # + # Before that though we clear the table of any stale locks. + now = self._clock.time_msec() token = random_string(6) - def _try_acquire_read_write_lock_txn(txn: LoggingTransaction) -> None: - # We attempt to acquire the lock by inserting into - # `worker_read_write_locks` and seeing if that fails any - # constraints. If it doesn't then we have acquired the lock, - # otherwise we haven't. - # - # Before that though we clear the table of any stale locks. - - delete_sql = """ - DELETE FROM worker_read_write_locks - WHERE last_renewed_ts < ? AND lock_name = ? AND lock_key = ?; - """ - - insert_sql = """ - INSERT INTO worker_read_write_locks (lock_name, lock_key, write_lock, instance_name, token, last_renewed_ts) - VALUES (?, ?, ?, ?, ?, ?) - """ - - if isinstance(self.database_engine, PostgresEngine): - # For Postgres we can send these queries at the same time. - txn.execute( - delete_sql + ";" + insert_sql, - ( - # DELETE args - now - _LOCK_TIMEOUT_MS, - lock_name, - lock_key, - # UPSERT args - lock_name, - lock_key, - write, - self._instance_name, - token, - now, - ), - ) - else: - # For SQLite these need to be two queries. - txn.execute( - delete_sql, - ( - now - _LOCK_TIMEOUT_MS, - lock_name, - lock_key, - ), - ) - txn.execute( - insert_sql, - ( - lock_name, - lock_key, - write, - self._instance_name, - token, - now, - ), - ) + delete_sql = """ + DELETE FROM worker_read_write_locks + WHERE last_renewed_ts < ? AND lock_name = ? AND lock_key = ?; + """ - return + insert_sql = """ + INSERT INTO worker_read_write_locks (lock_name, lock_key, write_lock, instance_name, token, last_renewed_ts) + VALUES (?, ?, ?, ?, ?, ?) + """ - try: - await self.db_pool.runInteraction( - "try_acquire_read_write_lock", - _try_acquire_read_write_lock_txn, + if isinstance(self.database_engine, PostgresEngine): + # For Postgres we can send these queries at the same time. + txn.execute( + delete_sql + ";" + insert_sql, + ( + # DELETE args + now - _LOCK_TIMEOUT_MS, + lock_name, + lock_key, + # UPSERT args + lock_name, + lock_key, + write, + self._instance_name, + token, + now, + ), + ) + else: + # For SQLite these need to be two queries. + txn.execute( + delete_sql, + ( + now - _LOCK_TIMEOUT_MS, + lock_name, + lock_key, + ), + ) + txn.execute( + insert_sql, + ( + lock_name, + lock_key, + write, + self._instance_name, + token, + now, + ), ) - except self.database_engine.module.IntegrityError: - return None lock = Lock( self._reactor, @@ -289,10 +299,58 @@ class LockStore(SQLBaseStore): token=token, ) - self._live_read_write_lock_tokens[(lock_name, lock_key, token)] = lock + def set_lock() -> None: + self._live_read_write_lock_tokens[(lock_name, lock_key, token)] = lock + + txn.call_after(set_lock) return lock + async def try_acquire_multi_read_write_lock( + self, + lock_names: Collection[Tuple[str, str]], + write: bool, + ) -> Optional[AsyncExitStack]: + """Try to acquire multiple locks for the given names/keys. Will return + an async context manager if the locks are successfully acquired, which + *must* be used (otherwise the lock will leak). + + If only a subset of the locks can be acquired then it will immediately + drop them and return `None`. + """ + try: + locks = await self.db_pool.runInteraction( + "try_acquire_multi_read_write_lock", + self._try_acquire_multi_read_write_lock_txn, + lock_names, + write, + ) + except self.database_engine.module.IntegrityError: + return None + + stack = AsyncExitStack() + + for lock in locks: + await stack.enter_async_context(lock) + + return stack + + def _try_acquire_multi_read_write_lock_txn( + self, + txn: LoggingTransaction, + lock_names: Collection[Tuple[str, str]], + write: bool, + ) -> Collection["Lock"]: + locks = [] + + for lock_name, lock_key in lock_names: + lock = self._try_acquire_read_write_lock_txn( + txn, lock_name, lock_key, write + ) + locks.append(lock) + + return locks + class Lock: """An async context manager that manages an acquired lock, ensuring it is diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py index 9773c1fcd2..b52f48cf04 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py @@ -249,12 +249,11 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): # Mark all state and own events as outliers logger.info("[purge] marking remaining events as outliers") txn.execute( - "UPDATE events SET outlier = ?" + "UPDATE events SET outlier = TRUE" " WHERE event_id IN (" - " SELECT event_id FROM events_to_purge " - " WHERE NOT should_delete" - ")", - (True,), + " SELECT event_id FROM events_to_purge " + " WHERE NOT should_delete" + ")" ) # synapse tries to take out an exclusive lock on room_depth whenever it diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index e098ceea3c..c13c0bc7d7 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -560,19 +560,19 @@ class PushRuleStore(PushRulesWorkerStore): if isinstance(self.database_engine, PostgresEngine): sql = """ INSERT INTO push_rules_enable (id, user_name, rule_id, enabled) - VALUES (?, ?, ?, ?) + VALUES (?, ?, ?, 1) ON CONFLICT DO NOTHING """ elif isinstance(self.database_engine, Sqlite3Engine): sql = """ INSERT OR IGNORE INTO push_rules_enable (id, user_name, rule_id, enabled) - VALUES (?, ?, ?, ?) + VALUES (?, ?, ?, 1) """ else: raise RuntimeError("Unknown database engine") new_enable_id = self._push_rules_enable_id_gen.get_next() - txn.execute(sql, (new_enable_id, user_id, rule_id, 1)) + txn.execute(sql, (new_enable_id, user_id, rule_id)) async def delete_push_rule(self, user_id: str, rule_id: str) -> None: """ diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 676d03bb7e..c582cf0573 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -454,9 +454,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): ) -> List[Tuple[str, int]]: sql = ( "SELECT user_id, expiration_ts_ms FROM account_validity" - " WHERE email_sent = ? AND (expiration_ts_ms - ?) <= ?" + " WHERE email_sent = FALSE AND (expiration_ts_ms - ?) <= ?" ) - values = [False, now_ms, renew_at] + values = [now_ms, renew_at] txn.execute(sql, values) return cast(List[Tuple[str, int]], txn.fetchall()) diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 830658f328..719e11aea6 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -936,11 +936,11 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): JOIN event_json USING (room_id, event_id) WHERE room_id = ? %(where_clause)s - AND contains_url = ? AND outlier = ? + AND contains_url = TRUE AND outlier = FALSE ORDER BY stream_ordering DESC LIMIT ? """ - txn.execute(sql % {"where_clause": ""}, (room_id, True, False, 100)) + txn.execute(sql % {"where_clause": ""}, (room_id, 100)) local_media_mxcs = [] remote_media_mxcs = [] @@ -976,7 +976,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): txn.execute( sql % {"where_clause": "AND stream_ordering < ?"}, - (room_id, next_token, True, False, 100), + (room_id, next_token, 100), ) return local_media_mxcs, remote_media_mxcs @@ -1086,9 +1086,9 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): # set quarantine if quarantined_by is not None: - sql += "AND safe_from_quarantine = ?" + sql += "AND safe_from_quarantine = FALSE" txn.executemany( - sql, [(quarantined_by, media_id, False) for media_id in local_mxcs] + sql, [(quarantined_by, media_id) for media_id in local_mxcs] ) # remove from quarantine else: diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 92cbe262a6..5a3611c415 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -1401,7 +1401,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): `to_token`), or `limit` is zero. """ - args = [False, room_id] + args: List[Any] = [room_id] order, from_bound, to_bound = generate_pagination_bounds( direction, from_token, to_token @@ -1475,7 +1475,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): event.topological_ordering, event.stream_ordering FROM events AS event %(join_clause)s - WHERE event.outlier = ? AND event.room_id = ? AND %(bounds)s + WHERE event.outlier = FALSE AND event.room_id = ? AND %(bounds)s ORDER BY event.topological_ordering %(order)s, event.stream_ordering %(order)s LIMIT ? """ % { |