diff options
Diffstat (limited to 'synapse/handlers')
27 files changed, 1499 insertions, 1018 deletions
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index d4fe7df533..cf9f19608a 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -70,6 +70,7 @@ class AdminHandler: "appservice_id", "consent_server_notice_sent", "consent_version", + "consent_ts", "user_type", "is_guest", } diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 814553e098..203b62e015 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -104,14 +104,15 @@ class ApplicationServicesHandler: with Measure(self.clock, "notify_interested_services"): self.is_processing = True try: - limit = 100 upper_bound = -1 while upper_bound < self.current_max: + last_token = await self.store.get_appservice_last_pos() ( upper_bound, events, - ) = await self.store.get_new_events_for_appservice( - self.current_max, limit + event_to_received_ts, + ) = await self.store.get_all_new_events_stream( + last_token, self.current_max, limit=100, get_prev_content=True ) events_by_room: Dict[str, List[EventBase]] = {} @@ -150,7 +151,7 @@ class ApplicationServicesHandler: ) now = self.clock.time_msec() - ts = await self.store.get_received_ts(event.event_id) + ts = event_to_received_ts[event.event_id] assert ts is not None synapse.metrics.event_processing_lag_by_event.labels( @@ -187,7 +188,7 @@ class ApplicationServicesHandler: if events: now = self.clock.time_msec() - ts = await self.store.get_received_ts(events[-1].event_id) + ts = event_to_received_ts[events[-1].event_id] assert ts is not None synapse.metrics.event_processing_lag.labels( diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 3d83236b0c..0327fc57a4 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -280,7 +280,7 @@ class AuthHandler: that it isn't stolen by re-authenticating them. Args: - requester: The user, as given by the access token + requester: The user making the request, according to the access token. request: The request sent by the client. @@ -565,7 +565,7 @@ class AuthHandler: except LoginError as e: # this step failed. Merge the error dict into the response # so that the client can have another go. - errordict = e.error_dict() + errordict = e.error_dict(self.hs.config) creds = await self.store.get_completed_ui_auth_stages(session.session_id) for f in flows: @@ -1435,20 +1435,25 @@ class AuthHandler: access_token: access token to be deleted """ - user_info = await self.auth.get_user_by_access_token(access_token) + token = await self.store.get_user_by_access_token(access_token) + if not token: + # At this point, the token should already have been fetched once by + # the caller, so this should not happen, unless of a race condition + # between two delete requests + raise SynapseError(HTTPStatus.UNAUTHORIZED, "Unrecognised access token") await self.store.delete_access_token(access_token) # see if any modules want to know about this await self.password_auth_provider.on_logged_out( - user_id=user_info.user_id, - device_id=user_info.device_id, + user_id=token.user_id, + device_id=token.device_id, access_token=access_token, ) # delete pushers associated with this access token - if user_info.token_id is not None: + if token.token_id is not None: await self.hs.get_pusherpool().remove_pushers_by_access_token( - user_info.user_id, (user_info.token_id,) + token.user_id, (token.token_id,) ) async def delete_access_tokens_for_user( diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index c05a170c55..901e2310b7 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -45,13 +45,13 @@ from synapse.types import ( JsonDict, StreamKeyType, StreamToken, - UserID, get_domain_from_id, get_verify_key_from_cross_signing_key, ) from synapse.util import stringutils from synapse.util.async_helpers import Linearizer from synapse.util.caches.expiringcache import ExpiringCache +from synapse.util.cancellation import cancellable from synapse.util.metrics import measure_func from synapse.util.retryutils import NotRetryingDestination @@ -74,6 +74,7 @@ class DeviceWorkerHandler: self._state_storage = hs.get_storage_controllers().state self._auth_handler = hs.get_auth_handler() self.server_name = hs.hostname + self._msc3852_enabled = hs.config.experimental.msc3852_enabled @trace async def get_devices_by_user(self, user_id: str) -> List[JsonDict]: @@ -118,11 +119,12 @@ class DeviceWorkerHandler: ips = await self.store.get_last_client_ip_by_device(user_id, device_id) _update_device_from_client_ips(device, ips) - set_tag("device", device) - set_tag("ips", ips) + set_tag("device", str(device)) + set_tag("ips", str(ips)) return device + @cancellable async def get_device_changes_in_shared_rooms( self, user_id: str, room_ids: Collection[str], from_token: StreamToken ) -> Collection[str]: @@ -162,6 +164,7 @@ class DeviceWorkerHandler: @trace @measure_func("device.get_user_ids_changed") + @cancellable async def get_user_ids_changed( self, user_id: str, from_token: StreamToken ) -> JsonDict: @@ -170,7 +173,7 @@ class DeviceWorkerHandler: """ set_tag("user_id", user_id) - set_tag("from_token", from_token) + set_tag("from_token", str(from_token)) now_room_key = self.store.get_room_max_token() room_ids = await self.store.get_rooms_for_user(user_id) @@ -309,6 +312,7 @@ class DeviceHandler(DeviceWorkerHandler): super().__init__(hs) self.federation_sender = hs.get_federation_sender() + self._storage_controllers = hs.get_storage_controllers() self.device_list_updater = DeviceListUpdater(hs, self) @@ -319,8 +323,6 @@ class DeviceHandler(DeviceWorkerHandler): self.device_list_updater.incoming_device_list_update, ) - hs.get_distributor().observe("user_left_room", self.user_left_room) - # Whether `_handle_new_device_update_async` is currently processing. self._handle_new_device_update_is_processing = False @@ -564,14 +566,6 @@ class DeviceHandler(DeviceWorkerHandler): StreamKeyType.DEVICE_LIST, position, users=[from_user_id] ) - async def user_left_room(self, user: UserID, room_id: str) -> None: - user_id = user.to_string() - room_ids = await self.store.get_rooms_for_user(user_id) - if not room_ids: - # We no longer share rooms with this user, so we'll no longer - # receive device updates. Mark this in DB. - await self.store.mark_remote_user_device_list_as_unsubscribed(user_id) - async def store_dehydrated_device( self, user_id: str, @@ -693,8 +687,11 @@ class DeviceHandler(DeviceWorkerHandler): # Ignore any users that aren't ours if self.hs.is_mine_id(user_id): - joined_user_ids = await self.store.get_users_in_room(room_id) - hosts = {get_domain_from_id(u) for u in joined_user_ids} + hosts = set( + await self._storage_controllers.state.get_current_hosts_in_room( + room_id + ) + ) hosts.discard(self.server_name) # Check if we've already sent this update to some hosts @@ -747,7 +744,13 @@ def _update_device_from_client_ips( device: JsonDict, client_ips: Mapping[Tuple[str, str], Mapping[str, Any]] ) -> None: ip = client_ips.get((device["user_id"], device["device_id"]), {}) - device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")}) + device.update( + { + "last_seen_user_agent": ip.get("user_agent"), + "last_seen_ts": ip.get("last_seen"), + "last_seen_ip": ip.get("ip"), + } + ) class DeviceListUpdater: @@ -795,7 +798,7 @@ class DeviceListUpdater: """ set_tag("origin", origin) - set_tag("edu_content", edu_content) + set_tag("edu_content", str(edu_content)) user_id = edu_content.pop("user_id") device_id = edu_content.pop("device_id") stream_id = str(edu_content.pop("stream_id")) # They may come as ints diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 09a7a4b238..7127d5aefc 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -30,7 +30,7 @@ from synapse.api.errors import ( from synapse.appservice import ApplicationService from synapse.module_api import NOT_SPAM from synapse.storage.databases.main.directory import RoomAliasMapping -from synapse.types import JsonDict, Requester, RoomAlias, UserID, get_domain_from_id +from synapse.types import JsonDict, Requester, RoomAlias if TYPE_CHECKING: from synapse.server import HomeServer @@ -83,8 +83,9 @@ class DirectoryHandler: # TODO(erikj): Add transactions. # TODO(erikj): Check if there is a current association. if not servers: - users = await self.store.get_users_in_room(room_id) - servers = {get_domain_from_id(u) for u in users} + servers = await self._storage_controllers.state.get_current_hosts_in_room( + room_id + ) if not servers: raise SynapseError(400, "Failed to get server list") @@ -133,7 +134,7 @@ class DirectoryHandler: else: # Server admins are not subject to the same constraints as normal # users when creating an alias (e.g. being in the room). - is_admin = await self.auth.is_server_admin(requester.user) + is_admin = await self.auth.is_server_admin(requester) if (self.require_membership and check_membership) and not is_admin: rooms_for_user = await self.store.get_rooms_for_user(user_id) @@ -197,7 +198,7 @@ class DirectoryHandler: user_id = requester.user.to_string() try: - can_delete = await self._user_can_delete_alias(room_alias, user_id) + can_delete = await self._user_can_delete_alias(room_alias, requester) except StoreError as e: if e.code == 404: raise NotFoundError("Unknown room alias") @@ -287,8 +288,9 @@ class DirectoryHandler: Codes.NOT_FOUND, ) - users = await self.store.get_users_in_room(room_id) - extra_servers = {get_domain_from_id(u) for u in users} + extra_servers = await self._storage_controllers.state.get_current_hosts_in_room( + room_id + ) servers_set = set(extra_servers) | set(servers) # If this server is in the list of servers, return it first. @@ -400,7 +402,9 @@ class DirectoryHandler: # either no interested services, or no service with an exclusive lock return True - async def _user_can_delete_alias(self, alias: RoomAlias, user_id: str) -> bool: + async def _user_can_delete_alias( + self, alias: RoomAlias, requester: Requester + ) -> bool: """Determine whether a user can delete an alias. One of the following must be true: @@ -413,7 +417,7 @@ class DirectoryHandler: """ creator = await self.store.get_room_alias_creator(alias.to_string()) - if creator == user_id: + if creator == requester.user.to_string(): return True # Resolve the alias to the corresponding room. @@ -422,9 +426,7 @@ class DirectoryHandler: if not room_id: return False - return await self.auth.check_can_change_room_list( - room_id, UserID.from_string(user_id) - ) + return await self.auth.check_can_change_room_list(room_id, requester) async def edit_published_room_list( self, requester: Requester, room_id: str, visibility: str @@ -463,7 +465,7 @@ class DirectoryHandler: raise SynapseError(400, "Unknown room") can_change_room_list = await self.auth.check_can_change_room_list( - room_id, requester.user + room_id, requester ) if not can_change_room_list: raise AuthError( @@ -528,10 +530,8 @@ class DirectoryHandler: Get a list of the aliases that currently point to this room on this server """ # allow access to server admins and current members of the room - is_admin = await self.auth.is_server_admin(requester.user) + is_admin = await self.auth.is_server_admin(requester) if not is_admin: - await self.auth.check_user_in_room_or_world_readable( - room_id, requester.user.to_string() - ) + await self.auth.check_user_in_room_or_world_readable(room_id, requester) return await self.store.get_aliases_for_room(room_id) diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 52bb5c9c55..8eed63ccf3 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -15,7 +15,7 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Tuple import attr from canonicaljson import encode_canonical_json @@ -37,7 +37,8 @@ from synapse.types import ( get_verify_key_from_cross_signing_key, ) from synapse.util import json_decoder, unwrapFirstError -from synapse.util.async_helpers import Linearizer +from synapse.util.async_helpers import Linearizer, delay_cancellation +from synapse.util.cancellation import cancellable from synapse.util.retryutils import NotRetryingDestination if TYPE_CHECKING: @@ -91,8 +92,13 @@ class E2eKeysHandler: ) @trace + @cancellable async def query_devices( - self, query_body: JsonDict, timeout: int, from_user_id: str, from_device_id: str + self, + query_body: JsonDict, + timeout: int, + from_user_id: str, + from_device_id: Optional[str], ) -> JsonDict: """Handle a device key query from a client @@ -120,9 +126,7 @@ class E2eKeysHandler: the number of in-flight queries at a time. """ async with self._query_devices_linearizer.queue((from_user_id, from_device_id)): - device_keys_query: Dict[str, Iterable[str]] = query_body.get( - "device_keys", {} - ) + device_keys_query: Dict[str, List[str]] = query_body.get("device_keys", {}) # separate users by domain. # make a map from domain to user_id to device_ids @@ -136,8 +140,8 @@ class E2eKeysHandler: else: remote_queries[user_id] = device_ids - set_tag("local_key_query", local_query) - set_tag("remote_key_query", remote_queries) + set_tag("local_key_query", str(local_query)) + set_tag("remote_key_query", str(remote_queries)) # First get local devices. # A map of destination -> failure response. @@ -171,6 +175,32 @@ class E2eKeysHandler: user_ids_not_in_cache, remote_results, ) = await self.store.get_user_devices_from_cache(query_list) + + # Check that the homeserver still shares a room with all cached users. + # Note that this check may be slightly racy when a remote user leaves a + # room after we have fetched their cached device list. In the worst case + # we will do extra federation queries for devices that we had cached. + cached_users = set(remote_results.keys()) + valid_cached_users = ( + await self.store.get_users_server_still_shares_room_with( + remote_results.keys() + ) + ) + invalid_cached_users = cached_users - valid_cached_users + if invalid_cached_users: + # Fix up results. If we get here, there is either a bug in device + # list tracking, or we hit the race mentioned above. + user_ids_not_in_cache.update(invalid_cached_users) + for invalid_user_id in invalid_cached_users: + remote_results.pop(invalid_user_id) + # This log message may be removed if it turns out it's almost + # entirely triggered by races. + logger.error( + "Devices for %s were cached, but the server no longer shares " + "any rooms with them. The cached device lists are stale.", + invalid_cached_users, + ) + for user_id, devices in remote_results.items(): user_devices = results.setdefault(user_id, {}) for device_id, device in devices.items(): @@ -206,22 +236,26 @@ class E2eKeysHandler: r[user_id] = remote_queries[user_id] # Now fetch any devices that we don't have in our cache + # TODO It might make sense to propagate cancellations into the + # deferreds which are querying remote homeservers. await make_deferred_yieldable( - defer.gatherResults( - [ - run_in_background( - self._query_devices_for_destination, - results, - cross_signing_keys, - failures, - destination, - queries, - timeout, - ) - for destination, queries in remote_queries_not_in_cache.items() - ], - consumeErrors=True, - ).addErrback(unwrapFirstError) + delay_cancellation( + defer.gatherResults( + [ + run_in_background( + self._query_devices_for_destination, + results, + cross_signing_keys, + failures, + destination, + queries, + timeout, + ) + for destination, queries in remote_queries_not_in_cache.items() + ], + consumeErrors=True, + ).addErrback(unwrapFirstError) + ) ) ret = {"device_keys": results, "failures": failures} @@ -341,10 +375,11 @@ class E2eKeysHandler: failure = _exception_to_failure(e) failures[destination] = failure set_tag("error", True) - set_tag("reason", failure) + set_tag("reason", str(failure)) return + @cancellable async def get_cross_signing_keys_from_cache( self, query: Iterable[str], from_user_id: Optional[str] ) -> Dict[str, Dict[str, dict]]: @@ -391,8 +426,9 @@ class E2eKeysHandler: } @trace + @cancellable async def query_local_devices( - self, query: Dict[str, Optional[List[str]]] + self, query: Mapping[str, Optional[List[str]]] ) -> Dict[str, Dict[str, dict]]: """Get E2E device keys for local users @@ -403,7 +439,7 @@ class E2eKeysHandler: Returns: A map from user_id -> device_id -> device details """ - set_tag("local_query", query) + set_tag("local_query", str(query)) local_query: List[Tuple[str, Optional[str]]] = [] result_dict: Dict[str, Dict[str, dict]] = {} @@ -461,7 +497,7 @@ class E2eKeysHandler: @trace async def claim_one_time_keys( - self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: int + self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: Optional[int] ) -> JsonDict: local_query: List[Tuple[str, str, str]] = [] remote_queries: Dict[str, Dict[str, Dict[str, str]]] = {} @@ -475,8 +511,8 @@ class E2eKeysHandler: domain = get_domain_from_id(user_id) remote_queries.setdefault(domain, {})[user_id] = one_time_keys - set_tag("local_key_query", local_query) - set_tag("remote_key_query", remote_queries) + set_tag("local_key_query", str(local_query)) + set_tag("remote_key_query", str(remote_queries)) results = await self.store.claim_e2e_one_time_keys(local_query) @@ -506,7 +542,7 @@ class E2eKeysHandler: failure = _exception_to_failure(e) failures[destination] = failure set_tag("error", True) - set_tag("reason", failure) + set_tag("reason", str(failure)) await make_deferred_yieldable( defer.gatherResults( @@ -609,7 +645,7 @@ class E2eKeysHandler: result = await self.store.count_e2e_one_time_keys(user_id, device_id) - set_tag("one_time_key_counts", result) + set_tag("one_time_key_counts", str(result)) return {"one_time_key_counts": result} async def _upload_one_time_keys_for_user( diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py index 446f509bdc..28dc08c22a 100644 --- a/synapse/handlers/e2e_room_keys.py +++ b/synapse/handlers/e2e_room_keys.py @@ -14,7 +14,7 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Dict, Optional +from typing import TYPE_CHECKING, Dict, Optional, cast from typing_extensions import Literal @@ -97,7 +97,7 @@ class E2eRoomKeysHandler: user_id, version, room_id, session_id ) - log_kv(results) + log_kv(cast(JsonDict, results)) return results @trace diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py index a2dd9c7efa..c3ddc5d182 100644 --- a/synapse/handlers/event_auth.py +++ b/synapse/handlers/event_auth.py @@ -129,12 +129,9 @@ class EventAuthHandler: else: users = {} - # Find the user with the highest power level. - users_in_room = await self._store.get_users_in_room(room_id) - # Only interested in local users. - local_users_in_room = [ - u for u in users_in_room if get_domain_from_id(u) == self._server_name - ] + # Find the user with the highest power level (only interested in local + # users). + local_users_in_room = await self._store.get_local_users_in_room(room_id) chosen_user = max( local_users_in_room, key=lambda user: users.get(user, users_default_level), diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index ac13340d3a..949b69cb41 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -151,7 +151,7 @@ class EventHandler: """Retrieve a single specified event. Args: - user: The user requesting the event + user: The local user requesting the event room_id: The expected room id. We'll return None if the event's room does not match. event_id: The event ID to obtain. @@ -173,8 +173,11 @@ class EventHandler: if not event: return None - users = await self.store.get_users_in_room(event.room_id) - is_peeking = user.to_string() not in users + is_user_in_room = await self.store.check_local_user_in_room( + user_id=user.to_string(), room_id=event.room_id + ) + # The user is peeking if they aren't in the room already + is_peeking = not is_user_in_room filtered = await filter_events_for_client( self._storage_controllers, user.to_string(), [event], is_peeking=is_peeking diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 3b5eaf5156..dd4b9f66d1 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -32,6 +32,7 @@ from typing import ( ) import attr +from prometheus_client import Histogram from signedjson.key import decode_verify_key_bytes from signedjson.sign import verify_signed_json from unpaddedbase64 import decode_base64 @@ -59,6 +60,7 @@ from synapse.events.validator import EventValidator from synapse.federation.federation_client import InvalidResponseError from synapse.http.servlet import assert_params_in_dict from synapse.logging.context import nested_logging_context +from synapse.logging.opentracing import SynapseTags, set_tag, tag_args, trace from synapse.metrics.background_process_metrics import run_as_background_process from synapse.module_api import NOT_SPAM from synapse.replication.http.federation import ( @@ -68,7 +70,7 @@ from synapse.replication.http.federation import ( from synapse.storage.databases.main.events import PartialStateConflictError from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.state import StateFilter -from synapse.types import JsonDict, StateMap, get_domain_from_id +from synapse.types import JsonDict, get_domain_from_id from synapse.util.async_helpers import Linearizer from synapse.util.retryutils import NotRetryingDestination from synapse.visibility import filter_events_for_server @@ -78,36 +80,28 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) - -def get_domains_from_state(state: StateMap[EventBase]) -> List[Tuple[str, int]]: - """Get joined domains from state - - Args: - state: State map from type/state key to event. - - Returns: - Returns a list of servers with the lowest depth of their joins. - Sorted by lowest depth first. - """ - joined_users = [ - (state_key, int(event.depth)) - for (e_type, state_key), event in state.items() - if e_type == EventTypes.Member and event.membership == Membership.JOIN - ] - - joined_domains: Dict[str, int] = {} - for u, d in joined_users: - try: - dom = get_domain_from_id(u) - old_d = joined_domains.get(dom) - if old_d: - joined_domains[dom] = min(d, old_d) - else: - joined_domains[dom] = d - except Exception: - pass - - return sorted(joined_domains.items(), key=lambda d: d[1]) +# Added to debug performance and track progress on optimizations +backfill_processing_before_timer = Histogram( + "synapse_federation_backfill_processing_before_time_seconds", + "sec", + [], + buckets=( + 0.1, + 0.5, + 1.0, + 2.5, + 5.0, + 7.5, + 10.0, + 15.0, + 20.0, + 30.0, + 40.0, + 60.0, + 80.0, + "+Inf", + ), +) class _BackfillPointType(Enum): @@ -137,6 +131,7 @@ class FederationHandler: def __init__(self, hs: "HomeServer"): self.hs = hs + self.clock = hs.get_clock() self.store = hs.get_datastores().main self._storage_controllers = hs.get_storage_controllers() self._state_storage_controller = self._storage_controllers.state @@ -180,6 +175,7 @@ class FederationHandler: "resume_sync_partial_state_room", self._resume_sync_partial_state_room ) + @trace async def maybe_backfill( self, room_id: str, current_depth: int, limit: int ) -> bool: @@ -195,12 +191,39 @@ class FederationHandler: return. This is used as part of the heuristic to decide if we should back paginate. """ + # Starting the processing time here so we can include the room backfill + # linearizer lock queue in the timing + processing_start_time = self.clock.time_msec() + async with self._room_backfill.queue(room_id): - return await self._maybe_backfill_inner(room_id, current_depth, limit) + return await self._maybe_backfill_inner( + room_id, + current_depth, + limit, + processing_start_time=processing_start_time, + ) async def _maybe_backfill_inner( - self, room_id: str, current_depth: int, limit: int + self, + room_id: str, + current_depth: int, + limit: int, + *, + processing_start_time: int, ) -> bool: + """ + Checks whether the `current_depth` is at or approaching any backfill + points in the room and if so, will backfill. We only care about + checking backfill points that happened before the `current_depth` + (meaning less than or equal to the `current_depth`). + + Args: + room_id: The room to backfill in. + current_depth: The depth to check at for any upcoming backfill points. + limit: The max number of events to request from the remote federated server. + processing_start_time: The time when `maybe_backfill` started + processing. Only used for timing. + """ backwards_extremities = [ _BackfillPoint(event_id, depth, _BackfillPointType.BACKWARDS_EXTREMITY) for event_id, depth in await self.store.get_oldest_event_ids_with_depth_in_room( @@ -368,23 +391,29 @@ class FederationHandler: logger.debug( "_maybe_backfill_inner: extremities_to_request %s", extremities_to_request ) + set_tag( + SynapseTags.RESULT_PREFIX + "extremities_to_request", + str(extremities_to_request), + ) + set_tag( + SynapseTags.RESULT_PREFIX + "extremities_to_request.length", + str(len(extremities_to_request)), + ) # Now we need to decide which hosts to hit first. - - # First we try hosts that are already in the room + # First we try hosts that are already in the room. # TODO: HEURISTIC ALERT. + likely_domains = ( + await self._storage_controllers.state.get_current_hosts_in_room(room_id) + ) - curr_state = await self._storage_controllers.state.get_current_state(room_id) - - curr_domains = get_domains_from_state(curr_state) - - likely_domains = [ - domain for domain, depth in curr_domains if domain != self.server_name - ] - - async def try_backfill(domains: List[str]) -> bool: + async def try_backfill(domains: Collection[str]) -> bool: # TODO: Should we try multiple of these at a time? for dom in domains: + # We don't want to ask our own server for information we don't have + if dom == self.server_name: + continue + try: await self._federation_event_handler.backfill( dom, room_id, limit=100, extremities=extremities_to_request @@ -423,6 +452,11 @@ class FederationHandler: return False + processing_end_time = self.clock.time_msec() + backfill_processing_before_timer.observe( + (processing_end_time - processing_start_time) / 1000 + ) + success = await try_backfill(likely_domains) if success: return True @@ -546,9 +580,9 @@ class FederationHandler: ) if ret.partial_state: - # TODO(faster_joins): roll this back if we don't manage to start the - # background resync (eg process_remote_join fails) - # https://github.com/matrix-org/synapse/issues/12998 + # Mark the room as having partial state. + # The background process is responsible for unmarking this flag, + # even if the join fails. await self.store.store_partial_state_room(room_id, ret.servers_in_room) try: @@ -574,17 +608,21 @@ class FederationHandler: room_id, ) raise LimitExceededError(msg=e.msg, errcode=e.errcode, retry_after_ms=0) - - if ret.partial_state: - # Kick off the process of asynchronously fetching the state for this - # room. - run_as_background_process( - desc="sync_partial_state_room", - func=self._sync_partial_state_room, - initial_destination=origin, - other_destinations=ret.servers_in_room, - room_id=room_id, - ) + finally: + # Always kick off the background process that asynchronously fetches + # state for the room. + # If the join failed, the background process is responsible for + # cleaning up — including unmarking the room as a partial state room. + if ret.partial_state: + # Kick off the process of asynchronously fetching the state for this + # room. + run_as_background_process( + desc="sync_partial_state_room", + func=self._sync_partial_state_room, + initial_destination=origin, + other_destinations=ret.servers_in_room, + room_id=room_id, + ) # We wait here until this instance has seen the events come down # replication (if we're using replication) as the below uses caches. @@ -748,6 +786,23 @@ class FederationHandler: # (and return a 404 otherwise) room_version = await self.store.get_room_version(room_id) + if await self.store.is_partial_state_room(room_id): + # If our server is still only partially joined, we can't give a complete + # response to /make_join, so return a 404 as we would if we weren't in the + # room at all. + # The main reason we can't respond properly is that we need to know about + # the auth events for the join event that we would return. + # We also should not bother entertaining the /make_join since we cannot + # handle the /send_join. + logger.info( + "Rejecting /make_join to %s because it's a partial state room", room_id + ) + raise SynapseError( + 404, + "Unable to handle /make_join right now; this server is not fully joined.", + errcode=Codes.NOT_FOUND, + ) + # now check that we are *still* in the room is_in_room = await self._event_auth_handler.check_host_in_room( room_id, self.server_name @@ -1058,6 +1113,8 @@ class FederationHandler: return event + @trace + @tag_args async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]: """Returns the state at the event. i.e. not including said event.""" event = await self.store.get_event(event_id, check_room_id=room_id) @@ -1539,15 +1596,16 @@ class FederationHandler: # Make an infinite iterator of destinations to try. Once we find a working # destination, we'll stick with it until it flakes. + destinations: Collection[str] if initial_destination is not None: # Move `initial_destination` to the front of the list. destinations = list(other_destinations) if initial_destination in destinations: destinations.remove(initial_destination) destinations = [initial_destination] + destinations - destination_iter = itertools.cycle(destinations) else: - destination_iter = itertools.cycle(other_destinations) + destinations = other_destinations + destination_iter = itertools.cycle(destinations) # `destination` is the current remote homeserver we're pulling from. destination = next(destination_iter) diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 9ce5bea0ed..100785ebba 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import collections import itertools import logging from http import HTTPStatus @@ -28,7 +29,7 @@ from typing import ( Tuple, ) -from prometheus_client import Counter +from prometheus_client import Counter, Histogram from synapse import event_auth from synapse.api.constants import ( @@ -58,6 +59,13 @@ from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.federation.federation_client import InvalidResponseError from synapse.logging.context import nested_logging_context +from synapse.logging.opentracing import ( + SynapseTags, + set_tag, + start_active_span, + tag_args, + trace, +) from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet from synapse.replication.http.federation import ( @@ -90,6 +98,36 @@ soft_failed_event_counter = Counter( "Events received over federation that we marked as soft_failed", ) +# Added to debug performance and track progress on optimizations +backfill_processing_after_timer = Histogram( + "synapse_federation_backfill_processing_after_time_seconds", + "sec", + [], + buckets=( + 0.1, + 0.25, + 0.5, + 1.0, + 2.5, + 5.0, + 7.5, + 10.0, + 15.0, + 20.0, + 25.0, + 30.0, + 40.0, + 50.0, + 60.0, + 80.0, + 100.0, + 120.0, + 150.0, + 180.0, + "+Inf", + ), +) + class FederationEventHandler: """Handles events that originated from federation. @@ -277,7 +315,8 @@ class FederationEventHandler: ) try: - await self._process_received_pdu(origin, pdu, state_ids=None) + context = await self._state_handler.compute_event_context(pdu) + await self._process_received_pdu(origin, pdu, context) except PartialStateConflictError: # The room was un-partial stated while we were processing the PDU. # Try once more, with full state this time. @@ -285,7 +324,8 @@ class FederationEventHandler: "Room %s was un-partial stated while processing the PDU, trying again.", room_id, ) - await self._process_received_pdu(origin, pdu, state_ids=None) + context = await self._state_handler.compute_event_context(pdu) + await self._process_received_pdu(origin, pdu, context) async def on_send_membership_event( self, origin: str, event: EventBase @@ -315,6 +355,7 @@ class FederationEventHandler: The event and context of the event after inserting it into the room graph. Raises: + RuntimeError if any prev_events are missing SynapseError if the event is not accepted into the room PartialStateConflictError if the room was un-partial stated in between computing the state at the event and persisting it. The caller should @@ -347,7 +388,7 @@ class FederationEventHandler: event.internal_metadata.send_on_behalf_of = origin context = await self._state_handler.compute_event_context(event) - context = await self._check_event_auth(origin, event, context) + await self._check_event_auth(origin, event, context) if context.rejected: raise SynapseError( 403, f"{event.membership} event was rejected", Codes.FORBIDDEN @@ -375,7 +416,7 @@ class FederationEventHandler: # need to. await self._event_creation_handler.cache_joined_hosts_for_event(event, context) - await self._check_for_soft_fail(event, None, origin=origin) + await self._check_for_soft_fail(event, context=context, origin=origin) await self._run_push_actions_and_persist_event(event, context) return event, context @@ -405,6 +446,7 @@ class FederationEventHandler: prev_member_event, ) + @trace async def process_remote_join( self, origin: str, @@ -485,7 +527,7 @@ class FederationEventHandler: partial_state=partial_state, ) - context = await self._check_event_auth(origin, event, context) + await self._check_event_auth(origin, event, context) if context.rejected: raise SynapseError(400, "Join event was rejected") @@ -533,32 +575,36 @@ class FederationEventHandler: # # This is the same operation as we do when we receive a regular event # over federation. - state_ids = await self._resolve_state_at_missing_prevs(destination, event) - - # build a new state group for it if need be - context = await self._state_handler.compute_event_context( - event, - state_ids_before_event=state_ids, + context = await self._compute_event_context_with_maybe_missing_prevs( + destination, event ) if context.partial_state: # this can happen if some or all of the event's prev_events still have - # partial state - ie, an event has an earlier stream_ordering than one - # or more of its prev_events, so we de-partial-state it before its - # prev_events. + # partial state. We were careful to only pick events from the db without + # partial-state prev events, so that implies that a prev event has + # been persisted (with partial state) since we did the query. # - # TODO(faster_joins): we probably need to be more intelligent, and - # exclude partial-state prev_events from consideration - # https://github.com/matrix-org/synapse/issues/13001 + # So, let's just ignore `event` for now; when we re-run the db query + # we should instead get its partial-state prev event, which we will + # de-partial-state, and then come back to event. logger.warning( - "%s still has partial state: can't de-partial-state it yet", + "%s still has prev_events with partial state: can't de-partial-state it yet", event.event_id, ) return + + # since the state at this event has changed, we should now re-evaluate + # whether it should have been rejected. We must already have all of the + # auth events (from last time we went round this path), so there is no + # need to pass the origin. + await self._check_event_auth(None, event, context) + await self._store.update_state_for_partial_state_event(event, context) self._state_storage_controller.notify_event_un_partial_stated( event.event_id ) + @trace async def backfill( self, dest: str, room_id: str, limit: int, extremities: Collection[str] ) -> None: @@ -588,21 +634,23 @@ class FederationEventHandler: if not events: return - # if there are any events in the wrong room, the remote server is buggy and - # should not be trusted. - for ev in events: - if ev.room_id != room_id: - raise InvalidResponseError( - f"Remote server {dest} returned event {ev.event_id} which is in " - f"room {ev.room_id}, when we were backfilling in {room_id}" - ) + with backfill_processing_after_timer.time(): + # if there are any events in the wrong room, the remote server is buggy and + # should not be trusted. + for ev in events: + if ev.room_id != room_id: + raise InvalidResponseError( + f"Remote server {dest} returned event {ev.event_id} which is in " + f"room {ev.room_id}, when we were backfilling in {room_id}" + ) - await self._process_pulled_events( - dest, - events, - backfilled=True, - ) + await self._process_pulled_events( + dest, + events, + backfilled=True, + ) + @trace async def _get_missing_events_for_pdu( self, origin: str, pdu: EventBase, prevs: Set[str], min_depth: int ) -> None: @@ -703,8 +751,9 @@ class FederationEventHandler: logger.info("Got %d prev_events", len(missing_events)) await self._process_pulled_events(origin, missing_events, backfilled=False) + @trace async def _process_pulled_events( - self, origin: str, events: Iterable[EventBase], backfilled: bool + self, origin: str, events: Collection[EventBase], backfilled: bool ) -> None: """Process a batch of events we have pulled from a remote server @@ -719,6 +768,15 @@ class FederationEventHandler: backfilled: True if this is part of a historical batch of events (inhibits notification to clients, and validation of device keys.) """ + set_tag( + SynapseTags.FUNC_ARG_PREFIX + "event_ids", + str([event.event_id for event in events]), + ) + set_tag( + SynapseTags.FUNC_ARG_PREFIX + "event_ids.length", + str(len(events)), + ) + set_tag(SynapseTags.FUNC_ARG_PREFIX + "backfilled", str(backfilled)) logger.debug( "processing pulled backfilled=%s events=%s", backfilled, @@ -741,6 +799,8 @@ class FederationEventHandler: with nested_logging_context(ev.event_id): await self._process_pulled_event(origin, ev, backfilled=backfilled) + @trace + @tag_args async def _process_pulled_event( self, origin: str, event: EventBase, backfilled: bool ) -> None: @@ -765,10 +825,24 @@ class FederationEventHandler: """ logger.info("Processing pulled event %s", event) - # these should not be outliers. - assert ( - not event.internal_metadata.is_outlier() - ), "pulled event unexpectedly flagged as outlier" + # This function should not be used to persist outliers (use something + # else) because this does a bunch of operations that aren't necessary + # (extra work; in particular, it makes sure we have all the prev_events + # and resolves the state across those prev events). If you happen to run + # into a situation where the event you're trying to process/backfill is + # marked as an `outlier`, then you should update that spot to return an + # `EventBase` copy that doesn't have `outlier` flag set. + # + # `EventBase` is used to represent both an event we have not yet + # persisted, and one that we have persisted and now keep in the cache. + # In an ideal world this method would only be called with the first type + # of event, but it turns out that's not actually the case and for + # example, you could get an event from cache that is marked as an + # `outlier` (fix up that spot though). + assert not event.internal_metadata.is_outlier(), ( + "Outlier event passed to _process_pulled_event. " + "To persist an event as a non-outlier, make sure to pass in a copy without `event.internal_metadata.outlier = true`." + ) event_id = event.event_id @@ -778,7 +852,7 @@ class FederationEventHandler: if existing: if not existing.internal_metadata.is_outlier(): logger.info( - "Ignoring received event %s which we have already seen", + "_process_pulled_event: Ignoring received event %s which we have already seen", event_id, ) return @@ -788,32 +862,66 @@ class FederationEventHandler: self._sanity_check_event(event) except SynapseError as err: logger.warning("Event %s failed sanity check: %s", event_id, err) + await self._store.record_event_failed_pull_attempt( + event.room_id, event_id, str(err) + ) return try: - state_ids = await self._resolve_state_at_missing_prevs(origin, event) - # TODO(faster_joins): make sure that _resolve_state_at_missing_prevs does - # not return partial state - # https://github.com/matrix-org/synapse/issues/13002 + try: + context = await self._compute_event_context_with_maybe_missing_prevs( + origin, event + ) + await self._process_received_pdu( + origin, + event, + context, + backfilled=backfilled, + ) + except PartialStateConflictError: + # The room was un-partial stated while we were processing the event. + # Try once more, with full state this time. + context = await self._compute_event_context_with_maybe_missing_prevs( + origin, event + ) - await self._process_received_pdu( - origin, event, state_ids=state_ids, backfilled=backfilled - ) + # We ought to have full state now, barring some unlikely race where we left and + # rejoned the room in the background. + if context.partial_state: + raise AssertionError( + f"Event {event.event_id} still has a partial resolved state " + f"after room {event.room_id} was un-partial stated" + ) + + await self._process_received_pdu( + origin, + event, + context, + backfilled=backfilled, + ) except FederationError as e: + await self._store.record_event_failed_pull_attempt( + event.room_id, event_id, str(e) + ) + if e.code == 403: logger.warning("Pulled event %s failed history check.", event_id) else: raise - async def _resolve_state_at_missing_prevs( + @trace + async def _compute_event_context_with_maybe_missing_prevs( self, dest: str, event: EventBase - ) -> Optional[StateMap[str]]: - """Calculate the state at an event with missing prev_events. + ) -> EventContext: + """Build an EventContext structure for a non-outlier event whose prev_events may + be missing. - This is used when we have pulled a batch of events from a remote server, and - still don't have all the prev_events. + This is used when we have pulled a batch of events from a remote server, and may + not have all the prev_events. - If we already have all the prev_events for `event`, this method does nothing. + To build an EventContext, we need to calculate the state before the event. If we + already have all the prev_events for `event`, we can simply use the state after + the prev_events to calculate the state before `event`. Otherwise, the missing prevs become new backwards extremities, and we fall back to asking the remote server for the state after each missing `prev_event`, @@ -834,8 +942,7 @@ class FederationEventHandler: event: an event to check for missing prevs. Returns: - if we already had all the prev events, `None`. Otherwise, returns - the event ids of the state at `event`. + The event context. Raises: FederationError if we fail to get the state from the remote server after any @@ -849,7 +956,7 @@ class FederationEventHandler: missing_prevs = prevs - seen if not missing_prevs: - return None + return await self._state_handler.compute_event_context(event) logger.info( "Event %s is missing prev_events %s: calculating state for a " @@ -861,9 +968,15 @@ class FederationEventHandler: # resolve them to find the correct state at the current event. try: + # Determine whether we may be about to retrieve partial state + # Events may be un-partial stated right after we compute the partial state + # flag, but that's okay, as long as the flag errs on the conservative side. + partial_state_flags = await self._store.get_partial_state_events(seen) + partial_state = any(partial_state_flags.values()) + # Get the state of the events we know about ours = await self._state_storage_controller.get_state_groups_ids( - room_id, seen + room_id, seen, await_full_state=False ) # state_maps is a list of mappings from (type, state_key) to event_id @@ -909,8 +1022,12 @@ class FederationEventHandler: "We can't get valid state history.", affected=event_id, ) - return state_map + return await self._state_handler.compute_event_context( + event, state_ids_before_event=state_map, partial_state=partial_state + ) + @trace + @tag_args async def _get_state_ids_after_missing_prev_event( self, destination: str, @@ -931,6 +1048,14 @@ class FederationEventHandler: InvalidResponseError: if the remote homeserver's response contains fields of the wrong type. """ + + # It would be better if we could query the difference from our known + # state to the given `event_id` so the sending server doesn't have to + # send as much and we don't have to process as many events. For example + # in a room like #matrix:matrix.org, we get 200k events (77k state_events, 122k + # auth_events) from this call. + # + # Tracked by https://github.com/matrix-org/synapse/issues/13618 ( state_event_ids, auth_event_ids, @@ -955,11 +1080,11 @@ class FederationEventHandler: ) have_events = await self._store.have_seen_events(room_id, desired_events) - missing_desired_events = desired_events - have_events + missing_desired_event_ids = desired_events - have_events logger.debug( "_get_state_ids_after_missing_prev_event(event_id=%s): We are missing %i events (got %i)", event_id, - len(missing_desired_events), + len(missing_desired_event_ids), len(have_events), ) @@ -971,17 +1096,34 @@ class FederationEventHandler: # already have a bunch of the state events. It would be nice if the # federation api gave us a way of finding out which we actually need. - missing_auth_events = set(auth_event_ids) - have_events - missing_auth_events.difference_update( - await self._store.have_seen_events(room_id, missing_auth_events) + missing_auth_event_ids = set(auth_event_ids) - have_events + missing_auth_event_ids.difference_update( + await self._store.have_seen_events(room_id, missing_auth_event_ids) ) logger.debug( "_get_state_ids_after_missing_prev_event(event_id=%s): We are also missing %i auth events", event_id, - len(missing_auth_events), + len(missing_auth_event_ids), ) - missing_events = missing_desired_events | missing_auth_events + missing_event_ids = missing_desired_event_ids | missing_auth_event_ids + + set_tag( + SynapseTags.RESULT_PREFIX + "missing_auth_event_ids", + str(missing_auth_event_ids), + ) + set_tag( + SynapseTags.RESULT_PREFIX + "missing_auth_event_ids.length", + str(len(missing_auth_event_ids)), + ) + set_tag( + SynapseTags.RESULT_PREFIX + "missing_desired_event_ids", + str(missing_desired_event_ids), + ) + set_tag( + SynapseTags.RESULT_PREFIX + "missing_desired_event_ids.length", + str(len(missing_desired_event_ids)), + ) # Making an individual request for each of 1000s of events has a lot of # overhead. On the other hand, we don't really want to fetch all of the events @@ -992,7 +1134,7 @@ class FederationEventHandler: # # TODO: might it be better to have an API which lets us do an aggregate event # request - if (len(missing_events) * 10) >= len(auth_event_ids) + len(state_event_ids): + if (len(missing_event_ids) * 10) >= len(auth_event_ids) + len(state_event_ids): logger.debug( "_get_state_ids_after_missing_prev_event(event_id=%s): Requesting complete state from remote", event_id, @@ -1002,10 +1144,10 @@ class FederationEventHandler: logger.debug( "_get_state_ids_after_missing_prev_event(event_id=%s): Fetching %i events from remote", event_id, - len(missing_events), + len(missing_event_ids), ) await self._get_events_and_persist( - destination=destination, room_id=room_id, event_ids=missing_events + destination=destination, room_id=room_id, event_ids=missing_event_ids ) # We now need to fill out the state map, which involves fetching the @@ -1063,12 +1205,24 @@ class FederationEventHandler: # missing state at that event is a warning, not a blocker # XXX: this doesn't sound right? it means that we'll end up with incomplete # state. + failed_to_fetch = desired_events - event_metadata.keys() + # `event_id` could be missing from `event_metadata` because it's not necessarily + # a state event. We've already checked that we've fetched it above. + failed_to_fetch.discard(event_id) if failed_to_fetch: logger.warning( "_get_state_ids_after_missing_prev_event(event_id=%s): Failed to fetch missing state events %s", event_id, failed_to_fetch, ) + set_tag( + SynapseTags.RESULT_PREFIX + "failed_to_fetch", + str(failed_to_fetch), + ) + set_tag( + SynapseTags.RESULT_PREFIX + "failed_to_fetch.length", + str(len(failed_to_fetch)), + ) if remote_event.is_state() and remote_event.rejected_reason is None: state_map[ @@ -1077,6 +1231,8 @@ class FederationEventHandler: return state_map + @trace + @tag_args async def _get_state_and_persist( self, destination: str, room_id: str, event_id: str ) -> None: @@ -1098,11 +1254,12 @@ class FederationEventHandler: destination=destination, room_id=room_id, event_ids=(event_id,) ) + @trace async def _process_received_pdu( self, origin: str, event: EventBase, - state_ids: Optional[StateMap[str]], + context: EventContext, backfilled: bool = False, ) -> None: """Called when we have a new non-outlier event. @@ -1124,30 +1281,20 @@ class FederationEventHandler: event: event to be persisted - state_ids: Normally None, but if we are handling a gap in the graph - (ie, we are missing one or more prev_events), the resolved state at the - event. Must not be partial state. + context: The `EventContext` to persist the event with. backfilled: True if this is part of a historical batch of events (inhibits notification to clients, and validation of device keys.) PartialStateConflictError: if the room was un-partial stated in between - computing the state at the event and persisting it. The caller should retry - exactly once in this case. Will never be raised if `state_ids` is provided. + computing the state at the event and persisting it. The caller should + recompute `context` and retry exactly once when this happens. """ logger.debug("Processing event: %s", event) assert not event.internal_metadata.outlier - context = await self._state_handler.compute_event_context( - event, - state_ids_before_event=state_ids, - ) try: - context = await self._check_event_auth( - origin, - event, - context, - ) + await self._check_event_auth(origin, event, context) except AuthError as e: # This happens only if we couldn't find the auth events. We'll already have # logged a warning, so now we just convert to a FederationError. @@ -1157,7 +1304,7 @@ class FederationEventHandler: # For new (non-backfilled and non-outlier) events we check if the event # passes auth based on the current state. If it doesn't then we # "soft-fail" the event. - await self._check_for_soft_fail(event, state_ids, origin=origin) + await self._check_for_soft_fail(event, context=context, origin=origin) await self._run_push_actions_and_persist_event(event, context, backfilled) @@ -1258,6 +1405,7 @@ class FederationEventHandler: except Exception: logger.exception("Failed to resync device for %s", sender) + @trace async def _handle_marker_event(self, origin: str, marker_event: EventBase) -> None: """Handles backfilling the insertion event when we receive a marker event that points to one. @@ -1289,7 +1437,7 @@ class FederationEventHandler: logger.debug("_handle_marker_event: received %s", marker_event) insertion_event_id = marker_event.content.get( - EventContentFields.MSC2716_MARKER_INSERTION + EventContentFields.MSC2716_INSERTION_EVENT_REFERENCE ) if insertion_event_id is None: @@ -1342,6 +1490,55 @@ class FederationEventHandler: marker_event, ) + async def backfill_event_id( + self, destination: str, room_id: str, event_id: str + ) -> EventBase: + """Backfill a single event and persist it as a non-outlier which means + we also pull in all of the state and auth events necessary for it. + + Args: + destination: The homeserver to pull the given event_id from. + room_id: The room where the event is from. + event_id: The event ID to backfill. + + Raises: + FederationError if we are unable to find the event from the destination + """ + logger.info( + "backfill_event_id: event_id=%s from destination=%s", event_id, destination + ) + + room_version = await self._store.get_room_version(room_id) + + event_from_response = await self._federation_client.get_pdu( + [destination], + event_id, + room_version, + ) + + if not event_from_response: + raise FederationError( + "ERROR", + 404, + "Unable to find event_id=%s from destination=%s to backfill." + % (event_id, destination), + affected=event_id, + ) + + # Persist the event we just fetched, including pulling all of the state + # and auth events to de-outlier it. This also sets up the necessary + # `state_groups` for the event. + await self._process_pulled_events( + destination, + [event_from_response], + # Prevent notifications going to clients + backfilled=True, + ) + + return event_from_response + + @trace + @tag_args async def _get_events_and_persist( self, destination: str, room_id: str, event_ids: Collection[str] ) -> None: @@ -1387,6 +1584,7 @@ class FederationEventHandler: logger.info("Fetched %i events of %i requested", len(events), len(event_ids)) await self._auth_and_persist_outliers(room_id, events) + @trace async def _auth_and_persist_outliers( self, room_id: str, events: Iterable[EventBase] ) -> None: @@ -1405,6 +1603,16 @@ class FederationEventHandler: """ event_map = {event.event_id: event for event in events} + event_ids = event_map.keys() + set_tag( + SynapseTags.FUNC_ARG_PREFIX + "event_ids", + str(event_ids), + ) + set_tag( + SynapseTags.FUNC_ARG_PREFIX + "event_ids.length", + str(len(event_ids)), + ) + # filter out any events we have already seen. This might happen because # the events were eagerly pushed to us (eg, during a room join), or because # another thread has raced against us since we decided to request the event. @@ -1521,24 +1729,21 @@ class FederationEventHandler: backfilled=True, ) + @trace async def _check_event_auth( - self, - origin: str, - event: EventBase, - context: EventContext, - ) -> EventContext: + self, origin: Optional[str], event: EventBase, context: EventContext + ) -> None: """ Checks whether an event should be rejected (for failing auth checks). Args: - origin: The host the event originates from. + origin: The host the event originates from. This is used to fetch + any missing auth events. It can be set to None, but only if we are + sure that we already have all the auth events. event: The event itself. context: The event context. - Returns: - The updated context object. - Raises: AuthError if we were unable to find copies of the event's auth events. (Most other failures just cause us to set `context.rejected`.) @@ -1553,7 +1758,7 @@ class FederationEventHandler: logger.warning("While validating received event %r: %s", event, e) # TODO: use a different rejected reason here? context.rejected = RejectedReason.AUTH_ERROR - return context + return # next, check that we have all of the event's auth events. # @@ -1563,8 +1768,19 @@ class FederationEventHandler: claimed_auth_events = await self._load_or_fetch_auth_events_for_event( origin, event ) + set_tag( + SynapseTags.RESULT_PREFIX + "claimed_auth_events", + str([ev.event_id for ev in claimed_auth_events]), + ) + set_tag( + SynapseTags.RESULT_PREFIX + "claimed_auth_events.length", + str(len(claimed_auth_events)), + ) # ... and check that the event passes auth at those auth events. + # https://spec.matrix.org/v1.3/server-server-api/#checks-performed-on-receipt-of-a-pdu: + # 4. Passes authorization rules based on the event’s auth events, + # otherwise it is rejected. try: await check_state_independent_auth_rules(self._store, event) check_state_dependent_auth_rules(event, claimed_auth_events) @@ -1573,55 +1789,91 @@ class FederationEventHandler: "While checking auth of %r against auth_events: %s", event, e ) context.rejected = RejectedReason.AUTH_ERROR - return context + return + + # now check the auth rules pass against the room state before the event + # https://spec.matrix.org/v1.3/server-server-api/#checks-performed-on-receipt-of-a-pdu: + # 5. Passes authorization rules based on the state before the event, + # otherwise it is rejected. + # + # ... however, if we only have partial state for the room, then there is a good + # chance that we'll be missing some of the state needed to auth the new event. + # So, we state-resolve the auth events that we are given against the state that + # we know about, which ensures things like bans are applied. (Note that we'll + # already have checked we have all the auth events, in + # _load_or_fetch_auth_events_for_event above) + if context.partial_state: + room_version = await self._store.get_room_version_id(event.room_id) + + local_state_id_map = await context.get_prev_state_ids() + claimed_auth_events_id_map = { + (ev.type, ev.state_key): ev.event_id for ev in claimed_auth_events + } + + state_for_auth_id_map = ( + await self._state_resolution_handler.resolve_events_with_store( + event.room_id, + room_version, + [local_state_id_map, claimed_auth_events_id_map], + event_map=None, + state_res_store=StateResolutionStore(self._store), + ) + ) + else: + event_types = event_auth.auth_types_for_event(event.room_version, event) + state_for_auth_id_map = await context.get_prev_state_ids( + StateFilter.from_types(event_types) + ) - # now check auth against what we think the auth events *should* be. - event_types = event_auth.auth_types_for_event(event.room_version, event) - prev_state_ids = await context.get_prev_state_ids( - StateFilter.from_types(event_types) + calculated_auth_event_ids = self._event_auth_handler.compute_auth_events( + event, state_for_auth_id_map, for_verification=True ) - auth_events_ids = self._event_auth_handler.compute_auth_events( - event, prev_state_ids, for_verification=True + # if those are the same, we're done here. + if collections.Counter(event.auth_event_ids()) == collections.Counter( + calculated_auth_event_ids + ): + return + + # otherwise, re-run the auth checks based on what we calculated. + calculated_auth_events = await self._store.get_events_as_list( + calculated_auth_event_ids ) - auth_events_x = await self._store.get_events(auth_events_ids) + + # log the differences + + claimed_auth_event_map = {(e.type, e.state_key): e for e in claimed_auth_events} calculated_auth_event_map = { - (e.type, e.state_key): e for e in auth_events_x.values() + (e.type, e.state_key): e for e in calculated_auth_events } + logger.info( + "event's auth_events are different to our calculated auth_events. " + "Claimed but not calculated: %s. Calculated but not claimed: %s", + [ + ev + for k, ev in claimed_auth_event_map.items() + if k not in calculated_auth_event_map + or calculated_auth_event_map[k].event_id != ev.event_id + ], + [ + ev + for k, ev in calculated_auth_event_map.items() + if k not in claimed_auth_event_map + or claimed_auth_event_map[k].event_id != ev.event_id + ], + ) try: - updated_auth_events = await self._update_auth_events_for_auth( + check_state_dependent_auth_rules(event, calculated_auth_events) + except AuthError as e: + logger.warning( + "While checking auth of %r against room state before the event: %s", event, - calculated_auth_event_map=calculated_auth_event_map, - ) - except Exception: - # We don't really mind if the above fails, so lets not fail - # processing if it does. However, it really shouldn't fail so - # let's still log as an exception since we'll still want to fix - # any bugs. - logger.exception( - "Failed to double check auth events for %s with remote. " - "Ignoring failure and continuing processing of event.", - event.event_id, - ) - updated_auth_events = None - - if updated_auth_events: - context = await self._update_context_for_auth_events( - event, context, updated_auth_events + e, ) - auth_events_for_auth = updated_auth_events - else: - auth_events_for_auth = calculated_auth_event_map - - try: - check_state_dependent_auth_rules(event, auth_events_for_auth.values()) - except AuthError as e: - logger.warning("Failed auth resolution for %r because %s", event, e) context.rejected = RejectedReason.AUTH_ERROR - return context - + @trace async def _maybe_kick_guest_users(self, event: EventBase) -> None: if event.type != EventTypes.GuestAccess: return @@ -1639,17 +1891,27 @@ class FederationEventHandler: async def _check_for_soft_fail( self, event: EventBase, - state_ids: Optional[StateMap[str]], + context: EventContext, origin: str, ) -> None: """Checks if we should soft fail the event; if so, marks the event as such. + Does nothing for events in rooms with partial state, since we may not have an + accurate membership event for the sender in the current state. + Args: event - state_ids: The state at the event if we don't have all the event's prev events + context: The `EventContext` which we are about to persist the event with. origin: The host the event originates from. """ + if await self._store.is_partial_state_room(event.room_id): + # We might not know the sender's membership in the current state, so don't + # soft fail anything. Even if we do have a membership for the sender in the + # current state, it may have been derived from state resolution between + # partial and full state and may not be accurate. + return + extrem_ids_list = await self._store.get_latest_event_ids_in_room(event.room_id) extrem_ids = set(extrem_ids_list) prev_event_ids = set(event.prev_event_ids()) @@ -1666,11 +1928,15 @@ class FederationEventHandler: auth_types = auth_types_for_event(room_version_obj, event) # Calculate the "current state". - if state_ids is not None: - # If we're explicitly given the state then we won't have all the - # prev events, and so we have a gap in the graph. In this case - # we want to be a little careful as we might have been down for - # a while and have an incorrect view of the current state, + seen_event_ids = await self._store.have_events_in_timeline(prev_event_ids) + has_missing_prevs = bool(prev_event_ids - seen_event_ids) + if has_missing_prevs: + # We don't have all the prev_events of this event, which means we have a + # gap in the graph, and the new event is going to become a new backwards + # extremity. + # + # In this case we want to be a little careful as we might have been + # down for a while and have an incorrect view of the current state, # however we still want to do checks as gaps are easy to # maliciously manufacture. # @@ -1683,6 +1949,7 @@ class FederationEventHandler: event.room_id, extrem_ids ) state_sets: List[StateMap[str]] = list(state_sets_d.values()) + state_ids = await context.get_prev_state_ids() state_sets.append(state_ids) current_state_ids = ( await self._state_resolution_handler.resolve_events_with_store( @@ -1731,95 +1998,8 @@ class FederationEventHandler: soft_failed_event_counter.inc() event.internal_metadata.soft_failed = True - async def _update_auth_events_for_auth( - self, - event: EventBase, - calculated_auth_event_map: StateMap[EventBase], - ) -> Optional[StateMap[EventBase]]: - """Helper for _check_event_auth. See there for docs. - - Checks whether a given event has the expected auth events. If it - doesn't then we talk to the remote server to compare state to see if - we can come to a consensus (e.g. if one server missed some valid - state). - - This attempts to resolve any potential divergence of state between - servers, but is not essential and so failures should not block further - processing of the event. - - Args: - event: - - calculated_auth_event_map: - Our calculated auth_events based on the state of the room - at the event's position in the DAG. - - Returns: - updated auth event map, or None if no changes are needed. - - """ - assert not event.internal_metadata.outlier - - # check for events which are in the event's claimed auth_events, but not - # in our calculated event map. - event_auth_events = set(event.auth_event_ids()) - different_auth = event_auth_events.difference( - e.event_id for e in calculated_auth_event_map.values() - ) - - if not different_auth: - return None - - logger.info( - "auth_events refers to events which are not in our calculated auth " - "chain: %s", - different_auth, - ) - - # XXX: currently this checks for redactions but I'm not convinced that is - # necessary? - different_events = await self._store.get_events_as_list(different_auth) - - # double-check they're all in the same room - we should already have checked - # this but it doesn't hurt to check again. - for d in different_events: - assert ( - d.room_id == event.room_id - ), f"Event {event.event_id} refers to auth_event {d.event_id} which is in a different room" - - # now we state-resolve between our own idea of the auth events, and the remote's - # idea of them. - - local_state = calculated_auth_event_map.values() - remote_auth_events = dict(calculated_auth_event_map) - remote_auth_events.update({(d.type, d.state_key): d for d in different_events}) - remote_state = remote_auth_events.values() - - room_version = await self._store.get_room_version_id(event.room_id) - new_state = await self._state_handler.resolve_events( - room_version, (local_state, remote_state), event - ) - different_state = { - (d.type, d.state_key): d - for d in new_state.values() - if calculated_auth_event_map.get((d.type, d.state_key)) != d - } - if not different_state: - logger.info("State res returned no new state") - return None - - logger.info( - "After state res: updating auth_events with new state %s", - different_state.values(), - ) - - # take a copy of calculated_auth_event_map before we modify it. - auth_events = dict(calculated_auth_event_map) - auth_events.update(different_state) - return auth_events - async def _load_or_fetch_auth_events_for_event( - self, destination: str, event: EventBase + self, destination: Optional[str], event: EventBase ) -> Collection[EventBase]: """Fetch this event's auth_events, from database or remote @@ -1835,12 +2015,19 @@ class FederationEventHandler: Args: destination: where to send the /event_auth request. Typically the server that sent us `event` in the first place. + + If this is None, no attempt is made to load any missing auth events: + rather, an AssertionError is raised if there are any missing events. + event: the event whose auth_events we want Returns: all of the events listed in `event.auth_events_ids`, after deduplication Raises: + AssertionError if some auth events were missing and no `destination` was + supplied. + AuthError if we were unable to fetch the auth_events for any reason. """ event_auth_event_ids = set(event.auth_event_ids()) @@ -1852,6 +2039,13 @@ class FederationEventHandler: ) if not missing_auth_event_ids: return event_auth_events.values() + if destination is None: + # this shouldn't happen: destination must be set unless we know we have already + # persisted the auth events. + raise AssertionError( + "_load_or_fetch_auth_events_for_event() called with no destination for " + "an event with missing auth_events" + ) logger.info( "Event %s refers to unknown auth events %s: fetching auth chain", @@ -1887,6 +2081,8 @@ class FederationEventHandler: # instead we raise an AuthError, which will make the caller ignore it. raise AuthError(code=HTTPStatus.FORBIDDEN, msg="Auth events could not be found") + @trace + @tag_args async def _get_remote_auth_chain_for_event( self, destination: str, room_id: str, event_id: str ) -> None: @@ -1915,61 +2111,7 @@ class FederationEventHandler: await self._auth_and_persist_outliers(room_id, remote_auth_events) - async def _update_context_for_auth_events( - self, event: EventBase, context: EventContext, auth_events: StateMap[EventBase] - ) -> EventContext: - """Update the state_ids in an event context after auth event resolution, - storing the changes as a new state group. - - Args: - event: The event we're handling the context for - - context: initial event context - - auth_events: Events to update in the event context. - - Returns: - new event context - """ - # exclude the state key of the new event from the current_state in the context. - if event.is_state(): - event_key: Optional[Tuple[str, str]] = (event.type, event.state_key) - else: - event_key = None - state_updates = { - k: a.event_id for k, a in auth_events.items() if k != event_key - } - - current_state_ids = await context.get_current_state_ids() - current_state_ids = dict(current_state_ids) # type: ignore - - current_state_ids.update(state_updates) - - prev_state_ids = await context.get_prev_state_ids() - prev_state_ids = dict(prev_state_ids) - - prev_state_ids.update({k: a.event_id for k, a in auth_events.items()}) - - # create a new state group as a delta from the existing one. - prev_group = context.state_group - state_group = await self._state_storage_controller.store_state_group( - event.event_id, - event.room_id, - prev_group=prev_group, - delta_ids=state_updates, - current_state_ids=current_state_ids, - ) - - return EventContext.with_state( - storage=self._storage_controllers, - state_group=state_group, - state_group_before_event=context.state_group_before_event, - state_delta_due_to_event=state_updates, - prev_group=prev_group, - delta_ids=state_updates, - partial_state=context.partial_state, - ) - + @trace async def _run_push_actions_and_persist_event( self, event: EventBase, context: EventContext, backfilled: bool = False ) -> None: @@ -2078,8 +2220,17 @@ class FederationEventHandler: self._message_handler.maybe_schedule_expiry(event) if not backfilled: # Never notify for backfilled events - for event in events: - await self._notify_persisted_event(event, max_stream_token) + with start_active_span("notify_persisted_events"): + set_tag( + SynapseTags.RESULT_PREFIX + "event_ids", + str([ev.event_id for ev in events]), + ) + set_tag( + SynapseTags.RESULT_PREFIX + "event_ids.length", + str(len(events)), + ) + for event in events: + await self._notify_persisted_event(event, max_stream_token) return max_stream_token.stream @@ -2120,6 +2271,10 @@ class FederationEventHandler: event, event_pos, max_stream_token, extra_users=extra_users ) + if event.type == EventTypes.Member and event.membership == Membership.JOIN: + # TODO retrieve the previous state, and exclude join -> join transitions + self._notifier.notify_user_joined_room(event.event_id, event.room_id) + def _sanity_check_event(self, ev: EventBase) -> None: """ Do some early sanity checks of a received event diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 9bca2bc4b2..93d09e9939 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -26,7 +26,6 @@ from synapse.api.errors import ( SynapseError, ) from synapse.api.ratelimiting import Ratelimiter -from synapse.config.emailconfig import ThreepidBehaviour from synapse.http import RequestTimedOutError from synapse.http.client import SimpleHttpClient from synapse.http.site import SynapseRequest @@ -163,8 +162,7 @@ class IdentityHandler: sid: str, mxid: str, id_server: str, - id_access_token: Optional[str] = None, - use_v2: bool = True, + id_access_token: str, ) -> JsonDict: """Bind a 3PID to an identity server @@ -174,8 +172,7 @@ class IdentityHandler: mxid: The MXID to bind the 3PID to id_server: The domain of the identity server to query id_access_token: The access token to authenticate to the identity - server with, if necessary. Required if use_v2 is true - use_v2: Whether to use v2 Identity Service API endpoints. Defaults to True + server with Raises: SynapseError: On any of the following conditions @@ -187,24 +184,15 @@ class IdentityHandler: """ logger.debug("Proxying threepid bind request for %s to %s", mxid, id_server) - # If an id_access_token is not supplied, force usage of v1 - if id_access_token is None: - use_v2 = False - if not valid_id_server_location(id_server): raise SynapseError( 400, "id_server must be a valid hostname with optional port and path components", ) - # Decide which API endpoint URLs to use - headers = {} bind_data = {"sid": sid, "client_secret": client_secret, "mxid": mxid} - if use_v2: - bind_url = "https://%s/_matrix/identity/v2/3pid/bind" % (id_server,) - headers["Authorization"] = create_id_access_token_header(id_access_token) # type: ignore - else: - bind_url = "https://%s/_matrix/identity/api/v1/3pid/bind" % (id_server,) + bind_url = "https://%s/_matrix/identity/v2/3pid/bind" % (id_server,) + headers = {"Authorization": create_id_access_token_header(id_access_token)} try: # Use the blacklisting http client as this call is only to identity servers @@ -223,21 +211,14 @@ class IdentityHandler: return data except HttpResponseException as e: - if e.code != 404 or not use_v2: - logger.error("3PID bind failed with Matrix error: %r", e) - raise e.to_synapse_error() + logger.error("3PID bind failed with Matrix error: %r", e) + raise e.to_synapse_error() except RequestTimedOutError: raise SynapseError(500, "Timed out contacting identity server") except CodeMessageException as e: data = json_decoder.decode(e.msg) # XXX WAT? return data - logger.info("Got 404 when POSTing JSON %s, falling back to v1 URL", bind_url) - res = await self.bind_threepid( - client_secret, sid, mxid, id_server, id_access_token, use_v2=False - ) - return res - async def try_unbind_threepid(self, mxid: str, threepid: dict) -> bool: """Attempt to remove a 3PID from an identity server, or if one is not provided, all identity servers we're aware the binding is present on @@ -300,8 +281,8 @@ class IdentityHandler: "id_server must be a valid hostname with optional port and path components", ) - url = "https://%s/_matrix/identity/api/v1/3pid/unbind" % (id_server,) - url_bytes = b"/_matrix/identity/api/v1/3pid/unbind" + url = "https://%s/_matrix/identity/v2/3pid/unbind" % (id_server,) + url_bytes = b"/_matrix/identity/v2/3pid/unbind" content = { "mxid": mxid, @@ -434,48 +415,6 @@ class IdentityHandler: return session_id - async def requestEmailToken( - self, - id_server: str, - email: str, - client_secret: str, - send_attempt: int, - next_link: Optional[str] = None, - ) -> JsonDict: - """ - Request an external server send an email on our behalf for the purposes of threepid - validation. - - Args: - id_server: The identity server to proxy to - email: The email to send the message to - client_secret: The unique client_secret sends by the user - send_attempt: Which attempt this is - next_link: A link to redirect the user to once they submit the token - - Returns: - The json response body from the server - """ - params = { - "email": email, - "client_secret": client_secret, - "send_attempt": send_attempt, - } - if next_link: - params["next_link"] = next_link - - try: - data = await self.http_client.post_json_get_json( - id_server + "/_matrix/identity/api/v1/validate/email/requestToken", - params, - ) - return data - except HttpResponseException as e: - logger.info("Proxied requestToken failed: %r", e) - raise e.to_synapse_error() - except RequestTimedOutError: - raise SynapseError(500, "Timed out contacting identity server") - async def requestMsisdnToken( self, id_server: str, @@ -549,18 +488,7 @@ class IdentityHandler: validation_session = None # Try to validate as email - if self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE: - # Remote emails will only be used if a valid identity server is provided. - assert ( - self.hs.config.registration.account_threepid_delegate_email is not None - ) - - # Ask our delegated email identity server - validation_session = await self.threepid_from_creds( - self.hs.config.registration.account_threepid_delegate_email, - threepid_creds, - ) - elif self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL: + if self.hs.config.email.can_verify_email: # Get a validated session matching these details validation_session = await self.store.get_threepid_validation_session( "email", client_secret, sid=sid, validated=True @@ -610,11 +538,7 @@ class IdentityHandler: raise SynapseError(400, "Error contacting the identity server") async def lookup_3pid( - self, - id_server: str, - medium: str, - address: str, - id_access_token: Optional[str] = None, + self, id_server: str, medium: str, address: str, id_access_token: str ) -> Optional[str]: """Looks up a 3pid in the passed identity server. @@ -629,60 +553,15 @@ class IdentityHandler: Returns: the matrix ID of the 3pid, or None if it is not recognized. """ - if id_access_token is not None: - try: - results = await self._lookup_3pid_v2( - id_server, id_access_token, medium, address - ) - return results - - except Exception as e: - # Catch HttpResponseExcept for a non-200 response code - # Check if this identity server does not know about v2 lookups - if isinstance(e, HttpResponseException) and e.code == 404: - # This is an old identity server that does not yet support v2 lookups - logger.warning( - "Attempted v2 lookup on v1 identity server %s. Falling " - "back to v1", - id_server, - ) - else: - logger.warning("Error when looking up hashing details: %s", e) - return None - - return await self._lookup_3pid_v1(id_server, medium, address) - - async def _lookup_3pid_v1( - self, id_server: str, medium: str, address: str - ) -> Optional[str]: - """Looks up a 3pid in the passed identity server using v1 lookup. - - Args: - id_server: The server name (including port, if required) - of the identity server to use. - medium: The type of the third party identifier (e.g. "email"). - address: The third party identifier (e.g. "foo@example.com"). - Returns: - the matrix ID of the 3pid, or None if it is not recognized. - """ try: - data = await self.blacklisting_http_client.get_json( - "%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server), - {"medium": medium, "address": address}, + results = await self._lookup_3pid_v2( + id_server, id_access_token, medium, address ) - - if "mxid" in data: - # note: we used to verify the identity server's signature here, but no longer - # require or validate it. See the following for context: - # https://github.com/matrix-org/synapse/issues/5253#issuecomment-666246950 - return data["mxid"] - except RequestTimedOutError: - raise SynapseError(500, "Timed out contacting identity server") - except OSError as e: - logger.warning("Error from v1 identity server lookup: %s" % (e,)) - - return None + return results + except Exception as e: + logger.warning("Error when looking up hashing details: %s", e) + return None async def _lookup_3pid_v2( self, id_server: str, id_access_token: str, medium: str, address: str @@ -811,7 +690,7 @@ class IdentityHandler: room_type: Optional[str], inviter_display_name: str, inviter_avatar_url: str, - id_access_token: Optional[str] = None, + id_access_token: str, ) -> Tuple[str, List[Dict[str, str]], Dict[str, str], str]: """ Asks an identity server for a third party invite. @@ -832,7 +711,7 @@ class IdentityHandler: inviter_display_name: The current display name of the inviter. inviter_avatar_url: The URL of the inviter's avatar. - id_access_token (str|None): The access token to authenticate to the identity + id_access_token (str): The access token to authenticate to the identity server with Returns: @@ -864,71 +743,24 @@ class IdentityHandler: invite_config["org.matrix.web_client_location"] = self._web_client_location # Add the identity service access token to the JSON body and use the v2 - # Identity Service endpoints if id_access_token is present + # Identity Service endpoints data = None - base_url = "%s%s/_matrix/identity" % (id_server_scheme, id_server) - if id_access_token: - key_validity_url = "%s%s/_matrix/identity/v2/pubkey/isvalid" % ( - id_server_scheme, - id_server, - ) + key_validity_url = "%s%s/_matrix/identity/v2/pubkey/isvalid" % ( + id_server_scheme, + id_server, + ) - # Attempt a v2 lookup - url = base_url + "/v2/store-invite" - try: - data = await self.blacklisting_http_client.post_json_get_json( - url, - invite_config, - {"Authorization": create_id_access_token_header(id_access_token)}, - ) - except RequestTimedOutError: - raise SynapseError(500, "Timed out contacting identity server") - except HttpResponseException as e: - if e.code != 404: - logger.info("Failed to POST %s with JSON: %s", url, e) - raise e - - if data is None: - key_validity_url = "%s%s/_matrix/identity/api/v1/pubkey/isvalid" % ( - id_server_scheme, - id_server, + url = "%s%s/_matrix/identity/v2/store-invite" % (id_server_scheme, id_server) + try: + data = await self.blacklisting_http_client.post_json_get_json( + url, + invite_config, + {"Authorization": create_id_access_token_header(id_access_token)}, ) - url = base_url + "/api/v1/store-invite" - - try: - data = await self.blacklisting_http_client.post_json_get_json( - url, invite_config - ) - except RequestTimedOutError: - raise SynapseError(500, "Timed out contacting identity server") - except HttpResponseException as e: - logger.warning( - "Error trying to call /store-invite on %s%s: %s", - id_server_scheme, - id_server, - e, - ) - - if data is None: - # Some identity servers may only support application/x-www-form-urlencoded - # types. This is especially true with old instances of Sydent, see - # https://github.com/matrix-org/sydent/pull/170 - try: - data = await self.blacklisting_http_client.post_urlencoded_get_json( - url, invite_config - ) - except HttpResponseException as e: - logger.warning( - "Error calling /store-invite on %s%s with fallback " - "encoding: %s", - id_server_scheme, - id_server, - e, - ) - raise e - - # TODO: Check for success + except RequestTimedOutError: + raise SynapseError(500, "Timed out contacting identity server") + token = data["token"] public_keys = data.get("public_keys", []) if "public_key" in data: diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 85b472f250..860c82c110 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -143,8 +143,8 @@ class InitialSyncHandler: joined_rooms, to_key=int(now_token.receipt_key), ) - if self.hs.config.experimental.msc2285_enabled: - receipt = ReceiptEventSource.filter_out_private_receipts(receipt, user_id) + + receipt = ReceiptEventSource.filter_out_private_receipts(receipt, user_id) tags_by_room = await self.store.get_tags_for_user(user_id) @@ -309,18 +309,18 @@ class InitialSyncHandler: if blocked: raise SynapseError(403, "This room has been blocked on this server") - user_id = requester.user.to_string() - ( membership, member_event_id, ) = await self.auth.check_user_in_room_or_world_readable( room_id, - user_id, + requester, allow_departed_users=True, ) is_peeking = member_event_id is None + user_id = requester.user.to_string() + if membership == Membership.JOIN: result = await self._room_initial_sync_joined( user_id, room_id, pagin_config, membership, is_peeking @@ -456,11 +456,8 @@ class InitialSyncHandler: ) if not receipts: return [] - if self.hs.config.experimental.msc2285_enabled: - receipts = ReceiptEventSource.filter_out_private_receipts( - receipts, user_id - ) - return receipts + + return ReceiptEventSource.filter_out_private_receipts(receipts, user_id) presence, receipts, (messages, token) = await make_deferred_yieldable( gather_results( diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 1980e37dae..e07cda133a 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -41,6 +41,7 @@ from synapse.api.errors import ( NotFoundError, ShadowBanError, SynapseError, + UnstableSpecAuthError, UnsupportedRoomVersionError, ) from synapse.api.room_versions import KNOWN_ROOM_VERSIONS @@ -51,6 +52,7 @@ from synapse.events.builder import EventBuilder from synapse.events.snapshot import EventContext from synapse.events.validator import EventValidator from synapse.handlers.directory import DirectoryHandler +from synapse.logging import opentracing from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.http.send_event import ReplicationSendEventRestServlet @@ -102,7 +104,7 @@ class MessageHandler: async def get_room_data( self, - user_id: str, + requester: Requester, room_id: str, event_type: str, state_key: str, @@ -110,7 +112,7 @@ class MessageHandler: """Get data from a room. Args: - user_id + requester: The user who did the request. room_id event_type state_key @@ -123,7 +125,7 @@ class MessageHandler: membership, membership_event_id, ) = await self.auth.check_user_in_room_or_world_readable( - room_id, user_id, allow_departed_users=True + room_id, requester, allow_departed_users=True ) if membership == Membership.JOIN: @@ -149,17 +151,20 @@ class MessageHandler: "Attempted to retrieve data from a room for a user that has never been in it. " "This should not have happened." ) - raise SynapseError(403, "User not in room", errcode=Codes.FORBIDDEN) + raise UnstableSpecAuthError( + 403, + "User not in room", + errcode=Codes.NOT_JOINED, + ) return data async def get_state_events( self, - user_id: str, + requester: Requester, room_id: str, state_filter: Optional[StateFilter] = None, at_token: Optional[StreamToken] = None, - is_guest: bool = False, ) -> List[dict]: """Retrieve all state events for a given room. If the user is joined to the room then return the current state. If the user has @@ -168,14 +173,13 @@ class MessageHandler: visible. Args: - user_id: The user requesting state events. + requester: The user requesting state events. room_id: The room ID to get all state events from. state_filter: The state filter used to fetch state from the database. at_token: the stream token of the at which we are requesting the stats. If the user is not allowed to view the state as of that stream token, we raise a 403 SynapseError. If None, returns the current state based on the current_state_events table. - is_guest: whether this user is a guest Returns: A list of dicts representing state events. [{}, {}, {}] Raises: @@ -185,6 +189,7 @@ class MessageHandler: members of this room. """ state_filter = state_filter or StateFilter.all() + user_id = requester.user.to_string() if at_token: last_event_id = ( @@ -217,7 +222,7 @@ class MessageHandler: membership, membership_event_id, ) = await self.auth.check_user_in_room_or_world_readable( - room_id, user_id, allow_departed_users=True + room_id, requester, allow_departed_users=True ) if membership == Membership.JOIN: @@ -311,30 +316,42 @@ class MessageHandler: Returns: A dict of user_id to profile info """ - user_id = requester.user.to_string() if not requester.app_service: # We check AS auth after fetching the room membership, as it # requires us to pull out all joined members anyway. membership, _ = await self.auth.check_user_in_room_or_world_readable( - room_id, user_id, allow_departed_users=True + room_id, requester, allow_departed_users=True ) if membership != Membership.JOIN: - raise NotImplementedError( - "Getting joined members after leaving is not implemented" + raise SynapseError( + code=403, + errcode=Codes.FORBIDDEN, + msg="Getting joined members while not being a current member of the room is forbidden.", ) - users_with_profile = await self.store.get_users_in_room_with_profiles(room_id) + users_with_profile = ( + await self._state_storage_controller.get_users_in_room_with_profiles( + room_id + ) + ) # If this is an AS, double check that they are allowed to see the members. # This can either be because the AS user is in the room or because there # is a user in the room that the AS is "interested in" - if requester.app_service and user_id not in users_with_profile: + if ( + requester.app_service + and requester.user.to_string() not in users_with_profile + ): for uid in users_with_profile: if requester.app_service.is_interested_in_user(uid): break else: # Loop fell through, AS has no interested users in room - raise AuthError(403, "Appservice not in room") + raise UnstableSpecAuthError( + 403, + "Appservice not in room", + errcode=Codes.NOT_JOINED, + ) return { user_id: { @@ -463,6 +480,7 @@ class EventCreationHandler: ) self._events_shard_config = self.config.worker.events_shard_config self._instance_name = hs.get_instance_name() + self._notifier = hs.get_notifier() self.room_prejoin_state_types = self.hs.config.api.room_prejoin_state @@ -734,18 +752,12 @@ class EventCreationHandler: if builder.type == EventTypes.Member: membership = builder.content.get("membership", None) if membership == Membership.JOIN: - return await self._is_server_notices_room(builder.room_id) + return await self.store.is_server_notice_room(builder.room_id) elif membership == Membership.LEAVE: # the user is always allowed to leave (but not kick people) return builder.state_key == requester.user.to_string() return False - async def _is_server_notices_room(self, room_id: str) -> bool: - if self.config.servernotices.server_notices_mxid is None: - return False - user_ids = await self.store.get_users_in_room(room_id) - return self.config.servernotices.server_notices_mxid in user_ids - async def assert_accepted_privacy_policy(self, requester: Requester) -> None: """Check if a user has accepted the privacy policy @@ -1134,6 +1146,10 @@ class EventCreationHandler: context = await self.state.compute_event_context( event, state_ids_before_event=state_map_for_event, + # TODO(faster_joins): check how MSC2716 works and whether we can have + # partial state here + # https://github.com/matrix-org/synapse/issues/13003 + partial_state=False, ) else: context = await self.state.compute_event_context(event) @@ -1358,9 +1374,10 @@ class EventCreationHandler: # and `state_groups` because they have `prev_events` that aren't persisted yet # (historical messages persisted in reverse-chronological order). if not event.internal_metadata.is_historical(): - await self._bulk_push_rule_evaluator.action_for_event_by_user( - event, context - ) + with opentracing.start_active_span("calculate_push_actions"): + await self._bulk_push_rule_evaluator.action_for_event_by_user( + event, context + ) try: # If we're a worker we need to hit out to the master. @@ -1444,7 +1461,13 @@ class EventCreationHandler: if state_entry.state_group in self._external_cache_joined_hosts_updates: return - joined_hosts = await self.store.get_joined_hosts(event.room_id, state_entry) + state = await state_entry.get_state( + self._storage_controllers.state, StateFilter.all() + ) + with opentracing.start_active_span("get_joined_hosts"): + joined_hosts = await self.store.get_joined_hosts( + event.room_id, state, state_entry + ) # Note that the expiry times must be larger than the expiry time in # _external_cache_joined_hosts_updates. @@ -1545,6 +1568,16 @@ class EventCreationHandler: requester, is_admin_redaction=is_admin_redaction ) + if event.type == EventTypes.Member and event.membership == Membership.JOIN: + ( + current_membership, + _, + ) = await self.store.get_local_current_membership_for_user_in_room( + event.state_key, event.room_id + ) + if current_membership != Membership.JOIN: + self._notifier.notify_user_joined_room(event.event_id, event.room_id) + await self._maybe_kick_guest_users(event, context) if event.type == EventTypes.CanonicalAlias: @@ -1844,13 +1877,8 @@ class EventCreationHandler: # For each room we need to find a joined member we can use to send # the dummy event with. - latest_event_ids = await self.store.get_prev_events_for_room(room_id) - members = await self.state.get_current_users_in_room( - room_id, latest_event_ids=latest_event_ids - ) + members = await self.store.get_local_users_in_room(room_id) for user_id in members: - if not self.hs.is_mine_id(user_id): - continue requester = create_requester(user_id, authenticated_entity=self.server_name) try: event, context = await self.create_event( @@ -1861,7 +1889,6 @@ class EventCreationHandler: "room_id": room_id, "sender": user_id, }, - prev_event_ids=latest_event_ids, ) event.internal_metadata.proactively_send = False diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 6262a35822..1f83bab836 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -24,7 +24,9 @@ from synapse.api.errors import SynapseError from synapse.api.filtering import Filter from synapse.events.utils import SerializeEventConfig from synapse.handlers.room import ShutdownRoomResponse +from synapse.logging.opentracing import trace from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.rest.admin._base import assert_user_is_admin from synapse.storage.state import StateFilter from synapse.streams.config import PaginationConfig from synapse.types import JsonDict, Requester, StreamKeyType @@ -158,11 +160,9 @@ class PaginationHandler: self._retention_allowed_lifetime_max = ( hs.config.retention.retention_allowed_lifetime_max ) + self._is_master = hs.config.worker.worker_app is None - if ( - hs.config.worker.run_background_tasks - and hs.config.retention.retention_enabled - ): + if hs.config.retention.retention_enabled and self._is_master: # Run the purge jobs described in the configuration file. for job in hs.config.retention.retention_purge_jobs: logger.info("Setting up purge job with config: %s", job) @@ -416,6 +416,7 @@ class PaginationHandler: await self._storage_controllers.purge_events.purge_room(room_id) + @trace async def get_messages( self, requester: Requester, @@ -423,6 +424,7 @@ class PaginationHandler: pagin_config: PaginationConfig, as_client_event: bool = True, event_filter: Optional[Filter] = None, + use_admin_priviledge: bool = False, ) -> JsonDict: """Get messages in a room. @@ -432,10 +434,16 @@ class PaginationHandler: pagin_config: The pagination config rules to apply, if any. as_client_event: True to get events in client-server format. event_filter: Filter to apply to results or None + use_admin_priviledge: if `True`, return all events, regardless + of whether `user` has access to them. To be used **ONLY** + from the admin API. Returns: Pagination API results """ + if use_admin_priviledge: + await assert_user_is_admin(self.auth, requester) + user_id = requester.user.to_string() if pagin_config.from_token: @@ -458,12 +466,14 @@ class PaginationHandler: room_token = from_token.room_key async with self.pagination_lock.read(room_id): - ( - membership, - member_event_id, - ) = await self.auth.check_user_in_room_or_world_readable( - room_id, user_id, allow_departed_users=True - ) + (membership, member_event_id) = (None, None) + if not use_admin_priviledge: + ( + membership, + member_event_id, + ) = await self.auth.check_user_in_room_or_world_readable( + room_id, requester, allow_departed_users=True + ) if pagin_config.direction == "b": # if we're going backwards, we might need to backfill. This @@ -475,7 +485,7 @@ class PaginationHandler: room_id, room_token.stream ) - if membership == Membership.LEAVE: + if not use_admin_priviledge and membership == Membership.LEAVE: # If they have left the room then clamp the token to be before # they left the room, to save the effort of loading from the # database. @@ -528,12 +538,13 @@ class PaginationHandler: if event_filter: events = await event_filter.filter(events) - events = await filter_events_for_client( - self._storage_controllers, - user_id, - events, - is_peeking=(member_event_id is None), - ) + if not use_admin_priviledge: + events = await filter_events_for_client( + self._storage_controllers, + user_id, + events, + is_peeking=(member_event_id is None), + ) # if after the filter applied there are no more events # return immediately - but there might be more in next_token batch diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 895ea63ed3..4e575ffbaa 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -34,7 +34,6 @@ from typing import ( Callable, Collection, Dict, - FrozenSet, Generator, Iterable, List, @@ -42,7 +41,6 @@ from typing import ( Set, Tuple, Type, - Union, ) from prometheus_client import Counter @@ -68,7 +66,6 @@ from synapse.storage.databases.main import DataStore from synapse.streams import EventSource from synapse.types import JsonDict, StreamKeyType, UserID, get_domain_from_id from synapse.util.async_helpers import Linearizer -from synapse.util.caches.descriptors import _CacheContext, cached from synapse.util.metrics import Measure from synapse.util.wheel_timer import WheelTimer @@ -1656,15 +1653,18 @@ class PresenceEventSource(EventSource[int, UserPresenceState]): # doesn't return. C.f. #5503. return [], max_token - # Figure out which other users this user should receive updates for - users_interested_in = await self._get_interested_in(user, explicit_room_id) + # Figure out which other users this user should explicitly receive + # updates for + additional_users_interested_in = ( + await self.get_presence_router().get_interested_users(user.to_string()) + ) # We have a set of users that we're interested in the presence of. We want to # cross-reference that with the users that have actually changed their presence. # Check whether this user should see all user updates - if users_interested_in == PresenceRouter.ALL_USERS: + if additional_users_interested_in == PresenceRouter.ALL_USERS: # Provide presence state for all users presence_updates = await self._filter_all_presence_updates_for_user( user_id, include_offline, from_key @@ -1673,34 +1673,47 @@ class PresenceEventSource(EventSource[int, UserPresenceState]): return presence_updates, max_token # Make mypy happy. users_interested_in should now be a set - assert not isinstance(users_interested_in, str) + assert not isinstance(additional_users_interested_in, str) + + # We always care about our own presence. + additional_users_interested_in.add(user_id) + + if explicit_room_id: + user_ids = await self.store.get_users_in_room(explicit_room_id) + additional_users_interested_in.update(user_ids) # The set of users that we're interested in and that have had a presence update. # We'll actually pull the presence updates for these users at the end. - interested_and_updated_users: Union[Set[str], FrozenSet[str]] = set() + interested_and_updated_users: Collection[str] if from_key is not None: # First get all users that have had a presence update updated_users = stream_change_cache.get_all_entities_changed(from_key) # Cross-reference users we're interested in with those that have had updates. - # Use a slightly-optimised method for processing smaller sets of updates. - if updated_users is not None and len(updated_users) < 500: - # For small deltas, it's quicker to get all changes and then - # cross-reference with the users we're interested in + if updated_users is not None: + # If we have the full list of changes for presence we can + # simply check which ones share a room with the user. get_updates_counter.labels("stream").inc() - for other_user_id in updated_users: - if other_user_id in users_interested_in: - # mypy thinks this variable could be a FrozenSet as it's possibly set - # to one in the `get_entities_changed` call below, and `add()` is not - # method on a FrozenSet. That doesn't affect us here though, as - # `interested_and_updated_users` is clearly a set() above. - interested_and_updated_users.add(other_user_id) # type: ignore + + sharing_users = await self.store.do_users_share_a_room( + user_id, updated_users + ) + + interested_and_updated_users = ( + sharing_users.union(additional_users_interested_in) + ).intersection(updated_users) + else: # Too many possible updates. Find all users we can see and check # if any of them have changed. get_updates_counter.labels("full").inc() + users_interested_in = ( + await self.store.get_users_who_share_room_with_user(user_id) + ) + users_interested_in.update(additional_users_interested_in) + interested_and_updated_users = ( stream_change_cache.get_entities_changed( users_interested_in, from_key @@ -1709,7 +1722,10 @@ class PresenceEventSource(EventSource[int, UserPresenceState]): else: # No from_key has been specified. Return the presence for all users # this user is interested in - interested_and_updated_users = users_interested_in + interested_and_updated_users = ( + await self.store.get_users_who_share_room_with_user(user_id) + ) + interested_and_updated_users.update(additional_users_interested_in) # Retrieve the current presence state for each user users_to_state = await self.get_presence_handler().current_state_for_users( @@ -1804,62 +1820,6 @@ class PresenceEventSource(EventSource[int, UserPresenceState]): def get_current_key(self) -> int: return self.store.get_current_presence_token() - @cached(num_args=2, cache_context=True) - async def _get_interested_in( - self, - user: UserID, - explicit_room_id: Optional[str] = None, - cache_context: Optional[_CacheContext] = None, - ) -> Union[Set[str], str]: - """Returns the set of users that the given user should see presence - updates for. - - Args: - user: The user to retrieve presence updates for. - explicit_room_id: The users that are in the room will be returned. - - Returns: - A set of user IDs to return presence updates for, or "ALL" to return all - known updates. - """ - user_id = user.to_string() - users_interested_in = set() - users_interested_in.add(user_id) # So that we receive our own presence - - # cache_context isn't likely to ever be None due to the @cached decorator, - # but we can't have a non-optional argument after the optional argument - # explicit_room_id either. Assert cache_context is not None so we can use it - # without mypy complaining. - assert cache_context - - # Check with the presence router whether we should poll additional users for - # their presence information - additional_users = await self.get_presence_router().get_interested_users( - user.to_string() - ) - if additional_users == PresenceRouter.ALL_USERS: - # If the module requested that this user see the presence updates of *all* - # users, then simply return that instead of calculating what rooms this - # user shares - return PresenceRouter.ALL_USERS - - # Add the additional users from the router - users_interested_in.update(additional_users) - - # Find the users who share a room with this user - users_who_share_room = await self.store.get_users_who_share_room_with_user( - user_id, on_invalidate=cache_context.invalidate - ) - users_interested_in.update(users_who_share_room) - - if explicit_room_id: - user_ids = await self.store.get_users_in_room( - explicit_room_id, on_invalidate=cache_context.invalidate - ) - users_interested_in.update(user_ids) - - return users_interested_in - def handle_timeouts( user_states: List[UserPresenceState], @@ -2091,8 +2051,7 @@ async def get_interested_remotes( ) for room_id, states in room_ids_to_states.items(): - user_ids = await store.get_users_in_room(room_id) - hosts = {get_domain_from_id(user_id) for user_id in user_ids} + hosts = await store.get_current_hosts_in_room(room_id) for host in hosts: hosts_and_states.setdefault(host, set()).update(states) diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 43d2882b0a..d2bdb9c8be 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -256,10 +256,9 @@ class ReceiptEventSource(EventSource[int, JsonDict]): room_ids, from_key=from_key, to_key=to_key ) - if self.config.experimental.msc2285_enabled: - events = ReceiptEventSource.filter_out_private_receipts( - events, user.to_string() - ) + events = ReceiptEventSource.filter_out_private_receipts( + events, user.to_string() + ) return events, to_key diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index c77d181722..20ec22105a 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -29,7 +29,13 @@ from synapse.api.constants import ( JoinRules, LoginType, ) -from synapse.api.errors import AuthError, Codes, ConsentNotGivenError, SynapseError +from synapse.api.errors import ( + AuthError, + Codes, + ConsentNotGivenError, + InvalidClientTokenError, + SynapseError, +) from synapse.appservice import ApplicationService from synapse.config.server import is_threepid_reserved from synapse.http.servlet import assert_params_in_dict @@ -180,10 +186,7 @@ class RegistrationHandler: ) if guest_access_token: user_data = await self.auth.get_user_by_access_token(guest_access_token) - if ( - not user_data.is_guest - or UserID.from_string(user_data.user_id).localpart != localpart - ): + if not user_data.is_guest or user_data.user.localpart != localpart: raise AuthError( 403, "Cannot register taken user ID without valid guest " @@ -618,7 +621,7 @@ class RegistrationHandler: user_id = user.to_string() service = self.store.get_app_service_by_token(as_token) if not service: - raise AuthError(403, "Invalid application service token.") + raise InvalidClientTokenError() if not service.is_interested_in_user(user_id): raise SynapseError( 400, diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index 0b63cd2186..28d7093f08 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -19,6 +19,7 @@ import attr from synapse.api.constants import RelationTypes from synapse.api.errors import SynapseError from synapse.events import EventBase, relation_from_event +from synapse.logging.opentracing import trace from synapse.storage.databases.main.relations import _RelatedEvent from synapse.types import JsonDict, Requester, StreamToken, UserID from synapse.visibility import filter_events_for_client @@ -73,7 +74,6 @@ class RelationsHandler: room_id: str, relation_type: Optional[str] = None, event_type: Optional[str] = None, - aggregation_key: Optional[str] = None, limit: int = 5, direction: str = "b", from_token: Optional[StreamToken] = None, @@ -89,7 +89,6 @@ class RelationsHandler: room_id: The room the event belongs to. relation_type: Only fetch events with this relation type, if given. event_type: Only fetch events with this event type, if given. - aggregation_key: Only fetch events with this aggregation key, if given. limit: Only fetch the most recent `limit` events. direction: Whether to fetch the most recent first (`"b"`) or the oldest first (`"f"`). @@ -104,7 +103,7 @@ class RelationsHandler: # TODO Properly handle a user leaving a room. (_, member_event_id) = await self._auth.check_user_in_room_or_world_readable( - room_id, user_id, allow_departed_users=True + room_id, requester, allow_departed_users=True ) # This gets the original event and checks that a) the event exists and @@ -122,7 +121,6 @@ class RelationsHandler: room_id=room_id, relation_type=relation_type, event_type=event_type, - aggregation_key=aggregation_key, limit=limit, direction=direction, from_token=from_token, @@ -364,6 +362,7 @@ class RelationsHandler: return results + @trace async def get_bundled_aggregations( self, events: Iterable[EventBase], user_id: str ) -> Dict[str, BundledAggregations]: diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index a54f163c0a..33e9a87002 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -19,6 +19,7 @@ import math import random import string from collections import OrderedDict +from http import HTTPStatus from typing import ( TYPE_CHECKING, Any, @@ -60,7 +61,6 @@ from synapse.event_auth import validate_event_for_room_version from synapse.events import EventBase from synapse.events.utils import copy_and_fixup_power_levels_contents from synapse.federation.federation_client import InvalidResponseError -from synapse.handlers.federation import get_domains_from_state from synapse.handlers.relations import BundledAggregations from synapse.module_api import NOT_SPAM from synapse.rest.admin._base import assert_user_is_admin @@ -705,8 +705,8 @@ class RoomCreationHandler: was, requested, `room_alias`. Secondly, the stream_id of the last persisted event. Raises: - SynapseError if the room ID couldn't be stored, or something went - horribly wrong. + SynapseError if the room ID couldn't be stored, 3pid invitation config + validation failed, or something went horribly wrong. ResourceLimitError if server is blocked to some resource being exceeded """ @@ -721,7 +721,7 @@ class RoomCreationHandler: # allow the server notices mxid to create rooms is_requester_admin = True else: - is_requester_admin = await self.auth.is_server_admin(requester.user) + is_requester_admin = await self.auth.is_server_admin(requester) # Let the third party rules modify the room creation config if needed, or abort # the room creation entirely with an exception. @@ -732,6 +732,19 @@ class RoomCreationHandler: invite_3pid_list = config.get("invite_3pid", []) invite_list = config.get("invite", []) + # validate each entry for correctness + for invite_3pid in invite_3pid_list: + if not all( + key in invite_3pid + for key in ("medium", "address", "id_server", "id_access_token") + ): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "all of `medium`, `address`, `id_server` and `id_access_token` " + "are required when making a 3pid invite", + Codes.MISSING_PARAM, + ) + if not is_requester_admin: spam_check = await self.spam_checker.user_may_create_room(user_id) if spam_check != NOT_SPAM: @@ -889,7 +902,11 @@ class RoomCreationHandler: # override any attempt to set room versions via the creation_content creation_content["room_version"] = room_version.identifier - last_stream_id = await self._send_events_for_new_room( + ( + last_stream_id, + last_sent_event_id, + depth, + ) = await self._send_events_for_new_room( requester, room_id, preset_config=preset_config, @@ -905,7 +922,7 @@ class RoomCreationHandler: if "name" in config: name = config["name"] ( - _, + name_event, last_stream_id, ) = await self.event_creation_handler.create_and_send_nonmember_event( requester, @@ -917,12 +934,16 @@ class RoomCreationHandler: "content": {"name": name}, }, ratelimit=False, + prev_event_ids=[last_sent_event_id], + depth=depth, ) + last_sent_event_id = name_event.event_id + depth += 1 if "topic" in config: topic = config["topic"] ( - _, + topic_event, last_stream_id, ) = await self.event_creation_handler.create_and_send_nonmember_event( requester, @@ -934,7 +955,11 @@ class RoomCreationHandler: "content": {"topic": topic}, }, ratelimit=False, + prev_event_ids=[last_sent_event_id], + depth=depth, ) + last_sent_event_id = topic_event.event_id + depth += 1 # we avoid dropping the lock between invites, as otherwise joins can # start coming in and making the createRoom slow. @@ -949,7 +974,7 @@ class RoomCreationHandler: for invitee in invite_list: ( - _, + member_event_id, last_stream_id, ) = await self.room_member_handler.update_membership_locked( requester, @@ -959,16 +984,23 @@ class RoomCreationHandler: ratelimit=False, content=content, new_room=True, + prev_event_ids=[last_sent_event_id], + depth=depth, ) + last_sent_event_id = member_event_id + depth += 1 for invite_3pid in invite_3pid_list: id_server = invite_3pid["id_server"] - id_access_token = invite_3pid.get("id_access_token") # optional + id_access_token = invite_3pid["id_access_token"] address = invite_3pid["address"] medium = invite_3pid["medium"] # Note that do_3pid_invite can raise a ShadowBanError, but this was # handled above by emptying invite_3pid_list. - last_stream_id = await self.hs.get_room_member_handler().do_3pid_invite( + ( + member_event_id, + last_stream_id, + ) = await self.hs.get_room_member_handler().do_3pid_invite( room_id, requester.user, medium, @@ -977,7 +1009,11 @@ class RoomCreationHandler: requester, txn_id=None, id_access_token=id_access_token, + prev_event_ids=[last_sent_event_id], + depth=depth, ) + last_sent_event_id = member_event_id + depth += 1 result = {"room_id": room_id} @@ -1005,20 +1041,22 @@ class RoomCreationHandler: power_level_content_override: Optional[JsonDict] = None, creator_join_profile: Optional[JsonDict] = None, ratelimit: bool = True, - ) -> int: + ) -> Tuple[int, str, int]: """Sends the initial events into a new room. `power_level_content_override` doesn't apply when initial state has power level state event content. Returns: - The stream_id of the last event persisted. + A tuple containing the stream ID, event ID and depth of the last + event sent to the room. """ creator_id = creator.user.to_string() event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""} + depth = 1 last_sent_event_id: Optional[str] = None def create(etype: str, content: JsonDict, **kwargs: Any) -> JsonDict: @@ -1031,6 +1069,7 @@ class RoomCreationHandler: async def send(etype: str, content: JsonDict, **kwargs: Any) -> int: nonlocal last_sent_event_id + nonlocal depth event = create(etype, content, **kwargs) logger.debug("Sending %s in new room", etype) @@ -1047,9 +1086,11 @@ class RoomCreationHandler: # Note: we don't pass state_event_ids here because this triggers # an additional query per event to look them up from the events table. prev_event_ids=[last_sent_event_id] if last_sent_event_id else [], + depth=depth, ) last_sent_event_id = sent_event.event_id + depth += 1 return last_stream_id @@ -1075,6 +1116,7 @@ class RoomCreationHandler: content=creator_join_profile, new_room=True, prev_event_ids=[last_sent_event_id], + depth=depth, ) last_sent_event_id = member_event_id @@ -1168,7 +1210,7 @@ class RoomCreationHandler: content={"algorithm": RoomEncryptionAlgorithms.DEFAULT}, ) - return last_sent_stream_id + return last_sent_stream_id, last_sent_event_id, depth def _generate_room_id(self) -> str: """Generates a random room ID. @@ -1250,13 +1292,16 @@ class RoomContextHandler: """ user = requester.user if use_admin_priviledge: - await assert_user_is_admin(self.auth, requester.user) + await assert_user_is_admin(self.auth, requester) before_limit = math.floor(limit / 2.0) after_limit = limit - before_limit - users = await self.store.get_users_in_room(room_id) - is_peeking = user.to_string() not in users + is_user_in_room = await self.store.check_local_user_in_room( + user_id=user.to_string(), room_id=room_id + ) + # The user is peeking if they aren't in the room already + is_peeking = not is_user_in_room async def filter_evts(events: List[EventBase]) -> List[EventBase]: if use_admin_priviledge: @@ -1355,6 +1400,7 @@ class TimestampLookupHandler: self.store = hs.get_datastores().main self.state_handler = hs.get_state_handler() self.federation_client = hs.get_federation_client() + self.federation_event_handler = hs.get_federation_event_handler() self._storage_controllers = hs.get_storage_controllers() async def get_event_for_timestamp( @@ -1429,17 +1475,16 @@ class TimestampLookupHandler: timestamp, ) - # Find other homeservers from the given state in the room - curr_state = await self._storage_controllers.state.get_current_state( - room_id + likely_domains = ( + await self._storage_controllers.state.get_current_hosts_in_room(room_id) ) - curr_domains = get_domains_from_state(curr_state) - likely_domains = [ - domain for domain, depth in curr_domains if domain != self.server_name - ] # Loop through each homeserver candidate until we get a succesful response for domain in likely_domains: + # We don't want to ask our own server for information we don't have + if domain == self.server_name: + continue + try: remote_response = await self.federation_client.timestamp_to_event( domain, room_id, timestamp, direction @@ -1450,38 +1495,68 @@ class TimestampLookupHandler: remote_response, ) - # TODO: Do we want to persist this as an extremity? - # TODO: I think ideally, we would try to backfill from - # this event and run this whole - # `get_event_for_timestamp` function again to make sure - # they didn't give us an event from their gappy history. remote_event_id = remote_response.event_id - origin_server_ts = remote_response.origin_server_ts + remote_origin_server_ts = remote_response.origin_server_ts + + # Backfill this event so we can get a pagination token for + # it with `/context` and paginate `/messages` from this + # point. + # + # TODO: The requested timestamp may lie in a part of the + # event graph that the remote server *also* didn't have, + # in which case they will have returned another event + # which may be nowhere near the requested timestamp. In + # the future, we may need to reconcile that gap and ask + # other homeservers, and/or extend `/timestamp_to_event` + # to return events on *both* sides of the timestamp to + # help reconcile the gap faster. + remote_event = ( + await self.federation_event_handler.backfill_event_id( + domain, room_id, remote_event_id + ) + ) + + # XXX: When we see that the remote server is not trustworthy, + # maybe we should not ask them first in the future. + if remote_origin_server_ts != remote_event.origin_server_ts: + logger.info( + "get_event_for_timestamp: Remote server (%s) claimed that remote_event_id=%s occured at remote_origin_server_ts=%s but that isn't true (actually occured at %s). Their claims are dubious and we should consider not trusting them.", + domain, + remote_event_id, + remote_origin_server_ts, + remote_event.origin_server_ts, + ) # Only return the remote event if it's closer than the local event if not local_event or ( - abs(origin_server_ts - timestamp) + abs(remote_event.origin_server_ts - timestamp) < abs(local_event.origin_server_ts - timestamp) ): - return remote_event_id, origin_server_ts + logger.info( + "get_event_for_timestamp: returning remote_event_id=%s (%s) since it's closer to timestamp=%s than local_event=%s (%s)", + remote_event_id, + remote_event.origin_server_ts, + timestamp, + local_event.event_id if local_event else None, + local_event.origin_server_ts if local_event else None, + ) + return remote_event_id, remote_origin_server_ts except (HttpResponseException, InvalidResponseError) as ex: # Let's not put a high priority on some other homeserver # failing to respond or giving a random response logger.debug( - "Failed to fetch /timestamp_to_event from %s because of exception(%s) %s args=%s", + "get_event_for_timestamp: Failed to fetch /timestamp_to_event from %s because of exception(%s) %s args=%s", domain, type(ex).__name__, ex, ex.args, ) - except Exception as ex: + except Exception: # But we do want to see some exceptions in our code logger.warning( - "Failed to fetch /timestamp_to_event from %s because of exception(%s) %s args=%s", + "get_event_for_timestamp: Failed to fetch /timestamp_to_event from %s because of exception", domain, - type(ex).__name__, - ex, - ex.args, + exc_info=True, ) # To appease mypy, we have to add both of these conditions to check for diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index 29868eb743..bb0bdb8e6f 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -182,7 +182,7 @@ class RoomListHandler: == HistoryVisibility.WORLD_READABLE, "guest_can_join": room["guest_access"] == "can_join", "join_rule": room["join_rules"], - "org.matrix.msc3827.room_type": room["room_type"], + "room_type": room["room_type"], } # Filter out Nones – rather omit the field altogether diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 04c44b2ccb..8d01f4bf2b 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -32,6 +32,7 @@ from synapse.event_auth import get_named_level, get_power_level_event from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN +from synapse.logging import opentracing from synapse.module_api import NOT_SPAM from synapse.storage.state import StateFilter from synapse.types import ( @@ -94,12 +95,29 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): rate_hz=hs.config.ratelimiting.rc_joins_local.per_second, burst_count=hs.config.ratelimiting.rc_joins_local.burst_count, ) + # Tracks joins from local users to rooms this server isn't a member of. + # I.e. joins this server makes by requesting /make_join /send_join from + # another server. self._join_rate_limiter_remote = Ratelimiter( store=self.store, clock=self.clock, rate_hz=hs.config.ratelimiting.rc_joins_remote.per_second, burst_count=hs.config.ratelimiting.rc_joins_remote.burst_count, ) + # TODO: find a better place to keep this Ratelimiter. + # It needs to be + # - written to by event persistence code + # - written to by something which can snoop on replication streams + # - read by the RoomMemberHandler to rate limit joins from local users + # - read by the FederationServer to rate limit make_joins and send_joins from + # other homeservers + # I wonder if a homeserver-wide collection of rate limiters might be cleaner? + self._join_rate_per_room_limiter = Ratelimiter( + store=self.store, + clock=self.clock, + rate_hz=hs.config.ratelimiting.rc_joins_per_room.per_second, + burst_count=hs.config.ratelimiting.rc_joins_per_room.burst_count, + ) # Ratelimiter for invites, keyed by room (across all issuers, all # recipients). @@ -136,6 +154,18 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): ) self.request_ratelimiter = hs.get_request_ratelimiter() + hs.get_notifier().add_new_join_in_room_callback(self._on_user_joined_room) + + def _on_user_joined_room(self, event_id: str, room_id: str) -> None: + """Notify the rate limiter that a room join has occurred. + + Use this to inform the RoomMemberHandler about joins that have either + - taken place on another homeserver, or + - on another worker in this homeserver. + Joins actioned by this worker should use the usual `ratelimit` method, which + checks the limit and increments the counter in one go. + """ + self._join_rate_per_room_limiter.record_action(requester=None, key=room_id) @abc.abstractmethod async def _remote_join( @@ -149,7 +179,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): """Try and join a room that this server is not in Args: - requester + requester: The user making the request, according to the access token. remote_room_hosts: List of servers that can be used to join via. room_id: Room that we are trying to join user: User who is trying to join @@ -285,6 +315,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): allow_no_prev_events: bool = False, prev_event_ids: Optional[List[str]] = None, state_event_ids: Optional[List[str]] = None, + depth: Optional[int] = None, txn_id: Optional[str] = None, ratelimit: bool = True, content: Optional[dict] = None, @@ -315,6 +346,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): prev_events are set so we need to set them ourself via this argument. This should normally be left as None, which will cause the auth_event_ids to be calculated based on the room state at the prev_events. + depth: Override the depth used to order the event in the DAG. + Should normally be set to None, which will cause the depth to be calculated + based on the prev_events. txn_id: ratelimit: @@ -370,6 +404,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): allow_no_prev_events=allow_no_prev_events, prev_event_ids=prev_event_ids, state_event_ids=state_event_ids, + depth=depth, require_consent=require_consent, outlier=outlier, historical=historical, @@ -391,14 +426,17 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # up blocking profile updates. if newly_joined and ratelimit: await self._join_rate_limiter_local.ratelimit(requester) - - result_event = await self.event_creation_handler.handle_new_client_event( - requester, - event, - context, - extra_users=[target], - ratelimit=ratelimit, - ) + await self._join_rate_per_room_limiter.ratelimit( + requester, key=room_id, update=False + ) + with opentracing.start_active_span("handle_new_client_event"): + result_event = await self.event_creation_handler.handle_new_client_event( + requester, + event, + context, + extra_users=[target], + ratelimit=ratelimit, + ) if event.membership == Membership.LEAVE: if prev_member_event_id: @@ -466,6 +504,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): allow_no_prev_events: bool = False, prev_event_ids: Optional[List[str]] = None, state_event_ids: Optional[List[str]] = None, + depth: Optional[int] = None, ) -> Tuple[str, int]: """Update a user's membership in a room. @@ -501,6 +540,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): prev_events are set so we need to set them ourself via this argument. This should normally be left as None, which will cause the auth_event_ids to be calculated based on the room state at the prev_events. + depth: Override the depth used to order the event in the DAG. + Should normally be set to None, which will cause the depth to be calculated + based on the prev_events. Returns: A tuple of the new event ID and stream ID. @@ -523,24 +565,26 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # by application services), and then by room ID. async with self.member_as_limiter.queue(as_id): async with self.member_linearizer.queue(key): - result = await self.update_membership_locked( - requester, - target, - room_id, - action, - txn_id=txn_id, - remote_room_hosts=remote_room_hosts, - third_party_signed=third_party_signed, - ratelimit=ratelimit, - content=content, - new_room=new_room, - require_consent=require_consent, - outlier=outlier, - historical=historical, - allow_no_prev_events=allow_no_prev_events, - prev_event_ids=prev_event_ids, - state_event_ids=state_event_ids, - ) + with opentracing.start_active_span("update_membership_locked"): + result = await self.update_membership_locked( + requester, + target, + room_id, + action, + txn_id=txn_id, + remote_room_hosts=remote_room_hosts, + third_party_signed=third_party_signed, + ratelimit=ratelimit, + content=content, + new_room=new_room, + require_consent=require_consent, + outlier=outlier, + historical=historical, + allow_no_prev_events=allow_no_prev_events, + prev_event_ids=prev_event_ids, + state_event_ids=state_event_ids, + depth=depth, + ) return result @@ -562,6 +606,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): allow_no_prev_events: bool = False, prev_event_ids: Optional[List[str]] = None, state_event_ids: Optional[List[str]] = None, + depth: Optional[int] = None, ) -> Tuple[str, int]: """Helper for update_membership. @@ -599,10 +644,14 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): prev_events are set so we need to set them ourself via this argument. This should normally be left as None, which will cause the auth_event_ids to be calculated based on the room state at the prev_events. + depth: Override the depth used to order the event in the DAG. + Should normally be set to None, which will cause the depth to be calculated + based on the prev_events. Returns: A tuple of the new event ID and stream ID. """ + content_specified = bool(content) if content is None: content = {} @@ -640,7 +689,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): errcode=Codes.BAD_JSON, ) - if "avatar_url" in content: + if "avatar_url" in content and content.get("avatar_url") is not None: if not await self.profile_handler.check_avatar_size_and_mime_type( content["avatar_url"], ): @@ -695,7 +744,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): is_requester_admin = True else: - is_requester_admin = await self.auth.is_server_admin(requester.user) + is_requester_admin = await self.auth.is_server_admin(requester) if not is_requester_admin: if self.config.server.block_non_admin_invites: @@ -732,6 +781,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): allow_no_prev_events=allow_no_prev_events, prev_event_ids=prev_event_ids, state_event_ids=state_event_ids, + depth=depth, content=content, require_consent=require_consent, outlier=outlier, @@ -740,14 +790,14 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): latest_event_ids = await self.store.get_prev_events_for_room(room_id) - current_state_ids = await self.state_handler.get_current_state_ids( - room_id, latest_event_ids=latest_event_ids + state_before_join = await self.state_handler.compute_state_after_events( + room_id, latest_event_ids ) # TODO: Refactor into dictionary of explicitly allowed transitions # between old and new state, with specific error messages for some # transitions and generic otherwise - old_state_id = current_state_ids.get((EventTypes.Member, target.to_string())) + old_state_id = state_before_join.get((EventTypes.Member, target.to_string())) if old_state_id: old_state = await self.store.get_event(old_state_id, allow_none=True) old_membership = old_state.content.get("membership") if old_state else None @@ -787,7 +837,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): old_membership == Membership.INVITE and effective_membership_state == Membership.LEAVE ): - is_blocked = await self._is_server_notice_room(room_id) + is_blocked = await self.store.is_server_notice_room(room_id) if is_blocked: raise SynapseError( HTTPStatus.FORBIDDEN, @@ -798,11 +848,11 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): if action == "kick": raise AuthError(403, "The target user is not in the room") - is_host_in_room = await self._is_host_in_room(current_state_ids) + is_host_in_room = await self._is_host_in_room(state_before_join) if effective_membership_state == Membership.JOIN: if requester.is_guest: - guest_can_join = await self._can_guest_join(current_state_ids) + guest_can_join = await self._can_guest_join(state_before_join) if not guest_can_join: # This should be an auth check, but guests are a local concept, # so don't really fit into the general auth process. @@ -818,7 +868,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): bypass_spam_checker = True else: - bypass_spam_checker = await self.auth.is_server_admin(requester.user) + bypass_spam_checker = await self.auth.is_server_admin(requester) inviter = await self._get_inviter(target.to_string(), room_id) if ( @@ -840,13 +890,23 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # Check if a remote join should be performed. remote_join, remote_room_hosts = await self._should_perform_remote_join( - target.to_string(), room_id, remote_room_hosts, content, is_host_in_room + target.to_string(), + room_id, + remote_room_hosts, + content, + is_host_in_room, + state_before_join, ) if remote_join: if ratelimit: await self._join_rate_limiter_remote.ratelimit( requester, ) + await self._join_rate_per_room_limiter.ratelimit( + requester, + key=room_id, + update=False, + ) inviter = await self._get_inviter(target.to_string(), room_id) if inviter and not self.hs.is_mine(inviter): @@ -967,6 +1027,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): ratelimit=ratelimit, prev_event_ids=latest_event_ids, state_event_ids=state_event_ids, + depth=depth, content=content, require_consent=require_consent, outlier=outlier, @@ -979,6 +1040,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): remote_room_hosts: List[str], content: JsonDict, is_host_in_room: bool, + state_before_join: StateMap[str], ) -> Tuple[bool, List[str]]: """ Check whether the server should do a remote join (as opposed to a local @@ -998,6 +1060,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): content: The content to use as the event body of the join. This may be modified. is_host_in_room: True if the host is in the room. + state_before_join: The state before the join event (i.e. the resolution of + the states after its parent events). Returns: A tuple of: @@ -1014,20 +1078,17 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # If the host is in the room, but not one of the authorised hosts # for restricted join rules, a remote join must be used. room_version = await self.store.get_room_version(room_id) - current_state_ids = await self._storage_controllers.state.get_current_state_ids( - room_id - ) # If restricted join rules are not being used, a local join can always # be used. if not await self.event_auth_handler.has_restricted_join_rules( - current_state_ids, room_version + state_before_join, room_version ): return False, [] # If the user is invited to the room or already joined, the join # event can always be issued locally. - prev_member_event_id = current_state_ids.get((EventTypes.Member, user_id), None) + prev_member_event_id = state_before_join.get((EventTypes.Member, user_id), None) prev_member_event = None if prev_member_event_id: prev_member_event = await self.store.get_event(prev_member_event_id) @@ -1042,10 +1103,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # # If not, generate a new list of remote hosts based on which # can issue invites. - event_map = await self.store.get_events(current_state_ids.values()) + event_map = await self.store.get_events(state_before_join.values()) current_state = { state_key: event_map[event_id] - for state_key, event_id in current_state_ids.items() + for state_key, event_id in state_before_join.items() } allowed_servers = get_servers_from_users( get_users_which_can_issue_invite(current_state) @@ -1059,7 +1120,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # Ensure the member should be allowed access via membership in a room. await self.event_auth_handler.check_restricted_join_rules( - current_state_ids, room_version, user_id, prev_member_event + state_before_join, room_version, user_id, prev_member_event ) # If this is going to be a local join, additional information must @@ -1069,7 +1130,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): EventContentFields.AUTHORISING_USER ] = await self.event_auth_handler.get_user_which_could_invite( room_id, - current_state_ids, + state_before_join, ) return False, [] @@ -1321,8 +1382,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): id_server: str, requester: Requester, txn_id: Optional[str], - id_access_token: Optional[str] = None, - ) -> int: + id_access_token: str, + prev_event_ids: Optional[List[str]] = None, + depth: Optional[int] = None, + ) -> Tuple[str, int]: """Invite a 3PID to a room. Args: @@ -1334,16 +1397,20 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): requester: The user making the request. txn_id: The transaction ID this is part of, or None if this is not part of a transaction. - id_access_token: The optional identity server access token. + id_access_token: Identity server access token. + depth: Override the depth used to order the event in the DAG. + prev_event_ids: The event IDs to use as the prev events + Should normally be set to None, which will cause the depth to be calculated + based on the prev_events. Returns: - The new stream ID. + Tuple of event ID and stream ordering position Raises: ShadowBanError if the requester has been shadow-banned. """ if self.config.server.block_non_admin_invites: - is_requester_admin = await self.auth.is_server_admin(requester.user) + is_requester_admin = await self.auth.is_server_admin(requester) if not is_requester_admin: raise SynapseError( 403, "Invites have been disabled on this server", Codes.FORBIDDEN @@ -1383,7 +1450,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # We don't check the invite against the spamchecker(s) here (through # user_may_invite) because we'll do it further down the line anyway (in # update_membership_locked). - _, stream_id = await self.update_membership( + event_id, stream_id = await self.update_membership( requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id ) else: @@ -1402,7 +1469,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): additional_fields=spam_check[1], ) - stream_id = await self._make_and_store_3pid_invite( + event, stream_id = await self._make_and_store_3pid_invite( requester, id_server, medium, @@ -1411,9 +1478,12 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): inviter, txn_id=txn_id, id_access_token=id_access_token, + prev_event_ids=prev_event_ids, + depth=depth, ) + event_id = event.event_id - return stream_id + return event_id, stream_id async def _make_and_store_3pid_invite( self, @@ -1424,8 +1494,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): room_id: str, user: UserID, txn_id: Optional[str], - id_access_token: Optional[str] = None, - ) -> int: + id_access_token: str, + prev_event_ids: Optional[List[str]] = None, + depth: Optional[int] = None, + ) -> Tuple[EventBase, int]: room_state = await self._storage_controllers.state.get_current_state( room_id, StateFilter.from_types( @@ -1518,8 +1590,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): }, ratelimit=False, txn_id=txn_id, + prev_event_ids=prev_event_ids, + depth=depth, ) - return stream_id + return event, stream_id async def _is_host_in_room(self, current_state_ids: StateMap[str]) -> bool: # Have we just created the room, and is this about to be the very @@ -1543,12 +1617,6 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): return False - async def _is_server_notice_room(self, room_id: str) -> bool: - if self._server_notices_mxid is None: - return False - user_ids = await self.store.get_users_in_room(room_id) - return self._server_notices_mxid in user_ids - class RoomMemberMasterHandler(RoomMemberHandler): def __init__(self, hs: "HomeServer"): @@ -1608,14 +1676,18 @@ class RoomMemberMasterHandler(RoomMemberHandler): ] if len(remote_room_hosts) == 0: - raise SynapseError(404, "No known servers") + raise SynapseError( + 404, + "Can't join remote room because no servers " + "that are in the room have been provided.", + ) check_complexity = self.hs.config.server.limit_remote_rooms.enabled if ( check_complexity and self.hs.config.server.limit_remote_rooms.admins_can_join ): - check_complexity = not await self.auth.is_server_admin(user) + check_complexity = not await self.store.is_server_admin(user) if check_complexity: # Fetch the room complexity @@ -1845,8 +1917,11 @@ class RoomMemberMasterHandler(RoomMemberHandler): ]: raise SynapseError(400, "User %s in room %s" % (user_id, room_id)) - if membership: - await self.store.forget(user_id, room_id) + # In normal case this call is only required if `membership` is not `None`. + # But: After the last member had left the room, the background update + # `_background_remove_left_rooms` is deleting rows related to this room from + # the table `current_state_events` and `get_current_state_events` is `None`. + await self.store.forget(user_id, room_id) def get_users_which_can_issue_invite(auth_events: StateMap[EventBase]) -> List[str]: diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py index 13098f56ed..ebd445adca 100644 --- a/synapse/handlers/room_summary.py +++ b/synapse/handlers/room_summary.py @@ -28,11 +28,11 @@ from synapse.api.constants import ( RoomTypes, ) from synapse.api.errors import ( - AuthError, Codes, NotFoundError, StoreError, SynapseError, + UnstableSpecAuthError, UnsupportedRoomVersionError, ) from synapse.api.ratelimiting import Ratelimiter @@ -175,10 +175,11 @@ class RoomSummaryHandler: # First of all, check that the room is accessible. if not await self._is_local_room_accessible(requested_room_id, requester): - raise AuthError( + raise UnstableSpecAuthError( 403, "User %s not in room %s, and room previews are disabled" % (requester, requested_room_id), + errcode=Codes.NOT_JOINED, ) # If this is continuing a previous session, pull the persisted data. @@ -452,7 +453,6 @@ class RoomSummaryHandler: "type": e.type, "state_key": e.state_key, "content": e.content, - "room_id": e.room_id, "sender": e.sender, "origin_server_ts": e.origin_server_ts, } diff --git a/synapse/handlers/send_email.py b/synapse/handlers/send_email.py index a305a66860..e2844799e8 100644 --- a/synapse/handlers/send_email.py +++ b/synapse/handlers/send_email.py @@ -23,10 +23,12 @@ from pkg_resources import parse_version import twisted from twisted.internet.defer import Deferred -from twisted.internet.interfaces import IOpenSSLContextFactory, IReactorTCP +from twisted.internet.interfaces import IOpenSSLContextFactory +from twisted.internet.ssl import optionsForClientTLS from twisted.mail.smtp import ESMTPSender, ESMTPSenderFactory from synapse.logging.context import make_deferred_yieldable +from synapse.types import ISynapseReactor if TYPE_CHECKING: from synapse.server import HomeServer @@ -48,7 +50,7 @@ class _NoTLSESMTPSender(ESMTPSender): async def _sendmail( - reactor: IReactorTCP, + reactor: ISynapseReactor, smtphost: str, smtpport: int, from_addr: str, @@ -59,6 +61,7 @@ async def _sendmail( require_auth: bool = False, require_tls: bool = False, enable_tls: bool = True, + force_tls: bool = False, ) -> None: """A simple wrapper around ESMTPSenderFactory, to allow substitution in tests @@ -73,8 +76,9 @@ async def _sendmail( password: password to give when authenticating require_auth: if auth is not offered, fail the request require_tls: if TLS is not offered, fail the reqest - enable_tls: True to enable TLS. If this is False and require_tls is True, + enable_tls: True to enable STARTTLS. If this is False and require_tls is True, the request will fail. + force_tls: True to enable Implicit TLS. """ msg = BytesIO(msg_bytes) d: "Deferred[object]" = Deferred() @@ -105,13 +109,23 @@ async def _sendmail( # set to enable TLS. factory = build_sender_factory(hostname=smtphost if enable_tls else None) - reactor.connectTCP( - smtphost, - smtpport, - factory, - timeout=30, - bindAddress=None, - ) + if force_tls: + reactor.connectSSL( + smtphost, + smtpport, + factory, + optionsForClientTLS(smtphost), + timeout=30, + bindAddress=None, + ) + else: + reactor.connectTCP( + smtphost, + smtpport, + factory, + timeout=30, + bindAddress=None, + ) await make_deferred_yieldable(d) @@ -132,6 +146,7 @@ class SendEmailHandler: self._smtp_pass = passwd.encode("utf-8") if passwd is not None else None self._require_transport_security = hs.config.email.require_transport_security self._enable_tls = hs.config.email.enable_smtp_tls + self._force_tls = hs.config.email.force_tls self._sendmail = _sendmail @@ -189,4 +204,5 @@ class SendEmailHandler: require_auth=self._smtp_user is not None, require_tls=self._require_transport_security, enable_tls=self._enable_tls, + force_tls=self._force_tls, ) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index d42a414c90..5293fa4d0e 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -13,7 +13,20 @@ # limitations under the License. import itertools import logging -from typing import TYPE_CHECKING, Any, Dict, FrozenSet, List, Optional, Set, Tuple +from typing import ( + TYPE_CHECKING, + AbstractSet, + Any, + Collection, + Dict, + FrozenSet, + List, + Mapping, + Optional, + Sequence, + Set, + Tuple, +) import attr from prometheus_client import Counter @@ -89,7 +102,7 @@ class SyncConfig: @attr.s(slots=True, frozen=True, auto_attribs=True) class TimelineBatch: prev_batch: StreamToken - events: List[EventBase] + events: Sequence[EventBase] limited: bool # A mapping of event ID to the bundled aggregations for the above events. # This is only calculated if limited is true. @@ -507,10 +520,17 @@ class SyncHandler: # ensure that we always include current state in the timeline current_state_ids: FrozenSet[str] = frozenset() if any(e.is_state() for e in recents): + # FIXME(faster_joins): We use the partial state here as + # we don't want to block `/sync` on finishing a lazy join. + # Which should be fine once + # https://github.com/matrix-org/synapse/issues/12989 is resolved, + # since we shouldn't reach here anymore? + # Note that we use the current state as a whitelist for filtering + # `recents`, so partial state is only a problem when a membership + # event turns up in `recents` but has not made it into the current + # state. current_state_ids_map = ( - await self._state_storage_controller.get_current_state_ids( - room_id - ) + await self.store.get_partial_current_state_ids(room_id) ) current_state_ids = frozenset(current_state_ids_map.values()) @@ -579,7 +599,13 @@ class SyncHandler: if any(e.is_state() for e in loaded_recents): # FIXME(faster_joins): We use the partial state here as # we don't want to block `/sync` on finishing a lazy join. - # Is this the correct way of doing it? + # Which should be fine once + # https://github.com/matrix-org/synapse/issues/12989 is resolved, + # since we shouldn't reach here anymore? + # Note that we use the current state as a whitelist for filtering + # `loaded_recents`, so partial state is only a problem when a + # membership event turns up in `loaded_recents` but has not made it + # into the current state. current_state_ids_map = ( await self.store.get_partial_current_state_ids(room_id) ) @@ -627,7 +653,10 @@ class SyncHandler: ) async def get_state_after_event( - self, event_id: str, state_filter: Optional[StateFilter] = None + self, + event_id: str, + state_filter: Optional[StateFilter] = None, + await_full_state: bool = True, ) -> StateMap[str]: """ Get the room state after the given event @@ -635,9 +664,14 @@ class SyncHandler: Args: event_id: event of interest state_filter: The state filter used to fetch state from the database. + await_full_state: if `True`, will block if we do not yet have complete state + at the event and `state_filter` is not satisfied by partial state. + Defaults to `True`. """ state_ids = await self._state_storage_controller.get_state_ids_for_event( - event_id, state_filter=state_filter or StateFilter.all() + event_id, + state_filter=state_filter or StateFilter.all(), + await_full_state=await_full_state, ) # using get_metadata_for_events here (instead of get_event) sidesteps an issue @@ -660,6 +694,7 @@ class SyncHandler: room_id: str, stream_position: StreamToken, state_filter: Optional[StateFilter] = None, + await_full_state: bool = True, ) -> StateMap[str]: """Get the room state at a particular stream position @@ -667,6 +702,9 @@ class SyncHandler: room_id: room for which to get state stream_position: point at which to get state state_filter: The state filter used to fetch state from the database. + await_full_state: if `True`, will block if we do not yet have complete state + at the last event in the room before `stream_position` and + `state_filter` is not satisfied by partial state. Defaults to `True`. """ # FIXME: This gets the state at the latest event before the stream ordering, # which might not be the same as the "current state" of the room at the time @@ -678,7 +716,9 @@ class SyncHandler: if last_event_id: state = await self.get_state_after_event( - last_event_id, state_filter=state_filter or StateFilter.all() + last_event_id, + state_filter=state_filter or StateFilter.all(), + await_full_state=await_full_state, ) else: @@ -852,16 +892,26 @@ class SyncHandler: now_token: StreamToken, full_state: bool, ) -> MutableStateMap[EventBase]: - """Works out the difference in state between the start of the timeline - and the previous sync. + """Works out the difference in state between the end of the previous sync and + the start of the timeline. Args: room_id: batch: The timeline batch for the room that will be sent to the user. sync_config: - since_token: Token of the end of the previous batch. May be None. + since_token: Token of the end of the previous batch. May be `None`. now_token: Token of the end of the current batch. full_state: Whether to force returning the full state. + `lazy_load_members` still applies when `full_state` is `True`. + + Returns: + The state to return in the sync response for the room. + + Clients will overlay this onto the state at the end of the previous sync to + arrive at the state at the start of the timeline. + + Clients will then overlay state events in the timeline to arrive at the + state at the end of the timeline, in preparation for the next sync. """ # TODO(mjark) Check if the state events were received by the server # after the previous sync, since we need to include those state @@ -869,8 +919,17 @@ class SyncHandler: # TODO(mjark) Check for new redactions in the state events. with Measure(self.clock, "compute_state_delta"): + # The memberships needed for events in the timeline. + # Only calculated when `lazy_load_members` is on. + members_to_fetch: Optional[Set[str]] = None + + # A dictionary mapping user IDs to the first event in the timeline sent by + # them. Only calculated when `lazy_load_members` is on. + first_event_by_sender_map: Optional[Dict[str, EventBase]] = None - members_to_fetch = None + # The contribution to the room state from state events in the timeline. + # Only contains the last event for any given state key. + timeline_state: StateMap[str] lazy_load_members = sync_config.filter_collection.lazy_load_members() include_redundant_members = ( @@ -881,10 +940,23 @@ class SyncHandler: # We only request state for the members needed to display the # timeline: - members_to_fetch = { - event.sender # FIXME: we also care about invite targets etc. - for event in batch.events - } + timeline_state = {} + + members_to_fetch = set() + first_event_by_sender_map = {} + for event in batch.events: + # Build the map from user IDs to the first timeline event they sent. + if event.sender not in first_event_by_sender_map: + first_event_by_sender_map[event.sender] = event + + # We need the event's sender, unless their membership was in a + # previous timeline event. + if (EventTypes.Member, event.sender) not in timeline_state: + members_to_fetch.add(event.sender) + # FIXME: we also care about invite targets etc. + + if event.is_state(): + timeline_state[(event.type, event.state_key)] = event.event_id if full_state: # always make sure we LL ourselves so we know we're in the room @@ -894,55 +966,80 @@ class SyncHandler: members_to_fetch.add(sync_config.user.to_string()) state_filter = StateFilter.from_lazy_load_member_list(members_to_fetch) + + # We are happy to use partial state to compute the `/sync` response. + # Since partial state may not include the lazy-loaded memberships we + # require, we fix up the state response afterwards with memberships from + # auth events. + await_full_state = False else: + timeline_state = { + (event.type, event.state_key): event.event_id + for event in batch.events + if event.is_state() + } + state_filter = StateFilter.all() + await_full_state = True - timeline_state = { - (event.type, event.state_key): event.event_id - for event in batch.events - if event.is_state() - } + # Now calculate the state to return in the sync response for the room. + # This is more or less the change in state between the end of the previous + # sync's timeline and the start of the current sync's timeline. + # See the docstring above for details. + state_ids: StateMap[str] if full_state: if batch: - current_state_ids = ( + state_at_timeline_end = ( await self._state_storage_controller.get_state_ids_for_event( - batch.events[-1].event_id, state_filter=state_filter + batch.events[-1].event_id, + state_filter=state_filter, + await_full_state=await_full_state, ) ) - state_ids = ( + state_at_timeline_start = ( await self._state_storage_controller.get_state_ids_for_event( - batch.events[0].event_id, state_filter=state_filter + batch.events[0].event_id, + state_filter=state_filter, + await_full_state=await_full_state, ) ) else: - current_state_ids = await self.get_state_at( - room_id, stream_position=now_token, state_filter=state_filter + state_at_timeline_end = await self.get_state_at( + room_id, + stream_position=now_token, + state_filter=state_filter, + await_full_state=await_full_state, ) - state_ids = current_state_ids + state_at_timeline_start = state_at_timeline_end state_ids = _calculate_state( timeline_contains=timeline_state, - timeline_start=state_ids, - previous={}, - current=current_state_ids, + timeline_start=state_at_timeline_start, + timeline_end=state_at_timeline_end, + previous_timeline_end={}, lazy_load_members=lazy_load_members, ) elif batch.limited: if batch: state_at_timeline_start = ( await self._state_storage_controller.get_state_ids_for_event( - batch.events[0].event_id, state_filter=state_filter + batch.events[0].event_id, + state_filter=state_filter, + await_full_state=await_full_state, ) ) else: # We can get here if the user has ignored the senders of all # the recent events. state_at_timeline_start = await self.get_state_at( - room_id, stream_position=now_token, state_filter=state_filter + room_id, + stream_position=now_token, + state_filter=state_filter, + await_full_state=await_full_state, ) # for now, we disable LL for gappy syncs - see @@ -964,28 +1061,35 @@ class SyncHandler: # is indeed the case. assert since_token is not None state_at_previous_sync = await self.get_state_at( - room_id, stream_position=since_token, state_filter=state_filter + room_id, + stream_position=since_token, + state_filter=state_filter, + await_full_state=await_full_state, ) if batch: - current_state_ids = ( + state_at_timeline_end = ( await self._state_storage_controller.get_state_ids_for_event( - batch.events[-1].event_id, state_filter=state_filter + batch.events[-1].event_id, + state_filter=state_filter, + await_full_state=await_full_state, ) ) else: - # Its not clear how we get here, but empirically we do - # (#5407). Logging has been added elsewhere to try and - # figure out where this state comes from. - current_state_ids = await self.get_state_at( - room_id, stream_position=now_token, state_filter=state_filter + # We can get here if the user has ignored the senders of all + # the recent events. + state_at_timeline_end = await self.get_state_at( + room_id, + stream_position=now_token, + state_filter=state_filter, + await_full_state=await_full_state, ) state_ids = _calculate_state( timeline_contains=timeline_state, timeline_start=state_at_timeline_start, - previous=state_at_previous_sync, - current=current_state_ids, + timeline_end=state_at_timeline_end, + previous_timeline_end=state_at_previous_sync, # we have to include LL members in case LL initial sync missed them lazy_load_members=lazy_load_members, ) @@ -1008,8 +1112,30 @@ class SyncHandler: (EventTypes.Member, member) for member in members_to_fetch ), + await_full_state=False, ) + # If we only have partial state for the room, `state_ids` may be missing the + # memberships we wanted. We attempt to find some by digging through the auth + # events of timeline events. + if lazy_load_members and await self.store.is_partial_state_room(room_id): + assert members_to_fetch is not None + assert first_event_by_sender_map is not None + + additional_state_ids = ( + await self._find_missing_partial_state_memberships( + room_id, members_to_fetch, first_event_by_sender_map, state_ids + ) + ) + state_ids = {**state_ids, **additional_state_ids} + + # At this point, if `lazy_load_members` is enabled, `state_ids` includes + # the memberships of all event senders in the timeline. This is because we + # may not have sent the memberships in a previous sync. + + # When `include_redundant_members` is on, we send all the lazy-loaded + # memberships of event senders. Otherwise we make an effort to limit the set + # of memberships we send to those that we have not already sent to this client. if lazy_load_members and not include_redundant_members: cache_key = (sync_config.user.to_string(), sync_config.device_id) cache = self.get_lazy_loaded_members_cache(cache_key) @@ -1051,6 +1177,99 @@ class SyncHandler: if e.type != EventTypes.Aliases # until MSC2261 or alternative solution } + async def _find_missing_partial_state_memberships( + self, + room_id: str, + members_to_fetch: Collection[str], + events_with_membership_auth: Mapping[str, EventBase], + found_state_ids: StateMap[str], + ) -> StateMap[str]: + """Finds missing memberships from a set of auth events and returns them as a + state map. + + Args: + room_id: The partial state room to find the remaining memberships for. + members_to_fetch: The memberships to find. + events_with_membership_auth: A mapping from user IDs to events whose auth + events are known to contain their membership. + found_state_ids: A dict from (type, state_key) -> state_event_id, containing + memberships that have been previously found. Entries in + `members_to_fetch` that have a membership in `found_state_ids` are + ignored. + + Returns: + A dict from ("m.room.member", state_key) -> state_event_id, containing the + memberships missing from `found_state_ids`. + + Raises: + KeyError: if `events_with_membership_auth` does not have an entry for a + missing membership. Memberships in `found_state_ids` do not need an + entry in `events_with_membership_auth`. + """ + additional_state_ids: MutableStateMap[str] = {} + + # Tracks the missing members for logging purposes. + missing_members = set() + + # Identify memberships missing from `found_state_ids` and pick out the auth + # events in which to look for them. + auth_event_ids: Set[str] = set() + for member in members_to_fetch: + if (EventTypes.Member, member) in found_state_ids: + continue + + missing_members.add(member) + event_with_membership_auth = events_with_membership_auth[member] + auth_event_ids.update(event_with_membership_auth.auth_event_ids()) + + auth_events = await self.store.get_events(auth_event_ids) + + # Run through the missing memberships once more, picking out the memberships + # from the pile of auth events we have just fetched. + for member in members_to_fetch: + if (EventTypes.Member, member) in found_state_ids: + continue + + event_with_membership_auth = events_with_membership_auth[member] + + # Dig through the auth events to find the desired membership. + for auth_event_id in event_with_membership_auth.auth_event_ids(): + # We only store events once we have all their auth events, + # so the auth event must be in the pile we have just + # fetched. + auth_event = auth_events[auth_event_id] + + if ( + auth_event.type == EventTypes.Member + and auth_event.state_key == member + ): + missing_members.remove(member) + additional_state_ids[ + (EventTypes.Member, member) + ] = auth_event.event_id + break + + if missing_members: + # There really shouldn't be any missing memberships now. Either: + # * we couldn't find an auth event, which shouldn't happen because we do + # not persist events with persisting their auth events first, or + # * the set of auth events did not contain a membership we wanted, which + # means our caller didn't compute the events in `members_to_fetch` + # correctly, or we somehow accepted an event whose auth events were + # dodgy. + logger.error( + "Failed to find memberships for %s in partial state room " + "%s in the auth events of %s.", + missing_members, + room_id, + [ + events_with_membership_auth[member].event_id + for member in missing_members + ], + ) + + return additional_state_ids + async def unread_notifs_for_room_id( self, room_id: str, sync_config: SyncConfig ) -> NotifCounts: @@ -1195,10 +1414,10 @@ class SyncHandler: async def _generate_sync_entry_for_device_list( self, sync_result_builder: "SyncResultBuilder", - newly_joined_rooms: Set[str], - newly_joined_or_invited_or_knocked_users: Set[str], - newly_left_rooms: Set[str], - newly_left_users: Set[str], + newly_joined_rooms: AbstractSet[str], + newly_joined_or_invited_or_knocked_users: AbstractSet[str], + newly_left_rooms: AbstractSet[str], + newly_left_users: AbstractSet[str], ) -> DeviceListUpdates: """Generate the DeviceListUpdates section of sync @@ -1216,8 +1435,7 @@ class SyncHandler: user_id = sync_result_builder.sync_config.user.to_string() since_token = sync_result_builder.since_token - # We're going to mutate these fields, so lets copy them rather than - # assume they won't get used later. + # Take a copy since these fields will be mutated later. newly_joined_or_invited_or_knocked_users = set( newly_joined_or_invited_or_knocked_users ) @@ -1417,8 +1635,8 @@ class SyncHandler: async def _generate_sync_entry_for_presence( self, sync_result_builder: "SyncResultBuilder", - newly_joined_rooms: Set[str], - newly_joined_or_invited_users: Set[str], + newly_joined_rooms: AbstractSet[str], + newly_joined_or_invited_users: AbstractSet[str], ) -> None: """Generates the presence portion of the sync response. Populates the `sync_result_builder` with the result. @@ -1476,7 +1694,7 @@ class SyncHandler: self, sync_result_builder: "SyncResultBuilder", account_data_by_room: Dict[str, Dict[str, JsonDict]], - ) -> Tuple[Set[str], Set[str], Set[str], Set[str]]: + ) -> Tuple[AbstractSet[str], AbstractSet[str], AbstractSet[str], AbstractSet[str]]: """Generates the rooms portion of the sync response. Populates the `sync_result_builder` with the result. @@ -1536,15 +1754,13 @@ class SyncHandler: ignored_users = await self.store.ignored_users(user_id) if since_token: room_changes = await self._get_rooms_changed( - sync_result_builder, ignored_users, self.rooms_to_exclude + sync_result_builder, ignored_users ) tags_by_room = await self.store.get_updated_tags( user_id, since_token.account_data_key ) else: - room_changes = await self._get_all_rooms( - sync_result_builder, ignored_users, self.rooms_to_exclude - ) + room_changes = await self._get_all_rooms(sync_result_builder, ignored_users) tags_by_room = await self.store.get_tags_for_user(user_id) log_kv({"rooms_changed": len(room_changes.room_entries)}) @@ -1623,13 +1839,14 @@ class SyncHandler: self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str], - excluded_rooms: List[str], ) -> _RoomChanges: """Determine the changes in rooms to report to the user. This function is a first pass at generating the rooms part of the sync response. It determines which rooms have changed during the sync period, and categorises - them into four buckets: "knock", "invite", "join" and "leave". + them into four buckets: "knock", "invite", "join" and "leave". It also excludes + from that list any room that appears in the list of rooms to exclude from sync + results in the server configuration. 1. Finds all membership changes for the user in the sync period (from `since_token` up to `now_token`). @@ -1655,7 +1872,7 @@ class SyncHandler: # _have_rooms_changed. We could keep the results in memory to avoid a # second query, at the cost of more complicated source code. membership_change_events = await self.store.get_membership_changes_for_user( - user_id, since_token.room_key, now_token.room_key, excluded_rooms + user_id, since_token.room_key, now_token.room_key, self.rooms_to_exclude ) mem_change_events_by_room_id: Dict[str, List[EventBase]] = {} @@ -1696,7 +1913,11 @@ class SyncHandler: continue if room_id in sync_result_builder.joined_room_ids or has_join: - old_state_ids = await self.get_state_at(room_id, since_token) + old_state_ids = await self.get_state_at( + room_id, + since_token, + state_filter=StateFilter.from_types([(EventTypes.Member, user_id)]), + ) old_mem_ev_id = old_state_ids.get((EventTypes.Member, user_id), None) old_mem_ev = None if old_mem_ev_id: @@ -1722,7 +1943,13 @@ class SyncHandler: newly_left_rooms.append(room_id) else: if not old_state_ids: - old_state_ids = await self.get_state_at(room_id, since_token) + old_state_ids = await self.get_state_at( + room_id, + since_token, + state_filter=StateFilter.from_types( + [(EventTypes.Member, user_id)] + ), + ) old_mem_ev_id = old_state_ids.get( (EventTypes.Member, user_id), None ) @@ -1862,7 +2089,6 @@ class SyncHandler: self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str], - ignored_rooms: List[str], ) -> _RoomChanges: """Returns entries for all rooms for the user. @@ -1884,7 +2110,7 @@ class SyncHandler: room_list = await self.store.get_rooms_for_local_user_where_membership_is( user_id=user_id, membership_list=Membership.LIST, - excluded_rooms=ignored_rooms, + excluded_rooms=self.rooms_to_exclude, ) room_entries = [] @@ -2150,7 +2376,9 @@ class SyncHandler: raise Exception("Unrecognized rtype: %r", room_builder.rtype) async def get_rooms_for_user_at( - self, user_id: str, room_key: RoomStreamToken + self, + user_id: str, + room_key: RoomStreamToken, ) -> FrozenSet[str]: """Get set of joined rooms for a user at the given stream ordering. @@ -2176,7 +2404,12 @@ class SyncHandler: # If the membership's stream ordering is after the given stream # ordering, we need to go and work out if the user was in the room # before. + # We also need to check whether the room should be excluded from sync + # responses as per the homeserver config. for joined_room in joined_rooms: + if joined_room.room_id in self.rooms_to_exclude: + continue + if not joined_room.event_pos.persisted_after(room_key): joined_room_ids.add(joined_room.room_id) continue @@ -2188,10 +2421,10 @@ class SyncHandler: joined_room.room_id, joined_room.event_pos.stream ) ) - users_in_room = await self.state.get_current_users_in_room( + user_ids_in_room = await self.state.get_current_user_ids_in_room( joined_room.room_id, extrems ) - if user_id in users_in_room: + if user_id in user_ids_in_room: joined_room_ids.add(joined_room.room_id) return frozenset(joined_room_ids) @@ -2211,8 +2444,8 @@ def _action_has_highlight(actions: List[JsonDict]) -> bool: def _calculate_state( timeline_contains: StateMap[str], timeline_start: StateMap[str], - previous: StateMap[str], - current: StateMap[str], + timeline_end: StateMap[str], + previous_timeline_end: StateMap[str], lazy_load_members: bool, ) -> StateMap[str]: """Works out what state to include in a sync response. @@ -2220,45 +2453,50 @@ def _calculate_state( Args: timeline_contains: state in the timeline timeline_start: state at the start of the timeline - previous: state at the end of the previous sync (or empty dict + timeline_end: state at the end of the timeline + previous_timeline_end: state at the end of the previous sync (or empty dict if this is an initial sync) - current: state at the end of the timeline lazy_load_members: whether to return members from timeline_start or not. assumes that timeline_start has already been filtered to include only the members the client needs to know about. """ - event_id_to_key = { - e: key - for key, e in itertools.chain( + event_id_to_state_key = { + event_id: state_key + for state_key, event_id in itertools.chain( timeline_contains.items(), - previous.items(), timeline_start.items(), - current.items(), + timeline_end.items(), + previous_timeline_end.items(), ) } - c_ids = set(current.values()) - ts_ids = set(timeline_start.values()) - p_ids = set(previous.values()) - tc_ids = set(timeline_contains.values()) + timeline_end_ids = set(timeline_end.values()) + timeline_start_ids = set(timeline_start.values()) + previous_timeline_end_ids = set(previous_timeline_end.values()) + timeline_contains_ids = set(timeline_contains.values()) # If we are lazyloading room members, we explicitly add the membership events # for the senders in the timeline into the state block returned by /sync, # as we may not have sent them to the client before. We find these membership # events by filtering them out of timeline_start, which has already been filtered # to only include membership events for the senders in the timeline. - # In practice, we can do this by removing them from the p_ids list, - # which is the list of relevant state we know we have already sent to the client. + # In practice, we can do this by removing them from the previous_timeline_end_ids + # list, which is the list of relevant state we know we have already sent to the + # client. # see https://github.com/matrix-org/synapse/pull/2970/files/efcdacad7d1b7f52f879179701c7e0d9b763511f#r204732809 if lazy_load_members: - p_ids.difference_update( + previous_timeline_end_ids.difference_update( e for t, e in timeline_start.items() if t[0] == EventTypes.Member ) - state_ids = ((c_ids | ts_ids) - p_ids) - tc_ids + state_ids = ( + (timeline_end_ids | timeline_start_ids) + - previous_timeline_end_ids + - timeline_contains_ids + ) - return {event_id_to_key[e]: e for e in state_ids} + return {event_id_to_state_key[e]: e for e in state_ids} @attr.s(slots=True, auto_attribs=True) @@ -2296,7 +2534,7 @@ class SyncResultBuilder: archived: List[ArchivedSyncResult] = attr.Factory(list) to_device: List[JsonDict] = attr.Factory(list) - def calculate_user_changes(self) -> Tuple[Set[str], Set[str]]: + def calculate_user_changes(self) -> Tuple[AbstractSet[str], AbstractSet[str]]: """Work out which other users have joined or left rooms we are joined to. This data only is only useful for an incremental sync. diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index d104ea07fe..a4cd8b8f0c 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -26,7 +26,7 @@ from synapse.metrics.background_process_metrics import ( ) from synapse.replication.tcp.streams import TypingStream from synapse.streams import EventSource -from synapse.types import JsonDict, Requester, StreamKeyType, UserID, get_domain_from_id +from synapse.types import JsonDict, Requester, StreamKeyType, UserID from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.metrics import Measure from synapse.util.wheel_timer import WheelTimer @@ -253,12 +253,11 @@ class TypingWriterHandler(FollowerTypingHandler): self, target_user: UserID, requester: Requester, room_id: str, timeout: int ) -> None: target_user_id = target_user.to_string() - auth_user_id = requester.user.to_string() if not self.is_mine_id(target_user_id): raise SynapseError(400, "User is not hosted on this homeserver") - if target_user_id != auth_user_id: + if target_user != requester.user: raise AuthError(400, "Cannot set another user's typing state") if requester.shadow_banned: @@ -266,7 +265,7 @@ class TypingWriterHandler(FollowerTypingHandler): await self.clock.sleep(random.randint(1, 10)) raise ShadowBanError() - await self.auth.check_user_in_room(room_id, target_user_id) + await self.auth.check_user_in_room(room_id, requester) logger.debug("%s has started typing in %s", target_user_id, room_id) @@ -289,12 +288,11 @@ class TypingWriterHandler(FollowerTypingHandler): self, target_user: UserID, requester: Requester, room_id: str ) -> None: target_user_id = target_user.to_string() - auth_user_id = requester.user.to_string() if not self.is_mine_id(target_user_id): raise SynapseError(400, "User is not hosted on this homeserver") - if target_user_id != auth_user_id: + if target_user != requester.user: raise AuthError(400, "Cannot set another user's typing state") if requester.shadow_banned: @@ -302,7 +300,7 @@ class TypingWriterHandler(FollowerTypingHandler): await self.clock.sleep(random.randint(1, 10)) raise ShadowBanError() - await self.auth.check_user_in_room(room_id, target_user_id) + await self.auth.check_user_in_room(room_id, requester) logger.debug("%s has stopped typing in %s", target_user_id, room_id) @@ -364,8 +362,9 @@ class TypingWriterHandler(FollowerTypingHandler): ) return - users = await self.store.get_users_in_room(room_id) - domains = {get_domain_from_id(u) for u in users} + domains = await self._storage_controllers.state.get_current_hosts_in_room( + room_id + ) if self.server_name in domains: logger.info("Got typing update from %s: %r", user_id, content) @@ -489,8 +488,15 @@ class TypingNotificationEventSource(EventSource[int, JsonDict]): handler = self.get_typing_handler() events = [] - for room_id in handler._room_serials.keys(): - if handler._room_serials[room_id] <= from_key: + + # Work on a copy of things here as these may change in the handler while + # waiting for the AS `is_interested_in_room` call to complete. + # Shallow copy is safe as no nested data is present. + latest_room_serial = handler._latest_room_serial + room_serials = handler._room_serials.copy() + + for room_id, serial in room_serials.items(): + if serial <= from_key: continue if not await service.is_interested_in_room(room_id, self._main_store): @@ -498,7 +504,7 @@ class TypingNotificationEventSource(EventSource[int, JsonDict]): events.append(self._make_event_for(room_id)) - return events, handler._latest_room_serial + return events, latest_room_serial async def get_new_events( self, diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py index 05cebb5d4d..a744d68c64 100644 --- a/synapse/handlers/ui_auth/checkers.py +++ b/synapse/handlers/ui_auth/checkers.py @@ -19,7 +19,6 @@ from twisted.web.client import PartialDownloadError from synapse.api.constants import LoginType from synapse.api.errors import Codes, LoginError, SynapseError -from synapse.config.emailconfig import ThreepidBehaviour from synapse.util import json_decoder if TYPE_CHECKING: @@ -153,7 +152,7 @@ class _BaseThreepidAuthChecker: logger.info("Getting validated threepid. threepidcreds: %r", (threepid_creds,)) - # msisdns are currently always ThreepidBehaviour.REMOTE + # msisdns are currently always verified via the IS if medium == "msisdn": if not self.hs.config.registration.account_threepid_delegate_msisdn: raise SynapseError( @@ -164,18 +163,7 @@ class _BaseThreepidAuthChecker: threepid_creds, ) elif medium == "email": - if ( - self.hs.config.email.threepid_behaviour_email - == ThreepidBehaviour.REMOTE - ): - assert self.hs.config.registration.account_threepid_delegate_email - threepid = await identity_handler.threepid_from_creds( - self.hs.config.registration.account_threepid_delegate_email, - threepid_creds, - ) - elif ( - self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL - ): + if self.hs.config.email.can_verify_email: threepid = None row = await self.store.get_threepid_validation_session( medium, @@ -227,10 +215,7 @@ class EmailIdentityAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChec _BaseThreepidAuthChecker.__init__(self, hs) def is_enabled(self) -> bool: - return self.hs.config.email.threepid_behaviour_email in ( - ThreepidBehaviour.REMOTE, - ThreepidBehaviour.LOCAL, - ) + return self.hs.config.email.can_verify_email async def check_auth(self, authdict: dict, clientip: str) -> Any: return await self._check_threepid("email", authdict) |