From d0a43d431ec3ab7603827d26924b3708a5bc33b2 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 12 Jun 2020 14:12:04 -0400 Subject: Fix a typo when comparing the URI & method during UI Auth. (#7689) --- synapse/handlers/auth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/handlers/auth.py') diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 119678e67b..b01124fe42 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -297,7 +297,7 @@ class AuthHandler(BaseHandler): # Convert the URI and method to strings. uri = request.uri.decode("utf-8") - method = request.uri.decode("utf-8") + method = request.method.decode("utf-8") # If there's no session ID, create a new session. if not sid: -- cgit 1.5.1 From ea26e9a98b0541fc886a1cb826a38352b7599dbe Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 1 Jul 2020 09:10:23 -0400 Subject: Ensure that HTML pages served from Synapse include headers to avoid embedding. --- synapse/app/homeserver.py | 3 +- synapse/handlers/auth.py | 30 +++------- synapse/handlers/oidc_handler.py | 13 ++-- synapse/http/server.py | 76 +++++++++++++++++++++--- synapse/rest/client/v1/pusher.py | 10 +--- synapse/rest/client/v2_alpha/account.py | 16 +++-- synapse/rest/client/v2_alpha/account_validity.py | 11 +--- synapse/rest/client/v2_alpha/auth.py | 18 +----- synapse/rest/client/v2_alpha/register.py | 10 ++-- synapse/rest/consent/consent_resource.py | 10 +--- 10 files changed, 103 insertions(+), 94 deletions(-) (limited to 'synapse/handlers/auth.py') diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 8454d74858..41994dc14b 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -56,6 +56,7 @@ from synapse.http.server import ( OptionsResource, RootOptionsRedirectResource, RootRedirect, + StaticResource, ) from synapse.http.site import SynapseSite from synapse.logging.context import LoggingContext @@ -228,7 +229,7 @@ class SynapseHomeServer(HomeServer): if name in ["static", "client"]: resources.update( { - STATIC_PREFIX: File( + STATIC_PREFIX: StaticResource( os.path.join(os.path.dirname(synapse.__file__), "static") ) } diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 119678e67b..bb3b43d5ae 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -38,7 +38,7 @@ from synapse.api.errors import ( from synapse.api.ratelimiting import Ratelimiter from synapse.handlers.ui_auth import INTERACTIVE_AUTH_CHECKERS from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker -from synapse.http.server import finish_request +from synapse.http.server import finish_request, respond_with_html from synapse.http.site import SynapseRequest from synapse.logging.context import defer_to_thread from synapse.metrics.background_process_metrics import run_as_background_process @@ -1055,13 +1055,8 @@ class AuthHandler(BaseHandler): ) # Render the HTML and return. - html_bytes = self._sso_auth_success_template.encode("utf-8") - request.setResponseCode(200) - request.setHeader(b"Content-Type", b"text/html; charset=utf-8") - request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) - - request.write(html_bytes) - finish_request(request) + html = self._sso_auth_success_template + respond_with_html(request, 200, html) async def complete_sso_login( self, @@ -1081,13 +1076,7 @@ class AuthHandler(BaseHandler): # flow. deactivated = await self.store.get_user_deactivated_status(registered_user_id) if deactivated: - html_bytes = self._sso_account_deactivated_template.encode("utf-8") - - request.setResponseCode(403) - request.setHeader(b"Content-Type", b"text/html; charset=utf-8") - request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) - request.write(html_bytes) - finish_request(request) + respond_with_html(request, 403, self._sso_account_deactivated_template) return self._complete_sso_login(registered_user_id, request, client_redirect_url) @@ -1128,17 +1117,12 @@ class AuthHandler(BaseHandler): # URL we redirect users to. redirect_url_no_params = client_redirect_url.split("?")[0] - html_bytes = self._sso_redirect_confirm_template.render( + html = self._sso_redirect_confirm_template.render( display_url=redirect_url_no_params, redirect_url=redirect_url, server_name=self._server_name, - ).encode("utf-8") - - request.setResponseCode(200) - request.setHeader(b"Content-Type", b"text/html; charset=utf-8") - request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) - request.write(html_bytes) - finish_request(request) + ) + respond_with_html(request, 200, html) @staticmethod def add_query_param_to_url(url: str, param_name: str, param: Any): diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index 9c08eb5399..87f0c5e197 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -35,7 +35,7 @@ from typing_extensions import TypedDict from twisted.web.client import readBody from synapse.config import ConfigError -from synapse.http.server import finish_request +from synapse.http.server import respond_with_html from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable from synapse.push.mailer import load_jinja2_templates @@ -144,15 +144,10 @@ class OidcHandler: access_denied. error_description: A human-readable description of the error. """ - html_bytes = self._error_template.render( + html = self._error_template.render( error=error, error_description=error_description - ).encode("utf-8") - - request.setResponseCode(400) - request.setHeader(b"Content-Type", b"text/html; charset=utf-8") - request.setHeader(b"Content-Length", b"%i" % len(html_bytes)) - request.write(html_bytes) - finish_request(request) + ) + respond_with_html(request, 400, html) def _validate_metadata(self): """Verifies the provider metadata. diff --git a/synapse/http/server.py b/synapse/http/server.py index 2487a72171..2331a2a4b0 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -30,7 +30,7 @@ from twisted.internet import defer from twisted.python import failure from twisted.web import resource from twisted.web.server import NOT_DONE_YET, Request -from twisted.web.static import NoRangeStaticProducer +from twisted.web.static import File, NoRangeStaticProducer from twisted.web.util import redirectTo import synapse.events @@ -202,12 +202,7 @@ def return_html_error( else: body = error_template.render(code=code, msg=msg) - body_bytes = body.encode("utf-8") - request.setResponseCode(code) - request.setHeader(b"Content-Type", b"text/html; charset=utf-8") - request.setHeader(b"Content-Length", b"%i" % (len(body_bytes),)) - request.write(body_bytes) - finish_request(request) + respond_with_html(request, code, body) def wrap_async_request_handler(h): @@ -420,6 +415,18 @@ class DirectServeResource(resource.Resource): return NOT_DONE_YET +class StaticResource(File): + """ + A resource that represents a plain non-interpreted file or directory. + + Differs from the File resource by adding clickjacking protection. + """ + + def render_GET(self, request: Request): + set_clickjacking_protection_headers(request) + return super().render_GET(request) + + def _options_handler(request): """Request handler for OPTIONS requests @@ -530,7 +537,7 @@ def respond_with_json_bytes( code (int): The HTTP response code. json_bytes (bytes): The json bytes to use as the response body. send_cors (bool): Whether to send Cross-Origin Resource Sharing headers - http://www.w3.org/TR/cors/ + https://fetch.spec.whatwg.org/#http-cors-protocol Returns: twisted.web.server.NOT_DONE_YET""" @@ -568,6 +575,59 @@ def set_cors_headers(request): ) +def respond_with_html(request: Request, code: int, html: str): + """ + Wraps `respond_with_html_bytes` by first encoding HTML from a str to UTF-8 bytes. + """ + respond_with_html_bytes(request, code, html.encode("utf-8")) + + +def respond_with_html_bytes(request: Request, code: int, html_bytes: bytes): + """ + Sends HTML (encoded as UTF-8 bytes) as the response to the given request. + + Note that this adds clickjacking protection headers and finishes the request. + + Args: + request: The http request to respond to. + code: The HTTP response code. + html_bytes: The HTML bytes to use as the response body. + """ + # could alternatively use request.notifyFinish() and flip a flag when + # the Deferred fires, but since the flag is RIGHT THERE it seems like + # a waste. + if request._disconnected: + logger.warning( + "Not sending response to request %s, already disconnected.", request + ) + return + + request.setResponseCode(code) + request.setHeader(b"Content-Type", b"text/html; charset=utf-8") + request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) + + # Ensure this content cannot be embedded. + set_clickjacking_protection_headers(request) + + request.write(html_bytes) + finish_request(request) + + +def set_clickjacking_protection_headers(request: Request): + """ + Set headers to guard against clickjacking of embedded content. + + This sets the X-Frame-Options and Content-Security-Policy headers which instructs + browsers to not allow the HTML of the response to be embedded onto another + page. + + Args: + request: The http request to add the headers to. + """ + request.setHeader(b"X-Frame-Options", b"DENY") + request.setHeader(b"Content-Security-Policy", b"frame-ancestors 'none';") + + def finish_request(request): """ Finish writing the response to the request. diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py index 550a2f1b44..5f65cb7d83 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py @@ -16,7 +16,7 @@ import logging from synapse.api.errors import Codes, StoreError, SynapseError -from synapse.http.server import finish_request +from synapse.http.server import respond_with_html_bytes from synapse.http.servlet import ( RestServlet, assert_params_in_dict, @@ -177,13 +177,9 @@ class PushersRemoveRestServlet(RestServlet): self.notifier.on_new_replication_data() - request.setResponseCode(200) - request.setHeader(b"Content-Type", b"text/html; charset=utf-8") - request.setHeader( - b"Content-Length", b"%d" % (len(PushersRemoveRestServlet.SUCCESS_HTML),) + respond_with_html_bytes( + request, 200, PushersRemoveRestServlet.SUCCESS_HTML, ) - request.write(PushersRemoveRestServlet.SUCCESS_HTML) - finish_request(request) return None def on_OPTIONS(self, _): diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 1dc4a3247f..b58a77826f 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -21,7 +21,7 @@ from six.moves import http_client from synapse.api.constants import LoginType from synapse.api.errors import Codes, SynapseError, ThreepidValidationError from synapse.config.emailconfig import ThreepidBehaviour -from synapse.http.server import finish_request +from synapse.http.server import finish_request, respond_with_html from synapse.http.servlet import ( RestServlet, assert_params_in_dict, @@ -199,16 +199,15 @@ class PasswordResetSubmitTokenServlet(RestServlet): # Otherwise show the success template html = self.config.email_password_reset_template_success_html - request.setResponseCode(200) + status_code = 200 except ThreepidValidationError as e: - request.setResponseCode(e.code) + status_code = e.code # Show a failure page with a reason template_vars = {"failure_reason": e.msg} html = self.failure_email_template.render(**template_vars) - request.write(html.encode("utf-8")) - finish_request(request) + respond_with_html(request, status_code, html) class PasswordRestServlet(RestServlet): @@ -571,16 +570,15 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet): # Otherwise show the success template html = self.config.email_add_threepid_template_success_html_content - request.setResponseCode(200) + status_code = 200 except ThreepidValidationError as e: - request.setResponseCode(e.code) + status_code = e.code # Show a failure page with a reason template_vars = {"failure_reason": e.msg} html = self.failure_email_template.render(**template_vars) - request.write(html.encode("utf-8")) - finish_request(request) + respond_with_html(request, status_code, html) class AddThreepidMsisdnSubmitTokenServlet(RestServlet): diff --git a/synapse/rest/client/v2_alpha/account_validity.py b/synapse/rest/client/v2_alpha/account_validity.py index 2f10fa64e2..d06336ceea 100644 --- a/synapse/rest/client/v2_alpha/account_validity.py +++ b/synapse/rest/client/v2_alpha/account_validity.py @@ -16,7 +16,7 @@ import logging from synapse.api.errors import AuthError, SynapseError -from synapse.http.server import finish_request +from synapse.http.server import respond_with_html from synapse.http.servlet import RestServlet from ._base import client_patterns @@ -26,9 +26,6 @@ logger = logging.getLogger(__name__) class AccountValidityRenewServlet(RestServlet): PATTERNS = client_patterns("/account_validity/renew$") - SUCCESS_HTML = ( - b"Your account has been successfully renewed." - ) def __init__(self, hs): """ @@ -59,11 +56,7 @@ class AccountValidityRenewServlet(RestServlet): status_code = 404 response = self.failure_html - request.setResponseCode(status_code) - request.setHeader(b"Content-Type", b"text/html; charset=utf-8") - request.setHeader(b"Content-Length", b"%d" % (len(response),)) - request.write(response.encode("utf8")) - finish_request(request) + respond_with_html(request, status_code, response) class AccountValiditySendMailServlet(RestServlet): diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py index 75590ebaeb..8e585e9153 100644 --- a/synapse/rest/client/v2_alpha/auth.py +++ b/synapse/rest/client/v2_alpha/auth.py @@ -18,7 +18,7 @@ import logging from synapse.api.constants import LoginType from synapse.api.errors import SynapseError from synapse.api.urls import CLIENT_API_PREFIX -from synapse.http.server import finish_request +from synapse.http.server import respond_with_html from synapse.http.servlet import RestServlet, parse_string from ._base import client_patterns @@ -200,13 +200,7 @@ class AuthRestServlet(RestServlet): raise SynapseError(404, "Unknown auth stage type") # Render the HTML and return. - html_bytes = html.encode("utf8") - request.setResponseCode(200) - request.setHeader(b"Content-Type", b"text/html; charset=utf-8") - request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) - - request.write(html_bytes) - finish_request(request) + respond_with_html(request, 200, html) return None async def on_POST(self, request, stagetype): @@ -263,13 +257,7 @@ class AuthRestServlet(RestServlet): raise SynapseError(404, "Unknown auth stage type") # Render the HTML and return. - html_bytes = html.encode("utf8") - request.setResponseCode(200) - request.setHeader(b"Content-Type", b"text/html; charset=utf-8") - request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) - - request.write(html_bytes) - finish_request(request) + respond_with_html(request, 200, html) return None def on_OPTIONS(self, _): diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index b9ffe86b2a..c8d2de7b54 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -38,7 +38,7 @@ from synapse.config.ratelimiting import FederationRateLimitConfig from synapse.config.registration import RegistrationConfig from synapse.config.server import is_threepid_reserved from synapse.handlers.auth import AuthHandler -from synapse.http.server import finish_request +from synapse.http.server import finish_request, respond_with_html from synapse.http.servlet import ( RestServlet, assert_params_in_dict, @@ -306,17 +306,15 @@ class RegistrationSubmitTokenServlet(RestServlet): # Otherwise show the success template html = self.config.email_registration_template_success_html_content - - request.setResponseCode(200) + status_code = 200 except ThreepidValidationError as e: - request.setResponseCode(e.code) + status_code = e.code # Show a failure page with a reason template_vars = {"failure_reason": e.msg} html = self.failure_email_template.render(**template_vars) - request.write(html.encode("utf-8")) - finish_request(request) + respond_with_html(request, status_code, html) class UsernameAvailabilityRestServlet(RestServlet): diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py index 1ddf9997ff..4a20282d1b 100644 --- a/synapse/rest/consent/consent_resource.py +++ b/synapse/rest/consent/consent_resource.py @@ -29,7 +29,7 @@ from synapse.api.errors import NotFoundError, StoreError, SynapseError from synapse.config import ConfigError from synapse.http.server import ( DirectServeResource, - finish_request, + respond_with_html, wrap_html_request_handler, ) from synapse.http.servlet import parse_string @@ -197,12 +197,8 @@ class ConsentResource(DirectServeResource): template_html = self._jinja_env.get_template( path.join(TEMPLATE_LANGUAGE, template_name) ) - html_bytes = template_html.render(**template_args).encode("utf8") - - request.setHeader(b"Content-Type", b"text/html; charset=utf-8") - request.setHeader(b"Content-Length", b"%i" % len(html_bytes)) - request.write(html_bytes) - finish_request(request) + html = template_html.render(**template_args) + respond_with_html(request, 200, html) def _check_hash(self, userid, userhmac): """ -- cgit 1.5.1 From 21a212f8e50343e9b55944fa75ece7911fd2cb70 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Fri, 3 Jul 2020 15:03:13 +0200 Subject: Fix inconsistent handling of upper and lower cases of email addresses. (#7021) fixes #7016 --- changelog.d/7021.bugfix | 1 + synapse/handlers/auth.py | 5 +- synapse/rest/client/v1/login.py | 12 +- synapse/rest/client/v2_alpha/account.py | 40 +++++-- synapse/rest/client/v2_alpha/register.py | 22 +++- synapse/util/threepids.py | 23 ++++ tests/rest/client/v2_alpha/test_account.py | 175 ++++++++++++++++++++++++----- tests/util/test_threepids.py | 49 ++++++++ 8 files changed, 279 insertions(+), 48 deletions(-) create mode 100644 changelog.d/7021.bugfix create mode 100644 tests/util/test_threepids.py (limited to 'synapse/handlers/auth.py') diff --git a/changelog.d/7021.bugfix b/changelog.d/7021.bugfix new file mode 100644 index 0000000000..140fe37b2d --- /dev/null +++ b/changelog.d/7021.bugfix @@ -0,0 +1 @@ +Fix inconsistent handling of upper and lower case in email addresses when used as identifiers for login, etc. Contributed by @dklimpel. diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index c3f86e7414..d713a06bf9 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -45,6 +45,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.module_api import ModuleApi from synapse.push.mailer import load_jinja2_templates from synapse.types import Requester, UserID +from synapse.util.threepids import canonicalise_email from ._base import BaseHandler @@ -928,7 +929,7 @@ class AuthHandler(BaseHandler): # for the presence of an email address during password reset was # case sensitive). if medium == "email": - address = address.lower() + address = canonicalise_email(address) await self.store.user_add_threepid( user_id, medium, address, validated_at, self.hs.get_clock().time_msec() @@ -956,7 +957,7 @@ class AuthHandler(BaseHandler): # 'Canonicalise' email addresses as per above if medium == "email": - address = address.lower() + address = canonicalise_email(address) identity_handler = self.hs.get_handlers().identity_handler result = await identity_handler.try_unbind_threepid( diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index bf0f9bd077..f6eef7afee 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -28,6 +28,7 @@ from synapse.rest.client.v2_alpha._base import client_patterns from synapse.rest.well_known import WellKnownBuilder from synapse.types import UserID from synapse.util.msisdn import phone_number_to_msisdn +from synapse.util.threepids import canonicalise_email logger = logging.getLogger(__name__) @@ -206,11 +207,14 @@ class LoginRestServlet(RestServlet): if medium is None or address is None: raise SynapseError(400, "Invalid thirdparty identifier") + # For emails, canonicalise the address. + # We store all email addresses canonicalised in the DB. + # (See add_threepid in synapse/handlers/auth.py) if medium == "email": - # For emails, transform the address to lowercase. - # We store all email addreses as lowercase in the DB. - # (See add_threepid in synapse/handlers/auth.py) - address = address.lower() + try: + address = canonicalise_email(address) + except ValueError as e: + raise SynapseError(400, str(e)) # We also apply account rate limiting using the 3PID as a key, as # otherwise using 3PID bypasses the ratelimiting based on user ID. diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 182a308eef..3767a809a4 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -30,7 +30,7 @@ from synapse.http.servlet import ( from synapse.push.mailer import Mailer, load_jinja2_templates from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.stringutils import assert_valid_client_secret, random_string -from synapse.util.threepids import check_3pid_allowed +from synapse.util.threepids import canonicalise_email, check_3pid_allowed from ._base import client_patterns, interactive_auth_handler @@ -83,7 +83,15 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): client_secret = body["client_secret"] assert_valid_client_secret(client_secret) - email = body["email"] + # Canonicalise the email address. The addresses are all stored canonicalised + # in the database. This allows the user to reset his password without having to + # know the exact spelling (eg. upper and lower case) of address in the database. + # Stored in the database "foo@bar.com" + # User requests with "FOO@bar.com" would raise a Not Found error + try: + email = canonicalise_email(body["email"]) + except ValueError as e: + raise SynapseError(400, str(e)) send_attempt = body["send_attempt"] next_link = body.get("next_link") # Optional param @@ -94,6 +102,10 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) + # The email will be sent to the stored address. + # This avoids a potential account hijack by requesting a password reset to + # an email address which is controlled by the attacker but which, after + # canonicalisation, matches the one in our database. existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( "email", email ) @@ -274,10 +286,13 @@ class PasswordRestServlet(RestServlet): if "medium" not in threepid or "address" not in threepid: raise SynapseError(500, "Malformed threepid") if threepid["medium"] == "email": - # For emails, transform the address to lowercase. - # We store all email addreses as lowercase in the DB. + # For emails, canonicalise the address. + # We store all email addresses canonicalised in the DB. # (See add_threepid in synapse/handlers/auth.py) - threepid["address"] = threepid["address"].lower() + try: + threepid["address"] = canonicalise_email(threepid["address"]) + except ValueError as e: + raise SynapseError(400, str(e)) # if using email, we must know about the email they're authing with! threepid_user_id = await self.datastore.get_user_id_by_threepid( threepid["medium"], threepid["address"] @@ -392,7 +407,16 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): client_secret = body["client_secret"] assert_valid_client_secret(client_secret) - email = body["email"] + # Canonicalise the email address. The addresses are all stored canonicalised + # in the database. + # This ensures that the validation email is sent to the canonicalised address + # as it will later be entered into the database. + # Otherwise the email will be sent to "FOO@bar.com" and stored as + # "foo@bar.com" in database. + try: + email = canonicalise_email(body["email"]) + except ValueError as e: + raise SynapseError(400, str(e)) send_attempt = body["send_attempt"] next_link = body.get("next_link") # Optional param @@ -403,9 +427,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) - existing_user_id = await self.store.get_user_id_by_threepid( - "email", body["email"] - ) + existing_user_id = await self.store.get_user_id_by_threepid("email", email) if existing_user_id is not None: if self.config.request_token_inhibit_3pid_errors: diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 56a451c42f..370742ce59 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -47,7 +47,7 @@ from synapse.push.mailer import load_jinja2_templates from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.stringutils import assert_valid_client_secret, random_string -from synapse.util.threepids import check_3pid_allowed +from synapse.util.threepids import canonicalise_email, check_3pid_allowed from ._base import client_patterns, interactive_auth_handler @@ -116,7 +116,14 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): client_secret = body["client_secret"] assert_valid_client_secret(client_secret) - email = body["email"] + # For emails, canonicalise the address. + # We store all email addresses canonicalised in the DB. + # (See on_POST in EmailThreepidRequestTokenRestServlet + # in synapse/rest/client/v2_alpha/account.py) + try: + email = canonicalise_email(body["email"]) + except ValueError as e: + raise SynapseError(400, str(e)) send_attempt = body["send_attempt"] next_link = body.get("next_link") # Optional param @@ -128,7 +135,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): ) existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( - "email", body["email"] + "email", email ) if existing_user_id is not None: @@ -552,6 +559,15 @@ class RegisterRestServlet(RestServlet): if login_type in auth_result: medium = auth_result[login_type]["medium"] address = auth_result[login_type]["address"] + # For emails, canonicalise the address. + # We store all email addresses canonicalised in the DB. + # (See on_POST in EmailThreepidRequestTokenRestServlet + # in synapse/rest/client/v2_alpha/account.py) + if medium == "email": + try: + address = canonicalise_email(address) + except ValueError as e: + raise SynapseError(400, str(e)) existing_user_id = await self.store.get_user_id_by_threepid( medium, address diff --git a/synapse/util/threepids.py b/synapse/util/threepids.py index 3ec1dfb0c2..43c2e0ac23 100644 --- a/synapse/util/threepids.py +++ b/synapse/util/threepids.py @@ -48,3 +48,26 @@ def check_3pid_allowed(hs, medium, address): return True return False + + +def canonicalise_email(address: str) -> str: + """'Canonicalise' email address + Case folding of local part of email address and lowercase domain part + See MSC2265, https://github.com/matrix-org/matrix-doc/pull/2265 + + Args: + address: email address to be canonicalised + Returns: + The canonical form of the email address + Raises: + ValueError if the address could not be parsed. + """ + + address = address.strip() + + parts = address.split("@") + if len(parts) != 2: + logger.debug("Couldn't parse email address %s", address) + raise ValueError("Unable to parse email address") + + return parts[0].casefold() + "@" + parts[1].lower() diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py index 3ab611f618..152a5182fa 100644 --- a/tests/rest/client/v2_alpha/test_account.py +++ b/tests/rest/client/v2_alpha/test_account.py @@ -108,6 +108,46 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): # Assert we can't log in with the old password self.attempt_wrong_password_login("kermit", old_password) + def test_basic_password_reset_canonicalise_email(self): + """Test basic password reset flow + Request password reset with different spelling + """ + old_password = "monkey" + new_password = "kangeroo" + + user_id = self.register_user("kermit", old_password) + self.login("kermit", old_password) + + email_profile = "test@example.com" + email_passwort_reset = "TEST@EXAMPLE.COM" + + # Add a threepid + self.get_success( + self.store.user_add_threepid( + user_id=user_id, + medium="email", + address=email_profile, + validated_at=0, + added_at=0, + ) + ) + + client_secret = "foobar" + session_id = self._request_token(email_passwort_reset, client_secret) + + self.assertEquals(len(self.email_attempts), 1) + link = self._get_link_from_email() + + self._validate_token(link) + + self._reset_password(new_password, session_id, client_secret) + + # Assert we can log in with the new password + self.login("kermit", new_password) + + # Assert we can't log in with the old password + self.attempt_wrong_password_login("kermit", old_password) + def test_cant_reset_password_without_clicking_link(self): """Test that we do actually need to click the link in the email """ @@ -386,44 +426,67 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): self.email = "test@example.com" self.url_3pid = b"account/3pid" - def test_add_email(self): - """Test adding an email to profile - """ - client_secret = "foobar" - session_id = self._request_token(self.email, client_secret) + def test_add_valid_email(self): + self.get_success(self._add_email(self.email, self.email)) - self.assertEquals(len(self.email_attempts), 1) - link = self._get_link_from_email() + def test_add_valid_email_second_time(self): + self.get_success(self._add_email(self.email, self.email)) + self.get_success( + self._request_token_invalid_email( + self.email, + expected_errcode=Codes.THREEPID_IN_USE, + expected_error="Email is already in use", + ) + ) - self._validate_token(link) + def test_add_valid_email_second_time_canonicalise(self): + self.get_success(self._add_email(self.email, self.email)) + self.get_success( + self._request_token_invalid_email( + "TEST@EXAMPLE.COM", + expected_errcode=Codes.THREEPID_IN_USE, + expected_error="Email is already in use", + ) + ) - request, channel = self.make_request( - "POST", - b"/_matrix/client/unstable/account/3pid/add", - { - "client_secret": client_secret, - "sid": session_id, - "auth": { - "type": "m.login.password", - "user": self.user_id, - "password": "test", - }, - }, - access_token=self.user_id_tok, + def test_add_email_no_at(self): + self.get_success( + self._request_token_invalid_email( + "address-without-at.bar", + expected_errcode=Codes.UNKNOWN, + expected_error="Unable to parse email address", + ) ) - self.render(request) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + def test_add_email_two_at(self): + self.get_success( + self._request_token_invalid_email( + "foo@foo@test.bar", + expected_errcode=Codes.UNKNOWN, + expected_error="Unable to parse email address", + ) + ) - # Get user - request, channel = self.make_request( - "GET", self.url_3pid, access_token=self.user_id_tok, + def test_add_email_bad_format(self): + self.get_success( + self._request_token_invalid_email( + "user@bad.example.net@good.example.com", + expected_errcode=Codes.UNKNOWN, + expected_error="Unable to parse email address", + ) ) - self.render(request) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) - self.assertEqual(self.email, channel.json_body["threepids"][0]["address"]) + def test_add_email_domain_to_lower(self): + self.get_success(self._add_email("foo@TEST.BAR", "foo@test.bar")) + + def test_add_email_domain_with_umlaut(self): + self.get_success(self._add_email("foo@Öumlaut.com", "foo@öumlaut.com")) + + def test_add_email_address_casefold(self): + self.get_success(self._add_email("Strauß@Example.com", "strauss@example.com")) + + def test_address_trim(self): + self.get_success(self._add_email(" foo@test.bar ", "foo@test.bar")) def test_add_email_if_disabled(self): """Test adding email to profile when doing so is disallowed @@ -616,6 +679,19 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): return channel.json_body["sid"] + def _request_token_invalid_email( + self, email, expected_errcode, expected_error, client_secret="foobar", + ): + request, channel = self.make_request( + "POST", + b"account/3pid/email/requestToken", + {"client_secret": client_secret, "email": email, "send_attempt": 1}, + ) + self.render(request) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(expected_errcode, channel.json_body["errcode"]) + self.assertEqual(expected_error, channel.json_body["error"]) + def _validate_token(self, link): # Remove the host path = link.replace("https://example.com", "") @@ -643,3 +719,42 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): assert match, "Could not find link in email" return match.group(0) + + def _add_email(self, request_email, expected_email): + """Test adding an email to profile + """ + client_secret = "foobar" + session_id = self._request_token(request_email, client_secret) + + self.assertEquals(len(self.email_attempts), 1) + link = self._get_link_from_email() + + self._validate_token(link) + + request, channel = self.make_request( + "POST", + b"/_matrix/client/unstable/account/3pid/add", + { + "client_secret": client_secret, + "sid": session_id, + "auth": { + "type": "m.login.password", + "user": self.user_id, + "password": "test", + }, + }, + access_token=self.user_id_tok, + ) + + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_3pid, access_token=self.user_id_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) + self.assertEqual(expected_email, channel.json_body["threepids"][0]["address"]) diff --git a/tests/util/test_threepids.py b/tests/util/test_threepids.py new file mode 100644 index 0000000000..5513724d87 --- /dev/null +++ b/tests/util/test_threepids.py @@ -0,0 +1,49 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Dirk Klimpel +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.util.threepids import canonicalise_email + +from tests.unittest import HomeserverTestCase + + +class CanonicaliseEmailTests(HomeserverTestCase): + def test_no_at(self): + with self.assertRaises(ValueError): + canonicalise_email("address-without-at.bar") + + def test_two_at(self): + with self.assertRaises(ValueError): + canonicalise_email("foo@foo@test.bar") + + def test_bad_format(self): + with self.assertRaises(ValueError): + canonicalise_email("user@bad.example.net@good.example.com") + + def test_valid_format(self): + self.assertEqual(canonicalise_email("foo@test.bar"), "foo@test.bar") + + def test_domain_to_lower(self): + self.assertEqual(canonicalise_email("foo@TEST.BAR"), "foo@test.bar") + + def test_domain_with_umlaut(self): + self.assertEqual(canonicalise_email("foo@Öumlaut.com"), "foo@öumlaut.com") + + def test_address_casefold(self): + self.assertEqual( + canonicalise_email("Strauß@Example.com"), "strauss@example.com" + ) + + def test_address_trim(self): + self.assertEqual(canonicalise_email(" foo@test.bar "), "foo@test.bar") -- cgit 1.5.1 From 62b1ce85398f52e7d6137e77083294d0c90af459 Mon Sep 17 00:00:00 2001 From: Will Hunt Date: Sun, 5 Jul 2020 16:32:02 +0100 Subject: isort 5 compatibility (#7786) The CI appears to use the latest version of isort, which is a problem when isort gets a major version bump. Rather than try to pin the version, I've done the necessary to make isort5 happy with synapse. --- changelog.d/7786.misc | 1 + scripts-dev/check_signature.py | 2 +- scripts-dev/lint.sh | 2 +- setup.cfg | 1 - synapse/api/auth.py | 3 +-- synapse/config/__main__.py | 1 + synapse/config/emailconfig.py | 3 +-- synapse/handlers/auth.py | 3 +-- synapse/handlers/cas_handler.py | 3 +-- synapse/logging/opentracing.py | 4 ++-- synapse/replication/tcp/client.py | 2 +- synapse/replication/tcp/handler.py | 4 ++-- synapse/replication/tcp/streams/events.py | 2 -- synapse/rest/media/v1/thumbnailer.py | 3 +-- synapse/secrets.py | 3 +-- synapse/storage/data_stores/main/events.py | 3 +-- synapse/storage/data_stores/main/ui_auth.py | 2 +- synapse/storage/types.py | 2 -- synapse/types.py | 2 +- tests/handlers/test_e2e_keys.py | 4 +--- tests/rest/media/v1/test_media_storage.py | 4 +--- tests/test_utils/event_injection.py | 2 -- tox.ini | 4 ++-- 23 files changed, 22 insertions(+), 38 deletions(-) create mode 100644 changelog.d/7786.misc (limited to 'synapse/handlers/auth.py') diff --git a/changelog.d/7786.misc b/changelog.d/7786.misc new file mode 100644 index 0000000000..27af2681dc --- /dev/null +++ b/changelog.d/7786.misc @@ -0,0 +1 @@ +Update linting scripts and codebase to be compatible with `isort` v5. diff --git a/scripts-dev/check_signature.py b/scripts-dev/check_signature.py index ecda103cf7..6755bc5282 100644 --- a/scripts-dev/check_signature.py +++ b/scripts-dev/check_signature.py @@ -2,9 +2,9 @@ import argparse import json import logging import sys -import urllib2 import dns.resolver +import urllib2 from signedjson.key import decode_verify_key_bytes, write_signing_keys from signedjson.sign import verify_signed_json from unpaddedbase64 import decode_base64 diff --git a/scripts-dev/lint.sh b/scripts-dev/lint.sh index 6f1ba22931..66b0568858 100755 --- a/scripts-dev/lint.sh +++ b/scripts-dev/lint.sh @@ -15,7 +15,7 @@ else fi echo "Linting these locations: $files" -isort -y -rc $files +isort $files python3 -m black $files ./scripts-dev/config-lint.sh flake8 $files diff --git a/setup.cfg b/setup.cfg index f2bca272e1..a32278ea8a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -26,7 +26,6 @@ ignore=W503,W504,E203,E731,E501 [isort] line_length = 88 -not_skip = __init__.py sections=FUTURE,STDLIB,COMPAT,THIRDPARTY,TWISTED,FIRSTPARTY,TESTS,LOCALFOLDER default_section=THIRDPARTY known_first_party = synapse diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 06ba6604f3..cb22508f4d 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging from typing import Optional @@ -22,7 +21,6 @@ from netaddr import IPAddress from twisted.internet import defer from twisted.web.server import Request -import synapse.logging.opentracing as opentracing import synapse.types from synapse import event_auth from synapse.api.auth_blocking import AuthBlocking @@ -35,6 +33,7 @@ from synapse.api.errors import ( ) from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import EventBase +from synapse.logging import opentracing as opentracing from synapse.types import StateMap, UserID from synapse.util.caches import register_cache from synapse.util.caches.lrucache import LruCache diff --git a/synapse/config/__main__.py b/synapse/config/__main__.py index fca35b008c..65043d5b5b 100644 --- a/synapse/config/__main__.py +++ b/synapse/config/__main__.py @@ -16,6 +16,7 @@ from synapse.config._base import ConfigError if __name__ == "__main__": import sys + from synapse.config.homeserver import HomeServerConfig action = sys.argv[1] diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index ca61214454..df08bcd1bc 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -14,7 +14,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from __future__ import print_function # This file can't be called email.py because if it is, we cannot: @@ -145,8 +144,8 @@ class EmailConfig(Config): or self.threepid_behaviour_email == ThreepidBehaviour.LOCAL ): # make sure we can import the required deps - import jinja2 import bleach + import jinja2 # prevent unused warnings jinja2 diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index d713a06bf9..a162392e4c 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -13,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging import time import unicodedata @@ -24,7 +23,6 @@ import attr import bcrypt # type: ignore[import] import pymacaroons -import synapse.util.stringutils as stringutils from synapse.api.constants import LoginType from synapse.api.errors import ( AuthError, @@ -45,6 +43,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.module_api import ModuleApi from synapse.push.mailer import load_jinja2_templates from synapse.types import Requester, UserID +from synapse.util import stringutils as stringutils from synapse.util.threepids import canonicalise_email from ._base import BaseHandler diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py index 76f213723a..d79ffefdb5 100644 --- a/synapse/handlers/cas_handler.py +++ b/synapse/handlers/cas_handler.py @@ -12,11 +12,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging import urllib -import xml.etree.ElementTree as ET from typing import Dict, Optional, Tuple +from xml.etree import ElementTree as ET from twisted.web.client import PartialDownloadError diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index 1676771ef0..c6c0e623c1 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -164,7 +164,6 @@ Gotchas than one caller? Will all of those calling functions have be in a context with an active span? """ - import contextlib import inspect import logging @@ -180,8 +179,8 @@ from twisted.internet import defer from synapse.config import ConfigError if TYPE_CHECKING: - from synapse.server import HomeServer from synapse.http.site import SynapseRequest + from synapse.server import HomeServer # Helper class @@ -227,6 +226,7 @@ except ImportError: tags = _DummyTagNames try: from jaeger_client import Config as JaegerConfig + from synapse.logging.scopecontextmanager import LogContextScopeManager except ImportError: JaegerConfig = None # type: ignore diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index df29732f51..4985e40b1f 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -33,8 +33,8 @@ from synapse.util.async_helpers import timeout_deferred from synapse.util.metrics import Measure if TYPE_CHECKING: - from synapse.server import HomeServer from synapse.replication.tcp.handler import ReplicationCommandHandler + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index e6a2e2598b..55b3b79008 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -13,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar @@ -149,10 +148,11 @@ class ReplicationCommandHandler: using TCP. """ if hs.config.redis.redis_enabled: + import txredisapi + from synapse.replication.tcp.redis import ( RedisDirectTcpReplicationClientFactory, ) - import txredisapi logger.info( "Connecting to redis (host=%r port=%r)", diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index f370390331..bdddb62ad6 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -13,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import heapq from collections import Iterable from typing import List, Tuple, Type @@ -22,7 +21,6 @@ import attr from ._base import Stream, StreamUpdateResult, Token, current_token_without_instance - """Handling of the 'events' replication stream This stream contains rows of various types. Each row therefore contains a 'type' diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py index c234ea7421..7126997134 100644 --- a/synapse/rest/media/v1/thumbnailer.py +++ b/synapse/rest/media/v1/thumbnailer.py @@ -12,11 +12,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging from io import BytesIO -import PIL.Image as Image +from PIL import Image as Image logger = logging.getLogger(__name__) diff --git a/synapse/secrets.py b/synapse/secrets.py index 0b327a0f82..5f43f81eb0 100644 --- a/synapse/secrets.py +++ b/synapse/secrets.py @@ -19,7 +19,6 @@ Injectable secrets module for Synapse. See https://docs.python.org/3/library/secrets.html#module-secrets for the API used in Python 3.6, and the API emulated in Python 2.7. """ - import sys # secrets is available since python 3.6 @@ -31,8 +30,8 @@ if sys.version_info[0:2] >= (3, 6): else: - import os import binascii + import os class Secrets(object): def token_bytes(self, nbytes=32): diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index cfd24d2f06..b7bf3fbd9d 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -14,7 +14,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import itertools import logging from collections import OrderedDict, namedtuple @@ -48,8 +47,8 @@ from synapse.util.frozenutils import frozendict_json_encoder from synapse.util.iterutils import batch_iter if TYPE_CHECKING: - from synapse.storage.data_stores.main import DataStore from synapse.server import HomeServer + from synapse.storage.data_stores.main import DataStore logger = logging.getLogger(__name__) diff --git a/synapse/storage/data_stores/main/ui_auth.py b/synapse/storage/data_stores/main/ui_auth.py index ec2f38c373..4c044b1a15 100644 --- a/synapse/storage/data_stores/main/ui_auth.py +++ b/synapse/storage/data_stores/main/ui_auth.py @@ -17,10 +17,10 @@ from typing import Any, Dict, Optional, Union import attr -import synapse.util.stringutils as stringutils from synapse.api.errors import StoreError from synapse.storage._base import SQLBaseStore from synapse.types import JsonDict +from synapse.util import stringutils as stringutils @attr.s diff --git a/synapse/storage/types.py b/synapse/storage/types.py index daff81c5ee..2d2b560e74 100644 --- a/synapse/storage/types.py +++ b/synapse/storage/types.py @@ -12,12 +12,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from typing import Any, Iterable, Iterator, List, Tuple from typing_extensions import Protocol - """ Some very basic protocol definitions for the DB-API2 classes specified in PEP-249 """ diff --git a/synapse/types.py b/synapse/types.py index acf60baddc..238b938064 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -29,7 +29,7 @@ from synapse.api.errors import Codes, SynapseError if sys.version_info[:3] >= (3, 6, 0): from typing import Collection else: - from typing import Sized, Iterable, Container + from typing import Container, Iterable, Sized T_co = TypeVar("T_co", covariant=True) diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 6c1dc72bd1..1acf287ca4 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -14,11 +14,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import mock -import signedjson.key as key -import signedjson.sign as sign +from signedjson import key as key, sign as sign from twisted.internet import defer diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index 2ed9312d56..66fa5978b2 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -12,8 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - import os import shutil import tempfile @@ -25,8 +23,8 @@ from urllib import parse from mock import Mock import attr -import PIL.Image as Image from parameterized import parameterized_class +from PIL import Image as Image from twisted.internet.defer import Deferred diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py index 431e9f8e5e..43297b530c 100644 --- a/tests/test_utils/event_injection.py +++ b/tests/test_utils/event_injection.py @@ -13,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from typing import Optional, Tuple import synapse.server @@ -25,7 +24,6 @@ from synapse.types import Collection from tests.test_utils import get_awaitable_result - """ Utility functions for poking events into the storage of the server under test. """ diff --git a/tox.ini b/tox.ini index ab6557f15e..1c042cb227 100644 --- a/tox.ini +++ b/tox.ini @@ -131,8 +131,8 @@ commands = [testenv:check_isort] skip_install = True -deps = isort -commands = /bin/sh -c "isort -c -df -sp setup.cfg -rc synapse tests scripts-dev scripts" +deps = isort==5.0.3 +commands = /bin/sh -c "isort -c --df --sp setup.cfg synapse tests scripts-dev scripts" [testenv:check-newsfragment] skip_install = True -- cgit 1.5.1 From 83434df3812650f53c60e91fb23c2079db0fb5b8 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 23 Jul 2020 15:45:39 -0400 Subject: Update the auth providers to be async. (#7935) --- changelog.d/7935.misc | 1 + docs/password_auth_providers.md | 187 ++++++++++++++++++----------------- synapse/handlers/auth.py | 7 +- synapse/handlers/ui_auth/checkers.py | 35 ++++--- 4 files changed, 118 insertions(+), 112 deletions(-) create mode 100644 changelog.d/7935.misc (limited to 'synapse/handlers/auth.py') diff --git a/changelog.d/7935.misc b/changelog.d/7935.misc new file mode 100644 index 0000000000..3771f99bf2 --- /dev/null +++ b/changelog.d/7935.misc @@ -0,0 +1 @@ +Convert the auth providers to be async/await. diff --git a/docs/password_auth_providers.md b/docs/password_auth_providers.md index 5d9ae67041..fef1d47e85 100644 --- a/docs/password_auth_providers.md +++ b/docs/password_auth_providers.md @@ -19,102 +19,103 @@ password auth provider module implementations: Password auth provider classes must provide the following methods: -*class* `SomeProvider.parse_config`(*config*) +* `parse_config(config)` + This method is passed the `config` object for this module from the + homeserver configuration file. -> This method is passed the `config` object for this module from the -> homeserver configuration file. -> -> It should perform any appropriate sanity checks on the provided -> configuration, and return an object which is then passed into -> `__init__`. + It should perform any appropriate sanity checks on the provided + configuration, and return an object which is then passed into -*class* `SomeProvider`(*config*, *account_handler*) + This method should have the `@staticmethod` decoration. -> The constructor is passed the config object returned by -> `parse_config`, and a `synapse.module_api.ModuleApi` object which -> allows the password provider to check if accounts exist and/or create -> new ones. +* `__init__(self, config, account_handler)` + + The constructor is passed the config object returned by + `parse_config`, and a `synapse.module_api.ModuleApi` object which + allows the password provider to check if accounts exist and/or create + new ones. ## Optional methods -Password auth provider classes may optionally provide the following -methods. - -*class* `SomeProvider.get_db_schema_files`() - -> This method, if implemented, should return an Iterable of -> `(name, stream)` pairs of database schema files. Each file is applied -> in turn at initialisation, and a record is then made in the database -> so that it is not re-applied on the next start. - -`someprovider.get_supported_login_types`() - -> This method, if implemented, should return a `dict` mapping from a -> login type identifier (such as `m.login.password`) to an iterable -> giving the fields which must be provided by the user in the submission -> to the `/login` api. These fields are passed in the `login_dict` -> dictionary to `check_auth`. -> -> For example, if a password auth provider wants to implement a custom -> login type of `com.example.custom_login`, where the client is expected -> to pass the fields `secret1` and `secret2`, the provider should -> implement this method and return the following dict: -> -> {"com.example.custom_login": ("secret1", "secret2")} - -`someprovider.check_auth`(*username*, *login_type*, *login_dict*) - -> This method is the one that does the real work. If implemented, it -> will be called for each login attempt where the login type matches one -> of the keys returned by `get_supported_login_types`. -> -> It is passed the (possibly UNqualified) `user` provided by the client, -> the login type, and a dictionary of login secrets passed by the -> client. -> -> The method should return a Twisted `Deferred` object, which resolves -> to the canonical `@localpart:domain` user id if authentication is -> successful, and `None` if not. -> -> Alternatively, the `Deferred` can resolve to a `(str, func)` tuple, in -> which case the second field is a callback which will be called with -> the result from the `/login` call (including `access_token`, -> `device_id`, etc.) - -`someprovider.check_3pid_auth`(*medium*, *address*, *password*) - -> This method, if implemented, is called when a user attempts to -> register or log in with a third party identifier, such as email. It is -> passed the medium (ex. "email"), an address (ex. -> "") and the user's password. -> -> The method should return a Twisted `Deferred` object, which resolves -> to a `str` containing the user's (canonical) User ID if -> authentication was successful, and `None` if not. -> -> As with `check_auth`, the `Deferred` may alternatively resolve to a -> `(user_id, callback)` tuple. - -`someprovider.check_password`(*user_id*, *password*) - -> This method provides a simpler interface than -> `get_supported_login_types` and `check_auth` for password auth -> providers that just want to provide a mechanism for validating -> `m.login.password` logins. -> -> Iif implemented, it will be called to check logins with an -> `m.login.password` login type. It is passed a qualified -> `@localpart:domain` user id, and the password provided by the user. -> -> The method should return a Twisted `Deferred` object, which resolves -> to `True` if authentication is successful, and `False` if not. - -`someprovider.on_logged_out`(*user_id*, *device_id*, *access_token*) - -> This method, if implemented, is called when a user logs out. It is -> passed the qualified user ID, the ID of the deactivated device (if -> any: access tokens are occasionally created without an associated -> device ID), and the (now deactivated) access token. -> -> It may return a Twisted `Deferred` object; the logout request will -> wait for the deferred to complete but the result is ignored. +Password auth provider classes may optionally provide the following methods: + +* `get_db_schema_files(self)` + + This method, if implemented, should return an Iterable of + `(name, stream)` pairs of database schema files. Each file is applied + in turn at initialisation, and a record is then made in the database + so that it is not re-applied on the next start. + +* `get_supported_login_types(self)` + + This method, if implemented, should return a `dict` mapping from a + login type identifier (such as `m.login.password`) to an iterable + giving the fields which must be provided by the user in the submission + to [the `/login` API](https://matrix.org/docs/spec/client_server/latest#post-matrix-client-r0-login). + These fields are passed in the `login_dict` dictionary to `check_auth`. + + For example, if a password auth provider wants to implement a custom + login type of `com.example.custom_login`, where the client is expected + to pass the fields `secret1` and `secret2`, the provider should + implement this method and return the following dict: + + ```python + {"com.example.custom_login": ("secret1", "secret2")} + ``` + +* `check_auth(self, username, login_type, login_dict)` + + This method does the real work. If implemented, it + will be called for each login attempt where the login type matches one + of the keys returned by `get_supported_login_types`. + + It is passed the (possibly unqualified) `user` field provided by the client, + the login type, and a dictionary of login secrets passed by the + client. + + The method should return an `Awaitable` object, which resolves + to the canonical `@localpart:domain` user ID if authentication is + successful, and `None` if not. + + Alternatively, the `Awaitable` can resolve to a `(str, func)` tuple, in + which case the second field is a callback which will be called with + the result from the `/login` call (including `access_token`, + `device_id`, etc.) + +* `check_3pid_auth(self, medium, address, password)` + + This method, if implemented, is called when a user attempts to + register or log in with a third party identifier, such as email. It is + passed the medium (ex. "email"), an address (ex. + "") and the user's password. + + The method should return an `Awaitable` object, which resolves + to a `str` containing the user's (canonical) User id if + authentication was successful, and `None` if not. + + As with `check_auth`, the `Awaitable` may alternatively resolve to a + `(user_id, callback)` tuple. + +* `check_password(self, user_id, password)` + + This method provides a simpler interface than + `get_supported_login_types` and `check_auth` for password auth + providers that just want to provide a mechanism for validating + `m.login.password` logins. + + If implemented, it will be called to check logins with an + `m.login.password` login type. It is passed a qualified + `@localpart:domain` user id, and the password provided by the user. + + The method should return an `Awaitable` object, which resolves + to `True` if authentication is successful, and `False` if not. + +* `on_logged_out(self, user_id, device_id, access_token)` + + This method, if implemented, is called when a user logs out. It is + passed the qualified user ID, the ID of the deactivated device (if + any: access tokens are occasionally created without an associated + device ID), and the (now deactivated) access token. + + It may return an `Awaitable` object; the logout request will + wait for the `Awaitable` to complete, but the result is ignored. diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index a162392e4c..c7d921c21a 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -13,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import inspect import logging import time import unicodedata @@ -863,11 +864,15 @@ class AuthHandler(BaseHandler): # see if any of our auth providers want to know about this for provider in self.password_providers: if hasattr(provider, "on_logged_out"): - await provider.on_logged_out( + # This might return an awaitable, if it does block the log out + # until it completes. + result = provider.on_logged_out( user_id=str(user_info["user"]), device_id=user_info["device_id"], access_token=access_token, ) + if inspect.isawaitable(result): + await result # delete pushers associated with this access token if user_info["token_id"] is not None: diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py index a140e9391e..a011e9fe29 100644 --- a/synapse/handlers/ui_auth/checkers.py +++ b/synapse/handlers/ui_auth/checkers.py @@ -14,10 +14,10 @@ # limitations under the License. import logging +from typing import Any from canonicaljson import json -from twisted.internet import defer from twisted.web.client import PartialDownloadError from synapse.api.constants import LoginType @@ -33,25 +33,25 @@ class UserInteractiveAuthChecker: def __init__(self, hs): pass - def is_enabled(self): + def is_enabled(self) -> bool: """Check if the configuration of the homeserver allows this checker to work Returns: - bool: True if this login type is enabled. + True if this login type is enabled. """ - def check_auth(self, authdict, clientip): + async def check_auth(self, authdict: dict, clientip: str) -> Any: """Given the authentication dict from the client, attempt to check this step Args: - authdict (dict): authentication dictionary from the client - clientip (str): The IP address of the client. + authdict: authentication dictionary from the client + clientip: The IP address of the client. Raises: SynapseError if authentication failed Returns: - Deferred: the result of authentication (to pass back to the client?) + The result of authentication (to pass back to the client?) """ raise NotImplementedError() @@ -62,8 +62,8 @@ class DummyAuthChecker(UserInteractiveAuthChecker): def is_enabled(self): return True - def check_auth(self, authdict, clientip): - return defer.succeed(True) + async def check_auth(self, authdict, clientip): + return True class TermsAuthChecker(UserInteractiveAuthChecker): @@ -72,8 +72,8 @@ class TermsAuthChecker(UserInteractiveAuthChecker): def is_enabled(self): return True - def check_auth(self, authdict, clientip): - return defer.succeed(True) + async def check_auth(self, authdict, clientip): + return True class RecaptchaAuthChecker(UserInteractiveAuthChecker): @@ -89,8 +89,7 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker): def is_enabled(self): return self._enabled - @defer.inlineCallbacks - def check_auth(self, authdict, clientip): + async def check_auth(self, authdict, clientip): try: user_response = authdict["response"] except KeyError: @@ -107,7 +106,7 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker): # TODO: get this from the homeserver rather than creating a new one for # each request try: - resp_body = yield self._http_client.post_urlencoded_get_json( + resp_body = await self._http_client.post_urlencoded_get_json( self._url, args={ "secret": self._secret, @@ -219,8 +218,8 @@ class EmailIdentityAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChec ThreepidBehaviour.LOCAL, ) - def check_auth(self, authdict, clientip): - return defer.ensureDeferred(self._check_threepid("email", authdict)) + async def check_auth(self, authdict, clientip): + return await self._check_threepid("email", authdict) class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker): @@ -233,8 +232,8 @@ class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker): def is_enabled(self): return bool(self.hs.config.account_threepid_delegate_msisdn) - def check_auth(self, authdict, clientip): - return defer.ensureDeferred(self._check_threepid("msisdn", authdict)) + async def check_auth(self, authdict, clientip): + return await self._check_threepid("msisdn", authdict) INTERACTIVE_AUTH_CHECKERS = [ -- cgit 1.5.1 From 66f24449dd614b23ea4c572d8d613efeb129e4a2 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 6 Aug 2020 08:09:55 -0400 Subject: Improve performance of the register endpoint (#8009) --- changelog.d/8009.misc | 1 + synapse/api/errors.py | 4 +- synapse/handlers/auth.py | 19 +++-- synapse/rest/client/v2_alpha/account.py | 86 +++++++++++++++------- synapse/rest/client/v2_alpha/register.py | 108 ++++++++++++++++++---------- tests/rest/client/v2_alpha/test_register.py | 2 +- 6 files changed, 146 insertions(+), 74 deletions(-) create mode 100644 changelog.d/8009.misc (limited to 'synapse/handlers/auth.py') diff --git a/changelog.d/8009.misc b/changelog.d/8009.misc new file mode 100644 index 0000000000..3d58a11313 --- /dev/null +++ b/changelog.d/8009.misc @@ -0,0 +1 @@ +Improve the performance of the register endpoint. diff --git a/synapse/api/errors.py b/synapse/api/errors.py index b3bab1aa52..6e40630ab6 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -238,14 +238,16 @@ class InteractiveAuthIncompleteError(Exception): (This indicates we should return a 401 with 'result' as the body) Attributes: + session_id: The ID of the ongoing interactive auth session. result: the server response to the request, which should be passed back to the client """ - def __init__(self, result: "JsonDict"): + def __init__(self, session_id: str, result: "JsonDict"): super(InteractiveAuthIncompleteError, self).__init__( "Interactive auth not yet complete" ) + self.session_id = session_id self.result = result diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index c7d921c21a..c24e7bafe0 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -162,7 +162,7 @@ class AuthHandler(BaseHandler): request_body: Dict[str, Any], clientip: str, description: str, - ) -> dict: + ) -> Tuple[dict, str]: """ Checks that the user is who they claim to be, via a UI auth. @@ -183,9 +183,14 @@ class AuthHandler(BaseHandler): describes the operation happening on their account. Returns: - The parameters for this request (which may + A tuple of (params, session_id). + + 'params' contains the parameters for this request (which may have been given only in a previous call). + 'session_id' is the ID of this session, either passed in by the + client or assigned by this call + Raises: InteractiveAuthIncompleteError if the client has not yet completed any of the permitted login flows @@ -207,7 +212,7 @@ class AuthHandler(BaseHandler): flows = [[login_type] for login_type in self._supported_ui_auth_types] try: - result, params, _ = await self.check_auth( + result, params, session_id = await self.check_ui_auth( flows, request, request_body, clientip, description ) except LoginError: @@ -230,7 +235,7 @@ class AuthHandler(BaseHandler): if user_id != requester.user.to_string(): raise AuthError(403, "Invalid auth") - return params + return params, session_id def get_enabled_auth_types(self): """Return the enabled user-interactive authentication types @@ -240,7 +245,7 @@ class AuthHandler(BaseHandler): """ return self.checkers.keys() - async def check_auth( + async def check_ui_auth( self, flows: List[List[str]], request: SynapseRequest, @@ -363,7 +368,7 @@ class AuthHandler(BaseHandler): if not authdict: raise InteractiveAuthIncompleteError( - self._auth_dict_for_flows(flows, session.session_id) + session.session_id, self._auth_dict_for_flows(flows, session.session_id) ) # check auth type currently being presented @@ -410,7 +415,7 @@ class AuthHandler(BaseHandler): ret = self._auth_dict_for_flows(flows, session.session_id) ret["completed"] = list(creds) ret.update(errordict) - raise InteractiveAuthIncompleteError(ret) + raise InteractiveAuthIncompleteError(session.session_id, ret) async def add_oob_auth( self, stagetype: str, authdict: Dict[str, Any], clientip: str diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 3767a809a4..fead85074b 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -18,7 +18,12 @@ import logging from http import HTTPStatus from synapse.api.constants import LoginType -from synapse.api.errors import Codes, SynapseError, ThreepidValidationError +from synapse.api.errors import ( + Codes, + InteractiveAuthIncompleteError, + SynapseError, + ThreepidValidationError, +) from synapse.config.emailconfig import ThreepidBehaviour from synapse.http.server import finish_request, respond_with_html from synapse.http.servlet import ( @@ -239,18 +244,12 @@ class PasswordRestServlet(RestServlet): # we do basic sanity checks here because the auth layer will store these # in sessions. Pull out the new password provided to us. - if "new_password" in body: - new_password = body.pop("new_password") + new_password = body.pop("new_password", None) + if new_password is not None: if not isinstance(new_password, str) or len(new_password) > 512: raise SynapseError(400, "Invalid password") self.password_policy_handler.validate_password(new_password) - # If the password is valid, hash it and store it back on the body. - # This ensures that only the hashed password is handled everywhere. - if "new_password_hash" in body: - raise SynapseError(400, "Unexpected property: new_password_hash") - body["new_password_hash"] = await self.auth_handler.hash(new_password) - # there are two possibilities here. Either the user does not have an # access token, and needs to do a password reset; or they have one and # need to validate their identity. @@ -263,23 +262,49 @@ class PasswordRestServlet(RestServlet): if self.auth.has_access_token(request): requester = await self.auth.get_user_by_req(request) - params = await self.auth_handler.validate_user_via_ui_auth( - requester, - request, - body, - self.hs.get_ip_from_request(request), - "modify your account password", - ) + try: + params, session_id = await self.auth_handler.validate_user_via_ui_auth( + requester, + request, + body, + self.hs.get_ip_from_request(request), + "modify your account password", + ) + except InteractiveAuthIncompleteError as e: + # The user needs to provide more steps to complete auth, but + # they're not required to provide the password again. + # + # If a password is available now, hash the provided password and + # store it for later. + if new_password: + password_hash = await self.auth_handler.hash(new_password) + await self.auth_handler.set_session_data( + e.session_id, "password_hash", password_hash + ) + raise user_id = requester.user.to_string() else: requester = None - result, params, _ = await self.auth_handler.check_auth( - [[LoginType.EMAIL_IDENTITY]], - request, - body, - self.hs.get_ip_from_request(request), - "modify your account password", - ) + try: + result, params, session_id = await self.auth_handler.check_ui_auth( + [[LoginType.EMAIL_IDENTITY]], + request, + body, + self.hs.get_ip_from_request(request), + "modify your account password", + ) + except InteractiveAuthIncompleteError as e: + # The user needs to provide more steps to complete auth, but + # they're not required to provide the password again. + # + # If a password is available now, hash the provided password and + # store it for later. + if new_password: + password_hash = await self.auth_handler.hash(new_password) + await self.auth_handler.set_session_data( + e.session_id, "password_hash", password_hash + ) + raise if LoginType.EMAIL_IDENTITY in result: threepid = result[LoginType.EMAIL_IDENTITY] @@ -304,12 +329,21 @@ class PasswordRestServlet(RestServlet): logger.error("Auth succeeded but no known type! %r", result.keys()) raise SynapseError(500, "", Codes.UNKNOWN) - assert_params_in_dict(params, ["new_password_hash"]) - new_password_hash = params["new_password_hash"] + # If we have a password in this request, prefer it. Otherwise, there + # must be a password hash from an earlier request. + if new_password: + password_hash = await self.auth_handler.hash(new_password) + else: + password_hash = await self.auth_handler.get_session_data( + session_id, "password_hash", None + ) + if not password_hash: + raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM) + logout_devices = params.get("logout_devices", True) await self._set_password_handler.set_password( - user_id, new_password_hash, logout_devices, requester + user_id, password_hash, logout_devices, requester ) return 200, {} diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 370742ce59..a4c079196d 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -24,6 +24,7 @@ import synapse.types from synapse.api.constants import LoginType from synapse.api.errors import ( Codes, + InteractiveAuthIncompleteError, SynapseError, ThreepidValidationError, UnrecognizedRequestError, @@ -387,6 +388,7 @@ class RegisterRestServlet(RestServlet): self.ratelimiter = hs.get_registration_ratelimiter() self.password_policy_handler = hs.get_password_policy_handler() self.clock = hs.get_clock() + self._registration_enabled = self.hs.config.enable_registration self._registration_flows = _calculate_registration_flows( hs.config, self.auth_handler @@ -412,20 +414,8 @@ class RegisterRestServlet(RestServlet): "Do not understand membership kind: %s" % (kind.decode("utf8"),) ) - # we do basic sanity checks here because the auth layer will store these - # in sessions. Pull out the username/password provided to us. - if "password" in body: - password = body.pop("password") - if not isinstance(password, str) or len(password) > 512: - raise SynapseError(400, "Invalid password") - self.password_policy_handler.validate_password(password) - - # If the password is valid, hash it and store it back on the body. - # This ensures that only the hashed password is handled everywhere. - if "password_hash" in body: - raise SynapseError(400, "Unexpected property: password_hash") - body["password_hash"] = await self.auth_handler.hash(password) - + # Pull out the provided username and do basic sanity checks early since + # the auth layer will store these in sessions. desired_username = None if "username" in body: if not isinstance(body["username"], str) or len(body["username"]) > 512: @@ -459,22 +449,35 @@ class RegisterRestServlet(RestServlet): ) return 200, result # we throw for non 200 responses - # for regular registration, downcase the provided username before - # attempting to register it. This should mean - # that people who try to register with upper-case in their usernames - # don't get a nasty surprise. (Note that we treat username - # case-insenstively in login, so they are free to carry on imagining - # that their username is CrAzYh4cKeR if that keeps them happy) - if desired_username is not None: - desired_username = desired_username.lower() - # == Normal User Registration == (everyone else) - if not self.hs.config.enable_registration: + if not self._registration_enabled: raise SynapseError(403, "Registration has been disabled") + # For regular registration, convert the provided username to lowercase + # before attempting to register it. This should mean that people who try + # to register with upper-case in their usernames don't get a nasty surprise. + # + # Note that we treat usernames case-insensitively in login, so they are + # free to carry on imagining that their username is CrAzYh4cKeR if that + # keeps them happy. + if desired_username is not None: + desired_username = desired_username.lower() + + # Check if this account is upgrading from a guest account. guest_access_token = body.get("guest_access_token", None) - if "initial_device_display_name" in body and "password_hash" not in body: + # Pull out the provided password and do basic sanity checks early. + # + # Note that we remove the password from the body since the auth layer + # will store the body in the session and we don't want a plaintext + # password store there. + password = body.pop("password", None) + if password is not None: + if not isinstance(password, str) or len(password) > 512: + raise SynapseError(400, "Invalid password") + self.password_policy_handler.validate_password(password) + + if "initial_device_display_name" in body and password is None: # ignore 'initial_device_display_name' if sent without # a password to work around a client bug where it sent # the 'initial_device_display_name' param alone, wiping out @@ -484,6 +487,7 @@ class RegisterRestServlet(RestServlet): session_id = self.auth_handler.get_session_id(body) registered_user_id = None + password_hash = None if session_id: # if we get a registered user id out of here, it means we previously # registered a user for this session, so we could just return the @@ -492,7 +496,12 @@ class RegisterRestServlet(RestServlet): registered_user_id = await self.auth_handler.get_session_data( session_id, "registered_user_id", None ) + # Extract the previously-hashed password from the session. + password_hash = await self.auth_handler.get_session_data( + session_id, "password_hash", None + ) + # Ensure that the username is valid. if desired_username is not None: await self.registration_handler.check_username( desired_username, @@ -500,20 +509,38 @@ class RegisterRestServlet(RestServlet): assigned_user_id=registered_user_id, ) - auth_result, params, session_id = await self.auth_handler.check_auth( - self._registration_flows, - request, - body, - self.hs.get_ip_from_request(request), - "register a new account", - ) + # Check if the user-interactive authentication flows are complete, if + # not this will raise a user-interactive auth error. + try: + auth_result, params, session_id = await self.auth_handler.check_ui_auth( + self._registration_flows, + request, + body, + self.hs.get_ip_from_request(request), + "register a new account", + ) + except InteractiveAuthIncompleteError as e: + # The user needs to provide more steps to complete auth. + # + # Hash the password and store it with the session since the client + # is not required to provide the password again. + # + # If a password hash was previously stored we will not attempt to + # re-hash and store it for efficiency. This assumes the password + # does not change throughout the authentication flow, but this + # should be fine since the data is meant to be consistent. + if not password_hash and password: + password_hash = await self.auth_handler.hash(password) + await self.auth_handler.set_session_data( + e.session_id, "password_hash", password_hash + ) + raise # Check that we're not trying to register a denied 3pid. # # the user-facing checks will probably already have happened in # /register/email/requestToken when we requested a 3pid, but that's not # guaranteed. - if auth_result: for login_type in [LoginType.EMAIL_IDENTITY, LoginType.MSISDN]: if login_type in auth_result: @@ -535,12 +562,15 @@ class RegisterRestServlet(RestServlet): # don't re-register the threepids registered = False else: - # NB: This may be from the auth handler and NOT from the POST - assert_params_in_dict(params, ["password_hash"]) + # If we have a password in this request, prefer it. Otherwise, there + # might be a password hash from an earlier request. + if password: + password_hash = await self.auth_handler.hash(password) + if not password_hash: + raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM) desired_username = params.get("username", None) guest_access_token = params.get("guest_access_token", None) - new_password_hash = params.get("password_hash", None) if desired_username is not None: desired_username = desired_username.lower() @@ -582,7 +612,7 @@ class RegisterRestServlet(RestServlet): registered_user_id = await self.registration_handler.register_user( localpart=desired_username, - password_hash=new_password_hash, + password_hash=password_hash, guest_access_token=guest_access_token, threepid=threepid, address=client_addr, @@ -595,8 +625,8 @@ class RegisterRestServlet(RestServlet): ): await self.store.upsert_monthly_active_user(registered_user_id) - # remember that we've now registered that user account, and with - # what user ID (since the user may not have specified) + # Remember that the user account has been registered (and the user + # ID it was registered with, since it might not have been specified). await self.auth_handler.set_session_data( session_id, "registered_user_id", registered_user_id ) diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index 7deaf5b24a..53a43038f0 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -116,8 +116,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"200", channel.result) self.assertDictContainsSubset(det_data, channel.json_body) + @override_config({"enable_registration": False}) def test_POST_disabled_registration(self): - self.hs.config.enable_registration = False request_data = json.dumps({"username": "kermit", "password": "monkey"}) self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None) -- cgit 1.5.1 From e04e465b4d2c66acb8885c31736c7b7bb4e7be52 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Mon, 17 Aug 2020 17:05:00 +0100 Subject: Use the default templates when a custom template file cannot be found (#8037) Fixes https://github.com/matrix-org/synapse/issues/6583 --- changelog.d/8037.feature | 1 + docs/sample_config.yaml | 4 +- synapse/config/_base.py | 100 ++++++++++++++++++++- synapse/config/emailconfig.py | 145 ++++++++++++++----------------- synapse/config/saml2_config.py | 14 +-- synapse/config/sso.py | 37 ++++---- synapse/handlers/account_validity.py | 20 +---- synapse/handlers/auth.py | 12 ++- synapse/handlers/oidc_handler.py | 5 +- synapse/push/mailer.py | 72 +-------------- synapse/push/pusher.py | 31 ++----- synapse/python_dependencies.py | 2 - synapse/rest/client/v2_alpha/account.py | 44 +++------- synapse/rest/client/v2_alpha/register.py | 31 ++----- tests/config/test_base.py | 82 +++++++++++++++++ 15 files changed, 310 insertions(+), 290 deletions(-) create mode 100644 changelog.d/8037.feature create mode 100644 tests/config/test_base.py (limited to 'synapse/handlers/auth.py') diff --git a/changelog.d/8037.feature b/changelog.d/8037.feature new file mode 100644 index 0000000000..2e5127477d --- /dev/null +++ b/changelog.d/8037.feature @@ -0,0 +1 @@ +Use the default template file when its equivalent is not found in a custom template directory. \ No newline at end of file diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 9235b89fb1..f168853f67 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -2002,9 +2002,7 @@ email: # Directory in which Synapse will try to find the template files below. # If not set, default templates from within the Synapse package will be used. # - # DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates. - # If you *do* uncomment it, you will need to make sure that all the templates - # below are in the directory. + # Do not uncomment this setting unless you want to customise the templates. # # Synapse will look for the following templates in this directory: # diff --git a/synapse/config/_base.py b/synapse/config/_base.py index fd137853b1..1417487427 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -18,12 +18,16 @@ import argparse import errno import os +import time +import urllib.parse from collections import OrderedDict from hashlib import sha256 from textwrap import dedent -from typing import Any, List, MutableMapping, Optional +from typing import Any, Callable, List, MutableMapping, Optional import attr +import jinja2 +import pkg_resources import yaml @@ -100,6 +104,11 @@ class Config(object): def __init__(self, root_config=None): self.root = root_config + # Get the path to the default Synapse template directory + self.default_template_dir = pkg_resources.resource_filename( + "synapse", "res/templates" + ) + def __getattr__(self, item: str) -> Any: """ Try and fetch a configuration option that does not exist on this class. @@ -184,6 +193,95 @@ class Config(object): with open(file_path) as file_stream: return file_stream.read() + def read_templates( + self, filenames: List[str], custom_template_directory: Optional[str] = None, + ) -> List[jinja2.Template]: + """Load a list of template files from disk using the given variables. + + This function will attempt to load the given templates from the default Synapse + template directory. If `custom_template_directory` is supplied, that directory + is tried first. + + Files read are treated as Jinja templates. These templates are not rendered yet. + + Args: + filenames: A list of template filenames to read. + + custom_template_directory: A directory to try to look for the templates + before using the default Synapse template directory instead. + + Raises: + ConfigError: if the file's path is incorrect or otherwise cannot be read. + + Returns: + A list of jinja2 templates. + """ + templates = [] + search_directories = [self.default_template_dir] + + # The loader will first look in the custom template directory (if specified) for the + # given filename. If it doesn't find it, it will use the default template dir instead + if custom_template_directory: + # Check that the given template directory exists + if not self.path_exists(custom_template_directory): + raise ConfigError( + "Configured template directory does not exist: %s" + % (custom_template_directory,) + ) + + # Search the custom template directory as well + search_directories.insert(0, custom_template_directory) + + loader = jinja2.FileSystemLoader(search_directories) + env = jinja2.Environment(loader=loader, autoescape=True) + + # Update the environment with our custom filters + env.filters.update( + { + "format_ts": _format_ts_filter, + "mxc_to_http": _create_mxc_to_http_filter(self.public_baseurl), + } + ) + + for filename in filenames: + # Load the template + template = env.get_template(filename) + templates.append(template) + + return templates + + +def _format_ts_filter(value: int, format: str): + return time.strftime(format, time.localtime(value / 1000)) + + +def _create_mxc_to_http_filter(public_baseurl: str) -> Callable: + """Create and return a jinja2 filter that converts MXC urls to HTTP + + Args: + public_baseurl: The public, accessible base URL of the homeserver + """ + + def mxc_to_http_filter(value, width, height, resize_method="crop"): + if value[0:6] != "mxc://": + return "" + + server_and_media_id = value[6:] + fragment = None + if "#" in server_and_media_id: + server_and_media_id, fragment = server_and_media_id.split("#", 1) + fragment = "#" + fragment + + params = {"width": width, "height": height, "method": resize_method} + return "%s_matrix/media/v1/thumbnail/%s?%s%s" % ( + public_baseurl, + server_and_media_id, + urllib.parse.urlencode(params), + fragment or "", + ) + + return mxc_to_http_filter + class RootConfig(object): """ diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index a63acbdc63..7a796996c0 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -23,7 +23,6 @@ from enum import Enum from typing import Optional import attr -import pkg_resources from ._base import Config, ConfigError @@ -98,21 +97,18 @@ class EmailConfig(Config): if parsed[1] == "": raise RuntimeError("Invalid notif_from address") + # A user-configurable template directory template_dir = email_config.get("template_dir") - # we need an absolute path, because we change directory after starting (and - # we don't yet know what auxiliary templates like mail.css we will need). - # (Note that loading as package_resources with jinja.PackageLoader doesn't - # work for the same reason.) - if not template_dir: - template_dir = pkg_resources.resource_filename("synapse", "res/templates") - - self.email_template_dir = os.path.abspath(template_dir) + if isinstance(template_dir, str): + # We need an absolute path, because we change directory after starting (and + # we don't yet know what auxiliary templates like mail.css we will need). + template_dir = os.path.abspath(template_dir) + elif template_dir is not None: + # If template_dir is something other than a str or None, warn the user + raise ConfigError("Config option email.template_dir must be type str") self.email_enable_notifs = email_config.get("enable_notifs", False) - account_validity_config = config.get("account_validity") or {} - account_validity_renewal_enabled = account_validity_config.get("renew_at") - self.threepid_behaviour_email = ( # Have Synapse handle the email sending if account_threepid_delegates.email # is not defined @@ -166,19 +162,6 @@ class EmailConfig(Config): email_config.get("validation_token_lifetime", "1h") ) - if ( - self.email_enable_notifs - or account_validity_renewal_enabled - or self.threepid_behaviour_email == ThreepidBehaviour.LOCAL - ): - # make sure we can import the required deps - import bleach - import jinja2 - - # prevent unused warnings - jinja2 - bleach - if self.threepid_behaviour_email == ThreepidBehaviour.LOCAL: missing = [] if not self.email_notif_from: @@ -196,49 +179,49 @@ class EmailConfig(Config): # These email templates have placeholders in them, and thus must be # parsed using a templating engine during a request - self.email_password_reset_template_html = email_config.get( + password_reset_template_html = email_config.get( "password_reset_template_html", "password_reset.html" ) - self.email_password_reset_template_text = email_config.get( + password_reset_template_text = email_config.get( "password_reset_template_text", "password_reset.txt" ) - self.email_registration_template_html = email_config.get( + registration_template_html = email_config.get( "registration_template_html", "registration.html" ) - self.email_registration_template_text = email_config.get( + registration_template_text = email_config.get( "registration_template_text", "registration.txt" ) - self.email_add_threepid_template_html = email_config.get( + add_threepid_template_html = email_config.get( "add_threepid_template_html", "add_threepid.html" ) - self.email_add_threepid_template_text = email_config.get( + add_threepid_template_text = email_config.get( "add_threepid_template_text", "add_threepid.txt" ) - self.email_password_reset_template_failure_html = email_config.get( + password_reset_template_failure_html = email_config.get( "password_reset_template_failure_html", "password_reset_failure.html" ) - self.email_registration_template_failure_html = email_config.get( + registration_template_failure_html = email_config.get( "registration_template_failure_html", "registration_failure.html" ) - self.email_add_threepid_template_failure_html = email_config.get( + add_threepid_template_failure_html = email_config.get( "add_threepid_template_failure_html", "add_threepid_failure.html" ) # These templates do not support any placeholder variables, so we # will read them from disk once during setup - email_password_reset_template_success_html = email_config.get( + password_reset_template_success_html = email_config.get( "password_reset_template_success_html", "password_reset_success.html" ) - email_registration_template_success_html = email_config.get( + registration_template_success_html = email_config.get( "registration_template_success_html", "registration_success.html" ) - email_add_threepid_template_success_html = email_config.get( + add_threepid_template_success_html = email_config.get( "add_threepid_template_success_html", "add_threepid_success.html" ) - # Check templates exist - for f in [ + # Read all templates from disk + ( self.email_password_reset_template_html, self.email_password_reset_template_text, self.email_registration_template_html, @@ -248,32 +231,36 @@ class EmailConfig(Config): self.email_password_reset_template_failure_html, self.email_registration_template_failure_html, self.email_add_threepid_template_failure_html, - email_password_reset_template_success_html, - email_registration_template_success_html, - email_add_threepid_template_success_html, - ]: - p = os.path.join(self.email_template_dir, f) - if not os.path.isfile(p): - raise ConfigError("Unable to find template file %s" % (p,)) - - # Retrieve content of web templates - filepath = os.path.join( - self.email_template_dir, email_password_reset_template_success_html + password_reset_template_success_html_template, + registration_template_success_html_template, + add_threepid_template_success_html_template, + ) = self.read_templates( + [ + password_reset_template_html, + password_reset_template_text, + registration_template_html, + registration_template_text, + add_threepid_template_html, + add_threepid_template_text, + password_reset_template_failure_html, + registration_template_failure_html, + add_threepid_template_failure_html, + password_reset_template_success_html, + registration_template_success_html, + add_threepid_template_success_html, + ], + template_dir, ) - self.email_password_reset_template_success_html = self.read_file( - filepath, "email.password_reset_template_success_html" - ) - filepath = os.path.join( - self.email_template_dir, email_registration_template_success_html - ) - self.email_registration_template_success_html_content = self.read_file( - filepath, "email.registration_template_success_html" + + # Render templates that do not contain any placeholders + self.email_password_reset_template_success_html_content = ( + password_reset_template_success_html_template.render() ) - filepath = os.path.join( - self.email_template_dir, email_add_threepid_template_success_html + self.email_registration_template_success_html_content = ( + registration_template_success_html_template.render() ) - self.email_add_threepid_template_success_html_content = self.read_file( - filepath, "email.add_threepid_template_success_html" + self.email_add_threepid_template_success_html_content = ( + add_threepid_template_success_html_template.render() ) if self.email_enable_notifs: @@ -290,17 +277,19 @@ class EmailConfig(Config): % (", ".join(missing),) ) - self.email_notif_template_html = email_config.get( + notif_template_html = email_config.get( "notif_template_html", "notif_mail.html" ) - self.email_notif_template_text = email_config.get( + notif_template_text = email_config.get( "notif_template_text", "notif_mail.txt" ) - for f in self.email_notif_template_text, self.email_notif_template_html: - p = os.path.join(self.email_template_dir, f) - if not os.path.isfile(p): - raise ConfigError("Unable to find email template file %s" % (p,)) + ( + self.email_notif_template_html, + self.email_notif_template_text, + ) = self.read_templates( + [notif_template_html, notif_template_text], template_dir, + ) self.email_notif_for_new_users = email_config.get( "notif_for_new_users", True @@ -309,18 +298,20 @@ class EmailConfig(Config): "client_base_url", email_config.get("riot_base_url", None) ) - if account_validity_renewal_enabled: - self.email_expiry_template_html = email_config.get( + if self.account_validity.renew_by_email_enabled: + expiry_template_html = email_config.get( "expiry_template_html", "notice_expiry.html" ) - self.email_expiry_template_text = email_config.get( + expiry_template_text = email_config.get( "expiry_template_text", "notice_expiry.txt" ) - for f in self.email_expiry_template_text, self.email_expiry_template_html: - p = os.path.join(self.email_template_dir, f) - if not os.path.isfile(p): - raise ConfigError("Unable to find email template file %s" % (p,)) + ( + self.account_validity_template_html, + self.account_validity_template_text, + ) = self.read_templates( + [expiry_template_html, expiry_template_text], template_dir, + ) subjects_config = email_config.get("subjects", {}) subjects = {} @@ -400,9 +391,7 @@ class EmailConfig(Config): # Directory in which Synapse will try to find the template files below. # If not set, default templates from within the Synapse package will be used. # - # DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates. - # If you *do* uncomment it, you will need to make sure that all the templates - # below are in the directory. + # Do not uncomment this setting unless you want to customise the templates. # # Synapse will look for the following templates in this directory: # diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py index 9277b5f342..036f8c0e90 100644 --- a/synapse/config/saml2_config.py +++ b/synapse/config/saml2_config.py @@ -18,8 +18,6 @@ import logging from typing import Any, List import attr -import jinja2 -import pkg_resources from synapse.python_dependencies import DependencyException, check_requirements from synapse.util.module_loader import load_module, load_python_module @@ -171,15 +169,9 @@ class SAML2Config(Config): saml2_config.get("saml_session_lifetime", "15m") ) - template_dir = saml2_config.get("template_dir") - if not template_dir: - template_dir = pkg_resources.resource_filename("synapse", "res/templates",) - - loader = jinja2.FileSystemLoader(template_dir) - # enable auto-escape here, to having to remember to escape manually in the - # template - env = jinja2.Environment(loader=loader, autoescape=True) - self.saml2_error_html_template = env.get_template("saml_error.html") + self.saml2_error_html_template = self.read_templates( + ["saml_error.html"], saml2_config.get("template_dir") + ) def _default_saml_config_dict( self, required_attributes: set, optional_attributes: set diff --git a/synapse/config/sso.py b/synapse/config/sso.py index 73b7296399..4427676167 100644 --- a/synapse/config/sso.py +++ b/synapse/config/sso.py @@ -12,11 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os from typing import Any, Dict -import pkg_resources - from ._base import Config @@ -29,22 +26,32 @@ class SSOConfig(Config): def read_config(self, config, **kwargs): sso_config = config.get("sso") or {} # type: Dict[str, Any] - # Pick a template directory in order of: - # * The sso-specific template_dir - # * /path/to/synapse/install/res/templates + # The sso-specific template_dir template_dir = sso_config.get("template_dir") - if not template_dir: - template_dir = pkg_resources.resource_filename("synapse", "res/templates",) - self.sso_template_dir = template_dir - self.sso_account_deactivated_template = self.read_file( - os.path.join(self.sso_template_dir, "sso_account_deactivated.html"), - "sso_account_deactivated_template", + # Read templates from disk + ( + self.sso_redirect_confirm_template, + self.sso_auth_confirm_template, + self.sso_error_template, + sso_account_deactivated_template, + sso_auth_success_template, + ) = self.read_templates( + [ + "sso_redirect_confirm.html", + "sso_auth_confirm.html", + "sso_error.html", + "sso_account_deactivated.html", + "sso_auth_success.html", + ], + template_dir, ) - self.sso_auth_success_template = self.read_file( - os.path.join(self.sso_template_dir, "sso_auth_success.html"), - "sso_auth_success_template", + + # These templates have no placeholders, so render them here + self.sso_account_deactivated_template = ( + sso_account_deactivated_template.render() ) + self.sso_auth_success_template = sso_auth_success_template.render() self.sso_client_whitelist = sso_config.get("client_whitelist") or [] diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py index 590135d19c..b865bf5b48 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py @@ -26,11 +26,6 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import UserID from synapse.util import stringutils -try: - from synapse.push.mailer import load_jinja2_templates -except ImportError: - load_jinja2_templates = None - logger = logging.getLogger(__name__) @@ -47,9 +42,11 @@ class AccountValidityHandler(object): if ( self._account_validity.enabled and self._account_validity.renew_by_email_enabled - and load_jinja2_templates ): # Don't do email-specific configuration if renewal by email is disabled. + self._template_html = self.config.account_validity_template_html + self._template_text = self.config.account_validity_template_text + try: app_name = self.hs.config.email_app_name @@ -65,17 +62,6 @@ class AccountValidityHandler(object): self._raw_from = email.utils.parseaddr(self._from_string)[1] - self._template_html, self._template_text = load_jinja2_templates( - self.config.email_template_dir, - [ - self.config.email_expiry_template_html, - self.config.email_expiry_template_text, - ], - apply_format_ts_filter=True, - apply_mxc_to_http_filter=True, - public_baseurl=self.config.public_baseurl, - ) - # Check the renewal emails to send and send them every 30min. def send_emails(): # run as a background process to make sure that the database transactions diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index c24e7bafe0..68d6870e40 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -42,7 +42,6 @@ from synapse.http.site import SynapseRequest from synapse.logging.context import defer_to_thread from synapse.metrics.background_process_metrics import run_as_background_process from synapse.module_api import ModuleApi -from synapse.push.mailer import load_jinja2_templates from synapse.types import Requester, UserID from synapse.util import stringutils as stringutils from synapse.util.threepids import canonicalise_email @@ -132,18 +131,17 @@ class AuthHandler(BaseHandler): # after the SSO completes and before redirecting them back to their client. # It notifies the user they are about to give access to their matrix account # to the client. - self._sso_redirect_confirm_template = load_jinja2_templates( - hs.config.sso_template_dir, ["sso_redirect_confirm.html"], - )[0] + self._sso_redirect_confirm_template = hs.config.sso_redirect_confirm_template + # The following template is shown during user interactive authentication # in the fallback auth scenario. It notifies the user that they are # authenticating for an operation to occur on their account. - self._sso_auth_confirm_template = load_jinja2_templates( - hs.config.sso_template_dir, ["sso_auth_confirm.html"], - )[0] + self._sso_auth_confirm_template = hs.config.sso_auth_confirm_template + # The following template is shown after a successful user interactive # authentication session. It tells the user they can close the window. self._sso_auth_success_template = hs.config.sso_auth_success_template + # The following template is shown during the SSO authentication process if # the account is deactivated. self._sso_account_deactivated_template = ( diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index fa5ee5de8f..87d28a7ae9 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -38,7 +38,6 @@ from synapse.config import ConfigError from synapse.http.server import respond_with_html from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable -from synapse.push.mailer import load_jinja2_templates from synapse.types import UserID, map_username_to_mxid_localpart if TYPE_CHECKING: @@ -123,9 +122,7 @@ class OidcHandler: self._hostname = hs.hostname # type: str self._server_name = hs.config.server_name # type: str self._macaroon_secret_key = hs.config.macaroon_secret_key - self._error_template = load_jinja2_templates( - hs.config.sso_template_dir, ["sso_error.html"] - )[0] + self._error_template = hs.config.sso_error_template # identifier for the external_ids table self._auth_provider_id = "oidc" diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index af117fddf9..c38e037281 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -16,8 +16,7 @@ import email.mime.multipart import email.utils import logging -import time -import urllib +import urllib.parse from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText from typing import Iterable, List, TypeVar @@ -640,72 +639,3 @@ def string_ordinal_total(s): for c in s: tot += ord(c) return tot - - -def format_ts_filter(value, format): - return time.strftime(format, time.localtime(value / 1000)) - - -def load_jinja2_templates( - template_dir, - template_filenames, - apply_format_ts_filter=False, - apply_mxc_to_http_filter=False, - public_baseurl=None, -): - """Loads and returns one or more jinja2 templates and applies optional filters - - Args: - template_dir (str): The directory where templates are stored - template_filenames (list[str]): A list of template filenames - apply_format_ts_filter (bool): Whether to apply a template filter that formats - timestamps - apply_mxc_to_http_filter (bool): Whether to apply a template filter that converts - mxc urls to http urls - public_baseurl (str|None): The public baseurl of the server. Required for - apply_mxc_to_http_filter to be enabled - - Returns: - A list of jinja2 templates corresponding to the given list of filenames, - with order preserved - """ - logger.info( - "loading email templates %s from '%s'", template_filenames, template_dir - ) - loader = jinja2.FileSystemLoader(template_dir) - env = jinja2.Environment(loader=loader) - - if apply_format_ts_filter: - env.filters["format_ts"] = format_ts_filter - - if apply_mxc_to_http_filter and public_baseurl: - env.filters["mxc_to_http"] = _create_mxc_to_http_filter(public_baseurl) - - templates = [] - for template_filename in template_filenames: - template = env.get_template(template_filename) - templates.append(template) - - return templates - - -def _create_mxc_to_http_filter(public_baseurl): - def mxc_to_http_filter(value, width, height, resize_method="crop"): - if value[0:6] != "mxc://": - return "" - - serverAndMediaId = value[6:] - fragment = None - if "#" in serverAndMediaId: - (serverAndMediaId, fragment) = serverAndMediaId.split("#", 1) - fragment = "#" + fragment - - params = {"width": width, "height": height, "method": resize_method} - return "%s_matrix/media/v1/thumbnail/%s?%s%s" % ( - public_baseurl, - serverAndMediaId, - urllib.parse.urlencode(params), - fragment or "", - ) - - return mxc_to_http_filter diff --git a/synapse/push/pusher.py b/synapse/push/pusher.py index 8ad0bf5936..f626797133 100644 --- a/synapse/push/pusher.py +++ b/synapse/push/pusher.py @@ -15,22 +15,13 @@ import logging +from synapse.push.emailpusher import EmailPusher +from synapse.push.mailer import Mailer + from .httppusher import HttpPusher logger = logging.getLogger(__name__) -# We try importing this if we can (it will fail if we don't -# have the optional email dependencies installed). We don't -# yet have the config to know if we need the email pusher, -# but importing this after daemonizing seems to fail -# (even though a simple test of importing from a daemonized -# process works fine) -try: - from synapse.push.emailpusher import EmailPusher - from synapse.push.mailer import Mailer, load_jinja2_templates -except Exception: - pass - class PusherFactory(object): def __init__(self, hs): @@ -43,16 +34,8 @@ class PusherFactory(object): if hs.config.email_enable_notifs: self.mailers = {} # app_name -> Mailer - self.notif_template_html, self.notif_template_text = load_jinja2_templates( - self.config.email_template_dir, - [ - self.config.email_notif_template_html, - self.config.email_notif_template_text, - ], - apply_format_ts_filter=True, - apply_mxc_to_http_filter=True, - public_baseurl=self.config.public_baseurl, - ) + self._notif_template_html = hs.config.email_notif_template_html + self._notif_template_text = hs.config.email_notif_template_text self.pusher_types["email"] = self._create_email_pusher @@ -73,8 +56,8 @@ class PusherFactory(object): mailer = Mailer( hs=self.hs, app_name=app_name, - template_html=self.notif_template_html, - template_text=self.notif_template_text, + template_html=self._notif_template_html, + template_text=self._notif_template_text, ) self.mailers[app_name] = mailer return EmailPusher(self.hs, pusherdict, mailer) diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index e5f22fb858..3250d41dde 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -78,8 +78,6 @@ CONDITIONAL_REQUIREMENTS = { "matrix-synapse-ldap3": ["matrix-synapse-ldap3>=0.1"], # we use execute_batch, which arrived in psycopg 2.7. "postgres": ["psycopg2>=2.7"], - # ConsentResource uses select_autoescape, which arrived in jinja 2.9 - "resources.consent": ["Jinja2>=2.9"], # ACME support is required to provision TLS certificates from authorities # that use the protocol, such as Let's Encrypt. "acme": [ diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index fead85074b..203e76b9f2 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -32,7 +32,7 @@ from synapse.http.servlet import ( parse_json_object_from_request, parse_string, ) -from synapse.push.mailer import Mailer, load_jinja2_templates +from synapse.push.mailer import Mailer from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.stringutils import assert_valid_client_secret, random_string from synapse.util.threepids import canonicalise_email, check_3pid_allowed @@ -53,21 +53,11 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): self.identity_handler = hs.get_handlers().identity_handler if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - template_html, template_text = load_jinja2_templates( - self.config.email_template_dir, - [ - self.config.email_password_reset_template_html, - self.config.email_password_reset_template_text, - ], - apply_format_ts_filter=True, - apply_mxc_to_http_filter=True, - public_baseurl=self.config.public_baseurl, - ) self.mailer = Mailer( hs=self.hs, app_name=self.config.email_app_name, - template_html=template_html, - template_text=template_text, + template_html=self.config.email_password_reset_template_html, + template_text=self.config.email_password_reset_template_text, ) async def on_POST(self, request): @@ -169,9 +159,8 @@ class PasswordResetSubmitTokenServlet(RestServlet): self.clock = hs.get_clock() self.store = hs.get_datastore() if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - (self.failure_email_template,) = load_jinja2_templates( - self.config.email_template_dir, - [self.config.email_password_reset_template_failure_html], + self._failure_email_template = ( + self.config.email_password_reset_template_failure_html ) async def on_GET(self, request, medium): @@ -214,14 +203,14 @@ class PasswordResetSubmitTokenServlet(RestServlet): return None # Otherwise show the success template - html = self.config.email_password_reset_template_success_html + html = self.config.email_password_reset_template_success_html_content status_code = 200 except ThreepidValidationError as e: status_code = e.code # Show a failure page with a reason template_vars = {"failure_reason": e.msg} - html = self.failure_email_template.render(**template_vars) + html = self._failure_email_template.render(**template_vars) respond_with_html(request, status_code, html) @@ -411,19 +400,11 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): self.store = self.hs.get_datastore() if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - template_html, template_text = load_jinja2_templates( - self.config.email_template_dir, - [ - self.config.email_add_threepid_template_html, - self.config.email_add_threepid_template_text, - ], - public_baseurl=self.config.public_baseurl, - ) self.mailer = Mailer( hs=self.hs, app_name=self.config.email_app_name, - template_html=template_html, - template_text=template_text, + template_html=self.config.email_add_threepid_template_html, + template_text=self.config.email_add_threepid_template_text, ) async def on_POST(self, request): @@ -578,9 +559,8 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet): self.clock = hs.get_clock() self.store = hs.get_datastore() if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - (self.failure_email_template,) = load_jinja2_templates( - self.config.email_template_dir, - [self.config.email_add_threepid_template_failure_html], + self._failure_email_template = ( + self.config.email_add_threepid_template_failure_html ) async def on_GET(self, request): @@ -631,7 +611,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet): # Show a failure page with a reason template_vars = {"failure_reason": e.msg} - html = self.failure_email_template.render(**template_vars) + html = self._failure_email_template.render(**template_vars) respond_with_html(request, status_code, html) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index f808175698..7290fd0756 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -44,7 +44,7 @@ from synapse.http.servlet import ( parse_json_object_from_request, parse_string, ) -from synapse.push.mailer import load_jinja2_templates +from synapse.push.mailer import Mailer from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.stringutils import assert_valid_client_secret, random_string @@ -81,23 +81,11 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): self.config = hs.config if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - from synapse.push.mailer import Mailer, load_jinja2_templates - - template_html, template_text = load_jinja2_templates( - self.config.email_template_dir, - [ - self.config.email_registration_template_html, - self.config.email_registration_template_text, - ], - apply_format_ts_filter=True, - apply_mxc_to_http_filter=True, - public_baseurl=self.config.public_baseurl, - ) self.mailer = Mailer( hs=self.hs, app_name=self.config.email_app_name, - template_html=template_html, - template_text=template_text, + template_html=self.config.email_registration_template_html, + template_text=self.config.email_registration_template_text, ) async def on_POST(self, request): @@ -262,15 +250,8 @@ class RegistrationSubmitTokenServlet(RestServlet): self.store = hs.get_datastore() if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - (self.failure_email_template,) = load_jinja2_templates( - self.config.email_template_dir, - [self.config.email_registration_template_failure_html], - ) - - if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - (self.failure_email_template,) = load_jinja2_templates( - self.config.email_template_dir, - [self.config.email_registration_template_failure_html], + self._failure_email_template = ( + self.config.email_registration_template_failure_html ) async def on_GET(self, request, medium): @@ -318,7 +299,7 @@ class RegistrationSubmitTokenServlet(RestServlet): # Show a failure page with a reason template_vars = {"failure_reason": e.msg} - html = self.failure_email_template.render(**template_vars) + html = self._failure_email_template.render(**template_vars) respond_with_html(request, status_code, html) diff --git a/tests/config/test_base.py b/tests/config/test_base.py new file mode 100644 index 0000000000..42ee5f56d9 --- /dev/null +++ b/tests/config/test_base.py @@ -0,0 +1,82 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os.path +import tempfile + +from synapse.config import ConfigError +from synapse.util.stringutils import random_string + +from tests import unittest + + +class BaseConfigTestCase(unittest.HomeserverTestCase): + def prepare(self, reactor, clock, hs): + self.hs = hs + + def test_loading_missing_templates(self): + # Use a temporary directory that exists on the system, but that isn't likely to + # contain template files + with tempfile.TemporaryDirectory() as tmp_dir: + # Attempt to load an HTML template from our custom template directory + template = self.hs.config.read_templates(["sso_error.html"], tmp_dir)[0] + + # If no errors, we should've gotten the default template instead + + # Render the template + a_random_string = random_string(5) + html_content = template.render({"error_description": a_random_string}) + + # Check that our string exists in the template + self.assertIn( + a_random_string, + html_content, + "Template file did not contain our test string", + ) + + def test_loading_custom_templates(self): + # Use a temporary directory that exists on the system + with tempfile.TemporaryDirectory() as tmp_dir: + # Create a temporary bogus template file + with tempfile.NamedTemporaryFile(dir=tmp_dir) as tmp_template: + # Get temporary file's filename + template_filename = os.path.basename(tmp_template.name) + + # Write a custom HTML template + contents = b"{{ test_variable }}" + tmp_template.write(contents) + tmp_template.flush() + + # Attempt to load the template from our custom template directory + template = ( + self.hs.config.read_templates([template_filename], tmp_dir) + )[0] + + # Render the template + a_random_string = random_string(5) + html_content = template.render({"test_variable": a_random_string}) + + # Check that our string exists in the template + self.assertIn( + a_random_string, + html_content, + "Template file did not contain our test string", + ) + + def test_loading_template_from_nonexistent_custom_directory(self): + with self.assertRaises(ConfigError): + self.hs.config.read_templates( + ["some_filename.html"], "a_nonexistent_directory" + ) -- cgit 1.5.1 From 3f91638da6ea0aeaf789ddc8ca1e624a11b7ebb2 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 20 Aug 2020 15:42:58 -0400 Subject: Allow denying or shadow banning registrations via the spam checker (#8034) --- changelog.d/8034.feature | 1 + synapse/events/spamcheck.py | 35 ++++++++++++++- synapse/handlers/auth.py | 8 ++++ synapse/handlers/cas_handler.py | 11 ++++- synapse/handlers/oidc_handler.py | 21 +++++++-- synapse/handlers/register.py | 26 ++++++++++- synapse/handlers/saml_handler.py | 18 +++++++- synapse/rest/client/v2_alpha/register.py | 5 +++ synapse/spam_checker_api/__init__.py | 11 +++++ .../main/schema/delta/58/07persist_ui_auth_ips.sql | 25 +++++++++++ synapse/storage/databases/main/ui_auth.py | 39 +++++++++++++++- tests/handlers/test_oidc.py | 18 ++++++-- tests/handlers/test_register.py | 52 +++++++++++++++++++++- tests/handlers/test_user_directory.py | 6 +-- 14 files changed, 258 insertions(+), 18 deletions(-) create mode 100644 changelog.d/8034.feature create mode 100644 synapse/storage/databases/main/schema/delta/58/07persist_ui_auth_ips.sql (limited to 'synapse/handlers/auth.py') diff --git a/changelog.d/8034.feature b/changelog.d/8034.feature new file mode 100644 index 0000000000..813e6d0903 --- /dev/null +++ b/changelog.d/8034.feature @@ -0,0 +1 @@ +Add support for shadow-banning users (ignoring any message send requests). diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py index 1ffc9525d1..a7cddac974 100644 --- a/synapse/events/spamcheck.py +++ b/synapse/events/spamcheck.py @@ -15,9 +15,10 @@ # limitations under the License. import inspect -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional, Tuple -from synapse.spam_checker_api import SpamCheckerApi +from synapse.spam_checker_api import RegistrationBehaviour, SpamCheckerApi +from synapse.types import Collection MYPY = False if MYPY: @@ -160,3 +161,33 @@ class SpamChecker(object): return True return False + + def check_registration_for_spam( + self, + email_threepid: Optional[dict], + username: Optional[str], + request_info: Collection[Tuple[str, str]], + ) -> RegistrationBehaviour: + """Checks if we should allow the given registration request. + + Args: + email_threepid: The email threepid used for registering, if any + username: The request user name, if any + request_info: List of tuples of user agent and IP that + were used during the registration process. + + Returns: + Enum for how the request should be handled + """ + + for spam_checker in self.spam_checkers: + # For backwards compatibility, only run if the method exists on the + # spam checker + checker = getattr(spam_checker, "check_registration_for_spam", None) + if checker: + behaviour = checker(email_threepid, username, request_info) + assert isinstance(behaviour, RegistrationBehaviour) + if behaviour != RegistrationBehaviour.ALLOW: + return behaviour + + return RegistrationBehaviour.ALLOW diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 68d6870e40..654f58ddae 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -364,6 +364,14 @@ class AuthHandler(BaseHandler): # authentication flow. await self.store.set_ui_auth_clientdict(sid, clientdict) + user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[ + 0 + ].decode("ascii", "surrogateescape") + + await self.store.add_user_agent_ip_to_ui_auth_session( + session.session_id, user_agent, clientip + ) + if not authdict: raise InteractiveAuthIncompleteError( session.session_id, self._auth_dict_for_flows(flows, session.session_id) diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py index 786e608fa2..a4cc4b9a5a 100644 --- a/synapse/handlers/cas_handler.py +++ b/synapse/handlers/cas_handler.py @@ -35,6 +35,7 @@ class CasHandler: """ def __init__(self, hs): + self.hs = hs self._hostname = hs.hostname self._auth_handler = hs.get_auth_handler() self._registration_handler = hs.get_registration_handler() @@ -210,8 +211,16 @@ class CasHandler: else: if not registered_user_id: + # Pull out the user-agent and IP from the request. + user_agent = request.requestHeaders.getRawHeaders( + b"User-Agent", default=[b""] + )[0].decode("ascii", "surrogateescape") + ip_address = self.hs.get_ip_from_request(request) + registered_user_id = await self._registration_handler.register_user( - localpart=localpart, default_display_name=user_display_name + localpart=localpart, + default_display_name=user_display_name, + user_agent_ips=(user_agent, ip_address), ) await self._auth_handler.complete_sso_login( diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index dd3703cbd2..c5bd2fea68 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -93,6 +93,7 @@ class OidcHandler: """ def __init__(self, hs: "HomeServer"): + self.hs = hs self._callback_url = hs.config.oidc_callback_url # type: str self._scopes = hs.config.oidc_scopes # type: List[str] self._client_auth = ClientAuth( @@ -689,9 +690,17 @@ class OidcHandler: self._render_error(request, "invalid_token", str(e)) return + # Pull out the user-agent and IP from the request. + user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[ + 0 + ].decode("ascii", "surrogateescape") + ip_address = self.hs.get_ip_from_request(request) + # Call the mapper to register/login the user try: - user_id = await self._map_userinfo_to_user(userinfo, token) + user_id = await self._map_userinfo_to_user( + userinfo, token, user_agent, ip_address + ) except MappingException as e: logger.exception("Could not map user") self._render_error(request, "mapping_error", str(e)) @@ -828,7 +837,9 @@ class OidcHandler: now = self._clock.time_msec() return now < expiry - async def _map_userinfo_to_user(self, userinfo: UserInfo, token: Token) -> str: + async def _map_userinfo_to_user( + self, userinfo: UserInfo, token: Token, user_agent: str, ip_address: str + ) -> str: """Maps a UserInfo object to a mxid. UserInfo should have a claim that uniquely identifies users. This claim @@ -843,6 +854,8 @@ class OidcHandler: Args: userinfo: an object representing the user token: a dict with the tokens obtained from the provider + user_agent: The user agent of the client making the request. + ip_address: The IP address of the client making the request. Raises: MappingException: if there was an error while mapping some properties @@ -899,7 +912,9 @@ class OidcHandler: # It's the first time this user is logging in and the mapped mxid was # not taken, register the user registered_user_id = await self._registration_handler.register_user( - localpart=localpart, default_display_name=attributes["display_name"], + localpart=localpart, + default_display_name=attributes["display_name"], + user_agent_ips=(user_agent, ip_address), ) await self._datastore.record_user_external_id( diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index ccd96e4626..cde2dbca92 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -26,6 +26,7 @@ from synapse.replication.http.register import ( ReplicationPostRegisterActionsServlet, ReplicationRegisterServlet, ) +from synapse.spam_checker_api import RegistrationBehaviour from synapse.storage.state import StateFilter from synapse.types import RoomAlias, UserID, create_requester @@ -52,6 +53,8 @@ class RegistrationHandler(BaseHandler): self.macaroon_gen = hs.get_macaroon_generator() self._server_notices_mxid = hs.config.server_notices_mxid + self.spam_checker = hs.get_spam_checker() + if hs.config.worker_app: self._register_client = ReplicationRegisterServlet.make_client(hs) self._register_device_client = RegisterDeviceReplicationServlet.make_client( @@ -144,7 +147,7 @@ class RegistrationHandler(BaseHandler): address=None, bind_emails=[], by_admin=False, - shadow_banned=False, + user_agent_ips=None, ): """Registers a new client on the server. @@ -162,7 +165,8 @@ class RegistrationHandler(BaseHandler): bind_emails (List[str]): list of emails to bind to this account. by_admin (bool): True if this registration is being made via the admin api, otherwise False. - shadow_banned (bool): Shadow-ban the created user. + user_agent_ips (List[(str, str)]): Tuples of IP addresses and user-agents used + during the registration process. Returns: str: user_id Raises: @@ -170,6 +174,24 @@ class RegistrationHandler(BaseHandler): """ self.check_registration_ratelimit(address) + result = self.spam_checker.check_registration_for_spam( + threepid, localpart, user_agent_ips or [], + ) + + if result == RegistrationBehaviour.DENY: + logger.info( + "Blocked registration of %r", localpart, + ) + # We return a 429 to make it not obvious that they've been + # denied. + raise SynapseError(429, "Rate limited") + + shadow_banned = result == RegistrationBehaviour.SHADOW_BAN + if shadow_banned: + logger.info( + "Shadow banning registration of %r", localpart, + ) + # do not check_auth_blocking if the call is coming through the Admin API if not by_admin: await self.auth.check_auth_blocking(threepid=threepid) diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index c1fcb98454..b426199aa6 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -54,6 +54,7 @@ class Saml2SessionData: class SamlHandler: def __init__(self, hs: "synapse.server.HomeServer"): + self.hs = hs self._saml_client = Saml2Client(hs.config.saml2_sp_config) self._auth = hs.get_auth() self._auth_handler = hs.get_auth_handler() @@ -133,8 +134,14 @@ class SamlHandler: # the dict. self.expire_sessions() + # Pull out the user-agent and IP from the request. + user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[ + 0 + ].decode("ascii", "surrogateescape") + ip_address = self.hs.get_ip_from_request(request) + user_id, current_session = await self._map_saml_response_to_user( - resp_bytes, relay_state + resp_bytes, relay_state, user_agent, ip_address ) # Complete the interactive auth session or the login. @@ -147,7 +154,11 @@ class SamlHandler: await self._auth_handler.complete_sso_login(user_id, request, relay_state) async def _map_saml_response_to_user( - self, resp_bytes: str, client_redirect_url: str + self, + resp_bytes: str, + client_redirect_url: str, + user_agent: str, + ip_address: str, ) -> Tuple[str, Optional[Saml2SessionData]]: """ Given a sample response, retrieve the cached session and user for it. @@ -155,6 +166,8 @@ class SamlHandler: Args: resp_bytes: The SAML response. client_redirect_url: The redirect URL passed in by the client. + user_agent: The user agent of the client making the request. + ip_address: The IP address of the client making the request. Returns: Tuple of the user ID and SAML session associated with this response. @@ -291,6 +304,7 @@ class SamlHandler: localpart=localpart, default_display_name=displayname, bind_emails=emails, + user_agent_ips=(user_agent, ip_address), ) await self._datastore.record_user_external_id( diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 7290fd0756..be0e680ac5 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -591,12 +591,17 @@ class RegisterRestServlet(RestServlet): Codes.THREEPID_IN_USE, ) + entries = await self.store.get_user_agents_ips_to_ui_auth_session( + session_id + ) + registered_user_id = await self.registration_handler.register_user( localpart=desired_username, password_hash=password_hash, guest_access_token=guest_access_token, threepid=threepid, address=client_addr, + user_agent_ips=entries, ) # Necessary due to auth checks prior to the threepid being # written to the db diff --git a/synapse/spam_checker_api/__init__.py b/synapse/spam_checker_api/__init__.py index 7f63f1bfa0..9be92e2565 100644 --- a/synapse/spam_checker_api/__init__.py +++ b/synapse/spam_checker_api/__init__.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from enum import Enum from twisted.internet import defer @@ -25,6 +26,16 @@ if MYPY: logger = logging.getLogger(__name__) +class RegistrationBehaviour(Enum): + """ + Enum to define whether a registration request should allowed, denied, or shadow-banned. + """ + + ALLOW = "allow" + SHADOW_BAN = "shadow_ban" + DENY = "deny" + + class SpamCheckerApi(object): """A proxy object that gets passed to spam checkers so they can get access to rooms and other relevant information. diff --git a/synapse/storage/databases/main/schema/delta/58/07persist_ui_auth_ips.sql b/synapse/storage/databases/main/schema/delta/58/07persist_ui_auth_ips.sql new file mode 100644 index 0000000000..4cc96a5341 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/07persist_ui_auth_ips.sql @@ -0,0 +1,25 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- A table of the IP address and user-agent used to complete each step of a +-- user-interactive authentication session. +CREATE TABLE IF NOT EXISTS ui_auth_sessions_ips( + session_id TEXT NOT NULL, + ip TEXT NOT NULL, + user_agent TEXT NOT NULL, + UNIQUE (session_id, ip, user_agent), + FOREIGN KEY (session_id) + REFERENCES ui_auth_sessions (session_id) +); diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py index 6281a41a3d..9eef8e57c5 100644 --- a/synapse/storage/databases/main/ui_auth.py +++ b/synapse/storage/databases/main/ui_auth.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union import attr @@ -260,6 +260,34 @@ class UIAuthWorkerStore(SQLBaseStore): return serverdict.get(key, default) + async def add_user_agent_ip_to_ui_auth_session( + self, session_id: str, user_agent: str, ip: str, + ): + """Add the given user agent / IP to the tracking table + """ + await self.db_pool.simple_upsert( + table="ui_auth_sessions_ips", + keyvalues={"session_id": session_id, "user_agent": user_agent, "ip": ip}, + values={}, + desc="add_user_agent_ip_to_ui_auth_session", + ) + + async def get_user_agents_ips_to_ui_auth_session( + self, session_id: str, + ) -> List[Tuple[str, str]]: + """Get the given user agents / IPs used during the ui auth process + + Returns: + List of user_agent/ip pairs + """ + rows = await self.db_pool.simple_select_list( + table="ui_auth_sessions_ips", + keyvalues={"session_id": session_id}, + retcols=("user_agent", "ip"), + desc="get_user_agents_ips_to_ui_auth_session", + ) + return [(row["user_agent"], row["ip"]) for row in rows] + class UIAuthStore(UIAuthWorkerStore): def delete_old_ui_auth_sessions(self, expiration_time: int): @@ -285,6 +313,15 @@ class UIAuthStore(UIAuthWorkerStore): txn.execute(sql, [expiration_time]) session_ids = [r[0] for r in txn.fetchall()] + # Delete the corresponding IP/user agents. + self.db_pool.simple_delete_many_txn( + txn, + table="ui_auth_sessions_ips", + column="session_id", + iterable=session_ids, + keyvalues={}, + ) + # Delete the corresponding completed credentials. self.db_pool.simple_delete_many_txn( txn, diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index 1bb25ab684..f92f3b8c15 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -374,12 +374,16 @@ class OidcHandlerTestCase(HomeserverTestCase): self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo) self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id) self.handler._auth_handler.complete_sso_login = simple_async_mock() - request = Mock(spec=["args", "getCookie", "addCookie"]) + request = Mock( + spec=["args", "getCookie", "addCookie", "requestHeaders", "getClientIP"] + ) code = "code" state = "state" nonce = "nonce" client_redirect_url = "http://client/redirect" + user_agent = "Browser" + ip_address = "10.0.0.1" session = self.handler._generate_oidc_session_token( state=state, nonce=nonce, @@ -392,6 +396,10 @@ class OidcHandlerTestCase(HomeserverTestCase): request.args[b"code"] = [code.encode("utf-8")] request.args[b"state"] = [state.encode("utf-8")] + request.requestHeaders = Mock(spec=["getRawHeaders"]) + request.requestHeaders.getRawHeaders.return_value = [user_agent.encode("ascii")] + request.getClientIP.return_value = ip_address + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) self.handler._auth_handler.complete_sso_login.assert_called_once_with( @@ -399,7 +407,9 @@ class OidcHandlerTestCase(HomeserverTestCase): ) self.handler._exchange_code.assert_called_once_with(code) self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce) - self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token) + self.handler._map_userinfo_to_user.assert_called_once_with( + userinfo, token, user_agent, ip_address + ) self.handler._fetch_userinfo.assert_not_called() self.handler._render_error.assert_not_called() @@ -431,7 +441,9 @@ class OidcHandlerTestCase(HomeserverTestCase): ) self.handler._exchange_code.assert_called_once_with(code) self.handler._parse_id_token.assert_not_called() - self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token) + self.handler._map_userinfo_to_user.assert_called_once_with( + userinfo, token, user_agent, ip_address + ) self.handler._fetch_userinfo.assert_called_once_with(token) self.handler._render_error.assert_not_called() diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index e364b1bd62..5c92d0e8c9 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -17,18 +17,21 @@ from mock import Mock from twisted.internet import defer +from synapse.api.auth import Auth from synapse.api.constants import UserTypes from synapse.api.errors import Codes, ResourceLimitError, SynapseError from synapse.handlers.register import RegistrationHandler +from synapse.spam_checker_api import RegistrationBehaviour from synapse.types import RoomAlias, UserID, create_requester from tests.test_utils import make_awaitable from tests.unittest import override_config +from tests.utils import mock_getRawHeaders from .. import unittest -class RegistrationHandlers(object): +class RegistrationHandlers: def __init__(self, hs): self.registration_handler = RegistrationHandler(hs) @@ -475,6 +478,53 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.handler.register_user(localpart=invalid_user_id), SynapseError ) + def test_spam_checker_deny(self): + """A spam checker can deny registration, which results in an error.""" + + class DenyAll: + def check_registration_for_spam( + self, email_threepid, username, request_info + ): + return RegistrationBehaviour.DENY + + # Configure a spam checker that denies all users. + spam_checker = self.hs.get_spam_checker() + spam_checker.spam_checkers = [DenyAll()] + + self.get_failure(self.handler.register_user(localpart="user"), SynapseError) + + def test_spam_checker_shadow_ban(self): + """A spam checker can choose to shadow-ban a user, which allows registration to succeed.""" + + class BanAll: + def check_registration_for_spam( + self, email_threepid, username, request_info + ): + return RegistrationBehaviour.SHADOW_BAN + + # Configure a spam checker that denies all users. + spam_checker = self.hs.get_spam_checker() + spam_checker.spam_checkers = [BanAll()] + + user_id = self.get_success(self.handler.register_user(localpart="user")) + + # Get an access token. + token = self.macaroon_generator.generate_access_token(user_id) + self.get_success( + self.store.add_access_token_to_user( + user_id=user_id, token=token, device_id=None, valid_until_ms=None + ) + ) + + # Ensure the user was marked as shadow-banned. + request = Mock(args={}) + request.args[b"access_token"] = [token.encode("ascii")] + request.requestHeaders.getRawHeaders = mock_getRawHeaders() + auth = Auth(self.hs) + requester = self.get_success(auth.get_user_by_req(request)) + + self.assertTrue(requester.shadow_banned) + async def get_or_create_user( self, requester, localpart, displayname, password_hash=None ): diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index 31ed89a5cd..87be94111f 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -238,7 +238,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): def test_spam_checker(self): """ - A user which fails to the spam checks will not appear in search results. + A user which fails the spam checks will not appear in search results. """ u1 = self.register_user("user1", "pass") u1_token = self.login(u1, "pass") @@ -269,7 +269,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): # Configure a spam checker that does not filter any users. spam_checker = self.hs.get_spam_checker() - class AllowAll(object): + class AllowAll: def check_username_for_spam(self, user_profile): # Allow all users. return False @@ -282,7 +282,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): self.assertEqual(len(s["results"]), 1) # Configure a spam checker that filters all users. - class BlockAll(object): + class BlockAll: def check_username_for_spam(self, user_profile): # All users are spammy. return True -- cgit 1.5.1