From a4a5ec4096f8de938f4a6e4264aeaaa0e0b26463 Mon Sep 17 00:00:00 2001 From: Manuel Stahl <37705355+awesome-manuel@users.noreply.github.com> Date: Thu, 7 May 2020 21:33:07 +0200 Subject: Add room details admin endpoint (#7317) --- synapse/rest/admin/__init__.py | 2 ++ synapse/rest/admin/rooms.py | 26 +++++++++++++++++++++++++- 2 files changed, 27 insertions(+), 1 deletion(-) (limited to 'synapse/rest') diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index ed70d448a1..6b85148a32 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -32,6 +32,7 @@ from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet from synapse.rest.admin.rooms import ( JoinRoomAliasServlet, ListRoomRestServlet, + RoomRestServlet, ShutdownRoomRestServlet, ) from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet @@ -193,6 +194,7 @@ def register_servlets(hs, http_server): """ register_servlets_for_client_rest_resource(hs, http_server) ListRoomRestServlet(hs).register(http_server) + RoomRestServlet(hs).register(http_server) JoinRoomAliasServlet(hs).register(http_server) PurgeRoomServlet(hs).register(http_server) SendServerNoticeServlet(hs).register(http_server) diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index d1bdb64111..7d40001988 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -26,6 +26,7 @@ from synapse.http.servlet import ( ) from synapse.rest.admin._base import ( admin_patterns, + assert_requester_is_admin, assert_user_is_admin, historical_admin_path_patterns, ) @@ -169,7 +170,7 @@ class ListRoomRestServlet(RestServlet): in a dictionary containing room information. Supports pagination. """ - PATTERNS = admin_patterns("/rooms") + PATTERNS = admin_patterns("/rooms$") def __init__(self, hs): self.store = hs.get_datastore() @@ -253,6 +254,29 @@ class ListRoomRestServlet(RestServlet): return 200, response +class RoomRestServlet(RestServlet): + """Get room details. + + TODO: Add on_POST to allow room creation without joining the room + """ + + PATTERNS = admin_patterns("/rooms/(?P[^/]+)$") + + def __init__(self, hs): + self.hs = hs + self.auth = hs.get_auth() + self.store = hs.get_datastore() + + async def on_GET(self, request, room_id): + await assert_requester_is_admin(self.auth, request) + + ret = await self.store.get_room_with_stats(room_id) + if not ret: + raise NotFoundError("Room not found") + + return 200, ret + + class JoinRoomAliasServlet(RestServlet): PATTERNS = admin_patterns("/join/(?P[^/]*)") -- cgit 1.5.1 From 616af44137c78d481024da83bb51ed0d50a49522 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Fri, 8 May 2020 14:30:40 +0200 Subject: Implement OpenID Connect-based login (#7256) --- changelog.d/7256.feature | 1 + docs/dev/oidc.md | 175 ++++++ docs/sample_config.yaml | 95 ++++ mypy.ini | 3 + synapse/app/homeserver.py | 12 + synapse/config/_base.pyi | 2 + synapse/config/homeserver.py | 2 + synapse/config/oidc_config.py | 177 ++++++ synapse/config/sso.py | 17 +- synapse/handlers/auth.py | 4 +- synapse/handlers/oidc_handler.py | 998 +++++++++++++++++++++++++++++++++ synapse/http/client.py | 7 + synapse/python_dependencies.py | 1 + synapse/res/templates/sso_error.html | 18 + synapse/rest/client/v1/login.py | 28 +- synapse/rest/oidc/__init__.py | 27 + synapse/rest/oidc/callback_resource.py | 31 + synapse/server.py | 6 + synapse/server.pyi | 5 + tests/handlers/test_oidc.py | 565 +++++++++++++++++++ tox.ini | 1 + 21 files changed, 2163 insertions(+), 12 deletions(-) create mode 100644 changelog.d/7256.feature create mode 100644 docs/dev/oidc.md create mode 100644 synapse/config/oidc_config.py create mode 100644 synapse/handlers/oidc_handler.py create mode 100644 synapse/res/templates/sso_error.html create mode 100644 synapse/rest/oidc/__init__.py create mode 100644 synapse/rest/oidc/callback_resource.py create mode 100644 tests/handlers/test_oidc.py (limited to 'synapse/rest') diff --git a/changelog.d/7256.feature b/changelog.d/7256.feature new file mode 100644 index 0000000000..7ad767bf71 --- /dev/null +++ b/changelog.d/7256.feature @@ -0,0 +1 @@ +Add OpenID Connect login/registration support. Contributed by Quentin Gliech, on behalf of [les Connecteurs](https://connecteu.rs). diff --git a/docs/dev/oidc.md b/docs/dev/oidc.md new file mode 100644 index 0000000000..a90c5d2441 --- /dev/null +++ b/docs/dev/oidc.md @@ -0,0 +1,175 @@ +# How to test OpenID Connect + +Any OpenID Connect Provider (OP) should work with Synapse, as long as it supports the authorization code flow. +There are a few options for that: + + - start a local OP. Synapse has been tested with [Hydra][hydra] and [Dex][dex-idp]. + Note that for an OP to work, it should be served under a secure (HTTPS) origin. + A certificate signed with a self-signed, locally trusted CA should work. In that case, start Synapse with a `SSL_CERT_FILE` environment variable set to the path of the CA. + - use a publicly available OP. Synapse has been tested with [Google][google-idp]. + - setup a SaaS OP, like [Auth0][auth0] and [Okta][okta]. Auth0 has a free tier which has been tested with Synapse. + +[google-idp]: https://developers.google.com/identity/protocols/OpenIDConnect#authenticatingtheuser +[auth0]: https://auth0.com/ +[okta]: https://www.okta.com/ +[dex-idp]: https://github.com/dexidp/dex +[hydra]: https://www.ory.sh/docs/hydra/ + + +## Sample configs + +Here are a few configs for providers that should work with Synapse. + +### [Dex][dex-idp] + +[Dex][dex-idp] is a simple, open-source, certified OpenID Connect Provider. +Although it is designed to help building a full-blown provider, with some external database, it can be configured with static passwords in a config file. + +Follow the [Getting Started guide](https://github.com/dexidp/dex/blob/master/Documentation/getting-started.md) to install Dex. + +Edit `examples/config-dev.yaml` config file from the Dex repo to add a client: + +```yaml +staticClients: +- id: synapse + secret: secret + redirectURIs: + - '[synapse base url]/_synapse/oidc/callback' + name: 'Synapse' +``` + +Run with `dex serve examples/config-dex.yaml` + +Synapse config: + +```yaml +oidc_config: + enabled: true + skip_verification: true # This is needed as Dex is served on an insecure endpoint + issuer: "http://127.0.0.1:5556/dex" + discover: true + client_id: "synapse" + client_secret: "secret" + scopes: + - openid + - profile + user_mapping_provider: + config: + localpart_template: '{{ user.name }}' + display_name_template: '{{ user.name|capitalize }}' +``` + +### [Auth0][auth0] + +1. Create a regular web application for Synapse +2. Set the Allowed Callback URLs to `[synapse base url]/_synapse/oidc/callback` +3. Add a rule to add the `preferred_username` claim. +
+ Code sample + + ```js + function addPersistenceAttribute(user, context, callback) { + user.user_metadata = user.user_metadata || {}; + user.user_metadata.preferred_username = user.user_metadata.preferred_username || user.user_id; + context.idToken.preferred_username = user.user_metadata.preferred_username; + + auth0.users.updateUserMetadata(user.user_id, user.user_metadata) + .then(function(){ + callback(null, user, context); + }) + .catch(function(err){ + callback(err); + }); + } + ``` + +
+ + +```yaml +oidc_config: + enabled: true + issuer: "https://your-tier.eu.auth0.com/" # TO BE FILLED + discover: true + client_id: "your-client-id" # TO BE FILLED + client_secret: "your-client-secret" # TO BE FILLED + scopes: + - openid + - profile + user_mapping_provider: + config: + localpart_template: '{{ user.preferred_username }}' + display_name_template: '{{ user.name }}' +``` + +### GitHub + +GitHub is a bit special as it is not an OpenID Connect compliant provider, but just a regular OAuth2 provider. +The `/user` API endpoint can be used to retrieve informations from the user. +As the OIDC login mechanism needs an attribute to uniquely identify users and that endpoint does not return a `sub` property, an alternative `subject_claim` has to be set. + +1. Create a new OAuth application: https://github.com/settings/applications/new +2. Set the callback URL to `[synapse base url]/_synapse/oidc/callback` + +```yaml +oidc_config: + enabled: true + issuer: "https://github.com/" + discover: false + client_id: "your-client-id" # TO BE FILLED + client_secret: "your-client-secret" # TO BE FILLED + authorization_endpoint: "https://github.com/login/oauth/authorize" + token_endpoint: "https://github.com/login/oauth/access_token" + userinfo_endpoint: "https://api.github.com/user" + scopes: + - read:user + user_mapping_provider: + config: + subject_claim: 'id' + localpart_template: '{{ user.login }}' + display_name_template: '{{ user.name }}' +``` + +### Google + +1. Setup a project in the Google API Console +2. Obtain the OAuth 2.0 credentials (see ) +3. Add this Authorized redirect URI: `[synapse base url]/_synapse/oidc/callback` + +```yaml +oidc_config: + enabled: true + issuer: "https://accounts.google.com/" + discover: true + client_id: "your-client-id" # TO BE FILLED + client_secret: "your-client-secret" # TO BE FILLED + scopes: + - openid + - profile + user_mapping_provider: + config: + localpart_template: '{{ user.given_name|lower }}' + display_name_template: '{{ user.name }}' +``` + +### Twitch + +1. Setup a developer account on [Twitch](https://dev.twitch.tv/) +2. Obtain the OAuth 2.0 credentials by [creating an app](https://dev.twitch.tv/console/apps/) +3. Add this OAuth Redirect URL: `[synapse base url]/_synapse/oidc/callback` + +```yaml +oidc_config: + enabled: true + issuer: "https://id.twitch.tv/oauth2/" + discover: true + client_id: "your-client-id" # TO BE FILLED + client_secret: "your-client-secret" # TO BE FILLED + client_auth_method: "client_secret_post" + scopes: + - openid + user_mapping_provider: + config: + localpart_template: '{{ user.preferred_username }}' + display_name_template: '{{ user.name }}' +``` diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 98ead7dc0e..1e397f7734 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1470,6 +1470,94 @@ saml2_config: #template_dir: "res/templates" +# Enable OpenID Connect for registration and login. Uses authlib. +# +oidc_config: + # enable OpenID Connect. Defaults to false. + # + #enabled: true + + # use the OIDC discovery mechanism to discover endpoints. Defaults to true. + # + #discover: true + + # the OIDC issuer. Used to validate tokens and discover the providers endpoints. Required. + # + #issuer: "https://accounts.example.com/" + + # oauth2 client id to use. Required. + # + #client_id: "provided-by-your-issuer" + + # oauth2 client secret to use. Required. + # + #client_secret: "provided-by-your-issuer" + + # auth method to use when exchanging the token. + # Valid values are "client_secret_basic" (default), "client_secret_post" and "none". + # + #client_auth_method: "client_auth_basic" + + # list of scopes to ask. This should include the "openid" scope. Defaults to ["openid"]. + # + #scopes: ["openid"] + + # the oauth2 authorization endpoint. Required if provider discovery is disabled. + # + #authorization_endpoint: "https://accounts.example.com/oauth2/auth" + + # the oauth2 token endpoint. Required if provider discovery is disabled. + # + #token_endpoint: "https://accounts.example.com/oauth2/token" + + # the OIDC userinfo endpoint. Required if discovery is disabled and the "openid" scope is not asked. + # + #userinfo_endpoint: "https://accounts.example.com/userinfo" + + # URI where to fetch the JWKS. Required if discovery is disabled and the "openid" scope is used. + # + #jwks_uri: "https://accounts.example.com/.well-known/jwks.json" + + # skip metadata verification. Defaults to false. + # Use this if you are connecting to a provider that is not OpenID Connect compliant. + # Avoid this in production. + # + #skip_verification: false + + + # An external module can be provided here as a custom solution to mapping + # attributes returned from a OIDC provider onto a matrix user. + # + user_mapping_provider: + # The custom module's class. Uncomment to use a custom module. + # Default is 'synapse.handlers.oidc_handler.JinjaOidcMappingProvider'. + # + #module: mapping_provider.OidcMappingProvider + + # Custom configuration values for the module. Below options are intended + # for the built-in provider, they should be changed if using a custom + # module. This section will be passed as a Python dictionary to the + # module's `parse_config` method. + # + # Below is the config of the default mapping provider, based on Jinja2 + # templates. Those templates are used to render user attributes, where the + # userinfo object is available through the `user` variable. + # + config: + # name of the claim containing a unique identifier for the user. + # Defaults to `sub`, which OpenID Connect compliant providers should provide. + # + #subject_claim: "sub" + + # Jinja2 template for the localpart of the MXID + # + localpart_template: "{{ user.preferred_username }}" + + # Jinja2 template for the display name to set on first login. Optional. + # + #display_name_template: "{{ user.given_name }} {{ user.last_name }}" + + # Enable CAS for registration and login. # @@ -1554,6 +1642,13 @@ sso: # # This template has no additional variables. # + # * HTML page to display to users if something goes wrong during the + # OpenID Connect authentication process: 'sso_error.html'. + # + # When rendering, this template is given two variables: + # * error: the technical name of the error + # * error_description: a human-readable message for the error + # # You can see the default templates at: # https://github.com/matrix-org/synapse/tree/master/synapse/res/templates # diff --git a/mypy.ini b/mypy.ini index 69be2f67ad..3533797d68 100644 --- a/mypy.ini +++ b/mypy.ini @@ -75,3 +75,6 @@ ignore_missing_imports = True [mypy-jwt.*] ignore_missing_imports = True + +[mypy-authlib.*] +ignore_missing_imports = True diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index cbd1ea475a..bc8695d8dd 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -192,6 +192,11 @@ class SynapseHomeServer(HomeServer): } ) + if self.get_config().oidc_enabled: + from synapse.rest.oidc import OIDCResource + + resources["/_synapse/oidc"] = OIDCResource(self) + if self.get_config().saml2_enabled: from synapse.rest.saml2 import SAML2Resource @@ -422,6 +427,13 @@ def setup(config_options): # Check if it needs to be reprovisioned every day. hs.get_clock().looping_call(reprovision_acme, 24 * 60 * 60 * 1000) + # Load the OIDC provider metadatas, if OIDC is enabled. + if hs.config.oidc_enabled: + oidc = hs.get_oidc_handler() + # Loading the provider metadata also ensures the provider config is valid. + yield defer.ensureDeferred(oidc.load_metadata()) + yield defer.ensureDeferred(oidc.load_jwks()) + _base.start(hs, config.listeners) hs.get_datastore().db.updates.start_doing_background_updates() diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi index 3053fc9d27..9e576060d4 100644 --- a/synapse/config/_base.pyi +++ b/synapse/config/_base.pyi @@ -13,6 +13,7 @@ from synapse.config import ( key, logger, metrics, + oidc_config, password, password_auth_providers, push, @@ -59,6 +60,7 @@ class RootConfig: saml2: saml2_config.SAML2Config cas: cas.CasConfig sso: sso.SSOConfig + oidc: oidc_config.OIDCConfig jwt: jwt_config.JWTConfig password: password.PasswordConfig email: emailconfig.EmailConfig diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index be6c6afa74..996d3e6bf7 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -27,6 +27,7 @@ from .jwt_config import JWTConfig from .key import KeyConfig from .logger import LoggingConfig from .metrics import MetricsConfig +from .oidc_config import OIDCConfig from .password import PasswordConfig from .password_auth_providers import PasswordAuthProviderConfig from .push import PushConfig @@ -66,6 +67,7 @@ class HomeServerConfig(RootConfig): AppServiceConfig, KeyConfig, SAML2Config, + OIDCConfig, CasConfig, SSOConfig, JWTConfig, diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py new file mode 100644 index 0000000000..5af110745e --- /dev/null +++ b/synapse/config/oidc_config.py @@ -0,0 +1,177 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Quentin Gliech +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.python_dependencies import DependencyException, check_requirements +from synapse.util.module_loader import load_module + +from ._base import Config, ConfigError + +DEFAULT_USER_MAPPING_PROVIDER = "synapse.handlers.oidc_handler.JinjaOidcMappingProvider" + + +class OIDCConfig(Config): + section = "oidc" + + def read_config(self, config, **kwargs): + self.oidc_enabled = False + + oidc_config = config.get("oidc_config") + + if not oidc_config or not oidc_config.get("enabled", False): + return + + try: + check_requirements("oidc") + except DependencyException as e: + raise ConfigError(e.message) + + public_baseurl = self.public_baseurl + if public_baseurl is None: + raise ConfigError("oidc_config requires a public_baseurl to be set") + self.oidc_callback_url = public_baseurl + "_synapse/oidc/callback" + + self.oidc_enabled = True + self.oidc_discover = oidc_config.get("discover", True) + self.oidc_issuer = oidc_config["issuer"] + self.oidc_client_id = oidc_config["client_id"] + self.oidc_client_secret = oidc_config["client_secret"] + self.oidc_client_auth_method = oidc_config.get( + "client_auth_method", "client_secret_basic" + ) + self.oidc_scopes = oidc_config.get("scopes", ["openid"]) + self.oidc_authorization_endpoint = oidc_config.get("authorization_endpoint") + self.oidc_token_endpoint = oidc_config.get("token_endpoint") + self.oidc_userinfo_endpoint = oidc_config.get("userinfo_endpoint") + self.oidc_jwks_uri = oidc_config.get("jwks_uri") + self.oidc_subject_claim = oidc_config.get("subject_claim", "sub") + self.oidc_skip_verification = oidc_config.get("skip_verification", False) + + ump_config = oidc_config.get("user_mapping_provider", {}) + ump_config.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER) + ump_config.setdefault("config", {}) + + ( + self.oidc_user_mapping_provider_class, + self.oidc_user_mapping_provider_config, + ) = load_module(ump_config) + + # Ensure loaded user mapping module has defined all necessary methods + required_methods = [ + "get_remote_user_id", + "map_user_attributes", + ] + missing_methods = [ + method + for method in required_methods + if not hasattr(self.oidc_user_mapping_provider_class, method) + ] + if missing_methods: + raise ConfigError( + "Class specified by oidc_config." + "user_mapping_provider.module is missing required " + "methods: %s" % (", ".join(missing_methods),) + ) + + def generate_config_section(self, config_dir_path, server_name, **kwargs): + return """\ + # Enable OpenID Connect for registration and login. Uses authlib. + # + oidc_config: + # enable OpenID Connect. Defaults to false. + # + #enabled: true + + # use the OIDC discovery mechanism to discover endpoints. Defaults to true. + # + #discover: true + + # the OIDC issuer. Used to validate tokens and discover the providers endpoints. Required. + # + #issuer: "https://accounts.example.com/" + + # oauth2 client id to use. Required. + # + #client_id: "provided-by-your-issuer" + + # oauth2 client secret to use. Required. + # + #client_secret: "provided-by-your-issuer" + + # auth method to use when exchanging the token. + # Valid values are "client_secret_basic" (default), "client_secret_post" and "none". + # + #client_auth_method: "client_auth_basic" + + # list of scopes to ask. This should include the "openid" scope. Defaults to ["openid"]. + # + #scopes: ["openid"] + + # the oauth2 authorization endpoint. Required if provider discovery is disabled. + # + #authorization_endpoint: "https://accounts.example.com/oauth2/auth" + + # the oauth2 token endpoint. Required if provider discovery is disabled. + # + #token_endpoint: "https://accounts.example.com/oauth2/token" + + # the OIDC userinfo endpoint. Required if discovery is disabled and the "openid" scope is not asked. + # + #userinfo_endpoint: "https://accounts.example.com/userinfo" + + # URI where to fetch the JWKS. Required if discovery is disabled and the "openid" scope is used. + # + #jwks_uri: "https://accounts.example.com/.well-known/jwks.json" + + # skip metadata verification. Defaults to false. + # Use this if you are connecting to a provider that is not OpenID Connect compliant. + # Avoid this in production. + # + #skip_verification: false + + + # An external module can be provided here as a custom solution to mapping + # attributes returned from a OIDC provider onto a matrix user. + # + user_mapping_provider: + # The custom module's class. Uncomment to use a custom module. + # Default is {mapping_provider!r}. + # + #module: mapping_provider.OidcMappingProvider + + # Custom configuration values for the module. Below options are intended + # for the built-in provider, they should be changed if using a custom + # module. This section will be passed as a Python dictionary to the + # module's `parse_config` method. + # + # Below is the config of the default mapping provider, based on Jinja2 + # templates. Those templates are used to render user attributes, where the + # userinfo object is available through the `user` variable. + # + config: + # name of the claim containing a unique identifier for the user. + # Defaults to `sub`, which OpenID Connect compliant providers should provide. + # + #subject_claim: "sub" + + # Jinja2 template for the localpart of the MXID + # + localpart_template: "{{{{ user.preferred_username }}}}" + + # Jinja2 template for the display name to set on first login. Optional. + # + #display_name_template: "{{{{ user.given_name }}}} {{{{ user.last_name }}}}" + """.format( + mapping_provider=DEFAULT_USER_MAPPING_PROVIDER + ) diff --git a/synapse/config/sso.py b/synapse/config/sso.py index cac6bc0139..aff642f015 100644 --- a/synapse/config/sso.py +++ b/synapse/config/sso.py @@ -36,17 +36,13 @@ class SSOConfig(Config): if not template_dir: template_dir = pkg_resources.resource_filename("synapse", "res/templates",) - self.sso_redirect_confirm_template_dir = template_dir + self.sso_template_dir = template_dir self.sso_account_deactivated_template = self.read_file( - os.path.join( - self.sso_redirect_confirm_template_dir, "sso_account_deactivated.html" - ), + os.path.join(self.sso_template_dir, "sso_account_deactivated.html"), "sso_account_deactivated_template", ) self.sso_auth_success_template = self.read_file( - os.path.join( - self.sso_redirect_confirm_template_dir, "sso_auth_success.html" - ), + os.path.join(self.sso_template_dir, "sso_auth_success.html"), "sso_auth_success_template", ) @@ -137,6 +133,13 @@ class SSOConfig(Config): # # This template has no additional variables. # + # * HTML page to display to users if something goes wrong during the + # OpenID Connect authentication process: 'sso_error.html'. + # + # When rendering, this template is given two variables: + # * error: the technical name of the error + # * error_description: a human-readable message for the error + # # You can see the default templates at: # https://github.com/matrix-org/synapse/tree/master/synapse/res/templates # diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 7613e5b6ab..f8d2331bf1 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -126,13 +126,13 @@ class AuthHandler(BaseHandler): # It notifies the user they are about to give access to their matrix account # to the client. self._sso_redirect_confirm_template = load_jinja2_templates( - hs.config.sso_redirect_confirm_template_dir, ["sso_redirect_confirm.html"], + hs.config.sso_template_dir, ["sso_redirect_confirm.html"], )[0] # The following template is shown during user interactive authentication # in the fallback auth scenario. It notifies the user that they are # authenticating for an operation to occur on their account. self._sso_auth_confirm_template = load_jinja2_templates( - hs.config.sso_redirect_confirm_template_dir, ["sso_auth_confirm.html"], + hs.config.sso_template_dir, ["sso_auth_confirm.html"], )[0] # The following template is shown after a successful user interactive # authentication session. It tells the user they can close the window. diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py new file mode 100644 index 0000000000..178f263439 --- /dev/null +++ b/synapse/handlers/oidc_handler.py @@ -0,0 +1,998 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Quentin Gliech +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import logging +from typing import Dict, Generic, List, Optional, Tuple, TypeVar +from urllib.parse import urlencode + +import attr +import pymacaroons +from authlib.common.security import generate_token +from authlib.jose import JsonWebToken +from authlib.oauth2.auth import ClientAuth +from authlib.oauth2.rfc6749.parameters import prepare_grant_uri +from authlib.oidc.core import CodeIDToken, ImplicitIDToken, UserInfo +from authlib.oidc.discovery import OpenIDProviderMetadata, get_well_known_url +from jinja2 import Environment, Template +from pymacaroons.exceptions import ( + MacaroonDeserializationException, + MacaroonInvalidSignatureException, +) +from typing_extensions import TypedDict + +from twisted.web.client import readBody + +from synapse.config import ConfigError +from synapse.http.server import finish_request +from synapse.http.site import SynapseRequest +from synapse.push.mailer import load_jinja2_templates +from synapse.server import HomeServer +from synapse.types import UserID, map_username_to_mxid_localpart + +logger = logging.getLogger(__name__) + +SESSION_COOKIE_NAME = b"oidc_session" + +#: A token exchanged from the token endpoint, as per RFC6749 sec 5.1. and +#: OpenID.Core sec 3.1.3.3. +Token = TypedDict( + "Token", + { + "access_token": str, + "token_type": str, + "id_token": Optional[str], + "refresh_token": Optional[str], + "expires_in": int, + "scope": Optional[str], + }, +) + +#: A JWK, as per RFC7517 sec 4. The type could be more precise than that, but +#: there is no real point of doing this in our case. +JWK = Dict[str, str] + +#: A JWK Set, as per RFC7517 sec 5. +JWKS = TypedDict("JWKS", {"keys": List[JWK]}) + + +class OidcError(Exception): + """Used to catch errors when calling the token_endpoint + """ + + def __init__(self, error, error_description=None): + self.error = error + self.error_description = error_description + + def __str__(self): + if self.error_description: + return "{}: {}".format(self.error, self.error_description) + return self.error + + +class MappingException(Exception): + """Used to catch errors when mapping the UserInfo object + """ + + +class OidcHandler: + """Handles requests related to the OpenID Connect login flow. + """ + + def __init__(self, hs: HomeServer): + self._callback_url = hs.config.oidc_callback_url # type: str + self._scopes = hs.config.oidc_scopes # type: List[str] + self._client_auth = ClientAuth( + hs.config.oidc_client_id, + hs.config.oidc_client_secret, + hs.config.oidc_client_auth_method, + ) # type: ClientAuth + self._client_auth_method = hs.config.oidc_client_auth_method # type: str + self._subject_claim = hs.config.oidc_subject_claim + self._provider_metadata = OpenIDProviderMetadata( + issuer=hs.config.oidc_issuer, + authorization_endpoint=hs.config.oidc_authorization_endpoint, + token_endpoint=hs.config.oidc_token_endpoint, + userinfo_endpoint=hs.config.oidc_userinfo_endpoint, + jwks_uri=hs.config.oidc_jwks_uri, + ) # type: OpenIDProviderMetadata + self._provider_needs_discovery = hs.config.oidc_discover # type: bool + self._user_mapping_provider = hs.config.oidc_user_mapping_provider_class( + hs.config.oidc_user_mapping_provider_config + ) # type: OidcMappingProvider + self._skip_verification = hs.config.oidc_skip_verification # type: bool + + self._http_client = hs.get_proxied_http_client() + self._auth_handler = hs.get_auth_handler() + self._registration_handler = hs.get_registration_handler() + self._datastore = hs.get_datastore() + self._clock = hs.get_clock() + self._hostname = hs.hostname # type: str + self._server_name = hs.config.server_name # type: str + self._macaroon_secret_key = hs.config.macaroon_secret_key + self._error_template = load_jinja2_templates( + hs.config.sso_template_dir, ["sso_error.html"] + )[0] + + # identifier for the external_ids table + self._auth_provider_id = "oidc" + + def _render_error( + self, request, error: str, error_description: Optional[str] = None + ) -> None: + """Renders the error template and respond with it. + + This is used to show errors to the user. The template of this page can + be found under ``synapse/res/templates/sso_error.html``. + + Args: + request: The incoming request from the browser. + We'll respond with an HTML page describing the error. + error: A technical identifier for this error. Those include + well-known OAuth2/OIDC error types like invalid_request or + access_denied. + error_description: A human-readable description of the error. + """ + html_bytes = self._error_template.render( + error=error, error_description=error_description + ).encode("utf-8") + + request.setResponseCode(400) + request.setHeader(b"Content-Type", b"text/html; charset=utf-8") + request.setHeader(b"Content-Length", b"%i" % len(html_bytes)) + request.write(html_bytes) + finish_request(request) + + def _validate_metadata(self): + """Verifies the provider metadata. + + This checks the validity of the currently loaded provider. Not + everything is checked, only: + + - ``issuer`` + - ``authorization_endpoint`` + - ``token_endpoint`` + - ``response_types_supported`` (checks if "code" is in it) + - ``jwks_uri`` + + Raises: + ValueError: if something in the provider is not valid + """ + # Skip verification to allow non-compliant providers (e.g. issuers not running on a secure origin) + if self._skip_verification is True: + return + + m = self._provider_metadata + m.validate_issuer() + m.validate_authorization_endpoint() + m.validate_token_endpoint() + + if m.get("token_endpoint_auth_methods_supported") is not None: + m.validate_token_endpoint_auth_methods_supported() + if ( + self._client_auth_method + not in m["token_endpoint_auth_methods_supported"] + ): + raise ValueError( + '"{auth_method}" not in "token_endpoint_auth_methods_supported" ({supported!r})'.format( + auth_method=self._client_auth_method, + supported=m["token_endpoint_auth_methods_supported"], + ) + ) + + if m.get("response_types_supported") is not None: + m.validate_response_types_supported() + + if "code" not in m["response_types_supported"]: + raise ValueError( + '"code" not in "response_types_supported" (%r)' + % (m["response_types_supported"],) + ) + + # If the openid scope was not requested, we need a userinfo endpoint to fetch user infos + if self._uses_userinfo: + if m.get("userinfo_endpoint") is None: + raise ValueError( + 'provider has no "userinfo_endpoint", even though it is required because the "openid" scope is not requested' + ) + else: + # If we're not using userinfo, we need a valid jwks to validate the ID token + if m.get("jwks") is None: + if m.get("jwks_uri") is not None: + m.validate_jwks_uri() + else: + raise ValueError('"jwks_uri" must be set') + + @property + def _uses_userinfo(self) -> bool: + """Returns True if the ``userinfo_endpoint`` should be used. + + This is based on the requested scopes: if the scopes include + ``openid``, the provider should give use an ID token containing the + user informations. If not, we should fetch them using the + ``access_token`` with the ``userinfo_endpoint``. + """ + + # Maybe that should be user-configurable and not inferred? + return "openid" not in self._scopes + + async def load_metadata(self) -> OpenIDProviderMetadata: + """Load and validate the provider metadata. + + The values metadatas are discovered if ``oidc_config.discovery`` is + ``True`` and then cached. + + Raises: + ValueError: if something in the provider is not valid + + Returns: + The provider's metadata. + """ + # If we are using the OpenID Discovery documents, it needs to be loaded once + # FIXME: should there be a lock here? + if self._provider_needs_discovery: + url = get_well_known_url(self._provider_metadata["issuer"], external=True) + metadata_response = await self._http_client.get_json(url) + # TODO: maybe update the other way around to let user override some values? + self._provider_metadata.update(metadata_response) + self._provider_needs_discovery = False + + self._validate_metadata() + + return self._provider_metadata + + async def load_jwks(self, force: bool = False) -> JWKS: + """Load the JSON Web Key Set used to sign ID tokens. + + If we're not using the ``userinfo_endpoint``, user infos are extracted + from the ID token, which is a JWT signed by keys given by the provider. + The keys are then cached. + + Args: + force: Force reloading the keys. + + Returns: + The key set + + Looks like this:: + + { + 'keys': [ + { + 'kid': 'abcdef', + 'kty': 'RSA', + 'alg': 'RS256', + 'use': 'sig', + 'e': 'XXXX', + 'n': 'XXXX', + } + ] + } + """ + if self._uses_userinfo: + # We're not using jwt signing, return an empty jwk set + return {"keys": []} + + # First check if the JWKS are loaded in the provider metadata. + # It can happen either if the provider gives its JWKS in the discovery + # document directly or if it was already loaded once. + metadata = await self.load_metadata() + jwk_set = metadata.get("jwks") + if jwk_set is not None and not force: + return jwk_set + + # Loading the JWKS using the `jwks_uri` metadata + uri = metadata.get("jwks_uri") + if not uri: + raise RuntimeError('Missing "jwks_uri" in metadata') + + jwk_set = await self._http_client.get_json(uri) + + # Caching the JWKS in the provider's metadata + self._provider_metadata["jwks"] = jwk_set + return jwk_set + + async def _exchange_code(self, code: str) -> Token: + """Exchange an authorization code for a token. + + This calls the ``token_endpoint`` with the authorization code we + received in the callback to exchange it for a token. The call uses the + ``ClientAuth`` to authenticate with the client with its ID and secret. + + Args: + code: The autorization code we got from the callback. + + Returns: + A dict containing various tokens. + + May look like this:: + + { + 'token_type': 'bearer', + 'access_token': 'abcdef', + 'expires_in': 3599, + 'id_token': 'ghijkl', + 'refresh_token': 'mnopqr', + } + + Raises: + OidcError: when the ``token_endpoint`` returned an error. + """ + metadata = await self.load_metadata() + token_endpoint = metadata.get("token_endpoint") + headers = { + "Content-Type": "application/x-www-form-urlencoded", + "User-Agent": self._http_client.user_agent, + "Accept": "application/json", + } + + args = { + "grant_type": "authorization_code", + "code": code, + "redirect_uri": self._callback_url, + } + body = urlencode(args, True) + + # Fill the body/headers with credentials + uri, headers, body = self._client_auth.prepare( + method="POST", uri=token_endpoint, headers=headers, body=body + ) + headers = {k: [v] for (k, v) in headers.items()} + + # Do the actual request + # We're not using the SimpleHttpClient util methods as we don't want to + # check the HTTP status code and we do the body encoding ourself. + response = await self._http_client.request( + method="POST", uri=uri, data=body.encode("utf-8"), headers=headers, + ) + + # This is used in multiple error messages below + status = "{code} {phrase}".format( + code=response.code, phrase=response.phrase.decode("utf-8") + ) + + resp_body = await readBody(response) + + if response.code >= 500: + # In case of a server error, we should first try to decode the body + # and check for an error field. If not, we respond with a generic + # error message. + try: + resp = json.loads(resp_body.decode("utf-8")) + error = resp["error"] + description = resp.get("error_description", error) + except (ValueError, KeyError): + # Catch ValueError for the JSON decoding and KeyError for the "error" field + error = "server_error" + description = ( + ( + 'Authorization server responded with a "{status}" error ' + "while exchanging the authorization code." + ).format(status=status), + ) + + raise OidcError(error, description) + + # Since it is a not a 5xx code, body should be a valid JSON. It will + # raise if not. + resp = json.loads(resp_body.decode("utf-8")) + + if "error" in resp: + error = resp["error"] + # In case the authorization server responded with an error field, + # it should be a 4xx code. If not, warn about it but don't do + # anything special and report the original error message. + if response.code < 400: + logger.debug( + "Invalid response from the authorization server: " + 'responded with a "{status}" ' + "but body has an error field: {error!r}".format( + status=status, error=resp["error"] + ) + ) + + description = resp.get("error_description", error) + raise OidcError(error, description) + + # Now, this should not be an error. According to RFC6749 sec 5.1, it + # should be a 200 code. We're a bit more flexible than that, and will + # only throw on a 4xx code. + if response.code >= 400: + description = ( + 'Authorization server responded with a "{status}" error ' + 'but did not include an "error" field in its response.'.format( + status=status + ) + ) + logger.warning(description) + # Body was still valid JSON. Might be useful to log it for debugging. + logger.warning("Code exchange response: {resp!r}".format(resp=resp)) + raise OidcError("server_error", description) + + return resp + + async def _fetch_userinfo(self, token: Token) -> UserInfo: + """Fetch user informations from the ``userinfo_endpoint``. + + Args: + token: the token given by the ``token_endpoint``. + Must include an ``access_token`` field. + + Returns: + UserInfo: an object representing the user. + """ + metadata = await self.load_metadata() + + resp = await self._http_client.get_json( + metadata["userinfo_endpoint"], + headers={"Authorization": ["Bearer {}".format(token["access_token"])]}, + ) + + return UserInfo(resp) + + async def _parse_id_token(self, token: Token, nonce: str) -> UserInfo: + """Return an instance of UserInfo from token's ``id_token``. + + Args: + token: the token given by the ``token_endpoint``. + Must include an ``id_token`` field. + nonce: the nonce value originally sent in the initial authorization + request. This value should match the one inside the token. + + Returns: + An object representing the user. + """ + metadata = await self.load_metadata() + claims_params = { + "nonce": nonce, + "client_id": self._client_auth.client_id, + } + if "access_token" in token: + # If we got an `access_token`, there should be an `at_hash` claim + # in the `id_token` that we can check against. + claims_params["access_token"] = token["access_token"] + claims_cls = CodeIDToken + else: + claims_cls = ImplicitIDToken + + alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"]) + + jwt = JsonWebToken(alg_values) + + claim_options = {"iss": {"values": [metadata["issuer"]]}} + + # Try to decode the keys in cache first, then retry by forcing the keys + # to be reloaded + jwk_set = await self.load_jwks() + try: + claims = jwt.decode( + token["id_token"], + key=jwk_set, + claims_cls=claims_cls, + claims_options=claim_options, + claims_params=claims_params, + ) + except ValueError: + jwk_set = await self.load_jwks(force=True) # try reloading the jwks + claims = jwt.decode( + token["id_token"], + key=jwk_set, + claims_cls=claims_cls, + claims_options=claim_options, + claims_params=claims_params, + ) + + claims.validate(leeway=120) # allows 2 min of clock skew + return UserInfo(claims) + + async def handle_redirect_request( + self, request: SynapseRequest, client_redirect_url: bytes + ) -> None: + """Handle an incoming request to /login/sso/redirect + + It redirects the browser to the authorization endpoint with a few + parameters: + + - ``client_id``: the client ID set in ``oidc_config.client_id`` + - ``response_type``: ``code`` + - ``redirect_uri``: the callback URL ; ``{base url}/_synapse/oidc/callback`` + - ``scope``: the list of scopes set in ``oidc_config.scopes`` + - ``state``: a random string + - ``nonce``: a random string + + In addition to redirecting the client, we are setting a cookie with + a signed macaroon token containing the state, the nonce and the + client_redirect_url params. Those are then checked when the client + comes back from the provider. + + + Args: + request: the incoming request from the browser. + We'll respond to it with a redirect and a cookie. + client_redirect_url: the URL that we should redirect the client to + when everything is done + """ + + state = generate_token() + nonce = generate_token() + + cookie = self._generate_oidc_session_token( + state=state, nonce=nonce, client_redirect_url=client_redirect_url.decode(), + ) + request.addCookie( + SESSION_COOKIE_NAME, + cookie, + path="/_synapse/oidc", + max_age="3600", + httpOnly=True, + sameSite="lax", + ) + + metadata = await self.load_metadata() + authorization_endpoint = metadata.get("authorization_endpoint") + uri = prepare_grant_uri( + authorization_endpoint, + client_id=self._client_auth.client_id, + response_type="code", + redirect_uri=self._callback_url, + scope=self._scopes, + state=state, + nonce=nonce, + ) + request.redirect(uri) + finish_request(request) + + async def handle_oidc_callback(self, request: SynapseRequest) -> None: + """Handle an incoming request to /_synapse/oidc/callback + + Since we might want to display OIDC-related errors in a user-friendly + way, we don't raise SynapseError from here. Instead, we call + ``self._render_error`` which displays an HTML page for the error. + + Most of the OpenID Connect logic happens here: + + - first, we check if there was any error returned by the provider and + display it + - then we fetch the session cookie, decode and verify it + - the ``state`` query parameter should match with the one stored in the + session cookie + - once we known this session is legit, exchange the code with the + provider using the ``token_endpoint`` (see ``_exchange_code``) + - once we have the token, use it to either extract the UserInfo from + the ``id_token`` (``_parse_id_token``), or use the ``access_token`` + to fetch UserInfo from the ``userinfo_endpoint`` + (``_fetch_userinfo``) + - map those UserInfo to a Matrix user (``_map_userinfo_to_user``) and + finish the login + + Args: + request: the incoming request from the browser. + """ + + # The provider might redirect with an error. + # In that case, just display it as-is. + if b"error" in request.args: + error = request.args[b"error"][0].decode() + description = request.args.get(b"error_description", [b""])[0].decode() + + # Most of the errors returned by the provider could be due by + # either the provider misbehaving or Synapse being misconfigured. + # The only exception of that is "access_denied", where the user + # probably cancelled the login flow. In other cases, log those errors. + if error != "access_denied": + logger.error("Error from the OIDC provider: %s %s", error, description) + + self._render_error(request, error, description) + return + + # Fetch the session cookie + session = request.getCookie(SESSION_COOKIE_NAME) + if session is None: + logger.info("No session cookie found") + self._render_error(request, "missing_session", "No session cookie found") + return + + # Remove the cookie. There is a good chance that if the callback failed + # once, it will fail next time and the code will already be exchanged. + # Removing it early avoids spamming the provider with token requests. + request.addCookie( + SESSION_COOKIE_NAME, + b"", + path="/_synapse/oidc", + expires="Thu, Jan 01 1970 00:00:00 UTC", + httpOnly=True, + sameSite="lax", + ) + + # Check for the state query parameter + if b"state" not in request.args: + logger.info("State parameter is missing") + self._render_error(request, "invalid_request", "State parameter is missing") + return + + state = request.args[b"state"][0].decode() + + # Deserialize the session token and verify it. + try: + nonce, client_redirect_url = self._verify_oidc_session_token(session, state) + except MacaroonDeserializationException as e: + logger.exception("Invalid session") + self._render_error(request, "invalid_session", str(e)) + return + except MacaroonInvalidSignatureException as e: + logger.exception("Could not verify session") + self._render_error(request, "mismatching_session", str(e)) + return + + # Exchange the code with the provider + if b"code" not in request.args: + logger.info("Code parameter is missing") + self._render_error(request, "invalid_request", "Code parameter is missing") + return + + logger.info("Exchanging code") + code = request.args[b"code"][0].decode() + try: + token = await self._exchange_code(code) + except OidcError as e: + logger.exception("Could not exchange code") + self._render_error(request, e.error, e.error_description) + return + + # Now that we have a token, get the userinfo, either by decoding the + # `id_token` or by fetching the `userinfo_endpoint`. + if self._uses_userinfo: + logger.info("Fetching userinfo") + try: + userinfo = await self._fetch_userinfo(token) + except Exception as e: + logger.exception("Could not fetch userinfo") + self._render_error(request, "fetch_error", str(e)) + return + else: + logger.info("Extracting userinfo from id_token") + try: + userinfo = await self._parse_id_token(token, nonce=nonce) + except Exception as e: + logger.exception("Invalid id_token") + self._render_error(request, "invalid_token", str(e)) + return + + # Call the mapper to register/login the user + try: + user_id = await self._map_userinfo_to_user(userinfo, token) + except MappingException as e: + logger.exception("Could not map user") + self._render_error(request, "mapping_error", str(e)) + return + + # and finally complete the login + await self._auth_handler.complete_sso_login( + user_id, request, client_redirect_url + ) + + def _generate_oidc_session_token( + self, + state: str, + nonce: str, + client_redirect_url: str, + duration_in_ms: int = (60 * 60 * 1000), + ) -> str: + """Generates a signed token storing data about an OIDC session. + + When Synapse initiates an authorization flow, it creates a random state + and a random nonce. Those parameters are given to the provider and + should be verified when the client comes back from the provider. + It is also used to store the client_redirect_url, which is used to + complete the SSO login flow. + + Args: + state: The ``state`` parameter passed to the OIDC provider. + nonce: The ``nonce`` parameter passed to the OIDC provider. + client_redirect_url: The URL the client gave when it initiated the + flow. + duration_in_ms: An optional duration for the token in milliseconds. + Defaults to an hour. + + Returns: + A signed macaroon token with the session informations. + """ + macaroon = pymacaroons.Macaroon( + location=self._server_name, identifier="key", key=self._macaroon_secret_key, + ) + macaroon.add_first_party_caveat("gen = 1") + macaroon.add_first_party_caveat("type = session") + macaroon.add_first_party_caveat("state = %s" % (state,)) + macaroon.add_first_party_caveat("nonce = %s" % (nonce,)) + macaroon.add_first_party_caveat( + "client_redirect_url = %s" % (client_redirect_url,) + ) + now = self._clock.time_msec() + expiry = now + duration_in_ms + macaroon.add_first_party_caveat("time < %d" % (expiry,)) + return macaroon.serialize() + + def _verify_oidc_session_token(self, session: str, state: str) -> Tuple[str, str]: + """Verifies and extract an OIDC session token. + + This verifies that a given session token was issued by this homeserver + and extract the nonce and client_redirect_url caveats. + + Args: + session: The session token to verify + state: The state the OIDC provider gave back + + Returns: + The nonce and the client_redirect_url for this session + """ + macaroon = pymacaroons.Macaroon.deserialize(session) + + v = pymacaroons.Verifier() + v.satisfy_exact("gen = 1") + v.satisfy_exact("type = session") + v.satisfy_exact("state = %s" % (state,)) + v.satisfy_general(lambda c: c.startswith("nonce = ")) + v.satisfy_general(lambda c: c.startswith("client_redirect_url = ")) + v.satisfy_general(self._verify_expiry) + + v.verify(macaroon, self._macaroon_secret_key) + + # Extract the `nonce` and `client_redirect_url` from the token + nonce = self._get_value_from_macaroon(macaroon, "nonce") + client_redirect_url = self._get_value_from_macaroon( + macaroon, "client_redirect_url" + ) + + return nonce, client_redirect_url + + def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str: + """Extracts a caveat value from a macaroon token. + + Args: + macaroon: the token + key: the key of the caveat to extract + + Returns: + The extracted value + + Raises: + Exception: if the caveat was not in the macaroon + """ + prefix = key + " = " + for caveat in macaroon.caveats: + if caveat.caveat_id.startswith(prefix): + return caveat.caveat_id[len(prefix) :] + raise Exception("No %s caveat in macaroon" % (key,)) + + def _verify_expiry(self, caveat: str) -> bool: + prefix = "time < " + if not caveat.startswith(prefix): + return False + expiry = int(caveat[len(prefix) :]) + now = self._clock.time_msec() + return now < expiry + + async def _map_userinfo_to_user(self, userinfo: UserInfo, token: Token) -> str: + """Maps a UserInfo object to a mxid. + + UserInfo should have a claim that uniquely identifies users. This claim + is usually `sub`, but can be configured with `oidc_config.subject_claim`. + It is then used as an `external_id`. + + If we don't find the user that way, we should register the user, + mapping the localpart and the display name from the UserInfo. + + If a user already exists with the mxid we've mapped, raise an exception. + + Args: + userinfo: an object representing the user + token: a dict with the tokens obtained from the provider + + Raises: + MappingException: if there was an error while mapping some properties + + Returns: + The mxid of the user + """ + try: + remote_user_id = self._user_mapping_provider.get_remote_user_id(userinfo) + except Exception as e: + raise MappingException( + "Failed to extract subject from OIDC response: %s" % (e,) + ) + + logger.info( + "Looking for existing mapping for user %s:%s", + self._auth_provider_id, + remote_user_id, + ) + + registered_user_id = await self._datastore.get_user_by_external_id( + self._auth_provider_id, remote_user_id, + ) + + if registered_user_id is not None: + logger.info("Found existing mapping %s", registered_user_id) + return registered_user_id + + try: + attributes = await self._user_mapping_provider.map_user_attributes( + userinfo, token + ) + except Exception as e: + raise MappingException( + "Could not extract user attributes from OIDC response: " + str(e) + ) + + logger.debug( + "Retrieved user attributes from user mapping provider: %r", attributes + ) + + if not attributes["localpart"]: + raise MappingException("localpart is empty") + + localpart = map_username_to_mxid_localpart(attributes["localpart"]) + + user_id = UserID(localpart, self._hostname) + if await self._datastore.get_users_by_id_case_insensitive(user_id.to_string()): + # This mxid is taken + raise MappingException( + "mxid '{}' is already taken".format(user_id.to_string()) + ) + + # It's the first time this user is logging in and the mapped mxid was + # not taken, register the user + registered_user_id = await self._registration_handler.register_user( + localpart=localpart, default_display_name=attributes["display_name"], + ) + + await self._datastore.record_user_external_id( + self._auth_provider_id, remote_user_id, registered_user_id, + ) + return registered_user_id + + +UserAttribute = TypedDict( + "UserAttribute", {"localpart": str, "display_name": Optional[str]} +) +C = TypeVar("C") + + +class OidcMappingProvider(Generic[C]): + """A mapping provider maps a UserInfo object to user attributes. + + It should provide the API described by this class. + """ + + def __init__(self, config: C): + """ + Args: + config: A custom config object from this module, parsed by ``parse_config()`` + """ + + @staticmethod + def parse_config(config: dict) -> C: + """Parse the dict provided by the homeserver's config + + Args: + config: A dictionary containing configuration options for this provider + + Returns: + A custom config object for this module + """ + raise NotImplementedError() + + def get_remote_user_id(self, userinfo: UserInfo) -> str: + """Get a unique user ID for this user. + + Usually, in an OIDC-compliant scenario, it should be the ``sub`` claim from the UserInfo object. + + Args: + userinfo: An object representing the user given by the OIDC provider + + Returns: + A unique user ID + """ + raise NotImplementedError() + + async def map_user_attributes( + self, userinfo: UserInfo, token: Token + ) -> UserAttribute: + """Map a ``UserInfo`` objects into user attributes. + + Args: + userinfo: An object representing the user given by the OIDC provider + token: A dict with the tokens returned by the provider + + Returns: + A dict containing the ``localpart`` and (optionally) the ``display_name`` + """ + raise NotImplementedError() + + +# Used to clear out "None" values in templates +def jinja_finalize(thing): + return thing if thing is not None else "" + + +env = Environment(finalize=jinja_finalize) + + +@attr.s +class JinjaOidcMappingConfig: + subject_claim = attr.ib() # type: str + localpart_template = attr.ib() # type: Template + display_name_template = attr.ib() # type: Optional[Template] + + +class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): + """An implementation of a mapping provider based on Jinja templates. + + This is the default mapping provider. + """ + + def __init__(self, config: JinjaOidcMappingConfig): + self._config = config + + @staticmethod + def parse_config(config: dict) -> JinjaOidcMappingConfig: + subject_claim = config.get("subject_claim", "sub") + + if "localpart_template" not in config: + raise ConfigError( + "missing key: oidc_config.user_mapping_provider.config.localpart_template" + ) + + try: + localpart_template = env.from_string(config["localpart_template"]) + except Exception as e: + raise ConfigError( + "invalid jinja template for oidc_config.user_mapping_provider.config.localpart_template: %r" + % (e,) + ) + + display_name_template = None # type: Optional[Template] + if "display_name_template" in config: + try: + display_name_template = env.from_string(config["display_name_template"]) + except Exception as e: + raise ConfigError( + "invalid jinja template for oidc_config.user_mapping_provider.config.display_name_template: %r" + % (e,) + ) + + return JinjaOidcMappingConfig( + subject_claim=subject_claim, + localpart_template=localpart_template, + display_name_template=display_name_template, + ) + + def get_remote_user_id(self, userinfo: UserInfo) -> str: + return userinfo[self._config.subject_claim] + + async def map_user_attributes( + self, userinfo: UserInfo, token: Token + ) -> UserAttribute: + localpart = self._config.localpart_template.render(user=userinfo).strip() + + display_name = None # type: Optional[str] + if self._config.display_name_template is not None: + display_name = self._config.display_name_template.render( + user=userinfo + ).strip() + + if display_name == "": + display_name = None + + return UserAttribute(localpart=localpart, display_name=display_name) diff --git a/synapse/http/client.py b/synapse/http/client.py index 3797545824..58eb47c69c 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -359,6 +359,7 @@ class SimpleHttpClient(object): actual_headers = { b"Content-Type": [b"application/x-www-form-urlencoded"], b"User-Agent": [self.user_agent], + b"Accept": [b"application/json"], } if headers: actual_headers.update(headers) @@ -399,6 +400,7 @@ class SimpleHttpClient(object): actual_headers = { b"Content-Type": [b"application/json"], b"User-Agent": [self.user_agent], + b"Accept": [b"application/json"], } if headers: actual_headers.update(headers) @@ -434,6 +436,10 @@ class SimpleHttpClient(object): ValueError: if the response was not JSON """ + actual_headers = {b"Accept": [b"application/json"]} + if headers: + actual_headers.update(headers) + body = yield self.get_raw(uri, args, headers=headers) return json.loads(body) @@ -467,6 +473,7 @@ class SimpleHttpClient(object): actual_headers = { b"Content-Type": [b"application/json"], b"User-Agent": [self.user_agent], + b"Accept": [b"application/json"], } if headers: actual_headers.update(headers) diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 39c99a2802..8b4312e5a3 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -92,6 +92,7 @@ CONDITIONAL_REQUIREMENTS = { 'eliot<1.8.0;python_version<"3.5.3"', ], "saml2": ["pysaml2>=4.5.0"], + "oidc": ["authlib>=0.14.0"], "systemd": ["systemd-python>=231"], "url_preview": ["lxml>=3.5.0"], "test": ["mock>=2.0", "parameterized"], diff --git a/synapse/res/templates/sso_error.html b/synapse/res/templates/sso_error.html new file mode 100644 index 0000000000..43a211386b --- /dev/null +++ b/synapse/res/templates/sso_error.html @@ -0,0 +1,18 @@ + + + + + SSO error + + +

Oops! Something went wrong during authentication.

+

+ Try logging in again from your Matrix client and if the problem persists + please contact the server's administrator. +

+

Error: {{ error }}

+ {% if error_description %} +
{{ error_description }}
+ {% endif %} + + diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 4de2f97d06..de7eca21f8 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -83,6 +83,7 @@ class LoginRestServlet(RestServlet): self.jwt_algorithm = hs.config.jwt_algorithm self.saml2_enabled = hs.config.saml2_enabled self.cas_enabled = hs.config.cas_enabled + self.oidc_enabled = hs.config.oidc_enabled self.auth_handler = self.hs.get_auth_handler() self.registration_handler = hs.get_registration_handler() self.handlers = hs.get_handlers() @@ -96,9 +97,7 @@ class LoginRestServlet(RestServlet): flows = [] if self.jwt_enabled: flows.append({"type": LoginRestServlet.JWT_TYPE}) - if self.saml2_enabled: - flows.append({"type": LoginRestServlet.SSO_TYPE}) - flows.append({"type": LoginRestServlet.TOKEN_TYPE}) + if self.cas_enabled: flows.append({"type": LoginRestServlet.SSO_TYPE}) @@ -114,6 +113,11 @@ class LoginRestServlet(RestServlet): # fall back to the fallback API if they don't understand one of the # login flow types returned. flows.append({"type": LoginRestServlet.TOKEN_TYPE}) + elif self.saml2_enabled: + flows.append({"type": LoginRestServlet.SSO_TYPE}) + flows.append({"type": LoginRestServlet.TOKEN_TYPE}) + elif self.oidc_enabled: + flows.append({"type": LoginRestServlet.SSO_TYPE}) flows.extend( ({"type": t} for t in self.auth_handler.get_supported_login_types()) @@ -465,6 +469,22 @@ class SAMLRedirectServlet(BaseSSORedirectServlet): return self._saml_handler.handle_redirect_request(client_redirect_url) +class OIDCRedirectServlet(RestServlet): + """Implementation for /login/sso/redirect for the OIDC login flow.""" + + PATTERNS = client_patterns("/login/sso/redirect", v1=True) + + def __init__(self, hs): + self._oidc_handler = hs.get_oidc_handler() + + async def on_GET(self, request): + args = request.args + if b"redirectUrl" not in args: + return 400, "Redirect URL not specified for SSO auth" + client_redirect_url = args[b"redirectUrl"][0] + await self._oidc_handler.handle_redirect_request(request, client_redirect_url) + + def register_servlets(hs, http_server): LoginRestServlet(hs).register(http_server) if hs.config.cas_enabled: @@ -472,3 +492,5 @@ def register_servlets(hs, http_server): CasTicketServlet(hs).register(http_server) elif hs.config.saml2_enabled: SAMLRedirectServlet(hs).register(http_server) + elif hs.config.oidc_enabled: + OIDCRedirectServlet(hs).register(http_server) diff --git a/synapse/rest/oidc/__init__.py b/synapse/rest/oidc/__init__.py new file mode 100644 index 0000000000..d958dd65bb --- /dev/null +++ b/synapse/rest/oidc/__init__.py @@ -0,0 +1,27 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Quentin Gliech +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from twisted.web.resource import Resource + +from synapse.rest.oidc.callback_resource import OIDCCallbackResource + +logger = logging.getLogger(__name__) + + +class OIDCResource(Resource): + def __init__(self, hs): + Resource.__init__(self) + self.putChild(b"callback", OIDCCallbackResource(hs)) diff --git a/synapse/rest/oidc/callback_resource.py b/synapse/rest/oidc/callback_resource.py new file mode 100644 index 0000000000..c03194f001 --- /dev/null +++ b/synapse/rest/oidc/callback_resource.py @@ -0,0 +1,31 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Quentin Gliech +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from synapse.http.server import DirectServeResource, wrap_html_request_handler + +logger = logging.getLogger(__name__) + + +class OIDCCallbackResource(DirectServeResource): + isLeaf = 1 + + def __init__(self, hs): + super().__init__() + self._oidc_handler = hs.get_oidc_handler() + + @wrap_html_request_handler + async def _async_render_GET(self, request): + return await self._oidc_handler.handle_oidc_callback(request) diff --git a/synapse/server.py b/synapse/server.py index bf97a16c09..b4aea81e24 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -204,6 +204,7 @@ class HomeServer(object): "account_validity_handler", "cas_handler", "saml_handler", + "oidc_handler", "event_client_serializer", "password_policy_handler", "storage", @@ -562,6 +563,11 @@ class HomeServer(object): return SamlHandler(self) + def build_oidc_handler(self): + from synapse.handlers.oidc_handler import OidcHandler + + return OidcHandler(self) + def build_event_client_serializer(self): return EventClientSerializer(self) diff --git a/synapse/server.pyi b/synapse/server.pyi index 18043a2593..31a9cc0389 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -13,6 +13,7 @@ import synapse.handlers.device import synapse.handlers.e2e_keys import synapse.handlers.message import synapse.handlers.presence +import synapse.handlers.register import synapse.handlers.room import synapse.handlers.room_member import synapse.handlers.set_password @@ -128,3 +129,7 @@ class HomeServer(object): pass def get_storage(self) -> synapse.storage.Storage: pass + def get_registration_handler(self) -> synapse.handlers.register.RegistrationHandler: + pass + def get_macaroon_generator(self) -> synapse.handlers.auth.MacaroonGenerator: + pass diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py new file mode 100644 index 0000000000..61963aa90d --- /dev/null +++ b/tests/handlers/test_oidc.py @@ -0,0 +1,565 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Quentin Gliech +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from urllib.parse import parse_qs, urlparse + +from mock import Mock, patch + +import attr +import pymacaroons + +from twisted.internet import defer +from twisted.python.failure import Failure +from twisted.web._newclient import ResponseDone + +from synapse.handlers.oidc_handler import ( + MappingException, + OidcError, + OidcHandler, + OidcMappingProvider, +) +from synapse.types import UserID + +from tests.unittest import HomeserverTestCase, override_config + + +@attr.s +class FakeResponse: + code = attr.ib() + body = attr.ib() + phrase = attr.ib() + + def deliverBody(self, protocol): + protocol.dataReceived(self.body) + protocol.connectionLost(Failure(ResponseDone())) + + +# These are a few constants that are used as config parameters in the tests. +ISSUER = "https://issuer/" +CLIENT_ID = "test-client-id" +CLIENT_SECRET = "test-client-secret" +BASE_URL = "https://synapse/" +CALLBACK_URL = BASE_URL + "_synapse/oidc/callback" +SCOPES = ["openid"] + +AUTHORIZATION_ENDPOINT = ISSUER + "authorize" +TOKEN_ENDPOINT = ISSUER + "token" +USERINFO_ENDPOINT = ISSUER + "userinfo" +WELL_KNOWN = ISSUER + ".well-known/openid-configuration" +JWKS_URI = ISSUER + ".well-known/jwks.json" + +# config for common cases +COMMON_CONFIG = { + "discover": False, + "authorization_endpoint": AUTHORIZATION_ENDPOINT, + "token_endpoint": TOKEN_ENDPOINT, + "jwks_uri": JWKS_URI, +} + + +# The cookie name and path don't really matter, just that it has to be coherent +# between the callback & redirect handlers. +COOKIE_NAME = b"oidc_session" +COOKIE_PATH = "/_synapse/oidc" + +MockedMappingProvider = Mock(OidcMappingProvider) + + +def simple_async_mock(return_value=None, raises=None): + # AsyncMock is not available in python3.5, this mimics part of its behaviour + async def cb(*args, **kwargs): + if raises: + raise raises + return return_value + + return Mock(side_effect=cb) + + +async def get_json(url): + # Mock get_json calls to handle jwks & oidc discovery endpoints + if url == WELL_KNOWN: + # Minimal discovery document, as defined in OpenID.Discovery + # https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata + return { + "issuer": ISSUER, + "authorization_endpoint": AUTHORIZATION_ENDPOINT, + "token_endpoint": TOKEN_ENDPOINT, + "jwks_uri": JWKS_URI, + "userinfo_endpoint": USERINFO_ENDPOINT, + "response_types_supported": ["code"], + "subject_types_supported": ["public"], + "id_token_signing_alg_values_supported": ["RS256"], + } + elif url == JWKS_URI: + return {"keys": []} + + +class OidcHandlerTestCase(HomeserverTestCase): + def make_homeserver(self, reactor, clock): + + self.http_client = Mock(spec=["get_json"]) + self.http_client.get_json.side_effect = get_json + self.http_client.user_agent = "Synapse Test" + + config = self.default_config() + config["public_baseurl"] = BASE_URL + oidc_config = config.get("oidc_config", {}) + oidc_config["enabled"] = True + oidc_config["client_id"] = CLIENT_ID + oidc_config["client_secret"] = CLIENT_SECRET + oidc_config["issuer"] = ISSUER + oidc_config["scopes"] = SCOPES + oidc_config["user_mapping_provider"] = { + "module": __name__ + ".MockedMappingProvider" + } + config["oidc_config"] = oidc_config + + hs = self.setup_test_homeserver( + http_client=self.http_client, + proxied_http_client=self.http_client, + config=config, + ) + + self.handler = OidcHandler(hs) + + return hs + + def metadata_edit(self, values): + return patch.dict(self.handler._provider_metadata, values) + + def assertRenderedError(self, error, error_description=None): + args = self.handler._render_error.call_args[0] + self.assertEqual(args[1], error) + if error_description is not None: + self.assertEqual(args[2], error_description) + # Reset the render_error mock + self.handler._render_error.reset_mock() + + def test_config(self): + """Basic config correctly sets up the callback URL and client auth correctly.""" + self.assertEqual(self.handler._callback_url, CALLBACK_URL) + self.assertEqual(self.handler._client_auth.client_id, CLIENT_ID) + self.assertEqual(self.handler._client_auth.client_secret, CLIENT_SECRET) + + @override_config({"oidc_config": {"discover": True}}) + @defer.inlineCallbacks + def test_discovery(self): + """The handler should discover the endpoints from OIDC discovery document.""" + # This would throw if some metadata were invalid + metadata = yield defer.ensureDeferred(self.handler.load_metadata()) + self.http_client.get_json.assert_called_once_with(WELL_KNOWN) + + self.assertEqual(metadata.issuer, ISSUER) + self.assertEqual(metadata.authorization_endpoint, AUTHORIZATION_ENDPOINT) + self.assertEqual(metadata.token_endpoint, TOKEN_ENDPOINT) + self.assertEqual(metadata.jwks_uri, JWKS_URI) + # FIXME: it seems like authlib does not have that defined in its metadata models + # self.assertEqual(metadata.userinfo_endpoint, USERINFO_ENDPOINT) + + # subsequent calls should be cached + self.http_client.reset_mock() + yield defer.ensureDeferred(self.handler.load_metadata()) + self.http_client.get_json.assert_not_called() + + @override_config({"oidc_config": COMMON_CONFIG}) + @defer.inlineCallbacks + def test_no_discovery(self): + """When discovery is disabled, it should not try to load from discovery document.""" + yield defer.ensureDeferred(self.handler.load_metadata()) + self.http_client.get_json.assert_not_called() + + @override_config({"oidc_config": COMMON_CONFIG}) + @defer.inlineCallbacks + def test_load_jwks(self): + """JWKS loading is done once (then cached) if used.""" + jwks = yield defer.ensureDeferred(self.handler.load_jwks()) + self.http_client.get_json.assert_called_once_with(JWKS_URI) + self.assertEqual(jwks, {"keys": []}) + + # subsequent calls should be cached… + self.http_client.reset_mock() + yield defer.ensureDeferred(self.handler.load_jwks()) + self.http_client.get_json.assert_not_called() + + # …unless forced + self.http_client.reset_mock() + yield defer.ensureDeferred(self.handler.load_jwks(force=True)) + self.http_client.get_json.assert_called_once_with(JWKS_URI) + + # Throw if the JWKS uri is missing + with self.metadata_edit({"jwks_uri": None}): + with self.assertRaises(RuntimeError): + yield defer.ensureDeferred(self.handler.load_jwks(force=True)) + + # Return empty key set if JWKS are not used + self.handler._scopes = [] # not asking the openid scope + self.http_client.get_json.reset_mock() + jwks = yield defer.ensureDeferred(self.handler.load_jwks(force=True)) + self.http_client.get_json.assert_not_called() + self.assertEqual(jwks, {"keys": []}) + + @override_config({"oidc_config": COMMON_CONFIG}) + def test_validate_config(self): + """Provider metadatas are extensively validated.""" + h = self.handler + + # Default test config does not throw + h._validate_metadata() + + with self.metadata_edit({"issuer": None}): + self.assertRaisesRegex(ValueError, "issuer", h._validate_metadata) + + with self.metadata_edit({"issuer": "http://insecure/"}): + self.assertRaisesRegex(ValueError, "issuer", h._validate_metadata) + + with self.metadata_edit({"issuer": "https://invalid/?because=query"}): + self.assertRaisesRegex(ValueError, "issuer", h._validate_metadata) + + with self.metadata_edit({"authorization_endpoint": None}): + self.assertRaisesRegex( + ValueError, "authorization_endpoint", h._validate_metadata + ) + + with self.metadata_edit({"authorization_endpoint": "http://insecure/auth"}): + self.assertRaisesRegex( + ValueError, "authorization_endpoint", h._validate_metadata + ) + + with self.metadata_edit({"token_endpoint": None}): + self.assertRaisesRegex(ValueError, "token_endpoint", h._validate_metadata) + + with self.metadata_edit({"token_endpoint": "http://insecure/token"}): + self.assertRaisesRegex(ValueError, "token_endpoint", h._validate_metadata) + + with self.metadata_edit({"jwks_uri": None}): + self.assertRaisesRegex(ValueError, "jwks_uri", h._validate_metadata) + + with self.metadata_edit({"jwks_uri": "http://insecure/jwks.json"}): + self.assertRaisesRegex(ValueError, "jwks_uri", h._validate_metadata) + + with self.metadata_edit({"response_types_supported": ["id_token"]}): + self.assertRaisesRegex( + ValueError, "response_types_supported", h._validate_metadata + ) + + with self.metadata_edit( + {"token_endpoint_auth_methods_supported": ["client_secret_basic"]} + ): + # should not throw, as client_secret_basic is the default auth method + h._validate_metadata() + + with self.metadata_edit( + {"token_endpoint_auth_methods_supported": ["client_secret_post"]} + ): + self.assertRaisesRegex( + ValueError, + "token_endpoint_auth_methods_supported", + h._validate_metadata, + ) + + # Tests for configs that the userinfo endpoint + self.assertFalse(h._uses_userinfo) + h._scopes = [] # do not request the openid scope + self.assertTrue(h._uses_userinfo) + self.assertRaisesRegex(ValueError, "userinfo_endpoint", h._validate_metadata) + + with self.metadata_edit( + {"userinfo_endpoint": USERINFO_ENDPOINT, "jwks_uri": None} + ): + # Shouldn't raise with a valid userinfo, even without + h._validate_metadata() + + @override_config({"oidc_config": {"skip_verification": True}}) + def test_skip_verification(self): + """Provider metadata validation can be disabled by config.""" + with self.metadata_edit({"issuer": "http://insecure"}): + # This should not throw + self.handler._validate_metadata() + + @defer.inlineCallbacks + def test_redirect_request(self): + """The redirect request has the right arguments & generates a valid session cookie.""" + req = Mock(spec=["addCookie", "redirect", "finish"]) + yield defer.ensureDeferred( + self.handler.handle_redirect_request(req, b"http://client/redirect") + ) + url = req.redirect.call_args[0][0] + url = urlparse(url) + auth_endpoint = urlparse(AUTHORIZATION_ENDPOINT) + + self.assertEqual(url.scheme, auth_endpoint.scheme) + self.assertEqual(url.netloc, auth_endpoint.netloc) + self.assertEqual(url.path, auth_endpoint.path) + + params = parse_qs(url.query) + self.assertEqual(params["redirect_uri"], [CALLBACK_URL]) + self.assertEqual(params["response_type"], ["code"]) + self.assertEqual(params["scope"], [" ".join(SCOPES)]) + self.assertEqual(params["client_id"], [CLIENT_ID]) + self.assertEqual(len(params["state"]), 1) + self.assertEqual(len(params["nonce"]), 1) + + # Check what is in the cookie + # note: python3.5 mock does not have the .called_once() method + calls = req.addCookie.call_args_list + self.assertEqual(len(calls), 1) # called once + # For some reason, call.args does not work with python3.5 + args = calls[0][0] + kwargs = calls[0][1] + self.assertEqual(args[0], COOKIE_NAME) + self.assertEqual(kwargs["path"], COOKIE_PATH) + cookie = args[1] + + macaroon = pymacaroons.Macaroon.deserialize(cookie) + state = self.handler._get_value_from_macaroon(macaroon, "state") + nonce = self.handler._get_value_from_macaroon(macaroon, "nonce") + redirect = self.handler._get_value_from_macaroon( + macaroon, "client_redirect_url" + ) + + self.assertEqual(params["state"], [state]) + self.assertEqual(params["nonce"], [nonce]) + self.assertEqual(redirect, "http://client/redirect") + + @defer.inlineCallbacks + def test_callback_error(self): + """Errors from the provider returned in the callback are displayed.""" + self.handler._render_error = Mock() + request = Mock(args={}) + request.args[b"error"] = [b"invalid_client"] + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.assertRenderedError("invalid_client", "") + + request.args[b"error_description"] = [b"some description"] + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.assertRenderedError("invalid_client", "some description") + + @defer.inlineCallbacks + def test_callback(self): + """Code callback works and display errors if something went wrong. + + A lot of scenarios are tested here: + - when the callback works, with userinfo from ID token + - when the user mapping fails + - when ID token verification fails + - when the callback works, with userinfo fetched from the userinfo endpoint + - when the userinfo fetching fails + - when the code exchange fails + """ + token = { + "type": "bearer", + "id_token": "id_token", + "access_token": "access_token", + } + userinfo = { + "sub": "foo", + "preferred_username": "bar", + } + user_id = UserID("foo", "domain.org") + self.handler._render_error = Mock(return_value=None) + self.handler._exchange_code = simple_async_mock(return_value=token) + self.handler._parse_id_token = simple_async_mock(return_value=userinfo) + self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo) + self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id) + self.handler._auth_handler.complete_sso_login = simple_async_mock() + request = Mock(spec=["args", "getCookie", "addCookie"]) + + code = "code" + state = "state" + nonce = "nonce" + client_redirect_url = "http://client/redirect" + session = self.handler._generate_oidc_session_token( + state=state, nonce=nonce, client_redirect_url=client_redirect_url, + ) + request.getCookie.return_value = session + + request.args = {} + request.args[b"code"] = [code.encode("utf-8")] + request.args[b"state"] = [state.encode("utf-8")] + + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + + self.handler._auth_handler.complete_sso_login.assert_called_once_with( + user_id, request, client_redirect_url, + ) + self.handler._exchange_code.assert_called_once_with(code) + self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce) + self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token) + self.handler._fetch_userinfo.assert_not_called() + self.handler._render_error.assert_not_called() + + # Handle mapping errors + self.handler._map_userinfo_to_user = simple_async_mock( + raises=MappingException() + ) + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.assertRenderedError("mapping_error") + self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id) + + # Handle ID token errors + self.handler._parse_id_token = simple_async_mock(raises=Exception()) + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.assertRenderedError("invalid_token") + + self.handler._auth_handler.complete_sso_login.reset_mock() + self.handler._exchange_code.reset_mock() + self.handler._parse_id_token.reset_mock() + self.handler._map_userinfo_to_user.reset_mock() + self.handler._fetch_userinfo.reset_mock() + + # With userinfo fetching + self.handler._scopes = [] # do not ask the "openid" scope + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + + self.handler._auth_handler.complete_sso_login.assert_called_once_with( + user_id, request, client_redirect_url, + ) + self.handler._exchange_code.assert_called_once_with(code) + self.handler._parse_id_token.assert_not_called() + self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token) + self.handler._fetch_userinfo.assert_called_once_with(token) + self.handler._render_error.assert_not_called() + + # Handle userinfo fetching error + self.handler._fetch_userinfo = simple_async_mock(raises=Exception()) + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.assertRenderedError("fetch_error") + + # Handle code exchange failure + self.handler._exchange_code = simple_async_mock( + raises=OidcError("invalid_request") + ) + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.assertRenderedError("invalid_request") + + @defer.inlineCallbacks + def test_callback_session(self): + """The callback verifies the session presence and validity""" + self.handler._render_error = Mock(return_value=None) + request = Mock(spec=["args", "getCookie", "addCookie"]) + + # Missing cookie + request.args = {} + request.getCookie.return_value = None + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.assertRenderedError("missing_session", "No session cookie found") + + # Missing session parameter + request.args = {} + request.getCookie.return_value = "session" + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.assertRenderedError("invalid_request", "State parameter is missing") + + # Invalid cookie + request.args = {} + request.args[b"state"] = [b"state"] + request.getCookie.return_value = "session" + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.assertRenderedError("invalid_session") + + # Mismatching session + session = self.handler._generate_oidc_session_token( + state="state", nonce="nonce", client_redirect_url="http://client/redirect", + ) + request.args = {} + request.args[b"state"] = [b"mismatching state"] + request.getCookie.return_value = session + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.assertRenderedError("mismatching_session") + + # Valid session + request.args = {} + request.args[b"state"] = [b"state"] + request.getCookie.return_value = session + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.assertRenderedError("invalid_request") + + @override_config({"oidc_config": {"client_auth_method": "client_secret_post"}}) + @defer.inlineCallbacks + def test_exchange_code(self): + """Code exchange behaves correctly and handles various error scenarios.""" + token = {"type": "bearer"} + token_json = json.dumps(token).encode("utf-8") + self.http_client.request = simple_async_mock( + return_value=FakeResponse(code=200, phrase=b"OK", body=token_json) + ) + code = "code" + ret = yield defer.ensureDeferred(self.handler._exchange_code(code)) + kwargs = self.http_client.request.call_args[1] + + self.assertEqual(ret, token) + self.assertEqual(kwargs["method"], "POST") + self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT) + + args = parse_qs(kwargs["data"].decode("utf-8")) + self.assertEqual(args["grant_type"], ["authorization_code"]) + self.assertEqual(args["code"], [code]) + self.assertEqual(args["client_id"], [CLIENT_ID]) + self.assertEqual(args["client_secret"], [CLIENT_SECRET]) + self.assertEqual(args["redirect_uri"], [CALLBACK_URL]) + + # Test error handling + self.http_client.request = simple_async_mock( + return_value=FakeResponse( + code=400, + phrase=b"Bad Request", + body=b'{"error": "foo", "error_description": "bar"}', + ) + ) + with self.assertRaises(OidcError) as exc: + yield defer.ensureDeferred(self.handler._exchange_code(code)) + self.assertEqual(exc.exception.error, "foo") + self.assertEqual(exc.exception.error_description, "bar") + + # Internal server error with no JSON body + self.http_client.request = simple_async_mock( + return_value=FakeResponse( + code=500, phrase=b"Internal Server Error", body=b"Not JSON", + ) + ) + with self.assertRaises(OidcError) as exc: + yield defer.ensureDeferred(self.handler._exchange_code(code)) + self.assertEqual(exc.exception.error, "server_error") + + # Internal server error with JSON body + self.http_client.request = simple_async_mock( + return_value=FakeResponse( + code=500, + phrase=b"Internal Server Error", + body=b'{"error": "internal_server_error"}', + ) + ) + with self.assertRaises(OidcError) as exc: + yield defer.ensureDeferred(self.handler._exchange_code(code)) + self.assertEqual(exc.exception.error, "internal_server_error") + + # 4xx error without "error" field + self.http_client.request = simple_async_mock( + return_value=FakeResponse(code=400, phrase=b"Bad request", body=b"{}",) + ) + with self.assertRaises(OidcError) as exc: + yield defer.ensureDeferred(self.handler._exchange_code(code)) + self.assertEqual(exc.exception.error, "server_error") + + # 2xx error with "error" field + self.http_client.request = simple_async_mock( + return_value=FakeResponse( + code=200, phrase=b"OK", body=b'{"error": "some_error"}', + ) + ) + with self.assertRaises(OidcError) as exc: + yield defer.ensureDeferred(self.handler._exchange_code(code)) + self.assertEqual(exc.exception.error, "some_error") diff --git a/tox.ini b/tox.ini index c699f3e46a..ad1902d47d 100644 --- a/tox.ini +++ b/tox.ini @@ -185,6 +185,7 @@ commands = mypy \ synapse/handlers/auth.py \ synapse/handlers/cas_handler.py \ synapse/handlers/directory.py \ + synapse/handlers/oidc_handler.py \ synapse/handlers/presence.py \ synapse/handlers/saml_handler.py \ synapse/handlers/sync.py \ -- cgit 1.5.1 From 225c16508705ecfdde44e3c90060609fab020e32 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Thu, 14 May 2020 16:32:49 +0100 Subject: Allow expired accounts to logout (#7443) --- changelog.d/7443.bugfix | 1 + synapse/api/auth.py | 50 ++++++++++++++------- synapse/rest/client/v1/logout.py | 6 +-- tests/rest/client/v1/test_login.py | 69 ++++++++++++++++++++++++++++- tests/rest/client/v2_alpha/test_register.py | 36 ++++++++++++++- 5 files changed, 140 insertions(+), 22 deletions(-) create mode 100644 changelog.d/7443.bugfix (limited to 'synapse/rest') diff --git a/changelog.d/7443.bugfix b/changelog.d/7443.bugfix new file mode 100644 index 0000000000..1dddec59ed --- /dev/null +++ b/changelog.d/7443.bugfix @@ -0,0 +1 @@ +Allow expired user accounts to log out their device sessions. \ No newline at end of file diff --git a/synapse/api/auth.py b/synapse/api/auth.py index e009b1a760..3c660318fc 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -22,6 +22,7 @@ import pymacaroons from netaddr import IPAddress from twisted.internet import defer +from twisted.web.server import Request import synapse.logging.opentracing as opentracing import synapse.types @@ -162,19 +163,25 @@ class Auth(object): @defer.inlineCallbacks def get_user_by_req( - self, request, allow_guest=False, rights="access", allow_expired=False + self, + request: Request, + allow_guest: bool = False, + rights: str = "access", + allow_expired: bool = False, ): """ Get a registered user's ID. Args: - request - An HTTP request with an access_token query parameter. - allow_expired - Whether to allow the request through even if the account is - expired. If true, Synapse will still require an access token to be - provided but won't check if the account it belongs to has expired. This - works thanks to /login delivering access tokens regardless of accounts' - expiration. + request: An HTTP request with an access_token query parameter. + allow_guest: If False, will raise an AuthError if the user making the + request is a guest. + rights: The operation being performed; the access token must allow this + allow_expired: If True, allow the request through even if the account + is expired, or session token lifetime has ended. Note that + /login will deliver access tokens regardless of expiration. + Returns: - defer.Deferred: resolves to a ``synapse.types.Requester`` object + defer.Deferred: resolves to a `synapse.types.Requester` object Raises: InvalidClientCredentialsError if no user by that token exists or the token is invalid. @@ -205,7 +212,9 @@ class Auth(object): return synapse.types.create_requester(user_id, app_service=app_service) - user_info = yield self.get_user_by_access_token(access_token, rights) + user_info = yield self.get_user_by_access_token( + access_token, rights, allow_expired=allow_expired + ) user = user_info["user"] token_id = user_info["token_id"] is_guest = user_info["is_guest"] @@ -280,13 +289,17 @@ class Auth(object): return user_id, app_service @defer.inlineCallbacks - def get_user_by_access_token(self, token, rights="access"): + def get_user_by_access_token( + self, token: str, rights: str = "access", allow_expired: bool = False, + ): """ Validate access token and get user_id from it Args: - token (str): The access token to get the user by. - rights (str): The operation being performed; the access token must - allow this. + token: The access token to get the user by + rights: The operation being performed; the access token must + allow this + allow_expired: If False, raises an InvalidClientTokenError + if the token is expired Returns: Deferred[dict]: dict that includes: `user` (UserID) @@ -294,8 +307,10 @@ class Auth(object): `token_id` (int|None): access token id. May be None if guest `device_id` (str|None): device corresponding to access token Raises: + InvalidClientTokenError if a user by that token exists, but the token is + expired InvalidClientCredentialsError if no user by that token exists or the token - is invalid. + is invalid """ if rights == "access": @@ -304,7 +319,8 @@ class Auth(object): if r: valid_until_ms = r["valid_until_ms"] if ( - valid_until_ms is not None + not allow_expired + and valid_until_ms is not None and valid_until_ms < self.clock.time_msec() ): # there was a valid access token, but it has expired. @@ -575,7 +591,7 @@ class Auth(object): return user_level >= send_level @staticmethod - def has_access_token(request): + def has_access_token(request: Request): """Checks if the request has an access_token. Returns: @@ -586,7 +602,7 @@ class Auth(object): return bool(query_params) or bool(auth_headers) @staticmethod - def get_access_token_from_request(request): + def get_access_token_from_request(request: Request): """Extracts the access_token from the request. Args: diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/v1/logout.py index 1cf3caf832..b0c30b65be 100644 --- a/synapse/rest/client/v1/logout.py +++ b/synapse/rest/client/v1/logout.py @@ -34,10 +34,10 @@ class LogoutRestServlet(RestServlet): return 200, {} async def on_POST(self, request): - requester = await self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request, allow_expired=True) if requester.device_id is None: - # the acccess token wasn't associated with a device. + # The access token wasn't associated with a device. # Just delete the access token access_token = self.auth.get_access_token_from_request(request) await self._auth_handler.delete_access_token(access_token) @@ -62,7 +62,7 @@ class LogoutAllRestServlet(RestServlet): return 200, {} async def on_POST(self, request): - requester = await self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request, allow_expired=True) user_id = requester.user.to_string() # first delete all of the user's devices diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index 1856c7ffd5..eb8f6264fd 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -4,7 +4,7 @@ import urllib.parse from mock import Mock import synapse.rest.admin -from synapse.rest.client.v1 import login +from synapse.rest.client.v1 import login, logout from synapse.rest.client.v2_alpha import devices from synapse.rest.client.v2_alpha.account import WhoamiRestServlet @@ -20,6 +20,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, login.register_servlets, + logout.register_servlets, devices.register_servlets, lambda hs, http_server: WhoamiRestServlet(hs).register(http_server), ] @@ -256,6 +257,72 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): self.render(request) self.assertEquals(channel.code, 200, channel.result) + @override_config({"session_lifetime": "24h"}) + def test_session_can_hard_logout_after_being_soft_logged_out(self): + self.register_user("kermit", "monkey") + + # log in as normal + access_token = self.login("kermit", "monkey") + + # we should now be able to make requests with the access token + request, channel = self.make_request( + b"GET", TEST_URL, access_token=access_token + ) + self.render(request) + self.assertEquals(channel.code, 200, channel.result) + + # time passes + self.reactor.advance(24 * 3600) + + # ... and we should be soft-logouted + request, channel = self.make_request( + b"GET", TEST_URL, access_token=access_token + ) + self.render(request) + self.assertEquals(channel.code, 401, channel.result) + self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") + self.assertEquals(channel.json_body["soft_logout"], True) + + # Now try to hard logout this session + request, channel = self.make_request( + b"POST", "/logout", access_token=access_token + ) + self.render(request) + self.assertEquals(channel.result["code"], b"200", channel.result) + + @override_config({"session_lifetime": "24h"}) + def test_session_can_hard_logout_all_sessions_after_being_soft_logged_out(self): + self.register_user("kermit", "monkey") + + # log in as normal + access_token = self.login("kermit", "monkey") + + # we should now be able to make requests with the access token + request, channel = self.make_request( + b"GET", TEST_URL, access_token=access_token + ) + self.render(request) + self.assertEquals(channel.code, 200, channel.result) + + # time passes + self.reactor.advance(24 * 3600) + + # ... and we should be soft-logouted + request, channel = self.make_request( + b"GET", TEST_URL, access_token=access_token + ) + self.render(request) + self.assertEquals(channel.code, 401, channel.result) + self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") + self.assertEquals(channel.json_body["soft_logout"], True) + + # Now try to hard log out all of the user's sessions + request, channel = self.make_request( + b"POST", "/logout/all", access_token=access_token + ) + self.render(request) + self.assertEquals(channel.result["code"], b"200", channel.result) + class CASTestCase(unittest.HomeserverTestCase): diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index a68a96f618..5637ce2090 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -25,7 +25,7 @@ import synapse.rest.admin from synapse.api.constants import LoginType from synapse.api.errors import Codes from synapse.appservice import ApplicationService -from synapse.rest.client.v1 import login +from synapse.rest.client.v1 import login, logout from synapse.rest.client.v2_alpha import account, account_validity, register, sync from tests import unittest @@ -313,6 +313,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): synapse.rest.admin.register_servlets_for_client_rest_resource, login.register_servlets, sync.register_servlets, + logout.register_servlets, account_validity.register_servlets, ] @@ -405,6 +406,39 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result ) + def test_logging_out_expired_user(self): + user_id = self.register_user("kermit", "monkey") + tok = self.login("kermit", "monkey") + + self.register_user("admin", "adminpassword", admin=True) + admin_tok = self.login("admin", "adminpassword") + + url = "/_matrix/client/unstable/admin/account_validity/validity" + params = { + "user_id": user_id, + "expiration_ts": 0, + "enable_renewal_emails": False, + } + request_data = json.dumps(params) + request, channel = self.make_request( + b"POST", url, request_data, access_token=admin_tok + ) + self.render(request) + self.assertEquals(channel.result["code"], b"200", channel.result) + + # Try to log the user out + request, channel = self.make_request(b"POST", "/logout", access_token=tok) + self.render(request) + self.assertEquals(channel.result["code"], b"200", channel.result) + + # Log the user in again (allowed for expired accounts) + tok = self.login("kermit", "monkey") + + # Try to log out all of the user's sessions + request, channel = self.make_request(b"POST", "/logout/all", access_token=tok) + self.render(request) + self.assertEquals(channel.result["code"], b"200", channel.result) + class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): -- cgit 1.5.1 From a3cf36f76ed41222241393adf608d0e640bb51b8 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 15 May 2020 12:26:02 -0400 Subject: Support UI Authentication for OpenID Connect accounts (#7457) --- changelog.d/7457.feature | 1 + synapse/handlers/auth.py | 4 +- synapse/handlers/oidc_handler.py | 76 +++++++++++++++++++++++++++--------- synapse/rest/client/v1/login.py | 31 +++++++++------ synapse/rest/client/v2_alpha/auth.py | 19 +++++++-- tests/handlers/test_oidc.py | 15 ++++--- 6 files changed, 105 insertions(+), 41 deletions(-) create mode 100644 changelog.d/7457.feature (limited to 'synapse/rest') diff --git a/changelog.d/7457.feature b/changelog.d/7457.feature new file mode 100644 index 0000000000..7ad767bf71 --- /dev/null +++ b/changelog.d/7457.feature @@ -0,0 +1 @@ +Add OpenID Connect login/registration support. Contributed by Quentin Gliech, on behalf of [les Connecteurs](https://connecteu.rs). diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 524281d2f1..75b39e878c 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -80,7 +80,9 @@ class AuthHandler(BaseHandler): self.hs = hs # FIXME better possibility to access registrationHandler later? self.macaroon_gen = hs.get_macaroon_generator() self._password_enabled = hs.config.password_enabled - self._sso_enabled = hs.config.saml2_enabled or hs.config.cas_enabled + self._sso_enabled = ( + hs.config.cas_enabled or hs.config.saml2_enabled or hs.config.oidc_enabled + ) # we keep this as a list despite the O(N^2) implication so that we can # keep PASSWORD first and avoid confusing clients which pick the first diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index 178f263439..4ba8c7fda5 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -311,7 +311,7 @@ class OidcHandler: ``ClientAuth`` to authenticate with the client with its ID and secret. Args: - code: The autorization code we got from the callback. + code: The authorization code we got from the callback. Returns: A dict containing various tokens. @@ -497,11 +497,14 @@ class OidcHandler: return UserInfo(claims) async def handle_redirect_request( - self, request: SynapseRequest, client_redirect_url: bytes - ) -> None: + self, + request: SynapseRequest, + client_redirect_url: bytes, + ui_auth_session_id: Optional[str] = None, + ) -> str: """Handle an incoming request to /login/sso/redirect - It redirects the browser to the authorization endpoint with a few + It returns a redirect to the authorization endpoint with a few parameters: - ``client_id``: the client ID set in ``oidc_config.client_id`` @@ -511,24 +514,32 @@ class OidcHandler: - ``state``: a random string - ``nonce``: a random string - In addition to redirecting the client, we are setting a cookie with + In addition generating a redirect URL, we are setting a cookie with a signed macaroon token containing the state, the nonce and the client_redirect_url params. Those are then checked when the client comes back from the provider. - Args: request: the incoming request from the browser. We'll respond to it with a redirect and a cookie. client_redirect_url: the URL that we should redirect the client to when everything is done + ui_auth_session_id: The session ID of the ongoing UI Auth (or + None if this is a login). + + Returns: + The redirect URL to the authorization endpoint. + """ state = generate_token() nonce = generate_token() cookie = self._generate_oidc_session_token( - state=state, nonce=nonce, client_redirect_url=client_redirect_url.decode(), + state=state, + nonce=nonce, + client_redirect_url=client_redirect_url.decode(), + ui_auth_session_id=ui_auth_session_id, ) request.addCookie( SESSION_COOKIE_NAME, @@ -541,7 +552,7 @@ class OidcHandler: metadata = await self.load_metadata() authorization_endpoint = metadata.get("authorization_endpoint") - uri = prepare_grant_uri( + return prepare_grant_uri( authorization_endpoint, client_id=self._client_auth.client_id, response_type="code", @@ -550,8 +561,6 @@ class OidcHandler: state=state, nonce=nonce, ) - request.redirect(uri) - finish_request(request) async def handle_oidc_callback(self, request: SynapseRequest) -> None: """Handle an incoming request to /_synapse/oidc/callback @@ -625,7 +634,11 @@ class OidcHandler: # Deserialize the session token and verify it. try: - nonce, client_redirect_url = self._verify_oidc_session_token(session, state) + ( + nonce, + client_redirect_url, + ui_auth_session_id, + ) = self._verify_oidc_session_token(session, state) except MacaroonDeserializationException as e: logger.exception("Invalid session") self._render_error(request, "invalid_session", str(e)) @@ -678,15 +691,21 @@ class OidcHandler: return # and finally complete the login - await self._auth_handler.complete_sso_login( - user_id, request, client_redirect_url - ) + if ui_auth_session_id: + await self._auth_handler.complete_sso_ui_auth( + user_id, ui_auth_session_id, request + ) + else: + await self._auth_handler.complete_sso_login( + user_id, request, client_redirect_url + ) def _generate_oidc_session_token( self, state: str, nonce: str, client_redirect_url: str, + ui_auth_session_id: Optional[str], duration_in_ms: int = (60 * 60 * 1000), ) -> str: """Generates a signed token storing data about an OIDC session. @@ -702,6 +721,8 @@ class OidcHandler: nonce: The ``nonce`` parameter passed to the OIDC provider. client_redirect_url: The URL the client gave when it initiated the flow. + ui_auth_session_id: The session ID of the ongoing UI Auth (or + None if this is a login). duration_in_ms: An optional duration for the token in milliseconds. Defaults to an hour. @@ -718,12 +739,19 @@ class OidcHandler: macaroon.add_first_party_caveat( "client_redirect_url = %s" % (client_redirect_url,) ) + if ui_auth_session_id: + macaroon.add_first_party_caveat( + "ui_auth_session_id = %s" % (ui_auth_session_id,) + ) now = self._clock.time_msec() expiry = now + duration_in_ms macaroon.add_first_party_caveat("time < %d" % (expiry,)) + return macaroon.serialize() - def _verify_oidc_session_token(self, session: str, state: str) -> Tuple[str, str]: + def _verify_oidc_session_token( + self, session: str, state: str + ) -> Tuple[str, str, Optional[str]]: """Verifies and extract an OIDC session token. This verifies that a given session token was issued by this homeserver @@ -734,7 +762,7 @@ class OidcHandler: state: The state the OIDC provider gave back Returns: - The nonce and the client_redirect_url for this session + The nonce, client_redirect_url, and ui_auth_session_id for this session """ macaroon = pymacaroons.Macaroon.deserialize(session) @@ -744,17 +772,27 @@ class OidcHandler: v.satisfy_exact("state = %s" % (state,)) v.satisfy_general(lambda c: c.startswith("nonce = ")) v.satisfy_general(lambda c: c.startswith("client_redirect_url = ")) + # Sometimes there's a UI auth session ID, it seems to be OK to attempt + # to always satisfy this. + v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = ")) v.satisfy_general(self._verify_expiry) v.verify(macaroon, self._macaroon_secret_key) - # Extract the `nonce` and `client_redirect_url` from the token + # Extract the `nonce`, `client_redirect_url`, and maybe the + # `ui_auth_session_id` from the token. nonce = self._get_value_from_macaroon(macaroon, "nonce") client_redirect_url = self._get_value_from_macaroon( macaroon, "client_redirect_url" ) + try: + ui_auth_session_id = self._get_value_from_macaroon( + macaroon, "ui_auth_session_id" + ) # type: Optional[str] + except ValueError: + ui_auth_session_id = None - return nonce, client_redirect_url + return nonce, client_redirect_url, ui_auth_session_id def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str: """Extracts a caveat value from a macaroon token. @@ -773,7 +811,7 @@ class OidcHandler: for caveat in macaroon.caveats: if caveat.caveat_id.startswith(prefix): return caveat.caveat_id[len(prefix) :] - raise Exception("No %s caveat in macaroon" % (key,)) + raise ValueError("No %s caveat in macaroon" % (key,)) def _verify_expiry(self, caveat: str) -> bool: prefix = "time < " diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index de7eca21f8..d89b2e5532 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -401,19 +401,22 @@ class BaseSSORedirectServlet(RestServlet): PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True) - def on_GET(self, request: SynapseRequest): + async def on_GET(self, request: SynapseRequest): args = request.args if b"redirectUrl" not in args: return 400, "Redirect URL not specified for SSO auth" client_redirect_url = args[b"redirectUrl"][0] - sso_url = self.get_sso_url(client_redirect_url) + sso_url = await self.get_sso_url(request, client_redirect_url) request.redirect(sso_url) finish_request(request) - def get_sso_url(self, client_redirect_url: bytes) -> bytes: + async def get_sso_url( + self, request: SynapseRequest, client_redirect_url: bytes + ) -> bytes: """Get the URL to redirect to, to perform SSO auth Args: + request: The client request to redirect. client_redirect_url: the URL that we should redirect the client to when everything is done @@ -428,7 +431,9 @@ class CasRedirectServlet(BaseSSORedirectServlet): def __init__(self, hs): self._cas_handler = hs.get_cas_handler() - def get_sso_url(self, client_redirect_url: bytes) -> bytes: + async def get_sso_url( + self, request: SynapseRequest, client_redirect_url: bytes + ) -> bytes: return self._cas_handler.get_redirect_url( {"redirectUrl": client_redirect_url} ).encode("ascii") @@ -465,11 +470,13 @@ class SAMLRedirectServlet(BaseSSORedirectServlet): def __init__(self, hs): self._saml_handler = hs.get_saml_handler() - def get_sso_url(self, client_redirect_url: bytes) -> bytes: + async def get_sso_url( + self, request: SynapseRequest, client_redirect_url: bytes + ) -> bytes: return self._saml_handler.handle_redirect_request(client_redirect_url) -class OIDCRedirectServlet(RestServlet): +class OIDCRedirectServlet(BaseSSORedirectServlet): """Implementation for /login/sso/redirect for the OIDC login flow.""" PATTERNS = client_patterns("/login/sso/redirect", v1=True) @@ -477,12 +484,12 @@ class OIDCRedirectServlet(RestServlet): def __init__(self, hs): self._oidc_handler = hs.get_oidc_handler() - async def on_GET(self, request): - args = request.args - if b"redirectUrl" not in args: - return 400, "Redirect URL not specified for SSO auth" - client_redirect_url = args[b"redirectUrl"][0] - await self._oidc_handler.handle_redirect_request(request, client_redirect_url) + async def get_sso_url( + self, request: SynapseRequest, client_redirect_url: bytes + ) -> bytes: + return await self._oidc_handler.handle_redirect_request( + request, client_redirect_url + ) def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py index 24dd3d3e96..7bca1326d5 100644 --- a/synapse/rest/client/v2_alpha/auth.py +++ b/synapse/rest/client/v2_alpha/auth.py @@ -131,14 +131,19 @@ class AuthRestServlet(RestServlet): self.registration_handler = hs.get_registration_handler() # SSO configuration. - self._saml_enabled = hs.config.saml2_enabled - if self._saml_enabled: - self._saml_handler = hs.get_saml_handler() self._cas_enabled = hs.config.cas_enabled if self._cas_enabled: self._cas_handler = hs.get_cas_handler() self._cas_server_url = hs.config.cas_server_url self._cas_service_url = hs.config.cas_service_url + self._saml_enabled = hs.config.saml2_enabled + if self._saml_enabled: + self._saml_handler = hs.get_saml_handler() + self._oidc_enabled = hs.config.oidc_enabled + if self._oidc_enabled: + self._oidc_handler = hs.get_oidc_handler() + self._cas_server_url = hs.config.cas_server_url + self._cas_service_url = hs.config.cas_service_url async def on_GET(self, request, stagetype): session = parse_string(request, "session") @@ -172,11 +177,17 @@ class AuthRestServlet(RestServlet): ) elif self._saml_enabled: - client_redirect_url = "" + client_redirect_url = b"" sso_redirect_url = self._saml_handler.handle_redirect_request( client_redirect_url, session ) + elif self._oidc_enabled: + client_redirect_url = b"" + sso_redirect_url = await self._oidc_handler.handle_redirect_request( + request, client_redirect_url, session + ) + else: raise SynapseError(400, "Homeserver not configured for SSO.") diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index 61963aa90d..1bb25ab684 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -292,11 +292,10 @@ class OidcHandlerTestCase(HomeserverTestCase): @defer.inlineCallbacks def test_redirect_request(self): """The redirect request has the right arguments & generates a valid session cookie.""" - req = Mock(spec=["addCookie", "redirect", "finish"]) - yield defer.ensureDeferred( + req = Mock(spec=["addCookie"]) + url = yield defer.ensureDeferred( self.handler.handle_redirect_request(req, b"http://client/redirect") ) - url = req.redirect.call_args[0][0] url = urlparse(url) auth_endpoint = urlparse(AUTHORIZATION_ENDPOINT) @@ -382,7 +381,10 @@ class OidcHandlerTestCase(HomeserverTestCase): nonce = "nonce" client_redirect_url = "http://client/redirect" session = self.handler._generate_oidc_session_token( - state=state, nonce=nonce, client_redirect_url=client_redirect_url, + state=state, + nonce=nonce, + client_redirect_url=client_redirect_url, + ui_auth_session_id=None, ) request.getCookie.return_value = session @@ -472,7 +474,10 @@ class OidcHandlerTestCase(HomeserverTestCase): # Mismatching session session = self.handler._generate_oidc_session_token( - state="state", nonce="nonce", client_redirect_url="http://client/redirect", + state="state", + nonce="nonce", + client_redirect_url="http://client/redirect", + ui_auth_session_id=None, ) request.args = {} request.args[b"state"] = [b"mismatching state"] -- cgit 1.5.1 From d4676910c91dd492ca5cc7c207969fa7bfe1bbee Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 15 May 2020 19:17:06 +0100 Subject: remove miscellaneous PY2 code --- synapse/http/matrixfederationclient.py | 8 ++------ synapse/logging/utils.py | 10 ++-------- synapse/push/httppusher.py | 11 +++-------- synapse/rest/media/v1/_base.py | 27 +++++++++------------------ synapse/util/caches/__init__.py | 7 +------ synapse/util/stringutils.py | 28 +++++++--------------------- 6 files changed, 24 insertions(+), 67 deletions(-) (limited to 'synapse/rest') diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 225a47e3c3..44077f5349 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -19,7 +19,7 @@ import random import sys from io import BytesIO -from six import PY3, raise_from, string_types +from six import raise_from, string_types from six.moves import urllib import attr @@ -70,11 +70,7 @@ incoming_responses_counter = Counter( MAX_LONG_RETRIES = 10 MAX_SHORT_RETRIES = 3 - -if PY3: - MAXINT = sys.maxsize -else: - MAXINT = sys.maxint +MAXINT = sys.maxsize _next_id = 1 diff --git a/synapse/logging/utils.py b/synapse/logging/utils.py index 0c2527bd86..99049bb5d8 100644 --- a/synapse/logging/utils.py +++ b/synapse/logging/utils.py @@ -20,8 +20,6 @@ import time from functools import wraps from inspect import getcallargs -from six import PY3 - _TIME_FUNC_ID = 0 @@ -30,12 +28,8 @@ def _log_debug_as_f(f, msg, msg_args): logger = logging.getLogger(name) if logger.isEnabledFor(logging.DEBUG): - if PY3: - lineno = f.__code__.co_firstlineno - pathname = f.__code__.co_filename - else: - lineno = f.func_code.co_firstlineno - pathname = f.func_code.co_filename + lineno = f.__code__.co_firstlineno + pathname = f.__code__.co_filename record = logging.LogRecord( name=name, diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index 5bb17d1228..eaaa7afc91 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -15,8 +15,6 @@ # limitations under the License. import logging -import six - from prometheus_client import Counter from twisted.internet import defer @@ -28,9 +26,6 @@ from synapse.push import PusherConfigException from . import push_rule_evaluator, push_tools -if six.PY3: - long = int - logger = logging.getLogger(__name__) http_push_processed_counter = Counter( @@ -318,7 +313,7 @@ class HttpPusher(object): { "app_id": self.app_id, "pushkey": self.pushkey, - "pushkey_ts": long(self.pushkey_ts / 1000), + "pushkey_ts": int(self.pushkey_ts / 1000), "data": self.data_minus_url, } ], @@ -347,7 +342,7 @@ class HttpPusher(object): { "app_id": self.app_id, "pushkey": self.pushkey, - "pushkey_ts": long(self.pushkey_ts / 1000), + "pushkey_ts": int(self.pushkey_ts / 1000), "data": self.data_minus_url, "tweaks": tweaks, } @@ -409,7 +404,7 @@ class HttpPusher(object): { "app_id": self.app_id, "pushkey": self.pushkey, - "pushkey_ts": long(self.pushkey_ts / 1000), + "pushkey_ts": int(self.pushkey_ts / 1000), "data": self.data_minus_url, } ], diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py index 503f2bed98..3689777266 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py @@ -17,7 +17,6 @@ import logging import os -from six import PY3 from six.moves import urllib from twisted.internet import defer @@ -324,23 +323,15 @@ def get_filename_from_headers(headers): upload_name_utf8 = upload_name_utf8[7:] # We have a filename*= section. This MUST be ASCII, and any UTF-8 # bytes are %-quoted. - if PY3: - try: - # Once it is decoded, we can then unquote the %-encoded - # parts strictly into a unicode string. - upload_name = urllib.parse.unquote( - upload_name_utf8.decode("ascii"), errors="strict" - ) - except UnicodeDecodeError: - # Incorrect UTF-8. - pass - else: - # On Python 2, we first unquote the %-encoded parts and then - # decode it strictly using UTF-8. - try: - upload_name = urllib.parse.unquote(upload_name_utf8).decode("utf8") - except UnicodeDecodeError: - pass + try: + # Once it is decoded, we can then unquote the %-encoded + # parts strictly into a unicode string. + upload_name = urllib.parse.unquote( + upload_name_utf8.decode("ascii"), errors="strict" + ) + except UnicodeDecodeError: + # Incorrect UTF-8. + pass # If there isn't check for an ascii name. if not upload_name: diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index 4b8a0c7a8f..dd356bf156 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -15,11 +15,9 @@ # limitations under the License. import logging +from sys import intern from typing import Callable, Dict, Optional -import six -from six.moves import intern - import attr from prometheus_client.core import Gauge @@ -154,9 +152,6 @@ def intern_string(string): return None try: - if six.PY2: - string = string.encode("ascii") - return intern(string) except UnicodeEncodeError: return string diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py index 81a44184ca..08c86e92b8 100644 --- a/synapse/util/stringutils.py +++ b/synapse/util/stringutils.py @@ -19,9 +19,6 @@ import re import string from collections import Iterable -from six import PY3 -from six.moves import range - from synapse.api.errors import Codes, SynapseError _string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@" @@ -46,24 +43,13 @@ def random_string_with_symbols(length): def is_ascii(s): - - if PY3: - if isinstance(s, bytes): - try: - s.decode("ascii").encode("ascii") - except UnicodeDecodeError: - return False - except UnicodeEncodeError: - return False - return True - - try: - s.encode("ascii") - except UnicodeEncodeError: - return False - except UnicodeDecodeError: - return False - else: + if isinstance(s, bytes): + try: + s.decode("ascii").encode("ascii") + except UnicodeDecodeError: + return False + except UnicodeEncodeError: + return False return True -- cgit 1.5.1 From 9dc6f3075aea7c76c3d6a201f8a78ace76f99a3e Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 20 May 2020 09:48:03 -0400 Subject: Hash passwords earlier in the password reset process (#7538) This now matches the logic of the registration process as modified in 56db0b1365965c02ff539193e26c333b7f70d101 / #7523. --- changelog.d/7538.bugfix | 1 + synapse/handlers/set_password.py | 5 +---- synapse/rest/admin/users.py | 13 +++++++++++-- synapse/rest/client/v2_alpha/account.py | 21 ++++++++++++++++++--- synapse/rest/client/v2_alpha/register.py | 4 ++-- 5 files changed, 33 insertions(+), 11 deletions(-) create mode 100644 changelog.d/7538.bugfix (limited to 'synapse/rest') diff --git a/changelog.d/7538.bugfix b/changelog.d/7538.bugfix new file mode 100644 index 0000000000..4a614a9e61 --- /dev/null +++ b/changelog.d/7538.bugfix @@ -0,0 +1 @@ + Hash passwords as early as possible during password reset. diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py index 63d8f9aa0d..4d245b618b 100644 --- a/synapse/handlers/set_password.py +++ b/synapse/handlers/set_password.py @@ -35,16 +35,13 @@ class SetPasswordHandler(BaseHandler): async def set_password( self, user_id: str, - new_password: str, + password_hash: str, logout_devices: bool, requester: Optional[Requester] = None, ): if not self.hs.config.password_localdb_enabled: raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN) - self._password_policy_handler.validate_password(new_password) - password_hash = await self._auth_handler.hash(new_password) - try: await self.store.user_set_password_hash(user_id, password_hash) except StoreError as e: diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 326682fbdb..e7f6928c85 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -222,8 +222,14 @@ class UserRestServletV2(RestServlet): else: new_password = body["password"] logout_devices = True + + new_password_hash = await self.auth_handler.hash(new_password) + await self.set_password_handler.set_password( - target_user.to_string(), new_password, logout_devices, requester + target_user.to_string(), + new_password_hash, + logout_devices, + requester, ) if "deactivated" in body: @@ -523,6 +529,7 @@ class ResetPasswordRestServlet(RestServlet): self.store = hs.get_datastore() self.hs = hs self.auth = hs.get_auth() + self.auth_handler = hs.get_auth_handler() self._set_password_handler = hs.get_set_password_handler() async def on_POST(self, request, target_user_id): @@ -539,8 +546,10 @@ class ResetPasswordRestServlet(RestServlet): new_password = params["new_password"] logout_devices = params.get("logout_devices", True) + new_password_hash = await self.auth_handler.hash(new_password) + await self._set_password_handler.set_password( - target_user_id, new_password, logout_devices, requester + target_user_id, new_password_hash, logout_devices, requester ) return 200, {} diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 1bd0234779..d4f721b6b9 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -220,12 +220,27 @@ class PasswordRestServlet(RestServlet): self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() self.datastore = self.hs.get_datastore() + self.password_policy_handler = hs.get_password_policy_handler() self._set_password_handler = hs.get_set_password_handler() @interactive_auth_handler async def on_POST(self, request): body = parse_json_object_from_request(request) + # we do basic sanity checks here because the auth layer will store these + # in sessions. Pull out the new password provided to us. + if "new_password" in body: + new_password = body.pop("new_password") + if not isinstance(new_password, str) or len(new_password) > 512: + raise SynapseError(400, "Invalid password") + self.password_policy_handler.validate_password(new_password) + + # If the password is valid, hash it and store it back on the body. + # This ensures that only the hashed password is handled everywhere. + if "new_password_hash" in body: + raise SynapseError(400, "Unexpected property: new_password_hash") + body["new_password_hash"] = await self.auth_handler.hash(new_password) + # there are two possibilities here. Either the user does not have an # access token, and needs to do a password reset; or they have one and # need to validate their identity. @@ -276,12 +291,12 @@ class PasswordRestServlet(RestServlet): logger.error("Auth succeeded but no known type! %r", result.keys()) raise SynapseError(500, "", Codes.UNKNOWN) - assert_params_in_dict(params, ["new_password"]) - new_password = params["new_password"] + assert_params_in_dict(params, ["new_password_hash"]) + new_password_hash = params["new_password_hash"] logout_devices = params.get("logout_devices", True) await self._set_password_handler.set_password( - user_id, new_password, logout_devices, requester + user_id, new_password_hash, logout_devices, requester ) return 200, {} diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index c26927f27b..addd4cae19 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -431,8 +431,8 @@ class RegisterRestServlet(RestServlet): raise SynapseError(400, "Invalid password") self.password_policy_handler.validate_password(password) - # If the password is valid, hash it and store it back on the request. - # This ensures the hashed password is handled everywhere. + # If the password is valid, hash it and store it back on the body. + # This ensures that only the hashed password is handled everywhere. if "password_hash" in body: raise SynapseError(400, "Unexpected property: password_hash") body["password_hash"] = await self.auth_handler.hash(password) -- cgit 1.5.1 From 66f2ebc22fec01b4673fabae22f2c94dfeac58e3 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 22 May 2020 07:17:30 -0400 Subject: Use a non-empty RelayState for user interactive auth with SAML. (#7552) --- changelog.d/7552.bugfix | 1 + synapse/rest/client/v2_alpha/auth.py | 5 ++++- 2 files changed, 5 insertions(+), 1 deletion(-) create mode 100644 changelog.d/7552.bugfix (limited to 'synapse/rest') diff --git a/changelog.d/7552.bugfix b/changelog.d/7552.bugfix new file mode 100644 index 0000000000..60b31d6d31 --- /dev/null +++ b/changelog.d/7552.bugfix @@ -0,0 +1 @@ +Fix "Missing RelayState parameter" error when using user interactive authentication with SAML for some SAML providers. diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py index 7bca1326d5..75590ebaeb 100644 --- a/synapse/rest/client/v2_alpha/auth.py +++ b/synapse/rest/client/v2_alpha/auth.py @@ -177,7 +177,10 @@ class AuthRestServlet(RestServlet): ) elif self._saml_enabled: - client_redirect_url = b"" + # Some SAML identity providers (e.g. Google) require a + # RelayState parameter on requests. It is not necessary here, so + # pass in a dummy redirect URL (which will never get used). + client_redirect_url = b"unused" sso_redirect_url = self._saml_handler.handle_redirect_request( client_redirect_url, session ) -- cgit 1.5.1 From 1531b214fc57714c14046a8f66c7b5fe5ec5dcdd Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 22 May 2020 14:21:54 +0100 Subject: Add ability to wait for replication streams (#7542) The idea here is that if an instance persists an event via the replication HTTP API it can return before we receive that event over replication, which can lead to races where code assumes that persisting an event immediately updates various caches (e.g. current state of the room). Most of Synapse doesn't hit such races, so we don't do the waiting automagically, instead we do so where necessary to avoid unnecessary delays. We may decide to change our minds here if it turns out there are a lot of subtle races going on. People probably want to look at this commit by commit. --- changelog.d/7542.misc | 1 + synapse/handlers/federation.py | 33 ++++++--- synapse/handlers/message.py | 36 ++++++--- synapse/handlers/room.py | 65 +++++++++++----- synapse/handlers/room_member.py | 65 ++++++++++------ synapse/handlers/room_member_worker.py | 11 +-- synapse/replication/http/federation.py | 13 +++- synapse/replication/http/membership.py | 14 ++-- synapse/replication/http/send_event.py | 4 +- synapse/replication/http/streams.py | 5 +- synapse/replication/tcp/client.py | 90 ++++++++++++++++++++++- synapse/rest/admin/rooms.py | 10 ++- synapse/rest/client/v1/room.py | 20 +++-- synapse/rest/client/v2_alpha/relations.py | 2 +- synapse/server.py | 5 ++ synapse/server.pyi | 5 ++ synapse/server_notices/server_notices_manager.py | 6 +- synapse/storage/data_stores/main/events_worker.py | 6 +- synapse/storage/data_stores/main/roommember.py | 2 + tests/federation/test_complexity.py | 8 +- tests/handlers/test_typing.py | 5 +- tests/storage/test_cleanup_extrems.py | 4 +- tests/storage/test_event_metrics.py | 2 +- tests/test_federation.py | 4 +- 24 files changed, 304 insertions(+), 112 deletions(-) create mode 100644 changelog.d/7542.misc (limited to 'synapse/rest') diff --git a/changelog.d/7542.misc b/changelog.d/7542.misc new file mode 100644 index 0000000000..7dd9b4823b --- /dev/null +++ b/changelog.d/7542.misc @@ -0,0 +1 @@ +Add ability to wait for replication streams. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index bb03cc9add..e354c803db 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -126,6 +126,7 @@ class FederationHandler(BaseHandler): self._server_notices_mxid = hs.config.server_notices_mxid self.config = hs.config self.http_client = hs.get_simple_http_client() + self._replication = hs.get_replication_data_handler() self._send_events_to_master = ReplicationFederationSendEventsRestServlet.make_client( hs @@ -1221,7 +1222,7 @@ class FederationHandler(BaseHandler): async def do_invite_join( self, target_hosts: Iterable[str], room_id: str, joinee: str, content: JsonDict - ) -> None: + ) -> Tuple[str, int]: """ Attempts to join the `joinee` to the room `room_id` via the servers contained in `target_hosts`. @@ -1304,15 +1305,23 @@ class FederationHandler(BaseHandler): room_id=room_id, room_version=room_version_obj, ) - await self._persist_auth_tree( + max_stream_id = await self._persist_auth_tree( origin, auth_chain, state, event, room_version_obj ) + # We wait here until this instance has seen the events come down + # replication (if we're using replication) as the below uses caches. + # + # TODO: Currently the events stream is written to from master + await self._replication.wait_for_stream_position( + "master", "events", max_stream_id + ) + # Check whether this room is the result of an upgrade of a room we already know # about. If so, migrate over user information predecessor = await self.store.get_room_predecessor(room_id) if not predecessor or not isinstance(predecessor.get("room_id"), str): - return + return event.event_id, max_stream_id old_room_id = predecessor["room_id"] logger.debug( "Found predecessor for %s during remote join: %s", room_id, old_room_id @@ -1325,6 +1334,7 @@ class FederationHandler(BaseHandler): ) logger.debug("Finished joining %s to %s", joinee, room_id) + return event.event_id, max_stream_id finally: room_queue = self.room_queues[room_id] del self.room_queues[room_id] @@ -1554,7 +1564,7 @@ class FederationHandler(BaseHandler): async def do_remotely_reject_invite( self, target_hosts: Iterable[str], room_id: str, user_id: str, content: JsonDict - ) -> EventBase: + ) -> Tuple[EventBase, int]: origin, event, room_version = await self._make_and_verify_event( target_hosts, room_id, user_id, "leave", content=content ) @@ -1574,9 +1584,9 @@ class FederationHandler(BaseHandler): await self.federation_client.send_leave(target_hosts, event) context = await self.state_handler.compute_event_context(event) - await self.persist_events_and_notify([(event, context)]) + stream_id = await self.persist_events_and_notify([(event, context)]) - return event + return event, stream_id async def _make_and_verify_event( self, @@ -1888,7 +1898,7 @@ class FederationHandler(BaseHandler): state: List[EventBase], event: EventBase, room_version: RoomVersion, - ) -> None: + ) -> int: """Checks the auth chain is valid (and passes auth checks) for the state and event. Then persists the auth chain and state atomically. Persists the event separately. Notifies about the persisted events @@ -1982,7 +1992,7 @@ class FederationHandler(BaseHandler): event, old_state=state ) - await self.persist_events_and_notify([(event, new_event_context)]) + return await self.persist_events_and_notify([(event, new_event_context)]) async def _prep_event( self, @@ -2835,7 +2845,7 @@ class FederationHandler(BaseHandler): self, event_and_contexts: Sequence[Tuple[EventBase, EventContext]], backfilled: bool = False, - ) -> None: + ) -> int: """Persists events and tells the notifier/pushers about them, if necessary. @@ -2845,11 +2855,12 @@ class FederationHandler(BaseHandler): backfilling or not """ if self.config.worker_app: - await self._send_events_to_master( + result = await self._send_events_to_master( store=self.store, event_and_contexts=event_and_contexts, backfilled=backfilled, ) + return result["max_stream_id"] else: max_stream_id = await self.storage.persistence.persist_events( event_and_contexts, backfilled=backfilled @@ -2864,6 +2875,8 @@ class FederationHandler(BaseHandler): for event, _ in event_and_contexts: await self._notify_persisted_event(event, max_stream_id) + return max_stream_id + async def _notify_persisted_event( self, event: EventBase, max_stream_id: int ) -> None: diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 8f362896a2..f445e2aa2a 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Optional +from typing import Optional, Tuple from six import iteritems, itervalues, string_types @@ -42,6 +42,7 @@ from synapse.api.errors import ( ) from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions from synapse.api.urls import ConsentURIBuilder +from synapse.events import EventBase from synapse.events.validator import EventValidator from synapse.logging.context import run_in_background from synapse.metrics.background_process_metrics import run_as_background_process @@ -630,7 +631,9 @@ class EventCreationHandler(object): msg = self._block_events_without_consent_error % {"consent_uri": consent_uri} raise ConsentNotGivenError(msg=msg, consent_uri=consent_uri) - async def send_nonmember_event(self, requester, event, context, ratelimit=True): + async def send_nonmember_event( + self, requester, event, context, ratelimit=True + ) -> int: """ Persists and notifies local clients and federation of an event. @@ -639,6 +642,9 @@ class EventCreationHandler(object): context (Context) the context of the event. ratelimit (bool): Whether to rate limit this send. is_guest (bool): Whether the sender is a guest. + + Return: + The stream_id of the persisted event. """ if event.type == EventTypes.Member: raise SynapseError( @@ -659,7 +665,7 @@ class EventCreationHandler(object): ) return prev_state - await self.handle_new_client_event( + return await self.handle_new_client_event( requester=requester, event=event, context=context, ratelimit=ratelimit ) @@ -688,7 +694,7 @@ class EventCreationHandler(object): async def create_and_send_nonmember_event( self, requester, event_dict, ratelimit=True, txn_id=None - ): + ) -> Tuple[EventBase, int]: """ Creates an event, then sends it. @@ -711,10 +717,10 @@ class EventCreationHandler(object): spam_error = "Spam is not permitted here" raise SynapseError(403, spam_error, Codes.FORBIDDEN) - await self.send_nonmember_event( + stream_id = await self.send_nonmember_event( requester, event, context, ratelimit=ratelimit ) - return event + return event, stream_id @measure_func("create_new_client_event") @defer.inlineCallbacks @@ -774,7 +780,7 @@ class EventCreationHandler(object): @measure_func("handle_new_client_event") async def handle_new_client_event( self, requester, event, context, ratelimit=True, extra_users=[] - ): + ) -> int: """Processes a new event. This includes checking auth, persisting it, notifying users, sending to remote servers, etc. @@ -787,6 +793,9 @@ class EventCreationHandler(object): context (EventContext) ratelimit (bool) extra_users (list(UserID)): Any extra users to notify about event + + Return: + The stream_id of the persisted event. """ if event.is_state() and (event.type, event.state_key) == ( @@ -827,7 +836,7 @@ class EventCreationHandler(object): try: # If we're a worker we need to hit out to the master. if self.config.worker_app: - await self.send_event_to_master( + result = await self.send_event_to_master( event_id=event.event_id, store=self.store, requester=requester, @@ -836,14 +845,17 @@ class EventCreationHandler(object): ratelimit=ratelimit, extra_users=extra_users, ) + stream_id = result["stream_id"] + event.internal_metadata.stream_ordering = stream_id success = True - return + return stream_id - await self.persist_and_notify_client_event( + stream_id = await self.persist_and_notify_client_event( requester, event, context, ratelimit=ratelimit, extra_users=extra_users ) success = True + return stream_id finally: if not success: # Ensure that we actually remove the entries in the push actions @@ -886,7 +898,7 @@ class EventCreationHandler(object): async def persist_and_notify_client_event( self, requester, event, context, ratelimit=True, extra_users=[] - ): + ) -> int: """Called when we have fully built the event, have already calculated the push actions for the event, and checked auth. @@ -1076,6 +1088,8 @@ class EventCreationHandler(object): # matters as sometimes presence code can take a while. run_in_background(self._bump_active_time, requester.user) + return event_stream_id + async def _bump_active_time(self, user): try: presence = self.hs.get_presence_handler() diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 13850ba672..2698a129ca 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -22,6 +22,7 @@ import logging import math import string from collections import OrderedDict +from typing import Tuple from six import iteritems, string_types @@ -518,7 +519,7 @@ class RoomCreationHandler(BaseHandler): async def create_room( self, requester, config, ratelimit=True, creator_join_profile=None - ): + ) -> Tuple[dict, int]: """ Creates a new room. Args: @@ -535,9 +536,9 @@ class RoomCreationHandler(BaseHandler): `avatar_url` and/or `displayname`. Returns: - Deferred[dict]: - a dict containing the keys `room_id` and, if an alias was - requested, `room_alias`. + First, a dict containing the keys `room_id` and, if an alias + was, requested, `room_alias`. Secondly, the stream_id of the + last persisted event. Raises: SynapseError if the room ID couldn't be stored, or something went horribly wrong. @@ -669,7 +670,7 @@ class RoomCreationHandler(BaseHandler): # override any attempt to set room versions via the creation_content creation_content["room_version"] = room_version.identifier - await self._send_events_for_new_room( + last_stream_id = await self._send_events_for_new_room( requester, room_id, preset_config=preset_config, @@ -683,7 +684,10 @@ class RoomCreationHandler(BaseHandler): if "name" in config: name = config["name"] - await self.event_creation_handler.create_and_send_nonmember_event( + ( + _, + last_stream_id, + ) = await self.event_creation_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.Name, @@ -697,7 +701,10 @@ class RoomCreationHandler(BaseHandler): if "topic" in config: topic = config["topic"] - await self.event_creation_handler.create_and_send_nonmember_event( + ( + _, + last_stream_id, + ) = await self.event_creation_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.Topic, @@ -715,7 +722,7 @@ class RoomCreationHandler(BaseHandler): if is_direct: content["is_direct"] = is_direct - await self.room_member_handler.update_membership( + _, last_stream_id = await self.room_member_handler.update_membership( requester, UserID.from_string(invitee), room_id, @@ -729,7 +736,7 @@ class RoomCreationHandler(BaseHandler): id_access_token = invite_3pid.get("id_access_token") # optional address = invite_3pid["address"] medium = invite_3pid["medium"] - await self.hs.get_room_member_handler().do_3pid_invite( + last_stream_id = await self.hs.get_room_member_handler().do_3pid_invite( room_id, requester.user, medium, @@ -745,7 +752,7 @@ class RoomCreationHandler(BaseHandler): if room_alias: result["room_alias"] = room_alias.to_string() - return result + return result, last_stream_id async def _send_events_for_new_room( self, @@ -758,7 +765,13 @@ class RoomCreationHandler(BaseHandler): room_alias=None, power_level_content_override=None, # Doesn't apply when initial state has power level state event content creator_join_profile=None, - ): + ) -> int: + """Sends the initial events into a new room. + + Returns: + The stream_id of the last event persisted. + """ + def create(etype, content, **kwargs): e = {"type": etype, "content": content} @@ -767,12 +780,16 @@ class RoomCreationHandler(BaseHandler): return e - async def send(etype, content, **kwargs): + async def send(etype, content, **kwargs) -> int: event = create(etype, content, **kwargs) logger.debug("Sending %s in new room", etype) - await self.event_creation_handler.create_and_send_nonmember_event( + ( + _, + last_stream_id, + ) = await self.event_creation_handler.create_and_send_nonmember_event( creator, event, ratelimit=False ) + return last_stream_id config = RoomCreationHandler.PRESETS_DICT[preset_config] @@ -797,7 +814,9 @@ class RoomCreationHandler(BaseHandler): # of the first events that get sent into a room. pl_content = initial_state.pop((EventTypes.PowerLevels, ""), None) if pl_content is not None: - await send(etype=EventTypes.PowerLevels, content=pl_content) + last_sent_stream_id = await send( + etype=EventTypes.PowerLevels, content=pl_content + ) else: power_level_content = { "users": {creator_id: 100}, @@ -830,33 +849,39 @@ class RoomCreationHandler(BaseHandler): if power_level_content_override: power_level_content.update(power_level_content_override) - await send(etype=EventTypes.PowerLevels, content=power_level_content) + last_sent_stream_id = await send( + etype=EventTypes.PowerLevels, content=power_level_content + ) if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state: - await send( + last_sent_stream_id = await send( etype=EventTypes.CanonicalAlias, content={"alias": room_alias.to_string()}, ) if (EventTypes.JoinRules, "") not in initial_state: - await send( + last_sent_stream_id = await send( etype=EventTypes.JoinRules, content={"join_rule": config["join_rules"]} ) if (EventTypes.RoomHistoryVisibility, "") not in initial_state: - await send( + last_sent_stream_id = await send( etype=EventTypes.RoomHistoryVisibility, content={"history_visibility": config["history_visibility"]}, ) if config["guest_can_join"]: if (EventTypes.GuestAccess, "") not in initial_state: - await send( + last_sent_stream_id = await send( etype=EventTypes.GuestAccess, content={"guest_access": "can_join"} ) for (etype, state_key), content in initial_state.items(): - await send(etype=etype, state_key=state_key, content=content) + last_sent_stream_id = await send( + etype=etype, state_key=state_key, content=content + ) + + return last_sent_stream_id async def _generate_room_id( self, creator_id: str, is_public: str, room_version: RoomVersion, diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index e51e1c32fe..691b6705b2 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -17,7 +17,7 @@ import abc import logging -from typing import Dict, Iterable, List, Optional, Tuple, Union +from typing import Dict, Iterable, List, Optional, Tuple from six.moves import http_client @@ -84,7 +84,7 @@ class RoomMemberHandler(object): room_id: str, user: UserID, content: dict, - ) -> Optional[dict]: + ) -> Tuple[str, int]: """Try and join a room that this server is not in Args: @@ -104,7 +104,7 @@ class RoomMemberHandler(object): room_id: str, target: UserID, content: dict, - ) -> dict: + ) -> Tuple[Optional[str], int]: """Attempt to reject an invite for a room this server is not in. If we fail to do so we locally mark the invite as rejected. @@ -154,7 +154,7 @@ class RoomMemberHandler(object): ratelimit: bool = True, content: Optional[dict] = None, require_consent: bool = True, - ) -> EventBase: + ) -> Tuple[str, int]: user_id = target.to_string() if content is None: @@ -187,9 +187,10 @@ class RoomMemberHandler(object): ) if duplicate is not None: # Discard the new event since this membership change is a no-op. - return duplicate + _, stream_id = await self.store.get_event_ordering(duplicate.event_id) + return duplicate.event_id, stream_id - await self.event_creation_handler.handle_new_client_event( + stream_id = await self.event_creation_handler.handle_new_client_event( requester, event, context, extra_users=[target], ratelimit=ratelimit ) @@ -213,7 +214,7 @@ class RoomMemberHandler(object): if prev_member_event.membership == Membership.JOIN: await self._user_left_room(target, room_id) - return event + return event.event_id, stream_id async def copy_room_tags_and_direct_to_room( self, old_room_id, new_room_id, user_id @@ -263,7 +264,7 @@ class RoomMemberHandler(object): ratelimit: bool = True, content: Optional[dict] = None, require_consent: bool = True, - ) -> Union[EventBase, Optional[dict]]: + ) -> Tuple[Optional[str], int]: key = (room_id,) with (await self.member_linearizer.queue(key)): @@ -294,7 +295,7 @@ class RoomMemberHandler(object): ratelimit: bool = True, content: Optional[dict] = None, require_consent: bool = True, - ) -> Union[EventBase, Optional[dict]]: + ) -> Tuple[Optional[str], int]: content_specified = bool(content) if content is None: content = {} @@ -398,7 +399,13 @@ class RoomMemberHandler(object): same_membership = old_membership == effective_membership_state same_sender = requester.user.to_string() == old_state.sender if same_sender and same_membership and same_content: - return old_state + _, stream_id = await self.store.get_event_ordering( + old_state.event_id + ) + return ( + old_state.event_id, + stream_id, + ) if old_membership in ["ban", "leave"] and action == "kick": raise AuthError(403, "The target user is not in the room") @@ -705,7 +712,7 @@ class RoomMemberHandler(object): requester: Requester, txn_id: Optional[str], id_access_token: Optional[str] = None, - ) -> None: + ) -> int: if self.config.block_non_admin_invites: is_requester_admin = await self.auth.is_server_admin(requester.user) if not is_requester_admin: @@ -737,11 +744,11 @@ class RoomMemberHandler(object): ) if invitee: - await self.update_membership( + _, stream_id = await self.update_membership( requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id ) else: - await self._make_and_store_3pid_invite( + stream_id = await self._make_and_store_3pid_invite( requester, id_server, medium, @@ -752,6 +759,8 @@ class RoomMemberHandler(object): id_access_token=id_access_token, ) + return stream_id + async def _make_and_store_3pid_invite( self, requester: Requester, @@ -762,7 +771,7 @@ class RoomMemberHandler(object): user: UserID, txn_id: Optional[str], id_access_token: Optional[str] = None, - ) -> None: + ) -> int: room_state = await self.state_handler.get_current_state(room_id) inviter_display_name = "" @@ -817,7 +826,10 @@ class RoomMemberHandler(object): id_access_token=id_access_token, ) - await self.event_creation_handler.create_and_send_nonmember_event( + ( + event, + stream_id, + ) = await self.event_creation_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.ThirdPartyInvite, @@ -835,6 +847,7 @@ class RoomMemberHandler(object): ratelimit=False, txn_id=txn_id, ) + return stream_id async def _is_host_in_room( self, current_state_ids: Dict[Tuple[str, str], str] @@ -916,7 +929,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): room_id: str, user: UserID, content: dict, - ) -> None: + ) -> Tuple[str, int]: """Implements RoomMemberHandler._remote_join """ # filter ourselves out of remote_room_hosts: do_invite_join ignores it @@ -945,7 +958,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): # join dance for now, since we're kinda implicitly checking # that we are allowed to join when we decide whether or not we # need to do the invite/join dance. - await self.federation_handler.do_invite_join( + event_id, stream_id = await self.federation_handler.do_invite_join( remote_room_hosts, room_id, user.to_string(), content ) await self._user_joined_room(user, room_id) @@ -955,14 +968,14 @@ class RoomMemberMasterHandler(RoomMemberHandler): if self.hs.config.limit_remote_rooms.enabled: if too_complex is False: # We checked, and we're under the limit. - return + return event_id, stream_id # Check again, but with the local state events too_complex = await self._is_local_room_too_complex(room_id) if too_complex is False: # We're under the limit. - return + return event_id, stream_id # The room is too large. Leave. requester = types.create_requester(user, None, False, None) @@ -975,6 +988,8 @@ class RoomMemberMasterHandler(RoomMemberHandler): errcode=Codes.RESOURCE_LIMIT_EXCEEDED, ) + return event_id, stream_id + async def _remote_reject_invite( self, requester: Requester, @@ -982,15 +997,15 @@ class RoomMemberMasterHandler(RoomMemberHandler): room_id: str, target: UserID, content: dict, - ) -> dict: + ) -> Tuple[Optional[str], int]: """Implements RoomMemberHandler._remote_reject_invite """ fed_handler = self.federation_handler try: - ret = await fed_handler.do_remotely_reject_invite( + event, stream_id = await fed_handler.do_remotely_reject_invite( remote_room_hosts, room_id, target.to_string(), content=content, ) - return ret + return event.event_id, stream_id except Exception as e: # if we were unable to reject the exception, just mark # it as rejected on our end and plough ahead. @@ -1000,8 +1015,10 @@ class RoomMemberMasterHandler(RoomMemberHandler): # logger.warning("Failed to reject invite: %s", e) - await self.store.locally_reject_invite(target.to_string(), room_id) - return {} + stream_id = await self.store.locally_reject_invite( + target.to_string(), room_id + ) + return None, stream_id async def _user_joined_room(self, target: UserID, room_id: str) -> None: """Implements RoomMemberHandler._user_joined_room diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py index 5c776cc0be..02e0c4103d 100644 --- a/synapse/handlers/room_member_worker.py +++ b/synapse/handlers/room_member_worker.py @@ -14,7 +14,7 @@ # limitations under the License. import logging -from typing import List, Optional +from typing import List, Optional, Tuple from synapse.api.errors import SynapseError from synapse.handlers.room_member import RoomMemberHandler @@ -43,7 +43,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler): room_id: str, user: UserID, content: dict, - ) -> Optional[dict]: + ) -> Tuple[str, int]: """Implements RoomMemberHandler._remote_join """ if len(remote_room_hosts) == 0: @@ -59,7 +59,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler): await self._user_joined_room(user, room_id) - return ret + return ret["event_id"], ret["stream_id"] async def _remote_reject_invite( self, @@ -68,16 +68,17 @@ class RoomMemberWorkerHandler(RoomMemberHandler): room_id: str, target: UserID, content: dict, - ) -> dict: + ) -> Tuple[Optional[str], int]: """Implements RoomMemberHandler._remote_reject_invite """ - return await self._remote_reject_client( + ret = await self._remote_reject_client( requester=requester, remote_room_hosts=remote_room_hosts, room_id=room_id, user_id=target.to_string(), content=content, ) + return ret["event_id"], ret["stream_id"] async def _user_joined_room(self, target: UserID, room_id: str) -> None: """Implements RoomMemberHandler._user_joined_room diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index 7e23b565b9..c287c4e269 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -29,7 +29,7 @@ logger = logging.getLogger(__name__) class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): """Handles events newly received from federation, including persisting and - notifying. + notifying. Returns the maximum stream ID of the persisted events. The API looks like: @@ -46,6 +46,13 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): "context": { .. serialized event context .. }, }], "backfilled": false + } + + 200 OK + + { + "max_stream_id": 32443, + } """ NAME = "fed_send_events" @@ -115,11 +122,11 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): logger.info("Got %d events from federation", len(event_and_contexts)) - await self.federation_handler.persist_events_and_notify( + max_stream_id = await self.federation_handler.persist_events_and_notify( event_and_contexts, backfilled ) - return 200, {} + return 200, {"max_stream_id": max_stream_id} class ReplicationFederationSendEduRestServlet(ReplicationEndpoint): diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py index 3577611fd7..050fd34562 100644 --- a/synapse/replication/http/membership.py +++ b/synapse/replication/http/membership.py @@ -76,11 +76,11 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint): logger.info("remote_join: %s into room: %s", user_id, room_id) - await self.federation_handler.do_invite_join( + event_id, stream_id = await self.federation_handler.do_invite_join( remote_room_hosts, room_id, user_id, event_content ) - return 200, {} + return 200, {"event_id": event_id, "stream_id": stream_id} class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): @@ -136,10 +136,10 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): logger.info("remote_reject_invite: %s out of room: %s", user_id, room_id) try: - event = await self.federation_handler.do_remotely_reject_invite( + event, stream_id = await self.federation_handler.do_remotely_reject_invite( remote_room_hosts, room_id, user_id, event_content, ) - ret = event.get_pdu_json() + event_id = event.event_id except Exception as e: # if we were unable to reject the exception, just mark # it as rejected on our end and plough ahead. @@ -149,10 +149,10 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): # logger.warning("Failed to reject invite: %s", e) - await self.store.locally_reject_invite(user_id, room_id) - ret = {} + stream_id = await self.store.locally_reject_invite(user_id, room_id) + event_id = None - return 200, ret + return 200, {"event_id": event_id, "stream_id": stream_id} class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint): diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py index b74b088ff4..c981723c1a 100644 --- a/synapse/replication/http/send_event.py +++ b/synapse/replication/http/send_event.py @@ -119,11 +119,11 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): "Got event to send with ID: %s into room: %s", event.event_id, event.room_id ) - await self.event_creation_handler.persist_and_notify_client_event( + stream_id = await self.event_creation_handler.persist_and_notify_client_event( requester, event, context, ratelimit=ratelimit, extra_users=extra_users ) - return 200, {} + return 200, {"stream_id": stream_id} def register_servlets(hs, http_server): diff --git a/synapse/replication/http/streams.py b/synapse/replication/http/streams.py index b705a8e16c..bde97eef32 100644 --- a/synapse/replication/http/streams.py +++ b/synapse/replication/http/streams.py @@ -51,10 +51,7 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint): super().__init__(hs) self._instance_name = hs.get_instance_name() - - # We pull the streams from the replication handler (if we try and make - # them ourselves we end up in an import loop). - self.streams = hs.get_tcp_replication().get_streams() + self.streams = hs.get_replication_streams() @staticmethod def _serialize_payload(stream_name, from_token, upto_token): diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 28826302f5..508ad1b720 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -14,19 +14,23 @@ # limitations under the License. """A replication client for use by synapse workers. """ - +import heapq import logging -from typing import TYPE_CHECKING, Tuple +from typing import TYPE_CHECKING, Dict, List, Tuple +from twisted.internet.defer import Deferred from twisted.internet.protocol import ReconnectingClientFactory from synapse.api.constants import EventTypes +from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol from synapse.replication.tcp.streams.events import ( EventsStream, EventsStreamEventRow, EventsStreamRow, ) +from synapse.util.async_helpers import timeout_deferred +from synapse.util.metrics import Measure if TYPE_CHECKING: from synapse.server import HomeServer @@ -35,6 +39,10 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +# How long we allow callers to wait for replication updates before timing out. +_WAIT_FOR_REPLICATION_TIMEOUT_SECONDS = 30 + + class DirectTcpReplicationClientFactory(ReconnectingClientFactory): """Factory for building connections to the master. Will reconnect if the connection is lost. @@ -92,6 +100,16 @@ class ReplicationDataHandler: self.store = hs.get_datastore() self.pusher_pool = hs.get_pusherpool() self.notifier = hs.get_notifier() + self._reactor = hs.get_reactor() + self._clock = hs.get_clock() + self._streams = hs.get_replication_streams() + self._instance_name = hs.get_instance_name() + + # Map from stream to list of deferreds waiting for the stream to + # arrive at a particular position. The lists are sorted by stream position. + self._streams_to_waiters = ( + {} + ) # type: Dict[str, List[Tuple[int, Deferred[None]]]] async def on_rdata( self, stream_name: str, instance_name: str, token: int, rows: list @@ -131,8 +149,76 @@ class ReplicationDataHandler: await self.pusher_pool.on_new_notifications(token, token) + # Notify any waiting deferreds. The list is ordered by position so we + # just iterate through the list until we reach a position that is + # greater than the received row position. + waiting_list = self._streams_to_waiters.get(stream_name, []) + + # Index of first item with a position after the current token, i.e we + # have called all deferreds before this index. If not overwritten by + # loop below means either a) no items in list so no-op or b) all items + # in list were called and so the list should be cleared. Setting it to + # `len(list)` works for both cases. + index_of_first_deferred_not_called = len(waiting_list) + + for idx, (position, deferred) in enumerate(waiting_list): + if position <= token: + try: + with PreserveLoggingContext(): + deferred.callback(None) + except Exception: + # The deferred has been cancelled or timed out. + pass + else: + # The list is sorted by position so we don't need to continue + # checking any futher entries in the list. + index_of_first_deferred_not_called = idx + break + + # Drop all entries in the waiting list that were called in the above + # loop. (This maintains the order so no need to resort) + waiting_list[:] = waiting_list[index_of_first_deferred_not_called:] + async def on_position(self, stream_name: str, instance_name: str, token: int): self.store.process_replication_rows(stream_name, instance_name, token, []) def on_remote_server_up(self, server: str): """Called when get a new REMOTE_SERVER_UP command.""" + + async def wait_for_stream_position( + self, instance_name: str, stream_name: str, position: int + ): + """Wait until this instance has received updates up to and including + the given stream position. + """ + + if instance_name == self._instance_name: + # We don't get told about updates written by this process, and + # anyway in that case we don't need to wait. + return + + current_position = self._streams[stream_name].current_token(self._instance_name) + if position <= current_position: + # We're already past the position + return + + # Create a new deferred that times out after N seconds, as we don't want + # to wedge here forever. + deferred = Deferred() + deferred = timeout_deferred( + deferred, _WAIT_FOR_REPLICATION_TIMEOUT_SECONDS, self._reactor + ) + + waiting_list = self._streams_to_waiters.setdefault(stream_name, []) + + # We insert into the list using heapq as it is more efficient than + # pushing then resorting each time. + heapq.heappush(waiting_list, (position, deferred)) + + # We measure here to get in flight counts and average waiting time. + with Measure(self._clock, "repl.wait_for_stream_position"): + logger.info("Waiting for repl stream %r to reach %s", stream_name, position) + await make_deferred_yieldable(deferred) + logger.info( + "Finished waiting for repl stream %r to reach %s", stream_name, position + ) diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 7d40001988..0a13e1ed34 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -59,6 +59,7 @@ class ShutdownRoomRestServlet(RestServlet): self.event_creation_handler = hs.get_event_creation_handler() self.room_member_handler = hs.get_room_member_handler() self.auth = hs.get_auth() + self._replication = hs.get_replication_data_handler() async def on_POST(self, request, room_id): requester = await self.auth.get_user_by_req(request) @@ -73,7 +74,7 @@ class ShutdownRoomRestServlet(RestServlet): message = content.get("message", self.DEFAULT_MESSAGE) room_name = content.get("room_name", "Content Violation Notification") - info = await self._room_creation_handler.create_room( + info, stream_id = await self._room_creation_handler.create_room( room_creator_requester, config={ "preset": "public_chat", @@ -94,6 +95,13 @@ class ShutdownRoomRestServlet(RestServlet): # desirable in case the first attempt at blocking the room failed below. await self.store.block_room(room_id, requester_user_id) + # We now wait for the create room to come back in via replication so + # that we can assume that all the joins/invites have propogated before + # we try and auto join below. + # + # TODO: Currently the events stream is written to from master + await self._replication.wait_for_stream_position("master", "events", stream_id) + users = await self.state.get_current_users_in_room(room_id) kicked_users = [] failed_to_kick_users = [] diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 6b5830cc3f..105e0cf4d2 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -93,7 +93,7 @@ class RoomCreateRestServlet(TransactionRestServlet): async def on_POST(self, request): requester = await self.auth.get_user_by_req(request) - info = await self._room_creation_handler.create_room( + info, _ = await self._room_creation_handler.create_room( requester, self.get_room_config(request) ) @@ -202,7 +202,7 @@ class RoomStateEventRestServlet(TransactionRestServlet): if event_type == EventTypes.Member: membership = content.get("membership", None) - event = await self.room_member_handler.update_membership( + event_id, _ = await self.room_member_handler.update_membership( requester, target=UserID.from_string(state_key), room_id=room_id, @@ -210,14 +210,18 @@ class RoomStateEventRestServlet(TransactionRestServlet): content=content, ) else: - event = await self.event_creation_handler.create_and_send_nonmember_event( + ( + event, + _, + ) = await self.event_creation_handler.create_and_send_nonmember_event( requester, event_dict, txn_id=txn_id ) + event_id = event.event_id ret = {} # type: dict - if event: - set_tag("event_id", event.event_id) - ret = {"event_id": event.event_id} + if event_id: + set_tag("event_id", event_id) + ret = {"event_id": event_id} return 200, ret @@ -247,7 +251,7 @@ class RoomSendEventRestServlet(TransactionRestServlet): if b"ts" in request.args and requester.app_service: event_dict["origin_server_ts"] = parse_integer(request, "ts", 0) - event = await self.event_creation_handler.create_and_send_nonmember_event( + event, _ = await self.event_creation_handler.create_and_send_nonmember_event( requester, event_dict, txn_id=txn_id ) @@ -781,7 +785,7 @@ class RoomRedactEventRestServlet(TransactionRestServlet): requester = await self.auth.get_user_by_req(request) content = parse_json_object_from_request(request) - event = await self.event_creation_handler.create_and_send_nonmember_event( + event, _ = await self.event_creation_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.Redaction, diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py index 63f07b63da..89002ffbff 100644 --- a/synapse/rest/client/v2_alpha/relations.py +++ b/synapse/rest/client/v2_alpha/relations.py @@ -111,7 +111,7 @@ class RelationSendServlet(RestServlet): "sender": requester.user.to_string(), } - event = await self.event_creation_handler.create_and_send_nonmember_event( + event, _ = await self.event_creation_handler.create_and_send_nonmember_event( requester, event_dict=event_dict, txn_id=txn_id ) diff --git a/synapse/server.py b/synapse/server.py index c530f1aa1a..ca2deb49bb 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -90,6 +90,7 @@ from synapse.push.pusherpool import PusherPool from synapse.replication.tcp.client import ReplicationDataHandler from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.replication.tcp.resource import ReplicationStreamer +from synapse.replication.tcp.streams import STREAMS_MAP from synapse.rest.media.v1.media_repository import ( MediaRepository, MediaRepositoryResource, @@ -210,6 +211,7 @@ class HomeServer(object): "storage", "replication_streamer", "replication_data_handler", + "replication_streams", ] REQUIRED_ON_MASTER_STARTUP = ["user_directory_handler", "stats_handler"] @@ -583,6 +585,9 @@ class HomeServer(object): def build_replication_data_handler(self): return ReplicationDataHandler(self) + def build_replication_streams(self): + return {stream.NAME: stream(self) for stream in STREAMS_MAP.values()} + def remove_pusher(self, app_id, push_key, user_id): return self.get_pusherpool().remove_pusher(app_id, push_key, user_id) diff --git a/synapse/server.pyi b/synapse/server.pyi index 9e7fad7e6e..fe8024d2d4 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -1,3 +1,5 @@ +from typing import Dict + import twisted.internet import synapse.api.auth @@ -28,6 +30,7 @@ import synapse.server_notices.server_notices_sender import synapse.state import synapse.storage from synapse.events.builder import EventBuilderFactory +from synapse.replication.tcp.streams import Stream class HomeServer(object): @property @@ -136,3 +139,5 @@ class HomeServer(object): pass def get_pusherpool(self) -> synapse.push.pusherpool.PusherPool: pass + def get_replication_streams(self) -> Dict[str, Stream]: + pass diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py index 999c621b92..bf2454c01c 100644 --- a/synapse/server_notices/server_notices_manager.py +++ b/synapse/server_notices/server_notices_manager.py @@ -83,10 +83,10 @@ class ServerNoticesManager(object): if state_key is not None: event_dict["state_key"] = state_key - res = await self._event_creation_handler.create_and_send_nonmember_event( + event, _ = await self._event_creation_handler.create_and_send_nonmember_event( requester, event_dict, ratelimit=False ) - return res + return event @cached() async def get_or_create_notice_room_for_user(self, user_id): @@ -143,7 +143,7 @@ class ServerNoticesManager(object): } requester = create_requester(self.server_notices_mxid) - info = await self._room_creation_handler.create_room( + info, _ = await self._room_creation_handler.create_room( requester, config={ "preset": RoomCreationPreset.PRIVATE_CHAT, diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py index 9130b74eb5..b880a71782 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py @@ -1289,12 +1289,12 @@ class EventsWorkerStore(SQLBaseStore): async def is_event_after(self, event_id1, event_id2): """Returns True if event_id1 is after event_id2 in the stream """ - to_1, so_1 = await self._get_event_ordering(event_id1) - to_2, so_2 = await self._get_event_ordering(event_id2) + to_1, so_1 = await self.get_event_ordering(event_id1) + to_2, so_2 = await self.get_event_ordering(event_id2) return (to_1, so_1) > (to_2, so_2) @cachedInlineCallbacks(max_entries=5000) - def _get_event_ordering(self, event_id): + def get_event_ordering(self, event_id): res = yield self.db.simple_select_one( table="events", retcols=["topological_ordering", "stream_ordering"], diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py index 1e9c850152..7c5ca81ae0 100644 --- a/synapse/storage/data_stores/main/roommember.py +++ b/synapse/storage/data_stores/main/roommember.py @@ -1069,6 +1069,8 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): with self._stream_id_gen.get_next() as stream_ordering: yield self.db.runInteraction("locally_reject_invite", f, stream_ordering) + return stream_ordering + def forget(self, user_id, room_id): """Indicate that user_id wishes to discard history for room_id.""" diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py index 94980733c4..0c9987be54 100644 --- a/tests/federation/test_complexity.py +++ b/tests/federation/test_complexity.py @@ -79,7 +79,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): # Mock out some things, because we don't want to test the whole join fed_transport.client.get_json = Mock(return_value=defer.succeed({"v1": 9999})) - handler.federation_handler.do_invite_join = Mock(return_value=defer.succeed(1)) + handler.federation_handler.do_invite_join = Mock( + return_value=defer.succeed(("", 1)) + ) d = handler._remote_join( None, @@ -115,7 +117,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): # Mock out some things, because we don't want to test the whole join fed_transport.client.get_json = Mock(return_value=defer.succeed(None)) - handler.federation_handler.do_invite_join = Mock(return_value=defer.succeed(1)) + handler.federation_handler.do_invite_join = Mock( + return_value=defer.succeed(("", 1)) + ) # Artificially raise the complexity self.hs.get_datastore().get_current_state_event_counts = lambda x: defer.succeed( diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 51e2b37218..2fa8d4739b 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -86,7 +86,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): reactor.pump((1000,)) hs = self.setup_test_homeserver( - notifier=Mock(), http_client=mock_federation_client, keyring=mock_keyring + notifier=Mock(), + http_client=mock_federation_client, + keyring=mock_keyring, + replication_streams={}, ) hs.datastores = datastores diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py index 0e04b2cf92..43425c969a 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py @@ -39,7 +39,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): # Create a test user and room self.user = UserID("alice", "test") self.requester = Requester(self.user, None, False, None, None) - info = self.get_success(self.room_creator.create_room(self.requester, {})) + info, _ = self.get_success(self.room_creator.create_room(self.requester, {})) self.room_id = info["room_id"] def run_background_update(self): @@ -261,7 +261,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase): self.user = UserID.from_string(self.register_user("user1", "password")) self.token1 = self.login("user1", "password") self.requester = Requester(self.user, None, False, None, None) - info = self.get_success(self.room_creator.create_room(self.requester, {})) + info, _ = self.get_success(self.room_creator.create_room(self.requester, {})) self.room_id = info["room_id"] self.event_creator = homeserver.get_event_creation_handler() homeserver.config.user_consent_version = self.CONSENT_VERSION diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py index a7b7fd36d3..a7b85004e5 100644 --- a/tests/storage/test_event_metrics.py +++ b/tests/storage/test_event_metrics.py @@ -33,7 +33,7 @@ class ExtremStatisticsTestCase(HomeserverTestCase): events = [(3, 2), (6, 2), (4, 6)] for event_count, extrems in events: - info = self.get_success(room_creator.create_room(requester, {})) + info, _ = self.get_success(room_creator.create_room(requester, {})) room_id = info["room_id"] last_event = None diff --git a/tests/test_federation.py b/tests/test_federation.py index 13ff14863e..c5099dd039 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -28,13 +28,13 @@ class MessageAcceptTests(unittest.HomeserverTestCase): user_id = UserID("us", "test") our_user = Requester(user_id, None, False, None, None) room_creator = self.homeserver.get_room_creation_handler() - room = ensureDeferred( + room_deferred = ensureDeferred( room_creator.create_room( our_user, room_creator.PRESETS_DICT["public_chat"], ratelimit=False ) ) self.reactor.advance(0.1) - self.room_id = self.successResultOf(room)["room_id"] + self.room_id = self.successResultOf(room_deferred)[0]["room_id"] self.store = self.homeserver.get_datastore() -- cgit 1.5.1 From e5c67d04dbe5ed45d659e826a5dfcd5044a4e374 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 22 May 2020 16:11:35 +0100 Subject: Add option to move event persistence off master (#7517) --- changelog.d/7517.feature | 1 + scripts/synapse_port_db | 3 + synapse/app/generic_worker.py | 53 +++++++++- synapse/config/logger.py | 1 + synapse/config/workers.py | 30 +++++- synapse/handlers/federation.py | 16 +-- synapse/handlers/message.py | 12 ++- synapse/handlers/presence.py | 6 ++ synapse/handlers/room.py | 7 ++ synapse/handlers/room_member.py | 39 ++++++-- synapse/replication/http/__init__.py | 4 +- synapse/replication/http/_base.py | 3 + synapse/replication/http/membership.py | 40 +++++++- synapse/replication/http/presence.py | 116 ++++++++++++++++++++++ synapse/replication/tcp/handler.py | 10 ++ synapse/rest/admin/rooms.py | 11 +- synapse/storage/data_stores/__init__.py | 6 +- synapse/storage/data_stores/main/devices.py | 30 ++++-- synapse/storage/data_stores/main/events.py | 34 ++++++- synapse/storage/data_stores/main/events_worker.py | 2 +- synapse/storage/data_stores/main/roommember.py | 25 ----- synapse/storage/persist_events.py | 6 ++ 22 files changed, 382 insertions(+), 73 deletions(-) create mode 100644 changelog.d/7517.feature create mode 100644 synapse/replication/http/presence.py (limited to 'synapse/rest') diff --git a/changelog.d/7517.feature b/changelog.d/7517.feature new file mode 100644 index 0000000000..6062f0061c --- /dev/null +++ b/changelog.d/7517.feature @@ -0,0 +1 @@ +Add option to move event persistence off master. diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index acd9ac4b75..9a0fbc61d8 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -196,6 +196,9 @@ class MockHomeserver: def get_reactor(self): return reactor + def get_instance_name(self): + return "master" + class Porter(object): def __init__(self, **kwargs): diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index a37520000a..2906b93f6a 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -39,7 +39,11 @@ from synapse.config.homeserver import HomeServerConfig from synapse.config.logger import setup_logging from synapse.federation import send_queue from synapse.federation.transport.server import TransportLayerServer -from synapse.handlers.presence import BasePresenceHandler, get_interested_parties +from synapse.handlers.presence import ( + BasePresenceHandler, + PresenceState, + get_interested_parties, +) from synapse.http.server import JsonResource, OptionsResource from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.site import SynapseSite @@ -47,6 +51,10 @@ from synapse.logging.context import LoggingContext from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource +from synapse.replication.http.presence import ( + ReplicationBumpPresenceActiveTime, + ReplicationPresenceSetState, +) from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage.account_data import SlavedAccountDataStore from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore @@ -247,6 +255,9 @@ class GenericWorkerPresence(BasePresenceHandler): # but we haven't notified the master of that yet self.users_going_offline = {} + self._bump_active_client = ReplicationBumpPresenceActiveTime.make_client(hs) + self._set_state_client = ReplicationPresenceSetState.make_client(hs) + self._send_stop_syncing_loop = self.clock.looping_call( self.send_stop_syncing, UPDATE_SYNCING_USERS_MS ) @@ -304,10 +315,6 @@ class GenericWorkerPresence(BasePresenceHandler): self.users_going_offline.pop(user_id, None) self.send_user_sync(user_id, False, last_sync_ms) - def set_state(self, user, state, ignore_status_msg=False): - # TODO Hows this supposed to work? - return defer.succeed(None) - async def user_syncing( self, user_id: str, affect_presence: bool ) -> ContextManager[None]: @@ -386,6 +393,42 @@ class GenericWorkerPresence(BasePresenceHandler): if count > 0 ] + async def set_state(self, target_user, state, ignore_status_msg=False): + """Set the presence state of the user. + """ + presence = state["presence"] + + valid_presence = ( + PresenceState.ONLINE, + PresenceState.UNAVAILABLE, + PresenceState.OFFLINE, + ) + if presence not in valid_presence: + raise SynapseError(400, "Invalid presence state") + + user_id = target_user.to_string() + + # If presence is disabled, no-op + if not self.hs.config.use_presence: + return + + # Proxy request to master + await self._set_state_client( + user_id=user_id, state=state, ignore_status_msg=ignore_status_msg + ) + + async def bump_presence_active_time(self, user): + """We've seen the user do something that indicates they're interacting + with the app. + """ + # If presence is disabled, no-op + if not self.hs.config.use_presence: + return + + # Proxy request to master + user_id = user.to_string() + await self._bump_active_client(user_id=user_id) + class GenericWorkerTyping(object): def __init__(self, hs): diff --git a/synapse/config/logger.py b/synapse/config/logger.py index a25c70e928..49f6c32beb 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -257,5 +257,6 @@ def setup_logging( logging.warning("***** STARTING SERVER *****") logging.warning("Server %s version %s", sys.argv[0], get_version_string(synapse)) logging.info("Server hostname: %s", config.server_name) + logging.info("Instance name: %s", hs.get_instance_name()) return logger diff --git a/synapse/config/workers.py b/synapse/config/workers.py index c80c338584..ed06b91a54 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -15,7 +15,7 @@ import attr -from ._base import Config +from ._base import Config, ConfigError @attr.s @@ -27,6 +27,17 @@ class InstanceLocationConfig: port = attr.ib(type=int) +@attr.s +class WriterLocations: + """Specifies the instances that write various streams. + + Attributes: + events: The instance that writes to the event and backfill streams. + """ + + events = attr.ib(default="master", type=str) + + class WorkerConfig(Config): """The workers are processes run separately to the main synapse process. They have their own pid_file and listener configuration. They use the @@ -83,11 +94,26 @@ class WorkerConfig(Config): bind_addresses.append("") # A map from instance name to host/port of their HTTP replication endpoint. - instance_map = config.get("instance_map", {}) or {} + instance_map = config.get("instance_map") or {} self.instance_map = { name: InstanceLocationConfig(**c) for name, c in instance_map.items() } + # Map from type of streams to source, c.f. WriterLocations. + writers = config.get("stream_writers") or {} + self.writers = WriterLocations(**writers) + + # Check that the configured writer for events also appears in + # `instance_map`. + if ( + self.writers.events != "master" + and self.writers.events not in self.instance_map + ): + raise ConfigError( + "Instance %r is configured to write events but does not appear in `instance_map` config." + % (self.writers.events,) + ) + def read_arguments(self, args): # We support a bunch of command line arguments that override options in # the config. A lot of these options have a worker_* prefix when running diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index e354c803db..75ec90d267 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -126,11 +126,10 @@ class FederationHandler(BaseHandler): self._server_notices_mxid = hs.config.server_notices_mxid self.config = hs.config self.http_client = hs.get_simple_http_client() + self._instance_name = hs.get_instance_name() self._replication = hs.get_replication_data_handler() - self._send_events_to_master = ReplicationFederationSendEventsRestServlet.make_client( - hs - ) + self._send_events = ReplicationFederationSendEventsRestServlet.make_client(hs) self._notify_user_membership_change = ReplicationUserJoinedLeftRoomRestServlet.make_client( hs ) @@ -1243,6 +1242,10 @@ class FederationHandler(BaseHandler): content: The event content to use for the join event. """ + # TODO: We should be able to call this on workers, but the upgrading of + # room stuff after join currently doesn't work on workers. + assert self.config.worker.worker_app is None + logger.debug("Joining %s to %s", joinee, room_id) origin, event, room_version_obj = await self._make_and_verify_event( @@ -1314,7 +1317,7 @@ class FederationHandler(BaseHandler): # # TODO: Currently the events stream is written to from master await self._replication.wait_for_stream_position( - "master", "events", max_stream_id + self.config.worker.writers.events, "events", max_stream_id ) # Check whether this room is the result of an upgrade of a room we already know @@ -2854,8 +2857,9 @@ class FederationHandler(BaseHandler): backfilled: Whether these events are a result of backfilling or not """ - if self.config.worker_app: - result = await self._send_events_to_master( + if self.config.worker.writers.events != self._instance_name: + result = await self._send_events( + instance_name=self.config.worker.writers.events, store=self.store, event_and_contexts=event_and_contexts, backfilled=backfilled, diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index f445e2aa2a..ea25f0515a 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -366,10 +366,11 @@ class EventCreationHandler(object): self.notifier = hs.get_notifier() self.config = hs.config self.require_membership_for_aliases = hs.config.require_membership_for_aliases + self._instance_name = hs.get_instance_name() self.room_invite_state_types = self.hs.config.room_invite_state_types - self.send_event_to_master = ReplicationSendEventRestServlet.make_client(hs) + self.send_event = ReplicationSendEventRestServlet.make_client(hs) # This is only used to get at ratelimit function, and maybe_kick_guest_users self.base_handler = BaseHandler(hs) @@ -835,8 +836,9 @@ class EventCreationHandler(object): success = False try: # If we're a worker we need to hit out to the master. - if self.config.worker_app: - result = await self.send_event_to_master( + if self.config.worker.writers.events != self._instance_name: + result = await self.send_event( + instance_name=self.config.worker.writers.events, event_id=event.event_id, store=self.store, requester=requester, @@ -902,9 +904,9 @@ class EventCreationHandler(object): """Called when we have fully built the event, have already calculated the push actions for the event, and checked auth. - This should only be run on master. + This should only be run on the instance in charge of persisting events. """ - assert not self.config.worker_app + assert self.config.worker.writers.events == self._instance_name if ratelimit: # We check if this is a room admin redacting an event so that we diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 9ea11c0754..3594f3b00f 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -193,6 +193,12 @@ class BasePresenceHandler(abc.ABC): ) -> None: """Set the presence state of the user. """ + @abc.abstractmethod + async def bump_presence_active_time(self, user: UserID): + """We've seen the user do something that indicates they're interacting + with the app. + """ + class PresenceHandler(BasePresenceHandler): def __init__(self, hs: "synapse.server.HomeServer"): diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 2698a129ca..61db3ccc43 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -89,6 +89,8 @@ class RoomCreationHandler(BaseHandler): self.room_member_handler = hs.get_room_member_handler() self.config = hs.config + self._replication = hs.get_replication_data_handler() + # linearizer to stop two upgrades happening at once self._upgrade_linearizer = Linearizer("room_upgrade_linearizer") @@ -752,6 +754,11 @@ class RoomCreationHandler(BaseHandler): if room_alias: result["room_alias"] = room_alias.to_string() + # Always wait for room creation to progate before returning + await self._replication.wait_for_stream_position( + self.hs.config.worker.writers.events, "events", last_stream_id + ) + return result, last_stream_id async def _send_events_for_new_room( diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 691b6705b2..0f7af982f0 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -26,6 +26,9 @@ from synapse.api.constants import EventTypes, Membership from synapse.api.errors import AuthError, Codes, SynapseError from synapse.events import EventBase from synapse.events.snapshot import EventContext +from synapse.replication.http.membership import ( + ReplicationLocallyRejectInviteRestServlet, +) from synapse.types import Collection, Requester, RoomAlias, RoomID, UserID from synapse.util.async_helpers import Linearizer from synapse.util.distributor import user_joined_room, user_left_room @@ -44,11 +47,6 @@ class RoomMemberHandler(object): __metaclass__ = abc.ABCMeta def __init__(self, hs): - """ - - Args: - hs (synapse.server.HomeServer): - """ self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() @@ -71,6 +69,17 @@ class RoomMemberHandler(object): self._enable_lookup = hs.config.enable_3pid_lookup self.allow_per_room_profiles = self.config.allow_per_room_profiles + self._event_stream_writer_instance = hs.config.worker.writers.events + self._is_on_event_persistence_instance = ( + self._event_stream_writer_instance == hs.get_instance_name() + ) + if self._is_on_event_persistence_instance: + self.persist_event_storage = hs.get_storage().persistence + else: + self._locally_reject_client = ReplicationLocallyRejectInviteRestServlet.make_client( + hs + ) + # This is only used to get at ratelimit function, and # maybe_kick_guest_users. It's fine there are multiple of these as # it doesn't store state. @@ -121,6 +130,22 @@ class RoomMemberHandler(object): """ raise NotImplementedError() + async def locally_reject_invite(self, user_id: str, room_id: str) -> int: + """Mark the invite has having been rejected even though we failed to + create a leave event for it. + """ + if self._is_on_event_persistence_instance: + return await self.persist_event_storage.locally_reject_invite( + user_id, room_id + ) + else: + result = await self._locally_reject_client( + instance_name=self._event_stream_writer_instance, + user_id=user_id, + room_id=room_id, + ) + return result["stream_id"] + @abc.abstractmethod async def _user_joined_room(self, target: UserID, room_id: str) -> None: """Notifies distributor on master process that the user has joined the @@ -1015,9 +1040,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): # logger.warning("Failed to reject invite: %s", e) - stream_id = await self.store.locally_reject_invite( - target.to_string(), room_id - ) + stream_id = await self.locally_reject_invite(target.to_string(), room_id) return None, stream_id async def _user_joined_room(self, target: UserID, room_id: str) -> None: diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py index a909744e93..19b69e0e11 100644 --- a/synapse/replication/http/__init__.py +++ b/synapse/replication/http/__init__.py @@ -19,6 +19,7 @@ from synapse.replication.http import ( federation, login, membership, + presence, register, send_event, streams, @@ -35,10 +36,11 @@ class ReplicationRestResource(JsonResource): def register_servlets(self, hs): send_event.register_servlets(hs, self) federation.register_servlets(hs, self) + presence.register_servlets(hs, self) + membership.register_servlets(hs, self) # The following can't currently be instantiated on workers. if hs.config.worker.worker_app is None: - membership.register_servlets(hs, self) login.register_servlets(hs, self) register.register_servlets(hs, self) devices.register_servlets(hs, self) diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index c3136a4eb9..793cef6c26 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -142,6 +142,7 @@ class ReplicationEndpoint(object): """ clock = hs.get_clock() client = hs.get_simple_http_client() + local_instance_name = hs.get_instance_name() master_host = hs.config.worker_replication_host master_port = hs.config.worker_replication_http_port @@ -151,6 +152,8 @@ class ReplicationEndpoint(object): @trace(opname="outgoing_replication_request") @defer.inlineCallbacks def send_request(instance_name="master", **kwargs): + if instance_name == local_instance_name: + raise Exception("Trying to send HTTP request to self") if instance_name == "master": host = master_host port = master_port diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py index 050fd34562..a7174c4a8f 100644 --- a/synapse/replication/http/membership.py +++ b/synapse/replication/http/membership.py @@ -14,12 +14,16 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint from synapse.types import Requester, UserID from synapse.util.distributor import user_joined_room, user_left_room +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -106,6 +110,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): self.federation_handler = hs.get_handlers().federation_handler self.store = hs.get_datastore() self.clock = hs.get_clock() + self.member_handler = hs.get_room_member_handler() @staticmethod def _serialize_payload(requester, room_id, user_id, remote_room_hosts, content): @@ -149,12 +154,44 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): # logger.warning("Failed to reject invite: %s", e) - stream_id = await self.store.locally_reject_invite(user_id, room_id) + stream_id = await self.member_handler.locally_reject_invite( + user_id, room_id + ) event_id = None return 200, {"event_id": event_id, "stream_id": stream_id} +class ReplicationLocallyRejectInviteRestServlet(ReplicationEndpoint): + """Rejects the invite for the user and room locally. + + Request format: + + POST /_synapse/replication/locally_reject_invite/:room_id/:user_id + + {} + """ + + NAME = "locally_reject_invite" + PATH_ARGS = ("room_id", "user_id") + + def __init__(self, hs: "HomeServer"): + super().__init__(hs) + + self.member_handler = hs.get_room_member_handler() + + @staticmethod + def _serialize_payload(room_id, user_id): + return {} + + async def _handle_request(self, request, room_id, user_id): + logger.info("locally_reject_invite: %s out of room: %s", user_id, room_id) + + stream_id = await self.member_handler.locally_reject_invite(user_id, room_id) + + return 200, {"stream_id": stream_id} + + class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint): """Notifies that a user has joined or left the room @@ -208,3 +245,4 @@ def register_servlets(hs, http_server): ReplicationRemoteJoinRestServlet(hs).register(http_server) ReplicationRemoteRejectInviteRestServlet(hs).register(http_server) ReplicationUserJoinedLeftRoomRestServlet(hs).register(http_server) + ReplicationLocallyRejectInviteRestServlet(hs).register(http_server) diff --git a/synapse/replication/http/presence.py b/synapse/replication/http/presence.py new file mode 100644 index 0000000000..ea1b33331b --- /dev/null +++ b/synapse/replication/http/presence.py @@ -0,0 +1,116 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import TYPE_CHECKING + +from synapse.http.servlet import parse_json_object_from_request +from synapse.replication.http._base import ReplicationEndpoint +from synapse.types import UserID + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class ReplicationBumpPresenceActiveTime(ReplicationEndpoint): + """We've seen the user do something that indicates they're interacting + with the app. + + The POST looks like: + + POST /_synapse/replication/bump_presence_active_time/ + + 200 OK + + {} + """ + + NAME = "bump_presence_active_time" + PATH_ARGS = ("user_id",) + METHOD = "POST" + CACHE = False + + def __init__(self, hs: "HomeServer"): + super().__init__(hs) + + self._presence_handler = hs.get_presence_handler() + + @staticmethod + def _serialize_payload(user_id): + return {} + + async def _handle_request(self, request, user_id): + await self._presence_handler.bump_presence_active_time( + UserID.from_string(user_id) + ) + + return ( + 200, + {}, + ) + + +class ReplicationPresenceSetState(ReplicationEndpoint): + """Set the presence state for a user. + + The POST looks like: + + POST /_synapse/replication/presence_set_state/ + + { + "state": { ... }, + "ignore_status_msg": false, + } + + 200 OK + + {} + """ + + NAME = "presence_set_state" + PATH_ARGS = ("user_id",) + METHOD = "POST" + CACHE = False + + def __init__(self, hs: "HomeServer"): + super().__init__(hs) + + self._presence_handler = hs.get_presence_handler() + + @staticmethod + def _serialize_payload(user_id, state, ignore_status_msg=False): + return { + "state": state, + "ignore_status_msg": ignore_status_msg, + } + + async def _handle_request(self, request, user_id): + content = parse_json_object_from_request(request) + + await self._presence_handler.set_state( + UserID.from_string(user_id), content["state"], content["ignore_status_msg"] + ) + + return ( + 200, + {}, + ) + + +def register_servlets(hs, http_server): + ReplicationBumpPresenceActiveTime(hs).register(http_server) + ReplicationPresenceSetState(hs).register(http_server) diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index acfa66a7a8..03300e5336 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -38,7 +38,9 @@ from synapse.replication.tcp.commands import ( from synapse.replication.tcp.protocol import AbstractConnection from synapse.replication.tcp.streams import ( STREAMS_MAP, + BackfillStream, CachesStream, + EventsStream, FederationStream, Stream, ) @@ -87,6 +89,14 @@ class ReplicationCommandHandler: self._streams_to_replicate.append(stream) continue + if isinstance(stream, (EventsStream, BackfillStream)): + # Only add EventStream and BackfillStream as a source on the + # instance in charge of event persistence. + if hs.config.worker.writers.events == hs.get_instance_name(): + self._streams_to_replicate.append(stream) + + continue + # Only add any other streams if we're on master. if hs.config.worker_app is not None: continue diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 0a13e1ed34..8173baef8f 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -100,7 +100,9 @@ class ShutdownRoomRestServlet(RestServlet): # we try and auto join below. # # TODO: Currently the events stream is written to from master - await self._replication.wait_for_stream_position("master", "events", stream_id) + await self._replication.wait_for_stream_position( + self.hs.config.worker.writers.events, "events", stream_id + ) users = await self.state.get_current_users_in_room(room_id) kicked_users = [] @@ -113,7 +115,7 @@ class ShutdownRoomRestServlet(RestServlet): try: target_requester = create_requester(user_id) - await self.room_member_handler.update_membership( + _, stream_id = await self.room_member_handler.update_membership( requester=target_requester, target=target_requester.user, room_id=room_id, @@ -123,6 +125,11 @@ class ShutdownRoomRestServlet(RestServlet): require_consent=False, ) + # Wait for leave to come in over replication before trying to forget. + await self._replication.wait_for_stream_position( + self.hs.config.worker.writers.events, "events", stream_id + ) + await self.room_member_handler.forget(target_requester.user, room_id) await self.room_member_handler.update_membership( diff --git a/synapse/storage/data_stores/__init__.py b/synapse/storage/data_stores/__init__.py index 791961b296..599ee470d4 100644 --- a/synapse/storage/data_stores/__init__.py +++ b/synapse/storage/data_stores/__init__.py @@ -66,9 +66,9 @@ class DataStores(object): self.main = main_store_class(database, db_conn, hs) - # If we're on a process that can persist events (currently - # master), also instantiate a `PersistEventsStore` - if hs.config.worker.worker_app is None: + # If we're on a process that can persist events also + # instantiate a `PersistEventsStore` + if hs.config.worker.writers.events == hs.get_instance_name(): self.persist_events = PersistEventsStore( hs, database, self.main ) diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py index 0e8378714a..417ac8dc7c 100644 --- a/synapse/storage/data_stores/main/devices.py +++ b/synapse/storage/data_stores/main/devices.py @@ -689,6 +689,25 @@ class DeviceWorkerStore(SQLBaseStore): desc="make_remote_user_device_cache_as_stale", ) + def mark_remote_user_device_list_as_unsubscribed(self, user_id): + """Mark that we no longer track device lists for remote user. + """ + + def _mark_remote_user_device_list_as_unsubscribed_txn(txn): + self.db.simple_delete_txn( + txn, + table="device_lists_remote_extremeties", + keyvalues={"user_id": user_id}, + ) + self._invalidate_cache_and_stream( + txn, self.get_device_list_last_stream_id_for_remote, (user_id,) + ) + + return self.db.runInteraction( + "mark_remote_user_device_list_as_unsubscribed", + _mark_remote_user_device_list_as_unsubscribed_txn, + ) + class DeviceBackgroundUpdateStore(SQLBaseStore): def __init__(self, database: Database, db_conn, hs): @@ -969,17 +988,6 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): desc="update_device", ) - @defer.inlineCallbacks - def mark_remote_user_device_list_as_unsubscribed(self, user_id): - """Mark that we no longer track device lists for remote user. - """ - yield self.db.simple_delete( - table="device_lists_remote_extremeties", - keyvalues={"user_id": user_id}, - desc="mark_remote_user_device_list_as_unsubscribed", - ) - self.get_device_list_last_stream_id_for_remote.invalidate((user_id,)) - def update_remote_device_list_cache_entry( self, user_id, device_id, content, stream_id ): diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index a97f8b3934..a6572571b4 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -138,10 +138,10 @@ class PersistEventsStore: self._backfill_id_gen = self.store._backfill_id_gen # type: StreamIdGenerator self._stream_id_gen = self.store._stream_id_gen # type: StreamIdGenerator - # This should only exist on master for now + # This should only exist on instances that are configured to write assert ( - hs.config.worker.worker_app is None - ), "Can only instantiate PersistEventsStore on master" + hs.config.worker.writers.events == hs.get_instance_name() + ), "Can only instantiate EventsStore on master" @_retry_on_integrity_error @defer.inlineCallbacks @@ -1590,3 +1590,31 @@ class PersistEventsStore: if not ev.internal_metadata.is_outlier() ], ) + + async def locally_reject_invite(self, user_id: str, room_id: str) -> int: + """Mark the invite has having been rejected even though we failed to + create a leave event for it. + """ + + sql = ( + "UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE" + " room_id = ? AND invitee = ? AND locally_rejected is NULL" + " AND replaced_by is NULL" + ) + + def f(txn, stream_ordering): + txn.execute(sql, (stream_ordering, True, room_id, user_id)) + + # We also clear this entry from `local_current_membership`. + # Ideally we'd point to a leave event, but we don't have one, so + # nevermind. + self.db.simple_delete_txn( + txn, + table="local_current_membership", + keyvalues={"room_id": room_id, "user_id": user_id}, + ) + + with self._stream_id_gen.get_next() as stream_ordering: + await self.db.runInteraction("locally_reject_invite", f, stream_ordering) + + return stream_ordering diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py index b880a71782..213d69100a 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py @@ -76,7 +76,7 @@ class EventsWorkerStore(SQLBaseStore): def __init__(self, database: Database, db_conn, hs): super(EventsWorkerStore, self).__init__(database, db_conn, hs) - if hs.config.worker_app is None: + if hs.config.worker.writers.events == hs.get_instance_name(): # We are the process in charge of generating stream ids for events, # so instantiate ID generators based on the database self._stream_id_gen = StreamIdGenerator( diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py index 7c5ca81ae0..137ebac833 100644 --- a/synapse/storage/data_stores/main/roommember.py +++ b/synapse/storage/data_stores/main/roommember.py @@ -1046,31 +1046,6 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): def __init__(self, database: Database, db_conn, hs): super(RoomMemberStore, self).__init__(database, db_conn, hs) - @defer.inlineCallbacks - def locally_reject_invite(self, user_id, room_id): - sql = ( - "UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE" - " room_id = ? AND invitee = ? AND locally_rejected is NULL" - " AND replaced_by is NULL" - ) - - def f(txn, stream_ordering): - txn.execute(sql, (stream_ordering, True, room_id, user_id)) - - # We also clear this entry from `local_current_membership`. - # Ideally we'd point to a leave event, but we don't have one, so - # nevermind. - self.db.simple_delete_txn( - txn, - table="local_current_membership", - keyvalues={"room_id": room_id, "user_id": user_id}, - ) - - with self._stream_id_gen.get_next() as stream_ordering: - yield self.db.runInteraction("locally_reject_invite", f, stream_ordering) - - return stream_ordering - def forget(self, user_id, room_id): """Indicate that user_id wishes to discard history for room_id.""" diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index 12e1ffb9a2..f159400a87 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -786,3 +786,9 @@ class EventsPersistenceStorage(object): for user_id in left_users: await self.main_store.mark_remote_user_device_list_as_unsubscribed(user_id) + + async def locally_reject_invite(self, user_id: str, room_id: str) -> int: + """Mark the invite has having been rejected even though we failed to + create a leave event for it. + """ + return await self.persist_events_store.locally_reject_invite(user_id, room_id) -- cgit 1.5.1