diff options
Diffstat (limited to 'synapse')
65 files changed, 1835 insertions, 1178 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py index 65c1f5aa3f..f2d3ac68eb 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -48,7 +48,7 @@ try: except ImportError: pass -__version__ = "1.23.0" +__version__ = "1.24.0" if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): # We import here so that we don't have to install a bunch of deps when diff --git a/synapse/_scripts/register_new_matrix_user.py b/synapse/_scripts/register_new_matrix_user.py index da0996edbc..dfe26dea6d 100644 --- a/synapse/_scripts/register_new_matrix_user.py +++ b/synapse/_scripts/register_new_matrix_user.py @@ -37,7 +37,7 @@ def request_registration( exit=sys.exit, ): - url = "%s/_matrix/client/r0/admin/register" % (server_location,) + url = "%s/_synapse/admin/v1/register" % (server_location.rstrip("/"),) # Get the nonce r = requests.get(url, verify=False) diff --git a/synapse/api/auth_blocking.py b/synapse/api/auth_blocking.py index d8fafd7cb8..9c227218e0 100644 --- a/synapse/api/auth_blocking.py +++ b/synapse/api/auth_blocking.py @@ -14,10 +14,12 @@ # limitations under the License. import logging +from typing import Optional from synapse.api.constants import LimitBlockingTypes, UserTypes from synapse.api.errors import Codes, ResourceLimitError from synapse.config.server import is_threepid_reserved +from synapse.types import Requester logger = logging.getLogger(__name__) @@ -33,24 +35,47 @@ class AuthBlocking: self._max_mau_value = hs.config.max_mau_value self._limit_usage_by_mau = hs.config.limit_usage_by_mau self._mau_limits_reserved_threepids = hs.config.mau_limits_reserved_threepids + self._server_name = hs.hostname - async def check_auth_blocking(self, user_id=None, threepid=None, user_type=None): + async def check_auth_blocking( + self, + user_id: Optional[str] = None, + threepid: Optional[dict] = None, + user_type: Optional[str] = None, + requester: Optional[Requester] = None, + ): """Checks if the user should be rejected for some external reason, such as monthly active user limiting or global disable flag Args: - user_id(str|None): If present, checks for presence against existing + user_id: If present, checks for presence against existing MAU cohort - threepid(dict|None): If present, checks for presence against configured + threepid: If present, checks for presence against configured reserved threepid. Used in cases where the user is trying register with a MAU blocked server, normally they would be rejected but their threepid is on the reserved list. user_id and threepid should never be set at the same time. - user_type(str|None): If present, is used to decide whether to check against + user_type: If present, is used to decide whether to check against certain blocking reasons like MAU. + + requester: If present, and the authenticated entity is a user, checks for + presence against existing MAU cohort. Passing in both a `user_id` and + `requester` is an error. """ + if requester and user_id: + raise Exception( + "Passed in both 'user_id' and 'requester' to 'check_auth_blocking'" + ) + + if requester: + if requester.authenticated_entity.startswith("@"): + user_id = requester.authenticated_entity + elif requester.authenticated_entity == self._server_name: + # We never block the server from doing actions on behalf of + # users. + return # Never fail an auth check for the server notices users or support user # This can be a problem where event creation is prohibited due to blocking diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 9c8dc785c6..895b38ae76 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -32,6 +32,7 @@ from synapse.app.phone_stats_home import start_phone_stats_home from synapse.config.server import ListenerConfig from synapse.crypto import context_factory from synapse.logging.context import PreserveLoggingContext +from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.util.async_helpers import Linearizer from synapse.util.daemonize import daemonize_process from synapse.util.rlimit import change_resource_limit @@ -244,6 +245,7 @@ def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]): # Set up the SIGHUP machinery. if hasattr(signal, "SIGHUP"): + @wrap_as_background_process("sighup") def handle_sighup(*args, **kwargs): # Tell systemd our state, if we're using it. This will silently fail if # we're not using systemd. @@ -254,7 +256,13 @@ def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]): sdnotify(b"READY=1") - signal.signal(signal.SIGHUP, handle_sighup) + # We defer running the sighup handlers until next reactor tick. This + # is so that we're in a sane state, e.g. flushing the logs may fail + # if the sighup happens in the middle of writing a log entry. + def run_sighup(*args, **kwargs): + hs.get_clock().call_later(0, handle_sighup, *args, **kwargs) + + signal.signal(signal.SIGHUP, run_sighup) register_sighup(refresh_certificate, hs) diff --git a/synapse/config/push.py b/synapse/config/push.py index a1f3752c8a..3adbfb73e6 100644 --- a/synapse/config/push.py +++ b/synapse/config/push.py @@ -21,8 +21,11 @@ class PushConfig(Config): section = "push" def read_config(self, config, **kwargs): - push_config = config.get("push", {}) + push_config = config.get("push") or {} self.push_include_content = push_config.get("include_content", True) + self.push_group_unread_count_by_room = push_config.get( + "group_unread_count_by_room", True + ) pusher_instances = config.get("pusher_instances") or [] self.pusher_shard_config = ShardedWorkerHandlingConfig(pusher_instances) @@ -49,18 +52,33 @@ class PushConfig(Config): def generate_config_section(self, config_dir_path, server_name, **kwargs): return """ - # Clients requesting push notifications can either have the body of - # the message sent in the notification poke along with other details - # like the sender, or just the event ID and room ID (`event_id_only`). - # If clients choose the former, this option controls whether the - # notification request includes the content of the event (other details - # like the sender are still included). For `event_id_only` push, it - # has no effect. - # - # For modern android devices the notification content will still appear - # because it is loaded by the app. iPhone, however will send a - # notification saying only that a message arrived and who it came from. - # - #push: - # include_content: true + ## Push ## + + push: + # Clients requesting push notifications can either have the body of + # the message sent in the notification poke along with other details + # like the sender, or just the event ID and room ID (`event_id_only`). + # If clients choose the former, this option controls whether the + # notification request includes the content of the event (other details + # like the sender are still included). For `event_id_only` push, it + # has no effect. + # + # For modern android devices the notification content will still appear + # because it is loaded by the app. iPhone, however will send a + # notification saying only that a message arrived and who it came from. + # + # The default value is "true" to include message details. Uncomment to only + # include the event ID and room ID in push notification payloads. + # + #include_content: false + + # When a push notification is received, an unread count is also sent. + # This number can either be calculated as the number of unread messages + # for the user, or the number of *rooms* the user has unread messages in. + # + # The default value is "true", meaning push clients will see the number of + # rooms with unread messages in them. Uncomment to instead send the number + # of unread messages. + # + #group_unread_count_by_room: false """ diff --git a/synapse/config/registration.py b/synapse/config/registration.py index b0a77a2e43..cc5f75123c 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -347,8 +347,9 @@ class RegistrationConfig(Config): # email will be globally disabled. # # Additionally, if `msisdn` is not set, registration and password resets via msisdn - # will be disabled regardless. This is due to Synapse currently not supporting any - # method of sending SMS messages on its own. + # will be disabled regardless, and users will not be able to associate an msisdn + # identifier to their account. This is due to Synapse currently not supporting + # any method of sending SMS messages on its own. # # To enable using an identity server for operations regarding a particular third-party # identifier type, set the value to the URL of that identity server as shown in the diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py index 2ff7dfb311..c1b8e98ae0 100644 --- a/synapse/config/saml2_config.py +++ b/synapse/config/saml2_config.py @@ -90,6 +90,8 @@ class SAML2Config(Config): "grandfathered_mxid_source_attribute", "uid" ) + self.saml2_idp_entityid = saml2_config.get("idp_entityid", None) + # user_mapping_provider may be None if the key is present but has no value ump_dict = saml2_config.get("user_mapping_provider") or {} @@ -256,6 +258,12 @@ class SAML2Config(Config): # remote: # - url: https://our_idp/metadata.xml + # Allowed clock difference in seconds between the homeserver and IdP. + # + # Uncomment the below to increase the accepted time difference from 0 to 3 seconds. + # + #accepted_time_diff: 3 + # By default, the user has to go to our login page first. If you'd like # to allow IdP-initiated login, set 'allow_unsolicited: true' in a # 'service.sp' section: @@ -377,6 +385,14 @@ class SAML2Config(Config): # value: "staff" # - attribute: department # value: "sales" + + # If the metadata XML contains multiple IdP entities then the `idp_entityid` + # option must be set to the entity to redirect users to. + # + # Most deployments only have a single IdP entity and so should omit this + # option. + # + #idp_entityid: 'https://our_idp/entityid' """ % { "config_dir_path": config_dir_path } diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py index bad18f7fdf..936896656a 100644 --- a/synapse/events/spamcheck.py +++ b/synapse/events/spamcheck.py @@ -15,13 +15,12 @@ # limitations under the License. import inspect -from typing import Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from synapse.spam_checker_api import RegistrationBehaviour from synapse.types import Collection -MYPY = False -if MYPY: +if TYPE_CHECKING: import synapse.events import synapse.server diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 23278e36b7..4b6ab470d0 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -49,6 +49,7 @@ from synapse.federation.federation_base import FederationBase, event_from_pdu_js from synapse.federation.persistence import TransactionActions from synapse.federation.units import Edu, Transaction from synapse.http.endpoint import parse_server_name +from synapse.http.servlet import assert_params_in_dict from synapse.logging.context import ( make_deferred_yieldable, nested_logging_context, @@ -391,7 +392,7 @@ class FederationServer(FederationBase): TRANSACTION_CONCURRENCY_LIMIT, ) - async def on_context_state_request( + async def on_room_state_request( self, origin: str, room_id: str, event_id: str ) -> Tuple[int, Dict[str, Any]]: origin_host, _ = parse_server_name(origin) @@ -514,11 +515,12 @@ class FederationServer(FederationBase): return {"event": ret_pdu.get_pdu_json(time_now)} async def on_send_join_request( - self, origin: str, content: JsonDict, room_id: str + self, origin: str, content: JsonDict ) -> Dict[str, Any]: logger.debug("on_send_join_request: content: %s", content) - room_version = await self.store.get_room_version(room_id) + assert_params_in_dict(content, ["room_id"]) + room_version = await self.store.get_room_version(content["room_id"]) pdu = event_from_pdu_json(content, room_version) origin_host, _ = parse_server_name(origin) @@ -547,12 +549,11 @@ class FederationServer(FederationBase): time_now = self._clock.time_msec() return {"event": pdu.get_pdu_json(time_now), "room_version": room_version} - async def on_send_leave_request( - self, origin: str, content: JsonDict, room_id: str - ) -> dict: + async def on_send_leave_request(self, origin: str, content: JsonDict) -> dict: logger.debug("on_send_leave_request: content: %s", content) - room_version = await self.store.get_room_version(room_id) + assert_params_in_dict(content, ["room_id"]) + room_version = await self.store.get_room_version(content["room_id"]) pdu = event_from_pdu_json(content, room_version) origin_host, _ = parse_server_name(origin) @@ -748,12 +749,8 @@ class FederationServer(FederationBase): ) return ret - async def on_exchange_third_party_invite_request( - self, room_id: str, event_dict: Dict - ): - ret = await self.handler.on_exchange_third_party_invite_request( - room_id, event_dict - ) + async def on_exchange_third_party_invite_request(self, event_dict: Dict): + ret = await self.handler.on_exchange_third_party_invite_request(event_dict) return ret async def check_server_matches_acl(self, server_name: str, room_id: str): diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index a0933fae88..b53e7a20ec 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -440,13 +440,13 @@ class FederationEventServlet(BaseFederationServlet): class FederationStateV1Servlet(BaseFederationServlet): - PATH = "/state/(?P<context>[^/]*)/?" + PATH = "/state/(?P<room_id>[^/]*)/?" - # This is when someone asks for all data for a given context. - async def on_GET(self, origin, content, query, context): - return await self.handler.on_context_state_request( + # This is when someone asks for all data for a given room. + async def on_GET(self, origin, content, query, room_id): + return await self.handler.on_room_state_request( origin, - context, + room_id, parse_string_from_args(query, "event_id", None, required=False), ) @@ -463,16 +463,16 @@ class FederationStateIdsServlet(BaseFederationServlet): class FederationBackfillServlet(BaseFederationServlet): - PATH = "/backfill/(?P<context>[^/]*)/?" + PATH = "/backfill/(?P<room_id>[^/]*)/?" - async def on_GET(self, origin, content, query, context): + async def on_GET(self, origin, content, query, room_id): versions = [x.decode("ascii") for x in query[b"v"]] limit = parse_integer_from_args(query, "limit", None) if not limit: return 400, {"error": "Did not include limit param"} - return await self.handler.on_backfill_request(origin, context, versions, limit) + return await self.handler.on_backfill_request(origin, room_id, versions, limit) class FederationQueryServlet(BaseFederationServlet): @@ -487,9 +487,9 @@ class FederationQueryServlet(BaseFederationServlet): class FederationMakeJoinServlet(BaseFederationServlet): - PATH = "/make_join/(?P<context>[^/]*)/(?P<user_id>[^/]*)" + PATH = "/make_join/(?P<room_id>[^/]*)/(?P<user_id>[^/]*)" - async def on_GET(self, origin, _content, query, context, user_id): + async def on_GET(self, origin, _content, query, room_id, user_id): """ Args: origin (unicode): The authenticated server_name of the calling server @@ -511,16 +511,16 @@ class FederationMakeJoinServlet(BaseFederationServlet): supported_versions = ["1"] content = await self.handler.on_make_join_request( - origin, context, user_id, supported_versions=supported_versions + origin, room_id, user_id, supported_versions=supported_versions ) return 200, content class FederationMakeLeaveServlet(BaseFederationServlet): - PATH = "/make_leave/(?P<context>[^/]*)/(?P<user_id>[^/]*)" + PATH = "/make_leave/(?P<room_id>[^/]*)/(?P<user_id>[^/]*)" - async def on_GET(self, origin, content, query, context, user_id): - content = await self.handler.on_make_leave_request(origin, context, user_id) + async def on_GET(self, origin, content, query, room_id, user_id): + content = await self.handler.on_make_leave_request(origin, room_id, user_id) return 200, content @@ -528,7 +528,7 @@ class FederationV1SendLeaveServlet(BaseFederationServlet): PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" async def on_PUT(self, origin, content, query, room_id, event_id): - content = await self.handler.on_send_leave_request(origin, content, room_id) + content = await self.handler.on_send_leave_request(origin, content) return 200, (200, content) @@ -538,43 +538,43 @@ class FederationV2SendLeaveServlet(BaseFederationServlet): PREFIX = FEDERATION_V2_PREFIX async def on_PUT(self, origin, content, query, room_id, event_id): - content = await self.handler.on_send_leave_request(origin, content, room_id) + content = await self.handler.on_send_leave_request(origin, content) return 200, content class FederationEventAuthServlet(BaseFederationServlet): - PATH = "/event_auth/(?P<context>[^/]*)/(?P<event_id>[^/]*)" + PATH = "/event_auth/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" - async def on_GET(self, origin, content, query, context, event_id): - return await self.handler.on_event_auth(origin, context, event_id) + async def on_GET(self, origin, content, query, room_id, event_id): + return await self.handler.on_event_auth(origin, room_id, event_id) class FederationV1SendJoinServlet(BaseFederationServlet): - PATH = "/send_join/(?P<context>[^/]*)/(?P<event_id>[^/]*)" + PATH = "/send_join/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" - async def on_PUT(self, origin, content, query, context, event_id): - # TODO(paul): assert that context/event_id parsed from path actually + async def on_PUT(self, origin, content, query, room_id, event_id): + # TODO(paul): assert that room_id/event_id parsed from path actually # match those given in content - content = await self.handler.on_send_join_request(origin, content, context) + content = await self.handler.on_send_join_request(origin, content) return 200, (200, content) class FederationV2SendJoinServlet(BaseFederationServlet): - PATH = "/send_join/(?P<context>[^/]*)/(?P<event_id>[^/]*)" + PATH = "/send_join/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" PREFIX = FEDERATION_V2_PREFIX - async def on_PUT(self, origin, content, query, context, event_id): - # TODO(paul): assert that context/event_id parsed from path actually + async def on_PUT(self, origin, content, query, room_id, event_id): + # TODO(paul): assert that room_id/event_id parsed from path actually # match those given in content - content = await self.handler.on_send_join_request(origin, content, context) + content = await self.handler.on_send_join_request(origin, content) return 200, content class FederationV1InviteServlet(BaseFederationServlet): - PATH = "/invite/(?P<context>[^/]*)/(?P<event_id>[^/]*)" + PATH = "/invite/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" - async def on_PUT(self, origin, content, query, context, event_id): + async def on_PUT(self, origin, content, query, room_id, event_id): # We don't get a room version, so we have to assume its EITHER v1 or # v2. This is "fine" as the only difference between V1 and V2 is the # state resolution algorithm, and we don't use that for processing @@ -589,12 +589,12 @@ class FederationV1InviteServlet(BaseFederationServlet): class FederationV2InviteServlet(BaseFederationServlet): - PATH = "/invite/(?P<context>[^/]*)/(?P<event_id>[^/]*)" + PATH = "/invite/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" PREFIX = FEDERATION_V2_PREFIX - async def on_PUT(self, origin, content, query, context, event_id): - # TODO(paul): assert that context/event_id parsed from path actually + async def on_PUT(self, origin, content, query, room_id, event_id): + # TODO(paul): assert that room_id/event_id parsed from path actually # match those given in content room_version = content["room_version"] @@ -616,9 +616,7 @@ class FederationThirdPartyInviteExchangeServlet(BaseFederationServlet): PATH = "/exchange_third_party_invite/(?P<room_id>[^/]*)" async def on_PUT(self, origin, content, query, room_id): - content = await self.handler.on_exchange_third_party_invite_request( - room_id, content - ) + content = await self.handler.on_exchange_third_party_invite_request(content) return 200, content diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index bd8e71ae56..bb81c0e81d 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -169,7 +169,9 @@ class BaseHandler: # and having homeservers have their own users leave keeps more # of that decision-making and control local to the guest-having # homeserver. - requester = synapse.types.create_requester(target_user, is_guest=True) + requester = synapse.types.create_requester( + target_user, is_guest=True, authenticated_entity=self.server_name + ) handler = self.hs.get_room_member_handler() await handler.update_membership( requester, diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 9fc8444228..5c6458eb52 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -226,7 +226,7 @@ class ApplicationServicesHandler: new_token: Optional[int], users: Collection[Union[str, UserID]], ): - logger.info("Checking interested services for %s" % (stream_key)) + logger.debug("Checking interested services for %s" % (stream_key)) with Measure(self.clock, "notify_interested_services_ephemeral"): for service in services: # Only handle typing if we have the latest token diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 213baea2e3..c7dc07008a 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2014 - 2016 OpenMarket Ltd # Copyright 2017 Vector Creations Ltd +# Copyright 2019 - 2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -25,6 +26,7 @@ from typing import ( Dict, Iterable, List, + Mapping, Optional, Tuple, Union, @@ -181,17 +183,12 @@ class AuthHandler(BaseHandler): # better way to break the loop account_handler = ModuleApi(hs, self) - self.password_providers = [] - for module, config in hs.config.password_providers: - try: - self.password_providers.append( - module(config=config, account_handler=account_handler) - ) - except Exception as e: - logger.error("Error while initializing %r: %s", module, e) - raise + self.password_providers = [ + PasswordProvider.load(module, config, account_handler) + for module, config in hs.config.password_providers + ] - logger.info("Extra password_providers: %r", self.password_providers) + logger.info("Extra password_providers: %s", self.password_providers) self.hs = hs # FIXME better possibility to access registrationHandler later? self.macaroon_gen = hs.get_macaroon_generator() @@ -205,15 +202,23 @@ class AuthHandler(BaseHandler): # type in the list. (NB that the spec doesn't require us to do so and # clients which favour types that they don't understand over those that # they do are technically broken) + + # start out by assuming PASSWORD is enabled; we will remove it later if not. login_types = [] - if self._password_enabled: + if hs.config.password_localdb_enabled: login_types.append(LoginType.PASSWORD) + for provider in self.password_providers: if hasattr(provider, "get_supported_login_types"): for t in provider.get_supported_login_types().keys(): if t not in login_types: login_types.append(t) + + if not self._password_enabled: + login_types.remove(LoginType.PASSWORD) + self._supported_login_types = login_types + # Login types and UI Auth types have a heavy overlap, but are not # necessarily identical. Login types have SSO (and other login types) # added in the rest layer, see synapse.rest.client.v1.login.LoginRestServerlet.on_GET. @@ -230,6 +235,13 @@ class AuthHandler(BaseHandler): burst_count=self.hs.config.rc_login_failed_attempts.burst_count, ) + # Ratelimitier for failed /login attempts + self._failed_login_attempts_ratelimiter = Ratelimiter( + clock=hs.get_clock(), + rate_hz=self.hs.config.rc_login_failed_attempts.per_second, + burst_count=self.hs.config.rc_login_failed_attempts.burst_count, + ) + self._clock = self.hs.get_clock() # Expire old UI auth sessions after a period of time. @@ -642,14 +654,8 @@ class AuthHandler(BaseHandler): res = await checker.check_auth(authdict, clientip=clientip) return res - # build a v1-login-style dict out of the authdict and fall back to the - # v1 code - user_id = authdict.get("user") - - if user_id is None: - raise SynapseError(400, "", Codes.MISSING_PARAM) - - (canonical_id, callback) = await self.validate_login(user_id, authdict) + # fall back to the v1 login flow + canonical_id, _ = await self.validate_login(authdict) return canonical_id def _get_params_recaptcha(self) -> dict: @@ -698,8 +704,12 @@ class AuthHandler(BaseHandler): } async def get_access_token_for_user_id( - self, user_id: str, device_id: Optional[str], valid_until_ms: Optional[int] - ): + self, + user_id: str, + device_id: Optional[str], + valid_until_ms: Optional[int], + puppets_user_id: Optional[str] = None, + ) -> str: """ Creates a new access token for the user with the given user ID. @@ -725,13 +735,25 @@ class AuthHandler(BaseHandler): fmt_expiry = time.strftime( " until %Y-%m-%d %H:%M:%S", time.localtime(valid_until_ms / 1000.0) ) - logger.info("Logging in user %s on device %s%s", user_id, device_id, fmt_expiry) + + if puppets_user_id: + logger.info( + "Logging in user %s as %s%s", user_id, puppets_user_id, fmt_expiry + ) + else: + logger.info( + "Logging in user %s on device %s%s", user_id, device_id, fmt_expiry + ) await self.auth.check_auth_blocking(user_id) access_token = self.macaroon_gen.generate_access_token(user_id) await self.store.add_access_token_to_user( - user_id, access_token, device_id, valid_until_ms + user_id=user_id, + token=access_token, + device_id=device_id, + valid_until_ms=valid_until_ms, + puppets_user_id=puppets_user_id, ) # the device *should* have been registered before we got here; however, @@ -808,17 +830,17 @@ class AuthHandler(BaseHandler): return self._supported_login_types async def validate_login( - self, username: str, login_submission: Dict[str, Any] + self, login_submission: Dict[str, Any], ratelimit: bool = False, ) -> Tuple[str, Optional[Callable[[Dict[str, str]], None]]]: """Authenticates the user for the /login API - Also used by the user-interactive auth flow to validate - m.login.password auth types. + Also used by the user-interactive auth flow to validate auth types which don't + have an explicit UIA handler, including m.password.auth. Args: - username: username supplied by the user login_submission: the whole of the login submission (including 'type' and other relevant fields) + ratelimit: whether to apply the failed_login_attempt ratelimiter Returns: A tuple of the canonical user id, and optional callback to be called once the access token and device id are issued @@ -827,38 +849,160 @@ class AuthHandler(BaseHandler): SynapseError if there was a problem with the request LoginError if there was an authentication problem. """ - - if username.startswith("@"): - qualified_user_id = username - else: - qualified_user_id = UserID(username, self.hs.hostname).to_string() - login_type = login_submission.get("type") - known_login_type = False + if not isinstance(login_type, str): + raise SynapseError(400, "Bad parameter: type", Codes.INVALID_PARAM) + + # ideally, we wouldn't be checking the identifier unless we know we have a login + # method which uses it (https://github.com/matrix-org/synapse/issues/8836) + # + # But the auth providers' check_auth interface requires a username, so in + # practice we can only support login methods which we can map to a username + # anyway. # special case to check for "password" for the check_password interface # for the auth providers password = login_submission.get("password") - if login_type == LoginType.PASSWORD: if not self._password_enabled: raise SynapseError(400, "Password login has been disabled.") - if not password: - raise SynapseError(400, "Missing parameter: password") + if not isinstance(password, str): + raise SynapseError(400, "Bad parameter: password", Codes.INVALID_PARAM) - for provider in self.password_providers: - if hasattr(provider, "check_password") and login_type == LoginType.PASSWORD: - known_login_type = True - is_valid = await provider.check_password(qualified_user_id, password) - if is_valid: - return qualified_user_id, None + # map old-school login fields into new-school "identifier" fields. + identifier_dict = convert_client_dict_legacy_fields_to_identifier( + login_submission + ) - if not hasattr(provider, "get_supported_login_types") or not hasattr( - provider, "check_auth" - ): - # this password provider doesn't understand custom login types - continue + # convert phone type identifiers to generic threepids + if identifier_dict["type"] == "m.id.phone": + identifier_dict = login_id_phone_to_thirdparty(identifier_dict) + + # convert threepid identifiers to user IDs + if identifier_dict["type"] == "m.id.thirdparty": + address = identifier_dict.get("address") + medium = identifier_dict.get("medium") + + if medium is None or address is None: + raise SynapseError(400, "Invalid thirdparty identifier") + + # For emails, canonicalise the address. + # We store all email addresses canonicalised in the DB. + # (See add_threepid in synapse/handlers/auth.py) + if medium == "email": + try: + address = canonicalise_email(address) + except ValueError as e: + raise SynapseError(400, str(e)) + + # We also apply account rate limiting using the 3PID as a key, as + # otherwise using 3PID bypasses the ratelimiting based on user ID. + if ratelimit: + self._failed_login_attempts_ratelimiter.ratelimit( + (medium, address), update=False + ) + + # Check for login providers that support 3pid login types + if login_type == LoginType.PASSWORD: + # we've already checked that there is a (valid) password field + assert isinstance(password, str) + ( + canonical_user_id, + callback_3pid, + ) = await self.check_password_provider_3pid(medium, address, password) + if canonical_user_id: + # Authentication through password provider and 3pid succeeded + return canonical_user_id, callback_3pid + + # No password providers were able to handle this 3pid + # Check local store + user_id = await self.hs.get_datastore().get_user_id_by_threepid( + medium, address + ) + if not user_id: + logger.warning( + "unknown 3pid identifier medium %s, address %r", medium, address + ) + # We mark that we've failed to log in here, as + # `check_password_provider_3pid` might have returned `None` due + # to an incorrect password, rather than the account not + # existing. + # + # If it returned None but the 3PID was bound then we won't hit + # this code path, which is fine as then the per-user ratelimit + # will kick in below. + if ratelimit: + self._failed_login_attempts_ratelimiter.can_do_action( + (medium, address) + ) + raise LoginError(403, "", errcode=Codes.FORBIDDEN) + + identifier_dict = {"type": "m.id.user", "user": user_id} + # by this point, the identifier should be an m.id.user: if it's anything + # else, we haven't understood it. + if identifier_dict["type"] != "m.id.user": + raise SynapseError(400, "Unknown login identifier type") + + username = identifier_dict.get("user") + if not username: + raise SynapseError(400, "User identifier is missing 'user' key") + + if username.startswith("@"): + qualified_user_id = username + else: + qualified_user_id = UserID(username, self.hs.hostname).to_string() + + # Check if we've hit the failed ratelimit (but don't update it) + if ratelimit: + self._failed_login_attempts_ratelimiter.ratelimit( + qualified_user_id.lower(), update=False + ) + + try: + return await self._validate_userid_login(username, login_submission) + except LoginError: + # The user has failed to log in, so we need to update the rate + # limiter. Using `can_do_action` avoids us raising a ratelimit + # exception and masking the LoginError. The actual ratelimiting + # should have happened above. + if ratelimit: + self._failed_login_attempts_ratelimiter.can_do_action( + qualified_user_id.lower() + ) + raise + + async def _validate_userid_login( + self, username: str, login_submission: Dict[str, Any], + ) -> Tuple[str, Optional[Callable[[Dict[str, str]], None]]]: + """Helper for validate_login + + Handles login, once we've mapped 3pids onto userids + + Args: + username: the username, from the identifier dict + login_submission: the whole of the login submission + (including 'type' and other relevant fields) + Returns: + A tuple of the canonical user id, and optional callback + to be called once the access token and device id are issued + Raises: + StoreError if there was a problem accessing the database + SynapseError if there was a problem with the request + LoginError if there was an authentication problem. + """ + if username.startswith("@"): + qualified_user_id = username + else: + qualified_user_id = UserID(username, self.hs.hostname).to_string() + + login_type = login_submission.get("type") + # we already checked that we have a valid login type + assert isinstance(login_type, str) + + known_login_type = False + + for provider in self.password_providers: supported_login_types = provider.get_supported_login_types() if login_type not in supported_login_types: # this password provider doesn't understand this login type @@ -883,15 +1027,17 @@ class AuthHandler(BaseHandler): result = await provider.check_auth(username, login_type, login_dict) if result: - if isinstance(result, str): - result = (result, None) return result if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled: known_login_type = True + # we've already checked that there is a (valid) password field + password = login_submission["password"] + assert isinstance(password, str) + canonical_user_id = await self._check_local_password( - qualified_user_id, password # type: ignore + qualified_user_id, password ) if canonical_user_id: @@ -922,19 +1068,9 @@ class AuthHandler(BaseHandler): unsuccessful, `user_id` and `callback` are both `None`. """ for provider in self.password_providers: - if hasattr(provider, "check_3pid_auth"): - # This function is able to return a deferred that either - # resolves None, meaning authentication failure, or upon - # success, to a str (which is the user_id) or a tuple of - # (user_id, callback_func), where callback_func should be run - # after we've finished everything else - result = await provider.check_3pid_auth(medium, address, password) - if result: - # Check if the return value is a str or a tuple - if isinstance(result, str): - # If it's a str, set callback function to None - result = (result, None) - return result + result = await provider.check_3pid_auth(medium, address, password) + if result: + return result return None, None @@ -992,16 +1128,11 @@ class AuthHandler(BaseHandler): # see if any of our auth providers want to know about this for provider in self.password_providers: - if hasattr(provider, "on_logged_out"): - # This might return an awaitable, if it does block the log out - # until it completes. - result = provider.on_logged_out( - user_id=user_info.user_id, - device_id=user_info.device_id, - access_token=access_token, - ) - if inspect.isawaitable(result): - await result + await provider.on_logged_out( + user_id=user_info.user_id, + device_id=user_info.device_id, + access_token=access_token, + ) # delete pushers associated with this access token if user_info.token_id is not None: @@ -1030,11 +1161,10 @@ class AuthHandler(BaseHandler): # see if any of our auth providers want to know about this for provider in self.password_providers: - if hasattr(provider, "on_logged_out"): - for token, token_id, device_id in tokens_and_devices: - await provider.on_logged_out( - user_id=user_id, device_id=device_id, access_token=token - ) + for token, token_id, device_id in tokens_and_devices: + await provider.on_logged_out( + user_id=user_id, device_id=device_id, access_token=token + ) # delete pushers associated with the access tokens await self.hs.get_pusherpool().remove_pushers_by_access_token( @@ -1358,3 +1488,127 @@ class MacaroonGenerator: macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) return macaroon + + +class PasswordProvider: + """Wrapper for a password auth provider module + + This class abstracts out all of the backwards-compatibility hacks for + password providers, to provide a consistent interface. + """ + + @classmethod + def load(cls, module, config, module_api: ModuleApi) -> "PasswordProvider": + try: + pp = module(config=config, account_handler=module_api) + except Exception as e: + logger.error("Error while initializing %r: %s", module, e) + raise + return cls(pp, module_api) + + def __init__(self, pp, module_api: ModuleApi): + self._pp = pp + self._module_api = module_api + + self._supported_login_types = {} + + # grandfather in check_password support + if hasattr(self._pp, "check_password"): + self._supported_login_types[LoginType.PASSWORD] = ("password",) + + g = getattr(self._pp, "get_supported_login_types", None) + if g: + self._supported_login_types.update(g()) + + def __str__(self): + return str(self._pp) + + def get_supported_login_types(self) -> Mapping[str, Iterable[str]]: + """Get the login types supported by this password provider + + Returns a map from a login type identifier (such as m.login.password) to an + iterable giving the fields which must be provided by the user in the submission + to the /login API. + + This wrapper adds m.login.password to the list if the underlying password + provider supports the check_password() api. + """ + return self._supported_login_types + + async def check_auth( + self, username: str, login_type: str, login_dict: JsonDict + ) -> Optional[Tuple[str, Optional[Callable]]]: + """Check if the user has presented valid login credentials + + This wrapper also calls check_password() if the underlying password provider + supports the check_password() api and the login type is m.login.password. + + Args: + username: user id presented by the client. Either an MXID or an unqualified + username. + + login_type: the login type being attempted - one of the types returned by + get_supported_login_types() + + login_dict: the dictionary of login secrets passed by the client. + + Returns: (user_id, callback) where `user_id` is the fully-qualified mxid of the + user, and `callback` is an optional callback which will be called with the + result from the /login call (including access_token, device_id, etc.) + """ + # first grandfather in a call to check_password + if login_type == LoginType.PASSWORD: + g = getattr(self._pp, "check_password", None) + if g: + qualified_user_id = self._module_api.get_qualified_user_id(username) + is_valid = await self._pp.check_password( + qualified_user_id, login_dict["password"] + ) + if is_valid: + return qualified_user_id, None + + g = getattr(self._pp, "check_auth", None) + if not g: + return None + result = await g(username, login_type, login_dict) + + # Check if the return value is a str or a tuple + if isinstance(result, str): + # If it's a str, set callback function to None + return result, None + + return result + + async def check_3pid_auth( + self, medium: str, address: str, password: str + ) -> Optional[Tuple[str, Optional[Callable]]]: + g = getattr(self._pp, "check_3pid_auth", None) + if not g: + return None + + # This function is able to return a deferred that either + # resolves None, meaning authentication failure, or upon + # success, to a str (which is the user_id) or a tuple of + # (user_id, callback_func), where callback_func should be run + # after we've finished everything else + result = await g(medium, address, password) + + # Check if the return value is a str or a tuple + if isinstance(result, str): + # If it's a str, set callback function to None + return result, None + + return result + + async def on_logged_out( + self, user_id: str, device_id: Optional[str], access_token: str + ) -> None: + g = getattr(self._pp, "on_logged_out", None) + if not g: + return + + # This might return an awaitable, if it does block the log out + # until it completes. + result = g(user_id=user_id, device_id=device_id, access_token=access_token,) + if inspect.isawaitable(result): + await result diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py index 048a3b3c0b..f4ea0a9767 100644 --- a/synapse/handlers/cas_handler.py +++ b/synapse/handlers/cas_handler.py @@ -14,7 +14,7 @@ # limitations under the License. import logging import urllib -from typing import Dict, Optional, Tuple +from typing import TYPE_CHECKING, Dict, Optional, Tuple from xml.etree import ElementTree as ET from twisted.web.client import PartialDownloadError @@ -23,6 +23,9 @@ from synapse.api.errors import Codes, LoginError from synapse.http.site import SynapseRequest from synapse.types import UserID, map_username_to_mxid_localpart +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) @@ -31,10 +34,10 @@ class CasHandler: Utility class for to handle the response from a CAS SSO service. Args: - hs (synapse.server.HomeServer) + hs """ - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self._hostname = hs.hostname self._auth_handler = hs.get_auth_handler() @@ -200,27 +203,57 @@ class CasHandler: args["session"] = session username, user_display_name = await self._validate_ticket(ticket, args) - localpart = map_username_to_mxid_localpart(username) - user_id = UserID(localpart, self._hostname).to_string() - registered_user_id = await self._auth_handler.check_user_exists(user_id) + # Pull out the user-agent and IP from the request. + user_agent = request.get_user_agent("") + ip_address = self.hs.get_ip_from_request(request) + + # Get the matrix ID from the CAS username. + user_id = await self._map_cas_user_to_matrix_user( + username, user_display_name, user_agent, ip_address + ) if session: await self._auth_handler.complete_sso_ui_auth( - registered_user_id, session, request, + user_id, session, request, ) - else: - if not registered_user_id: - # Pull out the user-agent and IP from the request. - user_agent = request.get_user_agent("") - ip_address = self.hs.get_ip_from_request(request) - - registered_user_id = await self._registration_handler.register_user( - localpart=localpart, - default_display_name=user_display_name, - user_agent_ips=(user_agent, ip_address), - ) + # If this not a UI auth request than there must be a redirect URL. + assert client_redirect_url await self._auth_handler.complete_sso_login( - registered_user_id, request, client_redirect_url + user_id, request, client_redirect_url ) + + async def _map_cas_user_to_matrix_user( + self, + remote_user_id: str, + display_name: Optional[str], + user_agent: str, + ip_address: str, + ) -> str: + """ + Given a CAS username, retrieve the user ID for it and possibly register the user. + + Args: + remote_user_id: The username from the CAS response. + display_name: The display name from the CAS response. + user_agent: The user agent of the client making the request. + ip_address: The IP address of the client making the request. + + Returns: + The user ID associated with this response. + """ + + localpart = map_username_to_mxid_localpart(remote_user_id) + user_id = UserID(localpart, self._hostname).to_string() + registered_user_id = await self._auth_handler.check_user_exists(user_id) + + # If the user does not exist, register it. + if not registered_user_id: + registered_user_id = await self._registration_handler.register_user( + localpart=localpart, + default_display_name=display_name, + user_agent_ips=[(user_agent, ip_address)], + ) + + return registered_user_id diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index 4efe6c530a..e808142365 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -39,6 +39,7 @@ class DeactivateAccountHandler(BaseHandler): self._room_member_handler = hs.get_room_member_handler() self._identity_handler = hs.get_identity_handler() self.user_directory_handler = hs.get_user_directory_handler() + self._server_name = hs.hostname # Flag that indicates whether the process to part users from rooms is running self._user_parter_running = False @@ -152,7 +153,7 @@ class DeactivateAccountHandler(BaseHandler): for room in pending_invites: try: await self._room_member_handler.update_membership( - create_requester(user), + create_requester(user, authenticated_entity=self._server_name), user, room.room_id, "leave", @@ -208,7 +209,7 @@ class DeactivateAccountHandler(BaseHandler): logger.info("User parter parting %r from %r", user_id, room_id) try: await self._room_member_handler.update_membership( - create_requester(user), + create_requester(user, authenticated_entity=self._server_name), user, room_id, "leave", diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index c386957706..b9799090f7 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -55,6 +55,7 @@ from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.events.validator import EventValidator from synapse.handlers._base import BaseHandler +from synapse.http.servlet import assert_params_in_dict from synapse.logging.context import ( make_deferred_yieldable, nested_logging_context, @@ -67,7 +68,7 @@ from synapse.replication.http.devices import ReplicationUserDevicesResyncRestSer from synapse.replication.http.federation import ( ReplicationCleanRoomRestServlet, ReplicationFederationSendEventsRestServlet, - ReplicationStoreRoomOnInviteRestServlet, + ReplicationStoreRoomOnOutlierMembershipRestServlet, ) from synapse.state import StateResolutionStore from synapse.storage.databases.main.events_worker import EventRedactBehaviour @@ -152,12 +153,14 @@ class FederationHandler(BaseHandler): self._user_device_resync = ReplicationUserDevicesResyncRestServlet.make_client( hs ) - self._maybe_store_room_on_invite = ReplicationStoreRoomOnInviteRestServlet.make_client( + self._maybe_store_room_on_outlier_membership = ReplicationStoreRoomOnOutlierMembershipRestServlet.make_client( hs ) else: self._device_list_updater = hs.get_device_handler().device_list_updater - self._maybe_store_room_on_invite = self.store.maybe_store_room_on_invite + self._maybe_store_room_on_outlier_membership = ( + self.store.maybe_store_room_on_outlier_membership + ) # When joining a room we need to queue any events for that room up. # For each room, a list of (pdu, origin) tuples. @@ -1617,7 +1620,7 @@ class FederationHandler(BaseHandler): # keep a record of the room version, if we don't yet know it. # (this may get overwritten if we later get a different room version in a # join dance). - await self._maybe_store_room_on_invite( + await self._maybe_store_room_on_outlier_membership( room_id=event.room_id, room_version=room_version ) @@ -2686,7 +2689,7 @@ class FederationHandler(BaseHandler): ) async def on_exchange_third_party_invite_request( - self, room_id: str, event_dict: JsonDict + self, event_dict: JsonDict ) -> None: """Handle an exchange_third_party_invite request from a remote server @@ -2694,12 +2697,11 @@ class FederationHandler(BaseHandler): into a normal m.room.member invite. Args: - room_id: The ID of the room. - - event_dict (dict[str, Any]): Dictionary containing the event body. + event_dict: Dictionary containing the event body. """ - room_version = await self.store.get_room_version_id(room_id) + assert_params_in_dict(event_dict, ["room_id"]) + room_version = await self.store.get_room_version_id(event_dict["room_id"]) # NB: event_dict has a particular specced format we might need to fudge # if we change event formats too much. diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index bc3e9607ca..9b3c6b4551 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -354,7 +354,8 @@ class IdentityHandler(BaseHandler): raise SynapseError(500, "An error was encountered when sending the email") token_expires = ( - self.hs.clock.time_msec() + self.hs.config.email_validation_token_lifetime + self.hs.get_clock().time_msec() + + self.hs.config.email_validation_token_lifetime ) await self.store.start_or_continue_validation_session( diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index c6791fb912..96843338ae 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -472,7 +472,7 @@ class EventCreationHandler: Returns: Tuple of created event, Context """ - await self.auth.check_auth_blocking(requester.user.to_string()) + await self.auth.check_auth_blocking(requester=requester) if event_dict["type"] == EventTypes.Create and event_dict["state_key"] == "": room_version = event_dict["content"]["room_version"] @@ -619,7 +619,13 @@ class EventCreationHandler: if requester.app_service is not None: return - user_id = requester.user.to_string() + user_id = requester.authenticated_entity + if not user_id.startswith("@"): + # The authenticated entity might not be a user, e.g. if it's the + # server puppetting the user. + return + + user = UserID.from_string(user_id) # exempt the system notices user if ( @@ -639,9 +645,7 @@ class EventCreationHandler: if u["consent_version"] == self.config.user_consent_version: return - consent_uri = self._consent_uri_builder.build_user_consent_uri( - requester.user.localpart - ) + consent_uri = self._consent_uri_builder.build_user_consent_uri(user.localpart) msg = self._block_events_without_consent_error % {"consent_uri": consent_uri} raise ConsentNotGivenError(msg=msg, consent_uri=consent_uri) @@ -1252,7 +1256,7 @@ class EventCreationHandler: for user_id in members: if not self.hs.is_mine_id(user_id): continue - requester = create_requester(user_id) + requester = create_requester(user_id, authenticated_entity=self.server_name) try: event, context = await self.create_event( requester, @@ -1273,11 +1277,6 @@ class EventCreationHandler: requester, event, context, ratelimit=False, ignore_shadow_ban=True, ) return True - except ConsentNotGivenError: - logger.info( - "Failed to send dummy event into room %s for user %s due to " - "lack of consent. Will try another user" % (room_id, user_id) - ) except AuthError: logger.info( "Failed to send dummy event into room %s for user %s due to " diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index 331d4e7e96..c605f7082a 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import inspect import logging from typing import TYPE_CHECKING, Dict, Generic, List, Optional, Tuple, TypeVar from urllib.parse import urlencode @@ -34,7 +35,8 @@ from typing_extensions import TypedDict from twisted.web.client import readBody from synapse.config import ConfigError -from synapse.http.server import respond_with_html +from synapse.handlers._base import BaseHandler +from synapse.handlers.sso import MappingException, UserAttributes from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart @@ -83,17 +85,12 @@ class OidcError(Exception): return self.error -class MappingException(Exception): - """Used to catch errors when mapping the UserInfo object - """ - - -class OidcHandler: +class OidcHandler(BaseHandler): """Handles requests related to the OpenID Connect login flow. """ def __init__(self, hs: "HomeServer"): - self.hs = hs + super().__init__(hs) self._callback_url = hs.config.oidc_callback_url # type: str self._scopes = hs.config.oidc_scopes # type: List[str] self._user_profile_method = hs.config.oidc_user_profile_method # type: str @@ -120,36 +117,13 @@ class OidcHandler: self._http_client = hs.get_proxied_http_client() self._auth_handler = hs.get_auth_handler() self._registration_handler = hs.get_registration_handler() - self._datastore = hs.get_datastore() - self._clock = hs.get_clock() - self._hostname = hs.hostname # type: str self._server_name = hs.config.server_name # type: str self._macaroon_secret_key = hs.config.macaroon_secret_key - self._error_template = hs.config.sso_error_template # identifier for the external_ids table self._auth_provider_id = "oidc" - def _render_error( - self, request, error: str, error_description: Optional[str] = None - ) -> None: - """Render the error template and respond to the request with it. - - This is used to show errors to the user. The template of this page can - be found under `synapse/res/templates/sso_error.html`. - - Args: - request: The incoming request from the browser. - We'll respond with an HTML page describing the error. - error: A technical identifier for this error. Those include - well-known OAuth2/OIDC error types like invalid_request or - access_denied. - error_description: A human-readable description of the error. - """ - html = self._error_template.render( - error=error, error_description=error_description - ) - respond_with_html(request, 400, html) + self._sso_handler = hs.get_sso_handler() def _validate_metadata(self): """Verifies the provider metadata. @@ -571,7 +545,7 @@ class OidcHandler: Since we might want to display OIDC-related errors in a user-friendly way, we don't raise SynapseError from here. Instead, we call - ``self._render_error`` which displays an HTML page for the error. + ``self._sso_handler.render_error`` which displays an HTML page for the error. Most of the OpenID Connect logic happens here: @@ -609,7 +583,7 @@ class OidcHandler: if error != "access_denied": logger.error("Error from the OIDC provider: %s %s", error, description) - self._render_error(request, error, description) + self._sso_handler.render_error(request, error, description) return # otherwise, it is presumably a successful response. see: @@ -619,7 +593,9 @@ class OidcHandler: session = request.getCookie(SESSION_COOKIE_NAME) # type: Optional[bytes] if session is None: logger.info("No session cookie found") - self._render_error(request, "missing_session", "No session cookie found") + self._sso_handler.render_error( + request, "missing_session", "No session cookie found" + ) return # Remove the cookie. There is a good chance that if the callback failed @@ -637,7 +613,9 @@ class OidcHandler: # Check for the state query parameter if b"state" not in request.args: logger.info("State parameter is missing") - self._render_error(request, "invalid_request", "State parameter is missing") + self._sso_handler.render_error( + request, "invalid_request", "State parameter is missing" + ) return state = request.args[b"state"][0].decode() @@ -651,17 +629,19 @@ class OidcHandler: ) = self._verify_oidc_session_token(session, state) except MacaroonDeserializationException as e: logger.exception("Invalid session") - self._render_error(request, "invalid_session", str(e)) + self._sso_handler.render_error(request, "invalid_session", str(e)) return except MacaroonInvalidSignatureException as e: logger.exception("Could not verify session") - self._render_error(request, "mismatching_session", str(e)) + self._sso_handler.render_error(request, "mismatching_session", str(e)) return # Exchange the code with the provider if b"code" not in request.args: logger.info("Code parameter is missing") - self._render_error(request, "invalid_request", "Code parameter is missing") + self._sso_handler.render_error( + request, "invalid_request", "Code parameter is missing" + ) return logger.debug("Exchanging code") @@ -670,7 +650,7 @@ class OidcHandler: token = await self._exchange_code(code) except OidcError as e: logger.exception("Could not exchange code") - self._render_error(request, e.error, e.error_description) + self._sso_handler.render_error(request, e.error, e.error_description) return logger.debug("Successfully obtained OAuth2 access token") @@ -683,7 +663,7 @@ class OidcHandler: userinfo = await self._fetch_userinfo(token) except Exception as e: logger.exception("Could not fetch userinfo") - self._render_error(request, "fetch_error", str(e)) + self._sso_handler.render_error(request, "fetch_error", str(e)) return else: logger.debug("Extracting userinfo from id_token") @@ -691,7 +671,7 @@ class OidcHandler: userinfo = await self._parse_id_token(token, nonce=nonce) except Exception as e: logger.exception("Invalid id_token") - self._render_error(request, "invalid_token", str(e)) + self._sso_handler.render_error(request, "invalid_token", str(e)) return # Pull out the user-agent and IP from the request. @@ -705,7 +685,7 @@ class OidcHandler: ) except MappingException as e: logger.exception("Could not map user") - self._render_error(request, "mapping_error", str(e)) + self._sso_handler.render_error(request, "mapping_error", str(e)) return # Mapping providers might not have get_extra_attributes: only call this @@ -770,7 +750,7 @@ class OidcHandler: macaroon.add_first_party_caveat( "ui_auth_session_id = %s" % (ui_auth_session_id,) ) - now = self._clock.time_msec() + now = self.clock.time_msec() expiry = now + duration_in_ms macaroon.add_first_party_caveat("time < %d" % (expiry,)) @@ -845,7 +825,7 @@ class OidcHandler: if not caveat.startswith(prefix): return False expiry = int(caveat[len(prefix) :]) - now = self._clock.time_msec() + now = self.clock.time_msec() return now < expiry async def _map_userinfo_to_user( @@ -885,71 +865,77 @@ class OidcHandler: # to be strings. remote_user_id = str(remote_user_id) - logger.info( - "Looking for existing mapping for user %s:%s", - self._auth_provider_id, - remote_user_id, - ) - - registered_user_id = await self._datastore.get_user_by_external_id( - self._auth_provider_id, remote_user_id, + # Older mapping providers don't accept the `failures` argument, so we + # try and detect support. + mapper_signature = inspect.signature( + self._user_mapping_provider.map_user_attributes ) + supports_failures = "failures" in mapper_signature.parameters - if registered_user_id is not None: - logger.info("Found existing mapping %s", registered_user_id) - return registered_user_id - - try: - attributes = await self._user_mapping_provider.map_user_attributes( - userinfo, token - ) - except Exception as e: - raise MappingException( - "Could not extract user attributes from OIDC response: " + str(e) - ) + async def oidc_response_to_user_attributes(failures: int) -> UserAttributes: + """ + Call the mapping provider to map the OIDC userinfo and token to user attributes. - logger.debug( - "Retrieved user attributes from user mapping provider: %r", attributes - ) + This is backwards compatibility for abstraction for the SSO handler. + """ + if supports_failures: + attributes = await self._user_mapping_provider.map_user_attributes( + userinfo, token, failures + ) + else: + # If the mapping provider does not support processing failures, + # do not continually generate the same Matrix ID since it will + # continue to already be in use. Note that the error raised is + # arbitrary and will get turned into a MappingException. + if failures: + raise MappingException( + "Mapping provider does not support de-duplicating Matrix IDs" + ) - if not attributes["localpart"]: - raise MappingException("localpart is empty") + attributes = await self._user_mapping_provider.map_user_attributes( # type: ignore + userinfo, token + ) - localpart = map_username_to_mxid_localpart(attributes["localpart"]) + return UserAttributes(**attributes) - user_id = UserID(localpart, self._hostname).to_string() - users = await self._datastore.get_users_by_id_case_insensitive(user_id) - if users: + async def grandfather_existing_users() -> Optional[str]: if self._allow_existing_users: - if len(users) == 1: - registered_user_id = next(iter(users)) - elif user_id in users: - registered_user_id = user_id - else: - raise MappingException( - "Attempted to login as '{}' but it matches more than one user inexactly: {}".format( - user_id, list(users.keys()) + # If allowing existing users we want to generate a single localpart + # and attempt to match it. + attributes = await oidc_response_to_user_attributes(failures=0) + + user_id = UserID(attributes.localpart, self.server_name).to_string() + users = await self.store.get_users_by_id_case_insensitive(user_id) + if users: + # If an existing matrix ID is returned, then use it. + if len(users) == 1: + previously_registered_user_id = next(iter(users)) + elif user_id in users: + previously_registered_user_id = user_id + else: + # Do not attempt to continue generating Matrix IDs. + raise MappingException( + "Attempted to login as '{}' but it matches more than one user inexactly: {}".format( + user_id, users + ) ) - ) - else: - # This mxid is taken - raise MappingException("mxid '{}' is already taken".format(user_id)) - else: - # It's the first time this user is logging in and the mapped mxid was - # not taken, register the user - registered_user_id = await self._registration_handler.register_user( - localpart=localpart, - default_display_name=attributes["display_name"], - user_agent_ips=(user_agent, ip_address), - ) - await self._datastore.record_user_external_id( - self._auth_provider_id, remote_user_id, registered_user_id, + + return previously_registered_user_id + + return None + + return await self._sso_handler.get_mxid_from_sso( + self._auth_provider_id, + remote_user_id, + user_agent, + ip_address, + oidc_response_to_user_attributes, + grandfather_existing_users, ) - return registered_user_id -UserAttribute = TypedDict( - "UserAttribute", {"localpart": str, "display_name": Optional[str]} +UserAttributeDict = TypedDict( + "UserAttributeDict", {"localpart": str, "display_name": Optional[str]} ) C = TypeVar("C") @@ -992,13 +978,15 @@ class OidcMappingProvider(Generic[C]): raise NotImplementedError() async def map_user_attributes( - self, userinfo: UserInfo, token: Token - ) -> UserAttribute: + self, userinfo: UserInfo, token: Token, failures: int + ) -> UserAttributeDict: """Map a `UserInfo` object into user attributes. Args: userinfo: An object representing the user given by the OIDC provider token: A dict with the tokens returned by the provider + failures: How many times a call to this function with this + UserInfo has resulted in a failure. Returns: A dict containing the ``localpart`` and (optionally) the ``display_name`` @@ -1098,10 +1086,17 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): return userinfo[self._config.subject_claim] async def map_user_attributes( - self, userinfo: UserInfo, token: Token - ) -> UserAttribute: + self, userinfo: UserInfo, token: Token, failures: int + ) -> UserAttributeDict: localpart = self._config.localpart_template.render(user=userinfo).strip() + # Ensure only valid characters are included in the MXID. + localpart = map_username_to_mxid_localpart(localpart) + + # Append suffix integer if last call to this function failed to produce + # a usable mxid. + localpart += str(failures) if failures else "" + display_name = None # type: Optional[str] if self._config.display_name_template is not None: display_name = self._config.display_name_template.render( @@ -1111,7 +1106,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): if display_name == "": display_name = None - return UserAttribute(localpart=localpart, display_name=display_name) + return UserAttributeDict(localpart=localpart, display_name=display_name) async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict: extras = {} # type: Dict[str, str] diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 426b58da9e..5372753707 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -299,17 +299,22 @@ class PaginationHandler: """ return self._purges_by_id.get(purge_id) - async def purge_room(self, room_id: str) -> None: - """Purge the given room from the database""" + async def purge_room(self, room_id: str, force: bool = False) -> None: + """Purge the given room from the database. + + Args: + room_id: room to be purged + force: set true to skip checking for joined users. + """ with await self.pagination_lock.write(room_id): # check we know about the room await self.store.get_room_version_id(room_id) # first check that we have no users in this room - joined = await self.store.is_host_joined(room_id, self._server_name) - - if joined: - raise SynapseError(400, "Users are still joined to this room") + if not force: + joined = await self.store.is_host_joined(room_id, self._server_name) + if joined: + raise SynapseError(400, "Users are still joined to this room") await self.storage.purge_events.purge_room(room_id) diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 8e014c9bb5..22d1e9d35c 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -25,7 +25,7 @@ The methods that define policy are: import abc import logging from contextlib import contextmanager -from typing import Dict, Iterable, List, Set, Tuple +from typing import TYPE_CHECKING, Dict, Iterable, List, Set, Tuple from prometheus_client import Counter from typing_extensions import ContextManager @@ -46,8 +46,7 @@ from synapse.util.caches.descriptors import cached from synapse.util.metrics import Measure from synapse.util.wheel_timer import WheelTimer -MYPY = False -if MYPY: +if TYPE_CHECKING: from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 74a1ddd780..dee0ef45e7 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -206,7 +206,9 @@ class ProfileHandler(BaseHandler): # the join event to update the displayname in the rooms. # This must be done by the target user himself. if by_admin: - requester = create_requester(target_user) + requester = create_requester( + target_user, authenticated_entity=requester.authenticated_entity, + ) await self.store.set_profile_displayname( target_user.localpart, displayname_to_set @@ -286,7 +288,9 @@ class ProfileHandler(BaseHandler): # Same like set_displayname if by_admin: - requester = create_requester(target_user) + requester = create_requester( + target_user, authenticated_entity=requester.authenticated_entity + ) await self.store.set_profile_avatar_url(target_user.localpart, new_avatar_url) diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index c242c409cf..153cbae7b9 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -158,7 +158,8 @@ class ReceiptEventSource: if from_key == to_key: return [], to_key - # We first need to fetch all new receipts + # Fetch all read receipts for all rooms, up to a limit of 100. This is ordered + # by most recent. rooms_to_events = await self.store.get_linearized_receipts_for_all_rooms( from_key=from_key, to_key=to_key ) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index ed1ff62599..0d85fd0868 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -15,10 +15,12 @@ """Contains functions for registering clients.""" import logging +from typing import TYPE_CHECKING, List, Optional, Tuple from synapse import types from synapse.api.constants import MAX_USERID_LENGTH, EventTypes, JoinRules, LoginType from synapse.api.errors import AuthError, Codes, ConsentNotGivenError, SynapseError +from synapse.appservice import ApplicationService from synapse.config.server import is_threepid_reserved from synapse.http.servlet import assert_params_in_dict from synapse.replication.http.login import RegisterDeviceReplicationServlet @@ -32,16 +34,14 @@ from synapse.types import RoomAlias, UserID, create_requester from ._base import BaseHandler +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) class RegistrationHandler(BaseHandler): - def __init__(self, hs): - """ - - Args: - hs (synapse.server.HomeServer): - """ + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.hs = hs self.auth = hs.get_auth() @@ -52,6 +52,7 @@ class RegistrationHandler(BaseHandler): self.ratelimiter = hs.get_registration_ratelimiter() self.macaroon_gen = hs.get_macaroon_generator() self._server_notices_mxid = hs.config.server_notices_mxid + self._server_name = hs.hostname self.spam_checker = hs.get_spam_checker() @@ -70,7 +71,10 @@ class RegistrationHandler(BaseHandler): self.session_lifetime = hs.config.session_lifetime async def check_username( - self, localpart, guest_access_token=None, assigned_user_id=None + self, + localpart: str, + guest_access_token: Optional[str] = None, + assigned_user_id: Optional[str] = None, ): if types.contains_invalid_mxid_characters(localpart): raise SynapseError( @@ -139,39 +143,45 @@ class RegistrationHandler(BaseHandler): async def register_user( self, - localpart=None, - password_hash=None, - guest_access_token=None, - make_guest=False, - admin=False, - threepid=None, - user_type=None, - default_display_name=None, - address=None, - bind_emails=[], - by_admin=False, - user_agent_ips=None, - ): + localpart: Optional[str] = None, + password_hash: Optional[str] = None, + guest_access_token: Optional[str] = None, + make_guest: bool = False, + admin: bool = False, + threepid: Optional[dict] = None, + user_type: Optional[str] = None, + default_display_name: Optional[str] = None, + address: Optional[str] = None, + bind_emails: List[str] = [], + by_admin: bool = False, + user_agent_ips: Optional[List[Tuple[str, str]]] = None, + ) -> str: """Registers a new client on the server. Args: localpart: The local part of the user ID to register. If None, one will be generated. - password_hash (str|None): The hashed password to assign to this user so they can + password_hash: The hashed password to assign to this user so they can login again. This can be None which means they cannot login again via a password (e.g. the user is an application service user). - user_type (str|None): type of user. One of the values from + guest_access_token: The access token used when this was a guest + account. + make_guest: True if the the new user should be guest, + false to add a regular user account. + admin: True if the user should be registered as a server admin. + threepid: The threepid used for registering, if any. + user_type: type of user. One of the values from api.constants.UserTypes, or None for a normal user. - default_display_name (unicode|None): if set, the new user's displayname + default_display_name: if set, the new user's displayname will be set to this. Defaults to 'localpart'. - address (str|None): the IP address used to perform the registration. - bind_emails (List[str]): list of emails to bind to this account. - by_admin (bool): True if this registration is being made via the + address: the IP address used to perform the registration. + bind_emails: list of emails to bind to this account. + by_admin: True if this registration is being made via the admin api, otherwise False. - user_agent_ips (List[(str, str)]): Tuples of IP addresses and user-agents used + user_agent_ips: Tuples of IP addresses and user-agents used during the registration process. Returns: - str: user_id + The registere user_id. Raises: SynapseError if there was a problem registering. """ @@ -235,8 +245,10 @@ class RegistrationHandler(BaseHandler): else: # autogen a sequential user ID fail_count = 0 - user = None - while not user: + # If a default display name is not given, generate one. + generate_display_name = default_display_name is None + # This breaks on successful registration *or* errors after 10 failures. + while True: # Fail after being unable to find a suitable ID a few times if fail_count > 10: raise SynapseError(500, "Unable to find a suitable guest user ID") @@ -245,7 +257,7 @@ class RegistrationHandler(BaseHandler): user = UserID(localpart, self.hs.hostname) user_id = user.to_string() self.check_user_id_not_appservice_exclusive(user_id) - if default_display_name is None: + if generate_display_name: default_display_name = localpart try: await self.register_with_store( @@ -261,8 +273,6 @@ class RegistrationHandler(BaseHandler): break except SynapseError: # if user id is taken, just generate another - user = None - user_id = None fail_count += 1 if not self.hs.config.user_consent_at_registration: @@ -294,7 +304,7 @@ class RegistrationHandler(BaseHandler): return user_id - async def _create_and_join_rooms(self, user_id: str): + async def _create_and_join_rooms(self, user_id: str) -> None: """ Create the auto-join rooms and join or invite the user to them. @@ -317,7 +327,8 @@ class RegistrationHandler(BaseHandler): requires_join = False if self.hs.config.registration.auto_join_user_id: fake_requester = create_requester( - self.hs.config.registration.auto_join_user_id + self.hs.config.registration.auto_join_user_id, + authenticated_entity=self._server_name, ) # If the room requires an invite, add the user to the list of invites. @@ -329,7 +340,9 @@ class RegistrationHandler(BaseHandler): # being necessary this will occur after the invite was sent. requires_join = True else: - fake_requester = create_requester(user_id) + fake_requester = create_requester( + user_id, authenticated_entity=self._server_name + ) # Choose whether to federate the new room. if not self.hs.config.registration.autocreate_auto_join_rooms_federated: @@ -362,7 +375,9 @@ class RegistrationHandler(BaseHandler): # created it, then ensure the first user joins it. if requires_join: await room_member_handler.update_membership( - requester=create_requester(user_id), + requester=create_requester( + user_id, authenticated_entity=self._server_name + ), target=UserID.from_string(user_id), room_id=info["room_id"], # Since it was just created, there are no remote hosts. @@ -370,15 +385,10 @@ class RegistrationHandler(BaseHandler): action="join", ratelimit=False, ) - - except ConsentNotGivenError as e: - # Technically not necessary to pull out this error though - # moving away from bare excepts is a good thing to do. - logger.error("Failed to join new user to %r: %r", r, e) except Exception as e: logger.error("Failed to join new user to %r: %r", r, e) - async def _join_rooms(self, user_id: str): + async def _join_rooms(self, user_id: str) -> None: """ Join or invite the user to the auto-join rooms. @@ -424,9 +434,13 @@ class RegistrationHandler(BaseHandler): # Send the invite, if necessary. if requires_invite: + # If an invite is required, there must be a auto-join user ID. + assert self.hs.config.registration.auto_join_user_id + await room_member_handler.update_membership( requester=create_requester( - self.hs.config.registration.auto_join_user_id + self.hs.config.registration.auto_join_user_id, + authenticated_entity=self._server_name, ), target=UserID.from_string(user_id), room_id=room_id, @@ -437,7 +451,9 @@ class RegistrationHandler(BaseHandler): # Send the join. await room_member_handler.update_membership( - requester=create_requester(user_id), + requester=create_requester( + user_id, authenticated_entity=self._server_name + ), target=UserID.from_string(user_id), room_id=room_id, remote_room_hosts=remote_room_hosts, @@ -452,7 +468,7 @@ class RegistrationHandler(BaseHandler): except Exception as e: logger.error("Failed to join new user to %r: %r", r, e) - async def _auto_join_rooms(self, user_id: str): + async def _auto_join_rooms(self, user_id: str) -> None: """Automatically joins users to auto join rooms - creating the room in the first place if the user is the first to be created. @@ -475,16 +491,16 @@ class RegistrationHandler(BaseHandler): else: await self._join_rooms(user_id) - async def post_consent_actions(self, user_id): + async def post_consent_actions(self, user_id: str) -> None: """A series of registration actions that can only be carried out once consent has been granted Args: - user_id (str): The user to join + user_id: The user to join """ await self._auto_join_rooms(user_id) - async def appservice_register(self, user_localpart, as_token): + async def appservice_register(self, user_localpart: str, as_token: str) -> str: user = UserID(user_localpart, self.hs.hostname) user_id = user.to_string() service = self.store.get_app_service_by_token(as_token) @@ -509,7 +525,9 @@ class RegistrationHandler(BaseHandler): ) return user_id - def check_user_id_not_appservice_exclusive(self, user_id, allowed_appservice=None): + def check_user_id_not_appservice_exclusive( + self, user_id: str, allowed_appservice: Optional[ApplicationService] = None + ) -> None: # don't allow people to register the server notices mxid if self._server_notices_mxid is not None: if user_id == self._server_notices_mxid: @@ -533,12 +551,12 @@ class RegistrationHandler(BaseHandler): errcode=Codes.EXCLUSIVE, ) - def check_registration_ratelimit(self, address): + def check_registration_ratelimit(self, address: Optional[str]) -> None: """A simple helper method to check whether the registration rate limit has been hit for a given IP address Args: - address (str|None): the IP address used to perform the registration. If this is + address: the IP address used to perform the registration. If this is None, no ratelimiting will be performed. Raises: @@ -549,42 +567,39 @@ class RegistrationHandler(BaseHandler): self.ratelimiter.ratelimit(address) - def register_with_store( + async def register_with_store( self, - user_id, - password_hash=None, - was_guest=False, - make_guest=False, - appservice_id=None, - create_profile_with_displayname=None, - admin=False, - user_type=None, - address=None, - shadow_banned=False, - ): + user_id: str, + password_hash: Optional[str] = None, + was_guest: bool = False, + make_guest: bool = False, + appservice_id: Optional[str] = None, + create_profile_with_displayname: Optional[str] = None, + admin: bool = False, + user_type: Optional[str] = None, + address: Optional[str] = None, + shadow_banned: bool = False, + ) -> None: """Register user in the datastore. Args: - user_id (str): The desired user ID to register. - password_hash (str|None): Optional. The password hash for this user. - was_guest (bool): Optional. Whether this is a guest account being + user_id: The desired user ID to register. + password_hash: Optional. The password hash for this user. + was_guest: Optional. Whether this is a guest account being upgraded to a non-guest account. - make_guest (boolean): True if the the new user should be guest, + make_guest: True if the the new user should be guest, false to add a regular user account. - appservice_id (str|None): The ID of the appservice registering the user. - create_profile_with_displayname (unicode|None): Optionally create a + appservice_id: The ID of the appservice registering the user. + create_profile_with_displayname: Optionally create a profile for the user, setting their displayname to the given value - admin (boolean): is an admin user? - user_type (str|None): type of user. One of the values from + admin: is an admin user? + user_type: type of user. One of the values from api.constants.UserTypes, or None for a normal user. - address (str|None): the IP address used to perform the registration. - shadow_banned (bool): Whether to shadow-ban the user - - Returns: - Awaitable + address: the IP address used to perform the registration. + shadow_banned: Whether to shadow-ban the user """ if self.hs.config.worker_app: - return self._register_client( + await self._register_client( user_id=user_id, password_hash=password_hash, was_guest=was_guest, @@ -597,7 +612,7 @@ class RegistrationHandler(BaseHandler): shadow_banned=shadow_banned, ) else: - return self.store.register_user( + await self.store.register_user( user_id=user_id, password_hash=password_hash, was_guest=was_guest, @@ -610,22 +625,24 @@ class RegistrationHandler(BaseHandler): ) async def register_device( - self, user_id, device_id, initial_display_name, is_guest=False - ): + self, + user_id: str, + device_id: Optional[str], + initial_display_name: Optional[str], + is_guest: bool = False, + ) -> Tuple[str, str]: """Register a device for a user and generate an access token. The access token will be limited by the homeserver's session_lifetime config. Args: - user_id (str): full canonical @user:id - device_id (str|None): The device ID to check, or None to generate - a new one. - initial_display_name (str|None): An optional display name for the - device. - is_guest (bool): Whether this is a guest account + user_id: full canonical @user:id + device_id: The device ID to check, or None to generate a new one. + initial_display_name: An optional display name for the device. + is_guest: Whether this is a guest account Returns: - tuple[str, str]: Tuple of device ID and access token + Tuple of device ID and access token """ if self.hs.config.worker_app: @@ -645,7 +662,7 @@ class RegistrationHandler(BaseHandler): ) valid_until_ms = self.clock.time_msec() + self.session_lifetime - device_id = await self.device_handler.check_device_registered( + registered_device_id = await self.device_handler.check_device_registered( user_id, device_id, initial_display_name ) if is_guest: @@ -655,20 +672,21 @@ class RegistrationHandler(BaseHandler): ) else: access_token = await self._auth_handler.get_access_token_for_user_id( - user_id, device_id=device_id, valid_until_ms=valid_until_ms + user_id, device_id=registered_device_id, valid_until_ms=valid_until_ms ) - return (device_id, access_token) + return (registered_device_id, access_token) - async def post_registration_actions(self, user_id, auth_result, access_token): + async def post_registration_actions( + self, user_id: str, auth_result: dict, access_token: Optional[str] + ) -> None: """A user has completed registration Args: - user_id (str): The user ID that consented - auth_result (dict): The authenticated credentials of the newly - registered user. - access_token (str|None): The access token of the newly logged in - device, or None if `inhibit_login` enabled. + user_id: The user ID that consented + auth_result: The authenticated credentials of the newly registered user. + access_token: The access token of the newly logged in device, or + None if `inhibit_login` enabled. """ if self.hs.config.worker_app: await self._post_registration_client( @@ -694,19 +712,20 @@ class RegistrationHandler(BaseHandler): if auth_result and LoginType.TERMS in auth_result: await self._on_user_consented(user_id, self.hs.config.user_consent_version) - async def _on_user_consented(self, user_id, consent_version): + async def _on_user_consented(self, user_id: str, consent_version: str) -> None: """A user consented to the terms on registration Args: - user_id (str): The user ID that consented. - consent_version (str): version of the policy the user has - consented to. + user_id: The user ID that consented. + consent_version: version of the policy the user has consented to. """ logger.info("%s has consented to the privacy policy", user_id) await self.store.user_set_consent_version(user_id, consent_version) await self.post_consent_actions(user_id) - async def _register_email_threepid(self, user_id, threepid, token): + async def _register_email_threepid( + self, user_id: str, threepid: dict, token: Optional[str] + ) -> None: """Add an email address as a 3pid identifier Also adds an email pusher for the email address, if configured in the @@ -715,10 +734,9 @@ class RegistrationHandler(BaseHandler): Must be called on master. Args: - user_id (str): id of user - threepid (object): m.login.email.identity auth response - token (str|None): access_token for the user, or None if not logged - in. + user_id: id of user + threepid: m.login.email.identity auth response + token: access_token for the user, or None if not logged in. """ reqd = ("medium", "address", "validated_at") if any(x not in threepid for x in reqd): @@ -744,6 +762,8 @@ class RegistrationHandler(BaseHandler): # up when the access token is saved, but that's quite an # invasive change I'd rather do separately. user_tuple = await self.store.get_user_by_access_token(token) + # The token better still exist. + assert user_tuple token_id = user_tuple.token_id await self.pusher_pool.add_pusher( @@ -758,14 +778,14 @@ class RegistrationHandler(BaseHandler): data={}, ) - async def _register_msisdn_threepid(self, user_id, threepid): + async def _register_msisdn_threepid(self, user_id: str, threepid: dict) -> None: """Add a phone number as a 3pid identifier Must be called on master. Args: - user_id (str): id of user - threepid (object): m.login.msisdn auth response + user_id: id of user + threepid: m.login.msisdn auth response """ try: assert_params_in_dict(threepid, ["medium", "address", "validated_at"]) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index e73031475f..930047e730 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -587,7 +587,7 @@ class RoomCreationHandler(BaseHandler): """ user_id = requester.user.to_string() - await self.auth.check_auth_blocking(user_id) + await self.auth.check_auth_blocking(requester=requester) if ( self._server_notices_mxid is not None @@ -1257,7 +1257,9 @@ class RoomShutdownHandler: 400, "User must be our own: %s" % (new_room_user_id,) ) - room_creator_requester = create_requester(new_room_user_id) + room_creator_requester = create_requester( + new_room_user_id, authenticated_entity=requester_user_id + ) info, stream_id = await self._room_creation_handler.create_room( room_creator_requester, @@ -1297,7 +1299,9 @@ class RoomShutdownHandler: try: # Kick users from room - target_requester = create_requester(user_id) + target_requester = create_requester( + user_id, authenticated_entity=requester_user_id + ) _, stream_id = await self.room_member_handler.update_membership( requester=target_requester, target=target_requester.user, diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 7cd858b7db..c002886324 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -31,7 +31,6 @@ from synapse.api.errors import ( from synapse.api.ratelimiting import Ratelimiter from synapse.events import EventBase from synapse.events.snapshot import EventContext -from synapse.storage.roommember import RoomsForUser from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID from synapse.util.async_helpers import Linearizer from synapse.util.distributor import user_left_room @@ -347,7 +346,15 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # later on. content = dict(content) - if not self.allow_per_room_profiles or requester.shadow_banned: + # allow the server notices mxid to set room-level profile + is_requester_server_notices_user = ( + self._server_notices_mxid is not None + and requester.user.to_string() == self._server_notices_mxid + ) + + if ( + not self.allow_per_room_profiles and not is_requester_server_notices_user + ) or requester.shadow_banned: # Strip profile data, knowing that new profile data will be added to the # event's content in event_creation_handler.create_event() using the target's # global profile. @@ -515,10 +522,16 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): elif effective_membership_state == Membership.LEAVE: if not is_host_in_room: # perhaps we've been invited - invite = await self.store.get_invite_for_local_user_in_room( - user_id=target.to_string(), room_id=room_id - ) # type: Optional[RoomsForUser] - if not invite: + ( + current_membership_type, + current_membership_event_id, + ) = await self.store.get_local_current_membership_for_user_in_room( + target.to_string(), room_id + ) + if ( + current_membership_type != Membership.INVITE + or not current_membership_event_id + ): logger.info( "%s sent a leave request to %s, but that is not an active room " "on this server, and there is no pending invite", @@ -528,6 +541,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): raise SynapseError(404, "Not a known room") + invite = await self.store.get_event(current_membership_event_id) logger.info( "%s rejects invite to %s from %s", target, room_id, invite.sender ) @@ -965,6 +979,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): self.distributor = hs.get_distributor() self.distributor.declare("user_left_room") + self._server_name = hs.hostname async def _is_remote_room_too_complex( self, room_id: str, remote_room_hosts: List[str] @@ -1059,7 +1074,9 @@ class RoomMemberMasterHandler(RoomMemberHandler): return event_id, stream_id # The room is too large. Leave. - requester = types.create_requester(user, None, False, False, None) + requester = types.create_requester( + user, authenticated_entity=self._server_name + ) await self.update_membership( requester=requester, target=user, room_id=room_id, action="leave" ) @@ -1104,32 +1121,34 @@ class RoomMemberMasterHandler(RoomMemberHandler): # logger.warning("Failed to reject invite: %s", e) - return await self._locally_reject_invite( + return await self._generate_local_out_of_band_leave( invite_event, txn_id, requester, content ) - async def _locally_reject_invite( + async def _generate_local_out_of_band_leave( self, - invite_event: EventBase, + previous_membership_event: EventBase, txn_id: Optional[str], requester: Requester, content: JsonDict, ) -> Tuple[str, int]: - """Generate a local invite rejection + """Generate a local leave event for a room - This is called after we fail to reject an invite via a remote server. It - generates an out-of-band membership event locally. + This can be called after we e.g fail to reject an invite via a remote server. + It generates an out-of-band membership event locally. Args: - invite_event: the invite to be rejected + previous_membership_event: the previous membership event for this user txn_id: optional transaction ID supplied by the client - requester: user making the rejection request, according to the access token - content: additional content to include in the rejection event. + requester: user making the request, according to the access token + content: additional content to include in the leave event. Normally an empty dict. - """ - room_id = invite_event.room_id - target_user = invite_event.state_key + Returns: + A tuple containing (event_id, stream_id of the leave event) + """ + room_id = previous_membership_event.room_id + target_user = previous_membership_event.state_key content["membership"] = Membership.LEAVE @@ -1141,12 +1160,12 @@ class RoomMemberMasterHandler(RoomMemberHandler): "state_key": target_user, } - # the auth events for the new event are the same as that of the invite, plus - # the invite itself. + # the auth events for the new event are the same as that of the previous event, plus + # the event itself. # - # the prev_events are just the invite. - prev_event_ids = [invite_event.event_id] - auth_event_ids = invite_event.auth_event_ids() + prev_event_ids + # the prev_events consist solely of the previous membership event. + prev_event_ids = [previous_membership_event.event_id] + auth_event_ids = previous_membership_event.auth_event_ids() + prev_event_ids event, context = await self.event_creation_handler.create_event( requester, diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index fd6c5e9ea8..76d4169fe2 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -24,7 +24,8 @@ from saml2.client import Saml2Client from synapse.api.errors import SynapseError from synapse.config import ConfigError from synapse.config.saml2_config import SamlAttributeRequirement -from synapse.http.server import respond_with_html +from synapse.handlers._base import BaseHandler +from synapse.handlers.sso import MappingException, UserAttributes from synapse.http.servlet import parse_string from synapse.http.site import SynapseRequest from synapse.module_api import ModuleApi @@ -37,15 +38,11 @@ from synapse.util.async_helpers import Linearizer from synapse.util.iterutils import chunk_seq if TYPE_CHECKING: - import synapse.server + from synapse.server import HomeServer logger = logging.getLogger(__name__) -class MappingException(Exception): - """Used to catch errors when mapping the SAML2 response to a user.""" - - @attr.s(slots=True) class Saml2SessionData: """Data we track about SAML2 sessions""" @@ -57,17 +54,14 @@ class Saml2SessionData: ui_auth_session_id = attr.ib(type=Optional[str], default=None) -class SamlHandler: - def __init__(self, hs: "synapse.server.HomeServer"): - self.hs = hs +class SamlHandler(BaseHandler): + def __init__(self, hs: "HomeServer"): + super().__init__(hs) self._saml_client = Saml2Client(hs.config.saml2_sp_config) - self._auth = hs.get_auth() + self._saml_idp_entityid = hs.config.saml2_idp_entityid self._auth_handler = hs.get_auth_handler() self._registration_handler = hs.get_registration_handler() - self._clock = hs.get_clock() - self._datastore = hs.get_datastore() - self._hostname = hs.hostname self._saml2_session_lifetime = hs.config.saml2_session_lifetime self._grandfathered_mxid_source_attribute = ( hs.config.saml2_grandfathered_mxid_source_attribute @@ -88,26 +82,9 @@ class SamlHandler: self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData] # a lock on the mappings - self._mapping_lock = Linearizer(name="saml_mapping", clock=self._clock) - - def _render_error( - self, request, error: str, error_description: Optional[str] = None - ) -> None: - """Render the error template and respond to the request with it. - - This is used to show errors to the user. The template of this page can - be found under `synapse/res/templates/sso_error.html`. + self._mapping_lock = Linearizer(name="saml_mapping", clock=self.clock) - Args: - request: The incoming request from the browser. - We'll respond with an HTML page describing the error. - error: A technical identifier for this error. - error_description: A human-readable description of the error. - """ - html = self._error_template.render( - error=error, error_description=error_description - ) - respond_with_html(request, 400, html) + self._sso_handler = hs.get_sso_handler() def handle_redirect_request( self, client_redirect_url: bytes, ui_auth_session_id: Optional[str] = None @@ -124,13 +101,13 @@ class SamlHandler: URL to redirect to """ reqid, info = self._saml_client.prepare_for_authenticate( - relay_state=client_redirect_url + entityid=self._saml_idp_entityid, relay_state=client_redirect_url ) # Since SAML sessions timeout it is useful to log when they were created. logger.info("Initiating a new SAML session: %s" % (reqid,)) - now = self._clock.time_msec() + now = self.clock.time_msec() self._outstanding_requests_dict[reqid] = Saml2SessionData( creation_time=now, ui_auth_session_id=ui_auth_session_id, ) @@ -171,12 +148,12 @@ class SamlHandler: # in the (user-visible) exception message, so let's log the exception here # so we can track down the session IDs later. logger.warning(str(e)) - self._render_error( + self._sso_handler.render_error( request, "unsolicited_response", "Unexpected SAML2 login." ) return except Exception as e: - self._render_error( + self._sso_handler.render_error( request, "invalid_response", "Unable to parse SAML2 response: %s." % (e,), @@ -184,7 +161,7 @@ class SamlHandler: return if saml2_auth.not_signed: - self._render_error( + self._sso_handler.render_error( request, "unsigned_respond", "SAML2 response was not signed." ) return @@ -210,7 +187,7 @@ class SamlHandler: # attributes. for requirement in self._saml2_attribute_requirements: if not _check_attribute_requirement(saml2_auth.ava, requirement): - self._render_error( + self._sso_handler.render_error( request, "unauthorised", "You are not authorised to log in here." ) return @@ -226,7 +203,7 @@ class SamlHandler: ) except MappingException as e: logger.exception("Could not map user") - self._render_error(request, "mapping_error", str(e)) + self._sso_handler.render_error(request, "mapping_error", str(e)) return # Complete the interactive auth session or the login. @@ -272,20 +249,26 @@ class SamlHandler: "Failed to extract remote user id from SAML response" ) - with (await self._mapping_lock.queue(self._auth_provider_id)): - # first of all, check if we already have a mapping for this user - logger.info( - "Looking for existing mapping for user %s:%s", - self._auth_provider_id, - remote_user_id, + async def saml_response_to_remapped_user_attributes( + failures: int, + ) -> UserAttributes: + """ + Call the mapping provider to map a SAML response to user attributes and coerce the result into the standard form. + + This is backwards compatibility for abstraction for the SSO handler. + """ + # Call the mapping provider. + result = self._user_mapping_provider.saml_response_to_user_attributes( + saml2_auth, failures, client_redirect_url ) - registered_user_id = await self._datastore.get_user_by_external_id( - self._auth_provider_id, remote_user_id + # Remap some of the results. + return UserAttributes( + localpart=result.get("mxid_localpart"), + display_name=result.get("displayname"), + emails=result.get("emails", []), ) - if registered_user_id is not None: - logger.info("Found existing mapping %s", registered_user_id) - return registered_user_id + async def grandfather_existing_users() -> Optional[str]: # backwards-compatibility hack: see if there is an existing user with a # suitable mapping from the uid if ( @@ -294,75 +277,35 @@ class SamlHandler: ): attrval = saml2_auth.ava[self._grandfathered_mxid_source_attribute][0] user_id = UserID( - map_username_to_mxid_localpart(attrval), self._hostname + map_username_to_mxid_localpart(attrval), self.server_name ).to_string() - logger.info( + + logger.debug( "Looking for existing account based on mapped %s %s", self._grandfathered_mxid_source_attribute, user_id, ) - users = await self._datastore.get_users_by_id_case_insensitive(user_id) + users = await self.store.get_users_by_id_case_insensitive(user_id) if users: registered_user_id = list(users.keys())[0] logger.info("Grandfathering mapping to %s", registered_user_id) - await self._datastore.record_user_external_id( - self._auth_provider_id, remote_user_id, registered_user_id - ) return registered_user_id - # Map saml response to user attributes using the configured mapping provider - for i in range(1000): - attribute_dict = self._user_mapping_provider.saml_response_to_user_attributes( - saml2_auth, i, client_redirect_url=client_redirect_url, - ) - - logger.debug( - "Retrieved SAML attributes from user mapping provider: %s " - "(attempt %d)", - attribute_dict, - i, - ) - - localpart = attribute_dict.get("mxid_localpart") - if not localpart: - raise MappingException( - "Error parsing SAML2 response: SAML mapping provider plugin " - "did not return a mxid_localpart value" - ) - - displayname = attribute_dict.get("displayname") - emails = attribute_dict.get("emails", []) - - # Check if this mxid already exists - if not await self._datastore.get_users_by_id_case_insensitive( - UserID(localpart, self._hostname).to_string() - ): - # This mxid is free - break - else: - # Unable to generate a username in 1000 iterations - # Break and return error to the user - raise MappingException( - "Unable to generate a Matrix ID from the SAML response" - ) + return None - logger.info("Mapped SAML user to local part %s", localpart) - - registered_user_id = await self._registration_handler.register_user( - localpart=localpart, - default_display_name=displayname, - bind_emails=emails, - user_agent_ips=(user_agent, ip_address), - ) - - await self._datastore.record_user_external_id( - self._auth_provider_id, remote_user_id, registered_user_id + with (await self._mapping_lock.queue(self._auth_provider_id)): + return await self._sso_handler.get_mxid_from_sso( + self._auth_provider_id, + remote_user_id, + user_agent, + ip_address, + saml_response_to_remapped_user_attributes, + grandfather_existing_users, ) - return registered_user_id def expire_sessions(self): - expire_before = self._clock.time_msec() - self._saml2_session_lifetime + expire_before = self.clock.time_msec() - self._saml2_session_lifetime to_expire = set() for reqid, data in self._outstanding_requests_dict.items(): if data.creation_time < expire_before: @@ -474,11 +417,11 @@ class DefaultSamlMappingProvider: ) # Use the configured mapper for this mxid_source - base_mxid_localpart = self._mxid_mapper(mxid_source) + localpart = self._mxid_mapper(mxid_source) # Append suffix integer if last call to this function failed to produce - # a usable mxid - localpart = base_mxid_localpart + (str(failures) if failures else "") + # a usable mxid. + localpart += str(failures) if failures else "" # Retrieve the display name from the saml response # If displayname is None, the mxid_localpart will be used instead diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py new file mode 100644 index 0000000000..47ad96f97e --- /dev/null +++ b/synapse/handlers/sso.py @@ -0,0 +1,244 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional + +import attr + +from synapse.api.errors import RedirectException +from synapse.handlers._base import BaseHandler +from synapse.http.server import respond_with_html +from synapse.types import UserID, contains_invalid_mxid_characters + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class MappingException(Exception): + """Used to catch errors when mapping an SSO response to user attributes. + + Note that the msg that is raised is shown to end-users. + """ + + +@attr.s +class UserAttributes: + localpart = attr.ib(type=str) + display_name = attr.ib(type=Optional[str], default=None) + emails = attr.ib(type=List[str], default=attr.Factory(list)) + + +class SsoHandler(BaseHandler): + # The number of attempts to ask the mapping provider for when generating an MXID. + _MAP_USERNAME_RETRIES = 1000 + + def __init__(self, hs: "HomeServer"): + super().__init__(hs) + self._registration_handler = hs.get_registration_handler() + self._error_template = hs.config.sso_error_template + + def render_error( + self, request, error: str, error_description: Optional[str] = None + ) -> None: + """Renders the error template and responds with it. + + This is used to show errors to the user. The template of this page can + be found under `synapse/res/templates/sso_error.html`. + + Args: + request: The incoming request from the browser. + We'll respond with an HTML page describing the error. + error: A technical identifier for this error. + error_description: A human-readable description of the error. + """ + html = self._error_template.render( + error=error, error_description=error_description + ) + respond_with_html(request, 400, html) + + async def get_sso_user_by_remote_user_id( + self, auth_provider_id: str, remote_user_id: str + ) -> Optional[str]: + """ + Maps the user ID of a remote IdP to a mxid for a previously seen user. + + If the user has not been seen yet, this will return None. + + Args: + auth_provider_id: A unique identifier for this SSO provider, e.g. + "oidc" or "saml". + remote_user_id: The user ID according to the remote IdP. This might + be an e-mail address, a GUID, or some other form. It must be + unique and immutable. + + Returns: + The mxid of a previously seen user. + """ + logger.debug( + "Looking for existing mapping for user %s:%s", + auth_provider_id, + remote_user_id, + ) + + # Check if we already have a mapping for this user. + previously_registered_user_id = await self.store.get_user_by_external_id( + auth_provider_id, remote_user_id, + ) + + # A match was found, return the user ID. + if previously_registered_user_id is not None: + logger.info( + "Found existing mapping for IdP '%s' and remote_user_id '%s': %s", + auth_provider_id, + remote_user_id, + previously_registered_user_id, + ) + return previously_registered_user_id + + # No match. + return None + + async def get_mxid_from_sso( + self, + auth_provider_id: str, + remote_user_id: str, + user_agent: str, + ip_address: str, + sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]], + grandfather_existing_users: Optional[Callable[[], Awaitable[Optional[str]]]], + ) -> str: + """ + Given an SSO ID, retrieve the user ID for it and possibly register the user. + + This first checks if the SSO ID has previously been linked to a matrix ID, + if it has that matrix ID is returned regardless of the current mapping + logic. + + If a callable is provided for grandfathering users, it is called and can + potentially return a matrix ID to use. If it does, the SSO ID is linked to + this matrix ID for subsequent calls. + + The mapping function is called (potentially multiple times) to generate + a localpart for the user. + + If an unused localpart is generated, the user is registered from the + given user-agent and IP address and the SSO ID is linked to this matrix + ID for subsequent calls. + + Args: + auth_provider_id: A unique identifier for this SSO provider, e.g. + "oidc" or "saml". + remote_user_id: The unique identifier from the SSO provider. + user_agent: The user agent of the client making the request. + ip_address: The IP address of the client making the request. + sso_to_matrix_id_mapper: A callable to generate the user attributes. + The only parameter is an integer which represents the amount of + times the returned mxid localpart mapping has failed. + + It is expected that the mapper can raise two exceptions, which + will get passed through to the caller: + + MappingException if there was a problem mapping the response + to the user. + RedirectException to redirect to an additional page (e.g. + to prompt the user for more information). + grandfather_existing_users: A callable which can return an previously + existing matrix ID. The SSO ID is then linked to the returned + matrix ID. + + Returns: + The user ID associated with the SSO response. + + Raises: + MappingException if there was a problem mapping the response to a user. + RedirectException: if the mapping provider needs to redirect the user + to an additional page. (e.g. to prompt for more information) + + """ + # first of all, check if we already have a mapping for this user + previously_registered_user_id = await self.get_sso_user_by_remote_user_id( + auth_provider_id, remote_user_id, + ) + if previously_registered_user_id: + return previously_registered_user_id + + # Check for grandfathering of users. + if grandfather_existing_users: + previously_registered_user_id = await grandfather_existing_users() + if previously_registered_user_id: + # Future logins should also match this user ID. + await self.store.record_user_external_id( + auth_provider_id, remote_user_id, previously_registered_user_id + ) + return previously_registered_user_id + + # Otherwise, generate a new user. + for i in range(self._MAP_USERNAME_RETRIES): + try: + attributes = await sso_to_matrix_id_mapper(i) + except (RedirectException, MappingException): + # Mapping providers are allowed to issue a redirect (e.g. to ask + # the user for more information) and can issue a mapping exception + # if a name cannot be generated. + raise + except Exception as e: + # Any other exception is unexpected. + raise MappingException( + "Could not extract user attributes from SSO response." + ) from e + + logger.debug( + "Retrieved user attributes from user mapping provider: %r (attempt %d)", + attributes, + i, + ) + + if not attributes.localpart: + raise MappingException( + "Error parsing SSO response: SSO mapping provider plugin " + "did not return a localpart value" + ) + + # Check if this mxid already exists + user_id = UserID(attributes.localpart, self.server_name).to_string() + if not await self.store.get_users_by_id_case_insensitive(user_id): + # This mxid is free + break + else: + # Unable to generate a username in 1000 iterations + # Break and return error to the user + raise MappingException( + "Unable to generate a Matrix ID from the SSO response" + ) + + # Since the localpart is provided via a potentially untrusted module, + # ensure the MXID is valid before registering. + if contains_invalid_mxid_characters(attributes.localpart): + raise MappingException("localpart is invalid: %s" % (attributes.localpart,)) + + logger.debug("Mapped SSO user to local part %s", attributes.localpart) + registered_user_id = await self._registration_handler.register_user( + localpart=attributes.localpart, + default_display_name=attributes.display_name, + bind_emails=attributes.emails, + user_agent_ips=[(user_agent, ip_address)], + ) + + await self.store.record_user_external_id( + auth_provider_id, remote_user_id, registered_user_id + ) + return registered_user_id diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 32e53c2d25..9827c7eb8d 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -31,6 +31,7 @@ from synapse.types import ( Collection, JsonDict, MutableStateMap, + Requester, RoomStreamToken, StateMap, StreamToken, @@ -260,6 +261,7 @@ class SyncHandler: async def wait_for_sync_for_user( self, + requester: Requester, sync_config: SyncConfig, since_token: Optional[StreamToken] = None, timeout: int = 0, @@ -273,7 +275,7 @@ class SyncHandler: # not been exceeded (if not part of the group by this point, almost certain # auth_blocking will occur) user_id = sync_config.user.to_string() - await self.auth.check_auth_blocking(user_id) + await self.auth.check_auth_blocking(requester=requester) res = await self.response_cache.wrap( sync_config.request_key, diff --git a/synapse/http/client.py b/synapse/http/client.py index f409368802..e5b13593f2 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -14,9 +14,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -import urllib +import urllib.parse from io import BytesIO from typing import ( + TYPE_CHECKING, Any, BinaryIO, Dict, @@ -31,7 +32,7 @@ from typing import ( import treq from canonicaljson import encode_canonical_json -from netaddr import IPAddress +from netaddr import IPAddress, IPSet from prometheus_client import Counter from zope.interface import implementer, provider @@ -39,6 +40,8 @@ from OpenSSL import SSL from OpenSSL.SSL import VERIFY_NONE from twisted.internet import defer, error as twisted_error, protocol, ssl from twisted.internet.interfaces import ( + IAddress, + IHostResolution, IReactorPluggableNameResolver, IResolutionReceiver, ) @@ -53,7 +56,7 @@ from twisted.web.client import ( ) from twisted.web.http import PotentialDataLoss from twisted.web.http_headers import Headers -from twisted.web.iweb import IResponse +from twisted.web.iweb import IAgent, IBodyProducer, IResponse from synapse.api.errors import Codes, HttpResponseException, SynapseError from synapse.http import QuieterFileBodyProducer, RequestTimedOutError, redact_uri @@ -63,6 +66,9 @@ from synapse.logging.opentracing import set_tag, start_active_span, tags from synapse.util import json_decoder from synapse.util.async_helpers import timeout_deferred +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) outgoing_requests_counter = Counter("synapse_http_client_requests", "", ["method"]) @@ -84,12 +90,19 @@ QueryParamValue = Union[str, bytes, Iterable[Union[str, bytes]]] QueryParams = Union[Mapping[str, QueryParamValue], Mapping[bytes, QueryParamValue]] -def check_against_blacklist(ip_address, ip_whitelist, ip_blacklist): +def check_against_blacklist( + ip_address: IPAddress, ip_whitelist: Optional[IPSet], ip_blacklist: IPSet +) -> bool: """ + Compares an IP address to allowed and disallowed IP sets. + Args: - ip_address (netaddr.IPAddress) - ip_whitelist (netaddr.IPSet) - ip_blacklist (netaddr.IPSet) + ip_address: The IP address to check + ip_whitelist: Allowed IP addresses. + ip_blacklist: Disallowed IP addresses. + + Returns: + True if the IP address is in the blacklist and not in the whitelist. """ if ip_address in ip_blacklist: if ip_whitelist is None or ip_address not in ip_whitelist: @@ -118,23 +131,30 @@ class IPBlacklistingResolver: addresses, preventing DNS rebinding attacks on URL preview. """ - def __init__(self, reactor, ip_whitelist, ip_blacklist): + def __init__( + self, + reactor: IReactorPluggableNameResolver, + ip_whitelist: Optional[IPSet], + ip_blacklist: IPSet, + ): """ Args: - reactor (twisted.internet.reactor) - ip_whitelist (netaddr.IPSet) - ip_blacklist (netaddr.IPSet) + reactor: The twisted reactor. + ip_whitelist: IP addresses to allow. + ip_blacklist: IP addresses to disallow. """ self._reactor = reactor self._ip_whitelist = ip_whitelist self._ip_blacklist = ip_blacklist - def resolveHostName(self, recv, hostname, portNumber=0): + def resolveHostName( + self, recv: IResolutionReceiver, hostname: str, portNumber: int = 0 + ) -> IResolutionReceiver: r = recv() - addresses = [] + addresses = [] # type: List[IAddress] - def _callback(): + def _callback() -> None: r.resolutionBegan(None) has_bad_ip = False @@ -161,15 +181,15 @@ class IPBlacklistingResolver: @provider(IResolutionReceiver) class EndpointReceiver: @staticmethod - def resolutionBegan(resolutionInProgress): + def resolutionBegan(resolutionInProgress: IHostResolution) -> None: pass @staticmethod - def addressResolved(address): + def addressResolved(address: IAddress) -> None: addresses.append(address) @staticmethod - def resolutionComplete(): + def resolutionComplete() -> None: _callback() self._reactor.nameResolver.resolveHostName( @@ -185,19 +205,29 @@ class BlacklistingAgentWrapper(Agent): directly (without an IP address lookup). """ - def __init__(self, agent, reactor, ip_whitelist=None, ip_blacklist=None): + def __init__( + self, + agent: IAgent, + ip_whitelist: Optional[IPSet] = None, + ip_blacklist: Optional[IPSet] = None, + ): """ Args: - agent (twisted.web.client.Agent): The Agent to wrap. - reactor (twisted.internet.reactor) - ip_whitelist (netaddr.IPSet) - ip_blacklist (netaddr.IPSet) + agent: The Agent to wrap. + ip_whitelist: IP addresses to allow. + ip_blacklist: IP addresses to disallow. """ self._agent = agent self._ip_whitelist = ip_whitelist self._ip_blacklist = ip_blacklist - def request(self, method, uri, headers=None, bodyProducer=None): + def request( + self, + method: bytes, + uri: bytes, + headers: Optional[Headers] = None, + bodyProducer: Optional[IBodyProducer] = None, + ) -> defer.Deferred: h = urllib.parse.urlparse(uri.decode("ascii")) try: @@ -226,23 +256,23 @@ class SimpleHttpClient: def __init__( self, - hs, - treq_args={}, - ip_whitelist=None, - ip_blacklist=None, - http_proxy=None, - https_proxy=None, + hs: "HomeServer", + treq_args: Dict[str, Any] = {}, + ip_whitelist: Optional[IPSet] = None, + ip_blacklist: Optional[IPSet] = None, + http_proxy: Optional[bytes] = None, + https_proxy: Optional[bytes] = None, ): """ Args: - hs (synapse.server.HomeServer) - treq_args (dict): Extra keyword arguments to be given to treq.request. - ip_blacklist (netaddr.IPSet): The IP addresses that are blacklisted that + hs + treq_args: Extra keyword arguments to be given to treq.request. + ip_blacklist: The IP addresses that are blacklisted that we may not request. - ip_whitelist (netaddr.IPSet): The whitelisted IP addresses, that we can + ip_whitelist: The whitelisted IP addresses, that we can request if it were otherwise caught in a blacklist. - http_proxy (bytes): proxy server to use for http connections. host[:port] - https_proxy (bytes): proxy server to use for https connections. host[:port] + http_proxy: proxy server to use for http connections. host[:port] + https_proxy: proxy server to use for https connections. host[:port] """ self.hs = hs @@ -306,7 +336,6 @@ class SimpleHttpClient: # by the DNS resolution. self.agent = BlacklistingAgentWrapper( self.agent, - self.reactor, ip_whitelist=self._ip_whitelist, ip_blacklist=self._ip_blacklist, ) @@ -397,7 +426,7 @@ class SimpleHttpClient: async def post_urlencoded_get_json( self, uri: str, - args: Mapping[str, Union[str, List[str]]] = {}, + args: Optional[Mapping[str, Union[str, List[str]]]] = None, headers: Optional[RawHeaders] = None, ) -> Any: """ @@ -422,9 +451,7 @@ class SimpleHttpClient: # TODO: Do we ever want to log message contents? logger.debug("post_urlencoded_get_json args: %s", args) - query_bytes = urllib.parse.urlencode(encode_urlencode_args(args), True).encode( - "utf8" - ) + query_bytes = encode_query_args(args) actual_headers = { b"Content-Type": [b"application/x-www-form-urlencoded"], @@ -432,7 +459,7 @@ class SimpleHttpClient: b"Accept": [b"application/json"], } if headers: - actual_headers.update(headers) + actual_headers.update(headers) # type: ignore response = await self.request( "POST", uri, headers=Headers(actual_headers), data=query_bytes @@ -479,7 +506,7 @@ class SimpleHttpClient: b"Accept": [b"application/json"], } if headers: - actual_headers.update(headers) + actual_headers.update(headers) # type: ignore response = await self.request( "POST", uri, headers=Headers(actual_headers), data=json_str @@ -495,7 +522,10 @@ class SimpleHttpClient: ) async def get_json( - self, uri: str, args: QueryParams = {}, headers: Optional[RawHeaders] = None, + self, + uri: str, + args: Optional[QueryParams] = None, + headers: Optional[RawHeaders] = None, ) -> Any: """Gets some json from the given URI. @@ -516,7 +546,7 @@ class SimpleHttpClient: """ actual_headers = {b"Accept": [b"application/json"]} if headers: - actual_headers.update(headers) + actual_headers.update(headers) # type: ignore body = await self.get_raw(uri, args, headers=headers) return json_decoder.decode(body.decode("utf-8")) @@ -525,7 +555,7 @@ class SimpleHttpClient: self, uri: str, json_body: Any, - args: QueryParams = {}, + args: Optional[QueryParams] = None, headers: RawHeaders = None, ) -> Any: """Puts some json to the given URI. @@ -546,9 +576,9 @@ class SimpleHttpClient: ValueError: if the response was not JSON """ - if len(args): - query_bytes = urllib.parse.urlencode(args, True) - uri = "%s?%s" % (uri, query_bytes) + if args: + query_str = urllib.parse.urlencode(args, True) + uri = "%s?%s" % (uri, query_str) json_str = encode_canonical_json(json_body) @@ -558,7 +588,7 @@ class SimpleHttpClient: b"Accept": [b"application/json"], } if headers: - actual_headers.update(headers) + actual_headers.update(headers) # type: ignore response = await self.request( "PUT", uri, headers=Headers(actual_headers), data=json_str @@ -574,7 +604,10 @@ class SimpleHttpClient: ) async def get_raw( - self, uri: str, args: QueryParams = {}, headers: Optional[RawHeaders] = None + self, + uri: str, + args: Optional[QueryParams] = None, + headers: Optional[RawHeaders] = None, ) -> bytes: """Gets raw text from the given URI. @@ -592,13 +625,13 @@ class SimpleHttpClient: HttpResponseException on a non-2xx HTTP response. """ - if len(args): - query_bytes = urllib.parse.urlencode(args, True) - uri = "%s?%s" % (uri, query_bytes) + if args: + query_str = urllib.parse.urlencode(args, True) + uri = "%s?%s" % (uri, query_str) actual_headers = {b"User-Agent": [self.user_agent]} if headers: - actual_headers.update(headers) + actual_headers.update(headers) # type: ignore response = await self.request("GET", uri, headers=Headers(actual_headers)) @@ -641,7 +674,7 @@ class SimpleHttpClient: actual_headers = {b"User-Agent": [self.user_agent]} if headers: - actual_headers.update(headers) + actual_headers.update(headers) # type: ignore response = await self.request("GET", url, headers=Headers(actual_headers)) @@ -649,12 +682,13 @@ class SimpleHttpClient: if ( b"Content-Length" in resp_headers + and max_size and int(resp_headers[b"Content-Length"][0]) > max_size ): - logger.warning("Requested URL is too large > %r bytes" % (self.max_size,)) + logger.warning("Requested URL is too large > %r bytes" % (max_size,)) raise SynapseError( 502, - "Requested file is too large > %r bytes" % (self.max_size,), + "Requested file is too large > %r bytes" % (max_size,), Codes.TOO_LARGE, ) @@ -668,7 +702,7 @@ class SimpleHttpClient: try: length = await make_deferred_yieldable( - _readBodyToFile(response, output_stream, max_size) + readBodyToFile(response, output_stream, max_size) ) except SynapseError: # This can happen e.g. because the body is too large. @@ -696,18 +730,16 @@ def _timeout_to_request_timed_out_error(f: Failure): return f -# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient. -# The two should be factored out. - - class _ReadBodyToFileProtocol(protocol.Protocol): - def __init__(self, stream, deferred, max_size): + def __init__( + self, stream: BinaryIO, deferred: defer.Deferred, max_size: Optional[int] + ): self.stream = stream self.deferred = deferred self.length = 0 self.max_size = max_size - def dataReceived(self, data): + def dataReceived(self, data: bytes) -> None: self.stream.write(data) self.length += len(data) if self.max_size is not None and self.length >= self.max_size: @@ -721,7 +753,7 @@ class _ReadBodyToFileProtocol(protocol.Protocol): self.deferred = defer.Deferred() self.transport.loseConnection() - def connectionLost(self, reason): + def connectionLost(self, reason: Failure) -> None: if reason.check(ResponseDone): self.deferred.callback(self.length) elif reason.check(PotentialDataLoss): @@ -732,35 +764,48 @@ class _ReadBodyToFileProtocol(protocol.Protocol): self.deferred.errback(reason) -# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient. -# The two should be factored out. +def readBodyToFile( + response: IResponse, stream: BinaryIO, max_size: Optional[int] +) -> defer.Deferred: + """ + Read a HTTP response body to a file-object. Optionally enforcing a maximum file size. + Args: + response: The HTTP response to read from. + stream: The file-object to write to. + max_size: The maximum file size to allow. + + Returns: + A Deferred which resolves to the length of the read body. + """ -def _readBodyToFile(response, stream, max_size): d = defer.Deferred() response.deliverBody(_ReadBodyToFileProtocol(stream, d, max_size)) return d -def encode_urlencode_args(args): - return {k: encode_urlencode_arg(v) for k, v in args.items()} +def encode_query_args(args: Optional[Mapping[str, Union[str, List[str]]]]) -> bytes: + """ + Encodes a map of query arguments to bytes which can be appended to a URL. + Args: + args: The query arguments, a mapping of string to string or list of strings. + + Returns: + The query arguments encoded as bytes. + """ + if args is None: + return b"" -def encode_urlencode_arg(arg): - if isinstance(arg, str): - return arg.encode("utf-8") - elif isinstance(arg, list): - return [encode_urlencode_arg(i) for i in arg] - else: - return arg + encoded_args = {} + for k, vs in args.items(): + if isinstance(vs, str): + vs = [vs] + encoded_args[k] = [v.encode("utf8") for v in vs] + query_str = urllib.parse.urlencode(encoded_args, True) -def _print_ex(e): - if hasattr(e, "reasons") and e.reasons: - for ex in e.reasons: - _print_ex(ex) - else: - logger.exception(e) + return query_str.encode("utf8") class InsecureInterceptableContextFactory(ssl.ContextFactory): diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py index 83d6196d4a..e77f9587d0 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py @@ -12,21 +12,25 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging -import urllib -from typing import List +import urllib.parse +from typing import List, Optional from netaddr import AddrFormatError, IPAddress from zope.interface import implementer from twisted.internet import defer from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS -from twisted.internet.interfaces import IStreamClientEndpoint -from twisted.web.client import Agent, HTTPConnectionPool +from twisted.internet.interfaces import ( + IProtocolFactory, + IReactorCore, + IStreamClientEndpoint, +) +from twisted.web.client import URI, Agent, HTTPConnectionPool from twisted.web.http_headers import Headers -from twisted.web.iweb import IAgent, IAgentEndpointFactory +from twisted.web.iweb import IAgent, IAgentEndpointFactory, IBodyProducer +from synapse.crypto.context_factory import FederationPolicyForHTTPS from synapse.http.federation.srv_resolver import Server, SrvResolver from synapse.http.federation.well_known_resolver import WellKnownResolver from synapse.logging.context import make_deferred_yieldable, run_in_background @@ -44,30 +48,30 @@ class MatrixFederationAgent: Doesn't implement any retries. (Those are done in MatrixFederationHttpClient.) Args: - reactor (IReactor): twisted reactor to use for underlying requests + reactor: twisted reactor to use for underlying requests - tls_client_options_factory (FederationPolicyForHTTPS|None): + tls_client_options_factory: factory to use for fetching client tls options, or none to disable TLS. - user_agent (bytes): + user_agent: The user agent header to use for federation requests. - _srv_resolver (SrvResolver|None): - SRVResolver impl to use for looking up SRV records. None to use a default - implementation. + _srv_resolver: + SrvResolver implementation to use for looking up SRV records. None + to use a default implementation. - _well_known_resolver (WellKnownResolver|None): + _well_known_resolver: WellKnownResolver to use to perform well-known lookups. None to use a default implementation. """ def __init__( self, - reactor, - tls_client_options_factory, - user_agent, - _srv_resolver=None, - _well_known_resolver=None, + reactor: IReactorCore, + tls_client_options_factory: Optional[FederationPolicyForHTTPS], + user_agent: bytes, + _srv_resolver: Optional[SrvResolver] = None, + _well_known_resolver: Optional[WellKnownResolver] = None, ): self._reactor = reactor self._clock = Clock(reactor) @@ -99,15 +103,20 @@ class MatrixFederationAgent: self._well_known_resolver = _well_known_resolver @defer.inlineCallbacks - def request(self, method, uri, headers=None, bodyProducer=None): + def request( + self, + method: bytes, + uri: bytes, + headers: Optional[Headers] = None, + bodyProducer: Optional[IBodyProducer] = None, + ) -> defer.Deferred: """ Args: - method (bytes): HTTP method: GET/POST/etc - uri (bytes): Absolute URI to be retrieved - headers (twisted.web.http_headers.Headers|None): - HTTP headers to send with the request, or None to - send no extra headers. - bodyProducer (twisted.web.iweb.IBodyProducer|None): + method: HTTP method: GET/POST/etc + uri: Absolute URI to be retrieved + headers: + HTTP headers to send with the request, or None to send no extra headers. + bodyProducer: An object which can generate bytes to make up the body of this request (for example, the properly encoded contents of a file for a file upload). Or None if the request is to have @@ -123,6 +132,9 @@ class MatrixFederationAgent: # explicit port. parsed_uri = urllib.parse.urlparse(uri) + # There must be a valid hostname. + assert parsed_uri.hostname + # If this is a matrix:// URI check if the server has delegated matrix # traffic using well-known delegation. # @@ -179,7 +191,12 @@ class MatrixHostnameEndpointFactory: """Factory for MatrixHostnameEndpoint for parsing to an Agent. """ - def __init__(self, reactor, tls_client_options_factory, srv_resolver): + def __init__( + self, + reactor: IReactorCore, + tls_client_options_factory: Optional[FederationPolicyForHTTPS], + srv_resolver: Optional[SrvResolver], + ): self._reactor = reactor self._tls_client_options_factory = tls_client_options_factory @@ -203,15 +220,20 @@ class MatrixHostnameEndpoint: resolution (i.e. via SRV). Does not check for well-known delegation. Args: - reactor (IReactor) - tls_client_options_factory (ClientTLSOptionsFactory|None): + reactor: twisted reactor to use for underlying requests + tls_client_options_factory: factory to use for fetching client tls options, or none to disable TLS. - srv_resolver (SrvResolver): The SRV resolver to use - parsed_uri (twisted.web.client.URI): The parsed URI that we're wanting - to connect to. + srv_resolver: The SRV resolver to use + parsed_uri: The parsed URI that we're wanting to connect to. """ - def __init__(self, reactor, tls_client_options_factory, srv_resolver, parsed_uri): + def __init__( + self, + reactor: IReactorCore, + tls_client_options_factory: Optional[FederationPolicyForHTTPS], + srv_resolver: SrvResolver, + parsed_uri: URI, + ): self._reactor = reactor self._parsed_uri = parsed_uri @@ -231,13 +253,13 @@ class MatrixHostnameEndpoint: self._srv_resolver = srv_resolver - def connect(self, protocol_factory): + def connect(self, protocol_factory: IProtocolFactory) -> defer.Deferred: """Implements IStreamClientEndpoint interface """ return run_in_background(self._do_connect, protocol_factory) - async def _do_connect(self, protocol_factory): + async def _do_connect(self, protocol_factory: IProtocolFactory) -> None: first_exception = None server_list = await self._resolve_server() @@ -303,20 +325,20 @@ class MatrixHostnameEndpoint: return [Server(host, 8448)] -def _is_ip_literal(host): +def _is_ip_literal(host: bytes) -> bool: """Test if the given host name is either an IPv4 or IPv6 literal. Args: - host (bytes) + host: The host name to check Returns: - bool + True if the hostname is an IP address literal. """ - host = host.decode("ascii") + host_str = host.decode("ascii") try: - IPAddress(host) + IPAddress(host_str) return True except AddrFormatError: return False diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py index 1cc666fbf6..5e08ef1664 100644 --- a/synapse/http/federation/well_known_resolver.py +++ b/synapse/http/federation/well_known_resolver.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging import random import time @@ -21,10 +20,11 @@ from typing import Callable, Dict, Optional, Tuple import attr from twisted.internet import defer +from twisted.internet.interfaces import IReactorTime from twisted.web.client import RedirectAgent, readBody from twisted.web.http import stringToDatetime from twisted.web.http_headers import Headers -from twisted.web.iweb import IResponse +from twisted.web.iweb import IAgent, IResponse from synapse.logging.context import make_deferred_yieldable from synapse.util import Clock, json_decoder @@ -81,11 +81,11 @@ class WellKnownResolver: def __init__( self, - reactor, - agent, - user_agent, - well_known_cache=None, - had_well_known_cache=None, + reactor: IReactorTime, + agent: IAgent, + user_agent: bytes, + well_known_cache: Optional[TTLCache] = None, + had_well_known_cache: Optional[TTLCache] = None, ): self._reactor = reactor self._clock = Clock(reactor) @@ -127,7 +127,7 @@ class WellKnownResolver: with Measure(self._clock, "get_well_known"): result, cache_period = await self._fetch_well_known( server_name - ) # type: Tuple[Optional[bytes], float] + ) # type: Optional[bytes], float except _FetchWellKnownFailure as e: if prev_result and e.temporary: diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 7e17cdb73e..4e27f93b7a 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -17,8 +17,9 @@ import cgi import logging import random import sys -import urllib +import urllib.parse from io import BytesIO +from typing import Callable, Dict, List, Optional, Tuple, Union import attr import treq @@ -27,25 +28,27 @@ from prometheus_client import Counter from signedjson.sign import sign_json from zope.interface import implementer -from twisted.internet import defer, protocol +from twisted.internet import defer from twisted.internet.error import DNSLookupError from twisted.internet.interfaces import IReactorPluggableNameResolver, IReactorTime from twisted.internet.task import _EPSILON, Cooperator -from twisted.web._newclient import ResponseDone from twisted.web.http_headers import Headers -from twisted.web.iweb import IResponse +from twisted.web.iweb import IBodyProducer, IResponse import synapse.metrics import synapse.util.retryutils from synapse.api.errors import ( - Codes, FederationDeniedError, HttpResponseException, RequestSendFailed, - SynapseError, ) from synapse.http import QuieterFileBodyProducer -from synapse.http.client import BlacklistingAgentWrapper, IPBlacklistingResolver +from synapse.http.client import ( + BlacklistingAgentWrapper, + IPBlacklistingResolver, + encode_query_args, + readBodyToFile, +) from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent from synapse.logging.context import make_deferred_yieldable from synapse.logging.opentracing import ( @@ -54,6 +57,7 @@ from synapse.logging.opentracing import ( start_active_span, tags, ) +from synapse.types import JsonDict from synapse.util import json_decoder from synapse.util.async_helpers import timeout_deferred from synapse.util.metrics import Measure @@ -76,47 +80,44 @@ MAXINT = sys.maxsize _next_id = 1 +QueryArgs = Dict[str, Union[str, List[str]]] + + @attr.s(slots=True, frozen=True) class MatrixFederationRequest: - method = attr.ib() + method = attr.ib(type=str) """HTTP method - :type: str """ - path = attr.ib() + path = attr.ib(type=str) """HTTP path - :type: str """ - destination = attr.ib() + destination = attr.ib(type=str) """The remote server to send the HTTP request to. - :type: str""" + """ - json = attr.ib(default=None) + json = attr.ib(default=None, type=Optional[JsonDict]) """JSON to send in the body. - :type: dict|None """ - json_callback = attr.ib(default=None) + json_callback = attr.ib(default=None, type=Optional[Callable[[], JsonDict]]) """A callback to generate the JSON. - :type: func|None """ - query = attr.ib(default=None) + query = attr.ib(default=None, type=Optional[dict]) """Query arguments. - :type: dict|None """ - txn_id = attr.ib(default=None) + txn_id = attr.ib(default=None, type=Optional[str]) """Unique ID for this request (for logging) - :type: str|None """ uri = attr.ib(init=False, type=bytes) """The URI of this request """ - def __attrs_post_init__(self): + def __attrs_post_init__(self) -> None: global _next_id txn_id = "%s-O-%s" % (self.method, _next_id) _next_id = (_next_id + 1) % (MAXINT - 1) @@ -136,7 +137,7 @@ class MatrixFederationRequest: ) object.__setattr__(self, "uri", uri) - def get_json(self): + def get_json(self) -> Optional[JsonDict]: if self.json_callback: return self.json_callback() return self.json @@ -148,7 +149,7 @@ async def _handle_json_response( request: MatrixFederationRequest, response: IResponse, start_ms: int, -): +) -> JsonDict: """ Reads the JSON body of a response, with a timeout @@ -160,7 +161,7 @@ async def _handle_json_response( start_ms: Timestamp when request was made Returns: - dict: parsed JSON response + The parsed JSON response """ try: check_content_type_is_json(response.headers) @@ -250,9 +251,7 @@ class MatrixFederationHttpClient: # Use a BlacklistingAgentWrapper to prevent circumventing the IP # blacklist via IP literals in server names self.agent = BlacklistingAgentWrapper( - self.agent, - self.reactor, - ip_blacklist=hs.config.federation_ip_range_blacklist, + self.agent, ip_blacklist=hs.config.federation_ip_range_blacklist, ) self.clock = hs.get_clock() @@ -266,27 +265,29 @@ class MatrixFederationHttpClient: self._cooperator = Cooperator(scheduler=schedule) async def _send_request_with_optional_trailing_slash( - self, request, try_trailing_slash_on_400=False, **send_request_args - ): + self, + request: MatrixFederationRequest, + try_trailing_slash_on_400: bool = False, + **send_request_args + ) -> IResponse: """Wrapper for _send_request which can optionally retry the request upon receiving a combination of a 400 HTTP response code and a 'M_UNRECOGNIZED' errcode. This is a workaround for Synapse <= v0.99.3 due to #3622. Args: - request (MatrixFederationRequest): details of request to be sent - try_trailing_slash_on_400 (bool): Whether on receiving a 400 + request: details of request to be sent + try_trailing_slash_on_400: Whether on receiving a 400 'M_UNRECOGNIZED' from the server to retry the request with a trailing slash appended to the request path. - send_request_args (Dict): A dictionary of arguments to pass to - `_send_request()`. + send_request_args: A dictionary of arguments to pass to `_send_request()`. Raises: HttpResponseException: If we get an HTTP response code >= 300 (except 429). Returns: - Dict: Parsed JSON response body. + Parsed JSON response body. """ try: response = await self._send_request(request, **send_request_args) @@ -313,24 +314,26 @@ class MatrixFederationHttpClient: async def _send_request( self, - request, - retry_on_dns_fail=True, - timeout=None, - long_retries=False, - ignore_backoff=False, - backoff_on_404=False, - ): + request: MatrixFederationRequest, + retry_on_dns_fail: bool = True, + timeout: Optional[int] = None, + long_retries: bool = False, + ignore_backoff: bool = False, + backoff_on_404: bool = False, + ) -> IResponse: """ Sends a request to the given server. Args: - request (MatrixFederationRequest): details of request to be sent + request: details of request to be sent + + retry_on_dns_fail: true if the request should be retied on DNS failures - timeout (int|None): number of milliseconds to wait for the response headers + timeout: number of milliseconds to wait for the response headers (including connecting to the server), *for each attempt*. 60s by default. - long_retries (bool): whether to use the long retry algorithm. + long_retries: whether to use the long retry algorithm. The regular retry algorithm makes 4 attempts, with intervals [0.5s, 1s, 2s]. @@ -346,14 +349,13 @@ class MatrixFederationHttpClient: NB: the long retry algorithm takes over 20 minutes to complete, with a default timeout of 60s! - ignore_backoff (bool): true to ignore the historical backoff data + ignore_backoff: true to ignore the historical backoff data and try the request anyway. - backoff_on_404 (bool): Back off if we get a 404 + backoff_on_404: Back off if we get a 404 Returns: - twisted.web.client.Response: resolves with the HTTP - response object on success. + Resolves with the HTTP response object on success. Raises: HttpResponseException: If we get an HTTP response code >= 300 @@ -404,7 +406,7 @@ class MatrixFederationHttpClient: ) # Inject the span into the headers - headers_dict = {} + headers_dict = {} # type: Dict[bytes, List[bytes]] inject_active_span_byte_dict(headers_dict, request.destination) headers_dict[b"User-Agent"] = [self.version_string_bytes] @@ -435,7 +437,7 @@ class MatrixFederationHttpClient: data = encode_canonical_json(json) producer = QuieterFileBodyProducer( BytesIO(data), cooperator=self._cooperator - ) + ) # type: Optional[IBodyProducer] else: producer = None auth_headers = self.build_auth_headers( @@ -524,14 +526,16 @@ class MatrixFederationHttpClient: ) body = None - e = HttpResponseException(response.code, response_phrase, body) + exc = HttpResponseException( + response.code, response_phrase, body + ) # Retry if the error is a 429 (Too Many Requests), # otherwise just raise a standard HttpResponseException if response.code == 429: - raise RequestSendFailed(e, can_retry=True) from e + raise RequestSendFailed(exc, can_retry=True) from exc else: - raise e + raise exc break except RequestSendFailed as e: @@ -582,22 +586,27 @@ class MatrixFederationHttpClient: return response def build_auth_headers( - self, destination, method, url_bytes, content=None, destination_is=None - ): + self, + destination: Optional[bytes], + method: bytes, + url_bytes: bytes, + content: Optional[JsonDict] = None, + destination_is: Optional[bytes] = None, + ) -> List[bytes]: """ Builds the Authorization headers for a federation request Args: - destination (bytes|None): The destination homeserver of the request. + destination: The destination homeserver of the request. May be None if the destination is an identity server, in which case destination_is must be non-None. - method (bytes): The HTTP method of the request - url_bytes (bytes): The URI path of the request - content (object): The body of the request - destination_is (bytes): As 'destination', but if the destination is an + method: The HTTP method of the request + url_bytes: The URI path of the request + content: The body of the request + destination_is: As 'destination', but if the destination is an identity server Returns: - list[bytes]: a list of headers to be added as "Authorization:" headers + A list of headers to be added as "Authorization:" headers """ request = { "method": method.decode("ascii"), @@ -629,33 +638,32 @@ class MatrixFederationHttpClient: async def put_json( self, - destination, - path, - args={}, - data={}, - json_data_callback=None, - long_retries=False, - timeout=None, - ignore_backoff=False, - backoff_on_404=False, - try_trailing_slash_on_400=False, - ): + destination: str, + path: str, + args: Optional[QueryArgs] = None, + data: Optional[JsonDict] = None, + json_data_callback: Optional[Callable[[], JsonDict]] = None, + long_retries: bool = False, + timeout: Optional[int] = None, + ignore_backoff: bool = False, + backoff_on_404: bool = False, + try_trailing_slash_on_400: bool = False, + ) -> Union[JsonDict, list]: """ Sends the specified json data using PUT Args: - destination (str): The remote server to send the HTTP request - to. - path (str): The HTTP path. - args (dict): query params - data (dict): A dict containing the data that will be used as + destination: The remote server to send the HTTP request to. + path: The HTTP path. + args: query params + data: A dict containing the data that will be used as the request body. This will be encoded as JSON. - json_data_callback (callable): A callable returning the dict to + json_data_callback: A callable returning the dict to use as the request body. - long_retries (bool): whether to use the long retry algorithm. See + long_retries: whether to use the long retry algorithm. See docs on _send_request for details. - timeout (int|None): number of milliseconds to wait for the response. + timeout: number of milliseconds to wait for the response. self._default_timeout (60s) by default. Note that we may make several attempts to send the request; this @@ -663,19 +671,19 @@ class MatrixFederationHttpClient: *each* attempt (including connection time) as well as the time spent reading the response body after a 200 response. - ignore_backoff (bool): true to ignore the historical backoff data + ignore_backoff: true to ignore the historical backoff data and try the request anyway. - backoff_on_404 (bool): True if we should count a 404 response as + backoff_on_404: True if we should count a 404 response as a failure of the server (and should therefore back off future requests). - try_trailing_slash_on_400 (bool): True if on a 400 M_UNRECOGNIZED + try_trailing_slash_on_400: True if on a 400 M_UNRECOGNIZED response we should try appending a trailing slash to the end of the request. Workaround for #3622 in Synapse <= v0.99.3. This will be attempted before backing off if backing off has been enabled. Returns: - dict|list: Succeeds when we get a 2xx HTTP response. The + Succeeds when we get a 2xx HTTP response. The result will be the decoded JSON body. Raises: @@ -721,29 +729,28 @@ class MatrixFederationHttpClient: async def post_json( self, - destination, - path, - data={}, - long_retries=False, - timeout=None, - ignore_backoff=False, - args={}, - ): + destination: str, + path: str, + data: Optional[JsonDict] = None, + long_retries: bool = False, + timeout: Optional[int] = None, + ignore_backoff: bool = False, + args: Optional[QueryArgs] = None, + ) -> Union[JsonDict, list]: """ Sends the specified json data using POST Args: - destination (str): The remote server to send the HTTP request - to. + destination: The remote server to send the HTTP request to. - path (str): The HTTP path. + path: The HTTP path. - data (dict): A dict containing the data that will be used as + data: A dict containing the data that will be used as the request body. This will be encoded as JSON. - long_retries (bool): whether to use the long retry algorithm. See + long_retries: whether to use the long retry algorithm. See docs on _send_request for details. - timeout (int|None): number of milliseconds to wait for the response. + timeout: number of milliseconds to wait for the response. self._default_timeout (60s) by default. Note that we may make several attempts to send the request; this @@ -751,10 +758,10 @@ class MatrixFederationHttpClient: *each* attempt (including connection time) as well as the time spent reading the response body after a 200 response. - ignore_backoff (bool): true to ignore the historical backoff data and + ignore_backoff: true to ignore the historical backoff data and try the request anyway. - args (dict): query params + args: query params Returns: dict|list: Succeeds when we get a 2xx HTTP response. The result will be the decoded JSON body. @@ -795,26 +802,25 @@ class MatrixFederationHttpClient: async def get_json( self, - destination, - path, - args=None, - retry_on_dns_fail=True, - timeout=None, - ignore_backoff=False, - try_trailing_slash_on_400=False, - ): + destination: str, + path: str, + args: Optional[QueryArgs] = None, + retry_on_dns_fail: bool = True, + timeout: Optional[int] = None, + ignore_backoff: bool = False, + try_trailing_slash_on_400: bool = False, + ) -> Union[JsonDict, list]: """ GETs some json from the given host homeserver and path Args: - destination (str): The remote server to send the HTTP request - to. + destination: The remote server to send the HTTP request to. - path (str): The HTTP path. + path: The HTTP path. - args (dict|None): A dictionary used to create query strings, defaults to + args: A dictionary used to create query strings, defaults to None. - timeout (int|None): number of milliseconds to wait for the response. + timeout: number of milliseconds to wait for the response. self._default_timeout (60s) by default. Note that we may make several attempts to send the request; this @@ -822,14 +828,14 @@ class MatrixFederationHttpClient: *each* attempt (including connection time) as well as the time spent reading the response body after a 200 response. - ignore_backoff (bool): true to ignore the historical backoff data + ignore_backoff: true to ignore the historical backoff data and try the request anyway. - try_trailing_slash_on_400 (bool): True if on a 400 M_UNRECOGNIZED + try_trailing_slash_on_400: True if on a 400 M_UNRECOGNIZED response we should try appending a trailing slash to the end of the request. Workaround for #3622 in Synapse <= v0.99.3. Returns: - dict|list: Succeeds when we get a 2xx HTTP response. The + Succeeds when we get a 2xx HTTP response. The result will be the decoded JSON body. Raises: @@ -870,24 +876,23 @@ class MatrixFederationHttpClient: async def delete_json( self, - destination, - path, - long_retries=False, - timeout=None, - ignore_backoff=False, - args={}, - ): + destination: str, + path: str, + long_retries: bool = False, + timeout: Optional[int] = None, + ignore_backoff: bool = False, + args: Optional[QueryArgs] = None, + ) -> Union[JsonDict, list]: """Send a DELETE request to the remote expecting some json response Args: - destination (str): The remote server to send the HTTP request - to. - path (str): The HTTP path. + destination: The remote server to send the HTTP request to. + path: The HTTP path. - long_retries (bool): whether to use the long retry algorithm. See + long_retries: whether to use the long retry algorithm. See docs on _send_request for details. - timeout (int|None): number of milliseconds to wait for the response. + timeout: number of milliseconds to wait for the response. self._default_timeout (60s) by default. Note that we may make several attempts to send the request; this @@ -895,12 +900,12 @@ class MatrixFederationHttpClient: *each* attempt (including connection time) as well as the time spent reading the response body after a 200 response. - ignore_backoff (bool): true to ignore the historical backoff data and + ignore_backoff: true to ignore the historical backoff data and try the request anyway. - args (dict): query params + args: query params Returns: - dict|list: Succeeds when we get a 2xx HTTP response. The + Succeeds when we get a 2xx HTTP response. The result will be the decoded JSON body. Raises: @@ -938,25 +943,25 @@ class MatrixFederationHttpClient: async def get_file( self, - destination, - path, + destination: str, + path: str, output_stream, - args={}, - retry_on_dns_fail=True, - max_size=None, - ignore_backoff=False, - ): + args: Optional[QueryArgs] = None, + retry_on_dns_fail: bool = True, + max_size: Optional[int] = None, + ignore_backoff: bool = False, + ) -> Tuple[int, Dict[bytes, List[bytes]]]: """GETs a file from a given homeserver Args: - destination (str): The remote server to send the HTTP request to. - path (str): The HTTP path to GET. - output_stream (file): File to write the response body to. - args (dict): Optional dictionary used to create the query string. - ignore_backoff (bool): true to ignore the historical backoff data + destination: The remote server to send the HTTP request to. + path: The HTTP path to GET. + output_stream: File to write the response body to. + args: Optional dictionary used to create the query string. + ignore_backoff: true to ignore the historical backoff data and try the request anyway. Returns: - tuple[int, dict]: Resolves with an (int,dict) tuple of + Resolves with an (int,dict) tuple of the file length and a dict of the response headers. Raises: @@ -980,7 +985,7 @@ class MatrixFederationHttpClient: headers = dict(response.headers.getAllRawHeaders()) try: - d = _readBodyToFile(response, output_stream, max_size) + d = readBodyToFile(response, output_stream, max_size) d.addTimeout(self.default_timeout, self.reactor) length = await make_deferred_yieldable(d) except Exception as e: @@ -1004,40 +1009,6 @@ class MatrixFederationHttpClient: return (length, headers) -class _ReadBodyToFileProtocol(protocol.Protocol): - def __init__(self, stream, deferred, max_size): - self.stream = stream - self.deferred = deferred - self.length = 0 - self.max_size = max_size - - def dataReceived(self, data): - self.stream.write(data) - self.length += len(data) - if self.max_size is not None and self.length >= self.max_size: - self.deferred.errback( - SynapseError( - 502, - "Requested file is too large > %r bytes" % (self.max_size,), - Codes.TOO_LARGE, - ) - ) - self.deferred = defer.Deferred() - self.transport.loseConnection() - - def connectionLost(self, reason): - if reason.check(ResponseDone): - self.deferred.callback(self.length) - else: - self.deferred.errback(reason) - - -def _readBodyToFile(response, stream, max_size): - d = defer.Deferred() - response.deliverBody(_ReadBodyToFileProtocol(stream, d, max_size)) - return d - - def _flatten_response_never_received(e): if hasattr(e, "reasons"): reasons = ", ".join( @@ -1049,13 +1020,13 @@ def _flatten_response_never_received(e): return repr(e) -def check_content_type_is_json(headers): +def check_content_type_is_json(headers: Headers) -> None: """ Check that a set of HTTP headers have a Content-Type header, and that it is application/json. Args: - headers (twisted.web.http_headers.Headers): headers to check + headers: headers to check Raises: RequestSendFailed: if the Content-Type header is missing or isn't JSON @@ -1078,18 +1049,3 @@ def check_content_type_is_json(headers): ), can_retry=False, ) - - -def encode_query_args(args): - if args is None: - return b"" - - encoded_args = {} - for k, vs in args.items(): - if isinstance(vs, str): - vs = [vs] - encoded_args[k] = [v.encode("UTF-8") for v in vs] - - query_bytes = urllib.parse.urlencode(encoded_args, True) - - return query_bytes.encode("utf8") diff --git a/synapse/http/server.py b/synapse/http/server.py index c0919f8cb7..6a4e429a6c 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -25,7 +25,7 @@ from io import BytesIO from typing import Any, Callable, Dict, Iterator, List, Tuple, Union import jinja2 -from canonicaljson import iterencode_canonical_json, iterencode_pretty_printed_json +from canonicaljson import iterencode_canonical_json from zope.interface import implementer from twisted.internet import defer, interfaces @@ -94,11 +94,7 @@ def return_json_error(f: failure.Failure, request: SynapseRequest) -> None: pass else: respond_with_json( - request, - error_code, - error_dict, - send_cors=True, - pretty_print=_request_user_agent_is_curl(request), + request, error_code, error_dict, send_cors=True, ) @@ -290,7 +286,6 @@ class DirectServeJsonResource(_AsyncResource): code, response_object, send_cors=True, - pretty_print=_request_user_agent_is_curl(request), canonical_json=self.canonical_json, ) @@ -587,7 +582,6 @@ def respond_with_json( code: int, json_object: Any, send_cors: bool = False, - pretty_print: bool = False, canonical_json: bool = True, ): """Sends encoded JSON in response to the given request. @@ -598,8 +592,6 @@ def respond_with_json( json_object: The object to serialize to JSON. send_cors: Whether to send Cross-Origin Resource Sharing headers https://fetch.spec.whatwg.org/#http-cors-protocol - pretty_print: Whether to include indentation and line-breaks in the - resulting JSON bytes. canonical_json: Whether to use the canonicaljson algorithm when encoding the JSON bytes. @@ -615,13 +607,10 @@ def respond_with_json( ) return None - if pretty_print: - encoder = iterencode_pretty_printed_json + if canonical_json: + encoder = iterencode_canonical_json else: - if canonical_json: - encoder = iterencode_canonical_json - else: - encoder = _encode_json_bytes + encoder = _encode_json_bytes request.setResponseCode(code) request.setHeader(b"Content-Type", b"application/json") @@ -685,7 +674,7 @@ def set_cors_headers(request: Request): ) request.setHeader( b"Access-Control-Allow-Headers", - b"Origin, X-Requested-With, Content-Type, Accept, Authorization", + b"Origin, X-Requested-With, Content-Type, Accept, Authorization, Date", ) @@ -759,11 +748,3 @@ def finish_request(request: Request): request.finish() except RuntimeError as e: logger.info("Connection disconnected before response was written: %r", e) - - -def _request_user_agent_is_curl(request: Request) -> bool: - user_agents = request.requestHeaders.getRawHeaders(b"User-Agent", default=[]) - for user_agent in user_agents: - if b"curl" in user_agent: - return True - return False diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 0142542852..72ab5750cc 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -49,6 +49,7 @@ class ModuleApi: self._store = hs.get_datastore() self._auth = hs.get_auth() self._auth_handler = auth_handler + self._server_name = hs.hostname # We expose these as properties below in order to attach a helpful docstring. self._http_client = hs.get_simple_http_client() # type: SimpleHttpClient @@ -336,7 +337,9 @@ class ModuleApi: SynapseError if the event was not allowed. """ # Create a requester object - requester = create_requester(event_dict["sender"]) + requester = create_requester( + event_dict["sender"], authenticated_entity=self._server_name + ) # Create and send the event ( diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index 793d0db2d9..eff0975b6a 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -75,6 +75,7 @@ class HttpPusher: self.failing_since = pusherdict["failing_since"] self.timed_call = None self._is_processing = False + self._group_unread_count_by_room = hs.config.push_group_unread_count_by_room # This is the highest stream ordering we know it's safe to process. # When new events arrive, we'll be given a window of new events: we @@ -136,7 +137,11 @@ class HttpPusher: async def _update_badge(self): # XXX as per https://github.com/matrix-org/matrix-doc/issues/2627, this seems # to be largely redundant. perhaps we can remove it. - badge = await push_tools.get_badge_count(self.hs.get_datastore(), self.user_id) + badge = await push_tools.get_badge_count( + self.hs.get_datastore(), + self.user_id, + group_by_room=self._group_unread_count_by_room, + ) await self._send_badge(badge) def on_timer(self): @@ -283,7 +288,11 @@ class HttpPusher: return True tweaks = push_rule_evaluator.tweaks_for_actions(push_action["actions"]) - badge = await push_tools.get_badge_count(self.hs.get_datastore(), self.user_id) + badge = await push_tools.get_badge_count( + self.hs.get_datastore(), + self.user_id, + group_by_room=self._group_unread_count_by_room, + ) event = await self.store.get_event(push_action["event_id"], allow_none=True) if event is None: diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index d0145666bf..6e7c880dc0 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -12,12 +12,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from synapse.push.presentable_names import calculate_room_name, name_from_member_event from synapse.storage import Storage +from synapse.storage.databases.main import DataStore -async def get_badge_count(store, user_id): +async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -> int: invites = await store.get_invited_rooms_for_local_user(user_id) joins = await store.get_rooms_for_user(user_id) @@ -34,9 +34,15 @@ async def get_badge_count(store, user_id): room_id, user_id, last_unread_event_id ) ) - # return one badge count per conversation, as count per - # message is so noisy as to be almost useless - badge += 1 if notifs["notify_count"] else 0 + if notifs["notify_count"] == 0: + continue + + if group_by_room: + # return one badge count per conversation + badge += 1 + else: + # increment the badge count by the number of unread messages in the room + badge += notifs["notify_count"] return badge diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index aab77fc453..c97e0df1f5 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -40,6 +40,10 @@ logger = logging.getLogger(__name__) # Note that these both represent runtime dependencies (and the versions # installed are checked at runtime). # +# Also note that we replicate these constraints in the Synapse Dockerfile while +# pre-installing dependencies. If these constraints are updated here, the same +# change should be made in the Dockerfile. +# # [1] https://pip.pypa.io/en/stable/reference/pip_install/#requirement-specifiers. REQUIREMENTS = [ @@ -69,14 +73,7 @@ REQUIREMENTS = [ "msgpack>=0.5.2", "phonenumbers>=8.2.0", # we use GaugeHistogramMetric, which was added in prom-client 0.4.0. - # prom-client has a history of breaking backwards compatibility between - # minor versions (https://github.com/prometheus/client_python/issues/317), - # so we also pin the minor version. - # - # Note that we replicate these constraints in the Synapse Dockerfile while - # pre-installing dependencies. If these constraints are updated here, the - # same change should be made in the Dockerfile. - "prometheus_client>=0.4.0,<0.9.0", + "prometheus_client>=0.4.0", # we use attr.validators.deep_iterable, which arrived in 19.1.0 (Note: # Fedora 31 only has 19.1, so if we want to upgrade we should wait until 33 # is out in November.) @@ -99,7 +96,11 @@ CONDITIONAL_REQUIREMENTS = { # python 3.5.2, as per https://github.com/itamarst/eliot/issues/418 'eliot<1.8.0;python_version<"3.5.3"', ], - "saml2": ["pysaml2>=4.5.0"], + "saml2": [ + # pysaml2 6.4.0 is incompatible with Python 3.5 (see https://github.com/IdentityPython/pysaml2/issues/749) + "pysaml2>=4.5.0,<6.4.0;python_version<'3.6'", + "pysaml2>=4.5.0;python_version>='3.6'", + ], "oidc": ["authlib>=0.14.0"], "systemd": ["systemd-python>=231"], "url_preview": ["lxml>=3.5.0"], diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index b4f4a68b5c..7a0dbb5b1a 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -254,20 +254,20 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint): return 200, {} -class ReplicationStoreRoomOnInviteRestServlet(ReplicationEndpoint): +class ReplicationStoreRoomOnOutlierMembershipRestServlet(ReplicationEndpoint): """Called to clean up any data in DB for a given room, ready for the server to join the room. Request format: - POST /_synapse/replication/store_room_on_invite/:room_id/:txn_id + POST /_synapse/replication/store_room_on_outlier_membership/:room_id/:txn_id { "room_version": "1", } """ - NAME = "store_room_on_invite" + NAME = "store_room_on_outlier_membership" PATH_ARGS = ("room_id",) def __init__(self, hs): @@ -282,7 +282,7 @@ class ReplicationStoreRoomOnInviteRestServlet(ReplicationEndpoint): async def _handle_request(self, request, room_id): content = parse_json_object_from_request(request) room_version = KNOWN_ROOM_VERSIONS[content["room_version"]] - await self.store.maybe_store_room_on_invite(room_id, room_version) + await self.store.maybe_store_room_on_outlier_membership(room_id, room_version) return 200, {} @@ -291,4 +291,4 @@ def register_servlets(hs, http_server): ReplicationFederationSendEduRestServlet(hs).register(http_server) ReplicationGetQueryRestServlet(hs).register(http_server) ReplicationCleanRoomRestServlet(hs).register(http_server) - ReplicationStoreRoomOnInviteRestServlet(hs).register(http_server) + ReplicationStoreRoomOnOutlierMembershipRestServlet(hs).register(http_server) diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py index f0c37eaf5e..84e002f934 100644 --- a/synapse/replication/http/membership.py +++ b/synapse/replication/http/membership.py @@ -12,9 +12,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, List, Optional, Tuple + +from twisted.web.http import Request from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint @@ -52,16 +53,23 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint): self.clock = hs.get_clock() @staticmethod - async def _serialize_payload( - requester, room_id, user_id, remote_room_hosts, content - ): + async def _serialize_payload( # type: ignore + requester: Requester, + room_id: str, + user_id: str, + remote_room_hosts: List[str], + content: JsonDict, + ) -> JsonDict: """ Args: - requester(Requester) - room_id (str) - user_id (str) - remote_room_hosts (list[str]): Servers to try and join via - content(dict): The event content to use for the join event + requester: The user making the request according to the access token + room_id: The ID of the room. + user_id: The ID of the user. + remote_room_hosts: Servers to try and join via + content: The event content to use for the join event + + Returns: + A dict representing the payload of the request. """ return { "requester": requester.serialize(), @@ -69,7 +77,9 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint): "content": content, } - async def _handle_request(self, request, room_id, user_id): + async def _handle_request( # type: ignore + self, request: Request, room_id: str, user_id: str + ) -> Tuple[int, JsonDict]: content = parse_json_object_from_request(request) remote_room_hosts = content["remote_room_hosts"] @@ -118,14 +128,17 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): txn_id: Optional[str], requester: Requester, content: JsonDict, - ): + ) -> JsonDict: """ Args: - invite_event_id: ID of the invite to be rejected - txn_id: optional transaction ID supplied by the client - requester: user making the rejection request, according to the access token - content: additional content to include in the rejection event. + invite_event_id: The ID of the invite to be rejected. + txn_id: Optional transaction ID supplied by the client + requester: User making the rejection request, according to the access token + content: Additional content to include in the rejection event. Normally an empty dict. + + Returns: + A dict representing the payload of the request. """ return { "txn_id": txn_id, @@ -133,7 +146,9 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): "content": content, } - async def _handle_request(self, request, invite_event_id): + async def _handle_request( # type: ignore + self, request: Request, invite_event_id: str + ) -> Tuple[int, JsonDict]: content = parse_json_object_from_request(request) txn_id = content["txn_id"] @@ -174,18 +189,25 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint): self.distributor = hs.get_distributor() @staticmethod - async def _serialize_payload(room_id, user_id, change): + async def _serialize_payload( # type: ignore + room_id: str, user_id: str, change: str + ) -> JsonDict: """ Args: - room_id (str) - user_id (str) - change (str): "left" + room_id: The ID of the room. + user_id: The ID of the user. + change: "left" + + Returns: + A dict representing the payload of the request. """ assert change == "left" return {} - def _handle_request(self, request, room_id, user_id, change): + def _handle_request( # type: ignore + self, request: Request, room_id: str, user_id: str, change: str + ) -> Tuple[int, JsonDict]: logger.info("user membership change: %s in %s", user_id, room_id) user = UserID.from_string(user_id) diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 2a4f7a1740..55ddebb4fe 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -21,11 +21,7 @@ import synapse from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.http.server import JsonResource from synapse.http.servlet import RestServlet, parse_json_object_from_request -from synapse.rest.admin._base import ( - admin_patterns, - assert_requester_is_admin, - historical_admin_path_patterns, -) +from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin from synapse.rest.admin.devices import ( DeleteDevicesRestServlet, DeviceRestServlet, @@ -61,6 +57,7 @@ from synapse.rest.admin.users import ( UserRestServletV2, UsersRestServlet, UsersRestServletV2, + UserTokenRestServlet, WhoisRestServlet, ) from synapse.types import RoomStreamToken @@ -83,7 +80,7 @@ class VersionServlet(RestServlet): class PurgeHistoryRestServlet(RestServlet): - PATTERNS = historical_admin_path_patterns( + PATTERNS = admin_patterns( "/purge_history/(?P<room_id>[^/]*)(/(?P<event_id>[^/]+))?" ) @@ -168,9 +165,7 @@ class PurgeHistoryRestServlet(RestServlet): class PurgeHistoryStatusRestServlet(RestServlet): - PATTERNS = historical_admin_path_patterns( - "/purge_history_status/(?P<purge_id>[^/]+)" - ) + PATTERNS = admin_patterns("/purge_history_status/(?P<purge_id>[^/]+)") def __init__(self, hs): """ @@ -223,6 +218,7 @@ def register_servlets(hs, http_server): UserAdminServlet(hs).register(http_server) UserMediaRestServlet(hs).register(http_server) UserMembershipRestServlet(hs).register(http_server) + UserTokenRestServlet(hs).register(http_server) UserRestServletV2(hs).register(http_server) UsersRestServletV2(hs).register(http_server) DeviceRestServlet(hs).register(http_server) diff --git a/synapse/rest/admin/_base.py b/synapse/rest/admin/_base.py index db9fea263a..e09234c644 100644 --- a/synapse/rest/admin/_base.py +++ b/synapse/rest/admin/_base.py @@ -22,28 +22,6 @@ from synapse.api.errors import AuthError from synapse.types import UserID -def historical_admin_path_patterns(path_regex): - """Returns the list of patterns for an admin endpoint, including historical ones - - This is a backwards-compatibility hack. Previously, the Admin API was exposed at - various paths under /_matrix/client. This function returns a list of patterns - matching those paths (as well as the new one), so that existing scripts which rely - on the endpoints being available there are not broken. - - Note that this should only be used for existing endpoints: new ones should just - register for the /_synapse/admin path. - """ - return [ - re.compile(prefix + path_regex) - for prefix in ( - "^/_synapse/admin/v1", - "^/_matrix/client/api/v1/admin", - "^/_matrix/client/unstable/admin", - "^/_matrix/client/r0/admin", - ) - ] - - def admin_patterns(path_regex: str, version: str = "v1"): """Returns the list of patterns for an admin endpoint diff --git a/synapse/rest/admin/groups.py b/synapse/rest/admin/groups.py index 0b54ca09f4..d0c86b204a 100644 --- a/synapse/rest/admin/groups.py +++ b/synapse/rest/admin/groups.py @@ -16,10 +16,7 @@ import logging from synapse.api.errors import SynapseError from synapse.http.servlet import RestServlet -from synapse.rest.admin._base import ( - assert_user_is_admin, - historical_admin_path_patterns, -) +from synapse.rest.admin._base import admin_patterns, assert_user_is_admin logger = logging.getLogger(__name__) @@ -28,7 +25,7 @@ class DeleteGroupAdminRestServlet(RestServlet): """Allows deleting of local groups """ - PATTERNS = historical_admin_path_patterns("/delete_group/(?P<group_id>[^/]*)") + PATTERNS = admin_patterns("/delete_group/(?P<group_id>[^/]*)") def __init__(self, hs): self.group_server = hs.get_groups_server_handler() diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py index ba50cb876d..c82b4f87d6 100644 --- a/synapse/rest/admin/media.py +++ b/synapse/rest/admin/media.py @@ -22,7 +22,6 @@ from synapse.rest.admin._base import ( admin_patterns, assert_requester_is_admin, assert_user_is_admin, - historical_admin_path_patterns, ) logger = logging.getLogger(__name__) @@ -34,10 +33,10 @@ class QuarantineMediaInRoom(RestServlet): """ PATTERNS = ( - historical_admin_path_patterns("/room/(?P<room_id>[^/]+)/media/quarantine") + admin_patterns("/room/(?P<room_id>[^/]+)/media/quarantine") + # This path kept around for legacy reasons - historical_admin_path_patterns("/quarantine_media/(?P<room_id>[^/]+)") + admin_patterns("/quarantine_media/(?P<room_id>[^/]+)") ) def __init__(self, hs): @@ -63,9 +62,7 @@ class QuarantineMediaByUser(RestServlet): this server. """ - PATTERNS = historical_admin_path_patterns( - "/user/(?P<user_id>[^/]+)/media/quarantine" - ) + PATTERNS = admin_patterns("/user/(?P<user_id>[^/]+)/media/quarantine") def __init__(self, hs): self.store = hs.get_datastore() @@ -90,7 +87,7 @@ class QuarantineMediaByID(RestServlet): it via this server. """ - PATTERNS = historical_admin_path_patterns( + PATTERNS = admin_patterns( "/media/quarantine/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)" ) @@ -116,7 +113,7 @@ class ListMediaInRoom(RestServlet): """Lists all of the media in a given room. """ - PATTERNS = historical_admin_path_patterns("/room/(?P<room_id>[^/]+)/media") + PATTERNS = admin_patterns("/room/(?P<room_id>[^/]+)/media") def __init__(self, hs): self.store = hs.get_datastore() @@ -134,7 +131,7 @@ class ListMediaInRoom(RestServlet): class PurgeMediaCacheRestServlet(RestServlet): - PATTERNS = historical_admin_path_patterns("/purge_media_cache") + PATTERNS = admin_patterns("/purge_media_cache") def __init__(self, hs): self.media_repository = hs.get_media_repository() diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index f5304ff43d..25f89e4685 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -29,7 +29,6 @@ from synapse.rest.admin._base import ( admin_patterns, assert_requester_is_admin, assert_user_is_admin, - historical_admin_path_patterns, ) from synapse.storage.databases.main.room import RoomSortOrder from synapse.types import RoomAlias, RoomID, UserID, create_requester @@ -44,7 +43,7 @@ class ShutdownRoomRestServlet(RestServlet): joined to the new room. """ - PATTERNS = historical_admin_path_patterns("/shutdown_room/(?P<room_id>[^/]+)") + PATTERNS = admin_patterns("/shutdown_room/(?P<room_id>[^/]+)") def __init__(self, hs): self.hs = hs @@ -71,14 +70,18 @@ class ShutdownRoomRestServlet(RestServlet): class DeleteRoomRestServlet(RestServlet): - """Delete a room from server. It is a combination and improvement of - shut down and purge room. + """Delete a room from server. + + It is a combination and improvement of shutdown and purge room. + Shuts down a room by removing all local users from the room. Blocking all future invites and joins to the room is optional. + If desired any local aliases will be repointed to a new room - created by `new_room_user_id` and kicked users will be auto + created by `new_room_user_id` and kicked users will be auto- joined to the new room. - It will remove all trace of a room from the database. + + If 'purge' is true, it will remove all traces of a room from the database. """ PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/delete$") @@ -111,6 +114,14 @@ class DeleteRoomRestServlet(RestServlet): Codes.BAD_JSON, ) + force_purge = content.get("force_purge", False) + if not isinstance(force_purge, bool): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Param 'force_purge' must be a boolean, if given", + Codes.BAD_JSON, + ) + ret = await self.room_shutdown_handler.shutdown_room( room_id=room_id, new_room_user_id=content.get("new_room_user_id"), @@ -122,7 +133,7 @@ class DeleteRoomRestServlet(RestServlet): # Purge room if purge: - await self.pagination_handler.purge_room(room_id) + await self.pagination_handler.purge_room(room_id, force=force_purge) return (200, ret) @@ -309,7 +320,9 @@ class JoinRoomAliasServlet(RestServlet): 400, "%s was not legal room ID or room alias" % (room_identifier,) ) - fake_requester = create_requester(target_user) + fake_requester = create_requester( + target_user, authenticated_entity=requester.authenticated_entity + ) # send invite if room has "JoinRules.INVITE" room_state = await self.state_handler.get_current_state(room_id) diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 3638e219f2..b0ff5e1ead 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -16,7 +16,7 @@ import hashlib import hmac import logging from http import HTTPStatus -from typing import Tuple +from typing import TYPE_CHECKING, Tuple from synapse.api.constants import UserTypes from synapse.api.errors import Codes, NotFoundError, SynapseError @@ -33,10 +33,13 @@ from synapse.rest.admin._base import ( admin_patterns, assert_requester_is_admin, assert_user_is_admin, - historical_admin_path_patterns, ) +from synapse.rest.client.v2_alpha._base import client_patterns from synapse.types import JsonDict, UserID +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) _GET_PUSHERS_ALLOWED_KEYS = { @@ -52,7 +55,7 @@ _GET_PUSHERS_ALLOWED_KEYS = { class UsersRestServlet(RestServlet): - PATTERNS = historical_admin_path_patterns("/users/(?P<user_id>[^/]*)$") + PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)$") def __init__(self, hs): self.hs = hs @@ -335,7 +338,7 @@ class UserRegisterServlet(RestServlet): nonce to the time it was generated, in int seconds. """ - PATTERNS = historical_admin_path_patterns("/register") + PATTERNS = admin_patterns("/register") NONCE_TIMEOUT = 60 def __init__(self, hs): @@ -458,7 +461,14 @@ class UserRegisterServlet(RestServlet): class WhoisRestServlet(RestServlet): - PATTERNS = historical_admin_path_patterns("/whois/(?P<user_id>[^/]*)") + path_regex = "/whois/(?P<user_id>[^/]*)$" + PATTERNS = ( + admin_patterns(path_regex) + + + # URL for spec reason + # https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-admin-whois-userid + client_patterns("/admin" + path_regex, v1=True) + ) def __init__(self, hs): self.hs = hs @@ -482,7 +492,7 @@ class WhoisRestServlet(RestServlet): class DeactivateAccountRestServlet(RestServlet): - PATTERNS = historical_admin_path_patterns("/deactivate/(?P<target_user_id>[^/]*)") + PATTERNS = admin_patterns("/deactivate/(?P<target_user_id>[^/]*)") def __init__(self, hs): self._deactivate_account_handler = hs.get_deactivate_account_handler() @@ -513,7 +523,7 @@ class DeactivateAccountRestServlet(RestServlet): class AccountValidityRenewServlet(RestServlet): - PATTERNS = historical_admin_path_patterns("/account_validity/validity$") + PATTERNS = admin_patterns("/account_validity/validity$") def __init__(self, hs): """ @@ -556,9 +566,7 @@ class ResetPasswordRestServlet(RestServlet): 200 OK with empty object if success otherwise an error. """ - PATTERNS = historical_admin_path_patterns( - "/reset_password/(?P<target_user_id>[^/]*)" - ) + PATTERNS = admin_patterns("/reset_password/(?P<target_user_id>[^/]*)") def __init__(self, hs): self.store = hs.get_datastore() @@ -600,7 +608,7 @@ class SearchUsersRestServlet(RestServlet): 200 OK with json object {list[dict[str, Any]], count} or empty object. """ - PATTERNS = historical_admin_path_patterns("/search_users/(?P<target_user_id>[^/]*)") + PATTERNS = admin_patterns("/search_users/(?P<target_user_id>[^/]*)") def __init__(self, hs): self.hs = hs @@ -828,3 +836,52 @@ class UserMediaRestServlet(RestServlet): ret["next_token"] = start + len(media) return 200, ret + + +class UserTokenRestServlet(RestServlet): + """An admin API for logging in as a user. + + Example: + + POST /_synapse/admin/v1/users/@test:example.com/login + {} + + 200 OK + { + "access_token": "<some_token>" + } + """ + + PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/login$") + + def __init__(self, hs: "HomeServer"): + self.hs = hs + self.store = hs.get_datastore() + self.auth = hs.get_auth() + self.auth_handler = hs.get_auth_handler() + + async def on_POST(self, request, user_id): + requester = await self.auth.get_user_by_req(request) + await assert_user_is_admin(self.auth, requester.user) + auth_user = requester.user + + if not self.hs.is_mine_id(user_id): + raise SynapseError(400, "Only local users can be logged in as") + + body = parse_json_object_from_request(request, allow_empty_body=True) + + valid_until_ms = body.get("valid_until_ms") + if valid_until_ms and not isinstance(valid_until_ms, int): + raise SynapseError(400, "'valid_until_ms' parameter must be an int") + + if auth_user.to_string() == user_id: + raise SynapseError(400, "Cannot use admin API to login as self") + + token = await self.auth_handler.get_access_token_for_user_id( + user_id=auth_user.to_string(), + device_id=None, + valid_until_ms=valid_until_ms, + puppets_user_id=user_id, + ) + + return 200, {"access_token": token} diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 94452fcbf5..d7ae148214 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -19,10 +19,6 @@ from typing import Awaitable, Callable, Dict, Optional from synapse.api.errors import Codes, LoginError, SynapseError from synapse.api.ratelimiting import Ratelimiter from synapse.appservice import ApplicationService -from synapse.handlers.auth import ( - convert_client_dict_legacy_fields_to_identifier, - login_id_phone_to_thirdparty, -) from synapse.http.server import finish_request from synapse.http.servlet import ( RestServlet, @@ -33,7 +29,6 @@ from synapse.http.site import SynapseRequest from synapse.rest.client.v2_alpha._base import client_patterns from synapse.rest.well_known import WellKnownBuilder from synapse.types import JsonDict, UserID -from synapse.util.threepids import canonicalise_email logger = logging.getLogger(__name__) @@ -78,11 +73,6 @@ class LoginRestServlet(RestServlet): rate_hz=self.hs.config.rc_login_account.per_second, burst_count=self.hs.config.rc_login_account.burst_count, ) - self._failed_attempts_ratelimiter = Ratelimiter( - clock=hs.get_clock(), - rate_hz=self.hs.config.rc_login_failed_attempts.per_second, - burst_count=self.hs.config.rc_login_failed_attempts.burst_count, - ) def on_GET(self, request: SynapseRequest): flows = [] @@ -140,27 +130,31 @@ class LoginRestServlet(RestServlet): result["well_known"] = well_known_data return 200, result - def _get_qualified_user_id(self, identifier): - if identifier["type"] != "m.id.user": - raise SynapseError(400, "Unknown login identifier type") - if "user" not in identifier: - raise SynapseError(400, "User identifier is missing 'user' key") - - if identifier["user"].startswith("@"): - return identifier["user"] - else: - return UserID(identifier["user"], self.hs.hostname).to_string() - async def _do_appservice_login( self, login_submission: JsonDict, appservice: ApplicationService ): - logger.info( - "Got appservice login request with identifier: %r", - login_submission.get("identifier"), - ) + identifier = login_submission.get("identifier") + logger.info("Got appservice login request with identifier: %r", identifier) + + if not isinstance(identifier, dict): + raise SynapseError( + 400, "Invalid identifier in login submission", Codes.INVALID_PARAM + ) + + # this login flow only supports identifiers of type "m.id.user". + if identifier.get("type") != "m.id.user": + raise SynapseError( + 400, "Unknown login identifier type", Codes.INVALID_PARAM + ) - identifier = convert_client_dict_legacy_fields_to_identifier(login_submission) - qualified_user_id = self._get_qualified_user_id(identifier) + user = identifier.get("user") + if not isinstance(user, str): + raise SynapseError(400, "Invalid user in identifier", Codes.INVALID_PARAM) + + if user.startswith("@"): + qualified_user_id = user + else: + qualified_user_id = UserID(user, self.hs.hostname).to_string() if not appservice.is_interested_in_user(qualified_user_id): raise LoginError(403, "Invalid access_token", errcode=Codes.FORBIDDEN) @@ -186,91 +180,9 @@ class LoginRestServlet(RestServlet): login_submission.get("address"), login_submission.get("user"), ) - identifier = convert_client_dict_legacy_fields_to_identifier(login_submission) - - # convert phone type identifiers to generic threepids - if identifier["type"] == "m.id.phone": - identifier = login_id_phone_to_thirdparty(identifier) - - # convert threepid identifiers to user IDs - if identifier["type"] == "m.id.thirdparty": - address = identifier.get("address") - medium = identifier.get("medium") - - if medium is None or address is None: - raise SynapseError(400, "Invalid thirdparty identifier") - - # For emails, canonicalise the address. - # We store all email addresses canonicalised in the DB. - # (See add_threepid in synapse/handlers/auth.py) - if medium == "email": - try: - address = canonicalise_email(address) - except ValueError as e: - raise SynapseError(400, str(e)) - - # We also apply account rate limiting using the 3PID as a key, as - # otherwise using 3PID bypasses the ratelimiting based on user ID. - self._failed_attempts_ratelimiter.ratelimit((medium, address), update=False) - - # Check for login providers that support 3pid login types - ( - canonical_user_id, - callback_3pid, - ) = await self.auth_handler.check_password_provider_3pid( - medium, address, login_submission["password"] - ) - if canonical_user_id: - # Authentication through password provider and 3pid succeeded - - result = await self._complete_login( - canonical_user_id, login_submission, callback_3pid - ) - return result - - # No password providers were able to handle this 3pid - # Check local store - user_id = await self.hs.get_datastore().get_user_id_by_threepid( - medium, address - ) - if not user_id: - logger.warning( - "unknown 3pid identifier medium %s, address %r", medium, address - ) - # We mark that we've failed to log in here, as - # `check_password_provider_3pid` might have returned `None` due - # to an incorrect password, rather than the account not - # existing. - # - # If it returned None but the 3PID was bound then we won't hit - # this code path, which is fine as then the per-user ratelimit - # will kick in below. - self._failed_attempts_ratelimiter.can_do_action((medium, address)) - raise LoginError(403, "", errcode=Codes.FORBIDDEN) - - identifier = {"type": "m.id.user", "user": user_id} - - # by this point, the identifier should be an m.id.user: if it's anything - # else, we haven't understood it. - qualified_user_id = self._get_qualified_user_id(identifier) - - # Check if we've hit the failed ratelimit (but don't update it) - self._failed_attempts_ratelimiter.ratelimit( - qualified_user_id.lower(), update=False + canonical_user_id, callback = await self.auth_handler.validate_login( + login_submission, ratelimit=True ) - - try: - canonical_user_id, callback = await self.auth_handler.validate_login( - identifier["user"], login_submission - ) - except LoginError: - # The user has failed to log in, so we need to update the rate - # limiter. Using `can_do_action` avoids us raising a ratelimit - # exception and masking the LoginError. The actual ratelimiting - # should have happened above. - self._failed_attempts_ratelimiter.can_do_action(qualified_user_id.lower()) - raise - result = await self._complete_login( canonical_user_id, login_submission, callback ) diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 25d3cc6148..93c06afe27 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -18,7 +18,7 @@ import logging import re -from typing import List, Optional +from typing import TYPE_CHECKING, List, Optional from urllib import parse as urlparse from synapse.api.constants import EventTypes, Membership @@ -48,8 +48,7 @@ from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, from synapse.util import json_decoder from synapse.util.stringutils import random_string -MYPY = False -if MYPY: +if TYPE_CHECKING: import synapse.server logger = logging.getLogger(__name__) diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index a54e1011f7..eebee44a44 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -115,7 +115,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): # comments for request_token_inhibit_3pid_errors. # Also wait for some random amount of time between 100ms and 1s to make it # look like we did something. - await self.hs.clock.sleep(random.randint(1, 10) / 10) + await self.hs.get_clock().sleep(random.randint(1, 10) / 10) return 200, {"sid": random_string(16)} raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND) @@ -387,7 +387,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): # comments for request_token_inhibit_3pid_errors. # Also wait for some random amount of time between 100ms and 1s to make it # look like we did something. - await self.hs.clock.sleep(random.randint(1, 10) / 10) + await self.hs.get_clock().sleep(random.randint(1, 10) / 10) return 200, {"sid": random_string(16)} raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) @@ -466,7 +466,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): # comments for request_token_inhibit_3pid_errors. # Also wait for some random amount of time between 100ms and 1s to make it # look like we did something. - await self.hs.clock.sleep(random.randint(1, 10) / 10) + await self.hs.get_clock().sleep(random.randint(1, 10) / 10) return 200, {"sid": random_string(16)} raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index ea68114026..a89ae6ddf9 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -135,7 +135,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): # comments for request_token_inhibit_3pid_errors. # Also wait for some random amount of time between 100ms and 1s to make it # look like we did something. - await self.hs.clock.sleep(random.randint(1, 10) / 10) + await self.hs.get_clock().sleep(random.randint(1, 10) / 10) return 200, {"sid": random_string(16)} raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) @@ -214,7 +214,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): # comments for request_token_inhibit_3pid_errors. # Also wait for some random amount of time between 100ms and 1s to make it # look like we did something. - await self.hs.clock.sleep(random.randint(1, 10) / 10) + await self.hs.get_clock().sleep(random.randint(1, 10) / 10) return 200, {"sid": random_string(16)} raise SynapseError( diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index 2b84eb89c0..8e52e4cca4 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -171,6 +171,7 @@ class SyncRestServlet(RestServlet): ) with context: sync_result = await self.sync_handler.wait_for_sync_for_user( + requester, sync_config, since_token=since_token, timeout=timeout, diff --git a/synapse/rest/key/v2/local_key_resource.py b/synapse/rest/key/v2/local_key_resource.py index c16280f668..d8e8e48c1c 100644 --- a/synapse/rest/key/v2/local_key_resource.py +++ b/synapse/rest/key/v2/local_key_resource.py @@ -66,7 +66,7 @@ class LocalKey(Resource): def __init__(self, hs): self.config = hs.config - self.clock = hs.clock + self.clock = hs.get_clock() self.update_response_body(self.clock.time_msec()) Resource.__init__(self) diff --git a/synapse/server.py b/synapse/server.py index 21a232bbd9..b017e3489f 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -27,7 +27,8 @@ import logging import os from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, TypeVar, cast -import twisted +import twisted.internet.base +import twisted.internet.tcp from twisted.mail.smtp import sendmail from twisted.web.iweb import IPolicyForHTTPS @@ -89,6 +90,7 @@ from synapse.handlers.room_member import RoomMemberMasterHandler from synapse.handlers.room_member_worker import RoomMemberWorkerHandler from synapse.handlers.search import SearchHandler from synapse.handlers.set_password import SetPasswordHandler +from synapse.handlers.sso import SsoHandler from synapse.handlers.stats import StatsHandler from synapse.handlers.sync import SyncHandler from synapse.handlers.typing import FollowerTypingHandler, TypingWriterHandler @@ -145,7 +147,8 @@ def cache_in_self(builder: T) -> T: "@cache_in_self can only be used on functions starting with `get_`" ) - depname = builder.__name__[len("get_") :] + # get_attr -> _attr + depname = builder.__name__[len("get") :] building = [False] @@ -233,15 +236,6 @@ class HomeServer(metaclass=abc.ABCMeta): self._instance_id = random_string(5) self._instance_name = config.worker_name or "master" - self.clock = Clock(reactor) - self.distributor = Distributor() - - self.registration_ratelimiter = Ratelimiter( - clock=self.clock, - rate_hz=config.rc_registration.per_second, - burst_count=config.rc_registration.burst_count, - ) - self.version_string = version_string self.datastores = None # type: Optional[Databases] @@ -299,8 +293,9 @@ class HomeServer(metaclass=abc.ABCMeta): def is_mine_id(self, string: str) -> bool: return string.split(":", 1)[1] == self.hostname + @cache_in_self def get_clock(self) -> Clock: - return self.clock + return Clock(self._reactor) def get_datastore(self) -> DataStore: if not self.datastores: @@ -317,11 +312,17 @@ class HomeServer(metaclass=abc.ABCMeta): def get_config(self) -> HomeServerConfig: return self.config + @cache_in_self def get_distributor(self) -> Distributor: - return self.distributor + return Distributor() + @cache_in_self def get_registration_ratelimiter(self) -> Ratelimiter: - return self.registration_ratelimiter + return Ratelimiter( + clock=self.get_clock(), + rate_hz=self.config.rc_registration.per_second, + burst_count=self.config.rc_registration.burst_count, + ) @cache_in_self def get_federation_client(self) -> FederationClient: @@ -391,6 +392,10 @@ class HomeServer(metaclass=abc.ABCMeta): return FollowerTypingHandler(self) @cache_in_self + def get_sso_handler(self) -> SsoHandler: + return SsoHandler(self) + + @cache_in_self def get_sync_handler(self) -> SyncHandler: return SyncHandler(self) @@ -681,7 +686,7 @@ class HomeServer(metaclass=abc.ABCMeta): @cache_in_self def get_federation_ratelimiter(self) -> FederationRateLimiter: - return FederationRateLimiter(self.clock, config=self.config.rc_federation) + return FederationRateLimiter(self.get_clock(), config=self.config.rc_federation) @cache_in_self def get_module_api(self) -> ModuleApi: diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py index d464c75c03..100dbd5e2c 100644 --- a/synapse/server_notices/server_notices_manager.py +++ b/synapse/server_notices/server_notices_manager.py @@ -39,6 +39,7 @@ class ServerNoticesManager: self._room_member_handler = hs.get_room_member_handler() self._event_creation_handler = hs.get_event_creation_handler() self._is_mine_id = hs.is_mine_id + self._server_name = hs.hostname self._notifier = hs.get_notifier() self.server_notices_mxid = self._config.server_notices_mxid @@ -72,7 +73,9 @@ class ServerNoticesManager: await self.maybe_invite_user_to_room(user_id, room_id) system_mxid = self._config.server_notices_mxid - requester = create_requester(system_mxid) + requester = create_requester( + system_mxid, authenticated_entity=self._server_name + ) logger.info("Sending server notice to %s", user_id) @@ -145,7 +148,9 @@ class ServerNoticesManager: "avatar_url": self._config.server_notices_mxid_avatar_url, } - requester = create_requester(self.server_notices_mxid) + requester = create_requester( + self.server_notices_mxid, authenticated_entity=self._server_name + ) info, _ = await self._room_creation_handler.create_room( requester, config={ @@ -174,7 +179,9 @@ class ServerNoticesManager: user_id: The ID of the user to invite. room_id: The ID of the room to invite the user to. """ - requester = create_requester(self.server_notices_mxid) + requester = create_requester( + self.server_notices_mxid, authenticated_entity=self._server_name + ) # Check whether the user has already joined or been invited to this room. If # that's the case, there is no need to re-invite them. diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py index ecfc6717b3..5d668aadb2 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py @@ -314,6 +314,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore): for table in ( "event_auth", "event_edges", + "event_json", "event_push_actions_staging", "event_reference_hashes", "event_relations", @@ -340,7 +341,6 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore): "destination_rooms", "event_backward_extremities", "event_forward_extremities", - "event_json", "event_push_actions", "event_search", "events", diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index ca7917c989..1e7949a323 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -278,7 +278,8 @@ class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta): async def get_linearized_receipts_for_all_rooms( self, to_key: int, from_key: Optional[int] = None ) -> Dict[str, JsonDict]: - """Get receipts for all rooms between two stream_ids. + """Get receipts for all rooms between two stream_ids, up + to a limit of the latest 100 read receipts. Args: to_key: Max stream id to fetch receipts upto. @@ -294,12 +295,16 @@ class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta): sql = """ SELECT * FROM receipts_linearized WHERE stream_id > ? AND stream_id <= ? + ORDER BY stream_id DESC + LIMIT 100 """ txn.execute(sql, [from_key, to_key]) else: sql = """ SELECT * FROM receipts_linearized WHERE stream_id <= ? + ORDER BY stream_id DESC + LIMIT 100 """ txn.execute(sql, [to_key]) diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index e5d07ce72a..fedb8a6c26 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -1110,6 +1110,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): token: str, device_id: Optional[str], valid_until_ms: Optional[int], + puppets_user_id: Optional[str] = None, ) -> int: """Adds an access token for the given user. @@ -1133,6 +1134,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): "token": token, "device_id": device_id, "valid_until_ms": valid_until_ms, + "puppets_user_id": puppets_user_id, }, desc="add_access_token_to_user", ) diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index dc0c4b5499..6b89db15c9 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -1240,13 +1240,15 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): logger.error("store_room with room_id=%s failed: %s", room_id, e) raise StoreError(500, "Problem creating room.") - async def maybe_store_room_on_invite(self, room_id: str, room_version: RoomVersion): + async def maybe_store_room_on_outlier_membership( + self, room_id: str, room_version: RoomVersion + ): """ - When we receive an invite over federation, store the version of the room if we - don't already know the room version. + When we receive an invite or any other event over federation that may relate to a room + we are not in, store the version of the room if we don't already know the room version. """ await self.db_pool.simple_upsert( - desc="maybe_store_room_on_invite", + desc="maybe_store_room_on_outlier_membership", table="rooms", keyvalues={"room_id": room_id}, values={}, diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 01d9dbb36f..dcdaf09682 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Set +from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Set, Tuple from synapse.api.constants import EventTypes, Membership from synapse.events import EventBase @@ -350,6 +350,38 @@ class RoomMemberWorkerStore(EventsWorkerStore): return results + async def get_local_current_membership_for_user_in_room( + self, user_id: str, room_id: str + ) -> Tuple[Optional[str], Optional[str]]: + """Retrieve the current local membership state and event ID for a user in a room. + + Args: + user_id: The ID of the user. + room_id: The ID of the room. + + Returns: + A tuple of (membership_type, event_id). Both will be None if a + room_id/user_id pair is not found. + """ + # Paranoia check. + if not self.hs.is_mine_id(user_id): + raise Exception( + "Cannot call 'get_local_current_membership_for_user_in_room' on " + "non-local user %s" % (user_id,), + ) + + results_dict = await self.db_pool.simple_select_one( + "local_current_membership", + {"room_id": room_id, "user_id": user_id}, + ("membership", "event_id"), + allow_none=True, + desc="get_local_current_membership_for_user_in_room", + ) + if not results_dict: + return None, None + + return results_dict.get("membership"), results_dict.get("event_id") + @cached(max_entries=500000, iterable=True) async def get_rooms_for_user_with_stream_ordering( self, user_id: str diff --git a/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.postgres b/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.postgres index b64926e9c9..3275ae2b20 100644 --- a/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.postgres +++ b/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.postgres @@ -20,14 +20,14 @@ */ -- add new index that includes method to local media -INSERT INTO background_updates (update_name, progress_json) VALUES - ('local_media_repository_thumbnails_method_idx', '{}'); +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (5807, 'local_media_repository_thumbnails_method_idx', '{}'); -- add new index that includes method to remote media -INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES - ('remote_media_repository_thumbnails_method_idx', '{}', 'local_media_repository_thumbnails_method_idx'); +INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES + (5807, 'remote_media_repository_thumbnails_method_idx', '{}', 'local_media_repository_thumbnails_method_idx'); -- drop old index -INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES - ('media_repository_drop_index_wo_method', '{}', 'remote_media_repository_thumbnails_method_idx'); +INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES + (5807, 'media_repository_drop_index_wo_method', '{}', 'remote_media_repository_thumbnails_method_idx'); diff --git a/synapse/storage/databases/main/schema/delta/58/12room_stats.sql b/synapse/storage/databases/main/schema/delta/58/12room_stats.sql index cade5dcca8..fd733adf13 100644 --- a/synapse/storage/databases/main/schema/delta/58/12room_stats.sql +++ b/synapse/storage/databases/main/schema/delta/58/12room_stats.sql @@ -28,5 +28,5 @@ -- functionality as the old one. This effectively restarts the background job -- from the beginning, without running it twice in a row, supporting both -- upgrade usecases. -INSERT INTO background_updates (update_name, progress_json) VALUES - ('populate_stats_process_rooms_2', '{}'); +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (5812, 'populate_stats_process_rooms_2', '{}'); diff --git a/synapse/storage/databases/main/schema/delta/58/22users_have_local_media.sql b/synapse/storage/databases/main/schema/delta/58/22users_have_local_media.sql index a2842687f1..e1a35be831 100644 --- a/synapse/storage/databases/main/schema/delta/58/22users_have_local_media.sql +++ b/synapse/storage/databases/main/schema/delta/58/22users_have_local_media.sql @@ -1,2 +1,2 @@ -INSERT INTO background_updates (update_name, progress_json) VALUES - ('users_have_local_media', '{}'); \ No newline at end of file +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (5822, 'users_have_local_media', '{}'); diff --git a/synapse/storage/databases/main/schema/delta/58/23e2e_cross_signing_keys_idx.sql b/synapse/storage/databases/main/schema/delta/58/23e2e_cross_signing_keys_idx.sql index 61c558db77..75c3915a94 100644 --- a/synapse/storage/databases/main/schema/delta/58/23e2e_cross_signing_keys_idx.sql +++ b/synapse/storage/databases/main/schema/delta/58/23e2e_cross_signing_keys_idx.sql @@ -13,5 +13,5 @@ * limitations under the License. */ -INSERT INTO background_updates (update_name, progress_json) VALUES - ('e2e_cross_signing_keys_idx', '{}'); +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (5823, 'e2e_cross_signing_keys_idx', '{}'); diff --git a/synapse/storage/databases/main/schema/delta/58/24drop_event_json_index.sql b/synapse/storage/databases/main/schema/delta/58/24drop_event_json_index.sql new file mode 100644 index 0000000000..8a39d54aed --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/24drop_event_json_index.sql @@ -0,0 +1,19 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- this index is essentially redundant. The only time it was ever used was when purging +-- rooms - and Synapse 1.24 will change that. + +DROP INDEX IF EXISTS event_json_room_id; diff --git a/synapse/types.py b/synapse/types.py index 66bb5bac8d..3ab6bdbe06 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -317,14 +317,14 @@ mxid_localpart_allowed_characters = set( ) -def contains_invalid_mxid_characters(localpart): +def contains_invalid_mxid_characters(localpart: str) -> bool: """Check for characters not allowed in an mxid or groupid localpart Args: - localpart (basestring): the localpart to be checked + localpart: the localpart to be checked Returns: - bool: True if there are any naughty characters + True if there are any naughty characters """ return any(c not in mxid_localpart_allowed_characters for c in localpart) |