diff options
author | Mathieu Velten <mathieuv@matrix.org> | 2023-10-09 11:20:08 +0200 |
---|---|---|
committer | Mathieu Velten <mathieuv@matrix.org> | 2023-10-09 15:19:19 +0200 |
commit | 99fefd5501858ba90a10f35af757f424bc189f8c (patch) | |
tree | d86eb4f21a9d995bea57c79aabceeaf83f8c72ff /synapse | |
parent | argggggghhhh (diff) | |
parent | Fix possible AttributeError when account-api is called over unix socket (#16404) (diff) | |
download | synapse-99fefd5501858ba90a10f35af757f424bc189f8c.tar.xz |
Merge remote-tracking branch 'origin/develop' into anoa/public_rooms_module_api
Diffstat (limited to 'synapse')
287 files changed, 12875 insertions, 6638 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py index b97ee59f15..4a9bbc4d57 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -21,12 +21,21 @@ import os import sys from typing import Any, Dict +from PIL import ImageFile + from synapse.util.rust import check_rust_lib_up_to_date from synapse.util.stringutils import strtobool +# Allow truncated JPEG images to be thumbnailed. +ImageFile.LOAD_TRUNCATED_IMAGES = True + # Check that we're not running on an unsupported Python version. -if sys.version_info < (3, 7): - print("Synapse requires Python 3.7 or above.") +# +# Note that we use an (unneeded) variable here so that pyupgrade doesn't nuke the +# if-statement completely. +py_version = sys.version_info +if py_version < (3, 8): + print("Synapse requires Python 3.8 or above.") sys.exit(1) # Allow using the asyncio reactor via env var. @@ -78,7 +87,7 @@ try: except ImportError: pass -import synapse.util +import synapse.util # noqa: E402 __version__ = synapse.util.SYNAPSE_VERSION diff --git a/synapse/_pydantic_compat.py b/synapse/_pydantic_compat.py new file mode 100644 index 0000000000..ddff72afa1 --- /dev/null +++ b/synapse/_pydantic_compat.py @@ -0,0 +1,26 @@ +# Copyright 2023 Maxwell G <maxwell@gtmx.me> +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from packaging.version import Version + +try: + from pydantic import __version__ as pydantic_version +except ImportError: + import importlib.metadata + + pydantic_version = importlib.metadata.version("pydantic") + +HAS_PYDANTIC_V2: bool = Version(pydantic_version).major == 2 + +__all__ = ("HAS_PYDANTIC_V2",) diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py index 27fee3d9a9..ab2b29cf1b 100755 --- a/synapse/_scripts/synapse_port_db.py +++ b/synapse/_scripts/synapse_port_db.py @@ -61,6 +61,7 @@ from synapse.storage.databases.main.deviceinbox import DeviceInboxBackgroundUpda from synapse.storage.databases.main.devices import DeviceBackgroundUpdateStore from synapse.storage.databases.main.e2e_room_keys import EndToEndRoomKeyBackgroundStore from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyBackgroundStore +from synapse.storage.databases.main.event_federation import EventFederationWorkerStore from synapse.storage.databases.main.event_push_actions import EventPushActionsStore from synapse.storage.databases.main.events_bg_updates import ( EventsBackgroundUpdatesStore, @@ -122,7 +123,7 @@ BOOLEAN_COLUMNS = { "redactions": ["have_censored"], "room_stats_state": ["is_federatable"], "rooms": ["is_public", "has_auth_chain_index"], - "users": ["shadow_banned", "approved"], + "users": ["shadow_banned", "approved", "locked"], "un_partial_stated_event_stream": ["rejection_status_changed"], "users_who_share_rooms": ["share_private"], "per_user_experimental_features": ["enabled"], @@ -196,6 +197,11 @@ IGNORED_TABLES = { "ui_auth_sessions", "ui_auth_sessions_credentials", "ui_auth_sessions_ips", + # Ignore the worker locks table, as a) there shouldn't be any acquired locks + # after porting, and b) the circular foreign key constraints make it hard to + # port. + "worker_read_write_locks_mode", + "worker_read_write_locks", } @@ -239,6 +245,7 @@ class Store( PresenceBackgroundUpdateStore, ReceiptsBackgroundUpdateStore, RelationsWorkerStore, + EventFederationWorkerStore, ): def execute(self, f: Callable[..., R], *args: Any, **kwargs: Any) -> Awaitable[R]: return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs) @@ -475,7 +482,10 @@ class Porter: do_backward[0] = False if forward_rows or backward_rows: - headers = [column[0] for column in txn.description] + assert txn.description is not None + headers: Optional[List[str]] = [ + column[0] for column in txn.description + ] else: headers = None @@ -537,6 +547,7 @@ class Porter: def r(txn: LoggingTransaction) -> Tuple[List[str], List[Tuple]]: txn.execute(select, (forward_chunk, self.batch_size)) rows = txn.fetchall() + assert txn.description is not None headers = [column[0] for column in txn.description] return headers, rows @@ -754,7 +765,7 @@ class Porter: # Step 2. Set up sequences # - # We do this before porting the tables so that event if we fail half + # We do this before porting the tables so that even if we fail half # way through the postgres DB always have sequences that are greater # than their respective tables. If we don't then creating the # `DataStore` object will fail due to the inconsistency. @@ -763,6 +774,10 @@ class Porter: await self._setup_user_id_seq() await self._setup_events_stream_seqs() await self._setup_sequence( + "un_partial_stated_event_stream_sequence", + ("un_partial_stated_event_stream",), + ) + await self._setup_sequence( "device_inbox_sequence", ("device_inbox", "device_federation_outbox") ) await self._setup_sequence( @@ -772,6 +787,11 @@ class Porter: await self._setup_sequence("receipts_sequence", ("receipts_linearized",)) await self._setup_sequence("presence_stream_sequence", ("presence_stream",)) await self._setup_auth_chain_sequence() + await self._setup_sequence( + "application_services_txn_id_seq", + ("application_services_txns",), + "txn_id", + ) # Step 3. Get tables. self.progress.set_state("Fetching tables") @@ -803,7 +823,9 @@ class Porter: ) # Map from table name to args passed to `handle_table`, i.e. a tuple # of: `postgres_size`, `table_size`, `forward_chunk`, `backward_chunk`. - tables_to_port_info_map = {r[0]: r[1:] for r in setup_res} + tables_to_port_info_map = { + r[0]: r[1:] for r in setup_res if r[0] not in IGNORED_TABLES + } # Step 5. Do the copying. # @@ -901,7 +923,8 @@ class Porter: def r(txn: LoggingTransaction) -> Tuple[List[str], List[Tuple]]: txn.execute(select) rows = txn.fetchall() - headers: List[str] = [column[0] for column in txn.description] + assert txn.description is not None + headers = [column[0] for column in txn.description] ts_ind = headers.index("ts") @@ -1074,7 +1097,10 @@ class Porter: ) async def _setup_sequence( - self, sequence_name: str, stream_id_tables: Iterable[str] + self, + sequence_name: str, + stream_id_tables: Iterable[str], + column_name: str = "stream_id", ) -> None: """Set a sequence to the correct value.""" current_stream_ids = [] @@ -1084,7 +1110,7 @@ class Porter: await self.sqlite_store.db_pool.simple_select_one_onecol( table=stream_id_table, keyvalues={}, - retcol="COALESCE(MAX(stream_id), 1)", + retcol=f"COALESCE(MAX({column_name}), 1)", allow_none=True, ), ) @@ -1184,10 +1210,10 @@ class CursesProgress(Progress): self.total_processed = 0 self.total_remaining = 0 - super(CursesProgress, self).__init__() + super().__init__() def update(self, table: str, num_done: int) -> None: - super(CursesProgress, self).update(table, num_done) + super().update(table, num_done) self.total_processed = 0 self.total_remaining = 0 @@ -1283,7 +1309,7 @@ class TerminalProgress(Progress): """Just prints progress to the terminal""" def update(self, table: str, num_done: int) -> None: - super(TerminalProgress, self).update(table, num_done) + super().update(table, num_done) data = self.tables[table] @@ -1369,6 +1395,9 @@ def main() -> None: sys.stderr.write("Database must use the 'psycopg2' connector.\n") sys.exit(3) + # Don't run the background tasks that get started by the data stores. + hs_config["run_background_tasks_on"] = "some_other_process" + config = HomeServerConfig() config.parse_config_dict(hs_config, "", "") diff --git a/synapse/_scripts/update_synapse_database.py b/synapse/_scripts/update_synapse_database.py index 0adf94bba6..992ae43881 100644 --- a/synapse/_scripts/update_synapse_database.py +++ b/synapse/_scripts/update_synapse_database.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python # Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -38,7 +37,7 @@ class MockHomeserver(HomeServer): DATASTORE_CLASS = DataStore # type: ignore [assignment] def __init__(self, config: HomeServerConfig): - super(MockHomeserver, self).__init__( + super().__init__( hostname=config.server.server_name, config=config, reactor=reactor, diff --git a/synapse/api/auth/__init__.py b/synapse/api/auth/__init__.py new file mode 100644 index 0000000000..bb3f50f2dd --- /dev/null +++ b/synapse/api/auth/__init__.py @@ -0,0 +1,176 @@ +# Copyright 2023 The Matrix.org Foundation. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, Tuple + +from typing_extensions import Protocol + +from twisted.web.server import Request + +from synapse.appservice import ApplicationService +from synapse.http.site import SynapseRequest +from synapse.types import Requester + +# guests always get this device id. +GUEST_DEVICE_ID = "guest_device" + + +class Auth(Protocol): + """The interface that an auth provider must implement.""" + + async def check_user_in_room( + self, + room_id: str, + requester: Requester, + allow_departed_users: bool = False, + ) -> Tuple[str, Optional[str]]: + """Check if the user is in the room, or was at some point. + Args: + room_id: The room to check. + + user_id: The user to check. + + current_state: Optional map of the current state of the room. + If provided then that map is used to check whether they are a + member of the room. Otherwise the current membership is + loaded from the database. + + allow_departed_users: if True, accept users that were previously + members but have now departed. + + Raises: + AuthError if the user is/was not in the room. + Returns: + The current membership of the user in the room and the + membership event ID of the user. + """ + + async def get_user_by_req( + self, + request: SynapseRequest, + allow_guest: bool = False, + allow_expired: bool = False, + allow_locked: bool = False, + ) -> Requester: + """Get a registered user's ID. + + Args: + request: An HTTP request with an access_token query parameter. + allow_guest: If False, will raise an AuthError if the user making the + request is a guest. + allow_expired: If True, allow the request through even if the account + is expired, or session token lifetime has ended. Note that + /login will deliver access tokens regardless of expiration. + + Returns: + Resolves to the requester + Raises: + InvalidClientCredentialsError if no user by that token exists or the token + is invalid. + AuthError if access is denied for the user in the access token + """ + + async def validate_appservice_can_control_user_id( + self, app_service: ApplicationService, user_id: str + ) -> None: + """Validates that the app service is allowed to control + the given user. + + Args: + app_service: The app service that controls the user + user_id: The author MXID that the app service is controlling + + Raises: + AuthError: If the application service is not allowed to control the user + (user namespace regex does not match, wrong homeserver, etc) + or if the user has not been registered yet. + """ + + async def get_user_by_access_token( + self, + token: str, + allow_expired: bool = False, + ) -> Requester: + """Validate access token and get user_id from it + + Args: + token: The access token to get the user by + allow_expired: If False, raises an InvalidClientTokenError + if the token is expired + + Raises: + InvalidClientTokenError if a user by that token exists, but the token is + expired + InvalidClientCredentialsError if no user by that token exists or the token + is invalid + """ + + async def is_server_admin(self, requester: Requester) -> bool: + """Check if the given user is a local server admin. + + Args: + requester: user to check + + Returns: + True if the user is an admin + """ + + async def check_can_change_room_list( + self, room_id: str, requester: Requester + ) -> bool: + """Determine whether the user is allowed to edit the room's entry in the + published room list. + + Args: + room_id + user + """ + + @staticmethod + def has_access_token(request: Request) -> bool: + """Checks if the request has an access_token. + + Returns: + False if no access_token was given, True otherwise. + """ + + @staticmethod + def get_access_token_from_request(request: Request) -> str: + """Extracts the access_token from the request. + + Args: + request: The http request. + Returns: + The access_token + Raises: + MissingClientTokenError: If there isn't a single access_token in the + request + """ + + async def check_user_in_room_or_world_readable( + self, room_id: str, requester: Requester, allow_departed_users: bool = False + ) -> Tuple[str, Optional[str]]: + """Checks that the user is or was in the room or the room is world + readable. If it isn't then an exception is raised. + + Args: + room_id: room to check + user_id: user to check + allow_departed_users: if True, accept users that were previously + members but have now departed + + Returns: + Resolves to the current membership of the user in the room and the + membership event ID of the user. If the user is not in the room and + never has been, then `(Membership.JOIN, None)` is returned. + """ diff --git a/synapse/api/auth.py b/synapse/api/auth/base.py index 66e869bc2d..9321d6f186 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth/base.py @@ -1,4 +1,4 @@ -# Copyright 2014 - 2016 OpenMarket Ltd +# Copyright 2023 The Matrix.org Foundation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,7 +14,6 @@ import logging from typing import TYPE_CHECKING, Optional, Tuple -import pymacaroons from netaddr import IPAddress from twisted.web.server import Request @@ -24,19 +23,11 @@ from synapse.api.constants import EventTypes, HistoryVisibility, Membership from synapse.api.errors import ( AuthError, Codes, - InvalidClientTokenError, MissingClientTokenError, UnstableSpecAuthError, ) from synapse.appservice import ApplicationService -from synapse.http import get_request_user_agent -from synapse.http.site import SynapseRequest -from synapse.logging.opentracing import ( - active_span, - force_tracing, - start_active_span, - trace, -) +from synapse.logging.opentracing import trace from synapse.types import Requester, create_requester from synapse.util.cancellation import cancellable @@ -46,26 +37,13 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -# guests always get this device id. -GUEST_DEVICE_ID = "guest_device" - - -class Auth: - """ - This class contains functions for authenticating users of our client-server API. - """ +class BaseAuth: + """Common base class for all auth implementations.""" def __init__(self, hs: "HomeServer"): self.hs = hs - self.clock = hs.get_clock() self.store = hs.get_datastores().main - self._account_validity_handler = hs.get_account_validity_handler() self._storage_controllers = hs.get_storage_controllers() - self._macaroon_generator = hs.get_macaroon_generator() - - self._track_appservice_user_ips = hs.config.appservice.track_appservice_user_ips - self._track_puppeted_user_ips = hs.config.api.track_puppeted_user_ips - self._force_tracing_for_users = hs.config.tracing.force_tracing_for_users async def check_user_in_room( self, @@ -119,139 +97,49 @@ class Auth: errcode=Codes.NOT_JOINED, ) - @cancellable - async def get_user_by_req( - self, - request: SynapseRequest, - allow_guest: bool = False, - allow_expired: bool = False, - ) -> Requester: - """Get a registered user's ID. + @trace + async def check_user_in_room_or_world_readable( + self, room_id: str, requester: Requester, allow_departed_users: bool = False + ) -> Tuple[str, Optional[str]]: + """Checks that the user is or was in the room or the room is world + readable. If it isn't then an exception is raised. Args: - request: An HTTP request with an access_token query parameter. - allow_guest: If False, will raise an AuthError if the user making the - request is a guest. - allow_expired: If True, allow the request through even if the account - is expired, or session token lifetime has ended. Note that - /login will deliver access tokens regardless of expiration. + room_id: room to check + user_id: user to check + allow_departed_users: if True, accept users that were previously + members but have now departed Returns: - Resolves to the requester - Raises: - InvalidClientCredentialsError if no user by that token exists or the token - is invalid. - AuthError if access is denied for the user in the access token + Resolves to the current membership of the user in the room and the + membership event ID of the user. If the user is not in the room and + never has been, then `(Membership.JOIN, None)` is returned. """ - parent_span = active_span() - with start_active_span("get_user_by_req"): - requester = await self._wrapped_get_user_by_req( - request, allow_guest, allow_expired - ) - - if parent_span: - if requester.authenticated_entity in self._force_tracing_for_users: - # request tracing is enabled for this user, so we need to force it - # tracing on for the parent span (which will be the servlet span). - # - # It's too late for the get_user_by_req span to inherit the setting, - # so we also force it on for that. - force_tracing() - force_tracing(parent_span) - parent_span.set_tag( - "authenticated_entity", requester.authenticated_entity - ) - parent_span.set_tag("user_id", requester.user.to_string()) - if requester.device_id is not None: - parent_span.set_tag("device_id", requester.device_id) - if requester.app_service is not None: - parent_span.set_tag("appservice_id", requester.app_service.id) - return requester - @cancellable - async def _wrapped_get_user_by_req( - self, - request: SynapseRequest, - allow_guest: bool, - allow_expired: bool, - ) -> Requester: - """Helper for get_user_by_req - - Once get_user_by_req has set up the opentracing span, this does the actual work. - """ try: - ip_addr = request.getClientAddress().host - user_agent = get_request_user_agent(request) - - access_token = self.get_access_token_from_request(request) - - # First check if it could be a request from an appservice - requester = await self._get_appservice_user(request) - if not requester: - # If not, it should be from a regular user - requester = await self.get_user_by_access_token( - access_token, allow_expired=allow_expired - ) - - # Deny the request if the user account has expired. - # This check is only done for regular users, not appservice ones. - if not allow_expired: - if await self._account_validity_handler.is_user_expired( - requester.user.to_string() - ): - # Raise the error if either an account validity module has determined - # the account has expired, or the legacy account validity - # implementation is enabled and determined the account has expired - raise AuthError( - 403, - "User account has expired", - errcode=Codes.EXPIRED_ACCOUNT, - ) - - if ip_addr and ( - not requester.app_service or self._track_appservice_user_ips + # check_user_in_room will return the most recent membership + # event for the user if: + # * The user is a non-guest user, and was ever in the room + # * The user is a guest user, and has joined the room + # else it will throw. + return await self.check_user_in_room( + room_id, requester, allow_departed_users=allow_departed_users + ) + except AuthError: + visibility = await self._storage_controllers.state.get_current_state_event( + room_id, EventTypes.RoomHistoryVisibility, "" + ) + if ( + visibility + and visibility.content.get("history_visibility") + == HistoryVisibility.WORLD_READABLE ): - # XXX(quenting): I'm 95% confident that we could skip setting the - # device_id to "dummy-device" for appservices, and that the only impact - # would be some rows which whould not deduplicate in the 'user_ips' - # table during the transition - recorded_device_id = ( - "dummy-device" - if requester.device_id is None and requester.app_service is not None - else requester.device_id - ) - await self.store.insert_client_ip( - user_id=requester.authenticated_entity, - access_token=access_token, - ip=ip_addr, - user_agent=user_agent, - device_id=recorded_device_id, - ) - - # Track also the puppeted user client IP if enabled and the user is puppeting - if ( - requester.user.to_string() != requester.authenticated_entity - and self._track_puppeted_user_ips - ): - await self.store.insert_client_ip( - user_id=requester.user.to_string(), - access_token=access_token, - ip=ip_addr, - user_agent=user_agent, - device_id=requester.device_id, - ) - - if requester.is_guest and not allow_guest: - raise AuthError( - 403, - "Guest access not allowed", - errcode=Codes.GUEST_ACCESS_FORBIDDEN, - ) - - request.requester = requester - return requester - except KeyError: - raise MissingClientTokenError() + return Membership.JOIN, None + raise AuthError( + 403, + "User %r not in room %s, and room previews are disabled" + % (requester.user, room_id), + ) async def validate_appservice_can_control_user_id( self, app_service: ApplicationService, user_id: str @@ -284,184 +172,16 @@ class Auth: 403, "Application service has not registered this user (%s)" % user_id ) - @cancellable - async def _get_appservice_user(self, request: Request) -> Optional[Requester]: - """ - Given a request, reads the request parameters to determine: - - whether it's an application service that's making this request - - what user the application service should be treated as controlling - (the user_id URI parameter allows an application service to masquerade - any applicable user in its namespace) - - what device the application service should be treated as controlling - (the device_id[^1] URI parameter allows an application service to masquerade - as any device that exists for the relevant user) - - [^1] Unstable and provided by MSC3202. - Must use `org.matrix.msc3202.device_id` in place of `device_id` for now. - - Returns: - the application service `Requester` of that request - - Postconditions: - - The `app_service` field in the returned `Requester` is set - - The `user_id` field in the returned `Requester` is either the application - service sender or the controlled user set by the `user_id` URI parameter - - The returned application service is permitted to control the returned user ID. - - The returned device ID, if present, has been checked to be a valid device ID - for the returned user ID. - """ - DEVICE_ID_ARG_NAME = b"org.matrix.msc3202.device_id" - - app_service = self.store.get_app_service_by_token( - self.get_access_token_from_request(request) - ) - if app_service is None: - return None - - if app_service.ip_range_whitelist: - ip_address = IPAddress(request.getClientAddress().host) - if ip_address not in app_service.ip_range_whitelist: - return None - - # This will always be set by the time Twisted calls us. - assert request.args is not None - - if b"user_id" in request.args: - effective_user_id = request.args[b"user_id"][0].decode("utf8") - await self.validate_appservice_can_control_user_id( - app_service, effective_user_id - ) - else: - effective_user_id = app_service.sender - - effective_device_id: Optional[str] = None - - if ( - self.hs.config.experimental.msc3202_device_masquerading_enabled - and DEVICE_ID_ARG_NAME in request.args - ): - effective_device_id = request.args[DEVICE_ID_ARG_NAME][0].decode("utf8") - # We only just set this so it can't be None! - assert effective_device_id is not None - device_opt = await self.store.get_device( - effective_user_id, effective_device_id - ) - if device_opt is None: - # For now, use 400 M_EXCLUSIVE if the device doesn't exist. - # This is an open thread of discussion on MSC3202 as of 2021-12-09. - raise AuthError( - 400, - f"Application service trying to use a device that doesn't exist ('{effective_device_id}' for {effective_user_id})", - Codes.EXCLUSIVE, - ) - - return create_requester( - effective_user_id, app_service=app_service, device_id=effective_device_id - ) - - async def get_user_by_access_token( - self, - token: str, - allow_expired: bool = False, - ) -> Requester: - """Validate access token and get user_id from it - - Args: - token: The access token to get the user by - allow_expired: If False, raises an InvalidClientTokenError - if the token is expired - - Raises: - InvalidClientTokenError if a user by that token exists, but the token is - expired - InvalidClientCredentialsError if no user by that token exists or the token - is invalid - """ - - # First look in the database to see if the access token is present - # as an opaque token. - user_info = await self.store.get_user_by_access_token(token) - if user_info: - valid_until_ms = user_info.valid_until_ms - if ( - not allow_expired - and valid_until_ms is not None - and valid_until_ms < self.clock.time_msec() - ): - # there was a valid access token, but it has expired. - # soft-logout the user. - raise InvalidClientTokenError( - msg="Access token has expired", soft_logout=True - ) - - # Mark the token as used. This is used to invalidate old refresh - # tokens after some time. - await self.store.mark_access_token_as_used(user_info.token_id) - - requester = create_requester( - user_id=user_info.user_id, - access_token_id=user_info.token_id, - is_guest=user_info.is_guest, - shadow_banned=user_info.shadow_banned, - device_id=user_info.device_id, - authenticated_entity=user_info.token_owner, - ) - - return requester - - # If the token isn't found in the database, then it could still be a - # macaroon for a guest, so we check that here. - try: - user_id = self._macaroon_generator.verify_guest_token(token) - - # Guest access tokens are not stored in the database (there can - # only be one access token per guest, anyway). - # - # In order to prevent guest access tokens being used as regular - # user access tokens (and hence getting around the invalidation - # process), we look up the user id and check that it is indeed - # a guest user. - # - # It would of course be much easier to store guest access - # tokens in the database as well, but that would break existing - # guest tokens. - stored_user = await self.store.get_user_by_id(user_id) - if not stored_user: - raise InvalidClientTokenError("Unknown user_id %s" % user_id) - if not stored_user["is_guest"]: - raise InvalidClientTokenError( - "Guest access token used for regular user" - ) - - return create_requester( - user_id=user_id, - is_guest=True, - # all guests get the same device id - device_id=GUEST_DEVICE_ID, - authenticated_entity=user_id, - ) - except ( - pymacaroons.exceptions.MacaroonException, - TypeError, - ValueError, - ) as e: - logger.warning( - "Invalid access token in auth: %s %s.", - type(e), - e, - ) - raise InvalidClientTokenError("Invalid access token passed.") - async def is_server_admin(self, requester: Requester) -> bool: """Check if the given user is a local server admin. Args: - requester: The user making the request, according to the access token. + requester: user to check Returns: True if the user is an admin """ - return await self.store.is_server_admin(requester.user) + raise NotImplementedError() async def check_can_change_room_list( self, room_id: str, requester: Requester @@ -470,8 +190,8 @@ class Auth: published room list. Args: - room_id: The room to check. - requester: The user making the request, according to the access token. + room_id + user """ is_admin = await self.is_server_admin(requester) @@ -518,7 +238,6 @@ class Auth: return bool(query_params) or bool(auth_headers) @staticmethod - @cancellable def get_access_token_from_request(request: Request) -> str: """Extracts the access_token from the request. @@ -556,47 +275,77 @@ class Auth: return query_params[0].decode("ascii") - @trace - async def check_user_in_room_or_world_readable( - self, room_id: str, requester: Requester, allow_departed_users: bool = False - ) -> Tuple[str, Optional[str]]: - """Checks that the user is or was in the room or the room is world - readable. If it isn't then an exception is raised. + @cancellable + async def get_appservice_user( + self, request: Request, access_token: str + ) -> Optional[Requester]: + """ + Given a request, reads the request parameters to determine: + - whether it's an application service that's making this request + - what user the application service should be treated as controlling + (the user_id URI parameter allows an application service to masquerade + any applicable user in its namespace) + - what device the application service should be treated as controlling + (the device_id[^1] URI parameter allows an application service to masquerade + as any device that exists for the relevant user) - Args: - room_id: The room to check. - requester: The user making the request, according to the access token. - allow_departed_users: If True, accept users that were previously - members but have now departed. + [^1] Unstable and provided by MSC3202. + Must use `org.matrix.msc3202.device_id` in place of `device_id` for now. Returns: - Resolves to the current membership of the user in the room and the - membership event ID of the user. If the user is not in the room and - never has been, then `(Membership.JOIN, None)` is returned. + the application service `Requester` of that request + + Postconditions: + - The `app_service` field in the returned `Requester` is set + - The `user_id` field in the returned `Requester` is either the application + service sender or the controlled user set by the `user_id` URI parameter + - The returned application service is permitted to control the returned user ID. + - The returned device ID, if present, has been checked to be a valid device ID + for the returned user ID. """ + DEVICE_ID_ARG_NAME = b"org.matrix.msc3202.device_id" - try: - # check_user_in_room will return the most recent membership - # event for the user if: - # * The user is a non-guest user, and was ever in the room - # * The user is a guest user, and has joined the room - # else it will throw. - return await self.check_user_in_room( - room_id, requester, allow_departed_users=allow_departed_users - ) - except AuthError: - visibility = await self._storage_controllers.state.get_current_state_event( - room_id, EventTypes.RoomHistoryVisibility, "" + app_service = self.store.get_app_service_by_token(access_token) + if app_service is None: + return None + + if app_service.ip_range_whitelist: + ip_address = IPAddress(request.getClientAddress().host) + if ip_address not in app_service.ip_range_whitelist: + return None + + # This will always be set by the time Twisted calls us. + assert request.args is not None + + if b"user_id" in request.args: + effective_user_id = request.args[b"user_id"][0].decode("utf8") + await self.validate_appservice_can_control_user_id( + app_service, effective_user_id ) - if ( - visibility - and visibility.content.get("history_visibility") - == HistoryVisibility.WORLD_READABLE - ): - return Membership.JOIN, None - raise UnstableSpecAuthError( - 403, - "User %s not in room %s, and room previews are disabled" - % (requester.user, room_id), - errcode=Codes.NOT_JOINED, + else: + effective_user_id = app_service.sender + + effective_device_id: Optional[str] = None + + if ( + self.hs.config.experimental.msc3202_device_masquerading_enabled + and DEVICE_ID_ARG_NAME in request.args + ): + effective_device_id = request.args[DEVICE_ID_ARG_NAME][0].decode("utf8") + # We only just set this so it can't be None! + assert effective_device_id is not None + device_opt = await self.store.get_device( + effective_user_id, effective_device_id ) + if device_opt is None: + # For now, use 400 M_EXCLUSIVE if the device doesn't exist. + # This is an open thread of discussion on MSC3202 as of 2021-12-09. + raise AuthError( + 400, + f"Application service trying to use a device that doesn't exist ('{effective_device_id}' for {effective_user_id})", + Codes.EXCLUSIVE, + ) + + return create_requester( + effective_user_id, app_service=app_service, device_id=effective_device_id + ) diff --git a/synapse/api/auth/internal.py b/synapse/api/auth/internal.py new file mode 100644 index 0000000000..36ee9c8b8f --- /dev/null +++ b/synapse/api/auth/internal.py @@ -0,0 +1,304 @@ +# Copyright 2023 The Matrix.org Foundation. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import TYPE_CHECKING + +import pymacaroons + +from synapse.api.errors import ( + AuthError, + Codes, + InvalidClientTokenError, + MissingClientTokenError, +) +from synapse.http import get_request_user_agent +from synapse.http.site import SynapseRequest +from synapse.logging.opentracing import active_span, force_tracing, start_active_span +from synapse.types import Requester, create_requester +from synapse.util.cancellation import cancellable + +from . import GUEST_DEVICE_ID +from .base import BaseAuth + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class InternalAuth(BaseAuth): + """ + This class contains functions for authenticating users of our client-server API. + """ + + def __init__(self, hs: "HomeServer"): + super().__init__(hs) + self.clock = hs.get_clock() + self._account_validity_handler = hs.get_account_validity_handler() + self._macaroon_generator = hs.get_macaroon_generator() + + self._track_appservice_user_ips = hs.config.appservice.track_appservice_user_ips + self._track_puppeted_user_ips = hs.config.api.track_puppeted_user_ips + self._force_tracing_for_users = hs.config.tracing.force_tracing_for_users + + @cancellable + async def get_user_by_req( + self, + request: SynapseRequest, + allow_guest: bool = False, + allow_expired: bool = False, + allow_locked: bool = False, + ) -> Requester: + """Get a registered user's ID. + + Args: + request: An HTTP request with an access_token query parameter. + allow_guest: If False, will raise an AuthError if the user making the + request is a guest. + allow_expired: If True, allow the request through even if the account + is expired, or session token lifetime has ended. Note that + /login will deliver access tokens regardless of expiration. + + Returns: + Resolves to the requester + Raises: + InvalidClientCredentialsError if no user by that token exists or the token + is invalid. + AuthError if access is denied for the user in the access token + """ + parent_span = active_span() + with start_active_span("get_user_by_req"): + requester = await self._wrapped_get_user_by_req( + request, allow_guest, allow_expired, allow_locked + ) + + if parent_span: + if requester.authenticated_entity in self._force_tracing_for_users: + # request tracing is enabled for this user, so we need to force it + # tracing on for the parent span (which will be the servlet span). + # + # It's too late for the get_user_by_req span to inherit the setting, + # so we also force it on for that. + force_tracing() + force_tracing(parent_span) + parent_span.set_tag( + "authenticated_entity", requester.authenticated_entity + ) + parent_span.set_tag("user_id", requester.user.to_string()) + if requester.device_id is not None: + parent_span.set_tag("device_id", requester.device_id) + if requester.app_service is not None: + parent_span.set_tag("appservice_id", requester.app_service.id) + return requester + + @cancellable + async def _wrapped_get_user_by_req( + self, + request: SynapseRequest, + allow_guest: bool, + allow_expired: bool, + allow_locked: bool, + ) -> Requester: + """Helper for get_user_by_req + + Once get_user_by_req has set up the opentracing span, this does the actual work. + """ + try: + ip_addr = request.get_client_ip_if_available() + user_agent = get_request_user_agent(request) + + access_token = self.get_access_token_from_request(request) + + # First check if it could be a request from an appservice + requester = await self.get_appservice_user(request, access_token) + if not requester: + # If not, it should be from a regular user + requester = await self.get_user_by_access_token( + access_token, allow_expired=allow_expired + ) + + # Deny the request if the user account is locked. + if not allow_locked and await self.store.get_user_locked_status( + requester.user.to_string() + ): + raise AuthError( + 401, + "User account has been locked", + errcode=Codes.USER_LOCKED, + additional_fields={"soft_logout": True}, + ) + + # Deny the request if the user account has expired. + # This check is only done for regular users, not appservice ones. + if not allow_expired: + if await self._account_validity_handler.is_user_expired( + requester.user.to_string() + ): + # Raise the error if either an account validity module has determined + # the account has expired, or the legacy account validity + # implementation is enabled and determined the account has expired + raise AuthError( + 403, + "User account has expired", + errcode=Codes.EXPIRED_ACCOUNT, + ) + + if ip_addr and ( + not requester.app_service or self._track_appservice_user_ips + ): + # XXX(quenting): I'm 95% confident that we could skip setting the + # device_id to "dummy-device" for appservices, and that the only impact + # would be some rows which whould not deduplicate in the 'user_ips' + # table during the transition + recorded_device_id = ( + "dummy-device" + if requester.device_id is None and requester.app_service is not None + else requester.device_id + ) + await self.store.insert_client_ip( + user_id=requester.authenticated_entity, + access_token=access_token, + ip=ip_addr, + user_agent=user_agent, + device_id=recorded_device_id, + ) + + # Track also the puppeted user client IP if enabled and the user is puppeting + if ( + requester.user.to_string() != requester.authenticated_entity + and self._track_puppeted_user_ips + ): + await self.store.insert_client_ip( + user_id=requester.user.to_string(), + access_token=access_token, + ip=ip_addr, + user_agent=user_agent, + device_id=requester.device_id, + ) + + if requester.is_guest and not allow_guest: + raise AuthError( + 403, + "Guest access not allowed", + errcode=Codes.GUEST_ACCESS_FORBIDDEN, + ) + + request.requester = requester + return requester + except KeyError: + raise MissingClientTokenError() + + async def get_user_by_access_token( + self, + token: str, + allow_expired: bool = False, + ) -> Requester: + """Validate access token and get user_id from it + + Args: + token: The access token to get the user by + allow_expired: If False, raises an InvalidClientTokenError + if the token is expired + + Raises: + InvalidClientTokenError if a user by that token exists, but the token is + expired + InvalidClientCredentialsError if no user by that token exists or the token + is invalid + """ + + # First look in the database to see if the access token is present + # as an opaque token. + user_info = await self.store.get_user_by_access_token(token) + if user_info: + valid_until_ms = user_info.valid_until_ms + if ( + not allow_expired + and valid_until_ms is not None + and valid_until_ms < self.clock.time_msec() + ): + # there was a valid access token, but it has expired. + # soft-logout the user. + raise InvalidClientTokenError( + msg="Access token has expired", soft_logout=True + ) + + # Mark the token as used. This is used to invalidate old refresh + # tokens after some time. + await self.store.mark_access_token_as_used(user_info.token_id) + + requester = create_requester( + user_id=user_info.user_id, + access_token_id=user_info.token_id, + is_guest=user_info.is_guest, + shadow_banned=user_info.shadow_banned, + device_id=user_info.device_id, + authenticated_entity=user_info.token_owner, + ) + + return requester + + # If the token isn't found in the database, then it could still be a + # macaroon for a guest, so we check that here. + try: + user_id = self._macaroon_generator.verify_guest_token(token) + + # Guest access tokens are not stored in the database (there can + # only be one access token per guest, anyway). + # + # In order to prevent guest access tokens being used as regular + # user access tokens (and hence getting around the invalidation + # process), we look up the user id and check that it is indeed + # a guest user. + # + # It would of course be much easier to store guest access + # tokens in the database as well, but that would break existing + # guest tokens. + stored_user = await self.store.get_user_by_id(user_id) + if not stored_user: + raise InvalidClientTokenError("Unknown user_id %s" % user_id) + if not stored_user.is_guest: + raise InvalidClientTokenError( + "Guest access token used for regular user" + ) + + return create_requester( + user_id=user_id, + is_guest=True, + # all guests get the same device id + device_id=GUEST_DEVICE_ID, + authenticated_entity=user_id, + ) + except ( + pymacaroons.exceptions.MacaroonException, + TypeError, + ValueError, + ) as e: + logger.warning( + "Invalid access token in auth: %s %s.", + type(e), + e, + ) + raise InvalidClientTokenError("Invalid access token passed.") + + async def is_server_admin(self, requester: Requester) -> bool: + """Check if the given user is a local server admin. + + Args: + requester: The user making the request, according to the access token. + + Returns: + True if the user is an admin + """ + return await self.store.is_server_admin(requester.user) diff --git a/synapse/api/auth/msc3861_delegated.py b/synapse/api/auth/msc3861_delegated.py new file mode 100644 index 0000000000..31bb035cc8 --- /dev/null +++ b/synapse/api/auth/msc3861_delegated.py @@ -0,0 +1,374 @@ +# Copyright 2023 The Matrix.org Foundation. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import TYPE_CHECKING, Any, Dict, List, Optional +from urllib.parse import urlencode + +from authlib.oauth2 import ClientAuth +from authlib.oauth2.auth import encode_client_secret_basic, encode_client_secret_post +from authlib.oauth2.rfc7523 import ClientSecretJWT, PrivateKeyJWT, private_key_jwt_sign +from authlib.oauth2.rfc7662 import IntrospectionToken +from authlib.oidc.discovery import OpenIDProviderMetadata, get_well_known_url +from prometheus_client import Histogram + +from twisted.web.client import readBody +from twisted.web.http_headers import Headers + +from synapse.api.auth.base import BaseAuth +from synapse.api.errors import ( + AuthError, + HttpResponseException, + InvalidClientTokenError, + OAuthInsufficientScopeError, + StoreError, + SynapseError, +) +from synapse.http.site import SynapseRequest +from synapse.logging.context import make_deferred_yieldable +from synapse.types import Requester, UserID, create_requester +from synapse.util import json_decoder +from synapse.util.caches.cached_call import RetryOnExceptionCachedCall + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + +introspection_response_timer = Histogram( + "synapse_api_auth_delegated_introspection_response", + "Time taken to get a response for an introspection request", + ["code"], +) + + +# Scope as defined by MSC2967 +# https://github.com/matrix-org/matrix-spec-proposals/pull/2967 +SCOPE_MATRIX_API = "urn:matrix:org.matrix.msc2967.client:api:*" +SCOPE_MATRIX_GUEST = "urn:matrix:org.matrix.msc2967.client:api:guest" +SCOPE_MATRIX_DEVICE_PREFIX = "urn:matrix:org.matrix.msc2967.client:device:" + +# Scope which allows access to the Synapse admin API +SCOPE_SYNAPSE_ADMIN = "urn:synapse:admin:*" + + +def scope_to_list(scope: str) -> List[str]: + """Convert a scope string to a list of scope tokens""" + return scope.strip().split(" ") + + +class PrivateKeyJWTWithKid(PrivateKeyJWT): # type: ignore[misc] + """An implementation of the private_key_jwt client auth method that includes a kid header. + + This is needed because some providers (Keycloak) require the kid header to figure + out which key to use to verify the signature. + """ + + def sign(self, auth: Any, token_endpoint: str) -> bytes: + return private_key_jwt_sign( + auth.client_secret, + client_id=auth.client_id, + token_endpoint=token_endpoint, + claims=self.claims, + header={"kid": auth.client_secret["kid"]}, + ) + + +class MSC3861DelegatedAuth(BaseAuth): + AUTH_METHODS = { + "client_secret_post": encode_client_secret_post, + "client_secret_basic": encode_client_secret_basic, + "client_secret_jwt": ClientSecretJWT(), + "private_key_jwt": PrivateKeyJWTWithKid(), + } + + EXTERNAL_ID_PROVIDER = "oauth-delegated" + + def __init__(self, hs: "HomeServer"): + super().__init__(hs) + + self._config = hs.config.experimental.msc3861 + auth_method = MSC3861DelegatedAuth.AUTH_METHODS.get( + self._config.client_auth_method.value, None + ) + # Those assertions are already checked when parsing the config + assert self._config.enabled, "OAuth delegation is not enabled" + assert self._config.issuer, "No issuer provided" + assert self._config.client_id, "No client_id provided" + assert auth_method is not None, "Invalid client_auth_method provided" + + self._clock = hs.get_clock() + self._http_client = hs.get_proxied_http_client() + self._hostname = hs.hostname + self._admin_token = self._config.admin_token + + self._issuer_metadata = RetryOnExceptionCachedCall(self._load_metadata) + + if isinstance(auth_method, PrivateKeyJWTWithKid): + # Use the JWK as the client secret when using the private_key_jwt method + assert self._config.jwk, "No JWK provided" + self._client_auth = ClientAuth( + self._config.client_id, self._config.jwk, auth_method + ) + else: + # Else use the client secret + assert self._config.client_secret, "No client_secret provided" + self._client_auth = ClientAuth( + self._config.client_id, self._config.client_secret, auth_method + ) + + async def _load_metadata(self) -> OpenIDProviderMetadata: + if self._config.issuer_metadata is not None: + return OpenIDProviderMetadata(**self._config.issuer_metadata) + url = get_well_known_url(self._config.issuer, external=True) + response = await self._http_client.get_json(url) + metadata = OpenIDProviderMetadata(**response) + # metadata.validate_introspection_endpoint() + return metadata + + async def _introspect_token(self, token: str) -> IntrospectionToken: + """ + Send a token to the introspection endpoint and returns the introspection response + + Parameters: + token: The token to introspect + + Raises: + HttpResponseException: If the introspection endpoint returns a non-2xx response + ValueError: If the introspection endpoint returns an invalid JSON response + JSONDecodeError: If the introspection endpoint returns a non-JSON response + Exception: If the HTTP request fails + + Returns: + The introspection response + """ + metadata = await self._issuer_metadata.get() + introspection_endpoint = metadata.get("introspection_endpoint") + raw_headers: Dict[str, str] = { + "Content-Type": "application/x-www-form-urlencoded", + "User-Agent": str(self._http_client.user_agent, "utf-8"), + "Accept": "application/json", + } + + args = {"token": token, "token_type_hint": "access_token"} + body = urlencode(args, True) + + # Fill the body/headers with credentials + uri, raw_headers, body = self._client_auth.prepare( + method="POST", uri=introspection_endpoint, headers=raw_headers, body=body + ) + headers = Headers({k: [v] for (k, v) in raw_headers.items()}) + + # Do the actual request + # We're not using the SimpleHttpClient util methods as we don't want to + # check the HTTP status code, and we do the body encoding ourselves. + + start_time = self._clock.time() + try: + response = await self._http_client.request( + method="POST", + uri=uri, + data=body.encode("utf-8"), + headers=headers, + ) + + resp_body = await make_deferred_yieldable(readBody(response)) + except Exception: + end_time = self._clock.time() + introspection_response_timer.labels("ERR").observe(end_time - start_time) + raise + + end_time = self._clock.time() + introspection_response_timer.labels(response.code).observe( + end_time - start_time + ) + + if response.code < 200 or response.code >= 300: + raise HttpResponseException( + response.code, + response.phrase.decode("ascii", errors="replace"), + resp_body, + ) + + resp = json_decoder.decode(resp_body.decode("utf-8")) + + if not isinstance(resp, dict): + raise ValueError( + "The introspection endpoint returned an invalid JSON response." + ) + + return IntrospectionToken(**resp) + + async def is_server_admin(self, requester: Requester) -> bool: + return "urn:synapse:admin:*" in requester.scope + + async def get_user_by_req( + self, + request: SynapseRequest, + allow_guest: bool = False, + allow_expired: bool = False, + allow_locked: bool = False, + ) -> Requester: + access_token = self.get_access_token_from_request(request) + + requester = await self.get_appservice_user(request, access_token) + if not requester: + # TODO: we probably want to assert the allow_guest inside this call + # so that we don't provision the user if they don't have enough permission: + requester = await self.get_user_by_access_token(access_token, allow_expired) + + if not allow_guest and requester.is_guest: + raise OAuthInsufficientScopeError([SCOPE_MATRIX_API]) + + request.requester = requester + + return requester + + async def get_user_by_access_token( + self, + token: str, + allow_expired: bool = False, + ) -> Requester: + if self._admin_token is not None and token == self._admin_token: + # XXX: This is a temporary solution so that the admin API can be called by + # the OIDC provider. This will be removed once we have OIDC client + # credentials grant support in matrix-authentication-service. + logging.info("Admin toked used") + # XXX: that user doesn't exist and won't be provisioned. + # This is mostly fine for admin calls, but we should also think about doing + # requesters without a user_id. + admin_user = UserID("__oidc_admin", self._hostname) + return create_requester( + user_id=admin_user, + scope=["urn:synapse:admin:*"], + ) + + try: + introspection_result = await self._introspect_token(token) + except Exception: + logger.exception("Failed to introspect token") + raise SynapseError(503, "Unable to introspect the access token") + + logger.info(f"Introspection result: {introspection_result!r}") + + # TODO: introspection verification should be more extensive, especially: + # - verify the audience + if not introspection_result.get("active"): + raise InvalidClientTokenError("Token is not active") + + # Let's look at the scope + scope: List[str] = scope_to_list(introspection_result.get("scope", "")) + + # Determine type of user based on presence of particular scopes + has_user_scope = SCOPE_MATRIX_API in scope + has_guest_scope = SCOPE_MATRIX_GUEST in scope + + if not has_user_scope and not has_guest_scope: + raise InvalidClientTokenError("No scope in token granting user rights") + + # Match via the sub claim + sub: Optional[str] = introspection_result.get("sub") + if sub is None: + raise InvalidClientTokenError( + "Invalid sub claim in the introspection result" + ) + + user_id_str = await self.store.get_user_by_external_id( + MSC3861DelegatedAuth.EXTERNAL_ID_PROVIDER, sub + ) + if user_id_str is None: + # If we could not find a user via the external_id, it either does not exist, + # or the external_id was never recorded + + # TODO: claim mapping should be configurable + username: Optional[str] = introspection_result.get("username") + if username is None or not isinstance(username, str): + raise AuthError( + 500, + "Invalid username claim in the introspection result", + ) + user_id = UserID(username, self._hostname) + + # First try to find a user from the username claim + user_info = await self.store.get_user_by_id(user_id=user_id.to_string()) + if user_info is None: + # If the user does not exist, we should create it on the fly + # TODO: we could use SCIM to provision users ahead of time and listen + # for SCIM SET events if those ever become standard: + # https://datatracker.ietf.org/doc/html/draft-hunt-scim-notify-00 + + # TODO: claim mapping should be configurable + # If present, use the name claim as the displayname + name: Optional[str] = introspection_result.get("name") + + await self.store.register_user( + user_id=user_id.to_string(), create_profile_with_displayname=name + ) + + # And record the sub as external_id + await self.store.record_user_external_id( + MSC3861DelegatedAuth.EXTERNAL_ID_PROVIDER, sub, user_id.to_string() + ) + else: + user_id = UserID.from_string(user_id_str) + + # Find device_ids in scope + # We only allow a single device_id in the scope, so we find them all in the + # scope list, and raise if there are more than one. The OIDC server should be + # the one enforcing valid scopes, so we raise a 500 if we find an invalid scope. + device_ids = [ + tok[len(SCOPE_MATRIX_DEVICE_PREFIX) :] + for tok in scope + if tok.startswith(SCOPE_MATRIX_DEVICE_PREFIX) + ] + + if len(device_ids) > 1: + raise AuthError( + 500, + "Multiple device IDs in scope", + ) + + device_id = device_ids[0] if device_ids else None + if device_id is not None: + # Sanity check the device_id + if len(device_id) > 255 or len(device_id) < 1: + raise AuthError( + 500, + "Invalid device ID in scope", + ) + + # Create the device on the fly if it does not exist + try: + await self.store.get_device( + user_id=user_id.to_string(), device_id=device_id + ) + except StoreError: + await self.store.store_device( + user_id=user_id.to_string(), + device_id=device_id, + initial_device_display_name="OIDC-native client", + ) + + # TODO: there is a few things missing in the requester here, which still need + # to be figured out, like: + # - impersonation, with the `authenticated_entity`, which is used for + # rate-limiting, MAU limits, etc. + # - shadow-banning, with the `shadow_banned` flag + # - a proper solution for appservices, which still needs to be figured out in + # the context of MSC3861 + return create_requester( + user_id=user_id, + device_id=device_id, + scope=scope, + is_guest=(has_guest_scope and not has_user_scope), + ) diff --git a/synapse/api/constants.py b/synapse/api/constants.py index cde9a2ecef..bf311b636d 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -18,8 +18,7 @@ """Contains constants from the specification.""" import enum - -from typing_extensions import Final +from typing import Final # the max size of a (canonical-json-encoded) event MAX_PDU_SIZE = 65536 @@ -123,10 +122,6 @@ class EventTypes: SpaceChild: Final = "m.space.child" SpaceParent: Final = "m.space.parent" - MSC2716_INSERTION: Final = "org.matrix.msc2716.insertion" - MSC2716_BATCH: Final = "org.matrix.msc2716.batch" - MSC2716_MARKER: Final = "org.matrix.msc2716.marker" - Reaction: Final = "m.reaction" @@ -222,21 +217,11 @@ class EventContentFields: # Used in m.room.guest_access events. GUEST_ACCESS: Final = "guest_access" - # Used on normal messages to indicate they were historically imported after the fact - MSC2716_HISTORICAL: Final = "org.matrix.msc2716.historical" - # For "insertion" events to indicate what the next batch ID should be in - # order to connect to it - MSC2716_NEXT_BATCH_ID: Final = "next_batch_id" - # Used on "batch" events to indicate which insertion event it connects to - MSC2716_BATCH_ID: Final = "batch_id" - # For "marker" events - MSC2716_INSERTION_EVENT_REFERENCE: Final = "insertion_event_reference" - # The authorising user for joining a restricted room. AUTHORISING_USER: Final = "join_authorised_via_users_server" # Use for mentioning users. - MSC3952_MENTIONS: Final = "org.matrix.msc3952.mentions" + MENTIONS: Final = "m.mentions" # an unspecced field added to to-device messages to identify them uniquely-ish TO_DEVICE_MSGID: Final = "org.matrix.msgid" diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 8c7c94b045..fdb2955be8 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -16,6 +16,7 @@ """Contains exceptions and error codes.""" import logging +import math import typing from enum import Enum from http import HTTPStatus @@ -80,6 +81,8 @@ class Codes(str, Enum): WEAK_PASSWORD = "M_WEAK_PASSWORD" INVALID_SIGNATURE = "M_INVALID_SIGNATURE" USER_DEACTIVATED = "M_USER_DEACTIVATED" + # USER_LOCKED = "M_USER_LOCKED" + USER_LOCKED = "ORG_MATRIX_MSC3939_USER_LOCKED" # Part of MSC3848 # https://github.com/matrix-org/matrix-spec-proposals/pull/3848 @@ -119,14 +122,20 @@ class Codes(str, Enum): class CodeMessageException(RuntimeError): - """An exception with integer code and message string attributes. + """An exception with integer code, a message string attributes and optional headers. Attributes: code: HTTP error code msg: string describing the error + headers: optional response headers to send """ - def __init__(self, code: Union[int, HTTPStatus], msg: str): + def __init__( + self, + code: Union[int, HTTPStatus], + msg: str, + headers: Optional[Dict[str, str]] = None, + ): super().__init__("%d: %s" % (code, msg)) # Some calls to this method pass instances of http.HTTPStatus for `code`. @@ -137,6 +146,7 @@ class CodeMessageException(RuntimeError): # To eliminate this behaviour, we convert them to their integer equivalents here. self.code = int(code) self.msg = msg + self.headers = headers class RedirectException(CodeMessageException): @@ -182,6 +192,7 @@ class SynapseError(CodeMessageException): msg: str, errcode: str = Codes.UNKNOWN, additional_fields: Optional[Dict] = None, + headers: Optional[Dict[str, str]] = None, ): """Constructs a synapse error. @@ -190,7 +201,7 @@ class SynapseError(CodeMessageException): msg: The human-readable error message. errcode: The matrix error code e.g 'M_FORBIDDEN' """ - super().__init__(code, msg) + super().__init__(code, msg, headers) self.errcode = errcode if additional_fields is None: self._additional_fields: Dict = {} @@ -200,6 +211,11 @@ class SynapseError(CodeMessageException): def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict": return cs_error(self.msg, self.errcode, **self._additional_fields) + @property + def debug_context(self) -> Optional[str]: + """Override this to add debugging context that shouldn't be sent to clients.""" + return None + class InvalidAPICallError(SynapseError): """You called an existing API endpoint, but fed that endpoint @@ -209,6 +225,13 @@ class InvalidAPICallError(SynapseError): super().__init__(HTTPStatus.BAD_REQUEST, msg, Codes.BAD_JSON) +class InvalidProxyCredentialsError(SynapseError): + """Error raised when the proxy credentials are invalid.""" + + def __init__(self, msg: str, errcode: str = Codes.UNKNOWN): + super().__init__(401, msg, errcode) + + class ProxiedRequestError(SynapseError): """An error from a general matrix endpoint, eg. from a proxied Matrix API call. @@ -335,6 +358,20 @@ class AuthError(SynapseError): super().__init__(code, msg, errcode, additional_fields) +class OAuthInsufficientScopeError(SynapseError): + """An error raised when the caller does not have sufficient scope to perform the requested action""" + + def __init__( + self, + required_scopes: List[str], + ): + headers = { + "WWW-Authenticate": 'Bearer error="insufficient_scope", scope="%s"' + % (" ".join(required_scopes)) + } + super().__init__(401, "Insufficient scope", Codes.FORBIDDEN, None, headers) + + class UnstableSpecAuthError(AuthError): """An error raised when a new error code is being proposed to replace a previous one. This error will return a "org.matrix.unstable.errcode" property with the new error code, @@ -472,19 +509,31 @@ class InvalidCaptchaError(SynapseError): class LimitExceededError(SynapseError): """A client has sent too many requests and is being throttled.""" + include_retry_after_header = False + def __init__( self, + limiter_name: str, code: int = 429, - msg: str = "Too Many Requests", retry_after_ms: Optional[int] = None, errcode: str = Codes.LIMIT_EXCEEDED, ): - super().__init__(code, msg, errcode) + headers = ( + {"Retry-After": str(math.ceil(retry_after_ms / 1000))} + if self.include_retry_after_header and retry_after_ms is not None + else None + ) + super().__init__(code, "Too Many Requests", errcode, headers=headers) self.retry_after_ms = retry_after_ms + self.limiter_name = limiter_name def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict": return cs_error(self.msg, self.errcode, retry_after_ms=self.retry_after_ms) + @property + def debug_context(self) -> Optional[str]: + return self.limiter_name + class RoomKeysVersionError(SynapseError): """A client has tried to upload to a non-current version of the room_keys store""" diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index de7c56bc0f..74ee8e9f3f 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -37,7 +37,7 @@ from synapse.api.constants import EduTypes, EventContentFields from synapse.api.errors import SynapseError from synapse.api.presence import UserPresenceState from synapse.events import EventBase, relation_from_event -from synapse.types import JsonDict, RoomID, UserID +from synapse.types import JsonDict, JsonMapping, RoomID, UserID if TYPE_CHECKING: from synapse.server import HomeServer @@ -128,20 +128,7 @@ USER_FILTER_SCHEMA = { "account_data": {"$ref": "#/definitions/filter"}, "room": {"$ref": "#/definitions/room_filter"}, "event_format": {"type": "string", "enum": ["client", "federation"]}, - "event_fields": { - "type": "array", - "items": { - "type": "string", - # Don't allow '\\' in event field filters. This makes matching - # events a lot easier as we can then use a negative lookbehind - # assertion to split '\.' If we allowed \\ then it would - # incorrectly split '\\.' See synapse.events.utils.serialize_event - # - # Note that because this is a regular expression, we have to escape - # each backslash in the pattern. - "pattern": r"^((?!\\\\).)*$", - }, - }, + "event_fields": {"type": "array", "items": {"type": "string"}}, }, "additionalProperties": True, # Allow new fields for forward compatibility } @@ -165,9 +152,9 @@ class Filtering: self.DEFAULT_FILTER_COLLECTION = FilterCollection(hs, {}) async def get_user_filter( - self, user_localpart: str, filter_id: Union[int, str] + self, user_id: UserID, filter_id: Union[int, str] ) -> "FilterCollection": - result = await self.store.get_user_filter(user_localpart, filter_id) + result = await self.store.get_user_filter(user_id, filter_id) return FilterCollection(self._hs, result) def add_user_filter(self, user_id: UserID, user_filter: JsonDict) -> Awaitable[int]: @@ -204,7 +191,7 @@ FilterEvent = TypeVar("FilterEvent", EventBase, UserPresenceState, JsonDict) class FilterCollection: - def __init__(self, hs: "HomeServer", filter_json: JsonDict): + def __init__(self, hs: "HomeServer", filter_json: JsonMapping): self._filter_json = filter_json room_filter_json = self._filter_json.get("room", {}) @@ -232,7 +219,7 @@ class FilterCollection: def __repr__(self) -> str: return "<FilterCollection %s>" % (json.dumps(self._filter_json),) - def get_filter_json(self) -> JsonDict: + def get_filter_json(self) -> JsonMapping: return self._filter_json def timeline_limit(self) -> int: @@ -326,7 +313,7 @@ class FilterCollection: class Filter: - def __init__(self, hs: "HomeServer", filter_json: JsonDict): + def __init__(self, hs: "HomeServer", filter_json: JsonMapping): self._hs = hs self._store = hs.get_datastores().main self.filter_json = filter_json diff --git a/synapse/api/presence.py b/synapse/api/presence.py index b80aa83cb3..afef6712e1 100644 --- a/synapse/api/presence.py +++ b/synapse/api/presence.py @@ -20,18 +20,53 @@ from synapse.api.constants import PresenceState from synapse.types import JsonDict +@attr.s(slots=True, auto_attribs=True) +class UserDevicePresenceState: + """ + Represents the current presence state of a user's device. + + user_id: The user ID. + device_id: The user's device ID. + state: The presence state, see PresenceState. + last_active_ts: Time in msec that the device last interacted with server. + last_sync_ts: Time in msec that the device last *completed* a sync + (or event stream). + """ + + user_id: str + device_id: Optional[str] + state: str + last_active_ts: int + last_sync_ts: int + + @classmethod + def default( + cls, user_id: str, device_id: Optional[str] + ) -> "UserDevicePresenceState": + """Returns a default presence state.""" + return cls( + user_id=user_id, + device_id=device_id, + state=PresenceState.OFFLINE, + last_active_ts=0, + last_sync_ts=0, + ) + + @attr.s(slots=True, frozen=True, auto_attribs=True) class UserPresenceState: """Represents the current presence state of the user. - user_id - last_active: Time in msec that the user last interacted with server. - last_federation_update: Time in msec since either a) we sent a presence + user_id: The user ID. + state: The presence state, see PresenceState. + last_active_ts: Time in msec that the user last interacted with server. + last_federation_update_ts: Time in msec since either a) we sent a presence update to other servers or b) we received a presence update, depending on if is a local user or not. - last_user_sync: Time in msec that the user last *completed* a sync + last_user_sync_ts: Time in msec that the user last *completed* a sync (or event stream). status_msg: User set status message. + currently_active: True if the user is currently syncing. """ user_id: str @@ -45,10 +80,6 @@ class UserPresenceState: def as_dict(self) -> JsonDict: return attr.asdict(self) - @staticmethod - def from_dict(d: JsonDict) -> "UserPresenceState": - return UserPresenceState(**d) - def copy_and_replace(self, **kwargs: Any) -> "UserPresenceState": return attr.evolve(self, **kwargs) diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index 511790c7c5..02ae45e8b3 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -40,7 +40,7 @@ class Ratelimiter: - the cost C of this request in tokens. Then, if there is room in the bucket for C tokens (T + C <= `burst_count`), the request is permitted and `cost` tokens are added to the bucket. - Otherwise the request is denied, and the bucket continues to hold T tokens. + Otherwise, the request is denied, and the bucket continues to hold T tokens. This means that the limiter enforces an average request frequency of `rate_hz`, while accumulating a buffer of up to `burst_count` requests which can be consumed @@ -55,18 +55,23 @@ class Ratelimiter: request. Args: + store: The datastore providing get_ratelimit_for_user. clock: A homeserver clock, for retrieving the current time - rate_hz: The long term number of actions that can be performed in a second. - burst_count: How many actions that can be performed before being limited. + cfg: The ratelimit configuration for this rate limiter including the + allowed rate and burst count. """ def __init__( - self, store: DataStore, clock: Clock, rate_hz: float, burst_count: int + self, + store: DataStore, + clock: Clock, + cfg: RatelimitSettings, ): self.clock = clock - self.rate_hz = rate_hz - self.burst_count = burst_count + self.rate_hz = cfg.per_second + self.burst_count = cfg.burst_count self.store = store + self._limiter_name = cfg.key # An ordered dictionary representing the token buckets tracked by this rate # limiter. Each entry maps a key of arbitrary type to a tuple representing: @@ -305,7 +310,8 @@ class Ratelimiter: if not allowed: raise LimitExceededError( - retry_after_ms=int(1000 * (time_allowed - time_now_s)) + limiter_name=self._limiter_name, + retry_after_ms=int(1000 * (time_allowed - time_now_s)), ) @@ -322,7 +328,9 @@ class RequestRatelimiter: # The rate_hz and burst_count are overridden on a per-user basis self.request_ratelimiter = Ratelimiter( - store=self.store, clock=self.clock, rate_hz=0, burst_count=0 + store=self.store, + clock=self.clock, + cfg=RatelimitSettings(key=rc_message.key, per_second=0, burst_count=0), ) self._rc_message = rc_message @@ -332,8 +340,7 @@ class RequestRatelimiter: self.admin_redaction_ratelimiter: Optional[Ratelimiter] = Ratelimiter( store=self.store, clock=self.clock, - rate_hz=rc_admin_redaction.per_second, - burst_count=rc_admin_redaction.burst_count, + cfg=rc_admin_redaction, ) else: self.admin_redaction_ratelimiter = None diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py index 7030b133d3..e7662d5b99 100644 --- a/synapse/api/room_versions.py +++ b/synapse/api/room_versions.py @@ -78,41 +78,29 @@ class RoomVersion: # MSC2209: Check 'notifications' key while verifying # m.room.power_levels auth rules. limit_notifications_power_levels: bool - # MSC2175: No longer include the creator in m.room.create events. - msc2175_implicit_room_creator: bool - # MSC2174/MSC2176: Apply updated redaction rules algorithm, move redacts to - # content property. - msc2176_redaction_rules: bool - # MSC3083: Support the 'restricted' join_rule. - msc3083_join_rules: bool - # MSC3375: Support for the proper redaction rules for MSC3083. This mustn't - # be enabled if MSC3083 is not. - msc3375_redaction_rules: bool - # MSC2403: Allows join_rules to be set to 'knock', changes auth rules to allow sending - # m.room.membership event with membership 'knock'. - msc2403_knocking: bool - # MSC2716: Adds m.room.power_levels -> content.historical field to control - # whether "insertion", "chunk", "marker" events can be sent - msc2716_historical: bool - # MSC2716: Adds support for redacting "insertion", "chunk", and "marker" events - msc2716_redactions: bool + # No longer include the creator in m.room.create events. + implicit_room_creator: bool + # Apply updated redaction rules algorithm from room version 11. + updated_redaction_rules: bool + # Support the 'restricted' join rule. + restricted_join_rule: bool + # Support for the proper redaction rules for the restricted join rule. This requires + # restricted_join_rule to be enabled. + restricted_join_rule_fix: bool + # Support the 'knock' join rule. + knock_join_rule: bool # MSC3389: Protect relation information from redaction. msc3389_relation_redactions: bool - # MSC3787: Adds support for a `knock_restricted` join rule, mixing concepts of - # knocks and restricted join rules into the same join condition. - msc3787_knock_restricted_join_rule: bool - # MSC3667: Enforce integer power levels - msc3667_int_only_power_levels: bool - # MSC3821: Do not redact the third_party_invite content field for membership events. - msc3821_redaction_rules: bool + # Support the 'knock_restricted' join rule. + knock_restricted_join_rule: bool + # Enforce integer power levels + enforce_int_power_levels: bool # MSC3931: Adds a push rule condition for "room version feature flags", making # some push rules room version dependent. Note that adding a flag to this list # is not enough to mark it "supported": the push rule evaluator also needs to # support the flag. Unknown flags are ignored by the evaluator, making conditions # fail if used. msc3931_push_features: Tuple[str, ...] # values from PushRuleRoomFlag - # MSC3989: Redact the origin field. - msc3989_redaction_rules: bool class RoomVersions: @@ -125,19 +113,15 @@ class RoomVersions: special_case_aliases_auth=True, strict_canonicaljson=False, limit_notifications_power_levels=False, - msc2175_implicit_room_creator=False, - msc2176_redaction_rules=False, - msc3083_join_rules=False, - msc3375_redaction_rules=False, - msc2403_knocking=False, - msc2716_historical=False, - msc2716_redactions=False, + implicit_room_creator=False, + updated_redaction_rules=False, + restricted_join_rule=False, + restricted_join_rule_fix=False, + knock_join_rule=False, msc3389_relation_redactions=False, - msc3787_knock_restricted_join_rule=False, - msc3667_int_only_power_levels=False, - msc3821_redaction_rules=False, + knock_restricted_join_rule=False, + enforce_int_power_levels=False, msc3931_push_features=(), - msc3989_redaction_rules=False, ) V2 = RoomVersion( "2", @@ -148,19 +132,15 @@ class RoomVersions: special_case_aliases_auth=True, strict_canonicaljson=False, limit_notifications_power_levels=False, - msc2175_implicit_room_creator=False, - msc2176_redaction_rules=False, - msc3083_join_rules=False, - msc3375_redaction_rules=False, - msc2403_knocking=False, - msc2716_historical=False, - msc2716_redactions=False, + implicit_room_creator=False, + updated_redaction_rules=False, + restricted_join_rule=False, + restricted_join_rule_fix=False, + knock_join_rule=False, msc3389_relation_redactions=False, - msc3787_knock_restricted_join_rule=False, - msc3667_int_only_power_levels=False, - msc3821_redaction_rules=False, + knock_restricted_join_rule=False, + enforce_int_power_levels=False, msc3931_push_features=(), - msc3989_redaction_rules=False, ) V3 = RoomVersion( "3", @@ -171,19 +151,15 @@ class RoomVersions: special_case_aliases_auth=True, strict_canonicaljson=False, limit_notifications_power_levels=False, - msc2175_implicit_room_creator=False, - msc2176_redaction_rules=False, - msc3083_join_rules=False, - msc3375_redaction_rules=False, - msc2403_knocking=False, - msc2716_historical=False, - msc2716_redactions=False, + implicit_room_creator=False, + updated_redaction_rules=False, + restricted_join_rule=False, + restricted_join_rule_fix=False, + knock_join_rule=False, msc3389_relation_redactions=False, - msc3787_knock_restricted_join_rule=False, - msc3667_int_only_power_levels=False, - msc3821_redaction_rules=False, + knock_restricted_join_rule=False, + enforce_int_power_levels=False, msc3931_push_features=(), - msc3989_redaction_rules=False, ) V4 = RoomVersion( "4", @@ -194,19 +170,15 @@ class RoomVersions: special_case_aliases_auth=True, strict_canonicaljson=False, limit_notifications_power_levels=False, - msc2175_implicit_room_creator=False, - msc2176_redaction_rules=False, - msc3083_join_rules=False, - msc3375_redaction_rules=False, - msc2403_knocking=False, - msc2716_historical=False, - msc2716_redactions=False, + implicit_room_creator=False, + updated_redaction_rules=False, + restricted_join_rule=False, + restricted_join_rule_fix=False, + knock_join_rule=False, msc3389_relation_redactions=False, - msc3787_knock_restricted_join_rule=False, - msc3667_int_only_power_levels=False, - msc3821_redaction_rules=False, + knock_restricted_join_rule=False, + enforce_int_power_levels=False, msc3931_push_features=(), - msc3989_redaction_rules=False, ) V5 = RoomVersion( "5", @@ -217,19 +189,15 @@ class RoomVersions: special_case_aliases_auth=True, strict_canonicaljson=False, limit_notifications_power_levels=False, - msc2175_implicit_room_creator=False, - msc2176_redaction_rules=False, - msc3083_join_rules=False, - msc3375_redaction_rules=False, - msc2403_knocking=False, - msc2716_historical=False, - msc2716_redactions=False, + implicit_room_creator=False, + updated_redaction_rules=False, + restricted_join_rule=False, + restricted_join_rule_fix=False, + knock_join_rule=False, msc3389_relation_redactions=False, - msc3787_knock_restricted_join_rule=False, - msc3667_int_only_power_levels=False, - msc3821_redaction_rules=False, + knock_restricted_join_rule=False, + enforce_int_power_levels=False, msc3931_push_features=(), - msc3989_redaction_rules=False, ) V6 = RoomVersion( "6", @@ -240,42 +208,15 @@ class RoomVersions: special_case_aliases_auth=False, strict_canonicaljson=True, limit_notifications_power_levels=True, - msc2175_implicit_room_creator=False, - msc2176_redaction_rules=False, - msc3083_join_rules=False, - msc3375_redaction_rules=False, - msc2403_knocking=False, - msc2716_historical=False, - msc2716_redactions=False, + implicit_room_creator=False, + updated_redaction_rules=False, + restricted_join_rule=False, + restricted_join_rule_fix=False, + knock_join_rule=False, msc3389_relation_redactions=False, - msc3787_knock_restricted_join_rule=False, - msc3667_int_only_power_levels=False, - msc3821_redaction_rules=False, + knock_restricted_join_rule=False, + enforce_int_power_levels=False, msc3931_push_features=(), - msc3989_redaction_rules=False, - ) - MSC2176 = RoomVersion( - "org.matrix.msc2176", - RoomDisposition.UNSTABLE, - EventFormatVersions.ROOM_V4_PLUS, - StateResolutionVersions.V2, - enforce_key_validity=True, - special_case_aliases_auth=False, - strict_canonicaljson=True, - limit_notifications_power_levels=True, - msc2175_implicit_room_creator=False, - msc2176_redaction_rules=True, - msc3083_join_rules=False, - msc3375_redaction_rules=False, - msc2403_knocking=False, - msc2716_historical=False, - msc2716_redactions=False, - msc3389_relation_redactions=False, - msc3787_knock_restricted_join_rule=False, - msc3667_int_only_power_levels=False, - msc3821_redaction_rules=False, - msc3931_push_features=(), - msc3989_redaction_rules=False, ) V7 = RoomVersion( "7", @@ -286,19 +227,15 @@ class RoomVersions: special_case_aliases_auth=False, strict_canonicaljson=True, limit_notifications_power_levels=True, - msc2175_implicit_room_creator=False, - msc2176_redaction_rules=False, - msc3083_join_rules=False, - msc3375_redaction_rules=False, - msc2403_knocking=True, - msc2716_historical=False, - msc2716_redactions=False, + implicit_room_creator=False, + updated_redaction_rules=False, + restricted_join_rule=False, + restricted_join_rule_fix=False, + knock_join_rule=True, msc3389_relation_redactions=False, - msc3787_knock_restricted_join_rule=False, - msc3667_int_only_power_levels=False, - msc3821_redaction_rules=False, + knock_restricted_join_rule=False, + enforce_int_power_levels=False, msc3931_push_features=(), - msc3989_redaction_rules=False, ) V8 = RoomVersion( "8", @@ -309,19 +246,15 @@ class RoomVersions: special_case_aliases_auth=False, strict_canonicaljson=True, limit_notifications_power_levels=True, - msc2175_implicit_room_creator=False, - msc2176_redaction_rules=False, - msc3083_join_rules=True, - msc3375_redaction_rules=False, - msc2403_knocking=True, - msc2716_historical=False, - msc2716_redactions=False, + implicit_room_creator=False, + updated_redaction_rules=False, + restricted_join_rule=True, + restricted_join_rule_fix=False, + knock_join_rule=True, msc3389_relation_redactions=False, - msc3787_knock_restricted_join_rule=False, - msc3667_int_only_power_levels=False, - msc3821_redaction_rules=False, + knock_restricted_join_rule=False, + enforce_int_power_levels=False, msc3931_push_features=(), - msc3989_redaction_rules=False, ) V9 = RoomVersion( "9", @@ -332,65 +265,15 @@ class RoomVersions: special_case_aliases_auth=False, strict_canonicaljson=True, limit_notifications_power_levels=True, - msc2175_implicit_room_creator=False, - msc2176_redaction_rules=False, - msc3083_join_rules=True, - msc3375_redaction_rules=True, - msc2403_knocking=True, - msc2716_historical=False, - msc2716_redactions=False, - msc3389_relation_redactions=False, - msc3787_knock_restricted_join_rule=False, - msc3667_int_only_power_levels=False, - msc3821_redaction_rules=False, - msc3931_push_features=(), - msc3989_redaction_rules=False, - ) - MSC3787 = RoomVersion( - "org.matrix.msc3787", - RoomDisposition.UNSTABLE, - EventFormatVersions.ROOM_V4_PLUS, - StateResolutionVersions.V2, - enforce_key_validity=True, - special_case_aliases_auth=False, - strict_canonicaljson=True, - limit_notifications_power_levels=True, - msc2175_implicit_room_creator=False, - msc2176_redaction_rules=False, - msc3083_join_rules=True, - msc3375_redaction_rules=True, - msc2403_knocking=True, - msc2716_historical=False, - msc2716_redactions=False, - msc3389_relation_redactions=False, - msc3787_knock_restricted_join_rule=True, - msc3667_int_only_power_levels=False, - msc3821_redaction_rules=False, - msc3931_push_features=(), - msc3989_redaction_rules=False, - ) - MSC3821 = RoomVersion( - "org.matrix.msc3821.opt1", - RoomDisposition.UNSTABLE, - EventFormatVersions.ROOM_V4_PLUS, - StateResolutionVersions.V2, - enforce_key_validity=True, - special_case_aliases_auth=False, - strict_canonicaljson=True, - limit_notifications_power_levels=True, - msc2175_implicit_room_creator=False, - msc2176_redaction_rules=False, - msc3083_join_rules=True, - msc3375_redaction_rules=True, - msc2403_knocking=True, - msc2716_historical=False, - msc2716_redactions=False, + implicit_room_creator=False, + updated_redaction_rules=False, + restricted_join_rule=True, + restricted_join_rule_fix=True, + knock_join_rule=True, msc3389_relation_redactions=False, - msc3787_knock_restricted_join_rule=False, - msc3667_int_only_power_levels=False, - msc3821_redaction_rules=True, + knock_restricted_join_rule=False, + enforce_int_power_levels=False, msc3931_push_features=(), - msc3989_redaction_rules=False, ) V10 = RoomVersion( "10", @@ -401,42 +284,15 @@ class RoomVersions: special_case_aliases_auth=False, strict_canonicaljson=True, limit_notifications_power_levels=True, - msc2175_implicit_room_creator=False, - msc2176_redaction_rules=False, - msc3083_join_rules=True, - msc3375_redaction_rules=True, - msc2403_knocking=True, - msc2716_historical=False, - msc2716_redactions=False, + implicit_room_creator=False, + updated_redaction_rules=False, + restricted_join_rule=True, + restricted_join_rule_fix=True, + knock_join_rule=True, msc3389_relation_redactions=False, - msc3787_knock_restricted_join_rule=True, - msc3667_int_only_power_levels=True, - msc3821_redaction_rules=False, + knock_restricted_join_rule=True, + enforce_int_power_levels=True, msc3931_push_features=(), - msc3989_redaction_rules=False, - ) - MSC2716v4 = RoomVersion( - "org.matrix.msc2716v4", - RoomDisposition.UNSTABLE, - EventFormatVersions.ROOM_V4_PLUS, - StateResolutionVersions.V2, - enforce_key_validity=True, - special_case_aliases_auth=False, - strict_canonicaljson=True, - limit_notifications_power_levels=True, - msc2175_implicit_room_creator=False, - msc2176_redaction_rules=False, - msc3083_join_rules=False, - msc3375_redaction_rules=False, - msc2403_knocking=True, - msc2716_historical=True, - msc2716_redactions=True, - msc3389_relation_redactions=False, - msc3787_knock_restricted_join_rule=False, - msc3667_int_only_power_levels=False, - msc3821_redaction_rules=False, - msc3931_push_features=(), - msc3989_redaction_rules=False, ) MSC1767v10 = RoomVersion( # MSC1767 (Extensible Events) based on room version "10" @@ -448,42 +304,34 @@ class RoomVersions: special_case_aliases_auth=False, strict_canonicaljson=True, limit_notifications_power_levels=True, - msc2175_implicit_room_creator=False, - msc2176_redaction_rules=False, - msc3083_join_rules=True, - msc3375_redaction_rules=True, - msc2403_knocking=True, - msc2716_historical=False, - msc2716_redactions=False, + implicit_room_creator=False, + updated_redaction_rules=False, + restricted_join_rule=True, + restricted_join_rule_fix=True, + knock_join_rule=True, msc3389_relation_redactions=False, - msc3787_knock_restricted_join_rule=True, - msc3667_int_only_power_levels=True, - msc3821_redaction_rules=False, + knock_restricted_join_rule=True, + enforce_int_power_levels=True, msc3931_push_features=(PushRuleRoomFlag.EXTENSIBLE_EVENTS,), - msc3989_redaction_rules=False, ) - MSC3989 = RoomVersion( - "org.matrix.msc3989", - RoomDisposition.UNSTABLE, + V11 = RoomVersion( + "11", + RoomDisposition.STABLE, EventFormatVersions.ROOM_V4_PLUS, StateResolutionVersions.V2, enforce_key_validity=True, special_case_aliases_auth=False, strict_canonicaljson=True, limit_notifications_power_levels=True, - msc2175_implicit_room_creator=False, - msc2176_redaction_rules=False, - msc3083_join_rules=True, - msc3375_redaction_rules=True, - msc2403_knocking=True, - msc2716_historical=False, - msc2716_redactions=False, + implicit_room_creator=True, # Used by MSC3820 + updated_redaction_rules=True, # Used by MSC3820 + restricted_join_rule=True, + restricted_join_rule_fix=True, + knock_join_rule=True, msc3389_relation_redactions=False, - msc3787_knock_restricted_join_rule=True, - msc3667_int_only_power_levels=True, - msc3821_redaction_rules=False, + knock_restricted_join_rule=True, + enforce_int_power_levels=True, msc3931_push_features=(), - msc3989_redaction_rules=True, ) @@ -496,14 +344,11 @@ KNOWN_ROOM_VERSIONS: Dict[str, RoomVersion] = { RoomVersions.V4, RoomVersions.V5, RoomVersions.V6, - RoomVersions.MSC2176, RoomVersions.V7, RoomVersions.V8, RoomVersions.V9, - RoomVersions.MSC3787, RoomVersions.V10, - RoomVersions.MSC2716v4, - RoomVersions.MSC3989, + RoomVersions.V11, ) } @@ -532,12 +377,12 @@ MSC3244_CAPABILITIES = { RoomVersionCapability( "knock", RoomVersions.V7, - lambda room_version: room_version.msc2403_knocking, + lambda room_version: room_version.knock_join_rule, ), RoomVersionCapability( "restricted", RoomVersions.V9, - lambda room_version: room_version.msc3083_join_rules, + lambda room_version: room_version.restricted_join_rule, ), ) } diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 936b1b0430..9ac7e4313e 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -27,9 +27,7 @@ from typing import ( Any, Awaitable, Callable, - Collection, Dict, - Iterable, List, NoReturn, Optional, @@ -76,7 +74,7 @@ from synapse.module_api.callbacks.spamchecker_callbacks import load_legacy_spam_ from synapse.module_api.callbacks.third_party_event_rules_callbacks import ( load_legacy_third_party_event_rules, ) -from synapse.types import ISynapseReactor +from synapse.types import ISynapseReactor, StrCollection from synapse.util import SYNAPSE_VERSION from synapse.util.caches.lrucache import setup_expire_lru_cache_entries from synapse.util.daemonize import daemonize_process @@ -278,7 +276,7 @@ def register_start( reactor.callWhenRunning(lambda: defer.ensureDeferred(wrapper())) -def listen_metrics(bind_addresses: Iterable[str], port: int) -> None: +def listen_metrics(bind_addresses: StrCollection, port: int) -> None: """ Start Prometheus metrics server. """ @@ -315,7 +313,7 @@ def _set_prometheus_client_use_created_metrics(new_value: bool) -> None: def listen_manhole( - bind_addresses: Collection[str], + bind_addresses: StrCollection, port: int, manhole_settings: ManholeConfig, manhole_globals: dict, @@ -339,7 +337,7 @@ def listen_manhole( def listen_tcp( - bind_addresses: Collection[str], + bind_addresses: StrCollection, port: int, factory: ServerFactory, reactor: IReactorTCP = reactor, @@ -386,6 +384,7 @@ def listen_unix( def listen_http( + hs: "HomeServer", listener_config: ListenerConfig, root_resource: Resource, version_string: str, @@ -406,6 +405,7 @@ def listen_http( version_string, max_request_body_size=max_request_body_size, reactor=reactor, + hs=hs, ) if isinstance(listener_config, TCPListenerConfig): @@ -446,7 +446,7 @@ def listen_http( def listen_ssl( - bind_addresses: Collection[str], + bind_addresses: StrCollection, port: int, factory: ServerFactory, context_factory: IOpenSSLContextFactory, diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py index f9aada269a..aa24f7da6c 100644 --- a/synapse/app/admin_cmd.py +++ b/synapse/app/admin_cmd.py @@ -17,7 +17,7 @@ import logging import os import sys import tempfile -from typing import List, Mapping, Optional +from typing import List, Mapping, Optional, Sequence from twisted.internet import defer, task @@ -57,7 +57,7 @@ from synapse.storage.databases.main.state import StateGroupWorkerStore from synapse.storage.databases.main.stream import StreamWorkerStore from synapse.storage.databases.main.tags import TagsWorkerStore from synapse.storage.databases.main.user_erasure_store import UserErasureWorkerStore -from synapse.types import JsonDict, StateMap +from synapse.types import JsonMapping, StateMap from synapse.util import SYNAPSE_VERSION from synapse.util.logcontext import LoggingContext @@ -198,7 +198,7 @@ class FileExfiltrationWriter(ExfiltrationWriter): for event in state.values(): json.dump(event, fp=f) - def write_profile(self, profile: JsonDict) -> None: + def write_profile(self, profile: JsonMapping) -> None: user_directory = os.path.join(self.base_directory, "user_data") os.makedirs(user_directory, exist_ok=True) profile_file = os.path.join(user_directory, "profile") @@ -206,7 +206,7 @@ class FileExfiltrationWriter(ExfiltrationWriter): with open(profile_file, "a") as f: json.dump(profile, fp=f) - def write_devices(self, devices: List[JsonDict]) -> None: + def write_devices(self, devices: Sequence[JsonMapping]) -> None: user_directory = os.path.join(self.base_directory, "user_data") os.makedirs(user_directory, exist_ok=True) device_file = os.path.join(user_directory, "devices") @@ -215,7 +215,7 @@ class FileExfiltrationWriter(ExfiltrationWriter): with open(device_file, "a") as f: json.dump(device, fp=f) - def write_connections(self, connections: List[JsonDict]) -> None: + def write_connections(self, connections: Sequence[JsonMapping]) -> None: user_directory = os.path.join(self.base_directory, "user_data") os.makedirs(user_directory, exist_ok=True) connection_file = os.path.join(user_directory, "connections") @@ -225,7 +225,7 @@ class FileExfiltrationWriter(ExfiltrationWriter): json.dump(connection, fp=f) def write_account_data( - self, file_name: str, account_data: Mapping[str, JsonDict] + self, file_name: str, account_data: Mapping[str, JsonMapping] ) -> None: account_data_directory = os.path.join( self.base_directory, "user_data", "account_data" @@ -237,7 +237,7 @@ class FileExfiltrationWriter(ExfiltrationWriter): with open(account_data_file, "a") as f: json.dump(account_data, fp=f) - def write_media_id(self, media_id: str, media_metadata: JsonDict) -> None: + def write_media_id(self, media_id: str, media_metadata: JsonMapping) -> None: file_directory = os.path.join(self.base_directory, "media_ids") os.makedirs(file_directory, exist_ok=True) media_id_file = os.path.join(file_directory, media_id) diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 909ebccf78..f7c80eee21 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -77,13 +77,13 @@ from synapse.storage.databases.main.monthly_active_users import ( ) from synapse.storage.databases.main.presence import PresenceStore from synapse.storage.databases.main.profile import ProfileWorkerStore +from synapse.storage.databases.main.purge_events import PurgeEventsStore from synapse.storage.databases.main.push_rule import PushRulesWorkerStore from synapse.storage.databases.main.pusher import PusherWorkerStore from synapse.storage.databases.main.receipts import ReceiptsWorkerStore from synapse.storage.databases.main.registration import RegistrationWorkerStore from synapse.storage.databases.main.relations import RelationsWorkerStore from synapse.storage.databases.main.room import RoomWorkerStore -from synapse.storage.databases.main.room_batch import RoomBatchStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.storage.databases.main.search import SearchStore from synapse.storage.databases.main.session import SessionStore @@ -92,6 +92,7 @@ from synapse.storage.databases.main.state import StateGroupWorkerStore from synapse.storage.databases.main.stats import StatsStore from synapse.storage.databases.main.stream import StreamWorkerStore from synapse.storage.databases.main.tags import TagsWorkerStore +from synapse.storage.databases.main.task_scheduler import TaskSchedulerWorkerStore from synapse.storage.databases.main.transactions import TransactionWorkerStore from synapse.storage.databases.main.ui_auth import UIAuthWorkerStore from synapse.storage.databases.main.user_directory import UserDirectoryStore @@ -120,7 +121,6 @@ class GenericWorkerStore( # the races it creates aren't too bad. KeyStore, RoomWorkerStore, - RoomBatchStore, DirectoryWorkerStore, PushRulesWorkerStore, ApplicationServiceTransactionWorkerStore, @@ -135,6 +135,7 @@ class GenericWorkerStore( RelationsWorkerStore, EventFederationWorkerStore, EventPushActionsWorkerStore, + PurgeEventsStore, StateGroupWorkerStore, SignatureWorkerStore, UserErasureWorkerStore, @@ -146,6 +147,7 @@ class GenericWorkerStore( TransactionWorkerStore, LockStore, SessionStore, + TaskSchedulerWorkerStore, ): # Properties that multiple storage classes define. Tell mypy what the # expected type is. @@ -223,6 +225,7 @@ class GenericWorkerServer(HomeServer): root_resource = create_resource_tree(resources, OptionsResource()) _base.listen_http( + self, listener_config, root_resource, self.version_string, diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 84236ac299..f188c7265a 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -139,6 +139,7 @@ class SynapseHomeServer(HomeServer): root_resource = OptionsResource() ports = listen_http( + self, listener_config, create_resource_tree(resources, root_resource), self.version_string, diff --git a/synapse/app/phone_stats_home.py b/synapse/app/phone_stats_home.py index 897dd3edac..09988670da 100644 --- a/synapse/app/phone_stats_home.py +++ b/synapse/app/phone_stats_home.py @@ -127,10 +127,6 @@ async def phone_stats_home( daily_sent_messages = await store.count_daily_sent_messages() stats["daily_sent_messages"] = daily_sent_messages - r30_results = await store.count_r30_users() - for name, count in r30_results.items(): - stats["r30_users_" + name] = count - r30v2_results = await store.count_r30v2_users() for name, count in r30v2_results.items(): stats["r30v2_users_" + name] = count diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index 35c330a3c4..6f4aa53c93 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -23,7 +23,7 @@ from netaddr import IPSet from synapse.api.constants import EventTypes from synapse.events import EventBase -from synapse.types import DeviceListUpdates, JsonDict, UserID +from synapse.types import DeviceListUpdates, JsonDict, JsonMapping, UserID from synapse.util.caches.descriptors import _CacheContext, cached if TYPE_CHECKING: @@ -86,6 +86,7 @@ class ApplicationService: url.rstrip("/") if isinstance(url, str) else None ) # url must not end with a slash self.hs_token = hs_token + # The full Matrix ID for this application service's sender. self.sender = sender self.namespaces = self._check_namespaces(namespaces) self.id = id @@ -212,7 +213,7 @@ class ApplicationService: True if the application service is interested in the user, False if not. """ return ( - # User is the appservice's sender_localpart user + # User is the appservice's configured sender_localpart user user_id == self.sender # User is in the appservice's user namespace or self.is_user_in_namespace(user_id) @@ -378,8 +379,8 @@ class AppServiceTransaction: service: ApplicationService, id: int, events: Sequence[EventBase], - ephemeral: List[JsonDict], - to_device_messages: List[JsonDict], + ephemeral: List[JsonMapping], + to_device_messages: List[JsonMapping], one_time_keys_count: TransactionOneTimeKeysCount, unused_fallback_keys: TransactionUnusedFallbackKeys, device_list_summary: DeviceListUpdates, diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index 5fb3d5083d..c42e1f11aa 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -16,9 +16,6 @@ import logging import urllib.parse from typing import ( TYPE_CHECKING, - Any, - Awaitable, - Callable, Dict, Iterable, List, @@ -27,10 +24,11 @@ from typing import ( Sequence, Tuple, TypeVar, + Union, ) from prometheus_client import Counter -from typing_extensions import Concatenate, ParamSpec, TypeGuard +from typing_extensions import ParamSpec, TypeGuard from synapse.api.constants import EventTypes, Membership, ThirdPartyEntityKind from synapse.api.errors import CodeMessageException, HttpResponseException @@ -42,7 +40,8 @@ from synapse.appservice import ( from synapse.events import EventBase from synapse.events.utils import SerializeEventConfig, serialize_event from synapse.http.client import SimpleHttpClient, is_unknown_endpoint -from synapse.types import DeviceListUpdates, JsonDict, ThirdPartyInstanceID +from synapse.logging import opentracing +from synapse.types import DeviceListUpdates, JsonDict, JsonMapping, ThirdPartyInstanceID from synapse.util.caches.response_cache import ResponseCache if TYPE_CHECKING: @@ -80,9 +79,7 @@ sent_todevice_counter = Counter( HOUR_IN_MS = 60 * 60 * 1000 - APP_SERVICE_PREFIX = "/_matrix/app/v1" -APP_SERVICE_UNSTABLE_PREFIX = "/_matrix/app/unstable" P = ParamSpec("P") R = TypeVar("R") @@ -123,51 +120,22 @@ class ApplicationServiceApi(SimpleHttpClient): def __init__(self, hs: "HomeServer"): super().__init__(hs) self.clock = hs.get_clock() + self.config = hs.config.appservice self.protocol_meta_cache: ResponseCache[Tuple[str, str]] = ResponseCache( hs.get_clock(), "as_protocol_meta", timeout_ms=HOUR_IN_MS ) - async def _send_with_fallbacks( - self, - service: "ApplicationService", - prefixes: List[str], - path: str, - func: Callable[Concatenate[str, P], Awaitable[R]], - *args: P.args, - **kwargs: P.kwargs, - ) -> R: - """ - Attempt to call an application service with multiple paths, falling back - until one succeeds. + def _get_headers(self, service: "ApplicationService") -> Dict[bytes, List[bytes]]: + """This makes sure we have always the auth header and opentracing headers set.""" - Args: - service: The appliacation service, this provides the base URL. - prefixes: A last of paths to try in order for the requests. - path: A suffix to append to each prefix. - func: The function to call, the first argument will be the full - endpoint to fetch. Other arguments are provided by args/kwargs. + # This is also ensured before in the functions. However this is needed to please + # the typechecks. + assert service.hs_token is not None - Returns: - The return value of func. - """ - for i, prefix in enumerate(prefixes, start=1): - uri = f"{service.url}{prefix}{path}" - try: - return await func(uri, *args, **kwargs) - except HttpResponseException as e: - # If an error is received that is due to an unrecognised path, - # fallback to next path (if one exists). Otherwise, consider it - # a legitimate error and raise. - if i < len(prefixes) and is_unknown_endpoint(e): - continue - raise - except Exception: - # Unexpected exceptions get sent to the caller. - raise - - # The function should always exit via the return or raise above this. - raise RuntimeError("Unexpected fallback behaviour. This should never be seen.") + headers = {b"Authorization": [b"Bearer " + service.hs_token.encode("ascii")]} + opentracing.inject_header_dict(headers, check_destination=False) + return headers async def query_user(self, service: "ApplicationService", user_id: str) -> bool: if service.url is None: @@ -177,13 +145,14 @@ class ApplicationServiceApi(SimpleHttpClient): assert service.hs_token is not None try: - response = await self._send_with_fallbacks( - service, - [APP_SERVICE_PREFIX, ""], - f"/users/{urllib.parse.quote(user_id)}", - self.get_json, - {"access_token": service.hs_token}, - headers={"Authorization": [f"Bearer {service.hs_token}"]}, + args = None + if self.config.use_appservice_legacy_authorization: + args = {"access_token": service.hs_token} + + response = await self.get_json( + f"{service.url}{APP_SERVICE_PREFIX}/users/{urllib.parse.quote(user_id)}", + args, + headers=self._get_headers(service), ) if response is not None: # just an empty json object return True @@ -203,13 +172,14 @@ class ApplicationServiceApi(SimpleHttpClient): assert service.hs_token is not None try: - response = await self._send_with_fallbacks( - service, - [APP_SERVICE_PREFIX, ""], - f"/rooms/{urllib.parse.quote(alias)}", - self.get_json, - {"access_token": service.hs_token}, - headers={"Authorization": [f"Bearer {service.hs_token}"]}, + args = None + if self.config.use_appservice_legacy_authorization: + args = {"access_token": service.hs_token} + + response = await self.get_json( + f"{service.url}{APP_SERVICE_PREFIX}/rooms/{urllib.parse.quote(alias)}", + args, + headers=self._get_headers(service), ) if response is not None: # just an empty json object return True @@ -241,17 +211,17 @@ class ApplicationServiceApi(SimpleHttpClient): assert service.hs_token is not None try: - args: Mapping[Any, Any] = { - **fields, - b"access_token": service.hs_token, - } - response = await self._send_with_fallbacks( - service, - [APP_SERVICE_PREFIX, APP_SERVICE_UNSTABLE_PREFIX], - f"/thirdparty/{kind}/{urllib.parse.quote(protocol)}", - self.get_json, + args: Mapping[bytes, Union[List[bytes], str]] = fields + if self.config.use_appservice_legacy_authorization: + args = { + **fields, + b"access_token": service.hs_token, + } + + response = await self.get_json( + f"{service.url}{APP_SERVICE_PREFIX}/thirdparty/{kind}/{urllib.parse.quote(protocol)}", args=args, - headers={"Authorization": [f"Bearer {service.hs_token}"]}, + headers=self._get_headers(service), ) if not isinstance(response, list): logger.warning( @@ -285,13 +255,14 @@ class ApplicationServiceApi(SimpleHttpClient): # This is required by the configuration. assert service.hs_token is not None try: - info = await self._send_with_fallbacks( - service, - [APP_SERVICE_PREFIX, APP_SERVICE_UNSTABLE_PREFIX], - f"/thirdparty/protocol/{urllib.parse.quote(protocol)}", - self.get_json, - {"access_token": service.hs_token}, - headers={"Authorization": [f"Bearer {service.hs_token}"]}, + args = None + if self.config.use_appservice_legacy_authorization: + args = {"access_token": service.hs_token} + + info = await self.get_json( + f"{service.url}{APP_SERVICE_PREFIX}/thirdparty/protocol/{urllib.parse.quote(protocol)}", + args, + headers=self._get_headers(service), ) if not _is_valid_3pe_metadata(info): @@ -328,15 +299,15 @@ class ApplicationServiceApi(SimpleHttpClient): await self.post_json_get_json( uri=f"{service.url}{APP_SERVICE_PREFIX}/ping", post_json={"transaction_id": txn_id}, - headers={"Authorization": [f"Bearer {service.hs_token}"]}, + headers=self._get_headers(service), ) async def push_bulk( self, service: "ApplicationService", events: Sequence[EventBase], - ephemeral: List[JsonDict], - to_device_messages: List[JsonDict], + ephemeral: List[JsonMapping], + to_device_messages: List[JsonMapping], one_time_keys_count: TransactionOneTimeKeysCount, unused_fallback_keys: TransactionUnusedFallbackKeys, device_list_summary: DeviceListUpdates, @@ -401,14 +372,15 @@ class ApplicationServiceApi(SimpleHttpClient): } try: - await self._send_with_fallbacks( - service, - [APP_SERVICE_PREFIX, ""], - f"/transactions/{urllib.parse.quote(str(txn_id))}", - self.put_json, + args = None + if self.config.use_appservice_legacy_authorization: + args = {"access_token": service.hs_token} + + await self.put_json( + f"{service.url}{APP_SERVICE_PREFIX}/transactions/{urllib.parse.quote(str(txn_id))}", json_body=body, - args={"access_token": service.hs_token}, - headers={"Authorization": [f"Bearer {service.hs_token}"]}, + args=args, + headers=self._get_headers(service), ) if logger.isEnabledFor(logging.DEBUG): logger.debug( @@ -481,7 +453,7 @@ class ApplicationServiceApi(SimpleHttpClient): response = await self.post_json_get_json( uri, body, - headers={"Authorization": [f"Bearer {service.hs_token}"]}, + headers=self._get_headers(service), ) except HttpResponseException as e: # The appservice doesn't support this endpoint. @@ -542,7 +514,7 @@ class ApplicationServiceApi(SimpleHttpClient): response = await self.post_json_get_json( uri, query, - headers={"Authorization": [f"Bearer {service.hs_token}"]}, + headers=self._get_headers(service), ) except HttpResponseException as e: # The appservice doesn't support this endpoint. diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py index 3a319b0d42..18a30bc376 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py @@ -73,7 +73,7 @@ from synapse.events import EventBase from synapse.logging.context import run_in_background from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.databases.main import DataStore -from synapse.types import DeviceListUpdates, JsonDict +from synapse.types import DeviceListUpdates, JsonMapping from synapse.util import Clock if TYPE_CHECKING: @@ -121,8 +121,8 @@ class ApplicationServiceScheduler: self, appservice: ApplicationService, events: Optional[Collection[EventBase]] = None, - ephemeral: Optional[Collection[JsonDict]] = None, - to_device_messages: Optional[Collection[JsonDict]] = None, + ephemeral: Optional[Collection[JsonMapping]] = None, + to_device_messages: Optional[Collection[JsonMapping]] = None, device_list_summary: Optional[DeviceListUpdates] = None, ) -> None: """ @@ -180,9 +180,9 @@ class _ServiceQueuer: # dict of {service_id: [events]} self.queued_events: Dict[str, List[EventBase]] = {} # dict of {service_id: [events]} - self.queued_ephemeral: Dict[str, List[JsonDict]] = {} + self.queued_ephemeral: Dict[str, List[JsonMapping]] = {} # dict of {service_id: [to_device_message_json]} - self.queued_to_device_messages: Dict[str, List[JsonDict]] = {} + self.queued_to_device_messages: Dict[str, List[JsonMapping]] = {} # dict of {service_id: [device_list_summary]} self.queued_device_list_summaries: Dict[str, List[DeviceListUpdates]] = {} @@ -200,9 +200,7 @@ class _ServiceQueuer: if service.id in self.requests_in_flight: return - run_as_background_process( - "as-sender-%s" % (service.id,), self._send_request, service - ) + run_as_background_process("as-sender", self._send_request, service) async def _send_request(self, service: ApplicationService) -> None: # sanity-check: we shouldn't get here if this service already has a sender @@ -295,8 +293,8 @@ class _ServiceQueuer: self, service: ApplicationService, events: Iterable[EventBase], - ephemerals: Iterable[JsonDict], - to_device_messages: Iterable[JsonDict], + ephemerals: Iterable[JsonMapping], + to_device_messages: Iterable[JsonMapping], ) -> Tuple[TransactionOneTimeKeysCount, TransactionUnusedFallbackKeys]: """ Given a list of the events, ephemeral messages and to-device messages, @@ -366,8 +364,8 @@ class _TransactionController: self, service: ApplicationService, events: Sequence[EventBase], - ephemeral: Optional[List[JsonDict]] = None, - to_device_messages: Optional[List[JsonDict]] = None, + ephemeral: Optional[List[JsonMapping]] = None, + to_device_messages: Optional[List[JsonMapping]] = None, one_time_keys_count: Optional[TransactionOneTimeKeysCount] = None, unused_fallback_keys: Optional[TransactionUnusedFallbackKeys] = None, device_list_summary: Optional[DeviceListUpdates] = None, @@ -478,14 +476,11 @@ class _Recoverer: self.backoff_counter = 1 def recover(self) -> None: - def _retry() -> None: - run_as_background_process( - "as-recoverer-%s" % (self.service.id,), self.retry - ) - delay = 2**self.backoff_counter logger.info("Scheduling retries on %s in %fs", self.service.id, delay) - self.clock.call_later(delay, _retry) + self.clock.call_later( + delay, run_as_background_process, "as-recoverer", self.retry + ) def _backoff(self) -> None: # cap the backoff to be around 8.5min => (2^9) = 512 secs diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 1d268a1817..c5816105f4 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -26,7 +26,6 @@ from textwrap import dedent from typing import ( Any, ClassVar, - Collection, Dict, Iterable, Iterator, @@ -179,17 +178,18 @@ class Config: If an integer is provided it is treated as bytes and is unchanged. - String byte sizes can have a suffix of 'K' or `M`, representing kibibytes and - mebibytes respectively. No suffix is understood as a plain byte count. + String byte sizes can have a suffix of 'K', `M`, `G` or `T`, + representing kibibytes, mebibytes, gibibytes and tebibytes respectively. + No suffix is understood as a plain byte count. Raises: TypeError, if given something other than an integer or a string ValueError: if given a string not of the form described above. """ - if type(value) is int: + if type(value) is int: # noqa: E721 return value - elif type(value) is str: - sizes = {"K": 1024, "M": 1024 * 1024} + elif isinstance(value, str): + sizes = {"K": 1024, "M": 1024 * 1024, "G": 1024**3, "T": 1024**4} size = 1 suffix = value[-1] if suffix in sizes: @@ -218,9 +218,9 @@ class Config: TypeError, if given something other than an integer or a string ValueError: if given a string not of the form described above. """ - if type(value) is int: + if type(value) is int: # noqa: E721 return value - elif type(value) is str: + elif isinstance(value, str): second = 1000 minute = 60 * second hour = 60 * minute @@ -383,7 +383,7 @@ class RootConfig: config_classes: List[Type[Config]] = [] - def __init__(self, config_files: Collection[str] = ()): + def __init__(self, config_files: StrSequence = ()): # Capture absolute paths here, so we can reload config after we daemonize. self.config_files = [os.path.abspath(path) for path in config_files] diff --git a/synapse/config/_util.py b/synapse/config/_util.py index acccca413b..746838eee3 100644 --- a/synapse/config/_util.py +++ b/synapse/config/_util.py @@ -11,10 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Type, TypeVar +from typing import TYPE_CHECKING, Any, Dict, Type, TypeVar import jsonschema -from pydantic import BaseModel, ValidationError, parse_obj_as + +from synapse._pydantic_compat import HAS_PYDANTIC_V2 + +if TYPE_CHECKING or HAS_PYDANTIC_V2: + from pydantic.v1 import BaseModel, ValidationError, parse_obj_as +else: + from pydantic import BaseModel, ValidationError, parse_obj_as from synapse.config._base import ConfigError from synapse.types import JsonDict, StrSequence diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py index c2710fdf04..a70dfbf41f 100644 --- a/synapse/config/appservice.py +++ b/synapse/config/appservice.py @@ -34,7 +34,7 @@ class AppServiceConfig(Config): def read_config(self, config: JsonDict, **kwargs: Any) -> None: self.app_service_config_files = config.get("app_service_config_files", []) if not isinstance(self.app_service_config_files, list) or not all( - type(x) is str for x in self.app_service_config_files + isinstance(x, str) for x in self.app_service_config_files ): raise ConfigError( "Expected '%s' to be a list of AS config files:" @@ -43,6 +43,14 @@ class AppServiceConfig(Config): ) self.track_appservice_user_ips = config.get("track_appservice_user_ips", False) + self.use_appservice_legacy_authorization = config.get( + "use_appservice_legacy_authorization", False + ) + if self.use_appservice_legacy_authorization: + logger.warning( + "The use of appservice legacy authorization via query params is deprecated" + " and should be considered insecure." + ) def load_appservices( diff --git a/synapse/config/auth.py b/synapse/config/auth.py index 35774962c0..3b4c77f572 100644 --- a/synapse/config/auth.py +++ b/synapse/config/auth.py @@ -29,7 +29,14 @@ class AuthConfig(Config): if password_config is None: password_config = {} - passwords_enabled = password_config.get("enabled", True) + # The default value of password_config.enabled is True, unless msc3861 is enabled. + msc3861_enabled = ( + (config.get("experimental_features") or {}) + .get("msc3861", {}) + .get("enabled", False) + ) + passwords_enabled = password_config.get("enabled", not msc3861_enabled) + # 'only_for_reauth' allows users who have previously set a password to use it, # even though passwords would otherwise be disabled. passwords_for_reauth_only = passwords_enabled == "only_for_reauth" @@ -53,3 +60,13 @@ class AuthConfig(Config): self.ui_auth_session_timeout = self.parse_duration( ui_auth.get("session_timeout", 0) ) + + # Logging in with an existing session. + login_via_existing = config.get("login_via_existing_session", {}) + self.login_via_existing_enabled = login_via_existing.get("enabled", False) + self.login_via_existing_require_ui_auth = login_via_existing.get( + "require_ui_auth", True + ) + self.login_via_existing_token_timeout = self.parse_duration( + login_via_existing.get("token_timeout", "5m") + ) diff --git a/synapse/config/cas.py b/synapse/config/cas.py index 9152c06bd6..bbc8f43073 100644 --- a/synapse/config/cas.py +++ b/synapse/config/cas.py @@ -18,7 +18,7 @@ from typing import Any, List from synapse.config.sso import SsoAttributeRequirement from synapse.types import JsonDict -from ._base import Config +from ._base import Config, ConfigError from ._util import validate_config @@ -41,17 +41,35 @@ class CasConfig(Config): public_baseurl = self.root.server.public_baseurl self.cas_service_url = public_baseurl + "_matrix/client/r0/login/cas/ticket" + self.cas_protocol_version = cas_config.get("protocol_version") + if ( + self.cas_protocol_version is not None + and self.cas_protocol_version not in [1, 2, 3] + ): + raise ConfigError( + "Unsupported CAS protocol version %s (only versions 1, 2, 3 are supported)" + % (self.cas_protocol_version,), + ("cas_config", "protocol_version"), + ) self.cas_displayname_attribute = cas_config.get("displayname_attribute") required_attributes = cas_config.get("required_attributes") or {} self.cas_required_attributes = _parsed_required_attributes_def( required_attributes ) + self.cas_enable_registration = cas_config.get("enable_registration", True) + + self.idp_name = cas_config.get("idp_name", "CAS") + self.idp_icon = cas_config.get("idp_icon") + self.idp_brand = cas_config.get("idp_brand") + else: self.cas_server_url = None self.cas_service_url = None + self.cas_protocol_version = None self.cas_displayname_attribute = None self.cas_required_attributes = [] + self.cas_enable_registration = False # CAS uses a legacy required attributes mapping, not the one provided by diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index d769b7f668..9f830e7094 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -12,15 +12,224 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional +import enum +from typing import TYPE_CHECKING, Any, Optional import attr +import attr.validators +from synapse.api.errors import LimitExceededError from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions from synapse.config import ConfigError -from synapse.config._base import Config +from synapse.config._base import Config, RootConfig from synapse.types import JsonDict +# Determine whether authlib is installed. +try: + import authlib # noqa: F401 + + HAS_AUTHLIB = True +except ImportError: + HAS_AUTHLIB = False + +if TYPE_CHECKING: + # Only import this if we're type checking, as it might not be installed at runtime. + from authlib.jose.rfc7517 import JsonWebKey + + +class ClientAuthMethod(enum.Enum): + """List of supported client auth methods.""" + + CLIENT_SECRET_POST = "client_secret_post" + CLIENT_SECRET_BASIC = "client_secret_basic" + CLIENT_SECRET_JWT = "client_secret_jwt" + PRIVATE_KEY_JWT = "private_key_jwt" + + +def _parse_jwks(jwks: Optional[JsonDict]) -> Optional["JsonWebKey"]: + """A helper function to parse a JWK dict into a JsonWebKey.""" + + if jwks is None: + return None + + from authlib.jose.rfc7517 import JsonWebKey + + return JsonWebKey.import_key(jwks) + + +@attr.s(slots=True, frozen=True) +class MSC3861: + """Configuration for MSC3861: Matrix architecture change to delegate authentication via OIDC""" + + enabled: bool = attr.ib(default=False, validator=attr.validators.instance_of(bool)) + """Whether to enable MSC3861 auth delegation.""" + + @enabled.validator + def _check_enabled(self, attribute: attr.Attribute, value: bool) -> None: + # Only allow enabling MSC3861 if authlib is installed + if value and not HAS_AUTHLIB: + raise ConfigError( + "MSC3861 is enabled but authlib is not installed. " + "Please install authlib to use MSC3861.", + ("experimental", "msc3861", "enabled"), + ) + + issuer: str = attr.ib(default="", validator=attr.validators.instance_of(str)) + """The URL of the OIDC Provider.""" + + issuer_metadata: Optional[JsonDict] = attr.ib(default=None) + """The issuer metadata to use, otherwise discovered from /.well-known/openid-configuration as per MSC2965.""" + + client_id: str = attr.ib( + default="", + validator=attr.validators.instance_of(str), + ) + """The client ID to use when calling the introspection endpoint.""" + + client_auth_method: ClientAuthMethod = attr.ib( + default=ClientAuthMethod.CLIENT_SECRET_POST, converter=ClientAuthMethod + ) + """The auth method used when calling the introspection endpoint.""" + + client_secret: Optional[str] = attr.ib( + default=None, + validator=attr.validators.optional(attr.validators.instance_of(str)), + ) + """ + The client secret to use when calling the introspection endpoint, + when using any of the client_secret_* client auth methods. + """ + + jwk: Optional["JsonWebKey"] = attr.ib(default=None, converter=_parse_jwks) + """ + The JWKS to use when calling the introspection endpoint, + when using the private_key_jwt client auth method. + """ + + @client_auth_method.validator + def _check_client_auth_method( + self, attribute: attr.Attribute, value: ClientAuthMethod + ) -> None: + # Check that the right client credentials are provided for the client auth method. + if not self.enabled: + return + + if value == ClientAuthMethod.PRIVATE_KEY_JWT and self.jwk is None: + raise ConfigError( + "A JWKS must be provided when using the private_key_jwt client auth method", + ("experimental", "msc3861", "client_auth_method"), + ) + + if ( + value + in ( + ClientAuthMethod.CLIENT_SECRET_POST, + ClientAuthMethod.CLIENT_SECRET_BASIC, + ClientAuthMethod.CLIENT_SECRET_JWT, + ) + and self.client_secret is None + ): + raise ConfigError( + f"A client secret must be provided when using the {value} client auth method", + ("experimental", "msc3861", "client_auth_method"), + ) + + account_management_url: Optional[str] = attr.ib( + default=None, + validator=attr.validators.optional(attr.validators.instance_of(str)), + ) + """The URL of the My Account page on the OIDC Provider as per MSC2965.""" + + admin_token: Optional[str] = attr.ib( + default=None, + validator=attr.validators.optional(attr.validators.instance_of(str)), + ) + """ + A token that should be considered as an admin token. + This is used by the OIDC provider, to make admin calls to Synapse. + """ + + def check_config_conflicts(self, root: RootConfig) -> None: + """Checks for any configuration conflicts with other parts of Synapse. + + Raises: + ConfigError: If there are any configuration conflicts. + """ + + if not self.enabled: + return + + if ( + root.auth.password_enabled_for_reauth + or root.auth.password_enabled_for_login + ): + raise ConfigError( + "Password auth cannot be enabled when OAuth delegation is enabled", + ("password_config", "enabled"), + ) + + if root.registration.enable_registration: + raise ConfigError( + "Registration cannot be enabled when OAuth delegation is enabled", + ("enable_registration",), + ) + + # We only need to test the user consent version, as if it must be set if the user_consent section was present in the config + if root.consent.user_consent_version is not None: + raise ConfigError( + "User consent cannot be enabled when OAuth delegation is enabled", + ("user_consent",), + ) + + if ( + root.oidc.oidc_enabled + or root.saml2.saml2_enabled + or root.cas.cas_enabled + or root.jwt.jwt_enabled + ): + raise ConfigError("SSO cannot be enabled when OAuth delegation is enabled") + + if bool(root.authproviders.password_providers): + raise ConfigError( + "Password auth providers cannot be enabled when OAuth delegation is enabled" + ) + + if root.captcha.enable_registration_captcha: + raise ConfigError( + "CAPTCHA cannot be enabled when OAuth delegation is enabled", + ("captcha", "enable_registration_captcha"), + ) + + if root.auth.login_via_existing_enabled: + raise ConfigError( + "Login via existing session cannot be enabled when OAuth delegation is enabled", + ("login_via_existing_session", "enabled"), + ) + + if root.registration.refresh_token_lifetime: + raise ConfigError( + "refresh_token_lifetime cannot be set when OAuth delegation is enabled", + ("refresh_token_lifetime",), + ) + + if root.registration.nonrefreshable_access_token_lifetime: + raise ConfigError( + "nonrefreshable_access_token_lifetime cannot be set when OAuth delegation is enabled", + ("nonrefreshable_access_token_lifetime",), + ) + + if root.registration.session_lifetime: + raise ConfigError( + "session_lifetime cannot be set when OAuth delegation is enabled", + ("session_lifetime",), + ) + + if root.registration.enable_3pid_changes: + raise ConfigError( + "enable_3pid_changes cannot be enabled when OAuth delegation is enabled", + ("enable_3pid_changes",), + ) + @attr.s(auto_attribs=True, frozen=True, slots=True) class MSC3866Config: @@ -46,8 +255,26 @@ class ExperimentalConfig(Config): # MSC3026 (busy presence state) self.msc3026_enabled: bool = experimental.get("msc3026_enabled", False) - # MSC2716 (importing historical messages) - self.msc2716_enabled: bool = experimental.get("msc2716_enabled", False) + # MSC2697 (device dehydration) + # Enabled by default since this option was added after adding the feature. + # It is not recommended that both MSC2697 and MSC3814 both be enabled at + # once. + self.msc2697_enabled: bool = experimental.get("msc2697_enabled", True) + + # MSC3814 (dehydrated devices with SSSS) + # This is an alternative method to achieve the same goals as MSC2697. + # It is not recommended that both MSC2697 and MSC3814 both be enabled at + # once. + self.msc3814_enabled: bool = experimental.get("msc3814_enabled", False) + + if self.msc2697_enabled and self.msc3814_enabled: + raise ConfigError( + "MSC2697 and MSC3814 should not both be enabled.", + ( + "experimental_features", + "msc3814_enabled", + ), + ) # MSC3244 (room version capabilities) self.msc3244_enabled: bool = experimental.get("msc3244_enabled", True) @@ -118,13 +345,6 @@ class ExperimentalConfig(Config): # MSC3881: Remotely toggle push notifications for another client self.msc3881_enabled: bool = experimental.get("msc3881_enabled", False) - # MSC3882: Allow an existing session to sign in a new session - self.msc3882_enabled: bool = experimental.get("msc3882_enabled", False) - self.msc3882_ui_auth: bool = experimental.get("msc3882_ui_auth", True) - self.msc3882_token_timeout = self.parse_duration( - experimental.get("msc3882_token_timeout", "5m") - ) - # MSC3874: Filtering /messages with rel_types / not_rel_types. self.msc3874_enabled: bool = experimental.get("msc3874_enabled", False) @@ -164,16 +384,6 @@ class ExperimentalConfig(Config): # MSC3391: Removing account data. self.msc3391_enabled = experimental.get("msc3391_enabled", False) - # MSC3952: Intentional mentions, this depends on MSC3966. - self.msc3952_intentional_mentions = experimental.get( - "msc3952_intentional_mentions", False - ) - - # MSC3959: Do not generate notifications for edits. - self.msc3958_supress_edit_notifs = experimental.get( - "msc3958_supress_edit_notifs", False - ) - # MSC3967: Do not require UIA when first uploading cross signing keys self.msc3967_enabled = experimental.get("msc3967_enabled", False) @@ -182,13 +392,30 @@ class ExperimentalConfig(Config): "msc3981_recurse_relations", False ) - # MSC3970: Scope transaction IDs to devices - self.msc3970_enabled = experimental.get("msc3970_enabled", False) + # MSC3861: Matrix architecture change to delegate authentication via OIDC + try: + self.msc3861 = MSC3861(**experimental.get("msc3861", {})) + except ValueError as exc: + raise ConfigError( + "Invalid MSC3861 configuration", ("experimental", "msc3861") + ) from exc - # MSC4009: E.164 Matrix IDs - self.msc4009_e164_mxids = experimental.get("msc4009_e164_mxids", False) + # Check that none of the other config options conflict with MSC3861 when enabled + self.msc3861.check_config_conflicts(self.root) # MSC4010: Do not allow setting m.push_rules account data. self.msc4010_push_rules_account_data = experimental.get( "msc4010_push_rules_account_data", False ) + + # MSC4041: Use HTTP header Retry-After to enable library-assisted retry handling + # + # This is a bit hacky, but the most reasonable way to *alway* include the + # headers. + LimitExceededError.include_retry_after_header = experimental.get( + "msc4041_enabled", False + ) + + self.msc4028_push_encrypted_events = experimental.get( + "msc4028_push_encrypted_events", False + ) diff --git a/synapse/config/federation.py b/synapse/config/federation.py index 336fca578a..97636039b8 100644 --- a/synapse/config/federation.py +++ b/synapse/config/federation.py @@ -22,6 +22,8 @@ class FederationConfig(Config): section = "federation" def read_config(self, config: JsonDict, **kwargs: Any) -> None: + federation_config = config.setdefault("federation", {}) + # FIXME: federation_domain_whitelist needs sytests self.federation_domain_whitelist: Optional[dict] = None federation_domain_whitelist = config.get("federation_domain_whitelist", None) @@ -49,5 +51,37 @@ class FederationConfig(Config): "allow_device_name_lookup_over_federation", False ) + # Allow for the configuration of timeout, max request retries + # and min/max retry delays in the matrix federation client. + self.client_timeout_ms = Config.parse_duration( + federation_config.get("client_timeout", "60s") + ) + self.max_long_retry_delay_ms = Config.parse_duration( + federation_config.get("max_long_retry_delay", "60s") + ) + self.max_short_retry_delay_ms = Config.parse_duration( + federation_config.get("max_short_retry_delay", "2s") + ) + self.max_long_retries = federation_config.get("max_long_retries", 10) + self.max_short_retries = federation_config.get("max_short_retries", 3) + + # Allow for the configuration of the backoff algorithm used + # when trying to reach an unavailable destination. + # Unlike previous configuration those values applies across + # multiple requests and the state of the backoff is stored on DB. + self.destination_min_retry_interval_ms = Config.parse_duration( + federation_config.get("destination_min_retry_interval", "10m") + ) + self.destination_retry_multiplier = federation_config.get( + "destination_retry_multiplier", 2 + ) + self.destination_max_retry_interval_ms = min( + Config.parse_duration( + federation_config.get("destination_max_retry_interval", "7d") + ), + # Set a hard-limit to not overflow the database column. + 2**62, + ) + _METRICS_FOR_DOMAINS_SCHEMA = {"type": "array", "items": {"type": "string"}} diff --git a/synapse/config/logger.py b/synapse/config/logger.py index 56db875b25..1e080133dc 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -117,9 +117,7 @@ root: # Write logs to the `buffer` handler, which will buffer them together in memory, # then write them to a file. # - # Replace "buffer" with "console" to log to stderr instead. (Note that you'll - # also need to update the configuration for the `twisted` logger above, in - # this case.) + # Replace "buffer" with "console" to log to stderr instead. # handlers: [buffer] diff --git a/synapse/config/oembed.py b/synapse/config/oembed.py index d7959639ee..59bc0b55f4 100644 --- a/synapse/config/oembed.py +++ b/synapse/config/oembed.py @@ -30,7 +30,7 @@ class OEmbedEndpointConfig: # The API endpoint to fetch. api_endpoint: str # The patterns to match. - url_patterns: List[Pattern] + url_patterns: List[Pattern[str]] # The supported formats. formats: Optional[List[str]] diff --git a/synapse/config/oidc.py b/synapse/config/oidc.py index 77c1d1dc8e..574d6afb95 100644 --- a/synapse/config/oidc.py +++ b/synapse/config/oidc.py @@ -280,6 +280,20 @@ def _parse_oidc_config_dict( for x in oidc_config.get("attribute_requirements", []) ] + # Read from either `client_secret_path` or `client_secret`. If both exist, error. + client_secret = oidc_config.get("client_secret") + client_secret_path = oidc_config.get("client_secret_path") + if client_secret_path is not None: + if client_secret is None: + client_secret = read_file( + client_secret_path, config_path + ("client_secret_path",) + ).rstrip("\n") + else: + raise ConfigError( + "Cannot specify both client_secret and client_secret_path", + config_path + ("client_secret",), + ) + return OidcProviderConfig( idp_id=idp_id, idp_name=oidc_config.get("idp_name", "OIDC"), @@ -288,7 +302,7 @@ def _parse_oidc_config_dict( discover=oidc_config.get("discover", True), issuer=oidc_config["issuer"], client_id=oidc_config["client_id"], - client_secret=oidc_config.get("client_secret"), + client_secret=client_secret, client_secret_jwt_key=client_secret_jwt_key, client_auth_method=oidc_config.get("client_auth_method", "client_secret_basic"), pkce_method=oidc_config.get("pkce_method", "auto"), diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py index a5514e70a2..4efbaeac0d 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, cast import attr @@ -21,16 +21,47 @@ from synapse.types import JsonDict from ._base import Config +@attr.s(slots=True, frozen=True, auto_attribs=True) class RatelimitSettings: - def __init__( - self, - config: Dict[str, float], + key: str + per_second: float + burst_count: int + + @classmethod + def parse( + cls, + config: Dict[str, Any], + key: str, defaults: Optional[Dict[str, float]] = None, - ): + ) -> "RatelimitSettings": + """Parse config[key] as a new-style rate limiter config. + + The key may refer to a nested dictionary using a full stop (.) to separate + each nested key. For example, use the key "a.b.c" to parse the following: + + a: + b: + c: + per_second: 10 + burst_count: 200 + + If this lookup fails, we'll fallback to the defaults. + """ defaults = defaults or {"per_second": 0.17, "burst_count": 3.0} - self.per_second = config.get("per_second", defaults["per_second"]) - self.burst_count = int(config.get("burst_count", defaults["burst_count"])) + rl_config = config + for part in key.split("."): + rl_config = rl_config.get(part, {}) + + # By this point we should have hit the rate limiter parameters. + # We don't actually check this though! + rl_config = cast(Dict[str, float], rl_config) + + return cls( + key=key, + per_second=rl_config.get("per_second", defaults["per_second"]), + burst_count=int(rl_config.get("burst_count", defaults["burst_count"])), + ) @attr.s(auto_attribs=True) @@ -49,15 +80,14 @@ class RatelimitConfig(Config): # Load the new-style messages config if it exists. Otherwise fall back # to the old method. if "rc_message" in config: - self.rc_message = RatelimitSettings( - config["rc_message"], defaults={"per_second": 0.2, "burst_count": 10.0} + self.rc_message = RatelimitSettings.parse( + config, "rc_message", defaults={"per_second": 0.2, "burst_count": 10.0} ) else: self.rc_message = RatelimitSettings( - { - "per_second": config.get("rc_messages_per_second", 0.2), - "burst_count": config.get("rc_message_burst_count", 10.0), - } + key="rc_messages", + per_second=config.get("rc_messages_per_second", 0.2), + burst_count=config.get("rc_message_burst_count", 10.0), ) # Load the new-style federation config, if it exists. Otherwise, fall @@ -79,51 +109,59 @@ class RatelimitConfig(Config): } ) - self.rc_registration = RatelimitSettings(config.get("rc_registration", {})) + self.rc_registration = RatelimitSettings.parse(config, "rc_registration", {}) - self.rc_registration_token_validity = RatelimitSettings( - config.get("rc_registration_token_validity", {}), + self.rc_registration_token_validity = RatelimitSettings.parse( + config, + "rc_registration_token_validity", defaults={"per_second": 0.1, "burst_count": 5}, ) # It is reasonable to login with a bunch of devices at once (i.e. when # setting up an account), but it is *not* valid to continually be # logging into new devices. - rc_login_config = config.get("rc_login", {}) - self.rc_login_address = RatelimitSettings( - rc_login_config.get("address", {}), + self.rc_login_address = RatelimitSettings.parse( + config, + "rc_login.address", defaults={"per_second": 0.003, "burst_count": 5}, ) - self.rc_login_account = RatelimitSettings( - rc_login_config.get("account", {}), + self.rc_login_account = RatelimitSettings.parse( + config, + "rc_login.account", defaults={"per_second": 0.003, "burst_count": 5}, ) - self.rc_login_failed_attempts = RatelimitSettings( - rc_login_config.get("failed_attempts", {}) + self.rc_login_failed_attempts = RatelimitSettings.parse( + config, + "rc_login.failed_attempts", + {}, ) self.federation_rr_transactions_per_room_per_second = config.get( "federation_rr_transactions_per_room_per_second", 50 ) - rc_admin_redaction = config.get("rc_admin_redaction") self.rc_admin_redaction = None - if rc_admin_redaction: - self.rc_admin_redaction = RatelimitSettings(rc_admin_redaction) + if "rc_admin_redaction" in config: + self.rc_admin_redaction = RatelimitSettings.parse( + config, "rc_admin_redaction", {} + ) - self.rc_joins_local = RatelimitSettings( - config.get("rc_joins", {}).get("local", {}), + self.rc_joins_local = RatelimitSettings.parse( + config, + "rc_joins.local", defaults={"per_second": 0.1, "burst_count": 10}, ) - self.rc_joins_remote = RatelimitSettings( - config.get("rc_joins", {}).get("remote", {}), + self.rc_joins_remote = RatelimitSettings.parse( + config, + "rc_joins.remote", defaults={"per_second": 0.01, "burst_count": 10}, ) # Track the rate of joins to a given room. If there are too many, temporarily # prevent local joins and remote joins via this server. - self.rc_joins_per_room = RatelimitSettings( - config.get("rc_joins_per_room", {}), + self.rc_joins_per_room = RatelimitSettings.parse( + config, + "rc_joins_per_room", defaults={"per_second": 1, "burst_count": 10}, ) @@ -132,31 +170,37 @@ class RatelimitConfig(Config): # * For requests received over federation this is keyed by the origin. # # Note that this isn't exposed in the configuration as it is obscure. - self.rc_key_requests = RatelimitSettings( - config.get("rc_key_requests", {}), + self.rc_key_requests = RatelimitSettings.parse( + config, + "rc_key_requests", defaults={"per_second": 20, "burst_count": 100}, ) - self.rc_3pid_validation = RatelimitSettings( - config.get("rc_3pid_validation") or {}, + self.rc_3pid_validation = RatelimitSettings.parse( + config, + "rc_3pid_validation", defaults={"per_second": 0.003, "burst_count": 5}, ) - self.rc_invites_per_room = RatelimitSettings( - config.get("rc_invites", {}).get("per_room", {}), + self.rc_invites_per_room = RatelimitSettings.parse( + config, + "rc_invites.per_room", defaults={"per_second": 0.3, "burst_count": 10}, ) - self.rc_invites_per_user = RatelimitSettings( - config.get("rc_invites", {}).get("per_user", {}), + self.rc_invites_per_user = RatelimitSettings.parse( + config, + "rc_invites.per_user", defaults={"per_second": 0.003, "burst_count": 5}, ) - self.rc_invites_per_issuer = RatelimitSettings( - config.get("rc_invites", {}).get("per_issuer", {}), + self.rc_invites_per_issuer = RatelimitSettings.parse( + config, + "rc_invites.per_issuer", defaults={"per_second": 0.3, "burst_count": 10}, ) - self.rc_third_party_invite = RatelimitSettings( - config.get("rc_third_party_invite", {}), + self.rc_third_party_invite = RatelimitSettings.parse( + config, + "rc_third_party_invite", defaults={"per_second": 0.0025, "burst_count": 5}, ) diff --git a/synapse/config/redis.py b/synapse/config/redis.py index 636cb450b8..3c4c499e22 100644 --- a/synapse/config/redis.py +++ b/synapse/config/redis.py @@ -33,6 +33,7 @@ class RedisConfig(Config): self.redis_host = redis_config.get("host", "localhost") self.redis_port = redis_config.get("port", 6379) + self.redis_path = redis_config.get("path", None) self.redis_dbid = redis_config.get("dbid", None) self.redis_password = redis_config.get("password") diff --git a/synapse/config/registration.py b/synapse/config/registration.py index df1d83dfaa..b8ad6fbc06 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -133,7 +133,16 @@ class RegistrationConfig(Config): self.enable_set_displayname = config.get("enable_set_displayname", True) self.enable_set_avatar_url = config.get("enable_set_avatar_url", True) - self.enable_3pid_changes = config.get("enable_3pid_changes", True) + + # The default value of enable_3pid_changes is True, unless msc3861 is enabled. + msc3861_enabled = ( + (config.get("experimental_features") or {}) + .get("msc3861", {}) + .get("enabled", False) + ) + self.enable_3pid_changes = config.get( + "enable_3pid_changes", not msc3861_enabled + ) self.disable_msisdn_registration = config.get( "disable_msisdn_registration", False diff --git a/synapse/config/saml2.py b/synapse/config/saml2.py index 49ca663dde..c69e24cf26 100644 --- a/synapse/config/saml2.py +++ b/synapse/config/saml2.py @@ -89,8 +89,14 @@ class SAML2Config(Config): "grandfathered_mxid_source_attribute", "uid" ) + # refers to a SAML IdP entity ID self.saml2_idp_entityid = saml2_config.get("idp_entityid", None) + # IdP properties for Matrix clients + self.idp_name = saml2_config.get("idp_name", "SAML") + self.idp_icon = saml2_config.get("idp_icon") + self.idp_brand = saml2_config.get("idp_brand") + # user_mapping_provider may be None if the key is present but has no value ump_dict = saml2_config.get("user_mapping_provider") or {} diff --git a/synapse/config/server.py b/synapse/config/server.py index b46fa51593..72d30da300 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -486,6 +486,17 @@ class ServerConfig(Config): else: self.redaction_retention_period = None + # How long to keep locally forgotten rooms before purging them from the DB. + forgotten_room_retention_period = config.get( + "forgotten_room_retention_period", None + ) + if forgotten_room_retention_period is not None: + self.forgotten_room_retention_period: Optional[int] = self.parse_duration( + forgotten_room_retention_period + ) + else: + self.forgotten_room_retention_period = None + # How long to keep entries in the `users_ips` table. user_ips_max_age = config.get("user_ips_max_age", "28d") if user_ips_max_age is not None: diff --git a/synapse/config/user_directory.py b/synapse/config/user_directory.py index c9e18b91e9..f60ec2ea66 100644 --- a/synapse/config/user_directory.py +++ b/synapse/config/user_directory.py @@ -35,3 +35,4 @@ class UserDirectoryConfig(Config): self.user_directory_search_prefer_local_users = user_directory_config.get( "prefer_local_users", False ) + self.show_locked_users = user_directory_config.get("show_locked_users", False) diff --git a/synapse/config/workers.py b/synapse/config/workers.py index d2311cc857..f1766088fc 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -15,10 +15,16 @@ import argparse import logging -from typing import Any, Dict, List, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union import attr -from pydantic import BaseModel, Extra, StrictBool, StrictInt, StrictStr + +from synapse._pydantic_compat import HAS_PYDANTIC_V2 + +if TYPE_CHECKING or HAS_PYDANTIC_V2: + from pydantic.v1 import BaseModel, Extra, StrictBool, StrictInt, StrictStr +else: + from pydantic import BaseModel, Extra, StrictBool, StrictInt, StrictStr from synapse.config._base import ( Config, @@ -41,11 +47,17 @@ Synapse version. Please use ``%s: name_of_worker`` instead. _MISSING_MAIN_PROCESS_INSTANCE_MAP_DATA = """ Missing data for a worker to connect to main process. Please include '%s' in the -`instance_map` declared in your shared yaml configuration, or optionally(as a deprecated -solution) in every worker's yaml as various `worker_replication_*` settings as defined -in workers documentation here: +`instance_map` declared in your shared yaml configuration as defined in configuration +documentation here: +`https://matrix-org.github.io/synapse/latest/usage/configuration/config_documentation.html#instance_map` +""" + +WORKER_REPLICATION_SETTING_DEPRECATED_MESSAGE = """ +'%s' is no longer a supported worker setting, please place '%s' onto your shared +configuration under `main` inside the `instance_map`. See workers documentation here: `https://matrix-org.github.io/synapse/latest/workers.html#worker-configuration` """ + # This allows for a handy knob when it's time to change from 'master' to # something with less 'history' MAIN_PROCESS_INSTANCE_NAME = "master" @@ -88,7 +100,7 @@ class ConfigModel(BaseModel): allow_mutation = False -class InstanceLocationConfig(ConfigModel): +class InstanceTcpLocationConfig(ConfigModel): """The host and port to talk to an instance via HTTP replication.""" host: StrictStr @@ -104,6 +116,23 @@ class InstanceLocationConfig(ConfigModel): return f"{self.host}:{self.port}" +class InstanceUnixLocationConfig(ConfigModel): + """The socket file to talk to an instance via HTTP replication.""" + + path: StrictStr + + def scheme(self) -> str: + """Hardcode a retrievable scheme""" + return "unix" + + def netloc(self) -> str: + """Nicely format the address location data""" + return f"{self.path}" + + +InstanceLocationConfig = Union[InstanceTcpLocationConfig, InstanceUnixLocationConfig] + + @attr.s class WriterLocations: """Specifies the instances that write various streams. @@ -148,6 +177,27 @@ class WriterLocations: ) +@attr.s(auto_attribs=True) +class OutboundFederationRestrictedTo: + """Whether we limit outbound federation to a certain set of instances. + + Attributes: + instances: optional list of instances that can make outbound federation + requests. If None then all instances can make federation requests. + locations: list of instance locations to connect to proxy via. + """ + + instances: Optional[List[str]] + locations: List[InstanceLocationConfig] = attr.Factory(list) + + def __contains__(self, instance: str) -> bool: + # It feels a bit dirty to return `True` if `instances` is `None`, but it makes + # sense in downstream usage in the sense that if + # `outbound_federation_restricted_to` is not configured, then any instance can + # talk to federation (no restrictions so always return `True`). + return self.instances is None or instance in self.instances + + class WorkerConfig(Config): """The workers are processes run separately to the main synapse process. They have their own pid_file and listener configuration. They use the @@ -216,22 +266,37 @@ class WorkerConfig(Config): ) # A map from instance name to host/port of their HTTP replication endpoint. - # Check if the main process is declared. Inject it into the map if it's not, - # based first on if a 'main' block is declared then on 'worker_replication_*' - # data. If both are available, default to instance_map. The main process - # itself doesn't need this data as it would never have to talk to itself. + # Check if the main process is declared. The main process itself doesn't need + # this data as it would never have to talk to itself. instance_map: Dict[str, Any] = config.get("instance_map", {}) - if instance_map and self.instance_name is not MAIN_PROCESS_INSTANCE_NAME: + if self.instance_name is not MAIN_PROCESS_INSTANCE_NAME: + # TODO: The next 3 condition blocks can be deleted after some time has + # passed and we're ready to stop checking for these settings. # The host used to connect to the main synapse main_host = config.get("worker_replication_host", None) + if main_host: + raise ConfigError( + WORKER_REPLICATION_SETTING_DEPRECATED_MESSAGE + % ("worker_replication_host", main_host) + ) # The port on the main synapse for HTTP replication endpoint main_port = config.get("worker_replication_http_port") + if main_port: + raise ConfigError( + WORKER_REPLICATION_SETTING_DEPRECATED_MESSAGE + % ("worker_replication_http_port", main_port) + ) # The tls mode on the main synapse for HTTP replication endpoint. # For backward compatibility this defaults to False. main_tls = config.get("worker_replication_http_tls", False) + if main_tls: + raise ConfigError( + WORKER_REPLICATION_SETTING_DEPRECATED_MESSAGE + % ("worker_replication_http_tls", main_tls) + ) # For now, accept 'main' in the instance_map, but the replication system # expects 'master', force that into being until it's changed later. @@ -241,30 +306,20 @@ class WorkerConfig(Config): ] del instance_map[MAIN_PROCESS_INSTANCE_MAP_NAME] - # This is the backwards compatibility bit that handles the - # worker_replication_* bits using setdefault() to not overwrite anything. - elif main_host is not None and main_port is not None: - instance_map.setdefault( - MAIN_PROCESS_INSTANCE_NAME, - { - "host": main_host, - "port": main_port, - "tls": main_tls, - }, - ) - else: # If we've gotten here, it means that the main process is not on the - # instance_map and that not enough worker_replication_* variables - # were declared in the worker's yaml. + # instance_map. raise ConfigError( _MISSING_MAIN_PROCESS_INSTANCE_MAP_DATA % MAIN_PROCESS_INSTANCE_MAP_NAME ) + # type-ignore: the expression `Union[A, B]` is not a Type[Union[A, B]] currently self.instance_map: Dict[ str, InstanceLocationConfig - ] = parse_and_validate_mapping(instance_map, InstanceLocationConfig) + ] = parse_and_validate_mapping( + instance_map, InstanceLocationConfig # type: ignore[arg-type] + ) # Map from type of streams to source, c.f. WriterLocations. writers = config.get("stream_writers") or {} @@ -357,6 +412,28 @@ class WorkerConfig(Config): new_option_name="update_user_directory_from_worker", ) + outbound_federation_restricted_to = config.get( + "outbound_federation_restricted_to", None + ) + self.outbound_federation_restricted_to = OutboundFederationRestrictedTo( + outbound_federation_restricted_to + ) + if outbound_federation_restricted_to: + if not self.worker_replication_secret: + raise ConfigError( + "`worker_replication_secret` must be configured when using `outbound_federation_restricted_to`." + ) + + for instance in outbound_federation_restricted_to: + if instance not in self.instance_map: + raise ConfigError( + "Instance %r is configured in 'outbound_federation_restricted_to' but does not appear in `instance_map` config." + % (instance,) + ) + self.outbound_federation_restricted_to.locations.append( + self.instance_map[instance] + ) + def _should_this_worker_perform_duty( self, config: Dict[str, Any], diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 260aab3241..fe86f54d80 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -23,12 +23,7 @@ from signedjson.key import ( get_verify_key, is_signing_algorithm_supported, ) -from signedjson.sign import ( - SignatureVerifyException, - encode_canonical_json, - signature_ids, - verify_signed_json, -) +from signedjson.sign import SignatureVerifyException, signature_ids, verify_signed_json from signedjson.types import VerifyKey from unpaddedbase64 import decode_base64 @@ -596,24 +591,12 @@ class BaseV2KeyFetcher(KeyFetcher): verify_key=verify_key, valid_until_ts=key_data["expired_ts"] ) - key_json_bytes = encode_canonical_json(response_json) - - await make_deferred_yieldable( - defer.gatherResults( - [ - run_in_background( - self.store.store_server_keys_json, - server_name=server_name, - key_id=key_id, - from_server=from_server, - ts_now_ms=time_added_ms, - ts_expires_ms=ts_valid_until_ms, - key_json_bytes=key_json_bytes, - ) - for key_id in verify_keys - ], - consumeErrors=True, - ).addErrback(unwrapFirstError) + await self.store.store_server_keys_response( + server_name=server_name, + from_server=from_server, + ts_added_ms=time_added_ms, + verify_keys=verify_keys, + response_json=response_json, ) return verify_keys @@ -775,10 +758,6 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): keys.setdefault(server_name, {}).update(processed_response) - await self.store.store_server_signature_keys( - perspective_name, time_now_ms, added_keys - ) - return keys def _validate_perspectives_response( diff --git a/synapse/event_auth.py b/synapse/event_auth.py index b4b43ec4d7..2ac9f8b309 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -126,7 +126,7 @@ def validate_event_for_room_version(event: "EventBase") -> None: raise AuthError(403, "Event not signed by sending server") is_invite_via_allow_rule = ( - event.room_version.msc3083_join_rules + event.room_version.restricted_join_rule and event.type == EventTypes.Member and event.membership == Membership.JOIN and EventContentFields.AUTHORISING_USER in event.content @@ -339,13 +339,6 @@ def check_state_dependent_auth_rules( if event.type == EventTypes.Redaction: check_redaction(event.room_version, event, auth_dict) - if ( - event.type == EventTypes.MSC2716_INSERTION - or event.type == EventTypes.MSC2716_BATCH - or event.type == EventTypes.MSC2716_MARKER - ): - check_historical(event.room_version, event, auth_dict) - logger.debug("Allowing! %s", event) @@ -359,13 +352,10 @@ LENIENT_EVENT_BYTE_LIMITS_ROOM_VERSIONS = { RoomVersions.V4, RoomVersions.V5, RoomVersions.V6, - RoomVersions.MSC2176, RoomVersions.V7, RoomVersions.V8, RoomVersions.V9, - RoomVersions.MSC3787, RoomVersions.V10, - RoomVersions.MSC2716v4, RoomVersions.MSC1767v10, } @@ -457,7 +447,7 @@ def _check_create(event: "EventBase") -> None: # 1.4 If content has no creator field, reject if the room version requires it. if ( - not event.room_version.msc2175_implicit_room_creator + not event.room_version.implicit_room_creator and EventContentFields.ROOM_CREATOR not in event.content ): raise AuthError(403, "Create event lacks a 'creator' property") @@ -494,7 +484,7 @@ def _is_membership_change_allowed( key = (EventTypes.Create, "") create = auth_events.get(key) if create and event.prev_event_ids()[0] == create.event_id: - if room_version.msc2175_implicit_room_creator: + if room_version.implicit_room_creator: creator = create.sender else: creator = create.content[EventContentFields.ROOM_CREATOR] @@ -517,7 +507,7 @@ def _is_membership_change_allowed( caller_invited = caller and caller.membership == Membership.INVITE caller_knocked = ( caller - and room_version.msc2403_knocking + and room_version.knock_join_rule and caller.membership == Membership.KNOCK ) @@ -617,9 +607,9 @@ def _is_membership_change_allowed( elif join_rule == JoinRules.PUBLIC: pass elif ( - room_version.msc3083_join_rules and join_rule == JoinRules.RESTRICTED + room_version.restricted_join_rule and join_rule == JoinRules.RESTRICTED ) or ( - room_version.msc3787_knock_restricted_join_rule + room_version.knock_restricted_join_rule and join_rule == JoinRules.KNOCK_RESTRICTED ): # This is the same as public, but the event must contain a reference @@ -649,9 +639,9 @@ def _is_membership_change_allowed( elif ( join_rule == JoinRules.INVITE - or (room_version.msc2403_knocking and join_rule == JoinRules.KNOCK) + or (room_version.knock_join_rule and join_rule == JoinRules.KNOCK) or ( - room_version.msc3787_knock_restricted_join_rule + room_version.knock_restricted_join_rule and join_rule == JoinRules.KNOCK_RESTRICTED ) ): @@ -679,15 +669,21 @@ def _is_membership_change_allowed( errcode=Codes.INSUFFICIENT_POWER, ) elif Membership.BAN == membership: - if user_level < ban_level or user_level <= target_level: + if user_level < ban_level: raise UnstableSpecAuthError( 403, "You don't have permission to ban", errcode=Codes.INSUFFICIENT_POWER, ) - elif room_version.msc2403_knocking and Membership.KNOCK == membership: + elif user_level <= target_level: + raise UnstableSpecAuthError( + 403, + "You don't have permission to ban this user", + errcode=Codes.INSUFFICIENT_POWER, + ) + elif room_version.knock_join_rule and Membership.KNOCK == membership: if join_rule != JoinRules.KNOCK and ( - not room_version.msc3787_knock_restricted_join_rule + not room_version.knock_restricted_join_rule or join_rule != JoinRules.KNOCK_RESTRICTED ): raise AuthError(403, "You don't have permission to knock") @@ -823,38 +819,6 @@ def check_redaction( raise AuthError(403, "You don't have permission to redact events") -def check_historical( - room_version_obj: RoomVersion, - event: "EventBase", - auth_events: StateMap["EventBase"], -) -> None: - """Check whether the event sender is allowed to send historical related - events like "insertion", "batch", and "marker". - - Returns: - None - - Raises: - AuthError if the event sender is not allowed to send historical related events - ("insertion", "batch", and "marker"). - """ - # Ignore the auth checks in room versions that do not support historical - # events - if not room_version_obj.msc2716_historical: - return - - user_level = get_user_power_level(event.user_id, auth_events) - - historical_level = get_named_level(auth_events, "historical", 100) - - if user_level < historical_level: - raise UnstableSpecAuthError( - 403, - 'You don\'t have permission to send send historical related events ("insertion", "batch", and "marker")', - errcode=Codes.INSUFFICIENT_POWER, - ) - - def _check_power_levels( room_version_obj: RoomVersion, event: "EventBase", @@ -876,7 +840,7 @@ def _check_power_levels( # Reject events with stringy power levels if required by room version if ( event.type == EventTypes.PowerLevels - and room_version_obj.msc3667_int_only_power_levels + and room_version_obj.enforce_int_power_levels ): for k, v in event.content.items(): if k in { @@ -888,11 +852,11 @@ def _check_power_levels( "kick", "invite", }: - if type(v) is not int: + if type(v) is not int: # noqa: E721 raise SynapseError(400, f"{v!r} must be an integer.") if k in {"events", "notifications", "users"}: if not isinstance(v, collections.abc.Mapping) or not all( - type(v) is int for v in v.values() + type(v) is int for v in v.values() # noqa: E721 ): raise SynapseError( 400, @@ -1012,7 +976,7 @@ def get_user_power_level(user_id: str, auth_events: StateMap["EventBase"]) -> in key = (EventTypes.Create, "") create_event = auth_events.get(key) if create_event is not None: - if create_event.room_version.msc2175_implicit_room_creator: + if create_event.room_version.implicit_room_creator: creator = create_event.sender else: creator = create_event.content[EventContentFields.ROOM_CREATOR] @@ -1150,7 +1114,7 @@ def auth_types_for_event( ) auth_types.add(key) - if room_version.msc3083_join_rules and membership == Membership.JOIN: + if room_version.restricted_join_rule and membership == Membership.JOIN: if EventContentFields.AUTHORISING_USER in event.content: key = ( EventTypes.Member, diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index de7e5be42b..3c1777b7ec 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -25,7 +25,6 @@ from typing import ( Iterable, List, Optional, - Sequence, Tuple, Type, TypeVar, @@ -198,7 +197,6 @@ class _EventInternalMetadata: soft_failed: DictProperty[bool] = DictProperty("soft_failed") proactively_send: DictProperty[bool] = DictProperty("proactively_send") redacted: DictProperty[bool] = DictProperty("redacted") - historical: DictProperty[bool] = DictProperty("historical") txn_id: DictProperty[str] = DictProperty("txn_id") """The transaction ID, if it was set when the event was created.""" @@ -288,14 +286,6 @@ class _EventInternalMetadata: """ return self._dict.get("redacted", False) - def is_historical(self) -> bool: - """Whether this is a historical message. - This is used by the batchsend historical message endpoint and - is needed to and mark the event as backfilled and skip some checks - like push notifications. - """ - return self._dict.get("historical", False) - def is_notifiable(self) -> bool: """Whether this event can trigger a push notification""" return not self.is_outlier() or self.is_out_of_band_membership() @@ -355,7 +345,7 @@ class EventBase(metaclass=abc.ABCMeta): @property def redacts(self) -> Optional[str]: """MSC2176 moved the redacts field into the content.""" - if self.room_version.msc2176_redaction_rules: + if self.room_version.updated_redaction_rules: return self.content.get("redacts") return self.get("redacts") @@ -417,7 +407,7 @@ class EventBase(metaclass=abc.ABCMeta): def keys(self) -> Iterable[str]: return self._dict.keys() - def prev_event_ids(self) -> Sequence[str]: + def prev_event_ids(self) -> List[str]: """Returns the list of prev event IDs. The order matches the order specified in the event, though there is no meaning to it. @@ -562,7 +552,7 @@ class FrozenEventV2(EventBase): self._event_id = "$" + encode_base64(compute_event_reference_hash(self)[1]) return self._event_id - def prev_event_ids(self) -> Sequence[str]: + def prev_event_ids(self) -> List[str]: """Returns the list of prev event IDs. The order matches the order specified in the event, though there is no meaning to it. diff --git a/synapse/events/builder.py b/synapse/events/builder.py index a254548c6c..43469b170f 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Any, Collection, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import attr from signedjson.types import SigningKey @@ -28,7 +28,7 @@ from synapse.event_auth import auth_types_for_event from synapse.events import EventBase, _EventInternalMetadata, make_event_from_dict from synapse.state import StateHandler from synapse.storage.databases.main import DataStore -from synapse.types import EventID, JsonDict +from synapse.types import EventID, JsonDict, StrCollection from synapse.types.state import StateFilter from synapse.util import Clock from synapse.util.stringutils import random_string @@ -103,7 +103,7 @@ class EventBuilder: async def build( self, - prev_event_ids: Collection[str], + prev_event_ids: List[str], auth_event_ids: Optional[List[str]], depth: Optional[int] = None, ) -> EventBase: @@ -136,7 +136,7 @@ class EventBuilder: format_version = self.room_version.event_format # The types of auth/prev events changes between event versions. - prev_events: Union[Collection[str], List[Tuple[str, Dict[str, str]]]] + prev_events: Union[StrCollection, List[Tuple[str, Dict[str, str]]]] auth_events: Union[List[str], List[Tuple[str, Dict[str, str]]]] if format_version == EventFormatVersions.ROOM_V1_V2: auth_events = await self._store.add_event_hashes(auth_event_ids) @@ -175,7 +175,7 @@ class EventBuilder: # MSC2174 moves the redacts property to the content, it is invalid to # provide it as a top-level property. - if self._redacts is not None and not self.room_version.msc2176_redaction_rules: + if self._redacts is not None and not self.room_version.updated_redaction_rules: event_dict["redacts"] = self._redacts if self._origin_server_ts is not None: diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index 9b4d692cf4..5bdfa3a8ac 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -12,13 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple import attr from immutabledict import immutabledict from synapse.appservice import ApplicationService from synapse.events import EventBase +from synapse.logging.opentracing import tag_args, trace from synapse.types import JsonDict, StateMap if TYPE_CHECKING: @@ -54,7 +55,6 @@ class UnpersistedEventContextBase(ABC): A method to convert an UnpersistedEventContext to an EventContext, suitable for sending to the database with the associated event. """ - pass @abstractmethod async def get_prev_state_ids( @@ -68,7 +68,6 @@ class UnpersistedEventContextBase(ABC): state_filter: specifies the type of state event to fetch from DB, example: EventTypes.JoinRules """ - pass @attr.s(slots=True, auto_attribs=True) @@ -106,33 +105,32 @@ class EventContext(UnpersistedEventContextBase): state_delta_due_to_event: If `state_group` and `state_group_before_event` are not None then this is the delta of the state between the two groups. - prev_group: If it is known, ``state_group``'s prev_group. Note that this being - None does not necessarily mean that ``state_group`` does not have - a prev_group! + state_group_deltas: If not empty, this is a dict collecting a mapping of the state + difference between state groups. - If the event is a state event, this is normally the same as - ``state_group_before_event``. + The keys are a tuple of two integers: the initial group and final state group. + The corresponding value is a state map representing the state delta between + these state groups. - If ``state_group`` is None (ie, the event is an outlier), ``prev_group`` - will always also be ``None``. + The dictionary is expected to have at most two entries with state groups of: - Note that this *not* (necessarily) the state group associated with - ``_prev_state_ids``. + 1. The state group before the event and after the event. + 2. The state group preceding the state group before the event and the + state group before the event. - delta_ids: If ``prev_group`` is not None, the state delta between ``prev_group`` - and ``state_group``. + This information is collected and stored as part of an optimization for persisting + events. partial_state: if True, we may be storing this event with a temporary, incomplete state. """ _storage: "StorageControllers" + state_group_deltas: Dict[Tuple[int, int], StateMap[str]] rejected: Optional[str] = None _state_group: Optional[int] = None state_group_before_event: Optional[int] = None _state_delta_due_to_event: Optional[StateMap[str]] = None - prev_group: Optional[int] = None - delta_ids: Optional[StateMap[str]] = None app_service: Optional[ApplicationService] = None partial_state: bool = False @@ -144,16 +142,14 @@ class EventContext(UnpersistedEventContextBase): state_group_before_event: Optional[int], state_delta_due_to_event: Optional[StateMap[str]], partial_state: bool, - prev_group: Optional[int] = None, - delta_ids: Optional[StateMap[str]] = None, + state_group_deltas: Dict[Tuple[int, int], StateMap[str]], ) -> "EventContext": return EventContext( storage=storage, state_group=state_group, state_group_before_event=state_group_before_event, state_delta_due_to_event=state_delta_due_to_event, - prev_group=prev_group, - delta_ids=delta_ids, + state_group_deltas=state_group_deltas, partial_state=partial_state, ) @@ -162,7 +158,7 @@ class EventContext(UnpersistedEventContextBase): storage: "StorageControllers", ) -> "EventContext": """Return an EventContext instance suitable for persisting an outlier event""" - return EventContext(storage=storage) + return EventContext(storage=storage, state_group_deltas={}) async def persist(self, event: EventBase) -> "EventContext": return self @@ -182,11 +178,10 @@ class EventContext(UnpersistedEventContextBase): "state_group": self._state_group, "state_group_before_event": self.state_group_before_event, "rejected": self.rejected, - "prev_group": self.prev_group, + "state_group_deltas": _encode_state_group_delta(self.state_group_deltas), "state_delta_due_to_event": _encode_state_dict( self._state_delta_due_to_event ), - "delta_ids": _encode_state_dict(self.delta_ids), "app_service_id": self.app_service.id if self.app_service else None, "partial_state": self.partial_state, } @@ -203,17 +198,17 @@ class EventContext(UnpersistedEventContextBase): Returns: The event context. """ + context = EventContext( # We use the state_group and prev_state_id stuff to pull the # current_state_ids out of the DB and construct prev_state_ids. storage=storage, state_group=input["state_group"], state_group_before_event=input["state_group_before_event"], - prev_group=input["prev_group"], + state_group_deltas=_decode_state_group_delta(input["state_group_deltas"]), state_delta_due_to_event=_decode_state_dict( input["state_delta_due_to_event"] ), - delta_ids=_decode_state_dict(input["delta_ids"]), rejected=input["rejected"], partial_state=input.get("partial_state", False), ) @@ -242,6 +237,8 @@ class EventContext(UnpersistedEventContextBase): return self._state_group + @trace + @tag_args async def get_current_state_ids( self, state_filter: Optional["StateFilter"] = None ) -> Optional[StateMap[str]]: @@ -275,6 +272,8 @@ class EventContext(UnpersistedEventContextBase): return prev_state_ids + @trace + @tag_args async def get_prev_state_ids( self, state_filter: Optional["StateFilter"] = None ) -> StateMap[str]: @@ -344,7 +343,7 @@ class UnpersistedEventContext(UnpersistedEventContextBase): _storage: "StorageControllers" state_group_before_event: Optional[int] state_group_after_event: Optional[int] - state_delta_due_to_event: Optional[dict] + state_delta_due_to_event: Optional[StateMap[str]] prev_group_for_state_group_before_event: Optional[int] delta_ids_to_state_group_before_event: Optional[StateMap[str]] partial_state: bool @@ -375,26 +374,16 @@ class UnpersistedEventContext(UnpersistedEventContextBase): events_and_persisted_context = [] for event, unpersisted_context in amended_events_and_context: - if event.is_state(): - context = EventContext( - storage=unpersisted_context._storage, - state_group=unpersisted_context.state_group_after_event, - state_group_before_event=unpersisted_context.state_group_before_event, - state_delta_due_to_event=unpersisted_context.state_delta_due_to_event, - partial_state=unpersisted_context.partial_state, - prev_group=unpersisted_context.state_group_before_event, - delta_ids=unpersisted_context.state_delta_due_to_event, - ) - else: - context = EventContext( - storage=unpersisted_context._storage, - state_group=unpersisted_context.state_group_after_event, - state_group_before_event=unpersisted_context.state_group_before_event, - state_delta_due_to_event=unpersisted_context.state_delta_due_to_event, - partial_state=unpersisted_context.partial_state, - prev_group=unpersisted_context.prev_group_for_state_group_before_event, - delta_ids=unpersisted_context.delta_ids_to_state_group_before_event, - ) + state_group_deltas = unpersisted_context._build_state_group_deltas() + + context = EventContext( + storage=unpersisted_context._storage, + state_group=unpersisted_context.state_group_after_event, + state_group_before_event=unpersisted_context.state_group_before_event, + state_delta_due_to_event=unpersisted_context.state_delta_due_to_event, + partial_state=unpersisted_context.partial_state, + state_group_deltas=state_group_deltas, + ) events_and_persisted_context.append((event, context)) return events_and_persisted_context @@ -447,11 +436,11 @@ class UnpersistedEventContext(UnpersistedEventContextBase): # if the event isn't a state event the state group doesn't change if not self.state_delta_due_to_event: - state_group_after_event = self.state_group_before_event + self.state_group_after_event = self.state_group_before_event # otherwise if it is a state event we need to get a state group for it else: - state_group_after_event = await self._storage.state.store_state_group( + self.state_group_after_event = await self._storage.state.store_state_group( event.event_id, event.room_id, prev_group=self.state_group_before_event, @@ -459,16 +448,81 @@ class UnpersistedEventContext(UnpersistedEventContextBase): current_state_ids=None, ) + state_group_deltas = self._build_state_group_deltas() + return EventContext.with_state( storage=self._storage, - state_group=state_group_after_event, + state_group=self.state_group_after_event, state_group_before_event=self.state_group_before_event, state_delta_due_to_event=self.state_delta_due_to_event, + state_group_deltas=state_group_deltas, partial_state=self.partial_state, - prev_group=self.state_group_before_event, - delta_ids=self.state_delta_due_to_event, ) + def _build_state_group_deltas(self) -> Dict[Tuple[int, int], StateMap]: + """ + Collect deltas between the state groups associated with this context + """ + state_group_deltas = {} + + # if we know the state group before the event and after the event, add them and the + # state delta between them to state_group_deltas + if self.state_group_before_event and self.state_group_after_event: + # if we have the state groups we should have the delta + assert self.state_delta_due_to_event is not None + state_group_deltas[ + ( + self.state_group_before_event, + self.state_group_after_event, + ) + ] = self.state_delta_due_to_event + + # the state group before the event may also have a state group which precedes it, if + # we have that and the state group before the event, add them and the state + # delta between them to state_group_deltas + if ( + self.prev_group_for_state_group_before_event + and self.state_group_before_event + ): + # if we have both state groups we should have the delta between them + assert self.delta_ids_to_state_group_before_event is not None + state_group_deltas[ + ( + self.prev_group_for_state_group_before_event, + self.state_group_before_event, + ) + ] = self.delta_ids_to_state_group_before_event + + return state_group_deltas + + +def _encode_state_group_delta( + state_group_delta: Dict[Tuple[int, int], StateMap[str]] +) -> List[Tuple[int, int, Optional[List[Tuple[str, str, str]]]]]: + if not state_group_delta: + return [] + + state_group_delta_encoded = [] + for key, value in state_group_delta.items(): + state_group_delta_encoded.append((key[0], key[1], _encode_state_dict(value))) + + return state_group_delta_encoded + + +def _decode_state_group_delta( + input: List[Tuple[int, int, List[Tuple[str, str, str]]]] +) -> Dict[Tuple[int, int], StateMap[str]]: + if not input: + return {} + + state_group_deltas = {} + for state_group_1, state_group_2, state_dict in input: + state_map = _decode_state_dict(state_dict) + assert state_map is not None + state_group_deltas[(state_group_1, state_group_2)] = state_map + + return state_group_deltas + def _encode_state_dict( state_dict: Optional[StateMap[str]], diff --git a/synapse/events/utils.py b/synapse/events/utils.py index e6d040176b..53af423a5a 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -22,6 +22,7 @@ from typing import ( Iterable, List, Mapping, + Match, MutableMapping, Optional, Union, @@ -46,12 +47,10 @@ if TYPE_CHECKING: from synapse.handlers.relations import BundledAggregations -# Split strings on "." but not "\." This uses a negative lookbehind assertion for '\' -# (?<!stuff) matches if the current position in the string is not preceded -# by a match for 'stuff'. -# TODO: This is fast, but fails to handle "foo\\.bar" which should be treated as -# the literal fields "foo\" and "bar" but will instead be treated as "foo\\.bar" -SPLIT_FIELD_REGEX = re.compile(r"(?<!\\)\.") +# Split strings on "." but not "\." (or "\\\."). +SPLIT_FIELD_REGEX = re.compile(r"\\*\.") +# Find escaped characters, e.g. those with a \ in front of them. +ESCAPE_SEQUENCE_PATTERN = re.compile(r"\\(.)") CANONICALJSON_MAX_INT = (2**53) - 1 CANONICALJSON_MIN_INT = -CANONICALJSON_MAX_INT @@ -109,13 +108,9 @@ def prune_event_dict(room_version: RoomVersion, event_dict: JsonDict) -> JsonDic "origin_server_ts", ] - # Room versions from before MSC2176 had additional allowed keys. - if not room_version.msc2176_redaction_rules: - allowed_keys.extend(["prev_state", "membership"]) - - # Room versions before MSC3989 kept the origin field. - if not room_version.msc3989_redaction_rules: - allowed_keys.append("origin") + # Earlier room versions from had additional allowed keys. + if not room_version.updated_redaction_rules: + allowed_keys.extend(["prev_state", "membership", "origin"]) event_type = event_dict["type"] @@ -128,9 +123,9 @@ def prune_event_dict(room_version: RoomVersion, event_dict: JsonDict) -> JsonDic if event_type == EventTypes.Member: add_fields("membership") - if room_version.msc3375_redaction_rules: + if room_version.restricted_join_rule_fix: add_fields(EventContentFields.AUTHORISING_USER) - if room_version.msc3821_redaction_rules: + if room_version.updated_redaction_rules: # Preserve the signed field under third_party_invite. third_party_invite = event_dict["content"].get("third_party_invite") if isinstance(third_party_invite, collections.abc.Mapping): @@ -141,14 +136,16 @@ def prune_event_dict(room_version: RoomVersion, event_dict: JsonDict) -> JsonDic ] elif event_type == EventTypes.Create: - # MSC2176 rules state that create events cannot be redacted. - if room_version.msc2176_redaction_rules: - return event_dict + if room_version.updated_redaction_rules: + # MSC2176 rules state that create events cannot have their `content` redacted. + new_content = event_dict["content"] + elif not room_version.implicit_room_creator: + # Some room versions give meaning to `creator` + add_fields("creator") - add_fields("creator") elif event_type == EventTypes.JoinRules: add_fields("join_rule") - if room_version.msc3083_join_rules: + if room_version.restricted_join_rule: add_fields("allow") elif event_type == EventTypes.PowerLevels: add_fields( @@ -162,24 +159,15 @@ def prune_event_dict(room_version: RoomVersion, event_dict: JsonDict) -> JsonDic "redact", ) - if room_version.msc2176_redaction_rules: + if room_version.updated_redaction_rules: add_fields("invite") - if room_version.msc2716_historical: - add_fields("historical") - elif event_type == EventTypes.Aliases and room_version.special_case_aliases_auth: add_fields("aliases") elif event_type == EventTypes.RoomHistoryVisibility: add_fields("history_visibility") - elif event_type == EventTypes.Redaction and room_version.msc2176_redaction_rules: + elif event_type == EventTypes.Redaction and room_version.updated_redaction_rules: add_fields("redacts") - elif room_version.msc2716_redactions and event_type == EventTypes.MSC2716_INSERTION: - add_fields(EventContentFields.MSC2716_NEXT_BATCH_ID) - elif room_version.msc2716_redactions and event_type == EventTypes.MSC2716_BATCH: - add_fields(EventContentFields.MSC2716_BATCH_ID) - elif room_version.msc2716_redactions and event_type == EventTypes.MSC2716_MARKER: - add_fields(EventContentFields.MSC2716_INSERTION_EVENT_REFERENCE) # Protect the rel_type and event_id fields under the m.relates_to field. if room_version.msc3389_relation_redactions: @@ -253,6 +241,57 @@ def _copy_field(src: JsonDict, dst: JsonDict, field: List[str]) -> None: sub_out_dict[key_to_move] = sub_dict[key_to_move] +def _escape_slash(m: Match[str]) -> str: + """ + Replacement function; replace a backslash-backslash or backslash-dot with the + second character. Leaves any other string alone. + """ + if m.group(1) in ("\\", "."): + return m.group(1) + return m.group(0) + + +def _split_field(field: str) -> List[str]: + """ + Splits strings on unescaped dots and removes escaping. + + Args: + field: A string representing a path to a field. + + Returns: + A list of nested fields to traverse. + """ + + # Convert the field and remove escaping: + # + # 1. "content.body.thing\.with\.dots" + # 2. ["content", "body", "thing\.with\.dots"] + # 3. ["content", "body", "thing.with.dots"] + + # Find all dots (and their preceding backslashes). If the dot is unescaped + # then emit a new field part. + result = [] + prev_start = 0 + for match in SPLIT_FIELD_REGEX.finditer(field): + # If the match is an *even* number of characters than the dot was escaped. + if len(match.group()) % 2 == 0: + continue + + # Add a new part (up to the dot, exclusive) after escaping. + result.append( + ESCAPE_SEQUENCE_PATTERN.sub( + _escape_slash, field[prev_start : match.end() - 1] + ) + ) + prev_start = match.end() + + # Add any part of the field after the last unescaped dot. (Note that if the + # character is a dot this correctly adds a blank string.) + result.append(re.sub(r"\\(.)", _escape_slash, field[prev_start:])) + + return result + + def only_fields(dictionary: JsonDict, fields: List[str]) -> JsonDict: """Return a new dict with only the fields in 'dictionary' which are present in 'fields'. @@ -260,7 +299,7 @@ def only_fields(dictionary: JsonDict, fields: List[str]) -> JsonDict: If there are no event fields specified then all fields are included. The entries may include '.' characters to indicate sub-fields. So ['content.body'] will include the 'body' field of the 'content' object. - A literal '.' character in a field name may be escaped using a '\'. + A literal '.' or '\' character in a field name may be escaped using a '\'. Args: dictionary: The dictionary to read from. @@ -275,13 +314,7 @@ def only_fields(dictionary: JsonDict, fields: List[str]) -> JsonDict: # for each field, convert it: # ["content.body.thing\.with\.dots"] => [["content", "body", "thing\.with\.dots"]] - split_fields = [SPLIT_FIELD_REGEX.split(f) for f in fields] - - # for each element of the output array of arrays: - # remove escaping so we can use the right key names. - split_fields[:] = [ - [f.replace(r"\.", r".") for f in field_array] for field_array in split_fields - ] + split_fields = [_split_field(f) for f in fields] output: JsonDict = {} for field_array in split_fields: @@ -361,7 +394,6 @@ def serialize_event( time_now_ms: int, *, config: SerializeEventConfig = _DEFAULT_SERIALIZE_EVENT_CONFIG, - msc3970_enabled: bool = False, ) -> JsonDict: """Serialize event for clients @@ -369,8 +401,6 @@ def serialize_event( e time_now_ms config: Event serialization config - msc3970_enabled: Whether MSC3970 is enabled. It changes whether we should - include the `transaction_id` in the event's `unsigned` section. Returns: The serialized event dictionary. @@ -396,38 +426,46 @@ def serialize_event( e.unsigned["redacted_because"], time_now_ms, config=config, - msc3970_enabled=msc3970_enabled, ) # If we have a txn_id saved in the internal_metadata, we should include it in the # unsigned section of the event if it was sent by the same session as the one # requesting the event. txn_id: Optional[str] = getattr(e.internal_metadata, "txn_id", None) - if txn_id is not None and config.requester is not None: - # For the MSC3970 rules to be applied, we *need* to have the device ID in the - # event internal metadata. Since we were not recording them before, if it hasn't - # been recorded, we fallback to the old behaviour. + if ( + txn_id is not None + and config.requester is not None + and config.requester.user.to_string() == e.sender + ): + # Some events do not have the device ID stored in the internal metadata, + # this includes old events as well as those created by appservice, guests, + # or with tokens minted with the admin API. For those events, fallback + # to using the access token instead. event_device_id: Optional[str] = getattr(e.internal_metadata, "device_id", None) - if msc3970_enabled and event_device_id is not None: + if event_device_id is not None: if event_device_id == config.requester.device_id: d["unsigned"]["transaction_id"] = txn_id else: - # The pre-MSC3970 behaviour is to only include the transaction ID if the - # event was sent from the same access token. For regular users, we can use - # the access token ID to determine this. For guests, we can't, but since - # each guest only has one access token, we can just check that the event was - # sent by the same user as the one requesting the event. + # Fallback behaviour: only include the transaction ID if the event + # was sent from the same access token. + # + # For regular users, the access token ID can be used to determine this. + # This includes access tokens minted with the admin API. + # + # For guests and appservice users, we can't check the access token ID + # so assume it is the same session. event_token_id: Optional[int] = getattr( e.internal_metadata, "token_id", None ) - if config.requester.user.to_string() == e.sender and ( + if ( ( event_token_id is not None and config.requester.access_token_id is not None and event_token_id == config.requester.access_token_id ) or config.requester.is_guest + or config.requester.app_service ): d["unsigned"]["transaction_id"] = txn_id @@ -442,6 +480,17 @@ def serialize_event( if config.as_client_event: d = config.event_format(d) + # If the event is a redaction, the field with the redacted event ID appears + # in a different location depending on the room version. e.redacts handles + # fetching from the proper location; copy it to the other location for forwards- + # and backwards-compatibility with clients. + if e.type == EventTypes.Redaction and e.redacts is not None: + if e.room_version.updated_redaction_rules: + d["redacts"] = e.redacts + else: + d["content"] = dict(d["content"]) + d["content"]["redacts"] = e.redacts + only_event_fields = config.only_event_fields if only_event_fields: if not isinstance(only_event_fields, list) or not all( @@ -460,9 +509,6 @@ class EventClientSerializer: clients. """ - def __init__(self, *, msc3970_enabled: bool = False): - self._msc3970_enabled = msc3970_enabled - def serialize_event( self, event: Union[JsonDict, EventBase], @@ -487,9 +533,7 @@ class EventClientSerializer: if not isinstance(event, EventBase): return event - serialized_event = serialize_event( - event, time_now, config=config, msc3970_enabled=self._msc3970_enabled - ) + serialized_event = serialize_event(event, time_now, config=config) # Check if there are any bundled aggregations to include with the event. if bundle_aggregations: @@ -658,7 +702,7 @@ def _copy_power_level_value_as_integer( :raises TypeError: if `old_value` is neither an integer nor a base-10 string representation of an integer. """ - if type(old_value) is int: + if type(old_value) is int: # noqa: E721 power_levels[key] = old_value return @@ -686,7 +730,7 @@ def validate_canonicaljson(value: Any) -> None: * Floats * NaN, Infinity, -Infinity """ - if type(value) is int: + if type(value) is int: # noqa: E721 if value < CANONICALJSON_MIN_INT or CANONICALJSON_MAX_INT < value: raise SynapseError(400, "JSON integer out of range", Codes.BAD_JSON) diff --git a/synapse/events/validator.py b/synapse/events/validator.py index 47203209db..83d9fb5813 100644 --- a/synapse/events/validator.py +++ b/synapse/events/validator.py @@ -12,10 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. import collections.abc -from typing import Iterable, List, Type, Union, cast +from typing import TYPE_CHECKING, List, Type, Union, cast import jsonschema -from pydantic import Field, StrictBool, StrictStr + +from synapse._pydantic_compat import HAS_PYDANTIC_V2 + +if TYPE_CHECKING or HAS_PYDANTIC_V2: + from pydantic.v1 import Field, StrictBool, StrictStr +else: + from pydantic import Field, StrictBool, StrictStr from synapse.api.constants import ( MAX_ALIAS_LENGTH, @@ -33,10 +39,10 @@ from synapse.events.utils import ( CANONICALJSON_MIN_INT, validate_canonicaljson, ) -from synapse.federation.federation_server import server_matches_acl_event from synapse.http.servlet import validate_json_object from synapse.rest.models import RequestBodyModel -from synapse.types import EventID, JsonDict, RoomID, UserID +from synapse.storage.controllers.state import server_acl_evaluator_from_event +from synapse.types import EventID, JsonDict, RoomID, StrCollection, UserID class EventValidator: @@ -100,7 +106,10 @@ class EventValidator: self._validate_retention(event) elif event.type == EventTypes.ServerACL: - if not server_matches_acl_event(config.server.server_name, event): + server_acl_evaluator = server_acl_evaluator_from_event(event) + if not server_acl_evaluator.server_matches_acl_event( + config.server.server_name + ): raise SynapseError( 400, "Can't create an ACL event that denies the local server" ) @@ -134,13 +143,8 @@ class EventValidator: ) # If the event contains a mentions key, validate it. - if ( - EventContentFields.MSC3952_MENTIONS in event.content - and config.experimental.msc3952_intentional_mentions - ): - validate_json_object( - event.content[EventContentFields.MSC3952_MENTIONS], Mentions - ) + if EventContentFields.MENTIONS in event.content: + validate_json_object(event.content[EventContentFields.MENTIONS], Mentions) def _validate_retention(self, event: EventBase) -> None: """Checks that an event that defines the retention policy for a room respects the @@ -156,7 +160,7 @@ class EventValidator: max_lifetime = event.content.get("max_lifetime") if min_lifetime is not None: - if type(min_lifetime) is not int: + if type(min_lifetime) is not int: # noqa: E721 raise SynapseError( code=400, msg="'min_lifetime' must be an integer", @@ -164,7 +168,7 @@ class EventValidator: ) if max_lifetime is not None: - if type(max_lifetime) is not int: + if type(max_lifetime) is not int: # noqa: E721 raise SynapseError( code=400, msg="'max_lifetime' must be an integer", @@ -230,7 +234,7 @@ class EventValidator: self._ensure_state_event(event) - def _ensure_strings(self, d: JsonDict, keys: Iterable[str]) -> None: + def _ensure_strings(self, d: JsonDict, keys: StrCollection) -> None: for s in keys: if s not in d: raise SynapseError(400, "'%s' not in content" % (s,)) diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index b77022b406..d4e7dd45a9 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -231,7 +231,7 @@ async def _check_sigs_on_pdu( # If this is a join event for a restricted room it may have been authorised # via a different server from the sending server. Check those signatures. if ( - room_version.msc3083_join_rules + room_version.restricted_join_rule and pdu.type == EventTypes.Member and pdu.membership == Membership.JOIN and EventContentFields.AUTHORISING_USER in pdu.content @@ -280,7 +280,7 @@ def event_from_pdu_json(pdu_json: JsonDict, room_version: RoomVersion) -> EventB _strip_unsigned_values(pdu_json) depth = pdu_json["depth"] - if type(depth) is not int: + if type(depth) is not int: # noqa: E721 raise SynapseError(400, "Depth %r not an intger" % (depth,), Codes.BAD_JSON) if depth < 0: diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 076b9287c6..1a7fa175ec 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -64,7 +64,7 @@ from synapse.federation.transport.client import SendJoinResponse from synapse.http.client import is_unknown_endpoint from synapse.http.types import QueryParams from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, tag_args, trace -from synapse.types import JsonDict, UserID, get_domain_from_id +from synapse.types import JsonDict, StrCollection, UserID, get_domain_from_id from synapse.util.async_helpers import concurrently_execute from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.retryutils import NotRetryingDestination @@ -236,6 +236,7 @@ class FederationClient(FederationBase): async def claim_client_keys( self, + user: UserID, destination: str, query: Dict[str, Dict[str, Dict[str, int]]], timeout: Optional[int], @@ -243,6 +244,7 @@ class FederationClient(FederationBase): """Claims one-time keys for a device hosted on a remote server. Args: + user: The user id of the requesting user destination: Domain name of the remote homeserver content: The query content. @@ -258,7 +260,9 @@ class FederationClient(FederationBase): use_unstable = False for user_id, one_time_keys in query.items(): for device_id, algorithms in one_time_keys.items(): - if any(count > 1 for count in algorithms.values()): + # If more than one algorithm is requested, attempt to use the unstable + # endpoint. + if sum(algorithms.values()) > 1: use_unstable = True if algorithms: # For the stable query, choose only the first algorithm. @@ -279,7 +283,7 @@ class FederationClient(FederationBase): if use_unstable: try: return await self.transport_layer.claim_client_keys_unstable( - destination, unstable_content, timeout + user, destination, unstable_content, timeout ) except HttpResponseException as e: # If an error is received that is due to an unrecognised endpoint, @@ -294,8 +298,9 @@ class FederationClient(FederationBase): else: logger.debug("Skipping unstable claim client keys API") + # TODO Potentially attempt multiple queries and combine the results? return await self.transport_layer.claim_client_keys( - destination, content, timeout + user, destination, content, timeout ) @trace @@ -978,7 +983,7 @@ class FederationClient(FederationBase): if not room_version: raise UnsupportedRoomVersionError() - if not room_version.msc2403_knocking and membership == Membership.KNOCK: + if not room_version.knock_join_rule and membership == Membership.KNOCK: raise SynapseError( 400, "This room version does not support knocking", @@ -1064,7 +1069,7 @@ class FederationClient(FederationBase): # * Ensure the signatures are good. # # Otherwise, fallback to the provided event. - if room_version.msc3083_join_rules and response.event: + if room_version.restricted_join_rule and response.event: event = response.event valid_pdu = await self._check_sigs_and_hash_and_fetch_one( @@ -1190,7 +1195,7 @@ class FederationClient(FederationBase): # MSC3083 defines additional error codes for room joins. failover_errcodes = None - if room_version.msc3083_join_rules: + if room_version.restricted_join_rule: failover_errcodes = ( Codes.UNABLE_AUTHORISE_JOIN, Codes.UNABLE_TO_GRANT_JOIN, @@ -1397,7 +1402,7 @@ class FederationClient(FederationBase): The remote homeserver return some state from the room. The response dictionary is in the form: - {"knock_state_events": [<state event dict>, ...]} + {"knock_room_state": [<state event dict>, ...]} The list of state events may be empty. @@ -1424,7 +1429,7 @@ class FederationClient(FederationBase): The remote homeserver can optionally return some state from the room. The response dictionary is in the form: - {"knock_state_events": [<state event dict>, ...]} + {"knock_room_state": [<state event dict>, ...]} The list of state events may be empty. """ @@ -1699,7 +1704,7 @@ class FederationClient(FederationBase): async def timestamp_to_event( self, *, - destinations: List[str], + destinations: StrCollection, room_id: str, timestamp: int, direction: Direction, @@ -1886,7 +1891,7 @@ class TimestampToEventResponse: ) origin_server_ts = d.get("origin_server_ts") - if type(origin_server_ts) is not int: + if type(origin_server_ts) is not int: # noqa: E721 raise ValueError( "Invalid response: 'origin_server_ts' must be a int but received %r" % origin_server_ts diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index f4ca70a698..6ac8d16095 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -29,10 +29,8 @@ from typing import ( Union, ) -from matrix_common.regex import glob_to_regex from prometheus_client import Counter, Gauge, Histogram -from twisted.internet.abstract import isIPAddress from twisted.python import failure from synapse.api.constants import ( @@ -63,6 +61,7 @@ from synapse.federation.federation_base import ( ) from synapse.federation.persistence import TransactionActions from synapse.federation.units import Edu, Transaction +from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME from synapse.http.servlet import assert_params_in_dict from synapse.logging.context import ( make_deferred_yieldable, @@ -137,6 +136,7 @@ class FederationServer(FederationBase): self._event_auth_handler = hs.get_event_auth_handler() self._room_member_handler = hs.get_room_member_handler() self._e2e_keys_handler = hs.get_e2e_keys_handler() + self._worker_lock_handler = hs.get_worker_locks_handler() self._state_storage_controller = hs.get_storage_controllers().state @@ -515,7 +515,7 @@ class FederationServer(FederationBase): logger.error( "Failed to handle PDU %s", event_id, - exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore + exc_info=(f.type, f.value, f.getTracebackObject()), ) return {"error": str(e)} @@ -806,7 +806,7 @@ class FederationServer(FederationBase): raise IncompatibleRoomVersionError(room_version=room_version.identifier) # Check that this room supports knocking as defined by its room version - if not room_version.msc2403_knocking: + if not room_version.knock_join_rule: raise SynapseError( 403, "This room version does not support knocking", @@ -850,14 +850,7 @@ class FederationServer(FederationBase): context, self._room_prejoin_state_types ) ) - return { - "knock_room_state": stripped_room_state, - # Since v1.37, Synapse incorrectly used "knock_state_events" for this field. - # Thus, we also populate a 'knock_state_events' with the same content to - # support old instances. - # See https://github.com/matrix-org/synapse/issues/14088. - "knock_state_events": stripped_room_state, - } + return {"knock_room_state": stripped_room_state} async def _on_send_membership_event( self, origin: str, content: JsonDict, membership_type: str, room_id: str @@ -909,7 +902,7 @@ class FederationServer(FederationBase): errcode=Codes.NOT_FOUND, ) - if membership_type == Membership.KNOCK and not room_version.msc2403_knocking: + if membership_type == Membership.KNOCK and not room_version.knock_join_rule: raise SynapseError( 403, "This room version does not support knocking", @@ -933,7 +926,7 @@ class FederationServer(FederationBase): # the event is valid to be sent into the room. Currently this is only done # if the user is being joined via restricted join rules. if ( - room_version.msc3083_join_rules + room_version.restricted_join_rule and event.membership == Membership.JOIN and EventContentFields.AUTHORISING_USER in event.content ): @@ -944,7 +937,7 @@ class FederationServer(FederationBase): if not self._is_mine_server_name(authorising_server): raise SynapseError( 400, - f"Cannot authorise request from resident server: {authorising_server}", + f"Cannot authorise membership event for {authorising_server}. We can only authorise requests from our own homeserver", ) event.signatures.update( @@ -1016,7 +1009,9 @@ class FederationServer(FederationBase): for user_id, device_keys in result.items(): for device_id, keys in device_keys.items(): for key_id, key in keys.items(): - json_result.setdefault(user_id, {})[device_id] = {key_id: key} + json_result.setdefault(user_id, {}).setdefault(device_id, {})[ + key_id + ] = key logger.info( "Claimed one-time-keys: %s", @@ -1234,9 +1229,18 @@ class FederationServer(FederationBase): logger.info("handling received PDU in room %s: %s", room_id, event) try: with nested_logging_context(event.event_id): - await self._federation_event_handler.on_receive_pdu( - origin, event - ) + # We're taking out a lock within a lock, which could + # lead to deadlocks if we're not careful. However, it is + # safe on this occasion as we only ever take a write + # lock when deleting a room, which we would never do + # while holding the `_INBOUND_EVENT_HANDLING_LOCK_NAME` + # lock. + async with self._worker_lock_handler.acquire_read_write_lock( + NEW_EVENT_DURING_PURGE_LOCK_NAME, room_id, write=False + ): + await self._federation_event_handler.on_receive_pdu( + origin, event + ) except FederationError as e: # XXX: Ideally we'd inform the remote we failed to process # the event, but we can't return an error in the transaction @@ -1247,7 +1251,7 @@ class FederationServer(FederationBase): logger.error( "Failed to handle PDU %s", event.event_id, - exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore + exc_info=(f.type, f.value, f.getTracebackObject()), ) received_ts = await self.store.remove_received_event_from_staging( @@ -1291,9 +1295,6 @@ class FederationServer(FederationBase): return lock = new_lock - def __str__(self) -> str: - return "<ReplicationLayer(%s)>" % self.server_name - async def exchange_third_party_invite( self, sender_user_id: str, target_user_id: str, room_id: str, signed: Dict ) -> None: @@ -1314,75 +1315,13 @@ class FederationServer(FederationBase): Raises: AuthError if the server does not match the ACL """ - acl_event = await self._storage_controllers.state.get_current_state_event( - room_id, EventTypes.ServerACL, "" - ) - if not acl_event or server_matches_acl_event(server_name, acl_event): - return - - raise AuthError(code=403, msg="Server is banned from room") - - -def server_matches_acl_event(server_name: str, acl_event: EventBase) -> bool: - """Check if the given server is allowed by the ACL event - - Args: - server_name: name of server, without any port part - acl_event: m.room.server_acl event - - Returns: - True if this server is allowed by the ACLs - """ - logger.debug("Checking %s against acl %s", server_name, acl_event.content) - - # first of all, check if literal IPs are blocked, and if so, whether the - # server name is a literal IP - allow_ip_literals = acl_event.content.get("allow_ip_literals", True) - if not isinstance(allow_ip_literals, bool): - logger.warning("Ignoring non-bool allow_ip_literals flag") - allow_ip_literals = True - if not allow_ip_literals: - # check for ipv6 literals. These start with '['. - if server_name[0] == "[": - return False - - # check for ipv4 literals. We can just lift the routine from twisted. - if isIPAddress(server_name): - return False - - # next, check the deny list - deny = acl_event.content.get("deny", []) - if not isinstance(deny, (list, tuple)): - logger.warning("Ignoring non-list deny ACL %s", deny) - deny = [] - for e in deny: - if _acl_entry_matches(server_name, e): - # logger.info("%s matched deny rule %s", server_name, e) - return False - - # then the allow list. - allow = acl_event.content.get("allow", []) - if not isinstance(allow, (list, tuple)): - logger.warning("Ignoring non-list allow ACL %s", allow) - allow = [] - for e in allow: - if _acl_entry_matches(server_name, e): - # logger.info("%s matched allow rule %s", server_name, e) - return True - - # everything else should be rejected. - # logger.info("%s fell through", server_name) - return False - - -def _acl_entry_matches(server_name: str, acl_entry: Any) -> bool: - if not isinstance(acl_entry, str): - logger.warning( - "Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry) + server_acl_evaluator = ( + await self._storage_controllers.state.get_server_acl_for_room(room_id) ) - return False - regex = glob_to_regex(acl_entry) - return bool(regex.match(server_name)) + if server_acl_evaluator and not server_acl_evaluator.server_matches_acl_event( + server_name + ): + raise AuthError(code=403, msg="Server is banned from room") class FederationHandlerRegistry: diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index fb448f2155..525968bcba 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -49,7 +49,7 @@ from synapse.api.presence import UserPresenceState from synapse.federation.sender import AbstractFederationSender, FederationSender from synapse.metrics import LaterGauge from synapse.replication.tcp.streams.federation import FederationStream -from synapse.types import JsonDict, ReadReceipt, RoomStreamToken +from synapse.types import JsonDict, ReadReceipt, RoomStreamToken, StrCollection from synapse.util.metrics import Measure from .units import Edu @@ -229,7 +229,7 @@ class FederationRemoteSendQueue(AbstractFederationSender): """ # nothing to do here: the replication listener will handle it. - def send_presence_to_destinations( + async def send_presence_to_destinations( self, states: Iterable[UserPresenceState], destinations: Iterable[str] ) -> None: """As per FederationSender @@ -245,7 +245,9 @@ class FederationRemoteSendQueue(AbstractFederationSender): self.notifier.on_new_replication_data() - def send_device_messages(self, destination: str, immediate: bool = True) -> None: + async def send_device_messages( + self, destinations: StrCollection, immediate: bool = True + ) -> None: """As per FederationSender""" # We don't need to replicate this as it gets sent down a different # stream. @@ -393,7 +395,7 @@ class PresenceDestinationsRow(BaseFederationRow): @staticmethod def from_data(data: JsonDict) -> "PresenceDestinationsRow": return PresenceDestinationsRow( - state=UserPresenceState.from_dict(data["state"]), destinations=data["dests"] + state=UserPresenceState(**data["state"]), destinations=data["dests"] ) def to_data(self) -> JsonDict: @@ -463,7 +465,7 @@ class ParsedFederationStreamData: edus: Dict[str, List[Edu]] -def process_rows_for_federation( +async def process_rows_for_federation( transaction_queue: FederationSender, rows: List[FederationStream.FederationStreamRow], ) -> None: @@ -496,7 +498,7 @@ def process_rows_for_federation( parsed_row.add_to_buffer(buff) for state, destinations in buff.presence_destinations: - transaction_queue.send_presence_to_destinations( + await transaction_queue.send_presence_to_destinations( states=[state], destinations=destinations ) diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index f3bdc5a4d2..7b6b1da090 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -67,7 +67,7 @@ The loop continues so long as there is anything to send. At each iteration of th When the `PerDestinationQueue` has the catch-up flag set, the *Catch-Up Transmission Loop* (`_catch_up_transmission_loop`) is used in lieu of the regular `_transaction_transmission_loop`. -(Only once the catch-up mode has been exited can the regular tranaction transmission behaviour +(Only once the catch-up mode has been exited can the regular transaction transmission behaviour be resumed.) *Catch-Up Mode*, entered upon Synapse startup or once a homeserver has fallen behind due to @@ -109,10 +109,8 @@ was enabled*, Catch-Up Mode is exited and we return to `_transaction_transmissio If a remote server is unreachable over federation, we back off from that server, with an exponentially-increasing retry interval. -Whilst we don't automatically retry after the interval, we prevent making new attempts -until such time as the back-off has cleared. -Once the back-off is cleared and a new PDU or EDU arrives for transmission, the transmission -loop resumes and empties the queue by making federation requests. +We automatically retry after the retry interval expires (roughly, the logic to do so +being triggered every minute). If the backoff grows too large (> 1 hour), the in-memory queue is emptied (to prevent unbounded growth) and Catch-Up Mode is entered. @@ -145,12 +143,14 @@ from prometheus_client import Counter from typing_extensions import Literal from twisted.internet import defer -from twisted.internet.interfaces import IDelayedCall import synapse.metrics from synapse.api.presence import UserPresenceState from synapse.events import EventBase -from synapse.federation.sender.per_destination_queue import PerDestinationQueue +from synapse.federation.sender.per_destination_queue import ( + CATCHUP_RETRY_INTERVAL, + PerDestinationQueue, +) from synapse.federation.sender.transaction_manager import TransactionManager from synapse.federation.units import Edu from synapse.logging.context import make_deferred_yieldable, run_in_background @@ -164,9 +164,10 @@ from synapse.metrics.background_process_metrics import ( run_as_background_process, wrap_as_background_process, ) -from synapse.types import JsonDict, ReadReceipt, RoomStreamToken +from synapse.types import JsonDict, ReadReceipt, RoomStreamToken, StrCollection from synapse.util import Clock from synapse.util.metrics import Measure +from synapse.util.retryutils import filter_destinations_by_retry_limiter if TYPE_CHECKING: from synapse.events.presence_router import PresenceRouter @@ -184,14 +185,18 @@ sent_pdus_destination_dist_total = Counter( "Total number of PDUs queued for sending across all destinations", ) -# Time (in s) after Synapse's startup that we will begin to wake up destinations -# that have catch-up outstanding. -CATCH_UP_STARTUP_DELAY_SEC = 15 +# Time (in s) to wait before trying to wake up destinations that have +# catch-up outstanding. This will also be the delay applied at startup +# before trying the same. +# Please note that rate limiting still applies, so while the loop is +# executed every X seconds the destinations may not be wake up because +# they are being rate limited following previous attempt failures. +WAKEUP_RETRY_PERIOD_SEC = 60 # Time (in s) to wait in between waking up each destination, i.e. one destination -# will be woken up every <x> seconds after Synapse's startup until we have woken -# every destination has outstanding catch-up. -CATCH_UP_STARTUP_INTERVAL_SEC = 5 +# will be woken up every <x> seconds until we have woken every destination +# has outstanding catch-up. +WAKEUP_INTERVAL_BETWEEN_DESTINATIONS_SEC = 5 class AbstractFederationSender(metaclass=abc.ABCMeta): @@ -212,7 +217,7 @@ class AbstractFederationSender(metaclass=abc.ABCMeta): raise NotImplementedError() @abc.abstractmethod - def send_presence_to_destinations( + async def send_presence_to_destinations( self, states: Iterable[UserPresenceState], destinations: Iterable[str] ) -> None: """Send the given presence states to the given destinations. @@ -241,9 +246,11 @@ class AbstractFederationSender(metaclass=abc.ABCMeta): raise NotImplementedError() @abc.abstractmethod - def send_device_messages(self, destination: str, immediate: bool = True) -> None: + async def send_device_messages( + self, destinations: StrCollection, immediate: bool = True + ) -> None: """Tells the sender that a new device message is ready to be sent to the - destination. The `immediate` flag specifies whether the messages should + destinations. The `immediate` flag specifies whether the messages should be tried to be sent immediately, or whether it can be delayed for a short while (to aid performance). """ @@ -415,12 +422,10 @@ class FederationSender(AbstractFederationSender): / hs.config.ratelimiting.federation_rr_transactions_per_room_per_second ) - # wake up destinations that have outstanding PDUs to be caught up - self._catchup_after_startup_timer: Optional[ - IDelayedCall - ] = self.clock.call_later( - CATCH_UP_STARTUP_DELAY_SEC, + # Regularly wake up destinations that have outstanding PDUs to be caught up + self.clock.looping_call( run_as_background_process, + WAKEUP_RETRY_PERIOD_SEC * 1000.0, "wake_destinations_needing_catchup", self._wake_destinations_needing_catchup, ) @@ -717,6 +722,13 @@ class FederationSender(AbstractFederationSender): pdu.internal_metadata.stream_ordering, ) + destinations = await filter_destinations_by_retry_limiter( + destinations, + clock=self.clock, + store=self.store, + retry_due_within_ms=CATCHUP_RETRY_INTERVAL, + ) + for destination in destinations: self._get_per_destination_queue(destination).send_pdu(pdu) @@ -764,12 +776,20 @@ class FederationSender(AbstractFederationSender): domains_set = await self._storage_controllers.state.get_current_hosts_in_room_or_partial_state_approximation( room_id ) - domains = [ + domains: StrCollection = [ d for d in domains_set if not self.is_mine_server_name(d) and self._federation_shard_config.should_handle(self._instance_name, d) ] + + domains = await filter_destinations_by_retry_limiter( + domains, + clock=self.clock, + store=self.store, + retry_due_within_ms=CATCHUP_RETRY_INTERVAL, + ) + if not domains: return @@ -817,7 +837,7 @@ class FederationSender(AbstractFederationSender): for queue in queues: queue.flush_read_receipts_for_room(room_id) - def send_presence_to_destinations( + async def send_presence_to_destinations( self, states: Iterable[UserPresenceState], destinations: Iterable[str] ) -> None: """Send the given presence states to the given destinations. @@ -832,13 +852,20 @@ class FederationSender(AbstractFederationSender): for state in states: assert self.is_mine_id(state.user_id) + destinations = await filter_destinations_by_retry_limiter( + [ + d + for d in destinations + if self._federation_shard_config.should_handle(self._instance_name, d) + ], + clock=self.clock, + store=self.store, + retry_due_within_ms=CATCHUP_RETRY_INTERVAL, + ) + for destination in destinations: if self.is_mine_server_name(destination): continue - if not self._federation_shard_config.should_handle( - self._instance_name, destination - ): - continue self._get_per_destination_queue(destination).send_presence( states, start_loop=False @@ -897,21 +924,29 @@ class FederationSender(AbstractFederationSender): else: queue.send_edu(edu) - def send_device_messages(self, destination: str, immediate: bool = True) -> None: - if self.is_mine_server_name(destination): - logger.warning("Not sending device update to ourselves") - return - - if not self._federation_shard_config.should_handle( - self._instance_name, destination - ): - return + async def send_device_messages( + self, destinations: StrCollection, immediate: bool = True + ) -> None: + destinations = await filter_destinations_by_retry_limiter( + [ + destination + for destination in destinations + if self._federation_shard_config.should_handle( + self._instance_name, destination + ) + and not self.is_mine_server_name(destination) + ], + clock=self.clock, + store=self.store, + retry_due_within_ms=CATCHUP_RETRY_INTERVAL, + ) - if immediate: - self._get_per_destination_queue(destination).attempt_new_transaction() - else: - self._get_per_destination_queue(destination).mark_new_data() - self._destination_wakeup_queue.add_to_queue(destination) + for destination in destinations: + if immediate: + self._get_per_destination_queue(destination).attempt_new_transaction() + else: + self._get_per_destination_queue(destination).mark_new_data() + self._destination_wakeup_queue.add_to_queue(destination) def wake_destination(self, destination: str) -> None: """Called when we want to retry sending transactions to a remote. @@ -966,7 +1001,6 @@ class FederationSender(AbstractFederationSender): if not destinations_to_wake: # finished waking all destinations! - self._catchup_after_startup_timer = None break last_processed = destinations_to_wake[-1] @@ -983,4 +1017,4 @@ class FederationSender(AbstractFederationSender): last_processed, ) self.wake_destination(destination) - await self.clock.sleep(CATCH_UP_STARTUP_INTERVAL_SEC) + await self.clock.sleep(WAKEUP_INTERVAL_BETWEEN_DESTINATIONS_SEC) diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index 31c5c2b7de..9105ba664c 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -59,6 +59,10 @@ sent_edus_by_type = Counter( ) +# If the retry interval is larger than this then we enter "catchup" mode +CATCHUP_RETRY_INTERVAL = 60 * 60 * 1000 + + class PerDestinationQueue: """ Manages the per-destination transmission queues. @@ -370,7 +374,7 @@ class PerDestinationQueue: ), ) - if e.retry_interval > 60 * 60 * 1000: + if e.retry_interval > CATCHUP_RETRY_INTERVAL: # we won't retry for another hour! # (this suggests a significant outage) # We drop pending EDUs because otherwise they will diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 1cfc4446c4..fab4800717 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -45,7 +45,7 @@ from synapse.events import EventBase, make_event_from_dict from synapse.federation.units import Transaction from synapse.http.matrixfederationclient import ByteParser, LegacyJsonSendParser from synapse.http.types import QueryParams -from synapse.types import JsonDict +from synapse.types import JsonDict, UserID from synapse.util import ExceptionBundle if TYPE_CHECKING: @@ -249,8 +249,10 @@ class TransportLayerClient: data=json_data, json_data_callback=json_data_callback, long_retries=True, - backoff_on_404=True, # If we get a 404 the other side has gone try_trailing_slash_on_400=True, + # Sending a transaction should always succeed, if it doesn't + # then something is wrong and we should backoff. + backoff_on_all_error_codes=True, ) async def make_query( @@ -429,7 +431,7 @@ class TransportLayerClient: The remote homeserver can optionally return some state from the room. The response dictionary is in the form: - {"knock_state_events": [<state event dict>, ...]} + {"knock_room_state": [<state event dict>, ...]} The list of state events may be empty. """ @@ -475,13 +477,11 @@ class TransportLayerClient: See synapse.federation.federation_client.FederationClient.get_public_rooms for more information. """ + path = _create_v1_path("/publicRooms") + if search_filter: # this uses MSC2197 (Search Filtering over Federation) - path = _create_v1_path("/publicRooms") - - data: Dict[str, Any] = { - "include_all_networks": "true" if include_all_networks else "false" - } + data: Dict[str, Any] = {"include_all_networks": include_all_networks} if third_party_instance_id: data["third_party_instance_id"] = third_party_instance_id if limit: @@ -505,17 +505,15 @@ class TransportLayerClient: ) raise else: - path = _create_v1_path("/publicRooms") - args: Dict[str, Union[str, Iterable[str]]] = { "include_all_networks": "true" if include_all_networks else "false" } if third_party_instance_id: - args["third_party_instance_id"] = (third_party_instance_id,) + args["third_party_instance_id"] = third_party_instance_id if limit: - args["limit"] = [str(limit)] + args["limit"] = str(limit) if since_token: - args["since"] = [since_token] + args["since"] = since_token try: response = await self.client.get_json( @@ -630,7 +628,11 @@ class TransportLayerClient: ) async def claim_client_keys( - self, destination: str, query_content: JsonDict, timeout: Optional[int] + self, + user: UserID, + destination: str, + query_content: JsonDict, + timeout: Optional[int], ) -> JsonDict: """Claim one-time keys for a list of devices hosted on a remote server. @@ -655,6 +657,7 @@ class TransportLayerClient: } Args: + user: the user_id of the requesting user destination: The server to query. query_content: The user ids to query. Returns: @@ -671,7 +674,11 @@ class TransportLayerClient: ) async def claim_client_keys_unstable( - self, destination: str, query_content: JsonDict, timeout: Optional[int] + self, + user: UserID, + destination: str, + query_content: JsonDict, + timeout: Optional[int], ) -> JsonDict: """Claim one-time keys for a list of devices hosted on a remote server. @@ -696,6 +703,7 @@ class TransportLayerClient: } Args: + user: the user_id of the requesting user destination: The server to query. query_content: The user ids to query. Returns: diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py index 3a744e25be..3248953b48 100644 --- a/synapse/federation/transport/server/federation.py +++ b/synapse/federation/transport/server/federation.py @@ -432,15 +432,6 @@ class FederationV2SendJoinServlet(BaseFederationServerServlet): PREFIX = FEDERATION_V2_PREFIX - def __init__( - self, - hs: "HomeServer", - authenticator: Authenticator, - ratelimiter: FederationRateLimiter, - server_name: str, - ): - super().__init__(hs, authenticator, ratelimiter, server_name) - async def on_PUT( self, origin: str, diff --git a/synapse/handlers/account.py b/synapse/handlers/account.py index c05a14304c..fa043cca86 100644 --- a/synapse/handlers/account.py +++ b/synapse/handlers/account.py @@ -102,7 +102,7 @@ class AccountHandler: """ status = {"exists": False} - userinfo = await self._main_store.get_userinfo_by_id(user_id.to_string()) + userinfo = await self._main_store.get_user_by_id(user_id.to_string()) if userinfo is not None: status = { diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py index 4aa4ebf7e4..f1a7a05df6 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py @@ -164,7 +164,7 @@ class AccountValidityHandler: try: user_display_name = await self.store.get_profile_displayname( - UserID.from_string(user_id).localpart + UserID.from_string(user_id) ) if user_display_name is None: user_display_name = user_id diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index b06f25b03c..97fd1fd427 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -14,11 +14,11 @@ import abc import logging -from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set +from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Sequence, Set from synapse.api.constants import Direction, Membership from synapse.events import EventBase -from synapse.types import JsonDict, RoomStreamToken, StateMap, UserID +from synapse.types import JsonMapping, RoomStreamToken, StateMap, UserID, UserInfo from synapse.visibility import filter_events_for_client if TYPE_CHECKING: @@ -35,7 +35,7 @@ class AdminHandler: self._state_storage_controller = self._storage_controllers.state self._msc3866_enabled = hs.config.experimental.msc3866.enabled - async def get_whois(self, user: UserID) -> JsonDict: + async def get_whois(self, user: UserID) -> JsonMapping: connections = [] sessions = await self._store.get_user_ip_and_agents(user) @@ -55,41 +55,35 @@ class AdminHandler: return ret - async def get_user(self, user: UserID) -> Optional[JsonDict]: + async def get_user(self, user: UserID) -> Optional[JsonMapping]: """Function to get user details""" - user_info_dict = await self._store.get_user_by_id(user.to_string()) - if user_info_dict is None: + user_info: Optional[UserInfo] = await self._store.get_user_by_id( + user.to_string() + ) + if user_info is None: return None - # Restrict returned information to a known set of fields. This prevents additional - # fields added to get_user_by_id from modifying Synapse's external API surface. - user_info_to_return = { - "name", - "admin", - "deactivated", - "shadow_banned", - "creation_ts", - "appservice_id", - "consent_server_notice_sent", - "consent_version", - "consent_ts", - "user_type", - "is_guest", + user_info_dict = { + "name": user.to_string(), + "admin": user_info.is_admin, + "deactivated": user_info.is_deactivated, + "locked": user_info.locked, + "shadow_banned": user_info.is_shadow_banned, + "creation_ts": user_info.creation_ts, + "appservice_id": user_info.appservice_id, + "consent_server_notice_sent": user_info.consent_server_notice_sent, + "consent_version": user_info.consent_version, + "consent_ts": user_info.consent_ts, + "user_type": user_info.user_type, + "is_guest": user_info.is_guest, } if self._msc3866_enabled: # Only include the approved flag if support for MSC3866 is enabled. - user_info_to_return.add("approved") - - # Restrict returned keys to a known set. - user_info_dict = { - key: value - for key, value in user_info_dict.items() - if key in user_info_to_return - } + user_info_dict["approved"] = user_info.approved # Add additional user metadata - profile = await self._store.get_profileinfo(user.localpart) + profile = await self._store.get_profileinfo(user) threepids = await self._store.user_get_threepids(user.to_string()) external_ids = [ ({"auth_provider": auth_provider, "external_id": external_id}) @@ -103,6 +97,9 @@ class AdminHandler: user_info_dict["external_ids"] = external_ids user_info_dict["erased"] = await self._store.is_user_erased(user.to_string()) + last_seen_ts = await self._store.get_last_seen_for_user_id(user.to_string()) + user_info_dict["last_seen_ts"] = last_seen_ts + return user_info_dict async def export_user_data(self, user_id: str, writer: "ExfiltrationWriter") -> Any: @@ -174,8 +171,8 @@ class AdminHandler: else: stream_ordering = room.stream_ordering - from_key = RoomStreamToken(0, 0) - to_key = RoomStreamToken(None, stream_ordering) + from_key = RoomStreamToken(topological=0, stream=0) + to_key = RoomStreamToken(stream=stream_ordering) # Events that we've processed in this room written_events: Set[str] = set() @@ -347,7 +344,7 @@ class ExfiltrationWriter(metaclass=abc.ABCMeta): raise NotImplementedError() @abc.abstractmethod - def write_profile(self, profile: JsonDict) -> None: + def write_profile(self, profile: JsonMapping) -> None: """Write the profile of a user. Args: @@ -356,7 +353,7 @@ class ExfiltrationWriter(metaclass=abc.ABCMeta): raise NotImplementedError() @abc.abstractmethod - def write_devices(self, devices: List[JsonDict]) -> None: + def write_devices(self, devices: Sequence[JsonMapping]) -> None: """Write the devices of a user. Args: @@ -365,7 +362,7 @@ class ExfiltrationWriter(metaclass=abc.ABCMeta): raise NotImplementedError() @abc.abstractmethod - def write_connections(self, connections: List[JsonDict]) -> None: + def write_connections(self, connections: Sequence[JsonMapping]) -> None: """Write the connections of a user. Args: @@ -375,7 +372,7 @@ class ExfiltrationWriter(metaclass=abc.ABCMeta): @abc.abstractmethod def write_account_data( - self, file_name: str, account_data: Mapping[str, JsonDict] + self, file_name: str, account_data: Mapping[str, JsonMapping] ) -> None: """Write the account data of a user. @@ -386,7 +383,7 @@ class ExfiltrationWriter(metaclass=abc.ABCMeta): raise NotImplementedError() @abc.abstractmethod - def write_media_id(self, media_id: str, media_metadata: JsonDict) -> None: + def write_media_id(self, media_id: str, media_metadata: JsonMapping) -> None: """Write the media's metadata of a user. Exports only the metadata, as this can be fetched from the database via read only. In order to access the files, a connection to the correct diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 6429545c98..c200a45f3a 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -46,6 +46,7 @@ from synapse.storage.databases.main.directory import RoomAliasMapping from synapse.types import ( DeviceListUpdates, JsonDict, + JsonMapping, RoomAlias, RoomStreamToken, StreamKeyType, @@ -215,7 +216,7 @@ class ApplicationServicesHandler: def notify_interested_services_ephemeral( self, - stream_key: str, + stream_key: StreamKeyType, new_token: Union[int, RoomStreamToken], users: Collection[Union[str, UserID]], ) -> None: @@ -325,7 +326,7 @@ class ApplicationServicesHandler: async def _notify_interested_services_ephemeral( self, services: List[ApplicationService], - stream_key: str, + stream_key: StreamKeyType, new_token: int, users: Collection[Union[str, UserID]], ) -> None: @@ -397,7 +398,7 @@ class ApplicationServicesHandler: async def _handle_typing( self, service: ApplicationService, new_token: int - ) -> List[JsonDict]: + ) -> List[JsonMapping]: """ Return the typing events since the given stream token that the given application service should receive. @@ -432,7 +433,7 @@ class ApplicationServicesHandler: async def _handle_receipts( self, service: ApplicationService, new_token: int - ) -> List[JsonDict]: + ) -> List[JsonMapping]: """ Return the latest read receipts that the given application service should receive. @@ -471,7 +472,7 @@ class ApplicationServicesHandler: service: ApplicationService, users: Collection[Union[str, UserID]], new_token: Optional[int], - ) -> List[JsonDict]: + ) -> List[JsonMapping]: """ Return the latest presence updates that the given application service should receive. @@ -491,7 +492,7 @@ class ApplicationServicesHandler: A list of json dictionaries containing data derived from the presence events that should be sent to the given application service. """ - events: List[JsonDict] = [] + events: List[JsonMapping] = [] presence_source = self.event_sources.sources.presence from_key = await self.store.get_type_stream_id_for_appservice( service, "presence" diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 59e340974d..2b0c505130 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -52,7 +52,6 @@ from synapse.api.errors import ( NotFoundError, StoreError, SynapseError, - UserDeactivatedError, ) from synapse.api.ratelimiting import Ratelimiter from synapse.handlers.ui_auth import ( @@ -219,19 +218,17 @@ class AuthHandler: self._failed_uia_attempts_ratelimiter = Ratelimiter( store=self.store, clock=self.clock, - rate_hz=self.hs.config.ratelimiting.rc_login_failed_attempts.per_second, - burst_count=self.hs.config.ratelimiting.rc_login_failed_attempts.burst_count, + cfg=self.hs.config.ratelimiting.rc_login_failed_attempts, ) # The number of seconds to keep a UI auth session active. self._ui_auth_session_timeout = hs.config.auth.ui_auth_session_timeout - # Ratelimitier for failed /login attempts + # Ratelimiter for failed /login attempts self._failed_login_attempts_ratelimiter = Ratelimiter( store=self.store, clock=hs.get_clock(), - rate_hz=self.hs.config.ratelimiting.rc_login_failed_attempts.per_second, - burst_count=self.hs.config.ratelimiting.rc_login_failed_attempts.burst_count, + cfg=self.hs.config.ratelimiting.rc_login_failed_attempts, ) self._clock = self.hs.get_clock() @@ -275,6 +272,8 @@ class AuthHandler: # response. self._extra_attributes: Dict[str, SsoLoginExtraAttributes] = {} + self.msc3861_oauth_delegation_enabled = hs.config.experimental.msc3861.enabled + async def validate_user_via_ui_auth( self, requester: Requester, @@ -323,8 +322,12 @@ class AuthHandler: LimitExceededError if the ratelimiter's failed request count for this user is too high to proceed - """ + if self.msc3861_oauth_delegation_enabled: + raise SynapseError( + HTTPStatus.INTERNAL_SERVER_ERROR, "UIA shouldn't be used with MSC3861" + ) + if not requester.access_token_id: raise ValueError("Cannot validate a user without an access token") if can_skip_ui_auth and self._ui_auth_session_timeout: @@ -1419,12 +1422,6 @@ class AuthHandler: return None (user_id, password_hash) = lookupres - # If the password hash is None, the account has likely been deactivated - if not password_hash: - deactivated = await self.store.get_user_deactivated_status(user_id) - if deactivated: - raise UserDeactivatedError("This account has been deactivated") - result = await self.validate_hash(password, password_hash) if not result: logger.warning("Failed password login for user %s", user_id) @@ -1749,15 +1746,18 @@ class AuthHandler: registered. auth_provider_session_id: The session ID from the SSO IdP received during login. """ - # If the account has been deactivated, do not proceed with the login - # flow. + # If the account has been deactivated, do not proceed with the login. + # + # This gets checked again when the token is submitted but this lets us + # provide an HTML error page to the user (instead of issuing a token and + # having it error later). deactivated = await self.store.get_user_deactivated_status(registered_user_id) if deactivated: respond_with_html(request, 403, self._sso_account_deactivated_template) return user_profile_data = await self.store.get_profileinfo( - UserID.from_string(registered_user_id).localpart + UserID.from_string(registered_user_id) ) # Store any extra attributes which will be passed in the login response. diff --git a/synapse/handlers/cas.py b/synapse/handlers/cas.py index fc467bc7c1..b5b8b9bd35 100644 --- a/synapse/handlers/cas.py +++ b/synapse/handlers/cas.py @@ -67,8 +67,10 @@ class CasHandler: self._cas_server_url = hs.config.cas.cas_server_url self._cas_service_url = hs.config.cas.cas_service_url + self._cas_protocol_version = hs.config.cas.cas_protocol_version self._cas_displayname_attribute = hs.config.cas.cas_displayname_attribute self._cas_required_attributes = hs.config.cas.cas_required_attributes + self._cas_enable_registration = hs.config.cas.cas_enable_registration self._http_client = hs.get_proxied_http_client() @@ -76,12 +78,13 @@ class CasHandler: self.idp_id = "cas" # user-facing name of this auth provider - self.idp_name = "CAS" + self.idp_name = hs.config.cas.idp_name - # we do not currently support brands/icons for CAS auth, but this is required by - # the SsoIdentityProvider protocol type. - self.idp_icon = None - self.idp_brand = None + # MXC URI for icon for this auth provider + self.idp_icon = hs.config.cas.idp_icon + + # optional brand identifier for this auth provider + self.idp_brand = hs.config.cas.idp_brand self._sso_handler = hs.get_sso_handler() @@ -120,7 +123,10 @@ class CasHandler: Returns: The parsed CAS response. """ - uri = self._cas_server_url + "/proxyValidate" + if self._cas_protocol_version == 3: + uri = self._cas_server_url + "/p3/proxyValidate" + else: + uri = self._cas_server_url + "/proxyValidate" args = { "ticket": ticket, "service": self._build_service_param(service_args), @@ -390,4 +396,5 @@ class CasHandler: client_redirect_url, cas_response_to_user_attributes, grandfather_existing_users, + registration_enabled=self._cas_enable_registration, ) diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index f299b89a1b..67adeae6a7 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -297,5 +297,5 @@ class DeactivateAccountHandler: # Add the user to the directory, if necessary. Note that # this must be done after the user is re-activated, because # deactivated users are excluded from the user directory. - profile = await self.store.get_profileinfo(user.localpart) + profile = await self.store.get_profileinfo(user) await self.user_directory_handler.handle_local_profile_change(user_id, profile) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 5d12a39e26..86ad96d030 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -43,9 +43,12 @@ from synapse.metrics.background_process_metrics import ( ) from synapse.types import ( JsonDict, + JsonMapping, + ScheduledTask, StrCollection, StreamKeyType, StreamToken, + TaskStatus, UserID, get_domain_from_id, get_verify_key_from_cross_signing_key, @@ -55,13 +58,17 @@ from synapse.util.async_helpers import Linearizer from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.cancellation import cancellable from synapse.util.metrics import measure_func -from synapse.util.retryutils import NotRetryingDestination +from synapse.util.retryutils import ( + NotRetryingDestination, + filter_destinations_by_retry_limiter, +) if TYPE_CHECKING: from synapse.server import HomeServer logger = logging.getLogger(__name__) +DELETE_DEVICE_MSGS_TASK_NAME = "delete_device_messages" MAX_DEVICE_DISPLAY_NAME_LEN = 100 DELETE_STALE_DEVICES_INTERVAL_MS = 24 * 60 * 60 * 1000 @@ -78,14 +85,20 @@ class DeviceWorkerHandler: self._appservice_handler = hs.get_application_service_handler() self._state_storage = hs.get_storage_controllers().state self._auth_handler = hs.get_auth_handler() + self._event_sources = hs.get_event_sources() self.server_name = hs.hostname self._msc3852_enabled = hs.config.experimental.msc3852_enabled self._query_appservices_for_keys = ( hs.config.experimental.msc3984_appservice_key_query ) + self._task_scheduler = hs.get_task_scheduler() self.device_list_updater = DeviceListWorkerUpdater(hs) + self._task_scheduler.register_action( + self._delete_device_messages, DELETE_DEVICE_MSGS_TASK_NAME + ) + @trace async def get_devices_by_user(self, user_id: str) -> List[JsonDict]: """ @@ -375,6 +388,33 @@ class DeviceWorkerHandler: "Trying handling device list state for partial join: not supported on workers." ) + DEVICE_MSGS_DELETE_BATCH_LIMIT = 1000 + DEVICE_MSGS_DELETE_SLEEP_MS = 1000 + + async def _delete_device_messages( + self, + task: ScheduledTask, + ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]: + """Scheduler task to delete device messages in batch of `DEVICE_MSGS_DELETE_BATCH_LIMIT`.""" + assert task.params is not None + user_id = task.params["user_id"] + device_id = task.params["device_id"] + up_to_stream_id = task.params["up_to_stream_id"] + + # Delete the messages in batches to avoid too much DB load. + while True: + res = await self.store.delete_messages_for_device( + user_id=user_id, + device_id=device_id, + up_to_stream_id=up_to_stream_id, + limit=DeviceHandler.DEVICE_MSGS_DELETE_BATCH_LIMIT, + ) + + if res < DeviceHandler.DEVICE_MSGS_DELETE_BATCH_LIMIT: + return TaskStatus.COMPLETE, None, None + + await self.clock.sleep(DeviceHandler.DEVICE_MSGS_DELETE_SLEEP_MS / 1000.0) + class DeviceHandler(DeviceWorkerHandler): device_list_updater: "DeviceListUpdater" @@ -385,6 +425,7 @@ class DeviceHandler(DeviceWorkerHandler): self.federation_sender = hs.get_federation_sender() self._account_data_handler = hs.get_account_data_handler() self._storage_controllers = hs.get_storage_controllers() + self.db_pool = hs.get_datastores().main.db_pool self.device_list_updater = DeviceListUpdater(hs, self) @@ -529,6 +570,7 @@ class DeviceHandler(DeviceWorkerHandler): user_id: The user to delete devices from. device_ids: The list of device IDs to delete """ + to_device_stream_id = self._event_sources.get_current_token().to_device_key try: await self.store.delete_devices(user_id, device_ids) @@ -558,6 +600,17 @@ class DeviceHandler(DeviceWorkerHandler): f"org.matrix.msc3890.local_notification_settings.{device_id}", ) + # Delete device messages asynchronously and in batches using the task scheduler + await self._task_scheduler.schedule_task( + DELETE_DEVICE_MSGS_TASK_NAME, + resource_id=device_id, + params={ + "user_id": user_id, + "device_id": device_id, + "up_to_stream_id": to_device_stream_id, + }, + ) + # Pushers are deleted after `delete_access_tokens_for_user` is called so that # modules using `on_logged_out` hook can use them if needed. await self.hs.get_pusherpool().remove_pushers_by_devices(user_id, device_ids) @@ -653,29 +706,38 @@ class DeviceHandler(DeviceWorkerHandler): async def store_dehydrated_device( self, user_id: str, + device_id: Optional[str], device_data: JsonDict, initial_device_display_name: Optional[str] = None, + keys_for_device: Optional[JsonDict] = None, ) -> str: - """Store a dehydrated device for a user. If the user had a previous - dehydrated device, it is removed. + """Store a dehydrated device for a user, optionally storing the keys associated with + it as well. If the user had a previous dehydrated device, it is removed. Args: user_id: the user that we are storing the device for + device_id: device id supplied by client device_data: the dehydrated device information initial_device_display_name: The display name to use for the device + keys_for_device: keys for the dehydrated device Returns: device id of the dehydrated device """ device_id = await self.check_device_registered( user_id, - None, + device_id, initial_device_display_name, ) + + time_now = self.clock.time_msec() + old_device_id = await self.store.store_dehydrated_device( - user_id, device_id, device_data + user_id, device_id, device_data, time_now, keys_for_device ) + if old_device_id is not None: await self.delete_devices(user_id, [old_device_id]) + return device_id async def rehydrate_device( @@ -697,12 +759,13 @@ class DeviceHandler(DeviceWorkerHandler): # If the dehydrated device was successfully deleted (the device ID # matched the stored dehydrated device), then modify the access - # token to use the dehydrated device's ID and copy the old device - # display name to the dehydrated device, and destroy the old device - # ID + # token and refresh token to use the dehydrated device's ID and + # copy the old device display name to the dehydrated device, + # and destroy the old device ID old_device_id = await self.store.set_device_for_access_token( access_token, device_id ) + await self.store.set_device_for_refresh_token(user_id, old_device_id, device_id) old_device = await self.store.get_device(user_id, old_device_id) if old_device is None: raise errors.NotFoundError() @@ -720,6 +783,22 @@ class DeviceHandler(DeviceWorkerHandler): return {"success": True} + async def delete_dehydrated_device(self, user_id: str, device_id: str) -> None: + """ + Delete a stored dehydrated device. + + Args: + user_id: the user_id to delete the device from + device_id: id of the dehydrated device to delete + """ + success = await self.store.remove_dehydrated_device(user_id, device_id) + + if not success: + raise errors.NotFoundError() + + await self.delete_devices(user_id, [device_id]) + await self.store.delete_e2e_keys_by_device(user_id=user_id, device_id=device_id) + @wrap_as_background_process("_handle_new_device_update_async") async def _handle_new_device_update_async(self) -> None: """Called when we have a new local device list update that we need to @@ -810,17 +889,16 @@ class DeviceHandler(DeviceWorkerHandler): user_id, hosts, ) - for host in hosts: - self.federation_sender.send_device_messages( - host, immediate=False - ) - # TODO: when called, this isn't in a logging context. - # This leads to log spam, sentry event spam, and massive - # memory usage. - # See https://github.com/matrix-org/synapse/issues/12552. - # log_kv( - # {"message": "sent device update to host", "host": host} - # ) + await self.federation_sender.send_device_messages( + hosts, immediate=False + ) + # TODO: when called, this isn't in a logging context. + # This leads to log spam, sentry event spam, and massive + # memory usage. + # See https://github.com/matrix-org/synapse/issues/12552. + # log_kv( + # {"message": "sent device update to host", "host": host} + # ) if current_stream_id != stream_id: # Clear the set of hosts we've already sent to as we're @@ -925,8 +1003,9 @@ class DeviceHandler(DeviceWorkerHandler): # Notify things that device lists need to be sent out. self.notifier.notify_replication() - for host in potentially_changed_hosts: - self.federation_sender.send_device_messages(host, immediate=False) + await self.federation_sender.send_device_messages( + potentially_changed_hosts, immediate=False + ) def _update_device_from_client_ips( @@ -956,7 +1035,7 @@ class DeviceListWorkerUpdater: async def multi_user_device_resync( self, user_ids: List[str], mark_failed_as_stale: bool = True - ) -> Dict[str, Optional[JsonDict]]: + ) -> Dict[str, Optional[JsonMapping]]: """ Like `user_device_resync` but operates on multiple users **from the same origin** at once. @@ -985,6 +1064,7 @@ class DeviceListUpdater(DeviceListWorkerUpdater): self._notifier = hs.get_notifier() self._remote_edu_linearizer = Linearizer(name="remote_device_list") + self._resync_linearizer = Linearizer(name="remote_device_resync") # user_id -> list of updates waiting to be handled. self._pending_updates: Dict[ @@ -1124,7 +1204,14 @@ class DeviceListUpdater(DeviceListWorkerUpdater): ) if resync: - await self.multi_user_device_resync([user_id]) + # We mark as stale up front in case we get restarted. + await self.store.mark_remote_users_device_caches_as_stale([user_id]) + run_as_background_process( + "_maybe_retry_device_resync", + self.multi_user_device_resync, + [user_id], + False, + ) else: # Simply update the single device, since we know that is the only # change (because of the single prev_id matching the current cache) @@ -1187,8 +1274,18 @@ class DeviceListUpdater(DeviceListWorkerUpdater): self._resync_retry_in_progress = True # Get all of the users that need resyncing. need_resync = await self.store.get_user_ids_requiring_device_list_resync() + + # Filter out users whose host is marked as "down" up front. + hosts = await filter_destinations_by_retry_limiter( + {get_domain_from_id(u) for u in need_resync}, self.clock, self.store + ) + hosts = set(hosts) + # Iterate over the set of user IDs. for user_id in need_resync: + if get_domain_from_id(user_id) not in hosts: + continue + try: # Try to resync the current user's devices list. result = (await self.multi_user_device_resync([user_id], False))[ @@ -1220,7 +1317,7 @@ class DeviceListUpdater(DeviceListWorkerUpdater): async def multi_user_device_resync( self, user_ids: List[str], mark_failed_as_stale: bool = True - ) -> Dict[str, Optional[JsonDict]]: + ) -> Dict[str, Optional[JsonMapping]]: """ Like `user_device_resync` but operates on multiple users **from the same origin** at once. @@ -1240,9 +1337,11 @@ class DeviceListUpdater(DeviceListWorkerUpdater): failed = set() # TODO(Perf): Actually batch these up for user_id in user_ids: - user_result, user_failed = await self._user_device_resync_returning_failed( - user_id - ) + async with self._resync_linearizer.queue(user_id): + ( + user_result, + user_failed, + ) = await self._user_device_resync_returning_failed(user_id) result[user_id] = user_result if user_failed: failed.add(user_id) @@ -1254,7 +1353,7 @@ class DeviceListUpdater(DeviceListWorkerUpdater): async def _user_device_resync_returning_failed( self, user_id: str - ) -> Tuple[Optional[JsonDict], bool]: + ) -> Tuple[Optional[JsonMapping], bool]: """Fetches all devices for a user and updates the device cache with them. Args: @@ -1267,6 +1366,12 @@ class DeviceListUpdater(DeviceListWorkerUpdater): e.g. due to a connection problem. - True iff the resync failed and the device list should be marked as stale. """ + # Check that we haven't gone and fetched the devices since we last + # checked if we needed to resync these device lists. + if await self.store.get_users_whose_devices_are_cached([user_id]): + cached = await self.store.get_cached_devices_for_user(user_id) + return cached, False + logger.debug("Attempting to resync the device list for %s", user_id) log_kv({"message": "Doing resync to update device list."}) # Fetch all devices for the user. diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index 3caf9b31cc..1c79f7a61e 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -13,10 +13,11 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Any, Dict +from http import HTTPStatus +from typing import TYPE_CHECKING, Any, Dict, Optional from synapse.api.constants import EduTypes, EventContentFields, ToDeviceEventTypes -from synapse.api.errors import SynapseError +from synapse.api.errors import Codes, SynapseError from synapse.api.ratelimiting import Ratelimiter from synapse.logging.context import run_in_background from synapse.logging.opentracing import ( @@ -48,6 +49,9 @@ class DeviceMessageHandler: self.store = hs.get_datastores().main self.notifier = hs.get_notifier() self.is_mine = hs.is_mine + if hs.config.experimental.msc3814_enabled: + self.event_sources = hs.get_event_sources() + self.device_handler = hs.get_device_handler() # We only need to poke the federation sender explicitly if its on the # same instance. Other federation sender instances will get notified by @@ -86,8 +90,7 @@ class DeviceMessageHandler: self._ratelimiter = Ratelimiter( store=self.store, clock=hs.get_clock(), - rate_hz=hs.config.ratelimiting.rc_key_requests.per_second, - burst_count=hs.config.ratelimiting.rc_key_requests.burst_count, + cfg=hs.config.ratelimiting.rc_key_requests, ) async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None: @@ -299,7 +302,93 @@ class DeviceMessageHandler: ) if self.federation_sender: - for destination in remote_messages.keys(): - # Enqueue a new federation transaction to send the new - # device messages to each remote destination. - self.federation_sender.send_device_messages(destination) + # Enqueue a new federation transaction to send the new + # device messages to each remote destination. + await self.federation_sender.send_device_messages(remote_messages.keys()) + + async def get_events_for_dehydrated_device( + self, + requester: Requester, + device_id: str, + since_token: Optional[str], + limit: int, + ) -> JsonDict: + """Fetches up to `limit` events sent to `device_id` starting from `since_token` + and returns the new since token. If there are no more messages, returns an empty + array. + + Args: + requester: the user requesting the messages + device_id: ID of the dehydrated device + since_token: stream id to start from when fetching messages + limit: the number of messages to fetch + Returns: + A dict containing the to-device messages, as well as a token that the client + can provide in the next call to fetch the next batch of messages + """ + + user_id = requester.user.to_string() + + # only allow fetching messages for the dehydrated device id currently associated + # with the user + dehydrated_device = await self.device_handler.get_dehydrated_device(user_id) + if dehydrated_device is None: + raise SynapseError( + HTTPStatus.FORBIDDEN, + "No dehydrated device exists", + Codes.FORBIDDEN, + ) + + dehydrated_device_id, _ = dehydrated_device + if device_id != dehydrated_device_id: + raise SynapseError( + HTTPStatus.FORBIDDEN, + "You may only fetch messages for your dehydrated device", + Codes.FORBIDDEN, + ) + + since_stream_id = 0 + if since_token: + if not since_token.startswith("d"): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "from parameter %r has an invalid format" % (since_token,), + errcode=Codes.INVALID_PARAM, + ) + + try: + since_stream_id = int(since_token[1:]) + except Exception: + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "from parameter %r has an invalid format" % (since_token,), + errcode=Codes.INVALID_PARAM, + ) + + to_token = self.event_sources.get_current_token().to_device_key + + messages, stream_id = await self.store.get_messages_for_device( + user_id, device_id, since_stream_id, to_token, limit + ) + + for message in messages: + # Remove the message id before sending to client + message_id = message.pop("message_id", None) + if message_id: + set_tag(SynapseTags.TO_DEVICE_EDU_ID, message_id) + + logger.debug( + "Returning %d to-device messages between %d and %d (current token: %d) for " + "dehydrated device %s, user_id %s", + len(messages), + since_stream_id, + stream_id, + to_token, + device_id, + user_id, + ) + + return { + "events": messages, + "next_batch": f"d{stream_id}", + } diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 1e0623c7f8..623a4e7b1d 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -277,7 +277,9 @@ class DirectoryHandler: except RequestSendFailed: raise SynapseError(502, "Failed to fetch alias") except CodeMessageException as e: - logging.warning("Error retrieving alias") + logging.warning( + "Error retrieving alias %s -> %s %s", room_alias, e.code, e.msg + ) if e.code == 404: fed_result = None else: diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 24741b667b..8c6432035d 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Tuple +from typing import TYPE_CHECKING, Dict, Iterable, List, Mapping, Optional, Tuple import attr from canonicaljson import encode_canonical_json @@ -31,6 +31,7 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace from synapse.types import ( JsonDict, + JsonMapping, UserID, get_domain_from_id, get_verify_key_from_cross_signing_key, @@ -272,11 +273,7 @@ class E2eKeysHandler: delay_cancellation=True, ) - ret = {"device_keys": results, "failures": failures} - - ret.update(cross_signing_keys) - - return ret + return {"device_keys": results, "failures": failures, **cross_signing_keys} @trace async def _query_devices_for_destination( @@ -408,7 +405,7 @@ class E2eKeysHandler: @cancellable async def get_cross_signing_keys_from_cache( self, query: Iterable[str], from_user_id: Optional[str] - ) -> Dict[str, Dict[str, dict]]: + ) -> Dict[str, Dict[str, JsonMapping]]: """Get cross-signing keys for users from the database Args: @@ -551,16 +548,13 @@ class E2eKeysHandler: self.config.federation.allow_device_name_lookup_over_federation ), ) - ret = {"device_keys": res} # add in the cross-signing keys cross_signing_keys = await self.get_cross_signing_keys_from_cache( device_keys_query, None ) - ret.update(cross_signing_keys) - - return ret + return {"device_keys": res, **cross_signing_keys} async def claim_local_one_time_keys( self, @@ -661,6 +655,7 @@ class E2eKeysHandler: async def claim_one_time_keys( self, query: Dict[str, Dict[str, Dict[str, int]]], + user: UserID, timeout: Optional[int], always_include_fallback_keys: bool, ) -> JsonDict: @@ -703,7 +698,7 @@ class E2eKeysHandler: device_keys = remote_queries[destination] try: remote_result = await self.federation.claim_client_keys( - destination, device_keys, timeout=timeout + user, destination, device_keys, timeout=timeout ) for user_id, keys in remote_result["one_time_keys"].items(): if user_id in device_keys: @@ -1126,7 +1121,7 @@ class E2eKeysHandler: user_id: str, master_key_id: str, signed_master_key: JsonDict, - stored_master_key: JsonDict, + stored_master_key: JsonMapping, devices: Dict[str, Dict[str, JsonDict]], ) -> List["SignatureListItem"]: """Check signatures of a user's master key made by their devices. @@ -1277,7 +1272,7 @@ class E2eKeysHandler: async def _get_e2e_cross_signing_verify_key( self, user_id: str, key_type: str, from_user_id: Optional[str] = None - ) -> Tuple[JsonDict, str, VerifyKey]: + ) -> Tuple[JsonMapping, str, VerifyKey]: """Fetch locally or remotely query for a cross-signing public key. First, attempt to fetch the cross-signing public key from storage. @@ -1332,7 +1327,7 @@ class E2eKeysHandler: self, user: UserID, desired_key_type: str, - ) -> Optional[Tuple[Dict[str, Any], str, VerifyKey]]: + ) -> Optional[Tuple[JsonMapping, str, VerifyKey]]: """Queries cross-signing keys for a remote user and saves them to the database Only the key specified by `key_type` will be returned, while all retrieved keys @@ -1473,7 +1468,7 @@ def _check_device_signature( user_id: str, verify_key: VerifyKey, signed_device: JsonDict, - stored_device: JsonDict, + stored_device: JsonMapping, ) -> None: """Check that a signature on a device or cross-signing key is correct and matches the copy of the device/key that we have stored. Throws an diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py index 3e37c0cbe2..82a7617a08 100644 --- a/synapse/handlers/event_auth.py +++ b/synapse/handlers/event_auth.py @@ -277,7 +277,7 @@ class EventAuthHandler: True if the proper room version and join rules are set for restricted access. """ # This only applies to room versions which support the new join rule. - if not room_version.msc3083_join_rules: + if not room_version.restricted_join_rule: return False # If there's no join rule, then it defaults to invite (so this doesn't apply). @@ -292,7 +292,7 @@ class EventAuthHandler: return True # also check for MSC3787 behaviour - if room_version.msc3787_knock_restricted_join_rule: + if room_version.knock_restricted_join_rule: return content_join_rule == JoinRules.KNOCK_RESTRICTED return False diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 33359f6ed7..d12803bf0f 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -67,6 +67,7 @@ class EventStreamHandler: context = await presence_handler.user_syncing( requester.user.to_string(), + requester.device_id, affect_presence=affect_presence, presence_state=PresenceState.ONLINE, ) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 2eb28d55ac..807a0867cc 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -60,6 +60,7 @@ from synapse.events import EventBase from synapse.events.snapshot import EventContext, UnpersistedEventContextBase from synapse.events.validator import EventValidator from synapse.federation.federation_client import InvalidResponseError +from synapse.handlers.pagination import PURGE_PAGINATION_LOCK_NAME from synapse.http.servlet import assert_params_in_dict from synapse.logging.context import nested_logging_context from synapse.logging.opentracing import SynapseTags, set_tag, tag_args, trace @@ -105,14 +106,12 @@ backfill_processing_before_timer = Histogram( ) +# TODO: We can refactor this away now that there is only one backfill point again class _BackfillPointType(Enum): # a regular backwards extremity (ie, an event which we don't yet have, but which # is referred to by other events in the DAG) BACKWARDS_EXTREMITY = enum.auto() - # an MSC2716 "insertion event" - INSERTION_PONT = enum.auto() - @attr.s(slots=True, auto_attribs=True, frozen=True) class _BackfillPoint: @@ -154,6 +153,7 @@ class FederationHandler: self._device_handler = hs.get_device_handler() self._bulk_push_rule_evaluator = hs.get_bulk_push_rule_evaluator() self._notifier = hs.get_notifier() + self._worker_locks = hs.get_worker_locks_handler() self._clean_room_for_join_client = ReplicationCleanRoomRestServlet.make_client( hs @@ -200,8 +200,9 @@ class FederationHandler: ) @trace + @tag_args async def maybe_backfill( - self, room_id: str, current_depth: int, limit: int + self, room_id: str, current_depth: int, limit: int, record_time: bool = True ) -> bool: """Checks the database to see if we should backfill before paginating, and if so do. @@ -214,19 +215,28 @@ class FederationHandler: limit: The number of events that the pagination request will return. This is used as part of the heuristic to decide if we should back paginate. + record_time: Whether to record the time it takes to backfill. + + Returns: + True if we actually tried to backfill something, otherwise False. """ # Starting the processing time here so we can include the room backfill # linearizer lock queue in the timing - processing_start_time = self.clock.time_msec() + processing_start_time = self.clock.time_msec() if record_time else 0 async with self._room_backfill.queue(room_id): - return await self._maybe_backfill_inner( - room_id, - current_depth, - limit, - processing_start_time=processing_start_time, - ) + async with self._worker_locks.acquire_read_write_lock( + PURGE_PAGINATION_LOCK_NAME, room_id, write=False + ): + return await self._maybe_backfill_inner( + room_id, + current_depth, + limit, + processing_start_time=processing_start_time, + ) + @trace + @tag_args async def _maybe_backfill_inner( self, room_id: str, @@ -247,6 +257,9 @@ class FederationHandler: limit: The max number of events to request from the remote federated server. processing_start_time: The time when `maybe_backfill` started processing. Only used for timing. If `None`, no timing observation will be made. + + Returns: + True if we actually tried to backfill something, otherwise False. """ backwards_extremities = [ _BackfillPoint(event_id, depth, _BackfillPointType.BACKWARDS_EXTREMITY) @@ -264,32 +277,10 @@ class FederationHandler: ) ] - insertion_events_to_be_backfilled: List[_BackfillPoint] = [] - if self.hs.config.experimental.msc2716_enabled: - insertion_events_to_be_backfilled = [ - _BackfillPoint(event_id, depth, _BackfillPointType.INSERTION_PONT) - for event_id, depth in await self.store.get_insertion_event_backward_extremities_in_room( - room_id=room_id, - current_depth=current_depth, - # We only need to end up with 5 extremities combined with - # the backfill points to make the `/backfill` request ... - # (see the other comment above for more context). - limit=50, - ) - ] - logger.debug( - "_maybe_backfill_inner: backwards_extremities=%s insertion_events_to_be_backfilled=%s", - backwards_extremities, - insertion_events_to_be_backfilled, - ) - # we now have a list of potential places to backpaginate from. We prefer to # start with the most recent (ie, max depth), so let's sort the list. sorted_backfill_points: List[_BackfillPoint] = sorted( - itertools.chain( - backwards_extremities, - insertion_events_to_be_backfilled, - ), + backwards_extremities, key=lambda e: -int(e.depth), ) @@ -302,15 +293,39 @@ class FederationHandler: len(sorted_backfill_points), sorted_backfill_points, ) + set_tag( + SynapseTags.RESULT_PREFIX + "sorted_backfill_points", + str(sorted_backfill_points), + ) + set_tag( + SynapseTags.RESULT_PREFIX + "sorted_backfill_points.length", + str(len(sorted_backfill_points)), + ) - # If we have no backfill points lower than the `current_depth` then - # either we can a) bail or b) still attempt to backfill. We opt to try - # backfilling anyway just in case we do get relevant events. + # If we have no backfill points lower than the `current_depth` then either we + # can a) bail or b) still attempt to backfill. We opt to try backfilling anyway + # just in case we do get relevant events. This is good for eventual consistency + # sake but we don't need to block the client for something that is just as + # likely not to return anything relevant so we backfill in the background. The + # only way, this could return something relevant is if we discover a new branch + # of history that extends all the way back to where we are currently paginating + # and it's within the 100 events that are returned from `/backfill`. if not sorted_backfill_points and current_depth != MAX_DEPTH: + # Check that we actually have later backfill points, if not just return. + have_later_backfill_points = await self.store.get_backfill_points_in_room( + room_id=room_id, + current_depth=MAX_DEPTH, + limit=1, + ) + if not have_later_backfill_points: + return False + logger.debug( "_maybe_backfill_inner: all backfill points are *after* current depth. Trying again with later backfill points." ) - return await self._maybe_backfill_inner( + run_as_background_process( + "_maybe_backfill_inner_anyway_with_max_depth", + self.maybe_backfill, room_id=room_id, # We use `MAX_DEPTH` so that we find all backfill points next # time (all events are below the `MAX_DEPTH`) @@ -319,8 +334,11 @@ class FederationHandler: # We don't want to start another timing observation from this # nested recursive call. The top-most call can record the time # overall otherwise the smaller one will throw off the results. - processing_start_time=None, + record_time=False, ) + # We return `False` because we're backfilling in the background and there is + # no new events immediately for the caller to know about yet. + return False # Even after recursing with `MAX_DEPTH`, we didn't find any # backward extremities to backfill from. @@ -384,10 +402,7 @@ class FederationHandler: # event but not anything before it. This would require looking at the # state *before* the event, ignoring the special casing certain event # types have. - if bp.type == _BackfillPointType.INSERTION_PONT: - event_ids_to_check = [bp.event_id] - else: - event_ids_to_check = await self.store.get_successor_events(bp.event_id) + event_ids_to_check = await self.store.get_successor_events(bp.event_id) events_to_check = await self.store.get_events_as_list( event_ids_to_check, @@ -853,19 +868,10 @@ class FederationHandler: # This is a bit of a hack and is cribbing off of invites. Basically we # store the room state here and retrieve it again when this event appears # in the invitee's sync stream. It is stripped out for all other local users. - stripped_room_state = ( - knock_response.get("knock_room_state") - # Since v1.37, Synapse incorrectly used "knock_state_events" for this field. - # Thus, we also check for a 'knock_state_events' to support old instances. - # See https://github.com/matrix-org/synapse/issues/14088. - or knock_response.get("knock_state_events") - ) + stripped_room_state = knock_response.get("knock_room_state") if stripped_room_state is None: - raise KeyError( - "Missing 'knock_room_state' (or legacy 'knock_state_events') field in " - "send_knock response" - ) + raise KeyError("Missing 'knock_room_state' field in send_knock response") event.unsigned["knock_room_state"] = stripped_room_state @@ -957,7 +963,7 @@ class FederationHandler: # Note that this requires the /send_join request to come back to the # same server. prev_event_ids = None - if room_version.msc3083_join_rules: + if room_version.restricted_join_rule: # Note that the room's state can change out from under us and render our # nice join rules-conformant event non-conformant by the time we build the # event. When this happens, our validation at the end fails and we respond @@ -1581,9 +1587,7 @@ class FederationHandler: event.content["third_party_invite"]["signed"]["token"], ) original_invite = None - prev_state_ids = await context.get_prev_state_ids( - StateFilter.from_types([(EventTypes.ThirdPartyInvite, None)]) - ) + prev_state_ids = await context.get_prev_state_ids(StateFilter.from_types([key])) original_invite_id = prev_state_ids.get(key) if original_invite_id: original_invite = await self.store.get_event( @@ -1636,7 +1640,7 @@ class FederationHandler: token = signed["token"] prev_state_ids = await context.get_prev_state_ids( - StateFilter.from_types([(EventTypes.ThirdPartyInvite, None)]) + StateFilter.from_types([(EventTypes.ThirdPartyInvite, token)]) ) invite_event_id = prev_state_ids.get((EventTypes.ThirdPartyInvite, token)) diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 06343d40e4..0cc8e990d9 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -88,7 +88,7 @@ from synapse.types import ( ) from synapse.types.state import StateFilter from synapse.util.async_helpers import Linearizer, concurrently_execute -from synapse.util.iterutils import batch_iter +from synapse.util.iterutils import batch_iter, partition from synapse.util.retryutils import NotRetryingDestination from synapse.util.stringutils import shortstr @@ -601,18 +601,6 @@ class FederationEventHandler: room_id, [(event, context)] ) - # If we're joining the room again, check if there is new marker - # state indicating that there is new history imported somewhere in - # the DAG. Multiple markers can exist in the current state with - # unique state_keys. - # - # Do this after the state from the remote join was persisted (via - # `persist_events_and_notify`). Otherwise we can run into a - # situation where the create event doesn't exist yet in the - # `current_state_events` - for e in state: - await self._handle_marker_event(origin, e) - return stream_id_after_persist async def update_state_for_partial_state_event( @@ -735,12 +723,11 @@ class FederationEventHandler: if not prevs - seen: return - latest_list = await self._store.get_latest_event_ids_in_room(room_id) + latest_frozen = await self._store.get_latest_event_ids_in_room(room_id) # We add the prev events that we have seen to the latest # list to ensure the remote server doesn't give them to us - latest = set(latest_list) - latest |= seen + latest = seen | latest_frozen logger.info( "Requesting missing events between %s and %s", @@ -865,7 +852,7 @@ class FederationEventHandler: [event.event_id for event in events] ) - new_events = [] + new_events: List[EventBase] = [] for event in events: event_id = event.event_id @@ -890,12 +877,64 @@ class FederationEventHandler: # Continue on with the events that are new to us. new_events.append(event) - # We want to sort these by depth so we process them and - # tell clients about them in order. - sorted_events = sorted(new_events, key=lambda x: x.depth) - for ev in sorted_events: - with nested_logging_context(ev.event_id): - await self._process_pulled_event(origin, ev, backfilled=backfilled) + set_tag( + SynapseTags.RESULT_PREFIX + "new_events.length", + str(len(new_events)), + ) + + @trace + async def _process_new_pulled_events(new_events: Collection[EventBase]) -> None: + # We want to sort these by depth so we process them and tell clients about + # them in order. It's also more efficient to backfill this way (`depth` + # ascending) because one backfill event is likely to be the `prev_event` of + # the next event we're going to process. + sorted_events = sorted(new_events, key=lambda x: x.depth) + for ev in sorted_events: + with nested_logging_context(ev.event_id): + await self._process_pulled_event(origin, ev, backfilled=backfilled) + + # Check if we've already tried to process these events at some point in the + # past. We aren't concerned with the expontntial backoff here, just whether it + # has failed to be processed before. + event_ids_with_failed_pull_attempts = ( + await self._store.get_event_ids_with_failed_pull_attempts( + [event.event_id for event in new_events] + ) + ) + + events_with_failed_pull_attempts, fresh_events = partition( + new_events, lambda e: e.event_id in event_ids_with_failed_pull_attempts + ) + set_tag( + SynapseTags.FUNC_ARG_PREFIX + "events_with_failed_pull_attempts", + str(event_ids_with_failed_pull_attempts), + ) + set_tag( + SynapseTags.RESULT_PREFIX + "events_with_failed_pull_attempts.length", + str(len(events_with_failed_pull_attempts)), + ) + set_tag( + SynapseTags.FUNC_ARG_PREFIX + "fresh_events", + str([event.event_id for event in fresh_events]), + ) + set_tag( + SynapseTags.RESULT_PREFIX + "fresh_events.length", + str(len(fresh_events)), + ) + + # Process previously failed backfill events in the background to not waste + # time on something that is likely to fail again. + if len(events_with_failed_pull_attempts) > 0: + run_as_background_process( + "_process_new_pulled_events_with_failed_pull_attempts", + _process_new_pulled_events, + events_with_failed_pull_attempts, + ) + + # We can optimistically try to process and wait for the event to be fully + # persisted if we've never tried before. + if len(fresh_events) > 0: + await _process_new_pulled_events(fresh_events) @trace @tag_args @@ -1401,8 +1440,6 @@ class FederationEventHandler: await self._run_push_actions_and_persist_event(event, context, backfilled) - await self._handle_marker_event(origin, event) - if backfilled or context.rejected: return @@ -1500,96 +1537,8 @@ class FederationEventHandler: except Exception: logger.exception("Failed to resync device for %s", sender) - @trace - async def _handle_marker_event(self, origin: str, marker_event: EventBase) -> None: - """Handles backfilling the insertion event when we receive a marker - event that points to one. - - Args: - origin: Origin of the event. Will be called to get the insertion event - marker_event: The event to process - """ - - if marker_event.type != EventTypes.MSC2716_MARKER: - # Not a marker event - return - - if marker_event.rejected_reason is not None: - # Rejected event - return - - # Skip processing a marker event if the room version doesn't - # support it or the event is not from the room creator. - room_version = await self._store.get_room_version(marker_event.room_id) - create_event = await self._store.get_create_event_for_room(marker_event.room_id) - if not room_version.msc2175_implicit_room_creator: - room_creator = create_event.content.get(EventContentFields.ROOM_CREATOR) - else: - room_creator = create_event.sender - if not room_version.msc2716_historical and ( - not self._config.experimental.msc2716_enabled - or marker_event.sender != room_creator - ): - return - - logger.debug("_handle_marker_event: received %s", marker_event) - - insertion_event_id = marker_event.content.get( - EventContentFields.MSC2716_INSERTION_EVENT_REFERENCE - ) - - if insertion_event_id is None: - # Nothing to retrieve then (invalid marker) - return - - already_seen_insertion_event = await self._store.have_seen_event( - marker_event.room_id, insertion_event_id - ) - if already_seen_insertion_event: - # No need to process a marker again if we have already seen the - # insertion event that it was pointing to - return - - logger.debug( - "_handle_marker_event: backfilling insertion event %s", insertion_event_id - ) - - await self._get_events_and_persist( - origin, - marker_event.room_id, - [insertion_event_id], - ) - - insertion_event = await self._store.get_event( - insertion_event_id, allow_none=True - ) - if insertion_event is None: - logger.warning( - "_handle_marker_event: server %s didn't return insertion event %s for marker %s", - origin, - insertion_event_id, - marker_event.event_id, - ) - return - - logger.debug( - "_handle_marker_event: succesfully backfilled insertion event %s from marker event %s", - insertion_event, - marker_event, - ) - - await self._store.insert_insertion_extremity( - insertion_event_id, marker_event.room_id - ) - - logger.debug( - "_handle_marker_event: insertion extremity added for %s from marker event %s", - insertion_event, - marker_event, - ) - async def backfill_event_id( - self, destinations: List[str], room_id: str, event_id: str + self, destinations: StrCollection, room_id: str, event_id: str ) -> PulledPduInfo: """Backfill a single event and persist it as a non-outlier which means we also pull in all of the state and auth events necessary for it. @@ -2026,8 +1975,7 @@ class FederationEventHandler: # partial and full state and may not be accurate. return - extrem_ids_list = await self._store.get_latest_event_ids_in_room(event.room_id) - extrem_ids = set(extrem_ids_list) + extrem_ids = await self._store.get_latest_event_ids_in_room(event.room_id) prev_event_ids = set(event.prev_event_ids()) if extrem_ids == prev_event_ids: @@ -2394,6 +2342,12 @@ class FederationEventHandler: # TODO retrieve the previous state, and exclude join -> join transitions self._notifier.notify_user_joined_room(event.event_id, event.room_id) + # If this is a server ACL event, clear the cache in the storage controller. + if event.type == EventTypes.ServerACL: + self._state_storage_controller.get_server_acl_for_room.invalidate( + (event.room_id,) + ) + def _sanity_check_event(self, ev: EventBase) -> None: """ Do some early sanity checks of a received event diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 3031384d25..472879c964 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -66,14 +66,12 @@ class IdentityHandler: self._3pid_validation_ratelimiter_ip = Ratelimiter( store=self.store, clock=hs.get_clock(), - rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second, - burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count, + cfg=hs.config.ratelimiting.rc_3pid_validation, ) self._3pid_validation_ratelimiter_address = Ratelimiter( store=self.store, clock=hs.get_clock(), - rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second, - burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count, + cfg=hs.config.ratelimiting.rc_3pid_validation, ) async def ratelimit_request_token_requests( diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index b3be7a86f0..c34bd7db95 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, List, Optional, Tuple, cast +from typing import TYPE_CHECKING, List, Optional, Tuple from synapse.api.constants import ( AccountDataTypes, @@ -23,7 +23,6 @@ from synapse.api.constants import ( Membership, ) from synapse.api.errors import SynapseError -from synapse.events import EventBase from synapse.events.utils import SerializeEventConfig from synapse.events.validator import EventValidator from synapse.handlers.presence import format_user_presence_state @@ -33,9 +32,9 @@ from synapse.storage.roommember import RoomsForUser from synapse.streams.config import PaginationConfig from synapse.types import ( JsonDict, + JsonMapping, Requester, RoomStreamToken, - StateMap, StreamKeyType, StreamToken, UserID, @@ -193,15 +192,12 @@ class InitialSyncHandler: ) elif event.membership == Membership.LEAVE: room_end_token = RoomStreamToken( - None, - event.stream_ordering, + stream=event.stream_ordering, ) deferred_room_state = run_in_background( self._state_storage_controller.get_state_for_events, [event.event_id], - ).addCallback( - lambda states: cast(StateMap[EventBase], states[event.event_id]) - ) + ).addCallback(lambda states: states[event.event_id]) (messages, token), current_state = await make_deferred_yieldable( gather_results( @@ -458,7 +454,7 @@ class InitialSyncHandler: for s in states ] - async def get_receipts() -> List[JsonDict]: + async def get_receipts() -> List[JsonMapping]: receipts = await self.store.get_linearized_receipts_for_room( room_id, to_key=now_token.receipt_key ) diff --git a/synapse/handlers/jwt.py b/synapse/handlers/jwt.py index 5fddc0e315..740bf9b3c4 100644 --- a/synapse/handlers/jwt.py +++ b/synapse/handlers/jwt.py @@ -16,7 +16,7 @@ from typing import TYPE_CHECKING from authlib.jose import JsonWebToken, JWTClaims from authlib.jose.errors import BadSignatureError, InvalidClaimError, JoseError -from synapse.api.errors import Codes, LoginError, StoreError, UserDeactivatedError +from synapse.api.errors import Codes, LoginError from synapse.types import JsonDict, UserID if TYPE_CHECKING: @@ -26,7 +26,6 @@ if TYPE_CHECKING: class JwtHandler: def __init__(self, hs: "HomeServer"): self.hs = hs - self._main_store = hs.get_datastores().main self.jwt_secret = hs.config.jwt.jwt_secret self.jwt_subject_claim = hs.config.jwt.jwt_subject_claim @@ -34,7 +33,7 @@ class JwtHandler: self.jwt_issuer = hs.config.jwt.jwt_issuer self.jwt_audiences = hs.config.jwt.jwt_audiences - async def validate_login(self, login_submission: JsonDict) -> str: + def validate_login(self, login_submission: JsonDict) -> str: """ Authenticates the user for the /login API @@ -103,16 +102,4 @@ class JwtHandler: if user is None: raise LoginError(403, "Invalid JWT", errcode=Codes.FORBIDDEN) - user_id = UserID(user, self.hs.hostname).to_string() - - # If the account has been deactivated, do not proceed with the login - # flow. - try: - deactivated = await self._main_store.get_user_deactivated_status(user_id) - except StoreError: - # JWT lazily creates users, so they may not exist in the database yet. - deactivated = False - if deactivated: - raise UserDeactivatedError("This account has been deactivated") - - return user_id + return UserID(user, self.hs.hostname).to_string() diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 0b61c2272b..d0d4626ed6 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -53,6 +53,7 @@ from synapse.events.snapshot import EventContext, UnpersistedEventContextBase from synapse.events.utils import SerializeEventConfig, maybe_upsert_event_field from synapse.events.validator import EventValidator from synapse.handlers.directory import DirectoryHandler +from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME from synapse.logging import opentracing from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics.background_process_metrics import run_as_background_process @@ -60,7 +61,6 @@ from synapse.replication.http.send_event import ReplicationSendEventRestServlet from synapse.replication.http.send_events import ReplicationSendEventsRestServlet from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.types import ( - MutableStateMap, PersistedEventPosition, Requester, RoomAlias, @@ -379,7 +379,7 @@ class MessageHandler: """ expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER) - if type(expiry_ts) is not int or event.is_state(): + if type(expiry_ts) is not int or event.is_state(): # noqa: E721 return # _schedule_expiry_for_event won't actually schedule anything if there's already @@ -486,6 +486,7 @@ class EventCreationHandler: self._events_shard_config = self.config.worker.events_shard_config self._instance_name = hs.get_instance_name() self._notifier = hs.get_notifier() + self._worker_lock_handler = hs.get_worker_locks_handler() self.room_prejoin_state_types = self.hs.config.api.room_prejoin_state @@ -560,8 +561,6 @@ class EventCreationHandler: expiry_ms=30 * 60 * 1000, ) - self._msc3970_enabled = hs.config.experimental.msc3970_enabled - async def create_event( self, requester: Requester, @@ -573,7 +572,6 @@ class EventCreationHandler: state_event_ids: Optional[List[str]] = None, require_consent: bool = True, outlier: bool = False, - historical: bool = False, depth: Optional[int] = None, state_map: Optional[StateMap[str]] = None, for_batch: bool = False, @@ -599,7 +597,7 @@ class EventCreationHandler: allow_no_prev_events: Whether to allow this event to be created an empty list of prev_events. Normally this is prohibited just because most events should have a prev_event and we should only use this in special - cases like MSC2716. + cases (previously useful for MSC2716). prev_event_ids: the forward extremities to use as the prev_events for the new event. @@ -614,13 +612,10 @@ class EventCreationHandler: If non-None, prev_event_ids must also be provided. state_event_ids: - The full state at a given event. This is used particularly by the MSC2716 - /batch_send endpoint. One use case is with insertion events which float at - the beginning of a historical batch and don't have any `prev_events` to - derive from; we add all of these state events as the explicit state so the - rest of the historical batch can inherit the same state and state_group. - This should normally be left as None, which will cause the auth_event_ids - to be calculated based on the room state at the prev_events. + The full state at a given event. This was previously used particularly + by the MSC2716 /batch_send endpoint. This should normally be left as + None, which will cause the auth_event_ids to be calculated based on the + room state at the prev_events. require_consent: Whether to check if the requester has consented to the privacy policy. @@ -629,10 +624,6 @@ class EventCreationHandler: it's from an arbitrary point and floating in the DAG as opposed to being inline with the current DAG. - historical: Indicates whether the message is being inserted - back in time around some existing events. This is used to skip - a few checks and mark the event as backfilled. - depth: Override the depth used to order the event in the DAG. Should normally be set to None, which will cause the depth to be calculated based on the prev_events. @@ -702,13 +693,9 @@ class EventCreationHandler: if require_consent and not is_exempt: await self.assert_accepted_privacy_policy(requester) - # Save the access token ID, the device ID and the transaction ID in the event - # internal metadata. This is useful to determine if we should echo the - # transaction_id in events. + # Save the the device ID and the transaction ID in the event internal metadata. + # This is useful to determine if we should echo the transaction_id in events. # See `synapse.events.utils.EventClientSerializer.serialize_event` - if requester.access_token_id is not None: - builder.internal_metadata.token_id = requester.access_token_id - if requester.device_id is not None: builder.internal_metadata.device_id = requester.device_id @@ -717,8 +704,6 @@ class EventCreationHandler: builder.internal_metadata.outlier = outlier - builder.internal_metadata.historical = historical - event, unpersisted_context = await self.create_new_client_event( builder=builder, requester=requester, @@ -749,7 +734,7 @@ class EventCreationHandler: prev_event_id = state_map.get((EventTypes.Member, event.sender)) else: prev_state_ids = await unpersisted_context.get_prev_state_ids( - StateFilter.from_types([(EventTypes.Member, None)]) + StateFilter.from_types([(EventTypes.Member, event.sender)]) ) prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender)) prev_event = ( @@ -839,13 +824,13 @@ class EventCreationHandler: u = await self.store.get_user_by_id(user_id) assert u is not None - if u["user_type"] in (UserTypes.SUPPORT, UserTypes.BOT): + if u.user_type in (UserTypes.SUPPORT, UserTypes.BOT): # support and bot users are not required to consent return - if u["appservice_id"] is not None: + if u.appservice_id is not None: # users registered by an appservice are exempt return - if u["consent_version"] == self.config.consent.user_consent_version: + if u.consent_version == self.config.consent.user_consent_version: return consent_uri = self._consent_uri_builder.build_user_consent_uri(user.localpart) @@ -871,7 +856,7 @@ class EventCreationHandler: return None prev_state_ids = await context.get_prev_state_ids( - StateFilter.from_types([(event.type, None)]) + StateFilter.from_types([(event.type, event.state_key)]) ) prev_event_id = prev_state_ids.get((event.type, event.state_key)) if not prev_event_id: @@ -887,14 +872,13 @@ class EventCreationHandler: return prev_event return None - async def get_event_from_transaction( + async def get_event_id_from_transaction( self, requester: Requester, txn_id: str, room_id: str, - ) -> Optional[EventBase]: - """For the given transaction ID and room ID, check if there is a matching event. - If so, fetch it and return it. + ) -> Optional[str]: + """For the given transaction ID and room ID, check if there is a matching event ID. Args: requester: The requester making the request in the context of which we want @@ -903,12 +887,12 @@ class EventCreationHandler: room_id: The room ID. Returns: - An event if one could be found, None otherwise. + An event ID if one could be found, None otherwise. """ + existing_event_id = None - if self._msc3970_enabled and requester.device_id: - # When MSC3970 is enabled, we lookup for events sent by the same device first, - # and fallback to the old behaviour if none were found. + # According to the spec, transactions are scoped to a user's device ID. + if requester.device_id: existing_event_id = ( await self.store.get_event_id_from_transaction_id_and_device_id( room_id, @@ -918,22 +902,33 @@ class EventCreationHandler: ) ) if existing_event_id: - return await self.store.get_event(existing_event_id) + return existing_event_id - # Pre-MSC3970, we looked up for events that were sent by the same session by - # using the access token ID. - if requester.access_token_id: - existing_event_id = ( - await self.store.get_event_id_from_transaction_id_and_token_id( - room_id, - requester.user.to_string(), - requester.access_token_id, - txn_id, - ) - ) - if existing_event_id: - return await self.store.get_event(existing_event_id) + return existing_event_id + + async def get_event_from_transaction( + self, + requester: Requester, + txn_id: str, + room_id: str, + ) -> Optional[EventBase]: + """For the given transaction ID and room ID, check if there is a matching event. + If so, fetch it and return it. + + Args: + requester: The requester making the request in the context of which we want + to fetch the event. + txn_id: The transaction ID. + room_id: The room ID. + Returns: + An event if one could be found, None otherwise. + """ + existing_event_id = await self.get_event_id_from_transaction( + requester, txn_id, room_id + ) + if existing_event_id: + return await self.store.get_event(existing_event_id) return None async def create_and_send_nonmember_event( @@ -947,7 +942,6 @@ class EventCreationHandler: txn_id: Optional[str] = None, ignore_shadow_ban: bool = False, outlier: bool = False, - historical: bool = False, depth: Optional[int] = None, ) -> Tuple[EventBase, int]: """ @@ -961,19 +955,16 @@ class EventCreationHandler: allow_no_prev_events: Whether to allow this event to be created an empty list of prev_events. Normally this is prohibited just because most events should have a prev_event and we should only use this in special - cases like MSC2716. + cases (previously useful for MSC2716). prev_event_ids: The event IDs to use as the prev events. Should normally be left as None to automatically request them from the database. state_event_ids: - The full state at a given event. This is used particularly by the MSC2716 - /batch_send endpoint. One use case is with insertion events which float at - the beginning of a historical batch and don't have any `prev_events` to - derive from; we add all of these state events as the explicit state so the - rest of the historical batch can inherit the same state and state_group. - This should normally be left as None, which will cause the auth_event_ids - to be calculated based on the room state at the prev_events. + The full state at a given event. This was previously used particularly + by the MSC2716 /batch_send endpoint. This should normally be left as + None, which will cause the auth_event_ids to be calculated based on the + room state at the prev_events. ratelimit: Whether to rate limit this send. txn_id: The transaction ID. ignore_shadow_ban: True if shadow-banned users should be allowed to @@ -981,9 +972,6 @@ class EventCreationHandler: outlier: Indicates whether the event is an `outlier`, i.e. if it's from an arbitrary point and floating in the DAG as opposed to being inline with the current DAG. - historical: Indicates whether the message is being inserted - back in time around some existing events. This is used to skip - a few checks and mark the event as backfilled. depth: Override the depth used to order the event in the DAG. Should normally be set to None, which will cause the depth to be calculated based on the prev_events. @@ -1028,6 +1016,37 @@ class EventCreationHandler: event.internal_metadata.stream_ordering, ) + async with self._worker_lock_handler.acquire_read_write_lock( + NEW_EVENT_DURING_PURGE_LOCK_NAME, room_id, write=False + ): + return await self._create_and_send_nonmember_event_locked( + requester=requester, + event_dict=event_dict, + allow_no_prev_events=allow_no_prev_events, + prev_event_ids=prev_event_ids, + state_event_ids=state_event_ids, + ratelimit=ratelimit, + txn_id=txn_id, + ignore_shadow_ban=ignore_shadow_ban, + outlier=outlier, + depth=depth, + ) + + async def _create_and_send_nonmember_event_locked( + self, + requester: Requester, + event_dict: dict, + allow_no_prev_events: bool = False, + prev_event_ids: Optional[List[str]] = None, + state_event_ids: Optional[List[str]] = None, + ratelimit: bool = True, + txn_id: Optional[str] = None, + ignore_shadow_ban: bool = False, + outlier: bool = False, + depth: Optional[int] = None, + ) -> Tuple[EventBase, int]: + room_id = event_dict["room_id"] + # If we don't have any prev event IDs specified then we need to # check that the host is in the room (as otherwise populating the # prev events will fail), at which point we may as well check the @@ -1053,7 +1072,6 @@ class EventCreationHandler: prev_event_ids=prev_event_ids, state_event_ids=state_event_ids, outlier=outlier, - historical=historical, depth=depth, ) context = await unpersisted_context.persist(event) @@ -1145,7 +1163,7 @@ class EventCreationHandler: allow_no_prev_events: Whether to allow this event to be created an empty list of prev_events. Normally this is prohibited just because most events should have a prev_event and we should only use this in special - cases like MSC2716. + cases (previously useful for MSC2716). prev_event_ids: the forward extremities to use as the prev_events for the new event. @@ -1158,13 +1176,10 @@ class EventCreationHandler: based on the room state at the prev_events. state_event_ids: - The full state at a given event. This is used particularly by the MSC2716 - /batch_send endpoint. One use case is with insertion events which float at - the beginning of a historical batch and don't have any `prev_events` to - derive from; we add all of these state events as the explicit state so the - rest of the historical batch can inherit the same state and state_group. - This should normally be left as None, which will cause the auth_event_ids - to be calculated based on the room state at the prev_events. + The full state at a given event. This was previously used particularly + by the MSC2716 /batch_send endpoint. This should normally be left as + None, which will cause the auth_event_ids to be calculated based on the + room state at the prev_events. depth: Override the depth used to order the event in the DAG. Should normally be set to None, which will cause the depth to be calculated @@ -1261,52 +1276,6 @@ class EventCreationHandler: if builder.internal_metadata.outlier: event.internal_metadata.outlier = True context = EventContext.for_outlier(self._storage_controllers) - elif ( - event.type == EventTypes.MSC2716_INSERTION - and state_event_ids - and builder.internal_metadata.is_historical() - ): - # Add explicit state to the insertion event so it has state to derive - # from even though it's floating with no `prev_events`. The rest of - # the batch can derive from this state and state_group. - # - # TODO(faster_joins): figure out how this works, and make sure that the - # old state is complete. - # https://github.com/matrix-org/synapse/issues/13003 - metadata = await self.store.get_metadata_for_events(state_event_ids) - - state_map_for_event: MutableStateMap[str] = {} - for state_id in state_event_ids: - data = metadata.get(state_id) - if data is None: - # We're trying to persist a new historical batch of events - # with the given state, e.g. via - # `RoomBatchSendEventRestServlet`. The state can be inferred - # by Synapse or set directly by the client. - # - # Either way, we should have persisted all the state before - # getting here. - raise Exception( - f"State event {state_id} not found in DB," - " Synapse should have persisted it before using it." - ) - - if data.state_key is None: - raise Exception( - f"Trying to set non-state event {state_id} as state" - ) - - state_map_for_event[(data.event_type, data.state_key)] = state_id - - # TODO(faster_joins): check how MSC2716 works and whether we can have - # partial state here - # https://github.com/matrix-org/synapse/issues/13003 - context = await self.state.calculate_context_info( - event, - state_ids_before_event=state_map_for_event, - partial_state=False, - ) - else: context = await self.state.calculate_context_info(event) @@ -1488,23 +1457,23 @@ class EventCreationHandler: # We now persist the event (and update the cache in parallel, since we # don't want to block on it). - event, context = events_and_context[0] + # + # Note: mypy gets confused if we inline dl and check with twisted#11770. + # Some kind of bug in mypy's deduction? + deferreds = ( + run_in_background( + self._persist_events, + requester=requester, + events_and_context=events_and_context, + ratelimit=ratelimit, + extra_users=extra_users, + ), + run_in_background( + self.cache_joined_hosts_for_events, events_and_context + ).addErrback(log_failure, "cache_joined_hosts_for_event failed"), + ) result, _ = await make_deferred_yieldable( - gather_results( - ( - run_in_background( - self._persist_events, - requester=requester, - events_and_context=events_and_context, - ratelimit=ratelimit, - extra_users=extra_users, - ), - run_in_background( - self.cache_joined_hosts_for_events, events_and_context - ).addErrback(log_failure, "cache_joined_hosts_for_event failed"), - ), - consumeErrors=True, - ) + gather_results(deferreds, consumeErrors=True) ).addErrback(unwrapFirstError) return result @@ -1633,12 +1602,11 @@ class EventCreationHandler: if state_entry.state_group in self._external_cache_joined_hosts_updates: return - state = await state_entry.get_state( - self._storage_controllers.state, StateFilter.all() - ) with opentracing.start_active_span("get_joined_hosts"): - joined_hosts = await self.store.get_joined_hosts( - event.room_id, state, state_entry + joined_hosts = ( + await self._storage_controllers.state.get_joined_hosts( + event.room_id, state_entry + ) ) # Note that the expiry times must be larger than the expiry time in @@ -1758,6 +1726,11 @@ class EventCreationHandler: event.event_id, event.room_id ) + if event.type == EventTypes.ServerACL: + self._storage_controllers.state.get_server_acl_for_room.invalidate( + (event.room_id,) + ) + await self._maybe_kick_guest_users(event, context) if event.type == EventTypes.CanonicalAlias: @@ -1876,28 +1849,6 @@ class EventCreationHandler: 403, "Redacting server ACL events is not permitted" ) - # Add a little safety stop-gap to prevent people from trying to - # redact MSC2716 related events when they're in a room version - # which does not support it yet. We allow people to use MSC2716 - # events in existing room versions but only from the room - # creator since it does not require any changes to the auth - # rules and in effect, the redaction algorithm . In the - # supported room version, we add the `historical` power level to - # auth the MSC2716 related events and adjust the redaction - # algorthim to keep the `historical` field around (redacting an - # event should only strip fields which don't affect the - # structural protocol level). - is_msc2716_event = ( - original_event.type == EventTypes.MSC2716_INSERTION - or original_event.type == EventTypes.MSC2716_BATCH - or original_event.type == EventTypes.MSC2716_MARKER - ) - if not room_version_obj.msc2716_historical and is_msc2716_event: - raise AuthError( - 403, - "Redacting MSC2716 events is not supported in this room version", - ) - event_types = event_auth.auth_types_for_event(event.room_version, event) prev_state_ids = await context.get_prev_state_ids( StateFilter.from_types(event_types) @@ -1935,58 +1886,12 @@ class EventCreationHandler: if prev_state_ids: raise AuthError(403, "Changing the room create event is forbidden") - if event.type == EventTypes.MSC2716_INSERTION: - room_version = await self.store.get_room_version_id(event.room_id) - room_version_obj = KNOWN_ROOM_VERSIONS[room_version] - - create_event = await self.store.get_create_event_for_room(event.room_id) - if not room_version_obj.msc2175_implicit_room_creator: - room_creator = create_event.content.get( - EventContentFields.ROOM_CREATOR - ) - else: - room_creator = create_event.sender - - # Only check an insertion event if the room version - # supports it or the event is from the room creator. - if room_version_obj.msc2716_historical or ( - self.config.experimental.msc2716_enabled - and event.sender == room_creator - ): - next_batch_id = event.content.get( - EventContentFields.MSC2716_NEXT_BATCH_ID - ) - conflicting_insertion_event_id = None - if next_batch_id: - conflicting_insertion_event_id = ( - await self.store.get_insertion_event_id_by_batch_id( - event.room_id, next_batch_id - ) - ) - if conflicting_insertion_event_id is not None: - # The current insertion event that we're processing is invalid - # because an insertion event already exists in the room with the - # same next_batch_id. We can't allow multiple because the batch - # pointing will get weird, e.g. we can't determine which insertion - # event the batch event is pointing to. - raise SynapseError( - HTTPStatus.BAD_REQUEST, - "Another insertion event already exists with the same next_batch_id", - errcode=Codes.INVALID_PARAM, - ) - - # Mark any `m.historical` messages as backfilled so they don't appear - # in `/sync` and have the proper decrementing `stream_ordering` as we import - backfilled = False - if event.internal_metadata.is_historical(): - backfilled = True - assert self._storage_controllers.persistence is not None ( persisted_events, max_stream_token, ) = await self._storage_controllers.persistence.persist_events( - events_and_context, backfilled=backfilled + events_and_context, ) events_and_pos = [] @@ -2004,7 +1909,10 @@ class EventCreationHandler: # We don't want to block sending messages on any presence code. This # matters as sometimes presence code can take a while. run_as_background_process( - "bump_presence_active_time", self._bump_active_time, requester.user + "bump_presence_active_time", + self._bump_active_time, + requester.user, + requester.device_id, ) async def _notify() -> None: @@ -2041,10 +1949,10 @@ class EventCreationHandler: logger.info("maybe_kick_guest_users %r", current_state) await self.hs.get_room_member_handler().kick_guest_users(current_state) - async def _bump_active_time(self, user: UserID) -> None: + async def _bump_active_time(self, user: UserID, device_id: Optional[str]) -> None: try: presence = self.hs.get_presence_handler() - await presence.bump_presence_active_time(user) + await presence.bump_presence_active_time(user, device_id) except Exception: logger.exception("Error bumping presence active time") @@ -2060,7 +1968,10 @@ class EventCreationHandler: ) for room_id in room_ids: - dummy_event_sent = await self._send_dummy_event_for_room(room_id) + async with self._worker_lock_handler.acquire_read_write_lock( + NEW_EVENT_DURING_PURGE_LOCK_NAME, room_id, write=False + ): + dummy_event_sent = await self._send_dummy_event_for_room(room_id) if not dummy_event_sent: # Did not find a valid user in the room, so remove from future attempts diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py index e7e0b5e049..24b68e0301 100644 --- a/synapse/handlers/oidc.py +++ b/synapse/handlers/oidc.py @@ -1354,7 +1354,7 @@ class OidcProvider: finish_request(request) -class LogoutToken(JWTClaims): +class LogoutToken(JWTClaims): # type: ignore[misc] """ Holds and verify claims of a logout token, as per https://openid.net/specs/openid-connect-backchannel-1_0.html#LogoutToken diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 63b35c8d62..878f267a4e 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -13,9 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Dict, List, Optional, Set - -import attr +from typing import TYPE_CHECKING, List, Optional, Set, Tuple, cast from twisted.python.failure import Failure @@ -23,15 +21,22 @@ from synapse.api.constants import Direction, EventTypes, Membership from synapse.api.errors import SynapseError from synapse.api.filtering import Filter from synapse.events.utils import SerializeEventConfig -from synapse.handlers.room import ShutdownRoomResponse +from synapse.handlers.room import ShutdownRoomParams, ShutdownRoomResponse +from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME from synapse.logging.opentracing import trace from synapse.metrics.background_process_metrics import run_as_background_process from synapse.rest.admin._base import assert_user_is_admin from synapse.streams.config import PaginationConfig -from synapse.types import JsonDict, Requester, StrCollection, StreamKeyType +from synapse.types import ( + JsonDict, + JsonMapping, + Requester, + ScheduledTask, + StreamKeyType, + TaskStatus, +) from synapse.types.state import StateFilter from synapse.util.async_helpers import ReadWriteLock -from synapse.util.stringutils import random_string from synapse.visibility import filter_events_for_client if TYPE_CHECKING: @@ -40,81 +45,23 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +# How many single event gaps we tolerate returning in a `/messages` response before we +# backfill and try to fill in the history. This is an arbitrarily picked number so feel +# free to tune it in the future. +BACKFILL_BECAUSE_TOO_MANY_GAPS_THRESHOLD = 3 -@attr.s(slots=True, auto_attribs=True) -class PurgeStatus: - """Object tracking the status of a purge request - - This class contains information on the progress of a purge request, for - return by get_purge_status. - """ - - STATUS_ACTIVE = 0 - STATUS_COMPLETE = 1 - STATUS_FAILED = 2 - - STATUS_TEXT = { - STATUS_ACTIVE: "active", - STATUS_COMPLETE: "complete", - STATUS_FAILED: "failed", - } - - # Save the error message if an error occurs - error: str = "" - # Tracks whether this request has completed. One of STATUS_{ACTIVE,COMPLETE,FAILED}. - status: int = STATUS_ACTIVE +# This is used to avoid purging a room several time at the same moment, +# and also paginating during a purge. Pagination can trigger backfill, +# which would create old events locally, and would potentially clash with the room delete. +PURGE_PAGINATION_LOCK_NAME = "purge_pagination_lock" - def asdict(self) -> JsonDict: - ret = {"status": PurgeStatus.STATUS_TEXT[self.status]} - if self.error: - ret["error"] = self.error - return ret +PURGE_HISTORY_ACTION_NAME = "purge_history" -@attr.s(slots=True, auto_attribs=True) -class DeleteStatus: - """Object tracking the status of a delete room request +PURGE_ROOM_ACTION_NAME = "purge_room" - This class contains information on the progress of a delete room request, for - return by get_delete_status. - """ - - STATUS_PURGING = 0 - STATUS_COMPLETE = 1 - STATUS_FAILED = 2 - STATUS_SHUTTING_DOWN = 3 - - STATUS_TEXT = { - STATUS_PURGING: "purging", - STATUS_COMPLETE: "complete", - STATUS_FAILED: "failed", - STATUS_SHUTTING_DOWN: "shutting_down", - } - - # Tracks whether this request has completed. - # One of STATUS_{PURGING,COMPLETE,FAILED,SHUTTING_DOWN}. - status: int = STATUS_PURGING - - # Save the error message if an error occurs - error: str = "" - - # Saves the result of an action to give it back to REST API - shutdown_room: ShutdownRoomResponse = { - "kicked_users": [], - "failed_to_kick_users": [], - "local_aliases": [], - "new_room_id": None, - } - - def asdict(self) -> JsonDict: - ret = { - "status": DeleteStatus.STATUS_TEXT[self.status], - "shutdown_room": self.shutdown_room, - } - if self.error: - ret["error"] = self.error - return ret +SHUTDOWN_AND_PURGE_ROOM_ACTION_NAME = "shutdown_and_purge_room" class PaginationHandler: @@ -124,9 +71,6 @@ class PaginationHandler: paginating during a purge. """ - # when to remove a completed deletion/purge from the results map - CLEAR_PURGE_AFTER_MS = 1000 * 3600 * 24 # 24 hours - def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() @@ -137,17 +81,12 @@ class PaginationHandler: self._server_name = hs.hostname self._room_shutdown_handler = hs.get_room_shutdown_handler() self._relations_handler = hs.get_relations_handler() + self._worker_locks = hs.get_worker_locks_handler() + self._task_scheduler = hs.get_task_scheduler() self.pagination_lock = ReadWriteLock() # IDs of rooms in which there currently an active purge *or delete* operation. self._purges_in_progress_by_room: Set[str] = set() - # map from purge id to PurgeStatus - self._purges_by_id: Dict[str, PurgeStatus] = {} - # map from purge id to DeleteStatus - self._delete_by_id: Dict[str, DeleteStatus] = {} - # map from room id to delete ids - # Dict[`room_id`, List[`delete_id`]] - self._delete_by_room: Dict[str, List[str]] = {} self._event_serializer = hs.get_event_client_serializer() self._retention_default_max_lifetime = ( @@ -160,6 +99,9 @@ class PaginationHandler: self._retention_allowed_lifetime_max = ( hs.config.retention.retention_allowed_lifetime_max ) + self._forgotten_room_retention_period = ( + hs.config.server.forgotten_room_retention_period + ) self._is_master = hs.config.worker.worker_app is None if hs.config.retention.retention_enabled and self._is_master: @@ -176,6 +118,14 @@ class PaginationHandler: job.longest_max_lifetime, ) + self._task_scheduler.register_action( + self._purge_history, PURGE_HISTORY_ACTION_NAME + ) + self._task_scheduler.register_action(self._purge_room, PURGE_ROOM_ACTION_NAME) + self._task_scheduler.register_action( + self._shutdown_and_purge_room, SHUTDOWN_AND_PURGE_ROOM_ACTION_NAME + ) + async def purge_history_for_rooms_in_range( self, min_ms: Optional[int], max_ms: Optional[int] ) -> None: @@ -211,7 +161,7 @@ class PaginationHandler: include_null = False logger.info( - "[purge] Running purge job for %s < max_lifetime <= %s (include NULLs = %s)", + "[purge] Running retention purge job for %s < max_lifetime <= %s (include NULLs = %s)", min_ms, max_ms, include_null, @@ -226,10 +176,10 @@ class PaginationHandler: for room_id, retention_policy in rooms.items(): logger.info("[purge] Attempting to purge messages in room %s", room_id) - if room_id in self._purges_in_progress_by_room: + if len(await self.get_delete_tasks_by_room(room_id, only_active=True)) > 0: logger.warning( - "[purge] not purging room %s as there's an ongoing purge running" - " for this room", + "[purge] not purging room %s for retention as there's an ongoing purge" + " running for this room", room_id, ) continue @@ -282,27 +232,20 @@ class PaginationHandler: (stream, topo, _event_id) = r token = "t%d-%d" % (topo, stream) - purge_id = random_string(16) - - self._purges_by_id[purge_id] = PurgeStatus() - - logger.info( - "Starting purging events in room %s (purge_id %s)" % (room_id, purge_id) - ) + logger.info("Starting purging events in room %s", room_id) # We want to purge everything, including local events, and to run the purge in # the background so that it's not blocking any other operation apart from # other purges in the same room. run_as_background_process( - "_purge_history", - self._purge_history, - purge_id, + PURGE_HISTORY_ACTION_NAME, + self.purge_history, room_id, token, True, ) - def start_purge_history( + async def start_purge_history( self, room_id: str, token: str, delete_local_events: bool = False ) -> str: """Start off a history purge on a room. @@ -316,106 +259,150 @@ class PaginationHandler: Returns: unique ID for this purge transaction. """ - if room_id in self._purges_in_progress_by_room: - raise SynapseError( - 400, "History purge already in progress for %s" % (room_id,) - ) - - purge_id = random_string(16) + purge_id = await self._task_scheduler.schedule_task( + PURGE_HISTORY_ACTION_NAME, + resource_id=room_id, + params={"token": token, "delete_local_events": delete_local_events}, + ) # we log the purge_id here so that it can be tied back to the # request id in the log lines. logger.info("[purge] starting purge_id %s", purge_id) - self._purges_by_id[purge_id] = PurgeStatus() - run_as_background_process( - "purge_history", - self._purge_history, - purge_id, - room_id, - token, - delete_local_events, - ) return purge_id async def _purge_history( - self, purge_id: str, room_id: str, token: str, delete_local_events: bool - ) -> None: + self, + task: ScheduledTask, + ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]: + """ + Scheduler action to purge some history of a room. + """ + if ( + task.resource_id is None + or task.params is None + or "token" not in task.params + or "delete_local_events" not in task.params + ): + return ( + TaskStatus.FAILED, + None, + "Not enough parameters passed to _purge_history", + ) + err = await self.purge_history( + task.resource_id, + task.params["token"], + task.params["delete_local_events"], + ) + if err is not None: + return TaskStatus.FAILED, None, err + return TaskStatus.COMPLETE, None, None + + async def purge_history( + self, + room_id: str, + token: str, + delete_local_events: bool, + ) -> Optional[str]: """Carry out a history purge on a room. Args: - purge_id: The ID for this purge. room_id: The room to purge from token: topological token to delete events before delete_local_events: True to delete local events as well as remote ones """ - self._purges_in_progress_by_room.add(room_id) try: - async with self.pagination_lock.write(room_id): + async with self._worker_locks.acquire_read_write_lock( + PURGE_PAGINATION_LOCK_NAME, room_id, write=True + ): await self._storage_controllers.purge_events.purge_history( room_id, token, delete_local_events ) logger.info("[purge] complete") - self._purges_by_id[purge_id].status = PurgeStatus.STATUS_COMPLETE + return None except Exception: f = Failure() logger.error( - "[purge] failed", exc_info=(f.type, f.value, f.getTracebackObject()) # type: ignore - ) - self._purges_by_id[purge_id].status = PurgeStatus.STATUS_FAILED - self._purges_by_id[purge_id].error = f.getErrorMessage() - finally: - self._purges_in_progress_by_room.discard(room_id) - - # remove the purge from the list 24 hours after it completes - def clear_purge() -> None: - del self._purges_by_id[purge_id] - - self.hs.get_reactor().callLater( - PaginationHandler.CLEAR_PURGE_AFTER_MS / 1000, clear_purge + "[purge] failed", exc_info=(f.type, f.value, f.getTracebackObject()) ) + return f.getErrorMessage() - def get_purge_status(self, purge_id: str) -> Optional[PurgeStatus]: - """Get the current status of an active purge - - Args: - purge_id: purge_id returned by start_purge_history - """ - return self._purges_by_id.get(purge_id) - - def get_delete_status(self, delete_id: str) -> Optional[DeleteStatus]: + async def get_delete_task(self, delete_id: str) -> Optional[ScheduledTask]: """Get the current status of an active deleting Args: delete_id: delete_id returned by start_shutdown_and_purge_room + or start_purge_history. """ - return self._delete_by_id.get(delete_id) + return await self._task_scheduler.get_task(delete_id) - def get_delete_ids_by_room(self, room_id: str) -> Optional[StrCollection]: - """Get all active delete ids by room + async def get_delete_tasks_by_room( + self, room_id: str, only_active: Optional[bool] = False + ) -> List[ScheduledTask]: + """Get complete, failed or active delete tasks by room Args: room_id: room_id that is deleted + only_active: if True, completed&failed tasks will be omitted """ - return self._delete_by_room.get(room_id) + statuses = [TaskStatus.ACTIVE] + if not only_active: + statuses += [TaskStatus.COMPLETE, TaskStatus.FAILED] + + return await self._task_scheduler.get_tasks( + actions=[PURGE_ROOM_ACTION_NAME, SHUTDOWN_AND_PURGE_ROOM_ACTION_NAME], + resource_id=room_id, + statuses=statuses, + ) - async def purge_room(self, room_id: str, force: bool = False) -> None: + async def _purge_room( + self, + task: ScheduledTask, + ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]: + """ + Scheduler action to purge a room. + """ + if not task.resource_id: + raise Exception("No room id passed to purge_room task") + params = task.params if task.params else {} + await self.purge_room(task.resource_id, params.get("force", False)) + return TaskStatus.COMPLETE, None, None + + async def purge_room( + self, + room_id: str, + force: bool, + ) -> None: """Purge the given room from the database. - This function is part the delete room v1 API. Args: room_id: room to be purged force: set true to skip checking for joined users. """ - async with self.pagination_lock.write(room_id): + logger.info("starting purge room_id=%s force=%s", room_id, force) + + async with self._worker_locks.acquire_multi_read_write_lock( + [ + (PURGE_PAGINATION_LOCK_NAME, room_id), + (NEW_EVENT_DURING_PURGE_LOCK_NAME, room_id), + ], + write=True, + ): # first check that we have no users in this room - if not force: - joined = await self.store.is_host_joined(room_id, self._server_name) - if joined: + joined = await self.store.is_host_joined(room_id, self._server_name) + if joined: + if force: + logger.info( + "force-purging room %s with some local users still joined", + room_id, + ) + else: raise SynapseError(400, "Users are still joined to this room") await self._storage_controllers.purge_events.purge_room(room_id) + logger.info("purge complete for room_id %s", room_id) + @trace async def get_messages( self, @@ -466,65 +453,150 @@ class PaginationHandler: room_token = from_token.room_key - async with self.pagination_lock.read(room_id): - (membership, member_event_id) = (None, None) - if not use_admin_priviledge: - ( - membership, - member_event_id, - ) = await self.auth.check_user_in_room_or_world_readable( - room_id, requester, allow_departed_users=True - ) - - if pagin_config.direction == Direction.BACKWARDS: - # if we're going backwards, we might need to backfill. This - # requires that we have a topo token. - if room_token.topological: - curr_topo = room_token.topological - else: - curr_topo = await self.store.get_current_topological_token( - room_id, room_token.stream - ) + (membership, member_event_id) = (None, None) + if not use_admin_priviledge: + ( + membership, + member_event_id, + ) = await self.auth.check_user_in_room_or_world_readable( + room_id, requester, allow_departed_users=True + ) - if not use_admin_priviledge and membership == Membership.LEAVE: - # If they have left the room then clamp the token to be before - # they left the room, to save the effort of loading from the - # database. + if pagin_config.direction == Direction.BACKWARDS: + # if we're going backwards, we might need to backfill. This + # requires that we have a topo token. + if room_token.topological: + curr_topo = room_token.topological + else: + curr_topo = await self.store.get_current_topological_token( + room_id, room_token.stream + ) - # This is only None if the room is world_readable, in which - # case "JOIN" would have been returned. - assert member_event_id + # If they have left the room then clamp the token to be before + # they left the room, to save the effort of loading from the + # database. + if ( + pagin_config.direction == Direction.BACKWARDS + and not use_admin_priviledge + and membership == Membership.LEAVE + ): + # This is only None if the room is world_readable, in which case + # "Membership.JOIN" would have been returned and we should never hit + # this branch. + assert member_event_id + + leave_token = await self.store.get_topological_token_for_event( + member_event_id + ) + assert leave_token.topological is not None - leave_token = await self.store.get_topological_token_for_event( - member_event_id - ) - assert leave_token.topological is not None + if leave_token.topological < curr_topo: + from_token = from_token.copy_and_replace( + StreamKeyType.ROOM, leave_token + ) - if leave_token.topological < curr_topo: - from_token = from_token.copy_and_replace( - StreamKeyType.ROOM, leave_token - ) + to_room_key = None + if pagin_config.to_token: + to_room_key = pagin_config.to_token.room_key + + # Initially fetch the events from the database. With any luck, we can return + # these without blocking on backfill (handled below). + events, next_key = await self.store.paginate_room_events( + room_id=room_id, + from_key=from_token.room_key, + to_key=to_room_key, + direction=pagin_config.direction, + limit=pagin_config.limit, + event_filter=event_filter, + ) - await self.hs.get_federation_handler().maybe_backfill( + if pagin_config.direction == Direction.BACKWARDS: + # We use a `Set` because there can be multiple events at a given depth + # and we only care about looking at the unique continum of depths to + # find gaps. + event_depths: Set[int] = {event.depth for event in events} + sorted_event_depths = sorted(event_depths) + + # Inspect the depths of the returned events to see if there are any gaps + found_big_gap = False + number_of_gaps = 0 + previous_event_depth = ( + sorted_event_depths[0] if len(sorted_event_depths) > 0 else 0 + ) + for event_depth in sorted_event_depths: + # We don't expect a negative depth but we'll just deal with it in + # any case by taking the absolute value to get the true gap between + # any two integers. + depth_gap = abs(event_depth - previous_event_depth) + # A `depth_gap` of 1 is a normal continuous chain to the next event + # (1 <-- 2 <-- 3) so anything larger indicates a missing event (it's + # also possible there is no event at a given depth but we can't ever + # know that for sure) + if depth_gap > 1: + number_of_gaps += 1 + + # We only tolerate a small number single-event long gaps in the + # returned events because those are most likely just events we've + # failed to pull in the past. Anything longer than that is probably + # a sign that we're missing a decent chunk of history and we should + # try to backfill it. + # + # XXX: It's possible we could tolerate longer gaps if we checked + # that a given events `prev_events` is one that has failed pull + # attempts and we could just treat it like a dead branch of history + # for now or at least something that we don't need the block the + # client on to try pulling. + # + # XXX: If we had something like MSC3871 to indicate gaps in the + # timeline to the client, we could also get away with any sized gap + # and just have the client refetch the holes as they see fit. + if depth_gap > 2: + found_big_gap = True + break + previous_event_depth = event_depth + + # Backfill in the foreground if we found a big gap, have too many holes, + # or we don't have enough events to fill the limit that the client asked + # for. + missing_too_many_events = ( + number_of_gaps > BACKFILL_BECAUSE_TOO_MANY_GAPS_THRESHOLD + ) + not_enough_events_to_fill_response = len(events) < pagin_config.limit + if ( + found_big_gap + or missing_too_many_events + or not_enough_events_to_fill_response + ): + did_backfill = await self.hs.get_federation_handler().maybe_backfill( room_id, curr_topo, limit=pagin_config.limit, ) - to_room_key = None - if pagin_config.to_token: - to_room_key = pagin_config.to_token.room_key - - events, next_key = await self.store.paginate_room_events( - room_id=room_id, - from_key=from_token.room_key, - to_key=to_room_key, - direction=pagin_config.direction, - limit=pagin_config.limit, - event_filter=event_filter, - ) + # If we did backfill something, refetch the events from the database to + # catch anything new that might have been added since we last fetched. + if did_backfill: + events, next_key = await self.store.paginate_room_events( + room_id=room_id, + from_key=from_token.room_key, + to_key=to_room_key, + direction=pagin_config.direction, + limit=pagin_config.limit, + event_filter=event_filter, + ) + else: + # Otherwise, we can backfill in the background for eventual + # consistency's sake but we don't need to block the client waiting + # for a costly federation call and processing. + run_as_background_process( + "maybe_backfill_in_the_background", + self.hs.get_federation_handler().maybe_backfill, + room_id, + curr_topo, + limit=pagin_config.limit, + ) - next_token = from_token.copy_and_replace(StreamKeyType.ROOM, next_key) + next_token = from_token.copy_and_replace(StreamKeyType.ROOM, next_key) # if no events are returned from pagination, that implies # we have reached the end of the available events. @@ -605,167 +677,72 @@ class PaginationHandler: async def _shutdown_and_purge_room( self, - delete_id: str, - room_id: str, - requester_user_id: str, - new_room_user_id: Optional[str] = None, - new_room_name: Optional[str] = None, - message: Optional[str] = None, - block: bool = False, - purge: bool = True, - force_purge: bool = False, - ) -> None: + task: ScheduledTask, + ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]: """ - Shuts down and purges a room. - - See `RoomShutdownHandler.shutdown_room` for details of creation of the new room - - Args: - delete_id: The ID for this delete. - room_id: The ID of the room to shut down. - requester_user_id: - User who requested the action. Will be recorded as putting the room on the - blocking list. - new_room_user_id: - If set, a new room will be created with this user ID - as the creator and admin, and all users in the old room will be - moved into that room. If not set, no new room will be created - and the users will just be removed from the old room. - new_room_name: - A string representing the name of the room that new users will - be invited to. Defaults to `Content Violation Notification` - message: - A string containing the first message that will be sent as - `new_room_user_id` in the new room. Ideally this will clearly - convey why the original room was shut down. - Defaults to `Sharing illegal content on this server is not - permitted and rooms in violation will be blocked.` - block: - If set to `true`, this room will be added to a blocking list, - preventing future attempts to join the room. Defaults to `false`. - purge: - If set to `true`, purge the given room from the database. - force_purge: - If set to `true`, the room will be purged from database - also if it fails to remove some users from room. - - Saves a `RoomShutdownHandler.ShutdownRoomResponse` in `DeleteStatus`: + Scheduler action to shutdown and purge a room. """ + if task.resource_id is None or task.params is None: + raise Exception( + "No room id and/or no parameters passed to shutdown_and_purge_room task" + ) - self._purges_in_progress_by_room.add(room_id) - try: - async with self.pagination_lock.write(room_id): - self._delete_by_id[delete_id].status = DeleteStatus.STATUS_SHUTTING_DOWN - self._delete_by_id[ - delete_id - ].shutdown_room = await self._room_shutdown_handler.shutdown_room( - room_id=room_id, - requester_user_id=requester_user_id, - new_room_user_id=new_room_user_id, - new_room_name=new_room_name, - message=message, - block=block, - ) - self._delete_by_id[delete_id].status = DeleteStatus.STATUS_PURGING + room_id = task.resource_id - if purge: - logger.info("starting purge room_id %s", room_id) + async def update_result(result: Optional[JsonMapping]) -> None: + await self._task_scheduler.update_task(task.id, result=result) - # first check that we have no users in this room - if not force_purge: - joined = await self.store.is_host_joined( - room_id, self._server_name - ) - if joined: - raise SynapseError( - 400, "Users are still joined to this room" - ) + shutdown_result = ( + cast(ShutdownRoomResponse, task.result) if task.result else None + ) - await self._storage_controllers.purge_events.purge_room(room_id) + shutdown_result = await self._room_shutdown_handler.shutdown_room( + room_id, + cast(ShutdownRoomParams, task.params), + shutdown_result, + update_result, + ) - logger.info("purge complete for room_id %s", room_id) - self._delete_by_id[delete_id].status = DeleteStatus.STATUS_COMPLETE - except Exception: - f = Failure() - logger.error( - "failed", - exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore - ) - self._delete_by_id[delete_id].status = DeleteStatus.STATUS_FAILED - self._delete_by_id[delete_id].error = f.getErrorMessage() - finally: - self._purges_in_progress_by_room.discard(room_id) - - # remove the delete from the list 24 hours after it completes - def clear_delete() -> None: - del self._delete_by_id[delete_id] - self._delete_by_room[room_id].remove(delete_id) - if not self._delete_by_room[room_id]: - del self._delete_by_room[room_id] - - self.hs.get_reactor().callLater( - PaginationHandler.CLEAR_PURGE_AFTER_MS / 1000, clear_delete + if task.params.get("purge", False): + await self.purge_room( + room_id, + task.params.get("force_purge", False), ) - def start_shutdown_and_purge_room( + return (TaskStatus.COMPLETE, shutdown_result, None) + + async def start_shutdown_and_purge_room( self, room_id: str, - requester_user_id: str, - new_room_user_id: Optional[str] = None, - new_room_name: Optional[str] = None, - message: Optional[str] = None, - block: bool = False, - purge: bool = True, - force_purge: bool = False, + shutdown_params: ShutdownRoomParams, ) -> str: """Start off shut down and purge on a room. Args: room_id: The ID of the room to shut down. - requester_user_id: - User who requested the action and put the room on the - blocking list. - new_room_user_id: - If set, a new room will be created with this user ID - as the creator and admin, and all users in the old room will be - moved into that room. If not set, no new room will be created - and the users will just be removed from the old room. - new_room_name: - A string representing the name of the room that new users will - be invited to. Defaults to `Content Violation Notification` - message: - A string containing the first message that will be sent as - `new_room_user_id` in the new room. Ideally this will clearly - convey why the original room was shut down. - Defaults to `Sharing illegal content on this server is not - permitted and rooms in violation will be blocked.` - block: - If set to `true`, this room will be added to a blocking list, - preventing future attempts to join the room. Defaults to `false`. - purge: - If set to `true`, purge the given room from the database. - force_purge: - If set to `true`, the room will be purged from database - also if it fails to remove some users from room. + shutdown_params: parameters for the shutdown Returns: unique ID for this delete transaction. """ - if room_id in self._purges_in_progress_by_room: - raise SynapseError( - 400, "History purge already in progress for %s" % (room_id,) - ) + if len(await self.get_delete_tasks_by_room(room_id, only_active=True)) > 0: + raise SynapseError(400, "Purge already in progress for %s" % (room_id,)) # This check is double to `RoomShutdownHandler.shutdown_room` # But here the requester get a direct response / error with HTTP request # and do not have to check the purge status + new_room_user_id = shutdown_params["new_room_user_id"] if new_room_user_id is not None: if not self.hs.is_mine_id(new_room_user_id): raise SynapseError( 400, "User must be our own: %s" % (new_room_user_id,) ) - delete_id = random_string(16) + delete_id = await self._task_scheduler.schedule_task( + SHUTDOWN_AND_PURGE_ROOM_ACTION_NAME, + resource_id=room_id, + params=shutdown_params, + ) # we log the delete_id here so that it can be tied back to the # request id in the log lines. @@ -775,19 +752,4 @@ class PaginationHandler: delete_id, ) - self._delete_by_id[delete_id] = DeleteStatus() - self._delete_by_room.setdefault(room_id, []).append(delete_id) - run_as_background_process( - "shutdown_and_purge_room", - self._shutdown_and_purge_room, - delete_id, - room_id, - requester_user_id, - new_room_user_id, - new_room_name, - message, - block, - purge, - force_purge, - ) return delete_id diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 4ad2233573..7c7cda3e95 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -13,26 +13,71 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""This module is responsible for keeping track of presence status of local +""" +This module is responsible for keeping track of presence status of local and remote users. The methods that define policy are: - PresenceHandler._update_states - PresenceHandler._handle_timeouts - should_notify + +# Tracking local presence + +For local users, presence is tracked on a per-device basis. When a user has multiple +devices the user presence state is derived by coalescing the presence from each +device: + + BUSY > ONLINE > UNAVAILABLE > OFFLINE + +The time that each device was last active and last synced is tracked in order to +automatically downgrade a device's presence state: + + A device may move from ONLINE -> UNAVAILABLE, if it has not been active for + a period of time. + + A device may go from any state -> OFFLINE, if it is not active and has not + synced for a period of time. + +The timeouts are handled using a wheel timer, which has coarse buckets. Timings +do not need to be exact. + +Generally a device's presence state is updated whenever a user syncs (via the +set_presence parameter), when the presence API is called, or if "pro-active" +events occur, including: + +* Sending an event, receipt, read marker. +* Updating typing status. + +The busy state has special status that it cannot is not downgraded by a call to +sync with a lower priority state *and* it takes a long period of time to transition +to offline. + +# Persisting (and restoring) presence + +For all users, presence is persisted on a per-user basis. Data is kept in-memory +and persisted periodically. When Synapse starts each worker loads the current +presence state and then tracks the presence stream to keep itself up-to-date. + +When restoring presence for local users a pseudo-device is created to match the +user state; this device follows the normal timeout logic (see above) and will +automatically be replaced with any information from currently available devices. + """ import abc import contextlib +import itertools import logging from bisect import bisect from contextlib import contextmanager from types import TracebackType from typing import ( TYPE_CHECKING, + AbstractSet, Any, - Awaitable, Callable, Collection, + ContextManager, Dict, Generator, Iterable, @@ -44,17 +89,19 @@ from typing import ( ) from prometheus_client import Counter -from typing_extensions import ContextManager import synapse.metrics from synapse.api.constants import EduTypes, EventTypes, Membership, PresenceState from synapse.api.errors import SynapseError -from synapse.api.presence import UserPresenceState +from synapse.api.presence import UserDevicePresenceState, UserPresenceState from synapse.appservice import ApplicationService from synapse.events.presence_router import PresenceRouter from synapse.logging.context import run_in_background from synapse.metrics import LaterGauge -from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.metrics.background_process_metrics import ( + run_as_background_process, + wrap_as_background_process, +) from synapse.replication.http.presence import ( ReplicationBumpPresenceActiveTime, ReplicationPresenceSetState, @@ -95,13 +142,12 @@ bump_active_time_counter = Counter("synapse_handler_presence_bump_active_time", get_updates_counter = Counter("synapse_handler_presence_get_updates", "", ["type"]) notify_reason_counter = Counter( - "synapse_handler_presence_notify_reason", "", ["reason"] + "synapse_handler_presence_notify_reason", "", ["locality", "reason"] ) state_transition_counter = Counter( - "synapse_handler_presence_state_transition", "", ["from", "to"] + "synapse_handler_presence_state_transition", "", ["locality", "from", "to"] ) - # If a user was last active in the last LAST_ACTIVE_GRANULARITY, consider them # "currently_active" LAST_ACTIVE_GRANULARITY = 60 * 1000 @@ -109,6 +155,8 @@ LAST_ACTIVE_GRANULARITY = 60 * 1000 # How long to wait until a new /events or /sync request before assuming # the client has gone. SYNC_ONLINE_TIMEOUT = 30 * 1000 +# Busy status waits longer, but does eventually go offline. +BUSY_ONLINE_TIMEOUT = 60 * 60 * 1000 # How long to wait before marking the user as idle. Compared against last active IDLE_TIMER = 5 * 60 * 1000 @@ -135,6 +183,7 @@ class BasePresenceHandler(abc.ABC): writer""" def __init__(self, hs: "HomeServer"): + self.hs = hs self.clock = hs.get_clock() self.store = hs.get_datastores().main self._storage_controllers = hs.get_storage_controllers() @@ -142,20 +191,34 @@ class BasePresenceHandler(abc.ABC): self.state = hs.get_state_handler() self.is_mine_id = hs.is_mine_id + self._presence_enabled = hs.config.server.use_presence + self._federation = None if hs.should_send_federation(): self._federation = hs.get_federation_sender() self._federation_queue = PresenceFederationQueue(hs, self) - self._busy_presence_enabled = hs.config.experimental.msc3026_enabled + self.VALID_PRESENCE: Tuple[str, ...] = ( + PresenceState.ONLINE, + PresenceState.UNAVAILABLE, + PresenceState.OFFLINE, + ) + + if hs.config.experimental.msc3026_enabled: + self.VALID_PRESENCE += (PresenceState.BUSY,) active_presence = self.store.take_presence_startup_info() + # The combined status across all user devices. self.user_to_current_state = {state.user_id: state for state in active_presence} @abc.abstractmethod async def user_syncing( - self, user_id: str, affect_presence: bool, presence_state: str + self, + user_id: str, + device_id: Optional[str], + affect_presence: bool, + presence_state: str, ) -> ContextManager[None]: """Returns a context manager that should surround any stream requests from the user. @@ -166,6 +229,7 @@ class BasePresenceHandler(abc.ABC): Args: user_id: the user that is starting a sync + device_id: the user's device that is starting a sync affect_presence: If false this function will be a no-op. Useful for streams that are not associated with an actual client that is being used by a user. @@ -173,15 +237,17 @@ class BasePresenceHandler(abc.ABC): """ @abc.abstractmethod - def get_currently_syncing_users_for_replication(self) -> Iterable[str]: - """Get an iterable of syncing users on this worker, to send to the presence handler + def get_currently_syncing_users_for_replication( + self, + ) -> Iterable[Tuple[str, Optional[str]]]: + """Get an iterable of syncing users and devices on this worker, to send to the presence handler This is called when a replication connection is established. It should return - a list of user ids, which are then sent as USER_SYNC commands to inform the - process handling presence about those users. + a list of tuples of user ID & device ID, which are then sent as USER_SYNC commands + to inform the process handling presence about those users/devices. Returns: - An iterable of user_id strings. + An iterable of tuples of user ID and device ID. """ async def get_state(self, target_user: UserID) -> UserPresenceState: @@ -242,28 +308,39 @@ class BasePresenceHandler(abc.ABC): async def set_state( self, target_user: UserID, + device_id: Optional[str], state: JsonDict, - ignore_status_msg: bool = False, force_notify: bool = False, + is_sync: bool = False, ) -> None: """Set the presence state of the user. Args: target_user: The ID of the user to set the presence state of. + device_id: the device that the user is setting the presence state of. state: The presence state as a JSON dictionary. - ignore_status_msg: True to ignore the "status_msg" field of the `state` dict. - If False, the user's current status will be updated. force_notify: Whether to force notification of the update to clients. + is_sync: True if this update was from a sync, which results in + *not* overriding a previously set BUSY status, updating the + user's last_user_sync_ts, and ignoring the "status_msg" field of + the `state` dict. """ @abc.abstractmethod - async def bump_presence_active_time(self, user: UserID) -> None: + async def bump_presence_active_time( + self, user: UserID, device_id: Optional[str] + ) -> None: """We've seen the user do something that indicates they're interacting with the app. """ async def update_external_syncs_row( # noqa: B027 (no-op by design) - self, process_id: str, user_id: str, is_syncing: bool, sync_time_msec: int + self, + process_id: str, + user_id: str, + device_id: Optional[str], + is_syncing: bool, + sync_time_msec: int, ) -> None: """Update the syncing users for an external process as a delta. @@ -274,6 +351,7 @@ class BasePresenceHandler(abc.ABC): syncing against. This allows synapse to process updates as user start and stop syncing against a given process. user_id: The user who has started or stopped syncing + device_id: The user's device that has started or stopped syncing is_syncing: Whether or not the user is now syncing sync_time_msec: Time in ms when the user was last syncing """ @@ -323,8 +401,10 @@ class BasePresenceHandler(abc.ABC): states, ) - for destination, host_states in hosts_to_states.items(): - self._federation.send_presence_to_destinations(host_states, [destination]) + for destinations, host_states in hosts_to_states: + await self._federation.send_presence_to_destinations( + host_states, destinations + ) async def send_full_presence_to_users(self, user_ids: StrCollection) -> None: """ @@ -369,7 +449,9 @@ class BasePresenceHandler(abc.ABC): # We set force_notify=True here so that this presence update is guaranteed to # increment the presence stream ID (which resending the current user's presence # otherwise would not do). - await self.set_state(UserID.from_string(user_id), state, force_notify=True) + await self.set_state( + UserID.from_string(user_id), None, state, force_notify=True + ) async def is_visible(self, observed_user: UserID, observer_user: UserID) -> bool: raise NotImplementedError( @@ -392,28 +474,26 @@ class _NullContextManager(ContextManager[None]): class WorkerPresenceHandler(BasePresenceHandler): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.hs = hs - self._presence_writer_instance = hs.config.worker.writers.presence[0] - self._presence_enabled = hs.config.server.use_presence - # Route presence EDUs to the right worker hs.get_federation_registry().register_instances_for_edu( EduTypes.PRESENCE, hs.config.worker.writers.presence, ) - # The number of ongoing syncs on this process, by user id. + # The number of ongoing syncs on this process, by (user ID, device ID). # Empty if _presence_enabled is false. - self._user_to_num_current_syncs: Dict[str, int] = {} + self._user_device_to_num_current_syncs: Dict[ + Tuple[str, Optional[str]], int + ] = {} self.notifier = hs.get_notifier() self.instance_id = hs.get_instance_id() - # user_id -> last_sync_ms. Lists the users that have stopped syncing but - # we haven't notified the presence writer of that yet - self.users_going_offline: Dict[str, int] = {} + # (user_id, device_id) -> last_sync_ms. Lists the devices that have stopped + # syncing but we haven't notified the presence writer of that yet + self._user_devices_going_offline: Dict[Tuple[str, Optional[str]], int] = {} self._bump_active_client = ReplicationBumpPresenceActiveTime.make_client(hs) self._set_state_client = ReplicationPresenceSetState.make_client(hs) @@ -422,8 +502,6 @@ class WorkerPresenceHandler(BasePresenceHandler): self.send_stop_syncing, UPDATE_SYNCING_USERS_MS ) - self._busy_presence_enabled = hs.config.experimental.msc3026_enabled - hs.get_reactor().addSystemEventTrigger( "before", "shutdown", @@ -438,42 +516,54 @@ class WorkerPresenceHandler(BasePresenceHandler): ClearUserSyncsCommand(self.instance_id) ) - def send_user_sync(self, user_id: str, is_syncing: bool, last_sync_ms: int) -> None: + def send_user_sync( + self, + user_id: str, + device_id: Optional[str], + is_syncing: bool, + last_sync_ms: int, + ) -> None: if self._presence_enabled: self.hs.get_replication_command_handler().send_user_sync( - self.instance_id, user_id, is_syncing, last_sync_ms + self.instance_id, user_id, device_id, is_syncing, last_sync_ms ) - def mark_as_coming_online(self, user_id: str) -> None: + def mark_as_coming_online(self, user_id: str, device_id: Optional[str]) -> None: """A user has started syncing. Send a UserSync to the presence writer, unless they had recently stopped syncing. """ - going_offline = self.users_going_offline.pop(user_id, None) + going_offline = self._user_devices_going_offline.pop((user_id, device_id), None) if not going_offline: # Safe to skip because we haven't yet told the presence writer they # were offline - self.send_user_sync(user_id, True, self.clock.time_msec()) + self.send_user_sync(user_id, device_id, True, self.clock.time_msec()) - def mark_as_going_offline(self, user_id: str) -> None: + def mark_as_going_offline(self, user_id: str, device_id: Optional[str]) -> None: """A user has stopped syncing. We wait before notifying the presence writer as its likely they'll come back soon. This allows us to avoid sending a stopped syncing immediately followed by a started syncing notification to the presence writer """ - self.users_going_offline[user_id] = self.clock.time_msec() + self._user_devices_going_offline[(user_id, device_id)] = self.clock.time_msec() def send_stop_syncing(self) -> None: """Check if there are any users who have stopped syncing a while ago and haven't come back yet. If there are poke the presence writer about them. """ now = self.clock.time_msec() - for user_id, last_sync_ms in list(self.users_going_offline.items()): + for (user_id, device_id), last_sync_ms in list( + self._user_devices_going_offline.items() + ): if now - last_sync_ms > UPDATE_SYNCING_USERS_MS: - self.users_going_offline.pop(user_id, None) - self.send_user_sync(user_id, False, last_sync_ms) + self._user_devices_going_offline.pop((user_id, device_id), None) + self.send_user_sync(user_id, device_id, False, last_sync_ms) async def user_syncing( - self, user_id: str, affect_presence: bool, presence_state: str + self, + user_id: str, + device_id: Optional[str], + affect_presence: bool, + presence_state: str, ) -> ContextManager[None]: """Record that a user is syncing. @@ -483,34 +573,32 @@ class WorkerPresenceHandler(BasePresenceHandler): if not affect_presence or not self._presence_enabled: return _NullContextManager() - prev_state = await self.current_state_for_user(user_id) - if prev_state.state != PresenceState.BUSY: - # We set state here but pass ignore_status_msg = True as we don't want to - # cause the status message to be cleared. - # Note that this causes last_active_ts to be incremented which is not - # what the spec wants: see comment in the BasePresenceHandler version - # of this function. - await self.set_state( - UserID.from_string(user_id), {"presence": presence_state}, True - ) + # Note that this causes last_active_ts to be incremented which is not + # what the spec wants. + await self.set_state( + UserID.from_string(user_id), + device_id, + state={"presence": presence_state}, + is_sync=True, + ) - curr_sync = self._user_to_num_current_syncs.get(user_id, 0) - self._user_to_num_current_syncs[user_id] = curr_sync + 1 + curr_sync = self._user_device_to_num_current_syncs.get((user_id, device_id), 0) + self._user_device_to_num_current_syncs[(user_id, device_id)] = curr_sync + 1 - # If we went from no in flight sync to some, notify replication - if self._user_to_num_current_syncs[user_id] == 1: - self.mark_as_coming_online(user_id) + # If this is the first in-flight sync, notify replication + if self._user_device_to_num_current_syncs[(user_id, device_id)] == 1: + self.mark_as_coming_online(user_id, device_id) def _end() -> None: # We check that the user_id is in user_to_num_current_syncs because # user_to_num_current_syncs may have been cleared if we are # shutting down. - if user_id in self._user_to_num_current_syncs: - self._user_to_num_current_syncs[user_id] -= 1 + if (user_id, device_id) in self._user_device_to_num_current_syncs: + self._user_device_to_num_current_syncs[(user_id, device_id)] -= 1 - # If we went from one in flight sync to non, notify replication - if self._user_to_num_current_syncs[user_id] == 0: - self.mark_as_going_offline(user_id) + # If there are no more in-flight syncs, notify replication + if self._user_device_to_num_current_syncs[(user_id, device_id)] == 0: + self.mark_as_going_offline(user_id, device_id) @contextlib.contextmanager def _user_syncing() -> Generator[None, None, None]: @@ -567,8 +655,8 @@ class WorkerPresenceHandler(BasePresenceHandler): for new_state in states: old_state = self.user_to_current_state.get(new_state.user_id) self.user_to_current_state[new_state.user_id] = new_state - - if not old_state or should_notify(old_state, new_state): + is_mine = self.is_mine_id(new_state.user_id) + if not old_state or should_notify(old_state, new_state, is_mine): state_to_notify.append(new_state) stream_id = token @@ -577,81 +665,80 @@ class WorkerPresenceHandler(BasePresenceHandler): # If this is a federation sender, notify about presence updates. await self.maybe_send_presence_to_interested_destinations(state_to_notify) - def get_currently_syncing_users_for_replication(self) -> Iterable[str]: + def get_currently_syncing_users_for_replication( + self, + ) -> Iterable[Tuple[str, Optional[str]]]: return [ - user_id - for user_id, count in self._user_to_num_current_syncs.items() + user_id_device_id + for user_id_device_id, count in self._user_device_to_num_current_syncs.items() if count > 0 ] async def set_state( self, target_user: UserID, + device_id: Optional[str], state: JsonDict, - ignore_status_msg: bool = False, force_notify: bool = False, + is_sync: bool = False, ) -> None: """Set the presence state of the user. Args: target_user: The ID of the user to set the presence state of. + device_id: the device that the user is setting the presence state of. state: The presence state as a JSON dictionary. - ignore_status_msg: True to ignore the "status_msg" field of the `state` dict. - If False, the user's current status will be updated. force_notify: Whether to force notification of the update to clients. + is_sync: True if this update was from a sync, which results in + *not* overriding a previously set BUSY status, updating the + user's last_user_sync_ts, and ignoring the "status_msg" field of + the `state` dict. """ presence = state["presence"] - valid_presence = ( - PresenceState.ONLINE, - PresenceState.UNAVAILABLE, - PresenceState.OFFLINE, - PresenceState.BUSY, - ) - - if presence not in valid_presence or ( - presence == PresenceState.BUSY and not self._busy_presence_enabled - ): + if presence not in self.VALID_PRESENCE: raise SynapseError(400, "Invalid presence state") user_id = target_user.to_string() # If presence is disabled, no-op - if not self.hs.config.server.use_presence: + if not self._presence_enabled: return # Proxy request to instance that writes presence await self._set_state_client( instance_name=self._presence_writer_instance, user_id=user_id, + device_id=device_id, state=state, - ignore_status_msg=ignore_status_msg, force_notify=force_notify, + is_sync=is_sync, ) - async def bump_presence_active_time(self, user: UserID) -> None: + async def bump_presence_active_time( + self, user: UserID, device_id: Optional[str] + ) -> None: """We've seen the user do something that indicates they're interacting with the app. """ # If presence is disabled, no-op - if not self.hs.config.server.use_presence: + if not self._presence_enabled: return # Proxy request to instance that writes presence user_id = user.to_string() await self._bump_active_client( - instance_name=self._presence_writer_instance, user_id=user_id + instance_name=self._presence_writer_instance, + user_id=user_id, + device_id=device_id, ) class PresenceHandler(BasePresenceHandler): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.hs = hs - self.server_name = hs.hostname self.wheel_timer: WheelTimer[str] = WheelTimer() self.notifier = hs.get_notifier() - self._presence_enabled = hs.config.server.use_presence federation_registry = hs.get_federation_registry() @@ -666,9 +753,27 @@ class PresenceHandler(BasePresenceHandler): lambda: len(self.user_to_current_state), ) + # The per-device presence state, maps user to devices to per-device presence state. + self._user_to_device_to_current_state: Dict[ + str, Dict[Optional[str], UserDevicePresenceState] + ] = {} + now = self.clock.time_msec() if self._presence_enabled: for state in self.user_to_current_state.values(): + # Create a psuedo-device to properly handle time outs. This will + # be overridden by any "real" devices within SYNC_ONLINE_TIMEOUT. + pseudo_device_id = None + self._user_to_device_to_current_state[state.user_id] = { + pseudo_device_id: UserDevicePresenceState( + user_id=state.user_id, + device_id=pseudo_device_id, + state=state.state, + last_active_ts=state.last_active_ts, + last_sync_ts=state.last_user_sync_ts, + ) + } + self.wheel_timer.insert( now=now, obj=state.user_id, then=state.last_active_ts + IDLE_TIMER ) @@ -702,21 +807,25 @@ class PresenceHandler(BasePresenceHandler): self._on_shutdown, ) - self._next_serial = 1 - # Keeps track of the number of *ongoing* syncs on this process. While # this is non zero a user will never go offline. - self.user_to_num_current_syncs: Dict[str, int] = {} + self._user_device_to_num_current_syncs: Dict[ + Tuple[str, Optional[str]], int + ] = {} # Keeps track of the number of *ongoing* syncs on other processes. - # While any sync is ongoing on another process the user will never + # + # While any sync is ongoing on another process the user's device will never # go offline. + # # Each process has a unique identifier and an update frequency. If # no update is received from that process within the update period then # we assume that all the sync requests on that process have stopped. - # Stored as a dict from process_id to set of user_id, and a dict of - # process_id to millisecond timestamp last updated. - self.external_process_to_current_syncs: Dict[str, Set[str]] = {} + # Stored as a dict from process_id to set of (user_id, device_id), and + # a dict of process_id to millisecond timestamp last updated. + self.external_process_to_current_syncs: Dict[ + str, Set[Tuple[str, Optional[str]]] + ] = {} self.external_process_last_updated_ms: Dict[str, int] = {} self.external_sync_linearizer = Linearizer(name="external_sync_linearizer") @@ -725,21 +834,16 @@ class PresenceHandler(BasePresenceHandler): # Start a LoopingCall in 30s that fires every 5s. # The initial delay is to allow disconnected clients a chance to # reconnect before we treat them as offline. - def run_timeout_handler() -> Awaitable[None]: - return run_as_background_process( - "handle_presence_timeouts", self._handle_timeouts - ) - self.clock.call_later( - 30, self.clock.looping_call, run_timeout_handler, 5000 + 30, self.clock.looping_call, self._handle_timeouts, 5000 ) - def run_persister() -> Awaitable[None]: - return run_as_background_process( - "persist_presence_changes", self._persist_unpersisted_changes - ) - - self.clock.call_later(60, self.clock.looping_call, run_persister, 60 * 1000) + self.clock.call_later( + 60, + self.clock.looping_call, + self._persist_unpersisted_changes, + 60 * 1000, + ) LaterGauge( "synapse_handlers_presence_wheel_timer_size", @@ -785,6 +889,7 @@ class PresenceHandler(BasePresenceHandler): ) logger.info("Finished _on_shutdown") + @wrap_as_background_process("persist_presence_changes") async def _persist_unpersisted_changes(self) -> None: """We periodically persist the unpersisted changes, as otherwise they may stack up and slow down shutdown times. @@ -895,11 +1000,12 @@ class PresenceHandler(BasePresenceHandler): list(to_federation_ping.values()), ) - for destination, states in hosts_to_states.items(): - self._federation_queue.send_presence_to_destinations( - states, [destination] + for destinations, states in hosts_to_states: + await self._federation_queue.send_presence_to_destinations( + states, destinations ) + @wrap_as_background_process("handle_presence_timeouts") async def _handle_timeouts(self) -> None: """Checks the presence of users that have timed out and updates as appropriate. @@ -924,7 +1030,10 @@ class PresenceHandler(BasePresenceHandler): # that were syncing on that process to see if they need to be timed # out. users_to_check.update( - self.external_process_to_current_syncs.pop(process_id, ()) + user_id + for user_id, device_id in self.external_process_to_current_syncs.pop( + process_id, () + ) ) self.external_process_last_updated_ms.pop(process_id) @@ -935,46 +1044,67 @@ class PresenceHandler(BasePresenceHandler): timers_fired_counter.inc(len(states)) - syncing_user_ids = { - user_id - for user_id, count in self.user_to_num_current_syncs.items() + # Set of user ID & device IDs which are currently syncing. + syncing_user_devices = { + user_id_device_id + for user_id_device_id, count in self._user_device_to_num_current_syncs.items() if count } - for user_ids in self.external_process_to_current_syncs.values(): - syncing_user_ids.update(user_ids) + syncing_user_devices.update( + itertools.chain(*self.external_process_to_current_syncs.values()) + ) changes = handle_timeouts( states, is_mine_fn=self.is_mine_id, - syncing_user_ids=syncing_user_ids, + syncing_user_devices=syncing_user_devices, + user_to_devices=self._user_to_device_to_current_state, now=now, ) return await self._update_states(changes) - async def bump_presence_active_time(self, user: UserID) -> None: + async def bump_presence_active_time( + self, user: UserID, device_id: Optional[str] + ) -> None: """We've seen the user do something that indicates they're interacting with the app. """ # If presence is disabled, no-op - if not self.hs.config.server.use_presence: + if not self._presence_enabled: return user_id = user.to_string() bump_active_time_counter.inc() - prev_state = await self.current_state_for_user(user_id) + now = self.clock.time_msec() + + # Update the device information & mark the device as online if it was + # unavailable. + devices = self._user_to_device_to_current_state.setdefault(user_id, {}) + device_state = devices.setdefault( + device_id, + UserDevicePresenceState.default(user_id, device_id), + ) + device_state.last_active_ts = now + if device_state.state == PresenceState.UNAVAILABLE: + device_state.state = PresenceState.ONLINE - new_fields: Dict[str, Any] = {"last_active_ts": self.clock.time_msec()} - if prev_state.state == PresenceState.UNAVAILABLE: - new_fields["state"] = PresenceState.ONLINE + # Update the user state, this will always update last_active_ts and + # might update the presence state. + prev_state = await self.current_state_for_user(user_id) + new_fields: Dict[str, Any] = { + "last_active_ts": now, + "state": _combine_device_states(devices.values()), + } await self._update_states([prev_state.copy_and_replace(**new_fields)]) async def user_syncing( self, user_id: str, + device_id: Optional[str], affect_presence: bool = True, presence_state: str = PresenceState.ONLINE, ) -> ContextManager[None]: @@ -986,66 +1116,31 @@ class PresenceHandler(BasePresenceHandler): when users disconnect/reconnect. Args: - user_id + user_id: the user that is starting a sync + device_id: the user's device that is starting a sync affect_presence: If false this function will be a no-op. Useful for streams that are not associated with an actual client that is being used by a user. presence_state: The presence state indicated in the sync request """ - # Override if it should affect the user's presence, if presence is - # disabled. - if not self.hs.config.server.use_presence: - affect_presence = False - - if affect_presence: - curr_sync = self.user_to_num_current_syncs.get(user_id, 0) - self.user_to_num_current_syncs[user_id] = curr_sync + 1 - - prev_state = await self.current_state_for_user(user_id) + if not affect_presence or not self._presence_enabled: + return _NullContextManager() - # If they're busy then they don't stop being busy just by syncing, - # so just update the last sync time. - if prev_state.state != PresenceState.BUSY: - # XXX: We set_state separately here and just update the last_active_ts above - # This keeps the logic as similar as possible between the worker and single - # process modes. Using set_state will actually cause last_active_ts to be - # updated always, which is not what the spec calls for, but synapse has done - # this for... forever, I think. - await self.set_state( - UserID.from_string(user_id), {"presence": presence_state}, True - ) - # Retrieve the new state for the logic below. This should come from the - # in-memory cache. - prev_state = await self.current_state_for_user(user_id) + curr_sync = self._user_device_to_num_current_syncs.get((user_id, device_id), 0) + self._user_device_to_num_current_syncs[(user_id, device_id)] = curr_sync + 1 - # To keep the single process behaviour consistent with worker mode, run the - # same logic as `update_external_syncs_row`, even though it looks weird. - if prev_state.state == PresenceState.OFFLINE: - await self._update_states( - [ - prev_state.copy_and_replace( - state=PresenceState.ONLINE, - last_active_ts=self.clock.time_msec(), - last_user_sync_ts=self.clock.time_msec(), - ) - ] - ) - # otherwise, set the new presence state & update the last sync time, - # but don't update last_active_ts as this isn't an indication that - # they've been active (even though it's probably been updated by - # set_state above) - else: - await self._update_states( - [ - prev_state.copy_and_replace( - last_user_sync_ts=self.clock.time_msec() - ) - ] - ) + # Note that this causes last_active_ts to be incremented which is not + # what the spec wants. + await self.set_state( + UserID.from_string(user_id), + device_id, + state={"presence": presence_state}, + is_sync=True, + ) async def _end() -> None: try: - self.user_to_num_current_syncs[user_id] -= 1 + self._user_device_to_num_current_syncs[(user_id, device_id)] -= 1 prev_state = await self.current_state_for_user(user_id) await self._update_states( @@ -1063,17 +1158,23 @@ class PresenceHandler(BasePresenceHandler): try: yield finally: - if affect_presence: - run_in_background(_end) + run_in_background(_end) return _user_syncing() - def get_currently_syncing_users_for_replication(self) -> Iterable[str]: + def get_currently_syncing_users_for_replication( + self, + ) -> Iterable[Tuple[str, Optional[str]]]: # since we are the process handling presence, there is nothing to do here. return [] async def update_external_syncs_row( - self, process_id: str, user_id: str, is_syncing: bool, sync_time_msec: int + self, + process_id: str, + user_id: str, + device_id: Optional[str], + is_syncing: bool, + sync_time_msec: int, ) -> None: """Update the syncing users for an external process as a delta. @@ -1082,6 +1183,7 @@ class PresenceHandler(BasePresenceHandler): syncing against. This allows synapse to process updates as user start and stop syncing against a given process. user_id: The user who has started or stopped syncing + device_id: The user's device that has started or stopped syncing is_syncing: Whether or not the user is now syncing sync_time_msec: Time in ms when the user was last syncing """ @@ -1092,31 +1194,33 @@ class PresenceHandler(BasePresenceHandler): process_id, set() ) - updates = [] - if is_syncing and user_id not in process_presence: - if prev_state.state == PresenceState.OFFLINE: - updates.append( - prev_state.copy_and_replace( - state=PresenceState.ONLINE, - last_active_ts=sync_time_msec, - last_user_sync_ts=sync_time_msec, - ) - ) - else: - updates.append( - prev_state.copy_and_replace(last_user_sync_ts=sync_time_msec) - ) - process_presence.add(user_id) - elif user_id in process_presence: - updates.append( - prev_state.copy_and_replace(last_user_sync_ts=sync_time_msec) + # USER_SYNC is sent when a user's device starts or stops syncing on + # a remote # process. (But only for the initial and last sync for that + # device.) + # + # When a device *starts* syncing it also calls set_state(...) which + # will update the state, last_active_ts, and last_user_sync_ts. + # Simply ensure the user & device is tracked as syncing in this case. + # + # When a device *stops* syncing, update the last_user_sync_ts and mark + # them as no longer syncing. Note this doesn't quite match the + # monolith behaviour, which updates last_user_sync_ts at the end of + # every sync, not just the last in-flight sync. + if is_syncing and (user_id, device_id) not in process_presence: + process_presence.add((user_id, device_id)) + elif not is_syncing and (user_id, device_id) in process_presence: + devices = self._user_to_device_to_current_state.setdefault(user_id, {}) + device_state = devices.setdefault( + device_id, UserDevicePresenceState.default(user_id, device_id) ) + device_state.last_sync_ts = sync_time_msec - if not is_syncing: - process_presence.discard(user_id) + new_state = prev_state.copy_and_replace( + last_user_sync_ts=sync_time_msec + ) + await self._update_states([new_state]) - if updates: - await self._update_states(updates) + process_presence.discard((user_id, device_id)) self.external_process_last_updated_ms[process_id] = self.clock.time_msec() @@ -1130,9 +1234,24 @@ class PresenceHandler(BasePresenceHandler): process_presence = self.external_process_to_current_syncs.pop( process_id, set() ) - prev_states = await self.current_state_for_users(process_presence) + time_now_ms = self.clock.time_msec() + # Mark each device as having a last sync time. + updated_users = set() + for user_id, device_id in process_presence: + device_state = self._user_to_device_to_current_state.setdefault( + user_id, {} + ).setdefault( + device_id, UserDevicePresenceState.default(user_id, device_id) + ) + + device_state.last_sync_ts = time_now_ms + updated_users.add(user_id) + + # Update each user (and insert into the appropriate timers to check if + # they've gone offline). + prev_states = await self.current_state_for_users(updated_users) await self._update_states( [ prev_state.copy_and_replace(last_user_sync_ts=time_now_ms) @@ -1215,51 +1334,69 @@ class PresenceHandler(BasePresenceHandler): async def set_state( self, target_user: UserID, + device_id: Optional[str], state: JsonDict, - ignore_status_msg: bool = False, force_notify: bool = False, + is_sync: bool = False, ) -> None: """Set the presence state of the user. Args: target_user: The ID of the user to set the presence state of. + device_id: the device that the user is setting the presence state of. state: The presence state as a JSON dictionary. - ignore_status_msg: True to ignore the "status_msg" field of the `state` dict. - If False, the user's current status will be updated. force_notify: Whether to force notification of the update to clients. + is_sync: True if this update was from a sync, which results in + *not* overriding a previously set BUSY status, updating the + user's last_user_sync_ts, and ignoring the "status_msg" field of + the `state` dict. """ status_msg = state.get("status_msg", None) presence = state["presence"] - valid_presence = ( - PresenceState.ONLINE, - PresenceState.UNAVAILABLE, - PresenceState.OFFLINE, - PresenceState.BUSY, - ) - - if presence not in valid_presence or ( - presence == PresenceState.BUSY and not self._busy_presence_enabled - ): + if presence not in self.VALID_PRESENCE: raise SynapseError(400, "Invalid presence state") # If presence is disabled, no-op - if not self.hs.config.server.use_presence: + if not self._presence_enabled: return user_id = target_user.to_string() + now = self.clock.time_msec() prev_state = await self.current_state_for_user(user_id) + # Syncs do not override a previous presence of busy. + # + # TODO: This is a hack for lack of multi-device support. Unfortunately + # removing this requires coordination with clients. + if prev_state.state == PresenceState.BUSY and is_sync: + presence = PresenceState.BUSY + + # Update the device specific information. + devices = self._user_to_device_to_current_state.setdefault(user_id, {}) + device_state = devices.setdefault( + device_id, + UserDevicePresenceState.default(user_id, device_id), + ) + device_state.state = presence + device_state.last_active_ts = now + if is_sync: + device_state.last_sync_ts = now + + # Based on the state of each user's device calculate the new presence state. + presence = _combine_device_states(devices.values()) + new_fields = {"state": presence} - if not ignore_status_msg: - new_fields["status_msg"] = status_msg + if presence == PresenceState.ONLINE or presence == PresenceState.BUSY: + new_fields["last_active_ts"] = now - if presence == PresenceState.ONLINE or ( - presence == PresenceState.BUSY and self._busy_presence_enabled - ): - new_fields["last_active_ts"] = self.clock.time_msec() + if is_sync: + new_fields["last_user_sync_ts"] = now + else: + # Syncs do not override the status message. + new_fields["status_msg"] = status_msg await self._update_states( [prev_state.copy_and_replace(**new_fields)], force_notify=force_notify @@ -1483,7 +1620,7 @@ class PresenceHandler(BasePresenceHandler): or state.status_msg is not None ] - self._federation_queue.send_presence_to_destinations( + await self._federation_queue.send_presence_to_destinations( destinations=newly_joined_remote_hosts, states=states, ) @@ -1494,29 +1631,37 @@ class PresenceHandler(BasePresenceHandler): prev_remote_hosts or newly_joined_remote_hosts ): local_states = await self.current_state_for_users(newly_joined_local_users) - self._federation_queue.send_presence_to_destinations( + await self._federation_queue.send_presence_to_destinations( destinations=prev_remote_hosts | newly_joined_remote_hosts, states=list(local_states.values()), ) -def should_notify(old_state: UserPresenceState, new_state: UserPresenceState) -> bool: +def should_notify( + old_state: UserPresenceState, new_state: UserPresenceState, is_mine: bool +) -> bool: """Decides if a presence state change should be sent to interested parties.""" + user_location = "remote" + if is_mine: + user_location = "local" + if old_state == new_state: return False if old_state.status_msg != new_state.status_msg: - notify_reason_counter.labels("status_msg_change").inc() + notify_reason_counter.labels(user_location, "status_msg_change").inc() return True if old_state.state != new_state.state: - notify_reason_counter.labels("state_change").inc() - state_transition_counter.labels(old_state.state, new_state.state).inc() + notify_reason_counter.labels(user_location, "state_change").inc() + state_transition_counter.labels( + user_location, old_state.state, new_state.state + ).inc() return True if old_state.state == PresenceState.ONLINE: if new_state.currently_active != old_state.currently_active: - notify_reason_counter.labels("current_active_change").inc() + notify_reason_counter.labels(user_location, "current_active_change").inc() return True if ( @@ -1525,12 +1670,16 @@ def should_notify(old_state: UserPresenceState, new_state: UserPresenceState) -> ): # Only notify about last active bumps if we're not currently active if not new_state.currently_active: - notify_reason_counter.labels("last_active_change_online").inc() + notify_reason_counter.labels( + user_location, "last_active_change_online" + ).inc() return True elif new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY: # Always notify for a transition where last active gets bumped. - notify_reason_counter.labels("last_active_change_not_online").inc() + notify_reason_counter.labels( + user_location, "last_active_change_not_online" + ).inc() return True return False @@ -1834,7 +1983,8 @@ class PresenceEventSource(EventSource[int, UserPresenceState]): def handle_timeouts( user_states: List[UserPresenceState], is_mine_fn: Callable[[str], bool], - syncing_user_ids: Set[str], + syncing_user_devices: AbstractSet[Tuple[str, Optional[str]]], + user_to_devices: Dict[str, Dict[Optional[str], UserDevicePresenceState]], now: int, ) -> List[UserPresenceState]: """Checks the presence of users that have timed out and updates as @@ -1843,7 +1993,8 @@ def handle_timeouts( Args: user_states: List of UserPresenceState's to check. is_mine_fn: Function that returns if a user_id is ours - syncing_user_ids: Set of user_ids with active syncs. + syncing_user_devices: A set of (user ID, device ID) tuples with active syncs.. + user_to_devices: A map of user ID to device ID to UserDevicePresenceState. now: Current time in ms. Returns: @@ -1852,9 +2003,16 @@ def handle_timeouts( changes = {} # Actual changes we need to notify people about for state in user_states: - is_mine = is_mine_fn(state.user_id) - - new_state = handle_timeout(state, is_mine, syncing_user_ids, now) + user_id = state.user_id + is_mine = is_mine_fn(user_id) + + new_state = handle_timeout( + state, + is_mine, + syncing_user_devices, + user_to_devices.get(user_id, {}), + now, + ) if new_state: changes[state.user_id] = new_state @@ -1862,14 +2020,19 @@ def handle_timeouts( def handle_timeout( - state: UserPresenceState, is_mine: bool, syncing_user_ids: Set[str], now: int + state: UserPresenceState, + is_mine: bool, + syncing_device_ids: AbstractSet[Tuple[str, Optional[str]]], + user_devices: Dict[Optional[str], UserDevicePresenceState], + now: int, ) -> Optional[UserPresenceState]: """Checks the presence of the user to see if any of the timers have elapsed Args: - state + state: UserPresenceState to check. is_mine: Whether the user is ours - syncing_user_ids: Set of user_ids with active syncs. + syncing_user_devices: A set of (user ID, device ID) tuples with active syncs.. + user_devices: A map of device ID to UserDevicePresenceState. now: Current time in ms. Returns: @@ -1880,34 +2043,63 @@ def handle_timeout( return None changed = False - user_id = state.user_id if is_mine: - if state.state == PresenceState.ONLINE: - if now - state.last_active_ts > IDLE_TIMER: - # Currently online, but last activity ages ago so auto - # idle - state = state.copy_and_replace(state=PresenceState.UNAVAILABLE) - changed = True - elif now - state.last_active_ts > LAST_ACTIVE_GRANULARITY: - # So that we send down a notification that we've - # stopped updating. + # Check per-device whether the device should be considered idle or offline + # due to timeouts. + device_changed = False + offline_devices = [] + for device_id, device_state in user_devices.items(): + if device_state.state == PresenceState.ONLINE: + if now - device_state.last_active_ts > IDLE_TIMER: + # Currently online, but last activity ages ago so auto + # idle + device_state.state = PresenceState.UNAVAILABLE + device_changed = True + + # If there are have been no sync for a while (and none ongoing), + # set presence to offline. + if (state.user_id, device_id) not in syncing_device_ids: + # If the user has done something recently but hasn't synced, + # don't set them as offline. + sync_or_active = max( + device_state.last_sync_ts, device_state.last_active_ts + ) + + # Implementations aren't meant to timeout a device with a busy + # state, but it needs to timeout *eventually* or else the user + # will be stuck in that state. + online_timeout = ( + BUSY_ONLINE_TIMEOUT + if device_state.state == PresenceState.BUSY + else SYNC_ONLINE_TIMEOUT + ) + if now - sync_or_active > online_timeout: + # Mark the device as going offline. + offline_devices.append(device_id) + device_changed = True + + # Offline devices are not needed and do not add information. + for device_id in offline_devices: + user_devices.pop(device_id) + + # If the presence state of the devices changed, then (maybe) update + # the user's overall presence state. + if device_changed: + new_presence = _combine_device_states(user_devices.values()) + if new_presence != state.state: + state = state.copy_and_replace(state=new_presence) changed = True + if now - state.last_active_ts > LAST_ACTIVE_GRANULARITY: + # So that we send down a notification that we've + # stopped updating. + changed = True + if now - state.last_federation_update_ts > FEDERATION_PING_INTERVAL: # Need to send ping to other servers to ensure they don't # timeout and set us to offline changed = True - - # If there are have been no sync for a while (and none ongoing), - # set presence to offline - if user_id not in syncing_user_ids: - # If the user has done something recently but hasn't synced, - # don't set them as offline. - sync_or_active = max(state.last_user_sync_ts, state.last_active_ts) - if now - sync_or_active > SYNC_ONLINE_TIMEOUT: - state = state.copy_and_replace(state=PresenceState.OFFLINE) - changed = True else: # We expect to be poked occasionally by the other side. # This is to protect against forgetful/buggy servers, so that @@ -1982,6 +2174,13 @@ def handle_update( new_state = new_state.copy_and_replace(last_federation_update_ts=now) federation_ping = True + if new_state.state == PresenceState.BUSY: + wheel_timer.insert( + now=now, + obj=user_id, + then=new_state.last_user_sync_ts + BUSY_ONLINE_TIMEOUT, + ) + else: wheel_timer.insert( now=now, @@ -1990,13 +2189,53 @@ def handle_update( ) # Check whether the change was something worth notifying about - if should_notify(prev_state, new_state): + if should_notify(prev_state, new_state, is_mine): new_state = new_state.copy_and_replace(last_federation_update_ts=now) persist_and_notify = True return new_state, persist_and_notify, federation_ping +PRESENCE_BY_PRIORITY = { + PresenceState.BUSY: 4, + PresenceState.ONLINE: 3, + PresenceState.UNAVAILABLE: 2, + PresenceState.OFFLINE: 1, +} + + +def _combine_device_states( + device_states: Iterable[UserDevicePresenceState], +) -> str: + """ + Find the device to use presence information from. + + Orders devices by priority, then last_active_ts. + + Args: + device_states: An iterable of device presence states + + Return: + The combined presence state. + """ + + # Based on (all) the user's devices calculate the new presence state. + presence = PresenceState.OFFLINE + last_active_ts = -1 + + # Find the device to use the presence state of based on the presence priority, + # but tie-break with how recently the device has been seen. + for device_state in device_states: + if (PRESENCE_BY_PRIORITY[device_state.state], device_state.last_active_ts) > ( + PRESENCE_BY_PRIORITY[presence], + last_active_ts, + ): + presence = device_state.state + last_active_ts = device_state.last_active_ts + + return presence + + async def get_interested_parties( store: DataStore, presence_router: PresenceRouter, states: List[UserPresenceState] ) -> Tuple[Dict[str, List[UserPresenceState]], Dict[str, List[UserPresenceState]]]: @@ -2037,7 +2276,7 @@ async def get_interested_remotes( store: DataStore, presence_router: PresenceRouter, states: List[UserPresenceState], -) -> Dict[str, Set[UserPresenceState]]: +) -> List[Tuple[StrCollection, Collection[UserPresenceState]]]: """Given a list of presence states figure out which remote servers should be sent which. @@ -2051,23 +2290,26 @@ async def get_interested_remotes( Returns: A map from destinations to presence states to send to that destination. """ - hosts_and_states: Dict[str, Set[UserPresenceState]] = {} + hosts_and_states: List[Tuple[StrCollection, Collection[UserPresenceState]]] = [] # First we look up the rooms each user is in (as well as any explicit # subscriptions), then for each distinct room we look up the remote # hosts in those rooms. - room_ids_to_states, users_to_states = await get_interested_parties( - store, presence_router, states - ) + for state in states: + room_ids = await store.get_rooms_for_user(state.user_id) + hosts: Set[str] = set() + for room_id in room_ids: + room_hosts = await store.get_current_hosts_in_room(room_id) + hosts.update(room_hosts) + hosts_and_states.append((hosts, [state])) - for room_id, states in room_ids_to_states.items(): - hosts = await store.get_current_hosts_in_room(room_id) - for host in hosts: - hosts_and_states.setdefault(host, set()).update(states) + # Ask a presence routing module for any additional parties if one + # is loaded. + router_users_to_states = await presence_router.get_users_for_states(states) - for user_id, states in users_to_states.items(): + for user_id, user_states in router_users_to_states.items(): host = get_domain_from_id(user_id) - hosts_and_states.setdefault(host, set()).update(states) + hosts_and_states.append(([host], user_states)) return hosts_and_states @@ -2145,7 +2387,7 @@ class PresenceFederationQueue: index = bisect(self._queue, (clear_before,)) self._queue = self._queue[index:] - def send_presence_to_destinations( + async def send_presence_to_destinations( self, states: Collection[UserPresenceState], destinations: StrCollection ) -> None: """Send the presence states to the given destinations. @@ -2165,7 +2407,7 @@ class PresenceFederationQueue: return if self._federation: - self._federation.send_presence_to_destinations( + await self._federation.send_presence_to_destinations( states=states, destinations=destinations, ) @@ -2288,7 +2530,7 @@ class PresenceFederationQueue: for host, user_ids in hosts_to_users.items(): states = await self._presence_handler.current_state_for_users(user_ids) - self._federation.send_presence_to_destinations( + await self._federation.send_presence_to_destinations( states=states.values(), destinations=[host], ) diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index a9160c87e3..c2109036ec 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -67,8 +67,8 @@ class ProfileHandler: target_user = UserID.from_string(user_id) if self.hs.is_mine(target_user): - profileinfo = await self.store.get_profileinfo(target_user.localpart) - if profileinfo.display_name is None: + profileinfo = await self.store.get_profileinfo(target_user) + if profileinfo.display_name is None and profileinfo.avatar_url is None: raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND) return { @@ -99,9 +99,7 @@ class ProfileHandler: async def get_displayname(self, target_user: UserID) -> Optional[str]: if self.hs.is_mine(target_user): try: - displayname = await self.store.get_profile_displayname( - target_user.localpart - ) + displayname = await self.store.get_profile_displayname(target_user) except StoreError as e: if e.code == 404: raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND) @@ -147,7 +145,7 @@ class ProfileHandler: raise AuthError(400, "Cannot set another user's displayname") if not by_admin and not self.hs.config.registration.enable_set_displayname: - profile = await self.store.get_profileinfo(target_user.localpart) + profile = await self.store.get_profileinfo(target_user) if profile.display_name: raise SynapseError( 400, @@ -165,7 +163,7 @@ class ProfileHandler: 400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,) ) - displayname_to_set: Optional[str] = new_displayname + displayname_to_set: Optional[str] = new_displayname.strip() if new_displayname == "": displayname_to_set = None @@ -180,7 +178,7 @@ class ProfileHandler: await self.store.set_profile_displayname(target_user, displayname_to_set) - profile = await self.store.get_profileinfo(target_user.localpart) + profile = await self.store.get_profileinfo(target_user) await self.user_directory_handler.handle_local_profile_change( target_user.to_string(), profile ) @@ -194,9 +192,7 @@ class ProfileHandler: async def get_avatar_url(self, target_user: UserID) -> Optional[str]: if self.hs.is_mine(target_user): try: - avatar_url = await self.store.get_profile_avatar_url( - target_user.localpart - ) + avatar_url = await self.store.get_profile_avatar_url(target_user) except StoreError as e: if e.code == 404: raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND) @@ -241,7 +237,7 @@ class ProfileHandler: raise AuthError(400, "Cannot set another user's avatar_url") if not by_admin and not self.hs.config.registration.enable_set_avatar_url: - profile = await self.store.get_profileinfo(target_user.localpart) + profile = await self.store.get_profileinfo(target_user) if profile.avatar_url: raise SynapseError( 400, "Changing avatar is disabled on this server", Codes.FORBIDDEN @@ -272,7 +268,7 @@ class ProfileHandler: await self.store.set_profile_avatar_url(target_user, avatar_url_to_set) - profile = await self.store.get_profileinfo(target_user.localpart) + profile = await self.store.get_profileinfo(target_user) await self.user_directory_handler.handle_local_profile_change( target_user.to_string(), profile ) @@ -369,14 +365,10 @@ class ProfileHandler: response = {} try: if just_field is None or just_field == "displayname": - response["displayname"] = await self.store.get_profile_displayname( - user.localpart - ) + response["displayname"] = await self.store.get_profile_displayname(user) if just_field is None or just_field == "avatar_url": - response["avatar_url"] = await self.store.get_profile_avatar_url( - user.localpart - ) + response["avatar_url"] = await self.store.get_profile_avatar_url(user) except StoreError as e: if e.code == 404: raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND) diff --git a/synapse/handlers/push_rules.py b/synapse/handlers/push_rules.py index 7ed88a3611..87b428ab1c 100644 --- a/synapse/handlers/push_rules.py +++ b/synapse/handlers/push_rules.py @@ -19,7 +19,7 @@ from synapse.api.errors import SynapseError, UnrecognizedRequestError from synapse.push.clientformat import format_push_rules_for_user from synapse.storage.push_rule import RuleNotFoundException from synapse.synapse_rust.push import get_base_rule_ids -from synapse.types import JsonDict, UserID +from synapse.types import JsonDict, StreamKeyType, UserID if TYPE_CHECKING: from synapse.server import HomeServer @@ -114,7 +114,9 @@ class PushRulesHandler: user_id: the user ID the change is for. """ stream_id = self._main_store.get_max_push_rules_stream_id() - self._notifier.on_new_event("push_rules_key", stream_id, users=[user_id]) + self._notifier.on_new_event( + StreamKeyType.PUSH_RULES, stream_id, users=[user_id] + ) async def push_rules_for_user( self, user: UserID diff --git a/synapse/handlers/read_marker.py b/synapse/handlers/read_marker.py index 49a497a860..df5a4f3e22 100644 --- a/synapse/handlers/read_marker.py +++ b/synapse/handlers/read_marker.py @@ -27,7 +27,6 @@ logger = logging.getLogger(__name__) class ReadMarkerHandler: def __init__(self, hs: "HomeServer"): - self.server_name = hs.config.server.server_name self.store = hs.get_datastores().main self.account_data_handler = hs.get_account_data_handler() self.read_marker_linearizer = Linearizer(name="read_marker") diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 2bacdebfb5..69ac468f75 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -19,6 +19,7 @@ from synapse.appservice import ApplicationService from synapse.streams import EventSource from synapse.types import ( JsonDict, + JsonMapping, ReadReceipt, StreamKeyType, UserID, @@ -37,6 +38,8 @@ class ReceiptsHandler: self.server_name = hs.config.server.server_name self.store = hs.get_datastores().main self.event_auth_handler = hs.get_event_auth_handler() + self.event_handler = hs.get_event_handler() + self._storage_controllers = hs.get_storage_controllers() self.hs = hs @@ -81,6 +84,20 @@ class ReceiptsHandler: ) continue + # Let's check that the origin server is in the room before accepting the receipt. + # We don't want to block waiting on a partial state so take an + # approximation if needed. + domains = await self._storage_controllers.state.get_current_hosts_in_room_or_partial_state_approximation( + room_id + ) + if origin not in domains: + logger.info( + "Ignoring receipt for room %r from server %s as they're not in the room", + room_id, + origin, + ) + continue + for receipt_type, users in room_values.items(): for user_id, user_values in users.items(): if get_domain_from_id(user_id) != origin: @@ -113,11 +130,10 @@ class ReceiptsHandler: async def _handle_new_receipts(self, receipts: List[ReadReceipt]) -> bool: """Takes a list of receipts, stores them and informs the notifier.""" - min_batch_id: Optional[int] = None - max_batch_id: Optional[int] = None + receipts_persisted: List[ReadReceipt] = [] for receipt in receipts: - res = await self.store.insert_receipt( + stream_id = await self.store.insert_receipt( receipt.room_id, receipt.receipt_type, receipt.user_id, @@ -126,30 +142,26 @@ class ReceiptsHandler: receipt.data, ) - if not res: - # res will be None if this receipt is 'old' + if stream_id is None: + # stream_id will be None if this receipt is 'old' continue - stream_id, max_persisted_id = res - - if min_batch_id is None or stream_id < min_batch_id: - min_batch_id = stream_id - if max_batch_id is None or max_persisted_id > max_batch_id: - max_batch_id = max_persisted_id + receipts_persisted.append(receipt) - # Either both of these should be None or neither. - if min_batch_id is None or max_batch_id is None: + if not receipts_persisted: # no new receipts return False - affected_room_ids = list({r.room_id for r in receipts}) + max_batch_id = self.store.get_max_receipt_stream_id() + + affected_room_ids = list({r.room_id for r in receipts_persisted}) self.notifier.on_new_event( StreamKeyType.RECEIPT, max_batch_id, rooms=affected_room_ids ) # Note that the min here shouldn't be relied upon to be accurate. await self.hs.get_pusherpool().on_new_receipts( - min_batch_id, max_batch_id, affected_room_ids + {r.user_id for r in receipts_persisted} ) return True @@ -158,17 +170,23 @@ class ReceiptsHandler: self, room_id: str, receipt_type: str, - user_id: str, + user_id: UserID, event_id: str, thread_id: Optional[str], ) -> None: """Called when a client tells us a local user has read up to the given event_id in the room. """ + + # Ensure the room/event exists, this will raise an error if the user + # cannot view the event. + if not await self.event_handler.get_event(user_id, room_id, event_id): + return + receipt = ReadReceipt( room_id=room_id, receipt_type=receipt_type, - user_id=user_id, + user_id=user_id.to_string(), event_ids=[event_id], thread_id=thread_id, data={"ts": int(self.clock.time_msec())}, @@ -182,15 +200,15 @@ class ReceiptsHandler: await self.federation_sender.send_read_receipt(receipt) -class ReceiptEventSource(EventSource[int, JsonDict]): +class ReceiptEventSource(EventSource[int, JsonMapping]): def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main self.config = hs.config @staticmethod def filter_out_private_receipts( - rooms: Sequence[JsonDict], user_id: str - ) -> List[JsonDict]: + rooms: Sequence[JsonMapping], user_id: str + ) -> List[JsonMapping]: """ Filters a list of serialized receipts (as returned by /sync and /initialSync) and removes private read receipts of other users. @@ -207,7 +225,7 @@ class ReceiptEventSource(EventSource[int, JsonDict]): The same as rooms, but filtered. """ - result = [] + result: List[JsonMapping] = [] # Iterate through each room's receipt content. for room in rooms: @@ -260,7 +278,7 @@ class ReceiptEventSource(EventSource[int, JsonDict]): room_ids: Iterable[str], is_guest: bool, explicit_room_id: Optional[str] = None, - ) -> Tuple[List[JsonDict], int]: + ) -> Tuple[List[JsonMapping], int]: from_key = int(from_key) to_key = self.get_current_key() @@ -279,7 +297,7 @@ class ReceiptEventSource(EventSource[int, JsonDict]): async def get_new_events_as( self, from_key: int, to_key: int, service: ApplicationService - ) -> Tuple[List[JsonDict], int]: + ) -> Tuple[List[JsonMapping], int]: """Returns a set of new read receipt events that an appservice may be interested in. diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index c80946c2e9..3a55056df5 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -143,15 +143,10 @@ class RegistrationHandler: assigned_user_id: Optional[str] = None, inhibit_user_in_use_error: bool = False, ) -> None: - if types.contains_invalid_mxid_characters( - localpart, self.hs.config.experimental.msc4009_e164_mxids - ): - extra_chars = ( - "=_-./+" if self.hs.config.experimental.msc4009_e164_mxids else "=_-./" - ) + if types.contains_invalid_mxid_characters(localpart): raise SynapseError( 400, - f"User ID can only contain characters a-z, 0-9, or '{extra_chars}'", + "User ID can only contain characters a-z, 0-9, or '=_-./+'", Codes.INVALID_USERNAME, ) @@ -315,7 +310,7 @@ class RegistrationHandler: approved=approved, ) - profile = await self.store.get_profileinfo(localpart) + profile = await self.store.get_profileinfo(user) await self.user_directory_handler.handle_local_profile_change( user_id, profile ) diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index 4824635162..9b13448cdd 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -13,7 +13,17 @@ # limitations under the License. import enum import logging -from typing import TYPE_CHECKING, Collection, Dict, FrozenSet, Iterable, List, Optional +from typing import ( + TYPE_CHECKING, + Collection, + Dict, + FrozenSet, + Iterable, + List, + Mapping, + Optional, + Sequence, +) import attr @@ -205,16 +215,22 @@ class RelationsHandler: event_id: The event IDs to look and redact relations of. initial_redaction_event: The redaction for the event referred to by event_id. - relation_types: The types of relations to look for. + relation_types: The types of relations to look for. If "*" is in the list, + all related events will be redacted regardless of the type. Raises: ShadowBanError if the requester is shadow-banned """ - related_event_ids = ( - await self._main_store.get_all_relations_for_event_with_types( - event_id, relation_types + if "*" in relation_types: + related_event_ids = await self._main_store.get_all_relations_for_event( + event_id + ) + else: + related_event_ids = ( + await self._main_store.get_all_relations_for_event_with_types( + event_id, relation_types + ) ) - ) for related_event_id in related_event_ids: try: @@ -239,7 +255,7 @@ class RelationsHandler: async def get_references_for_events( self, event_ids: Collection[str], ignored_users: FrozenSet[str] = frozenset() - ) -> Dict[str, List[_RelatedEvent]]: + ) -> Mapping[str, Sequence[_RelatedEvent]]: """Get a list of references to the given events. Args: diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 5e1702d78a..4cdf0a8502 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -20,7 +20,7 @@ import random import string from collections import OrderedDict from http import HTTPStatus -from typing import TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional, Tuple import attr from typing_extensions import TypedDict @@ -54,11 +54,11 @@ from synapse.events import EventBase from synapse.events.snapshot import UnpersistedEventContext from synapse.events.utils import copy_and_fixup_power_levels_contents from synapse.handlers.relations import BundledAggregations -from synapse.module_api import NOT_SPAM from synapse.rest.admin._base import assert_user_is_admin from synapse.streams import EventSource from synapse.types import ( JsonDict, + JsonMapping, MutableStateMap, Requester, RoomAlias, @@ -454,7 +454,7 @@ class RoomCreationHandler: spam_check = await self._spam_checker_module_callbacks.user_may_create_room( user_id ) - if spam_check != NOT_SPAM: + if spam_check != self._spam_checker_module_callbacks.NOT_SPAM: raise SynapseError( 403, "You are not permitted to create rooms", @@ -768,7 +768,7 @@ class RoomCreationHandler: spam_check = await self._spam_checker_module_callbacks.user_may_create_room( user_id ) - if spam_check != NOT_SPAM: + if spam_check != self._spam_checker_module_callbacks.NOT_SPAM: raise SynapseError( 403, "You are not permitted to create rooms", @@ -872,6 +872,8 @@ class RoomCreationHandler: visibility = config.get("visibility", "private") is_public = visibility == "public" + self._validate_room_config(config, visibility) + room_id = await self._generate_and_create_room_id( creator_id=user_id, is_public=is_public, @@ -1111,23 +1113,10 @@ class RoomCreationHandler: return new_event, new_unpersisted_context - visibility = room_config.get("visibility", "private") - preset_config = room_config.get( - "preset", - RoomCreationPreset.PRIVATE_CHAT - if visibility == "private" - else RoomCreationPreset.PUBLIC_CHAT, - ) - - try: - config = self._presets_dict[preset_config] - except KeyError: - raise SynapseError( - 400, f"'{preset_config}' is not a valid preset", errcode=Codes.BAD_JSON - ) + preset_config, config = self._room_preset_config(room_config) # MSC2175 removes the creator field from the create event. - if not room_version.msc2175_implicit_room_creator: + if not room_version.implicit_room_creator: creation_content["creator"] = creator_id creation_event, unpersisted_creation_context = await create_event( EventTypes.Create, creation_content, False @@ -1306,6 +1295,65 @@ class RoomCreationHandler: assert last_event.internal_metadata.stream_ordering is not None return last_event.internal_metadata.stream_ordering, last_event.event_id, depth + def _validate_room_config( + self, + config: JsonDict, + visibility: str, + ) -> None: + """Checks configuration parameters for a /createRoom request. + + If validation detects invalid parameters an exception may be raised to + cause room creation to be aborted and an error response to be returned + to the client. + + Args: + config: A dict of configuration options. Originally from the body of + the /createRoom request + visibility: One of "public" or "private" + """ + + # Validate the requested preset, raise a 400 error if not valid + preset_name, preset_config = self._room_preset_config(config) + + # If the user is trying to create an encrypted room and this is forbidden + # by the configured default_power_level_content_override, then reject the + # request before the room is created. + raw_initial_state = config.get("initial_state", []) + room_encryption_event = any( + s.get("type", "") == EventTypes.RoomEncryption for s in raw_initial_state + ) + + if preset_config["encrypted"] or room_encryption_event: + if self._default_power_level_content_override: + override = self._default_power_level_content_override.get(preset_name) + if override is not None: + event_levels = override.get("events", {}) + room_admin_level = event_levels.get(EventTypes.PowerLevels, 100) + encryption_level = event_levels.get(EventTypes.RoomEncryption, 100) + if encryption_level > room_admin_level: + raise SynapseError( + 403, + f"You cannot create an encrypted room. user_level ({room_admin_level}) < send_level ({encryption_level})", + ) + + def _room_preset_config(self, room_config: JsonDict) -> Tuple[str, dict]: + # The spec says rooms should default to private visibility if + # `visibility` is not specified. + visibility = room_config.get("visibility", "private") + preset_name = room_config.get( + "preset", + RoomCreationPreset.PRIVATE_CHAT + if visibility == "private" + else RoomCreationPreset.PUBLIC_CHAT, + ) + try: + preset_config = self._presets_dict[preset_name] + except KeyError: + raise SynapseError( + 400, f"'{preset_name}' is not a valid preset", errcode=Codes.BAD_JSON + ) + return preset_name, preset_config + def _generate_room_id(self) -> str: """Generates a random room ID. @@ -1490,7 +1538,6 @@ class RoomContextHandler: class TimestampLookupHandler: def __init__(self, hs: "HomeServer"): - self.server_name = hs.hostname self.store = hs.get_datastores().main self.state_handler = hs.get_state_handler() self.federation_client = hs.get_federation_client() @@ -1661,7 +1708,7 @@ class RoomEventSource(EventSource[RoomStreamToken, EventBase]): if from_key.topological: logger.warning("Stream has topological part!!!! %r", from_key) - from_key = RoomStreamToken(None, from_key.stream) + from_key = RoomStreamToken(stream=from_key.stream) app_service = self.store.get_app_service_by_user_id(user.to_string()) if app_service: @@ -1703,6 +1750,45 @@ class RoomEventSource(EventSource[RoomStreamToken, EventBase]): return self.store.get_current_room_stream_token_for_room_id(room_id) +class ShutdownRoomParams(TypedDict): + """ + Attributes: + requester_user_id: + User who requested the action. Will be recorded as putting the room on the + blocking list. + new_room_user_id: + If set, a new room will be created with this user ID + as the creator and admin, and all users in the old room will be + moved into that room. If not set, no new room will be created + and the users will just be removed from the old room. + new_room_name: + A string representing the name of the room that new users will + be invited to. Defaults to `Content Violation Notification` + message: + A string containing the first message that will be sent as + `new_room_user_id` in the new room. Ideally this will clearly + convey why the original room was shut down. + Defaults to `Sharing illegal content on this server is not + permitted and rooms in violation will be blocked.` + block: + If set to `true`, this room will be added to a blocking list, + preventing future attempts to join the room. Defaults to `false`. + purge: + If set to `true`, purge the given room from the database. + force_purge: + If set to `true`, the room will be purged from database + even if there are still users joined to the room. + """ + + requester_user_id: Optional[str] + new_room_user_id: Optional[str] + new_room_name: Optional[str] + message: Optional[str] + block: bool + purge: bool + force_purge: bool + + class ShutdownRoomResponse(TypedDict): """ Attributes: @@ -1740,12 +1826,12 @@ class RoomShutdownHandler: async def shutdown_room( self, room_id: str, - requester_user_id: str, - new_room_user_id: Optional[str] = None, - new_room_name: Optional[str] = None, - message: Optional[str] = None, - block: bool = False, - ) -> ShutdownRoomResponse: + params: ShutdownRoomParams, + result: Optional[ShutdownRoomResponse] = None, + update_result_fct: Optional[ + Callable[[Optional[JsonMapping]], Awaitable[None]] + ] = None, + ) -> Optional[ShutdownRoomResponse]: """ Shuts down a room. Moves all local users and room aliases automatically to a new room if `new_room_user_id` is set. Otherwise local users only @@ -1761,48 +1847,23 @@ class RoomShutdownHandler: Args: room_id: The ID of the room to shut down. - requester_user_id: - User who requested the action and put the room on the - blocking list. - new_room_user_id: - If set, a new room will be created with this user ID - as the creator and admin, and all users in the old room will be - moved into that room. If not set, no new room will be created - and the users will just be removed from the old room. - new_room_name: - A string representing the name of the room that new users will - be invited to. Defaults to `Content Violation Notification` - message: - A string containing the first message that will be sent as - `new_room_user_id` in the new room. Ideally this will clearly - convey why the original room was shut down. - Defaults to `Sharing illegal content on this server is not - permitted and rooms in violation will be blocked.` - block: - If set to `True`, users will be prevented from joining the old - room. This option can also be used to pre-emptively block a room, - even if it's unknown to this homeserver. In this case, the room - will be blocked, and no further action will be taken. If `False`, - attempting to delete an unknown room is invalid. - - Defaults to `False`. - - Returns: a dict containing the following keys: - kicked_users: An array of users (`user_id`) that were kicked. - failed_to_kick_users: - An array of users (`user_id`) that that were not kicked. - local_aliases: - An array of strings representing the local aliases that were - migrated from the old room to the new. - new_room_id: - A string representing the room ID of the new room, or None if - no such room was created. - """ + delete_id: The delete ID identifying this delete request + params: parameters for the shutdown, cf `ShutdownRoomParams` + result: current status of the shutdown, if it was interrupted + update_result_fct: function called when `result` is updated locally - if not new_room_name: - new_room_name = self.DEFAULT_ROOM_NAME - if not message: - message = self.DEFAULT_MESSAGE + Returns: a dict matching `ShutdownRoomResponse`. + """ + requester_user_id = params["requester_user_id"] + new_room_user_id = params["new_room_user_id"] + block = params["block"] + + new_room_name = ( + params["new_room_name"] + if params["new_room_name"] + else self.DEFAULT_ROOM_NAME + ) + message = params["message"] if params["message"] else self.DEFAULT_MESSAGE if not RoomID.is_valid(room_id): raise SynapseError(400, "%s is not a legal room ID" % (room_id,)) @@ -1814,22 +1875,33 @@ class RoomShutdownHandler: 403, "Shutdown of this room is forbidden", Codes.FORBIDDEN ) + result = ( + result + if result + else { + "kicked_users": [], + "failed_to_kick_users": [], + "local_aliases": [], + "new_room_id": None, + } + ) + # Action the block first (even if the room doesn't exist yet) if block: + if requester_user_id is None: + raise ValueError( + "shutdown_room: block=True not allowed when requester_user_id is None." + ) # This will work even if the room is already blocked, but that is # desirable in case the first attempt at blocking the room failed below. await self.store.block_room(room_id, requester_user_id) if not await self.store.get_room(room_id): # if we don't know about the room, there is nothing left to do. - return { - "kicked_users": [], - "failed_to_kick_users": [], - "local_aliases": [], - "new_room_id": None, - } + return result - if new_room_user_id is not None: + new_room_id = result.get("new_room_id") + if new_room_user_id is not None and new_room_id is None: if not self.hs.is_mine_id(new_room_user_id): raise SynapseError( 400, "User must be our own: %s" % (new_room_user_id,) @@ -1849,6 +1921,10 @@ class RoomShutdownHandler: ratelimit=False, ) + result["new_room_id"] = new_room_id + if update_result_fct: + await update_result_fct(result) + logger.info( "Shutting down room %r, joining to new room: %r", room_id, new_room_id ) @@ -1862,12 +1938,9 @@ class RoomShutdownHandler: stream_id, ) else: - new_room_id = None logger.info("Shutting down room %r", room_id) users = await self.store.get_users_in_room(room_id) - kicked_users = [] - failed_to_kick_users = [] for user_id in users: if not self.hs.is_mine_id(user_id): continue @@ -1896,7 +1969,9 @@ class RoomShutdownHandler: stream_id, ) - await self.room_member_handler.forget(target_requester.user, room_id) + await self.room_member_handler.forget( + target_requester.user, room_id, do_not_schedule_purge=True + ) # Join users to new room if new_room_user_id: @@ -1911,15 +1986,23 @@ class RoomShutdownHandler: require_consent=False, ) - kicked_users.append(user_id) + result["kicked_users"].append(user_id) + if update_result_fct: + await update_result_fct(result) except Exception: logger.exception( "Failed to leave old room and join new room for %r", user_id ) - failed_to_kick_users.append(user_id) + result["failed_to_kick_users"].append(user_id) + if update_result_fct: + await update_result_fct(result) # Send message in new room and move aliases if new_room_user_id: + room_creator_requester = create_requester( + new_room_user_id, authenticated_entity=requester_user_id + ) + await self.event_creation_handler.create_and_send_nonmember_event( room_creator_requester, { @@ -1931,18 +2014,15 @@ class RoomShutdownHandler: ratelimit=False, ) - aliases_for_room = await self.store.get_aliases_for_room(room_id) + result["local_aliases"] = list( + await self.store.get_aliases_for_room(room_id) + ) assert new_room_id is not None await self.store.update_aliases_for_room( room_id, new_room_id, requester_user_id ) else: - aliases_for_room = [] + result["local_aliases"] = [] - return { - "kicked_users": kicked_users, - "failed_to_kick_users": failed_to_kick_users, - "local_aliases": list(aliases_for_room), - "new_room_id": new_room_id, - } + return result diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py deleted file mode 100644 index bf9df60218..0000000000 --- a/synapse/handlers/room_batch.py +++ /dev/null @@ -1,466 +0,0 @@ -import logging -from typing import TYPE_CHECKING, List, Tuple - -from synapse.api.constants import EventContentFields, EventTypes -from synapse.appservice import ApplicationService -from synapse.http.servlet import assert_params_in_dict -from synapse.types import JsonDict, Requester, UserID, create_requester -from synapse.util.stringutils import random_string - -if TYPE_CHECKING: - from synapse.server import HomeServer - -logger = logging.getLogger(__name__) - - -class RoomBatchHandler: - def __init__(self, hs: "HomeServer"): - self.hs = hs - self.store = hs.get_datastores().main - self._state_storage_controller = hs.get_storage_controllers().state - self.event_creation_handler = hs.get_event_creation_handler() - self.room_member_handler = hs.get_room_member_handler() - self.auth = hs.get_auth() - - async def inherit_depth_from_prev_ids(self, prev_event_ids: List[str]) -> int: - """Finds the depth which would sort it after the most-recent - prev_event_id but before the successors of those events. If no - successors are found, we assume it's an historical extremity part of the - current batch and use the same depth of the prev_event_ids. - - Args: - prev_event_ids: List of prev event IDs - - Returns: - Inherited depth - """ - ( - most_recent_prev_event_id, - most_recent_prev_event_depth, - ) = await self.store.get_max_depth_of(prev_event_ids) - - # We want to insert the historical event after the `prev_event` but before the successor event - # - # We inherit depth from the successor event instead of the `prev_event` - # because events returned from `/messages` are first sorted by `topological_ordering` - # which is just the `depth` and then tie-break with `stream_ordering`. - # - # We mark these inserted historical events as "backfilled" which gives them a - # negative `stream_ordering`. If we use the same depth as the `prev_event`, - # then our historical event will tie-break and be sorted before the `prev_event` - # when it should come after. - # - # We want to use the successor event depth so they appear after `prev_event` because - # it has a larger `depth` but before the successor event because the `stream_ordering` - # is negative before the successor event. - assert most_recent_prev_event_id is not None - successor_event_ids = await self.store.get_successor_events( - most_recent_prev_event_id - ) - - # If we can't find any successor events, then it's a forward extremity of - # historical messages and we can just inherit from the previous historical - # event which we can already assume has the correct depth where we want - # to insert into. - if not successor_event_ids: - depth = most_recent_prev_event_depth - else: - ( - _, - oldest_successor_depth, - ) = await self.store.get_min_depth_of(successor_event_ids) - - depth = oldest_successor_depth - - return depth - - def create_insertion_event_dict( - self, sender: str, room_id: str, origin_server_ts: int - ) -> JsonDict: - """Creates an event dict for an "insertion" event with the proper fields - and a random batch ID. - - Args: - sender: The event author MXID - room_id: The room ID that the event belongs to - origin_server_ts: Timestamp when the event was sent - - Returns: - The new event dictionary to insert. - """ - - next_batch_id = random_string(8) - insertion_event = { - "type": EventTypes.MSC2716_INSERTION, - "sender": sender, - "room_id": room_id, - "content": { - EventContentFields.MSC2716_NEXT_BATCH_ID: next_batch_id, - EventContentFields.MSC2716_HISTORICAL: True, - }, - "origin_server_ts": origin_server_ts, - } - - return insertion_event - - async def create_requester_for_user_id_from_app_service( - self, user_id: str, app_service: ApplicationService - ) -> Requester: - """Creates a new requester for the given user_id - and validates that the app service is allowed to control - the given user. - - Args: - user_id: The author MXID that the app service is controlling - app_service: The app service that controls the user - - Returns: - Requester object - """ - - await self.auth.validate_appservice_can_control_user_id(app_service, user_id) - - return create_requester(user_id, app_service=app_service) - - async def get_most_recent_full_state_ids_from_event_id_list( - self, event_ids: List[str] - ) -> List[str]: - """Find the most recent event_id and grab the full state at that event. - We will use this as a base to auth our historical messages against. - - Args: - event_ids: List of event ID's to look at - - Returns: - List of event ID's - """ - - ( - most_recent_event_id, - _, - ) = await self.store.get_max_depth_of(event_ids) - # mapping from (type, state_key) -> state_event_id - assert most_recent_event_id is not None - prev_state_map = await self._state_storage_controller.get_state_ids_for_event( - most_recent_event_id - ) - # List of state event ID's - full_state_ids = list(prev_state_map.values()) - - return full_state_ids - - async def persist_state_events_at_start( - self, - state_events_at_start: List[JsonDict], - room_id: str, - initial_state_event_ids: List[str], - app_service_requester: Requester, - ) -> List[str]: - """Takes all `state_events_at_start` event dictionaries and creates/persists - them in a floating state event chain which don't resolve into the current room - state. They are floating because they reference no prev_events which disconnects - them from the normal DAG. - - Args: - state_events_at_start: - room_id: Room where you want the events persisted in. - initial_state_event_ids: - The base set of state for the historical batch which the floating - state chain will derive from. This should probably be the state - from the `prev_event` defined by `/batch_send?prev_event_id=$abc`. - app_service_requester: The requester of an application service. - - Returns: - List of state event ID's we just persisted - """ - assert app_service_requester.app_service - - state_event_ids_at_start = [] - state_event_ids = initial_state_event_ids.copy() - - # Make the state events float off on their own by specifying no - # prev_events for the first one in the chain so we don't have a bunch of - # `@mxid joined the room` noise between each batch. - prev_event_ids_for_state_chain: List[str] = [] - - for index, state_event in enumerate(state_events_at_start): - assert_params_in_dict( - state_event, ["type", "origin_server_ts", "content", "sender"] - ) - - logger.debug( - "RoomBatchSendEventRestServlet inserting state_event=%s", state_event - ) - - event_dict = { - "type": state_event["type"], - "origin_server_ts": state_event["origin_server_ts"], - "content": state_event["content"], - "room_id": room_id, - "sender": state_event["sender"], - "state_key": state_event["state_key"], - } - - # Mark all events as historical - event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True - - # TODO: This is pretty much the same as some other code to handle inserting state in this file - if event_dict["type"] == EventTypes.Member: - membership = event_dict["content"].get("membership", None) - event_id, _ = await self.room_member_handler.update_membership( - await self.create_requester_for_user_id_from_app_service( - state_event["sender"], app_service_requester.app_service - ), - target=UserID.from_string(event_dict["state_key"]), - room_id=room_id, - action=membership, - content=event_dict["content"], - historical=True, - # Only the first event in the state chain should be floating. - # The rest should hang off each other in a chain. - allow_no_prev_events=index == 0, - prev_event_ids=prev_event_ids_for_state_chain, - # The first event in the state chain is floating with no - # `prev_events` which means it can't derive state from - # anywhere automatically. So we need to set some state - # explicitly. - # - # Make sure to use a copy of this list because we modify it - # later in the loop here. Otherwise it will be the same - # reference and also update in the event when we append - # later. - state_event_ids=state_event_ids.copy(), - ) - else: - ( - event, - _, - ) = await self.event_creation_handler.create_and_send_nonmember_event( - await self.create_requester_for_user_id_from_app_service( - state_event["sender"], app_service_requester.app_service - ), - event_dict, - historical=True, - # Only the first event in the state chain should be floating. - # The rest should hang off each other in a chain. - allow_no_prev_events=index == 0, - prev_event_ids=prev_event_ids_for_state_chain, - # The first event in the state chain is floating with no - # `prev_events` which means it can't derive state from - # anywhere automatically. So we need to set some state - # explicitly. - # - # Make sure to use a copy of this list because we modify it - # later in the loop here. Otherwise it will be the same - # reference and also update in the event when we append later. - state_event_ids=state_event_ids.copy(), - ) - event_id = event.event_id - - state_event_ids_at_start.append(event_id) - state_event_ids.append(event_id) - # Connect all the state in a floating chain - prev_event_ids_for_state_chain = [event_id] - - return state_event_ids_at_start - - async def persist_historical_events( - self, - events_to_create: List[JsonDict], - room_id: str, - inherited_depth: int, - initial_state_event_ids: List[str], - app_service_requester: Requester, - ) -> List[str]: - """Create and persists all events provided sequentially. Handles the - complexity of creating events in chronological order so they can - reference each other by prev_event but still persists in - reverse-chronoloical order so they have the correct - (topological_ordering, stream_ordering) and sort correctly from - /messages. - - Args: - events_to_create: List of historical events to create in JSON - dictionary format. - room_id: Room where you want the events persisted in. - inherited_depth: The depth to create the events at (you will - probably by calling inherit_depth_from_prev_ids(...)). - initial_state_event_ids: - This is used to set explicit state for the insertion event at - the start of the historical batch since it's floating with no - prev_events to derive state from automatically. - app_service_requester: The requester of an application service. - - Returns: - List of persisted event IDs - """ - assert app_service_requester.app_service - - # We expect the first event in a historical batch to be an insertion event - assert events_to_create[0]["type"] == EventTypes.MSC2716_INSERTION - # We expect the last event in a historical batch to be an batch event - assert events_to_create[-1]["type"] == EventTypes.MSC2716_BATCH - - # Make the historical event chain float off on its own by specifying no - # prev_events for the first event in the chain which causes the HS to - # ask for the state at the start of the batch later. - prev_event_ids: List[str] = [] - - event_ids = [] - events_to_persist = [] - for index, ev in enumerate(events_to_create): - assert_params_in_dict(ev, ["type", "origin_server_ts", "content", "sender"]) - - assert self.hs.is_mine_id(ev["sender"]), "User must be our own: %s" % ( - ev["sender"], - ) - - event_dict = { - "type": ev["type"], - "origin_server_ts": ev["origin_server_ts"], - "content": ev["content"], - "room_id": room_id, - "sender": ev["sender"], # requester.user.to_string(), - "prev_events": prev_event_ids.copy(), - } - - # Mark all events as historical - event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True - - event, unpersisted_context = await self.event_creation_handler.create_event( - await self.create_requester_for_user_id_from_app_service( - ev["sender"], app_service_requester.app_service - ), - event_dict, - # Only the first event (which is the insertion event) in the - # chain should be floating. The rest should hang off each other - # in a chain. - allow_no_prev_events=index == 0, - prev_event_ids=event_dict.get("prev_events"), - # Since the first event (which is the insertion event) in the - # chain is floating with no `prev_events`, it can't derive state - # from anywhere automatically. So we need to set some state - # explicitly. - state_event_ids=initial_state_event_ids if index == 0 else None, - historical=True, - depth=inherited_depth, - ) - context = await unpersisted_context.persist(event) - assert context._state_group - - # Normally this is done when persisting the event but we have to - # pre-emptively do it here because we create all the events first, - # then persist them in another pass below. And we want to share - # state_groups across the whole batch so this lookup needs to work - # for the next event in the batch in this loop. - await self.store.store_state_group_id_for_event_id( - event_id=event.event_id, - state_group_id=context._state_group, - ) - - logger.debug( - "RoomBatchSendEventRestServlet inserting event=%s, prev_event_ids=%s", - event, - prev_event_ids, - ) - - events_to_persist.append((event, context)) - event_id = event.event_id - - event_ids.append(event_id) - prev_event_ids = [event_id] - - # Persist events in reverse-chronological order so they have the - # correct stream_ordering as they are backfilled (which decrements). - # Events are sorted by (topological_ordering, stream_ordering) - # where topological_ordering is just depth. - for event, context in reversed(events_to_persist): - # This call can't raise `PartialStateConflictError` since we forbid - # use of the historical batch API during partial state - await self.event_creation_handler.handle_new_client_event( - await self.create_requester_for_user_id_from_app_service( - event.sender, app_service_requester.app_service - ), - events_and_context=[(event, context)], - ) - - return event_ids - - async def handle_batch_of_events( - self, - events_to_create: List[JsonDict], - room_id: str, - batch_id_to_connect_to: str, - inherited_depth: int, - initial_state_event_ids: List[str], - app_service_requester: Requester, - ) -> Tuple[List[str], str]: - """ - Handles creating and persisting all of the historical events as well as - insertion and batch meta events to make the batch navigable in the DAG. - - Args: - events_to_create: List of historical events to create in JSON - dictionary format. - room_id: Room where you want the events created in. - batch_id_to_connect_to: The batch_id from the insertion event you - want this batch to connect to. - inherited_depth: The depth to create the events at (you will - probably by calling inherit_depth_from_prev_ids(...)). - initial_state_event_ids: - This is used to set explicit state for the insertion event at - the start of the historical batch since it's floating with no - prev_events to derive state from automatically. This should - probably be the state from the `prev_event` defined by - `/batch_send?prev_event_id=$abc` plus the outcome of - `persist_state_events_at_start` - app_service_requester: The requester of an application service. - - Returns: - Tuple containing a list of created events and the next_batch_id - """ - - # Connect this current batch to the insertion event from the previous batch - last_event_in_batch = events_to_create[-1] - batch_event = { - "type": EventTypes.MSC2716_BATCH, - "sender": app_service_requester.user.to_string(), - "room_id": room_id, - "content": { - EventContentFields.MSC2716_BATCH_ID: batch_id_to_connect_to, - EventContentFields.MSC2716_HISTORICAL: True, - }, - # Since the batch event is put at the end of the batch, - # where the newest-in-time event is, copy the origin_server_ts from - # the last event we're inserting - "origin_server_ts": last_event_in_batch["origin_server_ts"], - } - # Add the batch event to the end of the batch (newest-in-time) - events_to_create.append(batch_event) - - # Add an "insertion" event to the start of each batch (next to the oldest-in-time - # event in the batch) so the next batch can be connected to this one. - insertion_event = self.create_insertion_event_dict( - sender=app_service_requester.user.to_string(), - room_id=room_id, - # Since the insertion event is put at the start of the batch, - # where the oldest-in-time event is, copy the origin_server_ts from - # the first event we're inserting - origin_server_ts=events_to_create[0]["origin_server_ts"], - ) - next_batch_id = insertion_event["content"][ - EventContentFields.MSC2716_NEXT_BATCH_ID - ] - # Prepend the insertion event to the start of the batch (oldest-in-time) - events_to_create = [insertion_event] + events_to_create - - # Create and persist all of the historical events - event_ids = await self.persist_historical_events( - events_to_create=events_to_create, - room_id=room_id, - inherited_depth=inherited_depth, - initial_state_event_ids=initial_state_event_ids, - app_service_requester=app_service_requester, - ) - - return event_ids, next_batch_id diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index 49a7df973f..6049708aba 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -33,7 +33,7 @@ from synapse.api.errors import ( RequestSendFailed, SynapseError, ) -from synapse.types import JsonDict, PublicRoom, ThirdPartyInstanceID +from synapse.types import JsonDict, JsonMapping, PublicRoom, ThirdPartyInstanceID from synapse.util import filter_none from synapse.util.caches.descriptors import _CacheContext, cached from synapse.util.caches.response_cache import ResponseCache @@ -195,12 +195,14 @@ class RoomListHandler: nb_modules = len(self._module_api_callbacks.fetch_public_rooms_callbacks) - module_range = range(0, nb_modules) + module_range = range(nb_modules) # if not forwards: # module_range = reversed(module_range) for module_index in module_range: - fetch_public_rooms = self._module_api_callbacks.fetch_public_rooms_callbacks[module_index] + fetch_public_rooms = ( + self._module_api_callbacks.fetch_public_rooms_callbacks[module_index] + ) # Ask each module for a list of public rooms given the last_joined_members # value from the since token and the probing limit # last_joined_members needs to be reduce by one if this module has already @@ -239,7 +241,6 @@ class RoomListHandler: # rooms.reverse() results += rooms - print([(r.room_id, r.num_joined_members) for r in results]) response: JsonDict = {} @@ -322,7 +323,7 @@ class RoomListHandler: cache_context: _CacheContext, with_alias: bool = True, allow_private: bool = False, - ) -> Optional[JsonDict]: + ) -> Optional[JsonMapping]: """Returns the entry for a room Args: diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index af0ca5c26d..1b50495af1 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -37,12 +37,13 @@ from synapse.api.ratelimiting import Ratelimiter from synapse.event_auth import get_named_level, get_power_level_event from synapse.events import EventBase from synapse.events.snapshot import EventContext +from synapse.handlers.pagination import PURGE_ROOM_ACTION_NAME from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN from synapse.handlers.state_deltas import MatchChange, StateDeltasHandler +from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME from synapse.logging import opentracing from synapse.metrics import event_processing_positions from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.module_api import NOT_SPAM from synapse.types import ( JsonDict, Requester, @@ -94,6 +95,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): self.event_creation_handler = hs.get_event_creation_handler() self.account_data_handler = hs.get_account_data_handler() self.event_auth_handler = hs.get_event_auth_handler() + self._worker_lock_handler = hs.get_worker_locks_handler() self.member_linearizer: Linearizer = Linearizer(name="member") self.member_as_limiter = Linearizer(max_count=10, name="member_as_limiter") @@ -110,8 +112,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): self._join_rate_limiter_local = Ratelimiter( store=self.store, clock=self.clock, - rate_hz=hs.config.ratelimiting.rc_joins_local.per_second, - burst_count=hs.config.ratelimiting.rc_joins_local.burst_count, + cfg=hs.config.ratelimiting.rc_joins_local, ) # Tracks joins from local users to rooms this server isn't a member of. # I.e. joins this server makes by requesting /make_join /send_join from @@ -119,8 +120,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): self._join_rate_limiter_remote = Ratelimiter( store=self.store, clock=self.clock, - rate_hz=hs.config.ratelimiting.rc_joins_remote.per_second, - burst_count=hs.config.ratelimiting.rc_joins_remote.burst_count, + cfg=hs.config.ratelimiting.rc_joins_remote, ) # TODO: find a better place to keep this Ratelimiter. # It needs to be @@ -133,8 +133,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): self._join_rate_per_room_limiter = Ratelimiter( store=self.store, clock=self.clock, - rate_hz=hs.config.ratelimiting.rc_joins_per_room.per_second, - burst_count=hs.config.ratelimiting.rc_joins_per_room.burst_count, + cfg=hs.config.ratelimiting.rc_joins_per_room, ) # Ratelimiter for invites, keyed by room (across all issuers, all @@ -142,8 +141,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): self._invites_per_room_limiter = Ratelimiter( store=self.store, clock=self.clock, - rate_hz=hs.config.ratelimiting.rc_invites_per_room.per_second, - burst_count=hs.config.ratelimiting.rc_invites_per_room.burst_count, + cfg=hs.config.ratelimiting.rc_invites_per_room, ) # Ratelimiter for invites, keyed by recipient (across all rooms, all @@ -151,8 +149,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): self._invites_per_recipient_limiter = Ratelimiter( store=self.store, clock=self.clock, - rate_hz=hs.config.ratelimiting.rc_invites_per_user.per_second, - burst_count=hs.config.ratelimiting.rc_invites_per_user.burst_count, + cfg=hs.config.ratelimiting.rc_invites_per_user, ) # Ratelimiter for invites, keyed by issuer (across all rooms, all @@ -160,21 +157,21 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): self._invites_per_issuer_limiter = Ratelimiter( store=self.store, clock=self.clock, - rate_hz=hs.config.ratelimiting.rc_invites_per_issuer.per_second, - burst_count=hs.config.ratelimiting.rc_invites_per_issuer.burst_count, + cfg=hs.config.ratelimiting.rc_invites_per_issuer, ) self._third_party_invite_limiter = Ratelimiter( store=self.store, clock=self.clock, - rate_hz=hs.config.ratelimiting.rc_third_party_invite.per_second, - burst_count=hs.config.ratelimiting.rc_third_party_invite.burst_count, + cfg=hs.config.ratelimiting.rc_third_party_invite, ) self.request_ratelimiter = hs.get_request_ratelimiter() hs.get_notifier().add_new_join_in_room_callback(self._on_user_joined_room) - self._msc3970_enabled = hs.config.experimental.msc3970_enabled + self._forgotten_room_retention_period = ( + hs.config.server.forgotten_room_retention_period + ) def _on_user_joined_room(self, event_id: str, room_id: str) -> None: """Notify the rate limiter that a room join has occurred. @@ -285,7 +282,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): """ raise NotImplementedError() - async def forget(self, user: UserID, room_id: str) -> None: + async def forget( + self, user: UserID, room_id: str, do_not_schedule_purge: bool = False + ) -> None: user_id = user.to_string() member = await self._storage_controllers.state.get_current_state_event( @@ -305,6 +304,20 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # the table `current_state_events` and `get_current_state_events` is `None`. await self.store.forget(user_id, room_id) + # If everyone locally has left the room, then there is no reason for us to keep the + # room around and we automatically purge room after a little bit + if ( + not do_not_schedule_purge + and self._forgotten_room_retention_period + and await self.store.is_locally_forgotten_room(room_id) + ): + await self.hs.get_task_scheduler().schedule_task( + PURGE_ROOM_ACTION_NAME, + resource_id=room_id, + timestamp=self.clock.time_msec() + + self._forgotten_room_retention_period, + ) + async def ratelimit_multiple_invites( self, requester: Optional[Requester], @@ -362,7 +375,6 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): content: Optional[dict] = None, require_consent: bool = True, outlier: bool = False, - historical: bool = False, origin_server_ts: Optional[int] = None, ) -> Tuple[str, int]: """ @@ -370,24 +382,23 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): and persist a new event for the new membership change. Args: - requester: - target: + requester: User requesting the membership change, i.e. the sender of the + desired membership event. + target: Use whose membership should change, i.e. the state_key of the + desired membership event. room_id: membership: allow_no_prev_events: Whether to allow this event to be created an empty list of prev_events. Normally this is prohibited just because most events should have a prev_event and we should only use this in special - cases like MSC2716. + cases (previously useful for MSC2716). prev_event_ids: The event IDs to use as the prev events state_event_ids: - The full state at a given event. This is used particularly by the MSC2716 - /batch_send endpoint. One use case is the historical `state_events_at_start`; - since each is marked as an `outlier`, the `EventContext.for_outlier()` won't - have any `state_ids` set and therefore can't derive any state even though the - prev_events are set so we need to set them ourself via this argument. - This should normally be left as None, which will cause the auth_event_ids - to be calculated based on the room state at the prev_events. + The full state at a given event. This was previously used particularly + by the MSC2716 /batch_send endpoint. This should normally be left as + None, which will cause the auth_event_ids to be calculated based on the + room state at the prev_events. depth: Override the depth used to order the event in the DAG. Should normally be set to None, which will cause the depth to be calculated based on the prev_events. @@ -400,16 +411,12 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): outlier: Indicates whether the event is an `outlier`, i.e. if it's from an arbitrary point and floating in the DAG as opposed to being inline with the current DAG. - historical: Indicates whether the message is being inserted - back in time around some existing events. This is used to skip - a few checks and mark the event as backfilled. origin_server_ts: The origin_server_ts to use if a new event is created. Uses the current timestamp if set to None. Returns: Tuple of event ID and stream ordering position """ - user_id = target.to_string() if content is None: @@ -423,29 +430,11 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # do this check just before we persist an event as well, but may as well # do it up front for efficiency.) if txn_id: - existing_event_id = None - if self._msc3970_enabled and requester.device_id: - # When MSC3970 is enabled, we lookup for events sent by the same device - # first, and fallback to the old behaviour if none were found. - existing_event_id = ( - await self.store.get_event_id_from_transaction_id_and_device_id( - room_id, - requester.user.to_string(), - requester.device_id, - txn_id, - ) + existing_event_id = ( + await self.event_creation_handler.get_event_id_from_transaction( + requester, txn_id, room_id ) - - if requester.access_token_id and not existing_event_id: - existing_event_id = ( - await self.store.get_event_id_from_transaction_id_and_token_id( - room_id, - requester.user.to_string(), - requester.access_token_id, - txn_id, - ) - ) - + ) if existing_event_id: event_pos = await self.store.get_position_for_event(existing_event_id) return existing_event_id, event_pos.stream @@ -477,32 +466,16 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): depth=depth, require_consent=require_consent, outlier=outlier, - historical=historical, ) context = await unpersisted_context.persist(event) prev_state_ids = await context.get_prev_state_ids( - StateFilter.from_types([(EventTypes.Member, None)]) + StateFilter.from_types([(EventTypes.Member, user_id)]) ) prev_member_event_id = prev_state_ids.get( (EventTypes.Member, user_id), None ) - if event.membership == Membership.JOIN: - newly_joined = True - if prev_member_event_id: - prev_member_event = await self.store.get_event( - prev_member_event_id - ) - newly_joined = prev_member_event.membership != Membership.JOIN - - # Only rate-limit if the user actually joined the room, otherwise we'll end - # up blocking profile updates. - if newly_joined and ratelimit: - await self._join_rate_limiter_local.ratelimit(requester) - await self._join_rate_per_room_limiter.ratelimit( - requester, key=room_id, update=False - ) with opentracing.start_active_span("handle_new_client_event"): result_event = ( await self.event_creation_handler.handle_new_client_event( @@ -585,7 +558,6 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): new_room: bool = False, require_consent: bool = True, outlier: bool = False, - historical: bool = False, allow_no_prev_events: bool = False, prev_event_ids: Optional[List[str]] = None, state_event_ids: Optional[List[str]] = None, @@ -610,22 +582,16 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): outlier: Indicates whether the event is an `outlier`, i.e. if it's from an arbitrary point and floating in the DAG as opposed to being inline with the current DAG. - historical: Indicates whether the message is being inserted - back in time around some existing events. This is used to skip - a few checks and mark the event as backfilled. allow_no_prev_events: Whether to allow this event to be created an empty list of prev_events. Normally this is prohibited just because most events should have a prev_event and we should only use this in special - cases like MSC2716. + cases (previously useful for MSC2716). prev_event_ids: The event IDs to use as the prev events state_event_ids: - The full state at a given event. This is used particularly by the MSC2716 - /batch_send endpoint. One use case is the historical `state_events_at_start`; - since each is marked as an `outlier`, the `EventContext.for_outlier()` won't - have any `state_ids` set and therefore can't derive any state even though the - prev_events are set so we need to set them ourself via this argument. - This should normally be left as None, which will cause the auth_event_ids - to be calculated based on the room state at the prev_events. + The full state at a given event. This was previously used particularly + by the MSC2716 /batch_send endpoint. This should normally be left as + None, which will cause the auth_event_ids to be calculated based on the + room state at the prev_events. depth: Override the depth used to order the event in the DAG. Should normally be set to None, which will cause the depth to be calculated based on the prev_events. @@ -638,6 +604,25 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): Raises: ShadowBanError if a shadow-banned requester attempts to send an invite. """ + if ratelimit: + if action == Membership.JOIN: + # Only rate-limit if the user isn't already joined to the room, otherwise + # we'll end up blocking profile updates. + ( + current_membership, + _, + ) = await self.store.get_local_current_membership_for_user_in_room( + requester.user.to_string(), + room_id, + ) + if current_membership != Membership.JOIN: + await self._join_rate_limiter_local.ratelimit(requester) + await self._join_rate_per_room_limiter.ratelimit( + requester, key=room_id, update=False + ) + elif action == Membership.INVITE: + await self.ratelimit_invite(requester, room_id, target.to_string()) + if action == Membership.INVITE and requester.shadow_banned: # We randomly sleep a bit just to annoy the requester. await self.clock.sleep(random.randint(1, 10)) @@ -653,27 +638,29 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # by application services), and then by room ID. async with self.member_as_limiter.queue(as_id): async with self.member_linearizer.queue(key): - with opentracing.start_active_span("update_membership_locked"): - result = await self.update_membership_locked( - requester, - target, - room_id, - action, - txn_id=txn_id, - remote_room_hosts=remote_room_hosts, - third_party_signed=third_party_signed, - ratelimit=ratelimit, - content=content, - new_room=new_room, - require_consent=require_consent, - outlier=outlier, - historical=historical, - allow_no_prev_events=allow_no_prev_events, - prev_event_ids=prev_event_ids, - state_event_ids=state_event_ids, - depth=depth, - origin_server_ts=origin_server_ts, - ) + async with self._worker_lock_handler.acquire_read_write_lock( + NEW_EVENT_DURING_PURGE_LOCK_NAME, room_id, write=False + ): + with opentracing.start_active_span("update_membership_locked"): + result = await self.update_membership_locked( + requester, + target, + room_id, + action, + txn_id=txn_id, + remote_room_hosts=remote_room_hosts, + third_party_signed=third_party_signed, + ratelimit=ratelimit, + content=content, + new_room=new_room, + require_consent=require_consent, + outlier=outlier, + allow_no_prev_events=allow_no_prev_events, + prev_event_ids=prev_event_ids, + state_event_ids=state_event_ids, + depth=depth, + origin_server_ts=origin_server_ts, + ) return result @@ -691,7 +678,6 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): new_room: bool = False, require_consent: bool = True, outlier: bool = False, - historical: bool = False, allow_no_prev_events: bool = False, prev_event_ids: Optional[List[str]] = None, state_event_ids: Optional[List[str]] = None, @@ -718,22 +704,16 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): outlier: Indicates whether the event is an `outlier`, i.e. if it's from an arbitrary point and floating in the DAG as opposed to being inline with the current DAG. - historical: Indicates whether the message is being inserted - back in time around some existing events. This is used to skip - a few checks and mark the event as backfilled. allow_no_prev_events: Whether to allow this event to be created an empty list of prev_events. Normally this is prohibited just because most events should have a prev_event and we should only use this in special - cases like MSC2716. + cases (previously useful for MSC2716). prev_event_ids: The event IDs to use as the prev events state_event_ids: - The full state at a given event. This is used particularly by the MSC2716 - /batch_send endpoint. One use case is the historical `state_events_at_start`; - since each is marked as an `outlier`, the `EventContext.for_outlier()` won't - have any `state_ids` set and therefore can't derive any state even though the - prev_events are set so we need to set them ourself via this argument. - This should normally be left as None, which will cause the auth_event_ids - to be calculated based on the room state at the prev_events. + The full state at a given event. This was previously used particularly + by the MSC2716 /batch_send endpoint. This should normally be left as + None, which will cause the auth_event_ids to be calculated based on the + room state at the prev_events. depth: Override the depth used to order the event in the DAG. Should normally be set to None, which will cause the depth to be calculated based on the prev_events. @@ -819,8 +799,6 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): if effective_membership_state == Membership.INVITE: target_id = target.to_string() - if ratelimit: - await self.ratelimit_invite(requester, room_id, target_id) # block any attempts to invite the server notices mxid if target_id == self._server_notices_mxid: @@ -849,7 +827,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): spam_check = await self._spam_checker_module_callbacks.user_may_invite( requester.user.to_string(), target_id, room_id ) - if spam_check != NOT_SPAM: + if spam_check != self._spam_checker_module_callbacks.NOT_SPAM: logger.info("Blocking invite due to spam checker") block_invite_result = spam_check @@ -877,7 +855,6 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): content=content, require_consent=require_consent, outlier=outlier, - historical=historical, origin_server_ts=origin_server_ts, ) @@ -985,7 +962,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): target.to_string(), room_id, is_invited=inviter is not None ) ) - if spam_check != NOT_SPAM: + if spam_check != self._spam_checker_module_callbacks.NOT_SPAM: raise SynapseError( 403, "Not allowed to join this room", @@ -1364,7 +1341,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): requester = types.create_requester(target_user) prev_state_ids = await context.get_prev_state_ids( - StateFilter.from_types([(EventTypes.GuestAccess, None)]) + StateFilter.from_types([(EventTypes.GuestAccess, "")]) ) if event.membership == Membership.JOIN: if requester.is_guest: @@ -1386,11 +1363,14 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): ratelimit=ratelimit, ) - prev_member_event_id = prev_state_ids.get( - (EventTypes.Member, event.state_key), None - ) - if event.membership == Membership.LEAVE: + prev_state_ids = await context.get_prev_state_ids( + StateFilter.from_types([(EventTypes.Member, event.state_key)]) + ) + prev_member_event_id = prev_state_ids.get( + (EventTypes.Member, event.state_key), None + ) + if prev_member_event_id: prev_member_event = await self.store.get_event(prev_member_event_id) if prev_member_event.membership == Membership.JOIN: @@ -1498,7 +1478,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # put the server which owns the alias at the front of the server list. if room_alias.domain in servers: servers.remove(room_alias.domain) - servers.insert(0, room_alias.domain) + servers.insert(0, room_alias.domain) return RoomID.from_string(room_id), servers @@ -1600,7 +1580,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): room_id=room_id, ) ) - if spam_check != NOT_SPAM: + if spam_check != self._spam_checker_module_callbacks.NOT_SPAM: raise SynapseError( 403, "Cannot send threepid invite", diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py index 807245160d..dd559b4c45 100644 --- a/synapse/handlers/room_summary.py +++ b/synapse/handlers/room_summary.py @@ -35,6 +35,7 @@ from synapse.api.errors import ( UnsupportedRoomVersionError, ) from synapse.api.ratelimiting import Ratelimiter +from synapse.config.ratelimiting import RatelimitSettings from synapse.events import EventBase from synapse.types import JsonDict, Requester, StrCollection from synapse.util.caches.response_cache import ResponseCache @@ -94,7 +95,9 @@ class RoomSummaryHandler: self._server_name = hs.hostname self._federation_client = hs.get_federation_client() self._ratelimiter = Ratelimiter( - store=self._store, clock=hs.get_clock(), rate_hz=5, burst_count=10 + store=self._store, + clock=hs.get_clock(), + cfg=RatelimitSettings("<room summary>", per_second=5, burst_count=10), ) # If a user tries to fetch the same page multiple times in quick succession, @@ -564,9 +567,9 @@ class RoomSummaryHandler: join_rule = join_rules_event.content.get("join_rule") if ( join_rule == JoinRules.PUBLIC - or (room_version.msc2403_knocking and join_rule == JoinRules.KNOCK) + or (room_version.knock_join_rule and join_rule == JoinRules.KNOCK) or ( - room_version.msc3787_knock_restricted_join_rule + room_version.knock_restricted_join_rule and join_rule == JoinRules.KNOCK_RESTRICTED ) ): diff --git a/synapse/handlers/saml.py b/synapse/handlers/saml.py index 874860d461..d00035c332 100644 --- a/synapse/handlers/saml.py +++ b/synapse/handlers/saml.py @@ -27,9 +27,9 @@ from synapse.http.servlet import parse_string from synapse.http.site import SynapseRequest from synapse.module_api import ModuleApi from synapse.types import ( + MXID_LOCALPART_ALLOWED_CHARACTERS, UserID, map_username_to_mxid_localpart, - mxid_localpart_allowed_characters, ) from synapse.util.iterutils import chunk_seq @@ -74,12 +74,13 @@ class SamlHandler: self.idp_id = "saml" # user-facing name of this auth provider - self.idp_name = "SAML" + self.idp_name = hs.config.saml2.idp_name - # we do not currently support icons/brands for SAML auth, but this is required by - # the SsoIdentityProvider protocol type. - self.idp_icon = None - self.idp_brand = None + # MXC URI for icon for this auth provider + self.idp_icon = hs.config.saml2.idp_icon + + # optional brand identifier for this auth provider + self.idp_brand = hs.config.saml2.idp_brand # a map from saml session id to Saml2SessionData object self._outstanding_requests_dict: Dict[str, Saml2SessionData] = {} @@ -371,7 +372,7 @@ class SamlHandler: DOT_REPLACE_PATTERN = re.compile( - "[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),) + "[^%s]" % (re.escape("".join(MXID_LOCALPART_ALLOWED_CHARACTERS)),) ) diff --git a/synapse/handlers/send_email.py b/synapse/handlers/send_email.py index 804cc6e81e..657d9b3559 100644 --- a/synapse/handlers/send_email.py +++ b/synapse/handlers/send_email.py @@ -17,15 +17,17 @@ import logging from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText from io import BytesIO -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Dict, Optional from pkg_resources import parse_version import twisted from twisted.internet.defer import Deferred -from twisted.internet.interfaces import IOpenSSLContextFactory +from twisted.internet.endpoints import HostnameEndpoint +from twisted.internet.interfaces import IOpenSSLContextFactory, IProtocolFactory from twisted.internet.ssl import optionsForClientTLS from twisted.mail.smtp import ESMTPSender, ESMTPSenderFactory +from twisted.protocols.tls import TLSMemoryBIOFactory from synapse.logging.context import make_deferred_yieldable from synapse.types import ISynapseReactor @@ -97,6 +99,7 @@ async def _sendmail( **kwargs, ) + factory: IProtocolFactory if _is_old_twisted: # before twisted 21.2, we have to override the ESMTPSender protocol to disable # TLS @@ -110,22 +113,13 @@ async def _sendmail( factory = build_sender_factory(hostname=smtphost if enable_tls else None) if force_tls: - reactor.connectSSL( - smtphost, - smtpport, - factory, - optionsForClientTLS(smtphost), - timeout=30, - bindAddress=None, - ) - else: - reactor.connectTCP( - smtphost, - smtpport, - factory, - timeout=30, - bindAddress=None, - ) + factory = TLSMemoryBIOFactory(optionsForClientTLS(smtphost), True, factory) + + endpoint = HostnameEndpoint( + reactor, smtphost, smtpport, timeout=30, bindAddress=None + ) + + await make_deferred_yieldable(endpoint.connect(factory)) await make_deferred_yieldable(d) @@ -157,6 +151,7 @@ class SendEmailHandler: app_name: str, html: str, text: str, + additional_headers: Optional[Dict[str, str]] = None, ) -> None: """Send a multipart email with the given information. @@ -166,6 +161,7 @@ class SendEmailHandler: app_name: The app name to include in the From header. html: The HTML content to include in the email. text: The plain text content to include in the email. + additional_headers: A map of additional headers to include. """ try: from_string = self._from % {"app": app_name} @@ -178,8 +174,8 @@ class SendEmailHandler: if raw_to == "": raise RuntimeError("Invalid 'to' address") - html_part = MIMEText(html, "html", "utf8") - text_part = MIMEText(text, "plain", "utf8") + html_part = MIMEText(html, "html", "utf-8") + text_part = MIMEText(text, "plain", "utf-8") multipart_msg = MIMEMultipart("alternative") multipart_msg["Subject"] = subject @@ -187,6 +183,7 @@ class SendEmailHandler: multipart_msg["To"] = email_address multipart_msg["Date"] = email.utils.formatdate() multipart_msg["Message-ID"] = email.utils.make_msgid() + # Discourage automatic responses to Synapse's emails. # Per RFC 3834, automatic responses should not be sent if the "Auto-Submitted" # header is present with any value other than "no". See @@ -200,6 +197,11 @@ class SendEmailHandler: # https://stackoverflow.com/a/25324691/5252017 # https://stackoverflow.com/a/61646381/5252017 multipart_msg["X-Auto-Response-Suppress"] = "All" + + if additional_headers: + for header, value in additional_headers.items(): + multipart_msg[header] = value + multipart_msg.attach(text_part) multipart_msg.attach(html_part) diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index c3a51722bd..e9a544e754 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -24,13 +24,14 @@ from typing import ( Iterable, List, Mapping, + NoReturn, Optional, Set, ) from urllib.parse import urlencode import attr -from typing_extensions import NoReturn, Protocol +from typing_extensions import Protocol from twisted.web.iweb import IRequest from twisted.web.server import Request @@ -225,8 +226,6 @@ class SsoHandler: self._consent_at_registration = hs.config.consent.user_consent_at_registration - self._e164_mxids = hs.config.experimental.msc4009_e164_mxids - def register_identity_provider(self, p: SsoIdentityProvider) -> None: p_id = p.idp_id assert p_id not in self._identity_providers @@ -713,7 +712,7 @@ class SsoHandler: # Since the localpart is provided via a potentially untrusted module, # ensure the MXID is valid before registering. if not attributes.localpart or contains_invalid_mxid_characters( - attributes.localpart, self._e164_mxids + attributes.localpart ): raise MappingException("localpart is invalid: %s" % (attributes.localpart,)) @@ -793,7 +792,7 @@ class SsoHandler: if code != 200: raise Exception( - "GET request to download sso avatar image returned {}".format(code) + f"GET request to download sso avatar image returned {code}" ) # upload name includes hash of the image file's content so that we can @@ -946,7 +945,7 @@ class SsoHandler: localpart, ) - if contains_invalid_mxid_characters(localpart, self._e164_mxids): + if contains_invalid_mxid_characters(localpart): raise SynapseError(400, "localpart is invalid: %s" % (localpart,)) user_id = UserID(localpart, self._server_name).to_string() user_infos = await self._store.get_users_by_id_case_insensitive(user_id) diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py index 5c01482acf..3dde19fc81 100644 --- a/synapse/handlers/stats.py +++ b/synapse/handlers/stats.py @@ -14,9 +14,15 @@ # limitations under the License. import logging from collections import Counter -from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Tuple - -from typing_extensions import Counter as CounterType +from typing import ( + TYPE_CHECKING, + Any, + Counter as CounterType, + Dict, + Iterable, + Optional, + Tuple, +) from synapse.api.constants import EventContentFields, EventTypes, Membership from synapse.metrics import event_processing_positions @@ -42,7 +48,6 @@ class StatsHandler: self.store = hs.get_datastores().main self._storage_controllers = hs.get_storage_controllers() self.state = hs.get_state_handler() - self.server_name = hs.hostname self.clock = hs.get_clock() self.notifier = hs.get_notifier() self.is_mine_id = hs.is_mine_id diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index c010405be6..744e080309 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -40,6 +40,7 @@ from synapse.api.filtering import FilterCollection from synapse.api.presence import UserPresenceState from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import EventBase +from synapse.handlers.device import DELETE_DEVICE_MSGS_TASK_NAME from synapse.handlers.relations import BundledAggregations from synapse.logging import issue9533_logger from synapse.logging.context import current_context @@ -56,6 +57,7 @@ from synapse.storage.roommember import MemberSummary from synapse.types import ( DeviceListUpdates, JsonDict, + JsonMapping, MutableStateMap, Requester, RoomStreamToken, @@ -233,7 +235,7 @@ class SyncResult: archived: List[ArchivedSyncResult] to_device: List[JsonDict] device_lists: DeviceListUpdates - device_one_time_keys_count: JsonDict + device_one_time_keys_count: JsonMapping device_unused_fallback_key_types: List[str] def __bool__(self) -> bool: @@ -268,6 +270,7 @@ class SyncHandler: self._storage_controllers = hs.get_storage_controllers() self._state_storage_controller = self._storage_controllers.state self._device_handler = hs.get_device_handler() + self._task_scheduler = hs.get_task_scheduler() self.should_calculate_push_rules = hs.config.push.enable_push @@ -360,13 +363,36 @@ class SyncHandler: # (since we now know that the device has received them) if since_token is not None: since_stream_id = since_token.to_device_key + # Fast path: delete a limited number of to-device messages up front. + # We do this to avoid the overhead of scheduling a task for every + # sync. + device_deletion_limit = 100 deleted = await self.store.delete_messages_for_device( - sync_config.user.to_string(), sync_config.device_id, since_stream_id + sync_config.user.to_string(), + sync_config.device_id, + since_stream_id, + limit=device_deletion_limit, ) logger.debug( "Deleted %d to-device messages up to %d", deleted, since_stream_id ) + # If we hit the limit, schedule a background task to delete the rest. + if deleted >= device_deletion_limit: + await self._task_scheduler.schedule_task( + DELETE_DEVICE_MSGS_TASK_NAME, + resource_id=sync_config.device_id, + params={ + "user_id": sync_config.user.to_string(), + "device_id": sync_config.device_id, + "up_to_stream_id": since_stream_id, + }, + ) + logger.debug( + "Deletion of to-device messages up to %d scheduled", + since_stream_id, + ) + if timeout == 0 or since_token is None or full_state: # we are going to return immediately, so don't bother calling # notifier.wait_for_events. @@ -387,16 +413,16 @@ class SyncHandler: from_token=since_token, ) - # if nothing has happened in any of the users' rooms since /sync was called, - # the resultant next_batch will be the same as since_token (since the result - # is generated when wait_for_events is first called, and not regenerated - # when wait_for_events times out). - # - # If that happens, we mustn't cache it, so that when the client comes back - # with the same cache token, we don't immediately return the same empty - # result, causing a tightloop. (#8518) - if result.next_batch == since_token: - cache_context.should_cache = False + # if nothing has happened in any of the users' rooms since /sync was called, + # the resultant next_batch will be the same as since_token (since the result + # is generated when wait_for_events is first called, and not regenerated + # when wait_for_events times out). + # + # If that happens, we mustn't cache it, so that when the client comes back + # with the same cache token, we don't immediately return the same empty + # result, causing a tightloop. (#8518) + if result.next_batch == since_token: + cache_context.should_cache = False if result: if sync_config.filter_collection.lazy_load_members(): @@ -1442,11 +1468,9 @@ class SyncHandler: # Now we have our list of joined room IDs, exclude as configured and freeze joined_room_ids = frozenset( - ( - room_id - for room_id in mutable_joined_room_ids - if room_id not in mutable_rooms_to_exclude - ) + room_id + for room_id in mutable_joined_room_ids + if room_id not in mutable_rooms_to_exclude ) logger.debug( @@ -1534,7 +1558,7 @@ class SyncHandler: logger.debug("Fetching OTK data") device_id = sync_config.device_id - one_time_keys_count: JsonDict = {} + one_time_keys_count: JsonMapping = {} unused_fallback_key_types: List[str] = [] if device_id: # TODO: We should have a way to let clients differentiate between the states of: @@ -1770,19 +1794,23 @@ class SyncHandler: ) if push_rules_changed: - global_account_data = dict(global_account_data) - global_account_data[ - AccountDataTypes.PUSH_RULES - ] = await self._push_rules_handler.push_rules_for_user(sync_config.user) + global_account_data = { + AccountDataTypes.PUSH_RULES: await self._push_rules_handler.push_rules_for_user( + sync_config.user + ), + **global_account_data, + } else: all_global_account_data = await self.store.get_global_account_data_for_user( user_id ) - global_account_data = dict(all_global_account_data) - global_account_data[ - AccountDataTypes.PUSH_RULES - ] = await self._push_rules_handler.push_rules_for_user(sync_config.user) + global_account_data = { + AccountDataTypes.PUSH_RULES: await self._push_rules_handler.push_rules_for_user( + sync_config.user + ), + **all_global_account_data, + } account_data_for_user = ( await sync_config.filter_collection.filter_global_account_data( @@ -1886,7 +1914,7 @@ class SyncHandler: blocks_all_rooms or sync_result_builder.sync_config.filter_collection.blocks_all_room_account_data() ): - account_data_by_room: Mapping[str, Mapping[str, JsonDict]] = {} + account_data_by_room: Mapping[str, Mapping[str, JsonMapping]] = {} elif since_token and not sync_result_builder.full_state: account_data_by_room = ( await self.store.get_updated_room_account_data_for_user( @@ -2305,7 +2333,7 @@ class SyncHandler: continue leave_token = now_token.copy_and_replace( - StreamKeyType.ROOM, RoomStreamToken(None, event.stream_ordering) + StreamKeyType.ROOM, RoomStreamToken(stream=event.stream_ordering) ) room_entries.append( RoomSyncResultBuilder( @@ -2326,8 +2354,8 @@ class SyncHandler: sync_result_builder: "SyncResultBuilder", room_builder: "RoomSyncResultBuilder", ephemeral: List[JsonDict], - tags: Optional[Mapping[str, Mapping[str, Any]]], - account_data: Mapping[str, JsonDict], + tags: Optional[Mapping[str, JsonMapping]], + account_data: Mapping[str, JsonMapping], always_include: bool = False, ) -> None: """Populates the `joined` and `archived` section of `sync_result_builder` diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 7aeae5319c..bdefa7f26f 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -26,9 +26,17 @@ from synapse.metrics.background_process_metrics import ( ) from synapse.replication.tcp.streams import TypingStream from synapse.streams import EventSource -from synapse.types import JsonDict, Requester, StreamKeyType, UserID +from synapse.types import ( + JsonDict, + JsonMapping, + Requester, + StrCollection, + StreamKeyType, + UserID, +) from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.metrics import Measure +from synapse.util.retryutils import filter_destinations_by_retry_limiter from synapse.util.wheel_timer import WheelTimer if TYPE_CHECKING: @@ -150,8 +158,15 @@ class FollowerTypingHandler: now=now, obj=member, then=now + FEDERATION_PING_INTERVAL ) - hosts = await self._storage_controllers.state.get_current_hosts_in_room( - member.room_id + hosts: StrCollection = ( + await self._storage_controllers.state.get_current_hosts_in_room( + member.room_id + ) + ) + hosts = await filter_destinations_by_retry_limiter( + hosts, + clock=self.clock, + store=self.store, ) for domain in hosts: if not self.is_mine_server_name(domain): @@ -479,7 +494,7 @@ class TypingWriterHandler(FollowerTypingHandler): raise Exception("Typing writer instance got typing info over replication") -class TypingNotificationEventSource(EventSource[int, JsonDict]): +class TypingNotificationEventSource(EventSource[int, JsonMapping]): def __init__(self, hs: "HomeServer"): self._main_store = hs.get_datastores().main self.clock = hs.get_clock() @@ -489,7 +504,7 @@ class TypingNotificationEventSource(EventSource[int, JsonDict]): # self.get_typing_handler = hs.get_typing_handler - def _make_event_for(self, room_id: str) -> JsonDict: + def _make_event_for(self, room_id: str) -> JsonMapping: typing = self.get_typing_handler()._room_typing[room_id] return { "type": EduTypes.TYPING, @@ -499,7 +514,7 @@ class TypingNotificationEventSource(EventSource[int, JsonDict]): async def get_new_events_as( self, from_key: int, service: ApplicationService - ) -> Tuple[List[JsonDict], int]: + ) -> Tuple[List[JsonMapping], int]: """Returns a set of new typing events that an appservice may be interested in. @@ -543,7 +558,7 @@ class TypingNotificationEventSource(EventSource[int, JsonDict]): room_ids: Iterable[str], is_guest: bool, explicit_room_id: Optional[str] = None, - ) -> Tuple[List[JsonDict], int]: + ) -> Tuple[List[JsonMapping], int]: with Measure(self.clock, "typing.get_new_events"): from_key = int(from_key) handler = self.get_typing_handler() diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index 05197edc95..a0f5568000 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -94,6 +94,7 @@ class UserDirectoryHandler(StateDeltasHandler): self.is_mine_id = hs.is_mine_id self.update_user_directory = hs.config.worker.should_update_user_directory self.search_all_users = hs.config.userdirectory.user_directory_search_all_users + self.show_locked_users = hs.config.userdirectory.show_locked_users self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker self._hs = hs @@ -144,7 +145,9 @@ class UserDirectoryHandler(StateDeltasHandler): ] } """ - results = await self.store.search_user_dir(user_id, search_term, limit) + results = await self.store.search_user_dir( + user_id, search_term, limit, self.show_locked_users + ) # Remove any spammy users from the results. non_spammy_users = [] diff --git a/synapse/handlers/worker_lock.py b/synapse/handlers/worker_lock.py new file mode 100644 index 0000000000..58efe7116b --- /dev/null +++ b/synapse/handlers/worker_lock.py @@ -0,0 +1,337 @@ +# Copyright 2023 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +from types import TracebackType +from typing import ( + TYPE_CHECKING, + AsyncContextManager, + Collection, + Dict, + Optional, + Tuple, + Type, + Union, +) +from weakref import WeakSet + +import attr + +from twisted.internet import defer +from twisted.internet.interfaces import IReactorTime + +from synapse.logging.context import PreserveLoggingContext +from synapse.logging.opentracing import start_active_span +from synapse.metrics.background_process_metrics import wrap_as_background_process +from synapse.storage.databases.main.lock import Lock, LockStore +from synapse.util.async_helpers import timeout_deferred + +if TYPE_CHECKING: + from synapse.logging.opentracing import opentracing + from synapse.server import HomeServer + + +# This lock is used to avoid creating an event while we are purging the room. +# We take a read lock when creating an event, and a write one when purging a room. +# This is because it is fine to create several events concurrently, since referenced events +# will not disappear under our feet as long as we don't delete the room. +NEW_EVENT_DURING_PURGE_LOCK_NAME = "new_event_during_purge_lock" + + +class WorkerLocksHandler: + """A class for waiting on taking out locks, rather than using the storage + functions directly (which don't support awaiting). + """ + + def __init__(self, hs: "HomeServer") -> None: + self._reactor = hs.get_reactor() + self._store = hs.get_datastores().main + self._clock = hs.get_clock() + self._notifier = hs.get_notifier() + self._instance_name = hs.get_instance_name() + + # Map from lock name/key to set of `WaitingLock` that are active for + # that lock. + self._locks: Dict[ + Tuple[str, str], WeakSet[Union[WaitingLock, WaitingMultiLock]] + ] = {} + + self._clock.looping_call(self._cleanup_locks, 30_000) + + self._notifier.add_lock_released_callback(self._on_lock_released) + + def acquire_lock(self, lock_name: str, lock_key: str) -> "WaitingLock": + """Acquire a standard lock, returns a context manager that will block + until the lock is acquired. + + Note: Care must be taken to avoid deadlocks. In particular, this + function does *not* timeout. + + Usage: + async with handler.acquire_lock(name, key): + # Do work while holding the lock... + """ + + lock = WaitingLock( + reactor=self._reactor, + store=self._store, + handler=self, + lock_name=lock_name, + lock_key=lock_key, + write=None, + ) + + self._locks.setdefault((lock_name, lock_key), WeakSet()).add(lock) + + return lock + + def acquire_read_write_lock( + self, + lock_name: str, + lock_key: str, + *, + write: bool, + ) -> "WaitingLock": + """Acquire a read/write lock, returns a context manager that will block + until the lock is acquired. + + Note: Care must be taken to avoid deadlocks. In particular, this + function does *not* timeout. + + Usage: + async with handler.acquire_read_write_lock(name, key, write=True): + # Do work while holding the lock... + """ + + lock = WaitingLock( + reactor=self._reactor, + store=self._store, + handler=self, + lock_name=lock_name, + lock_key=lock_key, + write=write, + ) + + self._locks.setdefault((lock_name, lock_key), WeakSet()).add(lock) + + return lock + + def acquire_multi_read_write_lock( + self, + lock_names: Collection[Tuple[str, str]], + *, + write: bool, + ) -> "WaitingMultiLock": + """Acquires multi read/write locks at once, returns a context manager + that will block until all the locks are acquired. + + This will try and acquire all locks at once, and will never hold on to a + subset of the locks. (This avoids accidentally creating deadlocks). + + Note: Care must be taken to avoid deadlocks. In particular, this + function does *not* timeout. + """ + + lock = WaitingMultiLock( + lock_names=lock_names, + write=write, + reactor=self._reactor, + store=self._store, + handler=self, + ) + + for lock_name, lock_key in lock_names: + self._locks.setdefault((lock_name, lock_key), WeakSet()).add(lock) + + return lock + + def notify_lock_released(self, lock_name: str, lock_key: str) -> None: + """Notify that a lock has been released. + + Pokes both the notifier and replication. + """ + + self._notifier.notify_lock_released(self._instance_name, lock_name, lock_key) + + def _on_lock_released( + self, instance_name: str, lock_name: str, lock_key: str + ) -> None: + """Called when a lock has been released. + + Wakes up any locks that might be waiting on this. + """ + locks = self._locks.get((lock_name, lock_key)) + if not locks: + return + + def _wake_deferred(deferred: defer.Deferred) -> None: + if not deferred.called: + deferred.callback(None) + + for lock in locks: + self._clock.call_later(0, _wake_deferred, lock.deferred) + + @wrap_as_background_process("_cleanup_locks") + async def _cleanup_locks(self) -> None: + """Periodically cleans out stale entries in the locks map""" + self._locks = {key: value for key, value in self._locks.items() if value} + + +@attr.s(auto_attribs=True, eq=False) +class WaitingLock: + reactor: IReactorTime + store: LockStore + handler: WorkerLocksHandler + lock_name: str + lock_key: str + write: Optional[bool] + deferred: "defer.Deferred[None]" = attr.Factory(defer.Deferred) + _inner_lock: Optional[Lock] = None + _retry_interval: float = 0.1 + _lock_span: "opentracing.Scope" = attr.Factory( + lambda: start_active_span("WaitingLock.lock") + ) + + async def __aenter__(self) -> None: + self._lock_span.__enter__() + + with start_active_span("WaitingLock.waiting_for_lock"): + while self._inner_lock is None: + self.deferred = defer.Deferred() + + if self.write is not None: + lock = await self.store.try_acquire_read_write_lock( + self.lock_name, self.lock_key, write=self.write + ) + else: + lock = await self.store.try_acquire_lock( + self.lock_name, self.lock_key + ) + + if lock: + self._inner_lock = lock + break + + try: + # Wait until the we get notified the lock might have been + # released (by the deferred being resolved). We also + # periodically wake up in case the lock was released but we + # weren't notified. + with PreserveLoggingContext(): + await timeout_deferred( + deferred=self.deferred, + timeout=self._get_next_retry_interval(), + reactor=self.reactor, + ) + except Exception: + pass + + return await self._inner_lock.__aenter__() + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc: Optional[BaseException], + tb: Optional[TracebackType], + ) -> Optional[bool]: + assert self._inner_lock + + self.handler.notify_lock_released(self.lock_name, self.lock_key) + + try: + r = await self._inner_lock.__aexit__(exc_type, exc, tb) + finally: + self._lock_span.__exit__(exc_type, exc, tb) + + return r + + def _get_next_retry_interval(self) -> float: + next = self._retry_interval + self._retry_interval = max(5, next * 2) + return next * random.uniform(0.9, 1.1) + + +@attr.s(auto_attribs=True, eq=False) +class WaitingMultiLock: + lock_names: Collection[Tuple[str, str]] + + write: bool + + reactor: IReactorTime + store: LockStore + handler: WorkerLocksHandler + + deferred: "defer.Deferred[None]" = attr.Factory(defer.Deferred) + + _inner_lock_cm: Optional[AsyncContextManager] = None + _retry_interval: float = 0.1 + _lock_span: "opentracing.Scope" = attr.Factory( + lambda: start_active_span("WaitingLock.lock") + ) + + async def __aenter__(self) -> None: + self._lock_span.__enter__() + + with start_active_span("WaitingLock.waiting_for_lock"): + while self._inner_lock_cm is None: + self.deferred = defer.Deferred() + + lock_cm = await self.store.try_acquire_multi_read_write_lock( + self.lock_names, write=self.write + ) + + if lock_cm: + self._inner_lock_cm = lock_cm + break + + try: + # Wait until the we get notified the lock might have been + # released (by the deferred being resolved). We also + # periodically wake up in case the lock was released but we + # weren't notified. + with PreserveLoggingContext(): + await timeout_deferred( + deferred=self.deferred, + timeout=self._get_next_retry_interval(), + reactor=self.reactor, + ) + except Exception: + pass + + assert self._inner_lock_cm + await self._inner_lock_cm.__aenter__() + return + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc: Optional[BaseException], + tb: Optional[TracebackType], + ) -> Optional[bool]: + assert self._inner_lock_cm + + for lock_name, lock_key in self.lock_names: + self.handler.notify_lock_released(lock_name, lock_key) + + try: + r = await self._inner_lock_cm.__aexit__(exc_type, exc, tb) + finally: + self._lock_span.__exit__(exc_type, exc, tb) + + return r + + def _get_next_retry_interval(self) -> float: + next = self._retry_interval + self._retry_interval = max(5, next * 2) + return next * random.uniform(0.9, 1.1) diff --git a/synapse/http/client.py b/synapse/http/client.py index f1ab7a8bc9..c750e03b36 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -78,7 +78,7 @@ from synapse.http.replicationagent import ReplicationAgent from synapse.http.types import QueryParams from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.opentracing import set_tag, start_active_span, tags -from synapse.types import ISynapseReactor +from synapse.types import ISynapseReactor, StrSequence from synapse.util import json_decoder from synapse.util.async_helpers import timeout_deferred @@ -108,10 +108,9 @@ RawHeaders = Union[Mapping[str, "RawHeaderValue"], Mapping[bytes, "RawHeaderValu # the value actually has to be a List, but List is invariant so we can't specify that # the entries can either be Lists or bytes. RawHeaderValue = Union[ - List[str], + StrSequence, List[bytes], List[Union[str, bytes]], - Tuple[str, ...], Tuple[bytes, ...], Tuple[Union[str, bytes], ...], ] @@ -835,6 +834,7 @@ class ReplicationClient(BaseHttpClient): self.agent: IAgent = ReplicationAgent( hs.get_reactor(), + hs.config.worker.instance_map, contextFactory=hs.get_http_client_context_factory(), pool=pool, ) @@ -1036,7 +1036,12 @@ class _ReadBodyWithMaxSizeProtocol(protocol.Protocol): if reason.check(ResponseDone): self.deferred.callback(self.length) elif reason.check(PotentialDataLoss): - # stolen from https://github.com/twisted/treq/pull/49/files + # This applies to requests which don't set `Content-Length` or a + # `Transfer-Encoding` in the response because in this case the end of the + # response is indicated by the connection being closed, an event which may + # also be due to a transient network problem or other error. But since this + # behavior is expected of some servers (like YouTube), let's ignore it. + # Stolen from https://github.com/twisted/treq/pull/49/files # http://twistedmatrix.com/trac/ticket/4840 self.deferred.callback(self.length) else: diff --git a/synapse/http/connectproxyclient.py b/synapse/http/connectproxyclient.py index 23a60af171..636efc33e8 100644 --- a/synapse/http/connectproxyclient.py +++ b/synapse/http/connectproxyclient.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import abc import base64 import logging from typing import Optional, Union @@ -39,8 +40,14 @@ class ProxyConnectError(ConnectError): pass -@attr.s(auto_attribs=True) class ProxyCredentials: + @abc.abstractmethod + def as_proxy_authorization_value(self) -> bytes: + raise NotImplementedError() + + +@attr.s(auto_attribs=True) +class BasicProxyCredentials(ProxyCredentials): username_password: bytes def as_proxy_authorization_value(self) -> bytes: @@ -55,6 +62,17 @@ class ProxyCredentials: return b"Basic " + base64.encodebytes(self.username_password) +@attr.s(auto_attribs=True) +class BearerProxyCredentials(ProxyCredentials): + access_token: bytes + + def as_proxy_authorization_value(self) -> bytes: + """ + Return the value for a Proxy-Authorization header (i.e. 'Bearer xxx'). + """ + return b"Bearer " + self.access_token + + @implementer(IStreamClientEndpoint) class HTTPConnectProxyEndpoint: """An Endpoint implementation which will send a CONNECT request to an http proxy diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py index 7e8cf31682..a3a396bb37 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py @@ -51,8 +51,10 @@ logger = logging.getLogger(__name__) @implementer(IAgent) class MatrixFederationAgent: """An Agent-like thing which provides a `request` method which correctly - handles resolving matrix server names when using matrix://. Handles standard - https URIs as normal. + handles resolving matrix server names when using `matrix-federation://`. Handles + standard https URIs as normal. The `matrix-federation://` scheme is internal to + Synapse and we purposely want to avoid colliding with the `matrix://` URL scheme + which is now specced. Doesn't implement any retries. (Those are done in MatrixFederationHttpClient.) @@ -167,14 +169,14 @@ class MatrixFederationAgent: # There must be a valid hostname. assert parsed_uri.hostname - # If this is a matrix:// URI check if the server has delegated matrix + # If this is a matrix-federation:// URI check if the server has delegated matrix # traffic using well-known delegation. # # We have to do this here and not in the endpoint as we need to rewrite # the host header with the delegated server name. delegated_server = None if ( - parsed_uri.scheme == b"matrix" + parsed_uri.scheme == b"matrix-federation" and not _is_ip_literal(parsed_uri.hostname) and not parsed_uri.port ): @@ -250,7 +252,7 @@ class MatrixHostnameEndpointFactory: @implementer(IStreamClientEndpoint) class MatrixHostnameEndpoint: - """An endpoint that resolves matrix:// URLs using Matrix server name + """An endpoint that resolves matrix-federation:// URLs using Matrix server name resolution (i.e. via SRV). Does not check for well-known delegation. Args: @@ -379,7 +381,7 @@ class MatrixHostnameEndpoint: connect to. """ - if self._parsed_uri.scheme != b"matrix": + if self._parsed_uri.scheme != b"matrix-federation": return [Server(host=self._parsed_uri.host, port=self._parsed_uri.port)] # Note: We don't do well-known lookup as that needs to have happened @@ -397,15 +399,34 @@ class MatrixHostnameEndpoint: if port or _is_ip_literal(host): return [Server(host, port or 8448)] + # Check _matrix-fed._tcp SRV record. logger.debug("Looking up SRV record for %s", host.decode(errors="replace")) + server_list = await self._srv_resolver.resolve_service( + b"_matrix-fed._tcp." + host + ) + + if server_list: + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "Got %s from SRV lookup for %s", + ", ".join(map(str, server_list)), + host.decode(errors="replace"), + ) + return server_list + + # No _matrix-fed._tcp SRV record, fallback to legacy _matrix._tcp SRV record. + logger.debug( + "Looking up deprecated SRV record for %s", host.decode(errors="replace") + ) server_list = await self._srv_resolver.resolve_service(b"_matrix._tcp." + host) if server_list: - logger.debug( - "Got %s from SRV lookup for %s", - ", ".join(map(str, server_list)), - host.decode(errors="replace"), - ) + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "Got %s from deprecated SRV lookup for %s", + ", ".join(map(str, server_list)), + host.decode(errors="replace"), + ) return server_list # No SRV records, so we fallback to host and 8448 diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 9094dab0fe..08c7fc1631 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -50,7 +50,7 @@ from twisted.internet.interfaces import IReactorTime from twisted.internet.task import Cooperator from twisted.web.client import ResponseFailed from twisted.web.http_headers import Headers -from twisted.web.iweb import IBodyProducer, IResponse +from twisted.web.iweb import IAgent, IBodyProducer, IResponse import synapse.metrics import synapse.util.retryutils @@ -71,7 +71,9 @@ from synapse.http.client import ( encode_query_args, read_body_with_max_size, ) +from synapse.http.connectproxyclient import BearerProxyCredentials from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent +from synapse.http.proxyagent import ProxyAgent from synapse.http.types import QueryParams from synapse.logging import opentracing from synapse.logging.context import make_deferred_yieldable, run_in_background @@ -95,8 +97,6 @@ incoming_responses_counter = Counter( ) -MAX_LONG_RETRIES = 10 -MAX_SHORT_RETRIES = 3 MAXINT = sys.maxsize @@ -174,7 +174,14 @@ class MatrixFederationRequest: # The object is frozen so we can pre-compute this. uri = urllib.parse.urlunparse( - (b"matrix", destination_bytes, path_bytes, None, query_bytes, b"") + ( + b"matrix-federation", + destination_bytes, + path_bytes, + None, + query_bytes, + b"", + ) ) object.__setattr__(self, "uri", uri) @@ -236,7 +243,7 @@ class LegacyJsonSendParser(_BaseJsonParser[Tuple[int, JsonDict]]): return ( isinstance(v, list) and len(v) == 2 - and type(v[0]) == int + and type(v[0]) == int # noqa: E721 and isinstance(v[1], dict) ) @@ -388,17 +395,41 @@ class MatrixFederationHttpClient: if hs.config.server.user_agent_suffix: user_agent = "%s %s" % (user_agent, hs.config.server.user_agent_suffix) - federation_agent = MatrixFederationAgent( - self.reactor, - tls_client_options_factory, - user_agent.encode("ascii"), - hs.config.server.federation_ip_range_allowlist, - hs.config.server.federation_ip_range_blocklist, + outbound_federation_restricted_to = ( + hs.config.worker.outbound_federation_restricted_to ) + if hs.get_instance_name() in outbound_federation_restricted_to: + # Talk to federation directly + federation_agent: IAgent = MatrixFederationAgent( + self.reactor, + tls_client_options_factory, + user_agent.encode("ascii"), + hs.config.server.federation_ip_range_allowlist, + hs.config.server.federation_ip_range_blocklist, + ) + else: + proxy_authorization_secret = hs.config.worker.worker_replication_secret + assert ( + proxy_authorization_secret is not None + ), "`worker_replication_secret` must be set when using `outbound_federation_restricted_to` (used to authenticate requests across workers)" + federation_proxy_credentials = BearerProxyCredentials( + proxy_authorization_secret.encode("ascii") + ) + + # We need to talk to federation via the proxy via one of the configured + # locations + federation_proxy_locations = outbound_federation_restricted_to.locations + federation_agent = ProxyAgent( + self.reactor, + self.reactor, + tls_client_options_factory, + federation_proxy_locations=federation_proxy_locations, + federation_proxy_credentials=federation_proxy_credentials, + ) # Use a BlocklistingAgentWrapper to prevent circumventing the IP # blocking via IP literals in server names - self.agent = BlocklistingAgentWrapper( + self.agent: IAgent = BlocklistingAgentWrapper( federation_agent, ip_blocklist=hs.config.server.federation_ip_range_blocklist, ) @@ -406,7 +437,15 @@ class MatrixFederationHttpClient: self.clock = hs.get_clock() self._store = hs.get_datastores().main self.version_string_bytes = hs.version_string.encode("ascii") - self.default_timeout = 60 + self.default_timeout_seconds = hs.config.federation.client_timeout_ms / 1000 + self.max_long_retry_delay_seconds = ( + hs.config.federation.max_long_retry_delay_ms / 1000 + ) + self.max_short_retry_delay_seconds = ( + hs.config.federation.max_short_retry_delay_ms / 1000 + ) + self.max_long_retries = hs.config.federation.max_long_retries + self.max_short_retries = hs.config.federation.max_short_retries self._cooperator = Cooperator(scheduler=_make_scheduler(self.reactor)) @@ -473,6 +512,7 @@ class MatrixFederationHttpClient: long_retries: bool = False, ignore_backoff: bool = False, backoff_on_404: bool = False, + backoff_on_all_error_codes: bool = False, ) -> IResponse: """ Sends a request to the given server. @@ -499,13 +539,21 @@ class MatrixFederationHttpClient: Note that the above intervals are *in addition* to the time spent waiting for the request to complete (up to `timeout` ms). - NB: the long retry algorithm takes over 20 minutes to complete, with - a default timeout of 60s! + NB: the long retry algorithm takes over 20 minutes to complete, with a + default timeout of 60s! It's best not to use the `long_retries` option + for something that is blocking a client so we don't make them wait for + aaaaages, whereas some things like sending transactions (server to + server) we can be a lot more lenient but its very fuzzy / hand-wavey. + + In the future, we could be more intelligent about doing this sort of + thing by looking at things with the bigger picture in mind, + https://github.com/matrix-org/synapse/issues/8917 ignore_backoff: true to ignore the historical backoff data and try the request anyway. backoff_on_404: Back off if we get a 404 + backoff_on_all_error_codes: Back off if we get any error response Returns: Resolves with the HTTP response object on success. @@ -528,10 +576,10 @@ class MatrixFederationHttpClient: logger.exception(f"Invalid destination: {request.destination}.") raise FederationDeniedError(request.destination) - if timeout: + if timeout is not None: _sec_timeout = timeout / 1000 else: - _sec_timeout = self.default_timeout + _sec_timeout = self.default_timeout_seconds if ( self.hs.config.federation.federation_domain_whitelist is not None @@ -548,6 +596,7 @@ class MatrixFederationHttpClient: ignore_backoff=ignore_backoff, notifier=self.hs.get_notifier(), replication_client=self.hs.get_replication_command_handler(), + backoff_on_all_error_codes=backoff_on_all_error_codes, ) method_bytes = request.method.encode("ascii") @@ -576,9 +625,9 @@ class MatrixFederationHttpClient: # XXX: Would be much nicer to retry only at the transaction-layer # (once we have reliable transactions in place) if long_retries: - retries_left = MAX_LONG_RETRIES + retries_left = self.max_long_retries else: - retries_left = MAX_SHORT_RETRIES + retries_left = self.max_short_retries url_bytes = request.uri url_str = url_bytes.decode("ascii") @@ -723,24 +772,34 @@ class MatrixFederationHttpClient: if retries_left and not timeout: if long_retries: - delay = 4 ** (MAX_LONG_RETRIES + 1 - retries_left) - delay = min(delay, 60) - delay *= random.uniform(0.8, 1.4) + delay_seconds = 4 ** ( + self.max_long_retries + 1 - retries_left + ) + delay_seconds = min( + delay_seconds, self.max_long_retry_delay_seconds + ) + delay_seconds *= random.uniform(0.8, 1.4) else: - delay = 0.5 * 2 ** (MAX_SHORT_RETRIES - retries_left) - delay = min(delay, 2) - delay *= random.uniform(0.8, 1.4) + delay_seconds = 0.5 * 2 ** ( + self.max_short_retries - retries_left + ) + delay_seconds = min( + delay_seconds, self.max_short_retry_delay_seconds + ) + delay_seconds *= random.uniform(0.8, 1.4) logger.debug( "{%s} [%s] Waiting %ss before re-sending...", request.txn_id, request.destination, - delay, + delay_seconds, ) # Sleep for the calculated delay, or wake up immediately # if we get notified that the server is back up. - await self._sleeper.sleep(request.destination, delay * 1000) + await self._sleeper.sleep( + request.destination, delay_seconds * 1000 + ) retries_left -= 1 else: raise @@ -833,6 +892,7 @@ class MatrixFederationHttpClient: backoff_on_404: bool = False, try_trailing_slash_on_400: bool = False, parser: Literal[None] = None, + backoff_on_all_error_codes: bool = False, ) -> JsonDict: ... @@ -850,6 +910,7 @@ class MatrixFederationHttpClient: backoff_on_404: bool = False, try_trailing_slash_on_400: bool = False, parser: Optional[ByteParser[T]] = None, + backoff_on_all_error_codes: bool = False, ) -> T: ... @@ -866,6 +927,7 @@ class MatrixFederationHttpClient: backoff_on_404: bool = False, try_trailing_slash_on_400: bool = False, parser: Optional[ByteParser[T]] = None, + backoff_on_all_error_codes: bool = False, ) -> Union[JsonDict, T]: """Sends the specified json data using PUT @@ -901,6 +963,7 @@ class MatrixFederationHttpClient: enabled. parser: The parser to use to decode the response. Defaults to parsing as JSON. + backoff_on_all_error_codes: Back off if we get any error response Returns: Succeeds when we get a 2xx HTTP response. The @@ -934,12 +997,13 @@ class MatrixFederationHttpClient: ignore_backoff=ignore_backoff, long_retries=long_retries, timeout=timeout, + backoff_on_all_error_codes=backoff_on_all_error_codes, ) if timeout is not None: _sec_timeout = timeout / 1000 else: - _sec_timeout = self.default_timeout + _sec_timeout = self.default_timeout_seconds if parser is None: parser = cast(ByteParser[T], JsonParser()) @@ -1017,10 +1081,10 @@ class MatrixFederationHttpClient: ignore_backoff=ignore_backoff, ) - if timeout: + if timeout is not None: _sec_timeout = timeout / 1000 else: - _sec_timeout = self.default_timeout + _sec_timeout = self.default_timeout_seconds body = await _handle_response( self.reactor, _sec_timeout, request, response, start_ms, parser=JsonParser() @@ -1110,6 +1174,101 @@ class MatrixFederationHttpClient: RequestSendFailed: If there were problems connecting to the remote, due to e.g. DNS failures, connection timeouts etc. """ + json_dict, _ = await self.get_json_with_headers( + destination=destination, + path=path, + args=args, + retry_on_dns_fail=retry_on_dns_fail, + timeout=timeout, + ignore_backoff=ignore_backoff, + try_trailing_slash_on_400=try_trailing_slash_on_400, + parser=parser, + ) + return json_dict + + @overload + async def get_json_with_headers( + self, + destination: str, + path: str, + args: Optional[QueryParams] = None, + retry_on_dns_fail: bool = True, + timeout: Optional[int] = None, + ignore_backoff: bool = False, + try_trailing_slash_on_400: bool = False, + parser: Literal[None] = None, + ) -> Tuple[JsonDict, Dict[bytes, List[bytes]]]: + ... + + @overload + async def get_json_with_headers( + self, + destination: str, + path: str, + args: Optional[QueryParams] = ..., + retry_on_dns_fail: bool = ..., + timeout: Optional[int] = ..., + ignore_backoff: bool = ..., + try_trailing_slash_on_400: bool = ..., + parser: ByteParser[T] = ..., + ) -> Tuple[T, Dict[bytes, List[bytes]]]: + ... + + async def get_json_with_headers( + self, + destination: str, + path: str, + args: Optional[QueryParams] = None, + retry_on_dns_fail: bool = True, + timeout: Optional[int] = None, + ignore_backoff: bool = False, + try_trailing_slash_on_400: bool = False, + parser: Optional[ByteParser[T]] = None, + ) -> Tuple[Union[JsonDict, T], Dict[bytes, List[bytes]]]: + """GETs some json from the given host homeserver and path + + Args: + destination: The remote server to send the HTTP request to. + + path: The HTTP path. + + args: A dictionary used to create query strings, defaults to + None. + + retry_on_dns_fail: true if the request should be retried on DNS failures + + timeout: number of milliseconds to wait for the response. + self._default_timeout (60s) by default. + + Note that we may make several attempts to send the request; this + timeout applies to the time spent waiting for response headers for + *each* attempt (including connection time) as well as the time spent + reading the response body after a 200 response. + + ignore_backoff: true to ignore the historical backoff data + and try the request anyway. + + try_trailing_slash_on_400: True if on a 400 M_UNRECOGNIZED + response we should try appending a trailing slash to the end of + the request. Workaround for #3622 in Synapse <= v0.99.3. + + parser: The parser to use to decode the response. Defaults to + parsing as JSON. + + Returns: + Succeeds when we get a 2xx HTTP response. The result will be a tuple of the + decoded JSON body and a dict of the response headers. + + Raises: + HttpResponseException: If we get an HTTP response code >= 300 + (except 429). + NotRetryingDestination: If we are not yet ready to retry this + server. + FederationDeniedError: If this destination is not on our + federation whitelist + RequestSendFailed: If there were problems connecting to the + remote, due to e.g. DNS failures, connection timeouts etc. + """ request = MatrixFederationRequest( method="GET", destination=destination, path=path, query=args ) @@ -1125,10 +1284,12 @@ class MatrixFederationHttpClient: timeout=timeout, ) + headers = dict(response.headers.getAllRawHeaders()) + if timeout is not None: _sec_timeout = timeout / 1000 else: - _sec_timeout = self.default_timeout + _sec_timeout = self.default_timeout_seconds if parser is None: parser = cast(ByteParser[T], JsonParser()) @@ -1142,7 +1303,7 @@ class MatrixFederationHttpClient: parser=parser, ) - return body + return body, headers async def delete_json( self, @@ -1204,7 +1365,7 @@ class MatrixFederationHttpClient: if timeout is not None: _sec_timeout = timeout / 1000 else: - _sec_timeout = self.default_timeout + _sec_timeout = self.default_timeout_seconds body = await _handle_response( self.reactor, _sec_timeout, request, response, start_ms, parser=JsonParser() @@ -1256,7 +1417,7 @@ class MatrixFederationHttpClient: try: d = read_body_with_max_size(response, output_stream, max_size) - d.addTimeout(self.default_timeout, self.reactor) + d.addTimeout(self.default_timeout_seconds, self.reactor) length = await make_deferred_yieldable(d) except BodyExceededMaxSize: msg = "Requested file is too large > %r bytes" % (max_size,) diff --git a/synapse/http/proxy.py b/synapse/http/proxy.py new file mode 100644 index 0000000000..c9f51e51bc --- /dev/null +++ b/synapse/http/proxy.py @@ -0,0 +1,283 @@ +# Copyright 2023 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import json +import logging +import urllib.parse +from typing import TYPE_CHECKING, Any, Optional, Set, Tuple, cast + +from twisted.internet import protocol +from twisted.internet.interfaces import ITCPTransport +from twisted.internet.protocol import connectionDone +from twisted.python import failure +from twisted.python.failure import Failure +from twisted.web.client import ResponseDone +from twisted.web.http_headers import Headers +from twisted.web.iweb import IResponse +from twisted.web.resource import IResource +from twisted.web.server import Request, Site + +from synapse.api.errors import Codes, InvalidProxyCredentialsError +from synapse.http import QuieterFileBodyProducer +from synapse.http.server import _AsyncResource +from synapse.logging.context import make_deferred_yieldable, run_in_background +from synapse.types import ISynapseReactor +from synapse.util.async_helpers import timeout_deferred + +if TYPE_CHECKING: + from synapse.http.site import SynapseRequest + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + +# "Hop-by-hop" headers (as opposed to "end-to-end" headers) as defined by RFC2616 +# section 13.5.1 and referenced in RFC9110 section 7.6.1. These are meant to only be +# consumed by the immediate recipient and not be forwarded on. +HOP_BY_HOP_HEADERS = { + "Connection", + "Keep-Alive", + "Proxy-Authenticate", + "Proxy-Authorization", + "TE", + "Trailers", + "Transfer-Encoding", + "Upgrade", +} + + +def parse_connection_header_value( + connection_header_value: Optional[bytes], +) -> Set[str]: + """ + Parse the `Connection` header to determine which headers we should not be copied + over from the remote response. + + As defined by RFC2616 section 14.10 and RFC9110 section 7.6.1 + + Example: `Connection: close, X-Foo, X-Bar` will return `{"Close", "X-Foo", "X-Bar"}` + + Even though "close" is a special directive, let's just treat it as just another + header for simplicity. If people want to check for this directive, they can simply + check for `"Close" in headers`. + + Args: + connection_header_value: The value of the `Connection` header. + + Returns: + The set of header names that should not be copied over from the remote response. + The keys are capitalized in canonical capitalization. + """ + headers = Headers() + extra_headers_to_remove: Set[str] = set() + if connection_header_value: + extra_headers_to_remove = { + headers._canonicalNameCaps(connection_option.strip()).decode("ascii") + for connection_option in connection_header_value.split(b",") + } + + return extra_headers_to_remove + + +class ProxyResource(_AsyncResource): + """ + A stub resource that proxies any requests with a `matrix-federation://` scheme + through the given `federation_agent` to the remote homeserver and ferries back the + info. + """ + + isLeaf = True + + def __init__(self, reactor: ISynapseReactor, hs: "HomeServer"): + super().__init__(True) + + self.reactor = reactor + self.agent = hs.get_federation_http_client().agent + + self._proxy_authorization_secret = hs.config.worker.worker_replication_secret + + def _check_auth(self, request: Request) -> None: + # The `matrix-federation://` proxy functionality can only be used with auth. + # Protect homserver admins forgetting to configure a secret. + assert self._proxy_authorization_secret is not None + + # Get the authorization header. + auth_headers = request.requestHeaders.getRawHeaders(b"Proxy-Authorization") + + if not auth_headers: + raise InvalidProxyCredentialsError( + "Missing Proxy-Authorization header.", Codes.MISSING_TOKEN + ) + if len(auth_headers) > 1: + raise InvalidProxyCredentialsError( + "Too many Proxy-Authorization headers.", Codes.UNAUTHORIZED + ) + parts = auth_headers[0].split(b" ") + if parts[0] == b"Bearer" and len(parts) == 2: + received_secret = parts[1].decode("ascii") + if self._proxy_authorization_secret == received_secret: + # Success! + return + + raise InvalidProxyCredentialsError( + "Invalid Proxy-Authorization header.", Codes.UNAUTHORIZED + ) + + async def _async_render(self, request: "SynapseRequest") -> Tuple[int, Any]: + uri = urllib.parse.urlparse(request.uri) + assert uri.scheme == b"matrix-federation" + + # Check the authorization headers before handling the request. + self._check_auth(request) + + headers = Headers() + for header_name in (b"User-Agent", b"Authorization", b"Content-Type"): + header_value = request.getHeader(header_name) + if header_value: + headers.addRawHeader(header_name, header_value) + + request_deferred = run_in_background( + self.agent.request, + request.method, + request.uri, + headers=headers, + bodyProducer=QuieterFileBodyProducer(request.content), + ) + request_deferred = timeout_deferred( + request_deferred, + # This should be set longer than the timeout in `MatrixFederationHttpClient` + # so that it has enough time to complete and pass us the data before we give + # up. + timeout=90, + reactor=self.reactor, + ) + + response = await make_deferred_yieldable(request_deferred) + + return response.code, response + + def _send_response( + self, + request: "SynapseRequest", + code: int, + response_object: Any, + ) -> None: + response = cast(IResponse, response_object) + response_headers = cast(Headers, response.headers) + + request.setResponseCode(code) + + # The `Connection` header also defines which headers should not be copied over. + connection_header = response_headers.getRawHeaders(b"connection") + extra_headers_to_remove = parse_connection_header_value( + connection_header[0] if connection_header else None + ) + + # Copy headers. + for k, v in response_headers.getAllRawHeaders(): + # Do not copy over any hop-by-hop headers. These are meant to only be + # consumed by the immediate recipient and not be forwarded on. + header_key = k.decode("ascii") + if ( + header_key in HOP_BY_HOP_HEADERS + or header_key in extra_headers_to_remove + ): + continue + + request.responseHeaders.setRawHeaders(k, v) + + response.deliverBody(_ProxyResponseBody(request)) + + def _send_error_response( + self, + f: failure.Failure, + request: "SynapseRequest", + ) -> None: + if isinstance(f.value, InvalidProxyCredentialsError): + error_response_code = f.value.code + error_response_json = {"errcode": f.value.errcode, "err": f.value.msg} + else: + error_response_code = 502 + error_response_json = { + "errcode": Codes.UNKNOWN, + "err": "ProxyResource: Error when proxying request: %s %s -> %s" + % ( + request.method.decode("ascii"), + request.uri.decode("ascii"), + f, + ), + } + + request.setResponseCode(error_response_code) + request.setHeader(b"Content-Type", b"application/json") + request.write((json.dumps(error_response_json)).encode()) + request.finish() + + +class _ProxyResponseBody(protocol.Protocol): + """ + A protocol that proxies the given remote response data back out to the given local + request. + """ + + transport: Optional[ITCPTransport] = None + + def __init__(self, request: "SynapseRequest") -> None: + self._request = request + + def dataReceived(self, data: bytes) -> None: + # Avoid sending response data to the local request that already disconnected + if self._request._disconnected and self.transport is not None: + # Close the connection (forcefully) since all the data will get + # discarded anyway. + self.transport.abortConnection() + return + + self._request.write(data) + + def connectionLost(self, reason: Failure = connectionDone) -> None: + # If the local request is already finished (successfully or failed), don't + # worry about sending anything back. + if self._request.finished: + return + + if reason.check(ResponseDone): + self._request.finish() + else: + # Abort the underlying request since our remote request also failed. + self._request.transport.abortConnection() + + +class ProxySite(Site): + """ + Proxies any requests with a `matrix-federation://` scheme through the given + `federation_agent`. Otherwise, behaves like a normal `Site`. + """ + + def __init__( + self, + resource: IResource, + reactor: ISynapseReactor, + hs: "HomeServer", + ): + super().__init__(resource, reactor=reactor) + + self._proxy_resource = ProxyResource(reactor, hs=hs) + + def getResourceFor(self, request: "SynapseRequest") -> IResource: + uri = urllib.parse.urlparse(request.uri) + if uri.scheme == b"matrix-federation": + return self._proxy_resource + + return super().getResourceFor(request) diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py index 7bdc4acae7..59ab8fad35 100644 --- a/synapse/http/proxyagent.py +++ b/synapse/http/proxyagent.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import random import re -from typing import Any, Dict, Optional, Tuple +from typing import Any, Collection, Dict, List, Optional, Sequence, Tuple from urllib.parse import urlparse from urllib.request import ( # type: ignore[attr-defined] getproxies_environment, @@ -23,8 +24,17 @@ from urllib.request import ( # type: ignore[attr-defined] from zope.interface import implementer from twisted.internet import defer -from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS -from twisted.internet.interfaces import IReactorCore, IStreamClientEndpoint +from twisted.internet.endpoints import ( + HostnameEndpoint, + UNIXClientEndpoint, + wrapClientTLS, +) +from twisted.internet.interfaces import ( + IProtocol, + IProtocolFactory, + IReactorCore, + IStreamClientEndpoint, +) from twisted.python.failure import Failure from twisted.web.client import ( URI, @@ -36,8 +46,18 @@ from twisted.web.error import SchemeNotSupported from twisted.web.http_headers import Headers from twisted.web.iweb import IAgent, IBodyProducer, IPolicyForHTTPS, IResponse +from synapse.config.workers import ( + InstanceLocationConfig, + InstanceTcpLocationConfig, + InstanceUnixLocationConfig, +) from synapse.http import redact_uri -from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint, ProxyCredentials +from synapse.http.connectproxyclient import ( + BasicProxyCredentials, + HTTPConnectProxyEndpoint, + ProxyCredentials, +) +from synapse.logging.context import run_in_background logger = logging.getLogger(__name__) @@ -74,6 +94,14 @@ class ProxyAgent(_AgentBase): use_proxy: Whether proxy settings should be discovered and used from conventional environment variables. + federation_proxy_locations: An optional list of locations to proxy outbound federation + traffic through (only requests that use the `matrix-federation://` scheme + will be proxied). + + federation_proxy_credentials: Required if `federation_proxy_locations` is set. The + credentials to use when proxying outbound federation traffic through another + worker. + Raises: ValueError if use_proxy is set and the environment variables contain an invalid proxy specification. @@ -89,6 +117,8 @@ class ProxyAgent(_AgentBase): bindAddress: Optional[bytes] = None, pool: Optional[HTTPConnectionPool] = None, use_proxy: bool = False, + federation_proxy_locations: Collection[InstanceLocationConfig] = (), + federation_proxy_credentials: Optional[ProxyCredentials] = None, ): contextFactory = contextFactory or BrowserLikePolicyForHTTPS() @@ -127,6 +157,47 @@ class ProxyAgent(_AgentBase): self._policy_for_https = contextFactory self._reactor = reactor + self._federation_proxy_endpoint: Optional[IStreamClientEndpoint] = None + self._federation_proxy_credentials: Optional[ProxyCredentials] = None + if federation_proxy_locations: + assert ( + federation_proxy_credentials is not None + ), "`federation_proxy_credentials` are required when using `federation_proxy_locations`" + + endpoints: List[IStreamClientEndpoint] = [] + for federation_proxy_location in federation_proxy_locations: + endpoint: IStreamClientEndpoint + if isinstance(federation_proxy_location, InstanceTcpLocationConfig): + endpoint = HostnameEndpoint( + self.proxy_reactor, + federation_proxy_location.host, + federation_proxy_location.port, + ) + if federation_proxy_location.tls: + tls_connection_creator = ( + self._policy_for_https.creatorForNetloc( + federation_proxy_location.host.encode("utf-8"), + federation_proxy_location.port, + ) + ) + endpoint = wrapClientTLS(tls_connection_creator, endpoint) + + elif isinstance(federation_proxy_location, InstanceUnixLocationConfig): + endpoint = UNIXClientEndpoint( + self.proxy_reactor, federation_proxy_location.path + ) + + else: + # It is supremely unlikely we ever hit this + raise SchemeNotSupported( + f"Unknown type of Endpoint requested, check {federation_proxy_location}" + ) + + endpoints.append(endpoint) + + self._federation_proxy_endpoint = _RandomSampleEndpoints(endpoints) + self._federation_proxy_credentials = federation_proxy_credentials + def request( self, method: bytes, @@ -214,6 +285,25 @@ class ProxyAgent(_AgentBase): parsed_uri.port, self.https_proxy_creds, ) + elif ( + parsed_uri.scheme == b"matrix-federation" + and self._federation_proxy_endpoint + ): + assert ( + self._federation_proxy_credentials is not None + ), "`federation_proxy_credentials` are required when using `federation_proxy_locations`" + + # Set a Proxy-Authorization header + if headers is None: + headers = Headers() + # We always need authentication for the outbound federation proxy + headers.addRawHeader( + b"Proxy-Authorization", + self._federation_proxy_credentials.as_proxy_authorization_value(), + ) + + endpoint = self._federation_proxy_endpoint + request_path = uri else: # not using a proxy endpoint = HostnameEndpoint( @@ -233,6 +323,11 @@ class ProxyAgent(_AgentBase): endpoint = wrapClientTLS(tls_connection_creator, endpoint) elif parsed_uri.scheme == b"http": pass + elif ( + parsed_uri.scheme == b"matrix-federation" + and self._federation_proxy_endpoint + ): + pass else: return defer.fail( Failure( @@ -334,6 +429,42 @@ def parse_proxy( credentials = None if url.username and url.password: - credentials = ProxyCredentials(b"".join([url.username, b":", url.password])) + credentials = BasicProxyCredentials( + b"".join([url.username, b":", url.password]) + ) return url.scheme, url.hostname, url.port or default_port, credentials + + +@implementer(IStreamClientEndpoint) +class _RandomSampleEndpoints: + """An endpoint that randomly iterates through a given list of endpoints at + each connection attempt. + """ + + def __init__( + self, + endpoints: Sequence[IStreamClientEndpoint], + ) -> None: + assert endpoints + self._endpoints = endpoints + + def __repr__(self) -> str: + return f"<_RandomSampleEndpoints endpoints={self._endpoints}>" + + def connect( + self, protocol_factory: IProtocolFactory + ) -> "defer.Deferred[IProtocol]": + """Implements IStreamClientEndpoint interface""" + + return run_in_background(self._do_connect, protocol_factory) + + async def _do_connect(self, protocol_factory: IProtocolFactory) -> IProtocol: + failures: List[Failure] = [] + for endpoint in random.sample(self._endpoints, k=len(self._endpoints)): + try: + return await endpoint.connect(protocol_factory) + except Exception: + failures.append(Failure()) + + failures.pop().raiseException() diff --git a/synapse/http/replicationagent.py b/synapse/http/replicationagent.py index 5ecd08be0f..3ba2f22dfd 100644 --- a/synapse/http/replicationagent.py +++ b/synapse/http/replicationagent.py @@ -13,12 +13,16 @@ # limitations under the License. import logging -from typing import Optional +from typing import Dict, Optional from zope.interface import implementer from twisted.internet import defer -from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS +from twisted.internet.endpoints import ( + HostnameEndpoint, + UNIXClientEndpoint, + wrapClientTLS, +) from twisted.internet.interfaces import IStreamClientEndpoint from twisted.python.failure import Failure from twisted.web.client import URI, HTTPConnectionPool, _AgentBase @@ -32,6 +36,11 @@ from twisted.web.iweb import ( IResponse, ) +from synapse.config.workers import ( + InstanceLocationConfig, + InstanceTcpLocationConfig, + InstanceUnixLocationConfig, +) from synapse.types import ISynapseReactor logger = logging.getLogger(__name__) @@ -39,14 +48,16 @@ logger = logging.getLogger(__name__) @implementer(IAgentEndpointFactory) class ReplicationEndpointFactory: - """Connect to a given TCP socket""" + """Connect to a given TCP or UNIX socket""" def __init__( self, reactor: ISynapseReactor, + instance_map: Dict[str, InstanceLocationConfig], context_factory: IPolicyForHTTPS, ) -> None: self.reactor = reactor + self.instance_map = instance_map self.context_factory = context_factory def endpointForURI(self, uri: URI) -> IStreamClientEndpoint: @@ -58,15 +69,32 @@ class ReplicationEndpointFactory: Returns: The correct client endpoint object """ - if uri.scheme in (b"http", b"https"): - endpoint = HostnameEndpoint(self.reactor, uri.host, uri.port) - if uri.scheme == b"https": + # The given URI has a special scheme and includes the worker name. The + # actual connection details are pulled from the instance map. + worker_name = uri.netloc.decode("utf-8") + location_config = self.instance_map[worker_name] + scheme = location_config.scheme() + + if isinstance(location_config, InstanceTcpLocationConfig): + endpoint = HostnameEndpoint( + self.reactor, + location_config.host, + location_config.port, + ) + if scheme == "https": endpoint = wrapClientTLS( - self.context_factory.creatorForNetloc(uri.host, uri.port), endpoint + # The 'port' argument below isn't actually used by the function + self.context_factory.creatorForNetloc( + location_config.host.encode("utf-8"), + location_config.port, + ), + endpoint, ) return endpoint + elif isinstance(location_config, InstanceUnixLocationConfig): + return UNIXClientEndpoint(self.reactor, location_config.path) else: - raise SchemeNotSupported(f"Unsupported scheme: {uri.scheme!r}") + raise SchemeNotSupported(f"Unsupported scheme: {scheme}") @implementer(IAgent) @@ -80,6 +108,7 @@ class ReplicationAgent(_AgentBase): def __init__( self, reactor: ISynapseReactor, + instance_map: Dict[str, InstanceLocationConfig], contextFactory: IPolicyForHTTPS, connectTimeout: Optional[float] = None, bindAddress: Optional[bytes] = None, @@ -102,7 +131,9 @@ class ReplicationAgent(_AgentBase): created. """ _AgentBase.__init__(self, reactor, pool) - endpoint_factory = ReplicationEndpointFactory(reactor, contextFactory) + endpoint_factory = ReplicationEndpointFactory( + reactor, instance_map, contextFactory + ) self._endpointFactory = endpoint_factory def request( @@ -118,13 +149,16 @@ class ReplicationAgent(_AgentBase): An existing connection from the connection pool may be used or a new one may be created. - Currently, HTTP and HTTPS schemes are supported in uri. + Currently, HTTP, HTTPS and UNIX schemes are supported in uri. This is copied from twisted.web.client.Agent, except: - * It uses a different pool key (combining the host & port). - * It does not call _ensureValidURI(...) since it breaks on some - UNIX paths. + * It uses a different pool key (combining the scheme with either host & port or + socket path). + * It does not call _ensureValidURI(...) as the strictness of IDNA2008 is not + required when using a worker's name as a 'hostname' for Synapse HTTP + Replication machinery. Specifically, this allows a range of ascii characters + such as '+' and '_' in hostnames/worker's names. See: twisted.web.iweb.IAgent.request """ @@ -134,9 +168,12 @@ class ReplicationAgent(_AgentBase): except SchemeNotSupported: return defer.fail(Failure()) + worker_name = parsedURI.netloc.decode("utf-8") + key_scheme = self._endpointFactory.instance_map[worker_name].scheme() + key_netloc = self._endpointFactory.instance_map[worker_name].netloc() # This sets the Pool key to be: - # (http(s), <host:ip>) - key = (parsedURI.scheme, parsedURI.netloc) + # (http(s), <host:port>) or (unix, <socket_path>) + key = (key_scheme, key_netloc) # _requestWithEndpoint comes from _AgentBase class return self._requestWithEndpoint( diff --git a/synapse/http/server.py b/synapse/http/server.py index 101dc2e747..1e4e56f36b 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -18,6 +18,7 @@ import html import logging import types import urllib +import urllib.parse from http import HTTPStatus from http.client import FOUND from inspect import isawaitable @@ -65,7 +66,6 @@ from synapse.api.errors import ( UnrecognizedRequestError, ) from synapse.config.homeserver import HomeServerConfig -from synapse.http.site import SynapseRequest from synapse.logging.context import defer_to_thread, preserve_fn, run_in_background from synapse.logging.opentracing import active_span, start_active_span, trace_servlet from synapse.util import json_encoder @@ -76,6 +76,7 @@ from synapse.util.iterutils import chunk_seq if TYPE_CHECKING: import opentracing + from synapse.http.site import SynapseRequest from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -102,16 +103,25 @@ HTTP_STATUS_REQUEST_CANCELLED = 499 def return_json_error( - f: failure.Failure, request: SynapseRequest, config: Optional[HomeServerConfig] + f: failure.Failure, request: "SynapseRequest", config: Optional[HomeServerConfig] ) -> None: """Sends a JSON error response to clients.""" if f.check(SynapseError): # mypy doesn't understand that f.check asserts the type. - exc: SynapseError = f.value # type: ignore + exc: SynapseError = f.value error_code = exc.code error_dict = exc.error_dict(config) - logger.info("%s SynapseError: %s - %s", request, error_code, exc.msg) + if exc.headers is not None: + for header, value in exc.headers.items(): + request.setHeader(header, value) + error_ctx = exc.debug_context + if error_ctx: + logger.info( + "%s SynapseError: %s - %s (%s)", request, error_code, exc.msg, error_ctx + ) + else: + logger.info("%s SynapseError: %s - %s", request, error_code, exc.msg) elif f.check(CancelledError): error_code = HTTP_STATUS_REQUEST_CANCELLED error_dict = {"error": "Request cancelled", "errcode": Codes.UNKNOWN} @@ -121,7 +131,7 @@ def return_json_error( "Got cancellation before client disconnection from %r: %r", request.request_metrics.name, request, - exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore[arg-type] + exc_info=(f.type, f.value, f.getTracebackObject()), ) else: error_code = 500 @@ -131,7 +141,7 @@ def return_json_error( "Failed handle request via %r: %r", request.request_metrics.name, request, - exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore[arg-type] + exc_info=(f.type, f.value, f.getTracebackObject()), ) # Only respond with an error response if we haven't already started writing, @@ -169,9 +179,12 @@ def return_html_error( """ if f.check(CodeMessageException): # mypy doesn't understand that f.check asserts the type. - cme: CodeMessageException = f.value # type: ignore + cme: CodeMessageException = f.value code = cme.code msg = cme.msg + if cme.headers is not None: + for header, value in cme.headers.items(): + request.setHeader(header, value) if isinstance(cme, RedirectException): logger.info("%s redirect to %s", request, cme.location) @@ -183,7 +196,7 @@ def return_html_error( logger.error( "Failed handle request %r", request, - exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore[arg-type] + exc_info=(f.type, f.value, f.getTracebackObject()), ) elif f.check(CancelledError): code = HTTP_STATUS_REQUEST_CANCELLED @@ -193,7 +206,7 @@ def return_html_error( logger.error( "Got cancellation before client disconnection when handling request %r", request, - exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore[arg-type] + exc_info=(f.type, f.value, f.getTracebackObject()), ) else: code = HTTPStatus.INTERNAL_SERVER_ERROR @@ -202,7 +215,7 @@ def return_html_error( logger.error( "Failed handle request %r", request, - exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore[arg-type] + exc_info=(f.type, f.value, f.getTracebackObject()), ) if isinstance(error_template, str): @@ -214,8 +227,8 @@ def return_html_error( def wrap_async_request_handler( - h: Callable[["_AsyncResource", SynapseRequest], Awaitable[None]] -) -> Callable[["_AsyncResource", SynapseRequest], "defer.Deferred[None]"]: + h: Callable[["_AsyncResource", "SynapseRequest"], Awaitable[None]] +) -> Callable[["_AsyncResource", "SynapseRequest"], "defer.Deferred[None]"]: """Wraps an async request handler so that it calls request.processing. This helps ensure that work done by the request handler after the request is completed @@ -229,7 +242,7 @@ def wrap_async_request_handler( """ async def wrapped_async_request_handler( - self: "_AsyncResource", request: SynapseRequest + self: "_AsyncResource", request: "SynapseRequest" ) -> None: with request.processing(): await h(self, request) @@ -253,7 +266,7 @@ class HttpServer(Protocol): def register_paths( self, method: str, - path_patterns: Iterable[Pattern], + path_patterns: Iterable[Pattern[str]], callback: ServletCallback, servlet_classname: str, ) -> None: @@ -294,7 +307,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta): self._extract_context = extract_context - def render(self, request: SynapseRequest) -> int: + def render(self, request: "SynapseRequest") -> int: """This gets called by twisted every time someone sends us a request.""" request.render_deferred = defer.ensureDeferred( self._async_render_wrapper(request) @@ -302,7 +315,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta): return NOT_DONE_YET @wrap_async_request_handler - async def _async_render_wrapper(self, request: SynapseRequest) -> None: + async def _async_render_wrapper(self, request: "SynapseRequest") -> None: """This is a wrapper that delegates to `_async_render` and handles exceptions, return values, metrics, etc. """ @@ -322,7 +335,9 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta): f = failure.Failure() self._send_error_response(f, request) - async def _async_render(self, request: SynapseRequest) -> Optional[Tuple[int, Any]]: + async def _async_render( + self, request: "SynapseRequest" + ) -> Optional[Tuple[int, Any]]: """Delegates to `_async_render_<METHOD>` methods, or returns a 400 if no appropriate method exists. Can be overridden in sub classes for different routing. @@ -352,7 +367,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta): @abc.abstractmethod def _send_response( self, - request: SynapseRequest, + request: "SynapseRequest", code: int, response_object: Any, ) -> None: @@ -362,7 +377,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta): def _send_error_response( self, f: failure.Failure, - request: SynapseRequest, + request: "SynapseRequest", ) -> None: raise NotImplementedError() @@ -378,7 +393,7 @@ class DirectServeJsonResource(_AsyncResource): def _send_response( self, - request: SynapseRequest, + request: "SynapseRequest", code: int, response_object: Any, ) -> None: @@ -395,7 +410,7 @@ class DirectServeJsonResource(_AsyncResource): def _send_error_response( self, f: failure.Failure, - request: SynapseRequest, + request: "SynapseRequest", ) -> None: """Implements _AsyncResource._send_error_response""" return_json_error(f, request, None) @@ -467,7 +482,7 @@ class JsonResource(DirectServeJsonResource): ) def _get_handler_for_request( - self, request: SynapseRequest + self, request: "SynapseRequest" ) -> Tuple[ServletCallback, str, Dict[str, str]]: """Finds a callback method to handle the given request. @@ -497,7 +512,7 @@ class JsonResource(DirectServeJsonResource): # Huh. No one wanted to handle that? Fiiiiiine. raise UnrecognizedRequestError(code=404) - async def _async_render(self, request: SynapseRequest) -> Tuple[int, Any]: + async def _async_render(self, request: "SynapseRequest") -> Tuple[int, Any]: callback, servlet_classname, group_dict = self._get_handler_for_request(request) request.is_render_cancellable = is_function_cancellable(callback) @@ -529,7 +544,7 @@ class JsonResource(DirectServeJsonResource): def _send_error_response( self, f: failure.Failure, - request: SynapseRequest, + request: "SynapseRequest", ) -> None: """Implements _AsyncResource._send_error_response""" return_json_error(f, request, self.hs.config) @@ -545,7 +560,7 @@ class DirectServeHtmlResource(_AsyncResource): def _send_response( self, - request: SynapseRequest, + request: "SynapseRequest", code: int, response_object: Any, ) -> None: @@ -559,7 +574,7 @@ class DirectServeHtmlResource(_AsyncResource): def _send_error_response( self, f: failure.Failure, - request: SynapseRequest, + request: "SynapseRequest", ) -> None: """Implements _AsyncResource._send_error_response""" return_html_error(f, request, self.ERROR_TEMPLATE) @@ -586,7 +601,7 @@ class UnrecognizedRequestResource(resource.Resource): errcode of M_UNRECOGNIZED. """ - def render(self, request: SynapseRequest) -> int: + def render(self, request: "SynapseRequest") -> int: f = failure.Failure(UnrecognizedRequestError(code=404)) return_json_error(f, request, None) # A response has already been sent but Twisted requires either NOT_DONE_YET @@ -616,7 +631,7 @@ class RootRedirect(resource.Resource): class OptionsResource(resource.Resource): """Responds to OPTION requests for itself and all children.""" - def render_OPTIONS(self, request: SynapseRequest) -> bytes: + def render_OPTIONS(self, request: "SynapseRequest") -> bytes: request.setResponseCode(204) request.setHeader(b"Content-Length", b"0") @@ -731,7 +746,7 @@ def _encode_json_bytes(json_object: object) -> bytes: def respond_with_json( - request: SynapseRequest, + request: "SynapseRequest", code: int, json_object: Any, send_cors: bool = False, @@ -781,7 +796,7 @@ def respond_with_json( def respond_with_json_bytes( - request: SynapseRequest, + request: "SynapseRequest", code: int, json_bytes: bytes, send_cors: bool = False, @@ -819,7 +834,7 @@ def respond_with_json_bytes( async def _async_write_json_to_request_in_thread( - request: SynapseRequest, + request: "SynapseRequest", json_encoder: Callable[[Any], bytes], json_object: Any, ) -> None: @@ -877,7 +892,7 @@ def _write_bytes_to_request(request: Request, bytes_to_write: bytes) -> None: _ByteProducer(request, bytes_generator) -def set_cors_headers(request: SynapseRequest) -> None: +def set_cors_headers(request: "SynapseRequest") -> None: """Set the CORS headers so that javascript running in a web browsers can use this API @@ -904,7 +919,7 @@ def set_cors_headers(request: SynapseRequest) -> None: ) request.setHeader( b"Access-Control-Expose-Headers", - b"Synapse-Trace-Id", + b"Synapse-Trace-Id, Server", ) @@ -975,7 +990,7 @@ def set_clickjacking_protection_headers(request: Request) -> None: def respond_with_redirect( - request: SynapseRequest, url: bytes, statusCode: int = FOUND, cors: bool = False + request: "SynapseRequest", url: bytes, statusCode: int = FOUND, cors: bool = False ) -> None: """ Write a 302 (or other specified status code) response to the request, if it is still alive. diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index fc62793628..d9d5655c95 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -18,7 +18,6 @@ import logging from http import HTTPStatus from typing import ( TYPE_CHECKING, - Iterable, List, Mapping, Optional, @@ -29,8 +28,15 @@ from typing import ( overload, ) -from pydantic import BaseModel, MissingError, PydanticValueError, ValidationError -from pydantic.error_wrappers import ErrorWrapper +from synapse._pydantic_compat import HAS_PYDANTIC_V2 + +if TYPE_CHECKING or HAS_PYDANTIC_V2: + from pydantic.v1 import BaseModel, MissingError, PydanticValueError, ValidationError + from pydantic.v1.error_wrappers import ErrorWrapper +else: + from pydantic import BaseModel, MissingError, PydanticValueError, ValidationError + from pydantic.error_wrappers import ErrorWrapper + from typing_extensions import Literal from twisted.web.server import Request @@ -38,7 +44,7 @@ from twisted.web.server import Request from synapse.api.errors import Codes, SynapseError from synapse.http import redact_uri from synapse.http.server import HttpServer -from synapse.types import JsonDict, RoomAlias, RoomID +from synapse.types import JsonDict, RoomAlias, RoomID, StrCollection from synapse.util import json_decoder if TYPE_CHECKING: @@ -340,7 +346,7 @@ def parse_string( name: str, default: str, *, - allowed_values: Optional[Iterable[str]] = None, + allowed_values: Optional[StrCollection] = None, encoding: str = "ascii", ) -> str: ... @@ -352,7 +358,7 @@ def parse_string( name: str, *, required: Literal[True], - allowed_values: Optional[Iterable[str]] = None, + allowed_values: Optional[StrCollection] = None, encoding: str = "ascii", ) -> str: ... @@ -365,7 +371,7 @@ def parse_string( *, default: Optional[str] = None, required: bool = False, - allowed_values: Optional[Iterable[str]] = None, + allowed_values: Optional[StrCollection] = None, encoding: str = "ascii", ) -> Optional[str]: ... @@ -376,7 +382,7 @@ def parse_string( name: str, default: Optional[str] = None, required: bool = False, - allowed_values: Optional[Iterable[str]] = None, + allowed_values: Optional[StrCollection] = None, encoding: str = "ascii", ) -> Optional[str]: """ @@ -485,7 +491,7 @@ def parse_enum( def _parse_string_value( value: bytes, - allowed_values: Optional[Iterable[str]], + allowed_values: Optional[StrCollection], name: str, encoding: str, ) -> str: @@ -511,7 +517,7 @@ def parse_strings_from_args( args: Mapping[bytes, Sequence[bytes]], name: str, *, - allowed_values: Optional[Iterable[str]] = None, + allowed_values: Optional[StrCollection] = None, encoding: str = "ascii", ) -> Optional[List[str]]: ... @@ -523,7 +529,7 @@ def parse_strings_from_args( name: str, default: List[str], *, - allowed_values: Optional[Iterable[str]] = None, + allowed_values: Optional[StrCollection] = None, encoding: str = "ascii", ) -> List[str]: ... @@ -535,7 +541,7 @@ def parse_strings_from_args( name: str, *, required: Literal[True], - allowed_values: Optional[Iterable[str]] = None, + allowed_values: Optional[StrCollection] = None, encoding: str = "ascii", ) -> List[str]: ... @@ -548,7 +554,7 @@ def parse_strings_from_args( default: Optional[List[str]] = None, *, required: bool = False, - allowed_values: Optional[Iterable[str]] = None, + allowed_values: Optional[StrCollection] = None, encoding: str = "ascii", ) -> Optional[List[str]]: ... @@ -559,7 +565,7 @@ def parse_strings_from_args( name: str, default: Optional[List[str]] = None, required: bool = False, - allowed_values: Optional[Iterable[str]] = None, + allowed_values: Optional[StrCollection] = None, encoding: str = "ascii", ) -> Optional[List[str]]: """ @@ -610,7 +616,7 @@ def parse_string_from_args( name: str, default: Optional[str] = None, *, - allowed_values: Optional[Iterable[str]] = None, + allowed_values: Optional[StrCollection] = None, encoding: str = "ascii", ) -> Optional[str]: ... @@ -623,7 +629,7 @@ def parse_string_from_args( default: Optional[str] = None, *, required: Literal[True], - allowed_values: Optional[Iterable[str]] = None, + allowed_values: Optional[StrCollection] = None, encoding: str = "ascii", ) -> str: ... @@ -635,7 +641,7 @@ def parse_string_from_args( name: str, default: Optional[str] = None, required: bool = False, - allowed_values: Optional[Iterable[str]] = None, + allowed_values: Optional[StrCollection] = None, encoding: str = "ascii", ) -> Optional[str]: ... @@ -646,7 +652,7 @@ def parse_string_from_args( name: str, default: Optional[str] = None, required: bool = False, - allowed_values: Optional[Iterable[str]] = None, + allowed_values: Optional[StrCollection] = None, encoding: str = "ascii", ) -> Optional[str]: """ @@ -821,7 +827,7 @@ def parse_and_validate_json_object_from_request( return validate_json_object(content, model_type) -def assert_params_in_dict(body: JsonDict, required: Iterable[str]) -> None: +def assert_params_in_dict(body: JsonDict, required: StrCollection) -> None: absent = [] for k in required: if k not in body: diff --git a/synapse/http/site.py b/synapse/http/site.py index c530966ef3..a388d6cf7f 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -21,25 +21,29 @@ from zope.interface import implementer from twisted.internet.address import UNIXAddress from twisted.internet.defer import Deferred -from twisted.internet.interfaces import IAddress, IReactorTime +from twisted.internet.interfaces import IAddress from twisted.python.failure import Failure from twisted.web.http import HTTPChannel from twisted.web.resource import IResource, Resource -from twisted.web.server import Request, Site +from twisted.web.server import Request from synapse.config.server import ListenerConfig from synapse.http import get_request_user_agent, redact_uri +from synapse.http.proxy import ProxySite from synapse.http.request_metrics import RequestMetrics, requests_counter from synapse.logging.context import ( ContextRequest, LoggingContext, PreserveLoggingContext, ) -from synapse.types import Requester +from synapse.types import ISynapseReactor, Requester if TYPE_CHECKING: import opentracing + from synapse.server import HomeServer + + logger = logging.getLogger(__name__) _next_request_seq = 0 @@ -102,7 +106,7 @@ class SynapseRequest(Request): # A boolean indicating whether `render_deferred` should be cancelled if the # client disconnects early. Expected to be set by the coroutine started by # `Resource.render`, if rendering is asynchronous. - self.is_render_cancellable = False + self.is_render_cancellable: bool = False global _next_request_seq self.request_seq = _next_request_seq @@ -521,6 +525,11 @@ class SynapseRequest(Request): else: return self.getClientAddress().host + def request_info(self) -> "RequestInfo": + h = self.getHeader(b"User-Agent") + user_agent = h.decode("ascii", "replace") if h else None + return RequestInfo(user_agent=user_agent, ip=self.get_client_ip_if_available()) + class XForwardedForRequest(SynapseRequest): """Request object which honours proxy headers @@ -596,7 +605,7 @@ class _XForwardedForAddress: host: str -class SynapseSite(Site): +class SynapseSite(ProxySite): """ Synapse-specific twisted http Site @@ -618,7 +627,8 @@ class SynapseSite(Site): resource: IResource, server_version_string: str, max_request_body_size: int, - reactor: IReactorTime, + reactor: ISynapseReactor, + hs: "HomeServer", ): """ @@ -633,7 +643,11 @@ class SynapseSite(Site): dropping the connection reactor: reactor to be used to manage connection timeouts """ - Site.__init__(self, resource, reactor=reactor) + super().__init__( + resource=resource, + reactor=reactor, + hs=hs, + ) self.site_tag = site_tag self.reactor = reactor @@ -644,7 +658,9 @@ class SynapseSite(Site): request_id_header = config.http_options.request_id_header - self.experimental_cors_msc3886 = config.http_options.experimental_cors_msc3886 + self.experimental_cors_msc3886: bool = ( + config.http_options.experimental_cors_msc3886 + ) def request_factory(channel: HTTPChannel, queued: bool) -> Request: return request_class( @@ -661,3 +677,9 @@ class SynapseSite(Site): def log(self, request: SynapseRequest) -> None: pass + + +@attr.s(auto_attribs=True, frozen=True, slots=True) +class RequestInfo: + user_agent: Optional[str] + ip: str diff --git a/synapse/logging/_remote.py b/synapse/logging/_remote.py index 5a61b21eaf..284fbac524 100644 --- a/synapse/logging/_remote.py +++ b/synapse/logging/_remote.py @@ -18,10 +18,9 @@ import traceback from collections import deque from ipaddress import IPv4Address, IPv6Address, ip_address from math import floor -from typing import Callable, Optional +from typing import Callable, Deque, Optional import attr -from typing_extensions import Deque from zope.interface import implementer from twisted.application.internet import ClientService diff --git a/synapse/logging/_terse_json.py b/synapse/logging/_terse_json.py index b78d6e17c9..98c6038ff2 100644 --- a/synapse/logging/_terse_json.py +++ b/synapse/logging/_terse_json.py @@ -44,6 +44,7 @@ _IGNORED_LOG_RECORD_ATTRIBUTES = { "processName", "relativeCreated", "stack_info", + "taskName", "thread", "threadName", } diff --git a/synapse/logging/context.py b/synapse/logging/context.py index f62bea968f..bf7e311026 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -728,7 +728,7 @@ async def _unwrap_awaitable(awaitable: Awaitable[R]) -> R: @overload -def preserve_fn( # type: ignore[misc] +def preserve_fn( f: Callable[P, Awaitable[R]], ) -> Callable[P, "defer.Deferred[R]"]: # The `type: ignore[misc]` above suppresses @@ -756,7 +756,7 @@ def preserve_fn( @overload -def run_in_background( # type: ignore[misc] +def run_in_background( f: Callable[P, Awaitable[R]], *args: P.args, **kwargs: P.kwargs ) -> "defer.Deferred[R]": # The `type: ignore[misc]` above suppresses @@ -809,23 +809,24 @@ def run_in_background( # type: ignore[misc] # `res` may be a coroutine, `Deferred`, some other kind of awaitable, or a plain # value. Convert it to a `Deferred`. + d: "defer.Deferred[R]" if isinstance(res, typing.Coroutine): # Wrap the coroutine in a `Deferred`. - res = defer.ensureDeferred(res) + d = defer.ensureDeferred(res) elif isinstance(res, defer.Deferred): - pass + d = res elif isinstance(res, Awaitable): # `res` is probably some kind of completed awaitable, such as a `DoneAwaitable` # or `Future` from `make_awaitable`. - res = defer.ensureDeferred(_unwrap_awaitable(res)) + d = defer.ensureDeferred(_unwrap_awaitable(res)) else: # `res` is a plain value. Wrap it in a `Deferred`. - res = defer.succeed(res) + d = defer.succeed(res) - if res.called and not res.paused: + if d.called and not d.paused: # The function should have maintained the logcontext, so we can # optimise out the messing about - return res + return d # The function may have reset the context before returning, so # we need to restore it now. @@ -843,8 +844,8 @@ def run_in_background( # type: ignore[misc] # which is supposed to have a single entry and exit point. But # by spawning off another deferred, we are effectively # adding a new exit point.) - res.addBoth(_set_context_cb, ctx) - return res + d.addBoth(_set_context_cb, ctx) + return d T = TypeVar("T") @@ -877,7 +878,7 @@ def make_deferred_yieldable(deferred: "defer.Deferred[T]") -> "defer.Deferred[T] ResultT = TypeVar("ResultT") -def _set_context_cb(result: ResultT, context: LoggingContext) -> ResultT: +def _set_context_cb(result: ResultT, context: LoggingContextOrSentinel) -> ResultT: """A callback function which just sets the logging context""" set_current_context(context) return result diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index c70eee649c..4454fe29a5 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -171,6 +171,7 @@ from functools import wraps from typing import ( TYPE_CHECKING, Any, + Awaitable, Callable, Collection, ContextManager, @@ -903,26 +904,28 @@ def _custom_sync_async_decorator( """ if inspect.iscoroutinefunction(func): + # For this branch, we handle async functions like `async def func() -> RInner`. # In this branch, R = Awaitable[RInner], for some other type RInner @wraps(func) async def _wrapper( *args: P.args, **kwargs: P.kwargs ) -> Any: # Return type is RInner - with wrapping_logic(func, *args, **kwargs): - # type-ignore: func() returns R, but mypy doesn't know that R is - # Awaitable here. - return await func(*args, **kwargs) # type: ignore[misc] + # type-ignore: func() returns R, but mypy doesn't know that R is + # Awaitable here. + with wrapping_logic(func, *args, **kwargs): # type: ignore[arg-type] + return await func(*args, **kwargs) else: - # The other case here handles both sync functions and those - # decorated with inlineDeferred. + # The other case here handles sync functions including those decorated with + # `@defer.inlineCallbacks` or that return a `Deferred` or other `Awaitable`. @wraps(func) - def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + def _wrapper(*args: P.args, **kwargs: P.kwargs) -> Any: scope = wrapping_logic(func, *args, **kwargs) scope.__enter__() try: result = func(*args, **kwargs) + if isinstance(result, defer.Deferred): def call_back(result: R) -> R: @@ -930,20 +933,32 @@ def _custom_sync_async_decorator( return result def err_back(result: R) -> R: + # TODO: Pass the error details into `scope.__exit__(...)` for + # consistency with the other paths. scope.__exit__(None, None, None) return result result.addCallbacks(call_back, err_back) + elif inspect.isawaitable(result): + + async def wrap_awaitable() -> Any: + try: + assert isinstance(result, Awaitable) + awaited_result = await result + scope.__exit__(None, None, None) + return awaited_result + except Exception as e: + scope.__exit__(type(e), None, e.__traceback__) + raise + + # The original method returned an awaitable, eg. a coroutine, so we + # create another awaitable wrapping it that calls + # `scope.__exit__(...)`. + return wrap_awaitable() else: - if inspect.isawaitable(result): - logger.error( - "@trace may not have wrapped %s correctly! " - "The function is not async but returned a %s.", - func.__qualname__, - type(result).__name__, - ) - + # Just a simple sync function so we can just exit the scope and + # return the result without any fuss. scope.__exit__(None, None, None) return result @@ -965,8 +980,7 @@ def trace_with_opname( See the module's doc string for usage examples. """ - # type-ignore: mypy bug, see https://github.com/python/mypy/issues/12909 - @contextlib.contextmanager # type: ignore[arg-type] + @contextlib.contextmanager def _wrapping_logic( func: Callable[P, R], *args: P.args, **kwargs: P.kwargs ) -> Generator[None, None, None]: @@ -977,11 +991,7 @@ def trace_with_opname( if not opentracing: return func - # type-ignore: mypy seems to be confused by the ParamSpecs here. - # I think the problem is https://github.com/python/mypy/issues/12909 - return _custom_sync_async_decorator( - func, _wrapping_logic # type: ignore[arg-type] - ) + return _custom_sync_async_decorator(func, _wrapping_logic) return _decorator @@ -1009,8 +1019,7 @@ def tag_args(func: Callable[P, R]) -> Callable[P, R]: if not opentracing: return func - # type-ignore: mypy bug, see https://github.com/python/mypy/issues/12909 - @contextlib.contextmanager # type: ignore[arg-type] + @contextlib.contextmanager def _wrapping_logic( func: Callable[P, R], *args: P.args, **kwargs: P.kwargs ) -> Generator[None, None, None]: @@ -1027,9 +1036,7 @@ def tag_args(func: Callable[P, R]) -> Callable[P, R]: set_tag(SynapseTags.FUNC_KWARGS, str(kwargs)) yield - # type-ignore: mypy seems to be confused by the ParamSpecs here. - # I think the problem is https://github.com/python/mypy/issues/12909 - return _custom_sync_async_decorator(func, _wrapping_logic) # type: ignore[arg-type] + return _custom_sync_async_decorator(func, _wrapping_logic) @contextlib.contextmanager @@ -1055,7 +1062,7 @@ def trace_servlet( tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER, tags.HTTP_METHOD: request.get_method(), tags.HTTP_URL: request.get_redacted_uri(), - tags.PEER_HOST_IPV6: request.getClientAddress().host, + tags.PEER_HOST_IPV6: request.get_client_ip_if_available(), } request_name = request.request_metrics.name @@ -1076,9 +1083,11 @@ def trace_servlet( # with JsonResource). scope.span.set_operation_name(request.request_metrics.name) + # Mypy seems to think that start_context.tag below can be Optional[str], but + # that doesn't appear to be correct and works in practice. request_tags[ SynapseTags.REQUEST_TAG - ] = request.request_metrics.start_context.tag + ] = request.request_metrics.start_context.tag # type: ignore[assignment] # set the tags *after* the servlet completes, in case it decided to # prioritise the span (tags will get dropped on unprioritised spans) diff --git a/synapse/media/_base.py b/synapse/media/_base.py index ef8334ae25..13345acf75 100644 --- a/synapse/media/_base.py +++ b/synapse/media/_base.py @@ -26,11 +26,11 @@ from twisted.internet.interfaces import IConsumer from twisted.protocols.basic import FileSender from twisted.web.server import Request -from synapse.api.errors import Codes, SynapseError, cs_error +from synapse.api.errors import Codes, cs_error from synapse.http.server import finish_request, respond_with_json from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable -from synapse.util.stringutils import is_ascii, parse_and_validate_server_name +from synapse.util.stringutils import is_ascii logger = logging.getLogger(__name__) @@ -50,53 +50,46 @@ TEXT_CONTENT_TYPES = [ "text/xml", ] - -def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]: - """Parses the server name, media ID and optional file name from the request URI - - Also performs some rough validation on the server name. - - Args: - request: The `Request`. - - Returns: - A tuple containing the parsed server name, media ID and optional file name. - - Raises: - SynapseError(404): if parsing or validation fail for any reason - """ - try: - # The type on postpath seems incorrect in Twisted 21.2.0. - postpath: List[bytes] = request.postpath # type: ignore - assert postpath - - # This allows users to append e.g. /test.png to the URL. Useful for - # clients that parse the URL to see content type. - server_name_bytes, media_id_bytes = postpath[:2] - server_name = server_name_bytes.decode("utf-8") - media_id = media_id_bytes.decode("utf8") - - # Validate the server name, raising if invalid - parse_and_validate_server_name(server_name) - - file_name = None - if len(postpath) > 2: - try: - file_name = urllib.parse.unquote(postpath[-1].decode("utf-8")) - except UnicodeDecodeError: - pass - return server_name, media_id, file_name - except Exception: - raise SynapseError( - 404, "Invalid media id token %r" % (request.postpath,), Codes.UNKNOWN - ) +# A list of all content types that are "safe" to be rendered inline in a browser. +INLINE_CONTENT_TYPES = [ + "text/css", + "text/plain", + "text/csv", + "application/json", + "application/ld+json", + # We allow some media files deemed as safe, which comes from the matrix-react-sdk. + # https://github.com/matrix-org/matrix-react-sdk/blob/a70fcfd0bcf7f8c85986da18001ea11597989a7c/src/utils/blobs.ts#L51 + # SVGs are *intentionally* omitted. + "image/jpeg", + "image/gif", + "image/png", + "image/apng", + "image/webp", + "image/avif", + "video/mp4", + "video/webm", + "video/ogg", + "video/quicktime", + "audio/mp4", + "audio/webm", + "audio/aac", + "audio/mpeg", + "audio/ogg", + "audio/wave", + "audio/wav", + "audio/x-wav", + "audio/x-pn-wav", + "audio/flac", + "audio/x-flac", +] def respond_404(request: SynapseRequest) -> None: + assert request.path is not None respond_with_json( request, 404, - cs_error("Not found %r" % (request.postpath,), code=Codes.NOT_FOUND), + cs_error("Not found '%s'" % (request.path.decode(),), code=Codes.NOT_FOUND), send_cors=True, ) @@ -152,6 +145,14 @@ def add_file_headers( content_type = media_type request.setHeader(b"Content-Type", content_type.encode("UTF-8")) + + # A strict subset of content types is allowed to be inlined so that they may + # be viewed directly in a browser. Other file types are forced to be downloads. + if media_type.lower() in INLINE_CONTENT_TYPES: + disposition = "inline" + else: + disposition = "attachment" + if upload_name: # RFC6266 section 4.1 [1] defines both `filename` and `filename*`. # @@ -173,11 +174,17 @@ def add_file_headers( # correctly interpret those as of 0.99.2 and (b) they are a bit of a pain and we # may as well just do the filename* version. if _can_encode_filename_as_token(upload_name): - disposition = "inline; filename=%s" % (upload_name,) + disposition = "%s; filename=%s" % ( + disposition, + upload_name, + ) else: - disposition = "inline; filename*=utf-8''%s" % (_quote(upload_name),) + disposition = "%s; filename*=utf-8''%s" % ( + disposition, + _quote(upload_name), + ) - request.setHeader(b"Content-Disposition", disposition.encode("ascii")) + request.setHeader(b"Content-Disposition", disposition.encode("ascii")) # cache for at least a day. # XXX: we might want to turn this off for data we don't want to @@ -325,7 +332,7 @@ class ThumbnailInfo: # Content type of thumbnail, e.g. image/png type: str # The size of the media file, in bytes. - length: Optional[int] = None + length: int @attr.s(slots=True, frozen=True, auto_attribs=True) diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py index e81c987b10..7fd46901f7 100644 --- a/synapse/media/media_repository.py +++ b/synapse/media/media_repository.py @@ -35,6 +35,7 @@ from synapse.api.errors import ( from synapse.config.repository import ThumbnailRequirement from synapse.http.site import SynapseRequest from synapse.logging.context import defer_to_thread +from synapse.logging.opentracing import trace from synapse.media._base import ( FileInfo, Responder, @@ -47,6 +48,7 @@ from synapse.media.filepath import MediaFilePaths from synapse.media.media_storage import MediaStorage from synapse.media.storage_provider import StorageProviderWrapper from synapse.media.thumbnailer import Thumbnailer, ThumbnailError +from synapse.media.url_previewer import UrlPreviewer from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import UserID from synapse.util.async_helpers import Linearizer @@ -113,7 +115,7 @@ class MediaRepository: ) storage_providers.append(provider) - self.media_storage = MediaStorage( + self.media_storage: MediaStorage = MediaStorage( self.hs, self.primary_base_path, self.filepaths, storage_providers ) @@ -141,6 +143,13 @@ class MediaRepository: MEDIA_RETENTION_CHECK_PERIOD_MS, ) + if hs.config.media.url_preview_enabled: + self.url_previewer: Optional[UrlPreviewer] = UrlPreviewer( + hs, self, self.media_storage + ) + else: + self.url_previewer = None + def _start_update_recently_accessed(self) -> Deferred: return run_as_background_process( "update_recently_accessed_media", self._update_recently_accessed @@ -174,6 +183,7 @@ class MediaRepository: else: self.recently_accessed_locals.add(media_id) + @trace async def create_content( self, media_type: str, @@ -212,7 +222,10 @@ class MediaRepository: user_id=auth_user, ) - await self._generate_thumbnails(None, media_id, media_id, media_type) + try: + await self._generate_thumbnails(None, media_id, media_id, media_type) + except Exception as e: + logger.info("Failed to generate thumbnails: %s", e) return MXCUri(self.server_name, media_id) @@ -611,6 +624,7 @@ class MediaRepository: height=t_height, method=t_method, type=t_type, + length=t_byte_source.tell(), ), ) @@ -681,6 +695,7 @@ class MediaRepository: height=t_height, method=t_method, type=t_type, + length=t_byte_source.tell(), ), ) @@ -710,6 +725,7 @@ class MediaRepository: # Could not generate thumbnail. return None + @trace async def _generate_thumbnails( self, server_name: Optional[str], @@ -825,6 +841,7 @@ class MediaRepository: height=t_height, method=t_method, type=t_type, + length=t_byte_source.tell(), ), ) diff --git a/synapse/media/media_storage.py b/synapse/media/media_storage.py index a819d95407..a17ccb3d80 100644 --- a/synapse/media/media_storage.py +++ b/synapse/media/media_storage.py @@ -38,6 +38,7 @@ from twisted.protocols.basic import FileSender from synapse.api.errors import NotFoundError from synapse.logging.context import defer_to_thread, make_deferred_yieldable +from synapse.logging.opentracing import start_active_span, trace, trace_with_opname from synapse.util import Clock from synapse.util.file_consumer import BackgroundFileConsumer @@ -76,6 +77,7 @@ class MediaStorage: self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker self.clock = hs.get_clock() + @trace_with_opname("MediaStorage.store_file") async def store_file(self, source: IO, file_info: FileInfo) -> str: """Write `source` to the on disk media store, and also any other configured storage providers @@ -89,16 +91,19 @@ class MediaStorage: """ with self.store_into_file(file_info) as (f, fname, finish_cb): - # Write to the main repository + # Write to the main media repository await self.write_to_file(source, f) + # Write to the other storage providers await finish_cb() return fname + @trace_with_opname("MediaStorage.write_to_file") async def write_to_file(self, source: IO, output: IO) -> None: """Asynchronously write the `source` to `output`.""" await defer_to_thread(self.reactor, _write_file_synchronously, source, output) + @trace_with_opname("MediaStorage.store_into_file") @contextlib.contextmanager def store_into_file( self, file_info: FileInfo @@ -113,9 +118,9 @@ class MediaStorage: fname can be used to read the contents from after upload, e.g. to generate thumbnails. - finish_cb must be called and waited on after the file has been - successfully been written to. Should not be called if there was an - error. + finish_cb must be called and waited on after the file has been successfully been + written to. Should not be called if there was an error. Checks for spam and + stores the file into the configured storage providers. Args: file_info: Info about the file to store @@ -135,35 +140,48 @@ class MediaStorage: finished_called = [False] + main_media_repo_write_trace_scope = start_active_span( + "writing to main media repo" + ) + main_media_repo_write_trace_scope.__enter__() + try: with open(fname, "wb") as f: async def finish() -> None: - # Ensure that all writes have been flushed and close the - # file. - f.flush() - f.close() - - spam_check = await self._spam_checker_module_callbacks.check_media_file_for_spam( - ReadableFileWrapper(self.clock, fname), file_info - ) - if spam_check != self._spam_checker_module_callbacks.NOT_SPAM: - logger.info("Blocking media due to spam checker") - # Note that we'll delete the stored media, due to the - # try/except below. The media also won't be stored in - # the DB. - # We currently ignore any additional field returned by - # the spam-check API. - raise SpamMediaException(errcode=spam_check[0]) - - for provider in self.storage_providers: - await provider.store_file(path, file_info) - - finished_called[0] = True + # When someone calls finish, we assume they are done writing to the main media repo + main_media_repo_write_trace_scope.__exit__(None, None, None) + + with start_active_span("writing to other storage providers"): + # Ensure that all writes have been flushed and close the + # file. + f.flush() + f.close() + + spam_check = await self._spam_checker_module_callbacks.check_media_file_for_spam( + ReadableFileWrapper(self.clock, fname), file_info + ) + if spam_check != self._spam_checker_module_callbacks.NOT_SPAM: + logger.info("Blocking media due to spam checker") + # Note that we'll delete the stored media, due to the + # try/except below. The media also won't be stored in + # the DB. + # We currently ignore any additional field returned by + # the spam-check API. + raise SpamMediaException(errcode=spam_check[0]) + + for provider in self.storage_providers: + with start_active_span(str(provider)): + await provider.store_file(path, file_info) + + finished_called[0] = True yield f, fname, finish except Exception as e: try: + main_media_repo_write_trace_scope.__exit__( + type(e), None, e.__traceback__ + ) os.remove(fname) except Exception: pass @@ -171,7 +189,11 @@ class MediaStorage: raise e from None if not finished_called: - raise Exception("Finished callback not called") + exc = Exception("Finished callback not called") + main_media_repo_write_trace_scope.__exit__( + type(exc), None, exc.__traceback__ + ) + raise exc async def fetch_media(self, file_info: FileInfo) -> Optional[Responder]: """Attempts to fetch media described by file_info from the local cache @@ -214,6 +236,7 @@ class MediaStorage: return None + @trace async def ensure_media_is_in_local_cache(self, file_info: FileInfo) -> str: """Ensures that the given file is in the local cache. Attempts to download it from storage providers if it isn't. @@ -259,6 +282,7 @@ class MediaStorage: raise NotFoundError() + @trace def _file_info_to_path(self, file_info: FileInfo) -> str: """Converts file_info into a relative path. @@ -301,6 +325,7 @@ class MediaStorage: return self.filepaths.local_media_filepath_rel(file_info.file_id) +@trace def _write_file_synchronously(source: IO, dest: IO) -> None: """Write `source` to the file like `dest` synchronously. Should be called from a thread. diff --git a/synapse/media/oembed.py b/synapse/media/oembed.py index c0eaf04be5..2ce842c98d 100644 --- a/synapse/media/oembed.py +++ b/synapse/media/oembed.py @@ -14,7 +14,7 @@ import html import logging import urllib.parse -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, List, Optional, cast import attr @@ -98,7 +98,7 @@ class OEmbedProvider: # No match. return None - def autodiscover_from_html(self, tree: "etree.Element") -> Optional[str]: + def autodiscover_from_html(self, tree: "etree._Element") -> Optional[str]: """ Search an HTML document for oEmbed autodiscovery information. @@ -109,18 +109,22 @@ class OEmbedProvider: The URL to use for oEmbed information, or None if no URL was found. """ # Search for link elements with the proper rel and type attributes. - for tag in tree.xpath( - "//link[@rel='alternate'][@type='application/json+oembed']" + # Cast: the type returned by xpath depends on the xpath expression: mypy can't deduce this. + for tag in cast( + List["etree._Element"], + tree.xpath("//link[@rel='alternate'][@type='application/json+oembed']"), ): if "href" in tag.attrib: - return tag.attrib["href"] + return cast(str, tag.attrib["href"]) # Some providers (e.g. Flickr) use alternative instead of alternate. - for tag in tree.xpath( - "//link[@rel='alternative'][@type='application/json+oembed']" + # Cast: the type returned by xpath depends on the xpath expression: mypy can't deduce this. + for tag in cast( + List["etree._Element"], + tree.xpath("//link[@rel='alternative'][@type='application/json+oembed']"), ): if "href" in tag.attrib: - return tag.attrib["href"] + return cast(str, tag.attrib["href"]) return None @@ -200,7 +204,7 @@ class OEmbedProvider: calc_description_and_urls(open_graph_response, oembed["html"]) for size in ("width", "height"): val = oembed.get(size) - if type(val) is int: + if type(val) is int: # noqa: E721 open_graph_response[f"og:video:{size}"] = val elif oembed_type == "link": @@ -212,11 +216,12 @@ class OEmbedProvider: return OEmbedResult(open_graph_response, author_name, cache_age) -def _fetch_urls(tree: "etree.Element", tag_name: str) -> List[str]: +def _fetch_urls(tree: "etree._Element", tag_name: str) -> List[str]: results = [] - for tag in tree.xpath("//*/" + tag_name): + # Cast: the type returned by xpath depends on the xpath expression: mypy can't deduce this. + for tag in cast(List["etree._Element"], tree.xpath("//*/" + tag_name)): if "src" in tag.attrib: - results.append(tag.attrib["src"]) + results.append(cast(str, tag.attrib["src"])) return results @@ -244,11 +249,12 @@ def calc_description_and_urls(open_graph_response: JsonDict, html_body: str) -> parser = etree.HTMLParser(recover=True, encoding="utf-8") # Attempt to parse the body. If this fails, log and return no metadata. - tree = etree.fromstring(html_body, parser) + # TODO Develop of lxml-stubs has this correct. + tree = etree.fromstring(html_body, parser) # type: ignore[arg-type] # The data was successfully parsed, but no tree was found. if tree is None: - return + return # type: ignore[unreachable] # Attempt to find interesting URLs (images, videos, embeds). if "og:image" not in open_graph_response: diff --git a/synapse/media/preview_html.py b/synapse/media/preview_html.py index 516d0434f0..1bc7ccb7f3 100644 --- a/synapse/media/preview_html.py +++ b/synapse/media/preview_html.py @@ -24,6 +24,7 @@ from typing import ( Optional, Set, Union, + cast, ) if TYPE_CHECKING: @@ -115,7 +116,7 @@ def _get_html_media_encodings( def decode_body( body: bytes, uri: str, content_type: Optional[str] = None -) -> Optional["etree.Element"]: +) -> Optional["etree._Element"]: """ This uses lxml to parse the HTML document. @@ -152,11 +153,12 @@ def decode_body( # Attempt to parse the body. Returns None if the body was successfully # parsed, but no tree was found. - return etree.fromstring(body, parser) + # TODO Develop of lxml-stubs has this correct. + return etree.fromstring(body, parser) # type: ignore[arg-type] def _get_meta_tags( - tree: "etree.Element", + tree: "etree._Element", property: str, prefix: str, property_mapper: Optional[Callable[[str], Optional[str]]] = None, @@ -175,9 +177,15 @@ def _get_meta_tags( Returns: A map of tag name to value. """ + # This actually returns Dict[str, str], but the caller sets this as a variable + # which is Dict[str, Optional[str]]. results: Dict[str, Optional[str]] = {} - for tag in tree.xpath( - f"//*/meta[starts-with(@{property}, '{prefix}:')][@content][not(@content='')]" + # Cast: the type returned by xpath depends on the xpath expression: mypy can't deduce this. + for tag in cast( + List["etree._Element"], + tree.xpath( + f"//*/meta[starts-with(@{property}, '{prefix}:')][@content][not(@content='')]" + ), ): # if we've got more than 50 tags, someone is taking the piss if len(results) >= 50: @@ -187,14 +195,15 @@ def _get_meta_tags( ) return {} - key = tag.attrib[property] + key = cast(str, tag.attrib[property]) if property_mapper: - key = property_mapper(key) + new_key = property_mapper(key) # None is a special value used to ignore a value. - if key is None: + if new_key is None: continue + key = new_key - results[key] = tag.attrib["content"] + results[key] = cast(str, tag.attrib["content"]) return results @@ -219,7 +228,7 @@ def _map_twitter_to_open_graph(key: str) -> Optional[str]: return "og" + key[7:] -def parse_html_to_open_graph(tree: "etree.Element") -> Dict[str, Optional[str]]: +def parse_html_to_open_graph(tree: "etree._Element") -> Dict[str, Optional[str]]: """ Parse the HTML document into an Open Graph response. @@ -276,24 +285,36 @@ def parse_html_to_open_graph(tree: "etree.Element") -> Dict[str, Optional[str]]: if "og:title" not in og: # Attempt to find a title from the title tag, or the biggest header on the page. - title = tree.xpath("((//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1])/text()") + # Cast: the type returned by xpath depends on the xpath expression: mypy can't deduce this. + title = cast( + List["etree._ElementUnicodeResult"], + tree.xpath("((//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1])/text()"), + ) if title: og["og:title"] = title[0].strip() else: og["og:title"] = None if "og:image" not in og: - meta_image = tree.xpath( - "//*/meta[translate(@itemprop, 'IMAGE', 'image')='image'][not(@content='')]/@content[1]" + # Cast: the type returned by xpath depends on the xpath expression: mypy can't deduce this. + meta_image = cast( + List["etree._ElementUnicodeResult"], + tree.xpath( + "//*/meta[translate(@itemprop, 'IMAGE', 'image')='image'][not(@content='')]/@content[1]" + ), ) # If a meta image is found, use it. if meta_image: og["og:image"] = meta_image[0] else: # Try to find images which are larger than 10px by 10px. + # Cast: the type returned by xpath depends on the xpath expression: mypy can't deduce this. # # TODO: consider inlined CSS styles as well as width & height attribs - images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]") + images = cast( + List["etree._Element"], + tree.xpath("//img[@src][number(@width)>10][number(@height)>10]"), + ) images = sorted( images, key=lambda i: ( @@ -302,20 +323,29 @@ def parse_html_to_open_graph(tree: "etree.Element") -> Dict[str, Optional[str]]: ) # If no images were found, try to find *any* images. if not images: - images = tree.xpath("//img[@src][1]") + # Cast: the type returned by xpath depends on the xpath expression: mypy can't deduce this. + images = cast(List["etree._Element"], tree.xpath("//img[@src][1]")) if images: - og["og:image"] = images[0].attrib["src"] + og["og:image"] = cast(str, images[0].attrib["src"]) # Finally, fallback to the favicon if nothing else. else: - favicons = tree.xpath("//link[@href][contains(@rel, 'icon')]/@href[1]") + # Cast: the type returned by xpath depends on the xpath expression: mypy can't deduce this. + favicons = cast( + List["etree._ElementUnicodeResult"], + tree.xpath("//link[@href][contains(@rel, 'icon')]/@href[1]"), + ) if favicons: og["og:image"] = favicons[0] if "og:description" not in og: # Check the first meta description tag for content. - meta_description = tree.xpath( - "//*/meta[translate(@name, 'DESCRIPTION', 'description')='description'][not(@content='')]/@content[1]" + # Cast: the type returned by xpath depends on the xpath expression: mypy can't deduce this. + meta_description = cast( + List["etree._ElementUnicodeResult"], + tree.xpath( + "//*/meta[translate(@name, 'DESCRIPTION', 'description')='description'][not(@content='')]/@content[1]" + ), ) # If a meta description is found with content, use it. if meta_description: @@ -332,7 +362,7 @@ def parse_html_to_open_graph(tree: "etree.Element") -> Dict[str, Optional[str]]: return og -def parse_html_description(tree: "etree.Element") -> Optional[str]: +def parse_html_description(tree: "etree._Element") -> Optional[str]: """ Calculate a text description based on an HTML document. @@ -368,6 +398,9 @@ def parse_html_description(tree: "etree.Element") -> Optional[str]: "canvas", "img", "picture", + # etree.Comment is a function which creates an etree._Comment element. + # The "tag" attribute of an etree._Comment instance is confusingly the + # etree.Comment function instead of a string. etree.Comment, } @@ -381,8 +414,8 @@ def parse_html_description(tree: "etree.Element") -> Optional[str]: def _iterate_over_text( - tree: Optional["etree.Element"], - tags_to_ignore: Set[Union[str, "etree.Comment"]], + tree: Optional["etree._Element"], + tags_to_ignore: Set[object], stack_limit: int = 1024, ) -> Generator[str, None, None]: """Iterate over the tree returning text nodes in a depth first fashion, @@ -402,7 +435,7 @@ def _iterate_over_text( # This is a stack whose items are elements to iterate over *or* strings # to be returned. - elements: List[Union[str, "etree.Element"]] = [tree] + elements: List[Union[str, "etree._Element"]] = [tree] while elements: el = elements.pop() diff --git a/synapse/media/storage_provider.py b/synapse/media/storage_provider.py index 1c9b71d69c..70a45cfd5b 100644 --- a/synapse/media/storage_provider.py +++ b/synapse/media/storage_provider.py @@ -20,6 +20,7 @@ from typing import TYPE_CHECKING, Callable, Optional from synapse.config._base import Config from synapse.logging.context import defer_to_thread, run_in_background +from synapse.logging.opentracing import start_active_span, trace_with_opname from synapse.util.async_helpers import maybe_awaitable from ._base import FileInfo, Responder @@ -86,6 +87,7 @@ class StorageProviderWrapper(StorageProvider): def __str__(self) -> str: return "StorageProviderWrapper[%s]" % (self.backend,) + @trace_with_opname("StorageProviderWrapper.store_file") async def store_file(self, path: str, file_info: FileInfo) -> None: if not file_info.server_name and not self.store_local: return None @@ -114,6 +116,7 @@ class StorageProviderWrapper(StorageProvider): run_in_background(store) + @trace_with_opname("StorageProviderWrapper.fetch") async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]: if file_info.url_cache: # Files in the URL preview cache definitely aren't stored here, @@ -141,6 +144,7 @@ class FileStorageProviderBackend(StorageProvider): def __str__(self) -> str: return "FileStorageProviderBackend[%s]" % (self.base_directory,) + @trace_with_opname("FileStorageProviderBackend.store_file") async def store_file(self, path: str, file_info: FileInfo) -> None: """See StorageProvider.store_file""" @@ -152,13 +156,15 @@ class FileStorageProviderBackend(StorageProvider): # mypy needs help inferring the type of the second parameter, which is generic shutil_copyfile: Callable[[str, str], str] = shutil.copyfile - await defer_to_thread( - self.hs.get_reactor(), - shutil_copyfile, - primary_fname, - backup_fname, - ) - + with start_active_span("shutil_copyfile"): + await defer_to_thread( + self.hs.get_reactor(), + shutil_copyfile, + primary_fname, + backup_fname, + ) + + @trace_with_opname("FileStorageProviderBackend.fetch") async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]: """See StorageProvider.fetch""" diff --git a/synapse/media/thumbnailer.py b/synapse/media/thumbnailer.py index f909a4fb9a..d8979813b3 100644 --- a/synapse/media/thumbnailer.py +++ b/synapse/media/thumbnailer.py @@ -19,6 +19,8 @@ from typing import Optional, Tuple, Type from PIL import Image +from synapse.logging.opentracing import trace + logger = logging.getLogger(__name__) EXIF_ORIENTATION_TAG = 0x0112 @@ -76,12 +78,13 @@ class Thumbnailer: image_exif = self.image._getexif() # type: ignore if image_exif is not None: image_orientation = image_exif.get(EXIF_ORIENTATION_TAG) - assert type(image_orientation) is int + assert type(image_orientation) is int # noqa: E721 self.transpose_method = EXIF_TRANSPOSE_MAPPINGS.get(image_orientation) except Exception as e: # A lot of parsing errors can happen when parsing EXIF logger.info("Error parsing image EXIF information: %s", e) + @trace def transpose(self) -> Tuple[int, int]: """Transpose the image using its EXIF Orientation tag @@ -131,8 +134,9 @@ class Thumbnailer: else: with self.image: self.image = self.image.convert("RGB") - return self.image.resize((width, height), Image.ANTIALIAS) + return self.image.resize((width, height), Image.LANCZOS) + @trace def scale(self, width: int, height: int, output_type: str) -> BytesIO: """Rescales the image to the given dimensions. @@ -142,6 +146,7 @@ class Thumbnailer: with self._resize(width, height) as scaled: return self._encode_image(scaled, output_type) + @trace def crop(self, width: int, height: int, output_type: str) -> BytesIO: """Rescales and crops the image to the given dimensions preserving aspect:: diff --git a/synapse/media/url_previewer.py b/synapse/media/url_previewer.py index 70b32cee17..9b5a3dd5f4 100644 --- a/synapse/media/url_previewer.py +++ b/synapse/media/url_previewer.py @@ -846,9 +846,7 @@ def _is_media(content_type: str) -> bool: def _is_html(content_type: str) -> bool: content_type = content_type.lower() - return content_type.startswith("text/html") or content_type.startswith( - "application/xhtml" - ) + return content_type.startswith(("text/html", "application/xhtml")) def _is_json(content_type: str) -> bool: diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py index 8ce5887229..3cf2fbc3e2 100644 --- a/synapse/metrics/__init__.py +++ b/synapse/metrics/__init__.py @@ -25,7 +25,6 @@ from typing import ( Iterable, Mapping, Optional, - Sequence, Set, Tuple, Type, @@ -49,6 +48,7 @@ import synapse.metrics._reactor_metrics # noqa: F401 from synapse.metrics._gc import MIN_TIME_BETWEEN_GCS, install_gc_manager from synapse.metrics._twisted_exposition import MetricsResource, generate_latest from synapse.metrics._types import Collector +from synapse.types import StrSequence from synapse.util import SYNAPSE_VERSION logger = logging.getLogger(__name__) @@ -77,9 +77,11 @@ RegistryProxy = cast(CollectorRegistry, _RegistryProxy) @attr.s(slots=True, hash=True, auto_attribs=True) class LaterGauge(Collector): + """A Gauge which periodically calls a user-provided callback to produce metrics.""" + name: str desc: str - labels: Optional[Sequence[str]] = attr.ib(hash=False) + labels: Optional[StrSequence] = attr.ib(hash=False) # callback: should either return a value (if there are no labels for this metric), # or dict mapping from a label tuple to a value caller: Callable[ @@ -141,8 +143,8 @@ class InFlightGauge(Generic[MetricsEntry], Collector): self, name: str, desc: str, - labels: Sequence[str], - sub_metrics: Sequence[str], + labels: StrSequence, + sub_metrics: StrSequence, ): self.name = name self.desc = desc diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py index 9ea4e23b31..fceb7a9f3c 100644 --- a/synapse/metrics/background_process_metrics.py +++ b/synapse/metrics/background_process_metrics.py @@ -48,6 +48,9 @@ from synapse.metrics._types import Collector if TYPE_CHECKING: import resource + # Old versions don't have `LiteralString` + from typing_extensions import LiteralString + logger = logging.getLogger(__name__) @@ -191,7 +194,7 @@ R = TypeVar("R") def run_as_background_process( - desc: str, + desc: "LiteralString", func: Callable[..., Awaitable[Optional[R]]], *args: Any, bg_start_span: bool = True, @@ -259,7 +262,7 @@ P = ParamSpec("P") def wrap_as_background_process( - desc: str, + desc: "LiteralString", ) -> Callable[ [Callable[P, Awaitable[Optional[R]]]], Callable[P, "defer.Deferred[Optional[R]]"], @@ -322,13 +325,21 @@ class BackgroundProcessLoggingContext(LoggingContext): if instance_id is None: instance_id = id(self) super().__init__("%s-%s" % (name, instance_id)) - self._proc = _BackgroundProcess(name, self) + self._proc: Optional[_BackgroundProcess] = _BackgroundProcess(name, self) def start(self, rusage: "Optional[resource.struct_rusage]") -> None: """Log context has started running (again).""" super().start(rusage) + if self._proc is None: + logger.error( + "Background process re-entered without a proc: %s", + self.name, + stack_info=True, + ) + return + # We've become active again so we make sure we're in the list of active # procs. (Note that "start" here means we've become active, as opposed # to starting for the first time.) @@ -345,6 +356,14 @@ class BackgroundProcessLoggingContext(LoggingContext): super().__exit__(type, value, traceback) + if self._proc is None: + logger.error( + "Background process exited without a proc: %s", + self.name, + stack_info=True, + ) + return + # The background process has finished. We explicitly remove and manually # update the metrics here so that if nothing is scraping metrics the set # doesn't infinitely grow. @@ -352,3 +371,6 @@ class BackgroundProcessLoggingContext(LoggingContext): _background_processes_active_since_last_scrape.discard(self._proc) self._proc.update_metrics() + + # Set proc to None to break the reference cycle. + self._proc = None diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 4f727e0836..c5dbe9c757 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -31,13 +31,15 @@ from typing import ( import attr import jinja2 -from typing_extensions import ParamSpec +from typing_extensions import Concatenate, ParamSpec from twisted.internet import defer +from twisted.internet.interfaces import IDelayedCall from twisted.web.resource import Resource from synapse.api import errors from synapse.api.errors import SynapseError +from synapse.config import ConfigError from synapse.events import EventBase from synapse.events.presence_router import ( GET_INTERESTED_USERS_CALLBACK, @@ -82,6 +84,7 @@ from synapse.module_api.callbacks.public_rooms_callbacks import ( ) from synapse.module_api.callbacks.spamchecker_callbacks import ( CHECK_EVENT_FOR_SPAM_CALLBACK, + CHECK_LOGIN_FOR_SPAM_CALLBACK, CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK, CHECK_REGISTRATION_FOR_SPAM_CALLBACK, CHECK_USERNAME_FOR_SPAM_CALLBACK, @@ -124,6 +127,7 @@ from synapse.types import ( JsonMapping, Requester, RoomAlias, + RoomID, StateMap, UserID, UserInfo, @@ -256,6 +260,7 @@ class ModuleApi: self._device_handler = hs.get_device_handler() self.custom_template_dir = hs.config.server.custom_template_directory self._callbacks = hs.get_module_api_callbacks() + self.msc3861_oauth_delegation_enabled = hs.config.experimental.msc3861.enabled try: app_name = self._hs.config.email.email_app_name @@ -303,6 +308,7 @@ class ModuleApi: CHECK_REGISTRATION_FOR_SPAM_CALLBACK ] = None, check_media_file_for_spam: Optional[CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK] = None, + check_login_for_spam: Optional[CHECK_LOGIN_FOR_SPAM_CALLBACK] = None, ) -> None: """Registers callbacks for spam checking capabilities. @@ -320,6 +326,7 @@ class ModuleApi: check_username_for_spam=check_username_for_spam, check_registration_for_spam=check_registration_for_spam, check_media_file_for_spam=check_media_file_for_spam, + check_login_for_spam=check_login_for_spam, ) def register_account_validity_callbacks( @@ -423,6 +430,11 @@ class ModuleApi: Added in Synapse v1.46.0. """ + if self.msc3861_oauth_delegation_enabled: + raise ConfigError( + "Cannot use password auth provider callbacks when OAuth delegation is enabled" + ) + return self._password_auth_provider.register_password_auth_provider_callbacks( check_3pid_auth=check_3pid_auth, on_logged_out=on_logged_out, @@ -577,7 +589,7 @@ class ModuleApi: Returns: UserInfo object if a user was found, otherwise None """ - return await self._store.get_userinfo_by_id(user_id) + return await self._store.get_user_by_id(user_id) async def get_user_by_req( self, @@ -664,7 +676,9 @@ class ModuleApi: Returns: The profile information (i.e. display name and avatar URL). """ - return await self._store.get_profileinfo(localpart) + server_name = self._hs.hostname + user_id = UserID.from_string(f"@{localpart}:{server_name}") + return await self._store.get_profileinfo(user_id) async def get_threepids_for_user(self, user_id: str) -> List[Dict[str, str]]: """Look up the threepids (email addresses and phone numbers) associated with the @@ -888,7 +902,7 @@ class ModuleApi: def run_db_interaction( self, desc: str, - func: Callable[P, T], + func: Callable[Concatenate[LoggingTransaction, P], T], *args: P.args, **kwargs: P.kwargs, ) -> "defer.Deferred[T]": @@ -1183,7 +1197,7 @@ class ModuleApi: # Send to remote destinations. destination = UserID.from_string(user).domain - presence_handler.get_federation_queue().send_presence_to_destinations( + await presence_handler.get_federation_queue().send_presence_to_destinations( presence_events, [destination] ) @@ -1234,6 +1248,58 @@ class ModuleApi: f, ) + def should_run_background_tasks(self) -> bool: + """ + Return true if and only if the current worker is configured to run + background tasks. + There should only be one worker configured to run background tasks, so + this is helpful when you need to only run a task on one worker but don't + have any other good way to choose which one. + + Added in Synapse v1.89.0. + """ + return self._hs.config.worker.run_background_tasks + + def delayed_background_call( + self, + msec: float, + f: Callable, + *args: object, + desc: Optional[str] = None, + **kwargs: object, + ) -> IDelayedCall: + """Wraps a function as a background process and calls it in a given number of milliseconds. + + The scheduled call is not persistent: if the current Synapse instance is + restarted before the call is made, the call will not be made. + + Added in Synapse v1.90.0. + + Args: + msec: How long to wait before calling, in milliseconds. + f: The function to call once. f can be either synchronous or + asynchronous, and must follow Synapse's logcontext rules. + More info about logcontexts is available at + https://matrix-org.github.io/synapse/latest/log_contexts.html + *args: Positional arguments to pass to function. + desc: The background task's description. Default to the function's name. + **kwargs: Keyword arguments to pass to function. + + Returns: + IDelayedCall handle from twisted, which allows to cancel the delayed call if desired. + """ + + if desc is None: + desc = f.__name__ + + return self._clock.call_later( + # convert ms to seconds as needed by call_later. + msec * 0.001, + run_as_background_process, + desc, + lambda: maybe_awaitable(f(*args, **kwargs)), + ) + async def sleep(self, seconds: float) -> None: """Sleeps for the given number of seconds. @@ -1580,6 +1646,32 @@ class ModuleApi: start_timestamp, end_timestamp ) + async def get_canonical_room_alias(self, room_id: RoomID) -> Optional[RoomAlias]: + """ + Retrieve the given room's current canonical alias. + + A room may declare an alias as "canonical", meaning that it is the + preferred alias to use when referring to the room. This function + retrieves that alias from the room's state. + + Added in Synapse v1.86.0. + + Args: + room_id: The Room ID to find the alias of. + + Returns: + None if the room ID does not exist, or if the room exists but has no canonical alias. + Otherwise, the parsed room alias. + """ + room_alias_str = ( + await self._storage_controllers.state.get_canonical_alias_for_room( + room_id.to_string() + ) + ) + if room_alias_str: + return RoomAlias.from_string(room_alias_str) + return None + async def lookup_room_alias(self, room_alias: str) -> Tuple[str, List[str]]: """ Get the room ID associated with a room alias. @@ -1655,6 +1747,30 @@ class ModuleApi: room_alias_str = room_alias.to_string() if room_alias else None return room_id, room_alias_str + async def delete_room(self, room_id: str) -> None: + """ + Schedules the deletion of a room from Synapse's database. + + If the room is already being deleted, this method does nothing. + This method does not wait for the room to be deleted. + + Added in Synapse v1.89.0. + """ + # Future extensions to this method might want to e.g. allow use of `force_purge`. + # TODO In the future we should make sure this is persistent. + await self._hs.get_pagination_handler().start_shutdown_and_purge_room( + room_id, + { + "new_room_user_id": None, + "new_room_name": None, + "message": None, + "requester_user_id": None, + "block": False, + "purge": True, + "force_purge": False, + }, + ) + async def set_displayname( self, user_id: UserID, @@ -1790,7 +1906,7 @@ class AccountDataManager: raise TypeError(f"new_data must be a dict; got {type(new_data).__name__}") # Ensure the user exists, so we don't just write to users that aren't there. - if await self._store.get_userinfo_by_id(user_id) is None: + if await self._store.get_user_by_id(user_id) is None: raise ValueError(f"User {user_id} does not exist on this server.") await self._handler.add_account_data_for_user(user_id, data_type, new_data) diff --git a/synapse/module_api/callbacks/spamchecker_callbacks.py b/synapse/module_api/callbacks/spamchecker_callbacks.py index 4456d1b81e..32db7cce8d 100644 --- a/synapse/module_api/callbacks/spamchecker_callbacks.py +++ b/synapse/module_api/callbacks/spamchecker_callbacks.py @@ -196,6 +196,26 @@ CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK = Callable[ ] ], ] +CHECK_LOGIN_FOR_SPAM_CALLBACK = Callable[ + [ + str, + Optional[str], + Optional[str], + Collection[Tuple[Optional[str], str]], + Optional[str], + ], + Awaitable[ + Union[ + Literal["NOT_SPAM"], + Codes, + # Highly experimental, not officially part of the spamchecker API, may + # disappear without warning depending on the results of ongoing + # experiments. + # Use this to return additional information as part of an error. + Tuple[Codes, JsonDict], + ] + ], +] def load_legacy_spam_checkers(hs: "synapse.server.HomeServer") -> None: @@ -315,6 +335,7 @@ class SpamCheckerModuleApiCallbacks: self._check_media_file_for_spam_callbacks: List[ CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK ] = [] + self._check_login_for_spam_callbacks: List[CHECK_LOGIN_FOR_SPAM_CALLBACK] = [] def register_callbacks( self, @@ -335,6 +356,7 @@ class SpamCheckerModuleApiCallbacks: CHECK_REGISTRATION_FOR_SPAM_CALLBACK ] = None, check_media_file_for_spam: Optional[CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK] = None, + check_login_for_spam: Optional[CHECK_LOGIN_FOR_SPAM_CALLBACK] = None, ) -> None: """Register callbacks from module for each hook.""" if check_event_for_spam is not None: @@ -378,6 +400,9 @@ class SpamCheckerModuleApiCallbacks: if check_media_file_for_spam is not None: self._check_media_file_for_spam_callbacks.append(check_media_file_for_spam) + if check_login_for_spam is not None: + self._check_login_for_spam_callbacks.append(check_login_for_spam) + @trace async def check_event_for_spam( self, event: "synapse.events.EventBase" @@ -401,9 +426,7 @@ class SpamCheckerModuleApiCallbacks: generally discouraged as it doesn't support internationalization. """ for callback in self._check_event_for_spam_callbacks: - with Measure( - self.clock, "{}.{}".format(callback.__module__, callback.__qualname__) - ): + with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"): res = await delay_cancellation(callback(event)) if res is False or res == self.NOT_SPAM: # This spam-checker accepts the event. @@ -456,9 +479,7 @@ class SpamCheckerModuleApiCallbacks: True if the event should be silently dropped """ for callback in self._should_drop_federated_event_callbacks: - with Measure( - self.clock, "{}.{}".format(callback.__module__, callback.__qualname__) - ): + with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"): res: Union[bool, str] = await delay_cancellation(callback(event)) if res: return res @@ -480,9 +501,7 @@ class SpamCheckerModuleApiCallbacks: NOT_SPAM if the operation is permitted, [Codes, Dict] otherwise. """ for callback in self._user_may_join_room_callbacks: - with Measure( - self.clock, "{}.{}".format(callback.__module__, callback.__qualname__) - ): + with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"): res = await delay_cancellation(callback(user_id, room_id, is_invited)) # Normalize return values to `Codes` or `"NOT_SPAM"`. if res is True or res is self.NOT_SPAM: @@ -521,9 +540,7 @@ class SpamCheckerModuleApiCallbacks: NOT_SPAM if the operation is permitted, Codes otherwise. """ for callback in self._user_may_invite_callbacks: - with Measure( - self.clock, "{}.{}".format(callback.__module__, callback.__qualname__) - ): + with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"): res = await delay_cancellation( callback(inviter_userid, invitee_userid, room_id) ) @@ -568,9 +585,7 @@ class SpamCheckerModuleApiCallbacks: NOT_SPAM if the operation is permitted, Codes otherwise. """ for callback in self._user_may_send_3pid_invite_callbacks: - with Measure( - self.clock, "{}.{}".format(callback.__module__, callback.__qualname__) - ): + with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"): res = await delay_cancellation( callback(inviter_userid, medium, address, room_id) ) @@ -605,9 +620,7 @@ class SpamCheckerModuleApiCallbacks: userid: The ID of the user attempting to create a room """ for callback in self._user_may_create_room_callbacks: - with Measure( - self.clock, "{}.{}".format(callback.__module__, callback.__qualname__) - ): + with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"): res = await delay_cancellation(callback(userid)) if res is True or res is self.NOT_SPAM: continue @@ -641,9 +654,7 @@ class SpamCheckerModuleApiCallbacks: """ for callback in self._user_may_create_room_alias_callbacks: - with Measure( - self.clock, "{}.{}".format(callback.__module__, callback.__qualname__) - ): + with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"): res = await delay_cancellation(callback(userid, room_alias)) if res is True or res is self.NOT_SPAM: continue @@ -676,9 +687,7 @@ class SpamCheckerModuleApiCallbacks: room_id: The ID of the room that would be published """ for callback in self._user_may_publish_room_callbacks: - with Measure( - self.clock, "{}.{}".format(callback.__module__, callback.__qualname__) - ): + with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"): res = await delay_cancellation(callback(userid, room_id)) if res is True or res is self.NOT_SPAM: continue @@ -717,9 +726,7 @@ class SpamCheckerModuleApiCallbacks: True if the user is spammy. """ for callback in self._check_username_for_spam_callbacks: - with Measure( - self.clock, "{}.{}".format(callback.__module__, callback.__qualname__) - ): + with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"): # Make a copy of the user profile object to ensure the spam checker cannot # modify it. res = await delay_cancellation(callback(user_profile.copy())) @@ -751,9 +758,7 @@ class SpamCheckerModuleApiCallbacks: """ for callback in self._check_registration_for_spam_callbacks: - with Measure( - self.clock, "{}.{}".format(callback.__module__, callback.__qualname__) - ): + with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"): behaviour = await delay_cancellation( callback(email_threepid, username, request_info, auth_provider_id) ) @@ -763,6 +768,7 @@ class SpamCheckerModuleApiCallbacks: return RegistrationBehaviour.ALLOW + @trace async def check_media_file_for_spam( self, file_wrapper: ReadableFileWrapper, file_info: FileInfo ) -> Union[Tuple[Codes, dict], Literal["NOT_SPAM"]]: @@ -794,9 +800,7 @@ class SpamCheckerModuleApiCallbacks: """ for callback in self._check_media_file_for_spam_callbacks: - with Measure( - self.clock, "{}.{}".format(callback.__module__, callback.__qualname__) - ): + with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"): res = await delay_cancellation(callback(file_wrapper, file_info)) # Normalize return values to `Codes` or `"NOT_SPAM"`. if res is False or res is self.NOT_SPAM: @@ -819,3 +823,56 @@ class SpamCheckerModuleApiCallbacks: return synapse.api.errors.Codes.FORBIDDEN, {} return self.NOT_SPAM + + async def check_login_for_spam( + self, + user_id: str, + device_id: Optional[str], + initial_display_name: Optional[str], + request_info: Collection[Tuple[Optional[str], str]], + auth_provider_id: Optional[str] = None, + ) -> Union[Tuple[Codes, dict], Literal["NOT_SPAM"]]: + """Checks if we should allow the given registration request. + + Args: + user_id: The request user ID + request_info: List of tuples of user agent and IP that + were used during the registration process. + auth_provider_id: The SSO IdP the user used, e.g "oidc", "saml", + "cas". If any. Note this does not include users registered + via a password provider. + + Returns: + Enum for how the request should be handled + """ + + for callback in self._check_login_for_spam_callbacks: + with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"): + res = await delay_cancellation( + callback( + user_id, + device_id, + initial_display_name, + request_info, + auth_provider_id, + ) + ) + # Normalize return values to `Codes` or `"NOT_SPAM"`. + if res is self.NOT_SPAM: + continue + elif isinstance(res, synapse.api.errors.Codes): + return res, {} + elif ( + isinstance(res, tuple) + and len(res) == 2 + and isinstance(res[0], synapse.api.errors.Codes) + and isinstance(res[1], dict) + ): + return res + else: + logger.warning( + "Module returned invalid value, rejecting login as spam" + ) + return synapse.api.errors.Codes.FORBIDDEN, {} + + return self.NOT_SPAM diff --git a/synapse/module_api/callbacks/third_party_event_rules_callbacks.py b/synapse/module_api/callbacks/third_party_event_rules_callbacks.py index 911f37ba42..ecaeef3511 100644 --- a/synapse/module_api/callbacks/third_party_event_rules_callbacks.py +++ b/synapse/module_api/callbacks/third_party_event_rules_callbacks.py @@ -40,7 +40,7 @@ CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK = Callable[ [str, StateMap[EventBase], str], Awaitable[bool] ] ON_NEW_EVENT_CALLBACK = Callable[[EventBase, StateMap[EventBase]], Awaitable] -CHECK_CAN_SHUTDOWN_ROOM_CALLBACK = Callable[[str, str], Awaitable[bool]] +CHECK_CAN_SHUTDOWN_ROOM_CALLBACK = Callable[[Optional[str], str], Awaitable[bool]] CHECK_CAN_DEACTIVATE_USER_CALLBACK = Callable[[str, bool], Awaitable[bool]] ON_PROFILE_UPDATE_CALLBACK = Callable[[str, ProfileInfo, bool, bool], Awaitable] ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK = Callable[[str, bool, bool], Awaitable] @@ -429,12 +429,17 @@ class ThirdPartyEventRulesModuleApiCallbacks: "Failed to run module API callback %s: %s", callback, e ) - async def check_can_shutdown_room(self, user_id: str, room_id: str) -> bool: + async def check_can_shutdown_room( + self, user_id: Optional[str], room_id: str + ) -> bool: """Intercept requests to shutdown a room. If `False` is returned, the room must not be shut down. Args: - requester: The ID of the user requesting the shutdown. + user_id: The ID of the user requesting the shutdown. + If no user ID is supplied, then the room is being shut down through + some mechanism other than a user's request, e.g. through a module's + request. room_id: The ID of the room. """ for callback in self._check_can_shutdown_room_callbacks: diff --git a/synapse/notifier.py b/synapse/notifier.py index 897272ad5b..99e7715896 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -104,7 +104,7 @@ class _NotifierUserStream: def __init__( self, user_id: str, - rooms: Collection[str], + rooms: StrCollection, current_token: StreamToken, time_now_ms: int, ): @@ -126,7 +126,7 @@ class _NotifierUserStream: def notify( self, - stream_key: str, + stream_key: StreamKeyType, stream_id: Union[int, RoomStreamToken], time_now_ms: int, ) -> None: @@ -234,6 +234,9 @@ class Notifier: self._third_party_rules = hs.get_module_api_callbacks().third_party_event_rules + # List of callbacks to be notified when a lock is released + self._lock_released_callback: List[Callable[[str, str, str], None]] = [] + self.clock = hs.get_clock() self.appservice_handler = hs.get_application_service_handler() self._pusher_pool = hs.get_pusherpool() @@ -451,10 +454,10 @@ class Notifier: def on_new_event( self, - stream_key: str, + stream_key: StreamKeyType, new_token: Union[int, RoomStreamToken], users: Optional[Collection[Union[str, UserID]]] = None, - rooms: Optional[Collection[str]] = None, + rooms: Optional[StrCollection] = None, ) -> None: """Used to inform listeners that something has happened event wise. @@ -526,7 +529,7 @@ class Notifier: user_id: str, timeout: int, callback: Callable[[StreamToken, StreamToken], Awaitable[T]], - room_ids: Optional[Collection[str]] = None, + room_ids: Optional[StrCollection] = None, from_token: StreamToken = StreamToken.START, ) -> T: """Wait until the callback returns a non empty response or the @@ -652,30 +655,29 @@ class Notifier: events: List[Union[JsonDict, EventBase]] = [] end_token = from_token - for name, source in self.event_sources.sources.get_sources(): - keyname = "%s_key" % name - before_id = getattr(before_token, keyname) - after_id = getattr(after_token, keyname) + for keyname, source in self.event_sources.sources.get_sources(): + before_id = before_token.get_field(keyname) + after_id = after_token.get_field(keyname) if before_id == after_id: continue new_events, new_key = await source.get_new_events( user=user, - from_key=getattr(from_token, keyname), + from_key=from_token.get_field(keyname), limit=limit, is_guest=is_peeking, room_ids=room_ids, explicit_room_id=explicit_room_id, ) - if name == "room": + if keyname == StreamKeyType.ROOM: new_events = await filter_events_for_client( self._storage_controllers, user.to_string(), new_events, is_peeking=is_peeking, ) - elif name == "presence": + elif keyname == StreamKeyType.PRESENCE: now = self.clock.time_msec() new_events[:] = [ { @@ -785,6 +787,19 @@ class Notifier: # that any in flight requests can be immediately retried. self._federation_client.wake_destination(server) + def add_lock_released_callback( + self, callback: Callable[[str, str, str], None] + ) -> None: + """Add a function to be called whenever we are notified about a released lock.""" + self._lock_released_callback.append(callback) + + def notify_lock_released( + self, instance_name: str, lock_name: str, lock_key: str + ) -> None: + """Notify the callbacks that a lock has been released.""" + for cb in self._lock_released_callback: + cb(instance_name, lock_name, lock_key) + @attr.s(auto_attribs=True) class ReplicationNotifier: diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py index 9e3a98741a..4d405f2a0c 100644 --- a/synapse/push/__init__.py +++ b/synapse/push/__init__.py @@ -101,7 +101,7 @@ if TYPE_CHECKING: class PusherConfig: """Parameters necessary to configure a pusher.""" - id: Optional[str] + id: Optional[int] user_name: str profile_tag: str @@ -182,7 +182,7 @@ class Pusher(metaclass=abc.ABCMeta): raise NotImplementedError() @abc.abstractmethod - def on_new_receipts(self, min_stream_id: int, max_stream_id: int) -> None: + def on_new_receipts(self) -> None: raise NotImplementedError() @abc.abstractmethod diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 320084f5f5..14784312dc 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -120,9 +120,6 @@ class BulkPushRuleEvaluator: self.should_calculate_push_rules = self.hs.config.push.enable_push self._related_event_match_enabled = self.hs.config.experimental.msc3664_enabled - self._intentional_mentions_enabled = ( - self.hs.config.experimental.msc3952_intentional_mentions - ) self.room_push_rule_cache_metrics = register_cache( "cache", @@ -134,7 +131,7 @@ class BulkPushRuleEvaluator: async def _get_rules_for_event( self, event: EventBase, - ) -> Dict[str, FilteredPushRules]: + ) -> Mapping[str, FilteredPushRules]: """Get the push rules for all users who may need to be notified about the event. @@ -325,7 +322,6 @@ class BulkPushRuleEvaluator: ) -> None: if ( not event.internal_metadata.is_notifiable() - or event.internal_metadata.is_historical() or event.room_id in self.hs.config.server.rooms_to_exclude_from_sync ): # Push rules for events that aren't notifiable can't be processed by this and @@ -379,21 +375,18 @@ class BulkPushRuleEvaluator: # _get_power_levels_and_sender_level in its call to get_user_power_level # (even for room V10.) notification_levels = power_levels.get("notifications", {}) - if not event.room_version.msc3667_int_only_power_levels: + if not event.room_version.enforce_int_power_levels: keys = list(notification_levels.keys()) for key in keys: level = notification_levels.get(key, SENTINEL) - if level is not SENTINEL and type(level) is not int: + if level is not SENTINEL and type(level) is not int: # noqa: E721 try: notification_levels[key] = int(level) except (TypeError, ValueError): del notification_levels[key] # Pull out any user and room mentions. - has_mentions = ( - self._intentional_mentions_enabled - and EventContentFields.MSC3952_MENTIONS in event.content - ) + has_mentions = EventContentFields.MENTIONS in event.content evaluator = PushRuleEvaluator( _flatten_dict(event), @@ -479,7 +472,11 @@ StateGroup = Union[object, int] def _is_simple_value(value: Any) -> bool: - return isinstance(value, (bool, str)) or type(value) is int or value is None + return ( + isinstance(value, (bool, str)) + or type(value) is int # noqa: E721 + or value is None + ) def _flatten_dict( diff --git a/synapse/push/clientformat.py b/synapse/push/clientformat.py index 88b52c26a0..735cef0aed 100644 --- a/synapse/push/clientformat.py +++ b/synapse/push/clientformat.py @@ -41,12 +41,7 @@ def format_push_rules_for_user( rulearray.append(template_rule) - for type_key in ("pattern", "value"): - type_value = template_rule.pop(f"{type_key}_type", None) - if type_value == "user_id": - template_rule[type_key] = user.to_string() - elif type_value == "user_localpart": - template_rule[type_key] = user.localpart + _convert_type_to_value(template_rule, user) template_rule["enabled"] = enabled @@ -63,19 +58,20 @@ def format_push_rules_for_user( for c in template_rule["conditions"]: c.pop("_cache_key", None) - pattern_type = c.pop("pattern_type", None) - if pattern_type == "user_id": - c["pattern"] = user.to_string() - elif pattern_type == "user_localpart": - c["pattern"] = user.localpart - - sender_type = c.pop("sender_type", None) - if sender_type == "user_id": - c["sender"] = user.to_string() + _convert_type_to_value(c, user) return rules +def _convert_type_to_value(rule_or_cond: Dict[str, Any], user: UserID) -> None: + for type_key in ("pattern", "value"): + type_value = rule_or_cond.pop(f"{type_key}_type", None) + if type_value == "user_id": + rule_or_cond[type_key] = user.to_string() + elif type_value == "user_localpart": + rule_or_cond[type_key] = user.localpart + + def _add_empty_priority_class_arrays(d: Dict[str, list]) -> Dict[str, list]: for pc in PRIORITY_CLASS_MAP.keys(): d[pc] = [] diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py index 1710dd51b9..cf45fd09a8 100644 --- a/synapse/push/emailpusher.py +++ b/synapse/push/emailpusher.py @@ -99,7 +99,7 @@ class EmailPusher(Pusher): pass self.timed_call = None - def on_new_receipts(self, min_stream_id: int, max_stream_id: int) -> None: + def on_new_receipts(self) -> None: # We could wake up and cancel the timer but there tend to be quite a # lot of read receipts so it's probably less work to just let the # timer fire diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index 50027680cb..725910a659 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -160,7 +160,7 @@ class HttpPusher(Pusher): if should_check_for_notifs: self._start_processing() - def on_new_receipts(self, min_stream_id: int, max_stream_id: int) -> None: + def on_new_receipts(self) -> None: # Note that the min here shouldn't be relied upon to be accurate. # We could check the receipts are actually m.read receipts here, diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index 491a09b71d..b6cad18c2d 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -247,7 +247,7 @@ class Mailer: try: user_display_name = await self.store.get_profile_displayname( - UserID.from_string(user_id).localpart + UserID.from_string(user_id) ) if user_display_name is None: user_display_name = user_id @@ -298,20 +298,26 @@ class Mailer: notifs_by_room, state_by_room, notif_events, reason ) + unsubscribe_link = self._make_unsubscribe_link(user_id, app_id, email_address) + template_vars: TemplateVars = { "user_display_name": user_display_name, - "unsubscribe_link": self._make_unsubscribe_link( - user_id, app_id, email_address - ), + "unsubscribe_link": unsubscribe_link, "summary_text": summary_text, "rooms": rooms, "reason": reason, } - await self.send_email(email_address, summary_text, template_vars) + await self.send_email( + email_address, summary_text, template_vars, unsubscribe_link + ) async def send_email( - self, email_address: str, subject: str, extra_template_vars: TemplateVars + self, + email_address: str, + subject: str, + extra_template_vars: TemplateVars, + unsubscribe_link: Optional[str] = None, ) -> None: """Send an email with the given information and template text""" template_vars: TemplateVars = { @@ -330,6 +336,23 @@ class Mailer: app_name=self.app_name, html=html_text, text=plain_text, + # Include the List-Unsubscribe header which some clients render in the UI. + # Per RFC 2369, this can be a URL or mailto URL. See + # https://www.rfc-editor.org/rfc/rfc2369.html#section-3.2 + # + # It is preferred to use email, but Synapse doesn't support incoming email. + # + # Also include the List-Unsubscribe-Post header from RFC 8058. See + # https://www.rfc-editor.org/rfc/rfc8058.html#section-3.1 + # + # Note that many email clients will not render the unsubscribe link + # unless DKIM, etc. is properly setup. + additional_headers={ + "List-Unsubscribe-Post": "List-Unsubscribe=One-Click", + "List-Unsubscribe": f"<{unsubscribe_link}>", + } + if unsubscribe_link + else None, ) async def _get_room_vars( diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index 7ee07e4bee..a94a6e97c1 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -13,6 +13,7 @@ # limitations under the License. from typing import Dict +from synapse.api.constants import EventTypes, Membership from synapse.events import EventBase from synapse.push.presentable_names import calculate_room_name, name_from_member_event from synapse.storage.controllers import StorageControllers @@ -49,7 +50,41 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) - async def get_context_for_event( storage: StorageControllers, ev: EventBase, user_id: str ) -> Dict[str, str]: - ctx = {} + ctx: Dict[str, str] = {} + + if ev.internal_metadata.outlier: + # We don't have state for outliers, so we can't compute the context + # except for invites and knocks. (Such events are known as 'out-of-band + # memberships' for the user). + if ev.type != EventTypes.Member: + return ctx + + # We might be able to pull out the display name for the sender straight + # from the membership event + event_display_name = ev.content.get("displayname") + if event_display_name and ev.state_key == ev.sender: + ctx["sender_display_name"] = event_display_name + + room_state = [] + if ev.content.get("membership") == Membership.INVITE: + room_state = ev.unsigned.get("invite_room_state", []) + elif ev.content.get("membership") == Membership.KNOCK: + room_state = ev.unsigned.get("knock_room_state", []) + + # Ideally we'd reuse the logic in `calculate_room_name`, but that gets + # complicated to handle partial events vs pulling events from the DB. + for state_dict in room_state: + type_tuple = (state_dict["type"], state_dict.get("state_key")) + if type_tuple == (EventTypes.Member, ev.sender): + display_name = state_dict["content"].get("displayname") + if display_name: + ctx["sender_display_name"] = display_name + elif type_tuple == (EventTypes.Name, ""): + room_name = state_dict["content"].get("name") + if room_name: + ctx["name"] = room_name + + return ctx room_state_ids = await storage.state.get_state_ids_for_event(ev.event_id) diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 6517e3566f..15a2cc932f 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -292,20 +292,12 @@ class PusherPool: except Exception: logger.exception("Exception in pusher on_new_notifications") - async def on_new_receipts( - self, min_stream_id: int, max_stream_id: int, affected_room_ids: Iterable[str] - ) -> None: + async def on_new_receipts(self, users_affected: StrCollection) -> None: if not self.pushers: # nothing to do here. return try: - # Need to subtract 1 from the minimum because the lower bound here - # is not inclusive - users_affected = await self.store.get_users_sent_receipts_between( - min_stream_id - 1, max_stream_id - ) - for u in users_affected: # Don't push if the user account has expired expired = await self._account_validity_handler.is_user_expired(u) @@ -314,7 +306,7 @@ class PusherPool: if u in self.pushers: for p in self.pushers[u].values(): - p.on_new_receipts(min_stream_id, max_stream_id) + p.on_new_receipts() except Exception: logger.exception("Exception in pusher on_new_receipts") diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index dc7820f963..63cf24a14d 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -219,11 +219,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): with outgoing_gauge.track_inprogress(): if instance_name == local_instance_name: raise Exception("Trying to send HTTP request to self") - if instance_name in instance_map: - host = instance_map[instance_name].host - port = instance_map[instance_name].port - tls = instance_map[instance_name].tls - else: + if instance_name not in instance_map: raise Exception( "Instance %r not in 'instance_map' config" % (instance_name,) ) @@ -271,13 +267,11 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): "Unknown METHOD on %s replication endpoint" % (cls.NAME,) ) - # Here the protocol is hard coded to be http by default or https in case the replication - # port is set to have tls true. - scheme = "https" if tls else "http" - uri = "%s://%s:%s/_synapse/replication/%s/%s" % ( - scheme, - host, - port, + # Hard code a special scheme to show this only used for replication. The + # instance_name will be passed into the ReplicationEndpointFactory to + # determine connection details from the instance_map. + uri = "synapse-replication://%s/_synapse/replication/%s/%s" % ( + instance_name, cls.NAME, "/".join(url_args), ) diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py index f874f072f9..b8198e059c 100644 --- a/synapse/replication/http/devices.py +++ b/synapse/replication/http/devices.py @@ -20,7 +20,7 @@ from twisted.web.server import Request from synapse.http.server import HttpServer from synapse.logging.opentracing import active_span from synapse.replication.http._base import ReplicationEndpoint -from synapse.types import JsonDict +from synapse.types import JsonDict, JsonMapping if TYPE_CHECKING: from synapse.server import HomeServer @@ -62,7 +62,7 @@ class ReplicationMultiUserDevicesResyncRestServlet(ReplicationEndpoint): NAME = "multi_user_device_resync" PATH_ARGS = () - CACHE = False + CACHE = True def __init__(self, hs: "HomeServer"): super().__init__(hs) @@ -82,7 +82,7 @@ class ReplicationMultiUserDevicesResyncRestServlet(ReplicationEndpoint): async def _handle_request( # type: ignore[override] self, request: Request, content: JsonDict - ) -> Tuple[int, Dict[str, Optional[JsonDict]]]: + ) -> Tuple[int, Dict[str, Optional[JsonMapping]]]: user_ids: List[str] = content["user_ids"] logger.info("Resync for %r", user_ids) @@ -107,8 +107,7 @@ class ReplicationUploadKeysForUserRestServlet(ReplicationEndpoint): Calls to e2e_keys_handler.upload_keys_for_user(user_id, device_id, keys) on the main process to accomplish this. - Defined in https://spec.matrix.org/v1.4/client-server-api/#post_matrixclientv3keysupload - Request format(borrowed and expanded from KeyUploadServlet): + Request format for this endpoint (borrowed and expanded from KeyUploadServlet): POST /_synapse/replication/upload_keys_for_user @@ -117,6 +116,7 @@ class ReplicationUploadKeysForUserRestServlet(ReplicationEndpoint): "device_id": "<device_id>", "keys": { ....this part can be found in KeyUploadServlet in rest/client/keys.py.... + or as defined in https://spec.matrix.org/v1.4/client-server-api/#post_matrixclientv3keysupload } } diff --git a/synapse/replication/http/presence.py b/synapse/replication/http/presence.py index db16aac9c2..6c9e79fb07 100644 --- a/synapse/replication/http/presence.py +++ b/synapse/replication/http/presence.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Tuple +from typing import TYPE_CHECKING, Optional, Tuple from twisted.web.server import Request @@ -51,14 +51,14 @@ class ReplicationBumpPresenceActiveTime(ReplicationEndpoint): self._presence_handler = hs.get_presence_handler() @staticmethod - async def _serialize_payload(user_id: str) -> JsonDict: # type: ignore[override] - return {} + async def _serialize_payload(user_id: str, device_id: Optional[str]) -> JsonDict: # type: ignore[override] + return {"device_id": device_id} async def _handle_request( # type: ignore[override] self, request: Request, content: JsonDict, user_id: str ) -> Tuple[int, JsonDict]: await self._presence_handler.bump_presence_active_time( - UserID.from_string(user_id) + UserID.from_string(user_id), content.get("device_id") ) return (200, {}) @@ -73,8 +73,8 @@ class ReplicationPresenceSetState(ReplicationEndpoint): { "state": { ... }, - "ignore_status_msg": false, - "force_notify": false + "force_notify": false, + "is_sync": false } 200 OK @@ -95,14 +95,16 @@ class ReplicationPresenceSetState(ReplicationEndpoint): @staticmethod async def _serialize_payload( # type: ignore[override] user_id: str, + device_id: Optional[str], state: JsonDict, - ignore_status_msg: bool = False, force_notify: bool = False, + is_sync: bool = False, ) -> JsonDict: return { + "device_id": device_id, "state": state, - "ignore_status_msg": ignore_status_msg, "force_notify": force_notify, + "is_sync": is_sync, } async def _handle_request( # type: ignore[override] @@ -110,9 +112,10 @@ class ReplicationPresenceSetState(ReplicationEndpoint): ) -> Tuple[int, JsonDict]: await self._presence_handler.set_state( UserID.from_string(user_id), + content.get("device_id"), content["state"], - content["ignore_status_msg"], content["force_notify"], + content.get("is_sync", False), ) return (200, {}) diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 139f57cf86..d5337fe588 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -14,7 +14,9 @@ """A replication client for use by synapse workers. """ import logging -from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Dict, Iterable, Optional, Set, Tuple + +from sortedcontainers import SortedList from twisted.internet import defer from twisted.internet.defer import Deferred @@ -84,7 +86,9 @@ class ReplicationDataHandler: # Map from stream and instance to list of deferreds waiting for the stream to # arrive at a particular position. The lists are sorted by stream position. - self._streams_to_waiters: Dict[Tuple[str, str], List[Tuple[int, Deferred]]] = {} + self._streams_to_waiters: Dict[ + Tuple[str, str], SortedList[Tuple[int, Deferred]] + ] = {} async def on_rdata( self, stream_name: str, instance_name: str, token: int, rows: list @@ -125,9 +129,7 @@ class ReplicationDataHandler: self.notifier.on_new_event( StreamKeyType.RECEIPT, token, rooms=[row.room_id for row in rows] ) - await self._pusher_pool.on_new_receipts( - token, token, {row.room_id for row in rows} - ) + await self._pusher_pool.on_new_receipts({row.user_id for row in rows}) elif stream_name == ToDeviceStream.NAME: entities = [row.entity for row in rows if row.entity.startswith("@")] if entities: @@ -201,6 +203,12 @@ class ReplicationDataHandler: self.notifier.notify_user_joined_room( row.data.event_id, row.data.room_id ) + + # If this is a server ACL event, clear the cache in the storage controller. + if row.data.type == EventTypes.ServerACL: + self._state_storage_controller.get_server_acl_for_room.invalidate( + (row.data.room_id,) + ) elif stream_name == UnPartialStatedRoomStream.NAME: for row in rows: assert isinstance(row, UnPartialStatedRoomStreamRow) @@ -226,7 +234,9 @@ class ReplicationDataHandler: # Notify any waiting deferreds. The list is ordered by position so we # just iterate through the list until we reach a position that is # greater than the received row position. - waiting_list = self._streams_to_waiters.get((stream_name, instance_name), []) + waiting_list = self._streams_to_waiters.get((stream_name, instance_name)) + if not waiting_list: + return # Index of first item with a position after the current token, i.e we # have called all deferreds before this index. If not overwritten by @@ -250,7 +260,7 @@ class ReplicationDataHandler: # Drop all entries in the waiting list that were called in the above # loop. (This maintains the order so no need to resort) - waiting_list[:] = waiting_list[index_of_first_deferred_not_called:] + del waiting_list[:index_of_first_deferred_not_called] for deferred in deferreds_to_callback: try: @@ -310,11 +320,10 @@ class ReplicationDataHandler: ) waiting_list = self._streams_to_waiters.setdefault( - (stream_name, instance_name), [] + (stream_name, instance_name), SortedList(key=lambda t: t[0]) ) - waiting_list.append((position, deferred)) - waiting_list.sort(key=lambda t: t[0]) + waiting_list.add((position, deferred)) # We measure here to get in flight counts and average waiting time. with Measure(self._clock, "repl.wait_for_stream_position"): @@ -328,7 +337,7 @@ class ReplicationDataHandler: try: await make_deferred_yieldable(deferred) except defer.TimeoutError: - logger.error( + logger.warning( "Timed out waiting for repl stream %r to reach %s (%s)" "; currently at: %s", stream_name, @@ -405,7 +414,7 @@ class FederationSenderHandler: # The federation stream contains things that we want to send out, e.g. # presence, typing, etc. if stream_name == "federation": - send_queue.process_rows_for_federation(self.federation_sender, rows) + await send_queue.process_rows_for_federation(self.federation_sender, rows) await self.update_token(token) # ... and when new receipts happen @@ -422,16 +431,14 @@ class FederationSenderHandler: for row in rows if not row.entity.startswith("@") and not row.is_signature } - for host in hosts: - self.federation_sender.send_device_messages(host, immediate=False) + await self.federation_sender.send_device_messages(hosts, immediate=False) elif stream_name == ToDeviceStream.NAME: # The to_device stream includes stuff to be pushed to both local # clients and remote servers, so we ignore entities that start with # '@' (since they'll be local users rather than destinations). hosts = {row.entity for row in rows if not row.entity.startswith("@")} - for host in hosts: - self.federation_sender.send_device_messages(host) + await self.federation_sender.send_device_messages(hosts) async def _on_new_receipts( self, rows: Iterable[ReceiptsStream.ReceiptsStreamRow] diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index 32f52e54d8..0f0f851b79 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -18,7 +18,7 @@ allowed to be sent by which side. """ import abc import logging -from typing import Optional, Tuple, Type, TypeVar +from typing import List, Optional, Tuple, Type, TypeVar from synapse.replication.tcp.streams._base import StreamRow from synapse.util import json_decoder, json_encoder @@ -74,6 +74,8 @@ SC = TypeVar("SC", bound="_SimpleCommand") class _SimpleCommand(Command): """An implementation of Command whose argument is just a 'data' string.""" + __slots__ = ["data"] + def __init__(self, data: str): self.data = data @@ -122,6 +124,8 @@ class RdataCommand(Command): RDATA presence master 59 ["@baz:example.com", "online", ...] """ + __slots__ = ["stream_name", "instance_name", "token", "row"] + NAME = "RDATA" def __init__( @@ -179,6 +183,8 @@ class PositionCommand(Command): of the stream. """ + __slots__ = ["stream_name", "instance_name", "prev_token", "new_token"] + NAME = "POSITION" def __init__( @@ -235,6 +241,8 @@ class ReplicateCommand(Command): REPLICATE """ + __slots__: List[str] = [] + NAME = "REPLICATE" def __init__(self) -> None: @@ -264,30 +272,43 @@ class UserSyncCommand(Command): Where <state> is either "start" or "end" """ + __slots__ = ["instance_id", "user_id", "device_id", "is_syncing", "last_sync_ms"] + NAME = "USER_SYNC" def __init__( - self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int + self, + instance_id: str, + user_id: str, + device_id: Optional[str], + is_syncing: bool, + last_sync_ms: int, ): self.instance_id = instance_id self.user_id = user_id + self.device_id = device_id self.is_syncing = is_syncing self.last_sync_ms = last_sync_ms @classmethod def from_line(cls: Type["UserSyncCommand"], line: str) -> "UserSyncCommand": - instance_id, user_id, state, last_sync_ms = line.split(" ", 3) + device_id: Optional[str] + instance_id, user_id, device_id, state, last_sync_ms = line.split(" ", 4) + + if device_id == "None": + device_id = None if state not in ("start", "end"): raise Exception("Invalid USER_SYNC state %r" % (state,)) - return cls(instance_id, user_id, state == "start", int(last_sync_ms)) + return cls(instance_id, user_id, device_id, state == "start", int(last_sync_ms)) def to_line(self) -> str: return " ".join( ( self.instance_id, self.user_id, + str(self.device_id), "start" if self.is_syncing else "end", str(self.last_sync_ms), ) @@ -305,6 +326,8 @@ class ClearUserSyncsCommand(Command): CLEAR_USER_SYNC <instance_id> """ + __slots__ = ["instance_id"] + NAME = "CLEAR_USER_SYNC" def __init__(self, instance_id: str): @@ -332,6 +355,8 @@ class FederationAckCommand(Command): FEDERATION_ACK <instance_name> <token> """ + __slots__ = ["instance_name", "token"] + NAME = "FEDERATION_ACK" def __init__(self, instance_name: str, token: int): @@ -357,6 +382,15 @@ class UserIpCommand(Command): USER_IP <user_id>, <access_token>, <ip>, <device_id>, <last_seen>, <user_agent> """ + __slots__ = [ + "user_id", + "access_token", + "ip", + "user_agent", + "device_id", + "last_seen", + ] + NAME = "USER_IP" def __init__( @@ -412,8 +446,6 @@ class RemoteServerUpCommand(_SimpleCommand): """Sent when a worker has detected that a remote server is no longer "down" and retry timings should be reset. - If sent from a client the server will relay to all other workers. - Format:: REMOTE_SERVER_UP <server> @@ -422,6 +454,49 @@ class RemoteServerUpCommand(_SimpleCommand): NAME = "REMOTE_SERVER_UP" +class LockReleasedCommand(Command): + """Sent to inform other instances that a given lock has been dropped. + + Format:: + + LOCK_RELEASED ["<instance_name>", "<lock_name>", "<lock_key>"] + """ + + __slots__ = ["instance_name", "lock_name", "lock_key"] + + NAME = "LOCK_RELEASED" + + def __init__( + self, + instance_name: str, + lock_name: str, + lock_key: str, + ): + self.instance_name = instance_name + self.lock_name = lock_name + self.lock_key = lock_key + + @classmethod + def from_line(cls: Type["LockReleasedCommand"], line: str) -> "LockReleasedCommand": + instance_name, lock_name, lock_key = json_decoder.decode(line) + + return cls(instance_name, lock_name, lock_key) + + def to_line(self) -> str: + return json_encoder.encode([self.instance_name, self.lock_name, self.lock_key]) + + +class NewActiveTaskCommand(_SimpleCommand): + """Sent to inform instance handling background tasks that a new active task is available to run. + + Format:: + + NEW_ACTIVE_TASK "<task_id>" + """ + + NAME = "NEW_ACTIVE_TASK" + + _COMMANDS: Tuple[Type[Command], ...] = ( ServerCommand, RdataCommand, @@ -435,6 +510,8 @@ _COMMANDS: Tuple[Type[Command], ...] = ( UserIpCommand, RemoteServerUpCommand, ClearUserSyncsCommand, + LockReleasedCommand, + NewActiveTaskCommand, ) # Map of command name to command type. @@ -448,6 +525,7 @@ VALID_SERVER_COMMANDS = ( ErrorCommand.NAME, PingCommand.NAME, RemoteServerUpCommand.NAME, + LockReleasedCommand.NAME, ) # The commands the client is allowed to send @@ -461,6 +539,7 @@ VALID_CLIENT_COMMANDS = ( UserIpCommand.NAME, ErrorCommand.NAME, RemoteServerUpCommand.NAME, + LockReleasedCommand.NAME, ) diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 233ad61d49..b668bb5da1 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -17,6 +17,7 @@ from typing import ( TYPE_CHECKING, Any, Awaitable, + Deque, Dict, Iterable, Iterator, @@ -29,7 +30,6 @@ from typing import ( ) from prometheus_client import Counter -from typing_extensions import Deque from twisted.internet.protocol import ReconnectingClientFactory @@ -39,6 +39,8 @@ from synapse.replication.tcp.commands import ( ClearUserSyncsCommand, Command, FederationAckCommand, + LockReleasedCommand, + NewActiveTaskCommand, PositionCommand, RdataCommand, RemoteServerUpCommand, @@ -237,6 +239,10 @@ class ReplicationCommandHandler: if self._is_master: self._server_notices_sender = hs.get_server_notices_sender() + self._task_scheduler = None + if hs.config.worker.run_background_tasks: + self._task_scheduler = hs.get_task_scheduler() + if hs.config.redis.redis_enabled: # If we're using Redis, it's the background worker that should # receive USER_IP commands and store the relevant client IPs. @@ -248,6 +254,9 @@ class ReplicationCommandHandler: if self._is_master or self._should_insert_client_ips: self.subscribe_to_channel("USER_IP") + if hs.config.redis.redis_enabled: + self._notifier.add_lock_released_callback(self.on_lock_released) + def subscribe_to_channel(self, channel_name: str) -> None: """ Indicates that we wish to subscribe to a Redis channel by name. @@ -352,7 +361,15 @@ class ReplicationCommandHandler: reactor = hs.get_reactor() redis_config = hs.config.redis - if hs.config.redis.redis_use_tls: + if redis_config.redis_path is not None: + reactor.connectUNIX( + redis_config.redis_path, + self._factory, + timeout=30, + checkPID=False, + ) + + elif hs.config.redis.redis_use_tls: ssl_context_factory = ClientContextFactory(hs.config.redis) reactor.connectSSL( redis_config.redis_host, @@ -411,7 +428,11 @@ class ReplicationCommandHandler: if self._is_presence_writer: return self._presence_handler.update_external_syncs_row( - cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms + cmd.instance_id, + cmd.user_id, + cmd.device_id, + cmd.is_syncing, + cmd.last_sync_ms, ) else: return None @@ -623,7 +644,7 @@ class ReplicationCommandHandler: [stream.parse_row(row) for row in rows], ) - logger.info("Caught up with stream '%s' to %i", stream_name, cmd.new_token) + logger.info("Caught up with stream '%s' to %i", stream_name, cmd.new_token) # We've now caught up to position sent to us, notify handler. await self._replication_data_handler.on_position( @@ -640,6 +661,24 @@ class ReplicationCommandHandler: self._notifier.notify_remote_server_up(cmd.data) + def on_LOCK_RELEASED( + self, conn: IReplicationConnection, cmd: LockReleasedCommand + ) -> None: + """Called when we get a new LOCK_RELEASED command.""" + if cmd.instance_name == self._instance_name: + return + + self._notifier.notify_lock_released( + cmd.instance_name, cmd.lock_name, cmd.lock_key + ) + + def on_NEW_ACTIVE_TASK( + self, conn: IReplicationConnection, cmd: NewActiveTaskCommand + ) -> None: + """Called when get a new NEW_ACTIVE_TASK command.""" + if self._task_scheduler: + self._task_scheduler.launch_task_by_id(cmd.data) + def new_connection(self, connection: IReplicationConnection) -> None: """Called when we have a new connection.""" self._connections.append(connection) @@ -662,9 +701,9 @@ class ReplicationCommandHandler: ) now = self._clock.time_msec() - for user_id in currently_syncing: + for user_id, device_id in currently_syncing: connection.send_command( - UserSyncCommand(self._instance_id, user_id, True, now) + UserSyncCommand(self._instance_id, user_id, device_id, True, now) ) def lost_connection(self, connection: IReplicationConnection) -> None: @@ -716,11 +755,16 @@ class ReplicationCommandHandler: self.send_command(FederationAckCommand(self._instance_name, token)) def send_user_sync( - self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int + self, + instance_id: str, + user_id: str, + device_id: Optional[str], + is_syncing: bool, + last_sync_ms: int, ) -> None: """Poke the master that a user has started/stopped syncing.""" self.send_command( - UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms) + UserSyncCommand(instance_id, user_id, device_id, is_syncing, last_sync_ms) ) def send_user_ip( @@ -746,6 +790,17 @@ class ReplicationCommandHandler: """ self.send_command(RdataCommand(stream_name, self._instance_name, token, data)) + def on_lock_released( + self, instance_name: str, lock_name: str, lock_key: str + ) -> None: + """Called when we released a lock and should notify other instances.""" + if instance_name == self._instance_name: + self.send_command(LockReleasedCommand(instance_name, lock_name, lock_key)) + + def send_new_active_task(self, task_id: str) -> None: + """Called when a new task has been scheduled for immediate launch and is ACTIVE.""" + self.send_command(NewActiveTaskCommand(task_id)) + UpdateToken = TypeVar("UpdateToken") UpdateRow = TypeVar("UpdateRow") diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py index c8f4bf8b27..7e96145b3b 100644 --- a/synapse/replication/tcp/redis.py +++ b/synapse/replication/tcp/redis.py @@ -17,7 +17,12 @@ from inspect import isawaitable from typing import TYPE_CHECKING, Any, Generic, List, Optional, Type, TypeVar, cast import attr -import txredisapi +from txredisapi import ( + ConnectionHandler, + RedisFactory, + SubscriberProtocol, + UnixConnectionHandler, +) from zope.interface import implementer from twisted.internet.address import IPv4Address, IPv6Address @@ -68,7 +73,7 @@ class ConstantProperty(Generic[T, V]): @implementer(IReplicationConnection) -class RedisSubscriber(txredisapi.SubscriberProtocol): +class RedisSubscriber(SubscriberProtocol): """Connection to redis subscribed to replication stream. This class fulfils two functions: @@ -95,7 +100,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol): synapse_handler: "ReplicationCommandHandler" synapse_stream_prefix: str synapse_channel_names: List[str] - synapse_outbound_redis_connection: txredisapi.ConnectionHandler + synapse_outbound_redis_connection: ConnectionHandler def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) @@ -229,7 +234,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol): ) -class SynapseRedisFactory(txredisapi.RedisFactory): +class SynapseRedisFactory(RedisFactory): """A subclass of RedisFactory that periodically sends pings to ensure that we detect dead connections. """ @@ -245,7 +250,7 @@ class SynapseRedisFactory(txredisapi.RedisFactory): dbid: Optional[int], poolsize: int, isLazy: bool = False, - handler: Type = txredisapi.ConnectionHandler, + handler: Type = ConnectionHandler, charset: str = "utf-8", password: Optional[str] = None, replyTimeout: int = 30, @@ -326,7 +331,7 @@ class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory): def __init__( self, hs: "HomeServer", - outbound_redis_connection: txredisapi.ConnectionHandler, + outbound_redis_connection: ConnectionHandler, channel_names: List[str], ): super().__init__( @@ -368,7 +373,7 @@ def lazyConnection( reconnect: bool = True, password: Optional[str] = None, replyTimeout: int = 30, -) -> txredisapi.ConnectionHandler: +) -> ConnectionHandler: """Creates a connection to Redis that is lazily set up and reconnects if the connections is lost. """ @@ -380,7 +385,7 @@ def lazyConnection( dbid=dbid, poolsize=1, isLazy=True, - handler=txredisapi.ConnectionHandler, + handler=ConnectionHandler, password=password, replyTimeout=replyTimeout, ) @@ -408,3 +413,44 @@ def lazyConnection( ) return factory.handler + + +def lazyUnixConnection( + hs: "HomeServer", + path: str = "/tmp/redis.sock", + dbid: Optional[int] = None, + reconnect: bool = True, + password: Optional[str] = None, + replyTimeout: int = 30, +) -> ConnectionHandler: + """Creates a connection to Redis that is lazily set up and reconnects if the + connection is lost. + + Returns: + A subclass of ConnectionHandler, which is a UnixConnectionHandler in this case. + """ + + uuid = path + + factory = SynapseRedisFactory( + hs, + uuid=uuid, + dbid=dbid, + poolsize=1, + isLazy=True, + handler=UnixConnectionHandler, + password=password, + replyTimeout=replyTimeout, + ) + factory.continueTrying = reconnect + + reactor = hs.get_reactor() + + reactor.connectUNIX( + path, + factory, + timeout=30, + checkPID=False, + ) + + return factory.handler diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index 347467d863..1d9a29d22e 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -191,7 +191,12 @@ class ReplicationStreamer: if updates: logger.info( - "Streaming: %s -> %s", stream.NAME, updates[-1][0] + "Streaming: %s -> %s (limited: %s, updates: %s, max token: %s)", + stream.NAME, + updates[-1][0], + limited, + len(updates), + current_token, ) stream_updates_counter.labels(stream.NAME).inc(len(updates)) diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index 1af8d99d20..1be9c47c61 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -48,7 +48,6 @@ from synapse.rest.client import ( rendezvous, report_event, room, - room_batch, room_keys, room_upgrade_rest_servlet, sendtodevice, @@ -124,7 +123,7 @@ class ClientRestResource(JsonResource): if is_main_process: report_event.register_servlets(hs, client_resource) openid.register_servlets(hs, client_resource) - notifications.register_servlets(hs, client_resource) + notifications.register_servlets(hs, client_resource) devices.register_servlets(hs, client_resource) if is_main_process: thirdparty.register_servlets(hs, client_resource) @@ -132,7 +131,6 @@ class ClientRestResource(JsonResource): user_directory.register_servlets(hs, client_resource) if is_main_process: room_upgrade_rest_servlet.register_servlets(hs, client_resource) - room_batch.register_servlets(hs, client_resource) capabilities.register_servlets(hs, client_resource) if is_main_process: account_validity.register_servlets(hs, client_resource) diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index c729364839..9bd0d764f8 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -16,11 +16,11 @@ # limitations under the License. import logging -import platform from http import HTTPStatus from typing import TYPE_CHECKING, Optional, Tuple from synapse.api.errors import Codes, NotFoundError, SynapseError +from synapse.handlers.pagination import PURGE_HISTORY_ACTION_NAME from synapse.http.server import HttpServer, JsonResource from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.site import SynapseRequest @@ -93,7 +93,7 @@ from synapse.rest.admin.users import ( UserTokenRestServlet, WhoisRestServlet, ) -from synapse.types import JsonDict, RoomStreamToken +from synapse.types import JsonDict, RoomStreamToken, TaskStatus from synapse.util import SYNAPSE_VERSION if TYPE_CHECKING: @@ -106,10 +106,7 @@ class VersionServlet(RestServlet): PATTERNS = admin_patterns("/server_version$") def __init__(self, hs: "HomeServer"): - self.res = { - "server_version": SYNAPSE_VERSION, - "python_version": platform.python_version(), - } + self.res = {"server_version": SYNAPSE_VERSION} def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: return HTTPStatus.OK, self.res @@ -149,14 +146,14 @@ class PurgeHistoryRestServlet(RestServlet): # RoomStreamToken expects [int] not Optional[int] assert event.internal_metadata.stream_ordering is not None room_token = RoomStreamToken( - event.depth, event.internal_metadata.stream_ordering + topological=event.depth, stream=event.internal_metadata.stream_ordering ) token = await room_token.to_string(self.store) logger.info("[purge] purging up to token %s (event_id %s)", token, event_id) elif "purge_up_to_ts" in body: ts = body["purge_up_to_ts"] - if type(ts) is not int: + if type(ts) is not int: # noqa: E721 raise SynapseError( HTTPStatus.BAD_REQUEST, "purge_up_to_ts must be an int", @@ -196,7 +193,7 @@ class PurgeHistoryRestServlet(RestServlet): errcode=Codes.BAD_JSON, ) - purge_id = self.pagination_handler.start_purge_history( + purge_id = await self.pagination_handler.start_purge_history( room_id, token, delete_local_events=delete_local_events ) @@ -215,11 +212,20 @@ class PurgeHistoryStatusRestServlet(RestServlet): ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) - purge_status = self.pagination_handler.get_purge_status(purge_id) - if purge_status is None: + purge_task = await self.pagination_handler.get_delete_task(purge_id) + if purge_task is None or purge_task.action != PURGE_HISTORY_ACTION_NAME: raise NotFoundError("purge id '%s' not found" % purge_id) - return HTTPStatus.OK, purge_status.asdict() + result: JsonDict = { + "status": purge_task.status + if purge_task.status == TaskStatus.COMPLETE + or purge_task.status == TaskStatus.FAILED + else "active", + } + if purge_task.error: + result["error"] = purge_task.error + + return HTTPStatus.OK, result ######################################################################################## @@ -257,9 +263,11 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: DeleteRoomStatusByRoomIdRestServlet(hs).register(http_server) JoinRoomAliasServlet(hs).register(http_server) VersionServlet(hs).register(http_server) - UserAdminServlet(hs).register(http_server) + if not hs.config.experimental.msc3861.enabled: + UserAdminServlet(hs).register(http_server) UserMembershipRestServlet(hs).register(http_server) - UserTokenRestServlet(hs).register(http_server) + if not hs.config.experimental.msc3861.enabled: + UserTokenRestServlet(hs).register(http_server) UserRestServletV2(hs).register(http_server) UsersRestServletV2(hs).register(http_server) UserMediaStatisticsRestServlet(hs).register(http_server) @@ -274,9 +282,10 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: RoomEventContextServlet(hs).register(http_server) RateLimitRestServlet(hs).register(http_server) UsernameAvailableRestServlet(hs).register(http_server) - ListRegistrationTokensRestServlet(hs).register(http_server) - NewRegistrationTokenRestServlet(hs).register(http_server) - RegistrationTokenRestServlet(hs).register(http_server) + if not hs.config.experimental.msc3861.enabled: + ListRegistrationTokensRestServlet(hs).register(http_server) + NewRegistrationTokenRestServlet(hs).register(http_server) + RegistrationTokenRestServlet(hs).register(http_server) DestinationMembershipRestServlet(hs).register(http_server) DestinationResetConnectionRestServlet(hs).register(http_server) DestinationRestServlet(hs).register(http_server) @@ -306,10 +315,12 @@ def register_servlets_for_client_rest_resource( # The following resources can only be run on the main process. if hs.config.worker.worker_app is None: DeactivateAccountRestServlet(hs).register(http_server) - ResetPasswordRestServlet(hs).register(http_server) + if not hs.config.experimental.msc3861.enabled: + ResetPasswordRestServlet(hs).register(http_server) SearchUsersRestServlet(hs).register(http_server) - UserRegisterServlet(hs).register(http_server) - AccountValidityRenewServlet(hs).register(http_server) + if not hs.config.experimental.msc3861.enabled: + UserRegisterServlet(hs).register(http_server) + AccountValidityRenewServlet(hs).register(http_server) # Load the media repo ones if we're using them. Otherwise load the servlets which # don't need a media repo (typically readonly admin APIs). diff --git a/synapse/rest/admin/federation.py b/synapse/rest/admin/federation.py index e0ee55bd0e..8a617af599 100644 --- a/synapse/rest/admin/federation.py +++ b/synapse/rest/admin/federation.py @@ -198,7 +198,13 @@ class DestinationMembershipRestServlet(RestServlet): rooms, total = await self._store.get_destination_rooms_paginate( destination, start, limit, direction ) - response = {"rooms": rooms, "total": total} + response = { + "rooms": [ + {"room_id": room_id, "stream_ordering": stream_ordering} + for room_id, stream_ordering in rooms + ], + "total": total, + } if (start + limit) < total: response["next_token"] = str(start + len(rooms)) diff --git a/synapse/rest/admin/registration_tokens.py b/synapse/rest/admin/registration_tokens.py index 95e751288b..ffce92d45e 100644 --- a/synapse/rest/admin/registration_tokens.py +++ b/synapse/rest/admin/registration_tokens.py @@ -143,7 +143,7 @@ class NewRegistrationTokenRestServlet(RestServlet): else: # Get length of token to generate (default is 16) length = body.get("length", 16) - if type(length) is not int: + if type(length) is not int: # noqa: E721 raise SynapseError( HTTPStatus.BAD_REQUEST, "length must be an integer", @@ -163,7 +163,8 @@ class NewRegistrationTokenRestServlet(RestServlet): uses_allowed = body.get("uses_allowed", None) if not ( - uses_allowed is None or (type(uses_allowed) is int and uses_allowed >= 0) + uses_allowed is None + or (type(uses_allowed) is int and uses_allowed >= 0) # noqa: E721 ): raise SynapseError( HTTPStatus.BAD_REQUEST, @@ -172,13 +173,16 @@ class NewRegistrationTokenRestServlet(RestServlet): ) expiry_time = body.get("expiry_time", None) - if type(expiry_time) not in (int, type(None)): + if expiry_time is not None and type(expiry_time) is not int: # noqa: E721 raise SynapseError( HTTPStatus.BAD_REQUEST, "expiry_time must be an integer or null", Codes.INVALID_PARAM, ) - if type(expiry_time) is int and expiry_time < self.clock.time_msec(): + if ( + type(expiry_time) is int # noqa: E721 + and expiry_time < self.clock.time_msec() + ): raise SynapseError( HTTPStatus.BAD_REQUEST, "expiry_time must not be in the past", @@ -283,7 +287,7 @@ class RegistrationTokenRestServlet(RestServlet): uses_allowed = body["uses_allowed"] if not ( uses_allowed is None - or (type(uses_allowed) is int and uses_allowed >= 0) + or (type(uses_allowed) is int and uses_allowed >= 0) # noqa: E721 ): raise SynapseError( HTTPStatus.BAD_REQUEST, @@ -294,13 +298,16 @@ class RegistrationTokenRestServlet(RestServlet): if "expiry_time" in body: expiry_time = body["expiry_time"] - if type(expiry_time) not in (int, type(None)): + if expiry_time is not None and type(expiry_time) is not int: # noqa: E721 raise SynapseError( HTTPStatus.BAD_REQUEST, "expiry_time must be an integer or null", Codes.INVALID_PARAM, ) - if type(expiry_time) is int and expiry_time < self.clock.time_msec(): + if ( + type(expiry_time) is int # noqa: E721 + and expiry_time < self.clock.time_msec() + ): raise SynapseError( HTTPStatus.BAD_REQUEST, "expiry_time must not be in the past", diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 1d65560265..436718c8b2 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -19,6 +19,10 @@ from urllib import parse as urlparse from synapse.api.constants import Direction, EventTypes, JoinRules, Membership from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError from synapse.api.filtering import Filter +from synapse.handlers.pagination import ( + PURGE_ROOM_ACTION_NAME, + SHUTDOWN_AND_PURGE_ROOM_ACTION_NAME, +) from synapse.http.servlet import ( ResolveRoomIdMixin, RestServlet, @@ -36,7 +40,7 @@ from synapse.rest.admin._base import ( ) from synapse.storage.databases.main.room import RoomSortOrder from synapse.streams.config import PaginationConfig -from synapse.types import JsonDict, RoomID, UserID, create_requester +from synapse.types import JsonDict, RoomID, ScheduledTask, UserID, create_requester from synapse.types.state import StateFilter from synapse.util import json_decoder @@ -117,20 +121,30 @@ class RoomRestV2Servlet(RestServlet): 403, "Shutdown of this room is forbidden", Codes.FORBIDDEN ) - delete_id = self._pagination_handler.start_shutdown_and_purge_room( + delete_id = await self._pagination_handler.start_shutdown_and_purge_room( room_id=room_id, - new_room_user_id=content.get("new_room_user_id"), - new_room_name=content.get("room_name"), - message=content.get("message"), - requester_user_id=requester.user.to_string(), - block=block, - purge=purge, - force_purge=force_purge, + shutdown_params={ + "new_room_user_id": content.get("new_room_user_id"), + "new_room_name": content.get("room_name"), + "message": content.get("message"), + "requester_user_id": requester.user.to_string(), + "block": block, + "purge": purge, + "force_purge": force_purge, + }, ) return HTTPStatus.OK, {"delete_id": delete_id} +def _convert_delete_task_to_response(task: ScheduledTask) -> JsonDict: + return { + "delete_id": task.id, + "status": task.status, + "shutdown_room": task.result, + } + + class DeleteRoomStatusByRoomIdRestServlet(RestServlet): """Get the status of the delete room background task.""" @@ -150,21 +164,16 @@ class DeleteRoomStatusByRoomIdRestServlet(RestServlet): HTTPStatus.BAD_REQUEST, "%s is not a legal room ID" % (room_id,) ) - delete_ids = self._pagination_handler.get_delete_ids_by_room(room_id) - if delete_ids is None: - raise NotFoundError("No delete task for room_id '%s' found" % room_id) + delete_tasks = await self._pagination_handler.get_delete_tasks_by_room(room_id) - response = [] - for delete_id in delete_ids: - delete = self._pagination_handler.get_delete_status(delete_id) - if delete: - response += [ - { - "delete_id": delete_id, - **delete.asdict(), - } - ] - return HTTPStatus.OK, {"results": cast(JsonDict, response)} + if delete_tasks: + return HTTPStatus.OK, { + "results": [ + _convert_delete_task_to_response(task) for task in delete_tasks + ], + } + else: + raise NotFoundError("No delete task for room_id '%s' found" % room_id) class DeleteRoomStatusByDeleteIdRestServlet(RestServlet): @@ -181,11 +190,14 @@ class DeleteRoomStatusByDeleteIdRestServlet(RestServlet): ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self._auth, request) - delete_status = self._pagination_handler.get_delete_status(delete_id) - if delete_status is None: + delete_task = await self._pagination_handler.get_delete_task(delete_id) + if delete_task is None or ( + delete_task.action != PURGE_ROOM_ACTION_NAME + and delete_task.action != SHUTDOWN_AND_PURGE_ROOM_ACTION_NAME + ): raise NotFoundError("delete id '%s' not found" % delete_id) - return HTTPStatus.OK, cast(JsonDict, delete_status.asdict()) + return HTTPStatus.OK, _convert_delete_task_to_response(delete_task) class ListRoomRestServlet(RestServlet): @@ -349,11 +361,15 @@ class RoomRestServlet(RestServlet): ret = await room_shutdown_handler.shutdown_room( room_id=room_id, - new_room_user_id=content.get("new_room_user_id"), - new_room_name=content.get("room_name"), - message=content.get("message"), - requester_user_id=requester.user.to_string(), - block=block, + params={ + "new_room_user_id": content.get("new_room_user_id"), + "new_room_name": content.get("room_name"), + "message": content.get("message"), + "requester_user_id": requester.user.to_string(), + "block": block, + "purge": purge, + "force_purge": force_purge, + }, ) # Purge room diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 932333ae57..5b743a1d03 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -28,6 +28,7 @@ from synapse.http.servlet import ( parse_integer, parse_json_object_from_request, parse_string, + parse_strings_from_args, ) from synapse.http.site import SynapseRequest from synapse.rest.admin._base import ( @@ -38,7 +39,7 @@ from synapse.rest.admin._base import ( from synapse.rest.client._base import client_patterns from synapse.storage.databases.main.registration import ExternalIDReuseException from synapse.storage.databases.main.stats import UserSortOrder -from synapse.types import JsonDict, UserID +from synapse.types import JsonDict, JsonMapping, UserID if TYPE_CHECKING: from synapse.server import HomeServer @@ -64,6 +65,10 @@ class UsersRestServletV2(RestServlet): The parameter `guests` can be used to exclude guest users. The parameter `deactivated` can be used to include deactivated users. The parameter `order_by` can be used to order the result. + The parameter `not_user_type` can be used to exclude certain user types. + The parameter `locked` can be used to include locked users. + Possible values are `bot`, `support` or "empty string". + "empty string" here means to exclude users without a type. """ def __init__(self, hs: "HomeServer"): @@ -71,6 +76,7 @@ class UsersRestServletV2(RestServlet): self.auth = hs.get_auth() self.admin_handler = hs.get_admin_handler() self._msc3866_enabled = hs.config.experimental.msc3866.enabled + self._msc3861_enabled = hs.config.experimental.msc3861.enabled async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) @@ -94,8 +100,18 @@ class UsersRestServletV2(RestServlet): user_id = parse_string(request, "user_id") name = parse_string(request, "name") + guests = parse_boolean(request, "guests", default=True) + if self._msc3861_enabled and guests: + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "The guests parameter is not supported when MSC3861 is enabled.", + errcode=Codes.INVALID_PARAM, + ) + deactivated = parse_boolean(request, "deactivated", default=False) + locked = parse_boolean(request, "locked", default=False) + admins = parse_boolean(request, "admins") # If support for MSC3866 is not enabled, apply no filtering based on the # `approved` column. @@ -118,11 +134,17 @@ class UsersRestServletV2(RestServlet): UserSortOrder.AVATAR_URL.value, UserSortOrder.SHADOW_BANNED.value, UserSortOrder.CREATION_TS.value, + UserSortOrder.LAST_SEEN_TS.value, + UserSortOrder.LOCKED.value, ), ) direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS) + # twisted.web.server.Request.args is incorrectly defined as Optional[Any] + args: Dict[bytes, List[bytes]] = request.args # type: ignore + not_user_types = parse_strings_from_args(args, "not_user_type") + users, total = await self.store.get_users_paginate( start, limit, @@ -130,9 +152,12 @@ class UsersRestServletV2(RestServlet): name, guests, deactivated, + admins, order_by, direction, approved, + not_user_types, + locked, ) # If support for MSC3866 is not enabled, don't show the approval flag. @@ -190,7 +215,7 @@ class UserRestServletV2(RestServlet): async def on_GET( self, request: SynapseRequest, user_id: str - ) -> Tuple[int, JsonDict]: + ) -> Tuple[int, JsonMapping]: await assert_requester_is_admin(self.auth, request) target_user = UserID.from_string(user_id) @@ -205,7 +230,7 @@ class UserRestServletV2(RestServlet): async def on_PUT( self, request: SynapseRequest, user_id: str - ) -> Tuple[int, JsonDict]: + ) -> Tuple[int, JsonMapping]: requester = await self.auth.get_user_by_req(request) await assert_user_is_admin(self.auth, requester) @@ -263,6 +288,17 @@ class UserRestServletV2(RestServlet): HTTPStatus.BAD_REQUEST, "'deactivated' parameter is not of type boolean" ) + lock = body.get("locked", False) + if not isinstance(lock, bool): + raise SynapseError( + HTTPStatus.BAD_REQUEST, "'locked' parameter is not of type boolean" + ) + + if deactivate and lock: + raise SynapseError( + HTTPStatus.BAD_REQUEST, "An user can't be deactivated and locked" + ) + approved: Optional[bool] = None if "approved" in body and self._msc3866_enabled: approved = body["approved"] @@ -380,6 +416,12 @@ class UserRestServletV2(RestServlet): target_user.to_string() ) + if "locked" in body: + if lock and not user["locked"]: + await self.store.set_user_locked_status(user_id, True) + elif not lock and user["locked"]: + await self.store.set_user_locked_status(user_id, False) + if "user_type" in body: await self.store.set_user_type(target_user, user_type) @@ -620,7 +662,7 @@ class WhoisRestServlet(RestServlet): async def on_GET( self, request: SynapseRequest, user_id: str - ) -> Tuple[int, JsonDict]: + ) -> Tuple[int, JsonMapping]: target_user = UserID.from_string(user_id) requester = await self.auth.get_user_by_req(request) @@ -1135,14 +1177,17 @@ class RateLimitRestServlet(RestServlet): messages_per_second = body.get("messages_per_second", 0) burst_count = body.get("burst_count", 0) - if type(messages_per_second) is not int or messages_per_second < 0: + if ( + type(messages_per_second) is not int # noqa: E721 + or messages_per_second < 0 + ): raise SynapseError( HTTPStatus.BAD_REQUEST, "%r parameter must be a positive int" % (messages_per_second,), errcode=Codes.INVALID_PARAM, ) - if type(burst_count) is not int or burst_count < 0: + if type(burst_count) is not int or burst_count < 0: # noqa: E721 raise SynapseError( HTTPStatus.BAD_REQUEST, "%r parameter must be a positive int" % (burst_count,), diff --git a/synapse/rest/client/_base.py b/synapse/rest/client/_base.py index 5c1c19e1f3..73c568ef75 100644 --- a/synapse/rest/client/_base.py +++ b/synapse/rest/client/_base.py @@ -20,14 +20,14 @@ from typing import Any, Awaitable, Callable, Iterable, Pattern, Tuple, TypeVar, from synapse.api.errors import InteractiveAuthIncompleteError from synapse.api.urls import CLIENT_API_PREFIX -from synapse.types import JsonDict +from synapse.types import JsonDict, StrCollection logger = logging.getLogger(__name__) def client_patterns( path_regex: str, - releases: Iterable[str] = ("r0", "v3"), + releases: StrCollection = ("r0", "v3"), unstable: bool = True, v1: bool = False, ) -> Iterable[Pattern]: diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py index 3d0c55daa0..e74a87af4d 100644 --- a/synapse/rest/client/account.py +++ b/synapse/rest/client/account.py @@ -18,7 +18,12 @@ import random from typing import TYPE_CHECKING, List, Optional, Tuple from urllib.parse import urlparse -from pydantic import StrictBool, StrictStr, constr +from synapse._pydantic_compat import HAS_PYDANTIC_V2 + +if TYPE_CHECKING or HAS_PYDANTIC_V2: + from pydantic.v1 import StrictBool, StrictStr, constr +else: + from pydantic import StrictBool, StrictStr, constr from typing_extensions import Literal from twisted.web.server import Request @@ -27,6 +32,7 @@ from synapse.api.constants import LoginType from synapse.api.errors import ( Codes, InteractiveAuthIncompleteError, + NotFoundError, SynapseError, ThreepidValidationError, ) @@ -178,85 +184,81 @@ class PasswordRestServlet(RestServlet): # # In the second case, we require a password to confirm their identity. - requester = None - if self.auth.has_access_token(request): - requester = await self.auth.get_user_by_req(request) - try: + try: + requester = None + if self.auth.has_access_token(request): + requester = await self.auth.get_user_by_req(request) params, session_id = await self.auth_handler.validate_user_via_ui_auth( requester, request, - body.dict(exclude_unset=True), + body.dict(exclude_unset=True, exclude={"new_password"}), "modify your account password", ) - except InteractiveAuthIncompleteError as e: - # The user needs to provide more steps to complete auth, but - # they're not required to provide the password again. - # - # If a password is available now, hash the provided password and - # store it for later. - if new_password: - new_password_hash = await self.auth_handler.hash(new_password) - await self.auth_handler.set_session_data( - e.session_id, - UIAuthSessionDataConstants.PASSWORD_HASH, - new_password_hash, - ) - raise - user_id = requester.user.to_string() - else: - try: + user_id = requester.user.to_string() + else: result, params, session_id = await self.auth_handler.check_ui_auth( [[LoginType.EMAIL_IDENTITY]], request, - body.dict(exclude_unset=True), + body.dict(exclude_unset=True, exclude={"new_password"}), "modify your account password", ) - except InteractiveAuthIncompleteError as e: - # The user needs to provide more steps to complete auth, but - # they're not required to provide the password again. - # - # If a password is available now, hash the provided password and - # store it for later. - if new_password: - new_password_hash = await self.auth_handler.hash(new_password) - await self.auth_handler.set_session_data( - e.session_id, - UIAuthSessionDataConstants.PASSWORD_HASH, - new_password_hash, + + if LoginType.EMAIL_IDENTITY in result: + threepid = result[LoginType.EMAIL_IDENTITY] + if "medium" not in threepid or "address" not in threepid: + raise SynapseError(500, "Malformed threepid") + if threepid["medium"] == "email": + # For emails, canonicalise the address. + # We store all email addresses canonicalised in the DB. + # (See add_threepid in synapse/handlers/auth.py) + try: + threepid["address"] = validate_email(threepid["address"]) + except ValueError as e: + raise SynapseError(400, str(e)) + # if using email, we must know about the email they're authing with! + threepid_user_id = await self.datastore.get_user_id_by_threepid( + threepid["medium"], threepid["address"] ) + if not threepid_user_id: + raise SynapseError( + 404, "Email address not found", Codes.NOT_FOUND + ) + user_id = threepid_user_id + else: + logger.error("Auth succeeded but no known type! %r", result.keys()) + raise SynapseError(500, "", Codes.UNKNOWN) + + except InteractiveAuthIncompleteError as e: + # The user needs to provide more steps to complete auth, but + # they're not required to provide the password again. + # + # If a password is available now, hash the provided password and + # store it for later. We only do this if we don't already have the + # password hash stored, to avoid repeatedly hashing the password. + + if not new_password: raise - if LoginType.EMAIL_IDENTITY in result: - threepid = result[LoginType.EMAIL_IDENTITY] - if "medium" not in threepid or "address" not in threepid: - raise SynapseError(500, "Malformed threepid") - if threepid["medium"] == "email": - # For emails, canonicalise the address. - # We store all email addresses canonicalised in the DB. - # (See add_threepid in synapse/handlers/auth.py) - try: - threepid["address"] = validate_email(threepid["address"]) - except ValueError as e: - raise SynapseError(400, str(e)) - # if using email, we must know about the email they're authing with! - threepid_user_id = await self.datastore.get_user_id_by_threepid( - threepid["medium"], threepid["address"] - ) - if not threepid_user_id: - raise SynapseError(404, "Email address not found", Codes.NOT_FOUND) - user_id = threepid_user_id - else: - logger.error("Auth succeeded but no known type! %r", result.keys()) - raise SynapseError(500, "", Codes.UNKNOWN) + existing_session_password_hash = await self.auth_handler.get_session_data( + e.session_id, UIAuthSessionDataConstants.PASSWORD_HASH, None + ) + if existing_session_password_hash: + raise + + new_password_hash = await self.auth_handler.hash(new_password) + await self.auth_handler.set_session_data( + e.session_id, + UIAuthSessionDataConstants.PASSWORD_HASH, + new_password_hash, + ) + raise # If we have a password in this request, prefer it. Otherwise, use the # password hash from an earlier request. if new_password: password_hash: Optional[str] = await self.auth_handler.hash(new_password) elif session_id is not None: - password_hash = await self.auth_handler.get_session_data( - session_id, UIAuthSessionDataConstants.PASSWORD_HASH, None - ) + password_hash = existing_session_password_hash else: # UI validation was skipped, but the request did not include a new # password. @@ -600,6 +602,9 @@ class ThreepidRestServlet(RestServlet): # ThreePidBindRestServelet.PostBody with an `alias_generator` to handle # `threePidCreds` versus `three_pid_creds`. async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + if self.hs.config.experimental.msc3861.enabled: + raise NotFoundError(errcode=Codes.UNRECOGNIZED) + if not self.hs.config.registration.enable_3pid_changes: raise SynapseError( 400, "3PID changes are disabled on this server", Codes.FORBIDDEN @@ -890,19 +895,21 @@ class AccountStatusRestServlet(RestServlet): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: if hs.config.worker.worker_app is None: - EmailPasswordRequestTokenRestServlet(hs).register(http_server) - PasswordRestServlet(hs).register(http_server) - DeactivateAccountRestServlet(hs).register(http_server) - EmailThreepidRequestTokenRestServlet(hs).register(http_server) - MsisdnThreepidRequestTokenRestServlet(hs).register(http_server) - AddThreepidEmailSubmitTokenServlet(hs).register(http_server) - AddThreepidMsisdnSubmitTokenServlet(hs).register(http_server) + if not hs.config.experimental.msc3861.enabled: + EmailPasswordRequestTokenRestServlet(hs).register(http_server) + DeactivateAccountRestServlet(hs).register(http_server) + PasswordRestServlet(hs).register(http_server) + EmailThreepidRequestTokenRestServlet(hs).register(http_server) + MsisdnThreepidRequestTokenRestServlet(hs).register(http_server) + AddThreepidEmailSubmitTokenServlet(hs).register(http_server) + AddThreepidMsisdnSubmitTokenServlet(hs).register(http_server) ThreepidRestServlet(hs).register(http_server) if hs.config.worker.worker_app is None: - ThreepidAddRestServlet(hs).register(http_server) ThreepidBindRestServlet(hs).register(http_server) ThreepidUnbindRestServlet(hs).register(http_server) - ThreepidDeleteRestServlet(hs).register(http_server) + if not hs.config.experimental.msc3861.enabled: + ThreepidAddRestServlet(hs).register(http_server) + ThreepidDeleteRestServlet(hs).register(http_server) WhoamiRestServlet(hs).register(http_server) if hs.config.worker.worker_app is None and hs.config.experimental.msc3720_enabled: diff --git a/synapse/rest/client/account_data.py b/synapse/rest/client/account_data.py index b1f9e9dc9b..ce0c4e7742 100644 --- a/synapse/rest/client/account_data.py +++ b/synapse/rest/client/account_data.py @@ -20,7 +20,7 @@ from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.site import SynapseRequest -from synapse.types import JsonDict, RoomID +from synapse.types import JsonDict, JsonMapping, RoomID from ._base import client_patterns @@ -95,7 +95,7 @@ class AccountDataServlet(RestServlet): async def on_GET( self, request: SynapseRequest, user_id: str, account_data_type: str - ) -> Tuple[int, JsonDict]: + ) -> Tuple[int, JsonMapping]: requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot get account data for other users.") @@ -106,7 +106,7 @@ class AccountDataServlet(RestServlet): and account_data_type == AccountDataTypes.PUSH_RULES ): account_data: Optional[ - JsonDict + JsonMapping ] = await self._push_rules_handler.push_rules_for_user(requester.user) else: account_data = await self.store.get_global_account_data_by_type_for_user( @@ -236,7 +236,7 @@ class RoomAccountDataServlet(RestServlet): user_id: str, room_id: str, account_data_type: str, - ) -> Tuple[int, JsonDict]: + ) -> Tuple[int, JsonMapping]: requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot get account data for other users.") @@ -253,7 +253,7 @@ class RoomAccountDataServlet(RestServlet): self._hs.config.experimental.msc4010_push_rules_account_data and account_data_type == AccountDataTypes.PUSH_RULES ): - account_data: Optional[JsonDict] = {} + account_data: Optional[JsonMapping] = {} else: account_data = await self.store.get_account_data_for_room_and_type( user_id, room_id, account_data_type diff --git a/synapse/rest/client/capabilities.py b/synapse/rest/client/capabilities.py index 0dbf8f6818..3154b9f77e 100644 --- a/synapse/rest/client/capabilities.py +++ b/synapse/rest/client/capabilities.py @@ -65,6 +65,9 @@ class CapabilitiesRestServlet(RestServlet): "m.3pid_changes": { "enabled": self.config.registration.enable_3pid_changes }, + "m.get_login_token": { + "enabled": self.config.auth.login_via_existing_enabled, + }, } } diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py index e97d0bf475..80ae937921 100644 --- a/synapse/rest/client/devices.py +++ b/synapse/rest/client/devices.py @@ -14,17 +14,24 @@ # limitations under the License. import logging +from http import HTTPStatus from typing import TYPE_CHECKING, List, Optional, Tuple -from pydantic import Extra, StrictStr +from synapse._pydantic_compat import HAS_PYDANTIC_V2 + +if TYPE_CHECKING or HAS_PYDANTIC_V2: + from pydantic.v1 import Extra, StrictStr +else: + from pydantic import Extra, StrictStr from synapse.api import errors -from synapse.api.errors import NotFoundError +from synapse.api.errors import NotFoundError, SynapseError, UnrecognizedRequestError from synapse.handlers.device import DeviceHandler from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, parse_and_validate_json_object_from_request, + parse_integer, ) from synapse.http.site import SynapseRequest from synapse.rest.client._base import client_patterns, interactive_auth_handler @@ -135,6 +142,7 @@ class DeviceRestServlet(RestServlet): self.device_handler = handler self.auth_handler = hs.get_auth_handler() self._msc3852_enabled = hs.config.experimental.msc3852_enabled + self._msc3861_oauth_delegation_enabled = hs.config.experimental.msc3861.enabled async def on_GET( self, request: SynapseRequest, device_id: str @@ -166,6 +174,9 @@ class DeviceRestServlet(RestServlet): async def on_DELETE( self, request: SynapseRequest, device_id: str ) -> Tuple[int, JsonDict]: + if self._msc3861_oauth_delegation_enabled: + raise UnrecognizedRequestError(code=404) + requester = await self.auth.get_user_by_req(request) try: @@ -225,6 +236,8 @@ class DehydratedDeviceDataModel(RequestBodyModel): class DehydratedDeviceServlet(RestServlet): """Retrieve or store a dehydrated device. + Implements MSC2697. + GET /org.matrix.msc2697.v2/dehydrated_device HTTP/1.1 200 OK @@ -257,7 +270,10 @@ class DehydratedDeviceServlet(RestServlet): """ - PATTERNS = client_patterns("/org.matrix.msc2697.v2/dehydrated_device$", releases=()) + PATTERNS = client_patterns( + "/org.matrix.msc2697.v2/dehydrated_device$", + releases=(), + ) def __init__(self, hs: "HomeServer"): super().__init__() @@ -289,6 +305,7 @@ class DehydratedDeviceServlet(RestServlet): device_id = await self.device_handler.store_dehydrated_device( requester.user.to_string(), + None, submission.device_data.dict(), submission.initial_device_display_name, ) @@ -343,11 +360,216 @@ class ClaimDehydratedDeviceServlet(RestServlet): return 200, result +class DehydratedDeviceEventsServlet(RestServlet): + PATTERNS = client_patterns( + "/org.matrix.msc3814.v1/dehydrated_device/(?P<device_id>[^/]*)/events$", + releases=(), + ) + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.message_handler = hs.get_device_message_handler() + self.auth = hs.get_auth() + self.store = hs.get_datastores().main + + class PostBody(RequestBodyModel): + next_batch: Optional[StrictStr] + + async def on_POST( + self, request: SynapseRequest, device_id: str + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + + next_batch = parse_and_validate_json_object_from_request( + request, self.PostBody + ).next_batch + limit = parse_integer(request, "limit", 100) + + msgs = await self.message_handler.get_events_for_dehydrated_device( + requester=requester, + device_id=device_id, + since_token=next_batch, + limit=limit, + ) + + return 200, msgs + + +class DehydratedDeviceV2Servlet(RestServlet): + """Upload, retrieve, or delete a dehydrated device. + + GET /org.matrix.msc3814.v1/dehydrated_device + + HTTP/1.1 200 OK + Content-Type: application/json + + { + "device_id": "dehydrated_device_id", + "device_data": { + "algorithm": "org.matrix.msc2697.v1.dehydration.v1.olm", + "account": "dehydrated_device" + } + } + + PUT /org.matrix.msc3814.v1/dehydrated_device + Content-Type: application/json + + { + "device_id": "dehydrated_device_id", + "device_data": { + "algorithm": "org.matrix.msc2697.v1.dehydration.v1.olm", + "account": "dehydrated_device" + }, + "device_keys": { + "user_id": "<user_id>", + "device_id": "<device_id>", + "valid_until_ts": <millisecond_timestamp>, + "algorithms": [ + "m.olm.curve25519-aes-sha2", + ] + "keys": { + "<algorithm>:<device_id>": "<key_base64>", + }, + "signatures:" { + "<user_id>" { + "<algorithm>:<device_id>": "<signature_base64>" + } + } + }, + "fallback_keys": { + "<algorithm>:<device_id>": "<key_base64>", + "signed_<algorithm>:<device_id>": { + "fallback": true, + "key": "<key_base64>", + "signatures": { + "<user_id>": { + "<algorithm>:<device_id>": "<key_base64>" + } + } + } + } + "one_time_keys": { + "<algorithm>:<key_id>": "<key_base64>" + }, + + } + + HTTP/1.1 200 OK + Content-Type: application/json + + { + "device_id": "dehydrated_device_id" + } + + DELETE /org.matrix.msc3814.v1/dehydrated_device + + HTTP/1.1 200 OK + Content-Type: application/json + + { + "device_id": "dehydrated_device_id", + } + """ + + PATTERNS = [ + *client_patterns("/org.matrix.msc3814.v1/dehydrated_device$", releases=()), + ] + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.hs = hs + self.auth = hs.get_auth() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.e2e_keys_handler = hs.get_e2e_keys_handler() + self.device_handler = handler + + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + + dehydrated_device = await self.device_handler.get_dehydrated_device( + requester.user.to_string() + ) + + if dehydrated_device is not None: + (device_id, device_data) = dehydrated_device + result = {"device_id": device_id, "device_data": device_data} + return 200, result + else: + raise errors.NotFoundError("No dehydrated device available") + + async def on_DELETE(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + + dehydrated_device = await self.device_handler.get_dehydrated_device( + requester.user.to_string() + ) + + if dehydrated_device is not None: + (device_id, device_data) = dehydrated_device + + await self.device_handler.delete_dehydrated_device( + requester.user.to_string(), device_id + ) + + result = {"device_id": device_id} + + return 200, result + else: + raise errors.NotFoundError("No dehydrated device available") + + class PutBody(RequestBodyModel): + device_data: DehydratedDeviceDataModel + device_id: StrictStr + initial_device_display_name: Optional[StrictStr] + + class Config: + extra = Extra.allow + + async def on_PUT(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + submission = parse_and_validate_json_object_from_request(request, self.PutBody) + requester = await self.auth.get_user_by_req(request) + user_id = requester.user.to_string() + + old_dehydrated_device = await self.device_handler.get_dehydrated_device(user_id) + + # if an old device exists, delete it before creating a new one + if old_dehydrated_device: + await self.device_handler.delete_dehydrated_device( + user_id, old_dehydrated_device[0] + ) + + device_info = submission.dict() + if "device_keys" not in device_info.keys(): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Device key(s) not found, these must be provided.", + ) + + device_id = await self.device_handler.store_dehydrated_device( + requester.user.to_string(), + submission.device_id, + submission.device_data.dict(), + submission.initial_device_display_name, + device_info, + ) + + return 200, {"device_id": device_id} + + def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: - if hs.config.worker.worker_app is None: + if ( + hs.config.worker.worker_app is None + and not hs.config.experimental.msc3861.enabled + ): DeleteDevicesRestServlet(hs).register(http_server) DevicesRestServlet(hs).register(http_server) + if hs.config.worker.worker_app is None: DeviceRestServlet(hs).register(http_server) - DehydratedDeviceServlet(hs).register(http_server) - ClaimDehydratedDeviceServlet(hs).register(http_server) + if hs.config.experimental.msc2697_enabled: + DehydratedDeviceServlet(hs).register(http_server) + ClaimDehydratedDeviceServlet(hs).register(http_server) + if hs.config.experimental.msc3814_enabled: + DehydratedDeviceV2Servlet(hs).register(http_server) + DehydratedDeviceEventsServlet(hs).register(http_server) diff --git a/synapse/rest/client/directory.py b/synapse/rest/client/directory.py index 570bb52747..82944ca711 100644 --- a/synapse/rest/client/directory.py +++ b/synapse/rest/client/directory.py @@ -15,7 +15,13 @@ import logging from typing import TYPE_CHECKING, List, Optional, Tuple -from pydantic import StrictStr +from synapse._pydantic_compat import HAS_PYDANTIC_V2 + +if TYPE_CHECKING or HAS_PYDANTIC_V2: + from pydantic.v1 import StrictStr +else: + from pydantic import StrictStr + from typing_extensions import Literal from twisted.web.server import Request diff --git a/synapse/rest/client/filter.py b/synapse/rest/client/filter.py index 04561f36d7..b5879496db 100644 --- a/synapse/rest/client/filter.py +++ b/synapse/rest/client/filter.py @@ -19,7 +19,7 @@ from synapse.api.errors import AuthError, NotFoundError, StoreError, SynapseErro from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.site import SynapseRequest -from synapse.types import JsonDict, UserID +from synapse.types import JsonDict, JsonMapping, UserID from ._base import client_patterns, set_timeline_upper_limit @@ -41,7 +41,7 @@ class GetFilterRestServlet(RestServlet): async def on_GET( self, request: SynapseRequest, user_id: str, filter_id: str - ) -> Tuple[int, JsonDict]: + ) -> Tuple[int, JsonMapping]: target_user = UserID.from_string(user_id) requester = await self.auth.get_user_by_req(request) @@ -58,7 +58,7 @@ class GetFilterRestServlet(RestServlet): try: filter_collection = await self.filtering.get_user_filter( - user_localpart=target_user.localpart, filter_id=filter_id_int + user_id=target_user, filter_id=filter_id_int ) except StoreError as e: if e.code != 404: diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py index 9bbab5e624..70b8be1aa2 100644 --- a/synapse/rest/client/keys.py +++ b/synapse/rest/client/keys.py @@ -17,9 +17,10 @@ import logging import re from collections import Counter +from http import HTTPStatus from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple -from synapse.api.errors import InvalidAPICallError, SynapseError +from synapse.api.errors import Codes, InvalidAPICallError, SynapseError from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, @@ -287,7 +288,7 @@ class OneTimeKeyServlet(RestServlet): self.e2e_keys_handler = hs.get_e2e_keys_handler() async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - await self.auth.get_user_by_req(request, allow_guest=True) + requester = await self.auth.get_user_by_req(request, allow_guest=True) timeout = parse_integer(request, "timeout", 10 * 1000) body = parse_json_object_from_request(request) @@ -298,7 +299,7 @@ class OneTimeKeyServlet(RestServlet): query.setdefault(user_id, {})[device_id] = {algorithm: 1} result = await self.e2e_keys_handler.claim_one_time_keys( - query, timeout, always_include_fallback_keys=False + query, requester.user, timeout, always_include_fallback_keys=False ) return 200, result @@ -335,7 +336,7 @@ class UnstableOneTimeKeyServlet(RestServlet): self.e2e_keys_handler = hs.get_e2e_keys_handler() async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - await self.auth.get_user_by_req(request, allow_guest=True) + requester = await self.auth.get_user_by_req(request, allow_guest=True) timeout = parse_integer(request, "timeout", 10 * 1000) body = parse_json_object_from_request(request) @@ -346,7 +347,7 @@ class UnstableOneTimeKeyServlet(RestServlet): query.setdefault(user_id, {})[device_id] = Counter(algorithms) result = await self.e2e_keys_handler.claim_one_time_keys( - query, timeout, always_include_fallback_keys=True + query, requester.user, timeout, always_include_fallback_keys=True ) return 200, result @@ -375,9 +376,29 @@ class SigningKeyUploadServlet(RestServlet): user_id = requester.user.to_string() body = parse_json_object_from_request(request) - if self.hs.config.experimental.msc3967_enabled: - if await self.e2e_keys_handler.is_cross_signing_set_up_for_user(user_id): - # If we already have a master key then cross signing is set up and we require UIA to reset + is_cross_signing_setup = ( + await self.e2e_keys_handler.is_cross_signing_set_up_for_user(user_id) + ) + + # Before MSC3967 we required UIA both when setting up cross signing for the + # first time and when resetting the device signing key. With MSC3967 we only + # require UIA when resetting cross-signing, and not when setting up the first + # time. Because there is no UIA in MSC3861, for now we throw an error if the + # user tries to reset the device signing key when MSC3861 is enabled, but allow + # first-time setup. + if self.hs.config.experimental.msc3861.enabled: + # There is no way to reset the device signing key with MSC3861 + if is_cross_signing_setup: + raise SynapseError( + HTTPStatus.NOT_IMPLEMENTED, + "Resetting cross signing keys is not yet supported with MSC3861", + Codes.UNRECOGNIZED, + ) + # But first-time setup is fine + + elif self.hs.config.experimental.msc3967_enabled: + # If we already have a master key then cross signing is set up and we require UIA to reset + if is_cross_signing_setup: await self.auth_handler.validate_user_via_ui_auth( requester, request, @@ -387,6 +408,7 @@ class SigningKeyUploadServlet(RestServlet): can_skip_ui_auth=False, ) # Otherwise we don't require UIA since we are setting up cross signing for first time + else: # Previous behaviour is to always require UIA but allow it to be skipped await self.auth_handler.validate_user_via_ui_auth( diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index afdbf821b5..7be327e26f 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -35,6 +35,7 @@ from synapse.api.errors import ( LoginError, NotApprovedError, SynapseError, + UserDeactivatedError, ) from synapse.api.ratelimiting import Ratelimiter from synapse.api.urls import CLIENT_API_PREFIX @@ -49,7 +50,7 @@ from synapse.http.servlet import ( parse_json_object_from_request, parse_string, ) -from synapse.http.site import SynapseRequest +from synapse.http.site import RequestInfo, SynapseRequest from synapse.rest.client._base import client_patterns from synapse.rest.well_known import WellKnownBuilder from synapse.types import JsonDict, UserID @@ -84,6 +85,7 @@ class LoginRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs + self._main_store = hs.get_datastores().main # JWT configuration variables. self.jwt_enabled = hs.config.jwt.jwt_enabled @@ -102,6 +104,9 @@ class LoginRestServlet(RestServlet): and hs.config.experimental.msc3866.require_approval_for_new_accounts ) + # Whether get login token is enabled. + self._get_login_token_enabled = hs.config.auth.login_via_existing_enabled + self.auth = hs.get_auth() self.clock = hs.get_clock() @@ -109,19 +114,18 @@ class LoginRestServlet(RestServlet): self.auth_handler = self.hs.get_auth_handler() self.registration_handler = hs.get_registration_handler() self._sso_handler = hs.get_sso_handler() + self._spam_checker = hs.get_module_api_callbacks().spam_checker self._well_known_builder = WellKnownBuilder(hs) self._address_ratelimiter = Ratelimiter( - store=hs.get_datastores().main, + store=self._main_store, clock=hs.get_clock(), - rate_hz=self.hs.config.ratelimiting.rc_login_address.per_second, - burst_count=self.hs.config.ratelimiting.rc_login_address.burst_count, + cfg=self.hs.config.ratelimiting.rc_login_address, ) self._account_ratelimiter = Ratelimiter( - store=hs.get_datastores().main, + store=self._main_store, clock=hs.get_clock(), - rate_hz=self.hs.config.ratelimiting.rc_login_account.per_second, - burst_count=self.hs.config.ratelimiting.rc_login_account.burst_count, + cfg=self.hs.config.ratelimiting.rc_login_account, ) # ensure the CAS/SAML/OIDC handlers are loaded on this worker instance. @@ -140,6 +144,9 @@ class LoginRestServlet(RestServlet): # to SSO. flows.append({"type": LoginRestServlet.CAS_TYPE}) + # The login token flow requires m.login.token to be advertised. + support_login_token_flow = self._get_login_token_enabled + if self.cas_enabled or self.saml2_enabled or self.oidc_enabled: flows.append( { @@ -151,14 +158,23 @@ class LoginRestServlet(RestServlet): } ) - # While it's valid for us to advertise this login type generally, - # synapse currently only gives out these tokens as part of the - # SSO login flow. - # Generally we don't want to advertise login flows that clients - # don't know how to implement, since they (currently) will always - # fall back to the fallback API if they don't understand one of the - # login flow types returned. - flows.append({"type": LoginRestServlet.TOKEN_TYPE}) + # SSO requires a login token to be generated, so we need to advertise that flow + support_login_token_flow = True + + # While it's valid for us to advertise this login type generally, + # synapse currently only gives out these tokens as part of the + # SSO login flow or as part of login via an existing session. + # + # Generally we don't want to advertise login flows that clients + # don't know how to implement, since they (currently) will always + # fall back to the fallback API if they don't understand one of the + # login flow types returned. + if support_login_token_flow: + tokenTypeFlow: Dict[str, Any] = {"type": LoginRestServlet.TOKEN_TYPE} + # If the login token flow is enabled advertise the get_login_token flag. + if self._get_login_token_enabled: + tokenTypeFlow["get_login_token"] = True + flows.append(tokenTypeFlow) flows.extend({"type": t} for t in self.auth_handler.get_supported_login_types()) @@ -180,6 +196,8 @@ class LoginRestServlet(RestServlet): self._refresh_tokens_enabled and client_requested_refresh_token ) + request_info = request.request_info() + try: if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE: requester = await self.auth.get_user_by_req(request) @@ -199,6 +217,7 @@ class LoginRestServlet(RestServlet): login_submission, appservice, should_issue_refresh_token=should_issue_refresh_token, + request_info=request_info, ) elif ( self.jwt_enabled @@ -210,6 +229,7 @@ class LoginRestServlet(RestServlet): result = await self._do_jwt_login( login_submission, should_issue_refresh_token=should_issue_refresh_token, + request_info=request_info, ) elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE: await self._address_ratelimiter.ratelimit( @@ -218,6 +238,7 @@ class LoginRestServlet(RestServlet): result = await self._do_token_login( login_submission, should_issue_refresh_token=should_issue_refresh_token, + request_info=request_info, ) else: await self._address_ratelimiter.ratelimit( @@ -226,6 +247,7 @@ class LoginRestServlet(RestServlet): result = await self._do_other_login( login_submission, should_issue_refresh_token=should_issue_refresh_token, + request_info=request_info, ) except KeyError: raise SynapseError(400, "Missing JSON keys.") @@ -248,6 +270,8 @@ class LoginRestServlet(RestServlet): login_submission: JsonDict, appservice: ApplicationService, should_issue_refresh_token: bool = False, + *, + request_info: RequestInfo, ) -> LoginResponse: identifier = login_submission.get("identifier") logger.info("Got appservice login request with identifier: %r", identifier) @@ -280,10 +304,18 @@ class LoginRestServlet(RestServlet): login_submission, ratelimit=appservice.is_rate_limited(), should_issue_refresh_token=should_issue_refresh_token, + # The user represented by an appservice's configured sender_localpart + # is not actually created in Synapse. + should_check_deactivated=qualified_user_id != appservice.sender, + request_info=request_info, ) async def _do_other_login( - self, login_submission: JsonDict, should_issue_refresh_token: bool = False + self, + login_submission: JsonDict, + should_issue_refresh_token: bool = False, + *, + request_info: RequestInfo, ) -> LoginResponse: """Handle non-token/saml/jwt logins @@ -313,6 +345,7 @@ class LoginRestServlet(RestServlet): login_submission, callback, should_issue_refresh_token=should_issue_refresh_token, + request_info=request_info, ) return result @@ -326,6 +359,9 @@ class LoginRestServlet(RestServlet): auth_provider_id: Optional[str] = None, should_issue_refresh_token: bool = False, auth_provider_session_id: Optional[str] = None, + should_check_deactivated: bool = True, + *, + request_info: RequestInfo, ) -> LoginResponse: """Called when we've successfully authed the user and now need to actually login them in (e.g. create devices). This gets called on @@ -345,6 +381,12 @@ class LoginRestServlet(RestServlet): should_issue_refresh_token: True if this login should issue a refresh token alongside the access token. auth_provider_session_id: The session ID got during login from the SSO IdP. + should_check_deactivated: True if the user should be checked for + deactivation status before logging in. + + This exists purely for appservice's configured sender_localpart + which doesn't have an associated user in the database. + request_info: The user agent/IP address of the user. Returns: Dictionary of account information after successful login. @@ -364,6 +406,12 @@ class LoginRestServlet(RestServlet): ) user_id = canonical_uid + # If the account has been deactivated, do not proceed with the login. + if should_check_deactivated: + deactivated = await self._main_store.get_user_deactivated_status(user_id) + if deactivated: + raise UserDeactivatedError("This account has been deactivated") + device_id = login_submission.get("device_id") # If device_id is present, check that device_id is not longer than a reasonable 512 characters @@ -385,6 +433,22 @@ class LoginRestServlet(RestServlet): ) initial_display_name = login_submission.get("initial_device_display_name") + spam_check = await self._spam_checker.check_login_for_spam( + user_id, + device_id=device_id, + initial_display_name=initial_display_name, + request_info=[(request_info.user_agent, request_info.ip)], + auth_provider_id=auth_provider_id, + ) + if spam_check != self._spam_checker.NOT_SPAM: + logger.info("Blocking login due to spam checker") + raise SynapseError( + 403, + msg="Login was blocked by the server", + errcode=spam_check[0], + additional_fields=spam_check[1], + ) + ( device_id, access_token, @@ -419,7 +483,11 @@ class LoginRestServlet(RestServlet): return result async def _do_token_login( - self, login_submission: JsonDict, should_issue_refresh_token: bool = False + self, + login_submission: JsonDict, + should_issue_refresh_token: bool = False, + *, + request_info: RequestInfo, ) -> LoginResponse: """ Handle token login. @@ -442,10 +510,15 @@ class LoginRestServlet(RestServlet): auth_provider_id=res.auth_provider_id, should_issue_refresh_token=should_issue_refresh_token, auth_provider_session_id=res.auth_provider_session_id, + request_info=request_info, ) async def _do_jwt_login( - self, login_submission: JsonDict, should_issue_refresh_token: bool = False + self, + login_submission: JsonDict, + should_issue_refresh_token: bool = False, + *, + request_info: RequestInfo, ) -> LoginResponse: """ Handle the custom JWT login. @@ -458,12 +531,13 @@ class LoginRestServlet(RestServlet): Returns: The body of the JSON response. """ - user_id = await self.hs.get_jwt_handler().validate_login(login_submission) + user_id = self.hs.get_jwt_handler().validate_login(login_submission) return await self._complete_login( user_id, login_submission, create_non_existent_users=True, should_issue_refresh_token=should_issue_refresh_token, + request_info=request_info, ) @@ -616,6 +690,9 @@ class CasTicketServlet(RestServlet): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: + if hs.config.experimental.msc3861.enabled: + return + LoginRestServlet(hs).register(http_server) if ( hs.config.worker.worker_app is None diff --git a/synapse/rest/client/login_token_request.py b/synapse/rest/client/login_token_request.py index 43ea21d5e6..d189a923b5 100644 --- a/synapse/rest/client/login_token_request.py +++ b/synapse/rest/client/login_token_request.py @@ -15,6 +15,8 @@ import logging from typing import TYPE_CHECKING, Tuple +from synapse.api.ratelimiting import Ratelimiter +from synapse.config.ratelimiting import RatelimitSettings from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.site import SynapseRequest @@ -33,7 +35,7 @@ class LoginTokenRequestServlet(RestServlet): Request: - POST /login/token HTTP/1.1 + POST /login/get_token HTTP/1.1 Content-Type: application/json {} @@ -43,30 +45,48 @@ class LoginTokenRequestServlet(RestServlet): HTTP/1.1 200 OK { "login_token": "ABDEFGH", - "expires_in": 3600, + "expires_in_ms": 3600000, } """ - PATTERNS = client_patterns( - "/org.matrix.msc3882/login/token$", releases=[], v1=False, unstable=True - ) + PATTERNS = [ + *client_patterns( + "/login/get_token$", releases=["v1"], v1=False, unstable=False + ), + # TODO: this is no longer needed once unstable MSC3882 does not need to be supported: + *client_patterns( + "/org.matrix.msc3882/login/token$", releases=[], v1=False, unstable=True + ), + ] def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() - self.store = hs.get_datastores().main - self.clock = hs.get_clock() - self.server_name = hs.config.server.server_name + self._main_store = hs.get_datastores().main self.auth_handler = hs.get_auth_handler() - self.token_timeout = hs.config.experimental.msc3882_token_timeout - self.ui_auth = hs.config.experimental.msc3882_ui_auth + self.token_timeout = hs.config.auth.login_via_existing_token_timeout + self._require_ui_auth = hs.config.auth.login_via_existing_require_ui_auth + + # Ratelimit aggressively to a maximum of 1 request per minute. + # + # This endpoint can be used to spawn additional sessions and could be + # abused by a malicious client to create many sessions. + self._ratelimiter = Ratelimiter( + store=self._main_store, + clock=hs.get_clock(), + cfg=RatelimitSettings( + key="<login token request>", + per_second=1 / 60, + burst_count=1, + ), + ) @interactive_auth_handler async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) body = parse_json_object_from_request(request) - if self.ui_auth: + if self._require_ui_auth: await self.auth_handler.validate_user_via_ui_auth( requester, request, @@ -75,9 +95,12 @@ class LoginTokenRequestServlet(RestServlet): can_skip_ui_auth=False, # Don't allow skipping of UI auth ) + # Ensure that this endpoint isn't being used too often. (Ensure this is + # done *after* UI auth.) + await self._ratelimiter.ratelimit(None, requester.user.to_string().lower()) + login_token = await self.auth_handler.create_login_token_for_user_id( user_id=requester.user.to_string(), - auth_provider_id="org.matrix.msc3882.login_token_request", duration_ms=self.token_timeout, ) @@ -85,11 +108,13 @@ class LoginTokenRequestServlet(RestServlet): 200, { "login_token": login_token, + # TODO: this is no longer needed once unstable MSC3882 does not need to be supported: "expires_in": self.token_timeout // 1000, + "expires_in_ms": self.token_timeout, }, ) def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: - if hs.config.experimental.msc3882_enabled: + if hs.config.auth.login_via_existing_enabled: LoginTokenRequestServlet(hs).register(http_server) diff --git a/synapse/rest/client/logout.py b/synapse/rest/client/logout.py index 6d34625ad5..2e104d4888 100644 --- a/synapse/rest/client/logout.py +++ b/synapse/rest/client/logout.py @@ -40,7 +40,9 @@ class LogoutRestServlet(RestServlet): self._device_handler = handler async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request, allow_expired=True) + requester = await self.auth.get_user_by_req( + request, allow_expired=True, allow_locked=True + ) if requester.device_id is None: # The access token wasn't associated with a device. @@ -67,7 +69,9 @@ class LogoutAllRestServlet(RestServlet): self._device_handler = handler async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request, allow_expired=True) + requester = await self.auth.get_user_by_req( + request, allow_expired=True, allow_locked=True + ) user_id = requester.user.to_string() # first delete all of the user's devices @@ -80,5 +84,8 @@ class LogoutAllRestServlet(RestServlet): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: + if hs.config.experimental.msc3861.enabled: + return + LogoutRestServlet(hs).register(http_server) LogoutAllRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/models.py b/synapse/rest/client/models.py index 3d7940b0fc..880f79473c 100644 --- a/synapse/rest/client/models.py +++ b/synapse/rest/client/models.py @@ -13,7 +13,12 @@ # limitations under the License. from typing import TYPE_CHECKING, Dict, Optional -from pydantic import Extra, StrictInt, StrictStr, constr, validator +from synapse._pydantic_compat import HAS_PYDANTIC_V2 + +if TYPE_CHECKING or HAS_PYDANTIC_V2: + from pydantic.v1 import Extra, StrictInt, StrictStr, constr, validator +else: + from pydantic import Extra, StrictInt, StrictStr, constr, validator from synapse.rest.models import RequestBodyModel from synapse.util.threepids import validate_email diff --git a/synapse/rest/client/notifications.py b/synapse/rest/client/notifications.py index ea10042569..e7fe1332e7 100644 --- a/synapse/rest/client/notifications.py +++ b/synapse/rest/client/notifications.py @@ -36,6 +36,8 @@ logger = logging.getLogger(__name__) class NotificationsServlet(RestServlet): PATTERNS = client_patterns("/notifications$") + CATEGORY = "Client API requests" + def __init__(self, hs: "HomeServer"): super().__init__() self.store = hs.get_datastores().main diff --git a/synapse/rest/client/presence.py b/synapse/rest/client/presence.py index 8e193330f8..d578faa969 100644 --- a/synapse/rest/client/presence.py +++ b/synapse/rest/client/presence.py @@ -97,7 +97,7 @@ class PresenceStatusRestServlet(RestServlet): raise SynapseError(400, "Unable to parse state") if self._use_presence: - await self.presence_handler.set_state(user, state) + await self.presence_handler.set_state(user, requester.device_id, state) return 200, {} diff --git a/synapse/rest/client/push_rule.py b/synapse/rest/client/push_rule.py index 5c9fece3ba..5ed3b83a03 100644 --- a/synapse/rest/client/push_rule.py +++ b/synapse/rest/client/push_rule.py @@ -32,6 +32,7 @@ from synapse.push.rulekinds import PRIORITY_CLASS_MAP from synapse.rest.client._base import client_patterns from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException from synapse.types import JsonDict +from synapse.util.async_helpers import Linearizer if TYPE_CHECKING: from synapse.server import HomeServer @@ -53,26 +54,32 @@ class PushRuleRestServlet(RestServlet): self.notifier = hs.get_notifier() self._is_worker = hs.config.worker.worker_app is not None self._push_rules_handler = hs.get_push_rules_handler() + self._push_rule_linearizer = Linearizer(name="push_rules") async def on_PUT(self, request: SynapseRequest, path: str) -> Tuple[int, JsonDict]: if self._is_worker: raise Exception("Cannot handle PUT /push_rules on worker") + requester = await self.auth.get_user_by_req(request) + user_id = requester.user.to_string() + + async with self._push_rule_linearizer.queue(user_id): + return await self.handle_put(request, path, user_id) + + async def handle_put( + self, request: SynapseRequest, path: str, user_id: str + ) -> Tuple[int, JsonDict]: spec = _rule_spec_from_path(path.split("/")) try: priority_class = _priority_class_from_spec(spec) except InvalidRuleException as e: raise SynapseError(400, str(e)) - requester = await self.auth.get_user_by_req(request) - if "/" in spec.rule_id or "\\" in spec.rule_id: raise SynapseError(400, "rule_id may not contain slashes") content = parse_json_value_from_request(request) - user_id = requester.user.to_string() - if spec.attr: try: await self._push_rules_handler.set_rule_attr(user_id, spec, content) @@ -126,11 +133,20 @@ class PushRuleRestServlet(RestServlet): if self._is_worker: raise Exception("Cannot handle DELETE /push_rules on worker") - spec = _rule_spec_from_path(path.split("/")) - requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() + async with self._push_rule_linearizer.queue(user_id): + return await self.handle_delete(request, path, user_id) + + async def handle_delete( + self, + request: SynapseRequest, + path: str, + user_id: str, + ) -> Tuple[int, JsonDict]: + spec = _rule_spec_from_path(path.split("/")) + namespaced_rule_id = f"global/{spec.template}/{spec.rule_id}" try: diff --git a/synapse/rest/client/read_marker.py b/synapse/rest/client/read_marker.py index 4f96e51eeb..15e4d56cdb 100644 --- a/synapse/rest/client/read_marker.py +++ b/synapse/rest/client/read_marker.py @@ -52,7 +52,9 @@ class ReadMarkerRestServlet(RestServlet): ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) - await self.presence_handler.bump_presence_active_time(requester.user) + await self.presence_handler.bump_presence_active_time( + requester.user, requester.device_id + ) body = parse_json_object_from_request(request) @@ -82,7 +84,7 @@ class ReadMarkerRestServlet(RestServlet): await self.receipts_handler.received_client_receipt( room_id, receipt_type, - user_id=requester.user.to_string(), + user_id=requester.user, event_id=event_id, # Setting the thread ID is not possible with the /read_markers endpoint. thread_id=None, diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py index 316e7b9982..814d075faf 100644 --- a/synapse/rest/client/receipts.py +++ b/synapse/rest/client/receipts.py @@ -94,7 +94,9 @@ class ReceiptRestServlet(RestServlet): Codes.INVALID_PARAM, ) - await self.presence_handler.bump_presence_active_time(requester.user) + await self.presence_handler.bump_presence_active_time( + requester.user, requester.device_id + ) if receipt_type == ReceiptTypes.FULLY_READ: await self.read_marker_handler.received_client_read_marker( @@ -106,7 +108,7 @@ class ReceiptRestServlet(RestServlet): await self.receipts_handler.received_client_receipt( room_id, receipt_type, - user_id=requester.user.to_string(), + user_id=requester.user, event_id=event_id, thread_id=thread_id, ) diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index 7f84a17e29..132623462a 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -376,8 +376,7 @@ class RegistrationTokenValidityRestServlet(RestServlet): self.ratelimiter = Ratelimiter( store=self.store, clock=hs.get_clock(), - rate_hz=hs.config.ratelimiting.rc_registration_token_validity.per_second, - burst_count=hs.config.ratelimiting.rc_registration_token_validity.burst_count, + cfg=hs.config.ratelimiting.rc_registration_token_validity, ) async def on_GET(self, request: Request) -> Tuple[int, JsonDict]: @@ -462,9 +461,9 @@ class RegisterRestServlet(RestServlet): # the auth layer will store these in sessions. desired_username = None if "username" in body: - if not isinstance(body["username"], str) or len(body["username"]) > 512: - raise SynapseError(400, "Invalid username") desired_username = body["username"] + if not isinstance(desired_username, str) or len(desired_username) > 512: + raise SynapseError(400, "Invalid username") # fork off as soon as possible for ASes which have completely # different registration flows to normal users @@ -477,11 +476,6 @@ class RegisterRestServlet(RestServlet): "Appservice token must be provided when using a type of m.login.application_service", ) - # Set the desired user according to the AS API (which uses the - # 'user' key not 'username'). Since this is a new addition, we'll - # fallback to 'username' if they gave one. - desired_username = body.get("user", desired_username) - # XXX we should check that desired_username is valid. Currently # we give appservices carte blanche for any insanity in mxids, # because the IRC bridges rely on being able to register stupid @@ -489,7 +483,8 @@ class RegisterRestServlet(RestServlet): access_token = self.auth.get_access_token_from_request(request) - if not isinstance(desired_username, str): + # Desired username is either a string or None. + if desired_username is None: raise SynapseError(400, "Desired Username is missing or not a string") result = await self._do_appservice_registration( @@ -869,6 +864,74 @@ class RegisterRestServlet(RestServlet): return 200, result +class RegisterAppServiceOnlyRestServlet(RestServlet): + """An alternative registration API endpoint that only allows ASes to register + + This replaces the regular /register endpoint if MSC3861. There are two notable + differences with the regular /register endpoint: + - It only allows the `m.login.application_service` login type + - It does not create a device or access token for the just-registered user + + Note that the exact behaviour of this endpoint is not yet finalised. It should be + just good enough to make most ASes work. + """ + + PATTERNS = client_patterns("/register$") + CATEGORY = "Registration/login requests" + + def __init__(self, hs: "HomeServer"): + super().__init__() + + self.auth = hs.get_auth() + self.registration_handler = hs.get_registration_handler() + self.ratelimiter = hs.get_registration_ratelimiter() + + @interactive_auth_handler + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + body = parse_json_object_from_request(request) + + client_addr = request.getClientAddress().host + + await self.ratelimiter.ratelimit(None, client_addr, update=False) + + kind = parse_string(request, "kind", default="user") + + if kind == "guest": + raise SynapseError(403, "Guest access is disabled") + elif kind != "user": + raise UnrecognizedRequestError( + f"Do not understand membership kind: {kind}", + ) + + # Pull out the provided username and do basic sanity checks early since + # the auth layer will store these in sessions. + desired_username = body.get("username") + if not isinstance(desired_username, str) or len(desired_username) > 512: + raise SynapseError(400, "Invalid username") + + # Allow only ASes to use this API. + if body.get("type") != APP_SERVICE_REGISTRATION_TYPE: + raise SynapseError(403, "Non-application service registration type") + + if not self.auth.has_access_token(request): + raise SynapseError( + 400, + "Appservice token must be provided when using a type of m.login.application_service", + ) + + # XXX we should check that desired_username is valid. Currently + # we give appservices carte blanche for any insanity in mxids, + # because the IRC bridges rely on being able to register stupid + # IDs. + + as_token = self.auth.get_access_token_from_request(request) + + user_id = await self.registration_handler.appservice_register( + desired_username, as_token + ) + return 200, {"user_id": user_id} + + def _calculate_registration_flows( config: HomeServerConfig, auth_handler: AuthHandler ) -> List[List[str]]: @@ -955,6 +1018,10 @@ def _calculate_registration_flows( def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: + if hs.config.experimental.msc3861.enabled: + RegisterAppServiceOnlyRestServlet(hs).register(http_server) + return + if hs.config.worker.worker_app is None: EmailRegisterRequestTokenRestServlet(hs).register(http_server) MsisdnRegisterRequestTokenRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/report_event.py b/synapse/rest/client/report_event.py index ac1a63ca27..ee93e459f6 100644 --- a/synapse/rest/client/report_event.py +++ b/synapse/rest/client/report_event.py @@ -55,7 +55,7 @@ class ReportEventRestServlet(RestServlet): "Param 'reason' must be a string", Codes.BAD_JSON, ) - if type(body.get("score", 0)) is not int: + if type(body.get("score", 0)) is not int: # noqa: E721 raise SynapseError( HTTPStatus.BAD_REQUEST, "Param 'score' must be an integer", diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index fb0ca69b9b..99ad387b50 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -1125,7 +1125,7 @@ class RoomRedactEventRestServlet(TransactionRestServlet): # Ensure the redacts property in the content matches the one provided in # the URL. room_version = await self._store.get_room_version(room_id) - if room_version.msc2176_redaction_rules: + if room_version.updated_redaction_rules: if "redacts" in content and content["redacts"] != event_id: raise SynapseError( 400, @@ -1159,7 +1159,7 @@ class RoomRedactEventRestServlet(TransactionRestServlet): "sender": requester.user.to_string(), } # Earlier room versions had a top-level redacts property. - if not room_version.msc2176_redaction_rules: + if not room_version.updated_redaction_rules: event_dict["redacts"] = event_id ( @@ -1237,7 +1237,9 @@ class RoomTypingRestServlet(RestServlet): content = parse_json_object_from_request(request) - await self.presence_handler.bump_presence_active_time(requester.user) + await self.presence_handler.bump_presence_active_time( + requester.user, requester.device_id + ) # Limit timeout to stop people from setting silly typing timeouts. timeout = min(content.get("timeout", 30000), 120000) diff --git a/synapse/rest/client/room_batch.py b/synapse/rest/client/room_batch.py deleted file mode 100644 index 69f85112d8..0000000000 --- a/synapse/rest/client/room_batch.py +++ /dev/null @@ -1,254 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import re -from http import HTTPStatus -from typing import TYPE_CHECKING, Tuple - -from synapse.api.constants import EventContentFields -from synapse.api.errors import AuthError, Codes, SynapseError -from synapse.http.server import HttpServer -from synapse.http.servlet import ( - RestServlet, - assert_params_in_dict, - parse_json_object_from_request, - parse_string, - parse_strings_from_args, -) -from synapse.http.site import SynapseRequest -from synapse.types import JsonDict - -if TYPE_CHECKING: - from synapse.server import HomeServer - -logger = logging.getLogger(__name__) - - -class RoomBatchSendEventRestServlet(RestServlet): - """ - API endpoint which can insert a batch of events historically back in time - next to the given `prev_event`. - - `batch_id` comes from `next_batch_id `in the response of the batch send - endpoint and is derived from the "insertion" events added to each batch. - It's not required for the first batch send. - - `state_events_at_start` is used to define the historical state events - needed to auth the events like join events. These events will float - outside of the normal DAG as outlier's and won't be visible in the chat - history which also allows us to insert multiple batches without having a bunch - of `@mxid joined the room` noise between each batch. - - `events` is chronological list of events you want to insert. - There is a reverse-chronological constraint on batches so once you insert - some messages, you can only insert older ones after that. - tldr; Insert batches from your most recent history -> oldest history. - - POST /_matrix/client/unstable/org.matrix.msc2716/rooms/<roomID>/batch_send?prev_event_id=<eventID>&batch_id=<batchID> - { - "events": [ ... ], - "state_events_at_start": [ ... ] - } - """ - - PATTERNS = ( - re.compile( - "^/_matrix/client/unstable/org.matrix.msc2716" - "/rooms/(?P<room_id>[^/]*)/batch_send$" - ), - ) - CATEGORY = "Client API requests" - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.store = hs.get_datastores().main - self.event_creation_handler = hs.get_event_creation_handler() - self.auth = hs.get_auth() - self.room_batch_handler = hs.get_room_batch_handler() - - async def on_POST( - self, request: SynapseRequest, room_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request, allow_guest=False) - - if not requester.app_service: - raise AuthError( - HTTPStatus.FORBIDDEN, - "Only application services can use the /batchsend endpoint", - ) - - body = parse_json_object_from_request(request) - assert_params_in_dict(body, ["state_events_at_start", "events"]) - - assert request.args is not None - prev_event_ids_from_query = parse_strings_from_args( - request.args, "prev_event_id" - ) - batch_id_from_query = parse_string(request, "batch_id") - - if prev_event_ids_from_query is None: - raise SynapseError( - HTTPStatus.BAD_REQUEST, - "prev_event query parameter is required when inserting historical messages back in time", - errcode=Codes.MISSING_PARAM, - ) - - if await self.store.is_partial_state_room(room_id): - raise SynapseError( - HTTPStatus.BAD_REQUEST, - "Cannot insert history batches until we have fully joined the room", - errcode=Codes.UNABLE_DUE_TO_PARTIAL_STATE, - ) - - # Verify the batch_id_from_query corresponds to an actual insertion event - # and have the batch connected. - if batch_id_from_query: - corresponding_insertion_event_id = ( - await self.store.get_insertion_event_id_by_batch_id( - room_id, batch_id_from_query - ) - ) - if corresponding_insertion_event_id is None: - raise SynapseError( - HTTPStatus.BAD_REQUEST, - "No insertion event corresponds to the given ?batch_id", - errcode=Codes.INVALID_PARAM, - ) - - # Make sure that the prev_event_ids exist and aren't outliers - ie, they are - # regular parts of the room DAG where we know the state. - non_outlier_prev_events = await self.store.have_events_in_timeline( - prev_event_ids_from_query - ) - for prev_event_id in prev_event_ids_from_query: - if prev_event_id not in non_outlier_prev_events: - raise SynapseError( - HTTPStatus.BAD_REQUEST, - "prev_event %s does not exist, or is an outlier" % (prev_event_id,), - errcode=Codes.INVALID_PARAM, - ) - - # For the event we are inserting next to (`prev_event_ids_from_query`), - # find the most recent state events that allowed that message to be - # sent. We will use that as a base to auth our historical messages - # against. - state_event_ids = await self.room_batch_handler.get_most_recent_full_state_ids_from_event_id_list( - prev_event_ids_from_query - ) - - state_event_ids_at_start = [] - # Create and persist all of the state events that float off on their own - # before the batch. These will most likely be all of the invite/member - # state events used to auth the upcoming historical messages. - if body["state_events_at_start"]: - state_event_ids_at_start = ( - await self.room_batch_handler.persist_state_events_at_start( - state_events_at_start=body["state_events_at_start"], - room_id=room_id, - initial_state_event_ids=state_event_ids, - app_service_requester=requester, - ) - ) - # Update our ongoing auth event ID list with all of the new state we - # just created - state_event_ids.extend(state_event_ids_at_start) - - inherited_depth = await self.room_batch_handler.inherit_depth_from_prev_ids( - prev_event_ids_from_query - ) - - events_to_create = body["events"] - - # Figure out which batch to connect to. If they passed in - # batch_id_from_query let's use it. The batch ID passed in comes - # from the batch_id in the "insertion" event from the previous batch. - last_event_in_batch = events_to_create[-1] - base_insertion_event = None - if batch_id_from_query: - batch_id_to_connect_to = batch_id_from_query - # Otherwise, create an insertion event to act as a starting point. - # - # We don't always have an insertion event to start hanging more history - # off of (ideally there would be one in the main DAG, but that's not the - # case if we're wanting to add history to e.g. existing rooms without - # an insertion event), in which case we just create a new insertion event - # that can then get pointed to by a "marker" event later. - else: - base_insertion_event_dict = ( - self.room_batch_handler.create_insertion_event_dict( - sender=requester.user.to_string(), - room_id=room_id, - origin_server_ts=last_event_in_batch["origin_server_ts"], - ) - ) - base_insertion_event_dict["prev_events"] = prev_event_ids_from_query.copy() - - ( - base_insertion_event, - _, - ) = await self.event_creation_handler.create_and_send_nonmember_event( - await self.room_batch_handler.create_requester_for_user_id_from_app_service( - base_insertion_event_dict["sender"], - requester.app_service, - ), - base_insertion_event_dict, - prev_event_ids=base_insertion_event_dict.get("prev_events"), - # Also set the explicit state here because we want to resolve - # any `state_events_at_start` here too. It's not strictly - # necessary to accomplish anything but if someone asks for the - # state at this point, we probably want to show them the - # historical state that was part of this batch. - state_event_ids=state_event_ids, - historical=True, - depth=inherited_depth, - ) - - batch_id_to_connect_to = base_insertion_event.content[ - EventContentFields.MSC2716_NEXT_BATCH_ID - ] - - # Create and persist all of the historical events as well as insertion - # and batch meta events to make the batch navigable in the DAG. - event_ids, next_batch_id = await self.room_batch_handler.handle_batch_of_events( - events_to_create=events_to_create, - room_id=room_id, - batch_id_to_connect_to=batch_id_to_connect_to, - inherited_depth=inherited_depth, - initial_state_event_ids=state_event_ids, - app_service_requester=requester, - ) - - insertion_event_id = event_ids[0] - batch_event_id = event_ids[-1] - historical_event_ids = event_ids[1:-1] - - response_dict = { - "state_event_ids": state_event_ids_at_start, - "event_ids": historical_event_ids, - "next_batch_id": next_batch_id, - "insertion_event_id": insertion_event_id, - "batch_event_id": batch_event_id, - } - if base_insertion_event is not None: - response_dict["base_insertion_event_id"] = base_insertion_event.event_id - - return HTTPStatus.OK, response_dict - - -def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: - msc2716_enabled = hs.config.experimental.msc2716_enabled - - if msc2716_enabled: - RoomBatchSendEventRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/room_upgrade_rest_servlet.py b/synapse/rest/client/room_upgrade_rest_servlet.py index 6a7792e18b..b1f6b5d1b7 100644 --- a/synapse/rest/client/room_upgrade_rest_servlet.py +++ b/synapse/rest/client/room_upgrade_rest_servlet.py @@ -17,6 +17,7 @@ from typing import TYPE_CHECKING, Tuple from synapse.api.errors import Codes, ShadowBanError, SynapseError from synapse.api.room_versions import KNOWN_ROOM_VERSIONS +from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, @@ -60,6 +61,7 @@ class RoomUpgradeRestServlet(RestServlet): self._hs = hs self._room_creation_handler = hs.get_room_creation_handler() self._auth = hs.get_auth() + self._worker_lock_handler = hs.get_worker_locks_handler() async def on_POST( self, request: SynapseRequest, room_id: str @@ -78,9 +80,12 @@ class RoomUpgradeRestServlet(RestServlet): ) try: - new_room_id = await self._room_creation_handler.upgrade_room( - requester, room_id, new_version - ) + async with self._worker_lock_handler.acquire_read_write_lock( + NEW_EVENT_DURING_PURGE_LOCK_NAME, room_id, write=False + ): + new_room_id = await self._room_creation_handler.upgrade_room( + requester, room_id, new_version + ) except ShadowBanError: # Generate a random room ID. new_room_id = stringutils.random_string(18) diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index 03b0578945..42bdd3bb10 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -178,7 +178,7 @@ class SyncRestServlet(RestServlet): else: try: filter_collection = await self.filtering.get_user_filter( - user.localpart, filter_id + user, filter_id ) except StoreError as err: if err.code != 404: @@ -205,6 +205,7 @@ class SyncRestServlet(RestServlet): context = await self.presence_handler.user_syncing( user.to_string(), + requester.device_id, affect_presence=affect_presence, presence_state=set_presence, ) diff --git a/synapse/rest/client/transactions.py b/synapse/rest/client/transactions.py index 0d8a63d8be..3d814c404d 100644 --- a/synapse/rest/client/transactions.py +++ b/synapse/rest/client/transactions.py @@ -50,8 +50,6 @@ class HttpTransactionCache: # for at *LEAST* 30 mins, and at *MOST* 60 mins. self.cleaner = self.clock.looping_call(self._cleanup, CLEANUP_PERIOD_MS) - self._msc3970_enabled = hs.config.experimental.msc3970_enabled - def _get_transaction_key(self, request: IRequest, requester: Requester) -> Hashable: """A helper function which returns a transaction key that can be used with TransactionCache for idempotent requests. @@ -78,18 +76,20 @@ class HttpTransactionCache: elif requester.app_service is not None: return (path, "appservice", requester.app_service.id) - # With MSC3970, we use the user ID and device ID as the transaction key - elif self._msc3970_enabled: + # Use the user ID and device ID as the transaction key. + elif requester.device_id: assert requester.user, "Requester must have a user" assert requester.device_id, "Requester must have a device_id" return (path, "user", requester.user, requester.device_id) - # Otherwise, the pre-MSC3970 behaviour is to use the access token ID + # Some requsters don't have device IDs, these are mostly handled above + # (appservice and guest users), but does not cover access tokens minted + # by the admin API. Use the access token ID instead. else: assert ( requester.access_token_id is not None ), "Requester must have an access_token_id" - return (path, "user", requester.access_token_id) + return (path, "user_admin", requester.access_token_id) def fetch_or_execute_request( self, diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 32df054f56..95400ba570 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -102,8 +102,6 @@ class VersionsRestServlet(RestServlet): "org.matrix.msc2285.stable": True, # TODO: Remove when MSC2285 becomes a part of the spec # Supports filtering of /publicRooms by room type as per MSC3827 "org.matrix.msc3827.stable": True, - # Adds support for importing historical messages as per MSC2716 - "org.matrix.msc2716": self.config.experimental.msc2716_enabled, # Adds support for thread relations, per MSC3440. "org.matrix.msc3440.stable": True, # TODO: remove when "v1.3" is added above # Support for thread read receipts & notification counts. @@ -113,8 +111,8 @@ class VersionsRestServlet(RestServlet): "fi.mau.msc2815": self.config.experimental.msc2815_enabled, # Adds a ping endpoint for appservices to check HS->AS connection "fi.mau.msc2659.stable": True, # TODO: remove when "v1.7" is added above - # Adds support for login token requests as per MSC3882 - "org.matrix.msc3882": self.config.experimental.msc3882_enabled, + # TODO: this is no longer needed once unstable MSC3882 does not need to be supported: + "org.matrix.msc3882": self.config.auth.login_via_existing_enabled, # Adds support for remotely enabling/disabling pushers, as per MSC3881 "org.matrix.msc3881": self.config.experimental.msc3881_enabled, # Adds support for filtering /messages by event relation. @@ -124,8 +122,6 @@ class VersionsRestServlet(RestServlet): is not None, # Adds support for relation-based redactions as per MSC3912. "org.matrix.msc3912": self.config.experimental.msc3912_enabled, - # Adds support for unstable "intentional mentions" behaviour. - "org.matrix.msc3952_intentional_mentions": self.config.experimental.msc3952_intentional_mentions, # Whether recursively provide relations is supported. "org.matrix.msc3981": self.config.experimental.msc3981_recurse_relations, # Adds support for deleting account data. diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py index 25f9ea285b..88d3ec1baf 100644 --- a/synapse/rest/consent/consent_resource.py +++ b/synapse/rest/consent/consent_resource.py @@ -129,7 +129,7 @@ class ConsentResource(DirectServeHtmlResource): if u is None: raise NotFoundError("Unknown user") - has_consented = u["consent_version"] == version + has_consented = u.consent_version == version userhmac = userhmac_bytes.decode("ascii") try: diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index 8f3865d412..48c47058db 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -14,7 +14,14 @@ import logging import re -from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple +from typing import TYPE_CHECKING, Dict, Mapping, Optional, Set, Tuple + +from synapse._pydantic_compat import HAS_PYDANTIC_V2 + +if TYPE_CHECKING or HAS_PYDANTIC_V2: + from pydantic.v1 import Extra, StrictInt, StrictStr +else: + from pydantic import StrictInt, StrictStr, Extra from signedjson.sign import sign_json @@ -24,9 +31,11 @@ from synapse.crypto.keyring import ServerKeyFetcher from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, + parse_and_validate_json_object_from_request, parse_integer, - parse_json_object_from_request, ) +from synapse.rest.models import RequestBodyModel +from synapse.storage.keys import FetchKeyResultForRemote from synapse.types import JsonDict from synapse.util import json_decoder from synapse.util.async_helpers import yieldable_gather_results @@ -37,6 +46,13 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +class _KeyQueryCriteriaDataModel(RequestBodyModel): + class Config: + extra = Extra.allow + + minimum_valid_until_ts: Optional[StrictInt] + + class RemoteKey(RestServlet): """HTTP resource for retrieving the TLS certificate and NACL signature verification keys for a collection of servers. Checks that the reported @@ -95,6 +111,9 @@ class RemoteKey(RestServlet): CATEGORY = "Federation requests" + class PostBody(RequestBodyModel): + server_keys: Dict[StrictStr, Dict[StrictStr, _KeyQueryCriteriaDataModel]] + def __init__(self, hs: "HomeServer"): self.fetcher = ServerKeyFetcher(hs) self.store = hs.get_datastores().main @@ -136,35 +155,48 @@ class RemoteKey(RestServlet): ) minimum_valid_until_ts = parse_integer(request, "minimum_valid_until_ts") - arguments = {} - if minimum_valid_until_ts is not None: - arguments["minimum_valid_until_ts"] = minimum_valid_until_ts - query = {server: {key_id: arguments}} + query = { + server: { + key_id: _KeyQueryCriteriaDataModel( + minimum_valid_until_ts=minimum_valid_until_ts + ) + } + } else: query = {server: {}} return 200, await self.query_keys(query, query_remote_on_cache_miss=True) async def on_POST(self, request: Request) -> Tuple[int, JsonDict]: - content = parse_json_object_from_request(request) + content = parse_and_validate_json_object_from_request(request, self.PostBody) - query = content["server_keys"] + query = content.server_keys return 200, await self.query_keys(query, query_remote_on_cache_miss=True) async def query_keys( - self, query: JsonDict, query_remote_on_cache_miss: bool = False + self, + query: Dict[str, Dict[str, _KeyQueryCriteriaDataModel]], + query_remote_on_cache_miss: bool = False, ) -> JsonDict: logger.info("Handling query for keys %r", query) - store_queries = [] + server_keys: Dict[Tuple[str, str], Optional[FetchKeyResultForRemote]] = {} for server_name, key_ids in query.items(): - if not key_ids: - key_ids = (None,) - for key_id in key_ids: - store_queries.append((server_name, key_id, None)) + if key_ids: + results: Mapping[ + str, Optional[FetchKeyResultForRemote] + ] = await self.store.get_server_keys_json_for_remote( + server_name, key_ids + ) + else: + results = await self.store.get_all_server_keys_json_for_remote( + server_name + ) - cached = await self.store.get_server_keys_json_for_remote(store_queries) + server_keys.update( + ((server_name, key_id), res) for key_id, res in results.items() + ) json_results: Set[bytes] = set() @@ -173,25 +205,24 @@ class RemoteKey(RestServlet): # Map server_name->key_id->int. Note that the value of the int is unused. # XXX: why don't we just use a set? cache_misses: Dict[str, Dict[str, int]] = {} - for (server_name, key_id, _), key_results in cached.items(): - results = [(result["ts_added_ms"], result) for result in key_results] - - if key_id is None: + for (server_name, key_id), key_result in server_keys.items(): + if not query[server_name]: # all keys were requested. Just return what we have without worrying # about validity - for _, result in results: - # Cast to bytes since postgresql returns a memoryview. - json_results.add(bytes(result["key_json"])) + if key_result: + json_results.add(key_result.key_json) continue miss = False - if not results: + if key_result is None: miss = True else: - ts_added_ms, most_recent_result = max(results) - ts_valid_until_ms = most_recent_result["ts_valid_until_ms"] - req_key = query.get(server_name, {}).get(key_id, {}) - req_valid_until = req_key.get("minimum_valid_until_ts") + ts_added_ms = key_result.added_ts + ts_valid_until_ms = key_result.valid_until_ts + req_key = query.get(server_name, {}).get( + key_id, _KeyQueryCriteriaDataModel(minimum_valid_until_ts=None) + ) + req_valid_until = req_key.minimum_valid_until_ts if req_valid_until is not None: if ts_valid_until_ms < req_valid_until: logger.debug( @@ -235,8 +266,8 @@ class RemoteKey(RestServlet): ts_valid_until_ms, time_now_ms, ) - # Cast to bytes since postgresql returns a memoryview. - json_results.add(bytes(most_recent_result["key_json"])) + + json_results.add(key_result.key_json) if miss and query_remote_on_cache_miss: # only bother attempting to fetch keys from servers on our whitelist diff --git a/synapse/rest/media/config_resource.py b/synapse/rest/media/config_resource.py index a95804d327..dbf5133c72 100644 --- a/synapse/rest/media/config_resource.py +++ b/synapse/rest/media/config_resource.py @@ -14,17 +14,19 @@ # limitations under the License. # +import re from typing import TYPE_CHECKING -from synapse.http.server import DirectServeJsonResource, respond_with_json +from synapse.http.server import respond_with_json +from synapse.http.servlet import RestServlet from synapse.http.site import SynapseRequest if TYPE_CHECKING: from synapse.server import HomeServer -class MediaConfigResource(DirectServeJsonResource): - isLeaf = True +class MediaConfigResource(RestServlet): + PATTERNS = [re.compile("/_matrix/media/(r0|v3|v1)/config$")] def __init__(self, hs: "HomeServer"): super().__init__() @@ -33,9 +35,6 @@ class MediaConfigResource(DirectServeJsonResource): self.auth = hs.get_auth() self.limits_dict = {"m.upload.size": config.media.max_upload_size} - async def _async_render_GET(self, request: SynapseRequest) -> None: + async def on_GET(self, request: SynapseRequest) -> None: await self.auth.get_user_by_req(request) respond_with_json(request, 200, self.limits_dict, send_cors=True) - - async def _async_render_OPTIONS(self, request: SynapseRequest) -> None: - respond_with_json(request, 200, {}, send_cors=True) diff --git a/synapse/rest/media/download_resource.py b/synapse/rest/media/download_resource.py index 3c618ef60a..65b9ff52fa 100644 --- a/synapse/rest/media/download_resource.py +++ b/synapse/rest/media/download_resource.py @@ -13,16 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING +import re +from typing import TYPE_CHECKING, Optional -from synapse.http.server import ( - DirectServeJsonResource, - set_corp_headers, - set_cors_headers, -) -from synapse.http.servlet import parse_boolean +from synapse.http.server import set_corp_headers, set_cors_headers +from synapse.http.servlet import RestServlet, parse_boolean from synapse.http.site import SynapseRequest -from synapse.media._base import parse_media_id, respond_404 +from synapse.media._base import respond_404 +from synapse.util.stringutils import parse_and_validate_server_name if TYPE_CHECKING: from synapse.media.media_repository import MediaRepository @@ -31,15 +29,28 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class DownloadResource(DirectServeJsonResource): - isLeaf = True +class DownloadResource(RestServlet): + PATTERNS = [ + re.compile( + "/_matrix/media/(r0|v3|v1)/download/(?P<server_name>[^/]*)/(?P<media_id>[^/]*)(/(?P<file_name>[^/]*))?$" + ) + ] def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"): super().__init__() self.media_repo = media_repo self._is_mine_server_name = hs.is_mine_server_name - async def _async_render_GET(self, request: SynapseRequest) -> None: + async def on_GET( + self, + request: SynapseRequest, + server_name: str, + media_id: str, + file_name: Optional[str] = None, + ) -> None: + # Validate the server name, raising if invalid + parse_and_validate_server_name(server_name) + set_cors_headers(request) set_corp_headers(request) request.setHeader( @@ -58,9 +69,8 @@ class DownloadResource(DirectServeJsonResource): b"Referrer-Policy", b"no-referrer", ) - server_name, media_id, name = parse_media_id(request) if self._is_mine_server_name(server_name): - await self.media_repo.get_local_media(request, media_id, name) + await self.media_repo.get_local_media(request, media_id, file_name) else: allow_remote = parse_boolean(request, "allow_remote", default=True) if not allow_remote: @@ -72,4 +82,6 @@ class DownloadResource(DirectServeJsonResource): respond_404(request) return - await self.media_repo.get_remote_media(request, server_name, media_id, name) + await self.media_repo.get_remote_media( + request, server_name, media_id, file_name + ) diff --git a/synapse/rest/media/media_repository_resource.py b/synapse/rest/media/media_repository_resource.py index 5ebaa3b032..2089bb1029 100644 --- a/synapse/rest/media/media_repository_resource.py +++ b/synapse/rest/media/media_repository_resource.py @@ -15,7 +15,7 @@ from typing import TYPE_CHECKING from synapse.config._base import ConfigError -from synapse.http.server import UnrecognizedRequestResource +from synapse.http.server import HttpServer, JsonResource from .config_resource import MediaConfigResource from .download_resource import DownloadResource @@ -27,7 +27,7 @@ if TYPE_CHECKING: from synapse.server import HomeServer -class MediaRepositoryResource(UnrecognizedRequestResource): +class MediaRepositoryResource(JsonResource): """File uploading and downloading. Uploads are POSTed to a resource which returns a token which is used to GET @@ -70,6 +70,11 @@ class MediaRepositoryResource(UnrecognizedRequestResource): width and height are close to the requested size and the aspect matches the requested size. The client should scale the image if it needs to fit within a given rectangle. + + This gets mounted at various points under /_matrix/media, including: + * /_matrix/media/r0 + * /_matrix/media/v1 + * /_matrix/media/v3 """ def __init__(self, hs: "HomeServer"): @@ -77,17 +82,23 @@ class MediaRepositoryResource(UnrecognizedRequestResource): if not hs.config.media.can_load_media_repo: raise ConfigError("Synapse is not configured to use a media repo.") - super().__init__() + JsonResource.__init__(self, hs, canonical_json=False) + self.register_servlets(self, hs) + + @staticmethod + def register_servlets(http_server: HttpServer, hs: "HomeServer") -> None: media_repo = hs.get_media_repository() - self.putChild(b"upload", UploadResource(hs, media_repo)) - self.putChild(b"download", DownloadResource(hs, media_repo)) - self.putChild( - b"thumbnail", ThumbnailResource(hs, media_repo, media_repo.media_storage) + # Note that many of these should not exist as v1 endpoints, but empirically + # a lot of traffic still goes to them. + + UploadResource(hs, media_repo).register(http_server) + DownloadResource(hs, media_repo).register(http_server) + ThumbnailResource(hs, media_repo, media_repo.media_storage).register( + http_server ) if hs.config.media.url_preview_enabled: - self.putChild( - b"preview_url", - PreviewUrlResource(hs, media_repo, media_repo.media_storage), + PreviewUrlResource(hs, media_repo, media_repo.media_storage).register( + http_server ) - self.putChild(b"config", MediaConfigResource(hs)) + MediaConfigResource(hs).register(http_server) diff --git a/synapse/rest/media/preview_url_resource.py b/synapse/rest/media/preview_url_resource.py index 58513c4be4..c8acb65dca 100644 --- a/synapse/rest/media/preview_url_resource.py +++ b/synapse/rest/media/preview_url_resource.py @@ -13,24 +13,20 @@ # See the License for the specific language governing permissions and # limitations under the License. +import re from typing import TYPE_CHECKING -from synapse.http.server import ( - DirectServeJsonResource, - respond_with_json, - respond_with_json_bytes, -) -from synapse.http.servlet import parse_integer, parse_string +from synapse.http.server import respond_with_json_bytes +from synapse.http.servlet import RestServlet, parse_integer, parse_string from synapse.http.site import SynapseRequest from synapse.media.media_storage import MediaStorage -from synapse.media.url_previewer import UrlPreviewer if TYPE_CHECKING: from synapse.media.media_repository import MediaRepository from synapse.server import HomeServer -class PreviewUrlResource(DirectServeJsonResource): +class PreviewUrlResource(RestServlet): """ The `GET /_matrix/media/r0/preview_url` endpoint provides a generic preview API for URLs which outputs Open Graph (https://ogp.me/) responses (with some Matrix @@ -48,7 +44,7 @@ class PreviewUrlResource(DirectServeJsonResource): * Matrix cannot be used to distribute the metadata between homeservers. """ - isLeaf = True + PATTERNS = [re.compile("/_matrix/media/(r0|v3|v1)/preview_url$")] def __init__( self, @@ -62,14 +58,10 @@ class PreviewUrlResource(DirectServeJsonResource): self.clock = hs.get_clock() self.media_repo = media_repo self.media_storage = media_storage + assert self.media_repo.url_previewer is not None + self.url_previewer = self.media_repo.url_previewer - self._url_previewer = UrlPreviewer(hs, media_repo, media_storage) - - async def _async_render_OPTIONS(self, request: SynapseRequest) -> None: - request.setHeader(b"Allow", b"OPTIONS, GET") - respond_with_json(request, 200, {}, send_cors=True) - - async def _async_render_GET(self, request: SynapseRequest) -> None: + async def on_GET(self, request: SynapseRequest) -> None: # XXX: if get_user_by_req fails, what should we do in an async render? requester = await self.auth.get_user_by_req(request) url = parse_string(request, "url", required=True) @@ -77,5 +69,5 @@ class PreviewUrlResource(DirectServeJsonResource): if ts is None: ts = self.clock.time_msec() - og = await self._url_previewer.preview(url, requester.user, ts) + og = await self.url_previewer.preview(url, requester.user, ts) respond_with_json_bytes(request, 200, og, send_cors=True) diff --git a/synapse/rest/media/thumbnail_resource.py b/synapse/rest/media/thumbnail_resource.py index 661e604b85..85b6bdbe72 100644 --- a/synapse/rest/media/thumbnail_resource.py +++ b/synapse/rest/media/thumbnail_resource.py @@ -13,29 +13,24 @@ # See the License for the specific language governing permissions and # limitations under the License. - import logging -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +import re +from typing import TYPE_CHECKING, List, Optional, Tuple from synapse.api.errors import Codes, SynapseError, cs_error from synapse.config.repository import THUMBNAIL_SUPPORTED_MEDIA_FORMAT_MAP -from synapse.http.server import ( - DirectServeJsonResource, - respond_with_json, - set_corp_headers, - set_cors_headers, -) -from synapse.http.servlet import parse_integer, parse_string +from synapse.http.server import respond_with_json, set_corp_headers, set_cors_headers +from synapse.http.servlet import RestServlet, parse_integer, parse_string from synapse.http.site import SynapseRequest from synapse.media._base import ( FileInfo, ThumbnailInfo, - parse_media_id, respond_404, respond_with_file, respond_with_responder, ) from synapse.media.media_storage import MediaStorage +from synapse.util.stringutils import parse_and_validate_server_name if TYPE_CHECKING: from synapse.media.media_repository import MediaRepository @@ -44,8 +39,12 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class ThumbnailResource(DirectServeJsonResource): - isLeaf = True +class ThumbnailResource(RestServlet): + PATTERNS = [ + re.compile( + "/_matrix/media/(r0|v3|v1)/thumbnail/(?P<server_name>[^/]*)/(?P<media_id>[^/]*)$" + ) + ] def __init__( self, @@ -60,12 +59,17 @@ class ThumbnailResource(DirectServeJsonResource): self.media_storage = media_storage self.dynamic_thumbnails = hs.config.media.dynamic_thumbnails self._is_mine_server_name = hs.is_mine_server_name + self._server_name = hs.hostname self.prevent_media_downloads_from = hs.config.media.prevent_media_downloads_from - async def _async_render_GET(self, request: SynapseRequest) -> None: + async def on_GET( + self, request: SynapseRequest, server_name: str, media_id: str + ) -> None: + # Validate the server name, raising if invalid + parse_and_validate_server_name(server_name) + set_cors_headers(request) set_corp_headers(request) - server_name, media_id, _ = parse_media_id(request) width = parse_integer(request, "width", required=True) height = parse_integer(request, "height", required=True) method = parse_string(request, "method", "scale") @@ -155,30 +159,24 @@ class ThumbnailResource(DirectServeJsonResource): thumbnail_infos = await self.store.get_local_media_thumbnails(media_id) for info in thumbnail_infos: - t_w = info["thumbnail_width"] == desired_width - t_h = info["thumbnail_height"] == desired_height - t_method = info["thumbnail_method"] == desired_method - t_type = info["thumbnail_type"] == desired_type + t_w = info.width == desired_width + t_h = info.height == desired_height + t_method = info.method == desired_method + t_type = info.type == desired_type if t_w and t_h and t_method and t_type: file_info = FileInfo( server_name=None, file_id=media_id, url_cache=media_info["url_cache"], - thumbnail=ThumbnailInfo( - width=info["thumbnail_width"], - height=info["thumbnail_height"], - type=info["thumbnail_type"], - method=info["thumbnail_method"], - ), + thumbnail=info, ) - t_type = file_info.thumbnail_type - t_length = info["thumbnail_length"] - responder = await self.media_storage.fetch_media(file_info) if responder: - await respond_with_responder(request, responder, t_type, t_length) + await respond_with_responder( + request, responder, info.type, info.length + ) return logger.debug("We don't have a thumbnail of that size. Generating") @@ -218,29 +216,23 @@ class ThumbnailResource(DirectServeJsonResource): file_id = media_info["filesystem_id"] for info in thumbnail_infos: - t_w = info["thumbnail_width"] == desired_width - t_h = info["thumbnail_height"] == desired_height - t_method = info["thumbnail_method"] == desired_method - t_type = info["thumbnail_type"] == desired_type + t_w = info.width == desired_width + t_h = info.height == desired_height + t_method = info.method == desired_method + t_type = info.type == desired_type if t_w and t_h and t_method and t_type: file_info = FileInfo( server_name=server_name, file_id=media_info["filesystem_id"], - thumbnail=ThumbnailInfo( - width=info["thumbnail_width"], - height=info["thumbnail_height"], - type=info["thumbnail_type"], - method=info["thumbnail_method"], - ), + thumbnail=info, ) - t_type = file_info.thumbnail_type - t_length = info["thumbnail_length"] - responder = await self.media_storage.fetch_media(file_info) if responder: - await respond_with_responder(request, responder, t_type, t_length) + await respond_with_responder( + request, responder, info.type, info.length + ) return logger.debug("We don't have a thumbnail of that size. Generating") @@ -300,7 +292,7 @@ class ThumbnailResource(DirectServeJsonResource): desired_height: int, desired_method: str, desired_type: str, - thumbnail_infos: List[Dict[str, Any]], + thumbnail_infos: List[ThumbnailInfo], media_id: str, file_id: str, url_cache: bool, @@ -315,7 +307,7 @@ class ThumbnailResource(DirectServeJsonResource): desired_height: The desired height, the returned thumbnail may be larger than this. desired_method: The desired method used to generate the thumbnail. desired_type: The desired content-type of the thumbnail. - thumbnail_infos: A list of dictionaries of candidate thumbnails. + thumbnail_infos: A list of thumbnail info of candidate thumbnails. file_id: The ID of the media that a thumbnail is being requested for. url_cache: True if this is from a URL cache. server_name: The server name, if this is a remote thumbnail. @@ -418,13 +410,14 @@ class ThumbnailResource(DirectServeJsonResource): # `dynamic_thumbnails` is disabled. logger.info("Failed to find any generated thumbnails") + assert request.path is not None respond_with_json( request, 400, cs_error( - "Cannot find any thumbnails for the requested media (%r). This might mean the media is not a supported_media_format=(%s) or that thumbnailing failed for some other reason. (Dynamic thumbnails are disabled on this server.)" + "Cannot find any thumbnails for the requested media ('%s'). This might mean the media is not a supported_media_format=(%s) or that thumbnailing failed for some other reason. (Dynamic thumbnails are disabled on this server.)" % ( - request.postpath, + request.path.decode(), ", ".join(THUMBNAIL_SUPPORTED_MEDIA_FORMAT_MAP.keys()), ), code=Codes.UNKNOWN, @@ -438,7 +431,7 @@ class ThumbnailResource(DirectServeJsonResource): desired_height: int, desired_method: str, desired_type: str, - thumbnail_infos: List[Dict[str, Any]], + thumbnail_infos: List[ThumbnailInfo], file_id: str, url_cache: bool, server_name: Optional[str], @@ -451,7 +444,7 @@ class ThumbnailResource(DirectServeJsonResource): desired_height: The desired height, the returned thumbnail may be larger than this. desired_method: The desired method used to generate the thumbnail. desired_type: The desired content-type of the thumbnail. - thumbnail_infos: A list of dictionaries of candidate thumbnails. + thumbnail_infos: A list of thumbnail infos of candidate thumbnails. file_id: The ID of the media that a thumbnail is being requested for. url_cache: True if this is from a URL cache. server_name: The server name, if this is a remote thumbnail. @@ -469,21 +462,25 @@ class ThumbnailResource(DirectServeJsonResource): if desired_method == "crop": # Thumbnails that match equal or larger sizes of desired width/height. - crop_info_list: List[Tuple[int, int, int, bool, int, Dict[str, Any]]] = [] + crop_info_list: List[ + Tuple[int, int, int, bool, Optional[int], ThumbnailInfo] + ] = [] # Other thumbnails. - crop_info_list2: List[Tuple[int, int, int, bool, int, Dict[str, Any]]] = [] + crop_info_list2: List[ + Tuple[int, int, int, bool, Optional[int], ThumbnailInfo] + ] = [] for info in thumbnail_infos: # Skip thumbnails generated with different methods. - if info["thumbnail_method"] != "crop": + if info.method != "crop": continue - t_w = info["thumbnail_width"] - t_h = info["thumbnail_height"] + t_w = info.width + t_h = info.height aspect_quality = abs(d_w * t_h - d_h * t_w) min_quality = 0 if d_w <= t_w and d_h <= t_h else 1 size_quality = abs((d_w - t_w) * (d_h - t_h)) - type_quality = desired_type != info["thumbnail_type"] - length_quality = info["thumbnail_length"] + type_quality = desired_type != info.type + length_quality = info.length if t_w >= d_w or t_h >= d_h: crop_info_list.append( ( @@ -508,7 +505,7 @@ class ThumbnailResource(DirectServeJsonResource): ) # Pick the most appropriate thumbnail. Some values of `desired_width` and # `desired_height` may result in a tie, in which case we avoid comparing on - # the thumbnail info dictionary and pick the thumbnail that appears earlier + # the thumbnail info and pick the thumbnail that appears earlier # in the list of candidates. if crop_info_list: thumbnail_info = min(crop_info_list, key=lambda t: t[:-1])[-1] @@ -516,20 +513,20 @@ class ThumbnailResource(DirectServeJsonResource): thumbnail_info = min(crop_info_list2, key=lambda t: t[:-1])[-1] elif desired_method == "scale": # Thumbnails that match equal or larger sizes of desired width/height. - info_list: List[Tuple[int, bool, int, Dict[str, Any]]] = [] + info_list: List[Tuple[int, bool, int, ThumbnailInfo]] = [] # Other thumbnails. - info_list2: List[Tuple[int, bool, int, Dict[str, Any]]] = [] + info_list2: List[Tuple[int, bool, int, ThumbnailInfo]] = [] for info in thumbnail_infos: # Skip thumbnails generated with different methods. - if info["thumbnail_method"] != "scale": + if info.method != "scale": continue - t_w = info["thumbnail_width"] - t_h = info["thumbnail_height"] + t_w = info.width + t_h = info.height size_quality = abs((d_w - t_w) * (d_h - t_h)) - type_quality = desired_type != info["thumbnail_type"] - length_quality = info["thumbnail_length"] + type_quality = desired_type != info.type + length_quality = info.length if t_w >= d_w or t_h >= d_h: info_list.append((size_quality, type_quality, length_quality, info)) else: @@ -538,7 +535,7 @@ class ThumbnailResource(DirectServeJsonResource): ) # Pick the most appropriate thumbnail. Some values of `desired_width` and # `desired_height` may result in a tie, in which case we avoid comparing on - # the thumbnail info dictionary and pick the thumbnail that appears earlier + # the thumbnail info and pick the thumbnail that appears earlier # in the list of candidates. if info_list: thumbnail_info = min(info_list, key=lambda t: t[:-1])[-1] @@ -550,13 +547,7 @@ class ThumbnailResource(DirectServeJsonResource): file_id=file_id, url_cache=url_cache, server_name=server_name, - thumbnail=ThumbnailInfo( - width=thumbnail_info["thumbnail_width"], - height=thumbnail_info["thumbnail_height"], - type=thumbnail_info["thumbnail_type"], - method=thumbnail_info["thumbnail_method"], - length=thumbnail_info["thumbnail_length"], - ), + thumbnail=thumbnail_info, ) # No matching thumbnail was found. diff --git a/synapse/rest/media/upload_resource.py b/synapse/rest/media/upload_resource.py index 697348613b..949326d85d 100644 --- a/synapse/rest/media/upload_resource.py +++ b/synapse/rest/media/upload_resource.py @@ -14,11 +14,12 @@ # limitations under the License. import logging +import re from typing import IO, TYPE_CHECKING, Dict, List, Optional from synapse.api.errors import Codes, SynapseError -from synapse.http.server import DirectServeJsonResource, respond_with_json -from synapse.http.servlet import parse_bytes_from_args +from synapse.http.server import respond_with_json +from synapse.http.servlet import RestServlet, parse_bytes_from_args from synapse.http.site import SynapseRequest from synapse.media.media_storage import SpamMediaException @@ -29,8 +30,8 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class UploadResource(DirectServeJsonResource): - isLeaf = True +class UploadResource(RestServlet): + PATTERNS = [re.compile("/_matrix/media/(r0|v3|v1)/upload")] def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"): super().__init__() @@ -39,15 +40,11 @@ class UploadResource(DirectServeJsonResource): self.filepaths = media_repo.filepaths self.store = hs.get_datastores().main self.clock = hs.get_clock() - self.server_name = hs.hostname self.auth = hs.get_auth() self.max_upload_size = hs.config.media.max_upload_size self.clock = hs.get_clock() - async def _async_render_OPTIONS(self, request: SynapseRequest) -> None: - respond_with_json(request, 200, {}, send_cors=True) - - async def _async_render_POST(self, request: SynapseRequest) -> None: + async def on_POST(self, request: SynapseRequest) -> None: requester = await self.auth.get_user_by_req(request) raw_content_length = request.getHeader("Content-Length") if raw_content_length is None: diff --git a/synapse/rest/models.py b/synapse/rest/models.py index ac39cda8e5..de354a2135 100644 --- a/synapse/rest/models.py +++ b/synapse/rest/models.py @@ -1,4 +1,24 @@ -from pydantic import BaseModel, Extra +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from synapse._pydantic_compat import HAS_PYDANTIC_V2 + +if TYPE_CHECKING or HAS_PYDANTIC_V2: + from pydantic.v1 import BaseModel, Extra +else: + from pydantic import BaseModel, Extra class RequestBodyModel(BaseModel): diff --git a/synapse/rest/synapse/client/__init__.py b/synapse/rest/synapse/client/__init__.py index e55924f597..57335fb913 100644 --- a/synapse/rest/synapse/client/__init__.py +++ b/synapse/rest/synapse/client/__init__.py @@ -46,6 +46,12 @@ def build_synapse_client_resource_tree(hs: "HomeServer") -> Mapping[str, Resourc "/_synapse/client/unsubscribe": UnsubscribeResource(hs), } + # Expose the JWKS endpoint if OAuth2 delegation is enabled + if hs.config.experimental.msc3861.enabled: + from synapse.rest.synapse.client.jwks import JwksResource + + resources["/_synapse/jwks"] = JwksResource(hs) + # provider-specific SSO bits. Only load these if they are enabled, since they # rely on optional dependencies. if hs.config.oidc.oidc_enabled: diff --git a/synapse/rest/synapse/client/jwks.py b/synapse/rest/synapse/client/jwks.py new file mode 100644 index 0000000000..7c0a1223fb --- /dev/null +++ b/synapse/rest/synapse/client/jwks.py @@ -0,0 +1,70 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import TYPE_CHECKING, Tuple + +from synapse.http.server import DirectServeJsonResource +from synapse.http.site import SynapseRequest +from synapse.types import JsonDict + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class JwksResource(DirectServeJsonResource): + def __init__(self, hs: "HomeServer"): + super().__init__(extract_context=True) + + # Parameters that are allowed to be exposed in the public key. + # This is done manually, because authlib's private to public key conversion + # is unreliable depending on the version. Instead, we just serialize the private + # key and only keep the public parameters. + # List from https://www.iana.org/assignments/jose/jose.xhtml#web-key-parameters + public_parameters = { + "kty", + "use", + "key_ops", + "alg", + "kid", + "x5u", + "x5c", + "x5t", + "x5t#S256", + "crv", + "x", + "y", + "n", + "e", + "ext", + } + + key = hs.config.experimental.msc3861.jwk + + if key is not None: + private_key = key.as_dict() + public_key = { + k: v for k, v in private_key.items() if k in public_parameters + } + keys = [public_key] + else: + keys = [] + + self.res = { + "keys": keys, + } + + async def _async_render_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + return 200, self.res diff --git a/synapse/rest/synapse/client/unsubscribe.py b/synapse/rest/synapse/client/unsubscribe.py index 60321018f9..050fd7bba1 100644 --- a/synapse/rest/synapse/client/unsubscribe.py +++ b/synapse/rest/synapse/client/unsubscribe.py @@ -38,6 +38,10 @@ class UnsubscribeResource(DirectServeHtmlResource): self.macaroon_generator = hs.get_macaroon_generator() async def _async_render_GET(self, request: SynapseRequest) -> None: + """ + Handle a user opening an unsubscribe link in the browser, either via an + HTML/Text email or via the List-Unsubscribe header. + """ token = parse_string(request, "access_token", required=True) app_id = parse_string(request, "app_id", required=True) pushkey = parse_string(request, "pushkey", required=True) @@ -62,3 +66,16 @@ class UnsubscribeResource(DirectServeHtmlResource): 200, UnsubscribeResource.SUCCESS_HTML, ) + + async def _async_render_POST(self, request: SynapseRequest) -> None: + """ + Handle a mail user agent POSTing to the unsubscribe URL via the + List-Unsubscribe & List-Unsubscribe-Post headers. + """ + + # TODO Assert that the body has a single field + + # Assert the body has form encoded key/value pair of + # List-Unsubscribe=One-Click. + + await self._async_render_GET(request) diff --git a/synapse/rest/well_known.py b/synapse/rest/well_known.py index e2174fdfea..b8b4b5379b 100644 --- a/synapse/rest/well_known.py +++ b/synapse/rest/well_known.py @@ -44,6 +44,16 @@ class WellKnownBuilder: "base_url": self._config.registration.default_identity_server } + # We use the MSC3861 values as they are used by multiple MSCs + if self._config.experimental.msc3861.enabled: + result["org.matrix.msc2965.authentication"] = { + "issuer": self._config.experimental.msc3861.issuer + } + if self._config.experimental.msc3861.account_management_url is not None: + result["org.matrix.msc2965.authentication"][ + "account" + ] = self._config.experimental.msc3861.account_management_url + if self._config.server.extra_well_known_client_content: for ( key, diff --git a/synapse/server.py b/synapse/server.py index f6e245569c..71ead524d6 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -31,6 +31,7 @@ from twisted.web.iweb import IPolicyForHTTPS from twisted.web.resource import Resource from synapse.api.auth import Auth +from synapse.api.auth.internal import InternalAuth from synapse.api.auth_blocking import AuthBlocking from synapse.api.filtering import Filtering from synapse.api.ratelimiting import Ratelimiter, RequestRatelimiter @@ -90,7 +91,6 @@ from synapse.handlers.room import ( RoomShutdownHandler, TimestampLookupHandler, ) -from synapse.handlers.room_batch import RoomBatchHandler from synapse.handlers.room_list import RoomListHandler from synapse.handlers.room_member import ( RoomForgetterHandler, @@ -107,6 +107,7 @@ from synapse.handlers.stats import StatsHandler from synapse.handlers.sync import SyncHandler from synapse.handlers.typing import FollowerTypingHandler, TypingWriterHandler from synapse.handlers.user_directory import UserDirectoryHandler +from synapse.handlers.worker_lock import WorkerLocksHandler from synapse.http.client import ( InsecureInterceptableContextFactory, ReplicationClient, @@ -141,6 +142,7 @@ from synapse.util.distributor import Distributor from synapse.util.macaroons import MacaroonGenerator from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.stringutils import random_string +from synapse.util.task_scheduler import TaskScheduler logger = logging.getLogger(__name__) @@ -359,6 +361,7 @@ class HomeServer(metaclass=abc.ABCMeta): """ for i in self.REQUIRED_ON_BACKGROUND_TASK_STARTUP: getattr(self, "get_" + i + "_handler")() + self.get_task_scheduler() def get_reactor(self) -> ISynapseReactor: """ @@ -405,8 +408,7 @@ class HomeServer(metaclass=abc.ABCMeta): return Ratelimiter( store=self.get_datastores().main, clock=self.get_clock(), - rate_hz=self.config.ratelimiting.rc_registration.per_second, - burst_count=self.config.ratelimiting.rc_registration.burst_count, + cfg=self.config.ratelimiting.rc_registration, ) @cache_in_self @@ -427,7 +429,11 @@ class HomeServer(metaclass=abc.ABCMeta): @cache_in_self def get_auth(self) -> Auth: - return Auth(self) + if self.config.experimental.msc3861.enabled: + from synapse.api.auth.msc3861_delegated import MSC3861DelegatedAuth + + return MSC3861DelegatedAuth(self) + return InternalAuth(self) @cache_in_self def get_auth_blocking(self) -> AuthBlocking: @@ -488,10 +494,6 @@ class HomeServer(metaclass=abc.ABCMeta): return RoomCreationHandler(self) @cache_in_self - def get_room_batch_handler(self) -> RoomBatchHandler: - return RoomBatchHandler(self) - - @cache_in_self def get_room_shutdown_handler(self) -> RoomShutdownHandler: return RoomShutdownHandler(self) @@ -784,9 +786,7 @@ class HomeServer(metaclass=abc.ABCMeta): @cache_in_self def get_event_client_serializer(self) -> EventClientSerializer: - return EventClientSerializer( - msc3970_enabled=self.config.experimental.msc3970_enabled - ) + return EventClientSerializer() @cache_in_self def get_password_policy_handler(self) -> PasswordPolicyHandler: @@ -864,22 +864,36 @@ class HomeServer(metaclass=abc.ABCMeta): # We only want to import redis module if we're using it, as we have # `txredisapi` as an optional dependency. - from synapse.replication.tcp.redis import lazyConnection + from synapse.replication.tcp.redis import lazyConnection, lazyUnixConnection - logger.info( - "Connecting to redis (host=%r port=%r) for external cache", - self.config.redis.redis_host, - self.config.redis.redis_port, - ) + if self.config.redis.redis_path is None: + logger.info( + "Connecting to redis (host=%r port=%r) for external cache", + self.config.redis.redis_host, + self.config.redis.redis_port, + ) - return lazyConnection( - hs=self, - host=self.config.redis.redis_host, - port=self.config.redis.redis_port, - dbid=self.config.redis.redis_dbid, - password=self.config.redis.redis_password, - reconnect=True, - ) + return lazyConnection( + hs=self, + host=self.config.redis.redis_host, + port=self.config.redis.redis_port, + dbid=self.config.redis.redis_dbid, + password=self.config.redis.redis_password, + reconnect=True, + ) + else: + logger.info( + "Connecting to redis (path=%r) for external cache", + self.config.redis.redis_path, + ) + + return lazyUnixConnection( + hs=self, + path=self.config.redis.redis_path, + dbid=self.config.redis.redis_dbid, + password=self.config.redis.redis_password, + reconnect=True, + ) def should_send_federation(self) -> bool: "Should this server be sending federation traffic directly?" @@ -898,3 +912,11 @@ class HomeServer(metaclass=abc.ABCMeta): def get_common_usage_metrics_manager(self) -> CommonUsageMetricsManager: """Usage metrics shared between phone home stats and the prometheus exporter.""" return CommonUsageMetricsManager(self) + + @cache_in_self + def get_worker_locks_handler(self) -> WorkerLocksHandler: + return WorkerLocksHandler(self) + + @cache_in_self + def get_task_scheduler(self) -> TaskScheduler: + return TaskScheduler(self) diff --git a/synapse/server_notices/consent_server_notices.py b/synapse/server_notices/consent_server_notices.py index 94025ba41f..a879b6505e 100644 --- a/synapse/server_notices/consent_server_notices.py +++ b/synapse/server_notices/consent_server_notices.py @@ -79,15 +79,15 @@ class ConsentServerNotices: if u is None: return - if u["is_guest"] and not self._send_to_guests: + if u.is_guest and not self._send_to_guests: # don't send to guests return - if u["consent_version"] == self._current_consent_version: + if u.consent_version == self._current_consent_version: # user has already consented return - if u["consent_server_notice_sent"] == self._current_consent_version: + if u.consent_server_notice_sent == self._current_consent_version: # we've already sent a notice to the user return diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 6031095249..e977ed1044 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -20,7 +20,6 @@ from typing import ( Any, Awaitable, Callable, - Collection, DefaultDict, Dict, FrozenSet, @@ -45,10 +44,11 @@ from synapse.events.snapshot import ( UnpersistedEventContextBase, ) from synapse.logging.context import ContextResourceUsage +from synapse.logging.opentracing import tag_args, trace from synapse.replication.http.state import ReplicationUpdateCurrentStateRestServlet from synapse.state import v1, v2 from synapse.storage.databases.main.events_worker import EventRedactBehaviour -from synapse.types import StateMap +from synapse.types import StateMap, StrCollection from synapse.types.state import StateFilter from synapse.util.async_helpers import Linearizer from synapse.util.caches.expiringcache import ExpiringCache @@ -196,7 +196,7 @@ class StateHandler: async def compute_state_after_events( self, room_id: str, - event_ids: Collection[str], + event_ids: StrCollection, state_filter: Optional[StateFilter] = None, await_full_state: bool = True, ) -> StateMap[str]: @@ -230,7 +230,7 @@ class StateHandler: return await ret.get_state(self._state_storage_controller, state_filter) async def get_current_user_ids_in_room( - self, room_id: str, latest_event_ids: Collection[str] + self, room_id: str, latest_event_ids: StrCollection ) -> Set[str]: """ Get the users IDs who are currently in a room. @@ -255,7 +255,7 @@ class StateHandler: return await self.store.get_joined_user_ids_from_state(room_id, state) async def get_hosts_in_room_at_events( - self, room_id: str, event_ids: Collection[str] + self, room_id: str, event_ids: StrCollection ) -> FrozenSet[str]: """Get the hosts that were in a room at the given event ids @@ -267,9 +267,10 @@ class StateHandler: The hosts in the room at the given events """ entry = await self.resolve_state_groups_for_events(room_id, event_ids) - state = await entry.get_state(self._state_storage_controller, StateFilter.all()) - return await self.store.get_joined_hosts(room_id, state, entry) + return await self._state_storage_controller.get_joined_hosts(room_id, entry) + @trace + @tag_args async def calculate_context_info( self, event: EventBase, @@ -465,9 +466,10 @@ class StateHandler: return await unpersisted_context.persist(event) + @trace @measure_func() async def resolve_state_groups_for_events( - self, room_id: str, event_ids: Collection[str], await_full_state: bool = True + self, room_id: str, event_ids: StrCollection, await_full_state: bool = True ) -> _StateCacheEntry: """Given a list of event_ids this method fetches the state at each event, resolves conflicts between them and returns them. @@ -879,7 +881,7 @@ class StateResolutionStore: store: "DataStore" def get_events( - self, event_ids: Collection[str], allow_rejected: bool = False + self, event_ids: StrCollection, allow_rejected: bool = False ) -> Awaitable[Dict[str, EventBase]]: """Get events from the database diff --git a/synapse/state/v1.py b/synapse/state/v1.py index 500e384695..c76a2f082e 100644 --- a/synapse/state/v1.py +++ b/synapse/state/v1.py @@ -17,7 +17,6 @@ import logging from typing import ( Awaitable, Callable, - Collection, Dict, Iterable, List, @@ -32,7 +31,7 @@ from synapse.api.constants import EventTypes from synapse.api.errors import AuthError from synapse.api.room_versions import RoomVersion from synapse.events import EventBase -from synapse.types import MutableStateMap, StateMap +from synapse.types import MutableStateMap, StateMap, StrCollection logger = logging.getLogger(__name__) @@ -45,7 +44,7 @@ async def resolve_events_with_store( room_version: RoomVersion, state_sets: Sequence[StateMap[str]], event_map: Optional[Dict[str, EventBase]], - state_map_factory: Callable[[Collection[str]], Awaitable[Dict[str, EventBase]]], + state_map_factory: Callable[[StrCollection], Awaitable[Dict[str, EventBase]]], ) -> StateMap[str]: """ Args: diff --git a/synapse/state/v2.py b/synapse/state/v2.py index 1b9d7d8457..b2e63aed1e 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -19,12 +19,10 @@ from typing import ( Any, Awaitable, Callable, - Collection, Dict, Generator, Iterable, List, - Mapping, Optional, Sequence, Set, @@ -39,7 +37,7 @@ from synapse.api.constants import EventTypes from synapse.api.errors import AuthError from synapse.api.room_versions import RoomVersion from synapse.events import EventBase -from synapse.types import MutableStateMap, StateMap +from synapse.types import MutableStateMap, StateMap, StrCollection logger = logging.getLogger(__name__) @@ -56,7 +54,7 @@ class StateResolutionStore(Protocol): # This is usually synapse.state.StateResolutionStore, but it's replaced with a # TestStateResolutionStore in tests. def get_events( - self, event_ids: Collection[str], allow_rejected: bool = False + self, event_ids: StrCollection, allow_rejected: bool = False ) -> Awaitable[Dict[str, EventBase]]: ... @@ -270,7 +268,7 @@ async def _get_power_level_for_sender( async def _get_auth_chain_difference( room_id: str, - state_sets: Sequence[Mapping[Any, str]], + state_sets: Sequence[StateMap[str]], unpersisted_events: Dict[str, EventBase], state_res_store: StateResolutionStore, ) -> Set[str]: @@ -366,7 +364,7 @@ async def _get_auth_chain_difference( union = unpersisted_set_ids[0].union(*unpersisted_set_ids[1:]) intersection = unpersisted_set_ids[0].intersection(*unpersisted_set_ids[1:]) - auth_difference_unpersisted_part: Collection[str] = union - intersection + auth_difference_unpersisted_part: StrCollection = union - intersection else: auth_difference_unpersisted_part = () state_sets_ids = [set(state_set.values()) for state_set in state_sets] @@ -406,7 +404,7 @@ def _seperate( # mypy doesn't understand that discarding None above means that conflicted # state is StateMap[Set[str]], not StateMap[Set[Optional[Str]]]. - return unconflicted_state, conflicted_state # type: ignore + return unconflicted_state, conflicted_state # type: ignore[return-value] def _is_power_event(event: EventBase) -> bool: @@ -667,7 +665,7 @@ async def _mainline_sort( order_map = {} for idx, ev_id in enumerate(event_ids, start=1): depth = await _get_mainline_depth_for_event( - event_map[ev_id], mainline_map, event_map, state_res_store + clock, event_map[ev_id], mainline_map, event_map, state_res_store ) order_map[ev_id] = (depth, event_map[ev_id].origin_server_ts, ev_id) @@ -682,6 +680,7 @@ async def _mainline_sort( async def _get_mainline_depth_for_event( + clock: Clock, event: EventBase, mainline_map: Dict[str, int], event_map: Dict[str, EventBase], @@ -704,6 +703,7 @@ async def _get_mainline_depth_for_event( # We do an iterative search, replacing `event with the power level in its # auth events (if any) + idx = 0 while tmp_event: depth = mainline_map.get(tmp_event.event_id) if depth is not None: @@ -720,6 +720,11 @@ async def _get_mainline_depth_for_event( tmp_event = aev break + idx += 1 + + if idx % _AWAIT_AFTER_ITERATIONS == 0: + await clock.sleep(0) + # Didn't find a power level auth event, so we just return 0 return 0 diff --git a/synapse/static/index.html b/synapse/static/index.html index bf46df9097..297a7877f3 100644 --- a/synapse/static/index.html +++ b/synapse/static/index.html @@ -48,7 +48,7 @@ </div> <h1>It works! Synapse is running</h1> <p>Your Synapse server is listening on this port and is ready for messages.</p> - <p>To use this server you'll need <a href="https://matrix.org/docs/projects/try-matrix-now.html#clients" target="_blank" rel="noopener noreferrer">a Matrix client</a>. + <p>To use this server you'll need <a href="https://matrix.org/ecosystem/clients/" target="_blank" rel="noopener noreferrer">a Matrix client</a>. </p> <p>Welcome to the Matrix universe :)</p> <hr> diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 481fec72fe..fe4a763411 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -86,9 +86,14 @@ class SQLBaseStore(metaclass=ABCMeta): room_id: Room where state changed members_changed: The user_ids of members that have changed """ + + # XXX: If you add something to this function make sure you add it to + # `_invalidate_state_caches_all` as well. + # If there were any membership changes, purge the appropriate caches. for host in {get_domain_from_id(u) for u in members_changed}: self._attempt_to_invalidate_cache("is_host_joined", (room_id, host)) + self._attempt_to_invalidate_cache("is_host_invited", (room_id, host)) if members_changed: self._attempt_to_invalidate_cache("get_users_in_room", (room_id,)) self._attempt_to_invalidate_cache("get_current_hosts_in_room", (room_id,)) @@ -117,6 +122,32 @@ class SQLBaseStore(metaclass=ABCMeta): self._attempt_to_invalidate_cache("get_room_summary", (room_id,)) self._attempt_to_invalidate_cache("get_partial_current_state_ids", (room_id,)) + def _invalidate_state_caches_all(self, room_id: str) -> None: + """Invalidates caches that are based on the current state, but does + not stream invalidations down replication. + + Same as `_invalidate_state_caches`, except that works when we don't know + which memberships have changed. + + Args: + room_id: Room where state changed + """ + self._attempt_to_invalidate_cache("get_partial_current_state_ids", (room_id,)) + self._attempt_to_invalidate_cache("get_users_in_room", (room_id,)) + self._attempt_to_invalidate_cache("is_host_invited", None) + self._attempt_to_invalidate_cache("is_host_joined", None) + self._attempt_to_invalidate_cache("get_current_hosts_in_room", (room_id,)) + self._attempt_to_invalidate_cache("get_users_in_room_with_profiles", (room_id,)) + self._attempt_to_invalidate_cache("get_number_joined_users_in_room", (room_id,)) + self._attempt_to_invalidate_cache("get_local_users_in_room", (room_id,)) + self._attempt_to_invalidate_cache("does_pair_of_users_share_a_room", None) + self._attempt_to_invalidate_cache("get_user_in_room_with_profile", None) + self._attempt_to_invalidate_cache( + "get_rooms_for_user_with_stream_ordering", None + ) + self._attempt_to_invalidate_cache("get_rooms_for_user", None) + self._attempt_to_invalidate_cache("get_room_summary", (room_id,)) + def _attempt_to_invalidate_cache( self, cache_name: str, key: Optional[Collection[Any]] ) -> bool: diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index a99aea8926..12829d3d7d 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -11,7 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import abc import logging +from enum import Enum, IntEnum from types import TracebackType from typing import ( TYPE_CHECKING, @@ -23,18 +25,27 @@ from typing import ( Iterable, List, Optional, + Sequence, + Tuple, Type, ) import attr +from synapse._pydantic_compat import HAS_PYDANTIC_V2 from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage.engines import PostgresEngine from synapse.storage.types import Connection, Cursor from synapse.types import JsonDict from synapse.util import Clock, json_encoder from . import engines +if TYPE_CHECKING or HAS_PYDANTIC_V2: + from pydantic.v1 import BaseModel +else: + from pydantic import BaseModel + if TYPE_CHECKING: from synapse.server import HomeServer from synapse.storage.database import DatabasePool, LoggingTransaction @@ -47,6 +58,81 @@ DEFAULT_BATCH_SIZE_CALLBACK = Callable[[str, str], Awaitable[int]] MIN_BATCH_SIZE_CALLBACK = Callable[[str, str], Awaitable[int]] +class Constraint(metaclass=abc.ABCMeta): + """Base class representing different constraints. + + Used by `register_background_validate_constraint_and_delete_rows`. + """ + + @abc.abstractmethod + def make_check_clause(self, table: str) -> str: + """Returns an SQL expression that checks the row passes the constraint.""" + + @abc.abstractmethod + def make_constraint_clause_postgres(self) -> str: + """Returns an SQL clause for creating the constraint. + + Only used on Postgres DBs + """ + + +@attr.s(auto_attribs=True) +class ForeignKeyConstraint(Constraint): + """A foreign key constraint. + + Attributes: + referenced_table: The "parent" table name. + columns: The list of mappings of columns from table to referenced table + deferred: Whether to defer checking of the constraint to the end of the + transaction. This is useful for e.g. backwards compatibility where + an older version inserted data in the wrong order. + """ + + referenced_table: str + columns: Sequence[Tuple[str, str]] + deferred: bool + + def make_check_clause(self, table: str) -> str: + join_clause = " AND ".join( + f"{col1} = {table}.{col2}" for col1, col2 in self.columns + ) + return f"EXISTS (SELECT 1 FROM {self.referenced_table} WHERE {join_clause})" + + def make_constraint_clause_postgres(self) -> str: + column1_list = ", ".join(col1 for col1, col2 in self.columns) + column2_list = ", ".join(col2 for col1, col2 in self.columns) + defer_clause = " DEFERRABLE INITIALLY DEFERRED" if self.deferred else "" + return f"FOREIGN KEY ({column1_list}) REFERENCES {self.referenced_table} ({column2_list}) {defer_clause}" + + +@attr.s(auto_attribs=True) +class NotNullConstraint(Constraint): + """A NOT NULL column constraint""" + + column: str + + def make_check_clause(self, table: str) -> str: + return f"{self.column} IS NOT NULL" + + def make_constraint_clause_postgres(self) -> str: + return f"CHECK ({self.column} IS NOT NULL)" + + +class ValidateConstraintProgress(BaseModel): + """The format of the progress JSON for validate constraint background + updates. + + Used by `register_background_validate_constraint_and_delete_rows`. + """ + + class State(str, Enum): + check = "check" + validate = "validate" + + state: State = State.validate + lower_bound: Sequence[Any] = () + + @attr.s(slots=True, frozen=True, auto_attribs=True) class _BackgroundUpdateHandler: """A handler for a given background update. @@ -136,6 +222,15 @@ class BackgroundUpdatePerformance: return float(self.total_item_count) / float(self.total_duration_ms) +class UpdaterStatus(IntEnum): + # Use negative values for error conditions. + ABORTED = -1 + DISABLED = 0 + NOT_STARTED = 1 + RUNNING_UPDATE = 2 + COMPLETE = 3 + + class BackgroundUpdater: """Background updates are updates to the database that run in the background. Each update processes a batch of data at once. We attempt to @@ -146,6 +241,7 @@ class BackgroundUpdater: def __init__(self, hs: "HomeServer", database: "DatabasePool"): self._clock = hs.get_clock() self.db_pool = database + self.hs = hs self._database_name = database.name() @@ -158,11 +254,16 @@ class BackgroundUpdater: self._background_update_performance: Dict[str, BackgroundUpdatePerformance] = {} self._background_update_handlers: Dict[str, _BackgroundUpdateHandler] = {} + # TODO: all these bool flags make me feel icky---can we combine into a status + # enum? self._all_done = False # Whether we're currently running updates self._running = False + # Marker to be set if we abort and halt all background updates. + self._aborted = False + # Whether background updates are enabled. This allows us to # enable/disable background updates via the admin API. self.enabled = True @@ -175,6 +276,20 @@ class BackgroundUpdater: self.sleep_duration_ms = hs.config.background_updates.sleep_duration_ms self.sleep_enabled = hs.config.background_updates.sleep_enabled + def get_status(self) -> UpdaterStatus: + """An integer summarising the updater status. Used as a metric.""" + if self._aborted: + return UpdaterStatus.ABORTED + # TODO: a status for "have seen at least one failure, but haven't aborted yet". + if not self.enabled: + return UpdaterStatus.DISABLED + + if self._all_done: + return UpdaterStatus.COMPLETE + if self._running: + return UpdaterStatus.RUNNING_UPDATE + return UpdaterStatus.NOT_STARTED + def register_update_controller_callbacks( self, on_update: ON_UPDATE_CALLBACK, @@ -293,13 +408,14 @@ class BackgroundUpdater: try: result = await self.do_next_background_update(sleep) back_to_back_failures = 0 - except Exception: + except Exception as e: + logger.exception("Error doing update: %s", e) back_to_back_failures += 1 if back_to_back_failures >= 5: + self._aborted = True raise RuntimeError( "5 back-to-back background update failures; aborting." ) - logger.exception("Error doing update") else: if result: logger.info( @@ -561,6 +677,50 @@ class BackgroundUpdater: updater, oneshot=True ) + def register_background_validate_constraint( + self, update_name: str, constraint_name: str, table: str + ) -> None: + """Helper for store classes to do a background validate constraint. + + This only applies on PostgreSQL. + + To use: + + 1. use a schema delta file to add a background update. Example: + INSERT INTO background_updates (update_name, progress_json) VALUES + ('validate_my_constraint', '{}'); + + 2. In the Store constructor, call this method + + Args: + update_name: update_name to register for + constraint_name: name of constraint to validate + table: table the constraint is applied to + """ + + def runner(conn: Connection) -> None: + c = conn.cursor() + + sql = f""" + ALTER TABLE {table} VALIDATE CONSTRAINT {constraint_name}; + """ + logger.debug("[SQL] %s", sql) + c.execute(sql) + + async def updater(progress: JsonDict, batch_size: int) -> int: + assert isinstance( + self.db_pool.engine, engines.PostgresEngine + ), "validate constraint background update registered for non-Postres database" + + logger.info("Validating constraint %s to %s", constraint_name, table) + await self.db_pool.runWithConnection(runner) + await self._end_background_update(update_name) + return 1 + + self._background_update_handlers[update_name] = _BackgroundUpdateHandler( + updater, oneshot=True + ) + async def create_index_in_background( self, index_name: str, @@ -602,6 +762,11 @@ class BackgroundUpdater: logger.debug("[SQL] %s", sql) c.execute(sql) + # override the global statement timeout to avoid accidentally squashing + # a long-running index creation process + timeout_sql = "SET SESSION statement_timeout = 0" + c.execute(timeout_sql) + sql = ( "CREATE %(unique)s INDEX CONCURRENTLY %(name)s" " ON %(table)s" @@ -622,6 +787,12 @@ class BackgroundUpdater: logger.debug("[SQL] %s", sql) c.execute(sql) finally: + # mypy ignore - `statement_timeout` is defined on PostgresEngine + # reset the global timeout to the default + default_timeout = self.db_pool.engine.statement_timeout # type: ignore[attr-defined] + undo_timeout_sql = f"SET statement_timeout = {default_timeout}" + conn.cursor().execute(undo_timeout_sql) + conn.set_session(autocommit=False) # type: ignore def create_index_sqlite(conn: Connection) -> None: @@ -666,6 +837,179 @@ class BackgroundUpdater: logger.info("Adding index %s to %s", index_name, table) await self.db_pool.runWithConnection(runner) + def register_background_validate_constraint_and_delete_rows( + self, + update_name: str, + table: str, + constraint_name: str, + constraint: Constraint, + unique_columns: Sequence[str], + ) -> None: + """Helper for store classes to do a background validate constraint, and + delete rows that do not pass the constraint check. + + Note: This deletes rows that don't match the constraint. This may not be + appropriate in all situations, and so the suitability of using this + method should be considered on a case-by-case basis. + + This only applies on PostgreSQL. + + For SQLite the table gets recreated as part of the schema delta and the + data is copied over synchronously (or whatever the correct way to + describe it as). + + Args: + update_name: The name of the background update. + table: The table with the invalid constraint. + constraint_name: The name of the constraint + constraint: A `Constraint` object matching the type of constraint. + unique_columns: A sequence of columns that form a unique constraint + on the table. Used to iterate over the table. + """ + + assert isinstance( + self.db_pool.engine, engines.PostgresEngine + ), "validate constraint background update registered for non-Postres database" + + async def updater(progress: JsonDict, batch_size: int) -> int: + return await self.validate_constraint_and_delete_in_background( + update_name=update_name, + table=table, + constraint_name=constraint_name, + constraint=constraint, + unique_columns=unique_columns, + progress=progress, + batch_size=batch_size, + ) + + self._background_update_handlers[update_name] = _BackgroundUpdateHandler( + updater, oneshot=True + ) + + async def validate_constraint_and_delete_in_background( + self, + update_name: str, + table: str, + constraint_name: str, + constraint: Constraint, + unique_columns: Sequence[str], + progress: JsonDict, + batch_size: int, + ) -> int: + """Validates a table constraint that has been marked as `NOT VALID`, + deleting rows that don't pass the constraint check. + + This will delete rows that do not meet the validation check. + + update_name: str, + table: str, + constraint_name: str, + constraint: Constraint, + unique_columns: Sequence[str], + """ + + # We validate the constraint by: + # 1. Trying to validate the constraint as is. If this succeeds then + # we're done. + # 2. Otherwise, we manually scan the table to remove rows that don't + # match the constraint. + # 3. We try re-validating the constraint. + + parsed_progress = ValidateConstraintProgress.parse_obj(progress) + + if parsed_progress.state == ValidateConstraintProgress.State.check: + return_columns = ", ".join(unique_columns) + order_columns = ", ".join(unique_columns) + + where_clause = "" + args: List[Any] = [] + if parsed_progress.lower_bound: + where_clause = f"""WHERE ({order_columns}) > ({", ".join("?" for _ in unique_columns)})""" + args.extend(parsed_progress.lower_bound) + + args.append(batch_size) + + sql = f""" + SELECT + {return_columns}, + {constraint.make_check_clause(table)} AS check + FROM {table} + {where_clause} + ORDER BY {order_columns} + LIMIT ? + """ + + def validate_constraint_in_background_check( + txn: "LoggingTransaction", + ) -> None: + txn.execute(sql, args) + rows = txn.fetchall() + + new_progress = parsed_progress.copy() + + if not rows: + new_progress.state = ValidateConstraintProgress.State.validate + self._background_update_progress_txn( + txn, update_name, new_progress.dict() + ) + return + + new_progress.lower_bound = rows[-1][:-1] + + to_delete = [row[:-1] for row in rows if not row[-1]] + + if to_delete: + logger.warning( + "Deleting %d rows that do not pass new constraint", + len(to_delete), + ) + + self.db_pool.simple_delete_many_batch_txn( + txn, table=table, keys=unique_columns, values=to_delete + ) + + self._background_update_progress_txn( + txn, update_name, new_progress.dict() + ) + + await self.db_pool.runInteraction( + "validate_constraint_in_background_check", + validate_constraint_in_background_check, + ) + + return batch_size + + elif parsed_progress.state == ValidateConstraintProgress.State.validate: + sql = f"ALTER TABLE {table} VALIDATE CONSTRAINT {constraint_name}" + + def validate_constraint_in_background_validate( + txn: "LoggingTransaction", + ) -> None: + txn.execute(sql) + + try: + await self.db_pool.runInteraction( + "validate_constraint_in_background_validate", + validate_constraint_in_background_validate, + ) + + await self._end_background_update(update_name) + except self.db_pool.engine.module.IntegrityError as e: + # If we get an integrity error here, then we go back and recheck the table. + logger.warning("Integrity error when validating constraint: %s", e) + await self._background_update_progress( + update_name, + ValidateConstraintProgress( + state=ValidateConstraintProgress.State.check + ).dict(), + ) + + return batch_size + else: + raise Exception( + f"Unrecognized state '{parsed_progress.state}' when trying to validate_constraint_and_delete_in_background" + ) + async def _end_background_update(self, update_name: str) -> None: """Removes a completed background update task from the queue. @@ -721,3 +1065,86 @@ class BackgroundUpdater: keyvalues={"update_name": update_name}, updatevalues={"progress_json": progress_json}, ) + + +def run_validate_constraint_and_delete_rows_schema_delta( + txn: "LoggingTransaction", + ordering: int, + update_name: str, + table: str, + constraint_name: str, + constraint: Constraint, + sqlite_table_name: str, + sqlite_table_schema: str, +) -> None: + """Runs a schema delta to add a constraint to the table. This should be run + in a schema delta file. + + For PostgreSQL the constraint is added and validated in the background. + + For SQLite the table is recreated and data copied across immediately. This + is done by the caller passing in a script to create the new table. Note that + table indexes and triggers are copied over automatically. + + There must be a corresponding call to + `register_background_validate_constraint_and_delete_rows` to register the + background update in one of the data store classes. + + Attributes: + txn ordering, update_name: For adding a row to background_updates table. + table: The table to add constraint to. constraint_name: The name of the + new constraint constraint: A `Constraint` object describing the + constraint sqlite_table_name: For SQLite the name of the empty copy of + table sqlite_table_schema: A SQL script for creating the above table. + """ + + if isinstance(txn.database_engine, PostgresEngine): + # For postgres we can just add the constraint and mark it as NOT VALID, + # and then insert a background update to go and check the validity in + # the background. + txn.execute( + f""" + ALTER TABLE {table} + ADD CONSTRAINT {constraint_name} {constraint.make_constraint_clause_postgres()} + NOT VALID + """ + ) + + txn.execute( + "INSERT INTO background_updates (ordering, update_name, progress_json) VALUES (?, ?, '{}')", + (ordering, update_name), + ) + else: + # For SQLite, we: + # 1. fetch all indexes/triggers/etc related to the table + # 2. create an empty copy of the table + # 3. copy across the rows (that satisfy the check) + # 4. replace the old table with the new able. + # 5. add back all the indexes/triggers/etc + + # Fetch the indexes/triggers/etc. Note that `sql` column being null is + # due to indexes being auto created based on the class definition (e.g. + # PRIMARY KEY), and so don't need to be recreated. + txn.execute( + """ + SELECT sql FROM sqlite_master + WHERE tbl_name = ? AND type != 'table' AND sql IS NOT NULL + """, + (table,), + ) + extras = [row[0] for row in txn] + + txn.execute(sqlite_table_schema) + + sql = f""" + INSERT INTO {sqlite_table_name} SELECT * FROM {table} + WHERE {constraint.make_check_clause(table)} + """ + + txn.execute(sql) + + txn.execute(f"DROP TABLE {table}") + txn.execute(f"ALTER TABLE {sqlite_table_name} RENAME TO {table}") + + for extra in extras: + txn.execute(extra) diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py index f1d2c71c91..f39ae2d635 100644 --- a/synapse/storage/controllers/persist_events.py +++ b/synapse/storage/controllers/persist_events.py @@ -19,6 +19,7 @@ import logging from collections import deque from typing import ( TYPE_CHECKING, + AbstractSet, Any, Awaitable, Callable, @@ -45,6 +46,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership from synapse.events import EventBase from synapse.events.snapshot import EventContext +from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable from synapse.logging.opentracing import ( SynapseTags, @@ -153,12 +155,13 @@ class _UpdateCurrentStateTask: _EventPersistQueueTask = Union[_PersistEventsTask, _UpdateCurrentStateTask] +_PersistResult = TypeVar("_PersistResult") @attr.s(auto_attribs=True, slots=True) -class _EventPersistQueueItem: +class _EventPersistQueueItem(Generic[_PersistResult]): task: _EventPersistQueueTask - deferred: ObservableDeferred + deferred: ObservableDeferred[_PersistResult] parent_opentracing_span_contexts: List = attr.ib(factory=list) """A list of opentracing spans waiting for this batch""" @@ -167,9 +170,6 @@ class _EventPersistQueueItem: """The opentracing span under which the persistence actually happened""" -_PersistResult = TypeVar("_PersistResult") - - class _EventPeristenceQueue(Generic[_PersistResult]): """Queues up tasks so that they can be processed with only one concurrent transaction per room. @@ -338,6 +338,7 @@ class EventsPersistenceStorageController: ) self._state_resolution_handler = hs.get_state_resolution_handler() self._state_controller = state_controller + self.hs = hs async def _process_event_persist_queue_task( self, @@ -350,15 +351,22 @@ class EventsPersistenceStorageController: A dictionary of event ID to event ID we didn't persist as we already had another event persisted with the same TXN ID. """ - if isinstance(task, _PersistEventsTask): - return await self._persist_event_batch(room_id, task) - elif isinstance(task, _UpdateCurrentStateTask): - await self._update_current_state(room_id, task) - return {} - else: - raise AssertionError( - f"Found an unexpected task type in event persistence queue: {task}" - ) + + # Ensure that the room can't be deleted while we're persisting events to + # it. We might already have taken out the lock, but since this is just a + # "read" lock its inherently reentrant. + async with self.hs.get_worker_locks_handler().acquire_read_write_lock( + NEW_EVENT_DURING_PURGE_LOCK_NAME, room_id, write=False + ): + if isinstance(task, _PersistEventsTask): + return await self._persist_event_batch(room_id, task) + elif isinstance(task, _UpdateCurrentStateTask): + await self._update_current_state(room_id, task) + return {} + else: + raise AssertionError( + f"Found an unexpected task type in event persistence queue: {task}" + ) @trace async def persist_events( @@ -611,7 +619,7 @@ class EventsPersistenceStorageController: ) for room_id, ev_ctx_rm in events_by_room.items(): - latest_event_ids = set( + latest_event_ids = ( await self.main_store.get_latest_event_ids_in_room(room_id) ) new_latest_event_ids = await self._calculate_new_extremities( @@ -733,7 +741,7 @@ class EventsPersistenceStorageController: self, room_id: str, event_contexts: List[Tuple[EventBase, EventContext]], - latest_event_ids: Collection[str], + latest_event_ids: AbstractSet[str], ) -> Set[str]: """Calculates the new forward extremities for a room given events to persist. @@ -751,8 +759,6 @@ class EventsPersistenceStorageController: and not event.internal_metadata.is_soft_failed() ] - latest_event_ids = set(latest_event_ids) - # start with the existing forward extremities result = set(latest_event_ids) @@ -791,7 +797,7 @@ class EventsPersistenceStorageController: self, room_id: str, events_context: List[Tuple[EventBase, EventContext]], - old_latest_event_ids: Set[str], + old_latest_event_ids: AbstractSet[str], new_latest_event_ids: Set[str], ) -> Tuple[Optional[StateMap[str]], Optional[StateMap[str]], Set[str]]: """Calculate the current state dict after adding some new events to @@ -839,9 +845,8 @@ class EventsPersistenceStorageController: "group" % (ev.event_id,) ) continue - - if ctx.prev_group: - state_group_deltas[(ctx.prev_group, ctx.state_group)] = ctx.delta_ids + if ctx.state_group_deltas: + state_group_deltas.update(ctx.state_group_deltas) # We need to map the event_ids to their state groups. First, let's # check if the event is one we're persisting, in which case we can diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py index 9d7a8a792f..46957723a1 100644 --- a/synapse/storage/controllers/state.py +++ b/synapse/storage/controllers/state.py @@ -12,22 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from itertools import chain from typing import ( TYPE_CHECKING, AbstractSet, Any, - Awaitable, Callable, Collection, Dict, + FrozenSet, Iterable, List, Mapping, Optional, Tuple, + Union, ) -from synapse.api.constants import EventTypes +from synapse.api.constants import EventTypes, Membership from synapse.events import EventBase from synapse.logging.opentracing import tag_args, trace from synapse.storage.roommember import ProfileInfo @@ -35,14 +37,21 @@ from synapse.storage.util.partial_state_events_tracker import ( PartialCurrentStateTracker, PartialStateEventsTracker, ) -from synapse.types import MutableStateMap, StateMap +from synapse.synapse_rust.acl import ServerAclEvaluator +from synapse.types import MutableStateMap, StateMap, get_domain_from_id from synapse.types.state import StateFilter +from synapse.util.async_helpers import Linearizer +from synapse.util.caches import intern_string +from synapse.util.caches.descriptors import cached from synapse.util.cancellation import cancellable +from synapse.util.metrics import Measure if TYPE_CHECKING: from synapse.server import HomeServer + from synapse.state import _StateCacheEntry from synapse.storage.databases import Databases + logger = logging.getLogger(__name__) @@ -53,10 +62,15 @@ class StateStorageController: def __init__(self, hs: "HomeServer", stores: "Databases"): self._is_mine_id = hs.is_mine_id + self._clock = hs.get_clock() self.stores = stores self._partial_state_events_tracker = PartialStateEventsTracker(stores.main) self._partial_state_room_tracker = PartialCurrentStateTracker(stores.main) + # Used by `_get_joined_hosts` to ensure only one thing mutates the cache + # at a time. Keyed by room_id. + self._joined_host_linearizer = Linearizer("_JoinedHostsCache") + def notify_event_un_partial_stated(self, event_id: str) -> None: self._partial_state_events_tracker.notify_un_partial_stated(event_id) @@ -67,6 +81,8 @@ class StateStorageController: """ self._partial_state_room_tracker.notify_un_partial_stated(room_id) + @trace + @tag_args async def get_state_group_delta( self, state_group: int ) -> Tuple[Optional[int], Optional[StateMap[str]]]: @@ -84,6 +100,8 @@ class StateStorageController: state_group_delta = await self.stores.state.get_state_group_delta(state_group) return state_group_delta.prev_group, state_group_delta.delta_ids + @trace + @tag_args async def get_state_groups_ids( self, _room_id: str, event_ids: Collection[str], await_full_state: bool = True ) -> Dict[int, MutableStateMap[str]]: @@ -114,6 +132,8 @@ class StateStorageController: return group_to_state + @trace + @tag_args async def get_state_ids_for_group( self, state_group: int, state_filter: Optional[StateFilter] = None ) -> StateMap[str]: @@ -130,6 +150,8 @@ class StateStorageController: return group_to_state[state_group] + @trace + @tag_args async def get_state_groups( self, room_id: str, event_ids: Collection[str] ) -> Dict[int, List[EventBase]]: @@ -165,9 +187,11 @@ class StateStorageController: for group, event_id_map in group_to_ids.items() } - def _get_state_groups_from_groups( + @trace + @tag_args + async def _get_state_groups_from_groups( self, groups: List[int], state_filter: StateFilter - ) -> Awaitable[Dict[int, StateMap[str]]]: + ) -> Dict[int, StateMap[str]]: """Returns the state groups for a given set of groups, filtering on types of state events. @@ -180,9 +204,12 @@ class StateStorageController: Dict of state group to state map. """ - return self.stores.state._get_state_groups_from_groups(groups, state_filter) + return await self.stores.state._get_state_groups_from_groups( + groups, state_filter + ) @trace + @tag_args async def get_state_for_events( self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None ) -> Dict[str, StateMap[EventBase]]: @@ -280,6 +307,8 @@ class StateStorageController: return {event: event_to_state[event] for event in event_ids} + @trace + @tag_args async def get_state_for_event( self, event_id: str, state_filter: Optional[StateFilter] = None ) -> StateMap[EventBase]: @@ -303,6 +332,7 @@ class StateStorageController: return state_map[event_id] @trace + @tag_args async def get_state_ids_for_event( self, event_id: str, @@ -333,9 +363,11 @@ class StateStorageController: ) return state_map[event_id] - def get_state_for_groups( + @trace + @tag_args + async def get_state_for_groups( self, groups: Iterable[int], state_filter: Optional[StateFilter] = None - ) -> Awaitable[Dict[int, MutableStateMap[str]]]: + ) -> Dict[int, MutableStateMap[str]]: """Gets the state at each of a list of state groups, optionally filtering by type/state_key @@ -347,7 +379,7 @@ class StateStorageController: Returns: Dict of state group to state map. """ - return self.stores.state._get_state_for_groups( + return await self.stores.state._get_state_for_groups( groups, state_filter or StateFilter.all() ) @@ -402,6 +434,8 @@ class StateStorageController: event_id, room_id, prev_group, delta_ids, current_state_ids ) + @trace + @tag_args @cancellable async def get_current_state_ids( self, @@ -442,6 +476,8 @@ class StateStorageController: room_id, on_invalidate=on_invalidate ) + @trace + @tag_args async def get_canonical_alias_for_room(self, room_id: str) -> Optional[str]: """Get canonical alias for room, if any @@ -464,8 +500,35 @@ class StateStorageController: if not event: return None - return event.content.get("canonical_alias") + return event.content.get("alias") + + @cached() + async def get_server_acl_for_room( + self, room_id: str + ) -> Optional[ServerAclEvaluator]: + """Get the server ACL evaluator for room, if any + + This does up-front parsing of the content to ignore bad data and pre-compile + regular expressions. + + Args: + room_id: The room ID + + Returns: + The server ACL evaluator, if any + """ + + acl_event = await self.get_current_state_event( + room_id, EventTypes.ServerACL, "" + ) + + if not acl_event: + return None + + return server_acl_evaluator_from_event(acl_event) + @trace + @tag_args async def get_current_state_deltas( self, prev_stream_id: int, max_stream_id: int ) -> Tuple[int, List[Dict[str, Any]]]: @@ -500,6 +563,7 @@ class StateStorageController: ) @trace + @tag_args async def get_current_state( self, room_id: str, state_filter: Optional[StateFilter] = None ) -> StateMap[EventBase]: @@ -516,6 +580,8 @@ class StateStorageController: return state_map + @trace + @tag_args async def get_current_state_event( self, room_id: str, event_type: str, state_key: str ) -> Optional[EventBase]: @@ -527,6 +593,8 @@ class StateStorageController: ) return state_map.get(key) + @trace + @tag_args async def get_current_hosts_in_room(self, room_id: str) -> AbstractSet[str]: """Get current hosts in room based on current state. @@ -538,7 +606,9 @@ class StateStorageController: return await self.stores.main.get_current_hosts_in_room(room_id) - async def get_current_hosts_in_room_ordered(self, room_id: str) -> List[str]: + @trace + @tag_args + async def get_current_hosts_in_room_ordered(self, room_id: str) -> Tuple[str, ...]: """Get current hosts in room based on current state. Blocks until we have full state for the given room. This only happens for rooms @@ -553,6 +623,8 @@ class StateStorageController: return await self.stores.main.get_current_hosts_in_room_ordered(room_id) + @trace + @tag_args async def get_current_hosts_in_room_or_partial_state_approximation( self, room_id: str ) -> Collection[str]: @@ -582,6 +654,8 @@ class StateStorageController: return hosts + @trace + @tag_args async def get_users_in_room_with_profiles( self, room_id: str ) -> Mapping[str, ProfileInfo]: @@ -593,3 +667,155 @@ class StateStorageController: await self._partial_state_room_tracker.await_full_state(room_id) return await self.stores.main.get_users_in_room_with_profiles(room_id) + + async def get_joined_hosts( + self, room_id: str, state_entry: "_StateCacheEntry" + ) -> FrozenSet[str]: + state_group: Union[object, int] = state_entry.state_group + if not state_group: + # If state_group is None it means it has yet to be assigned a + # state group, i.e. we need to make sure that calls with a state_group + # of None don't hit previous cached calls with a None state_group. + # To do this we set the state_group to a new object as object() != object() + state_group = object() + + assert state_group is not None + with Measure(self._clock, "get_joined_hosts"): + return await self._get_joined_hosts( + room_id, state_group, state_entry=state_entry + ) + + @cached(num_args=2, max_entries=10000, iterable=True) + async def _get_joined_hosts( + self, + room_id: str, + state_group: Union[object, int], + state_entry: "_StateCacheEntry", + ) -> FrozenSet[str]: + # We don't use `state_group`, it's there so that we can cache based on + # it. However, its important that its never None, since two + # current_state's with a state_group of None are likely to be different. + # + # The `state_group` must match the `state_entry.state_group` (if not None). + assert state_group is not None + assert state_entry.state_group is None or state_entry.state_group == state_group + + # We use a secondary cache of previous work to allow us to build up the + # joined hosts for the given state group based on previous state groups. + # + # We cache one object per room containing the results of the last state + # group we got joined hosts for. The idea is that generally + # `get_joined_hosts` is called with the "current" state group for the + # room, and so consecutive calls will be for consecutive state groups + # which point to the previous state group. + cache = await self.stores.main._get_joined_hosts_cache(room_id) + + # If the state group in the cache matches, we already have the data we need. + if state_entry.state_group == cache.state_group: + return frozenset(cache.hosts_to_joined_users) + + # Since we'll mutate the cache we need to lock. + async with self._joined_host_linearizer.queue(room_id): + if state_entry.state_group == cache.state_group: + # Same state group, so nothing to do. We've already checked for + # this above, but the cache may have changed while waiting on + # the lock. + pass + elif state_entry.prev_group == cache.state_group: + # The cached work is for the previous state group, so we work out + # the delta. + assert state_entry.delta_ids is not None + for (typ, state_key), event_id in state_entry.delta_ids.items(): + if typ != EventTypes.Member: + continue + + host = intern_string(get_domain_from_id(state_key)) + user_id = state_key + known_joins = cache.hosts_to_joined_users.setdefault(host, set()) + + event = await self.stores.main.get_event(event_id) + if event.membership == Membership.JOIN: + known_joins.add(user_id) + else: + known_joins.discard(user_id) + + if not known_joins: + cache.hosts_to_joined_users.pop(host, None) + else: + # The cache doesn't match the state group or prev state group, + # so we calculate the result from first principles. + # + # We need to fetch all hosts joined to the room according to `state` by + # inspecting all join memberships in `state`. However, if the `state` is + # relatively recent then many of its events are likely to be held in + # the current state of the room, which is easily available and likely + # cached. + # + # We therefore compute the set of `state` events not in the + # current state and only fetch those. + current_memberships = ( + await self.stores.main._get_approximate_current_memberships_in_room( + room_id + ) + ) + unknown_state_events = {} + joined_users_in_current_state = [] + + state = await state_entry.get_state( + self, StateFilter.from_types([(EventTypes.Member, None)]) + ) + + for (type, state_key), event_id in state.items(): + if event_id not in current_memberships: + unknown_state_events[type, state_key] = event_id + elif current_memberships[event_id] == Membership.JOIN: + joined_users_in_current_state.append(state_key) + + joined_user_ids = await self.stores.main.get_joined_user_ids_from_state( + room_id, unknown_state_events + ) + + cache.hosts_to_joined_users = {} + for user_id in chain(joined_user_ids, joined_users_in_current_state): + host = intern_string(get_domain_from_id(user_id)) + cache.hosts_to_joined_users.setdefault(host, set()).add(user_id) + + if state_entry.state_group: + cache.state_group = state_entry.state_group + else: + cache.state_group = object() + + return frozenset(cache.hosts_to_joined_users) + + +def server_acl_evaluator_from_event(acl_event: EventBase) -> "ServerAclEvaluator": + """ + Create a ServerAclEvaluator from a m.room.server_acl event's content. + + This does up-front parsing of the content to ignore bad data. It then creates + the ServerAclEvaluator which will pre-compile regular expressions from the globs. + """ + + # first of all, parse if literal IPs are blocked. + allow_ip_literals = acl_event.content.get("allow_ip_literals", True) + if not isinstance(allow_ip_literals, bool): + logger.warning("Ignoring non-bool allow_ip_literals flag") + allow_ip_literals = True + + # next, parse the deny list by ignoring any non-strings. + deny = acl_event.content.get("deny", []) + if not isinstance(deny, (list, tuple)): + logger.warning("Ignoring non-list deny ACL %s", deny) + deny = [] + else: + deny = [s for s in deny if isinstance(s, str)] + + # then the allow list. + allow = acl_event.content.get("allow", []) + if not isinstance(allow, (list, tuple)): + logger.warning("Ignoring non-list allow ACL %s", allow) + allow = [] + else: + allow = [s for s in allow if isinstance(s, str)] + + return ServerAclEvaluator(allow_ip_literals, allow, deny) diff --git a/synapse/storage/database.py b/synapse/storage/database.py index bdaa508dbe..7d8af5c610 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -31,6 +31,7 @@ from typing import ( Iterator, List, Optional, + Sequence, Tuple, Type, TypeVar, @@ -54,7 +55,7 @@ from synapse.logging.context import ( current_context, make_deferred_yieldable, ) -from synapse.metrics import register_threadpool +from synapse.metrics import LaterGauge, register_threadpool from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.background_updates import BackgroundUpdater from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine @@ -358,7 +359,9 @@ class LoggingTransaction: return self.txn.rowcount @property - def description(self) -> Any: + def description( + self, + ) -> Optional[Sequence[Any]]: return self.txn.description def execute_batch(self, sql: str, args: Iterable[Iterable[Any]]) -> None: @@ -407,10 +410,11 @@ class LoggingTransaction: return self._do_execute( # TODO: is it safe for values to be Iterable[Iterable[Any]] here? # https://www.psycopg.org/docs/extras.html?highlight=execute_batch#psycopg2.extras.execute_values says values should be Sequence[Sequence] - lambda the_sql: execute_values( - self.txn, the_sql, values, template=template, fetch=fetch + lambda the_sql, the_values: execute_values( + self.txn, the_sql, the_values, template=template, fetch=fetch ), sql, + values, ) def execute(self, sql: str, parameters: SQLQueryParameters = ()) -> None: @@ -547,6 +551,12 @@ class DatabasePool: self._db_pool = make_pool(hs.get_reactor(), database_config, engine) self.updates = BackgroundUpdater(hs, self) + LaterGauge( + "synapse_background_update_status", + "Background update status", + [], + self.updates.get_status, + ) self._previous_txn_total_time = 0.0 self._current_txn_total_time = 0.0 @@ -1171,6 +1181,7 @@ class DatabasePool: keyvalues: Dict[str, Any], values: Dict[str, Any], insertion_values: Optional[Dict[str, Any]] = None, + where_clause: Optional[str] = None, desc: str = "simple_upsert", ) -> bool: """Insert a row with values + insertion_values; on conflict, update with values. @@ -1221,6 +1232,7 @@ class DatabasePool: keyvalues: The unique key columns and their new values values: The nonunique columns and their new values insertion_values: additional key/values to use only when inserting + where_clause: An index predicate to apply to the upsert. desc: description of the transaction, for logging and metrics Returns: Returns True if a row was inserted or updated (i.e. if `values` is @@ -1241,6 +1253,7 @@ class DatabasePool: keyvalues, values, insertion_values, + where_clause, db_autocommit=autocommit, ) except self.engine.module.IntegrityError as e: @@ -1523,7 +1536,7 @@ class DatabasePool: # Lock the table just once, to prevent it being done once per row. # Note that, according to Postgres' documentation, once obtained, # the lock is held for the remainder of the current transaction. - self.engine.lock_table(txn, "user_ips") + self.engine.lock_table(txn, table) for keyv, valv in zip(key_values, value_values): _keys = dict(zip(key_names, keyv)) @@ -2307,6 +2320,43 @@ class DatabasePool: return txn.rowcount + @staticmethod + def simple_delete_many_batch_txn( + txn: LoggingTransaction, + table: str, + keys: Collection[str], + values: Iterable[Iterable[Any]], + ) -> None: + """Executes a DELETE query on the named table. + + The input is given as a list of rows, where each row is a list of values. + (Actually any iterable is fine.) + + Args: + txn: The transaction to use. + table: string giving the table name + keys: list of column names + values: for each row, a list of values in the same order as `keys` + """ + + if isinstance(txn.database_engine, PostgresEngine): + # We use `execute_values` as it can be a lot faster than `execute_batch`, + # but it's only available on postgres. + sql = "DELETE FROM %s WHERE (%s) IN (VALUES ?)" % ( + table, + ", ".join(k for k in keys), + ) + + txn.execute_values(sql, values, fetch=False) + else: + sql = "DELETE FROM %s WHERE (%s) = (%s)" % ( + table, + ", ".join(k for k in keys), + ", ".join("?" for _ in keys), + ) + + txn.execute_batch(sql, values) + def get_cache_dict( self, db_conn: LoggingDatabaseConnection, @@ -2368,7 +2418,7 @@ class DatabasePool: keyvalues: Optional[Dict[str, Any]] = None, exclude_keyvalues: Optional[Dict[str, Any]] = None, order_direction: str = "ASC", - ) -> List[Dict[str, Any]]: + ) -> List[Tuple[Any, ...]]: """ Executes a SELECT query on the named table with start and limit, of row numbers, which may return zero or number of rows from start to limit, @@ -2397,7 +2447,7 @@ class DatabasePool: order_direction: Whether the results should be ordered "ASC" or "DESC". Returns: - The result as a list of dictionaries. + The result as a list of tuples. """ if order_direction not in ["ASC", "DESC"]: raise ValueError("order_direction must be one of 'ASC' or 'DESC'.") @@ -2424,7 +2474,7 @@ class DatabasePool: ) txn.execute(sql, arg_list + [limit, start]) - return cls.cursor_to_dict(txn) + return txn.fetchall() async def simple_search_list( self, diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 0032a92f49..dfcbf0a175 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -19,6 +19,7 @@ from typing import TYPE_CHECKING, List, Optional, Tuple, cast from synapse.api.constants import Direction from synapse.config.homeserver import HomeServerConfig +from synapse.storage._base import make_in_list_sql_clause from synapse.storage.database import ( DatabasePool, LoggingDatabaseConnection, @@ -61,7 +62,6 @@ from .registration import RegistrationStore from .rejections import RejectionsStore from .relations import RelationsStore from .room import RoomStore -from .room_batch import RoomBatchStore from .roommember import RoomMemberStore from .search import SearchStore from .session import SessionStore @@ -70,6 +70,7 @@ from .state import StateStore from .stats import StatsStore from .stream import StreamWorkerStore from .tags import TagsStore +from .task_scheduler import TaskSchedulerWorkerStore from .transactions import TransactionWorkerStore from .ui_auth import UIAuthStore from .user_directory import UserDirectoryStore @@ -87,7 +88,6 @@ class DataStore( DeviceStore, RoomMemberStore, RoomStore, - RoomBatchStore, RegistrationStore, ProfileStore, PresenceStore, @@ -128,6 +128,7 @@ class DataStore( CacheInvalidationWorkerStore, LockStore, SessionStore, + TaskSchedulerWorkerStore, ): def __init__( self, @@ -141,26 +142,6 @@ class DataStore( super().__init__(database, db_conn, hs) - async def get_users(self) -> List[JsonDict]: - """Function to retrieve a list of users in users table. - - Returns: - A list of dictionaries representing users. - """ - return await self.db_pool.simple_select_list( - table="users", - keyvalues={}, - retcols=[ - "name", - "password_hash", - "is_guest", - "admin", - "user_type", - "deactivated", - ], - desc="get_users", - ) - async def get_users_paginate( self, start: int, @@ -169,9 +150,12 @@ class DataStore( name: Optional[str] = None, guests: bool = True, deactivated: bool = False, + admins: Optional[bool] = None, order_by: str = UserSortOrder.NAME.value, direction: Direction = Direction.FORWARDS, approved: bool = True, + not_user_types: Optional[List[str]] = None, + locked: bool = False, ) -> Tuple[List[JsonDict], int]: """Function to retrieve a paginated list of users from users list. This will return a json list of users and the @@ -184,9 +168,14 @@ class DataStore( name: search for local part of user_id or display name guests: whether to in include guest users deactivated: whether to include deactivated users + admins: Optional flag to filter admins. If true, only admins are queried. + if false, admins are excluded from the query. When it is + none (the default), both admins and none-admins are queried. order_by: the sort order of the returned list direction: sort ascending or descending approved: whether to include approved users + not_user_types: list of user types to exclude + locked: whether to include locked users Returns: A tuple of a list of mappings from user to information and a count of total users. """ @@ -195,7 +184,7 @@ class DataStore( txn: LoggingTransaction, ) -> Tuple[List[JsonDict], int]: filters = [] - args = [self.hs.config.server.server_name] + args: list = [] # Set ordering order_by_column = UserSortOrder(order_by).value @@ -219,17 +208,64 @@ class DataStore( if not deactivated: filters.append("deactivated = 0") + if not locked: + filters.append("locked IS FALSE") + + if admins is not None: + if admins: + filters.append("admin = 1") + else: + filters.append("admin = 0") + if not approved: # We ignore NULL values for the approved flag because these should only # be already existing users that we consider as already approved. filters.append("approved IS FALSE") + if not_user_types: + if len(not_user_types) == 1 and not_user_types[0] == "": + # Only exclude NULL type users + filters.append("user_type IS NOT NULL") + else: + not_user_types_has_empty = False + not_user_types_without_empty = [] + + for not_user_type in not_user_types: + if not_user_type == "": + not_user_types_has_empty = True + else: + not_user_types_without_empty.append(not_user_type) + + not_user_type_clause, not_user_type_args = make_in_list_sql_clause( + self.database_engine, + "u.user_type", + not_user_types_without_empty, + ) + + if not_user_types_has_empty: + # NULL values should be excluded. + # They evaluate to false > nothing to do here. + filters.append("NOT %s" % (not_user_type_clause)) + else: + # NULL values should *not* be excluded. + # Add a special predicate to the query. + filters.append( + "(NOT %s OR %s IS NULL)" + % (not_user_type_clause, "u.user_type") + ) + + args.extend(not_user_type_args) + where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else "" sql_base = f""" FROM users as u - LEFT JOIN profiles AS p ON u.name = '@' || p.user_id || ':' || ? + LEFT JOIN profiles AS p ON u.name = p.full_user_id LEFT JOIN erased_users AS eu ON u.name = eu.user_id + LEFT JOIN ( + SELECT user_id, MAX(last_seen) AS last_seen_ts + FROM user_ips GROUP BY user_id + ) ls ON u.name = ls.user_id {where_clause} """ sql = "SELECT COUNT(*) as total_users " + sql_base @@ -239,7 +275,7 @@ class DataStore( sql = f""" SELECT name, user_type, is_guest, admin, deactivated, shadow_banned, displayname, avatar_url, creation_ts * 1000 as creation_ts, approved, - eu.user_id is not null as erased + eu.user_id is not null as erased, last_seen_ts, locked {sql_base} ORDER BY {order_by_column} {order}, u.name ASC LIMIT ? OFFSET ? diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index 8f7bdbc61a..16c284807a 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -43,7 +43,7 @@ from synapse.storage.util.id_generators import ( MultiWriterIdGenerator, StreamIdGenerator, ) -from synapse.types import JsonDict +from synapse.types import JsonDict, JsonMapping from synapse.util import json_encoder from synapse.util.caches.descriptors import cached from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -119,7 +119,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) @cached() async def get_global_account_data_for_user( self, user_id: str - ) -> Mapping[str, JsonDict]: + ) -> Mapping[str, JsonMapping]: """ Get all the global client account_data for a user. @@ -151,10 +151,10 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) sql += " AND content != '{}'" txn.execute(sql, (user_id,)) - rows = self.db_pool.cursor_to_dict(txn) return { - row["account_data_type"]: db_to_json(row["content"]) for row in rows + account_data_type: db_to_json(content) + for account_data_type, content in txn } return await self.db_pool.runInteraction( @@ -164,7 +164,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) @cached() async def get_room_account_data_for_user( self, user_id: str - ) -> Mapping[str, Mapping[str, JsonDict]]: + ) -> Mapping[str, Mapping[str, JsonMapping]]: """ Get all of the per-room client account_data for a user. @@ -196,13 +196,12 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) sql += " AND content != '{}'" txn.execute(sql, (user_id,)) - rows = self.db_pool.cursor_to_dict(txn) by_room: Dict[str, Dict[str, JsonDict]] = {} - for row in rows: - room_data = by_room.setdefault(row["room_id"], {}) + for room_id, account_data_type, content in txn: + room_data = by_room.setdefault(room_id, {}) - room_data[row["account_data_type"]] = db_to_json(row["content"]) + room_data[account_data_type] = db_to_json(content) return by_room @@ -213,7 +212,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) @cached(num_args=2, max_entries=5000, tree=True) async def get_global_account_data_by_type_for_user( self, user_id: str, data_type: str - ) -> Optional[JsonDict]: + ) -> Optional[JsonMapping]: """ Returns: The account data. @@ -265,7 +264,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) @cached(num_args=2, tree=True) async def get_account_data_for_room( self, user_id: str, room_id: str - ) -> Mapping[str, JsonDict]: + ) -> Mapping[str, JsonMapping]: """Get all the client account_data for a user for a room. Args: @@ -296,7 +295,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) @cached(num_args=3, max_entries=5000, tree=True) async def get_account_data_for_room_and_type( self, user_id: str, room_id: str, account_data_type: str - ) -> Optional[JsonDict]: + ) -> Optional[JsonMapping]: """Get the client account_data of given type for a user for a room. Args: @@ -394,7 +393,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) async def get_updated_global_account_data_for_user( self, user_id: str, stream_id: int - ) -> Dict[str, JsonDict]: + ) -> Mapping[str, JsonMapping]: """Get all the global account_data that's changed for a user. Args: diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 484db175d0..073a99cd84 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -14,17 +14,7 @@ # limitations under the License. import logging import re -from typing import ( - TYPE_CHECKING, - Any, - Dict, - List, - Optional, - Pattern, - Sequence, - Tuple, - cast, -) +from typing import TYPE_CHECKING, List, Optional, Pattern, Sequence, Tuple, cast from synapse.appservice import ( ApplicationService, @@ -45,7 +35,7 @@ from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.storage.types import Cursor from synapse.storage.util.sequence import build_sequence_generator -from synapse.types import DeviceListUpdates, JsonDict +from synapse.types import DeviceListUpdates, JsonMapping from synapse.util import json_encoder from synapse.util.caches.descriptors import _CacheContext, cached @@ -268,8 +258,8 @@ class ApplicationServiceTransactionWorkerStore( self, service: ApplicationService, events: Sequence[EventBase], - ephemeral: List[JsonDict], - to_device_messages: List[JsonDict], + ephemeral: List[JsonMapping], + to_device_messages: List[JsonMapping], one_time_keys_count: TransactionOneTimeKeysCount, unused_fallback_keys: TransactionUnusedFallbackKeys, device_list_summary: DeviceListUpdates, @@ -353,21 +343,15 @@ class ApplicationServiceTransactionWorkerStore( def _get_oldest_unsent_txn( txn: LoggingTransaction, - ) -> Optional[Dict[str, Any]]: + ) -> Optional[Tuple[int, str]]: # Monotonically increasing txn ids, so just select the smallest # one in the txns table (we delete them when they are sent) txn.execute( - "SELECT * FROM application_services_txns WHERE as_id=?" + "SELECT txn_id, event_ids FROM application_services_txns WHERE as_id=?" " ORDER BY txn_id ASC LIMIT 1", (service.id,), ) - rows = self.db_pool.cursor_to_dict(txn) - if not rows: - return None - - entry = rows[0] - - return entry + return cast(Optional[Tuple[int, str]], txn.fetchone()) entry = await self.db_pool.runInteraction( "get_oldest_unsent_appservice_txn", _get_oldest_unsent_txn @@ -376,8 +360,9 @@ class ApplicationServiceTransactionWorkerStore( if not entry: return None - event_ids = db_to_json(entry["event_ids"]) + txn_id, event_ids_str = entry + event_ids = db_to_json(event_ids_str) events = await self.get_events_as_list(event_ids) # TODO: to-device messages, one-time key counts, device list summaries and unused @@ -385,7 +370,7 @@ class ApplicationServiceTransactionWorkerStore( # We likely want to populate those for reliability. return AppServiceTransaction( service=service, - id=entry["txn_id"], + id=txn_id, events=events, ephemeral=[], to_device_messages=[], diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index 46fa0a73f9..2fbd389c71 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -18,6 +18,8 @@ import logging from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Tuple from synapse.api.constants import EventTypes +from synapse.config._base import Config +from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.replication.tcp.streams import BackfillStream, CachesStream from synapse.replication.tcp.streams.events import ( EventsStream, @@ -46,6 +48,27 @@ logger = logging.getLogger(__name__) # based on the current state when notifying workers over replication. CURRENT_STATE_CACHE_NAME = "cs_cache_fake" +# As above, but for invalidating event caches on history deletion +PURGE_HISTORY_CACHE_NAME = "ph_cache_fake" + +# As above, but for invalidating room caches on room deletion +DELETE_ROOM_CACHE_NAME = "dr_cache_fake" + +# How long between cache invalidation table cleanups, once we have caught up +# with the backlog. +REGULAR_CLEANUP_INTERVAL_MS = Config.parse_duration("1h") + +# How long between cache invalidation table cleanups, before we have caught +# up with the backlog. +CATCH_UP_CLEANUP_INTERVAL_MS = Config.parse_duration("1m") + +# Maximum number of cache invalidation rows to delete at once. +CLEAN_UP_MAX_BATCH_SIZE = 20_000 + +# Keep cache invalidations for 7 days +# (This is likely to be quite excessive.) +RETENTION_PERIOD_OF_CACHE_INVALIDATIONS_MS = Config.parse_duration("7d") + class CacheInvalidationWorkerStore(SQLBaseStore): def __init__( @@ -92,6 +115,18 @@ class CacheInvalidationWorkerStore(SQLBaseStore): else: self._cache_id_gen = None + # Occasionally clean up the cache invalidations stream table by deleting + # old rows. + # This is only applicable when Postgres is in use; this table is unused + # and not populated at all when SQLite is the active database engine. + if hs.config.worker.run_background_tasks and isinstance( + self.database_engine, PostgresEngine + ): + self.hs.get_clock().call_later( + CATCH_UP_CLEANUP_INTERVAL_MS / 1000, + self._clean_up_cache_invalidation_wrapper, + ) + async def get_all_updated_caches( self, instance_name: str, last_id: int, current_id: int, limit: int ) -> Tuple[List[Tuple[int, tuple]], int, bool]: @@ -175,6 +210,23 @@ class CacheInvalidationWorkerStore(SQLBaseStore): room_id = row.keys[0] members_changed = set(row.keys[1:]) self._invalidate_state_caches(room_id, members_changed) + elif row.cache_func == PURGE_HISTORY_CACHE_NAME: + if row.keys is None: + raise Exception( + "Can't send an 'invalidate all' for 'purge history' cache" + ) + + room_id = row.keys[0] + self._invalidate_caches_for_room_events(room_id) + elif row.cache_func == DELETE_ROOM_CACHE_NAME: + if row.keys is None: + raise Exception( + "Can't send an 'invalidate all' for 'delete room' cache" + ) + + room_id = row.keys[0] + self._invalidate_caches_for_room_events(room_id) + self._invalidate_caches_for_room(room_id) else: self._attempt_to_invalidate_cache(row.cache_func, row.keys) @@ -226,6 +278,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore): relates_to: Optional[str], backfilled: bool, ) -> None: + # XXX: If you add something to this function make sure you add it to + # `_invalidate_caches_for_room_events` as well. + # This invalidates any local in-memory cached event objects, the original # process triggering the invalidation is responsible for clearing any external # cached objects. @@ -263,6 +318,17 @@ class CacheInvalidationWorkerStore(SQLBaseStore): ) self._attempt_to_invalidate_cache("get_rooms_for_user", (state_key,)) + self._attempt_to_invalidate_cache( + "did_forget", + ( + state_key, + room_id, + ), + ) + self._attempt_to_invalidate_cache( + "get_forgotten_rooms_for_user", (state_key,) + ) + if relates_to: self._attempt_to_invalidate_cache("get_relations_for_event", (relates_to,)) self._attempt_to_invalidate_cache("get_references_for_event", (relates_to,)) @@ -271,6 +337,108 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self._attempt_to_invalidate_cache("get_thread_participated", (relates_to,)) self._attempt_to_invalidate_cache("get_threads", (room_id,)) + def _invalidate_caches_for_room_events_and_stream( + self, txn: LoggingTransaction, room_id: str + ) -> None: + """Invalidate caches associated with events in a room, and stream to + replication. + + Used when we delete events a room, but don't know which events we've + deleted. + """ + + self._send_invalidation_to_replication(txn, PURGE_HISTORY_CACHE_NAME, [room_id]) + txn.call_after(self._invalidate_caches_for_room_events, room_id) + + def _invalidate_caches_for_room_events(self, room_id: str) -> None: + """Invalidate caches associated with events in a room, and stream to + replication. + + Used when we delete events in a room, but don't know which events we've + deleted. + """ + + self._invalidate_local_get_event_cache_all() # type: ignore[attr-defined] + + self._attempt_to_invalidate_cache("have_seen_event", (room_id,)) + self._attempt_to_invalidate_cache("get_latest_event_ids_in_room", (room_id,)) + self._attempt_to_invalidate_cache( + "get_unread_event_push_actions_by_room_for_user", (room_id,) + ) + + self._attempt_to_invalidate_cache("_get_membership_from_event_id", None) + self._attempt_to_invalidate_cache("get_relations_for_event", None) + self._attempt_to_invalidate_cache("get_applicable_edit", None) + self._attempt_to_invalidate_cache("get_thread_id", None) + self._attempt_to_invalidate_cache("get_thread_id_for_receipts", None) + self._attempt_to_invalidate_cache("get_invited_rooms_for_local_user", None) + self._attempt_to_invalidate_cache( + "get_rooms_for_user_with_stream_ordering", None + ) + self._attempt_to_invalidate_cache("get_rooms_for_user", None) + self._attempt_to_invalidate_cache("did_forget", None) + self._attempt_to_invalidate_cache("get_forgotten_rooms_for_user", None) + self._attempt_to_invalidate_cache("get_references_for_event", None) + self._attempt_to_invalidate_cache("get_thread_summary", None) + self._attempt_to_invalidate_cache("get_thread_participated", None) + self._attempt_to_invalidate_cache("get_threads", (room_id,)) + + self._attempt_to_invalidate_cache("_get_state_group_for_event", None) + + self._attempt_to_invalidate_cache("get_event_ordering", None) + self._attempt_to_invalidate_cache("is_partial_state_event", None) + self._attempt_to_invalidate_cache("_get_joined_profile_from_event_id", None) + + def _invalidate_caches_for_room_and_stream( + self, txn: LoggingTransaction, room_id: str + ) -> None: + """Invalidate caches associated with rooms, and stream to replication. + + Used when we delete rooms. + """ + + self._send_invalidation_to_replication(txn, DELETE_ROOM_CACHE_NAME, [room_id]) + txn.call_after(self._invalidate_caches_for_room, room_id) + + def _invalidate_caches_for_room(self, room_id: str) -> None: + """Invalidate caches associated with rooms. + + Used when we delete rooms. + """ + + # If we've deleted the room then we also need to purge all event caches. + self._invalidate_caches_for_room_events(room_id) + + self._attempt_to_invalidate_cache("get_account_data_for_room", None) + self._attempt_to_invalidate_cache("get_account_data_for_room_and_type", None) + self._attempt_to_invalidate_cache("get_aliases_for_room", (room_id,)) + self._attempt_to_invalidate_cache("get_latest_event_ids_in_room", (room_id,)) + self._attempt_to_invalidate_cache("_get_forward_extremeties_for_room", None) + self._attempt_to_invalidate_cache( + "get_unread_event_push_actions_by_room_for_user", (room_id,) + ) + self._attempt_to_invalidate_cache( + "_get_linearized_receipts_for_room", (room_id,) + ) + self._attempt_to_invalidate_cache("is_room_blocked", (room_id,)) + self._attempt_to_invalidate_cache("get_retention_policy_for_room", (room_id,)) + self._attempt_to_invalidate_cache( + "_get_partial_state_servers_at_join", (room_id,) + ) + self._attempt_to_invalidate_cache("is_partial_state_room", (room_id,)) + self._attempt_to_invalidate_cache("get_invited_rooms_for_local_user", None) + self._attempt_to_invalidate_cache( + "get_current_hosts_in_room_ordered", (room_id,) + ) + self._attempt_to_invalidate_cache("did_forget", None) + self._attempt_to_invalidate_cache("get_forgotten_rooms_for_user", None) + self._attempt_to_invalidate_cache("_get_membership_from_event_id", None) + self._attempt_to_invalidate_cache("get_room_version_id", (room_id,)) + + # And delete state caches. + + self._invalidate_state_caches_all(room_id) + async def invalidate_cache_and_stream( self, cache_name: str, keys: Tuple[Any, ...] ) -> None: @@ -377,6 +545,14 @@ class CacheInvalidationWorkerStore(SQLBaseStore): "Can't stream invalidate all with magic current state cache" ) + if cache_name == PURGE_HISTORY_CACHE_NAME and keys is None: + raise Exception( + "Can't stream invalidate all with magic purge history cache" + ) + + if cache_name == DELETE_ROOM_CACHE_NAME and keys is None: + raise Exception("Can't stream invalidate all with magic delete room cache") + if isinstance(self.database_engine, PostgresEngine): assert self._cache_id_gen is not None @@ -407,3 +583,104 @@ class CacheInvalidationWorkerStore(SQLBaseStore): return self._cache_id_gen.get_current_token_for_writer(instance_name) else: return 0 + + @wrap_as_background_process("clean_up_old_cache_invalidations") + async def _clean_up_cache_invalidation_wrapper(self) -> None: + """ + Clean up cache invalidation stream table entries occasionally. + If we are behind (i.e. there are entries old enough to + be deleted but too many of them to be deleted in one go), + then we run slightly more frequently. + """ + delete_up_to: int = ( + self.hs.get_clock().time_msec() - RETENTION_PERIOD_OF_CACHE_INVALIDATIONS_MS + ) + + in_backlog = await self._clean_up_batch_of_old_cache_invalidations(delete_up_to) + + # Vary how long we wait before calling again depending on whether we + # are still sifting through backlog or we have caught up. + if in_backlog: + next_interval = CATCH_UP_CLEANUP_INTERVAL_MS + else: + next_interval = REGULAR_CLEANUP_INTERVAL_MS + + self.hs.get_clock().call_later( + next_interval / 1000, self._clean_up_cache_invalidation_wrapper + ) + + async def _clean_up_batch_of_old_cache_invalidations( + self, delete_up_to_millisec: int + ) -> bool: + """ + Remove old rows from the `cache_invalidation_stream_by_instance` table automatically (this table is unused in SQLite). + + Up to `CLEAN_UP_BATCH_SIZE` rows will be deleted at once. + + Returns true if and only if we were limited by batch size (i.e. we are in backlog: + there are more things to clean up). + """ + + def _clean_up_batch_of_old_cache_invalidations_txn( + txn: LoggingTransaction, + ) -> bool: + # First get the earliest stream ID + txn.execute( + """ + SELECT stream_id FROM cache_invalidation_stream_by_instance + ORDER BY stream_id ASC + LIMIT 1 + """ + ) + row = txn.fetchone() + if row is None: + return False + earliest_stream_id: int = row[0] + + # Then find the last stream ID of the range we will delete + txn.execute( + """ + SELECT stream_id FROM cache_invalidation_stream_by_instance + WHERE stream_id <= ? AND invalidation_ts <= ? + ORDER BY stream_id DESC + LIMIT 1 + """, + (earliest_stream_id + CLEAN_UP_MAX_BATCH_SIZE, delete_up_to_millisec), + ) + row = txn.fetchone() + if row is None: + return False + cutoff_stream_id: int = row[0] + + # Determine whether we are caught up or still catching up + txn.execute( + """ + SELECT invalidation_ts FROM cache_invalidation_stream_by_instance + WHERE stream_id > ? + ORDER BY stream_id ASC + LIMIT 1 + """, + (cutoff_stream_id,), + ) + row = txn.fetchone() + if row is None: + in_backlog = False + else: + # We are in backlog if the next row could have been deleted + # if we didn't have such a small batch size + in_backlog = row[0] <= delete_up_to_millisec + + txn.execute( + """ + DELETE FROM cache_invalidation_stream_by_instance + WHERE ? <= stream_id AND stream_id <= ? + """, + (earliest_stream_id, cutoff_stream_id), + ) + + return in_backlog + + return await self.db_pool.runInteraction( + "clean_up_old_cache_invalidations", + _clean_up_batch_of_old_cache_invalidations_txn, + ) diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py index 0df160d2b0..7da47c3dd7 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py @@ -579,6 +579,11 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke device_id: Optional[str], now: Optional[int] = None, ) -> None: + # The sync proxy continuously triggers /sync even if the user is not + # present so should be excluded from user_ips entries. + if user_agent == "sync-v3-proxy-": + return + if not now: now = int(self._clock.time_msec()) key = (user_id, access_token, ip) @@ -759,3 +764,14 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke } return list(results.values()) + + async def get_last_seen_for_user_id(self, user_id: str) -> Optional[int]: + """Get the last seen timestamp for a user, if we have it.""" + + return await self.db_pool.simple_select_one_onecol( + table="user_ips", + keyvalues={"user_id": user_id}, + retcol="MAX(last_seen)", + allow_none=True, + desc="get_last_seen_for_user_id", + ) diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index b471fcb064..744e98c6d0 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -349,7 +349,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): table="devices", column="user_id", iterable=user_ids_to_query, - keyvalues={"user_id": user_id, "hidden": False}, + keyvalues={"hidden": False}, retcols=("device_id",), ) @@ -445,13 +445,18 @@ class DeviceInboxWorkerStore(SQLBaseStore): @trace async def delete_messages_for_device( - self, user_id: str, device_id: Optional[str], up_to_stream_id: int + self, + user_id: str, + device_id: Optional[str], + up_to_stream_id: int, + limit: int, ) -> int: """ Args: user_id: The recipient user_id. device_id: The recipient device_id. up_to_stream_id: Where to delete messages up to. + limit: maximum number of messages to delete Returns: The number of messages deleted. @@ -472,12 +477,16 @@ class DeviceInboxWorkerStore(SQLBaseStore): log_kv({"message": "No changes in cache since last check"}) return 0 + ROW_ID_NAME = self.database_engine.row_id_name + def delete_messages_for_device_txn(txn: LoggingTransaction) -> int: - sql = ( - "DELETE FROM device_inbox" - " WHERE user_id = ? AND device_id = ?" - " AND stream_id <= ?" - ) + sql = f""" + DELETE FROM device_inbox WHERE {ROW_ID_NAME} IN ( + SELECT {ROW_ID_NAME} FROM device_inbox + WHERE user_id = ? AND device_id = ? AND stream_id <= ? + LIMIT {limit} + ) + """ txn.execute(sql, (user_id, device_id, up_to_stream_id)) return txn.rowcount @@ -487,6 +496,11 @@ class DeviceInboxWorkerStore(SQLBaseStore): log_kv({"message": f"deleted {count} messages for device", "count": count}) + # In this case we don't know if we hit the limit or the delete is complete + # so let's not update the cache. + if count == limit: + return count + # Update the cache, ensuring that we only ever increase the value updated_last_deleted_stream_id = self._last_device_delete_cache.get( (user_id, device_id), 0 diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index a67fdb3c22..9f3804a504 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -28,6 +28,7 @@ from typing import ( cast, ) +from canonicaljson import encode_canonical_json from typing_extensions import Literal from synapse.api.constants import EduTypes @@ -54,7 +55,12 @@ from synapse.storage.util.id_generators import ( AbstractStreamIdGenerator, StreamIdGenerator, ) -from synapse.types import JsonDict, StrCollection, get_verify_key_from_cross_signing_key +from synapse.types import ( + JsonDict, + JsonMapping, + StrCollection, + get_verify_key_from_cross_signing_key, +) from synapse.util import json_decoder, json_encoder from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.lrucache import LruCache @@ -745,7 +751,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): @cancellable async def get_user_devices_from_cache( self, user_ids: Set[str], user_and_device_ids: List[Tuple[str, str]] - ) -> Tuple[Set[str], Dict[str, Mapping[str, JsonDict]]]: + ) -> Tuple[Set[str], Dict[str, Mapping[str, JsonMapping]]]: """Get the devices (and keys if any) for remote users from the cache. Args: @@ -758,28 +764,20 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): mapping of user_id -> device_id -> device_info. """ unique_user_ids = user_ids | {user_id for user_id, _ in user_and_device_ids} - user_map = await self.get_device_list_last_stream_id_for_remotes( - list(unique_user_ids) - ) - # We go and check if any of the users need to have their device lists - # resynced. If they do then we remove them from the cached list. - users_needing_resync = await self.get_user_ids_requiring_device_list_resync( + user_ids_in_cache = await self.get_users_whose_devices_are_cached( unique_user_ids ) - user_ids_in_cache = { - user_id for user_id, stream_id in user_map.items() if stream_id - } - users_needing_resync user_ids_not_in_cache = unique_user_ids - user_ids_in_cache # First fetch all the users which all devices are to be returned. - results: Dict[str, Mapping[str, JsonDict]] = {} + results: Dict[str, Mapping[str, JsonMapping]] = {} for user_id in user_ids: if user_id in user_ids_in_cache: results[user_id] = await self.get_cached_devices_for_user(user_id) # Then fetch all device-specific requests, but skip users we've already # fetched all devices for. - device_specific_results: Dict[str, Dict[str, JsonDict]] = {} + device_specific_results: Dict[str, Dict[str, JsonMapping]] = {} for user_id, device_id in user_and_device_ids: if user_id in user_ids_in_cache and user_id not in user_ids: device = await self._get_cached_user_device(user_id, device_id) @@ -791,8 +789,26 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): return user_ids_not_in_cache, results + async def get_users_whose_devices_are_cached( + self, user_ids: StrCollection + ) -> Set[str]: + """Checks which of the given users we have cached the devices for.""" + user_map = await self.get_device_list_last_stream_id_for_remotes(user_ids) + + # We go and check if any of the users need to have their device lists + # resynced. If they do then we remove them from the cached list. + users_needing_resync = await self.get_user_ids_requiring_device_list_resync( + user_ids + ) + user_ids_in_cache = { + user_id for user_id, stream_id in user_map.items() if stream_id + } - users_needing_resync + return user_ids_in_cache + @cached(num_args=2, tree=True) - async def _get_cached_user_device(self, user_id: str, device_id: str) -> JsonDict: + async def _get_cached_user_device( + self, user_id: str, device_id: str + ) -> JsonMapping: content = await self.db_pool.simple_select_one_onecol( table="device_lists_remote_cache", keyvalues={"user_id": user_id, "device_id": device_id}, @@ -802,7 +818,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): return db_to_json(content) @cached() - async def get_cached_devices_for_user(self, user_id: str) -> Mapping[str, JsonDict]: + async def get_cached_devices_for_user( + self, user_id: str + ) -> Mapping[str, JsonMapping]: devices = await self.db_pool.simple_select_list( table="device_lists_remote_cache", keyvalues={"user_id": user_id}, @@ -1033,7 +1051,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): ) async def get_device_list_last_stream_id_for_remotes( self, user_ids: Iterable[str] - ) -> Dict[str, Optional[str]]: + ) -> Mapping[str, Optional[str]]: rows = await self.db_pool.simple_select_many_batch( table="device_lists_remote_extremeties", column="user_id", @@ -1188,8 +1206,42 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): ) def _store_dehydrated_device_txn( - self, txn: LoggingTransaction, user_id: str, device_id: str, device_data: str + self, + txn: LoggingTransaction, + user_id: str, + device_id: str, + device_data: str, + time: int, + keys: Optional[JsonDict] = None, ) -> Optional[str]: + # TODO: make keys non-optional once support for msc2697 is dropped + if keys: + device_keys = keys.get("device_keys", None) + if device_keys: + # Type ignore - this function is defined on EndToEndKeyStore which we do + # have access to due to hs.get_datastore() "magic" + self._set_e2e_device_keys_txn( # type: ignore[attr-defined] + txn, user_id, device_id, time, device_keys + ) + + one_time_keys = keys.get("one_time_keys", None) + if one_time_keys: + key_list = [] + for key_id, key_obj in one_time_keys.items(): + algorithm, key_id = key_id.split(":") + key_list.append( + ( + algorithm, + key_id, + encode_canonical_json(key_obj).decode("ascii"), + ) + ) + self._add_e2e_one_time_keys_txn(txn, user_id, device_id, time, key_list) + + fallback_keys = keys.get("fallback_keys", None) + if fallback_keys: + self._set_e2e_fallback_keys_txn(txn, user_id, device_id, fallback_keys) + old_device_id = self.db_pool.simple_select_one_onecol_txn( txn, table="dehydrated_devices", @@ -1203,10 +1255,16 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): keyvalues={"user_id": user_id}, values={"device_id": device_id, "device_data": device_data}, ) + return old_device_id async def store_dehydrated_device( - self, user_id: str, device_id: str, device_data: JsonDict + self, + user_id: str, + device_id: str, + device_data: JsonDict, + time_now: int, + keys: Optional[dict] = None, ) -> Optional[str]: """Store a dehydrated device for a user. @@ -1214,15 +1272,21 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): user_id: the user that we are storing the device for device_id: the ID of the dehydrated device device_data: the dehydrated device information + time_now: current time at the request in milliseconds + keys: keys for the dehydrated device + Returns: device id of the user's previous dehydrated device, if any """ + return await self.db_pool.runInteraction( "store_dehydrated_device_txn", self._store_dehydrated_device_txn, user_id, device_id, json_encoder.encode(device_data), + time_now, + keys, ) async def remove_dehydrated_device(self, user_id: str, device_id: str) -> bool: @@ -1349,13 +1413,13 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): def get_devices_not_accessed_since_txn( txn: LoggingTransaction, - ) -> List[Dict[str, str]]: + ) -> List[Tuple[str, str]]: sql = """ SELECT user_id, device_id FROM devices WHERE last_seen < ? AND hidden = FALSE """ txn.execute(sql, (since_ms,)) - return self.db_pool.cursor_to_dict(txn) + return cast(List[Tuple[str, str]], txn.fetchall()) rows = await self.db_pool.runInteraction( "get_devices_not_accessed_since", @@ -1363,11 +1427,11 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): ) devices: Dict[str, List[str]] = {} - for row in rows: + for user_id, device_id in rows: # Remote devices are never stale from our point of view. - if self.hs.is_mine_id(row["user_id"]): - user_devices = devices.setdefault(row["user_id"], []) - user_devices.append(row["device_id"]) + if self.hs.is_mine_id(user_id): + user_devices = devices.setdefault(user_id, []) + user_devices.append(device_id) return devices @@ -1719,14 +1783,6 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): self.db_pool.simple_delete_many_txn( txn, - table="device_inbox", - column="device_id", - values=device_ids, - keyvalues={"user_id": user_id}, - ) - - self.db_pool.simple_delete_many_txn( - txn, table="device_auth_providers", column="device_id", values=device_ids, @@ -1941,17 +1997,25 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): user_id, stream_ids[-1], ) + txn.call_after( + self._get_e2e_device_keys_for_federation_query_inner.invalidate, + (user_id,), + ) min_stream_id = stream_ids[0] # Delete older entries in the table, as we really only care about # when the latest change happened. - txn.execute_batch( - """ + cleanup_obsolete_stmt = """ DELETE FROM device_lists_stream - WHERE user_id = ? AND device_id = ? AND stream_id < ? - """, - [(user_id, device_id, min_stream_id) for device_id in device_ids], + WHERE user_id = ? AND stream_id < ? AND %s + """ + device_ids_clause, device_ids_args = make_in_list_sql_clause( + txn.database_engine, "device_id", device_ids + ) + txn.execute( + cleanup_obsolete_stmt % (device_ids_clause,), + [user_id, min_stream_id] + device_ids_args, ) self.db_pool.simple_insert_many_txn( diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py index d01f28cc80..bc7c6a6346 100644 --- a/synapse/storage/databases/main/e2e_room_keys.py +++ b/synapse/storage/databases/main/e2e_room_keys.py @@ -208,7 +208,7 @@ class EndToEndRoomKeyStore(EndToEndRoomKeyBackgroundStore): "message": "Set room key", "room_id": room_id, "session_id": session_id, - StreamKeyType.ROOM: room_key, + StreamKeyType.ROOM.value: room_key, } ) diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 4bc391f213..749ae54e20 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -16,6 +16,7 @@ import abc from typing import ( TYPE_CHECKING, + Any, Collection, Dict, Iterable, @@ -39,6 +40,7 @@ from synapse.appservice import ( TransactionUnusedFallbackKeys, ) from synapse.logging.opentracing import log_kv, set_tag, trace +from synapse.replication.tcp.streams._base import DeviceListsStream from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import ( DatabasePool, @@ -50,7 +52,7 @@ from synapse.storage.database import ( from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import StreamIdGenerator -from synapse.types import JsonDict +from synapse.types import JsonDict, JsonMapping from synapse.util import json_decoder, json_encoder from synapse.util.caches.descriptors import cached, cachedList from synapse.util.cancellation import cancellable @@ -104,9 +106,26 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker self.hs.config.federation.allow_device_name_lookup_over_federation ) + def process_replication_rows( + self, + stream_name: str, + instance_name: str, + token: int, + rows: Iterable[Any], + ) -> None: + if stream_name == DeviceListsStream.NAME: + for row in rows: + assert isinstance(row, DeviceListsStream.DeviceListsStreamRow) + if row.entity.startswith("@"): + self._get_e2e_device_keys_for_federation_query_inner.invalidate( + (row.entity,) + ) + + super().process_replication_rows(stream_name, instance_name, token, rows) + async def get_e2e_device_keys_for_federation_query( self, user_id: str - ) -> Tuple[int, List[JsonDict]]: + ) -> Tuple[int, Sequence[JsonMapping]]: """Get all devices (with any device keys) for a user Returns: @@ -114,6 +133,50 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker """ now_stream_id = self.get_device_stream_token() + # We need to be careful with the caching here, as we need to always + # return *all* persisted devices, however there may be a lag between a + # new device being persisted and the cache being invalidated. + cached_results = ( + self._get_e2e_device_keys_for_federation_query_inner.cache.get_immediate( + user_id, None + ) + ) + if cached_results is not None: + # Check that there have been no new devices added by another worker + # after the cache. This should be quick as there should be few rows + # with a higher stream ordering. + # + # Note that we invalidate based on the device stream, so we only + # have to check for potential invalidations after the + # `now_stream_id`. + sql = """ + SELECT user_id FROM device_lists_stream + WHERE stream_id >= ? AND user_id = ? + """ + rows = await self.db_pool.execute( + "get_e2e_device_keys_for_federation_query_check", + None, + sql, + now_stream_id, + user_id, + ) + if not rows: + # No new rows, so cache is still valid. + return now_stream_id, cached_results + + # There has, so let's invalidate the cache and run the query. + self._get_e2e_device_keys_for_federation_query_inner.invalidate((user_id,)) + + results = await self._get_e2e_device_keys_for_federation_query_inner(user_id) + + return now_stream_id, results + + @cached(iterable=True) + async def _get_e2e_device_keys_for_federation_query_inner( + self, user_id: str + ) -> Sequence[JsonMapping]: + """Get all devices (with any device keys) for a user""" + devices = await self.get_e2e_device_keys_and_signatures([(user_id, None)]) if devices: @@ -134,9 +197,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker results.append(result) - return now_stream_id, results + return results - return now_stream_id, [] + return [] @trace @cancellable @@ -459,42 +522,63 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker new_keys: keys to add - each a tuple of (algorithm, key_id, key json) """ - def _add_e2e_one_time_keys(txn: LoggingTransaction) -> None: - set_tag("user_id", user_id) - set_tag("device_id", device_id) - set_tag("new_keys", str(new_keys)) - # We are protected from race between lookup and insertion due to - # a unique constraint. If there is a race of two calls to - # `add_e2e_one_time_keys` then they'll conflict and we will only - # insert one set. - self.db_pool.simple_insert_many_txn( - txn, - table="e2e_one_time_keys_json", - keys=( - "user_id", - "device_id", - "algorithm", - "key_id", - "ts_added_ms", - "key_json", - ), - values=[ - (user_id, device_id, algorithm, key_id, time_now, json_bytes) - for algorithm, key_id, json_bytes in new_keys - ], - ) - self._invalidate_cache_and_stream( - txn, self.count_e2e_one_time_keys, (user_id, device_id) - ) - await self.db_pool.runInteraction( - "add_e2e_one_time_keys_insert", _add_e2e_one_time_keys + "add_e2e_one_time_keys_insert", + self._add_e2e_one_time_keys_txn, + user_id, + device_id, + time_now, + new_keys, + ) + + def _add_e2e_one_time_keys_txn( + self, + txn: LoggingTransaction, + user_id: str, + device_id: str, + time_now: int, + new_keys: Iterable[Tuple[str, str, str]], + ) -> None: + """Insert some new one time keys for a device. Errors if any of the keys already exist. + + Args: + user_id: id of user to get keys for + device_id: id of device to get keys for + time_now: insertion time to record (ms since epoch) + new_keys: keys to add - each a tuple of (algorithm, key_id, key json) - note + that the key JSON must be in canonical JSON form + """ + set_tag("user_id", user_id) + set_tag("device_id", device_id) + set_tag("new_keys", str(new_keys)) + # We are protected from race between lookup and insertion due to + # a unique constraint. If there is a race of two calls to + # `add_e2e_one_time_keys` then they'll conflict and we will only + # insert one set. + self.db_pool.simple_insert_many_txn( + txn, + table="e2e_one_time_keys_json", + keys=( + "user_id", + "device_id", + "algorithm", + "key_id", + "ts_added_ms", + "key_json", + ), + values=[ + (user_id, device_id, algorithm, key_id, time_now, json_bytes) + for algorithm, key_id, json_bytes in new_keys + ], + ) + self._invalidate_cache_and_stream( + txn, self.count_e2e_one_time_keys, (user_id, device_id) ) @cached(max_entries=10000) async def count_e2e_one_time_keys( self, user_id: str, device_id: str - ) -> Dict[str, int]: + ) -> Mapping[str, int]: """Count the number of one time keys the server has for a device Returns: A mapping from algorithm to number of keys for that algorithm. @@ -660,6 +744,14 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker device_id: str, fallback_keys: JsonDict, ) -> None: + """Set the user's e2e fallback keys. + + Args: + user_id: the user whose keys are being set + device_id: the device whose keys are being set + fallback_keys: the keys to set. This is a map from key ID (which is + of the form "algorithm:id") to key data. + """ # fallback_keys will usually only have one item in it, so using a for # loop (as opposed to calling simple_upsert_many_txn) won't be too bad # FIXME: make sure that only one key per algorithm is uploaded @@ -720,7 +812,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker async def get_e2e_cross_signing_key( self, user_id: str, key_type: str, from_user_id: Optional[str] = None - ) -> Optional[JsonDict]: + ) -> Optional[JsonMapping]: """Returns a user's cross-signing key. Args: @@ -741,7 +833,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker return user_keys.get(key_type) @cached(num_args=1) - def _get_bare_e2e_cross_signing_keys(self, user_id: str) -> Mapping[str, JsonDict]: + def _get_bare_e2e_cross_signing_keys( + self, user_id: str + ) -> Mapping[str, JsonMapping]: """Dummy function. Only used to make a cache for _get_bare_e2e_cross_signing_keys_bulk. """ @@ -754,7 +848,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker ) async def _get_bare_e2e_cross_signing_keys_bulk( self, user_ids: Iterable[str] - ) -> Dict[str, Optional[Mapping[str, JsonDict]]]: + ) -> Mapping[str, Optional[Mapping[str, JsonMapping]]]: """Returns the cross-signing keys for a set of users. The output of this function should be passed to _get_e2e_cross_signing_signatures_txn if the signatures for the calling user need to be fetched. @@ -768,15 +862,12 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker their user ID will map to None. """ - result = await self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_bare_e2e_cross_signing_keys_bulk", self._get_bare_e2e_cross_signing_keys_bulk_txn, user_ids, ) - # The `Optional` comes from the `@cachedList` decorator. - return cast(Dict[str, Optional[Mapping[str, JsonDict]]], result) - def _get_bare_e2e_cross_signing_keys_bulk_txn( self, txn: LoggingTransaction, @@ -830,14 +921,10 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker } txn.execute(sql, params) - rows = self.db_pool.cursor_to_dict(txn) - for row in rows: - user_id = row["user_id"] - key_type = row["keytype"] - key = db_to_json(row["keydata"]) + for user_id, key_type, key_data, _ in txn: user_keys = result.setdefault(user_id, {}) - user_keys[key_type] = key + user_keys[key_type] = db_to_json(key_data) return result @@ -897,13 +984,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker query_params.extend(item) txn.execute(sql, query_params) - rows = self.db_pool.cursor_to_dict(txn) # and add the signatures to the appropriate keys - for row in rows: - key_id: str = row["key_id"] - target_user_id: str = row["target_user_id"] - target_device_id: str = row["target_device_id"] + for target_user_id, target_device_id, key_id, signature in txn: key_type = devices[(target_user_id, target_device_id)] # We need to copy everything, because the result may have come # from the cache. dict.copy only does a shallow copy, so we @@ -921,20 +1004,18 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker ].copy() if from_user_id in signatures: user_sigs = signatures[from_user_id] = signatures[from_user_id] - user_sigs[key_id] = row["signature"] + user_sigs[key_id] = signature else: - signatures[from_user_id] = {key_id: row["signature"]} + signatures[from_user_id] = {key_id: signature} else: - target_user_key["signatures"] = { - from_user_id: {key_id: row["signature"]} - } + target_user_key["signatures"] = {from_user_id: {key_id: signature}} return keys @cancellable async def get_e2e_cross_signing_keys_bulk( self, user_ids: List[str], from_user_id: Optional[str] = None - ) -> Dict[str, Optional[Mapping[str, JsonDict]]]: + ) -> Mapping[str, Optional[Mapping[str, JsonMapping]]]: """Returns the cross-signing keys for a set of users. Args: @@ -951,7 +1032,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker if from_user_id: result = cast( - Dict[str, Optional[Mapping[str, JsonDict]]], + Dict[str, Optional[Mapping[str, JsonMapping]]], await self.db_pool.runInteraction( "get_e2e_cross_signing_signatures", self._get_e2e_cross_signing_signatures_txn, @@ -1241,42 +1322,69 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): ) -> bool: """Stores device keys for a device. Returns whether there was a change or the keys were already in the database. + + Args: + user_id: user_id of the user to store keys for + device_id: device_id of the device to store keys for + time_now: time at the request to store the keys + device_keys: the keys to store """ - def _set_e2e_device_keys_txn(txn: LoggingTransaction) -> bool: - set_tag("user_id", user_id) - set_tag("device_id", device_id) - set_tag("time_now", time_now) - set_tag("device_keys", str(device_keys)) + return await self.db_pool.runInteraction( + "set_e2e_device_keys", + self._set_e2e_device_keys_txn, + user_id, + device_id, + time_now, + device_keys, + ) - old_key_json = self.db_pool.simple_select_one_onecol_txn( - txn, - table="e2e_device_keys_json", - keyvalues={"user_id": user_id, "device_id": device_id}, - retcol="key_json", - allow_none=True, - ) + def _set_e2e_device_keys_txn( + self, + txn: LoggingTransaction, + user_id: str, + device_id: str, + time_now: int, + device_keys: JsonDict, + ) -> bool: + """Stores device keys for a device. Returns whether there was a change + or the keys were already in the database. + + Args: + user_id: user_id of the user to store keys for + device_id: device_id of the device to store keys for + time_now: time at the request to store the keys + device_keys: the keys to store + """ + set_tag("user_id", user_id) + set_tag("device_id", device_id) + set_tag("time_now", time_now) + set_tag("device_keys", str(device_keys)) - # In py3 we need old_key_json to match new_key_json type. The DB - # returns unicode while encode_canonical_json returns bytes. - new_key_json = encode_canonical_json(device_keys).decode("utf-8") + old_key_json = self.db_pool.simple_select_one_onecol_txn( + txn, + table="e2e_device_keys_json", + keyvalues={"user_id": user_id, "device_id": device_id}, + retcol="key_json", + allow_none=True, + ) - if old_key_json == new_key_json: - log_kv({"Message": "Device key already stored."}) - return False + # In py3 we need old_key_json to match new_key_json type. The DB + # returns unicode while encode_canonical_json returns bytes. + new_key_json = encode_canonical_json(device_keys).decode("utf-8") - self.db_pool.simple_upsert_txn( - txn, - table="e2e_device_keys_json", - keyvalues={"user_id": user_id, "device_id": device_id}, - values={"ts_added_ms": time_now, "key_json": new_key_json}, - ) - log_kv({"message": "Device keys stored."}) - return True + if old_key_json == new_key_json: + log_kv({"Message": "Device key already stored."}) + return False - return await self.db_pool.runInteraction( - "set_e2e_device_keys", _set_e2e_device_keys_txn + self.db_pool.simple_upsert_txn( + txn, + table="e2e_device_keys_json", + keyvalues={"user_id": user_id, "device_id": device_id}, + values={"ts_added_ms": time_now, "key_json": new_key_json}, ) + log_kv({"message": "Device keys stored."}) + return True async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None: def delete_e2e_keys_by_device_txn(txn: LoggingTransaction) -> None: diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index ac19de183c..afffa54985 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -19,6 +19,7 @@ from typing import ( TYPE_CHECKING, Collection, Dict, + FrozenSet, Iterable, List, Optional, @@ -31,13 +32,14 @@ from typing import ( import attr from prometheus_client import Counter, Gauge -from synapse.api.constants import MAX_DEPTH, EventTypes +from synapse.api.constants import MAX_DEPTH from synapse.api.errors import StoreError from synapse.api.room_versions import EventFormatVersions, RoomVersion from synapse.events import EventBase, make_event_from_dict from synapse.logging.opentracing import tag_args, trace from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause +from synapse.storage.background_updates import ForeignKeyConstraint from synapse.storage.database import ( DatabasePool, LoggingDatabaseConnection, @@ -46,7 +48,7 @@ from synapse.storage.database import ( from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.signatures import SignatureWorkerStore from synapse.storage.engines import PostgresEngine, Sqlite3Engine -from synapse.types import JsonDict +from synapse.types import JsonDict, StrCollection from synapse.util import json_encoder from synapse.util.caches.descriptors import cached from synapse.util.caches.lrucache import LruCache @@ -140,6 +142,17 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas self._clock.looping_call(self._get_stats_for_federation_staging, 30 * 1000) + if isinstance(self.database_engine, PostgresEngine): + self.db_pool.updates.register_background_validate_constraint_and_delete_rows( + update_name="event_forward_extremities_event_id_foreign_key_constraint_update", + table="event_forward_extremities", + constraint_name="event_forward_extremities_event_id", + constraint=ForeignKeyConstraint( + "events", [("event_id", "event_id")], deferred=True + ), + unique_columns=("event_id", "room_id"), + ) + async def get_auth_chain( self, room_id: str, event_ids: Collection[str], include_given: bool = False ) -> List[EventBase]: @@ -440,33 +453,56 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas # sets. seen_chains: Set[int] = set() - sql = """ - SELECT event_id, chain_id, sequence_number - FROM event_auth_chains - WHERE %s - """ - for batch in batch_iter(initial_events, 1000): - clause, args = make_in_list_sql_clause( - txn.database_engine, "event_id", batch - ) - txn.execute(sql % (clause,), args) + # Fetch the chain cover index for the initial set of events we're + # considering. + def fetch_chain_info(events_to_fetch: Collection[str]) -> None: + sql = """ + SELECT event_id, chain_id, sequence_number + FROM event_auth_chains + WHERE %s + """ + for batch in batch_iter(events_to_fetch, 1000): + clause, args = make_in_list_sql_clause( + txn.database_engine, "event_id", batch + ) + txn.execute(sql % (clause,), args) - for event_id, chain_id, sequence_number in txn: - chain_info[event_id] = (chain_id, sequence_number) - seen_chains.add(chain_id) - chain_to_event.setdefault(chain_id, {})[sequence_number] = event_id + for event_id, chain_id, sequence_number in txn: + chain_info[event_id] = (chain_id, sequence_number) + seen_chains.add(chain_id) + chain_to_event.setdefault(chain_id, {})[sequence_number] = event_id + + fetch_chain_info(initial_events) # Check that we actually have a chain ID for all the events. events_missing_chain_info = initial_events.difference(chain_info) + + # The result set to return, i.e. the auth chain difference. + result: Set[str] = set() + if events_missing_chain_info: - # This can happen due to e.g. downgrade/upgrade of the server. We - # raise an exception and fall back to the previous algorithm. - logger.info( - "Unexpectedly found that events don't have chain IDs in room %s: %s", + # For some reason we have events we haven't calculated the chain + # index for, so we need to handle those separately. This should only + # happen for older rooms where the server doesn't have all the auth + # events. + result = self._fixup_auth_chain_difference_sets( + txn, room_id, - events_missing_chain_info, + state_sets=state_sets, + events_missing_chain_info=events_missing_chain_info, + events_that_have_chain_index=chain_info, ) - raise _NoChainCoverIndex(room_id) + + # We now need to refetch any events that we have added to the state + # sets. + new_events_to_fetch = { + event_id + for state_set in state_sets + for event_id in state_set + if event_id not in initial_events + } + + fetch_chain_info(new_events_to_fetch) # Corresponds to `state_sets`, except as a map from chain ID to max # sequence number reachable from the state set. @@ -475,8 +511,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas chains: Dict[int, int] = {} set_to_chain.append(chains) - for event_id in state_set: - chain_id, seq_no = chain_info[event_id] + for state_id in state_set: + chain_id, seq_no = chain_info[state_id] chains[chain_id] = max(seq_no, chains.get(chain_id, 0)) @@ -520,7 +556,6 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas # from *any* state set and the minimum sequence number reachable from # *all* state sets. Events in that range are in the auth chain # difference. - result = set() # Mapping from chain ID to the range of sequence numbers that should be # pulled from the database. @@ -576,6 +611,122 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas return result + def _fixup_auth_chain_difference_sets( + self, + txn: LoggingTransaction, + room_id: str, + state_sets: List[Set[str]], + events_missing_chain_info: Set[str], + events_that_have_chain_index: Collection[str], + ) -> Set[str]: + """Helper for `_get_auth_chain_difference_using_cover_index_txn` to + handle the case where we haven't calculated the chain cover index for + all events. + + This modifies `state_sets` so that they only include events that have a + chain cover index, and returns a set of event IDs that are part of the + auth difference. + """ + + # This works similarly to the handling of unpersisted events in + # `synapse.state.v2_get_auth_chain_difference`. We uses the observation + # that if you can split the set of events into two classes X and Y, + # where no events in Y have events in X in their auth chain, then we can + # calculate the auth difference by considering X and Y separately. + # + # We do this in three steps: + # 1. Compute the set of events without chain cover index belonging to + # the auth difference. + # 2. Replacing the un-indexed events in the state_sets with their auth + # events, recursively, until the state_sets contain only indexed + # events. We can then calculate the auth difference of those state + # sets using the chain cover index. + # 3. Add the results of 1 and 2 together. + + # By construction we know that all events that we haven't persisted the + # chain cover index for are contained in + # `event_auth_chain_to_calculate`, so we pull out the events from those + # rather than doing recursive queries to walk the auth chain. + # + # We pull out those events with their auth events, which gives us enough + # information to construct the auth chain of an event up to auth events + # that have the chain cover index. + sql = """ + SELECT tc.event_id, ea.auth_id, eac.chain_id IS NOT NULL + FROM event_auth_chain_to_calculate AS tc + LEFT JOIN event_auth AS ea USING (event_id) + LEFT JOIN event_auth_chains AS eac ON (ea.auth_id = eac.event_id) + WHERE tc.room_id = ? + """ + txn.execute(sql, (room_id,)) + event_to_auth_ids: Dict[str, Set[str]] = {} + events_that_have_chain_index = set(events_that_have_chain_index) + for event_id, auth_id, auth_id_has_chain in txn: + s = event_to_auth_ids.setdefault(event_id, set()) + if auth_id is not None: + s.add(auth_id) + if auth_id_has_chain: + events_that_have_chain_index.add(auth_id) + + if events_missing_chain_info - event_to_auth_ids.keys(): + # Uh oh, we somehow haven't correctly done the chain cover index, + # bail and fall back to the old method. + logger.info( + "Unexpectedly found that events don't have chain IDs in room %s: %s", + room_id, + events_missing_chain_info - event_to_auth_ids.keys(), + ) + raise _NoChainCoverIndex(room_id) + + # Create a map from event IDs we care about to their partial auth chain. + event_id_to_partial_auth_chain: Dict[str, Set[str]] = {} + for event_id, auth_ids in event_to_auth_ids.items(): + if not any(event_id in state_set for state_set in state_sets): + continue + + processing = set(auth_ids) + to_add = set() + while processing: + auth_id = processing.pop() + to_add.add(auth_id) + + sub_auth_ids = event_to_auth_ids.get(auth_id) + if sub_auth_ids is None: + continue + + processing.update(sub_auth_ids - to_add) + + event_id_to_partial_auth_chain[event_id] = to_add + + # Now we do two things: + # 1. Update the state sets to only include indexed events; and + # 2. Create a new list containing the auth chains of the un-indexed + # events + unindexed_state_sets: List[Set[str]] = [] + for state_set in state_sets: + unindexed_state_set = set() + for event_id, auth_chain in event_id_to_partial_auth_chain.items(): + if event_id not in state_set: + continue + + unindexed_state_set.add(event_id) + + state_set.discard(event_id) + state_set.difference_update(auth_chain) + for auth_id in auth_chain: + if auth_id in events_that_have_chain_index: + state_set.add(auth_id) + else: + unindexed_state_set.add(auth_id) + + unindexed_state_sets.append(unindexed_state_set) + + # Calculate and return the auth difference of the un-indexed events. + union = unindexed_state_sets[0].union(*unindexed_state_sets[1:]) + intersection = unindexed_state_sets[0].intersection(*unindexed_state_sets[1:]) + + return union - intersection + def _get_auth_chain_difference_txn( self, txn: LoggingTransaction, state_sets: List[Set[str]] ) -> Set[str]: @@ -831,7 +982,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas * because the schema change is in a background update, it's not * necessarily safe to assume that it will have been completed. */ - AND edge.is_state is ? /* False */ + AND edge.is_state is FALSE /** * We only want backwards extremities that are older than or at * the same position of the given `current_depth` (where older @@ -874,7 +1025,6 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas sql, ( room_id, - False, current_depth, self._clock.time_msec(), BACKFILL_EVENT_EXPONENTIAL_BACKOFF_MAXIMUM_DOUBLING_STEPS, @@ -891,124 +1041,6 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas room_id, ) - @trace - async def get_insertion_event_backward_extremities_in_room( - self, - room_id: str, - current_depth: int, - limit: int, - ) -> List[Tuple[str, int]]: - """ - Get the insertion events we know about that we haven't backfilled yet - along with the approximate depth. Only returns insertion events that are - at a depth lower than or equal to the `current_depth`. Sorted by depth, - highest to lowest (descending) so the closest events to the - `current_depth` are first in the list. - - We ignore insertion events that are newer than the user's current scroll - position (ie, those with depth greater than `current_depth`) as: - 1. we don't really care about getting events that have happened - after our current position; and - 2. by the nature of paginating and scrolling back, we have likely - previously tried and failed to backfill from that insertion event, so - to avoid getting "stuck" requesting the same backfill repeatedly - we drop those insertion event. - - Args: - room_id: Room where we want to find the oldest events - current_depth: The depth at the user's current scrollback position - limit: The max number of insertion event extremities to return - - Returns: - List of (event_id, depth) tuples. Sorted by depth, highest to lowest - (descending) so the closest events to the `current_depth` are first - in the list. - """ - - def get_insertion_event_backward_extremities_in_room_txn( - txn: LoggingTransaction, room_id: str - ) -> List[Tuple[str, int]]: - if isinstance(self.database_engine, PostgresEngine): - least_function = "LEAST" - elif isinstance(self.database_engine, Sqlite3Engine): - least_function = "MIN" - else: - raise RuntimeError("Unknown database engine") - - sql = f""" - SELECT - insertion_event_extremity.event_id, event.depth - /* We only want insertion events that are also marked as backwards extremities */ - FROM insertion_event_extremities AS insertion_event_extremity - /* Get the depth of the insertion event from the events table */ - INNER JOIN events AS event USING (event_id) - /** - * We use this info to make sure we don't retry to use a backfill point - * if we've already attempted to backfill from it recently. - */ - LEFT JOIN event_failed_pull_attempts AS failed_backfill_attempt_info - ON - failed_backfill_attempt_info.room_id = insertion_event_extremity.room_id - AND failed_backfill_attempt_info.event_id = insertion_event_extremity.event_id - WHERE - insertion_event_extremity.room_id = ? - /** - * We only want extremities that are older than or at - * the same position of the given `current_depth` (where older - * means less than the given depth) because we're looking backwards - * from the `current_depth` when backfilling. - * - * current_depth (ignore events that come after this, ignore 2-4) - * | - * â–¼ - * <oldest-in-time> [0]<--[1]<--[2]<--[3]<--[4] <newest-in-time> - */ - AND event.depth <= ? /* current_depth */ - /** - * Exponential back-off (up to the upper bound) so we don't retry the - * same backfill point over and over. ex. 2hr, 4hr, 8hr, 16hr, etc - * - * We use `1 << n` as a power of 2 equivalent for compatibility - * with older SQLites. The left shift equivalent only works with - * powers of 2 because left shift is a binary operation (base-2). - * Otherwise, we would use `power(2, n)` or the power operator, `2^n`. - */ - AND ( - failed_backfill_attempt_info.event_id IS NULL - OR ? /* current_time */ >= failed_backfill_attempt_info.last_attempt_ts + ( - (1 << {least_function}(failed_backfill_attempt_info.num_attempts, ? /* max doubling steps */)) - * ? /* step */ - ) - ) - /** - * Sort from highest (closest to the `current_depth`) to the lowest depth - * because the closest are most relevant to backfill from first. - * Then tie-break on alphabetical order of the event_ids so we get a - * consistent ordering which is nice when asserting things in tests. - */ - ORDER BY event.depth DESC, insertion_event_extremity.event_id DESC - LIMIT ? - """ - - txn.execute( - sql, - ( - room_id, - current_depth, - self._clock.time_msec(), - BACKFILL_EVENT_EXPONENTIAL_BACKOFF_MAXIMUM_DOUBLING_STEPS, - BACKFILL_EVENT_EXPONENTIAL_BACKOFF_STEP_MILLISECONDS, - limit, - ), - ) - return cast(List[Tuple[str, int]], txn.fetchall()) - - return await self.db_pool.runInteraction( - "get_insertion_event_backward_extremities_in_room", - get_insertion_event_backward_extremities_in_room_txn, - room_id, - ) - async def get_max_depth_of( self, event_ids: Collection[str] ) -> Tuple[Optional[str], int]: @@ -1148,13 +1180,14 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas ) @cached(max_entries=5000, iterable=True) - async def get_latest_event_ids_in_room(self, room_id: str) -> Sequence[str]: - return await self.db_pool.simple_select_onecol( + async def get_latest_event_ids_in_room(self, room_id: str) -> FrozenSet[str]: + event_ids = await self.db_pool.simple_select_onecol( table="event_forward_extremities", keyvalues={"room_id": room_id}, retcol="event_id", desc="get_latest_event_ids_in_room", ) + return frozenset(event_ids) async def get_min_depth(self, room_id: str) -> Optional[int]: """For the given room, get the minimum depth we have seen for it.""" @@ -1280,50 +1313,6 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas return event_ids - def _get_connected_batch_event_backfill_results_txn( - self, txn: LoggingTransaction, insertion_event_id: str, limit: int - ) -> List[BackfillQueueNavigationItem]: - """ - Find any batch connections of a given insertion event. - A batch event points at a insertion event via: - batch_event.content[MSC2716_BATCH_ID] -> insertion_event.content[MSC2716_NEXT_BATCH_ID] - - Args: - txn: The database transaction to use - insertion_event_id: The event ID to navigate from. We will find - batch events that point back at this insertion event. - limit: Max number of event ID's to query for and return - - Returns: - List of batch events that the backfill queue can process - """ - batch_connection_query = """ - SELECT e.depth, e.stream_ordering, c.event_id, e.type FROM insertion_events AS i - /* Find the batch that connects to the given insertion event */ - INNER JOIN batch_events AS c - ON i.next_batch_id = c.batch_id - /* Get the depth of the batch start event from the events table */ - INNER JOIN events AS e ON c.event_id = e.event_id - /* Find an insertion event which matches the given event_id */ - WHERE i.event_id = ? - LIMIT ? - """ - - # Find any batch connections for the given insertion event - txn.execute( - batch_connection_query, - (insertion_event_id, limit), - ) - return [ - BackfillQueueNavigationItem( - depth=row[0], - stream_ordering=row[1], - event_id=row[2], - type=row[3], - ) - for row in txn - ] - def _get_connected_prev_event_backfill_results_txn( self, txn: LoggingTransaction, event_id: str, limit: int ) -> List[BackfillQueueNavigationItem]: @@ -1472,40 +1461,6 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas event_id_results.add(event_id) - # Try and find any potential historical batches of message history. - if self.hs.config.experimental.msc2716_enabled: - # We need to go and try to find any batch events connected - # to a given insertion event (by batch_id). If we find any, we'll - # add them to the queue and navigate up the DAG like normal in the - # next iteration of the loop. - if event_type == EventTypes.MSC2716_INSERTION: - # Find any batch connections for the given insertion event - connected_batch_event_backfill_results = ( - self._get_connected_batch_event_backfill_results_txn( - txn, event_id, limit - len(event_id_results) - ) - ) - logger.debug( - "_get_backfill_events(room_id=%s): connected_batch_event_backfill_results=%s", - room_id, - connected_batch_event_backfill_results, - ) - for ( - connected_batch_event_backfill_item - ) in connected_batch_event_backfill_results: - if ( - connected_batch_event_backfill_item.event_id - not in event_id_results - ): - queue.put( - ( - -connected_batch_event_backfill_item.depth, - -connected_batch_event_backfill_item.stream_ordering, - connected_batch_event_backfill_item.event_id, - connected_batch_event_backfill_item.type, - ) - ) - # Now we just look up the DAG by prev_events as normal connected_prev_event_backfill_results = ( self._get_connected_prev_event_backfill_results_txn( @@ -1584,6 +1539,35 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas txn.execute(sql, (room_id, event_id, 1, self._clock.time_msec(), cause)) @trace + async def get_event_ids_with_failed_pull_attempts( + self, event_ids: StrCollection + ) -> Set[str]: + """ + Filter the given list of `event_ids` and return events which have any failed + pull attempts. + + Args: + event_ids: A list of events to filter down. + + Returns: + A filtered down list of `event_ids` that have previous failed pull attempts. + """ + + rows = await self.db_pool.simple_select_many_batch( + table="event_failed_pull_attempts", + column="event_id", + iterable=event_ids, + keyvalues={}, + retcols=("event_id",), + desc="get_event_ids_with_failed_pull_attempts", + ) + event_ids_with_failed_pull_attempts: Set[str] = { + row["event_id"] for row in rows + } + + return event_ids_with_failed_pull_attempts + + @trace async def get_event_ids_to_not_pull_from_backoff( self, room_id: str, @@ -1719,19 +1703,6 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas _delete_old_forward_extrem_cache_txn, ) - @trace - async def insert_insertion_extremity(self, event_id: str, room_id: str) -> None: - await self.db_pool.simple_upsert( - table="insertion_event_extremities", - keyvalues={"event_id": event_id}, - values={ - "event_id": event_id, - "room_id": room_id, - }, - insertion_values={}, - desc="insert_insertion_extremity", - ) - async def insert_received_event_to_staging( self, origin: str, event: EventBase ) -> None: diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index 6fdb1e292e..39556481ff 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -182,6 +182,7 @@ class UserPushAction(EmailPushAction): profile_tag: str +# TODO This is used as a cached value and is mutable. @attr.s(slots=True, auto_attribs=True) class NotifCounts: """ @@ -193,7 +194,7 @@ class NotifCounts: highlight_count: int = 0 -@attr.s(slots=True, auto_attribs=True) +@attr.s(slots=True, frozen=True, auto_attribs=True) class RoomNotifCounts: """ The per-user, per-room count of notifications. Used by sync and push. @@ -201,7 +202,7 @@ class RoomNotifCounts: main_timeline: NotifCounts # Map of thread ID to the notification counts. - threads: Dict[str, NotifCounts] + threads: Mapping[str, NotifCounts] @staticmethod def empty() -> "RoomNotifCounts": @@ -289,179 +290,52 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas unique=True, ) - self.db_pool.updates.register_background_update_handler( - "event_push_backfill_thread_id", - self._background_backfill_thread_id, + self.db_pool.updates.register_background_validate_constraint( + "event_push_actions_staging_thread_id", + constraint_name="event_push_actions_staging_thread_id", + table="event_push_actions_staging", ) - - # Indexes which will be used to quickly make the thread_id column non-null. - self.db_pool.updates.register_background_index_update( - "event_push_actions_thread_id_null", - index_name="event_push_actions_thread_id_null", + self.db_pool.updates.register_background_validate_constraint( + "event_push_actions_thread_id", + constraint_name="event_push_actions_thread_id", table="event_push_actions", - columns=["thread_id"], - where_clause="thread_id IS NULL", ) - self.db_pool.updates.register_background_index_update( - "event_push_summary_thread_id_null", - index_name="event_push_summary_thread_id_null", + self.db_pool.updates.register_background_validate_constraint( + "event_push_summary_thread_id", + constraint_name="event_push_summary_thread_id", table="event_push_summary", - columns=["thread_id"], - where_clause="thread_id IS NULL", ) - # Check ASAP (and then later, every 1s) to see if we have finished - # background updates the event_push_actions and event_push_summary tables. - self._clock.call_later(0.0, self._check_event_push_backfill_thread_id) - self._event_push_backfill_thread_id_done = False - - @wrap_as_background_process("check_event_push_backfill_thread_id") - async def _check_event_push_backfill_thread_id(self) -> None: - """ - Has thread_id finished backfilling? - - If not, we need to just-in-time update it so the queries work. - """ - done = await self.db_pool.updates.has_completed_background_update( - "event_push_backfill_thread_id" + self.db_pool.updates.register_background_update_handler( + "event_push_drop_null_thread_id_indexes", + self._background_drop_null_thread_id_indexes, ) - if done: - self._event_push_backfill_thread_id_done = True - else: - # Reschedule to run. - self._clock.call_later(15.0, self._check_event_push_backfill_thread_id) - - async def _background_backfill_thread_id( + async def _background_drop_null_thread_id_indexes( self, progress: JsonDict, batch_size: int ) -> int: """ - Fill in the thread_id field for event_push_actions and event_push_summary. - - This is preparatory so that it can be made non-nullable in the future. - - Because all current (null) data is done in an unthreaded manner this - simply assumes it is on the "main" timeline. Since event_push_actions - are periodically cleared it is not possible to correctly re-calculate - the thread_id. + Drop the indexes used to find null thread_ids for event_push_actions and + event_push_summary. """ - event_push_actions_done = progress.get("event_push_actions_done", False) - - def add_thread_id_txn( - txn: LoggingTransaction, start_stream_ordering: int - ) -> int: - sql = """ - SELECT stream_ordering - FROM event_push_actions - WHERE - thread_id IS NULL - AND stream_ordering > ? - ORDER BY stream_ordering - LIMIT ? - """ - txn.execute(sql, (start_stream_ordering, batch_size)) - - # No more rows to process. - rows = txn.fetchall() - if not rows: - progress["event_push_actions_done"] = True - self.db_pool.updates._background_update_progress_txn( - txn, "event_push_backfill_thread_id", progress - ) - return 0 - - # Update the thread ID for any of those rows. - max_stream_ordering = rows[-1][0] - - sql = """ - UPDATE event_push_actions - SET thread_id = 'main' - WHERE ? < stream_ordering AND stream_ordering <= ? AND thread_id IS NULL - """ - txn.execute( - sql, - ( - start_stream_ordering, - max_stream_ordering, - ), - ) - - # Update progress. - processed_rows = txn.rowcount - progress["max_event_push_actions_stream_ordering"] = max_stream_ordering - self.db_pool.updates._background_update_progress_txn( - txn, "event_push_backfill_thread_id", progress - ) - - return processed_rows - - def add_thread_id_summary_txn(txn: LoggingTransaction) -> int: - min_user_id = progress.get("max_summary_user_id", "") - min_room_id = progress.get("max_summary_room_id", "") - - # Slightly overcomplicated query for getting the Nth user ID / room - # ID tuple, or the last if there are less than N remaining. - sql = """ - SELECT user_id, room_id FROM ( - SELECT user_id, room_id FROM event_push_summary - WHERE (user_id, room_id) > (?, ?) - AND thread_id IS NULL - ORDER BY user_id, room_id - LIMIT ? - ) AS e - ORDER BY user_id DESC, room_id DESC - LIMIT 1 - """ - txn.execute(sql, (min_user_id, min_room_id, batch_size)) - row = txn.fetchone() - if not row: - return 0 - - max_user_id, max_room_id = row - - sql = """ - UPDATE event_push_summary - SET thread_id = 'main' - WHERE - (?, ?) < (user_id, room_id) AND (user_id, room_id) <= (?, ?) - AND thread_id IS NULL - """ - txn.execute(sql, (min_user_id, min_room_id, max_user_id, max_room_id)) - processed_rows = txn.rowcount + def drop_null_thread_id_indexes_txn(txn: LoggingTransaction) -> None: + sql = "DROP INDEX IF EXISTS event_push_actions_thread_id_null" + logger.debug("[SQL] %s", sql) + txn.execute(sql) - progress["max_summary_user_id"] = max_user_id - progress["max_summary_room_id"] = max_room_id - self.db_pool.updates._background_update_progress_txn( - txn, "event_push_backfill_thread_id", progress - ) - - return processed_rows + sql = "DROP INDEX IF EXISTS event_push_summary_thread_id_null" + logger.debug("[SQL] %s", sql) + txn.execute(sql) - # First update the event_push_actions table, then the event_push_summary table. - # - # Note that the event_push_actions_staging table is ignored since it is - # assumed that items in that table will only exist for a short period of - # time. - if not event_push_actions_done: - result = await self.db_pool.runInteraction( - "event_push_backfill_thread_id", - add_thread_id_txn, - progress.get("max_event_push_actions_stream_ordering", 0), - ) - else: - result = await self.db_pool.runInteraction( - "event_push_backfill_thread_id", - add_thread_id_summary_txn, - ) - - # Only done after the event_push_summary table is done. - if not result: - await self.db_pool.updates._end_background_update( - "event_push_backfill_thread_id" - ) - - return result + await self.db_pool.runInteraction( + "drop_null_thread_id_indexes_txn", + drop_null_thread_id_indexes_txn, + ) + await self.db_pool.updates._end_background_update( + "event_push_drop_null_thread_id_indexes" + ) + return 0 async def get_unread_counts_by_room_for_user(self, user_id: str) -> Dict[str, int]: """Get the notification count by room for a user. Only considers notifications, @@ -610,7 +484,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas return room_to_count - @cached(tree=True, max_entries=5000, iterable=True) + @cached(tree=True, max_entries=5000, iterable=True) # type: ignore[synapse-@cached-mutable] async def get_unread_event_push_actions_by_room_for_user( self, room_id: str, @@ -711,25 +585,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas (ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE), ) - # First ensure that the existing rows have an updated thread_id field. - if not self._event_push_backfill_thread_id_done: - txn.execute( - """ - UPDATE event_push_summary - SET thread_id = ? - WHERE room_id = ? AND user_id = ? AND thread_id is NULL - """, - (MAIN_TIMELINE, room_id, user_id), - ) - txn.execute( - """ - UPDATE event_push_actions - SET thread_id = ? - WHERE room_id = ? AND user_id = ? AND thread_id is NULL - """, - (MAIN_TIMELINE, room_id, user_id), - ) - # First we pull the counts from the summary table. # # We check that `last_receipt_stream_ordering` matches the stream ordering of the @@ -1545,25 +1400,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas (room_id, user_id, stream_ordering, *thread_args), ) - # First ensure that the existing rows have an updated thread_id field. - if not self._event_push_backfill_thread_id_done: - txn.execute( - """ - UPDATE event_push_summary - SET thread_id = ? - WHERE room_id = ? AND user_id = ? AND thread_id is NULL - """, - (MAIN_TIMELINE, room_id, user_id), - ) - txn.execute( - """ - UPDATE event_push_actions - SET thread_id = ? - WHERE room_id = ? AND user_id = ? AND thread_id is NULL - """, - (MAIN_TIMELINE, room_id, user_id), - ) - # Fetch the notification counts between the stream ordering of the # latest receipt and what was previously summarised. unread_counts = self._get_notif_unread_count_for_user_room( @@ -1698,19 +1534,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas rotate_to_stream_ordering: The new maximum event stream ordering to summarise. """ - # Ensure that any new actions have an updated thread_id. - if not self._event_push_backfill_thread_id_done: - txn.execute( - """ - UPDATE event_push_actions - SET thread_id = ? - WHERE ? < stream_ordering AND stream_ordering <= ? AND thread_id IS NULL - """, - (MAIN_TIMELINE, old_rotate_stream_ordering, rotate_to_stream_ordering), - ) - - # XXX Do we need to update summaries here too? - # Calculate the new counts that should be upserted into event_push_summary sql = """ SELECT user_id, room_id, thread_id, @@ -1773,28 +1596,11 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas logger.info("Rotating notifications, handling %d rows", len(summaries)) - # Ensure that any updated threads have the proper thread_id. - if not self._event_push_backfill_thread_id_done: - txn.execute_batch( - """ - UPDATE event_push_summary - SET thread_id = ? - WHERE room_id = ? AND user_id = ? AND thread_id is NULL - """, - [ - (MAIN_TIMELINE, room_id, user_id) - for user_id, room_id, _ in summaries - ], - ) - self.db_pool.simple_upsert_many_txn( txn, table="event_push_summary", key_names=("user_id", "room_id", "thread_id"), - key_values=[ - (user_id, room_id, thread_id) - for user_id, room_id, thread_id in summaries - ], + key_values=list(summaries), value_names=("notif_count", "unread_count", "stream_ordering"), value_values=[ ( @@ -1932,42 +1738,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas # We sleep to ensure that we don't overwhelm the DB. await self._clock.sleep(1.0) - -class EventPushActionsStore(EventPushActionsWorkerStore): - EPA_HIGHLIGHT_INDEX = "epa_highlight_index" - - def __init__( - self, - database: DatabasePool, - db_conn: LoggingDatabaseConnection, - hs: "HomeServer", - ): - super().__init__(database, db_conn, hs) - - self.db_pool.updates.register_background_index_update( - self.EPA_HIGHLIGHT_INDEX, - index_name="event_push_actions_u_highlight", - table="event_push_actions", - columns=["user_id", "stream_ordering"], - ) - - self.db_pool.updates.register_background_index_update( - "event_push_actions_highlights_index", - index_name="event_push_actions_highlights_index", - table="event_push_actions", - columns=["user_id", "room_id", "topological_ordering", "stream_ordering"], - where_clause="highlight=1", - ) - - # Add index to make deleting old push actions faster. - self.db_pool.updates.register_background_index_update( - "event_push_actions_stream_highlight_index", - index_name="event_push_actions_stream_highlight_index", - table="event_push_actions", - columns=["highlight", "stream_ordering"], - where_clause="highlight=0", - ) - async def get_push_actions_for_user( self, user_id: str, @@ -2026,6 +1796,42 @@ class EventPushActionsStore(EventPushActionsWorkerStore): ] +class EventPushActionsStore(EventPushActionsWorkerStore): + EPA_HIGHLIGHT_INDEX = "epa_highlight_index" + + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): + super().__init__(database, db_conn, hs) + + self.db_pool.updates.register_background_index_update( + self.EPA_HIGHLIGHT_INDEX, + index_name="event_push_actions_u_highlight", + table="event_push_actions", + columns=["user_id", "stream_ordering"], + ) + + self.db_pool.updates.register_background_index_update( + "event_push_actions_highlights_index", + index_name="event_push_actions_highlights_index", + table="event_push_actions", + columns=["user_id", "room_id", "topological_ordering", "stream_ordering"], + where_clause="highlight=1", + ) + + # Add index to make deleting old push actions faster. + self.db_pool.updates.register_background_index_update( + "event_push_actions_stream_highlight_index", + index_name="event_push_actions_stream_highlight_index", + table="event_push_actions", + columns=["highlight", "stream_ordering"], + where_clause="highlight=0", + ) + + def _action_has_highlight(actions: Collection[Union[Mapping, str]]) -> bool: for action in actions: if not isinstance(action, dict): diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index e2e6eb479f..d4dcdb898c 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -127,8 +127,6 @@ class PersistEventsStore: self._backfill_id_gen: AbstractStreamIdGenerator = self.store._backfill_id_gen self._stream_id_gen: AbstractStreamIdGenerator = self.store._stream_id_gen - self._msc3970_enabled = hs.config.experimental.msc3970_enabled - @trace async def _persist_events_and_state_updates( self, @@ -224,7 +222,7 @@ class PersistEventsStore: for room_id, latest_event_ids in new_forward_extremities.items(): self.store.get_latest_event_ids_in_room.prefill( - (room_id,), list(latest_event_ids) + (room_id,), frozenset(latest_event_ids) ) async def _get_events_which_are_prevs(self, event_ids: Iterable[str]) -> List[str]: @@ -415,12 +413,6 @@ class PersistEventsStore: backfilled=False, ) - self._update_forward_extremities_txn( - txn, - new_forward_extremities=new_forward_extremities, - max_stream_order=max_stream_order, - ) - # Ensure that we don't have the same event twice. events_and_contexts = self._filter_events_and_contexts_for_duplicates( events_and_contexts @@ -439,6 +431,12 @@ class PersistEventsStore: self._store_event_txn(txn, events_and_contexts=events_and_contexts) + self._update_forward_extremities_txn( + txn, + new_forward_extremities=new_forward_extremities, + max_stream_order=max_stream_order, + ) + self._persist_transaction_ids_txn(txn, events_and_contexts) # Insert into event_to_state_groups. @@ -829,15 +827,7 @@ class PersistEventsStore: "target_chain_id", "target_sequence_number", ), - values=[ - (source_id, source_seq, target_id, target_seq) - for ( - source_id, - source_seq, - target_id, - target_seq, - ) in chain_links.get_additions() - ], + values=list(chain_links.get_additions()), ) @staticmethod @@ -980,26 +970,12 @@ class PersistEventsStore: """Persist the mapping from transaction IDs to event IDs (if defined).""" inserted_ts = self._clock.time_msec() - to_insert_token_id: List[Tuple[str, str, str, int, str, int]] = [] to_insert_device_id: List[Tuple[str, str, str, str, str, int]] = [] for event, _ in events_and_contexts: txn_id = getattr(event.internal_metadata, "txn_id", None) - token_id = getattr(event.internal_metadata, "token_id", None) device_id = getattr(event.internal_metadata, "device_id", None) if txn_id is not None: - if token_id is not None: - to_insert_token_id.append( - ( - event.event_id, - event.room_id, - event.sender, - token_id, - txn_id, - inserted_ts, - ) - ) - if device_id is not None: to_insert_device_id.append( ( @@ -1012,28 +988,8 @@ class PersistEventsStore: ) ) - # Pre-MSC3970, we rely on the access_token_id to scope the txn_id for events. - # Since this is an experimental flag, we still store the mapping even if the - # flag is disabled. - if to_insert_token_id: - self.db_pool.simple_insert_many_txn( - txn, - table="event_txn_id", - keys=( - "event_id", - "room_id", - "user_id", - "token_id", - "txn_id", - "inserted_ts", - ), - values=to_insert_token_id, - ) - - # With MSC3970, we rely on the device_id instead to scope the txn_id for events. - # We're only inserting if MSC3970 is *enabled*, because else the pre-MSC3970 - # behaviour would allow for a UNIQUE constraint violation on this table - if to_insert_device_id and self._msc3970_enabled: + # Synapse relies on the device_id to scope transactions for events.. + if to_insert_device_id: self.db_pool.simple_insert_many_txn( txn, table="event_txn_id_device_id", @@ -1455,8 +1411,8 @@ class PersistEventsStore: }, ) - sql = "UPDATE events SET outlier = ? WHERE event_id = ?" - txn.execute(sql, (False, event.event_id)) + sql = "UPDATE events SET outlier = FALSE WHERE event_id = ?" + txn.execute(sql, (event.event_id,)) # Update the event_backward_extremities table now that this # event isn't an outlier any more. @@ -1549,13 +1505,13 @@ class PersistEventsStore: for event, _ in events_and_contexts if not event.internal_metadata.is_redacted() ] - sql = "UPDATE redactions SET have_censored = ? WHERE " + sql = "UPDATE redactions SET have_censored = FALSE WHERE " clause, args = make_in_list_sql_clause( self.database_engine, "redacts", unredacted_events, ) - txn.execute(sql + clause, [False] + args) + txn.execute(sql + clause, args) self.db_pool.simple_insert_many_txn( txn, @@ -1664,9 +1620,6 @@ class PersistEventsStore: self._handle_event_relations(txn, event) - self._handle_insertion_event(txn, event) - self._handle_batch_event(txn, event) - # Store the labels for this event. labels = event.content.get(EventContentFields.LABELS) if labels: @@ -1677,7 +1630,7 @@ class PersistEventsStore: if self._ephemeral_messages_enabled: # If there's an expiry timestamp on the event, store it. expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER) - if type(expiry_ts) is int and not event.is_state(): + if type(expiry_ts) is int and not event.is_state(): # noqa: E721 self._insert_event_expiry_txn(txn, event.event_id, expiry_ts) # Insert into the room_memberships table. @@ -1701,8 +1654,6 @@ class PersistEventsStore: ) -> None: to_prefill = [] - rows = [] - ev_map = {e.event_id: e for e, _ in events_and_contexts} if not ev_map: return @@ -1723,19 +1674,27 @@ class PersistEventsStore: ) txn.execute(sql + clause, args) - rows = self.db_pool.cursor_to_dict(txn) - for row in rows: - event = ev_map[row["event_id"]] - if not row["rejects"] and not row["redacts"]: + for event_id, redacts, rejects in txn: + event = ev_map[event_id] + if not rejects and not redacts: to_prefill.append(EventCacheEntry(event=event, redacted_event=None)) - async def prefill() -> None: + async def external_prefill() -> None: + for cache_entry in to_prefill: + await self.store._get_event_cache.set_external( + (cache_entry.event.event_id,), cache_entry + ) + + def local_prefill() -> None: for cache_entry in to_prefill: - await self.store._get_event_cache.set( + self.store._get_event_cache.set_local( (cache_entry.event.event_id,), cache_entry ) - txn.async_call_after(prefill) + # The order these are called here is not as important as knowing that after the + # transaction is finished, the async_call_after will run before the call_after. + txn.async_call_after(external_prefill) + txn.call_after(local_prefill) def _store_redaction(self, txn: LoggingTransaction, event: EventBase) -> None: assert event.redacts is not None @@ -1918,128 +1877,6 @@ class PersistEventsStore: ), ) - def _handle_insertion_event( - self, txn: LoggingTransaction, event: EventBase - ) -> None: - """Handles keeping track of insertion events and edges/connections. - Part of MSC2716. - - Args: - txn: The database transaction object - event: The event to process - """ - - if event.type != EventTypes.MSC2716_INSERTION: - # Not a insertion event - return - - # Skip processing an insertion event if the room version doesn't - # support it or the event is not from the room creator. - room_version = self.store.get_room_version_txn(txn, event.room_id) - room_creator = self.db_pool.simple_select_one_onecol_txn( - txn, - table="rooms", - keyvalues={"room_id": event.room_id}, - retcol="creator", - allow_none=True, - ) - if not room_version.msc2716_historical and ( - not self.hs.config.experimental.msc2716_enabled - or event.sender != room_creator - ): - return - - next_batch_id = event.content.get(EventContentFields.MSC2716_NEXT_BATCH_ID) - if next_batch_id is None: - # Invalid insertion event without next batch ID - return - - logger.debug( - "_handle_insertion_event (next_batch_id=%s) %s", next_batch_id, event - ) - - # Keep track of the insertion event and the batch ID - self.db_pool.simple_insert_txn( - txn, - table="insertion_events", - values={ - "event_id": event.event_id, - "room_id": event.room_id, - "next_batch_id": next_batch_id, - }, - ) - - # Insert an edge for every prev_event connection - for prev_event_id in event.prev_event_ids(): - self.db_pool.simple_insert_txn( - txn, - table="insertion_event_edges", - values={ - "event_id": event.event_id, - "room_id": event.room_id, - "insertion_prev_event_id": prev_event_id, - }, - ) - - def _handle_batch_event(self, txn: LoggingTransaction, event: EventBase) -> None: - """Handles inserting the batch edges/connections between the batch event - and an insertion event. Part of MSC2716. - - Args: - txn: The database transaction object - event: The event to process - """ - - if event.type != EventTypes.MSC2716_BATCH: - # Not a batch event - return - - # Skip processing a batch event if the room version doesn't - # support it or the event is not from the room creator. - room_version = self.store.get_room_version_txn(txn, event.room_id) - room_creator = self.db_pool.simple_select_one_onecol_txn( - txn, - table="rooms", - keyvalues={"room_id": event.room_id}, - retcol="creator", - allow_none=True, - ) - if not room_version.msc2716_historical and ( - not self.hs.config.experimental.msc2716_enabled - or event.sender != room_creator - ): - return - - batch_id = event.content.get(EventContentFields.MSC2716_BATCH_ID) - if batch_id is None: - # Invalid batch event without a batch ID - return - - logger.debug("_handle_batch_event batch_id=%s %s", batch_id, event) - - # Keep track of the insertion event and the batch ID - self.db_pool.simple_insert_txn( - txn, - table="batch_events", - values={ - "event_id": event.event_id, - "room_id": event.room_id, - "batch_id": batch_id, - }, - ) - - # When we receive an event with a `batch_id` referencing the - # `next_batch_id` of the insertion event, we can remove it from the - # `insertion_event_extremities` table. - sql = """ - DELETE FROM insertion_event_extremities WHERE event_id IN ( - SELECT event_id FROM insertion_events - WHERE next_batch_id = ? - ) - """ - - txn.execute(sql, (batch_id,)) - def _handle_redact_relations( self, txn: LoggingTransaction, room_id: str, redacted_event_id: str ) -> None: @@ -2158,10 +1995,10 @@ class PersistEventsStore: ): if ( "min_lifetime" in event.content - and type(event.content["min_lifetime"]) is not int + and type(event.content["min_lifetime"]) is not int # noqa: E721 ) or ( "max_lifetime" in event.content - and type(event.content["max_lifetime"]) is not int + and type(event.content["max_lifetime"]) is not int # noqa: E721 ): # Ignore the event if one of the value isn't an integer. return @@ -2434,14 +2271,14 @@ class PersistEventsStore: " SELECT 1 FROM events" " LEFT JOIN event_edges edge" " ON edge.event_id = events.event_id" - " WHERE events.event_id = ? AND events.room_id = ? AND (events.outlier = ? OR edge.event_id IS NULL)" + " WHERE events.event_id = ? AND events.room_id = ? AND (events.outlier = FALSE OR edge.event_id IS NULL)" " )" ) txn.execute_batch( query, [ - (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False) + (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id) for ev in events for e_id in ev.prev_event_ids() if not ev.internal_metadata.is_outlier() diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index a39bc90974..b788d70fc5 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -24,6 +24,7 @@ from typing import ( Dict, Iterable, List, + Mapping, MutableMapping, Optional, Set, @@ -883,7 +884,7 @@ class EventsWorkerStore(SQLBaseStore): async def _invalidate_async_get_event_cache(self, event_id: str) -> None: """ - Invalidates an event in the asyncronous get event cache, which may be remote. + Invalidates an event in the asynchronous get event cache, which may be remote. Arguments: event_id: the event ID to invalidate @@ -903,6 +904,15 @@ class EventsWorkerStore(SQLBaseStore): self._event_ref.pop(event_id, None) self._current_event_fetches.pop(event_id, None) + def _invalidate_local_get_event_cache_all(self) -> None: + """Clears the in-memory get event caches. + + Used when we purge room history. + """ + self._get_event_cache.clear() + self._event_ref.clear() + self._current_event_fetches.clear() + async def _get_events_from_cache( self, events: Iterable[str], update_metrics: bool = True ) -> Dict[str, EventCacheEntry]: @@ -1624,7 +1634,7 @@ class EventsWorkerStore(SQLBaseStore): self, room_id: str, event_ids: Collection[str], - ) -> Dict[str, bool]: + ) -> Mapping[str, bool]: """Helper for have_seen_events Returns: @@ -2013,25 +2023,6 @@ class EventsWorkerStore(SQLBaseStore): desc="get_next_event_to_expire", func=get_next_event_to_expire_txn ) - async def get_event_id_from_transaction_id_and_token_id( - self, room_id: str, user_id: str, token_id: int, txn_id: str - ) -> Optional[str]: - """Look up if we have already persisted an event for the transaction ID, - returning the event ID if so. - """ - return await self.db_pool.simple_select_one_onecol( - table="event_txn_id", - keyvalues={ - "room_id": room_id, - "user_id": user_id, - "token_id": token_id, - "txn_id": txn_id, - }, - retcol="event_id", - allow_none=True, - desc="get_event_id_from_transaction_id_and_token_id", - ) - async def get_event_id_from_transaction_id_and_device_id( self, room_id: str, user_id: str, device_id: str, txn_id: str ) -> Optional[str]: @@ -2063,29 +2054,35 @@ class EventsWorkerStore(SQLBaseStore): """ mapping = {} - txn_id_to_event: Dict[Tuple[str, int, str], str] = {} + txn_id_to_event: Dict[Tuple[str, str, str, str], str] = {} for event in events: - token_id = getattr(event.internal_metadata, "token_id", None) + device_id = getattr(event.internal_metadata, "device_id", None) txn_id = getattr(event.internal_metadata, "txn_id", None) - if token_id and txn_id: + if device_id and txn_id: # Check if this is a duplicate of an event in the given events. - existing = txn_id_to_event.get((event.room_id, token_id, txn_id)) + existing = txn_id_to_event.get( + (event.room_id, event.sender, device_id, txn_id) + ) if existing: mapping[event.event_id] = existing continue # Check if this is a duplicate of an event we've already # persisted. - existing = await self.get_event_id_from_transaction_id_and_token_id( - event.room_id, event.sender, token_id, txn_id + existing = await self.get_event_id_from_transaction_id_and_device_id( + event.room_id, event.sender, device_id, txn_id ) if existing: mapping[event.event_id] = existing - txn_id_to_event[(event.room_id, token_id, txn_id)] = existing + txn_id_to_event[ + (event.room_id, event.sender, device_id, txn_id) + ] = existing else: - txn_id_to_event[(event.room_id, token_id, txn_id)] = event.event_id + txn_id_to_event[ + (event.room_id, event.sender, device_id, txn_id) + ] = event.event_id return mapping @@ -2329,7 +2326,7 @@ class EventsWorkerStore(SQLBaseStore): @cachedList(cached_method_name="is_partial_state_event", list_name="event_ids") async def get_partial_state_events( self, event_ids: Collection[str] - ) -> Dict[str, bool]: + ) -> Mapping[str, bool]: """Checks which of the given events have partial state Args: diff --git a/synapse/storage/databases/main/experimental_features.py b/synapse/storage/databases/main/experimental_features.py index cf3226ae5a..654f924019 100644 --- a/synapse/storage/databases/main/experimental_features.py +++ b/synapse/storage/databases/main/experimental_features.py @@ -12,11 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Dict +from typing import TYPE_CHECKING, Dict, FrozenSet from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.databases.main import CacheInvalidationWorkerStore -from synapse.types import StrCollection from synapse.util.caches.descriptors import cached if TYPE_CHECKING: @@ -34,7 +33,7 @@ class ExperimentalFeaturesStore(CacheInvalidationWorkerStore): super().__init__(database, db_conn, hs) @cached() - async def list_enabled_features(self, user_id: str) -> StrCollection: + async def list_enabled_features(self, user_id: str) -> FrozenSet[str]: """ Checks to see what features are enabled for a given user Args: @@ -49,7 +48,7 @@ class ExperimentalFeaturesStore(CacheInvalidationWorkerStore): ["feature"], ) - return [feature["feature"] for feature in enabled] + return frozenset(feature["feature"] for feature in enabled) async def set_features_for_user( self, diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py index da31eb44dc..7d94685caf 100644 --- a/synapse/storage/databases/main/filtering.py +++ b/synapse/storage/databases/main/filtering.py @@ -13,10 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple, Union, cast +from typing import TYPE_CHECKING, Optional, Tuple, Union, cast from canonicaljson import encode_canonical_json -from typing_extensions import TYPE_CHECKING from synapse.api.errors import Codes, StoreError, SynapseError from synapse.storage._base import SQLBaseStore, db_to_json @@ -26,7 +25,7 @@ from synapse.storage.database import ( LoggingTransaction, ) from synapse.storage.engines import PostgresEngine -from synapse.types import JsonDict, UserID +from synapse.types import JsonDict, JsonMapping, UserID from synapse.util.caches.descriptors import cached if TYPE_CHECKING: @@ -71,7 +70,7 @@ class FilteringWorkerStore(SQLBaseStore): SELECT user_id FROM user_filters WHERE user_id > ? ORDER BY user_id - LIMIT 1 OFFSET 50 + LIMIT 1 OFFSET 1000 """ txn.execute(sql, (lower_bound_id,)) res = txn.fetchone() @@ -145,8 +144,8 @@ class FilteringWorkerStore(SQLBaseStore): @cached(num_args=2) async def get_user_filter( - self, user_localpart: str, filter_id: Union[int, str] - ) -> JsonDict: + self, user_id: UserID, filter_id: Union[int, str] + ) -> JsonMapping: # filter_id is BIGINT UNSIGNED, so if it isn't a number, fail # with a coherent error message rather than 500 M_UNKNOWN. try: @@ -156,7 +155,7 @@ class FilteringWorkerStore(SQLBaseStore): def_json = await self.db_pool.simple_select_one_onecol( table="user_filters", - keyvalues={"user_id": user_localpart, "filter_id": filter_id}, + keyvalues={"full_user_id": user_id.to_string(), "filter_id": filter_id}, retcol="filter_json", allow_none=False, desc="get_user_filter", @@ -172,15 +171,15 @@ class FilteringWorkerStore(SQLBaseStore): def _do_txn(txn: LoggingTransaction) -> int: sql = ( "SELECT filter_id FROM user_filters " - "WHERE user_id = ? AND filter_json = ?" + "WHERE full_user_id = ? AND filter_json = ?" ) - txn.execute(sql, (user_id.localpart, bytearray(def_json))) + txn.execute(sql, (user_id.to_string(), bytearray(def_json))) filter_id_response = txn.fetchone() if filter_id_response is not None: return filter_id_response[0] - sql = "SELECT MAX(filter_id) FROM user_filters WHERE user_id = ?" - txn.execute(sql, (user_id.localpart,)) + sql = "SELECT MAX(filter_id) FROM user_filters WHERE full_user_id = ?" + txn.execute(sql, (user_id.to_string(),)) max_id = cast(Tuple[Optional[int]], txn.fetchone())[0] if max_id is None: filter_id = 0 diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py index 1666e3c43b..889c578b9c 100644 --- a/synapse/storage/databases/main/keys.py +++ b/synapse/storage/databases/main/keys.py @@ -16,15 +16,17 @@ import itertools import json import logging -from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple +from typing import Dict, Iterable, Mapping, Optional, Tuple +from canonicaljson import encode_canonical_json from signedjson.key import decode_verify_key_bytes from unpaddedbase64 import decode_base64 -from synapse.storage._base import SQLBaseStore from synapse.storage.database import LoggingTransaction -from synapse.storage.keys import FetchKeyResult +from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore +from synapse.storage.keys import FetchKeyResult, FetchKeyResultForRemote from synapse.storage.types import Cursor +from synapse.types import JsonDict from synapse.util.caches.descriptors import cached, cachedList from synapse.util.iterutils import batch_iter @@ -34,162 +36,89 @@ logger = logging.getLogger(__name__) db_binary_type = memoryview -class KeyStore(SQLBaseStore): +class KeyStore(CacheInvalidationWorkerStore): """Persistence for signature verification keys""" - @cached() - def _get_server_signature_key( - self, server_name_and_key_id: Tuple[str, str] - ) -> FetchKeyResult: - raise NotImplementedError() - - @cachedList( - cached_method_name="_get_server_signature_key", - list_name="server_name_and_key_ids", - ) - async def get_server_signature_keys( - self, server_name_and_key_ids: Iterable[Tuple[str, str]] - ) -> Dict[Tuple[str, str], FetchKeyResult]: - """ - Args: - server_name_and_key_ids: - iterable of (server_name, key-id) tuples to fetch keys for - - Returns: - A map from (server_name, key_id) -> FetchKeyResult, or None if the - key is unknown - """ - keys = {} - - def _get_keys(txn: Cursor, batch: Tuple[Tuple[str, str], ...]) -> None: - """Processes a batch of keys to fetch, and adds the result to `keys`.""" - - # batch_iter always returns tuples so it's safe to do len(batch) - sql = """ - SELECT server_name, key_id, verify_key, ts_valid_until_ms - FROM server_signature_keys WHERE 1=0 - """ + " OR (server_name=? AND key_id=?)" * len( - batch - ) - - txn.execute(sql, tuple(itertools.chain.from_iterable(batch))) - - for row in txn: - server_name, key_id, key_bytes, ts_valid_until_ms = row - - if ts_valid_until_ms is None: - # Old keys may be stored with a ts_valid_until_ms of null, - # in which case we treat this as if it was set to `0`, i.e. - # it won't match key requests that define a minimum - # `ts_valid_until_ms`. - ts_valid_until_ms = 0 - - keys[(server_name, key_id)] = FetchKeyResult( - verify_key=decode_verify_key_bytes(key_id, bytes(key_bytes)), - valid_until_ts=ts_valid_until_ms, - ) - - def _txn(txn: Cursor) -> Dict[Tuple[str, str], FetchKeyResult]: - for batch in batch_iter(server_name_and_key_ids, 50): - _get_keys(txn, batch) - return keys - - return await self.db_pool.runInteraction("get_server_signature_keys", _txn) - - async def store_server_signature_keys( + async def store_server_keys_response( self, + server_name: str, from_server: str, ts_added_ms: int, - verify_keys: Mapping[Tuple[str, str], FetchKeyResult], + verify_keys: Dict[str, FetchKeyResult], + response_json: JsonDict, ) -> None: - """Stores NACL verification keys for remote servers. + """Stores the keys for the given server that we got from `from_server`. + Args: - from_server: Where the verification keys were looked up - ts_added_ms: The time to record that the key was added - verify_keys: - keys to be stored. Each entry is a triplet of - (server_name, key_id, key). + server_name: The owner of the keys + from_server: Which server we got the keys from + ts_added_ms: When we're adding the keys + verify_keys: The decoded keys + response_json: The full *signed* response JSON that contains the keys. """ - key_values = [] - value_values = [] - invalidations = [] - for (server_name, key_id), fetch_result in verify_keys.items(): - key_values.append((server_name, key_id)) - value_values.append( - ( - from_server, - ts_added_ms, - fetch_result.valid_until_ts, - db_binary_type(fetch_result.verify_key.encode()), - ) + + key_json_bytes = encode_canonical_json(response_json) + + def store_server_keys_response_txn(txn: LoggingTransaction) -> None: + self.db_pool.simple_upsert_many_txn( + txn, + table="server_signature_keys", + key_names=("server_name", "key_id"), + key_values=[(server_name, key_id) for key_id in verify_keys], + value_names=( + "from_server", + "ts_added_ms", + "ts_valid_until_ms", + "verify_key", + ), + value_values=[ + ( + from_server, + ts_added_ms, + fetch_result.valid_until_ts, + db_binary_type(fetch_result.verify_key.encode()), + ) + for fetch_result in verify_keys.values() + ], ) - # invalidate takes a tuple corresponding to the params of - # _get_server_signature_key. _get_server_signature_key only takes one - # param, which is itself the 2-tuple (server_name, key_id). - invalidations.append((server_name, key_id)) - await self.db_pool.simple_upsert_many( - table="server_signature_keys", - key_names=("server_name", "key_id"), - key_values=key_values, - value_names=( - "from_server", - "ts_added_ms", - "ts_valid_until_ms", - "verify_key", - ), - value_values=value_values, - desc="store_server_signature_keys", - ) + self.db_pool.simple_upsert_many_txn( + txn, + table="server_keys_json", + key_names=("server_name", "key_id", "from_server"), + key_values=[ + (server_name, key_id, from_server) for key_id in verify_keys + ], + value_names=( + "ts_added_ms", + "ts_valid_until_ms", + "key_json", + ), + value_values=[ + ( + ts_added_ms, + fetch_result.valid_until_ts, + db_binary_type(key_json_bytes), + ) + for fetch_result in verify_keys.values() + ], + ) - invalidate = self._get_server_signature_key.invalidate - for i in invalidations: - invalidate((i,)) + # invalidate takes a tuple corresponding to the params of + # _get_server_keys_json. _get_server_keys_json only takes one + # param, which is itself the 2-tuple (server_name, key_id). + for key_id in verify_keys: + self._invalidate_cache_and_stream( + txn, self._get_server_keys_json, ((server_name, key_id),) + ) + self._invalidate_cache_and_stream( + txn, self.get_server_key_json_for_remote, (server_name, key_id) + ) - async def store_server_keys_json( - self, - server_name: str, - key_id: str, - from_server: str, - ts_now_ms: int, - ts_expires_ms: int, - key_json_bytes: bytes, - ) -> None: - """Stores the JSON bytes for a set of keys from a server - The JSON should be signed by the originating server, the intermediate - server, and by this server. Updates the value for the - (server_name, key_id, from_server) triplet if one already existed. - Args: - server_name: The name of the server. - key_id: The identifier of the key this JSON is for. - from_server: The server this JSON was fetched from. - ts_now_ms: The time now in milliseconds. - ts_valid_until_ms: The time when this json stops being valid. - key_json_bytes: The encoded JSON. - """ - await self.db_pool.simple_upsert( - table="server_keys_json", - keyvalues={ - "server_name": server_name, - "key_id": key_id, - "from_server": from_server, - }, - values={ - "server_name": server_name, - "key_id": key_id, - "from_server": from_server, - "ts_added_ms": ts_now_ms, - "ts_valid_until_ms": ts_expires_ms, - "key_json": db_binary_type(key_json_bytes), - }, - desc="store_server_keys_json", + await self.db_pool.runInteraction( + "store_server_keys_response", store_server_keys_response_txn ) - # invalidate takes a tuple corresponding to the params of - # _get_server_keys_json. _get_server_keys_json only takes one - # param, which is itself the 2-tuple (server_name, key_id). - self._get_server_keys_json.invalidate((((server_name, key_id),))) - @cached() def _get_server_keys_json( self, server_name_and_key_id: Tuple[str, str] @@ -201,7 +130,7 @@ class KeyStore(SQLBaseStore): ) async def get_server_keys_json( self, server_name_and_key_ids: Iterable[Tuple[str, str]] - ) -> Dict[Tuple[str, str], FetchKeyResult]: + ) -> Mapping[Tuple[str, str], FetchKeyResult]: """ Args: server_name_and_key_ids: @@ -217,12 +146,17 @@ class KeyStore(SQLBaseStore): """Processes a batch of keys to fetch, and adds the result to `keys`.""" # batch_iter always returns tuples so it's safe to do len(batch) - sql = """ - SELECT server_name, key_id, key_json, ts_valid_until_ms - FROM server_keys_json WHERE 1=0 - """ + " OR (server_name=? AND key_id=?)" * len( - batch - ) + where_clause = " OR (server_name=? AND key_id=?)" * len(batch) + + # `server_keys_json` can have multiple entries per server (one per + # remote server we fetched from, if using perspectives). Order by + # `ts_added_ms` so the most recently fetched one always wins. + sql = f""" + SELECT server_name, key_id, key_json, ts_valid_until_ms + FROM server_keys_json WHERE 1=0 + {where_clause} + ORDER BY ts_added_ms + """ txn.execute(sql, tuple(itertools.chain.from_iterable(batch))) @@ -253,47 +187,87 @@ class KeyStore(SQLBaseStore): return await self.db_pool.runInteraction("get_server_keys_json", _txn) - async def get_server_keys_json_for_remote( - self, server_keys: Iterable[Tuple[str, Optional[str], Optional[str]]] - ) -> Dict[Tuple[str, Optional[str], Optional[str]], List[Dict[str, Any]]]: - """Retrieve the key json for a list of server_keys and key ids. - If no keys are found for a given server, key_id and source then - that server, key_id, and source triplet entry will be an empty list. - The JSON is returned as a byte array so that it can be efficiently - used in an HTTP response. + @cached() + def get_server_key_json_for_remote( + self, + server_name: str, + key_id: str, + ) -> Optional[FetchKeyResultForRemote]: + raise NotImplementedError() - Args: - server_keys: List of (server_name, key_id, source) triplets. + @cachedList( + cached_method_name="get_server_key_json_for_remote", list_name="key_ids" + ) + async def get_server_keys_json_for_remote( + self, server_name: str, key_ids: Iterable[str] + ) -> Mapping[str, Optional[FetchKeyResultForRemote]]: + """Fetch the cached keys for the given server/key IDs. - Returns: - A mapping from (server_name, key_id, source) triplets to a list of dicts + If we have multiple entries for a given key ID, returns the most recent. """ + rows = await self.db_pool.simple_select_many_batch( + table="server_keys_json", + column="key_id", + iterable=key_ids, + keyvalues={"server_name": server_name}, + retcols=( + "key_id", + "from_server", + "ts_added_ms", + "ts_valid_until_ms", + "key_json", + ), + desc="get_server_keys_json_for_remote", + ) - def _get_server_keys_json_txn( - txn: LoggingTransaction, - ) -> Dict[Tuple[str, Optional[str], Optional[str]], List[Dict[str, Any]]]: - results = {} - for server_name, key_id, from_server in server_keys: - keyvalues = {"server_name": server_name} - if key_id is not None: - keyvalues["key_id"] = key_id - if from_server is not None: - keyvalues["from_server"] = from_server - rows = self.db_pool.simple_select_list_txn( - txn, - "server_keys_json", - keyvalues=keyvalues, - retcols=( - "key_id", - "from_server", - "ts_added_ms", - "ts_valid_until_ms", - "key_json", - ), - ) - results[(server_name, key_id, from_server)] = rows - return results + if not rows: + return {} + + # We sort the rows so that the most recently added entry is picked up. + rows.sort(key=lambda r: r["ts_added_ms"]) + + return { + row["key_id"]: FetchKeyResultForRemote( + # Cast to bytes since postgresql returns a memoryview. + key_json=bytes(row["key_json"]), + valid_until_ts=row["ts_valid_until_ms"], + added_ts=row["ts_added_ms"], + ) + for row in rows + } - return await self.db_pool.runInteraction( - "get_server_keys_json", _get_server_keys_json_txn + async def get_all_server_keys_json_for_remote( + self, + server_name: str, + ) -> Dict[str, FetchKeyResultForRemote]: + """Fetch the cached keys for the given server. + + If we have multiple entries for a given key ID, returns the most recent. + """ + rows = await self.db_pool.simple_select_list( + table="server_keys_json", + keyvalues={"server_name": server_name}, + retcols=( + "key_id", + "from_server", + "ts_added_ms", + "ts_valid_until_ms", + "key_json", + ), + desc="get_server_keys_json_for_remote", ) + + if not rows: + return {} + + rows.sort(key=lambda r: r["ts_added_ms"]) + + return { + row["key_id"]: FetchKeyResultForRemote( + # Cast to bytes since postgresql returns a memoryview. + key_json=bytes(row["key_json"]), + valid_until_ts=row["ts_valid_until_ms"], + added_ts=row["ts_added_ms"], + ) + for row in rows + } diff --git a/synapse/storage/databases/main/lock.py b/synapse/storage/databases/main/lock.py index 7270ef09da..5a01ec2137 100644 --- a/synapse/storage/databases/main/lock.py +++ b/synapse/storage/databases/main/lock.py @@ -12,11 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from contextlib import AsyncExitStack from types import TracebackType -from typing import TYPE_CHECKING, Optional, Set, Tuple, Type +from typing import TYPE_CHECKING, Collection, Optional, Set, Tuple, Type from weakref import WeakValueDictionary -from twisted.internet.interfaces import IReactorCore +from twisted.internet.task import LoopingCall from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore @@ -25,6 +26,7 @@ from synapse.storage.database import ( LoggingDatabaseConnection, LoggingTransaction, ) +from synapse.types import ISynapseReactor from synapse.util import Clock from synapse.util.stringutils import random_string @@ -68,12 +70,20 @@ class LockStore(SQLBaseStore): self._reactor = hs.get_reactor() self._instance_name = hs.get_instance_id() - # A map from `(lock_name, lock_key)` to the token of any locks that we - # think we currently hold. - self._live_tokens: WeakValueDictionary[ + # A map from `(lock_name, lock_key)` to lock that we think we + # currently hold. + self._live_lock_tokens: WeakValueDictionary[ Tuple[str, str], Lock ] = WeakValueDictionary() + # A map from `(lock_name, lock_key, token)` to read/write lock that we + # think we currently hold. For a given lock_name/lock_key, there can be + # multiple read locks at a time but only one write lock (no mixing read + # and write locks at the same time). + self._live_read_write_lock_tokens: WeakValueDictionary[ + Tuple[str, str, str], Lock + ] = WeakValueDictionary() + # When we shut down we want to remove the locks. Technically this can # lead to a race, as we may drop the lock while we are still processing. # However, a) it should be a small window, b) the lock is best effort @@ -86,16 +96,22 @@ class LockStore(SQLBaseStore): self._acquiring_locks: Set[Tuple[str, str]] = set() + self._clock.looping_call( + self._reap_stale_read_write_locks, _LOCK_TIMEOUT_MS / 10.0 + ) + @wrap_as_background_process("LockStore._on_shutdown") async def _on_shutdown(self) -> None: """Called when the server is shutting down""" logger.info("Dropping held locks due to shutdown") - # We need to take a copy of the tokens dict as dropping the locks will - # cause the dictionary to change. - locks = dict(self._live_tokens) + # We need to take a copy of the locks as dropping the locks will cause + # the dictionary to change. + locks = list(self._live_lock_tokens.values()) + list( + self._live_read_write_lock_tokens.values() + ) - for lock in locks.values(): + for lock in locks: await lock.release() logger.info("Dropped locks due to shutdown") @@ -122,7 +138,7 @@ class LockStore(SQLBaseStore): """ # Check if this process has taken out a lock and if it's still valid. - lock = self._live_tokens.get((lock_name, lock_key)) + lock = self._live_lock_tokens.get((lock_name, lock_key)) if lock and await lock.is_still_valid(): return None @@ -176,61 +192,148 @@ class LockStore(SQLBaseStore): self._reactor, self._clock, self, + read_write=False, lock_name=lock_name, lock_key=lock_key, token=token, ) - self._live_tokens[(lock_name, lock_key)] = lock + self._live_lock_tokens[(lock_name, lock_key)] = lock return lock - async def _is_lock_still_valid( - self, lock_name: str, lock_key: str, token: str - ) -> bool: - """Checks whether this instance still holds the lock.""" - last_renewed_ts = await self.db_pool.simple_select_one_onecol( - table="worker_locks", - keyvalues={ - "lock_name": lock_name, - "lock_key": lock_key, - "token": token, - }, - retcol="last_renewed_ts", - allow_none=True, - desc="is_lock_still_valid", - ) - return ( - last_renewed_ts is not None - and self._clock.time_msec() - _LOCK_TIMEOUT_MS < last_renewed_ts - ) + async def try_acquire_read_write_lock( + self, + lock_name: str, + lock_key: str, + write: bool, + ) -> Optional["Lock"]: + """Try to acquire a lock for the given name/key. Will return an async + context manager if the lock is successfully acquired, which *must* be + used (otherwise the lock will leak). + """ - async def _renew_lock(self, lock_name: str, lock_key: str, token: str) -> None: - """Attempt to renew the lock if we still hold it.""" - await self.db_pool.simple_update( - table="worker_locks", - keyvalues={ + try: + lock = await self.db_pool.runInteraction( + "try_acquire_read_write_lock", + self._try_acquire_read_write_lock_txn, + lock_name, + lock_key, + write, + db_autocommit=True, + ) + except self.database_engine.module.IntegrityError: + return None + + return lock + + def _try_acquire_read_write_lock_txn( + self, + txn: LoggingTransaction, + lock_name: str, + lock_key: str, + write: bool, + ) -> "Lock": + # We attempt to acquire the lock by inserting into + # `worker_read_write_locks` and seeing if that fails any + # constraints. If it doesn't then we have acquired the lock, + # otherwise we haven't. + + now = self._clock.time_msec() + token = random_string(6) + + self.db_pool.simple_insert_txn( + txn, + table="worker_read_write_locks", + values={ "lock_name": lock_name, "lock_key": lock_key, + "write_lock": write, + "instance_name": self._instance_name, "token": token, + "last_renewed_ts": now, }, - updatevalues={"last_renewed_ts": self._clock.time_msec()}, - desc="renew_lock", ) - async def _drop_lock(self, lock_name: str, lock_key: str, token: str) -> None: - """Attempt to drop the lock, if we still hold it""" - await self.db_pool.simple_delete( - table="worker_locks", - keyvalues={ - "lock_name": lock_name, - "lock_key": lock_key, - "token": token, - }, - desc="drop_lock", + lock = Lock( + self._reactor, + self._clock, + self, + read_write=True, + lock_name=lock_name, + lock_key=lock_key, + token=token, ) - self._live_tokens.pop((lock_name, lock_key), None) + def set_lock() -> None: + self._live_read_write_lock_tokens[(lock_name, lock_key, token)] = lock + + txn.call_after(set_lock) + + return lock + + async def try_acquire_multi_read_write_lock( + self, + lock_names: Collection[Tuple[str, str]], + write: bool, + ) -> Optional[AsyncExitStack]: + """Try to acquire multiple locks for the given names/keys. Will return + an async context manager if the locks are successfully acquired, which + *must* be used (otherwise the lock will leak). + + If only a subset of the locks can be acquired then it will immediately + drop them and return `None`. + """ + try: + locks = await self.db_pool.runInteraction( + "try_acquire_multi_read_write_lock", + self._try_acquire_multi_read_write_lock_txn, + lock_names, + write, + ) + except self.database_engine.module.IntegrityError: + return None + + stack = AsyncExitStack() + + for lock in locks: + await stack.enter_async_context(lock) + + return stack + + def _try_acquire_multi_read_write_lock_txn( + self, + txn: LoggingTransaction, + lock_names: Collection[Tuple[str, str]], + write: bool, + ) -> Collection["Lock"]: + locks = [] + + for lock_name, lock_key in lock_names: + lock = self._try_acquire_read_write_lock_txn( + txn, lock_name, lock_key, write + ) + locks.append(lock) + + return locks + + @wrap_as_background_process("_reap_stale_read_write_locks") + async def _reap_stale_read_write_locks(self) -> None: + delete_sql = """ + DELETE FROM worker_read_write_locks + WHERE last_renewed_ts < ? + """ + + def reap_stale_read_write_locks_txn(txn: LoggingTransaction) -> None: + txn.execute(delete_sql, (self._clock.time_msec() - _LOCK_TIMEOUT_MS,)) + if txn.rowcount: + logger.info("Reaped %d stale locks", txn.rowcount) + + await self.db_pool.runInteraction( + "_reap_stale_read_write_locks", + reap_stale_read_write_locks_txn, + db_autocommit=True, + ) class Lock: @@ -256,9 +359,10 @@ class Lock: def __init__( self, - reactor: IReactorCore, + reactor: ISynapseReactor, clock: Clock, store: LockStore, + read_write: bool, lock_name: str, lock_key: str, token: str, @@ -266,21 +370,39 @@ class Lock: self._reactor = reactor self._clock = clock self._store = store + self._read_write = read_write self._lock_name = lock_name self._lock_key = lock_key self._token = token - self._looping_call = clock.looping_call( - self._renew, _RENEWAL_INTERVAL_MS, store, lock_name, lock_key, token - ) + self._table = "worker_read_write_locks" if read_write else "worker_locks" + + # We might be called from a non-main thread, so we defer setting up the + # looping call. + self._looping_call: Optional[LoopingCall] = None + reactor.callFromThread(self._setup_looping_call) self._dropped = False + def _setup_looping_call(self) -> None: + self._looping_call = self._clock.looping_call( + self._renew, + _RENEWAL_INTERVAL_MS, + self._store, + self._clock, + self._read_write, + self._lock_name, + self._lock_key, + self._token, + ) + @staticmethod @wrap_as_background_process("Lock._renew") async def _renew( store: LockStore, + clock: Clock, + read_write: bool, lock_name: str, lock_key: str, token: str, @@ -291,12 +413,34 @@ class Lock: don't end up with a reference to `self` in the reactor, which would stop this from being cleaned up if we dropped the context manager. """ - await store._renew_lock(lock_name, lock_key, token) + table = "worker_read_write_locks" if read_write else "worker_locks" + await store.db_pool.simple_update( + table=table, + keyvalues={ + "lock_name": lock_name, + "lock_key": lock_key, + "token": token, + }, + updatevalues={"last_renewed_ts": clock.time_msec()}, + desc="renew_lock", + ) async def is_still_valid(self) -> bool: """Check if the lock is still held by us""" - return await self._store._is_lock_still_valid( - self._lock_name, self._lock_key, self._token + last_renewed_ts = await self._store.db_pool.simple_select_one_onecol( + table=self._table, + keyvalues={ + "lock_name": self._lock_name, + "lock_key": self._lock_key, + "token": self._token, + }, + retcol="last_renewed_ts", + allow_none=True, + desc="is_lock_still_valid", + ) + return ( + last_renewed_ts is not None + and self._clock.time_msec() - _LOCK_TIMEOUT_MS < last_renewed_ts ) async def __aenter__(self) -> None: @@ -322,10 +466,26 @@ class Lock: if self._dropped: return - if self._looping_call.running: + if self._looping_call and self._looping_call.running: self._looping_call.stop() - await self._store._drop_lock(self._lock_name, self._lock_key, self._token) + await self._store.db_pool.simple_delete( + table=self._table, + keyvalues={ + "lock_name": self._lock_name, + "lock_key": self._lock_key, + "token": self._token, + }, + desc="drop_lock", + ) + + if self._read_write: + self._store._live_read_write_lock_tokens.pop( + (self._lock_name, self._lock_key, self._token), None + ) + else: + self._store._live_lock_tokens.pop((self._lock_name, self._lock_key), None) + self._dropped = True def __del__(self) -> None: @@ -333,8 +493,9 @@ class Lock: # We should not be dropped without the lock being released (unless # we're shutting down), but if we are then let's at least stop # renewing the lock. - if self._looping_call.running: - self._looping_call.stop() + if self._looping_call and self._looping_call.running: + # We might be called from a non-main thread. + self._reactor.callFromThread(self._looping_call.stop) if self._reactor.running: logger.error( diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index fa8be214ce..2e6b176bd2 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -27,6 +27,8 @@ from typing import ( ) from synapse.api.constants import Direction +from synapse.logging.opentracing import trace +from synapse.media._base import ThumbnailInfo from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( DatabasePool, @@ -328,6 +330,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): "get_local_media_ids", _get_local_media_ids_txn ) + @trace async def store_local_media( self, media_id: str, @@ -433,8 +436,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): desc="store_url_cache", ) - async def get_local_media_thumbnails(self, media_id: str) -> List[Dict[str, Any]]: - return await self.db_pool.simple_select_list( + async def get_local_media_thumbnails(self, media_id: str) -> List[ThumbnailInfo]: + rows = await self.db_pool.simple_select_list( "local_media_repository_thumbnails", {"media_id": media_id}, ( @@ -446,7 +449,18 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): ), desc="get_local_media_thumbnails", ) + return [ + ThumbnailInfo( + width=row["thumbnail_width"], + height=row["thumbnail_height"], + method=row["thumbnail_method"], + type=row["thumbnail_type"], + length=row["thumbnail_length"], + ) + for row in rows + ] + @trace async def store_local_thumbnail( self, media_id: str, @@ -553,8 +567,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): async def get_remote_media_thumbnails( self, origin: str, media_id: str - ) -> List[Dict[str, Any]]: - return await self.db_pool.simple_select_list( + ) -> List[ThumbnailInfo]: + rows = await self.db_pool.simple_select_list( "remote_media_cache_thumbnails", {"media_origin": origin, "media_id": media_id}, ( @@ -563,11 +577,21 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): "thumbnail_method", "thumbnail_type", "thumbnail_length", - "filesystem_id", ), desc="get_remote_media_thumbnails", ) + return [ + ThumbnailInfo( + width=row["thumbnail_width"], + height=row["thumbnail_height"], + method=row["thumbnail_method"], + type=row["thumbnail_type"], + length=row["thumbnail_length"], + ) + for row in rows + ] + @trace async def get_remote_media_thumbnail( self, origin: str, @@ -599,6 +623,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): desc="get_remote_media_thumbnail", ) + @trace async def store_remote_media_thumbnail( self, origin: str, diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py index 14294a0bb8..595e22982e 100644 --- a/synapse/storage/databases/main/metrics.py +++ b/synapse/storage/databases/main/metrics.py @@ -248,89 +248,6 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): (count,) = cast(Tuple[int], txn.fetchone()) return count - async def count_r30_users(self) -> Dict[str, int]: - """ - Counts the number of 30 day retained users, defined as:- - * Users who have created their accounts more than 30 days ago - * Where last seen at most 30 days ago - * Where account creation and last_seen are > 30 days apart - - Returns: - A mapping of counts globally as well as broken out by platform. - """ - - def _count_r30_users(txn: LoggingTransaction) -> Dict[str, int]: - thirty_days_in_secs = 86400 * 30 - now = int(self._clock.time()) - thirty_days_ago_in_secs = now - thirty_days_in_secs - - sql = """ - SELECT platform, COUNT(*) FROM ( - SELECT - users.name, platform, users.creation_ts * 1000, - MAX(uip.last_seen) - FROM users - INNER JOIN ( - SELECT - user_id, - last_seen, - CASE - WHEN user_agent LIKE '%%Android%%' THEN 'android' - WHEN user_agent LIKE '%%iOS%%' THEN 'ios' - WHEN user_agent LIKE '%%Electron%%' THEN 'electron' - WHEN user_agent LIKE '%%Mozilla%%' THEN 'web' - WHEN user_agent LIKE '%%Gecko%%' THEN 'web' - ELSE 'unknown' - END - AS platform - FROM user_ips - ) uip - ON users.name = uip.user_id - AND users.appservice_id is NULL - AND users.creation_ts < ? - AND uip.last_seen/1000 > ? - AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30 - GROUP BY users.name, platform, users.creation_ts - ) u GROUP BY platform - """ - - results = {} - txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs)) - - for row in txn: - if row[0] == "unknown": - pass - results[row[0]] = row[1] - - sql = """ - SELECT COUNT(*) FROM ( - SELECT users.name, users.creation_ts * 1000, - MAX(uip.last_seen) - FROM users - INNER JOIN ( - SELECT - user_id, - last_seen - FROM user_ips - ) uip - ON users.name = uip.user_id - AND appservice_id is NULL - AND users.creation_ts < ? - AND uip.last_seen/1000 > ? - AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30 - GROUP BY users.name, users.creation_ts - ) u - """ - - txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs)) - - (count,) = cast(Tuple[int], txn.fetchone()) - results["all"] = count - - return results - - return await self.db_pool.runInteraction("count_r30_users", _count_r30_users) - async def count_r30v2_users(self) -> Dict[str, int]: """ Counts the number of 30 day retained users, defined as users that: diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py index beb210f8ee..519f05fb60 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py @@ -11,8 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, cast +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Iterable, + List, + Mapping, + Optional, + Tuple, + Union, + cast, +) from synapse.api.presence import PresenceState, UserPresenceState from synapse.replication.tcp.streams import PresenceStream @@ -24,6 +34,7 @@ from synapse.storage.database import ( ) from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.engines import PostgresEngine +from synapse.storage.engines._base import IsolationLevel from synapse.storage.types import Connection from synapse.storage.util.id_generators import ( AbstractStreamIdGenerator, @@ -115,11 +126,16 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore) ) async with stream_ordering_manager as stream_orderings: + # Run the interaction with an isolation level of READ_COMMITTED to avoid + # serialization errors(and rollbacks) in the database. This way it will + # ignore new rows during the DELETE, but will pick them up the next time + # this is run. Currently, that is between 5-60 seconds. await self.db_pool.runInteraction( "update_presence", self._update_presence_txn, stream_orderings, presence_states, + isolation_level=IsolationLevel.READ_COMMITTED, ) return stream_orderings[-1], self._presence_id_gen.get_current_token() @@ -244,7 +260,7 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore) ) async def get_presence_for_users( self, user_ids: Iterable[str] - ) -> Dict[str, UserPresenceState]: + ) -> Mapping[str, UserPresenceState]: rows = await self.db_pool.simple_select_many_batch( table="presence_stream", column="user_id", @@ -370,28 +386,47 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore) limit = 100 offset = 0 while True: - rows = await self.db_pool.runInteraction( - "get_presence_for_all_users", - self.db_pool.simple_select_list_paginate_txn, - "presence_stream", - orderby="stream_id", - start=offset, - limit=limit, - exclude_keyvalues=exclude_keyvalues, - retcols=( - "user_id", - "state", - "last_active_ts", - "last_federation_update_ts", - "last_user_sync_ts", - "status_msg", - "currently_active", + rows = cast( + List[Tuple[str, str, int, int, int, Optional[str], Union[int, bool]]], + await self.db_pool.runInteraction( + "get_presence_for_all_users", + self.db_pool.simple_select_list_paginate_txn, + "presence_stream", + orderby="stream_id", + start=offset, + limit=limit, + exclude_keyvalues=exclude_keyvalues, + retcols=( + "user_id", + "state", + "last_active_ts", + "last_federation_update_ts", + "last_user_sync_ts", + "status_msg", + "currently_active", + ), + order_direction="ASC", ), - order_direction="ASC", ) - for row in rows: - users_to_state[row["user_id"]] = UserPresenceState(**row) + for ( + user_id, + state, + last_active_ts, + last_federation_update_ts, + last_user_sync_ts, + status_msg, + currently_active, + ) in rows: + users_to_state[user_id] = UserPresenceState( + user_id=user_id, + state=state, + last_active_ts=last_active_ts, + last_federation_update_ts=last_federation_update_ts, + last_user_sync_ts=last_user_sync_ts, + status_msg=status_msg, + currently_active=bool(currently_active), + ) # We've run out of updates to query if len(rows) < limit: @@ -419,13 +454,21 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore) txn = db_conn.cursor() txn.execute(sql, (PresenceState.OFFLINE,)) - rows = self.db_pool.cursor_to_dict(txn) + rows = txn.fetchall() txn.close() - for row in rows: - row["currently_active"] = bool(row["currently_active"]) - - return [UserPresenceState(**row) for row in rows] + return [ + UserPresenceState( + user_id=user_id, + state=state, + last_active_ts=last_active_ts, + last_federation_update_ts=last_federation_update_ts, + last_user_sync_ts=last_user_sync_ts, + status_msg=status_msg, + currently_active=bool(currently_active), + ) + for user_id, state, last_active_ts, last_federation_update_ts, last_user_sync_ts, status_msg, currently_active in rows + ] def take_presence_startup_info(self) -> List[UserPresenceState]: active_on_startup = self._presence_on_startup diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py index 65c92bef51..3ba9cc8853 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py @@ -65,7 +65,7 @@ class ProfileWorkerStore(SQLBaseStore): SELECT user_id FROM profiles WHERE user_id > ? ORDER BY user_id - LIMIT 1 OFFSET 50 + LIMIT 1 OFFSET 1000 """ txn.execute(sql, (lower_bound_id,)) res = txn.fetchone() @@ -137,11 +137,11 @@ class ProfileWorkerStore(SQLBaseStore): return 50 - async def get_profileinfo(self, user_localpart: str) -> ProfileInfo: + async def get_profileinfo(self, user_id: UserID) -> ProfileInfo: try: profile = await self.db_pool.simple_select_one( table="profiles", - keyvalues={"user_id": user_localpart}, + keyvalues={"full_user_id": user_id.to_string()}, retcols=("displayname", "avatar_url"), desc="get_profileinfo", ) @@ -156,18 +156,18 @@ class ProfileWorkerStore(SQLBaseStore): avatar_url=profile["avatar_url"], display_name=profile["displayname"] ) - async def get_profile_displayname(self, user_localpart: str) -> Optional[str]: + async def get_profile_displayname(self, user_id: UserID) -> Optional[str]: return await self.db_pool.simple_select_one_onecol( table="profiles", - keyvalues={"user_id": user_localpart}, + keyvalues={"full_user_id": user_id.to_string()}, retcol="displayname", desc="get_profile_displayname", ) - async def get_profile_avatar_url(self, user_localpart: str) -> Optional[str]: + async def get_profile_avatar_url(self, user_id: UserID) -> Optional[str]: return await self.db_pool.simple_select_one_onecol( table="profiles", - keyvalues={"user_id": user_localpart}, + keyvalues={"full_user_id": user_id.to_string()}, retcol="avatar_url", desc="get_profile_avatar_url", ) diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py index efbd3e75d9..dea0e0458c 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py @@ -249,12 +249,11 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): # Mark all state and own events as outliers logger.info("[purge] marking remaining events as outliers") txn.execute( - "UPDATE events SET outlier = ?" + "UPDATE events SET outlier = TRUE" " WHERE event_id IN (" - " SELECT event_id FROM events_to_purge " - " WHERE NOT should_delete" - ")", - (True,), + " SELECT event_id FROM events_to_purge " + " WHERE NOT should_delete" + ")" ) # synapse tries to take out an exclusive lock on room_depth whenever it @@ -308,6 +307,8 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): logger.info("[purge] done") + self._invalidate_caches_for_room_events_and_stream(txn, room_id) + return referenced_state_groups async def purge_room(self, room_id: str) -> List[int]: @@ -449,10 +450,6 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): "e2e_room_keys", "event_push_summary", "pusher_throttle", - "insertion_events", - "insertion_event_extremities", - "insertion_event_edges", - "batch_events", "room_account_data", "room_tags", # "rooms" happens last, to keep the foreign keys in the other tables @@ -485,10 +482,6 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): # index on them. In any case we should be clearing out 'stream' tables # periodically anyway (#5888) - # TODO: we could probably usefully do a bunch more cache invalidation here - - # XXX: as with purge_history, this is racy, but no worse than other races - # that already exist. - self._invalidate_cache_and_stream(txn, self.have_seen_event, (room_id,)) + self._invalidate_caches_for_room_and_stream(txn, room_id) return state_groups diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 9f862f00c1..923166974c 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -88,8 +88,7 @@ def _load_rules( msc1767_enabled=experimental_config.msc1767_enabled, msc3664_enabled=experimental_config.msc3664_enabled, msc3381_polls_enabled=experimental_config.msc3381_polls_enabled, - msc3952_intentional_mentions=experimental_config.msc3952_intentional_mentions, - msc3958_suppress_edits_enabled=experimental_config.msc3958_supress_edit_notifs, + msc4028_push_encrypted_events=experimental_config.msc4028_push_encrypted_events, ) return filtered_rules @@ -218,7 +217,7 @@ class PushRulesWorkerStore( @cachedList(cached_method_name="get_push_rules_for_user", list_name="user_ids") async def bulk_get_push_rules( self, user_ids: Collection[str] - ) -> Dict[str, FilteredPushRules]: + ) -> Mapping[str, FilteredPushRules]: if not user_ids: return {} @@ -561,19 +560,19 @@ class PushRuleStore(PushRulesWorkerStore): if isinstance(self.database_engine, PostgresEngine): sql = """ INSERT INTO push_rules_enable (id, user_name, rule_id, enabled) - VALUES (?, ?, ?, ?) + VALUES (?, ?, ?, 1) ON CONFLICT DO NOTHING """ elif isinstance(self.database_engine, Sqlite3Engine): sql = """ INSERT OR IGNORE INTO push_rules_enable (id, user_name, rule_id, enabled) - VALUES (?, ?, ?, ?) + VALUES (?, ?, ?, 1) """ else: raise RuntimeError("Unknown database engine") new_enable_id = self._push_rules_enable_id_gen.get_next() - txn.execute(sql, (new_enable_id, user_id, rule_id, 1)) + txn.execute(sql, (new_enable_id, user_id, rule_id)) async def delete_push_rule(self, user_id: str, rule_id: str) -> None: """ diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index 87e28e22d3..c7eb7fc478 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -47,6 +47,27 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +# The type of a row in the pushers table. +PusherRow = Tuple[ + int, # id + str, # user_name + Optional[int], # access_token + str, # profile_tag + str, # kind + str, # app_id + str, # app_display_name + str, # device_display_name + str, # pushkey + int, # ts + str, # lang + str, # data + int, # last_stream_ordering + int, # last_success + int, # failing_since + bool, # enabled + str, # device_id +] + class PusherWorkerStore(SQLBaseStore): def __init__( @@ -83,30 +104,66 @@ class PusherWorkerStore(SQLBaseStore): self._remove_deleted_email_pushers, ) - def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[PusherConfig]: + def _decode_pushers_rows( + self, + rows: Iterable[PusherRow], + ) -> Iterator[PusherConfig]: """JSON-decode the data in the rows returned from the `pushers` table Drops any rows whose data cannot be decoded """ - for r in rows: - data_json = r["data"] + for ( + id, + user_name, + access_token, + profile_tag, + kind, + app_id, + app_display_name, + device_display_name, + pushkey, + ts, + lang, + data, + last_stream_ordering, + last_success, + failing_since, + enabled, + device_id, + ) in rows: try: - r["data"] = db_to_json(data_json) + data_json = db_to_json(data) except Exception as e: logger.warning( "Invalid JSON in data for pusher %d: %s, %s", - r["id"], - data_json, + id, + data, e.args[0], ) continue - # If we're using SQLite, then boolean values are integers. This is - # troublesome since some code using the return value of this method might - # expect it to be a boolean, or will expose it to clients (in responses). - r["enabled"] = bool(r["enabled"]) - - yield PusherConfig(**r) + yield PusherConfig( + id=id, + user_name=user_name, + profile_tag=profile_tag, + kind=kind, + app_id=app_id, + app_display_name=app_display_name, + device_display_name=device_display_name, + pushkey=pushkey, + ts=ts, + lang=lang, + data=data_json, + last_stream_ordering=last_stream_ordering, + last_success=last_success, + failing_since=failing_since, + # If we're using SQLite, then boolean values are integers. This is + # troublesome since some code using the return value of this method might + # expect it to be a boolean, or will expose it to clients (in responses). + enabled=bool(enabled), + device_id=device_id, + access_token=access_token, + ) def get_pushers_stream_token(self) -> int: return self._pushers_id_gen.get_current_token() @@ -136,7 +193,7 @@ class PusherWorkerStore(SQLBaseStore): The pushers for which the given columns have the given values. """ - def get_pushers_by_txn(txn: LoggingTransaction) -> List[Dict[str, Any]]: + def get_pushers_by_txn(txn: LoggingTransaction) -> List[PusherRow]: # We could technically use simple_select_list here, but we need to call # COALESCE on the 'enabled' column. While it is technically possible to give # simple_select_list the whole `COALESCE(...) AS ...` as a column name, it @@ -154,7 +211,7 @@ class PusherWorkerStore(SQLBaseStore): txn.execute(sql, list(keyvalues.values())) - return self.db_pool.cursor_to_dict(txn) + return cast(List[PusherRow], txn.fetchall()) ret = await self.db_pool.runInteraction( desc="get_pushers_by", @@ -164,14 +221,22 @@ class PusherWorkerStore(SQLBaseStore): return self._decode_pushers_rows(ret) async def get_enabled_pushers(self) -> Iterator[PusherConfig]: - def get_enabled_pushers_txn(txn: LoggingTransaction) -> Iterator[PusherConfig]: - txn.execute("SELECT * FROM pushers WHERE COALESCE(enabled, TRUE)") - rows = self.db_pool.cursor_to_dict(txn) - - return self._decode_pushers_rows(rows) + def get_enabled_pushers_txn(txn: LoggingTransaction) -> List[PusherRow]: + txn.execute( + """ + SELECT id, user_name, access_token, profile_tag, kind, app_id, + app_display_name, device_display_name, pushkey, ts, lang, data, + last_stream_ordering, last_success, failing_since, + enabled, device_id + FROM pushers WHERE COALESCE(enabled, TRUE) + """ + ) + return cast(List[PusherRow], txn.fetchall()) - return await self.db_pool.runInteraction( - "get_enabled_pushers", get_enabled_pushers_txn + return self._decode_pushers_rows( + await self.db_pool.runInteraction( + "get_enabled_pushers", get_enabled_pushers_txn + ) ) async def get_all_updated_pushers_rows( @@ -304,7 +369,7 @@ class PusherWorkerStore(SQLBaseStore): ) async def get_throttle_params_by_room( - self, pusher_id: str + self, pusher_id: int ) -> Dict[str, ThrottleParams]: res = await self.db_pool.simple_select_list( "pusher_throttle", @@ -323,7 +388,7 @@ class PusherWorkerStore(SQLBaseStore): return params_by_room async def set_throttle_params( - self, pusher_id: str, room_id: str, params: ThrottleParams + self, pusher_id: int, room_id: str, params: ThrottleParams ) -> None: await self.db_pool.simple_upsert( "pusher_throttle", @@ -534,7 +599,7 @@ class PusherBackgroundUpdatesStore(SQLBaseStore): (last_pusher_id, batch_size), ) - rows = self.db_pool.cursor_to_dict(txn) + rows = txn.fetchall() if len(rows) == 0: return 0 @@ -550,19 +615,19 @@ class PusherBackgroundUpdatesStore(SQLBaseStore): txn=txn, table="pushers", key_names=("id",), - key_values=[(row["pusher_id"],) for row in rows], + key_values=[row[0] for row in rows], value_names=("device_id", "access_token"), # If there was already a device_id on the pusher, we only want to clear # the access_token column, so we keep the existing device_id. Otherwise, # we set the device_id we got from joining the access_tokens table. value_values=[ - (row["pusher_device_id"] or row["token_device_id"], None) - for row in rows + (pusher_device_id or token_device_id, None) + for _, pusher_device_id, token_device_id in rows ], ) self.db_pool.updates._background_update_progress_txn( - txn, "set_device_id_for_pushers", {"pusher_id": rows[-1]["pusher_id"]} + txn, "set_device_id_for_pushers", {"pusher_id": rows[-1][0]} ) return len(rows) diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 5ee5c7ad9f..b2645ab43c 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -43,7 +43,7 @@ from synapse.storage.util.id_generators import ( MultiWriterIdGenerator, StreamIdGenerator, ) -from synapse.types import JsonDict +from synapse.types import JsonDict, JsonMapping from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -218,7 +218,7 @@ class ReceiptsWorkerStore(SQLBaseStore): @cached() async def _get_receipts_for_user_with_orderings( self, user_id: str, receipt_type: str - ) -> JsonDict: + ) -> JsonMapping: """ Fetch receipts for all rooms that the given user is joined to. @@ -258,7 +258,7 @@ class ReceiptsWorkerStore(SQLBaseStore): async def get_linearized_receipts_for_rooms( self, room_ids: Iterable[str], to_key: int, from_key: Optional[int] = None - ) -> List[dict]: + ) -> List[JsonMapping]: """Get receipts for multiple rooms for sending to clients. Args: @@ -287,7 +287,7 @@ class ReceiptsWorkerStore(SQLBaseStore): async def get_linearized_receipts_for_room( self, room_id: str, to_key: int, from_key: Optional[int] = None - ) -> Sequence[JsonDict]: + ) -> Sequence[JsonMapping]: """Get receipts for a single room for sending to clients. Args: @@ -310,28 +310,28 @@ class ReceiptsWorkerStore(SQLBaseStore): @cached(tree=True) async def _get_linearized_receipts_for_room( self, room_id: str, to_key: int, from_key: Optional[int] = None - ) -> Sequence[JsonDict]: + ) -> Sequence[JsonMapping]: """See get_linearized_receipts_for_room""" - def f(txn: LoggingTransaction) -> List[Dict[str, Any]]: + def f(txn: LoggingTransaction) -> List[Tuple[str, str, str, str]]: if from_key: sql = ( - "SELECT * FROM receipts_linearized WHERE" + "SELECT receipt_type, user_id, event_id, data" + " FROM receipts_linearized WHERE" " room_id = ? AND stream_id > ? AND stream_id <= ?" ) txn.execute(sql, (room_id, from_key, to_key)) else: sql = ( - "SELECT * FROM receipts_linearized WHERE" + "SELECT receipt_type, user_id, event_id, data" + " FROM receipts_linearized WHERE" " room_id = ? AND stream_id <= ?" ) txn.execute(sql, (room_id, to_key)) - rows = self.db_pool.cursor_to_dict(txn) - - return rows + return cast(List[Tuple[str, str, str, str]], txn.fetchall()) rows = await self.db_pool.runInteraction("get_linearized_receipts_for_room", f) @@ -339,10 +339,10 @@ class ReceiptsWorkerStore(SQLBaseStore): return [] content: JsonDict = {} - for row in rows: - content.setdefault(row["event_id"], {}).setdefault(row["receipt_type"], {})[ - row["user_id"] - ] = db_to_json(row["data"]) + for receipt_type, user_id, event_id, data in rows: + content.setdefault(event_id, {}).setdefault(receipt_type, {})[ + user_id + ] = db_to_json(data) return [{"type": EduTypes.RECEIPT, "room_id": room_id, "content": content}] @@ -353,14 +353,17 @@ class ReceiptsWorkerStore(SQLBaseStore): ) async def _get_linearized_receipts_for_rooms( self, room_ids: Collection[str], to_key: int, from_key: Optional[int] = None - ) -> Dict[str, Sequence[JsonDict]]: + ) -> Mapping[str, Sequence[JsonMapping]]: if not room_ids: return {} - def f(txn: LoggingTransaction) -> List[Dict[str, Any]]: + def f( + txn: LoggingTransaction, + ) -> List[Tuple[str, str, str, str, Optional[str], str]]: if from_key: sql = """ - SELECT * FROM receipts_linearized WHERE + SELECT room_id, receipt_type, user_id, event_id, thread_id, data + FROM receipts_linearized WHERE stream_id > ? AND stream_id <= ? AND """ clause, args = make_in_list_sql_clause( @@ -370,7 +373,8 @@ class ReceiptsWorkerStore(SQLBaseStore): txn.execute(sql + clause, [from_key, to_key] + list(args)) else: sql = """ - SELECT * FROM receipts_linearized WHERE + SELECT room_id, receipt_type, user_id, event_id, thread_id, data + FROM receipts_linearized WHERE stream_id <= ? AND """ @@ -380,29 +384,31 @@ class ReceiptsWorkerStore(SQLBaseStore): txn.execute(sql + clause, [to_key] + list(args)) - return self.db_pool.cursor_to_dict(txn) + return cast( + List[Tuple[str, str, str, str, Optional[str], str]], txn.fetchall() + ) txn_results = await self.db_pool.runInteraction( "_get_linearized_receipts_for_rooms", f ) results: JsonDict = {} - for row in txn_results: + for room_id, receipt_type, user_id, event_id, thread_id, data in txn_results: # We want a single event per room, since we want to batch the # receipts by room, event and type. room_event = results.setdefault( - row["room_id"], - {"type": EduTypes.RECEIPT, "room_id": row["room_id"], "content": {}}, + room_id, + {"type": EduTypes.RECEIPT, "room_id": room_id, "content": {}}, ) # The content is of the form: # {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. } - event_entry = room_event["content"].setdefault(row["event_id"], {}) - receipt_type = event_entry.setdefault(row["receipt_type"], {}) + event_entry = room_event["content"].setdefault(event_id, {}) + receipt_type_dict = event_entry.setdefault(receipt_type, {}) - receipt_type[row["user_id"]] = db_to_json(row["data"]) - if row["thread_id"]: - receipt_type[row["user_id"]]["thread_id"] = row["thread_id"] + receipt_type_dict[user_id] = db_to_json(data) + if thread_id: + receipt_type_dict[user_id]["thread_id"] = thread_id results = { room_id: [results[room_id]] if room_id in results else [] @@ -415,7 +421,7 @@ class ReceiptsWorkerStore(SQLBaseStore): ) async def get_linearized_receipts_for_all_rooms( self, to_key: int, from_key: Optional[int] = None - ) -> Mapping[str, JsonDict]: + ) -> Mapping[str, JsonMapping]: """Get receipts for all rooms between two stream_ids, up to a limit of the latest 100 read receipts. @@ -428,10 +434,11 @@ class ReceiptsWorkerStore(SQLBaseStore): A dictionary of roomids to a list of receipts. """ - def f(txn: LoggingTransaction) -> List[Dict[str, Any]]: + def f(txn: LoggingTransaction) -> List[Tuple[str, str, str, str, str]]: if from_key: sql = """ - SELECT * FROM receipts_linearized WHERE + SELECT room_id, receipt_type, user_id, event_id, data + FROM receipts_linearized WHERE stream_id > ? AND stream_id <= ? ORDER BY stream_id DESC LIMIT 100 @@ -439,7 +446,8 @@ class ReceiptsWorkerStore(SQLBaseStore): txn.execute(sql, [from_key, to_key]) else: sql = """ - SELECT * FROM receipts_linearized WHERE + SELECT room_id, receipt_type, user_id, event_id, data + FROM receipts_linearized WHERE stream_id <= ? ORDER BY stream_id DESC LIMIT 100 @@ -447,27 +455,27 @@ class ReceiptsWorkerStore(SQLBaseStore): txn.execute(sql, [to_key]) - return self.db_pool.cursor_to_dict(txn) + return cast(List[Tuple[str, str, str, str, str]], txn.fetchall()) txn_results = await self.db_pool.runInteraction( "get_linearized_receipts_for_all_rooms", f ) results: JsonDict = {} - for row in txn_results: + for room_id, receipt_type, user_id, event_id, data in txn_results: # We want a single event per room, since we want to batch the # receipts by room, event and type. room_event = results.setdefault( - row["room_id"], - {"type": EduTypes.RECEIPT, "room_id": row["room_id"], "content": {}}, + room_id, + {"type": EduTypes.RECEIPT, "room_id": room_id, "content": {}}, ) # The content is of the form: # {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. } - event_entry = room_event["content"].setdefault(row["event_id"], {}) - receipt_type = event_entry.setdefault(row["receipt_type"], {}) + event_entry = room_event["content"].setdefault(event_id, {}) + receipt_type_dict = event_entry.setdefault(receipt_type, {}) - receipt_type[row["user_id"]] = db_to_json(row["data"]) + receipt_type_dict[user_id] = db_to_json(data) return results @@ -742,7 +750,7 @@ class ReceiptsWorkerStore(SQLBaseStore): event_ids: List[str], thread_id: Optional[str], data: dict, - ) -> Optional[Tuple[int, int]]: + ) -> Optional[int]: """Insert a receipt, either from local client or remote server. Automatically does conversion between linearized and graph @@ -795,9 +803,7 @@ class ReceiptsWorkerStore(SQLBaseStore): now - event_ts, ) - await self.db_pool.runInteraction( - "insert_graph_receipt", - self._insert_graph_receipt_txn, + await self._insert_graph_receipt( room_id, receipt_type, user_id, @@ -806,13 +812,10 @@ class ReceiptsWorkerStore(SQLBaseStore): data, ) - max_persisted_id = self._receipts_id_gen.get_current_token() - - return stream_id, max_persisted_id + return stream_id - def _insert_graph_receipt_txn( + async def _insert_graph_receipt( self, - txn: LoggingTransaction, room_id: str, receipt_type: str, user_id: str, @@ -822,13 +825,6 @@ class ReceiptsWorkerStore(SQLBaseStore): ) -> None: assert self._can_write_to_receipts - txn.call_after( - self._get_receipts_for_user_with_orderings.invalidate, - (user_id, receipt_type), - ) - # FIXME: This shouldn't invalidate the whole cache - txn.call_after(self._get_linearized_receipts_for_room.invalidate, (room_id,)) - keyvalues = { "room_id": room_id, "receipt_type": receipt_type, @@ -840,8 +836,8 @@ class ReceiptsWorkerStore(SQLBaseStore): else: keyvalues["thread_id"] = thread_id - self.db_pool.simple_upsert_txn( - txn, + await self.db_pool.simple_upsert( + desc="insert_graph_receipt", table="receipts_graph", keyvalues=keyvalues, values={ @@ -851,6 +847,11 @@ class ReceiptsWorkerStore(SQLBaseStore): where_clause=where_clause, ) + self._get_receipts_for_user_with_orderings.invalidate((user_id, receipt_type)) + + # FIXME: This shouldn't invalidate the whole cache + self._get_linearized_receipts_for_room.invalidate((room_id,)) + class ReceiptsBackgroundUpdateStore(SQLBaseStore): POPULATE_RECEIPT_EVENT_STREAM_ORDERING = "populate_event_stream_ordering" @@ -939,11 +940,7 @@ class ReceiptsBackgroundUpdateStore(SQLBaseStore): receipts.""" def _remote_duplicate_receipts_txn(txn: LoggingTransaction) -> None: - if isinstance(self.database_engine, PostgresEngine): - ROW_ID_NAME = "ctid" - else: - ROW_ID_NAME = "rowid" - + ROW_ID_NAME = self.database_engine.row_id_name # Identify any duplicate receipts arising from # https://github.com/matrix-org/synapse/issues/14406. # The following query takes less than a minute on matrix.org. diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 676d03bb7e..64a2c31a5d 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -16,7 +16,7 @@ import logging import random import re -from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union, cast +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast import attr @@ -192,78 +192,68 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): ) @cached() - async def get_user_by_id(self, user_id: str) -> Optional[Mapping[str, Any]]: - """Deprecated: use get_userinfo_by_id instead""" + async def get_user_by_id(self, user_id: str) -> Optional[UserInfo]: + """Returns info about the user account, if it exists.""" - def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]: + def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[UserInfo]: # We could technically use simple_select_one here, but it would not perform # the COALESCEs (unless hacked into the column names), which could yield # confusing results. txn.execute( """ SELECT - name, password_hash, is_guest, admin, consent_version, consent_ts, + name, is_guest, admin, consent_version, consent_ts, consent_server_notice_sent, appservice_id, creation_ts, user_type, deactivated, COALESCE(shadow_banned, FALSE) AS shadow_banned, - COALESCE(approved, TRUE) AS approved + COALESCE(approved, TRUE) AS approved, + COALESCE(locked, FALSE) AS locked FROM users WHERE name = ? """, (user_id,), ) - rows = self.db_pool.cursor_to_dict(txn) - - if len(rows) == 0: + row = txn.fetchone() + if not row: return None - return rows[0] + ( + name, + is_guest, + admin, + consent_version, + consent_ts, + consent_server_notice_sent, + appservice_id, + creation_ts, + user_type, + deactivated, + shadow_banned, + approved, + locked, + ) = row + + return UserInfo( + appservice_id=appservice_id, + consent_server_notice_sent=consent_server_notice_sent, + consent_version=consent_version, + consent_ts=consent_ts, + creation_ts=creation_ts, + is_admin=bool(admin), + is_deactivated=bool(deactivated), + is_guest=bool(is_guest), + is_shadow_banned=bool(shadow_banned), + user_id=UserID.from_string(name), + user_type=user_type, + approved=bool(approved), + locked=bool(locked), + ) - row = await self.db_pool.runInteraction( + return await self.db_pool.runInteraction( desc="get_user_by_id", func=get_user_by_id_txn, ) - if row is not None: - # If we're using SQLite our boolean values will be integers. Because we - # present some of this data as is to e.g. server admins via REST APIs, we - # want to make sure we're returning the right type of data. - # Note: when adding a column name to this list, be wary of NULLable columns, - # since NULL values will be turned into False. - boolean_columns = ["admin", "deactivated", "shadow_banned", "approved"] - for column in boolean_columns: - if not isinstance(row[column], bool): - row[column] = bool(row[column]) - - return row - - async def get_userinfo_by_id(self, user_id: str) -> Optional[UserInfo]: - """Get a UserInfo object for a user by user ID. - - Note! Currently uses the cache of `get_user_by_id`. Once that deprecated method is removed, - this method should be cached. - - Args: - user_id: The user to fetch user info for. - Returns: - `UserInfo` object if user found, otherwise `None`. - """ - user_data = await self.get_user_by_id(user_id) - if not user_data: - return None - return UserInfo( - appservice_id=user_data["appservice_id"], - consent_server_notice_sent=user_data["consent_server_notice_sent"], - consent_version=user_data["consent_version"], - creation_ts=user_data["creation_ts"], - is_admin=bool(user_data["admin"]), - is_deactivated=bool(user_data["deactivated"]), - is_guest=bool(user_data["is_guest"]), - is_shadow_banned=bool(user_data["shadow_banned"]), - user_id=UserID.from_string(user_data["name"]), - user_type=user_data["user_type"], - ) - async def is_trial_user(self, user_id: str) -> bool: """Checks if user is in the "trial" period, i.e. within the first N days of registration defined by `mau_trial_days` config or the @@ -279,10 +269,10 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): now = self._clock.time_msec() days = self.config.server.mau_appservice_trial_days.get( - info["appservice_id"], self.config.server.mau_trial_days + info.appservice_id, self.config.server.mau_trial_days ) trial_duration_ms = days * 24 * 60 * 60 * 1000 - is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms + is_trial = (now - info.creation_ts * 1000) < trial_duration_ms return is_trial @cached() @@ -454,9 +444,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): ) -> List[Tuple[str, int]]: sql = ( "SELECT user_id, expiration_ts_ms FROM account_validity" - " WHERE email_sent = ? AND (expiration_ts_ms - ?) <= ?" + " WHERE email_sent = FALSE AND (expiration_ts_ms - ?) <= ?" ) - values = [False, now_ms, renew_at] + values = [now_ms, renew_at] txn.execute(sql, values) return cast(List[Tuple[str, int]], txn.fetchall()) @@ -600,16 +590,31 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): """ txn.execute(sql, (token,)) - rows = self.db_pool.cursor_to_dict(txn) - - if rows: - row = rows[0] - - # This field is nullable, ensure it comes out as a boolean - if row["token_used"] is None: - row["token_used"] = False + row = txn.fetchone() - return TokenLookupResult(**row) + if row: + ( + user_id, + is_guest, + shadow_banned, + token_id, + device_id, + valid_until_ms, + token_owner, + token_used, + ) = row + + return TokenLookupResult( + user_id=user_id, + is_guest=is_guest, + shadow_banned=shadow_banned, + token_id=token_id, + device_id=device_id, + valid_until_ms=valid_until_ms, + token_owner=token_owner, + # This field is nullable, ensure it comes out as a boolean + token_used=bool(token_used), + ) return None @@ -854,11 +859,10 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): """Counts all users registered on the homeserver.""" def _count_users(txn: LoggingTransaction) -> int: - txn.execute("SELECT COUNT(*) AS users FROM users") - rows = self.db_pool.cursor_to_dict(txn) - if rows: - return rows[0]["users"] - return 0 + txn.execute("SELECT COUNT(*) FROM users") + row = txn.fetchone() + assert row is not None + return row[0] return await self.db_pool.runInteraction("count_users", _count_users) @@ -912,11 +916,10 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): """Counts all users without a special user_type registered on the homeserver.""" def _count_users(txn: LoggingTransaction) -> int: - txn.execute("SELECT COUNT(*) AS users FROM users where user_type is null") - rows = self.db_pool.cursor_to_dict(txn) - if rows: - return rows[0]["users"] - return 0 + txn.execute("SELECT COUNT(*) FROM users where user_type is null") + row = txn.fetchone() + assert row is not None + return row[0] return await self.db_pool.runInteraction("count_real_users", _count_users) @@ -1116,6 +1119,27 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): # Convert the integer into a boolean. return res == 1 + @cached() + async def get_user_locked_status(self, user_id: str) -> bool: + """Retrieve the value for the `locked` property for the provided user. + + Args: + user_id: The ID of the user to retrieve the status for. + + Returns: + True if the user was locked, false if the user is still active. + """ + + res = await self.db_pool.simple_select_one_onecol( + table="users", + keyvalues={"name": user_id}, + retcol="locked", + desc="get_user_locked_status", + ) + + # Convert the potential integer into a boolean. + return bool(res) + async def get_threepid_validation_session( self, medium: Optional[str], @@ -1252,12 +1276,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): ) txn.execute(sql, []) - res = self.db_pool.cursor_to_dict(txn) - if res: - for user in res: - self.set_expiration_date_for_user_txn( - txn, user["name"], use_delta=True - ) + for (name,) in txn.fetchall(): + self.set_expiration_date_for_user_txn(txn, name, use_delta=True) await self.db_pool.runInteraction( "get_users_with_no_expiration_date", @@ -1963,11 +1983,12 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): (user_id,), ) - rows = self.db_pool.cursor_to_dict(txn) + row = txn.fetchone() + assert row is not None # We cast to bool because the value returned by the database engine might # be an integer if we're using SQLite. - return bool(rows[0]["approved"]) + return bool(row[0]) return await self.db_pool.runInteraction( desc="is_user_pending_approval", @@ -2045,22 +2066,22 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): (last_user, batch_size), ) - rows = self.db_pool.cursor_to_dict(txn) + rows = txn.fetchall() if not rows: return True, 0 rows_processed_nb = 0 - for user in rows: - if not user["count_tokens"] and not user["count_threepids"]: - self.set_user_deactivated_status_txn(txn, user["name"], True) + for name, count_tokens, count_threepids in rows: + if not count_tokens and not count_threepids: + self.set_user_deactivated_status_txn(txn, name, True) rows_processed_nb += 1 logger.info("Marked %d rows as deactivated", rows_processed_nb) self.db_pool.updates._background_update_progress_txn( - txn, "users_set_deactivated_flag", {"user_id": rows[-1]["name"]} + txn, "users_set_deactivated_flag", {"user_id": rows[-1][0]} ) if batch_size > len(rows): @@ -2111,6 +2132,33 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) txn.call_after(self.is_guest.invalidate, (user_id,)) + async def set_user_locked_status(self, user_id: str, locked: bool) -> None: + """Set the `locked` property for the provided user to the provided value. + + Args: + user_id: The ID of the user to set the status for. + locked: The value to set for `locked`. + """ + + await self.db_pool.runInteraction( + "set_user_locked_status", + self.set_user_locked_status_txn, + user_id, + locked, + ) + + def set_user_locked_status_txn( + self, txn: LoggingTransaction, user_id: str, locked: bool + ) -> None: + self.db_pool.simple_update_one_txn( + txn=txn, + table="users", + keyvalues={"name": user_id}, + updatevalues={"locked": locked}, + ) + self._invalidate_cache_and_stream(txn, self.get_user_locked_status, (user_id,)) + self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) + def update_user_approval_status_txn( self, txn: LoggingTransaction, user_id: str, approved: bool ) -> None: @@ -2253,6 +2301,26 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): return next_id + async def set_device_for_refresh_token( + self, user_id: str, old_device_id: str, device_id: str + ) -> None: + """Moves refresh tokens from old device to current device + + Args: + user_id: The user of the devices. + old_device_id: The old device. + device_id: The new device ID. + Returns: + None + """ + + await self.db_pool.simple_update( + "refresh_tokens", + keyvalues={"user_id": user_id, "device_id": old_device_id}, + updatevalues={"device_id": device_id}, + desc="set_device_for_refresh_token", + ) + def _set_device_for_access_token_txn( self, txn: LoggingTransaction, token: str, device_id: str ) -> str: diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 4a6c6c724d..9246b418f5 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -365,6 +365,36 @@ class RelationsWorkerStore(SQLBaseStore): func=get_all_relation_ids_for_event_with_types_txn, ) + async def get_all_relations_for_event( + self, + event_id: str, + ) -> List[str]: + """Get the event IDs of all events that have a relation to the given event. + + Args: + event_id: The event for which to look for related events. + + Returns: + A list of the IDs of the events that relate to the given event. + """ + + def get_all_relation_ids_for_event_txn( + txn: LoggingTransaction, + ) -> List[str]: + rows = self.db_pool.simple_select_list_txn( + txn=txn, + table="event_relations", + keyvalues={"relates_to_id": event_id}, + retcols=["event_id"], + ) + + return [row["event_id"] for row in rows] + + return await self.db_pool.runInteraction( + desc="get_all_relation_ids_for_event", + func=get_all_relation_ids_for_event_txn, + ) + async def event_includes_relation(self, event_id: str) -> bool: """Check if the given event relates to another event. @@ -428,14 +458,14 @@ class RelationsWorkerStore(SQLBaseStore): ) return result is not None - @cached() + @cached() # type: ignore[synapse-@cached-mutable] async def get_references_for_event(self, event_id: str) -> List[JsonDict]: raise NotImplementedError() @cachedList(cached_method_name="get_references_for_event", list_name="event_ids") async def get_references_for_events( self, event_ids: Collection[str] - ) -> Mapping[str, Optional[List[_RelatedEvent]]]: + ) -> Mapping[str, Optional[Sequence[_RelatedEvent]]]: """Get a list of references to the given events. Args: @@ -482,14 +512,15 @@ class RelationsWorkerStore(SQLBaseStore): "_get_references_for_events_txn", _get_references_for_events_txn ) - @cached() + @cached() # type: ignore[synapse-@cached-mutable] def get_applicable_edit(self, event_id: str) -> Optional[EventBase]: raise NotImplementedError() - @cachedList(cached_method_name="get_applicable_edit", list_name="event_ids") + # TODO: This returns a mutable object, which is generally bad. + @cachedList(cached_method_name="get_applicable_edit", list_name="event_ids") # type: ignore[synapse-@cached-mutable] async def get_applicable_edits( self, event_ids: Collection[str] - ) -> Dict[str, Optional[EventBase]]: + ) -> Mapping[str, Optional[EventBase]]: """Get the most recent edit (if any) that has happened for the given events. @@ -568,14 +599,15 @@ class RelationsWorkerStore(SQLBaseStore): for original_event_id in event_ids } - @cached() + @cached() # type: ignore[synapse-@cached-mutable] def get_thread_summary(self, event_id: str) -> Optional[Tuple[int, EventBase]]: raise NotImplementedError() - @cachedList(cached_method_name="get_thread_summary", list_name="event_ids") + # TODO: This returns a mutable object, which is generally bad. + @cachedList(cached_method_name="get_thread_summary", list_name="event_ids") # type: ignore[synapse-@cached-mutable] async def get_thread_summaries( self, event_ids: Collection[str] - ) -> Dict[str, Optional[Tuple[int, EventBase]]]: + ) -> Mapping[str, Optional[Tuple[int, EventBase]]]: """Get the number of threaded replies and the latest reply (if any) for the given events. Args: @@ -749,7 +781,7 @@ class RelationsWorkerStore(SQLBaseStore): @cachedList(cached_method_name="get_thread_participated", list_name="event_ids") async def get_threads_participated( self, event_ids: Collection[str], user_id: str - ) -> Dict[str, bool]: + ) -> Mapping[str, bool]: """Get whether the requesting user participated in the given threads. This is separate from get_thread_summaries since that can be cached across @@ -901,7 +933,7 @@ class RelationsWorkerStore(SQLBaseStore): room_id: str, limit: int = 5, from_token: Optional[ThreadsNextBatch] = None, - ) -> Tuple[List[str], Optional[ThreadsNextBatch]]: + ) -> Tuple[Sequence[str], Optional[ThreadsNextBatch]]: """Get a list of thread IDs, ordered by topological ordering of their latest reply. diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 4247b88adf..3972d70358 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -848,7 +848,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): def get_retention_policy_for_room_txn( txn: LoggingTransaction, - ) -> List[Dict[str, Optional[int]]]: + ) -> Optional[Tuple[Optional[int], Optional[int]]]: txn.execute( """ SELECT min_lifetime, max_lifetime FROM room_retention @@ -858,7 +858,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): (room_id,), ) - return self.db_pool.cursor_to_dict(txn) + return cast(Optional[Tuple[Optional[int], Optional[int]]], txn.fetchone()) ret = await self.db_pool.runInteraction( "get_retention_policy_for_room", @@ -873,8 +873,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): max_lifetime=self.config.retention.retention_default_max_lifetime, ) - min_lifetime = ret[0]["min_lifetime"] - max_lifetime = ret[0]["max_lifetime"] + min_lifetime, max_lifetime = ret # If one of the room's policy's attributes isn't defined, use the matching # attribute from the default policy. @@ -953,11 +952,11 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): JOIN event_json USING (room_id, event_id) WHERE room_id = ? %(where_clause)s - AND contains_url = ? AND outlier = ? + AND contains_url = TRUE AND outlier = FALSE ORDER BY stream_ordering DESC LIMIT ? """ - txn.execute(sql % {"where_clause": ""}, (room_id, True, False, 100)) + txn.execute(sql % {"where_clause": ""}, (room_id, 100)) local_media_mxcs = [] remote_media_mxcs = [] @@ -993,7 +992,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): txn.execute( sql % {"where_clause": "AND stream_ordering < ?"}, - (room_id, next_token, True, False, 100), + (room_id, next_token, 100), ) return local_media_mxcs, remote_media_mxcs @@ -1103,9 +1102,9 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): # set quarantine if quarantined_by is not None: - sql += "AND safe_from_quarantine = ?" + sql += "AND safe_from_quarantine = FALSE" txn.executemany( - sql, [(quarantined_by, media_id, False) for media_id in local_mxcs] + sql, [(quarantined_by, media_id) for media_id in local_mxcs] ) # remove from quarantine else: @@ -1179,14 +1178,13 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): txn.execute(sql, args) - rows = self.db_pool.cursor_to_dict(txn) - rooms_dict = {} - - for row in rows: - rooms_dict[row["room_id"]] = RetentionPolicy( - min_lifetime=row["min_lifetime"], - max_lifetime=row["max_lifetime"], + rooms_dict = { + room_id: RetentionPolicy( + min_lifetime=min_lifetime, + max_lifetime=max_lifetime, ) + for room_id, min_lifetime, max_lifetime in txn + } if include_null: # If required, do a second query that retrieves all of the rooms we know @@ -1195,13 +1193,11 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): txn.execute(sql) - rows = self.db_pool.cursor_to_dict(txn) - # If a room isn't already in the dict (i.e. it doesn't have a retention # policy in its state), add it with a null policy. - for row in rows: - if row["room_id"] not in rooms_dict: - rooms_dict[row["room_id"]] = RetentionPolicy() + for (room_id,) in txn: + if room_id not in rooms_dict: + rooms_dict[room_id] = RetentionPolicy() return rooms_dict @@ -1720,24 +1716,24 @@ class RoomBackgroundUpdateStore(SQLBaseStore): (last_room, batch_size), ) - rows = self.db_pool.cursor_to_dict(txn) + rows = txn.fetchall() if not rows: return True - for row in rows: - if not row["json"]: + for room_id, event_id, json in rows: + if not json: retention_policy = {} else: - ev = db_to_json(row["json"]) + ev = db_to_json(json) retention_policy = ev["content"] self.db_pool.simple_insert_txn( txn=txn, table="room_retention", values={ - "room_id": row["room_id"], - "event_id": row["event_id"], + "room_id": room_id, + "event_id": event_id, "min_lifetime": retention_policy.get("min_lifetime"), "max_lifetime": retention_policy.get("max_lifetime"), }, @@ -1746,7 +1742,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore): logger.info("Inserted %d rows into room_retention", len(rows)) self.db_pool.updates._background_update_progress_txn( - txn, "insert_room_retention", {"room_id": rows[-1]["room_id"]} + txn, "insert_room_retention", {"room_id": rows[-1][0]} ) if batch_size > len(rows): @@ -2153,7 +2149,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): raise StoreError(400, "No create event in state") # Before MSC2175, the room creator was a separate field. - if not room_version.msc2175_implicit_room_creator: + if not room_version.implicit_room_creator: room_creator = create_event.content.get(EventContentFields.ROOM_CREATOR) if not isinstance(room_creator, str): diff --git a/synapse/storage/databases/main/room_batch.py b/synapse/storage/databases/main/room_batch.py deleted file mode 100644 index 131f357d04..0000000000 --- a/synapse/storage/databases/main/room_batch.py +++ /dev/null @@ -1,47 +0,0 @@ -# Copyright 2021 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Optional - -from synapse.storage._base import SQLBaseStore - - -class RoomBatchStore(SQLBaseStore): - async def get_insertion_event_id_by_batch_id( - self, room_id: str, batch_id: str - ) -> Optional[str]: - """Retrieve a insertion event ID. - - Args: - batch_id: The batch ID of the insertion event to retrieve. - - Returns: - The event_id of an insertion event, or None if there is no known - insertion event for the given insertion event. - """ - return await self.db_pool.simple_select_one_onecol( - table="insertion_events", - keyvalues={"room_id": room_id, "next_batch_id": batch_id}, - retcol="event_id", - allow_none=True, - ) - - async def store_state_group_id_for_event_id( - self, event_id: str, state_group_id: int - ) -> None: - await self.db_pool.simple_upsert( - table="event_to_state_groups", - keyvalues={"event_id": event_id}, - values={"state_group": state_group_id, "event_id": event_id}, - ) diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index e068f27a10..bbe08368db 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from itertools import chain from typing import ( TYPE_CHECKING, AbstractSet, @@ -57,15 +56,12 @@ from synapse.types import ( StrCollection, get_domain_from_id, ) -from synapse.util.async_helpers import Linearizer -from synapse.util.caches import intern_string from synapse.util.caches.descriptors import _CacheContext, cached, cachedList from synapse.util.iterutils import batch_iter from synapse.util.metrics import Measure if TYPE_CHECKING: from synapse.server import HomeServer - from synapse.state import _StateCacheEntry logger = logging.getLogger(__name__) @@ -91,10 +87,6 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): ): super().__init__(database, db_conn, hs) - # Used by `_get_joined_hosts` to ensure only one thing mutates the cache - # at a time. Keyed by room_id. - self._joined_host_linearizer = Linearizer("_JoinedHostsCache") - self._server_notices_mxid = hs.config.servernotices.server_notices_mxid if ( @@ -199,7 +191,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): ) async def get_subset_users_in_room_with_profiles( self, room_id: str, user_ids: Collection[str] - ) -> Dict[str, ProfileInfo]: + ) -> Mapping[str, ProfileInfo]: """Get a mapping from user ID to profile information for a list of users in a given room. @@ -283,7 +275,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): _get_users_in_room_with_profiles, ) - @cached(max_entries=100000) + @cached(max_entries=100000) # type: ignore[synapse-@cached-mutable] async def get_room_summary(self, room_id: str) -> Mapping[str, MemberSummary]: """Get the details of a room roughly suitable for use by the room summary extension to /sync. Useful when lazy loading room members. @@ -684,7 +676,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): ) async def _get_rooms_for_users( self, user_ids: Collection[str] - ) -> Dict[str, FrozenSet[str]]: + ) -> Mapping[str, FrozenSet[str]]: """A batched version of `get_rooms_for_user`. Returns: @@ -889,7 +881,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): ) async def _get_user_ids_from_membership_event_ids( self, event_ids: Iterable[str] - ) -> Dict[str, Optional[str]]: + ) -> Mapping[str, Optional[str]]: """For given set of member event_ids check if they point to a join event. @@ -927,11 +919,10 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): raise Exception("Invalid host name") sql = """ - SELECT state_key FROM current_state_events AS c - INNER JOIN room_memberships AS m USING (event_id) - WHERE m.membership = ? + SELECT state_key FROM current_state_events + WHERE membership = ? AND type = 'm.room.member' - AND c.room_id = ? + AND room_id = ? AND state_key LIKE ? LIMIT 1 """ @@ -993,7 +984,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): ) @cached(iterable=True, max_entries=10000) - async def get_current_hosts_in_room_ordered(self, room_id: str) -> List[str]: + async def get_current_hosts_in_room_ordered(self, room_id: str) -> Tuple[str, ...]: """ Get current hosts in room based on current state. @@ -1022,12 +1013,14 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): # `get_users_in_room` rather than funky SQL. domains = await self.get_current_hosts_in_room(room_id) - return list(domains) + return tuple(domains) # For PostgreSQL we can use a regex to pull out the domains from the # joined users in `current_state_events` via regex. - def get_current_hosts_in_room_ordered_txn(txn: LoggingTransaction) -> List[str]: + def get_current_hosts_in_room_ordered_txn( + txn: LoggingTransaction, + ) -> Tuple[str, ...]: # Returns a list of servers currently joined in the room sorted by # longest in the room first (aka. with the lowest depth). The # heuristic of sorting by servers who have been in the room the @@ -1052,126 +1045,12 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): """ txn.execute(sql, (room_id,)) # `server_domain` will be `NULL` for malformed MXIDs with no colons. - return [d for d, in txn if d is not None] + return tuple(d for d, in txn if d is not None) return await self.db_pool.runInteraction( "get_current_hosts_in_room_ordered", get_current_hosts_in_room_ordered_txn ) - async def get_joined_hosts( - self, room_id: str, state: StateMap[str], state_entry: "_StateCacheEntry" - ) -> FrozenSet[str]: - state_group: Union[object, int] = state_entry.state_group - if not state_group: - # If state_group is None it means it has yet to be assigned a - # state group, i.e. we need to make sure that calls with a state_group - # of None don't hit previous cached calls with a None state_group. - # To do this we set the state_group to a new object as object() != object() - state_group = object() - - assert state_group is not None - with Measure(self._clock, "get_joined_hosts"): - return await self._get_joined_hosts( - room_id, state_group, state, state_entry=state_entry - ) - - @cached(num_args=2, max_entries=10000, iterable=True) - async def _get_joined_hosts( - self, - room_id: str, - state_group: Union[object, int], - state: StateMap[str], - state_entry: "_StateCacheEntry", - ) -> FrozenSet[str]: - # We don't use `state_group`, it's there so that we can cache based on - # it. However, its important that its never None, since two - # current_state's with a state_group of None are likely to be different. - # - # The `state_group` must match the `state_entry.state_group` (if not None). - assert state_group is not None - assert state_entry.state_group is None or state_entry.state_group == state_group - - # We use a secondary cache of previous work to allow us to build up the - # joined hosts for the given state group based on previous state groups. - # - # We cache one object per room containing the results of the last state - # group we got joined hosts for. The idea is that generally - # `get_joined_hosts` is called with the "current" state group for the - # room, and so consecutive calls will be for consecutive state groups - # which point to the previous state group. - cache = await self._get_joined_hosts_cache(room_id) # type: ignore[misc] - - # If the state group in the cache matches, we already have the data we need. - if state_entry.state_group == cache.state_group: - return frozenset(cache.hosts_to_joined_users) - - # Since we'll mutate the cache we need to lock. - async with self._joined_host_linearizer.queue(room_id): - if state_entry.state_group == cache.state_group: - # Same state group, so nothing to do. We've already checked for - # this above, but the cache may have changed while waiting on - # the lock. - pass - elif state_entry.prev_group == cache.state_group: - # The cached work is for the previous state group, so we work out - # the delta. - assert state_entry.delta_ids is not None - for (typ, state_key), event_id in state_entry.delta_ids.items(): - if typ != EventTypes.Member: - continue - - host = intern_string(get_domain_from_id(state_key)) - user_id = state_key - known_joins = cache.hosts_to_joined_users.setdefault(host, set()) - - event = await self.get_event(event_id) - if event.membership == Membership.JOIN: - known_joins.add(user_id) - else: - known_joins.discard(user_id) - - if not known_joins: - cache.hosts_to_joined_users.pop(host, None) - else: - # The cache doesn't match the state group or prev state group, - # so we calculate the result from first principles. - # - # We need to fetch all hosts joined to the room according to `state` by - # inspecting all join memberships in `state`. However, if the `state` is - # relatively recent then many of its events are likely to be held in - # the current state of the room, which is easily available and likely - # cached. - # - # We therefore compute the set of `state` events not in the - # current state and only fetch those. - current_memberships = ( - await self._get_approximate_current_memberships_in_room(room_id) - ) - unknown_state_events = {} - joined_users_in_current_state = [] - - for (type, state_key), event_id in state.items(): - if event_id not in current_memberships: - unknown_state_events[type, state_key] = event_id - elif current_memberships[event_id] == Membership.JOIN: - joined_users_in_current_state.append(state_key) - - joined_user_ids = await self.get_joined_user_ids_from_state( - room_id, unknown_state_events - ) - - cache.hosts_to_joined_users = {} - for user_id in chain(joined_user_ids, joined_users_in_current_state): - host = intern_string(get_domain_from_id(user_id)) - cache.hosts_to_joined_users.setdefault(host, set()).add(user_id) - - if state_entry.state_group: - cache.state_group = state_entry.state_group - else: - cache.state_group = object() - - return frozenset(cache.hosts_to_joined_users) - async def _get_approximate_current_memberships_in_room( self, room_id: str ) -> Mapping[str, Optional[str]]: @@ -1192,7 +1071,8 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): ) return {row["event_id"]: row["membership"] for row in rows} - @cached(max_entries=10000) + # TODO This returns a mutable object, which is generally confusing when using a cache. + @cached(max_entries=10000) # type: ignore[synapse-@cached-mutable] def _get_joined_hosts_cache(self, room_id: str) -> "_JoinedHostsCache": return _JoinedHostsCache() @@ -1314,7 +1194,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): ) async def get_membership_from_event_ids( self, member_event_ids: Iterable[str] - ) -> Dict[str, Optional[EventIdMembership]]: + ) -> Mapping[str, Optional[EventIdMembership]]: """Get user_id and membership of a set of event IDs. Returns: @@ -1461,7 +1341,6 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore): SELECT stream_ordering, event_id, events.room_id, event_json.json FROM events INNER JOIN event_json USING (event_id) - INNER JOIN room_memberships USING (event_id) WHERE ? <= stream_ordering AND stream_ordering < ? AND type = 'm.room.member' ORDER BY stream_ordering DESC @@ -1470,18 +1349,16 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore): txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size)) - rows = self.db_pool.cursor_to_dict(txn) + rows = txn.fetchall() if not rows: return 0 - min_stream_id = rows[-1]["stream_ordering"] + min_stream_id = rows[-1][0] to_update = [] - for row in rows: - event_id = row["event_id"] - room_id = row["room_id"] + for _, event_id, room_id, json in rows: try: - event_json = db_to_json(row["json"]) + event_json = db_to_json(json) content = event_json["content"] except Exception: continue diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index a7aae661d8..1d69c4a5f0 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -179,22 +179,24 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): # store_search_entries_txn with a generator function, but that # would mean having two cursors open on the database at once. # Instead we just build a list of results. - rows = self.db_pool.cursor_to_dict(txn) + rows = txn.fetchall() if not rows: return 0 - min_stream_id = rows[-1]["stream_ordering"] + min_stream_id = rows[-1][0] event_search_rows = [] - for row in rows: + for ( + stream_ordering, + event_id, + room_id, + etype, + json, + origin_server_ts, + ) in rows: try: - event_id = row["event_id"] - room_id = row["room_id"] - etype = row["type"] - stream_ordering = row["stream_ordering"] - origin_server_ts = row["origin_server_ts"] try: - event_json = db_to_json(row["json"]) + event_json = db_to_json(json) content = event_json["content"] except Exception: continue diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index ebb2ae964f..5eaaff5b68 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -14,7 +14,17 @@ # limitations under the License. import collections.abc import logging -from typing import TYPE_CHECKING, Any, Collection, Dict, Iterable, Optional, Set, Tuple +from typing import ( + TYPE_CHECKING, + Any, + Collection, + Dict, + Iterable, + Mapping, + Optional, + Set, + Tuple, +) import attr @@ -372,7 +382,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): ) async def _get_state_group_for_events( self, event_ids: Collection[str] - ) -> Dict[str, int]: + ) -> Mapping[str, int]: """Returns mapping event_id -> state_group. Raises: diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py index 97c4dc2603..9d403919e4 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py @@ -19,6 +19,7 @@ from itertools import chain from typing import ( TYPE_CHECKING, Any, + Counter, Dict, Iterable, List, @@ -28,8 +29,6 @@ from typing import ( cast, ) -from typing_extensions import Counter - from twisted.internet.defer import DeferredLock from synapse.api.constants import Direction, EventContentFields, EventTypes, Membership @@ -108,6 +107,8 @@ class UserSortOrder(Enum): AVATAR_URL = "avatar_url" SHADOW_BANNED = "shadow_banned" CREATION_TS = "creation_ts" + LAST_SEEN_TS = "last_seen_ts" + LOCKED = "locked" class StatsStore(StateDeltasStore): @@ -697,7 +698,7 @@ class StatsStore(StateDeltasStore): txn: LoggingTransaction, ) -> Tuple[List[JsonDict], int]: filters = [] - args = [self.hs.config.server.server_name] + args: list = [] if search_term: filters.append("(lmr.user_id LIKE ? OR displayname LIKE ?)") @@ -733,7 +734,7 @@ class StatsStore(StateDeltasStore): sql_base = """ FROM local_media_repository as lmr - LEFT JOIN profiles AS p ON lmr.user_id = '@' || p.user_id || ':' || ? + LEFT JOIN profiles AS p ON lmr.user_id = p.full_user_id {} GROUP BY lmr.user_id, displayname """.format( diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 92cbe262a6..ea06e4eee0 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -266,7 +266,7 @@ def generate_next_token( # when we are going backwards so we subtract one from the # stream part. last_stream_ordering -= 1 - return RoomStreamToken(last_topo_ordering, last_stream_ordering) + return RoomStreamToken(topological=last_topo_ordering, stream=last_stream_ordering) def _make_generic_sql_bound( @@ -558,7 +558,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): if p > min_pos } - return RoomStreamToken(None, min_pos, immutabledict(positions)) + return RoomStreamToken(stream=min_pos, instance_map=immutabledict(positions)) async def get_room_events_stream_for_rooms( self, @@ -708,7 +708,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): ret.reverse() if rows: - key = RoomStreamToken(None, min(r.stream_ordering for r in rows)) + key = RoomStreamToken(stream=min(r.stream_ordering for r in rows)) else: # Assume we didn't get anything because there was nothing to # get. @@ -969,7 +969,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): topo = await self.db_pool.runInteraction( "_get_max_topological_txn", self._get_max_topological_txn, room_id ) - return RoomStreamToken(topo, stream_ordering) + return RoomStreamToken(topological=topo, stream=stream_ordering) @overload def get_stream_id_for_event_txn( @@ -1033,7 +1033,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): retcols=("stream_ordering", "topological_ordering"), desc="get_topological_token_for_event", ) - return RoomStreamToken(row["topological_ordering"], row["stream_ordering"]) + return RoomStreamToken( + topological=row["topological_ordering"], stream=row["stream_ordering"] + ) async def get_current_topological_token(self, room_id: str, stream_key: int) -> int: """Gets the topological token in a room after or at the given stream @@ -1114,8 +1116,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): else: topo = None internal = event.internal_metadata - internal.before = RoomStreamToken(topo, stream - 1) - internal.after = RoomStreamToken(topo, stream) + internal.before = RoomStreamToken(topological=topo, stream=stream - 1) + internal.after = RoomStreamToken(topological=topo, stream=stream) internal.order = (int(topo) if topo else 0, int(stream)) async def get_events_around( @@ -1191,11 +1193,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): # Paginating backwards includes the event at the token, but paginating # forward doesn't. before_token = RoomStreamToken( - results["topological_ordering"] - 1, results["stream_ordering"] + topological=results["topological_ordering"] - 1, + stream=results["stream_ordering"], ) after_token = RoomStreamToken( - results["topological_ordering"], results["stream_ordering"] + topological=results["topological_ordering"], + stream=results["stream_ordering"], ) rows, start_token = self._paginate_room_events_txn( @@ -1401,7 +1405,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): `to_token`), or `limit` is zero. """ - args = [False, room_id] + args: List[Any] = [room_id] order, from_bound, to_bound = generate_pagination_bounds( direction, from_token, to_token @@ -1475,7 +1479,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): event.topological_ordering, event.stream_ordering FROM events AS event %(join_clause)s - WHERE event.outlier = ? AND event.room_id = ? AND %(bounds)s + WHERE event.outlier = FALSE AND event.room_id = ? AND %(bounds)s ORDER BY event.topological_ordering %(order)s, event.stream_ordering %(order)s LIMIT ? """ % { diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py index c149a9eacb..61403a98cf 100644 --- a/synapse/storage/databases/main/tags.py +++ b/synapse/storage/databases/main/tags.py @@ -23,7 +23,7 @@ from synapse.storage._base import db_to_json from synapse.storage.database import LoggingTransaction from synapse.storage.databases.main.account_data import AccountDataWorkerStore from synapse.storage.util.id_generators import AbstractStreamIdGenerator -from synapse.types import JsonDict +from synapse.types import JsonDict, JsonMapping from synapse.util import json_encoder from synapse.util.caches.descriptors import cached @@ -34,7 +34,7 @@ class TagsWorkerStore(AccountDataWorkerStore): @cached() async def get_tags_for_user( self, user_id: str - ) -> Mapping[str, Mapping[str, JsonDict]]: + ) -> Mapping[str, Mapping[str, JsonMapping]]: """Get all the tags for a user. @@ -109,7 +109,7 @@ class TagsWorkerStore(AccountDataWorkerStore): async def get_updated_tags( self, user_id: str, stream_id: int - ) -> Mapping[str, Mapping[str, JsonDict]]: + ) -> Mapping[str, Mapping[str, JsonMapping]]: """Get all the tags for the rooms where the tags have changed since the given version diff --git a/synapse/storage/databases/main/task_scheduler.py b/synapse/storage/databases/main/task_scheduler.py new file mode 100644 index 0000000000..5555b53575 --- /dev/null +++ b/synapse/storage/databases/main/task_scheduler.py @@ -0,0 +1,230 @@ +# Copyright 2023 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Any, List, Optional, Tuple, cast + +from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, + make_in_list_sql_clause, +) +from synapse.types import JsonDict, JsonMapping, ScheduledTask, TaskStatus +from synapse.util import json_encoder + +if TYPE_CHECKING: + from synapse.server import HomeServer + +ScheduledTaskRow = Tuple[str, str, str, int, str, str, str, str] + + +class TaskSchedulerWorkerStore(SQLBaseStore): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): + super().__init__(database, db_conn, hs) + + @staticmethod + def _convert_row_to_task(row: ScheduledTaskRow) -> ScheduledTask: + task_id, action, status, timestamp, resource_id, params, result, error = row + return ScheduledTask( + id=task_id, + action=action, + status=TaskStatus(status), + timestamp=timestamp, + resource_id=resource_id, + params=db_to_json(params) if params is not None else None, + result=db_to_json(result) if result is not None else None, + error=error, + ) + + async def get_scheduled_tasks( + self, + *, + actions: Optional[List[str]] = None, + resource_id: Optional[str] = None, + statuses: Optional[List[TaskStatus]] = None, + max_timestamp: Optional[int] = None, + limit: Optional[int] = None, + ) -> List[ScheduledTask]: + """Get a list of scheduled tasks from the DB. + + Args: + actions: Limit the returned tasks to those specific action names + resource_id: Limit the returned tasks to the specific resource id, if specified + statuses: Limit the returned tasks to the specific statuses + max_timestamp: Limit the returned tasks to the ones that have + a timestamp inferior to the specified one + limit: Only return `limit` number of rows if set. + + Returns: a list of `ScheduledTask`, ordered by increasing timestamps + """ + + def get_scheduled_tasks_txn(txn: LoggingTransaction) -> List[ScheduledTaskRow]: + clauses: List[str] = [] + args: List[Any] = [] + if resource_id: + clauses.append("resource_id = ?") + args.append(resource_id) + if actions is not None: + clause, temp_args = make_in_list_sql_clause( + txn.database_engine, "action", actions + ) + clauses.append(clause) + args.extend(temp_args) + if statuses is not None: + clause, temp_args = make_in_list_sql_clause( + txn.database_engine, "status", statuses + ) + clauses.append(clause) + args.extend(temp_args) + if max_timestamp is not None: + clauses.append("timestamp <= ?") + args.append(max_timestamp) + + sql = "SELECT * FROM scheduled_tasks" + if clauses: + sql = sql + " WHERE " + " AND ".join(clauses) + + sql = sql + " ORDER BY timestamp" + + if limit is not None: + sql += " LIMIT ?" + args.append(limit) + + txn.execute(sql, args) + return cast(List[ScheduledTaskRow], txn.fetchall()) + + rows = await self.db_pool.runInteraction( + "get_scheduled_tasks", get_scheduled_tasks_txn + ) + return [TaskSchedulerWorkerStore._convert_row_to_task(row) for row in rows] + + async def insert_scheduled_task(self, task: ScheduledTask) -> None: + """Insert a specified `ScheduledTask` in the DB. + + Args: + task: the `ScheduledTask` to insert + """ + await self.db_pool.simple_insert( + "scheduled_tasks", + { + "id": task.id, + "action": task.action, + "status": task.status, + "timestamp": task.timestamp, + "resource_id": task.resource_id, + "params": None + if task.params is None + else json_encoder.encode(task.params), + "result": None + if task.result is None + else json_encoder.encode(task.result), + "error": task.error, + }, + desc="insert_scheduled_task", + ) + + async def update_scheduled_task( + self, + id: str, + timestamp: int, + *, + status: Optional[TaskStatus] = None, + result: Optional[JsonMapping] = None, + error: Optional[str] = None, + ) -> bool: + """Update a scheduled task in the DB with some new value(s). + + Args: + id: id of the `ScheduledTask` to update + timestamp: new timestamp of the task + status: new status of the task + result: new result of the task + error: new error of the task + + Returns: `False` if no matching row was found, `True` otherwise + """ + updatevalues: JsonDict = {"timestamp": timestamp} + if status is not None: + updatevalues["status"] = status + if result is not None: + updatevalues["result"] = json_encoder.encode(result) + if error is not None: + updatevalues["error"] = error + nb_rows = await self.db_pool.simple_update( + "scheduled_tasks", + {"id": id}, + updatevalues, + desc="update_scheduled_task", + ) + return nb_rows > 0 + + async def get_scheduled_task(self, id: str) -> Optional[ScheduledTask]: + """Get a specific `ScheduledTask` from its id. + + Args: + id: the id of the task to retrieve + + Returns: the task if available, `None` otherwise + """ + row = await self.db_pool.simple_select_one( + table="scheduled_tasks", + keyvalues={"id": id}, + retcols=( + "id", + "action", + "status", + "timestamp", + "resource_id", + "params", + "result", + "error", + ), + allow_none=True, + desc="get_scheduled_task", + ) + + return ( + TaskSchedulerWorkerStore._convert_row_to_task( + ( + row["id"], + row["action"], + row["status"], + row["timestamp"], + row["resource_id"], + row["params"], + row["result"], + row["error"], + ) + ) + if row + else None + ) + + async def delete_scheduled_task(self, id: str) -> None: + """Delete a specific task from its id. + + Args: + id: the id of the task to delete + """ + await self.db_pool.simple_delete( + "scheduled_tasks", + keyvalues={"id": id}, + desc="delete_scheduled_task", + ) diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index c3bd36efc9..f35757280d 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -14,7 +14,7 @@ import logging from enum import Enum -from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, cast +from typing import TYPE_CHECKING, Iterable, List, Mapping, Optional, Tuple, cast import attr from canonicaljson import encode_canonical_json @@ -28,8 +28,8 @@ from synapse.storage.database import ( LoggingTransaction, ) from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore -from synapse.types import JsonDict -from synapse.util.caches.descriptors import cached +from synapse.types import JsonDict, StrCollection +from synapse.util.caches.descriptors import cached, cachedList if TYPE_CHECKING: from synapse.server import HomeServer @@ -205,6 +205,26 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): else: return None + @cachedList( + cached_method_name="get_destination_retry_timings", list_name="destinations" + ) + async def get_destination_retry_timings_batch( + self, destinations: StrCollection + ) -> Mapping[str, Optional[DestinationRetryTimings]]: + rows = await self.db_pool.simple_select_many_batch( + table="destinations", + iterable=destinations, + column="destination", + retcols=("destination", "failure_ts", "retry_last_ts", "retry_interval"), + desc="get_destination_retry_timings_batch", + ) + + return { + row.pop("destination"): DestinationRetryTimings(**row) + for row in rows + if row["retry_last_ts"] and row["failure_ts"] and row["retry_interval"] + } + async def set_destination_retry_timings( self, destination: str, @@ -256,8 +276,10 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): retry_interval = EXCLUDED.retry_interval WHERE EXCLUDED.retry_interval = 0 + OR EXCLUDED.retry_last_ts = 0 OR destinations.retry_interval IS NULL OR destinations.retry_interval < EXCLUDED.retry_interval + OR destinations.retry_last_ts < EXCLUDED.retry_last_ts """ txn.execute(sql, (destination, failure_ts, retry_last_ts, retry_interval)) @@ -504,7 +526,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): start: int, limit: int, direction: Direction = Direction.FORWARDS, - ) -> Tuple[List[JsonDict], int]: + ) -> Tuple[List[Tuple[str, int]], int]: """Function to retrieve a paginated list of destination's rooms. This will return a json list of rooms and the total number of rooms. @@ -515,12 +537,14 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): limit: number of rows to retrieve direction: sort ascending or descending by room_id Returns: - A tuple of a dict of rooms and a count of total rooms. + A tuple of a list of room tuples and a count of total rooms. + + Each room tuple is room_id, stream_ordering. """ def get_destination_rooms_paginate_txn( txn: LoggingTransaction, - ) -> Tuple[List[JsonDict], int]: + ) -> Tuple[List[Tuple[str, int]], int]: if direction == Direction.BACKWARDS: order = "DESC" else: @@ -534,14 +558,17 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): txn.execute(sql, [destination]) count = cast(Tuple[int], txn.fetchone())[0] - rooms = self.db_pool.simple_select_list_paginate_txn( - txn=txn, - table="destination_rooms", - orderby="room_id", - start=start, - limit=limit, - retcols=("room_id", "stream_ordering"), - order_direction=order, + rooms = cast( + List[Tuple[str, int]], + self.db_pool.simple_select_list_paginate_txn( + txn=txn, + table="destination_rooms", + orderby="room_id", + start=start, + limit=limit, + retcols=("room_id", "stream_ordering"), + order_direction=order, + ), ) return rooms, count diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index b7d58978de..f0dc31fee6 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -17,6 +17,7 @@ import re import unicodedata from typing import ( TYPE_CHECKING, + Collection, Iterable, List, Mapping, @@ -45,7 +46,7 @@ from synapse.util.stringutils import non_null_str_or_none if TYPE_CHECKING: from synapse.server import HomeServer -from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules +from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules, UserTypes from synapse.storage.database import ( DatabasePool, LoggingDatabaseConnection, @@ -61,7 +62,6 @@ from synapse.types import ( get_domain_from_id, get_localpart_from_id, ) -from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) @@ -356,13 +356,30 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): Add all local users to the user directory. """ - def _get_next_batch(txn: LoggingTransaction) -> Optional[List[str]]: - sql = "SELECT user_id FROM %s LIMIT %s" % ( - TEMP_TABLE + "_users", - str(batch_size), - ) - txn.execute(sql) - user_result = cast(List[Tuple[str]], txn.fetchall()) + def _populate_user_directory_process_users_txn( + txn: LoggingTransaction, + ) -> Optional[int]: + if self.database_engine.supports_returning: + # Note: we use an ORDER BY in the SELECT to force usage of an + # index. Otherwise, postgres does a sequential scan that is + # surprisingly slow (I think due to the fact it will read/skip + # over lots of already deleted rows). + sql = f""" + DELETE FROM {TEMP_TABLE + "_users"} + WHERE user_id IN ( + SELECT user_id FROM {TEMP_TABLE + "_users"} ORDER BY user_id LIMIT ? + ) + RETURNING user_id + """ + txn.execute(sql, (batch_size,)) + user_result = cast(List[Tuple[str]], txn.fetchall()) + else: + sql = "SELECT user_id FROM %s ORDER BY user_id LIMIT %s" % ( + TEMP_TABLE + "_users", + str(batch_size), + ) + txn.execute(sql) + user_result = cast(List[Tuple[str]], txn.fetchall()) if not user_result: return None @@ -378,85 +395,80 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): assert count_result is not None progress["remaining"] = count_result[0] - return users_to_work_on - - users_to_work_on = await self.db_pool.runInteraction( - "populate_user_directory_temp_read", _get_next_batch - ) + if not users_to_work_on: + return None - # No more users -- complete the transaction. - if not users_to_work_on: - await self.db_pool.updates._end_background_update( - "populate_user_directory_process_users" + logger.debug( + "Processing the next %d users of %d remaining", + len(users_to_work_on), + progress["remaining"], ) - return 1 - logger.debug( - "Processing the next %d users of %d remaining" - % (len(users_to_work_on), progress["remaining"]) - ) - - # First filter down to users we want to insert into the user directory. - users_to_insert = [ - user_id - for user_id in users_to_work_on - if await self.should_include_local_user_in_dir(user_id) - ] + # First filter down to users we want to insert into the user directory. + users_to_insert = self._filter_local_users_for_dir_txn( + txn, users_to_work_on + ) - # Next fetch their profiles. Note that the `user_id` here is the - # *localpart*, and that not all users have profiles. - profile_rows = await self.db_pool.simple_select_many_batch( - table="profiles", - column="user_id", - iterable=[get_localpart_from_id(u) for u in users_to_insert], - retcols=( - "user_id", - "displayname", - "avatar_url", - ), - keyvalues={}, - desc="populate_user_directory_process_users_get_profiles", - ) - profiles = { - f"@{row['user_id']}:{self.server_name}": _UserDirProfile( - f"@{row['user_id']}:{self.server_name}", - row["displayname"], - row["avatar_url"], + # Next fetch their profiles. Note that not all users have profiles. + profile_rows = self.db_pool.simple_select_many_txn( + txn, + table="profiles", + column="full_user_id", + iterable=list(users_to_insert), + retcols=( + "full_user_id", + "displayname", + "avatar_url", + ), + keyvalues={}, ) - for row in profile_rows - } + profiles = { + row["full_user_id"]: _UserDirProfile( + row["full_user_id"], + row["displayname"], + row["avatar_url"], + ) + for row in profile_rows + } - profiles_to_insert = [ - profiles.get(user_id) or _UserDirProfile(user_id) - for user_id in users_to_insert - ] + profiles_to_insert = [ + profiles.get(user_id) or _UserDirProfile(user_id) + for user_id in users_to_insert + ] + + # Actually insert the users with their profiles into the directory. + self._update_profiles_in_user_dir_txn(txn, profiles_to_insert) + + # We've finished processing the users. Delete it from the table, if + # we haven't already. + if not self.database_engine.supports_returning: + self.db_pool.simple_delete_many_txn( + txn, + table=TEMP_TABLE + "_users", + column="user_id", + values=users_to_work_on, + keyvalues={}, + ) - # Actually insert the users with their profiles into the directory. - await self.db_pool.runInteraction( - "populate_user_directory_process_users_insertion", - self._update_profiles_in_user_dir_txn, - profiles_to_insert, - ) + # Update the remaining counter. + progress["remaining"] -= len(users_to_work_on) + self.db_pool.updates._background_update_progress_txn( + txn, "populate_user_directory_process_users", progress + ) + return len(users_to_work_on) - # We've finished processing the users. Delete it from the table. - await self.db_pool.simple_delete_many( - table=TEMP_TABLE + "_users", - column="user_id", - iterable=users_to_work_on, - keyvalues={}, - desc="populate_user_directory_process_users_delete", + processed_count = await self.db_pool.runInteraction( + "populate_user_directory_temp", _populate_user_directory_process_users_txn ) - # Update the remaining counter. - progress["remaining"] -= len(users_to_work_on) - await self.db_pool.runInteraction( - "populate_user_directory", - self.db_pool.updates._background_update_progress_txn, - "populate_user_directory_process_users", - progress, - ) + # No more users -- complete the transaction. + if not processed_count: + await self.db_pool.updates._end_background_update( + "populate_user_directory_process_users" + ) + return 1 - return len(users_to_work_on) + return processed_count async def should_include_local_user_in_dir(self, user: str) -> bool: """Certain classes of local user are omitted from the user directory. @@ -494,6 +506,30 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): return True + def _filter_local_users_for_dir_txn( + self, txn: LoggingTransaction, users: Collection[str] + ) -> Collection[str]: + """A batched version of `should_include_local_user_in_dir`""" + users = [ + user + for user in users + if self.get_app_service_by_user_id(user) is None # type: ignore[attr-defined] + and not self.get_if_app_services_interested_in_user(user) # type: ignore[attr-defined] + ] + + rows = self.db_pool.simple_select_many_txn( + txn, + table="users", + column="name", + iterable=users, + keyvalues={ + "deactivated": 0, + }, + retcols=("name", "user_type"), + ) + + return [row["name"] for row in rows if row["user_type"] != UserTypes.SUPPORT] + async def is_room_world_readable_or_publicly_joinable(self, room_id: str) -> bool: """Check if the room is either world_readable or publically joinable""" @@ -733,9 +769,6 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): # This should be unreachable. raise Exception("Unrecognized database engine") - for p in profiles: - txn.call_after(self.get_user_in_directory.invalidate, (p.user_id,)) - async def add_users_who_share_private_room( self, room_id: str, user_id_tuples: Iterable[Tuple[str, str]] ) -> None: @@ -793,14 +826,12 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): txn.execute(f"{truncate} user_directory_search") txn.execute(f"{truncate} users_in_public_rooms") txn.execute(f"{truncate} users_who_share_private_rooms") - txn.call_after(self.get_user_in_directory.invalidate_all) await self.db_pool.runInteraction( "delete_all_from_user_dir", _delete_all_from_user_dir_txn ) - @cached() - async def get_user_in_directory(self, user_id: str) -> Optional[Mapping[str, str]]: + async def _get_user_in_directory(self, user_id: str) -> Optional[Mapping[str, str]]: return await self.db_pool.simple_select_one( table="user_directory", keyvalues={"user_id": user_id}, @@ -862,7 +893,6 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): table="users_who_share_private_rooms", keyvalues={"other_user_id": user_id}, ) - txn.call_after(self.get_user_in_directory.invalidate, (user_id,)) await self.db_pool.runInteraction( "remove_from_user_dir", _remove_from_user_dir_txn @@ -965,7 +995,11 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): ) async def search_user_dir( - self, user_id: str, search_term: str, limit: int + self, + user_id: str, + search_term: str, + limit: int, + show_locked_users: bool = False, ) -> SearchResult: """Searches for users in directory @@ -999,6 +1033,9 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): ) """ + if not show_locked_users: + where_clause += " AND (u.locked IS NULL OR u.locked = FALSE)" + # We allow manipulating the ranking algorithm by injecting statements # based on config options. additional_ordering_statements = [] @@ -1023,12 +1060,16 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): # The array of numbers are the weights for the various part of the # search: (domain, _, display name, localpart) sql = """ + WITH matching_users AS ( + SELECT user_id, vector FROM user_directory_search WHERE vector @@ to_tsquery('simple', ?) + LIMIT 10000 + ) SELECT d.user_id AS user_id, display_name, avatar_url - FROM user_directory_search as t + FROM matching_users as t INNER JOIN user_directory AS d USING (user_id) + LEFT JOIN users AS u ON t.user_id = u.name WHERE %(where_clause)s - AND vector @@ to_tsquery('simple', ?) ORDER BY (CASE WHEN d.user_id IS NOT NULL THEN 4.0 ELSE 1.0 END) * (CASE WHEN display_name IS NOT NULL THEN 1.2 ELSE 1.0 END) @@ -1057,8 +1098,9 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): "order_case_statements": " ".join(additional_ordering_statements), } args = ( - join_args - + (full_query, exact_query, prefix_query) + (full_query,) + + join_args + + (exact_query, prefix_query) + ordering_arguments + (limit + 1,) ) @@ -1081,6 +1123,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): SELECT d.user_id AS user_id, display_name, avatar_url FROM user_directory_search as t INNER JOIN user_directory AS d USING (user_id) + LEFT JOIN users AS u ON t.user_id = u.name WHERE %(where_clause)s AND value MATCH ? diff --git a/synapse/storage/databases/main/user_erasure_store.py b/synapse/storage/databases/main/user_erasure_store.py index f79006533f..06fcbe5e54 100644 --- a/synapse/storage/databases/main/user_erasure_store.py +++ b/synapse/storage/databases/main/user_erasure_store.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Iterable +from typing import Iterable, Mapping from synapse.storage.database import LoggingTransaction from synapse.storage.databases.main import CacheInvalidationWorkerStore @@ -40,7 +40,7 @@ class UserErasureWorkerStore(CacheInvalidationWorkerStore): return bool(result) @cachedList(cached_method_name="is_user_erased", list_name="user_ids") - async def are_users_erased(self, user_ids: Iterable[str]) -> Dict[str, bool]: + async def are_users_erased(self, user_ids: Iterable[str]) -> Mapping[str, bool]: """ Checks which users in a list have requested erasure diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py index 097dea5182..6ff533a129 100644 --- a/synapse/storage/databases/state/bg_updates.py +++ b/synapse/storage/databases/state/bg_updates.py @@ -15,6 +15,7 @@ import logging from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union +from synapse.logging.opentracing import tag_args, trace from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( DatabasePool, @@ -40,6 +41,8 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore): updates. """ + @trace + @tag_args def _count_state_group_hops_txn( self, txn: LoggingTransaction, state_group: int ) -> int: @@ -83,12 +86,26 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore): return count + @trace + @tag_args def _get_state_groups_from_groups_txn( self, txn: LoggingTransaction, groups: List[int], state_filter: Optional[StateFilter] = None, ) -> Mapping[int, StateMap[str]]: + """ + Given a number of state groups, fetch the latest state for each group. + + Args: + txn: The transaction object. + groups: The given state groups that you want to fetch the latest state for. + state_filter: The state filter to apply the state we fetch state from the database. + + Returns: + Map from state_group to a StateMap at that point. + """ + state_filter = state_filter or StateFilter.all() results: Dict[int, MutableStateMap[str]] = {group: {} for group in groups} @@ -201,8 +218,10 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore): if where_clause: where_clause = " AND (%s)" % (where_clause,) - # We don't use WITH RECURSIVE on sqlite3 as there are distributions - # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) + # XXX: We could `WITH RECURSIVE` here since it's supported on SQLite 3.8.3 + # or higher and our minimum supported version is greater than that. + # + # We just haven't put in the time to refactor this. for group in groups: next_group: Optional[int] = group @@ -256,6 +275,16 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): STATE_GROUPS_ROOM_INDEX_UPDATE_NAME = "state_groups_room_id_idx" STATE_GROUP_EDGES_UNIQUE_INDEX_UPDATE_NAME = "state_group_edges_unique_idx" + CURRENT_STATE_EVENTS_STREAM_ORDERING_INDEX_UPDATE_NAME = ( + "current_state_events_stream_ordering_idx" + ) + ROOM_MEMBERSHIPS_STREAM_ORDERING_INDEX_UPDATE_NAME = ( + "room_memberships_stream_ordering_idx" + ) + LOCAL_CURRENT_MEMBERSHIP_STREAM_ORDERING_INDEX_UPDATE_NAME = ( + "local_current_membership_stream_ordering_idx" + ) + def __init__( self, database: DatabasePool, @@ -292,6 +321,27 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): replaces_index="state_group_edges_idx", ) + # These indices are needed to validate the foreign key constraint + # when events are deleted. + self.db_pool.updates.register_background_index_update( + self.CURRENT_STATE_EVENTS_STREAM_ORDERING_INDEX_UPDATE_NAME, + index_name="current_state_events_stream_ordering_idx", + table="current_state_events", + columns=["event_stream_ordering"], + ) + self.db_pool.updates.register_background_index_update( + self.ROOM_MEMBERSHIPS_STREAM_ORDERING_INDEX_UPDATE_NAME, + index_name="room_memberships_stream_ordering_idx", + table="room_memberships", + columns=["event_stream_ordering"], + ) + self.db_pool.updates.register_background_index_update( + self.LOCAL_CURRENT_MEMBERSHIP_STREAM_ORDERING_INDEX_UPDATE_NAME, + index_name="local_current_membership_stream_ordering_idx", + table="local_current_membership", + columns=["event_stream_ordering"], + ) + async def _background_deduplicate_state( self, progress: dict, batch_size: int ) -> int: diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index 29ff64e876..6984d11352 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -20,6 +20,7 @@ import attr from synapse.api.constants import EventTypes from synapse.events import EventBase from synapse.events.snapshot import UnpersistedEventContext, UnpersistedEventContextBase +from synapse.logging.opentracing import tag_args, trace from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( DatabasePool, @@ -159,6 +160,8 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): "get_state_group_delta", _get_state_group_delta_txn ) + @trace + @tag_args @cancellable async def _get_state_groups_from_groups( self, groups: List[int], state_filter: StateFilter @@ -187,6 +190,8 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): return results + @trace + @tag_args def _get_state_for_group_using_cache( self, cache: DictionaryCache[int, StateKey, str], @@ -239,6 +244,8 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): return state_filter.filter_state(state_dict_ids), not missing_types + @trace + @tag_args @cancellable async def _get_state_for_groups( self, groups: Iterable[int], state_filter: Optional[StateFilter] = None @@ -305,6 +312,8 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): return state + @trace + @tag_args def _get_state_for_groups_using_cache( self, groups: Iterable[int], @@ -403,6 +412,8 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): fetched_keys=non_member_types, ) + @trace + @tag_args async def store_state_deltas_for_batched( self, events_and_context: List[Tuple[EventBase, UnpersistedEventContextBase]], @@ -520,6 +531,8 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): prev_group, ) + @trace + @tag_args async def store_state_group( self, event_id: str, @@ -772,6 +785,8 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): ((sg,) for sg in state_groups_to_delete), ) + @trace + @tag_args async def get_previous_state_groups( self, state_groups: Iterable[int] ) -> Dict[int, int]: diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py index 0363cdc038..b1a2418cbd 100644 --- a/synapse/storage/engines/_base.py +++ b/synapse/storage/engines/_base.py @@ -100,6 +100,12 @@ class BaseDatabaseEngine(Generic[ConnectionType, CursorType], metaclass=abc.ABCM """Gets a string giving the server version. For example: '3.22.0'""" ... + @property + @abc.abstractmethod + def row_id_name(self) -> str: + """Gets the literal name representing a row id for this engine.""" + ... + @abc.abstractmethod def in_transaction(self, conn: ConnectionType) -> bool: """Whether the connection is currently in a transaction.""" @@ -145,5 +151,5 @@ class BaseDatabaseEngine(Generic[ConnectionType, CursorType], metaclass=abc.ABCM This is not provided by DBAPI2, and so needs engine-specific support. """ - with open(filepath, "rt") as f: + with open(filepath) as f: cls.executescript(cursor, f.read()) diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index b350f57ccb..6309363217 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -45,6 +45,15 @@ class PostgresEngine( psycopg2.extensions.register_adapter(bytes, _disable_bytes_adapter) self.synchronous_commit: bool = database_config.get("synchronous_commit", True) + # Set the statement timeout to 1 hour by default. + # Any query taking more than 1 hour should probably be considered a bug; + # most of the time this is a sign that work needs to be split up or that + # some degenerate query plan has been created and the client has probably + # timed out/walked off anyway. + # This is in milliseconds. + self.statement_timeout: Optional[int] = database_config.get( + "statement_timeout", 60 * 60 * 1000 + ) self._version: Optional[int] = None # unknown as yet self.isolation_level_map: Mapping[int, int] = { @@ -157,6 +166,10 @@ class PostgresEngine( if not self.synchronous_commit: cursor.execute("SET synchronous_commit TO OFF") + # Abort really long-running statements and turn them into errors. + if self.statement_timeout is not None: + cursor.execute("SET statement_timeout TO ?", (self.statement_timeout,)) + cursor.close() db_conn.commit() @@ -198,6 +211,10 @@ class PostgresEngine( else: return "%i.%i.%i" % (numver / 10000, (numver % 10000) / 100, numver % 100) + @property + def row_id_name(self) -> str: + return "ctid" + def in_transaction(self, conn: psycopg2.extensions.connection) -> bool: return conn.status != psycopg2.extensions.STATUS_READY diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py index ca8c59297c..802069e1e1 100644 --- a/synapse/storage/engines/sqlite.py +++ b/synapse/storage/engines/sqlite.py @@ -123,6 +123,10 @@ class Sqlite3Engine(BaseDatabaseEngine[sqlite3.Connection, sqlite3.Cursor]): """Gets a string giving the server version. For example: '3.22.0'.""" return "%i.%i.%i" % sqlite3.sqlite_version_info + @property + def row_id_name(self) -> str: + return "rowid" + def in_transaction(self, conn: sqlite3.Connection) -> bool: return conn.in_transaction diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py index 71584f3f74..e74b2269d2 100644 --- a/synapse/storage/keys.py +++ b/synapse/storage/keys.py @@ -25,3 +25,10 @@ logger = logging.getLogger(__name__) class FetchKeyResult: verify_key: VerifyKey # the key itself valid_until_ts: int # how long we can use this key for + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class FetchKeyResultForRemote: + key_json: bytes # the full key JSON + valid_until_ts: int # how long we can use this key for, in milliseconds. + added_ts: int # When we added this key, in milliseconds. diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 38b7abd801..31501fd573 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -16,10 +16,18 @@ import logging import os import re from collections import Counter -from typing import Collection, Generator, Iterable, List, Optional, TextIO, Tuple +from typing import ( + Collection, + Counter as CounterType, + Generator, + Iterable, + List, + Optional, + TextIO, + Tuple, +) import attr -from typing_extensions import Counter as CounterType from synapse.config.homeserver import HomeServerConfig from synapse.storage.database import LoggingDatabaseConnection, LoggingTransaction diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 2500381b7b..cbfb32014c 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -45,6 +45,7 @@ class ProfileInfo: display_name: Optional[str] +# TODO This is used as a cached value and is mutable. @attr.s(slots=True, frozen=True, weakref_slot=False, auto_attribs=True) class MemberSummary: # A truncated list of (user_id, event_id) tuples for users of a given diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py index df2cc31ca6..de89de7d74 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -SCHEMA_VERSION = 77 # remember to update the list below when updating +SCHEMA_VERSION = 82 # remember to update the list below when updating """Represents the expectations made by the codebase about the database schema This should be incremented whenever the codebase changes its requirements on the @@ -103,16 +103,30 @@ Changes in SCHEMA_VERSION = 76: Changes in SCHEMA_VERSION = 77 - (Postgres) Add NOT VALID CHECK (full_user_id IS NOT NULL) to tables profiles and user_filters + +Changes in SCHEMA_VERSION = 78 + - Validate check (full_user_id IS NOT NULL) on tables profiles and user_filters + +Changes in SCHEMA_VERSION = 79 + - Add tables to handle in DB read-write locks. + - Add some mitigations for a painful race between foreground and background updates, cf #15677. + +Changes in SCHEMA_VERSION = 80 + - The event_txn_id_device_id is always written to for new events. + - Add tables for the task scheduler. + +Changes in SCHEMA_VERSION = 81 + - The event_txn_id is no longer written to for new events. + +Changes in SCHEMA_VERSION = 82 + - The insertion_events, insertion_event_extremities, insertion_event_edges, and + batch_events tables are no longer purged in preparation for their removal. """ SCHEMA_COMPAT_VERSION = ( - # Queries against `event_stream_ordering` columns in membership tables must - # be disambiguated. - # - # insertions to the column `full_user_id` of tables profiles and user_filters can no - # longer be null - 76 + # The event_txn_id table and tables from MSC2716 no longer exist. + 82 ) """Limit on how far the synapse codebase can be rolled back without breaking db compat diff --git a/synapse/storage/schema/main/delta/48/group_unique_indexes.py b/synapse/storage/schema/main/delta/48/group_unique_indexes.py index ad2da4c8af..622686d28f 100644 --- a/synapse/storage/schema/main/delta/48/group_unique_indexes.py +++ b/synapse/storage/schema/main/delta/48/group_unique_indexes.py @@ -14,7 +14,7 @@ from synapse.storage.database import LoggingTransaction -from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine +from synapse.storage.engines import BaseDatabaseEngine from synapse.storage.prepare_database import get_statements FIX_INDEXES = """ @@ -37,7 +37,7 @@ CREATE INDEX group_rooms_r_idx ON group_rooms(room_id); def run_create(cur: LoggingTransaction, database_engine: BaseDatabaseEngine) -> None: - rowid = "ctid" if isinstance(database_engine, PostgresEngine) else "rowid" + rowid = database_engine.row_id_name # remove duplicates from group_users & group_invites tables cur.execute( diff --git a/synapse/storage/schema/main/delta/77/05thread_notifications_backfill.sql b/synapse/storage/schema/main/delta/77/05thread_notifications_backfill.sql new file mode 100644 index 0000000000..a5da7a17a0 --- /dev/null +++ b/synapse/storage/schema/main/delta/77/05thread_notifications_backfill.sql @@ -0,0 +1,48 @@ +/* Copyright 2023 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Force the background updates from 06thread_notifications.sql to run in the +-- foreground as code will now require those to be "done". + +DELETE FROM background_updates WHERE update_name = 'event_push_backfill_thread_id'; + +-- Overwrite any null thread_id values. +UPDATE event_push_actions_staging SET thread_id = 'main' WHERE thread_id IS NULL; +UPDATE event_push_actions SET thread_id = 'main' WHERE thread_id IS NULL; + +-- Empirically we can end up with entries in the push summary table with both a +-- `NULL` and `main` thread ID, which causes the insert below to fail. We fudge +-- this by deleting any `NULL` rows that have a corresponding `main`. +DELETE FROM event_push_summary AS a WHERE thread_id IS NULL AND EXISTS ( + SELECT 1 FROM event_push_summary AS b + WHERE b.thread_id = 'main' AND a.user_id = b.user_id AND a.room_id = b.room_id +); +-- Copy the NULL threads to have a 'main' thread ID. +-- +-- Note: Some people seem to have duplicate rows with a `NULL` thread ID, in +-- which case we just fudge it with using MAX of the values. The counts *may* be +-- wrong for such rooms, but a) its an edge case, and b) they'll be fixed when +-- the user reads the room. +INSERT INTO event_push_summary (user_id, room_id, notif_count, stream_ordering, unread_count, last_receipt_stream_ordering, thread_id) + SELECT user_id, room_id, MAX(notif_count), MAX(stream_ordering), MAX(unread_count), MAX(last_receipt_stream_ordering), 'main' + FROM event_push_summary + WHERE thread_id IS NULL + GROUP BY user_id, room_id, thread_id; + +DELETE FROM event_push_summary AS a WHERE thread_id IS NULL; + +-- Drop the background updates to calculate the indexes used to find null thread_ids. +DELETE FROM background_updates WHERE update_name = 'event_push_actions_thread_id_null'; +DELETE FROM background_updates WHERE update_name = 'event_push_summary_thread_id_null'; diff --git a/synapse/storage/schema/main/delta/77/06thread_notifications_not_null.sql.sqlite b/synapse/storage/schema/main/delta/77/06thread_notifications_not_null.sql.sqlite new file mode 100644 index 0000000000..d19b9648b5 --- /dev/null +++ b/synapse/storage/schema/main/delta/77/06thread_notifications_not_null.sql.sqlite @@ -0,0 +1,102 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + -- The thread_id columns can now be made non-nullable. +-- +-- SQLite doesn't support modifying columns to an existing table, so it must +-- be recreated. + +-- Create the new tables. +CREATE TABLE event_push_actions_staging_new ( + event_id TEXT NOT NULL, + user_id TEXT NOT NULL, + actions TEXT NOT NULL, + notif SMALLINT NOT NULL, + highlight SMALLINT NOT NULL, + unread SMALLINT, + thread_id TEXT, + inserted_ts BIGINT, + CONSTRAINT event_push_actions_staging_thread_id CHECK (thread_id is NOT NULL) +); + +CREATE TABLE event_push_actions_new ( + room_id TEXT NOT NULL, + event_id TEXT NOT NULL, + user_id TEXT NOT NULL, + profile_tag VARCHAR(32), + actions TEXT NOT NULL, + topological_ordering BIGINT, + stream_ordering BIGINT, + notif SMALLINT, + highlight SMALLINT, + unread SMALLINT, + thread_id TEXT, + CONSTRAINT event_id_user_id_profile_tag_uniqueness UNIQUE (room_id, event_id, user_id, profile_tag), + CONSTRAINT event_push_actions_thread_id CHECK (thread_id is NOT NULL) +); + +CREATE TABLE event_push_summary_new ( + user_id TEXT NOT NULL, + room_id TEXT NOT NULL, + notif_count BIGINT NOT NULL, + stream_ordering BIGINT NOT NULL, + unread_count BIGINT, + last_receipt_stream_ordering BIGINT, + thread_id TEXT, + CONSTRAINT event_push_summary_thread_id CHECK (thread_id is NOT NULL) +); + +-- Copy the data. +INSERT INTO event_push_actions_staging_new (event_id, user_id, actions, notif, highlight, unread, thread_id, inserted_ts) + SELECT event_id, user_id, actions, notif, highlight, unread, thread_id, inserted_ts + FROM event_push_actions_staging; + +INSERT INTO event_push_actions_new (room_id, event_id, user_id, profile_tag, actions, topological_ordering, stream_ordering, notif, highlight, unread, thread_id) + SELECT room_id, event_id, user_id, profile_tag, actions, topological_ordering, stream_ordering, notif, highlight, unread, thread_id + FROM event_push_actions; + +INSERT INTO event_push_summary_new (user_id, room_id, notif_count, stream_ordering, unread_count, last_receipt_stream_ordering, thread_id) + SELECT user_id, room_id, notif_count, stream_ordering, unread_count, last_receipt_stream_ordering, thread_id + FROM event_push_summary; + +-- Drop the old tables. +DROP TABLE event_push_actions_staging; +DROP TABLE event_push_actions; +DROP TABLE event_push_summary; + +-- Rename the tables. +ALTER TABLE event_push_actions_staging_new RENAME TO event_push_actions_staging; +ALTER TABLE event_push_actions_new RENAME TO event_push_actions; +ALTER TABLE event_push_summary_new RENAME TO event_push_summary; + +-- Recreate the indexes. +CREATE INDEX event_push_actions_staging_id ON event_push_actions_staging(event_id); + +CREATE INDEX event_push_actions_highlights_index ON event_push_actions (user_id, room_id, topological_ordering, stream_ordering); +CREATE INDEX event_push_actions_rm_tokens on event_push_actions( user_id, room_id, topological_ordering, stream_ordering ); +CREATE INDEX event_push_actions_room_id_user_id on event_push_actions(room_id, user_id); +CREATE INDEX event_push_actions_stream_ordering on event_push_actions( stream_ordering, user_id ); +CREATE INDEX event_push_actions_u_highlight ON event_push_actions (user_id, stream_ordering); + +CREATE UNIQUE INDEX event_push_summary_unique_index2 ON event_push_summary (user_id, room_id, thread_id) ; + +-- Recreate some indexes in the background, by re-running the background updates +-- from 72/02event_push_actions_index.sql and 72/06thread_notifications.sql. +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (7706, 'event_push_summary_unique_index2', '{}') + ON CONFLICT (update_name) DO UPDATE SET progress_json = '{}'; +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (7706, 'event_push_actions_stream_highlight_index', '{}') + ON CONFLICT (update_name) DO UPDATE SET progress_json = '{}'; diff --git a/synapse/storage/schema/main/delta/77/06thread_notifications_not_null_event_push_actions.sql.postgres b/synapse/storage/schema/main/delta/77/06thread_notifications_not_null_event_push_actions.sql.postgres new file mode 100644 index 0000000000..381184b5e2 --- /dev/null +++ b/synapse/storage/schema/main/delta/77/06thread_notifications_not_null_event_push_actions.sql.postgres @@ -0,0 +1,27 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- The thread_id columns can now be made non-nullable, this is done by using a +-- constraint (and not altering the column) to avoid taking out a full table lock. +-- +-- We initially add an invalid constraint which guards against new data (this +-- doesn't lock the table). +ALTER TABLE event_push_actions + ADD CONSTRAINT event_push_actions_thread_id CHECK (thread_id IS NOT NULL) NOT VALID; + +-- We then validate the constraint which doesn't need to worry about new data. It +-- only needs a SHARE UPDATE EXCLUSIVE lock but can still take a while to complete. +INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES + (7706, 'event_push_actions_thread_id', '{}', 'event_push_actions_staging_thread_id'); diff --git a/synapse/storage/schema/main/delta/77/06thread_notifications_not_null_event_push_actions_staging.sql.postgres b/synapse/storage/schema/main/delta/77/06thread_notifications_not_null_event_push_actions_staging.sql.postgres new file mode 100644 index 0000000000..395f9c7260 --- /dev/null +++ b/synapse/storage/schema/main/delta/77/06thread_notifications_not_null_event_push_actions_staging.sql.postgres @@ -0,0 +1,27 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- The thread_id columns can now be made non-nullable, this is done by using a +-- constraint (and not altering the column) to avoid taking out a full table lock. +-- +-- We initially add an invalid constraint which guards against new data (this +-- doesn't lock the table). +ALTER TABLE event_push_actions_staging + ADD CONSTRAINT event_push_actions_staging_thread_id CHECK (thread_id IS NOT NULL) NOT VALID; + +-- We then validate the constraint which doesn't need to worry about new data. It +-- only needs a SHARE UPDATE EXCLUSIVE lock but can still take a while to complete. +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (7706, 'event_push_actions_staging_thread_id', '{}'); diff --git a/synapse/storage/schema/main/delta/77/06thread_notifications_not_null_event_push_summary.sql.postgres b/synapse/storage/schema/main/delta/77/06thread_notifications_not_null_event_push_summary.sql.postgres new file mode 100644 index 0000000000..140ceff1fa --- /dev/null +++ b/synapse/storage/schema/main/delta/77/06thread_notifications_not_null_event_push_summary.sql.postgres @@ -0,0 +1,29 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- The thread_id columns can now be made non-nullable, this is done by using a +-- constraint (and not altering the column) to avoid taking out a full table lock. +-- +-- We initially add an invalid constraint which guards against new data (this +-- doesn't lock the table). +ALTER TABLE event_push_summary + ADD CONSTRAINT event_push_summary_thread_id CHECK (thread_id IS NOT NULL) NOT VALID; + +-- We then validate the constraint which doesn't need to worry about new data. It +-- only needs a SHARE UPDATE EXCLUSIVE lock but can still take a while to complete. +INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES + (7706, 'event_push_summary_thread_id', '{}', 'event_push_actions_thread_id'), + -- Also clean-up the old indexes. + (7706, 'event_push_drop_null_thread_id_indexes', '{}', 'event_push_summary_thread_id'); diff --git a/synapse/storage/schema/main/delta/77/14bg_indices_event_stream_ordering.sql b/synapse/storage/schema/main/delta/77/14bg_indices_event_stream_ordering.sql new file mode 100644 index 0000000000..ec8cd522ec --- /dev/null +++ b/synapse/storage/schema/main/delta/77/14bg_indices_event_stream_ordering.sql @@ -0,0 +1,20 @@ +/* Copyright 2023 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT INTO background_updates (ordering, update_name, progress_json) + VALUES + (7714, 'current_state_events_stream_ordering_idx', '{}'), + (7714, 'local_current_membership_stream_ordering_idx', '{}'), + (7714, 'room_memberships_stream_ordering_idx', '{}'); diff --git a/synapse/storage/schema/main/delta/78/01_validate_and_update_profiles.py b/synapse/storage/schema/main/delta/78/01_validate_and_update_profiles.py new file mode 100644 index 0000000000..8398d8f548 --- /dev/null +++ b/synapse/storage/schema/main/delta/78/01_validate_and_update_profiles.py @@ -0,0 +1,92 @@ +# Copyright 2023 The Matrix.org Foundation C.I.C +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.config.homeserver import HomeServerConfig +from synapse.storage.database import LoggingTransaction +from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine + + +def run_upgrade( + cur: LoggingTransaction, + database_engine: BaseDatabaseEngine, + config: HomeServerConfig, +) -> None: + """ + Part 3 of a multi-step migration to drop the column `user_id` and replace it with + `full_user_id`. See the database schema docs for more information on the full + migration steps. + """ + hostname = config.server.server_name + + if isinstance(database_engine, PostgresEngine): + # check if the constraint can be validated + check_sql = """ + SELECT user_id from profiles WHERE full_user_id IS NULL + """ + cur.execute(check_sql) + res = cur.fetchall() + + if res: + # there are rows the background job missed, finish them here before we validate the constraint + process_rows_sql = """ + UPDATE profiles + SET full_user_id = '@' || user_id || ? + WHERE user_id IN ( + SELECT user_id FROM profiles WHERE full_user_id IS NULL + ) + """ + cur.execute(process_rows_sql, (f":{hostname}",)) + + # Now we can validate + validate_sql = """ + ALTER TABLE profiles VALIDATE CONSTRAINT full_user_id_not_null + """ + cur.execute(validate_sql) + + else: + # in SQLite we need to rewrite the table to add the constraint. + # First drop any temporary table that might be here from a previous failed migration. + cur.execute("DROP TABLE IF EXISTS temp_profiles") + + create_sql = """ + CREATE TABLE temp_profiles ( + full_user_id text NOT NULL, + user_id text, + displayname text, + avatar_url text, + UNIQUE (full_user_id), + UNIQUE (user_id) + ) + """ + cur.execute(create_sql) + + copy_sql = """ + INSERT INTO temp_profiles ( + user_id, + displayname, + avatar_url, + full_user_id) + SELECT user_id, displayname, avatar_url, '@' || user_id || ':' || ? FROM profiles + """ + cur.execute(copy_sql, (f"{hostname}",)) + + drop_sql = """ + DROP TABLE profiles + """ + cur.execute(drop_sql) + + rename_sql = """ + ALTER TABLE temp_profiles RENAME to profiles + """ + cur.execute(rename_sql) diff --git a/synapse/storage/schema/main/delta/78/02_validate_and_update_user_filters.py b/synapse/storage/schema/main/delta/78/02_validate_and_update_user_filters.py new file mode 100644 index 0000000000..e148ed26f2 --- /dev/null +++ b/synapse/storage/schema/main/delta/78/02_validate_and_update_user_filters.py @@ -0,0 +1,93 @@ +# Copyright 2023 The Matrix.org Foundation C.I.C +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.config.homeserver import HomeServerConfig +from synapse.storage.database import LoggingTransaction +from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine + + +def run_upgrade( + cur: LoggingTransaction, + database_engine: BaseDatabaseEngine, + config: HomeServerConfig, +) -> None: + """ + Part 3 of a multi-step migration to drop the column `user_id` and replace it with + `full_user_id`. See the database schema docs for more information on the full + migration steps. + """ + hostname = config.server.server_name + + if isinstance(database_engine, PostgresEngine): + # check if the constraint can be validated + check_sql = """ + SELECT user_id from user_filters WHERE full_user_id IS NULL + """ + cur.execute(check_sql) + res = cur.fetchall() + + if res: + # there are rows the background job missed, finish them here before we validate constraint + process_rows_sql = """ + UPDATE user_filters + SET full_user_id = '@' || user_id || ? + WHERE user_id IN ( + SELECT user_id FROM user_filters WHERE full_user_id IS NULL + ) + """ + cur.execute(process_rows_sql, (f":{hostname}",)) + + # Now we can validate + validate_sql = """ + ALTER TABLE user_filters VALIDATE CONSTRAINT full_user_id_not_null + """ + cur.execute(validate_sql) + + else: + cur.execute("DROP TABLE IF EXISTS temp_user_filters") + create_sql = """ + CREATE TABLE temp_user_filters ( + full_user_id text NOT NULL, + user_id text NOT NULL, + filter_id bigint NOT NULL, + filter_json bytea NOT NULL + ) + """ + cur.execute(create_sql) + + index_sql = """ + CREATE UNIQUE INDEX IF NOT EXISTS user_filters_unique ON + temp_user_filters (user_id, filter_id) + """ + cur.execute(index_sql) + + copy_sql = """ + INSERT INTO temp_user_filters ( + user_id, + filter_id, + filter_json, + full_user_id) + SELECT user_id, filter_id, filter_json, '@' || user_id || ':' || ? FROM user_filters + """ + cur.execute(copy_sql, (f"{hostname}",)) + + drop_sql = """ + DROP TABLE user_filters + """ + cur.execute(drop_sql) + + rename_sql = """ + ALTER TABLE temp_user_filters RENAME to user_filters + """ + cur.execute(rename_sql) diff --git a/synapse/storage/schema/main/delta/78/03_remove_unused_indexes_user_filters.py b/synapse/storage/schema/main/delta/78/03_remove_unused_indexes_user_filters.py new file mode 100644 index 0000000000..f5ba1c3fd4 --- /dev/null +++ b/synapse/storage/schema/main/delta/78/03_remove_unused_indexes_user_filters.py @@ -0,0 +1,65 @@ +# Copyright 2023 The Matrix.org Foundation C.I.C +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from synapse.config.homeserver import HomeServerConfig +from synapse.storage.database import LoggingTransaction +from synapse.storage.engines import BaseDatabaseEngine, Sqlite3Engine + + +def run_update( + cur: LoggingTransaction, + database_engine: BaseDatabaseEngine, + config: HomeServerConfig, +) -> None: + """ + Fix to drop unused indexes caused by incorrectly adding UNIQUE constraint to + columns `user_id` and `full_user_id` of table `user_filters` in previous migration. + """ + + if isinstance(database_engine, Sqlite3Engine): + cur.execute("DROP TABLE IF EXISTS temp_user_filters") + create_sql = """ + CREATE TABLE temp_user_filters ( + full_user_id text NOT NULL, + user_id text NOT NULL, + filter_id bigint NOT NULL, + filter_json bytea NOT NULL + ) + """ + cur.execute(create_sql) + + copy_sql = """ + INSERT INTO temp_user_filters ( + user_id, + filter_id, + filter_json, + full_user_id) + SELECT user_id, filter_id, filter_json, full_user_id FROM user_filters + """ + cur.execute(copy_sql) + + drop_sql = """ + DROP TABLE user_filters + """ + cur.execute(drop_sql) + + rename_sql = """ + ALTER TABLE temp_user_filters RENAME to user_filters + """ + cur.execute(rename_sql) + + index_sql = """ + CREATE UNIQUE INDEX IF NOT EXISTS user_filters_unique ON + user_filters (user_id, filter_id) + """ + cur.execute(index_sql) diff --git a/synapse/storage/schema/main/delta/78/03event_extremities_constraints.py b/synapse/storage/schema/main/delta/78/03event_extremities_constraints.py new file mode 100644 index 0000000000..bf8c57dbe8 --- /dev/null +++ b/synapse/storage/schema/main/delta/78/03event_extremities_constraints.py @@ -0,0 +1,57 @@ +# Copyright 2023 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This migration adds foreign key constraint to `event_forward_extremities` table. +""" +from synapse.storage.background_updates import ( + ForeignKeyConstraint, + run_validate_constraint_and_delete_rows_schema_delta, +) +from synapse.storage.database import LoggingTransaction +from synapse.storage.engines import BaseDatabaseEngine + +FORWARD_EXTREMITIES_TABLE_SCHEMA = """ + CREATE TABLE event_forward_extremities2( + event_id TEXT NOT NULL, + room_id TEXT NOT NULL, + UNIQUE (event_id, room_id), + CONSTRAINT event_forward_extremities_event_id FOREIGN KEY (event_id) REFERENCES events (event_id) DEFERRABLE INITIALLY DEFERRED + ) +""" + + +def run_create(cur: LoggingTransaction, database_engine: BaseDatabaseEngine) -> None: + # We mark this as a deferred constraint, as the previous version of Synapse + # inserted the event into the forward extremities *before* the events table. + # By marking as deferred we ensure that downgrading to the previous version + # will continue to work. + run_validate_constraint_and_delete_rows_schema_delta( + cur, + ordering=7803, + update_name="event_forward_extremities_event_id_foreign_key_constraint_update", + table="event_forward_extremities", + constraint_name="event_forward_extremities_event_id", + constraint=ForeignKeyConstraint( + "events", [("event_id", "event_id")], deferred=True + ), + sqlite_table_name="event_forward_extremities2", + sqlite_table_schema=FORWARD_EXTREMITIES_TABLE_SCHEMA, + ) + + # We can't add a similar constraint to `event_backward_extremities` as the + # events in there don't exist in the `events` table and `event_edges` + # doesn't have a unique constraint on `prev_event_id` (so we can't make a + # foreign key point to it). diff --git a/synapse/storage/schema/main/delta/78/04_add_full_user_id_index_user_filters.py b/synapse/storage/schema/main/delta/78/04_add_full_user_id_index_user_filters.py new file mode 100644 index 0000000000..97fecc2bd9 --- /dev/null +++ b/synapse/storage/schema/main/delta/78/04_add_full_user_id_index_user_filters.py @@ -0,0 +1,25 @@ +# Copyright 2023 The Matrix.org Foundation C.I.C +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.storage.database import LoggingTransaction +from synapse.storage.engines import BaseDatabaseEngine, Sqlite3Engine + + +def run_create(cur: LoggingTransaction, database_engine: BaseDatabaseEngine) -> None: + if isinstance(database_engine, Sqlite3Engine): + idx_sql = """ + CREATE UNIQUE INDEX IF NOT EXISTS user_filters_full_user_id_unique ON + user_filters (full_user_id, filter_id) + """ + cur.execute(idx_sql) diff --git a/synapse/storage/schema/main/delta/79/03_read_write_locks_triggers.sql.postgres b/synapse/storage/schema/main/delta/79/03_read_write_locks_triggers.sql.postgres new file mode 100644 index 0000000000..7df07ab0da --- /dev/null +++ b/synapse/storage/schema/main/delta/79/03_read_write_locks_triggers.sql.postgres @@ -0,0 +1,102 @@ +/* Copyright 2023 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +-- We implement read/write style locks by using two tables with mutual foreign +-- key constraints. Note that this implementation is vulnerable to starving +-- writers if read locks repeatedly get acquired. +-- +-- The first table (`worker_read_write_locks_mode`) indicates that a given lock +-- has either been acquired in read mode *or* write mode, but not both. This is +-- enforced by the unique constraint. Each instance of a lock being acquired is +-- associated with a random `token`. +-- +-- The second table (`worker_read_write_locks`) tracks who has currently +-- acquired a given lock. For a given lock_name/lock_key, there can be multiple +-- read locks at a time but only one write lock (no mixing read and write locks +-- at the same time). +-- +-- The foreign key from the second to first table enforces that for any given +-- lock the second table cannot have a mix of rows with read or write. +-- +-- The foreign key from the first to second table enforces that we don't have a +-- row for a lock in the first table if not in the second table. +-- +-- +-- Furthermore, we add some triggers to automatically keep the first table up to +-- date when inserting/deleting from the second table. This reduces the number +-- of round trips needed to acquire and release locks, as those operations +-- simply become an INSERT or DELETE. These triggers are added in a separate +-- delta due to database specific syntax. + + +-- A table to track whether a lock is currently acquired, and if so whether its +-- in read or write mode. +CREATE TABLE IF NOT EXISTS worker_read_write_locks_mode ( + lock_name TEXT NOT NULL, + lock_key TEXT NOT NULL, + -- Whether this lock is in read (false) or write (true) mode + write_lock BOOLEAN NOT NULL, + -- A token that has currently acquired the lock. We need this so that we can + -- add a foreign constraint from this table to `worker_read_write_locks`. + token TEXT NOT NULL +); + +-- Ensure that we can only have one row per lock +CREATE UNIQUE INDEX IF NOT EXISTS worker_read_write_locks_mode_key ON worker_read_write_locks_mode (lock_name, lock_key); +-- We need this (redundant) constraint so that we can have a foreign key +-- constraint against this table. +CREATE UNIQUE INDEX IF NOT EXISTS worker_read_write_locks_mode_type ON worker_read_write_locks_mode (lock_name, lock_key, write_lock); + + +-- A table to track who has currently acquired a given lock. +CREATE TABLE IF NOT EXISTS worker_read_write_locks ( + lock_name TEXT NOT NULL, + lock_key TEXT NOT NULL, + -- We write the instance name to ease manual debugging, we don't ever read + -- from it. + -- Note: instance names aren't guarenteed to be unique. + instance_name TEXT NOT NULL, + -- Whether the process has taken out a "read" or a "write" lock. + write_lock BOOLEAN NOT NULL, + -- A random string generated each time an instance takes out a lock. Used by + -- the instance to tell whether the lock is still held by it (e.g. in the + -- case where the process stalls for a long time the lock may time out and + -- be taken out by another instance, at which point the original instance + -- can tell it no longer holds the lock as the tokens no longer match). + token TEXT NOT NULL, + last_renewed_ts BIGINT NOT NULL, + + -- This constraint ensures that a given lock has only been acquired in read + -- xor write mode, but not both. + FOREIGN KEY (lock_name, lock_key, write_lock) REFERENCES worker_read_write_locks_mode (lock_name, lock_key, write_lock) +); + +CREATE UNIQUE INDEX IF NOT EXISTS worker_read_write_locks_key ON worker_read_write_locks (lock_name, lock_key, token); +-- Ensures that only one instance can acquire a lock in write mode at a time. +CREATE UNIQUE INDEX IF NOT EXISTS worker_read_write_locks_write ON worker_read_write_locks (lock_name, lock_key) WHERE write_lock; + + +-- Add a foreign key constraint to ensure that if a lock is in +-- `worker_read_write_locks_mode` then there must be a corresponding row in +-- `worker_read_write_locks` (i.e. we don't accidentally end up with a row in +-- `worker_read_write_locks_mode` when the lock is not currently acquired). +-- +-- We only add to PostgreSQL as SQLite does not support adding constraints +-- after table creation, and so doesn't support "circular" foreign key +-- constraints. +ALTER TABLE worker_read_write_locks_mode DROP CONSTRAINT IF EXISTS worker_read_write_locks_mode_foreign; +ALTER TABLE worker_read_write_locks_mode ADD CONSTRAINT worker_read_write_locks_mode_foreign + FOREIGN KEY (lock_name, lock_key, token) REFERENCES worker_read_write_locks(lock_name, lock_key, token) DEFERRABLE INITIALLY DEFERRED; diff --git a/synapse/storage/schema/main/delta/79/03_read_write_locks_triggers.sql.sqlite b/synapse/storage/schema/main/delta/79/03_read_write_locks_triggers.sql.sqlite new file mode 100644 index 0000000000..95f9dbf120 --- /dev/null +++ b/synapse/storage/schema/main/delta/79/03_read_write_locks_triggers.sql.sqlite @@ -0,0 +1,72 @@ +/* Copyright 2023 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +-- c.f. the postgres version for context. The tables and constraints are the +-- same, however they need to be defined slightly differently to work around how +-- each database handles circular foreign key references. + + + +-- A table to track whether a lock is currently acquired, and if so whether its +-- in read or write mode. +CREATE TABLE IF NOT EXISTS worker_read_write_locks_mode ( + lock_name TEXT NOT NULL, + lock_key TEXT NOT NULL, + -- Whether this lock is in read (false) or write (true) mode + write_lock BOOLEAN NOT NULL, + -- A token that has currently acquired the lock. We need this so that we can + -- add a foreign constraint from this table to `worker_read_write_locks`. + token TEXT NOT NULL, + -- Add a foreign key constraint to ensure that if a lock is in + -- `worker_read_write_locks_mode` then there must be a corresponding row in + -- `worker_read_write_locks` (i.e. we don't accidentally end up with a row in + -- `worker_read_write_locks_mode` when the lock is not currently acquired). + FOREIGN KEY (lock_name, lock_key, token) REFERENCES worker_read_write_locks(lock_name, lock_key, token) DEFERRABLE INITIALLY DEFERRED +); + +-- Ensure that we can only have one row per lock +CREATE UNIQUE INDEX IF NOT EXISTS worker_read_write_locks_mode_key ON worker_read_write_locks_mode (lock_name, lock_key); +-- We need this (redundant) constraint so that we can have a foreign key +-- constraint against this table. +CREATE UNIQUE INDEX IF NOT EXISTS worker_read_write_locks_mode_type ON worker_read_write_locks_mode (lock_name, lock_key, write_lock); + + +-- A table to track who has currently acquired a given lock. +CREATE TABLE IF NOT EXISTS worker_read_write_locks ( + lock_name TEXT NOT NULL, + lock_key TEXT NOT NULL, + -- We write the instance name to ease manual debugging, we don't ever read + -- from it. + -- Note: instance names aren't guarenteed to be unique. + instance_name TEXT NOT NULL, + -- Whether the process has taken out a "read" or a "write" lock. + write_lock BOOLEAN NOT NULL, + -- A random string generated each time an instance takes out a lock. Used by + -- the instance to tell whether the lock is still held by it (e.g. in the + -- case where the process stalls for a long time the lock may time out and + -- be taken out by another instance, at which point the original instance + -- can tell it no longer holds the lock as the tokens no longer match). + token TEXT NOT NULL, + last_renewed_ts BIGINT NOT NULL, + + -- This constraint ensures that a given lock has only been acquired in read + -- xor write mode, but not both. + FOREIGN KEY (lock_name, lock_key, write_lock) REFERENCES worker_read_write_locks_mode (lock_name, lock_key, write_lock) +); + +CREATE UNIQUE INDEX IF NOT EXISTS worker_read_write_locks_key ON worker_read_write_locks (lock_name, lock_key, token); +-- Ensures that only one instance can acquire a lock in write mode at a time. +CREATE UNIQUE INDEX IF NOT EXISTS worker_read_write_locks_write ON worker_read_write_locks (lock_name, lock_key) WHERE write_lock; diff --git a/synapse/storage/schema/main/delta/79/04_mitigate_stream_ordering_update_race.py b/synapse/storage/schema/main/delta/79/04_mitigate_stream_ordering_update_race.py new file mode 100644 index 0000000000..ae63585847 --- /dev/null +++ b/synapse/storage/schema/main/delta/79/04_mitigate_stream_ordering_update_race.py @@ -0,0 +1,70 @@ +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from synapse.storage.database import LoggingTransaction +from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine + + +def run_create( + cur: LoggingTransaction, + database_engine: BaseDatabaseEngine, +) -> None: + """ + An attempt to mitigate a painful race between foreground and background updates + touching the `stream_ordering` column of the events table. More info can be found + at https://github.com/matrix-org/synapse/issues/15677. + """ + + # technically the bg update we're concerned with below should only have been added in + # postgres but it doesn't hurt to be extra careful + if isinstance(database_engine, PostgresEngine): + select_sql = """ + SELECT 1 FROM background_updates + WHERE update_name = 'replace_stream_ordering_column' + """ + cur.execute(select_sql) + res = cur.fetchone() + + # if the background update `replace_stream_ordering_column` is still pending, we need + # to drop the indexes added in 7403, and re-add them to the column `stream_ordering2` + # with the idea that they will be preserved when the column is renamed `stream_ordering` + # after the background update has finished + if res: + drop_cse_sql = """ + ALTER TABLE current_state_events DROP CONSTRAINT IF EXISTS event_stream_ordering_fkey + """ + cur.execute(drop_cse_sql) + + drop_lcm_sql = """ + ALTER TABLE local_current_membership DROP CONSTRAINT IF EXISTS event_stream_ordering_fkey + """ + cur.execute(drop_lcm_sql) + + drop_rm_sql = """ + ALTER TABLE room_memberships DROP CONSTRAINT IF EXISTS event_stream_ordering_fkey + """ + cur.execute(drop_rm_sql) + + add_cse_sql = """ + ALTER TABLE current_state_events ADD CONSTRAINT event_stream_ordering_fkey + FOREIGN KEY (event_stream_ordering) REFERENCES events(stream_ordering2) NOT VALID; + """ + cur.execute(add_cse_sql) + + add_lcm_sql = """ + ALTER TABLE local_current_membership ADD CONSTRAINT event_stream_ordering_fkey + FOREIGN KEY (event_stream_ordering) REFERENCES events(stream_ordering2) NOT VALID; + """ + cur.execute(add_lcm_sql) + + add_rm_sql = """ + ALTER TABLE room_memberships ADD CONSTRAINT event_stream_ordering_fkey + FOREIGN KEY (event_stream_ordering) REFERENCES events(stream_ordering2) NOT VALID; + """ + cur.execute(add_rm_sql) diff --git a/synapse/storage/schema/main/delta/79/05_read_write_locks_triggers.sql.postgres b/synapse/storage/schema/main/delta/79/05_read_write_locks_triggers.sql.postgres new file mode 100644 index 0000000000..ea3496ef2d --- /dev/null +++ b/synapse/storage/schema/main/delta/79/05_read_write_locks_triggers.sql.postgres @@ -0,0 +1,69 @@ +/* Copyright 2023 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Fix up the triggers that were in `78/04_read_write_locks_triggers.sql` + +-- Add a trigger to UPSERT into `worker_read_write_locks_mode` whenever we try +-- and acquire a lock, i.e. insert into `worker_read_write_locks`, +CREATE OR REPLACE FUNCTION upsert_read_write_lock_parent() RETURNS trigger AS $$ +BEGIN + INSERT INTO worker_read_write_locks_mode (lock_name, lock_key, write_lock, token) + VALUES (NEW.lock_name, NEW.lock_key, NEW.write_lock, NEW.token) + ON CONFLICT (lock_name, lock_key) + DO UPDATE SET write_lock = NEW.write_lock, token = NEW.token; + RETURN NEW; +END +$$ +LANGUAGE plpgsql; + +DROP TRIGGER IF EXISTS upsert_read_write_lock_parent_trigger ON worker_read_write_locks; +CREATE TRIGGER upsert_read_write_lock_parent_trigger BEFORE INSERT ON worker_read_write_locks + FOR EACH ROW + EXECUTE PROCEDURE upsert_read_write_lock_parent(); + + +-- Ensure that we keep `worker_read_write_locks_mode` up to date whenever a lock +-- is released (i.e. a row deleted from `worker_read_write_locks`). Either we +-- update the `worker_read_write_locks_mode.token` to match another instance +-- that has currently acquired the lock, or we delete the row if nobody has +-- currently acquired a lock. +CREATE OR REPLACE FUNCTION delete_read_write_lock_parent() RETURNS trigger AS $$ +DECLARE + new_token TEXT; +BEGIN + SELECT token INTO new_token FROM worker_read_write_locks + WHERE + lock_name = OLD.lock_name + AND lock_key = OLD.lock_key + LIMIT 1 FOR UPDATE; + + IF NOT FOUND THEN + DELETE FROM worker_read_write_locks_mode + WHERE lock_name = OLD.lock_name AND lock_key = OLD.lock_key AND token = OLD.token; + ELSE + UPDATE worker_read_write_locks_mode + SET token = new_token + WHERE lock_name = OLD.lock_name AND lock_key = OLD.lock_key; + END IF; + + RETURN NEW; +END +$$ +LANGUAGE plpgsql; + +DROP TRIGGER IF EXISTS delete_read_write_lock_parent_trigger ON worker_read_write_locks; +CREATE TRIGGER delete_read_write_lock_parent_trigger AFTER DELETE ON worker_read_write_locks + FOR EACH ROW + EXECUTE PROCEDURE delete_read_write_lock_parent(); diff --git a/synapse/storage/schema/main/delta/79/05_read_write_locks_triggers.sql.sqlite b/synapse/storage/schema/main/delta/79/05_read_write_locks_triggers.sql.sqlite new file mode 100644 index 0000000000..acb1a77c80 --- /dev/null +++ b/synapse/storage/schema/main/delta/79/05_read_write_locks_triggers.sql.sqlite @@ -0,0 +1,65 @@ +/* Copyright 2023 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Fix up the triggers that were in `78/04_read_write_locks_triggers.sql` + +-- Add a trigger to UPSERT into `worker_read_write_locks_mode` whenever we try +-- and acquire a lock, i.e. insert into `worker_read_write_locks`, +DROP TRIGGER IF EXISTS upsert_read_write_lock_parent_trigger; +CREATE TRIGGER IF NOT EXISTS upsert_read_write_lock_parent_trigger +BEFORE INSERT ON worker_read_write_locks +FOR EACH ROW +BEGIN + -- First ensure that `worker_read_write_locks_mode` doesn't have stale + -- entries in it, as on SQLite we don't have the foreign key constraint to + -- enforce this. + DELETE FROM worker_read_write_locks_mode + WHERE lock_name = NEW.lock_name AND lock_key = NEW.lock_key + AND NOT EXISTS ( + SELECT 1 FROM worker_read_write_locks + WHERE lock_name = NEW.lock_name AND lock_key = NEW.lock_key + ); + + INSERT INTO worker_read_write_locks_mode (lock_name, lock_key, write_lock, token) + VALUES (NEW.lock_name, NEW.lock_key, NEW.write_lock, NEW.token) + ON CONFLICT (lock_name, lock_key) + DO UPDATE SET write_lock = NEW.write_lock, token = NEW.token; +END; + +-- Ensure that we keep `worker_read_write_locks_mode` up to date whenever a lock +-- is released (i.e. a row deleted from `worker_read_write_locks`). Either we +-- update the `worker_read_write_locks_mode.token` to match another instance +-- that has currently acquired the lock, or we delete the row if nobody has +-- currently acquired a lock. +DROP TRIGGER IF EXISTS delete_read_write_lock_parent_trigger; +CREATE TRIGGER IF NOT EXISTS delete_read_write_lock_parent_trigger +AFTER DELETE ON worker_read_write_locks +FOR EACH ROW +BEGIN + DELETE FROM worker_read_write_locks_mode + WHERE lock_name = OLD.lock_name AND lock_key = OLD.lock_key + AND token = OLD.token + AND NOT EXISTS ( + SELECT 1 FROM worker_read_write_locks + WHERE lock_name = OLD.lock_name AND lock_key = OLD.lock_key + ); + + UPDATE worker_read_write_locks_mode + SET token = ( + SELECT token FROM worker_read_write_locks + WHERE lock_name = OLD.lock_name AND lock_key = OLD.lock_key + ) + WHERE lock_name = OLD.lock_name AND lock_key = OLD.lock_key; +END; diff --git a/synapse/storage/schema/main/delta/80/01_users_alter_locked.sql b/synapse/storage/schema/main/delta/80/01_users_alter_locked.sql new file mode 100644 index 0000000000..21c7971441 --- /dev/null +++ b/synapse/storage/schema/main/delta/80/01_users_alter_locked.sql @@ -0,0 +1,16 @@ +/* Copyright 2023 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE users ADD locked BOOLEAN DEFAULT FALSE NOT NULL; diff --git a/synapse/storage/schema/main/delta/80/02_read_write_locks_unlogged.sql.postgres b/synapse/storage/schema/main/delta/80/02_read_write_locks_unlogged.sql.postgres new file mode 100644 index 0000000000..5b5dbf2687 --- /dev/null +++ b/synapse/storage/schema/main/delta/80/02_read_write_locks_unlogged.sql.postgres @@ -0,0 +1,30 @@ +/* Copyright 2023 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Mark the worker_read_write_locks* tables as UNLOGGED, to increase +-- performance. This means that we don't replicate the tables, and they get +-- truncated on a crash. This is acceptable as a) in those cases it's likely +-- that Synapse needs to be stopped/restarted anyway, and b) the locks are +-- considered best-effort anyway. + +-- We need to remove and recreate the circular foreign key references, as +-- UNLOGGED tables can't reference normal tables. +ALTER TABLE worker_read_write_locks_mode DROP CONSTRAINT IF EXISTS worker_read_write_locks_mode_foreign; + +ALTER TABLE worker_read_write_locks SET UNLOGGED; +ALTER TABLE worker_read_write_locks_mode SET UNLOGGED; + +ALTER TABLE worker_read_write_locks_mode ADD CONSTRAINT worker_read_write_locks_mode_foreign + FOREIGN KEY (lock_name, lock_key, token) REFERENCES worker_read_write_locks(lock_name, lock_key, token) DEFERRABLE INITIALLY DEFERRED; diff --git a/synapse/storage/schema/main/delta/80/02_scheduled_tasks.sql b/synapse/storage/schema/main/delta/80/02_scheduled_tasks.sql new file mode 100644 index 0000000000..286d109ed7 --- /dev/null +++ b/synapse/storage/schema/main/delta/80/02_scheduled_tasks.sql @@ -0,0 +1,28 @@ +/* Copyright 2023 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- cf ScheduledTask docstring for the meaning of the fields. +CREATE TABLE IF NOT EXISTS scheduled_tasks( + id TEXT PRIMARY KEY, + action TEXT NOT NULL, + status TEXT NOT NULL, + timestamp BIGINT NOT NULL, + resource_id TEXT, + params TEXT, + result TEXT, + error TEXT +); + +CREATE INDEX IF NOT EXISTS scheduled_tasks_status ON scheduled_tasks(status); diff --git a/synapse/storage/schema/main/delta/80/03_read_write_locks_triggers.sql.postgres b/synapse/storage/schema/main/delta/80/03_read_write_locks_triggers.sql.postgres new file mode 100644 index 0000000000..31de5bfa18 --- /dev/null +++ b/synapse/storage/schema/main/delta/80/03_read_write_locks_triggers.sql.postgres @@ -0,0 +1,37 @@ +/* Copyright 2023 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Fix up the triggers that were in `78/04_read_write_locks_triggers.sql` + +-- Reduce the number of writes we do on this table. +-- +-- Note: that we still want to lock the row here (i.e. still do a `DO UPDATE +-- SET`) so that we serialize updates. +CREATE OR REPLACE FUNCTION upsert_read_write_lock_parent() RETURNS trigger AS $$ +BEGIN + INSERT INTO worker_read_write_locks_mode (lock_name, lock_key, write_lock, token) + VALUES (NEW.lock_name, NEW.lock_key, NEW.write_lock, NEW.token) + ON CONFLICT (lock_name, lock_key) + DO UPDATE SET write_lock = NEW.write_lock + WHERE OLD.write_lock != NEW.write_lock; + RETURN NEW; +END +$$ +LANGUAGE plpgsql; + +DROP TRIGGER IF EXISTS upsert_read_write_lock_parent_trigger ON worker_read_write_locks; +CREATE TRIGGER upsert_read_write_lock_parent_trigger BEFORE INSERT ON worker_read_write_locks + FOR EACH ROW + EXECUTE PROCEDURE upsert_read_write_lock_parent(); diff --git a/synapse/storage/schema/main/delta/80/04_read_write_locks_deadlock.sql.postgres b/synapse/storage/schema/main/delta/80/04_read_write_locks_deadlock.sql.postgres new file mode 100644 index 0000000000..0eb459c0b9 --- /dev/null +++ b/synapse/storage/schema/main/delta/80/04_read_write_locks_deadlock.sql.postgres @@ -0,0 +1,71 @@ +/* Copyright 2023 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +-- Remove a previous attempt to avoid deadlocks +DROP TRIGGER IF EXISTS delete_read_write_lock_parent_before_trigger ON worker_read_write_locks; +DROP FUNCTION IF EXISTS delete_read_write_lock_parent_before; + + +-- Ensure that we keep `worker_read_write_locks_mode` up to date whenever a lock +-- is released (i.e. a row deleted from `worker_read_write_locks`). Either we +-- update the `worker_read_write_locks_mode.token` to match another instance +-- that has currently acquired the lock, or we delete the row if nobody has +-- currently acquired a lock. +CREATE OR REPLACE FUNCTION delete_read_write_lock_parent() RETURNS trigger AS $$ +DECLARE + new_token TEXT; + mode_row_token TEXT; +BEGIN + -- Only update the token in `_mode` if its our token. This prevents + -- deadlocks. + -- + -- We shove the token into `mode_row_token`, as otherwise postgres complains + -- we're not using the returned data. + SELECT token INTO mode_row_token FROM worker_read_write_locks_mode + WHERE + lock_name = OLD.lock_name + AND lock_key = OLD.lock_key + AND token = OLD.token + FOR UPDATE; + + IF NOT FOUND THEN + RETURN NEW; + END IF; + + SELECT token INTO new_token FROM worker_read_write_locks + WHERE + lock_name = OLD.lock_name + AND lock_key = OLD.lock_key + LIMIT 1 FOR UPDATE SKIP LOCKED; + + IF NOT FOUND THEN + DELETE FROM worker_read_write_locks_mode + WHERE lock_name = OLD.lock_name AND lock_key = OLD.lock_key AND token = OLD.token; + ELSE + UPDATE worker_read_write_locks_mode + SET token = new_token + WHERE lock_name = OLD.lock_name AND lock_key = OLD.lock_key; + END IF; + + RETURN NEW; +END +$$ +LANGUAGE plpgsql; + +DROP TRIGGER IF EXISTS delete_read_write_lock_parent_trigger ON worker_read_write_locks; +CREATE TRIGGER delete_read_write_lock_parent_trigger AFTER DELETE ON worker_read_write_locks + FOR EACH ROW + EXECUTE PROCEDURE delete_read_write_lock_parent(); diff --git a/synapse/storage/schema/main/delta/82/02_scheduled_tasks_index.sql b/synapse/storage/schema/main/delta/82/02_scheduled_tasks_index.sql new file mode 100644 index 0000000000..6b90275139 --- /dev/null +++ b/synapse/storage/schema/main/delta/82/02_scheduled_tasks_index.sql @@ -0,0 +1,16 @@ +/* Copyright 2023 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE INDEX IF NOT EXISTS scheduled_tasks_timestamp ON scheduled_tasks(timestamp); diff --git a/synapse/storage/schema/main/delta/82/03_drop_old_tables.sql b/synapse/storage/schema/main/delta/82/03_drop_old_tables.sql new file mode 100644 index 0000000000..149020bbd7 --- /dev/null +++ b/synapse/storage/schema/main/delta/82/03_drop_old_tables.sql @@ -0,0 +1,24 @@ +/* Copyright 2023 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Drop the old event transaction ID table, the event_txn_id_device_id table +-- should be used instead. +DROP TABLE IF EXISTS event_txn_id; + +-- Drop tables related to MSC2716 since the implementation is being removed +DROP TABLE insertion_events; +DROP TABLE insertion_event_edges; +DROP TABLE insertion_event_extremities; +DROP TABLE batch_events; diff --git a/synapse/storage/types.py b/synapse/storage/types.py index 34ac807530..afaeef9a5a 100644 --- a/synapse/storage/types.py +++ b/synapse/storage/types.py @@ -53,22 +53,10 @@ class Cursor(Protocol): @property def description( self, - ) -> Optional[ - Sequence[ - # Note that this is an approximate typing based on sqlite3 and other - # drivers, and may not be entirely accurate. - # FWIW, the DBAPI 2 spec is: https://peps.python.org/pep-0249/#description - Tuple[ - str, - Optional[Any], - Optional[int], - Optional[int], - Optional[int], - Optional[int], - Optional[int], - ] - ] - ]: + ) -> Optional[Sequence[Any]]: + # At the time of writing, Synapse only assumes that `column[0]: str` for each + # `column in description`. Since this is hard to express in the type system, and + # as this is rarely used in Synapse, we deem `column: Any` good enough. ... @property diff --git a/synapse/streams/events.py b/synapse/streams/events.py index d7084d2358..609a0978a9 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Iterator, Tuple +from typing import TYPE_CHECKING, Sequence, Tuple import attr @@ -23,7 +23,7 @@ from synapse.handlers.room import RoomEventSource from synapse.handlers.typing import TypingNotificationEventSource from synapse.logging.opentracing import trace from synapse.streams import EventSource -from synapse.types import StreamToken +from synapse.types import StreamKeyType, StreamToken if TYPE_CHECKING: from synapse.server import HomeServer @@ -37,9 +37,14 @@ class _EventSourcesInner: receipt: ReceiptEventSource account_data: AccountDataEventSource - def get_sources(self) -> Iterator[Tuple[str, EventSource]]: - for attribute in attr.fields(_EventSourcesInner): - yield attribute.name, getattr(self, attribute.name) + def get_sources(self) -> Sequence[Tuple[StreamKeyType, EventSource]]: + return [ + (StreamKeyType.ROOM, self.room), + (StreamKeyType.PRESENCE, self.presence), + (StreamKeyType.TYPING, self.typing), + (StreamKeyType.RECEIPT, self.receipt), + (StreamKeyType.ACCOUNT_DATA, self.account_data), + ] class EventSources: diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py index a4afe6e8c2..77d9679582 100644 --- a/synapse/types/__init__.py +++ b/synapse/types/__init__.py @@ -15,6 +15,7 @@ import abc import re import string +from enum import Enum from typing import ( TYPE_CHECKING, AbstractSet, @@ -22,6 +23,7 @@ from typing import ( ClassVar, Dict, List, + Literal, Mapping, Match, MutableMapping, @@ -32,13 +34,14 @@ from typing import ( Type, TypeVar, Union, + overload, ) import attr from immutabledict import immutabledict from signedjson.key import decode_verify_key_bytes from signedjson.types import VerifyKey -from typing_extensions import Final, TypedDict +from typing_extensions import TypedDict from unpaddedbase64 import decode_base64 from zope.interface import Interface @@ -58,6 +61,8 @@ from synapse.util.cancellation import cancellable from synapse.util.stringutils import parse_and_validate_server_name if TYPE_CHECKING: + from typing_extensions import Self + from synapse.appservice.api import ApplicationService from synapse.storage.databases.main import DataStore, PurgeEventsStore from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore @@ -117,11 +122,12 @@ class Requester: Attributes: user: id of the user making the request - access_token_id: *ID* of the access token used for this - request, or None if it came via the appservice API or similar + access_token_id: *ID* of the access token used for this request, or + None for appservices, guests, and tokens generated by the admin API is_guest: True if the user making this request is a guest user shadow_banned: True if the user making this request has been shadow-banned. - device_id: device_id which was set at authentication time + device_id: device_id which was set at authentication time, or + None for appservices, guests, and tokens generated by the admin API app_service: the AS requesting on behalf of the user authenticated_entity: The entity that authenticated when making the request. This is different to the user_id when an admin user or the server is @@ -131,6 +137,7 @@ class Requester: user: "UserID" access_token_id: Optional[int] is_guest: bool + scope: Set[str] shadow_banned: bool device_id: Optional[str] app_service: Optional["ApplicationService"] @@ -147,6 +154,7 @@ class Requester: "user_id": self.user.to_string(), "access_token_id": self.access_token_id, "is_guest": self.is_guest, + "scope": list(self.scope), "shadow_banned": self.shadow_banned, "device_id": self.device_id, "app_server_id": self.app_service.id if self.app_service else None, @@ -175,6 +183,7 @@ class Requester: user=UserID.from_string(input["user_id"]), access_token_id=input["access_token_id"], is_guest=input["is_guest"], + scope=set(input.get("scope", [])), shadow_banned=input["shadow_banned"], device_id=input["device_id"], app_service=appservice, @@ -186,6 +195,7 @@ def create_requester( user_id: Union[str, "UserID"], access_token_id: Optional[int] = None, is_guest: bool = False, + scope: StrCollection = (), shadow_banned: bool = False, device_id: Optional[str] = None, app_service: Optional["ApplicationService"] = None, @@ -199,6 +209,7 @@ def create_requester( access_token_id: *ID* of the access token used for this request, or None if it came via the appservice API or similar is_guest: True if the user making this request is a guest user + scope: the scope of the access token used for this request, if any shadow_banned: True if the user making this request is shadow-banned. device_id: device_id which was set at authentication time app_service: the AS requesting on behalf of the user @@ -215,10 +226,13 @@ def create_requester( if authenticated_entity is None: authenticated_entity = user_id.to_string() + scope = set(scope) + return Requester( user_id, access_token_id, is_guest, + scope, shadow_banned, device_id, app_service, @@ -340,22 +354,15 @@ class EventID(DomainSpecificString): SIGIL = "$" -mxid_localpart_allowed_characters = set( - "_-./=" + string.ascii_lowercase + string.digits +MXID_LOCALPART_ALLOWED_CHARACTERS = set( + "_-./=+" + string.ascii_lowercase + string.digits ) -# MSC4007 adds the + to the allowed characters. -# -# TODO If this was accepted, update the SSO code to support this, see the callers -# of map_username_to_mxid_localpart. -extended_mxid_localpart_allowed_characters = mxid_localpart_allowed_characters | {"+"} # Guest user IDs are purely numeric. GUEST_USER_ID_PATTERN = re.compile(r"^\d+$") -def contains_invalid_mxid_characters( - localpart: str, use_extended_character_set: bool -) -> bool: +def contains_invalid_mxid_characters(localpart: str) -> bool: """Check for characters not allowed in an mxid or groupid localpart Args: @@ -366,12 +373,7 @@ def contains_invalid_mxid_characters( Returns: True if there are any naughty characters """ - allowed_characters = ( - extended_mxid_localpart_allowed_characters - if use_extended_character_set - else mxid_localpart_allowed_characters - ) - return any(c not in allowed_characters for c in localpart) + return any(c not in MXID_LOCALPART_ALLOWED_CHARACTERS for c in localpart) UPPER_CASE_PATTERN = re.compile(b"[A-Z_]") @@ -388,7 +390,7 @@ UPPER_CASE_PATTERN = re.compile(b"[A-Z_]") # bytes rather than strings # NON_MXID_CHARACTER_PATTERN = re.compile( - ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters - {"="})),)).encode( + ("[^%s]" % (re.escape("".join(MXID_LOCALPART_ALLOWED_CHARACTERS - {"="})),)).encode( "ascii" ) ) @@ -437,7 +439,78 @@ def map_username_to_mxid_localpart( @attr.s(frozen=True, slots=True, order=False) -class RoomStreamToken: +class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta): + """An abstract stream token class for streams that supports multiple + writers. + + This works by keeping track of the stream position of each writer, + represented by a default `stream` attribute and a map of instance name to + stream position of any writers that are ahead of the default stream + position. + """ + + stream: int = attr.ib(validator=attr.validators.instance_of(int), kw_only=True) + + instance_map: "immutabledict[str, int]" = attr.ib( + factory=immutabledict, + validator=attr.validators.deep_mapping( + key_validator=attr.validators.instance_of(str), + value_validator=attr.validators.instance_of(int), + mapping_validator=attr.validators.instance_of(immutabledict), + ), + kw_only=True, + ) + + @classmethod + @abc.abstractmethod + async def parse(cls, store: "DataStore", string: str) -> "Self": + """Parse the string representation of the token.""" + ... + + @abc.abstractmethod + async def to_string(self, store: "DataStore") -> str: + """Serialize the token into its string representation.""" + ... + + def copy_and_advance(self, other: "Self") -> "Self": + """Return a new token such that if an event is after both this token and + the other token, then its after the returned token too. + """ + + max_stream = max(self.stream, other.stream) + + instance_map = { + instance: max( + self.instance_map.get(instance, self.stream), + other.instance_map.get(instance, other.stream), + ) + for instance in set(self.instance_map).union(other.instance_map) + } + + return attr.evolve( + self, stream=max_stream, instance_map=immutabledict(instance_map) + ) + + def get_max_stream_pos(self) -> int: + """Get the maximum stream position referenced in this token. + + The corresponding "min" position is, by definition just `self.stream`. + + This is used to handle tokens that have non-empty `instance_map`, and so + reference stream positions after the `self.stream` position. + """ + return max(self.instance_map.values(), default=self.stream) + + def get_stream_pos_for_instance(self, instance_name: str) -> int: + """Get the stream position that the given writer was at at this token.""" + + # If we don't have an entry for the instance we can assume that it was + # at `self.stream`. + return self.instance_map.get(instance_name, self.stream) + + +@attr.s(frozen=True, slots=True, order=False) +class RoomStreamToken(AbstractMultiWriterStreamToken): """Tokens are positions between events. The token "s1" comes after event 1. s0 s1 @@ -514,16 +587,8 @@ class RoomStreamToken: topological: Optional[int] = attr.ib( validator=attr.validators.optional(attr.validators.instance_of(int)), - ) - stream: int = attr.ib(validator=attr.validators.instance_of(int)) - - instance_map: "immutabledict[str, int]" = attr.ib( - factory=immutabledict, - validator=attr.validators.deep_mapping( - key_validator=attr.validators.instance_of(str), - value_validator=attr.validators.instance_of(int), - mapping_validator=attr.validators.instance_of(immutabledict), - ), + kw_only=True, + default=None, ) def __attrs_post_init__(self) -> None: @@ -583,17 +648,7 @@ class RoomStreamToken: if self.topological or other.topological: raise Exception("Can't advance topological tokens") - max_stream = max(self.stream, other.stream) - - instance_map = { - instance: max( - self.instance_map.get(instance, self.stream), - other.instance_map.get(instance, other.stream), - ) - for instance in set(self.instance_map).union(other.instance_map) - } - - return RoomStreamToken(None, max_stream, immutabledict(instance_map)) + return super().copy_and_advance(other) def as_historical_tuple(self) -> Tuple[int, int]: """Returns a tuple of `(topological, stream)` for historical tokens. @@ -619,16 +674,6 @@ class RoomStreamToken: # at `self.stream`. return self.instance_map.get(instance_name, self.stream) - def get_max_stream_pos(self) -> int: - """Get the maximum stream position referenced in this token. - - The corresponding "min" position is, by definition just `self.stream`. - - This is used to handle tokens that have non-empty `instance_map`, and so - reference stream positions after the `self.stream` position. - """ - return max(self.instance_map.values(), default=self.stream) - async def to_string(self, store: "DataStore") -> str: if self.topological is not None: return "t%d-%d" % (self.topological, self.stream) @@ -650,20 +695,20 @@ class RoomStreamToken: return "s%d" % (self.stream,) -class StreamKeyType: +class StreamKeyType(Enum): """Known stream types. A stream is a list of entities ordered by an incrementing "stream token". """ - ROOM: Final = "room_key" - PRESENCE: Final = "presence_key" - TYPING: Final = "typing_key" - RECEIPT: Final = "receipt_key" - ACCOUNT_DATA: Final = "account_data_key" - PUSH_RULES: Final = "push_rules_key" - TO_DEVICE: Final = "to_device_key" - DEVICE_LIST: Final = "device_list_key" + ROOM = "room_key" + PRESENCE = "presence_key" + TYPING = "typing_key" + RECEIPT = "receipt_key" + ACCOUNT_DATA = "account_data_key" + PUSH_RULES = "push_rules_key" + TO_DEVICE = "to_device_key" + DEVICE_LIST = "device_list_key" UN_PARTIAL_STATED_ROOMS = "un_partial_stated_rooms_key" @@ -785,7 +830,7 @@ class StreamToken: def room_stream_id(self) -> int: return self.room_key.stream - def copy_and_advance(self, key: str, new_value: Any) -> "StreamToken": + def copy_and_advance(self, key: StreamKeyType, new_value: Any) -> "StreamToken": """Advance the given key in the token to a new value if and only if the new value is after the old value. @@ -798,35 +843,68 @@ class StreamToken: return new_token new_token = self.copy_and_replace(key, new_value) - new_id = int(getattr(new_token, key)) - old_id = int(getattr(self, key)) + new_id = new_token.get_field(key) + old_id = self.get_field(key) if old_id < new_id: return new_token else: return self - def copy_and_replace(self, key: str, new_value: Any) -> "StreamToken": - return attr.evolve(self, **{key: new_value}) + def copy_and_replace(self, key: StreamKeyType, new_value: Any) -> "StreamToken": + return attr.evolve(self, **{key.value: new_value}) + @overload + def get_field(self, key: Literal[StreamKeyType.ROOM]) -> RoomStreamToken: + ... -StreamToken.START = StreamToken(RoomStreamToken(None, 0), 0, 0, 0, 0, 0, 0, 0, 0, 0) + @overload + def get_field( + self, + key: Literal[ + StreamKeyType.ACCOUNT_DATA, + StreamKeyType.DEVICE_LIST, + StreamKeyType.PRESENCE, + StreamKeyType.PUSH_RULES, + StreamKeyType.RECEIPT, + StreamKeyType.TO_DEVICE, + StreamKeyType.TYPING, + StreamKeyType.UN_PARTIAL_STATED_ROOMS, + ], + ) -> int: + ... + @overload + def get_field(self, key: StreamKeyType) -> Union[int, RoomStreamToken]: + ... -@attr.s(slots=True, frozen=True, auto_attribs=True) -class PersistedEventPosition: - """Position of a newly persisted event with instance that persisted it. + def get_field(self, key: StreamKeyType) -> Union[int, RoomStreamToken]: + """Returns the stream ID for the given key.""" + return getattr(self, key.value) + + +StreamToken.START = StreamToken(RoomStreamToken(stream=0), 0, 0, 0, 0, 0, 0, 0, 0, 0) - This can be used to test whether the event is persisted before or after a - RoomStreamToken. - """ + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class PersistedPosition: + """Position of a newly persisted row with instance that persisted it.""" instance_name: str stream: int - def persisted_after(self, token: RoomStreamToken) -> bool: + def persisted_after(self, token: AbstractMultiWriterStreamToken) -> bool: return token.get_stream_pos_for_instance(self.instance_name) < self.stream + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class PersistedEventPosition(PersistedPosition): + """Position of a newly persisted event with instance that persisted it. + + This can be used to test whether the event is persisted before or after a + RoomStreamToken. + """ + def to_room_stream_token(self) -> RoomStreamToken: """Converts the position to a room stream token such that events persisted in the same room after this position will be after the @@ -837,7 +915,7 @@ class PersistedEventPosition: """ # Doing the naive thing satisfies the desired properties described in # the docstring. - return RoomStreamToken(None, self.stream) + return RoomStreamToken(stream=self.stream) @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -934,31 +1012,37 @@ def get_verify_key_from_cross_signing_key( @attr.s(auto_attribs=True, frozen=True, slots=True) class UserInfo: - """Holds information about a user. Result of get_userinfo_by_id. + """Holds information about a user. Result of get_user_by_id. Attributes: user_id: ID of the user. appservice_id: Application service ID that created this user. consent_server_notice_sent: Version of policy documents the user has been sent. consent_version: Version of policy documents the user has consented to. + consent_ts: Time the user consented creation_ts: Creation timestamp of the user. is_admin: True if the user is an admin. is_deactivated: True if the user has been deactivated. is_guest: True if the user is a guest user. is_shadow_banned: True if the user has been shadow-banned. user_type: User type (None for normal user, 'support' and 'bot' other options). + approved: If the user has been "approved" to register on the server. + locked: Whether the user's account has been locked """ user_id: UserID appservice_id: Optional[int] consent_server_notice_sent: Optional[str] consent_version: Optional[str] + consent_ts: Optional[int] user_type: Optional[str] creation_ts: int is_admin: bool is_deactivated: bool is_guest: bool is_shadow_banned: bool + approved: bool + locked: bool @attr.s(auto_attribs=True, frozen=True, slots=True) @@ -985,3 +1069,41 @@ class UserProfile(TypedDict): class RetentionPolicy: min_lifetime: Optional[int] = None max_lifetime: Optional[int] = None + + +class TaskStatus(str, Enum): + """Status of a scheduled task""" + + # Task is scheduled but not active + SCHEDULED = "scheduled" + # Task is active and probably running, and if not + # will be run on next scheduler loop run + ACTIVE = "active" + # Task has completed successfully + COMPLETE = "complete" + # Task is over and either returned a failed status, or had an exception + FAILED = "failed" + + +@attr.s(auto_attribs=True, frozen=True, slots=True) +class ScheduledTask: + """Description of a scheduled task""" + + # Id used to identify the task + id: str + # Name of the action to be run by this task + action: str + # Current status of this task + status: TaskStatus + # If the status is SCHEDULED then this represents when it should be launched, + # otherwise it represents the last time this task got a change of state. + # In milliseconds since epoch in system time timezone, usually UTC. + timestamp: int + # Optionally bind a task to some resource id for easy retrieval + resource_id: Optional[str] + # Optional parameters that will be passed to the function ran by the task + params: Optional[JsonMapping] + # Optional result that can be updated by the running task + result: Optional[JsonMapping] + # Optional error that should be assigned a value when the status is FAILED + error: Optional[str] diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index 6f40914639..ba7d4f1246 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -76,7 +76,7 @@ def unwrapFirstError(failure: Failure) -> Failure: # the subFailure's value, which will do a better job of preserving stacktraces. # (actually, you probably want to use yieldable_gather_results anyway) failure.trap(defer.FirstError) - return failure.value.subFailure # type: ignore[union-attr] # Issue in Twisted's annotations + return failure.value.subFailure P = ParamSpec("P") @@ -116,6 +116,11 @@ class Clock: Waits `msec` initially before calling `f` for the first time. + If the function given to `looping_call` returns an awaitable/deferred, the next + call isn't scheduled until after the returned awaitable has finished. We get + this functionality thanks to this function being a thin wrapper around + `twisted.internet.task.LoopingCall`. + Note that the function will be called with no logcontext, so if it is anything other than trivial, you probably want to wrap it in run_as_background_process. @@ -178,7 +183,7 @@ def log_failure( """ logger.error( - msg, exc_info=(failure.type, failure.value, failure.getTracebackObject()) # type: ignore[arg-type] + msg, exc_info=(failure.type, failure.value, failure.getTracebackObject()) ) if not consumeErrors: diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 01e3cd46f6..0cbeb0c365 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -19,15 +19,18 @@ import collections import inspect import itertools import logging +import typing from contextlib import asynccontextmanager from typing import ( Any, + AsyncContextManager, AsyncIterator, Awaitable, Callable, Collection, Coroutine, Dict, + Generator, Generic, Hashable, Iterable, @@ -42,7 +45,7 @@ from typing import ( ) import attr -from typing_extensions import AsyncContextManager, Concatenate, Literal, ParamSpec +from typing_extensions import Concatenate, Literal, ParamSpec from twisted.internet import defer from twisted.internet.defer import CancelledError @@ -138,7 +141,7 @@ class ObservableDeferred(Generic[_T], AbstractObservableDeferred[_T]): for observer in observers: # This is a little bit of magic to correctly propagate stack # traces when we `await` on one of the observer deferreds. - f.value.__failure__ = f # type: ignore[union-attr] + f.value.__failure__ = f try: observer.errback(f) except Exception as e: @@ -397,7 +400,7 @@ class _LinearizerEntry: # The number of things executing. count: int # Deferreds for the things blocked from executing. - deferreds: collections.OrderedDict + deferreds: typing.OrderedDict["defer.Deferred[None]", Literal[1]] class Linearizer: @@ -716,30 +719,25 @@ def timeout_deferred( return new_d -# This class can't be generic because it uses slots with attrs. -# See: https://github.com/python-attrs/attrs/issues/313 @attr.s(slots=True, frozen=True, auto_attribs=True) -class DoneAwaitable: # should be: Generic[R] +class DoneAwaitable(Awaitable[R]): """Simple awaitable that returns the provided value.""" - value: Any # should be: R + value: R - def __await__(self) -> Any: - return self - - def __iter__(self) -> "DoneAwaitable": - return self - - def __next__(self) -> None: - raise StopIteration(self.value) + def __await__(self) -> Generator[Any, None, R]: + yield None + return self.value def maybe_awaitable(value: Union[Awaitable[R], R]) -> Awaitable[R]: """Convert a value to an awaitable if not already an awaitable.""" if inspect.isawaitable(value): - assert isinstance(value, Awaitable) return value + # For some reason mypy doesn't deduce that value is not Awaitable here, even though + # inspect.isawaitable returns a TypeGuard. + assert not isinstance(value, Awaitable) return DoneAwaitable(value) diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py index bf7bd351e0..029eedcc6f 100644 --- a/synapse/util/caches/deferred_cache.py +++ b/synapse/util/caches/deferred_cache.py @@ -470,7 +470,7 @@ class CacheMultipleEntries(CacheEntry[KT, VT]): def deferred(self, key: KT) -> "defer.Deferred[VT]": if not self._deferred: self._deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True) - return self._deferred.observe().addCallback(lambda res: res.get(key)) + return self._deferred.observe().addCallback(lambda res: res[key]) def add_invalidation_callback( self, key: KT, callback: Optional[Callable[[], None]] diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 81df71a0c5..ce736fdf75 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -36,6 +36,8 @@ from typing import ( ) from weakref import WeakValueDictionary +import attr + from twisted.internet import defer from twisted.python.failure import Failure @@ -220,7 +222,9 @@ class DeferredCacheDescriptor(_CacheDescriptorBase): self.iterable = iterable self.prune_unread_entries = prune_unread_entries - def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]: + def __get__( + self, obj: Optional[Any], owner: Optional[Type] + ) -> Callable[..., "defer.Deferred[Any]"]: cache: DeferredCache[CacheKey, Any] = DeferredCache( name=self.name, max_entries=self.max_entries, @@ -232,7 +236,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase): get_cache_key = self.cache_key_builder @functools.wraps(self.orig) - def _wrapped(*args: Any, **kwargs: Any) -> Any: + def _wrapped(*args: Any, **kwargs: Any) -> "defer.Deferred[Any]": # If we're passed a cache_context then we'll want to call its invalidate() # whenever we are invalidated invalidate_callback = kwargs.pop("on_invalidate", None) @@ -464,6 +468,35 @@ class _CacheContext: ) +@attr.s(auto_attribs=True, slots=True, frozen=True) +class _CachedFunctionDescriptor: + """Helper for `@cached`, we name it so that we can hook into it with mypy + plugin.""" + + max_entries: int + num_args: Optional[int] + uncached_args: Optional[Collection[str]] + tree: bool + cache_context: bool + iterable: bool + prune_unread_entries: bool + name: Optional[str] + + def __call__(self, orig: F) -> CachedFunction[F]: + d = DeferredCacheDescriptor( + orig, + max_entries=self.max_entries, + num_args=self.num_args, + uncached_args=self.uncached_args, + tree=self.tree, + cache_context=self.cache_context, + iterable=self.iterable, + prune_unread_entries=self.prune_unread_entries, + name=self.name, + ) + return cast(CachedFunction[F], d) + + def cached( *, max_entries: int = 1000, @@ -474,9 +507,8 @@ def cached( iterable: bool = False, prune_unread_entries: bool = True, name: Optional[str] = None, -) -> Callable[[F], CachedFunction[F]]: - func = lambda orig: DeferredCacheDescriptor( - orig, +) -> _CachedFunctionDescriptor: + return _CachedFunctionDescriptor( max_entries=max_entries, num_args=num_args, uncached_args=uncached_args, @@ -487,7 +519,26 @@ def cached( name=name, ) - return cast(Callable[[F], CachedFunction[F]], func) + +@attr.s(auto_attribs=True, slots=True, frozen=True) +class _CachedListFunctionDescriptor: + """Helper for `@cachedList`, we name it so that we can hook into it with mypy + plugin.""" + + cached_method_name: str + list_name: str + num_args: Optional[int] = None + name: Optional[str] = None + + def __call__(self, orig: F) -> CachedFunction[F]: + d = DeferredCacheListDescriptor( + orig, + cached_method_name=self.cached_method_name, + list_name=self.list_name, + num_args=self.num_args, + name=self.name, + ) + return cast(CachedFunction[F], d) def cachedList( @@ -496,7 +547,7 @@ def cachedList( list_name: str, num_args: Optional[int] = None, name: Optional[str] = None, -) -> Callable[[F], CachedFunction[F]]: +) -> _CachedListFunctionDescriptor: """Creates a descriptor that wraps a function in a `DeferredCacheListDescriptor`. Used to do batch lookups for an already created cache. One of the arguments @@ -525,16 +576,13 @@ def cachedList( def batch_do_something(self, first_arg, second_args): ... """ - func = lambda orig: DeferredCacheListDescriptor( - orig, + return _CachedListFunctionDescriptor( cached_method_name=cached_method_name, list_name=list_name, num_args=num_args, name=name, ) - return cast(Callable[[F], CachedFunction[F]], func) - def _get_cache_key_builder( param_names: Sequence[str], diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py index 5eaf70c7ab..2fbc7b1e6c 100644 --- a/synapse/util/caches/dictionary_cache.py +++ b/synapse/util/caches/dictionary_cache.py @@ -14,7 +14,7 @@ import enum import logging import threading -from typing import Any, Dict, Generic, Iterable, Optional, Set, Tuple, TypeVar, Union +from typing import Dict, Generic, Iterable, Optional, Set, Tuple, TypeVar, Union import attr from typing_extensions import Literal @@ -33,10 +33,8 @@ DKT = TypeVar("DKT") DV = TypeVar("DV") -# This class can't be generic because it uses slots with attrs. -# See: https://github.com/python-attrs/attrs/issues/313 @attr.s(slots=True, frozen=True, auto_attribs=True) -class DictionaryEntry: # should be: Generic[DKT, DV]. +class DictionaryEntry(Generic[DKT, DV]): """Returned when getting an entry from the cache If `full` is true then `known_absent` will be the empty set. @@ -50,8 +48,8 @@ class DictionaryEntry: # should be: Generic[DKT, DV]. """ full: bool - known_absent: Set[Any] # should be: Set[DKT] - value: Dict[Any, Any] # should be: Dict[DKT, DV] + known_absent: Set[DKT] + value: Dict[DKT, DV] def __len__(self) -> int: return len(self.value) diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py index 01ad02af67..e73cf66080 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py @@ -14,7 +14,7 @@ import logging from collections import OrderedDict -from typing import Any, Generic, Optional, TypeVar, Union, overload +from typing import Any, Generic, Iterable, Optional, TypeVar, Union, overload import attr from typing_extensions import Literal @@ -73,7 +73,7 @@ class ExpiringCache(Generic[KT, VT]): self._expiry_ms = expiry_ms self._reset_expiry_on_get = reset_expiry_on_get - self._cache: OrderedDict[KT, _CacheEntry] = OrderedDict() + self._cache: OrderedDict[KT, _CacheEntry[VT]] = OrderedDict() self.iterable = iterable @@ -84,9 +84,7 @@ class ExpiringCache(Generic[KT, VT]): return def f() -> "defer.Deferred[None]": - return run_as_background_process( - "prune_cache_%s" % self._cache_name, self._prune_cache - ) + return run_as_background_process("prune_cache", self._prune_cache) self._clock.looping_call(f, self._expiry_ms / 2) @@ -100,7 +98,10 @@ class ExpiringCache(Generic[KT, VT]): while self._max_size and len(self) > self._max_size: _key, value = self._cache.popitem(last=False) if self.iterable: - self.metrics.inc_evictions(EvictionReason.size, len(value.value)) + # type-ignore, here and below: if self.iterable is true, then the value + # type VT should be Sized (i.e. have a __len__ method). We don't enforce + # this via the type system at present. + self.metrics.inc_evictions(EvictionReason.size, len(value.value)) # type: ignore[arg-type] else: self.metrics.inc_evictions(EvictionReason.size) @@ -134,7 +135,7 @@ class ExpiringCache(Generic[KT, VT]): return default if self.iterable: - self.metrics.inc_evictions(EvictionReason.invalidation, len(value.value)) + self.metrics.inc_evictions(EvictionReason.invalidation, len(value.value)) # type: ignore[arg-type] else: self.metrics.inc_evictions(EvictionReason.invalidation) @@ -182,7 +183,7 @@ class ExpiringCache(Generic[KT, VT]): for k in keys_to_delete: value = self._cache.pop(k) if self.iterable: - self.metrics.inc_evictions(EvictionReason.time, len(value.value)) + self.metrics.inc_evictions(EvictionReason.time, len(value.value)) # type: ignore[arg-type] else: self.metrics.inc_evictions(EvictionReason.time) @@ -195,7 +196,8 @@ class ExpiringCache(Generic[KT, VT]): def __len__(self) -> int: if self.iterable: - return sum(len(entry.value) for entry in self._cache.values()) + g: Iterable[int] = (len(entry.value) for entry in self._cache.values()) # type: ignore[arg-type] + return sum(g) else: return len(self._cache) @@ -218,6 +220,6 @@ class ExpiringCache(Generic[KT, VT]): @attr.s(slots=True, auto_attribs=True) -class _CacheEntry: +class _CacheEntry(Generic[VT]): time: int - value: Any + value: VT diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 452d5d04c1..be6554319a 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -93,10 +93,8 @@ VT = TypeVar("VT") # a general type var, distinct from either KT or VT T = TypeVar("T") -P = TypeVar("P") - -class _TimedListNode(ListNode[P]): +class _TimedListNode(ListNode[T]): """A `ListNode` that tracks last access time.""" __slots__ = ["last_access_ts_secs"] @@ -821,7 +819,7 @@ class AsyncLruCache(Generic[KT, VT]): utilize external cache systems that require await behaviour to be created. """ - def __init__(self, *args, **kwargs): # type: ignore + def __init__(self, *args: Any, **kwargs: Any): self._lru_cache: LruCache[KT, VT] = LruCache(*args, **kwargs) async def get( @@ -844,7 +842,13 @@ class AsyncLruCache(Generic[KT, VT]): return self._lru_cache.get(key, update_metrics=update_metrics) async def set(self, key: KT, value: VT) -> None: - self._lru_cache.set(key, value) + # This will add the entries in the correct order, local first external second + self.set_local(key, value) + await self.set_external(key, value) + + async def set_external(self, key: KT, value: VT) -> None: + # This method should add an entry to any configured external cache, in this case noop. + pass def set_local(self, key: KT, value: VT) -> None: self._lru_cache.set(key, value) @@ -864,5 +868,5 @@ class AsyncLruCache(Generic[KT, VT]): async def contains(self, key: KT) -> bool: return self._lru_cache.contains(key) - async def clear(self) -> None: + def clear(self) -> None: self._lru_cache.clear() diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py index 340e5e9145..0cb46700a9 100644 --- a/synapse/util/caches/response_cache.py +++ b/synapse/util/caches/response_cache.py @@ -36,7 +36,7 @@ from synapse.logging.opentracing import ( ) from synapse.util import Clock from synapse.util.async_helpers import AbstractObservableDeferred, ObservableDeferred -from synapse.util.caches import register_cache +from synapse.util.caches import EvictionReason, register_cache logger = logging.getLogger(__name__) @@ -167,7 +167,7 @@ class ResponseCache(Generic[KV]): # the should_cache bit, we leave it in the cache for now and schedule # its removal later. if self.timeout_sec and context.should_cache: - self.clock.call_later(self.timeout_sec, self.unset, key) + self.clock.call_later(self.timeout_sec, self._entry_timeout, key) else: # otherwise, remove the result immediately. self.unset(key) @@ -185,6 +185,12 @@ class ResponseCache(Generic[KV]): Args: key: key used to remove the cached value """ + self._metrics.inc_evictions(EvictionReason.invalidation) + self._result_cache.pop(key, None) + + def _entry_timeout(self, key: KV) -> None: + """For the call_later to remove from the cache""" + self._metrics.inc_evictions(EvictionReason.time) self._result_cache.pop(key, None) async def wrap( diff --git a/synapse/util/caches/ttlcache.py b/synapse/util/caches/ttlcache.py index f6b3ee31e4..48a6e4a906 100644 --- a/synapse/util/caches/ttlcache.py +++ b/synapse/util/caches/ttlcache.py @@ -35,10 +35,10 @@ class TTLCache(Generic[KT, VT]): def __init__(self, cache_name: str, timer: Callable[[], float] = time.time): # map from key to _CacheEntry - self._data: Dict[KT, _CacheEntry] = {} + self._data: Dict[KT, _CacheEntry[KT, VT]] = {} # the _CacheEntries, sorted by expiry time - self._expiry_list: SortedList[_CacheEntry] = SortedList() + self._expiry_list: SortedList[_CacheEntry[KT, VT]] = SortedList() self._timer = timer @@ -160,11 +160,11 @@ class TTLCache(Generic[KT, VT]): @attr.s(frozen=True, slots=True, auto_attribs=True) -class _CacheEntry: # Should be Generic[KT, VT]. See python-attrs/attrs#313 +class _CacheEntry(Generic[KT, VT]): """TTLCache entry""" # expiry_time is the first attribute, so that entries are sorted by expiry. expiry_time: float ttl: float - key: Any # should be KT - value: Any # should be VT + key: KT + value: VT diff --git a/synapse/util/check_dependencies.py b/synapse/util/check_dependencies.py index 1c0fde4966..f7cead9e12 100644 --- a/synapse/util/check_dependencies.py +++ b/synapse/util/check_dependencies.py @@ -21,16 +21,13 @@ require. But this is probably just symptomatic of Python's package management. """ import logging +from importlib import metadata from typing import Iterable, NamedTuple, Optional from packaging.requirements import Requirement DISTRIBUTION_NAME = "matrix-synapse" -try: - from importlib import metadata -except ImportError: - import importlib_metadata as metadata # type: ignore[no-redef] __all__ = ["check_requirements"] @@ -54,9 +51,9 @@ class DependencyException(Exception): DEV_EXTRAS = {"lint", "mypy", "test", "dev"} -RUNTIME_EXTRAS = ( - set(metadata.metadata(DISTRIBUTION_NAME).get_all("Provides-Extra")) - DEV_EXTRAS -) +ALL_EXTRAS = metadata.metadata(DISTRIBUTION_NAME).get_all("Provides-Extra") +assert ALL_EXTRAS is not None +RUNTIME_EXTRAS = set(ALL_EXTRAS) - DEV_EXTRAS VERSION = metadata.version(DISTRIBUTION_NAME) diff --git a/synapse/util/gai_resolver.py b/synapse/util/gai_resolver.py index 214eb17fbc..fecf829ade 100644 --- a/synapse/util/gai_resolver.py +++ b/synapse/util/gai_resolver.py @@ -136,7 +136,7 @@ class GAIResolver: # The types on IHostnameResolver is incorrect in Twisted, see # https://twistedmatrix.com/trac/ticket/10276 - def resolveHostName( # type: ignore[override] + def resolveHostName( self, resolutionReceiver: IResolutionReceiver, hostName: str, diff --git a/synapse/util/iterutils.py b/synapse/util/iterutils.py index 4938ddf703..a0efb96d3b 100644 --- a/synapse/util/iterutils.py +++ b/synapse/util/iterutils.py @@ -15,11 +15,13 @@ import heapq from itertools import islice from typing import ( + Callable, Collection, Dict, Generator, Iterable, Iterator, + List, Mapping, Set, Sized, @@ -71,6 +73,31 @@ def chunk_seq(iseq: S, maxlen: int) -> Iterator[S]: return (iseq[i : i + maxlen] for i in range(0, len(iseq), maxlen)) +def partition( + iterable: Iterable[T], predicate: Callable[[T], bool] +) -> Tuple[List[T], List[T]]: + """ + Separate a given iterable into two lists based on the result of a predicate function. + + Args: + iterable: the iterable to partition (separate) + predicate: a function that takes an item from the iterable and returns a boolean + + Returns: + A tuple of two lists, the first containing all items for which the predicate + returned True, the second containing all items for which the predicate returned + False + """ + true_results = [] + false_results = [] + for item in iterable: + if predicate(item): + true_results.append(item) + else: + false_results.append(item) + return true_results, false_results + + def sorted_topologically( nodes: Iterable[T], graph: Mapping[T, Collection[T]], diff --git a/synapse/util/macaroons.py b/synapse/util/macaroons.py index 644c341e8c..db6c40a3e1 100644 --- a/synapse/util/macaroons.py +++ b/synapse/util/macaroons.py @@ -218,7 +218,7 @@ class MacaroonGenerator: # to avoid validating those as guest tokens, we explicitely verify if # the macaroon includes the "guest = true" caveat. is_guest = any( - (caveat.caveat_id == "guest = true" for caveat in macaroon.caveats) + caveat.caveat_id == "guest = true" for caveat in macaroon.caveats ) if not is_guest: diff --git a/synapse/util/manhole.py b/synapse/util/manhole.py index 48b8195ca1..8cb766860e 100644 --- a/synapse/util/manhole.py +++ b/synapse/util/manhole.py @@ -98,7 +98,9 @@ def manhole(settings: ManholeConfig, globals: Dict[str, Any]) -> ServerFactory: SynapseManhole, dict(globals, __name__="__console__") ) - factory = manhole_ssh.ConchFactory(portal.Portal(rlm, [checker])) + # type-ignore: This is an error in Twisted's annotations. See + # https://github.com/twisted/twisted/issues/11812 and /11813 . + factory = manhole_ssh.ConchFactory(portal.Portal(rlm, [checker])) # type: ignore[arg-type] # conch has the wrong type on these dicts (says bytes to bytes, # should be bytes to Keys judging by how it's used). diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py index 2ad55ac13e..f693ba2a8c 100644 --- a/synapse/util/ratelimitutils.py +++ b/synapse/util/ratelimitutils.py @@ -20,6 +20,7 @@ import typing from typing import ( Any, Callable, + ContextManager, DefaultDict, Dict, Iterator, @@ -33,7 +34,6 @@ from typing import ( from weakref import WeakSet from prometheus_client.core import Counter -from typing_extensions import ContextManager from twisted.internet import defer @@ -291,7 +291,8 @@ class _PerHostRatelimiter: if self.metrics_name: rate_limit_reject_counter.labels(self.metrics_name).inc() raise LimitExceededError( - retry_after_ms=int(self.window_size / self.sleep_limit) + limiter_name="rc_federation", + retry_after_ms=int(self.window_size / self.sleep_limit), ) self.request_times.append(time_now) diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py index dcc037b982..0e1f907667 100644 --- a/synapse/util/retryutils.py +++ b/synapse/util/retryutils.py @@ -19,6 +19,7 @@ from typing import TYPE_CHECKING, Any, Optional, Type from synapse.api.errors import CodeMessageException from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage import DataStore +from synapse.types import StrCollection from synapse.util import Clock if TYPE_CHECKING: @@ -27,15 +28,6 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -# the initial backoff, after the first transaction fails -MIN_RETRY_INTERVAL = 10 * 60 * 1000 - -# how much we multiply the backoff by after each subsequent fail -RETRY_MULTIPLIER = 5 - -# a cap on the backoff. (Essentially none) -MAX_RETRY_INTERVAL = 2**62 - class NotRetryingDestination(Exception): def __init__(self, retry_last_ts: int, retry_interval: int, destination: str): @@ -125,6 +117,30 @@ async def get_retry_limiter( ) +async def filter_destinations_by_retry_limiter( + destinations: StrCollection, + clock: Clock, + store: DataStore, + retry_due_within_ms: int = 0, +) -> StrCollection: + """Filter down the list of destinations to only those that will are either + alive or due for a retry (within `retry_due_within_ms`) + """ + if not destinations: + return destinations + + retry_timings = await store.get_destination_retry_timings_batch(destinations) + + now = int(clock.time_msec()) + + return [ + destination + for destination, timings in retry_timings.items() + if timings is None + or timings.retry_last_ts + timings.retry_interval <= now + retry_due_within_ms + ] + + class RetryDestinationLimiter: def __init__( self, @@ -137,6 +153,7 @@ class RetryDestinationLimiter: backoff_on_failure: bool = True, notifier: Optional["Notifier"] = None, replication_client: Optional["ReplicationCommandHandler"] = None, + backoff_on_all_error_codes: bool = False, ): """Marks the destination as "down" if an exception is thrown in the context, except for CodeMessageException with code < 500. @@ -156,6 +173,9 @@ class RetryDestinationLimiter: backoff_on_failure: set to False if we should not increase the retry interval on a failure. + + backoff_on_all_error_codes: Whether we should back off on any + error code. """ self.clock = clock self.store = store @@ -165,10 +185,21 @@ class RetryDestinationLimiter: self.retry_interval = retry_interval self.backoff_on_404 = backoff_on_404 self.backoff_on_failure = backoff_on_failure + self.backoff_on_all_error_codes = backoff_on_all_error_codes self.notifier = notifier self.replication_client = replication_client + self.destination_min_retry_interval_ms = ( + self.store.hs.config.federation.destination_min_retry_interval_ms + ) + self.destination_retry_multiplier = ( + self.store.hs.config.federation.destination_retry_multiplier + ) + self.destination_max_retry_interval_ms = ( + self.store.hs.config.federation.destination_max_retry_interval_ms + ) + def __enter__(self) -> None: pass @@ -178,6 +209,7 @@ class RetryDestinationLimiter: exc_val: Optional[BaseException], exc_tb: Optional[TracebackType], ) -> None: + success = exc_type is None valid_err_code = False if exc_type is None: valid_err_code = True @@ -194,7 +226,9 @@ class RetryDestinationLimiter: # won't accept our requests for at least a while. # 429 is us being aggressively rate limited, so lets rate limit # ourselves. - if exc_val.code == 404 and self.backoff_on_404: + if self.backoff_on_all_error_codes: + valid_err_code = False + elif exc_val.code == 404 and self.backoff_on_404: valid_err_code = False elif exc_val.code in (401, 429): valid_err_code = False @@ -203,7 +237,7 @@ class RetryDestinationLimiter: else: valid_err_code = False - if valid_err_code: + if success: # We connected successfully. if not self.retry_interval: return @@ -214,19 +248,27 @@ class RetryDestinationLimiter: self.failure_ts = None retry_last_ts = 0 self.retry_interval = 0 + elif valid_err_code: + # We got a potentially valid error code back. We don't reset the + # timers though, as the other side might actually be down anyway + # (e.g. some deprovisioned servers will always return a 404 or 403, + # and we don't want to keep resetting the retry timers for them). + return elif not self.backoff_on_failure: return else: # We couldn't connect. if self.retry_interval: self.retry_interval = int( - self.retry_interval * RETRY_MULTIPLIER * random.uniform(0.8, 1.4) + self.retry_interval + * self.destination_retry_multiplier + * random.uniform(0.8, 1.4) ) - if self.retry_interval >= MAX_RETRY_INTERVAL: - self.retry_interval = MAX_RETRY_INTERVAL + if self.retry_interval >= self.destination_max_retry_interval_ms: + self.retry_interval = self.destination_max_retry_interval_ms else: - self.retry_interval = MIN_RETRY_INTERVAL + self.retry_interval = self.destination_min_retry_interval_ms logger.info( "Connection to %s was unsuccessful (%s(%s)); backoff now %i", diff --git a/synapse/util/task_scheduler.py b/synapse/util/task_scheduler.py new file mode 100644 index 0000000000..caf13b3474 --- /dev/null +++ b/synapse/util/task_scheduler.py @@ -0,0 +1,398 @@ +# Copyright 2023 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Set, Tuple + +from twisted.python.failure import Failure + +from synapse.logging.context import nested_logging_context +from synapse.metrics import LaterGauge +from synapse.metrics.background_process_metrics import ( + run_as_background_process, + wrap_as_background_process, +) +from synapse.types import JsonMapping, ScheduledTask, TaskStatus +from synapse.util.stringutils import random_string + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class TaskScheduler: + """ + This is a simple task sheduler aimed at resumable tasks: usually we use `run_in_background` + to launch a background task, or Twisted `deferLater` if we want to do so later on. + + The problem with that is that the tasks will just stop and never be resumed if synapse + is stopped for whatever reason. + + How this works: + - A function mapped to a named action should first be registered with `register_action`. + This function will be called when trying to resuming tasks after a synapse shutdown, + so this registration should happen when synapse is initialised, NOT right before scheduling + a task. + - A task can then be launched using this named action with `schedule_task`. A `params` dict + can be passed, and it will be available to the registered function when launched. This task + can be launch either now-ish, or later on by giving a `timestamp` parameter. + + The function may call `update_task` at any time to update the `result` of the task, + and this can be used to resume the task at a specific point and/or to convey a result to + the code launching the task. + You can also specify the `result` (and/or an `error`) when returning from the function. + + The reconciliation loop runs every minute, so this is not a precise scheduler. + There is a limit of 10 concurrent tasks, so tasks may be delayed if the pool is already + full. In this regard, please take great care that scheduled tasks can actually finished. + For now there is no mechanism to stop a running task if it is stuck. + + Tasks will be run on the worker specified with `run_background_tasks_on` config, + or the main one by default. + """ + + # Precision of the scheduler, evaluation of tasks to run will only happen + # every `SCHEDULE_INTERVAL_MS` ms + SCHEDULE_INTERVAL_MS = 1 * 60 * 1000 # 1mn + # How often to clean up old tasks. + CLEANUP_INTERVAL_MS = 30 * 60 * 1000 + # Time before a complete or failed task is deleted from the DB + KEEP_TASKS_FOR_MS = 7 * 24 * 60 * 60 * 1000 # 1 week + # Maximum number of tasks that can run at the same time + MAX_CONCURRENT_RUNNING_TASKS = 10 + # Time from the last task update after which we will log a warning + LAST_UPDATE_BEFORE_WARNING_MS = 24 * 60 * 60 * 1000 # 24hrs + + def __init__(self, hs: "HomeServer"): + self._hs = hs + self._store = hs.get_datastores().main + self._clock = hs.get_clock() + self._running_tasks: Set[str] = set() + # A map between action names and their registered function + self._actions: Dict[ + str, + Callable[ + [ScheduledTask], + Awaitable[Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]], + ], + ] = {} + self._run_background_tasks = hs.config.worker.run_background_tasks + + # Flag to make sure we only try and launch new tasks once at a time. + self._launching_new_tasks = False + + if self._run_background_tasks: + self._clock.looping_call( + self._launch_scheduled_tasks, + TaskScheduler.SCHEDULE_INTERVAL_MS, + ) + self._clock.looping_call( + self._clean_scheduled_tasks, + TaskScheduler.SCHEDULE_INTERVAL_MS, + ) + + LaterGauge( + "synapse_scheduler_running_tasks", + "The number of concurrent running tasks handled by the TaskScheduler", + labels=None, + caller=lambda: len(self._running_tasks), + ) + + def register_action( + self, + function: Callable[ + [ScheduledTask], + Awaitable[Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]], + ], + action_name: str, + ) -> None: + """Register a function to be executed when an action is scheduled with + the specified action name. + + Actions need to be registered as early as possible so that a resumed action + can find its matching function. It's usually better to NOT do that right before + calling `schedule_task` but rather in an `__init__` method. + + Args: + function: The function to be executed for this action. The parameter + passed to the function when launched is the `ScheduledTask` being run. + The function should return a tuple of new `status`, `result` + and `error` as specified in `ScheduledTask`. + action_name: The name of the action to be associated with the function + """ + self._actions[action_name] = function + + async def schedule_task( + self, + action: str, + *, + resource_id: Optional[str] = None, + timestamp: Optional[int] = None, + params: Optional[JsonMapping] = None, + ) -> str: + """Schedule a new potentially resumable task. A function matching the specified + `action` should have be registered with `register_action` before the task is run. + + Args: + action: the name of a previously registered action + resource_id: a task can be associated with a resource id to facilitate + getting all tasks associated with a specific resource + timestamp: if `None`, the task will be launched as soon as possible, otherwise it + will be launch as soon as possible after the `timestamp` value. + Note that this scheduler is not meant to be precise, and the scheduling + could be delayed if too many tasks are already running + params: a set of parameters that can be easily accessed from inside the + executed function + + Returns: + The id of the scheduled task + """ + status = TaskStatus.SCHEDULED + if timestamp is None or timestamp < self._clock.time_msec(): + timestamp = self._clock.time_msec() + status = TaskStatus.ACTIVE + + task = ScheduledTask( + random_string(16), + action, + status, + timestamp, + resource_id, + params, + result=None, + error=None, + ) + await self._store.insert_scheduled_task(task) + + if status == TaskStatus.ACTIVE: + if self._run_background_tasks: + await self._launch_task(task) + else: + self._hs.get_replication_command_handler().send_new_active_task(task.id) + + return task.id + + async def update_task( + self, + id: str, + *, + timestamp: Optional[int] = None, + status: Optional[TaskStatus] = None, + result: Optional[JsonMapping] = None, + error: Optional[str] = None, + ) -> bool: + """Update some task associated values. This is exposed publically so it can + be used inside task functions, mainly to update the result and be able to + resume a task at a specific step after a restart of synapse. + + It can also be used to stage a task, by setting the `status` to `SCHEDULED` with + a new timestamp. + + The `status` can only be set to `ACTIVE` or `SCHEDULED`, `COMPLETE` and `FAILED` + are terminal status and can only be set by returning it in the function. + + Args: + id: the id of the task to update + timestamp: useful to schedule a new stage of the task at a later date + status: the new `TaskStatus` of the task + result: the new result of the task + error: the new error of the task + """ + if status == TaskStatus.COMPLETE or status == TaskStatus.FAILED: + raise Exception( + "update_task can't be called with a FAILED or COMPLETE status" + ) + + if timestamp is None: + timestamp = self._clock.time_msec() + return await self._store.update_scheduled_task( + id, + timestamp, + status=status, + result=result, + error=error, + ) + + async def get_task(self, id: str) -> Optional[ScheduledTask]: + """Get a specific task description by id. + + Args: + id: the id of the task to retrieve + + Returns: + The task information or `None` if it doesn't exist or it has + already been removed because it's too old. + """ + return await self._store.get_scheduled_task(id) + + async def get_tasks( + self, + *, + actions: Optional[List[str]] = None, + resource_id: Optional[str] = None, + statuses: Optional[List[TaskStatus]] = None, + max_timestamp: Optional[int] = None, + limit: Optional[int] = None, + ) -> List[ScheduledTask]: + """Get a list of tasks. Returns all the tasks if no args is provided. + + If an arg is `None` all tasks matching the other args will be selected. + If an arg is an empty list, the corresponding value of the task needs + to be `None` to be selected. + + Args: + actions: Limit the returned tasks to those specific action names + resource_id: Limit the returned tasks to the specific resource id, if specified + statuses: Limit the returned tasks to the specific statuses + max_timestamp: Limit the returned tasks to the ones that have + a timestamp inferior to the specified one + limit: Only return `limit` number of rows if set. + + Returns + A list of `ScheduledTask`, ordered by increasing timestamps + """ + return await self._store.get_scheduled_tasks( + actions=actions, + resource_id=resource_id, + statuses=statuses, + max_timestamp=max_timestamp, + limit=limit, + ) + + async def delete_task(self, id: str) -> None: + """Delete a task. Running tasks can't be deleted. + + Can only be called from the worker handling the task scheduling. + + Args: + id: id of the task to delete + """ + task = await self.get_task(id) + if task is None: + raise Exception(f"Task {id} does not exist") + if task.status == TaskStatus.ACTIVE: + raise Exception(f"Task {id} is currently ACTIVE and can't be deleted") + await self._store.delete_scheduled_task(id) + + def launch_task_by_id(self, id: str) -> None: + """Try launching the task with the given ID.""" + # Don't bother trying to launch new tasks if we're already at capacity. + if len(self._running_tasks) >= TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS: + return + + run_as_background_process("launch_task_by_id", self._launch_task_by_id, id) + + async def _launch_task_by_id(self, id: str) -> None: + """Helper async function for `launch_task_by_id`.""" + task = await self.get_task(id) + if task: + await self._launch_task(task) + + @wrap_as_background_process("launch_scheduled_tasks") + async def _launch_scheduled_tasks(self) -> None: + """Retrieve and launch scheduled tasks that should be running at that time.""" + # Don't bother trying to launch new tasks if we're already at capacity. + if len(self._running_tasks) >= TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS: + return + + if self._launching_new_tasks: + return + + self._launching_new_tasks = True + + try: + for task in await self.get_tasks( + statuses=[TaskStatus.ACTIVE], limit=self.MAX_CONCURRENT_RUNNING_TASKS + ): + await self._launch_task(task) + for task in await self.get_tasks( + statuses=[TaskStatus.SCHEDULED], + max_timestamp=self._clock.time_msec(), + limit=self.MAX_CONCURRENT_RUNNING_TASKS, + ): + await self._launch_task(task) + + finally: + self._launching_new_tasks = False + + @wrap_as_background_process("clean_scheduled_tasks") + async def _clean_scheduled_tasks(self) -> None: + """Clean old complete or failed jobs to avoid clutter the DB.""" + now = self._clock.time_msec() + for task in await self._store.get_scheduled_tasks( + statuses=[TaskStatus.FAILED, TaskStatus.COMPLETE], + max_timestamp=now - TaskScheduler.KEEP_TASKS_FOR_MS, + ): + # FAILED and COMPLETE tasks should never be running + assert task.id not in self._running_tasks + await self._store.delete_scheduled_task(task.id) + + async def _launch_task(self, task: ScheduledTask) -> None: + """Launch a scheduled task now. + + Args: + task: the task to launch + """ + assert self._run_background_tasks + + if task.action not in self._actions: + raise Exception( + f"No function associated with action {task.action} of the scheduled task {task.id}" + ) + function = self._actions[task.action] + + async def wrapper() -> None: + with nested_logging_context(task.id): + try: + (status, result, error) = await function(task) + except Exception: + f = Failure() + logger.error( + f"scheduled task {task.id} failed", + exc_info=(f.type, f.value, f.getTracebackObject()), + ) + status = TaskStatus.FAILED + result = None + error = f.getErrorMessage() + + await self._store.update_scheduled_task( + task.id, + self._clock.time_msec(), + status=status, + result=result, + error=error, + ) + self._running_tasks.remove(task.id) + + # Try launch a new task since we've finished with this one. + self._clock.call_later(1, self._launch_scheduled_tasks) + + if len(self._running_tasks) >= TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS: + return + + if ( + self._clock.time_msec() + > task.timestamp + TaskScheduler.LAST_UPDATE_BEFORE_WARNING_MS + ): + logger.warn( + f"Task {task.id} (action {task.action}) has seen no update for more than 24h and may be stuck" + ) + + if task.id in self._running_tasks: + return + + self._running_tasks.add(task.id) + await self.update_task(task.id, status=TaskStatus.ACTIVE) + run_as_background_process(f"task-{task.action}", wrapper) diff --git a/synapse/visibility.py b/synapse/visibility.py index 468e22f8f6..f15fdd8314 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -17,6 +17,7 @@ from enum import Enum, auto from typing import ( Collection, Dict, + Final, FrozenSet, List, Mapping, @@ -27,7 +28,6 @@ from typing import ( ) import attr -from typing_extensions import Final from synapse.api.constants import EventTypes, HistoryVisibility, Membership from synapse.events import EventBase @@ -36,12 +36,12 @@ from synapse.events.utils import prune_event from synapse.logging.opentracing import trace from synapse.storage.controllers import StorageControllers from synapse.storage.databases.main import DataStore -from synapse.types import RetentionPolicy, StateMap, get_domain_from_id +from synapse.types import RetentionPolicy, StateMap, StrCollection, get_domain_from_id from synapse.types.state import StateFilter from synapse.util import Clock logger = logging.getLogger(__name__) - +filtered_event_logger = logging.getLogger("synapse.visibility.filtered_event_debug") VISIBILITY_PRIORITY = ( HistoryVisibility.WORLD_READABLE, @@ -97,8 +97,8 @@ async def filter_events_for_client( events_before_filtering = events events = [e for e in events if not e.internal_metadata.is_soft_failed()] if len(events_before_filtering) != len(events): - if logger.isEnabledFor(logging.DEBUG): - logger.debug( + if filtered_event_logger.isEnabledFor(logging.DEBUG): + filtered_event_logger.debug( "filter_events_for_client: Filtered out soft-failed events: Before=%s, After=%s", [event.event_id for event in events_before_filtering], [event.event_id for event in events], @@ -150,12 +150,12 @@ async def filter_events_for_client( async def filter_event_for_clients_with_state( store: DataStore, - user_ids: Collection[str], + user_ids: StrCollection, event: EventBase, context: EventContext, is_peeking: bool = False, filter_send_to_client: bool = True, -) -> Collection[str]: +) -> StrCollection: """ Checks to see if an event is visible to the users in the list at the time of the event. @@ -319,7 +319,7 @@ def _check_client_allowed_to_see_event( _check_filter_send_to_client(event, clock, retention_policy, sender_ignored) == _CheckFilter.DENIED ): - logger.debug( + filtered_event_logger.debug( "_check_client_allowed_to_see_event(event=%s): Filtered out event because `_check_filter_send_to_client` returned `_CheckFilter.DENIED`", event.event_id, ) @@ -341,7 +341,7 @@ def _check_client_allowed_to_see_event( ) return event - logger.debug( + filtered_event_logger.debug( "_check_client_allowed_to_see_event(event=%s): Filtered out event because it's an outlier", event.event_id, ) @@ -367,7 +367,7 @@ def _check_client_allowed_to_see_event( membership_result = _check_membership(user_id, event, visibility, state, is_peeking) if not membership_result.allowed: - logger.debug( + filtered_event_logger.debug( "_check_client_allowed_to_see_event(event=%s): Filtered out event because the user can't see the event because of their membership, membership_result.allowed=%s membership_result.joined=%s", event.event_id, membership_result.allowed, @@ -378,7 +378,7 @@ def _check_client_allowed_to_see_event( # If the sender has been erased and the user was not joined at the time, we # must only return the redacted form. if sender_erased and not membership_result.joined: - logger.debug( + filtered_event_logger.debug( "_check_client_allowed_to_see_event(event=%s): Returning pruned event because `sender_erased` and the user was not joined at the time", event.event_id, ) |