From 394673055db4df49bfd58c2f6118834a6d928563 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Wed, 23 Jun 2021 15:57:41 +0100 Subject: Re-introduce "Leave out optional keys from /sync" change (#10214) Required some fixes due to merge conflicts with #6739, but nothing too hairy. The first commit is the same as the original (after merge conflict resolution) then two more for compatibility with the latest sync code. --- tests/rest/client/v2_alpha/test_sync.py | 30 +----------------------------- 1 file changed, 1 insertion(+), 29 deletions(-) (limited to 'tests/rest') diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py index 012910f136..cdca3a3e23 100644 --- a/tests/rest/client/v2_alpha/test_sync.py +++ b/tests/rest/client/v2_alpha/test_sync.py @@ -41,35 +41,7 @@ class FilterTestCase(unittest.HomeserverTestCase): channel = self.make_request("GET", "/sync") self.assertEqual(channel.code, 200) - self.assertTrue( - { - "next_batch", - "rooms", - "presence", - "account_data", - "to_device", - "device_lists", - }.issubset(set(channel.json_body.keys())) - ) - - def test_sync_presence_disabled(self): - """ - When presence is disabled, the key does not appear in /sync. - """ - self.hs.config.use_presence = False - - channel = self.make_request("GET", "/sync") - - self.assertEqual(channel.code, 200) - self.assertTrue( - { - "next_batch", - "rooms", - "account_data", - "to_device", - "device_lists", - }.issubset(set(channel.json_body.keys())) - ) + self.assertIn("next_batch", channel.json_body) class SyncFilterTestCase(unittest.HomeserverTestCase): -- cgit 1.5.1 From bd4919fb72b2a75f1c0a7f0c78bd619fd2ae30e8 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Thu, 24 Jun 2021 15:33:20 +0200 Subject: MSC2918 Refresh tokens implementation (#9450) This implements refresh tokens, as defined by MSC2918 This MSC has been implemented client side in Hydrogen Web: vector-im/hydrogen-web#235 The basics of the MSC works: requesting refresh tokens on login, having the access tokens expire, and using the refresh token to get a new one. Signed-off-by: Quentin Gliech --- changelog.d/9450.feature | 1 + scripts/synapse_port_db | 4 +- synapse/api/auth.py | 5 + synapse/config/registration.py | 21 ++ synapse/handlers/auth.py | 132 ++++++++++++- synapse/handlers/register.py | 52 ++++- synapse/module_api/__init__.py | 2 +- synapse/replication/http/login.py | 13 +- synapse/rest/client/v1/login.py | 171 +++++++++++++--- synapse/rest/client/v2_alpha/register.py | 88 +++++++-- synapse/storage/databases/main/registration.py | 207 ++++++++++++++++++- .../schema/main/delta/59/14refresh_tokens.sql | 34 ++++ tests/api/test_auth.py | 1 + tests/handlers/test_device.py | 2 +- tests/rest/client/v2_alpha/test_auth.py | 220 ++++++++++++++++++++- 15 files changed, 892 insertions(+), 61 deletions(-) create mode 100644 changelog.d/9450.feature create mode 100644 synapse/storage/schema/main/delta/59/14refresh_tokens.sql (limited to 'tests/rest') diff --git a/changelog.d/9450.feature b/changelog.d/9450.feature new file mode 100644 index 0000000000..455936a41d --- /dev/null +++ b/changelog.d/9450.feature @@ -0,0 +1 @@ +Implement refresh tokens as specified by [MSC2918](https://github.com/matrix-org/matrix-doc/pull/2918). diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index 86eb76cbca..2bbaf5557d 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -93,6 +93,7 @@ BOOLEAN_COLUMNS = { "local_media_repository": ["safe_from_quarantine"], "users": ["shadow_banned"], "e2e_fallback_keys_json": ["used"], + "access_tokens": ["used"], } @@ -307,7 +308,8 @@ class Porter(object): information_schema.table_constraints AS tc INNER JOIN information_schema.constraint_column_usage AS ccu USING (table_schema, constraint_name) - WHERE tc.constraint_type = 'FOREIGN KEY'; + WHERE tc.constraint_type = 'FOREIGN KEY' + AND tc.table_name != ccu.table_name; """ txn.execute(sql) diff --git a/synapse/api/auth.py b/synapse/api/auth.py index edf1b918eb..29cf257633 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -245,6 +245,11 @@ class Auth: errcode=Codes.GUEST_ACCESS_FORBIDDEN, ) + # Mark the token as used. This is used to invalidate old refresh + # tokens after some time. + if not user_info.token_used and token_id is not None: + await self.store.mark_access_token_as_used(token_id) + requester = create_requester( user_info.user_id, token_id, diff --git a/synapse/config/registration.py b/synapse/config/registration.py index d9dc55a0c3..0ad919b139 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -119,6 +119,27 @@ class RegistrationConfig(Config): session_lifetime = self.parse_duration(session_lifetime) self.session_lifetime = session_lifetime + # The `access_token_lifetime` applies for tokens that can be renewed + # using a refresh token, as per MSC2918. If it is `None`, the refresh + # token mechanism is disabled. + # + # Since it is incompatible with the `session_lifetime` mechanism, it is set to + # `None` by default if a `session_lifetime` is set. + access_token_lifetime = config.get( + "access_token_lifetime", "5m" if session_lifetime is None else None + ) + if access_token_lifetime is not None: + access_token_lifetime = self.parse_duration(access_token_lifetime) + self.access_token_lifetime = access_token_lifetime + + if session_lifetime is not None and access_token_lifetime is not None: + raise ConfigError( + "The refresh token mechanism is incompatible with the " + "`session_lifetime` option. Consider disabling the " + "`session_lifetime` option or disabling the refresh token " + "mechanism by removing the `access_token_lifetime` option." + ) + # The success template used during fallback auth. self.fallback_success_template = self.read_template("auth_success.html") diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 1971e373ed..e2ac595a62 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -30,6 +30,7 @@ from typing import ( Optional, Tuple, Union, + cast, ) import attr @@ -72,6 +73,7 @@ from synapse.util.stringutils import base62_encode from synapse.util.threepids import canonicalise_email if TYPE_CHECKING: + from synapse.rest.client.v1.login import LoginResponse from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -777,6 +779,108 @@ class AuthHandler(BaseHandler): "params": params, } + async def refresh_token( + self, + refresh_token: str, + valid_until_ms: Optional[int], + ) -> Tuple[str, str]: + """ + Consumes a refresh token and generate both a new access token and a new refresh token from it. + + The consumed refresh token is considered invalid after the first use of the new access token or the new refresh token. + + Args: + refresh_token: The token to consume. + valid_until_ms: The expiration timestamp of the new access token. + + Returns: + A tuple containing the new access token and refresh token + """ + + # Verify the token signature first before looking up the token + if not self._verify_refresh_token(refresh_token): + raise SynapseError(401, "invalid refresh token", Codes.UNKNOWN_TOKEN) + + existing_token = await self.store.lookup_refresh_token(refresh_token) + if existing_token is None: + raise SynapseError(401, "refresh token does not exist", Codes.UNKNOWN_TOKEN) + + if ( + existing_token.has_next_access_token_been_used + or existing_token.has_next_refresh_token_been_refreshed + ): + raise SynapseError( + 403, "refresh token isn't valid anymore", Codes.FORBIDDEN + ) + + ( + new_refresh_token, + new_refresh_token_id, + ) = await self.get_refresh_token_for_user_id( + user_id=existing_token.user_id, device_id=existing_token.device_id + ) + access_token = await self.get_access_token_for_user_id( + user_id=existing_token.user_id, + device_id=existing_token.device_id, + valid_until_ms=valid_until_ms, + refresh_token_id=new_refresh_token_id, + ) + await self.store.replace_refresh_token( + existing_token.token_id, new_refresh_token_id + ) + return access_token, new_refresh_token + + def _verify_refresh_token(self, token: str) -> bool: + """ + Verifies the shape of a refresh token. + + Args: + token: The refresh token to verify + + Returns: + Whether the token has the right shape + """ + parts = token.split("_", maxsplit=4) + if len(parts) != 4: + return False + + type, localpart, rand, crc = parts + + # Refresh tokens are prefixed by "syr_", let's check that + if type != "syr": + return False + + # Check the CRC + base = f"{type}_{localpart}_{rand}" + expected_crc = base62_encode(crc32(base.encode("ascii")), minwidth=6) + if crc != expected_crc: + return False + + return True + + async def get_refresh_token_for_user_id( + self, + user_id: str, + device_id: str, + ) -> Tuple[str, int]: + """ + Creates a new refresh token for the user with the given user ID. + + Args: + user_id: canonical user ID + device_id: the device ID to associate with the token. + + Returns: + The newly created refresh token and its ID in the database + """ + refresh_token = self.generate_refresh_token(UserID.from_string(user_id)) + refresh_token_id = await self.store.add_refresh_token_to_user( + user_id=user_id, + token=refresh_token, + device_id=device_id, + ) + return refresh_token, refresh_token_id + async def get_access_token_for_user_id( self, user_id: str, @@ -784,6 +888,7 @@ class AuthHandler(BaseHandler): valid_until_ms: Optional[int], puppets_user_id: Optional[str] = None, is_appservice_ghost: bool = False, + refresh_token_id: Optional[int] = None, ) -> str: """ Creates a new access token for the user with the given user ID. @@ -801,6 +906,8 @@ class AuthHandler(BaseHandler): valid_until_ms: when the token is valid until. None for no expiry. is_appservice_ghost: Whether the user is an application ghost user + refresh_token_id: the refresh token ID that will be associated with + this access token. Returns: The access token for the user's session. Raises: @@ -836,6 +943,7 @@ class AuthHandler(BaseHandler): device_id=device_id, valid_until_ms=valid_until_ms, puppets_user_id=puppets_user_id, + refresh_token_id=refresh_token_id, ) # the device *should* have been registered before we got here; however, @@ -928,7 +1036,7 @@ class AuthHandler(BaseHandler): self, login_submission: Dict[str, Any], ratelimit: bool = False, - ) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]: + ) -> Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]: """Authenticates the user for the /login API Also used by the user-interactive auth flow to validate auth types which don't @@ -1073,7 +1181,7 @@ class AuthHandler(BaseHandler): self, username: str, login_submission: Dict[str, Any], - ) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]: + ) -> Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]: """Helper for validate_login Handles login, once we've mapped 3pids onto userids @@ -1151,7 +1259,7 @@ class AuthHandler(BaseHandler): async def check_password_provider_3pid( self, medium: str, address: str, password: str - ) -> Tuple[Optional[str], Optional[Callable[[Dict[str, str]], Awaitable[None]]]]: + ) -> Tuple[Optional[str], Optional[Callable[["LoginResponse"], Awaitable[None]]]]: """Check if a password provider is able to validate a thirdparty login Args: @@ -1215,6 +1323,19 @@ class AuthHandler(BaseHandler): crc = base62_encode(crc32(base.encode("ascii")), minwidth=6) return f"{base}_{crc}" + def generate_refresh_token(self, for_user: UserID) -> str: + """Generates an opaque string, for use as a refresh token""" + + # we use the following format for refresh tokens: + # syr___ + + b64local = unpaddedbase64.encode_base64(for_user.localpart.encode("utf-8")) + random_string = stringutils.random_string(20) + base = f"syr_{b64local}_{random_string}" + + crc = base62_encode(crc32(base.encode("ascii")), minwidth=6) + return f"{base}_{crc}" + async def validate_short_term_login_token( self, login_token: str ) -> LoginTokenAttributes: @@ -1563,7 +1684,7 @@ class AuthHandler(BaseHandler): ) respond_with_html(request, 200, html) - async def _sso_login_callback(self, login_result: JsonDict) -> None: + async def _sso_login_callback(self, login_result: "LoginResponse") -> None: """ A login callback which might add additional attributes to the login response. @@ -1577,7 +1698,8 @@ class AuthHandler(BaseHandler): extra_attributes = self._extra_attributes.get(login_result["user_id"]) if extra_attributes: - login_result.update(extra_attributes.extra_attributes) + login_result_dict = cast(Dict[str, Any], login_result) + login_result_dict.update(extra_attributes.extra_attributes) def _expire_sso_extra_attributes(self) -> None: """ diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 4b4b579741..26ef016179 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -15,9 +15,10 @@ """Contains functions for registering clients.""" import logging -from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple +from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple from prometheus_client import Counter +from typing_extensions import TypedDict from synapse import types from synapse.api.constants import MAX_USERID_LENGTH, EventTypes, JoinRules, LoginType @@ -54,6 +55,16 @@ login_counter = Counter( ["guest", "auth_provider"], ) +LoginDict = TypedDict( + "LoginDict", + { + "device_id": str, + "access_token": str, + "valid_until_ms": Optional[int], + "refresh_token": Optional[str], + }, +) + class RegistrationHandler(BaseHandler): def __init__(self, hs: "HomeServer"): @@ -85,6 +96,7 @@ class RegistrationHandler(BaseHandler): self.pusher_pool = hs.get_pusherpool() self.session_lifetime = hs.config.session_lifetime + self.access_token_lifetime = hs.config.access_token_lifetime async def check_username( self, @@ -696,7 +708,8 @@ class RegistrationHandler(BaseHandler): is_guest: bool = False, is_appservice_ghost: bool = False, auth_provider_id: Optional[str] = None, - ) -> Tuple[str, str]: + should_issue_refresh_token: bool = False, + ) -> Tuple[str, str, Optional[int], Optional[str]]: """Register a device for a user and generate an access token. The access token will be limited by the homeserver's session_lifetime config. @@ -708,8 +721,9 @@ class RegistrationHandler(BaseHandler): is_guest: Whether this is a guest account auth_provider_id: The SSO IdP the user used, if any (just used for the prometheus metrics). + should_issue_refresh_token: Whether it should also issue a refresh token Returns: - Tuple of device ID and access token + Tuple of device ID, access token, access token expiration time and refresh token """ res = await self._register_device_client( user_id=user_id, @@ -717,6 +731,7 @@ class RegistrationHandler(BaseHandler): initial_display_name=initial_display_name, is_guest=is_guest, is_appservice_ghost=is_appservice_ghost, + should_issue_refresh_token=should_issue_refresh_token, ) login_counter.labels( @@ -724,7 +739,12 @@ class RegistrationHandler(BaseHandler): auth_provider=(auth_provider_id or ""), ).inc() - return res["device_id"], res["access_token"] + return ( + res["device_id"], + res["access_token"], + res["valid_until_ms"], + res["refresh_token"], + ) async def register_device_inner( self, @@ -733,7 +753,8 @@ class RegistrationHandler(BaseHandler): initial_display_name: Optional[str], is_guest: bool = False, is_appservice_ghost: bool = False, - ) -> Dict[str, str]: + should_issue_refresh_token: bool = False, + ) -> LoginDict: """Helper for register_device Does the bits that need doing on the main process. Not for use outside this @@ -748,6 +769,9 @@ class RegistrationHandler(BaseHandler): ) valid_until_ms = self.clock.time_msec() + self.session_lifetime + refresh_token = None + refresh_token_id = None + registered_device_id = await self.device_handler.check_device_registered( user_id, device_id, initial_display_name ) @@ -755,14 +779,30 @@ class RegistrationHandler(BaseHandler): assert valid_until_ms is None access_token = self.macaroon_gen.generate_guest_access_token(user_id) else: + if should_issue_refresh_token: + ( + refresh_token, + refresh_token_id, + ) = await self._auth_handler.get_refresh_token_for_user_id( + user_id, + device_id=registered_device_id, + ) + valid_until_ms = self.clock.time_msec() + self.access_token_lifetime + access_token = await self._auth_handler.get_access_token_for_user_id( user_id, device_id=registered_device_id, valid_until_ms=valid_until_ms, is_appservice_ghost=is_appservice_ghost, + refresh_token_id=refresh_token_id, ) - return {"device_id": registered_device_id, "access_token": access_token} + return { + "device_id": registered_device_id, + "access_token": access_token, + "valid_until_ms": valid_until_ms, + "refresh_token": refresh_token, + } async def post_registration_actions( self, user_id: str, auth_result: dict, access_token: Optional[str] diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 58b255eb1b..721c45abac 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -168,7 +168,7 @@ class ModuleApi: "Using deprecated ModuleApi.register which creates a dummy user device." ) user_id = yield self.register_user(localpart, displayname, emails or []) - _, access_token = yield self.register_device(user_id) + _, access_token, _, _ = yield self.register_device(user_id) return user_id, access_token def register_user( diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py index c2e8c00293..550bd5c95f 100644 --- a/synapse/replication/http/login.py +++ b/synapse/replication/http/login.py @@ -36,20 +36,29 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): @staticmethod async def _serialize_payload( - user_id, device_id, initial_display_name, is_guest, is_appservice_ghost + user_id, + device_id, + initial_display_name, + is_guest, + is_appservice_ghost, + should_issue_refresh_token, ): """ Args: + user_id (int) device_id (str|None): Device ID to use, if None a new one is generated. initial_display_name (str|None) is_guest (bool) + is_appservice_ghost (bool) + should_issue_refresh_token (bool) """ return { "device_id": device_id, "initial_display_name": initial_display_name, "is_guest": is_guest, "is_appservice_ghost": is_appservice_ghost, + "should_issue_refresh_token": should_issue_refresh_token, } async def _handle_request(self, request, user_id): @@ -59,6 +68,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): initial_display_name = content["initial_display_name"] is_guest = content["is_guest"] is_appservice_ghost = content["is_appservice_ghost"] + should_issue_refresh_token = content["should_issue_refresh_token"] res = await self.registration_handler.register_device_inner( user_id, @@ -66,6 +76,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): initial_display_name, is_guest, is_appservice_ghost=is_appservice_ghost, + should_issue_refresh_token=should_issue_refresh_token, ) return 200, res diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index f6be5f1020..cbcb60fe31 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -14,7 +14,9 @@ import logging import re -from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional + +from typing_extensions import TypedDict from synapse.api.errors import Codes, LoginError, SynapseError from synapse.api.ratelimiting import Ratelimiter @@ -25,6 +27,8 @@ from synapse.http import get_request_uri from synapse.http.server import HttpServer, finish_request from synapse.http.servlet import ( RestServlet, + assert_params_in_dict, + parse_boolean, parse_bytes_from_args, parse_json_object_from_request, parse_string, @@ -40,6 +44,21 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +LoginResponse = TypedDict( + "LoginResponse", + { + "user_id": str, + "access_token": str, + "home_server": str, + "expires_in_ms": Optional[int], + "refresh_token": Optional[str], + "device_id": str, + "well_known": Optional[Dict[str, Any]], + }, + total=False, +) + + class LoginRestServlet(RestServlet): PATTERNS = client_patterns("/login$", v1=True) CAS_TYPE = "m.login.cas" @@ -48,6 +67,7 @@ class LoginRestServlet(RestServlet): JWT_TYPE = "org.matrix.login.jwt" JWT_TYPE_DEPRECATED = "m.login.jwt" APPSERVICE_TYPE = "uk.half-shot.msc2778.login.application_service" + REFRESH_TOKEN_PARAM = "org.matrix.msc2918.refresh_token" def __init__(self, hs: "HomeServer"): super().__init__() @@ -65,9 +85,12 @@ class LoginRestServlet(RestServlet): self.cas_enabled = hs.config.cas_enabled self.oidc_enabled = hs.config.oidc_enabled self._msc2858_enabled = hs.config.experimental.msc2858_enabled + self._msc2918_enabled = hs.config.access_token_lifetime is not None self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.auth_handler = self.hs.get_auth_handler() self.registration_handler = hs.get_registration_handler() self._sso_handler = hs.get_sso_handler() @@ -138,6 +161,15 @@ class LoginRestServlet(RestServlet): async def on_POST(self, request: SynapseRequest): login_submission = parse_json_object_from_request(request) + if self._msc2918_enabled: + # Check if this login should also issue a refresh token, as per + # MSC2918 + should_issue_refresh_token = parse_boolean( + request, name=LoginRestServlet.REFRESH_TOKEN_PARAM, default=False + ) + else: + should_issue_refresh_token = False + try: if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE: appservice = self.auth.get_appservice_by_req(request) @@ -147,19 +179,32 @@ class LoginRestServlet(RestServlet): None, request.getClientIP() ) - result = await self._do_appservice_login(login_submission, appservice) + result = await self._do_appservice_login( + login_submission, + appservice, + should_issue_refresh_token=should_issue_refresh_token, + ) elif self.jwt_enabled and ( login_submission["type"] == LoginRestServlet.JWT_TYPE or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED ): await self._address_ratelimiter.ratelimit(None, request.getClientIP()) - result = await self._do_jwt_login(login_submission) + result = await self._do_jwt_login( + login_submission, + should_issue_refresh_token=should_issue_refresh_token, + ) elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE: await self._address_ratelimiter.ratelimit(None, request.getClientIP()) - result = await self._do_token_login(login_submission) + result = await self._do_token_login( + login_submission, + should_issue_refresh_token=should_issue_refresh_token, + ) else: await self._address_ratelimiter.ratelimit(None, request.getClientIP()) - result = await self._do_other_login(login_submission) + result = await self._do_other_login( + login_submission, + should_issue_refresh_token=should_issue_refresh_token, + ) except KeyError: raise SynapseError(400, "Missing JSON keys.") @@ -169,7 +214,10 @@ class LoginRestServlet(RestServlet): return 200, result async def _do_appservice_login( - self, login_submission: JsonDict, appservice: ApplicationService + self, + login_submission: JsonDict, + appservice: ApplicationService, + should_issue_refresh_token: bool = False, ): identifier = login_submission.get("identifier") logger.info("Got appservice login request with identifier: %r", identifier) @@ -198,14 +246,21 @@ class LoginRestServlet(RestServlet): raise LoginError(403, "Invalid access_token", errcode=Codes.FORBIDDEN) return await self._complete_login( - qualified_user_id, login_submission, ratelimit=appservice.is_rate_limited() + qualified_user_id, + login_submission, + ratelimit=appservice.is_rate_limited(), + should_issue_refresh_token=should_issue_refresh_token, ) - async def _do_other_login(self, login_submission: JsonDict) -> Dict[str, str]: + async def _do_other_login( + self, login_submission: JsonDict, should_issue_refresh_token: bool = False + ) -> LoginResponse: """Handle non-token/saml/jwt logins Args: login_submission: + should_issue_refresh_token: True if this login should issue + a refresh token alongside the access token. Returns: HTTP response @@ -224,7 +279,10 @@ class LoginRestServlet(RestServlet): login_submission, ratelimit=True ) result = await self._complete_login( - canonical_user_id, login_submission, callback + canonical_user_id, + login_submission, + callback, + should_issue_refresh_token=should_issue_refresh_token, ) return result @@ -232,11 +290,12 @@ class LoginRestServlet(RestServlet): self, user_id: str, login_submission: JsonDict, - callback: Optional[Callable[[Dict[str, str]], Awaitable[None]]] = None, + callback: Optional[Callable[[LoginResponse], Awaitable[None]]] = None, create_non_existent_users: bool = False, ratelimit: bool = True, auth_provider_id: Optional[str] = None, - ) -> Dict[str, str]: + should_issue_refresh_token: bool = False, + ) -> LoginResponse: """Called when we've successfully authed the user and now need to actually login them in (e.g. create devices). This gets called on all successful logins. @@ -253,6 +312,8 @@ class LoginRestServlet(RestServlet): ratelimit: Whether to ratelimit the login request. auth_provider_id: The SSO IdP the user used, if any (just used for the prometheus metrics). + should_issue_refresh_token: True if this login should issue + a refresh token alongside the access token. Returns: result: Dictionary of account information after successful login. @@ -274,28 +335,48 @@ class LoginRestServlet(RestServlet): device_id = login_submission.get("device_id") initial_display_name = login_submission.get("initial_device_display_name") - device_id, access_token = await self.registration_handler.register_device( - user_id, device_id, initial_display_name, auth_provider_id=auth_provider_id + ( + device_id, + access_token, + valid_until_ms, + refresh_token, + ) = await self.registration_handler.register_device( + user_id, + device_id, + initial_display_name, + auth_provider_id=auth_provider_id, + should_issue_refresh_token=should_issue_refresh_token, ) - result = { - "user_id": user_id, - "access_token": access_token, - "home_server": self.hs.hostname, - "device_id": device_id, - } + result = LoginResponse( + user_id=user_id, + access_token=access_token, + home_server=self.hs.hostname, + device_id=device_id, + ) + + if valid_until_ms is not None: + expires_in_ms = valid_until_ms - self.clock.time_msec() + result["expires_in_ms"] = expires_in_ms + + if refresh_token is not None: + result["refresh_token"] = refresh_token if callback is not None: await callback(result) return result - async def _do_token_login(self, login_submission: JsonDict) -> Dict[str, str]: + async def _do_token_login( + self, login_submission: JsonDict, should_issue_refresh_token: bool = False + ) -> LoginResponse: """ Handle the final stage of SSO login. Args: - login_submission: The JSON request body. + login_submission: The JSON request body. + should_issue_refresh_token: True if this login should issue + a refresh token alongside the access token. Returns: The body of the JSON response. @@ -309,9 +390,12 @@ class LoginRestServlet(RestServlet): login_submission, self.auth_handler._sso_login_callback, auth_provider_id=res.auth_provider_id, + should_issue_refresh_token=should_issue_refresh_token, ) - async def _do_jwt_login(self, login_submission: JsonDict) -> Dict[str, str]: + async def _do_jwt_login( + self, login_submission: JsonDict, should_issue_refresh_token: bool = False + ) -> LoginResponse: token = login_submission.get("token", None) if token is None: raise LoginError( @@ -342,7 +426,10 @@ class LoginRestServlet(RestServlet): user_id = UserID(user, self.hs.hostname).to_string() result = await self._complete_login( - user_id, login_submission, create_non_existent_users=True + user_id, + login_submission, + create_non_existent_users=True, + should_issue_refresh_token=should_issue_refresh_token, ) return result @@ -371,6 +458,42 @@ def _get_auth_flow_dict_for_idp( return e +class RefreshTokenServlet(RestServlet): + PATTERNS = client_patterns( + "/org.matrix.msc2918.refresh_token/refresh$", releases=(), unstable=True + ) + + def __init__(self, hs: "HomeServer"): + self._auth_handler = hs.get_auth_handler() + self._clock = hs.get_clock() + self.access_token_lifetime = hs.config.access_token_lifetime + + async def on_POST( + self, + request: SynapseRequest, + ): + refresh_submission = parse_json_object_from_request(request) + + assert_params_in_dict(refresh_submission, ["refresh_token"]) + token = refresh_submission["refresh_token"] + if not isinstance(token, str): + raise SynapseError(400, "Invalid param: refresh_token", Codes.INVALID_PARAM) + + valid_until_ms = self._clock.time_msec() + self.access_token_lifetime + access_token, refresh_token = await self._auth_handler.refresh_token( + token, valid_until_ms + ) + expires_in_ms = valid_until_ms - self._clock.time_msec() + return ( + 200, + { + "access_token": access_token, + "refresh_token": refresh_token, + "expires_in_ms": expires_in_ms, + }, + ) + + class SsoRedirectServlet(RestServlet): PATTERNS = list(client_patterns("/login/(cas|sso)/redirect$", v1=True)) + [ re.compile( @@ -477,6 +600,8 @@ class CasTicketServlet(RestServlet): def register_servlets(hs, http_server): LoginRestServlet(hs).register(http_server) + if hs.config.access_token_lifetime is not None: + RefreshTokenServlet(hs).register(http_server) SsoRedirectServlet(hs).register(http_server) if hs.config.cas_enabled: CasTicketServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index a30a5df1b1..4d31584acd 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -41,11 +41,13 @@ from synapse.http.server import finish_request, respond_with_html from synapse.http.servlet import ( RestServlet, assert_params_in_dict, + parse_boolean, parse_json_object_from_request, parse_string, ) from synapse.metrics import threepid_send_requests from synapse.push.mailer import Mailer +from synapse.types import JsonDict from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.stringutils import assert_valid_client_secret, random_string @@ -399,6 +401,7 @@ class RegisterRestServlet(RestServlet): self.password_policy_handler = hs.get_password_policy_handler() self.clock = hs.get_clock() self._registration_enabled = self.hs.config.enable_registration + self._msc2918_enabled = hs.config.access_token_lifetime is not None self._registration_flows = _calculate_registration_flows( hs.config, self.auth_handler @@ -424,6 +427,15 @@ class RegisterRestServlet(RestServlet): "Do not understand membership kind: %s" % (kind.decode("utf8"),) ) + if self._msc2918_enabled: + # Check if this registration should also issue a refresh token, as + # per MSC2918 + should_issue_refresh_token = parse_boolean( + request, name="org.matrix.msc2918.refresh_token", default=False + ) + else: + should_issue_refresh_token = False + # Pull out the provided username and do basic sanity checks early since # the auth layer will store these in sessions. desired_username = None @@ -462,7 +474,10 @@ class RegisterRestServlet(RestServlet): raise SynapseError(400, "Desired Username is missing or not a string") result = await self._do_appservice_registration( - desired_username, access_token, body + desired_username, + access_token, + body, + should_issue_refresh_token=should_issue_refresh_token, ) return 200, result @@ -665,7 +680,9 @@ class RegisterRestServlet(RestServlet): registered = True return_dict = await self._create_registration_details( - registered_user_id, params + registered_user_id, + params, + should_issue_refresh_token=should_issue_refresh_token, ) if registered: @@ -677,7 +694,9 @@ class RegisterRestServlet(RestServlet): return 200, return_dict - async def _do_appservice_registration(self, username, as_token, body): + async def _do_appservice_registration( + self, username, as_token, body, should_issue_refresh_token: bool = False + ): user_id = await self.registration_handler.appservice_register( username, as_token ) @@ -685,19 +704,27 @@ class RegisterRestServlet(RestServlet): user_id, body, is_appservice_ghost=True, + should_issue_refresh_token=should_issue_refresh_token, ) async def _create_registration_details( - self, user_id, params, is_appservice_ghost=False + self, + user_id: str, + params: JsonDict, + is_appservice_ghost: bool = False, + should_issue_refresh_token: bool = False, ): """Complete registration of newly-registered user Allocates device_id if one was not given; also creates access_token. Args: - (str) user_id: full canonical @user:id - (object) params: registration parameters, from which we pull - device_id, initial_device_name and inhibit_login + user_id: full canonical @user:id + params: registration parameters, from which we pull device_id, + initial_device_name and inhibit_login + is_appservice_ghost + should_issue_refresh_token: True if this registration should issue + a refresh token alongside the access token. Returns: dictionary for response from /register """ @@ -705,15 +732,29 @@ class RegisterRestServlet(RestServlet): if not params.get("inhibit_login", False): device_id = params.get("device_id") initial_display_name = params.get("initial_device_display_name") - device_id, access_token = await self.registration_handler.register_device( + ( + device_id, + access_token, + valid_until_ms, + refresh_token, + ) = await self.registration_handler.register_device( user_id, device_id, initial_display_name, is_guest=False, is_appservice_ghost=is_appservice_ghost, + should_issue_refresh_token=should_issue_refresh_token, ) result.update({"access_token": access_token, "device_id": device_id}) + + if valid_until_ms is not None: + expires_in_ms = valid_until_ms - self.clock.time_msec() + result["expires_in_ms"] = expires_in_ms + + if refresh_token is not None: + result["refresh_token"] = refresh_token + return result async def _do_guest_registration(self, params, address=None): @@ -727,19 +768,30 @@ class RegisterRestServlet(RestServlet): # we have nowhere to store it. device_id = synapse.api.auth.GUEST_DEVICE_ID initial_display_name = params.get("initial_device_display_name") - device_id, access_token = await self.registration_handler.register_device( + ( + device_id, + access_token, + valid_until_ms, + refresh_token, + ) = await self.registration_handler.register_device( user_id, device_id, initial_display_name, is_guest=True ) - return ( - 200, - { - "user_id": user_id, - "device_id": device_id, - "access_token": access_token, - "home_server": self.hs.hostname, - }, - ) + result = { + "user_id": user_id, + "device_id": device_id, + "access_token": access_token, + "home_server": self.hs.hostname, + } + + if valid_until_ms is not None: + expires_in_ms = valid_until_ms - self.clock.time_msec() + result["expires_in_ms"] = expires_in_ms + + if refresh_token is not None: + result["refresh_token"] = refresh_token + + return 200, result def _calculate_registration_flows( diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index e5c5cf8ff0..e31c5864ac 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -53,6 +53,9 @@ class TokenLookupResult: valid_until_ms: The timestamp the token expires, if any. token_owner: The "owner" of the token. This is either the same as the user, or a server admin who is logged in as the user. + token_used: True if this token was used at least once in a request. + This field can be out of date since `get_user_by_access_token` is + cached. """ user_id = attr.ib(type=str) @@ -62,6 +65,7 @@ class TokenLookupResult: device_id = attr.ib(type=Optional[str], default=None) valid_until_ms = attr.ib(type=Optional[int], default=None) token_owner = attr.ib(type=str) + token_used = attr.ib(type=bool, default=False) # Make the token owner default to the user ID, which is the common case. @token_owner.default @@ -69,6 +73,29 @@ class TokenLookupResult: return self.user_id +@attr.s(frozen=True, slots=True) +class RefreshTokenLookupResult: + """Result of looking up a refresh token.""" + + user_id = attr.ib(type=str) + """The user this token belongs to.""" + + device_id = attr.ib(type=str) + """The device associated with this refresh token.""" + + token_id = attr.ib(type=int) + """The ID of this refresh token.""" + + next_token_id = attr.ib(type=Optional[int]) + """The ID of the refresh token which replaced this one.""" + + has_next_refresh_token_been_refreshed = attr.ib(type=bool) + """True if the next refresh token was used for another refresh.""" + + has_next_access_token_been_used = attr.ib(type=bool) + """True if the next access token was already used at least once.""" + + class RegistrationWorkerStore(CacheInvalidationWorkerStore): def __init__( self, @@ -441,7 +468,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): access_tokens.id as token_id, access_tokens.device_id, access_tokens.valid_until_ms, - access_tokens.user_id as token_owner + access_tokens.user_id as token_owner, + access_tokens.used as token_used FROM users INNER JOIN access_tokens on users.name = COALESCE(puppets_user_id, access_tokens.user_id) WHERE token = ? @@ -449,8 +477,15 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): txn.execute(sql, (token,)) rows = self.db_pool.cursor_to_dict(txn) + if rows: - return TokenLookupResult(**rows[0]) + row = rows[0] + + # This field is nullable, ensure it comes out as a boolean + if row["token_used"] is None: + row["token_used"] = False + + return TokenLookupResult(**row) return None @@ -1072,6 +1107,111 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): desc="update_access_token_last_validated", ) + @cached() + async def mark_access_token_as_used(self, token_id: int) -> None: + """ + Mark the access token as used, which invalidates the refresh token used + to obtain it. + + Because get_user_by_access_token is cached, this function might be + called multiple times for the same token, effectively doing unnecessary + SQL updates. Because updating the `used` field only goes one way (from + False to True) it is safe to cache this function as well to avoid this + issue. + + Args: + token_id: The ID of the access token to update. + Raises: + StoreError if there was a problem updating this. + """ + await self.db_pool.simple_update_one( + "access_tokens", + {"id": token_id}, + {"used": True}, + desc="mark_access_token_as_used", + ) + + async def lookup_refresh_token( + self, token: str + ) -> Optional[RefreshTokenLookupResult]: + """Lookup a refresh token with hints about its validity.""" + + def _lookup_refresh_token_txn(txn) -> Optional[RefreshTokenLookupResult]: + txn.execute( + """ + SELECT + rt.id token_id, + rt.user_id, + rt.device_id, + rt.next_token_id, + (nrt.next_token_id IS NOT NULL) has_next_refresh_token_been_refreshed, + at.used has_next_access_token_been_used + FROM refresh_tokens rt + LEFT JOIN refresh_tokens nrt ON rt.next_token_id = nrt.id + LEFT JOIN access_tokens at ON at.refresh_token_id = nrt.id + WHERE rt.token = ? + """, + (token,), + ) + row = txn.fetchone() + + if row is None: + return None + + return RefreshTokenLookupResult( + token_id=row[0], + user_id=row[1], + device_id=row[2], + next_token_id=row[3], + has_next_refresh_token_been_refreshed=row[4], + # This column is nullable, ensure it's a boolean + has_next_access_token_been_used=(row[5] or False), + ) + + return await self.db_pool.runInteraction( + "lookup_refresh_token", _lookup_refresh_token_txn + ) + + async def replace_refresh_token(self, token_id: int, next_token_id: int) -> None: + """ + Set the successor of a refresh token, removing the existing successor + if any. + + Args: + token_id: ID of the refresh token to update. + next_token_id: ID of its successor. + """ + + def _replace_refresh_token_txn(txn) -> None: + # First check if there was an existing refresh token + old_next_token_id = self.db_pool.simple_select_one_onecol_txn( + txn, + "refresh_tokens", + {"id": token_id}, + "next_token_id", + allow_none=True, + ) + + self.db_pool.simple_update_one_txn( + txn, + "refresh_tokens", + {"id": token_id}, + {"next_token_id": next_token_id}, + ) + + # Delete the old "next" token if it exists. This should cascade and + # delete the associated access_token + if old_next_token_id is not None: + self.db_pool.simple_delete_one_txn( + txn, + "refresh_tokens", + {"id": old_next_token_id}, + ) + + await self.db_pool.runInteraction( + "replace_refresh_token", _replace_refresh_token_txn + ) + class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): def __init__( @@ -1263,6 +1403,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id") + self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id") async def add_access_token_to_user( self, @@ -1271,14 +1412,18 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): device_id: Optional[str], valid_until_ms: Optional[int], puppets_user_id: Optional[str] = None, + refresh_token_id: Optional[int] = None, ) -> int: """Adds an access token for the given user. Args: user_id: The user ID. token: The new access token to add. - device_id: ID of the device to associate with the access token + device_id: ID of the device to associate with the access token. valid_until_ms: when the token is valid until. None for no expiry. + puppets_user_id + refresh_token_id: ID of the refresh token generated alongside this + access token. Raises: StoreError if there was a problem adding this. Returns: @@ -1297,12 +1442,47 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): "valid_until_ms": valid_until_ms, "puppets_user_id": puppets_user_id, "last_validated": now, + "refresh_token_id": refresh_token_id, + "used": False, }, desc="add_access_token_to_user", ) return next_id + async def add_refresh_token_to_user( + self, + user_id: str, + token: str, + device_id: Optional[str], + ) -> int: + """Adds a refresh token for the given user. + + Args: + user_id: The user ID. + token: The new access token to add. + device_id: ID of the device to associate with the refresh token. + Raises: + StoreError if there was a problem adding this. + Returns: + The token ID + """ + next_id = self._refresh_tokens_id_gen.get_next() + + await self.db_pool.simple_insert( + "refresh_tokens", + { + "id": next_id, + "user_id": user_id, + "device_id": device_id, + "token": token, + "next_token_id": None, + }, + desc="add_refresh_token_to_user", + ) + + return next_id + def _set_device_for_access_token_txn(self, txn, token: str, device_id: str) -> str: old_device_id = self.db_pool.simple_select_one_onecol_txn( txn, "access_tokens", {"token": token}, "device_id" @@ -1545,7 +1725,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): device_id: Optional[str] = None, ) -> List[Tuple[str, int, Optional[str]]]: """ - Invalidate access tokens belonging to a user + Invalidate access and refresh tokens belonging to a user Args: user_id: ID of user the tokens belong to @@ -1565,7 +1745,13 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): items = keyvalues.items() where_clause = " AND ".join(k + " = ?" for k, _ in items) values = [v for _, v in items] # type: List[Union[str, int]] + # Conveniently, refresh_tokens and access_tokens both use the user_id and device_id fields. Only caveat + # is the `except_token_id` param that is tricky to get right, so for now we're just using the same where + # clause and values before we handle that. This seems to be only used in the "set password" handler. + refresh_where_clause = where_clause + refresh_values = values.copy() if except_token_id: + # TODO: support that for refresh tokens where_clause += " AND id != ?" values.append(except_token_id) @@ -1583,6 +1769,11 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): txn.execute("DELETE FROM access_tokens WHERE %s" % where_clause, values) + txn.execute( + "DELETE FROM refresh_tokens WHERE %s" % refresh_where_clause, + refresh_values, + ) + return tokens_and_devices return await self.db_pool.runInteraction("user_delete_access_tokens", f) @@ -1599,6 +1790,14 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): await self.db_pool.runInteraction("delete_access_token", f) + async def delete_refresh_token(self, refresh_token: str) -> None: + def f(txn): + self.db_pool.simple_delete_one_txn( + txn, table="refresh_tokens", keyvalues={"token": refresh_token} + ) + + await self.db_pool.runInteraction("delete_refresh_token", f) + async def add_user_pending_deactivation(self, user_id: str) -> None: """ Adds a user to the table of users who need to be parted from all the rooms they're diff --git a/synapse/storage/schema/main/delta/59/14refresh_tokens.sql b/synapse/storage/schema/main/delta/59/14refresh_tokens.sql new file mode 100644 index 0000000000..9a6bce1e3e --- /dev/null +++ b/synapse/storage/schema/main/delta/59/14refresh_tokens.sql @@ -0,0 +1,34 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Holds MSC2918 refresh tokens +CREATE TABLE refresh_tokens ( + id BIGINT PRIMARY KEY, + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + token TEXT NOT NULL, + -- When consumed, a new refresh token is generated, which is tracked by + -- this foreign key + next_token_id BIGINT REFERENCES refresh_tokens (id) ON DELETE CASCADE, + UNIQUE(token) +); + +-- Add a reference to the refresh token generated alongside each access token +ALTER TABLE "access_tokens" + ADD COLUMN refresh_token_id BIGINT REFERENCES refresh_tokens (id) ON DELETE CASCADE; + +-- Add a flag whether the token was already used or not +ALTER TABLE "access_tokens" + ADD COLUMN used BOOLEAN; diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 1b0a815757..f76fea4f66 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -58,6 +58,7 @@ class AuthTestCase(unittest.HomeserverTestCase): user_id=self.test_user, token_id=5, device_id="device" ) self.store.get_user_by_access_token = simple_async_mock(user_info) + self.store.mark_access_token_as_used = simple_async_mock(None) request = Mock(args={}) request.args[b"access_token"] = [self.test_token] diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index 84c38b295d..3ac48e5e95 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -257,7 +257,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase): self.assertEqual(device_data, {"device_data": {"foo": "bar"}}) # Create a new login for the user and dehydrated the device - device_id, access_token = self.get_success( + device_id, access_token, _expiration_time, _refresh_token = self.get_success( self.registration.register_device( user_id=user_id, device_id=None, diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py index 485e3650c3..6b90f838b6 100644 --- a/tests/rest/client/v2_alpha/test_auth.py +++ b/tests/rest/client/v2_alpha/test_auth.py @@ -20,7 +20,7 @@ import synapse.rest.admin from synapse.api.constants import LoginType from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker from synapse.rest.client.v1 import login -from synapse.rest.client.v2_alpha import auth, devices, register +from synapse.rest.client.v2_alpha import account, auth, devices, register from synapse.rest.synapse.client import build_synapse_client_resource_tree from synapse.types import JsonDict, UserID @@ -498,3 +498,221 @@ class UIAuthTests(unittest.HomeserverTestCase): self.delete_device( self.user_tok, self.device_id, 403, body={"auth": {"session": session_id}} ) + + +class RefreshAuthTests(unittest.HomeserverTestCase): + servlets = [ + auth.register_servlets, + account.register_servlets, + login.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, + register.register_servlets, + ] + hijack_auth = False + + def prepare(self, reactor, clock, hs): + self.user_pass = "pass" + self.user = self.register_user("test", self.user_pass) + + def test_login_issue_refresh_token(self): + """ + A login response should include a refresh_token only if asked. + """ + # Test login + body = {"type": "m.login.password", "user": "test", "password": self.user_pass} + + login_without_refresh = self.make_request( + "POST", "/_matrix/client/r0/login", body + ) + self.assertEqual(login_without_refresh.code, 200, login_without_refresh.result) + self.assertNotIn("refresh_token", login_without_refresh.json_body) + + login_with_refresh = self.make_request( + "POST", + "/_matrix/client/r0/login?org.matrix.msc2918.refresh_token=true", + body, + ) + self.assertEqual(login_with_refresh.code, 200, login_with_refresh.result) + self.assertIn("refresh_token", login_with_refresh.json_body) + self.assertIn("expires_in_ms", login_with_refresh.json_body) + + def test_register_issue_refresh_token(self): + """ + A register response should include a refresh_token only if asked. + """ + register_without_refresh = self.make_request( + "POST", + "/_matrix/client/r0/register", + { + "username": "test2", + "password": self.user_pass, + "auth": {"type": LoginType.DUMMY}, + }, + ) + self.assertEqual( + register_without_refresh.code, 200, register_without_refresh.result + ) + self.assertNotIn("refresh_token", register_without_refresh.json_body) + + register_with_refresh = self.make_request( + "POST", + "/_matrix/client/r0/register?org.matrix.msc2918.refresh_token=true", + { + "username": "test3", + "password": self.user_pass, + "auth": {"type": LoginType.DUMMY}, + }, + ) + self.assertEqual(register_with_refresh.code, 200, register_with_refresh.result) + self.assertIn("refresh_token", register_with_refresh.json_body) + self.assertIn("expires_in_ms", register_with_refresh.json_body) + + def test_token_refresh(self): + """ + A refresh token can be used to issue a new access token. + """ + body = {"type": "m.login.password", "user": "test", "password": self.user_pass} + login_response = self.make_request( + "POST", + "/_matrix/client/r0/login?org.matrix.msc2918.refresh_token=true", + body, + ) + self.assertEqual(login_response.code, 200, login_response.result) + + refresh_response = self.make_request( + "POST", + "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", + {"refresh_token": login_response.json_body["refresh_token"]}, + ) + self.assertEqual(refresh_response.code, 200, refresh_response.result) + self.assertIn("access_token", refresh_response.json_body) + self.assertIn("refresh_token", refresh_response.json_body) + self.assertIn("expires_in_ms", refresh_response.json_body) + + # The access and refresh tokens should be different from the original ones after refresh + self.assertNotEqual( + login_response.json_body["access_token"], + refresh_response.json_body["access_token"], + ) + self.assertNotEqual( + login_response.json_body["refresh_token"], + refresh_response.json_body["refresh_token"], + ) + + @override_config({"access_token_lifetime": "1m"}) + def test_refresh_token_expiration(self): + """ + The access token should have some time as specified in the config. + """ + body = {"type": "m.login.password", "user": "test", "password": self.user_pass} + login_response = self.make_request( + "POST", + "/_matrix/client/r0/login?org.matrix.msc2918.refresh_token=true", + body, + ) + self.assertEqual(login_response.code, 200, login_response.result) + self.assertApproximates( + login_response.json_body["expires_in_ms"], 60 * 1000, 100 + ) + + refresh_response = self.make_request( + "POST", + "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", + {"refresh_token": login_response.json_body["refresh_token"]}, + ) + self.assertEqual(refresh_response.code, 200, refresh_response.result) + self.assertApproximates( + refresh_response.json_body["expires_in_ms"], 60 * 1000, 100 + ) + + def test_refresh_token_invalidation(self): + """Refresh tokens are invalidated after first use of the next token. + + A refresh token is considered invalid if: + - it was already used at least once + - and either + - the next access token was used + - the next refresh token was used + + The chain of tokens goes like this: + + login -|-> first_refresh -> third_refresh (fails) + |-> second_refresh -> fifth_refresh + |-> fourth_refresh (fails) + """ + + body = {"type": "m.login.password", "user": "test", "password": self.user_pass} + login_response = self.make_request( + "POST", + "/_matrix/client/r0/login?org.matrix.msc2918.refresh_token=true", + body, + ) + self.assertEqual(login_response.code, 200, login_response.result) + + # This first refresh should work properly + first_refresh_response = self.make_request( + "POST", + "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", + {"refresh_token": login_response.json_body["refresh_token"]}, + ) + self.assertEqual( + first_refresh_response.code, 200, first_refresh_response.result + ) + + # This one as well, since the token in the first one was never used + second_refresh_response = self.make_request( + "POST", + "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", + {"refresh_token": login_response.json_body["refresh_token"]}, + ) + self.assertEqual( + second_refresh_response.code, 200, second_refresh_response.result + ) + + # This one should not, since the token from the first refresh is not valid anymore + third_refresh_response = self.make_request( + "POST", + "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", + {"refresh_token": first_refresh_response.json_body["refresh_token"]}, + ) + self.assertEqual( + third_refresh_response.code, 401, third_refresh_response.result + ) + + # The associated access token should also be invalid + whoami_response = self.make_request( + "GET", + "/_matrix/client/r0/account/whoami", + access_token=first_refresh_response.json_body["access_token"], + ) + self.assertEqual(whoami_response.code, 401, whoami_response.result) + + # But all other tokens should work (they will expire after some time) + for access_token in [ + second_refresh_response.json_body["access_token"], + login_response.json_body["access_token"], + ]: + whoami_response = self.make_request( + "GET", "/_matrix/client/r0/account/whoami", access_token=access_token + ) + self.assertEqual(whoami_response.code, 200, whoami_response.result) + + # Now that the access token from the last valid refresh was used once, refreshing with the N-1 token should fail + fourth_refresh_response = self.make_request( + "POST", + "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", + {"refresh_token": login_response.json_body["refresh_token"]}, + ) + self.assertEqual( + fourth_refresh_response.code, 403, fourth_refresh_response.result + ) + + # But refreshing from the last valid refresh token still works + fifth_refresh_response = self.make_request( + "POST", + "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", + {"refresh_token": second_refresh_response.json_body["refresh_token"]}, + ) + self.assertEqual( + fifth_refresh_response.code, 200, fifth_refresh_response.result + ) -- cgit 1.5.1 From f55836929d3c64f3f8d883d8f3643a88b6c9cbca Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 29 Jun 2021 12:00:04 -0400 Subject: Do not recurse into non-spaces in the spaces summary. (#10256) Previously m.child.room events in non-space rooms would be treated as part of the room graph, but this is no longer supported. --- changelog.d/10256.misc | 1 + synapse/api/constants.py | 6 +++++ synapse/handlers/space_summary.py | 11 +++++++-- tests/handlers/test_space_summary.py | 48 +++++++++++++++++++----------------- tests/rest/client/v1/utils.py | 3 ++- 5 files changed, 43 insertions(+), 26 deletions(-) create mode 100644 changelog.d/10256.misc (limited to 'tests/rest') diff --git a/changelog.d/10256.misc b/changelog.d/10256.misc new file mode 100644 index 0000000000..adef12fcb9 --- /dev/null +++ b/changelog.d/10256.misc @@ -0,0 +1 @@ +Improve the performance of the spaces summary endpoint by only recursing into spaces (and not rooms in general). diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 414e4c019a..8363c2bb0f 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -201,6 +201,12 @@ class EventContentFields: ) +class RoomTypes: + """Understood values of the room_type field of m.room.create events.""" + + SPACE = "m.space" + + class RoomEncryptionAlgorithms: MEGOLM_V1_AES_SHA2 = "m.megolm.v1.aes-sha2" DEFAULT = MEGOLM_V1_AES_SHA2 diff --git a/synapse/handlers/space_summary.py b/synapse/handlers/space_summary.py index 17fc47ce16..266f369883 100644 --- a/synapse/handlers/space_summary.py +++ b/synapse/handlers/space_summary.py @@ -25,6 +25,7 @@ from synapse.api.constants import ( EventTypes, HistoryVisibility, Membership, + RoomTypes, ) from synapse.events import EventBase from synapse.events.utils import format_event_for_client_v2 @@ -318,7 +319,8 @@ class SpaceSummaryHandler: Returns: A tuple of: - An iterable of a single value of the room. + The room information, if the room should be returned to the + user. None, otherwise. An iterable of the sorted children events. This may be limited to a maximum size or may include all children. @@ -328,7 +330,11 @@ class SpaceSummaryHandler: room_entry = await self._build_room_entry(room_id) - # look for child rooms/spaces. + # If the room is not a space, return just the room information. + if room_entry.get("room_type") != RoomTypes.SPACE: + return room_entry, () + + # Otherwise, look for child rooms/spaces. child_events = await self._get_child_events(room_id) if suggested_only: @@ -348,6 +354,7 @@ class SpaceSummaryHandler: event_format=format_event_for_client_v2, ) ) + return room_entry, events_result async def _summarize_remote_room( diff --git a/tests/handlers/test_space_summary.py b/tests/handlers/test_space_summary.py index 131d362ccc..9771d3fb3b 100644 --- a/tests/handlers/test_space_summary.py +++ b/tests/handlers/test_space_summary.py @@ -14,6 +14,7 @@ from typing import Any, Iterable, Optional, Tuple from unittest import mock +from synapse.api.constants import EventContentFields, RoomTypes from synapse.api.errors import AuthError from synapse.handlers.space_summary import _child_events_comparison_key from synapse.rest import admin @@ -97,9 +98,21 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): self.hs = hs self.handler = self.hs.get_space_summary_handler() + # Create a user. self.user = self.register_user("user", "pass") self.token = self.login("user", "pass") + # Create a space and a child room. + self.space = self.helper.create_room_as( + self.user, + tok=self.token, + extra_content={ + "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE} + }, + ) + self.room = self.helper.create_room_as(self.user, tok=self.token) + self._add_child(self.space, self.room, self.token) + def _add_child(self, space_id: str, room_id: str, token: str) -> None: """Add a child room to a space.""" self.helper.send_state( @@ -128,43 +141,32 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): def test_simple_space(self): """Test a simple space with a single room.""" - space = self.helper.create_room_as(self.user, tok=self.token) - room = self.helper.create_room_as(self.user, tok=self.token) - self._add_child(space, room, self.token) - - result = self.get_success(self.handler.get_space_summary(self.user, space)) + result = self.get_success(self.handler.get_space_summary(self.user, self.space)) # The result should have the space and the room in it, along with a link # from space -> room. - self._assert_rooms(result, [space, room]) - self._assert_events(result, [(space, room)]) + self._assert_rooms(result, [self.space, self.room]) + self._assert_events(result, [(self.space, self.room)]) def test_visibility(self): """A user not in a space cannot inspect it.""" - space = self.helper.create_room_as(self.user, tok=self.token) - room = self.helper.create_room_as(self.user, tok=self.token) - self._add_child(space, room, self.token) - user2 = self.register_user("user2", "pass") token2 = self.login("user2", "pass") # The user cannot see the space. - self.get_failure(self.handler.get_space_summary(user2, space), AuthError) + self.get_failure(self.handler.get_space_summary(user2, self.space), AuthError) # Joining the room causes it to be visible. - self.helper.join(space, user2, tok=token2) - result = self.get_success(self.handler.get_space_summary(user2, space)) + self.helper.join(self.space, user2, tok=token2) + result = self.get_success(self.handler.get_space_summary(user2, self.space)) # The result should only have the space, but includes the link to the room. - self._assert_rooms(result, [space]) - self._assert_events(result, [(space, room)]) + self._assert_rooms(result, [self.space]) + self._assert_events(result, [(self.space, self.room)]) def test_world_readable(self): """A world-readable room is visible to everyone.""" - space = self.helper.create_room_as(self.user, tok=self.token) - room = self.helper.create_room_as(self.user, tok=self.token) - self._add_child(space, room, self.token) self.helper.send_state( - space, + self.space, event_type="m.room.history_visibility", body={"history_visibility": "world_readable"}, tok=self.token, @@ -173,6 +175,6 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): user2 = self.register_user("user2", "pass") # The space should be visible, as well as the link to the room. - result = self.get_success(self.handler.get_space_summary(user2, space)) - self._assert_rooms(result, [space]) - self._assert_events(result, [(space, room)]) + result = self.get_success(self.handler.get_space_summary(user2, self.space)) + self._assert_rooms(result, [self.space]) + self._assert_events(result, [(self.space, self.room)]) diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index ed55a640af..69798e95c3 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -52,6 +52,7 @@ class RestHelper: room_version: str = None, tok: str = None, expect_code: int = 200, + extra_content: Optional[Dict] = None, ) -> str: """ Create a room. @@ -72,7 +73,7 @@ class RestHelper: temp_id = self.auth_user_id self.auth_user_id = room_creator path = "/_matrix/client/r0/createRoom" - content = {} + content = extra_content or {} if not is_public: content["visibility"] = "private" if room_version: -- cgit 1.5.1 From 6c02cca95f8136010062b6af0fa36a2906a96a6b Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Thu, 1 Jul 2021 11:26:24 +0200 Subject: Add SSO `external_ids` to Query User Account admin API (#10261) Related to #10251 --- changelog.d/10261.feature | 1 + docs/admin_api/user_admin_api.md | 12 ++- synapse/handlers/admin.py | 7 ++ tests/rest/admin/test_user.py | 224 ++++++++++++++++++++++++--------------- 4 files changed, 159 insertions(+), 85 deletions(-) create mode 100644 changelog.d/10261.feature (limited to 'tests/rest') diff --git a/changelog.d/10261.feature b/changelog.d/10261.feature new file mode 100644 index 0000000000..cd55cecbd5 --- /dev/null +++ b/changelog.d/10261.feature @@ -0,0 +1 @@ +Add SSO `external_ids` to the Query User Account admin API. diff --git a/docs/admin_api/user_admin_api.md b/docs/admin_api/user_admin_api.md index ef1e735e33..4a65d0c3bc 100644 --- a/docs/admin_api/user_admin_api.md +++ b/docs/admin_api/user_admin_api.md @@ -36,7 +36,17 @@ It returns a JSON body like the following: "creation_ts": 1560432506, "appservice_id": null, "consent_server_notice_sent": null, - "consent_version": null + "consent_version": null, + "external_ids": [ + { + "auth_provider": "", + "external_id": "" + }, + { + "auth_provider": "", + "external_id": "" + } + ] } ``` diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index f72ded038e..d75a8b15c3 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -62,9 +62,16 @@ class AdminHandler(BaseHandler): if ret: profile = await self.store.get_profileinfo(user.localpart) threepids = await self.store.user_get_threepids(user.to_string()) + external_ids = [ + ({"auth_provider": auth_provider, "external_id": external_id}) + for auth_provider, external_id in await self.store.get_external_ids_by_user( + user.to_string() + ) + ] ret["displayname"] = profile.display_name ret["avatar_url"] = profile.avatar_url ret["threepids"] = threepids + ret["external_ids"] = external_ids return ret async def export_user_data(self, user_id: str, writer: "ExfiltrationWriter") -> Any: diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index d599a4c984..a34d051734 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -1150,7 +1150,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.other_user_token, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual("You are not a server admin", channel.json_body["error"]) channel = self.make_request( @@ -1160,7 +1160,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content=b"{}", ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual("You are not a server admin", channel.json_body["error"]) def test_user_does_not_exist(self): @@ -1177,6 +1177,58 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual("M_NOT_FOUND", channel.json_body["errcode"]) + def test_get_user(self): + """ + Test a simple get of a user. + """ + channel = self.make_request( + "GET", + self.url_other_user, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual("User", channel.json_body["displayname"]) + self._check_fields(channel.json_body) + + def test_get_user_with_sso(self): + """ + Test get a user with SSO details. + """ + self.get_success( + self.store.record_user_external_id( + "auth_provider1", "external_id1", self.other_user + ) + ) + self.get_success( + self.store.record_user_external_id( + "auth_provider2", "external_id2", self.other_user + ) + ) + + channel = self.make_request( + "GET", + self.url_other_user, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual( + "external_id1", channel.json_body["external_ids"][0]["external_id"] + ) + self.assertEqual( + "auth_provider1", channel.json_body["external_ids"][0]["auth_provider"] + ) + self.assertEqual( + "external_id2", channel.json_body["external_ids"][1]["external_id"] + ) + self.assertEqual( + "auth_provider2", channel.json_body["external_ids"][1]["auth_provider"] + ) + self._check_fields(channel.json_body) + def test_create_server_admin(self): """ Check that a new admin user is created successfully. @@ -1184,30 +1236,29 @@ class UserRestTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v2/users/@bob:test" # Create user (server admin) - body = json.dumps( - { - "password": "abc123", - "admin": True, - "displayname": "Bob's name", - "threepids": [{"medium": "email", "address": "bob@bob.bob"}], - "avatar_url": "mxc://fibble/wibble", - } - ) + body = { + "password": "abc123", + "admin": True, + "displayname": "Bob's name", + "threepids": [{"medium": "email", "address": "bob@bob.bob"}], + "avatar_url": "mxc://fibble/wibble", + } channel = self.make_request( "PUT", url, access_token=self.admin_user_tok, - content=body.encode(encoding="utf_8"), + content=body, ) - self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(201, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) self.assertTrue(channel.json_body["admin"]) self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"]) + self._check_fields(channel.json_body) # Get user channel = self.make_request( @@ -1216,7 +1267,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) @@ -1225,6 +1276,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertFalse(channel.json_body["is_guest"]) self.assertFalse(channel.json_body["deactivated"]) self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"]) + self._check_fields(channel.json_body) def test_create_user(self): """ @@ -1233,30 +1285,29 @@ class UserRestTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v2/users/@bob:test" # Create user - body = json.dumps( - { - "password": "abc123", - "admin": False, - "displayname": "Bob's name", - "threepids": [{"medium": "email", "address": "bob@bob.bob"}], - "avatar_url": "mxc://fibble/wibble", - } - ) + body = { + "password": "abc123", + "admin": False, + "displayname": "Bob's name", + "threepids": [{"medium": "email", "address": "bob@bob.bob"}], + "avatar_url": "mxc://fibble/wibble", + } channel = self.make_request( "PUT", url, access_token=self.admin_user_tok, - content=body.encode(encoding="utf_8"), + content=body, ) - self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(201, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) self.assertFalse(channel.json_body["admin"]) self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"]) + self._check_fields(channel.json_body) # Get user channel = self.make_request( @@ -1265,7 +1316,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) @@ -1275,6 +1326,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertFalse(channel.json_body["deactivated"]) self.assertFalse(channel.json_body["shadow_banned"]) self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"]) + self._check_fields(channel.json_body) @override_config( {"limit_usage_by_mau": True, "max_mau_value": 2, "mau_trial_days": 0} @@ -1311,16 +1363,14 @@ class UserRestTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v2/users/@bob:test" # Create user - body = json.dumps({"password": "abc123", "admin": False}) - channel = self.make_request( "PUT", url, access_token=self.admin_user_tok, - content=body.encode(encoding="utf_8"), + content={"password": "abc123", "admin": False}, ) - self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(201, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertFalse(channel.json_body["admin"]) @@ -1350,17 +1400,15 @@ class UserRestTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v2/users/@bob:test" # Create user - body = json.dumps({"password": "abc123", "admin": False}) - channel = self.make_request( "PUT", url, access_token=self.admin_user_tok, - content=body.encode(encoding="utf_8"), + content={"password": "abc123", "admin": False}, ) # Admin user is not blocked by mau anymore - self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(201, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertFalse(channel.json_body["admin"]) @@ -1382,21 +1430,19 @@ class UserRestTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v2/users/@bob:test" # Create user - body = json.dumps( - { - "password": "abc123", - "threepids": [{"medium": "email", "address": "bob@bob.bob"}], - } - ) + body = { + "password": "abc123", + "threepids": [{"medium": "email", "address": "bob@bob.bob"}], + } channel = self.make_request( "PUT", url, access_token=self.admin_user_tok, - content=body.encode(encoding="utf_8"), + content=body, ) - self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(201, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) @@ -1426,21 +1472,19 @@ class UserRestTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v2/users/@bob:test" # Create user - body = json.dumps( - { - "password": "abc123", - "threepids": [{"medium": "email", "address": "bob@bob.bob"}], - } - ) + body = { + "password": "abc123", + "threepids": [{"medium": "email", "address": "bob@bob.bob"}], + } channel = self.make_request( "PUT", url, access_token=self.admin_user_tok, - content=body.encode(encoding="utf_8"), + content=body, ) - self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(201, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) @@ -1457,16 +1501,15 @@ class UserRestTestCase(unittest.HomeserverTestCase): """ # Change password - body = json.dumps({"password": "hahaha"}) - channel = self.make_request( "PUT", self.url_other_user, access_token=self.admin_user_tok, - content=body.encode(encoding="utf_8"), + content={"password": "hahaha"}, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) + self._check_fields(channel.json_body) def test_set_displayname(self): """ @@ -1474,16 +1517,14 @@ class UserRestTestCase(unittest.HomeserverTestCase): """ # Modify user - body = json.dumps({"displayname": "foobar"}) - channel = self.make_request( "PUT", self.url_other_user, access_token=self.admin_user_tok, - content=body.encode(encoding="utf_8"), + content={"displayname": "foobar"}, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("foobar", channel.json_body["displayname"]) @@ -1494,7 +1535,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("foobar", channel.json_body["displayname"]) @@ -1504,18 +1545,14 @@ class UserRestTestCase(unittest.HomeserverTestCase): """ # Delete old and add new threepid to user - body = json.dumps( - {"threepids": [{"medium": "email", "address": "bob3@bob.bob"}]} - ) - channel = self.make_request( "PUT", self.url_other_user, access_token=self.admin_user_tok, - content=body.encode(encoding="utf_8"), + content={"threepids": [{"medium": "email", "address": "bob3@bob.bob"}]}, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][0]["address"]) @@ -1527,7 +1564,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][0]["address"]) @@ -1552,7 +1589,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertFalse(channel.json_body["deactivated"]) self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) @@ -1567,7 +1604,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content={"deactivated": True}, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertTrue(channel.json_body["deactivated"]) self.assertIsNone(channel.json_body["password_hash"]) @@ -1583,7 +1620,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertTrue(channel.json_body["deactivated"]) self.assertIsNone(channel.json_body["password_hash"]) @@ -1610,7 +1647,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content={"deactivated": True}, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertTrue(channel.json_body["deactivated"]) @@ -1626,7 +1663,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content={"displayname": "Foobar"}, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertTrue(channel.json_body["deactivated"]) self.assertEqual("Foobar", channel.json_body["displayname"]) @@ -1650,7 +1687,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"deactivated": False}, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) # Reactivate the user. channel = self.make_request( @@ -1659,7 +1696,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"deactivated": False, "password": "foo"}, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertFalse(channel.json_body["deactivated"]) self.assertIsNotNone(channel.json_body["password_hash"]) @@ -1681,7 +1718,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"deactivated": False, "password": "foo"}, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) # Reactivate the user without a password. @@ -1691,7 +1728,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"deactivated": False}, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertFalse(channel.json_body["deactivated"]) self.assertIsNone(channel.json_body["password_hash"]) @@ -1713,7 +1750,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"deactivated": False, "password": "foo"}, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) # Reactivate the user without a password. @@ -1723,7 +1760,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"deactivated": False}, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertFalse(channel.json_body["deactivated"]) self.assertIsNone(channel.json_body["password_hash"]) @@ -1742,7 +1779,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content={"admin": True}, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertTrue(channel.json_body["admin"]) @@ -1753,7 +1790,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertTrue(channel.json_body["admin"]) @@ -1772,7 +1809,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content={"password": "abc123"}, ) - self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(201, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("bob", channel.json_body["displayname"]) @@ -1783,7 +1820,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("bob", channel.json_body["displayname"]) self.assertEqual(0, channel.json_body["deactivated"]) @@ -1796,7 +1833,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content={"password": "abc123", "deactivated": "false"}, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) # Check user is not deactivated channel = self.make_request( @@ -1805,7 +1842,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("bob", channel.json_body["displayname"]) @@ -1830,7 +1867,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"deactivated": True}, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertTrue(channel.json_body["deactivated"]) self.assertIsNone(channel.json_body["password_hash"]) self._is_erased(user_id, False) @@ -1838,6 +1875,25 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertIsNone(self.get_success(d)) self._is_erased(user_id, True) + def _check_fields(self, content: JsonDict): + """Checks that the expected user attributes are present in content + + Args: + content: Content dictionary to check + """ + self.assertIn("displayname", content) + self.assertIn("threepids", content) + self.assertIn("avatar_url", content) + self.assertIn("admin", content) + self.assertIn("deactivated", content) + self.assertIn("shadow_banned", content) + self.assertIn("password_hash", content) + self.assertIn("creation_ts", content) + self.assertIn("appservice_id", content) + self.assertIn("consent_server_notice_sent", content) + self.assertIn("consent_version", content) + self.assertIn("external_ids", content) + class UserMembershipRestTestCase(unittest.HomeserverTestCase): -- cgit 1.5.1 From bcb0962a7250d6c1430ad42f5ed234ffea8f2468 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Tue, 6 Jul 2021 14:08:53 +0200 Subject: Fix deactivate a user if he does not have a profile (#10252) --- changelog.d/10252.bugfix | 1 + synapse/storage/databases/main/profile.py | 8 +-- tests/rest/admin/test_user.py | 86 ++++++++++++++++++++++++------- 3 files changed, 73 insertions(+), 22 deletions(-) create mode 100644 changelog.d/10252.bugfix (limited to 'tests/rest') diff --git a/changelog.d/10252.bugfix b/changelog.d/10252.bugfix new file mode 100644 index 0000000000..c8ddd14528 --- /dev/null +++ b/changelog.d/10252.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in v1.26.0 where only users who have set profile information could be deactivated with erasure enabled. diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py index 9b4e95e134..ba7075caa5 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py @@ -73,20 +73,20 @@ class ProfileWorkerStore(SQLBaseStore): async def set_profile_displayname( self, user_localpart: str, new_displayname: Optional[str] ) -> None: - await self.db_pool.simple_update_one( + await self.db_pool.simple_upsert( table="profiles", keyvalues={"user_id": user_localpart}, - updatevalues={"displayname": new_displayname}, + values={"displayname": new_displayname}, desc="set_profile_displayname", ) async def set_profile_avatar_url( self, user_localpart: str, new_avatar_url: Optional[str] ) -> None: - await self.db_pool.simple_update_one( + await self.db_pool.simple_upsert( table="profiles", keyvalues={"user_id": user_localpart}, - updatevalues={"avatar_url": new_avatar_url}, + values={"avatar_url": new_avatar_url}, desc="set_profile_avatar_url", ) diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index a34d051734..4fccce34fd 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -939,7 +939,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): """ channel = self.make_request("POST", self.url, b"{}") - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_requester_is_not_admin(self): @@ -950,7 +950,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): channel = self.make_request("POST", url, access_token=self.other_user_token) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual("You are not a server admin", channel.json_body["error"]) channel = self.make_request( @@ -960,7 +960,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): content=b"{}", ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual("You are not a server admin", channel.json_body["error"]) def test_user_does_not_exist(self): @@ -990,7 +990,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) def test_user_is_not_local(self): @@ -1006,7 +1006,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): def test_deactivate_user_erase_true(self): """ - Test deactivating an user and set `erase` to `true` + Test deactivating a user and set `erase` to `true` """ # Get user @@ -1016,24 +1016,22 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(False, channel.json_body["deactivated"]) self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"]) self.assertEqual("User1", channel.json_body["displayname"]) - # Deactivate user - body = json.dumps({"erase": True}) - + # Deactivate and erase user channel = self.make_request( "POST", self.url, access_token=self.admin_user_tok, - content=body.encode(encoding="utf_8"), + content={"erase": True}, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) # Get user channel = self.make_request( @@ -1042,7 +1040,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(True, channel.json_body["deactivated"]) self.assertEqual(0, len(channel.json_body["threepids"])) @@ -1053,7 +1051,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): def test_deactivate_user_erase_false(self): """ - Test deactivating an user and set `erase` to `false` + Test deactivating a user and set `erase` to `false` """ # Get user @@ -1063,7 +1061,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(False, channel.json_body["deactivated"]) self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) @@ -1071,13 +1069,11 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): self.assertEqual("User1", channel.json_body["displayname"]) # Deactivate user - body = json.dumps({"erase": False}) - channel = self.make_request( "POST", self.url, access_token=self.admin_user_tok, - content=body.encode(encoding="utf_8"), + content={"erase": False}, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -1089,7 +1085,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(True, channel.json_body["deactivated"]) self.assertEqual(0, len(channel.json_body["threepids"])) @@ -1098,6 +1094,60 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): self._is_erased("@user:test", False) + def test_deactivate_user_erase_true_no_profile(self): + """ + Test deactivating a user and set `erase` to `true` + if user has no profile information (stored in the database table `profiles`). + """ + + # Users normally have an entry in `profiles`, but occasionally they are created without one. + # To test deactivation for users without a profile, we delete the profile information for our user. + self.get_success( + self.store.db_pool.simple_delete_one( + table="profiles", keyvalues={"user_id": "user"} + ) + ) + + # Get user + channel = self.make_request( + "GET", + self.url_other_user, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(False, channel.json_body["deactivated"]) + self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) + self.assertIsNone(channel.json_body["avatar_url"]) + self.assertIsNone(channel.json_body["displayname"]) + + # Deactivate and erase user + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content={"erase": True}, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + + # Get user + channel = self.make_request( + "GET", + self.url_other_user, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(True, channel.json_body["deactivated"]) + self.assertEqual(0, len(channel.json_body["threepids"])) + self.assertIsNone(channel.json_body["avatar_url"]) + self.assertIsNone(channel.json_body["displayname"]) + + self._is_erased("@user:test", True) + def _is_erased(self, user_id: str, expect: bool) -> None: """Assert that the user is erased or not""" d = self.store.is_user_erased(user_id) -- cgit 1.5.1