From 6e613a10d072c32e72d6b97b2d178bb840769f3e Mon Sep 17 00:00:00 2001 From: Callum Brown Date: Wed, 18 Aug 2021 13:13:35 +0100 Subject: Display an error page during failure of fallback UIA. (#10561) --- synapse/rest/client/auth.py | 39 ++++++++++++++++++++++++--------------- 1 file changed, 24 insertions(+), 15 deletions(-) (limited to 'synapse/rest/client/auth.py') diff --git a/synapse/rest/client/auth.py b/synapse/rest/client/auth.py index 6ea1b50a62..73284e48ec 100644 --- a/synapse/rest/client/auth.py +++ b/synapse/rest/client/auth.py @@ -16,7 +16,7 @@ import logging from typing import TYPE_CHECKING from synapse.api.constants import LoginType -from synapse.api.errors import SynapseError +from synapse.api.errors import LoginError, SynapseError from synapse.api.urls import CLIENT_API_PREFIX from synapse.http.server import respond_with_html from synapse.http.servlet import RestServlet, parse_string @@ -95,29 +95,32 @@ class AuthRestServlet(RestServlet): authdict = {"response": response, "session": session} - success = await self.auth_handler.add_oob_auth( - LoginType.RECAPTCHA, authdict, request.getClientIP() - ) - - if success: - html = self.success_template.render() - else: + try: + await self.auth_handler.add_oob_auth( + LoginType.RECAPTCHA, authdict, request.getClientIP() + ) + except LoginError as e: + # Authentication failed, let user try again html = self.recaptcha_template.render( session=session, myurl="%s/r0/auth/%s/fallback/web" % (CLIENT_API_PREFIX, LoginType.RECAPTCHA), sitekey=self.hs.config.recaptcha_public_key, + error=e.msg, ) + else: + # No LoginError was raised, so authentication was successful + html = self.success_template.render() + elif stagetype == LoginType.TERMS: authdict = {"session": session} - success = await self.auth_handler.add_oob_auth( - LoginType.TERMS, authdict, request.getClientIP() - ) - - if success: - html = self.success_template.render() - else: + try: + await self.auth_handler.add_oob_auth( + LoginType.TERMS, authdict, request.getClientIP() + ) + except LoginError as e: + # Authentication failed, let user try again html = self.terms_template.render( session=session, terms_url="%s_matrix/consent?v=%s" @@ -127,10 +130,16 @@ class AuthRestServlet(RestServlet): ), myurl="%s/r0/auth/%s/fallback/web" % (CLIENT_API_PREFIX, LoginType.TERMS), + error=e.msg, ) + else: + # No LoginError was raised, so authentication was successful + html = self.success_template.render() + elif stagetype == LoginType.SSO: # The SSO fallback workflow should not post here, raise SynapseError(404, "Fallback SSO auth does not support POST requests.") + else: raise SynapseError(404, "Unknown auth stage type") -- cgit 1.5.1 From 947dbbdfd1e0029da66f956d277b7c089928e1e7 Mon Sep 17 00:00:00 2001 From: Callum Brown Date: Sat, 21 Aug 2021 22:14:43 +0100 Subject: Implement MSC3231: Token authenticated registration (#10142) Signed-off-by: Callum Brown This is part of my GSoC project implementing [MSC3231](https://github.com/matrix-org/matrix-doc/pull/3231). --- changelog.d/10142.feature | 1 + docs/SUMMARY.md | 1 + docs/sample_config.yaml | 15 + .../admin_api/registration_tokens.md | 295 +++++++++ docs/workers.md | 1 + synapse/api/constants.py | 1 + synapse/app/generic_worker.py | 6 +- synapse/config/ratelimiting.py | 11 + synapse/config/registration.py | 15 + synapse/handlers/ui_auth/__init__.py | 5 + synapse/handlers/ui_auth/checkers.py | 65 ++ synapse/res/templates/registration_token.html | 23 + synapse/rest/admin/__init__.py | 8 + synapse/rest/admin/registration_tokens.py | 321 ++++++++++ synapse/rest/client/auth.py | 24 + synapse/rest/client/register.py | 72 +++ synapse/storage/databases/main/registration.py | 316 +++++++++ synapse/storage/databases/main/ui_auth.py | 43 ++ .../main/delta/63/01create_registration_tokens.sql | 23 + tests/rest/admin/test_registration_tokens.py | 710 +++++++++++++++++++++ tests/rest/client/test_register.py | 434 +++++++++++++ 21 files changed, 2389 insertions(+), 1 deletion(-) create mode 100644 changelog.d/10142.feature create mode 100644 docs/usage/administration/admin_api/registration_tokens.md create mode 100644 synapse/res/templates/registration_token.html create mode 100644 synapse/rest/admin/registration_tokens.py create mode 100644 synapse/storage/schema/main/delta/63/01create_registration_tokens.sql create mode 100644 tests/rest/admin/test_registration_tokens.py (limited to 'synapse/rest/client/auth.py') diff --git a/changelog.d/10142.feature b/changelog.d/10142.feature new file mode 100644 index 0000000000..5353f6269d --- /dev/null +++ b/changelog.d/10142.feature @@ -0,0 +1 @@ +Add support for [MSC3231 - Token authenticated registration](https://github.com/matrix-org/matrix-doc/pull/3231). Users can be required to submit a token during registration to authenticate themselves. Contributed by Callum Brown. diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index 634bc833ab..4fcd2b7852 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -53,6 +53,7 @@ - [Media](admin_api/media_admin_api.md) - [Purge History](admin_api/purge_history_api.md) - [Register Users](admin_api/register_api.md) + - [Registration Tokens](usage/administration/admin_api/registration_tokens.md) - [Manipulate Room Membership](admin_api/room_membership.md) - [Rooms](admin_api/rooms.md) - [Server Notices](admin_api/server_notices.md) diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 2b0c453242..935841dbfa 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -793,6 +793,8 @@ log_config: "CONFDIR/SERVERNAME.log.config" # is using # - one for registration that ratelimits registration requests based on the # client's IP address. +# - one for checking the validity of registration tokens that ratelimits +# requests based on the client's IP address. # - one for login that ratelimits login requests based on the client's IP # address. # - one for login that ratelimits login requests based on the account the @@ -821,6 +823,10 @@ log_config: "CONFDIR/SERVERNAME.log.config" # per_second: 0.17 # burst_count: 3 # +#rc_registration_token_validity: +# per_second: 0.1 +# burst_count: 5 +# #rc_login: # address: # per_second: 0.17 @@ -1169,6 +1175,15 @@ url_preview_accept_language: # #enable_3pid_lookup: true +# Require users to submit a token during registration. +# Tokens can be managed using the admin API: +# https://matrix-org.github.io/synapse/latest/usage/administration/admin_api/registration_tokens.html +# Note that `enable_registration` must be set to `true`. +# Disabling this option will not delete any tokens previously generated. +# Defaults to false. Uncomment the following to require tokens: +# +#registration_requires_token: true + # If set, allows registration of standard or admin accounts by anyone who # has the shared secret, even if registration is otherwise disabled. # diff --git a/docs/usage/administration/admin_api/registration_tokens.md b/docs/usage/administration/admin_api/registration_tokens.md new file mode 100644 index 0000000000..828c0277d6 --- /dev/null +++ b/docs/usage/administration/admin_api/registration_tokens.md @@ -0,0 +1,295 @@ +# Registration Tokens + +This API allows you to manage tokens which can be used to authenticate +registration requests, as proposed in [MSC3231](https://github.com/govynnus/matrix-doc/blob/token-registration/proposals/3231-token-authenticated-registration.md). +To use it, you will need to enable the `registration_requires_token` config +option, and authenticate by providing an `access_token` for a server admin: +see [Admin API](../../usage/administration/admin_api). +Note that this API is still experimental; not all clients may support it yet. + + +## Registration token objects + +Most endpoints make use of JSON objects that contain details about tokens. +These objects have the following fields: +- `token`: The token which can be used to authenticate registration. +- `uses_allowed`: The number of times the token can be used to complete a + registration before it becomes invalid. +- `pending`: The number of pending uses the token has. When someone uses + the token to authenticate themselves, the pending counter is incremented + so that the token is not used more than the permitted number of times. + When the person completes registration the pending counter is decremented, + and the completed counter is incremented. +- `completed`: The number of times the token has been used to successfully + complete a registration. +- `expiry_time`: The latest time the token is valid. Given as the number of + milliseconds since 1970-01-01 00:00:00 UTC (the start of the Unix epoch). + To convert this into a human-readable form you can remove the milliseconds + and use the `date` command. For example, `date -d '@1625394937'`. + + +## List all tokens + +Lists all tokens and details about them. If the request is successful, the top +level JSON object will have a `registration_tokens` key which is an array of +registration token objects. + +``` +GET /_synapse/admin/v1/registration_tokens +``` + +Optional query parameters: +- `valid`: `true` or `false`. If `true`, only valid tokens are returned. + If `false`, only tokens that have expired or have had all uses exhausted are + returned. If omitted, all tokens are returned regardless of validity. + +Example: + +``` +GET /_synapse/admin/v1/registration_tokens +``` +``` +200 OK + +{ + "registration_tokens": [ + { + "token": "abcd", + "uses_allowed": 3, + "pending": 0, + "completed": 1, + "expiry_time": null + }, + { + "token": "pqrs", + "uses_allowed": 2, + "pending": 1, + "completed": 1, + "expiry_time": null + }, + { + "token": "wxyz", + "uses_allowed": null, + "pending": 0, + "completed": 9, + "expiry_time": 1625394937000 // 2021-07-04 10:35:37 UTC + } + ] +} +``` + +Example using the `valid` query parameter: + +``` +GET /_synapse/admin/v1/registration_tokens?valid=false +``` +``` +200 OK + +{ + "registration_tokens": [ + { + "token": "pqrs", + "uses_allowed": 2, + "pending": 1, + "completed": 1, + "expiry_time": null + }, + { + "token": "wxyz", + "uses_allowed": null, + "pending": 0, + "completed": 9, + "expiry_time": 1625394937000 // 2021-07-04 10:35:37 UTC + } + ] +} +``` + + +## Get one token + +Get details about a single token. If the request is successful, the response +body will be a registration token object. + +``` +GET /_synapse/admin/v1/registration_tokens/ +``` + +Path parameters: +- `token`: The registration token to return details of. + +Example: + +``` +GET /_synapse/admin/v1/registration_tokens/abcd +``` +``` +200 OK + +{ + "token": "abcd", + "uses_allowed": 3, + "pending": 0, + "completed": 1, + "expiry_time": null +} +``` + + +## Create token + +Create a new registration token. If the request is successful, the newly created +token will be returned as a registration token object in the response body. + +``` +POST /_synapse/admin/v1/registration_tokens/new +``` + +The request body must be a JSON object and can contain the following fields: +- `token`: The registration token. A string of no more than 64 characters that + consists only of characters matched by the regex `[A-Za-z0-9-_]`. + Default: randomly generated. +- `uses_allowed`: The integer number of times the token can be used to complete + a registration before it becomes invalid. + Default: `null` (unlimited uses). +- `expiry_time`: The latest time the token is valid. Given as the number of + milliseconds since 1970-01-01 00:00:00 UTC (the start of the Unix epoch). + You could use, for example, `date '+%s000' -d 'tomorrow'`. + Default: `null` (token does not expire). +- `length`: The length of the token randomly generated if `token` is not + specified. Must be between 1 and 64 inclusive. Default: `16`. + +If a field is omitted the default is used. + +Example using defaults: + +``` +POST /_synapse/admin/v1/registration_tokens/new + +{} +``` +``` +200 OK + +{ + "token": "0M-9jbkf2t_Tgiw1", + "uses_allowed": null, + "pending": 0, + "completed": 0, + "expiry_time": null +} +``` + +Example specifying some fields: + +``` +POST /_synapse/admin/v1/registration_tokens/new + +{ + "token": "defg", + "uses_allowed": 1 +} +``` +``` +200 OK + +{ + "token": "defg", + "uses_allowed": 1, + "pending": 0, + "completed": 0, + "expiry_time": null +} +``` + + +## Update token + +Update the number of allowed uses or expiry time of a token. If the request is +successful, the updated token will be returned as a registration token object +in the response body. + +``` +PUT /_synapse/admin/v1/registration_tokens/ +``` + +Path parameters: +- `token`: The registration token to update. + +The request body must be a JSON object and can contain the following fields: +- `uses_allowed`: The integer number of times the token can be used to complete + a registration before it becomes invalid. By setting `uses_allowed` to `0` + the token can be easily made invalid without deleting it. + If `null` the token will have an unlimited number of uses. +- `expiry_time`: The latest time the token is valid. Given as the number of + milliseconds since 1970-01-01 00:00:00 UTC (the start of the Unix epoch). + If `null` the token will not expire. + +If a field is omitted its value is not modified. + +Example: + +``` +PUT /_synapse/admin/v1/registration_tokens/defg + +{ + "expiry_time": 4781243146000 // 2121-07-06 11:05:46 UTC +} +``` +``` +200 OK + +{ + "token": "defg", + "uses_allowed": 1, + "pending": 0, + "completed": 0, + "expiry_time": 4781243146000 +} +``` + + +## Delete token + +Delete a registration token. If the request is successful, the response body +will be an empty JSON object. + +``` +DELETE /_synapse/admin/v1/registration_tokens/ +``` + +Path parameters: +- `token`: The registration token to delete. + +Example: + +``` +DELETE /_synapse/admin/v1/registration_tokens/wxyz +``` +``` +200 OK + +{} +``` + + +## Errors + +If a request fails a "standard error response" will be returned as defined in +the [Matrix Client-Server API specification](https://matrix.org/docs/spec/client_server/r0.6.1#api-standards). + +For example, if the token specified in a path parameter does not exist a +`404 Not Found` error will be returned. + +``` +GET /_synapse/admin/v1/registration_tokens/1234 +``` +``` +404 Not Found + +{ + "errcode": "M_NOT_FOUND", + "error": "No such registration token: 1234" +} +``` diff --git a/docs/workers.md b/docs/workers.md index 2e63f03452..3121241894 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -236,6 +236,7 @@ expressions: # Registration/login requests ^/_matrix/client/(api/v1|r0|unstable)/login$ ^/_matrix/client/(r0|unstable)/register$ + ^/_matrix/client/unstable/org.matrix.msc3231/register/org.matrix.msc3231.login.registration_token/validity$ # Event sending requests ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/redact diff --git a/synapse/api/constants.py b/synapse/api/constants.py index e0e24fddac..829061c870 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -79,6 +79,7 @@ class LoginType: TERMS = "m.login.terms" SSO = "m.login.sso" DUMMY = "m.login.dummy" + REGISTRATION_TOKEN = "org.matrix.msc3231.login.registration_token" # This is used in the `type` parameter for /register when called by diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 845e6a8220..fd2626dbe1 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -95,7 +95,10 @@ from synapse.rest.client.profile import ( ProfileRestServlet, ) from synapse.rest.client.push_rule import PushRuleRestServlet -from synapse.rest.client.register import RegisterRestServlet +from synapse.rest.client.register import ( + RegisterRestServlet, + RegistrationTokenValidityRestServlet, +) from synapse.rest.client.sendtodevice import SendToDeviceRestServlet from synapse.rest.client.versions import VersionsRestServlet from synapse.rest.client.voip import VoipRestServlet @@ -279,6 +282,7 @@ class GenericWorkerServer(HomeServer): resource = JsonResource(self, canonical_json=False) RegisterRestServlet(self).register(resource) + RegistrationTokenValidityRestServlet(self).register(resource) login.register_servlets(self, resource) ThreepidRestServlet(self).register(resource) DevicesRestServlet(self).register(resource) diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py index 7a8d5851c4..f856327bd8 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py @@ -79,6 +79,11 @@ class RatelimitConfig(Config): self.rc_registration = RateLimitConfig(config.get("rc_registration", {})) + self.rc_registration_token_validity = RateLimitConfig( + config.get("rc_registration_token_validity", {}), + defaults={"per_second": 0.1, "burst_count": 5}, + ) + rc_login_config = config.get("rc_login", {}) self.rc_login_address = RateLimitConfig(rc_login_config.get("address", {})) self.rc_login_account = RateLimitConfig(rc_login_config.get("account", {})) @@ -143,6 +148,8 @@ class RatelimitConfig(Config): # is using # - one for registration that ratelimits registration requests based on the # client's IP address. + # - one for checking the validity of registration tokens that ratelimits + # requests based on the client's IP address. # - one for login that ratelimits login requests based on the client's IP # address. # - one for login that ratelimits login requests based on the account the @@ -171,6 +178,10 @@ class RatelimitConfig(Config): # per_second: 0.17 # burst_count: 3 # + #rc_registration_token_validity: + # per_second: 0.1 + # burst_count: 5 + # #rc_login: # address: # per_second: 0.17 diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 0ad919b139..7cffdacfa5 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -33,6 +33,9 @@ class RegistrationConfig(Config): self.registrations_require_3pid = config.get("registrations_require_3pid", []) self.allowed_local_3pids = config.get("allowed_local_3pids", []) self.enable_3pid_lookup = config.get("enable_3pid_lookup", True) + self.registration_requires_token = config.get( + "registration_requires_token", False + ) self.registration_shared_secret = config.get("registration_shared_secret") self.bcrypt_rounds = config.get("bcrypt_rounds", 12) @@ -140,6 +143,9 @@ class RegistrationConfig(Config): "mechanism by removing the `access_token_lifetime` option." ) + # The fallback template used for authenticating using a registration token + self.registration_token_template = self.read_template("registration_token.html") + # The success template used during fallback auth. self.fallback_success_template = self.read_template("auth_success.html") @@ -199,6 +205,15 @@ class RegistrationConfig(Config): # #enable_3pid_lookup: true + # Require users to submit a token during registration. + # Tokens can be managed using the admin API: + # https://matrix-org.github.io/synapse/latest/usage/administration/admin_api/registration_tokens.html + # Note that `enable_registration` must be set to `true`. + # Disabling this option will not delete any tokens previously generated. + # Defaults to false. Uncomment the following to require tokens: + # + #registration_requires_token: true + # If set, allows registration of standard or admin accounts by anyone who # has the shared secret, even if registration is otherwise disabled. # diff --git a/synapse/handlers/ui_auth/__init__.py b/synapse/handlers/ui_auth/__init__.py index 4c3b669fae..13b0c61d2e 100644 --- a/synapse/handlers/ui_auth/__init__.py +++ b/synapse/handlers/ui_auth/__init__.py @@ -34,3 +34,8 @@ class UIAuthSessionDataConstants: # used by validate_user_via_ui_auth to store the mxid of the user we are validating # for. REQUEST_USER_ID = "request_user_id" + + # used during registration to store the registration token used (if required) so that: + # - we can prevent a token being used twice by one session + # - we can 'use up' the token after registration has successfully completed + REGISTRATION_TOKEN = "org.matrix.msc3231.login.registration_token" diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py index 270541cc76..d3828dec6b 100644 --- a/synapse/handlers/ui_auth/checkers.py +++ b/synapse/handlers/ui_auth/checkers.py @@ -241,11 +241,76 @@ class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker): return await self._check_threepid("msisdn", authdict) +class RegistrationTokenAuthChecker(UserInteractiveAuthChecker): + AUTH_TYPE = LoginType.REGISTRATION_TOKEN + + def __init__(self, hs: "HomeServer"): + super().__init__(hs) + self.hs = hs + self._enabled = bool(hs.config.registration_requires_token) + self.store = hs.get_datastore() + + def is_enabled(self) -> bool: + return self._enabled + + async def check_auth(self, authdict: dict, clientip: str) -> Any: + if "token" not in authdict: + raise LoginError(400, "Missing registration token", Codes.MISSING_PARAM) + if not isinstance(authdict["token"], str): + raise LoginError( + 400, "Registration token must be a string", Codes.INVALID_PARAM + ) + if "session" not in authdict: + raise LoginError(400, "Missing UIA session", Codes.MISSING_PARAM) + + # Get these here to avoid cyclic dependencies + from synapse.handlers.ui_auth import UIAuthSessionDataConstants + + auth_handler = self.hs.get_auth_handler() + + session = authdict["session"] + token = authdict["token"] + + # If the LoginType.REGISTRATION_TOKEN stage has already been completed, + # return early to avoid incrementing `pending` again. + stored_token = await auth_handler.get_session_data( + session, UIAuthSessionDataConstants.REGISTRATION_TOKEN + ) + if stored_token: + if token != stored_token: + raise LoginError( + 400, "Registration token has changed", Codes.INVALID_PARAM + ) + else: + return token + + if await self.store.registration_token_is_valid(token): + # Increment pending counter, so that if token has limited uses it + # can't be used up by someone else in the meantime. + await self.store.set_registration_token_pending(token) + # Store the token in the UIA session, so that once registration + # is complete `completed` can be incremented. + await auth_handler.set_session_data( + session, + UIAuthSessionDataConstants.REGISTRATION_TOKEN, + token, + ) + # The token will be stored as the result of the authentication stage + # in ui_auth_sessions_credentials. This allows the pending counter + # for tokens to be decremented when expired sessions are deleted. + return token + else: + raise LoginError( + 401, "Invalid registration token", errcode=Codes.UNAUTHORIZED + ) + + INTERACTIVE_AUTH_CHECKERS = [ DummyAuthChecker, TermsAuthChecker, RecaptchaAuthChecker, EmailIdentityAuthChecker, MsisdnAuthChecker, + RegistrationTokenAuthChecker, ] """A list of UserInteractiveAuthChecker classes""" diff --git a/synapse/res/templates/registration_token.html b/synapse/res/templates/registration_token.html new file mode 100644 index 0000000000..4577ce1702 --- /dev/null +++ b/synapse/res/templates/registration_token.html @@ -0,0 +1,23 @@ + + +Authentication + + + + +
+
+ {% if error is defined %} +

Error: {{ error }}

+ {% endif %} +

+ Please enter a registration token. +

+ + + +
+
+ + diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 7f3051aef1..6e1c8736e1 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -36,6 +36,11 @@ from synapse.rest.admin.event_reports import ( ) from synapse.rest.admin.groups import DeleteGroupAdminRestServlet from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo +from synapse.rest.admin.registration_tokens import ( + ListRegistrationTokensRestServlet, + NewRegistrationTokenRestServlet, + RegistrationTokenRestServlet, +) from synapse.rest.admin.rooms import ( DeleteRoomRestServlet, ForwardExtremitiesRestServlet, @@ -238,6 +243,9 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: RoomEventContextServlet(hs).register(http_server) RateLimitRestServlet(hs).register(http_server) UsernameAvailableRestServlet(hs).register(http_server) + ListRegistrationTokensRestServlet(hs).register(http_server) + NewRegistrationTokenRestServlet(hs).register(http_server) + RegistrationTokenRestServlet(hs).register(http_server) def register_servlets_for_client_rest_resource( diff --git a/synapse/rest/admin/registration_tokens.py b/synapse/rest/admin/registration_tokens.py new file mode 100644 index 0000000000..5a1c929d85 --- /dev/null +++ b/synapse/rest/admin/registration_tokens.py @@ -0,0 +1,321 @@ +# Copyright 2021 Callum Brown +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import string +from typing import TYPE_CHECKING, Tuple + +from synapse.api.errors import Codes, NotFoundError, SynapseError +from synapse.http.servlet import ( + RestServlet, + parse_boolean, + parse_json_object_from_request, +) +from synapse.http.site import SynapseRequest +from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin +from synapse.types import JsonDict + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class ListRegistrationTokensRestServlet(RestServlet): + """List registration tokens. + + To list all tokens: + + GET /_synapse/admin/v1/registration_tokens + + 200 OK + + { + "registration_tokens": [ + { + "token": "abcd", + "uses_allowed": 3, + "pending": 0, + "completed": 1, + "expiry_time": null + }, + { + "token": "wxyz", + "uses_allowed": null, + "pending": 0, + "completed": 9, + "expiry_time": 1625394937000 + } + ] + } + + The optional query parameter `valid` can be used to filter the response. + If it is `true`, only valid tokens are returned. If it is `false`, only + tokens that have expired or have had all uses exhausted are returned. + If it is omitted, all tokens are returned regardless of validity. + """ + + PATTERNS = admin_patterns("/registration_tokens$") + + def __init__(self, hs: "HomeServer"): + self.hs = hs + self.auth = hs.get_auth() + self.store = hs.get_datastore() + + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + await assert_requester_is_admin(self.auth, request) + valid = parse_boolean(request, "valid") + token_list = await self.store.get_registration_tokens(valid) + return 200, {"registration_tokens": token_list} + + +class NewRegistrationTokenRestServlet(RestServlet): + """Create a new registration token. + + For example, to create a token specifying some fields: + + POST /_synapse/admin/v1/registration_tokens/new + + { + "token": "defg", + "uses_allowed": 1 + } + + 200 OK + + { + "token": "defg", + "uses_allowed": 1, + "pending": 0, + "completed": 0, + "expiry_time": null + } + + Defaults are used for any fields not specified. + """ + + PATTERNS = admin_patterns("/registration_tokens/new$") + + def __init__(self, hs: "HomeServer"): + self.hs = hs + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.clock = hs.get_clock() + # A string of all the characters allowed to be in a registration_token + self.allowed_chars = string.ascii_letters + string.digits + "-_" + self.allowed_chars_set = set(self.allowed_chars) + + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + await assert_requester_is_admin(self.auth, request) + body = parse_json_object_from_request(request) + + if "token" in body: + token = body["token"] + if not isinstance(token, str): + raise SynapseError(400, "token must be a string", Codes.INVALID_PARAM) + if not (0 < len(token) <= 64): + raise SynapseError( + 400, + "token must not be empty and must not be longer than 64 characters", + Codes.INVALID_PARAM, + ) + if not set(token).issubset(self.allowed_chars_set): + raise SynapseError( + 400, + "token must consist only of characters matched by the regex [A-Za-z0-9-_]", + Codes.INVALID_PARAM, + ) + + else: + # Get length of token to generate (default is 16) + length = body.get("length", 16) + if not isinstance(length, int): + raise SynapseError( + 400, "length must be an integer", Codes.INVALID_PARAM + ) + if not (0 < length <= 64): + raise SynapseError( + 400, + "length must be greater than zero and not greater than 64", + Codes.INVALID_PARAM, + ) + + # Generate token + token = await self.store.generate_registration_token( + length, self.allowed_chars + ) + + uses_allowed = body.get("uses_allowed", None) + if not ( + uses_allowed is None + or (isinstance(uses_allowed, int) and uses_allowed >= 0) + ): + raise SynapseError( + 400, + "uses_allowed must be a non-negative integer or null", + Codes.INVALID_PARAM, + ) + + expiry_time = body.get("expiry_time", None) + if not isinstance(expiry_time, (int, type(None))): + raise SynapseError( + 400, "expiry_time must be an integer or null", Codes.INVALID_PARAM + ) + if isinstance(expiry_time, int) and expiry_time < self.clock.time_msec(): + raise SynapseError( + 400, "expiry_time must not be in the past", Codes.INVALID_PARAM + ) + + created = await self.store.create_registration_token( + token, uses_allowed, expiry_time + ) + if not created: + raise SynapseError( + 400, f"Token already exists: {token}", Codes.INVALID_PARAM + ) + + resp = { + "token": token, + "uses_allowed": uses_allowed, + "pending": 0, + "completed": 0, + "expiry_time": expiry_time, + } + return 200, resp + + +class RegistrationTokenRestServlet(RestServlet): + """Retrieve, update, or delete the given token. + + For example, + + to retrieve a token: + + GET /_synapse/admin/v1/registration_tokens/abcd + + 200 OK + + { + "token": "abcd", + "uses_allowed": 3, + "pending": 0, + "completed": 1, + "expiry_time": null + } + + + to update a token: + + PUT /_synapse/admin/v1/registration_tokens/defg + + { + "uses_allowed": 5, + "expiry_time": 4781243146000 + } + + 200 OK + + { + "token": "defg", + "uses_allowed": 5, + "pending": 0, + "completed": 0, + "expiry_time": 4781243146000 + } + + + to delete a token: + + DELETE /_synapse/admin/v1/registration_tokens/wxyz + + 200 OK + + {} + """ + + PATTERNS = admin_patterns("/registration_tokens/(?P[^/]*)$") + + def __init__(self, hs: "HomeServer"): + self.hs = hs + self.clock = hs.get_clock() + self.auth = hs.get_auth() + self.store = hs.get_datastore() + + async def on_GET(self, request: SynapseRequest, token: str) -> Tuple[int, JsonDict]: + """Retrieve a registration token.""" + await assert_requester_is_admin(self.auth, request) + token_info = await self.store.get_one_registration_token(token) + + # If no result return a 404 + if token_info is None: + raise NotFoundError(f"No such registration token: {token}") + + return 200, token_info + + async def on_PUT(self, request: SynapseRequest, token: str) -> Tuple[int, JsonDict]: + """Update a registration token.""" + await assert_requester_is_admin(self.auth, request) + body = parse_json_object_from_request(request) + new_attributes = {} + + # Only add uses_allowed to new_attributes if it is present and valid + if "uses_allowed" in body: + uses_allowed = body["uses_allowed"] + if not ( + uses_allowed is None + or (isinstance(uses_allowed, int) and uses_allowed >= 0) + ): + raise SynapseError( + 400, + "uses_allowed must be a non-negative integer or null", + Codes.INVALID_PARAM, + ) + new_attributes["uses_allowed"] = uses_allowed + + if "expiry_time" in body: + expiry_time = body["expiry_time"] + if not isinstance(expiry_time, (int, type(None))): + raise SynapseError( + 400, "expiry_time must be an integer or null", Codes.INVALID_PARAM + ) + if isinstance(expiry_time, int) and expiry_time < self.clock.time_msec(): + raise SynapseError( + 400, "expiry_time must not be in the past", Codes.INVALID_PARAM + ) + new_attributes["expiry_time"] = expiry_time + + if len(new_attributes) == 0: + # Nothing to update, get token info to return + token_info = await self.store.get_one_registration_token(token) + else: + token_info = await self.store.update_registration_token( + token, new_attributes + ) + + # If no result return a 404 + if token_info is None: + raise NotFoundError(f"No such registration token: {token}") + + return 200, token_info + + async def on_DELETE( + self, request: SynapseRequest, token: str + ) -> Tuple[int, JsonDict]: + """Delete a registration token.""" + await assert_requester_is_admin(self.auth, request) + + if await self.store.delete_registration_token(token): + return 200, {} + + raise NotFoundError(f"No such registration token: {token}") diff --git a/synapse/rest/client/auth.py b/synapse/rest/client/auth.py index 73284e48ec..91800c0278 100644 --- a/synapse/rest/client/auth.py +++ b/synapse/rest/client/auth.py @@ -46,6 +46,7 @@ class AuthRestServlet(RestServlet): self.registration_handler = hs.get_registration_handler() self.recaptcha_template = hs.config.recaptcha_template self.terms_template = hs.config.terms_template + self.registration_token_template = hs.config.registration_token_template self.success_template = hs.config.fallback_success_template async def on_GET(self, request, stagetype): @@ -74,6 +75,12 @@ class AuthRestServlet(RestServlet): # re-authenticate with their SSO provider. html = await self.auth_handler.start_sso_ui_auth(request, session) + elif stagetype == LoginType.REGISTRATION_TOKEN: + html = self.registration_token_template.render( + session=session, + myurl=f"{CLIENT_API_PREFIX}/r0/auth/{LoginType.REGISTRATION_TOKEN}/fallback/web", + ) + else: raise SynapseError(404, "Unknown auth stage type") @@ -140,6 +147,23 @@ class AuthRestServlet(RestServlet): # The SSO fallback workflow should not post here, raise SynapseError(404, "Fallback SSO auth does not support POST requests.") + elif stagetype == LoginType.REGISTRATION_TOKEN: + token = parse_string(request, "token", required=True) + authdict = {"session": session, "token": token} + + try: + await self.auth_handler.add_oob_auth( + LoginType.REGISTRATION_TOKEN, authdict, request.getClientIP() + ) + except LoginError as e: + html = self.registration_token_template.render( + session=session, + myurl=f"{CLIENT_API_PREFIX}/r0/auth/{LoginType.REGISTRATION_TOKEN}/fallback/web", + error=e.msg, + ) + else: + html = self.success_template.render() + else: raise SynapseError(404, "Unknown auth stage type") diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index 58b8e8f261..2781a0ea96 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -28,6 +28,7 @@ from synapse.api.errors import ( ThreepidValidationError, UnrecognizedRequestError, ) +from synapse.api.ratelimiting import Ratelimiter from synapse.config import ConfigError from synapse.config.captcha import CaptchaConfig from synapse.config.consent import ConsentConfig @@ -379,6 +380,55 @@ class UsernameAvailabilityRestServlet(RestServlet): return 200, {"available": True} +class RegistrationTokenValidityRestServlet(RestServlet): + """Check the validity of a registration token. + + Example: + + GET /_matrix/client/unstable/org.matrix.msc3231/register/org.matrix.msc3231.login.registration_token/validity?token=abcd + + 200 OK + + { + "valid": true + } + """ + + PATTERNS = client_patterns( + f"/org.matrix.msc3231/register/{LoginType.REGISTRATION_TOKEN}/validity", + releases=(), + unstable=True, + ) + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super().__init__() + self.hs = hs + self.store = hs.get_datastore() + self.ratelimiter = Ratelimiter( + store=self.store, + clock=hs.get_clock(), + rate_hz=hs.config.ratelimiting.rc_registration_token_validity.per_second, + burst_count=hs.config.ratelimiting.rc_registration_token_validity.burst_count, + ) + + async def on_GET(self, request): + await self.ratelimiter.ratelimit(None, (request.getClientIP(),)) + + if not self.hs.config.enable_registration: + raise SynapseError( + 403, "Registration has been disabled", errcode=Codes.FORBIDDEN + ) + + token = parse_string(request, "token", required=True) + valid = await self.store.registration_token_is_valid(token) + + return 200, {"valid": valid} + + class RegisterRestServlet(RestServlet): PATTERNS = client_patterns("/register$") @@ -686,6 +736,22 @@ class RegisterRestServlet(RestServlet): ) if registered: + # Check if a token was used to authenticate registration + registration_token = await self.auth_handler.get_session_data( + session_id, + UIAuthSessionDataConstants.REGISTRATION_TOKEN, + ) + if registration_token: + # Increment the `completed` counter for the token + await self.store.use_registration_token(registration_token) + # Indicate that the token has been successfully used so that + # pending is not decremented again when expiring old UIA sessions. + await self.store.mark_ui_auth_stage_complete( + session_id, + LoginType.REGISTRATION_TOKEN, + True, + ) + await self.registration_handler.post_registration_actions( user_id=registered_user_id, auth_result=auth_result, @@ -868,6 +934,11 @@ def _calculate_registration_flows( for flow in flows: flow.insert(0, LoginType.RECAPTCHA) + # Prepend registration token to all flows if we're requiring a token + if config.registration_requires_token: + for flow in flows: + flow.insert(0, LoginType.REGISTRATION_TOKEN) + return flows @@ -876,4 +947,5 @@ def register_servlets(hs, http_server): MsisdnRegisterRequestTokenRestServlet(hs).register(http_server) UsernameAvailabilityRestServlet(hs).register(http_server) RegistrationSubmitTokenServlet(hs).register(http_server) + RegistrationTokenValidityRestServlet(hs).register(http_server) RegisterRestServlet(hs).register(http_server) diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 469dd53e0c..a6517962f6 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -1168,6 +1168,322 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): desc="update_access_token_last_validated", ) + async def registration_token_is_valid(self, token: str) -> bool: + """Checks if a token can be used to authenticate a registration. + + Args: + token: The registration token to be checked + Returns: + True if the token is valid, False otherwise. + """ + res = await self.db_pool.simple_select_one( + "registration_tokens", + keyvalues={"token": token}, + retcols=["uses_allowed", "pending", "completed", "expiry_time"], + allow_none=True, + ) + + # Check if the token exists + if res is None: + return False + + # Check if the token has expired + now = self._clock.time_msec() + if res["expiry_time"] and res["expiry_time"] < now: + return False + + # Check if the token has been used up + if ( + res["uses_allowed"] + and res["pending"] + res["completed"] >= res["uses_allowed"] + ): + return False + + # Otherwise, the token is valid + return True + + async def set_registration_token_pending(self, token: str) -> None: + """Increment the pending registrations counter for a token. + + Args: + token: The registration token pending use + """ + + def _set_registration_token_pending_txn(txn): + pending = self.db_pool.simple_select_one_onecol_txn( + txn, + "registration_tokens", + keyvalues={"token": token}, + retcol="pending", + ) + self.db_pool.simple_update_one_txn( + txn, + "registration_tokens", + keyvalues={"token": token}, + updatevalues={"pending": pending + 1}, + ) + + return await self.db_pool.runInteraction( + "set_registration_token_pending", _set_registration_token_pending_txn + ) + + async def use_registration_token(self, token: str) -> None: + """Complete a use of the given registration token. + + The `pending` counter will be decremented, and the `completed` + counter will be incremented. + + Args: + token: The registration token to be 'used' + """ + + def _use_registration_token_txn(txn): + # Normally, res is Optional[Dict[str, Any]]. + # Override type because the return type is only optional if + # allow_none is True, and we don't want mypy throwing errors + # about None not being indexable. + res: Dict[str, Any] = self.db_pool.simple_select_one_txn( + txn, + "registration_tokens", + keyvalues={"token": token}, + retcols=["pending", "completed"], + ) # type: ignore + + # Decrement pending and increment completed + self.db_pool.simple_update_one_txn( + txn, + "registration_tokens", + keyvalues={"token": token}, + updatevalues={ + "completed": res["completed"] + 1, + "pending": res["pending"] - 1, + }, + ) + + return await self.db_pool.runInteraction( + "use_registration_token", _use_registration_token_txn + ) + + async def get_registration_tokens( + self, valid: Optional[bool] = None + ) -> List[Dict[str, Any]]: + """List all registration tokens. Used by the admin API. + + Args: + valid: If True, only valid tokens are returned. + If False, only invalid tokens are returned. + Default is None: return all tokens regardless of validity. + + Returns: + A list of dicts, each containing details of a token. + """ + + def select_registration_tokens_txn(txn, now: int, valid: Optional[bool]): + if valid is None: + # Return all tokens regardless of validity + txn.execute("SELECT * FROM registration_tokens") + + elif valid: + # Select valid tokens only + sql = ( + "SELECT * FROM registration_tokens WHERE " + "(uses_allowed > pending + completed OR uses_allowed IS NULL) " + "AND (expiry_time > ? OR expiry_time IS NULL)" + ) + txn.execute(sql, [now]) + + else: + # Select invalid tokens only + sql = ( + "SELECT * FROM registration_tokens WHERE " + "uses_allowed <= pending + completed OR expiry_time <= ?" + ) + txn.execute(sql, [now]) + + return self.db_pool.cursor_to_dict(txn) + + return await self.db_pool.runInteraction( + "select_registration_tokens", + select_registration_tokens_txn, + self._clock.time_msec(), + valid, + ) + + async def get_one_registration_token(self, token: str) -> Optional[Dict[str, Any]]: + """Get info about the given registration token. Used by the admin API. + + Args: + token: The token to retrieve information about. + + Returns: + A dict, or None if token doesn't exist. + """ + return await self.db_pool.simple_select_one( + "registration_tokens", + keyvalues={"token": token}, + retcols=["token", "uses_allowed", "pending", "completed", "expiry_time"], + allow_none=True, + desc="get_one_registration_token", + ) + + async def generate_registration_token( + self, length: int, chars: str + ) -> Optional[str]: + """Generate a random registration token. Used by the admin API. + + Args: + length: The length of the token to generate. + chars: A string of the characters allowed in the generated token. + + Returns: + The generated token. + + Raises: + SynapseError if a unique registration token could still not be + generated after a few tries. + """ + # Make a few attempts at generating a unique token of the required + # length before failing. + for _i in range(3): + # Generate token + token = "".join(random.choices(chars, k=length)) + + # Check if the token already exists + existing_token = await self.db_pool.simple_select_one_onecol( + "registration_tokens", + keyvalues={"token": token}, + retcol="token", + allow_none=True, + desc="check_if_registration_token_exists", + ) + + if existing_token is None: + # The generated token doesn't exist yet, return it + return token + + raise SynapseError( + 500, + "Unable to generate a unique registration token. Try again with a greater length", + Codes.UNKNOWN, + ) + + async def create_registration_token( + self, token: str, uses_allowed: Optional[int], expiry_time: Optional[int] + ) -> bool: + """Create a new registration token. Used by the admin API. + + Args: + token: The token to create. + uses_allowed: The number of times the token can be used to complete + a registration before it becomes invalid. A value of None indicates + unlimited uses. + expiry_time: The latest time the token is valid. Given as the + number of milliseconds since 1970-01-01 00:00:00 UTC. A value of + None indicates that the token does not expire. + + Returns: + Whether the row was inserted or not. + """ + + def _create_registration_token_txn(txn): + row = self.db_pool.simple_select_one_txn( + txn, + "registration_tokens", + keyvalues={"token": token}, + retcols=["token"], + allow_none=True, + ) + + if row is not None: + # Token already exists + return False + + self.db_pool.simple_insert_txn( + txn, + "registration_tokens", + values={ + "token": token, + "uses_allowed": uses_allowed, + "pending": 0, + "completed": 0, + "expiry_time": expiry_time, + }, + ) + + return True + + return await self.db_pool.runInteraction( + "create_registration_token", _create_registration_token_txn + ) + + async def update_registration_token( + self, token: str, updatevalues: Dict[str, Optional[int]] + ) -> Optional[Dict[str, Any]]: + """Update a registration token. Used by the admin API. + + Args: + token: The token to update. + updatevalues: A dict with the fields to update. E.g.: + `{"uses_allowed": 3}` to update just uses_allowed, or + `{"uses_allowed": 3, "expiry_time": None}` to update both. + This is passed straight to simple_update_one. + + Returns: + A dict with all info about the token, or None if token doesn't exist. + """ + + def _update_registration_token_txn(txn): + try: + self.db_pool.simple_update_one_txn( + txn, + "registration_tokens", + keyvalues={"token": token}, + updatevalues=updatevalues, + ) + except StoreError: + # Update failed because token does not exist + return None + + # Get all info about the token so it can be sent in the response + return self.db_pool.simple_select_one_txn( + txn, + "registration_tokens", + keyvalues={"token": token}, + retcols=[ + "token", + "uses_allowed", + "pending", + "completed", + "expiry_time", + ], + allow_none=True, + ) + + return await self.db_pool.runInteraction( + "update_registration_token", _update_registration_token_txn + ) + + async def delete_registration_token(self, token: str) -> bool: + """Delete a registration token. Used by the admin API. + + Args: + token: The token to delete. + + Returns: + Whether the token was successfully deleted or not. + """ + try: + await self.db_pool.simple_delete_one( + "registration_tokens", + keyvalues={"token": token}, + desc="delete_registration_token", + ) + except StoreError: + # Deletion failed because token does not exist + return False + + return True + @cached() async def mark_access_token_as_used(self, token_id: int) -> None: """ diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py index 38bfdf5dad..4d6bbc94c7 100644 --- a/synapse/storage/databases/main/ui_auth.py +++ b/synapse/storage/databases/main/ui_auth.py @@ -15,6 +15,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union import attr +from synapse.api.constants import LoginType from synapse.api.errors import StoreError from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import LoggingTransaction @@ -329,6 +330,48 @@ class UIAuthWorkerStore(SQLBaseStore): keyvalues={}, ) + # If a registration token was used, decrement the pending counter + # before deleting the session. + rows = self.db_pool.simple_select_many_txn( + txn, + table="ui_auth_sessions_credentials", + column="session_id", + iterable=session_ids, + keyvalues={"stage_type": LoginType.REGISTRATION_TOKEN}, + retcols=["result"], + ) + + # Get the tokens used and how much pending needs to be decremented by. + token_counts: Dict[str, int] = {} + for r in rows: + # If registration was successfully completed, the result of the + # registration token stage for that session will be True. + # If a token was used to authenticate, but registration was + # never completed, the result will be the token used. + token = db_to_json(r["result"]) + if isinstance(token, str): + token_counts[token] = token_counts.get(token, 0) + 1 + + # Update the `pending` counters. + if len(token_counts) > 0: + token_rows = self.db_pool.simple_select_many_txn( + txn, + table="registration_tokens", + column="token", + iterable=list(token_counts.keys()), + keyvalues={}, + retcols=["token", "pending"], + ) + for token_row in token_rows: + token = token_row["token"] + new_pending = token_row["pending"] - token_counts[token] + self.db_pool.simple_update_one_txn( + txn, + table="registration_tokens", + keyvalues={"token": token}, + updatevalues={"pending": new_pending}, + ) + # Delete the corresponding completed credentials. self.db_pool.simple_delete_many_txn( txn, diff --git a/synapse/storage/schema/main/delta/63/01create_registration_tokens.sql b/synapse/storage/schema/main/delta/63/01create_registration_tokens.sql new file mode 100644 index 0000000000..ee6cf958f4 --- /dev/null +++ b/synapse/storage/schema/main/delta/63/01create_registration_tokens.sql @@ -0,0 +1,23 @@ +/* Copyright 2021 Callum Brown + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS registration_tokens( + token TEXT NOT NULL, -- The token that can be used for authentication. + uses_allowed INT, -- The total number of times this token can be used. NULL if no limit. + pending INT NOT NULL, -- The number of in progress registrations using this token. + completed INT NOT NULL, -- The number of times this token has been used to complete a registration. + expiry_time BIGINT, -- The latest time this token will be valid (epoch time in milliseconds). NULL if token doesn't expire. + UNIQUE (token) +); diff --git a/tests/rest/admin/test_registration_tokens.py b/tests/rest/admin/test_registration_tokens.py new file mode 100644 index 0000000000..4927321e5a --- /dev/null +++ b/tests/rest/admin/test_registration_tokens.py @@ -0,0 +1,710 @@ +# Copyright 2021 Callum Brown +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import string + +import synapse.rest.admin +from synapse.api.errors import Codes +from synapse.rest.client import login + +from tests import unittest + + +class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.other_user_tok = self.login("user", "pass") + + self.url = "/_synapse/admin/v1/registration_tokens" + + def _new_token(self, **kwargs): + """Helper function to create a token.""" + token = kwargs.get( + "token", + "".join(random.choices(string.ascii_letters, k=8)), + ) + self.get_success( + self.store.db_pool.simple_insert( + "registration_tokens", + { + "token": token, + "uses_allowed": kwargs.get("uses_allowed", None), + "pending": kwargs.get("pending", 0), + "completed": kwargs.get("completed", 0), + "expiry_time": kwargs.get("expiry_time", None), + }, + ) + ) + return token + + # CREATION + + def test_create_no_auth(self): + """Try to create a token without authentication.""" + channel = self.make_request("POST", self.url + "/new", {}) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_create_requester_not_admin(self): + """Try to create a token while not an admin.""" + channel = self.make_request( + "POST", + self.url + "/new", + {}, + access_token=self.other_user_tok, + ) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_create_using_defaults(self): + """Create a token using all the defaults.""" + channel = self.make_request( + "POST", + self.url + "/new", + {}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(len(channel.json_body["token"]), 16) + self.assertIsNone(channel.json_body["uses_allowed"]) + self.assertIsNone(channel.json_body["expiry_time"]) + self.assertEqual(channel.json_body["pending"], 0) + self.assertEqual(channel.json_body["completed"], 0) + + def test_create_specifying_fields(self): + """Create a token specifying the value of all fields.""" + data = { + "token": "abcd", + "uses_allowed": 1, + "expiry_time": self.clock.time_msec() + 1000000, + } + + channel = self.make_request( + "POST", + self.url + "/new", + data, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["token"], "abcd") + self.assertEqual(channel.json_body["uses_allowed"], 1) + self.assertEqual(channel.json_body["expiry_time"], data["expiry_time"]) + self.assertEqual(channel.json_body["pending"], 0) + self.assertEqual(channel.json_body["completed"], 0) + + def test_create_with_null_value(self): + """Create a token specifying unlimited uses and no expiry.""" + data = { + "uses_allowed": None, + "expiry_time": None, + } + + channel = self.make_request( + "POST", + self.url + "/new", + data, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(len(channel.json_body["token"]), 16) + self.assertIsNone(channel.json_body["uses_allowed"]) + self.assertIsNone(channel.json_body["expiry_time"]) + self.assertEqual(channel.json_body["pending"], 0) + self.assertEqual(channel.json_body["completed"], 0) + + def test_create_token_too_long(self): + """Check token longer than 64 chars is invalid.""" + data = {"token": "a" * 65} + + channel = self.make_request( + "POST", + self.url + "/new", + data, + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + def test_create_token_invalid_chars(self): + """Check you can't create token with invalid characters.""" + data = { + "token": "abc/def", + } + + channel = self.make_request( + "POST", + self.url + "/new", + data, + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + def test_create_token_already_exists(self): + """Check you can't create token that already exists.""" + data = { + "token": "abcd", + } + + channel1 = self.make_request( + "POST", + self.url + "/new", + data, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, int(channel1.result["code"]), msg=channel1.result["body"]) + + channel2 = self.make_request( + "POST", + self.url + "/new", + data, + access_token=self.admin_user_tok, + ) + self.assertEqual(400, int(channel2.result["code"]), msg=channel2.result["body"]) + self.assertEqual(channel2.json_body["errcode"], Codes.INVALID_PARAM) + + def test_create_unable_to_generate_token(self): + """Check right error is raised when server can't generate unique token.""" + # Create all possible single character tokens + tokens = [] + for c in string.ascii_letters + string.digits + "-_": + tokens.append( + { + "token": c, + "uses_allowed": None, + "pending": 0, + "completed": 0, + "expiry_time": None, + } + ) + self.get_success( + self.store.db_pool.simple_insert_many( + "registration_tokens", + tokens, + "create_all_registration_tokens", + ) + ) + + # Check creating a single character token fails with a 500 status code + channel = self.make_request( + "POST", + self.url + "/new", + {"length": 1}, + access_token=self.admin_user_tok, + ) + self.assertEqual(500, int(channel.result["code"]), msg=channel.result["body"]) + + def test_create_uses_allowed(self): + """Check you can only create a token with good values for uses_allowed.""" + # Should work with 0 (token is invalid from the start) + channel = self.make_request( + "POST", + self.url + "/new", + {"uses_allowed": 0}, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["uses_allowed"], 0) + + # Should fail with negative integer + channel = self.make_request( + "POST", + self.url + "/new", + {"uses_allowed": -5}, + access_token=self.admin_user_tok, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + # Should fail with float + channel = self.make_request( + "POST", + self.url + "/new", + {"uses_allowed": 1.5}, + access_token=self.admin_user_tok, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + def test_create_expiry_time(self): + """Check you can't create a token with an invalid expiry_time.""" + # Should fail with a time in the past + channel = self.make_request( + "POST", + self.url + "/new", + {"expiry_time": self.clock.time_msec() - 10000}, + access_token=self.admin_user_tok, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + # Should fail with float + channel = self.make_request( + "POST", + self.url + "/new", + {"expiry_time": self.clock.time_msec() + 1000000.5}, + access_token=self.admin_user_tok, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + def test_create_length(self): + """Check you can only generate a token with a valid length.""" + # Should work with 64 + channel = self.make_request( + "POST", + self.url + "/new", + {"length": 64}, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(len(channel.json_body["token"]), 64) + + # Should fail with 0 + channel = self.make_request( + "POST", + self.url + "/new", + {"length": 0}, + access_token=self.admin_user_tok, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + # Should fail with a negative integer + channel = self.make_request( + "POST", + self.url + "/new", + {"length": -5}, + access_token=self.admin_user_tok, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + # Should fail with a float + channel = self.make_request( + "POST", + self.url + "/new", + {"length": 8.5}, + access_token=self.admin_user_tok, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + # Should fail with 65 + channel = self.make_request( + "POST", + self.url + "/new", + {"length": 65}, + access_token=self.admin_user_tok, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + # UPDATING + + def test_update_no_auth(self): + """Try to update a token without authentication.""" + channel = self.make_request( + "PUT", + self.url + "/1234", # Token doesn't exist but that doesn't matter + {}, + ) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_update_requester_not_admin(self): + """Try to update a token while not an admin.""" + channel = self.make_request( + "PUT", + self.url + "/1234", # Token doesn't exist but that doesn't matter + {}, + access_token=self.other_user_tok, + ) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_update_non_existent(self): + """Try to update a token that doesn't exist.""" + channel = self.make_request( + "PUT", + self.url + "/1234", + {"uses_allowed": 1}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) + + def test_update_uses_allowed(self): + """Test updating just uses_allowed.""" + # Create new token using default values + token = self._new_token() + + # Should succeed with 1 + channel = self.make_request( + "PUT", + self.url + "/" + token, + {"uses_allowed": 1}, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["uses_allowed"], 1) + self.assertIsNone(channel.json_body["expiry_time"]) + + # Should succeed with 0 (makes token invalid) + channel = self.make_request( + "PUT", + self.url + "/" + token, + {"uses_allowed": 0}, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["uses_allowed"], 0) + self.assertIsNone(channel.json_body["expiry_time"]) + + # Should succeed with null + channel = self.make_request( + "PUT", + self.url + "/" + token, + {"uses_allowed": None}, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertIsNone(channel.json_body["uses_allowed"]) + self.assertIsNone(channel.json_body["expiry_time"]) + + # Should fail with a float + channel = self.make_request( + "PUT", + self.url + "/" + token, + {"uses_allowed": 1.5}, + access_token=self.admin_user_tok, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + # Should fail with a negative integer + channel = self.make_request( + "PUT", + self.url + "/" + token, + {"uses_allowed": -5}, + access_token=self.admin_user_tok, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + def test_update_expiry_time(self): + """Test updating just expiry_time.""" + # Create new token using default values + token = self._new_token() + new_expiry_time = self.clock.time_msec() + 1000000 + + # Should succeed with a time in the future + channel = self.make_request( + "PUT", + self.url + "/" + token, + {"expiry_time": new_expiry_time}, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["expiry_time"], new_expiry_time) + self.assertIsNone(channel.json_body["uses_allowed"]) + + # Should succeed with null + channel = self.make_request( + "PUT", + self.url + "/" + token, + {"expiry_time": None}, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertIsNone(channel.json_body["expiry_time"]) + self.assertIsNone(channel.json_body["uses_allowed"]) + + # Should fail with a time in the past + past_time = self.clock.time_msec() - 10000 + channel = self.make_request( + "PUT", + self.url + "/" + token, + {"expiry_time": past_time}, + access_token=self.admin_user_tok, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + # Should fail a float + channel = self.make_request( + "PUT", + self.url + "/" + token, + {"expiry_time": new_expiry_time + 0.5}, + access_token=self.admin_user_tok, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + def test_update_both(self): + """Test updating both uses_allowed and expiry_time.""" + # Create new token using default values + token = self._new_token() + new_expiry_time = self.clock.time_msec() + 1000000 + + data = { + "uses_allowed": 1, + "expiry_time": new_expiry_time, + } + + channel = self.make_request( + "PUT", + self.url + "/" + token, + data, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["uses_allowed"], 1) + self.assertEqual(channel.json_body["expiry_time"], new_expiry_time) + + def test_update_invalid_type(self): + """Test using invalid types doesn't work.""" + # Create new token using default values + token = self._new_token() + + data = { + "uses_allowed": False, + "expiry_time": "1626430124000", + } + + channel = self.make_request( + "PUT", + self.url + "/" + token, + data, + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + # DELETING + + def test_delete_no_auth(self): + """Try to delete a token without authentication.""" + channel = self.make_request( + "DELETE", + self.url + "/1234", # Token doesn't exist but that doesn't matter + {}, + ) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_delete_requester_not_admin(self): + """Try to delete a token while not an admin.""" + channel = self.make_request( + "DELETE", + self.url + "/1234", # Token doesn't exist but that doesn't matter + {}, + access_token=self.other_user_tok, + ) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_delete_non_existent(self): + """Try to delete a token that doesn't exist.""" + channel = self.make_request( + "DELETE", + self.url + "/1234", + {}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) + + def test_delete(self): + """Test deleting a token.""" + # Create new token using default values + token = self._new_token() + + channel = self.make_request( + "DELETE", + self.url + "/" + token, + {}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # GETTING ONE + + def test_get_no_auth(self): + """Try to get a token without authentication.""" + channel = self.make_request( + "GET", + self.url + "/1234", # Token doesn't exist but that doesn't matter + {}, + ) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_get_requester_not_admin(self): + """Try to get a token while not an admin.""" + channel = self.make_request( + "GET", + self.url + "/1234", # Token doesn't exist but that doesn't matter + {}, + access_token=self.other_user_tok, + ) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_get_non_existent(self): + """Try to get a token that doesn't exist.""" + channel = self.make_request( + "GET", + self.url + "/1234", + {}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) + + def test_get(self): + """Test getting a token.""" + # Create new token using default values + token = self._new_token() + + channel = self.make_request( + "GET", + self.url + "/" + token, + {}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["token"], token) + self.assertIsNone(channel.json_body["uses_allowed"]) + self.assertIsNone(channel.json_body["expiry_time"]) + self.assertEqual(channel.json_body["pending"], 0) + self.assertEqual(channel.json_body["completed"], 0) + + # LISTING + + def test_list_no_auth(self): + """Try to list tokens without authentication.""" + channel = self.make_request("GET", self.url, {}) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_list_requester_not_admin(self): + """Try to list tokens while not an admin.""" + channel = self.make_request( + "GET", + self.url, + {}, + access_token=self.other_user_tok, + ) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_list_all(self): + """Test listing all tokens.""" + # Create new token using default values + token = self._new_token() + + channel = self.make_request( + "GET", + self.url, + {}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(len(channel.json_body["registration_tokens"]), 1) + token_info = channel.json_body["registration_tokens"][0] + self.assertEqual(token_info["token"], token) + self.assertIsNone(token_info["uses_allowed"]) + self.assertIsNone(token_info["expiry_time"]) + self.assertEqual(token_info["pending"], 0) + self.assertEqual(token_info["completed"], 0) + + def test_list_invalid_query_parameter(self): + """Test with `valid` query parameter not `true` or `false`.""" + channel = self.make_request( + "GET", + self.url + "?valid=x", + {}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + + def _test_list_query_parameter(self, valid: str): + """Helper used to test both valid=true and valid=false.""" + # Create 2 valid and 2 invalid tokens. + now = self.hs.get_clock().time_msec() + # Create always valid token + valid1 = self._new_token() + # Create token that hasn't been used up + valid2 = self._new_token(uses_allowed=1) + # Create token that has expired + invalid1 = self._new_token(expiry_time=now - 10000) + # Create token that has been used up but hasn't expired + invalid2 = self._new_token( + uses_allowed=2, + pending=1, + completed=1, + expiry_time=now + 1000000, + ) + + if valid == "true": + tokens = [valid1, valid2] + else: + tokens = [invalid1, invalid2] + + channel = self.make_request( + "GET", + self.url + "?valid=" + valid, + {}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(len(channel.json_body["registration_tokens"]), 2) + token_info_1 = channel.json_body["registration_tokens"][0] + token_info_2 = channel.json_body["registration_tokens"][1] + self.assertIn(token_info_1["token"], tokens) + self.assertIn(token_info_2["token"], tokens) + + def test_list_valid(self): + """Test listing just valid tokens.""" + self._test_list_query_parameter(valid="true") + + def test_list_invalid(self): + """Test listing just invalid tokens.""" + self._test_list_query_parameter(valid="false") diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py index fecda037a5..9f3ab2c985 100644 --- a/tests/rest/client/test_register.py +++ b/tests/rest/client/test_register.py @@ -24,6 +24,7 @@ from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType from synapse.api.errors import Codes from synapse.appservice import ApplicationService from synapse.rest.client import account, account_validity, login, logout, register, sync +from synapse.storage._base import db_to_json from tests import unittest from tests.unittest import override_config @@ -204,6 +205,371 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"200", channel.result) + @override_config({"registration_requires_token": True}) + def test_POST_registration_requires_token(self): + username = "kermit" + device_id = "frogfone" + token = "abcd" + store = self.hs.get_datastore() + self.get_success( + store.db_pool.simple_insert( + "registration_tokens", + { + "token": token, + "uses_allowed": None, + "pending": 0, + "completed": 0, + "expiry_time": None, + }, + ) + ) + params = { + "username": username, + "password": "monkey", + "device_id": device_id, + } + + # Request without auth to get flows and session + channel = self.make_request(b"POST", self.url, json.dumps(params)) + self.assertEquals(channel.result["code"], b"401", channel.result) + flows = channel.json_body["flows"] + # Synapse adds a dummy stage to differentiate flows where otherwise one + # flow would be a subset of another flow. + self.assertCountEqual( + [[LoginType.REGISTRATION_TOKEN, LoginType.DUMMY]], + (f["stages"] for f in flows), + ) + session = channel.json_body["session"] + + # Do the registration token stage and check it has completed + params["auth"] = { + "type": LoginType.REGISTRATION_TOKEN, + "token": token, + "session": session, + } + request_data = json.dumps(params) + channel = self.make_request(b"POST", self.url, request_data) + self.assertEquals(channel.result["code"], b"401", channel.result) + completed = channel.json_body["completed"] + self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed) + + # Do the m.login.dummy stage and check registration was successful + params["auth"] = { + "type": LoginType.DUMMY, + "session": session, + } + request_data = json.dumps(params) + channel = self.make_request(b"POST", self.url, request_data) + det_data = { + "user_id": f"@{username}:{self.hs.hostname}", + "home_server": self.hs.hostname, + "device_id": device_id, + } + self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertDictContainsSubset(det_data, channel.json_body) + + # Check the `completed` counter has been incremented and pending is 0 + res = self.get_success( + store.db_pool.simple_select_one( + "registration_tokens", + keyvalues={"token": token}, + retcols=["pending", "completed"], + ) + ) + self.assertEquals(res["completed"], 1) + self.assertEquals(res["pending"], 0) + + @override_config({"registration_requires_token": True}) + def test_POST_registration_token_invalid(self): + params = { + "username": "kermit", + "password": "monkey", + } + # Request without auth to get session + channel = self.make_request(b"POST", self.url, json.dumps(params)) + session = channel.json_body["session"] + + # Test with token param missing (invalid) + params["auth"] = { + "type": LoginType.REGISTRATION_TOKEN, + "session": session, + } + channel = self.make_request(b"POST", self.url, json.dumps(params)) + self.assertEquals(channel.result["code"], b"401", channel.result) + self.assertEquals(channel.json_body["errcode"], Codes.MISSING_PARAM) + self.assertEquals(channel.json_body["completed"], []) + + # Test with non-string (invalid) + params["auth"]["token"] = 1234 + channel = self.make_request(b"POST", self.url, json.dumps(params)) + self.assertEquals(channel.result["code"], b"401", channel.result) + self.assertEquals(channel.json_body["errcode"], Codes.INVALID_PARAM) + self.assertEquals(channel.json_body["completed"], []) + + # Test with unknown token (invalid) + params["auth"]["token"] = "1234" + channel = self.make_request(b"POST", self.url, json.dumps(params)) + self.assertEquals(channel.result["code"], b"401", channel.result) + self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED) + self.assertEquals(channel.json_body["completed"], []) + + @override_config({"registration_requires_token": True}) + def test_POST_registration_token_limit_uses(self): + token = "abcd" + store = self.hs.get_datastore() + # Create token that can be used once + self.get_success( + store.db_pool.simple_insert( + "registration_tokens", + { + "token": token, + "uses_allowed": 1, + "pending": 0, + "completed": 0, + "expiry_time": None, + }, + ) + ) + params1 = {"username": "bert", "password": "monkey"} + params2 = {"username": "ernie", "password": "monkey"} + # Do 2 requests without auth to get two session IDs + channel1 = self.make_request(b"POST", self.url, json.dumps(params1)) + session1 = channel1.json_body["session"] + channel2 = self.make_request(b"POST", self.url, json.dumps(params2)) + session2 = channel2.json_body["session"] + + # Use token with session1 and check `pending` is 1 + params1["auth"] = { + "type": LoginType.REGISTRATION_TOKEN, + "token": token, + "session": session1, + } + self.make_request(b"POST", self.url, json.dumps(params1)) + # Repeat request to make sure pending isn't increased again + self.make_request(b"POST", self.url, json.dumps(params1)) + pending = self.get_success( + store.db_pool.simple_select_one_onecol( + "registration_tokens", + keyvalues={"token": token}, + retcol="pending", + ) + ) + self.assertEquals(pending, 1) + + # Check auth fails when using token with session2 + params2["auth"] = { + "type": LoginType.REGISTRATION_TOKEN, + "token": token, + "session": session2, + } + channel = self.make_request(b"POST", self.url, json.dumps(params2)) + self.assertEquals(channel.result["code"], b"401", channel.result) + self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED) + self.assertEquals(channel.json_body["completed"], []) + + # Complete registration with session1 + params1["auth"]["type"] = LoginType.DUMMY + self.make_request(b"POST", self.url, json.dumps(params1)) + # Check pending=0 and completed=1 + res = self.get_success( + store.db_pool.simple_select_one( + "registration_tokens", + keyvalues={"token": token}, + retcols=["pending", "completed"], + ) + ) + self.assertEquals(res["pending"], 0) + self.assertEquals(res["completed"], 1) + + # Check auth still fails when using token with session2 + channel = self.make_request(b"POST", self.url, json.dumps(params2)) + self.assertEquals(channel.result["code"], b"401", channel.result) + self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED) + self.assertEquals(channel.json_body["completed"], []) + + @override_config({"registration_requires_token": True}) + def test_POST_registration_token_expiry(self): + token = "abcd" + now = self.hs.get_clock().time_msec() + store = self.hs.get_datastore() + # Create token that expired yesterday + self.get_success( + store.db_pool.simple_insert( + "registration_tokens", + { + "token": token, + "uses_allowed": None, + "pending": 0, + "completed": 0, + "expiry_time": now - 24 * 60 * 60 * 1000, + }, + ) + ) + params = {"username": "kermit", "password": "monkey"} + # Request without auth to get session + channel = self.make_request(b"POST", self.url, json.dumps(params)) + session = channel.json_body["session"] + + # Check authentication fails with expired token + params["auth"] = { + "type": LoginType.REGISTRATION_TOKEN, + "token": token, + "session": session, + } + channel = self.make_request(b"POST", self.url, json.dumps(params)) + self.assertEquals(channel.result["code"], b"401", channel.result) + self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED) + self.assertEquals(channel.json_body["completed"], []) + + # Update token so it expires tomorrow + self.get_success( + store.db_pool.simple_update_one( + "registration_tokens", + keyvalues={"token": token}, + updatevalues={"expiry_time": now + 24 * 60 * 60 * 1000}, + ) + ) + + # Check authentication succeeds + channel = self.make_request(b"POST", self.url, json.dumps(params)) + completed = channel.json_body["completed"] + self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed) + + @override_config({"registration_requires_token": True}) + def test_POST_registration_token_session_expiry(self): + """Test `pending` is decremented when an uncompleted session expires.""" + token = "abcd" + store = self.hs.get_datastore() + self.get_success( + store.db_pool.simple_insert( + "registration_tokens", + { + "token": token, + "uses_allowed": None, + "pending": 0, + "completed": 0, + "expiry_time": None, + }, + ) + ) + + # Do 2 requests without auth to get two session IDs + params1 = {"username": "bert", "password": "monkey"} + params2 = {"username": "ernie", "password": "monkey"} + channel1 = self.make_request(b"POST", self.url, json.dumps(params1)) + session1 = channel1.json_body["session"] + channel2 = self.make_request(b"POST", self.url, json.dumps(params2)) + session2 = channel2.json_body["session"] + + # Use token with both sessions + params1["auth"] = { + "type": LoginType.REGISTRATION_TOKEN, + "token": token, + "session": session1, + } + self.make_request(b"POST", self.url, json.dumps(params1)) + + params2["auth"] = { + "type": LoginType.REGISTRATION_TOKEN, + "token": token, + "session": session2, + } + self.make_request(b"POST", self.url, json.dumps(params2)) + + # Complete registration with session1 + params1["auth"]["type"] = LoginType.DUMMY + self.make_request(b"POST", self.url, json.dumps(params1)) + + # Check `result` of registration token stage for session1 is `True` + result1 = self.get_success( + store.db_pool.simple_select_one_onecol( + "ui_auth_sessions_credentials", + keyvalues={ + "session_id": session1, + "stage_type": LoginType.REGISTRATION_TOKEN, + }, + retcol="result", + ) + ) + self.assertTrue(db_to_json(result1)) + + # Check `result` for session2 is the token used + result2 = self.get_success( + store.db_pool.simple_select_one_onecol( + "ui_auth_sessions_credentials", + keyvalues={ + "session_id": session2, + "stage_type": LoginType.REGISTRATION_TOKEN, + }, + retcol="result", + ) + ) + self.assertEquals(db_to_json(result2), token) + + # Delete both sessions (mimics expiry) + self.get_success( + store.delete_old_ui_auth_sessions(self.hs.get_clock().time_msec()) + ) + + # Check pending is now 0 + pending = self.get_success( + store.db_pool.simple_select_one_onecol( + "registration_tokens", + keyvalues={"token": token}, + retcol="pending", + ) + ) + self.assertEquals(pending, 0) + + @override_config({"registration_requires_token": True}) + def test_POST_registration_token_session_expiry_deleted_token(self): + """Test session expiry doesn't break when the token is deleted. + + 1. Start but don't complete UIA with a registration token + 2. Delete the token from the database + 3. Expire the session + """ + token = "abcd" + store = self.hs.get_datastore() + self.get_success( + store.db_pool.simple_insert( + "registration_tokens", + { + "token": token, + "uses_allowed": None, + "pending": 0, + "completed": 0, + "expiry_time": None, + }, + ) + ) + + # Do request without auth to get a session ID + params = {"username": "kermit", "password": "monkey"} + channel = self.make_request(b"POST", self.url, json.dumps(params)) + session = channel.json_body["session"] + + # Use token + params["auth"] = { + "type": LoginType.REGISTRATION_TOKEN, + "token": token, + "session": session, + } + self.make_request(b"POST", self.url, json.dumps(params)) + + # Delete token + self.get_success( + store.db_pool.simple_delete_one( + "registration_tokens", + keyvalues={"token": token}, + ) + ) + + # Delete session (mimics expiry) + self.get_success( + store.delete_old_ui_auth_sessions(self.hs.get_clock().time_msec()) + ) + def test_advertised_flows(self): channel = self.make_request(b"POST", self.url, b"{}") self.assertEquals(channel.result["code"], b"401", channel.result) @@ -744,3 +1110,71 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase): self.assertGreaterEqual(res, now_ms + self.validity_period - self.max_delta) self.assertLessEqual(res, now_ms + self.validity_period) + + +class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase): + servlets = [register.register_servlets] + url = "/_matrix/client/unstable/org.matrix.msc3231/register/org.matrix.msc3231.login.registration_token/validity" + + def default_config(self): + config = super().default_config() + config["registration_requires_token"] = True + return config + + def test_GET_token_valid(self): + token = "abcd" + store = self.hs.get_datastore() + self.get_success( + store.db_pool.simple_insert( + "registration_tokens", + { + "token": token, + "uses_allowed": None, + "pending": 0, + "completed": 0, + "expiry_time": None, + }, + ) + ) + + channel = self.make_request( + b"GET", + f"{self.url}?token={token}", + ) + self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEquals(channel.json_body["valid"], True) + + def test_GET_token_invalid(self): + token = "1234" + channel = self.make_request( + b"GET", + f"{self.url}?token={token}", + ) + self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEquals(channel.json_body["valid"], False) + + @override_config( + {"rc_registration_token_validity": {"per_second": 0.1, "burst_count": 5}} + ) + def test_GET_ratelimiting(self): + token = "1234" + + for i in range(0, 6): + channel = self.make_request( + b"GET", + f"{self.url}?token={token}", + ) + + if i == 5: + self.assertEquals(channel.result["code"], b"429", channel.result) + retry_after_ms = int(channel.json_body["retry_after_ms"]) + else: + self.assertEquals(channel.result["code"], b"200", channel.result) + + self.reactor.advance(retry_after_ms / 1000.0 + 1.0) + + channel = self.make_request( + b"GET", + f"{self.url}?token={token}", + ) + self.assertEquals(channel.result["code"], b"200", channel.result) -- cgit 1.5.1 From 1aa0dad02187c3b972187f5952cfbce336b0ca5c Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 26 Aug 2021 07:53:52 -0400 Subject: Additional type hints for REST servlets (part 2). (#10674) Applies the changes from #10665 to additional modules. --- changelog.d/10674.misc | 1 + synapse/handlers/presence.py | 5 +++ synapse/rest/client/auth.py | 11 ++++--- synapse/rest/client/devices.py | 48 +++++++++++++++------------- synapse/rest/client/events.py | 38 ++++++++++++++--------- synapse/rest/client/filter.py | 26 +++++++++++----- synapse/rest/client/groups.py | 3 +- synapse/rest/client/initial_sync.py | 16 +++++++--- synapse/rest/client/keys.py | 57 +++++++++++++--------------------- synapse/rest/client/knock.py | 3 +- synapse/rest/client/login.py | 21 ++++++------- synapse/rest/client/logout.py | 17 +++++++--- synapse/rest/client/notifications.py | 13 ++++++-- synapse/rest/client/openid.py | 16 +++++++--- synapse/rest/client/password_policy.py | 18 ++++++----- synapse/rest/client/presence.py | 24 +++++++++----- synapse/rest/client/profile.py | 37 ++++++++++++++++------ 17 files changed, 216 insertions(+), 138 deletions(-) create mode 100644 changelog.d/10674.misc (limited to 'synapse/rest/client/auth.py') diff --git a/changelog.d/10674.misc b/changelog.d/10674.misc new file mode 100644 index 0000000000..39a37b90b1 --- /dev/null +++ b/changelog.d/10674.misc @@ -0,0 +1 @@ +Add missing type hints to REST servlets. diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 7ca14e1d84..4418d63df7 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -353,6 +353,11 @@ class BasePresenceHandler(abc.ABC): # otherwise would not do). await self.set_state(UserID.from_string(user_id), state, force_notify=True) + async def is_visible(self, observed_user: UserID, observer_user: UserID) -> bool: + raise NotImplementedError( + "Attempting to check presence on a non-presence worker." + ) + class _NullContextManager(ContextManager[None]): """A context manager which does nothing.""" diff --git a/synapse/rest/client/auth.py b/synapse/rest/client/auth.py index 91800c0278..df8cc4ac7a 100644 --- a/synapse/rest/client/auth.py +++ b/synapse/rest/client/auth.py @@ -15,11 +15,14 @@ import logging from typing import TYPE_CHECKING +from twisted.web.server import Request + from synapse.api.constants import LoginType from synapse.api.errors import LoginError, SynapseError from synapse.api.urls import CLIENT_API_PREFIX -from synapse.http.server import respond_with_html +from synapse.http.server import HttpServer, respond_with_html from synapse.http.servlet import RestServlet, parse_string +from synapse.http.site import SynapseRequest from ._base import client_patterns @@ -49,7 +52,7 @@ class AuthRestServlet(RestServlet): self.registration_token_template = hs.config.registration_token_template self.success_template = hs.config.fallback_success_template - async def on_GET(self, request, stagetype): + async def on_GET(self, request: SynapseRequest, stagetype: str) -> None: session = parse_string(request, "session") if not session: raise SynapseError(400, "No session supplied") @@ -88,7 +91,7 @@ class AuthRestServlet(RestServlet): respond_with_html(request, 200, html) return None - async def on_POST(self, request, stagetype): + async def on_POST(self, request: Request, stagetype: str) -> None: session = parse_string(request, "session") if not session: @@ -172,5 +175,5 @@ class AuthRestServlet(RestServlet): return None -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: AuthRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py index 8b9674db06..25bc3c8f47 100644 --- a/synapse/rest/client/devices.py +++ b/synapse/rest/client/devices.py @@ -14,34 +14,36 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, Tuple from synapse.api import errors +from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, assert_params_in_dict, parse_json_object_from_request, ) from synapse.http.site import SynapseRequest +from synapse.types import JsonDict from ._base import client_patterns, interactive_auth_handler +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) class DevicesRestServlet(RestServlet): PATTERNS = client_patterns("/devices$") - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() - async def on_GET(self, request): + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) devices = await self.device_handler.get_devices_by_user( requester.user.to_string() @@ -57,7 +59,7 @@ class DeleteDevicesRestServlet(RestServlet): PATTERNS = client_patterns("/delete_devices") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.auth = hs.get_auth() @@ -65,7 +67,7 @@ class DeleteDevicesRestServlet(RestServlet): self.auth_handler = hs.get_auth_handler() @interactive_auth_handler - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) try: @@ -100,18 +102,16 @@ class DeleteDevicesRestServlet(RestServlet): class DeviceRestServlet(RestServlet): PATTERNS = client_patterns("/devices/(?P[^/]*)$") - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() self.auth_handler = hs.get_auth_handler() - async def on_GET(self, request, device_id): + async def on_GET( + self, request: SynapseRequest, device_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) device = await self.device_handler.get_device( requester.user.to_string(), device_id @@ -119,7 +119,9 @@ class DeviceRestServlet(RestServlet): return 200, device @interactive_auth_handler - async def on_DELETE(self, request, device_id): + async def on_DELETE( + self, request: SynapseRequest, device_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) try: @@ -146,7 +148,9 @@ class DeviceRestServlet(RestServlet): await self.device_handler.delete_device(requester.user.to_string(), device_id) return 200, {} - async def on_PUT(self, request, device_id): + async def on_PUT( + self, request: SynapseRequest, device_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) body = parse_json_object_from_request(request) @@ -193,13 +197,13 @@ class DehydratedDeviceServlet(RestServlet): PATTERNS = client_patterns("/org.matrix.msc2697.v2/dehydrated_device", releases=()) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() - async def on_GET(self, request: SynapseRequest): + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) dehydrated_device = await self.device_handler.get_dehydrated_device( requester.user.to_string() @@ -211,7 +215,7 @@ class DehydratedDeviceServlet(RestServlet): else: raise errors.NotFoundError("No dehydrated device available") - async def on_PUT(self, request: SynapseRequest): + async def on_PUT(self, request: SynapseRequest) -> Tuple[int, JsonDict]: submission = parse_json_object_from_request(request) requester = await self.auth.get_user_by_req(request) @@ -259,13 +263,13 @@ class ClaimDehydratedDeviceServlet(RestServlet): "/org.matrix.msc2697.v2/dehydrated_device/claim", releases=() ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() - async def on_POST(self, request: SynapseRequest): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) submission = parse_json_object_from_request(request) @@ -292,7 +296,7 @@ class ClaimDehydratedDeviceServlet(RestServlet): return (200, result) -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: DeleteDevicesRestServlet(hs).register(http_server) DevicesRestServlet(hs).register(http_server) DeviceRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/events.py b/synapse/rest/client/events.py index 52bb579cfd..13b72a045a 100644 --- a/synapse/rest/client/events.py +++ b/synapse/rest/client/events.py @@ -14,11 +14,18 @@ """This module contains REST servlets to do with event streaming, /events.""" import logging +from typing import TYPE_CHECKING, Dict, List, Tuple, Union from synapse.api.errors import SynapseError -from synapse.http.servlet import RestServlet +from synapse.http.server import HttpServer +from synapse.http.servlet import RestServlet, parse_string +from synapse.http.site import SynapseRequest from synapse.rest.client._base import client_patterns from synapse.streams.config import PaginationConfig +from synapse.types import JsonDict + +if TYPE_CHECKING: + from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -28,31 +35,30 @@ class EventStreamRestServlet(RestServlet): DEFAULT_LONGPOLL_TIME_MS = 30000 - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.event_stream_handler = hs.get_event_stream_handler() self.auth = hs.get_auth() self.store = hs.get_datastore() - async def on_GET(self, request): + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) is_guest = requester.is_guest - room_id = None + args: Dict[bytes, List[bytes]] = request.args # type: ignore if is_guest: - if b"room_id" not in request.args: + if b"room_id" not in args: raise SynapseError(400, "Guest users must specify room_id param") - if b"room_id" in request.args: - room_id = request.args[b"room_id"][0].decode("ascii") + room_id = parse_string(request, "room_id") pagin_config = await PaginationConfig.from_request(self.store, request) timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS - if b"timeout" in request.args: + if b"timeout" in args: try: - timeout = int(request.args[b"timeout"][0]) + timeout = int(args[b"timeout"][0]) except ValueError: raise SynapseError(400, "timeout must be in milliseconds.") - as_client_event = b"raw" not in request.args + as_client_event = b"raw" not in args chunk = await self.event_stream_handler.get_stream( requester.user.to_string(), @@ -70,25 +76,27 @@ class EventStreamRestServlet(RestServlet): class EventRestServlet(RestServlet): PATTERNS = client_patterns("/events/(?P[^/]*)$", v1=True) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.clock = hs.get_clock() self.event_handler = hs.get_event_handler() self.auth = hs.get_auth() self._event_serializer = hs.get_event_client_serializer() - async def on_GET(self, request, event_id): + async def on_GET( + self, request: SynapseRequest, event_id: str + ) -> Tuple[int, Union[str, JsonDict]]: requester = await self.auth.get_user_by_req(request) event = await self.event_handler.get_event(requester.user, None, event_id) time_now = self.clock.time_msec() if event: - event = await self._event_serializer.serialize_event(event, time_now) - return 200, event + result = await self._event_serializer.serialize_event(event, time_now) + return 200, result else: return 404, "Event not found." -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: EventStreamRestServlet(hs).register(http_server) EventRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/filter.py b/synapse/rest/client/filter.py index 411667a9c8..6ed60c7418 100644 --- a/synapse/rest/client/filter.py +++ b/synapse/rest/client/filter.py @@ -13,26 +13,34 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, Tuple from synapse.api.errors import AuthError, NotFoundError, StoreError, SynapseError +from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request -from synapse.types import UserID +from synapse.http.site import SynapseRequest +from synapse.types import JsonDict, UserID from ._base import client_patterns, set_timeline_upper_limit +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) class GetFilterRestServlet(RestServlet): PATTERNS = client_patterns("/user/(?P[^/]*)/filter/(?P[^/]*)") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.auth = hs.get_auth() self.filtering = hs.get_filtering() - async def on_GET(self, request, user_id, filter_id): + async def on_GET( + self, request: SynapseRequest, user_id: str, filter_id: str + ) -> Tuple[int, JsonDict]: target_user = UserID.from_string(user_id) requester = await self.auth.get_user_by_req(request) @@ -43,13 +51,13 @@ class GetFilterRestServlet(RestServlet): raise AuthError(403, "Can only get filters for local users") try: - filter_id = int(filter_id) + filter_id_int = int(filter_id) except Exception: raise SynapseError(400, "Invalid filter_id") try: filter_collection = await self.filtering.get_user_filter( - user_localpart=target_user.localpart, filter_id=filter_id + user_localpart=target_user.localpart, filter_id=filter_id_int ) except StoreError as e: if e.code != 404: @@ -62,13 +70,15 @@ class GetFilterRestServlet(RestServlet): class CreateFilterRestServlet(RestServlet): PATTERNS = client_patterns("/user/(?P[^/]*)/filter") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.auth = hs.get_auth() self.filtering = hs.get_filtering() - async def on_POST(self, request, user_id): + async def on_POST( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: target_user = UserID.from_string(user_id) requester = await self.auth.get_user_by_req(request) @@ -89,6 +99,6 @@ class CreateFilterRestServlet(RestServlet): return 200, {"filter_id": str(filter_id)} -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: GetFilterRestServlet(hs).register(http_server) CreateFilterRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/groups.py b/synapse/rest/client/groups.py index 6285680c00..c3667ff8aa 100644 --- a/synapse/rest/client/groups.py +++ b/synapse/rest/client/groups.py @@ -26,6 +26,7 @@ from synapse.api.constants import ( ) from synapse.api.errors import Codes, SynapseError from synapse.handlers.groups_local import GroupsLocalHandler +from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, assert_params_in_dict, @@ -930,7 +931,7 @@ class GroupsForUserServlet(RestServlet): return 200, result -def register_servlets(hs: "HomeServer", http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: GroupServlet(hs).register(http_server) GroupSummaryServlet(hs).register(http_server) GroupInvitedUsersServlet(hs).register(http_server) diff --git a/synapse/rest/client/initial_sync.py b/synapse/rest/client/initial_sync.py index 12ba0e91db..49b1037b28 100644 --- a/synapse/rest/client/initial_sync.py +++ b/synapse/rest/client/initial_sync.py @@ -12,25 +12,33 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import TYPE_CHECKING, Dict, List, Tuple +from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_boolean +from synapse.http.site import SynapseRequest from synapse.rest.client._base import client_patterns from synapse.streams.config import PaginationConfig +from synapse.types import JsonDict + +if TYPE_CHECKING: + from synapse.server import HomeServer # TODO: Needs unit testing class InitialSyncRestServlet(RestServlet): PATTERNS = client_patterns("/initialSync$", v1=True) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.initial_sync_handler = hs.get_initial_sync_handler() self.auth = hs.get_auth() self.store = hs.get_datastore() - async def on_GET(self, request): + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) - as_client_event = b"raw" not in request.args + args: Dict[bytes, List[bytes]] = request.args # type: ignore + as_client_event = b"raw" not in args pagination_config = await PaginationConfig.from_request(self.store, request) include_archived = parse_boolean(request, "archived", default=False) content = await self.initial_sync_handler.snapshot_all_rooms( @@ -43,5 +51,5 @@ class InitialSyncRestServlet(RestServlet): return 200, content -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: InitialSyncRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py index 012491f597..7281b2ee29 100644 --- a/synapse/rest/client/keys.py +++ b/synapse/rest/client/keys.py @@ -15,20 +15,25 @@ # limitations under the License. import logging -from typing import Any +from typing import TYPE_CHECKING, Any, Optional, Tuple from synapse.api.errors import InvalidAPICallError, SynapseError +from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, parse_integer, parse_json_object_from_request, parse_string, ) +from synapse.http.site import SynapseRequest from synapse.logging.opentracing import log_kv, set_tag, trace -from synapse.types import StreamToken +from synapse.types import JsonDict, StreamToken from ._base import client_patterns, interactive_auth_handler +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -60,18 +65,16 @@ class KeyUploadServlet(RestServlet): PATTERNS = client_patterns("/keys/upload(/(?P[^/]+))?$") - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.e2e_keys_handler = hs.get_e2e_keys_handler() self.device_handler = hs.get_device_handler() @trace(opname="upload_keys") - async def on_POST(self, request, device_id): + async def on_POST( + self, request: SynapseRequest, device_id: Optional[str] + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) user_id = requester.user.to_string() body = parse_json_object_from_request(request) @@ -149,16 +152,12 @@ class KeyQueryServlet(RestServlet): PATTERNS = client_patterns("/keys/query$") - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.e2e_keys_handler = hs.get_e2e_keys_handler() - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) user_id = requester.user.to_string() device_id = requester.device_id @@ -195,17 +194,13 @@ class KeyChangesServlet(RestServlet): PATTERNS = client_patterns("/keys/changes$") - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() self.store = hs.get_datastore() - async def on_GET(self, request): + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) from_token_string = parse_string(request, "from", required=True) @@ -245,12 +240,12 @@ class OneTimeKeyServlet(RestServlet): PATTERNS = client_patterns("/keys/claim$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.e2e_keys_handler = hs.get_e2e_keys_handler() - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await self.auth.get_user_by_req(request, allow_guest=True) timeout = parse_integer(request, "timeout", 10 * 1000) body = parse_json_object_from_request(request) @@ -269,11 +264,7 @@ class SigningKeyUploadServlet(RestServlet): PATTERNS = client_patterns("/keys/device_signing/upload$", releases=()) - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.auth = hs.get_auth() @@ -281,7 +272,7 @@ class SigningKeyUploadServlet(RestServlet): self.auth_handler = hs.get_auth_handler() @interactive_auth_handler - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() body = parse_json_object_from_request(request) @@ -329,16 +320,12 @@ class SignaturesUploadServlet(RestServlet): PATTERNS = client_patterns("/keys/signatures/upload$") - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.e2e_keys_handler = hs.get_e2e_keys_handler() - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) user_id = requester.user.to_string() body = parse_json_object_from_request(request) @@ -349,7 +336,7 @@ class SignaturesUploadServlet(RestServlet): return 200, result -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: KeyUploadServlet(hs).register(http_server) KeyQueryServlet(hs).register(http_server) KeyChangesServlet(hs).register(http_server) diff --git a/synapse/rest/client/knock.py b/synapse/rest/client/knock.py index 7d1bc40658..68fb08d0ba 100644 --- a/synapse/rest/client/knock.py +++ b/synapse/rest/client/knock.py @@ -19,6 +19,7 @@ from twisted.web.server import Request from synapse.api.constants import Membership from synapse.api.errors import SynapseError +from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, parse_json_object_from_request, @@ -103,5 +104,5 @@ class KnockRoomAliasServlet(RestServlet): ) -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: KnockRoomAliasServlet(hs).register(http_server) diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index 11d07776b2..4be502a77b 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -1,4 +1,4 @@ -# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2014-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ import logging import re -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional, Tuple from typing_extensions import TypedDict @@ -110,7 +110,7 @@ class LoginRestServlet(RestServlet): # counters are initialised for the auth_provider_ids. _load_sso_handlers(hs) - def on_GET(self, request: SynapseRequest): + def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: flows = [] if self.jwt_enabled: flows.append({"type": LoginRestServlet.JWT_TYPE}) @@ -157,7 +157,7 @@ class LoginRestServlet(RestServlet): return 200, {"flows": flows} - async def on_POST(self, request: SynapseRequest): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, LoginResponse]: login_submission = parse_json_object_from_request(request) if self._msc2918_enabled: @@ -217,7 +217,7 @@ class LoginRestServlet(RestServlet): login_submission: JsonDict, appservice: ApplicationService, should_issue_refresh_token: bool = False, - ): + ) -> LoginResponse: identifier = login_submission.get("identifier") logger.info("Got appservice login request with identifier: %r", identifier) @@ -467,10 +467,7 @@ class RefreshTokenServlet(RestServlet): self._clock = hs.get_clock() self.access_token_lifetime = hs.config.access_token_lifetime - async def on_POST( - self, - request: SynapseRequest, - ): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: refresh_submission = parse_json_object_from_request(request) assert_params_in_dict(refresh_submission, ["refresh_token"]) @@ -570,7 +567,7 @@ class SsoRedirectServlet(RestServlet): class CasTicketServlet(RestServlet): PATTERNS = client_patterns("/login/cas/ticket", v1=True) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self._cas_handler = hs.get_cas_handler() @@ -592,7 +589,7 @@ class CasTicketServlet(RestServlet): ) -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: LoginRestServlet(hs).register(http_server) if hs.config.access_token_lifetime is not None: RefreshTokenServlet(hs).register(http_server) @@ -601,7 +598,7 @@ def register_servlets(hs, http_server): CasTicketServlet(hs).register(http_server) -def _load_sso_handlers(hs: "HomeServer"): +def _load_sso_handlers(hs: "HomeServer") -> None: """Ensure that the SSO handlers are loaded, if they are enabled by configuration. This is mostly useful to ensure that the CAS/SAML/OIDC handlers register themselves diff --git a/synapse/rest/client/logout.py b/synapse/rest/client/logout.py index 6055cac2bd..193a6951b9 100644 --- a/synapse/rest/client/logout.py +++ b/synapse/rest/client/logout.py @@ -13,9 +13,16 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, Tuple +from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet +from synapse.http.site import SynapseRequest from synapse.rest.client._base import client_patterns +from synapse.types import JsonDict + +if TYPE_CHECKING: + from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -23,13 +30,13 @@ logger = logging.getLogger(__name__) class LogoutRestServlet(RestServlet): PATTERNS = client_patterns("/logout$", v1=True) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self._auth_handler = hs.get_auth_handler() self._device_handler = hs.get_device_handler() - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_expired=True) if requester.device_id is None: @@ -48,13 +55,13 @@ class LogoutRestServlet(RestServlet): class LogoutAllRestServlet(RestServlet): PATTERNS = client_patterns("/logout/all$", v1=True) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self._auth_handler = hs.get_auth_handler() self._device_handler = hs.get_device_handler() - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_expired=True) user_id = requester.user.to_string() @@ -67,6 +74,6 @@ class LogoutAllRestServlet(RestServlet): return 200, {} -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: LogoutRestServlet(hs).register(http_server) LogoutAllRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/notifications.py b/synapse/rest/client/notifications.py index 0ede643c2d..d1d8a984c6 100644 --- a/synapse/rest/client/notifications.py +++ b/synapse/rest/client/notifications.py @@ -13,26 +13,33 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, Tuple from synapse.events.utils import format_event_for_client_v2_without_room_id +from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_integer, parse_string +from synapse.http.site import SynapseRequest +from synapse.types import JsonDict from ._base import client_patterns +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) class NotificationsServlet(RestServlet): PATTERNS = client_patterns("/notifications$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.store = hs.get_datastore() self.auth = hs.get_auth() self.clock = hs.get_clock() self._event_serializer = hs.get_event_client_serializer() - async def on_GET(self, request): + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() @@ -87,5 +94,5 @@ class NotificationsServlet(RestServlet): return 200, {"notifications": returned_push_actions, "next_token": next_token} -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: NotificationsServlet(hs).register(http_server) diff --git a/synapse/rest/client/openid.py b/synapse/rest/client/openid.py index e8d2673819..4dda6dce4b 100644 --- a/synapse/rest/client/openid.py +++ b/synapse/rest/client/openid.py @@ -12,15 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. - import logging +from typing import TYPE_CHECKING, Tuple from synapse.api.errors import AuthError +from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.http.site import SynapseRequest +from synapse.types import JsonDict from synapse.util.stringutils import random_string from ._base import client_patterns +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -58,14 +64,16 @@ class IdTokenServlet(RestServlet): EXPIRES_MS = 3600 * 1000 - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() self.clock = hs.get_clock() self.server_name = hs.config.server_name - async def on_POST(self, request, user_id): + async def on_POST( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot request tokens for other users.") @@ -90,5 +98,5 @@ class IdTokenServlet(RestServlet): ) -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: IdTokenServlet(hs).register(http_server) diff --git a/synapse/rest/client/password_policy.py b/synapse/rest/client/password_policy.py index a83927aee6..6d64efb165 100644 --- a/synapse/rest/client/password_policy.py +++ b/synapse/rest/client/password_policy.py @@ -13,28 +13,32 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, Tuple +from twisted.web.server import Request + +from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet +from synapse.types import JsonDict from ._base import client_patterns +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) class PasswordPolicyServlet(RestServlet): PATTERNS = client_patterns("/password_policy$") - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.policy = hs.config.password_policy self.enabled = hs.config.password_policy_enabled - def on_GET(self, request): + def on_GET(self, request: Request) -> Tuple[int, JsonDict]: if not self.enabled or not self.policy: return (200, {}) @@ -53,5 +57,5 @@ class PasswordPolicyServlet(RestServlet): return (200, policy) -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: PasswordPolicyServlet(hs).register(http_server) diff --git a/synapse/rest/client/presence.py b/synapse/rest/client/presence.py index 6c27e5faf9..94dd4fe2f4 100644 --- a/synapse/rest/client/presence.py +++ b/synapse/rest/client/presence.py @@ -15,12 +15,18 @@ """ This module contains REST servlets to do with presence: /presence/ """ import logging +from typing import TYPE_CHECKING, Tuple from synapse.api.errors import AuthError, SynapseError from synapse.handlers.presence import format_user_presence_state +from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.http.site import SynapseRequest from synapse.rest.client._base import client_patterns -from synapse.types import UserID +from synapse.types import JsonDict, UserID + +if TYPE_CHECKING: + from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -28,7 +34,7 @@ logger = logging.getLogger(__name__) class PresenceStatusRestServlet(RestServlet): PATTERNS = client_patterns("/presence/(?P[^/]*)/status", v1=True) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.presence_handler = hs.get_presence_handler() @@ -37,7 +43,9 @@ class PresenceStatusRestServlet(RestServlet): self._use_presence = hs.config.server.use_presence - async def on_GET(self, request, user_id): + async def on_GET( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) user = UserID.from_string(user_id) @@ -53,13 +61,15 @@ class PresenceStatusRestServlet(RestServlet): raise AuthError(403, "You are not allowed to see their presence.") state = await self.presence_handler.get_state(target_user=user) - state = format_user_presence_state( + result = format_user_presence_state( state, self.clock.time_msec(), include_user_id=False ) - return 200, state + return 200, result - async def on_PUT(self, request, user_id): + async def on_PUT( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) user = UserID.from_string(user_id) @@ -91,5 +101,5 @@ class PresenceStatusRestServlet(RestServlet): return 200, {} -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: PresenceStatusRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/profile.py b/synapse/rest/client/profile.py index 5463ed2c4f..d0f20de569 100644 --- a/synapse/rest/client/profile.py +++ b/synapse/rest/client/profile.py @@ -14,22 +14,31 @@ """ This module contains REST servlets to do with profile: /profile/ """ +from typing import TYPE_CHECKING, Tuple + from synapse.api.errors import Codes, SynapseError +from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.http.site import SynapseRequest from synapse.rest.client._base import client_patterns -from synapse.types import UserID +from synapse.types import JsonDict, UserID + +if TYPE_CHECKING: + from synapse.server import HomeServer class ProfileDisplaynameRestServlet(RestServlet): PATTERNS = client_patterns("/profile/(?P[^/]*)/displayname", v1=True) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.profile_handler = hs.get_profile_handler() self.auth = hs.get_auth() - async def on_GET(self, request, user_id): + async def on_GET( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: requester_user = None if self.hs.config.require_auth_for_profile_requests: @@ -48,7 +57,9 @@ class ProfileDisplaynameRestServlet(RestServlet): return 200, ret - async def on_PUT(self, request, user_id): + async def on_PUT( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) user = UserID.from_string(user_id) is_admin = await self.auth.is_server_admin(requester.user) @@ -72,13 +83,15 @@ class ProfileDisplaynameRestServlet(RestServlet): class ProfileAvatarURLRestServlet(RestServlet): PATTERNS = client_patterns("/profile/(?P[^/]*)/avatar_url", v1=True) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.profile_handler = hs.get_profile_handler() self.auth = hs.get_auth() - async def on_GET(self, request, user_id): + async def on_GET( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: requester_user = None if self.hs.config.require_auth_for_profile_requests: @@ -97,7 +110,9 @@ class ProfileAvatarURLRestServlet(RestServlet): return 200, ret - async def on_PUT(self, request, user_id): + async def on_PUT( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) user = UserID.from_string(user_id) is_admin = await self.auth.is_server_admin(requester.user) @@ -120,13 +135,15 @@ class ProfileAvatarURLRestServlet(RestServlet): class ProfileRestServlet(RestServlet): PATTERNS = client_patterns("/profile/(?P[^/]*)", v1=True) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.profile_handler = hs.get_profile_handler() self.auth = hs.get_auth() - async def on_GET(self, request, user_id): + async def on_GET( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: requester_user = None if self.hs.config.require_auth_for_profile_requests: @@ -149,7 +166,7 @@ class ProfileRestServlet(RestServlet): return 200, ret -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: ProfileDisplaynameRestServlet(hs).register(http_server) ProfileAvatarURLRestServlet(hs).register(http_server) ProfileRestServlet(hs).register(http_server) -- cgit 1.5.1