From 4b965c862dc66c0da5d3240add70e9b5f0aa720b Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Wed, 14 Apr 2021 16:34:27 +0200 Subject: Remove redundant "coding: utf-8" lines (#9786) Part of #9744 Removes all redundant `# -*- coding: utf-8 -*-` lines from files, as python 3 automatically reads source code as utf-8 now. `Signed-off-by: Jonathan de Jong ` --- tests/rest/client/__init__.py | 1 - tests/rest/client/test_consent.py | 1 - tests/rest/client/test_ephemeral_message.py | 1 - tests/rest/client/test_identity.py | 1 - tests/rest/client/test_power_levels.py | 1 - tests/rest/client/test_redactions.py | 1 - tests/rest/client/test_retention.py | 1 - tests/rest/client/test_third_party_rules.py | 1 - tests/rest/client/v1/__init__.py | 1 - tests/rest/client/v1/test_directory.py | 1 - tests/rest/client/v1/test_events.py | 1 - tests/rest/client/v1/test_login.py | 1 - tests/rest/client/v1/test_presence.py | 1 - tests/rest/client/v1/test_profile.py | 1 - tests/rest/client/v1/test_push_rule_attrs.py | 1 - tests/rest/client/v1/test_rooms.py | 1 - tests/rest/client/v1/test_typing.py | 1 - tests/rest/client/v1/utils.py | 1 - tests/rest/client/v2_alpha/test_account.py | 1 - tests/rest/client/v2_alpha/test_auth.py | 1 - tests/rest/client/v2_alpha/test_capabilities.py | 1 - tests/rest/client/v2_alpha/test_filter.py | 1 - tests/rest/client/v2_alpha/test_password_policy.py | 1 - tests/rest/client/v2_alpha/test_register.py | 1 - tests/rest/client/v2_alpha/test_relations.py | 1 - tests/rest/client/v2_alpha/test_shared_rooms.py | 1 - tests/rest/client/v2_alpha/test_sync.py | 1 - tests/rest/client/v2_alpha/test_upgrade_room.py | 1 - 28 files changed, 28 deletions(-) (limited to 'tests/rest/client') diff --git a/tests/rest/client/__init__.py b/tests/rest/client/__init__.py index fe0ac3f8e9..629e2df74a 100644 --- a/tests/rest/client/__init__.py +++ b/tests/rest/client/__init__.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/tests/rest/client/test_consent.py b/tests/rest/client/test_consent.py index c74693e9b2..5cc62a910a 100644 --- a/tests/rest/client/test_consent.py +++ b/tests/rest/client/test_consent.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Copyright 2018 New Vector # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/tests/rest/client/test_ephemeral_message.py b/tests/rest/client/test_ephemeral_message.py index 56937dcd2e..eec0fc01f9 100644 --- a/tests/rest/client/test_ephemeral_message.py +++ b/tests/rest/client/test_ephemeral_message.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Copyright 2019 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py index c0a9fc6925..478296ba0e 100644 --- a/tests/rest/client/test_identity.py +++ b/tests/rest/client/test_identity.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Copyright 2019 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/tests/rest/client/test_power_levels.py b/tests/rest/client/test_power_levels.py index 5256c11fe6..ba5ad47df5 100644 --- a/tests/rest/client/test_power_levels.py +++ b/tests/rest/client/test_power_levels.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Copyright 2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/tests/rest/client/test_redactions.py b/tests/rest/client/test_redactions.py index e0c74591b6..dfd85221d0 100644 --- a/tests/rest/client/test_redactions.py +++ b/tests/rest/client/test_redactions.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py index f892a71228..e1a6e73e17 100644 --- a/tests/rest/client/test_retention.py +++ b/tests/rest/client/test_retention.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Copyright 2019 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py index a7ebe0c3e9..e1fe72fc5d 100644 --- a/tests/rest/client/test_third_party_rules.py +++ b/tests/rest/client/test_third_party_rules.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the 'License'); diff --git a/tests/rest/client/v1/__init__.py b/tests/rest/client/v1/__init__.py index bfebb0f644..5e83dba2ed 100644 --- a/tests/rest/client/v1/__init__.py +++ b/tests/rest/client/v1/__init__.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/tests/rest/client/v1/test_directory.py b/tests/rest/client/v1/test_directory.py index edd1d184f8..8ed470490b 100644 --- a/tests/rest/client/v1/test_directory.py +++ b/tests/rest/client/v1/test_directory.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Copyright 2019 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py index 87a18d2cb9..852bda408c 100644 --- a/tests/rest/client/v1/test_events.py +++ b/tests/rest/client/v1/test_events.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index c7b79ab8a7..605b952316 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Copyright 2019-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py index c136827f79..3a050659ca 100644 --- a/tests/rest/client/v1/test_presence.py +++ b/tests/rest/client/v1/test_presence.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py index f3448c94dd..165ad33fb7 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/tests/rest/client/v1/test_push_rule_attrs.py b/tests/rest/client/v1/test_push_rule_attrs.py index 2bc512d75e..d077616082 100644 --- a/tests/rest/client/v1/test_push_rule_attrs.py +++ b/tests/rest/client/v1/test_push_rule_attrs.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Copyright 2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 4df20c90fd..92babf65e0 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd # Copyright 2017 Vector Creations Ltd # Copyright 2018-2019 New Vector Ltd diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index 0b8f565121..0aad48a162 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd # Copyright 2018 New Vector # diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index a6a292b20c..ed55a640af 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd # Copyright 2017 Vector Creations Ltd # Copyright 2018-2019 New Vector Ltd diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py index e72b61963d..4ef19145d1 100644 --- a/tests/rest/client/v2_alpha/test_account.py +++ b/tests/rest/client/v2_alpha/test_account.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Copyright 2015-2016 OpenMarket Ltd # Copyright 2017-2018 New Vector Ltd # Copyright 2019 The Matrix.org Foundation C.I.C. diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py index ed433d9333..485e3650c3 100644 --- a/tests/rest/client/v2_alpha/test_auth.py +++ b/tests/rest/client/v2_alpha/test_auth.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Copyright 2018 New Vector # Copyright 2020-2021 The Matrix.org Foundation C.I.C # diff --git a/tests/rest/client/v2_alpha/test_capabilities.py b/tests/rest/client/v2_alpha/test_capabilities.py index 287a1a485c..874052c61c 100644 --- a/tests/rest/client/v2_alpha/test_capabilities.py +++ b/tests/rest/client/v2_alpha/test_capabilities.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Copyright 2019 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/v2_alpha/test_filter.py index f761c44936..c7e47725b7 100644 --- a/tests/rest/client/v2_alpha/test_filter.py +++ b/tests/rest/client/v2_alpha/test_filter.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/tests/rest/client/v2_alpha/test_password_policy.py b/tests/rest/client/v2_alpha/test_password_policy.py index 5ebc5707a5..6f07ff6cbb 100644 --- a/tests/rest/client/v2_alpha/test_password_policy.py +++ b/tests/rest/client/v2_alpha/test_password_policy.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index cd60ea7081..054d4e4140 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd # Copyright 2017-2018 New Vector Ltd # Copyright 2019 The Matrix.org Foundation C.I.C. diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py index 21ee436b91..856aa8682f 100644 --- a/tests/rest/client/v2_alpha/test_relations.py +++ b/tests/rest/client/v2_alpha/test_relations.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Copyright 2019 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/tests/rest/client/v2_alpha/test_shared_rooms.py b/tests/rest/client/v2_alpha/test_shared_rooms.py index dd83a1f8ff..cedb9614a8 100644 --- a/tests/rest/client/v2_alpha/test_shared_rooms.py +++ b/tests/rest/client/v2_alpha/test_shared_rooms.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Copyright 2020 Half-Shot # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py index 2dbf42397a..dbcbdf159a 100644 --- a/tests/rest/client/v2_alpha/test_sync.py +++ b/tests/rest/client/v2_alpha/test_sync.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Copyright 2018-2019 New Vector Ltd # Copyright 2019 The Matrix.org Foundation C.I.C. # diff --git a/tests/rest/client/v2_alpha/test_upgrade_room.py b/tests/rest/client/v2_alpha/test_upgrade_room.py index d890d11863..5f3f15fc57 100644 --- a/tests/rest/client/v2_alpha/test_upgrade_room.py +++ b/tests/rest/client/v2_alpha/test_upgrade_room.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Copyright 2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); -- cgit 1.5.1 From 71f0623de968f07292d5a092e9197f7513ab6cde Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Mon, 19 Apr 2021 19:16:34 +0100 Subject: Port "Allow users to click account renewal links multiple times without hitting an 'Invalid Token' page #74" from synapse-dinsic (#9832) This attempts to be a direct port of https://github.com/matrix-org/synapse-dinsic/pull/74 to mainline. There was some fiddling required to deal with the changes that have been made to mainline since (mainly dealing with the split of `RegistrationWorkerStore` from `RegistrationStore`, and the changes made to `self.make_request` in test code). --- UPGRADE.rst | 23 +++ changelog.d/9832.feature | 1 + docs/sample_config.yaml | 148 ++++++++++-------- synapse/api/auth.py | 6 +- synapse/config/_base.pyi | 2 + synapse/config/account_validity.py | 165 +++++++++++++++++++++ synapse/config/emailconfig.py | 2 +- synapse/config/homeserver.py | 3 +- synapse/config/registration.py | 129 ---------------- synapse/handlers/account_validity.py | 101 ++++++++++--- synapse/handlers/deactivate_account.py | 4 +- synapse/push/pusherpool.py | 8 +- .../res/templates/account_previously_renewed.html | 1 + synapse/res/templates/account_renewed.html | 2 +- synapse/rest/client/v2_alpha/account_validity.py | 32 +++- synapse/storage/databases/main/registration.py | 62 ++++++-- .../59/12account_validity_token_used_ts_ms.sql | 18 +++ tests/rest/client/v2_alpha/test_register.py | 52 +++++-- 18 files changed, 496 insertions(+), 263 deletions(-) create mode 100644 changelog.d/9832.feature create mode 100644 synapse/config/account_validity.py create mode 100644 synapse/res/templates/account_previously_renewed.html create mode 100644 synapse/storage/databases/main/schema/delta/59/12account_validity_token_used_ts_ms.sql (limited to 'tests/rest/client') diff --git a/UPGRADE.rst b/UPGRADE.rst index 665821d4ef..eff976017d 100644 --- a/UPGRADE.rst +++ b/UPGRADE.rst @@ -85,6 +85,29 @@ for example: wget https://packages.matrix.org/debian/pool/main/m/matrix-synapse-py3/matrix-synapse-py3_1.3.0+stretch1_amd64.deb dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb +Upgrading to v1.33.0 +==================== + +Account Validity HTML templates can now display a user's expiration date +------------------------------------------------------------------------ + +This may affect you if you have enabled the account validity feature, and have made use of a +custom HTML template specified by the ``account_validity.template_dir`` or ``account_validity.account_renewed_html_path`` +Synapse config options. + +The template can now accept an ``expiration_ts`` variable, which represents the unix timestamp in milliseconds for the +future date of which their account has been renewed until. See the +`default template `_ +for an example of usage. + +ALso note that a new HTML template, ``account_previously_renewed.html``, has been added. This is is shown to users +when they attempt to renew their account with a valid renewal token that has already been used before. The default +template contents can been found +`here `_, +and can also accept an ``expiration_ts`` variable. This template replaces the error message users would previously see +upon attempting to use a valid renewal token more than once. + + Upgrading to v1.32.0 ==================== diff --git a/changelog.d/9832.feature b/changelog.d/9832.feature new file mode 100644 index 0000000000..e76395fbe8 --- /dev/null +++ b/changelog.d/9832.feature @@ -0,0 +1 @@ +Don't return an error when a user attempts to renew their account multiple times with the same token. Instead, state when their account is set to expire. This change concerns the optional account validity feature. \ No newline at end of file diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 9182dcd987..d260d76259 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1175,69 +1175,6 @@ url_preview_accept_language: # #enable_registration: false -# Optional account validity configuration. This allows for accounts to be denied -# any request after a given period. -# -# Once this feature is enabled, Synapse will look for registered users without an -# expiration date at startup and will add one to every account it found using the -# current settings at that time. -# This means that, if a validity period is set, and Synapse is restarted (it will -# then derive an expiration date from the current validity period), and some time -# after that the validity period changes and Synapse is restarted, the users' -# expiration dates won't be updated unless their account is manually renewed. This -# date will be randomly selected within a range [now + period - d ; now + period], -# where d is equal to 10% of the validity period. -# -account_validity: - # The account validity feature is disabled by default. Uncomment the - # following line to enable it. - # - #enabled: true - - # The period after which an account is valid after its registration. When - # renewing the account, its validity period will be extended by this amount - # of time. This parameter is required when using the account validity - # feature. - # - #period: 6w - - # The amount of time before an account's expiry date at which Synapse will - # send an email to the account's email address with a renewal link. By - # default, no such emails are sent. - # - # If you enable this setting, you will also need to fill out the 'email' and - # 'public_baseurl' configuration sections. - # - #renew_at: 1w - - # The subject of the email sent out with the renewal link. '%(app)s' can be - # used as a placeholder for the 'app_name' parameter from the 'email' - # section. - # - # Note that the placeholder must be written '%(app)s', including the - # trailing 's'. - # - # If this is not set, a default value is used. - # - #renew_email_subject: "Renew your %(app)s account" - - # Directory in which Synapse will try to find templates for the HTML files to - # serve to the user when trying to renew an account. If not set, default - # templates from within the Synapse package will be used. - # - #template_dir: "res/templates" - - # File within 'template_dir' giving the HTML to be displayed to the user after - # they successfully renewed their account. If not set, default text is used. - # - #account_renewed_html_path: "account_renewed.html" - - # File within 'template_dir' giving the HTML to be displayed when the user - # tries to renew an account with an invalid renewal token. If not set, - # default text is used. - # - #invalid_token_html_path: "invalid_token.html" - # Time that a user's session remains valid for, after they log in. # # Note that this is not currently compatible with guest logins. @@ -1432,6 +1369,91 @@ account_threepid_delegates: #auto_join_rooms_for_guests: false +## Account Validity ## + +# Optional account validity configuration. This allows for accounts to be denied +# any request after a given period. +# +# Once this feature is enabled, Synapse will look for registered users without an +# expiration date at startup and will add one to every account it found using the +# current settings at that time. +# This means that, if a validity period is set, and Synapse is restarted (it will +# then derive an expiration date from the current validity period), and some time +# after that the validity period changes and Synapse is restarted, the users' +# expiration dates won't be updated unless their account is manually renewed. This +# date will be randomly selected within a range [now + period - d ; now + period], +# where d is equal to 10% of the validity period. +# +account_validity: + # The account validity feature is disabled by default. Uncomment the + # following line to enable it. + # + #enabled: true + + # The period after which an account is valid after its registration. When + # renewing the account, its validity period will be extended by this amount + # of time. This parameter is required when using the account validity + # feature. + # + #period: 6w + + # The amount of time before an account's expiry date at which Synapse will + # send an email to the account's email address with a renewal link. By + # default, no such emails are sent. + # + # If you enable this setting, you will also need to fill out the 'email' and + # 'public_baseurl' configuration sections. + # + #renew_at: 1w + + # The subject of the email sent out with the renewal link. '%(app)s' can be + # used as a placeholder for the 'app_name' parameter from the 'email' + # section. + # + # Note that the placeholder must be written '%(app)s', including the + # trailing 's'. + # + # If this is not set, a default value is used. + # + #renew_email_subject: "Renew your %(app)s account" + + # Directory in which Synapse will try to find templates for the HTML files to + # serve to the user when trying to renew an account. If not set, default + # templates from within the Synapse package will be used. + # + # The currently available templates are: + # + # * account_renewed.html: Displayed to the user after they have successfully + # renewed their account. + # + # * account_previously_renewed.html: Displayed to the user if they attempt to + # renew their account with a token that is valid, but that has already + # been used. In this case the account is not renewed again. + # + # * invalid_token.html: Displayed to the user when they try to renew an account + # with an unknown or invalid renewal token. + # + # See https://github.com/matrix-org/synapse/tree/master/synapse/res/templates for + # default template contents. + # + # The file name of some of these templates can be configured below for legacy + # reasons. + # + #template_dir: "res/templates" + + # A custom file name for the 'account_renewed.html' template. + # + # If not set, the file is assumed to be named "account_renewed.html". + # + #account_renewed_html_path: "account_renewed.html" + + # A custom file name for the 'invalid_token.html' template. + # + # If not set, the file is assumed to be named "invalid_token.html". + # + #invalid_token_html_path: "invalid_token.html" + + ## Metrics ### # Enable collection and rendering of performance metrics diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 6c13f53957..872fd100cd 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -79,7 +79,9 @@ class Auth: self._auth_blocking = AuthBlocking(self.hs) - self._account_validity = hs.config.account_validity + self._account_validity_enabled = ( + hs.config.account_validity.account_validity_enabled + ) self._track_appservice_user_ips = hs.config.track_appservice_user_ips self._macaroon_secret_key = hs.config.macaroon_secret_key @@ -222,7 +224,7 @@ class Auth: shadow_banned = user_info.shadow_banned # Deny the request if the user account has expired. - if self._account_validity.enabled and not allow_expired: + if self._account_validity_enabled and not allow_expired: if await self.store.is_account_expired( user_info.user_id, self.clock.time_msec() ): diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi index e896fd34e2..ddec356a07 100644 --- a/synapse/config/_base.pyi +++ b/synapse/config/_base.pyi @@ -1,6 +1,7 @@ from typing import Any, Iterable, List, Optional from synapse.config import ( + account_validity, api, appservice, auth, @@ -59,6 +60,7 @@ class RootConfig: captcha: captcha.CaptchaConfig voip: voip.VoipConfig registration: registration.RegistrationConfig + account_validity: account_validity.AccountValidityConfig metrics: metrics.MetricsConfig api: api.ApiConfig appservice: appservice.AppServiceConfig diff --git a/synapse/config/account_validity.py b/synapse/config/account_validity.py new file mode 100644 index 0000000000..c58a7d95a7 --- /dev/null +++ b/synapse/config/account_validity.py @@ -0,0 +1,165 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from synapse.config._base import Config, ConfigError + + +class AccountValidityConfig(Config): + section = "account_validity" + + def read_config(self, config, **kwargs): + account_validity_config = config.get("account_validity") or {} + self.account_validity_enabled = account_validity_config.get("enabled", False) + self.account_validity_renew_by_email_enabled = ( + "renew_at" in account_validity_config + ) + + if self.account_validity_enabled: + if "period" in account_validity_config: + self.account_validity_period = self.parse_duration( + account_validity_config["period"] + ) + else: + raise ConfigError("'period' is required when using account validity") + + if "renew_at" in account_validity_config: + self.account_validity_renew_at = self.parse_duration( + account_validity_config["renew_at"] + ) + + if "renew_email_subject" in account_validity_config: + self.account_validity_renew_email_subject = account_validity_config[ + "renew_email_subject" + ] + else: + self.account_validity_renew_email_subject = "Renew your %(app)s account" + + self.account_validity_startup_job_max_delta = ( + self.account_validity_period * 10.0 / 100.0 + ) + + if self.account_validity_renew_by_email_enabled: + if not self.public_baseurl: + raise ConfigError("Can't send renewal emails without 'public_baseurl'") + + # Load account validity templates. + account_validity_template_dir = account_validity_config.get("template_dir") + + account_renewed_template_filename = account_validity_config.get( + "account_renewed_html_path", "account_renewed.html" + ) + invalid_token_template_filename = account_validity_config.get( + "invalid_token_html_path", "invalid_token.html" + ) + + # Read and store template content + ( + self.account_validity_account_renewed_template, + self.account_validity_account_previously_renewed_template, + self.account_validity_invalid_token_template, + ) = self.read_templates( + [ + account_renewed_template_filename, + "account_previously_renewed.html", + invalid_token_template_filename, + ], + account_validity_template_dir, + ) + + def generate_config_section(self, **kwargs): + return """\ + ## Account Validity ## + + # Optional account validity configuration. This allows for accounts to be denied + # any request after a given period. + # + # Once this feature is enabled, Synapse will look for registered users without an + # expiration date at startup and will add one to every account it found using the + # current settings at that time. + # This means that, if a validity period is set, and Synapse is restarted (it will + # then derive an expiration date from the current validity period), and some time + # after that the validity period changes and Synapse is restarted, the users' + # expiration dates won't be updated unless their account is manually renewed. This + # date will be randomly selected within a range [now + period - d ; now + period], + # where d is equal to 10% of the validity period. + # + account_validity: + # The account validity feature is disabled by default. Uncomment the + # following line to enable it. + # + #enabled: true + + # The period after which an account is valid after its registration. When + # renewing the account, its validity period will be extended by this amount + # of time. This parameter is required when using the account validity + # feature. + # + #period: 6w + + # The amount of time before an account's expiry date at which Synapse will + # send an email to the account's email address with a renewal link. By + # default, no such emails are sent. + # + # If you enable this setting, you will also need to fill out the 'email' and + # 'public_baseurl' configuration sections. + # + #renew_at: 1w + + # The subject of the email sent out with the renewal link. '%(app)s' can be + # used as a placeholder for the 'app_name' parameter from the 'email' + # section. + # + # Note that the placeholder must be written '%(app)s', including the + # trailing 's'. + # + # If this is not set, a default value is used. + # + #renew_email_subject: "Renew your %(app)s account" + + # Directory in which Synapse will try to find templates for the HTML files to + # serve to the user when trying to renew an account. If not set, default + # templates from within the Synapse package will be used. + # + # The currently available templates are: + # + # * account_renewed.html: Displayed to the user after they have successfully + # renewed their account. + # + # * account_previously_renewed.html: Displayed to the user if they attempt to + # renew their account with a token that is valid, but that has already + # been used. In this case the account is not renewed again. + # + # * invalid_token.html: Displayed to the user when they try to renew an account + # with an unknown or invalid renewal token. + # + # See https://github.com/matrix-org/synapse/tree/master/synapse/res/templates for + # default template contents. + # + # The file name of some of these templates can be configured below for legacy + # reasons. + # + #template_dir: "res/templates" + + # A custom file name for the 'account_renewed.html' template. + # + # If not set, the file is assumed to be named "account_renewed.html". + # + #account_renewed_html_path: "account_renewed.html" + + # A custom file name for the 'invalid_token.html' template. + # + # If not set, the file is assumed to be named "invalid_token.html". + # + #invalid_token_html_path: "invalid_token.html" + """ diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index c587939c7a..5564d7d097 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -299,7 +299,7 @@ class EmailConfig(Config): "client_base_url", email_config.get("riot_base_url", None) ) - if self.account_validity.renew_by_email_enabled: + if self.account_validity_renew_by_email_enabled: expiry_template_html = email_config.get( "expiry_template_html", "notice_expiry.html" ) diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index 1309535068..58e3bcd511 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -12,8 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from ._base import RootConfig +from .account_validity import AccountValidityConfig from .api import ApiConfig from .appservice import AppServiceConfig from .auth import AuthConfig @@ -68,6 +68,7 @@ class HomeServerConfig(RootConfig): CaptchaConfig, VoipConfig, RegistrationConfig, + AccountValidityConfig, MetricsConfig, ApiConfig, AppServiceConfig, diff --git a/synapse/config/registration.py b/synapse/config/registration.py index f8a2768af8..e6f52b4f40 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -12,74 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os - -import pkg_resources - from synapse.api.constants import RoomCreationPreset from synapse.config._base import Config, ConfigError from synapse.types import RoomAlias, UserID from synapse.util.stringutils import random_string_with_symbols, strtobool -class AccountValidityConfig(Config): - section = "accountvalidity" - - def __init__(self, config, synapse_config): - if config is None: - return - super().__init__() - self.enabled = config.get("enabled", False) - self.renew_by_email_enabled = "renew_at" in config - - if self.enabled: - if "period" in config: - self.period = self.parse_duration(config["period"]) - else: - raise ConfigError("'period' is required when using account validity") - - if "renew_at" in config: - self.renew_at = self.parse_duration(config["renew_at"]) - - if "renew_email_subject" in config: - self.renew_email_subject = config["renew_email_subject"] - else: - self.renew_email_subject = "Renew your %(app)s account" - - self.startup_job_max_delta = self.period * 10.0 / 100.0 - - if self.renew_by_email_enabled: - if "public_baseurl" not in synapse_config: - raise ConfigError("Can't send renewal emails without 'public_baseurl'") - - template_dir = config.get("template_dir") - - if not template_dir: - template_dir = pkg_resources.resource_filename("synapse", "res/templates") - - if "account_renewed_html_path" in config: - file_path = os.path.join(template_dir, config["account_renewed_html_path"]) - - self.account_renewed_html_content = self.read_file( - file_path, "account_validity.account_renewed_html_path" - ) - else: - self.account_renewed_html_content = ( - "Your account has been successfully renewed." - ) - - if "invalid_token_html_path" in config: - file_path = os.path.join(template_dir, config["invalid_token_html_path"]) - - self.invalid_token_html_content = self.read_file( - file_path, "account_validity.invalid_token_html_path" - ) - else: - self.invalid_token_html_content = ( - "Invalid renewal token." - ) - - class RegistrationConfig(Config): section = "registration" @@ -92,10 +30,6 @@ class RegistrationConfig(Config): str(config["disable_registration"]) ) - self.account_validity = AccountValidityConfig( - config.get("account_validity") or {}, config - ) - self.registrations_require_3pid = config.get("registrations_require_3pid", []) self.allowed_local_3pids = config.get("allowed_local_3pids", []) self.enable_3pid_lookup = config.get("enable_3pid_lookup", True) @@ -207,69 +141,6 @@ class RegistrationConfig(Config): # #enable_registration: false - # Optional account validity configuration. This allows for accounts to be denied - # any request after a given period. - # - # Once this feature is enabled, Synapse will look for registered users without an - # expiration date at startup and will add one to every account it found using the - # current settings at that time. - # This means that, if a validity period is set, and Synapse is restarted (it will - # then derive an expiration date from the current validity period), and some time - # after that the validity period changes and Synapse is restarted, the users' - # expiration dates won't be updated unless their account is manually renewed. This - # date will be randomly selected within a range [now + period - d ; now + period], - # where d is equal to 10%% of the validity period. - # - account_validity: - # The account validity feature is disabled by default. Uncomment the - # following line to enable it. - # - #enabled: true - - # The period after which an account is valid after its registration. When - # renewing the account, its validity period will be extended by this amount - # of time. This parameter is required when using the account validity - # feature. - # - #period: 6w - - # The amount of time before an account's expiry date at which Synapse will - # send an email to the account's email address with a renewal link. By - # default, no such emails are sent. - # - # If you enable this setting, you will also need to fill out the 'email' and - # 'public_baseurl' configuration sections. - # - #renew_at: 1w - - # The subject of the email sent out with the renewal link. '%%(app)s' can be - # used as a placeholder for the 'app_name' parameter from the 'email' - # section. - # - # Note that the placeholder must be written '%%(app)s', including the - # trailing 's'. - # - # If this is not set, a default value is used. - # - #renew_email_subject: "Renew your %%(app)s account" - - # Directory in which Synapse will try to find templates for the HTML files to - # serve to the user when trying to renew an account. If not set, default - # templates from within the Synapse package will be used. - # - #template_dir: "res/templates" - - # File within 'template_dir' giving the HTML to be displayed to the user after - # they successfully renewed their account. If not set, default text is used. - # - #account_renewed_html_path: "account_renewed.html" - - # File within 'template_dir' giving the HTML to be displayed when the user - # tries to renew an account with an invalid renewal token. If not set, - # default text is used. - # - #invalid_token_html_path: "invalid_token.html" - # Time that a user's session remains valid for, after they log in. # # Note that this is not currently compatible with guest logins. diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py index 66ce7e8b83..5b927f10b3 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py @@ -17,7 +17,7 @@ import email.utils import logging from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, List, Optional, Tuple from synapse.api.errors import StoreError, SynapseError from synapse.logging.context import make_deferred_yieldable @@ -39,28 +39,44 @@ class AccountValidityHandler: self.sendmail = self.hs.get_sendmail() self.clock = self.hs.get_clock() - self._account_validity = self.hs.config.account_validity + self._account_validity_enabled = ( + hs.config.account_validity.account_validity_enabled + ) + self._account_validity_renew_by_email_enabled = ( + hs.config.account_validity.account_validity_renew_by_email_enabled + ) + + self._account_validity_period = None + if self._account_validity_enabled: + self._account_validity_period = ( + hs.config.account_validity.account_validity_period + ) if ( - self._account_validity.enabled - and self._account_validity.renew_by_email_enabled + self._account_validity_enabled + and self._account_validity_renew_by_email_enabled ): # Don't do email-specific configuration if renewal by email is disabled. - self._template_html = self.config.account_validity_template_html - self._template_text = self.config.account_validity_template_text + self._template_html = ( + hs.config.account_validity.account_validity_template_html + ) + self._template_text = ( + hs.config.account_validity.account_validity_template_text + ) + account_validity_renew_email_subject = ( + hs.config.account_validity.account_validity_renew_email_subject + ) try: - app_name = self.hs.config.email_app_name + app_name = hs.config.email_app_name - self._subject = self._account_validity.renew_email_subject % { - "app": app_name - } + self._subject = account_validity_renew_email_subject % {"app": app_name} - self._from_string = self.hs.config.email_notif_from % {"app": app_name} + self._from_string = hs.config.email_notif_from % {"app": app_name} except Exception: # If substitution failed, fall back to the bare strings. - self._subject = self._account_validity.renew_email_subject - self._from_string = self.hs.config.email_notif_from + self._subject = account_validity_renew_email_subject + self._from_string = hs.config.email_notif_from self._raw_from = email.utils.parseaddr(self._from_string)[1] @@ -220,50 +236,87 @@ class AccountValidityHandler: attempts += 1 raise StoreError(500, "Couldn't generate a unique string as refresh string.") - async def renew_account(self, renewal_token: str) -> bool: + async def renew_account(self, renewal_token: str) -> Tuple[bool, bool, int]: """Renews the account attached to a given renewal token by pushing back the expiration date by the current validity period in the server's configuration. + If it turns out that the token is valid but has already been used, then the + token is considered stale. A token is stale if the 'token_used_ts_ms' db column + is non-null. + Args: renewal_token: Token sent with the renewal request. Returns: - Whether the provided token is valid. + A tuple containing: + * A bool representing whether the token is valid and unused. + * A bool which is `True` if the token is valid, but stale. + * An int representing the user's expiry timestamp as milliseconds since the + epoch, or 0 if the token was invalid. """ try: - user_id = await self.store.get_user_from_renewal_token(renewal_token) + ( + user_id, + current_expiration_ts, + token_used_ts, + ) = await self.store.get_user_from_renewal_token(renewal_token) except StoreError: - return False + return False, False, 0 + + # Check whether this token has already been used. + if token_used_ts: + logger.info( + "User '%s' attempted to use previously used token '%s' to renew account", + user_id, + renewal_token, + ) + return False, True, current_expiration_ts logger.debug("Renewing an account for user %s", user_id) - await self.renew_account_for_user(user_id) - return True + # Renew the account. Pass the renewal_token here so that it is not cleared. + # We want to keep the token around in case the user attempts to renew their + # account with the same token twice (clicking the email link twice). + # + # In that case, the token will be accepted, but the account's expiration ts + # will remain unchanged. + new_expiration_ts = await self.renew_account_for_user( + user_id, renewal_token=renewal_token + ) + + return True, False, new_expiration_ts async def renew_account_for_user( self, user_id: str, expiration_ts: Optional[int] = None, email_sent: bool = False, + renewal_token: Optional[str] = None, ) -> int: """Renews the account attached to a given user by pushing back the expiration date by the current validity period in the server's configuration. Args: - renewal_token: Token sent with the renewal request. + user_id: The ID of the user to renew. expiration_ts: New expiration date. Defaults to now + validity period. - email_sen: Whether an email has been sent for this validity period. - Defaults to False. + email_sent: Whether an email has been sent for this validity period. + renewal_token: Token sent with the renewal request. The user's token + will be cleared if this is None. Returns: New expiration date for this account, as a timestamp in milliseconds since epoch. """ + now = self.clock.time_msec() if expiration_ts is None: - expiration_ts = self.clock.time_msec() + self._account_validity.period + expiration_ts = now + self._account_validity_period await self.store.set_account_validity_for_user( - user_id=user_id, expiration_ts=expiration_ts, email_sent=email_sent + user_id=user_id, + expiration_ts=expiration_ts, + email_sent=email_sent, + renewal_token=renewal_token, + token_used_ts=now, ) return expiration_ts diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index 3f6f9f7f3d..45d2404dde 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -49,7 +49,9 @@ class DeactivateAccountHandler(BaseHandler): if hs.config.run_background_tasks: hs.get_reactor().callWhenRunning(self._start_user_parting) - self._account_validity_enabled = hs.config.account_validity.enabled + self._account_validity_enabled = ( + hs.config.account_validity.account_validity_enabled + ) async def deactivate_account( self, diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 564a5ed0df..579fcdf472 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -62,7 +62,9 @@ class PusherPool: self.store = self.hs.get_datastore() self.clock = self.hs.get_clock() - self._account_validity = hs.config.account_validity + self._account_validity_enabled = ( + hs.config.account_validity.account_validity_enabled + ) # We shard the handling of push notifications by user ID. self._pusher_shard_config = hs.config.push.pusher_shard_config @@ -236,7 +238,7 @@ class PusherPool: for u in users_affected: # Don't push if the user account has expired - if self._account_validity.enabled: + if self._account_validity_enabled: expired = await self.store.is_account_expired( u, self.clock.time_msec() ) @@ -266,7 +268,7 @@ class PusherPool: for u in users_affected: # Don't push if the user account has expired - if self._account_validity.enabled: + if self._account_validity_enabled: expired = await self.store.is_account_expired( u, self.clock.time_msec() ) diff --git a/synapse/res/templates/account_previously_renewed.html b/synapse/res/templates/account_previously_renewed.html new file mode 100644 index 0000000000..b751359bdf --- /dev/null +++ b/synapse/res/templates/account_previously_renewed.html @@ -0,0 +1 @@ +Your account is valid until {{ expiration_ts|format_ts("%d-%m-%Y") }}. diff --git a/synapse/res/templates/account_renewed.html b/synapse/res/templates/account_renewed.html index 894da030af..e8c0f52f05 100644 --- a/synapse/res/templates/account_renewed.html +++ b/synapse/res/templates/account_renewed.html @@ -1 +1 @@ -Your account has been successfully renewed. +Your account has been successfully renewed and is valid until {{ expiration_ts|format_ts("%d-%m-%Y") }}. diff --git a/synapse/rest/client/v2_alpha/account_validity.py b/synapse/rest/client/v2_alpha/account_validity.py index 0ad07fb895..2d1ad3d3fb 100644 --- a/synapse/rest/client/v2_alpha/account_validity.py +++ b/synapse/rest/client/v2_alpha/account_validity.py @@ -36,24 +36,40 @@ class AccountValidityRenewServlet(RestServlet): self.hs = hs self.account_activity_handler = hs.get_account_validity_handler() self.auth = hs.get_auth() - self.success_html = hs.config.account_validity.account_renewed_html_content - self.failure_html = hs.config.account_validity.invalid_token_html_content + self.account_renewed_template = ( + hs.config.account_validity.account_validity_account_renewed_template + ) + self.account_previously_renewed_template = ( + hs.config.account_validity.account_validity_account_previously_renewed_template + ) + self.invalid_token_template = ( + hs.config.account_validity.account_validity_invalid_token_template + ) async def on_GET(self, request): if b"token" not in request.args: raise SynapseError(400, "Missing renewal token") renewal_token = request.args[b"token"][0] - token_valid = await self.account_activity_handler.renew_account( + ( + token_valid, + token_stale, + expiration_ts, + ) = await self.account_activity_handler.renew_account( renewal_token.decode("utf8") ) if token_valid: status_code = 200 - response = self.success_html + response = self.account_renewed_template.render(expiration_ts=expiration_ts) + elif token_stale: + status_code = 200 + response = self.account_previously_renewed_template.render( + expiration_ts=expiration_ts + ) else: status_code = 404 - response = self.failure_html + response = self.invalid_token_template.render(expiration_ts=expiration_ts) respond_with_html(request, status_code, response) @@ -71,10 +87,12 @@ class AccountValiditySendMailServlet(RestServlet): self.hs = hs self.account_activity_handler = hs.get_account_validity_handler() self.auth = hs.get_auth() - self.account_validity = self.hs.config.account_validity + self.account_validity_renew_by_email_enabled = ( + hs.config.account_validity.account_validity_renew_by_email_enabled + ) async def on_POST(self, request): - if not self.account_validity.renew_by_email_enabled: + if not self.account_validity_renew_by_email_enabled: raise AuthError( 403, "Account renewal via email is disabled on this server." ) diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 833214b7e0..6e5ee557d2 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -91,13 +91,25 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): id_column=None, ) - self._account_validity = hs.config.account_validity - if hs.config.run_background_tasks and self._account_validity.enabled: - self._clock.call_later( - 0.0, - self._set_expiration_date_when_missing, + self._account_validity_enabled = ( + hs.config.account_validity.account_validity_enabled + ) + self._account_validity_period = None + self._account_validity_startup_job_max_delta = None + if self._account_validity_enabled: + self._account_validity_period = ( + hs.config.account_validity.account_validity_period + ) + self._account_validity_startup_job_max_delta = ( + hs.config.account_validity.account_validity_startup_job_max_delta ) + if hs.config.run_background_tasks: + self._clock.call_later( + 0.0, + self._set_expiration_date_when_missing, + ) + # Create a background job for culling expired 3PID validity tokens if hs.config.run_background_tasks: self._clock.looping_call( @@ -194,6 +206,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): expiration_ts: int, email_sent: bool, renewal_token: Optional[str] = None, + token_used_ts: Optional[int] = None, ) -> None: """Updates the account validity properties of the given account, with the given values. @@ -207,6 +220,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): period. renewal_token: Renewal token the user can use to extend the validity of their account. Defaults to no token. + token_used_ts: A timestamp of when the current token was used to renew + the account. """ def set_account_validity_for_user_txn(txn): @@ -218,6 +233,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): "expiration_ts_ms": expiration_ts, "email_sent": email_sent, "renewal_token": renewal_token, + "token_used_ts_ms": token_used_ts, }, ) self._invalidate_cache_and_stream( @@ -231,7 +247,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): async def set_renewal_token_for_user( self, user_id: str, renewal_token: str ) -> None: - """Defines a renewal token for a given user. + """Defines a renewal token for a given user, and clears the token_used timestamp. Args: user_id: ID of the user to set the renewal token for. @@ -244,26 +260,40 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): await self.db_pool.simple_update_one( table="account_validity", keyvalues={"user_id": user_id}, - updatevalues={"renewal_token": renewal_token}, + updatevalues={"renewal_token": renewal_token, "token_used_ts_ms": None}, desc="set_renewal_token_for_user", ) - async def get_user_from_renewal_token(self, renewal_token: str) -> str: - """Get a user ID from a renewal token. + async def get_user_from_renewal_token( + self, renewal_token: str + ) -> Tuple[str, int, Optional[int]]: + """Get a user ID and renewal status from a renewal token. Args: renewal_token: The renewal token to perform the lookup with. Returns: - The ID of the user to which the token belongs. + A tuple of containing the following values: + * The ID of a user to which the token belongs. + * An int representing the user's expiry timestamp as milliseconds since the + epoch, or 0 if the token was invalid. + * An optional int representing the timestamp of when the user renewed their + account timestamp as milliseconds since the epoch. None if the account + has not been renewed using the current token yet. """ - return await self.db_pool.simple_select_one_onecol( + ret_dict = await self.db_pool.simple_select_one( table="account_validity", keyvalues={"renewal_token": renewal_token}, - retcol="user_id", + retcols=["user_id", "expiration_ts_ms", "token_used_ts_ms"], desc="get_user_from_renewal_token", ) + return ( + ret_dict["user_id"], + ret_dict["expiration_ts_ms"], + ret_dict["token_used_ts_ms"], + ) + async def get_renewal_token_for_user(self, user_id: str) -> str: """Get the renewal token associated with a given user ID. @@ -302,7 +332,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): "get_users_expiring_soon", select_users_txn, self._clock.time_msec(), - self.config.account_validity.renew_at, + self.config.account_validity_renew_at, ) async def set_renewal_mail_status(self, user_id: str, email_sent: bool) -> None: @@ -964,11 +994,11 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): delta equal to 10% of the validity period. """ now_ms = self._clock.time_msec() - expiration_ts = now_ms + self._account_validity.period + expiration_ts = now_ms + self._account_validity_period if use_delta: expiration_ts = self.rand.randrange( - expiration_ts - self._account_validity.startup_job_max_delta, + expiration_ts - self._account_validity_startup_job_max_delta, expiration_ts, ) @@ -1412,7 +1442,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): except self.database_engine.module.IntegrityError: raise StoreError(400, "User ID already taken.", errcode=Codes.USER_IN_USE) - if self._account_validity.enabled: + if self._account_validity_enabled: self.set_expiration_date_for_user_txn(txn, user_id) if create_profile_with_displayname: diff --git a/synapse/storage/databases/main/schema/delta/59/12account_validity_token_used_ts_ms.sql b/synapse/storage/databases/main/schema/delta/59/12account_validity_token_used_ts_ms.sql new file mode 100644 index 0000000000..4836dac16e --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/59/12account_validity_token_used_ts_ms.sql @@ -0,0 +1,18 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Track when users renew their account using the value of the 'renewal_token' column. +-- This field should be set to NULL after a fresh token is generated. +ALTER TABLE account_validity ADD token_used_ts_ms BIGINT; diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index 054d4e4140..98695b05d5 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -492,8 +492,8 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): (user_id, tok) = self.create_user() - # Move 6 days forward. This should trigger a renewal email to be sent. - self.reactor.advance(datetime.timedelta(days=6).total_seconds()) + # Move 5 days forward. This should trigger a renewal email to be sent. + self.reactor.advance(datetime.timedelta(days=5).total_seconds()) self.assertEqual(len(self.email_attempts), 1) # Retrieving the URL from the email is too much pain for now, so we @@ -504,14 +504,32 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"200", channel.result) # Check that we're getting HTML back. - content_type = None - for header in channel.result.get("headers", []): - if header[0] == b"Content-Type": - content_type = header[1] - self.assertEqual(content_type, b"text/html; charset=utf-8", channel.result) + content_type = channel.headers.getRawHeaders(b"Content-Type") + self.assertEqual(content_type, [b"text/html; charset=utf-8"], channel.result) # Check that the HTML we're getting is the one we expect on a successful renewal. - expected_html = self.hs.config.account_validity.account_renewed_html_content + expiration_ts = self.get_success(self.store.get_expiration_ts_for_user(user_id)) + expected_html = self.hs.config.account_validity.account_validity_account_renewed_template.render( + expiration_ts=expiration_ts + ) + self.assertEqual( + channel.result["body"], expected_html.encode("utf8"), channel.result + ) + + # Move 1 day forward. Try to renew with the same token again. + url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token + channel = self.make_request(b"GET", url) + self.assertEquals(channel.result["code"], b"200", channel.result) + + # Check that we're getting HTML back. + content_type = channel.headers.getRawHeaders(b"Content-Type") + self.assertEqual(content_type, [b"text/html; charset=utf-8"], channel.result) + + # Check that the HTML we're getting is the one we expect when reusing a + # token. The account expiration date should not have changed. + expected_html = self.hs.config.account_validity.account_validity_account_previously_renewed_template.render( + expiration_ts=expiration_ts + ) self.assertEqual( channel.result["body"], expected_html.encode("utf8"), channel.result ) @@ -531,15 +549,14 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"404", channel.result) # Check that we're getting HTML back. - content_type = None - for header in channel.result.get("headers", []): - if header[0] == b"Content-Type": - content_type = header[1] - self.assertEqual(content_type, b"text/html; charset=utf-8", channel.result) + content_type = channel.headers.getRawHeaders(b"Content-Type") + self.assertEqual(content_type, [b"text/html; charset=utf-8"], channel.result) # Check that the HTML we're getting is the one we expect when using an # invalid/unknown token. - expected_html = self.hs.config.account_validity.invalid_token_html_content + expected_html = ( + self.hs.config.account_validity.account_validity_invalid_token_template.render() + ) self.assertEqual( channel.result["body"], expected_html.encode("utf8"), channel.result ) @@ -647,7 +664,12 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase): config["account_validity"] = {"enabled": False} self.hs = self.setup_test_homeserver(config=config) - self.hs.config.account_validity.period = self.validity_period + + # We need to set these directly, instead of in the homeserver config dict above. + # This is due to account validity-related config options not being read by + # Synapse when account_validity.enabled is False. + self.hs.get_datastore()._account_validity_period = self.validity_period + self.hs.get_datastore()._account_validity_startup_job_max_delta = self.max_delta self.store = self.hs.get_datastore() -- cgit 1.5.1 From 495b214f4f8f45d16ffee851c8ab7a380dd0e2b2 Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Tue, 20 Apr 2021 12:50:49 +0200 Subject: Fix (final) Bugbear violations (#9838) --- changelog.d/9838.misc | 1 + scripts-dev/definitions.py | 2 +- scripts-dev/list_url_patterns.py | 2 +- setup.cfg | 3 +-- synapse/event_auth.py | 2 +- synapse/federation/send_queue.py | 4 ++-- synapse/handlers/auth.py | 2 +- synapse/handlers/device.py | 13 +++++-------- synapse/handlers/federation.py | 2 +- synapse/logging/_remote.py | 4 ++-- synapse/rest/key/v2/remote_key_resource.py | 4 ++-- synapse/storage/databases/main/events.py | 10 +++++----- tests/handlers/test_federation.py | 2 +- tests/replication/tcp/streams/test_events.py | 4 ++-- tests/rest/admin/test_device.py | 4 ++-- tests/rest/admin/test_event_reports.py | 8 ++++---- tests/rest/admin/test_room.py | 8 ++++---- tests/rest/admin/test_statistics.py | 2 +- tests/rest/admin/test_user.py | 4 ++-- tests/rest/client/v1/test_rooms.py | 6 +++--- tests/storage/test_event_metrics.py | 4 ++-- tests/unittest.py | 2 +- tests/utils.py | 2 +- 23 files changed, 46 insertions(+), 49 deletions(-) create mode 100644 changelog.d/9838.misc (limited to 'tests/rest/client') diff --git a/changelog.d/9838.misc b/changelog.d/9838.misc new file mode 100644 index 0000000000..b98ce56309 --- /dev/null +++ b/changelog.d/9838.misc @@ -0,0 +1 @@ +Introduce flake8-bugbear to the test suite and fix some of its lint violations. \ No newline at end of file diff --git a/scripts-dev/definitions.py b/scripts-dev/definitions.py index 313860df13..c82ddd9677 100755 --- a/scripts-dev/definitions.py +++ b/scripts-dev/definitions.py @@ -140,7 +140,7 @@ if __name__ == "__main__": definitions = {} for directory in args.directories: - for root, dirs, files in os.walk(directory): + for root, _, files in os.walk(directory): for filename in files: if filename.endswith(".py"): filepath = os.path.join(root, filename) diff --git a/scripts-dev/list_url_patterns.py b/scripts-dev/list_url_patterns.py index 26ad7c67f4..e85420dea8 100755 --- a/scripts-dev/list_url_patterns.py +++ b/scripts-dev/list_url_patterns.py @@ -48,7 +48,7 @@ args = parser.parse_args() for directory in args.directories: - for root, dirs, files in os.walk(directory): + for root, _, files in os.walk(directory): for filename in files: if filename.endswith(".py"): filepath = os.path.join(root, filename) diff --git a/setup.cfg b/setup.cfg index 33601b71d5..e5ceb7ed19 100644 --- a/setup.cfg +++ b/setup.cfg @@ -18,8 +18,7 @@ ignore = # E203: whitespace before ':' (which is contrary to pep8?) # E731: do not assign a lambda expression, use a def # E501: Line too long (black enforces this for us) -# B007: Subsection of the bugbear suite (TODO: add in remaining fixes) -ignore=W503,W504,E203,E731,E501,B007 +ignore=W503,W504,E203,E731,E501 [isort] line_length = 88 diff --git a/synapse/event_auth.py b/synapse/event_auth.py index 5234e3f81e..c831d9f73c 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -670,7 +670,7 @@ def _verify_third_party_invite(event: EventBase, auth_events: StateMap[EventBase public_key = public_key_object["public_key"] try: for server, signature_block in signed["signatures"].items(): - for key_name, encoded_signature in signature_block.items(): + for key_name in signature_block.keys(): if not key_name.startswith("ed25519:"): continue verify_key = decode_verify_key_bytes( diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index d71f04e43e..65d76ea974 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -501,10 +501,10 @@ def process_rows_for_federation( states=[state], destinations=destinations ) - for destination, edu_map in buff.keyed_edus.items(): + for edu_map in buff.keyed_edus.values(): for key, edu in edu_map.items(): transaction_queue.send_edu(edu, key) - for destination, edu_list in buff.edus.items(): + for edu_list in buff.edus.values(): for edu in edu_list: transaction_queue.send_edu(edu, None) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index b8a37b6477..36f2450e2e 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -1248,7 +1248,7 @@ class AuthHandler(BaseHandler): # see if any of our auth providers want to know about this for provider in self.password_providers: - for token, token_id, device_id in tokens_and_devices: + for token, _, device_id in tokens_and_devices: await provider.on_logged_out( user_id=user_id, device_id=device_id, access_token=token ) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index d75edb184b..c1d7800981 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -156,8 +156,7 @@ class DeviceWorkerHandler(BaseHandler): # The user may have left the room # TODO: Check if they actually did or if we were just invited. if room_id not in room_ids: - for key, event_id in current_state_ids.items(): - etype, state_key = key + for etype, state_key in current_state_ids.keys(): if etype != EventTypes.Member: continue possibly_left.add(state_key) @@ -179,8 +178,7 @@ class DeviceWorkerHandler(BaseHandler): log_kv( {"event": "encountered empty previous state", "room_id": room_id} ) - for key, event_id in current_state_ids.items(): - etype, state_key = key + for etype, state_key in current_state_ids.keys(): if etype != EventTypes.Member: continue possibly_changed.add(state_key) @@ -198,8 +196,7 @@ class DeviceWorkerHandler(BaseHandler): for state_dict in prev_state_ids.values(): member_event = state_dict.get((EventTypes.Member, user_id), None) if not member_event or member_event != current_member_id: - for key, event_id in current_state_ids.items(): - etype, state_key = key + for etype, state_key in current_state_ids.keys(): if etype != EventTypes.Member: continue possibly_changed.add(state_key) @@ -714,7 +711,7 @@ class DeviceListUpdater: # This can happen since we batch updates return - for device_id, stream_id, prev_ids, content in pending_updates: + for device_id, stream_id, prev_ids, _ in pending_updates: logger.debug( "Handling update %r/%r, ID: %r, prev: %r ", user_id, @@ -740,7 +737,7 @@ class DeviceListUpdater: else: # Simply update the single device, since we know that is the only # change (because of the single prev_id matching the current cache) - for device_id, stream_id, prev_ids, content in pending_updates: + for device_id, stream_id, _, content in pending_updates: await self.store.update_remote_device_list_cache_entry( user_id, device_id, content, stream_id ) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 4b3730aa3b..dbdd7d2db3 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -2956,7 +2956,7 @@ class FederationHandler(BaseHandler): try: # for each sig on the third_party_invite block of the actual invite for server, signature_block in signed["signatures"].items(): - for key_name, encoded_signature in signature_block.items(): + for key_name in signature_block.keys(): if not key_name.startswith("ed25519:"): continue diff --git a/synapse/logging/_remote.py b/synapse/logging/_remote.py index 4e8b0f8d10..c515690b38 100644 --- a/synapse/logging/_remote.py +++ b/synapse/logging/_remote.py @@ -226,11 +226,11 @@ class RemoteHandler(logging.Handler): old_buffer = self._buffer self._buffer = deque() - for i in range(buffer_split): + for _ in range(buffer_split): self._buffer.append(old_buffer.popleft()) end_buffer = [] - for i in range(buffer_split): + for _ in range(buffer_split): end_buffer.append(old_buffer.pop()) self._buffer.extend(reversed(end_buffer)) diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index c57ac22e58..f648678b09 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -144,7 +144,7 @@ class RemoteKey(DirectServeJsonResource): # Note that the value is unused. cache_misses = {} # type: Dict[str, Dict[str, int]] - for (server_name, key_id, from_server), results in cached.items(): + for (server_name, key_id, _), results in cached.items(): results = [(result["ts_added_ms"], result) for result in results] if not results and key_id is not None: @@ -206,7 +206,7 @@ class RemoteKey(DirectServeJsonResource): # Cast to bytes since postgresql returns a memoryview. json_results.add(bytes(most_recent_result["key_json"])) else: - for ts_added, result in results: + for _, result in results: # Cast to bytes since postgresql returns a memoryview. json_results.add(bytes(result["key_json"])) diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index a362521e20..fd25c8112d 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -170,7 +170,7 @@ class PersistEventsStore: ) async with stream_ordering_manager as stream_orderings: - for (event, context), stream in zip(events_and_contexts, stream_orderings): + for (event, _), stream in zip(events_and_contexts, stream_orderings): event.internal_metadata.stream_ordering = stream await self.db_pool.runInteraction( @@ -297,7 +297,7 @@ class PersistEventsStore: txn.execute(sql + clause, args) to_recursively_check = [] - for event_id, prev_event_id, metadata, rejected in txn: + for _, prev_event_id, metadata, rejected in txn: if prev_event_id in existing_prevs: continue @@ -1127,7 +1127,7 @@ class PersistEventsStore: def _update_forward_extremities_txn( self, txn, new_forward_extremities, max_stream_order ): - for room_id, new_extrem in new_forward_extremities.items(): + for room_id in new_forward_extremities.keys(): self.db_pool.simple_delete_txn( txn, table="event_forward_extremities", keyvalues={"room_id": room_id} ) @@ -1399,7 +1399,7 @@ class PersistEventsStore: ] state_values = [] - for event, context in state_events_and_contexts: + for event, _ in state_events_and_contexts: vals = { "event_id": event.event_id, "room_id": event.room_id, @@ -1468,7 +1468,7 @@ class PersistEventsStore: # nothing to do here return - for event, context in events_and_contexts: + for event, _ in events_and_contexts: if event.type == EventTypes.Redaction and event.redacts is not None: # Remove the entries in the event_push_actions table for the # redacted event. diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index c7b0975a19..8796af45ed 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -222,7 +222,7 @@ class FederationTestCase(unittest.HomeserverTestCase): room_version, ) - for i in range(3): + for _ in range(3): event = create_invite() self.get_success( self.handler.on_invite_request( diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py index 323237c1bb..f51fa0a79e 100644 --- a/tests/replication/tcp/streams/test_events.py +++ b/tests/replication/tcp/streams/test_events.py @@ -239,7 +239,7 @@ class EventsStreamTestCase(BaseStreamTestCase): # the state rows are unsorted state_rows = [] # type: List[EventsStreamCurrentStateRow] - for stream_name, token, row in received_rows: + for stream_name, _, row in received_rows: self.assertEqual("events", stream_name) self.assertIsInstance(row, EventsStreamRow) self.assertEqual(row.type, "state") @@ -356,7 +356,7 @@ class EventsStreamTestCase(BaseStreamTestCase): # the state rows are unsorted state_rows = [] # type: List[EventsStreamCurrentStateRow] - for j in range(STATES_PER_USER + 1): + for _ in range(STATES_PER_USER + 1): stream_name, token, row = received_rows.pop(0) self.assertEqual("events", stream_name) self.assertIsInstance(row, EventsStreamRow) diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py index ecbee30bb5..120730b764 100644 --- a/tests/rest/admin/test_device.py +++ b/tests/rest/admin/test_device.py @@ -430,7 +430,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): """ # Create devices number_devices = 5 - for n in range(number_devices): + for _ in range(number_devices): self.login("user", "pass") # Get devices @@ -547,7 +547,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): # Create devices number_devices = 5 - for n in range(number_devices): + for _ in range(number_devices): self.login("user", "pass") # Get devices diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py index 8c66da3af4..29341bc6e9 100644 --- a/tests/rest/admin/test_event_reports.py +++ b/tests/rest/admin/test_event_reports.py @@ -48,22 +48,22 @@ class EventReportsTestCase(unittest.HomeserverTestCase): self.helper.join(self.room_id2, user=self.admin_user, tok=self.admin_user_tok) # Two rooms and two users. Every user sends and reports every room event - for i in range(5): + for _ in range(5): self._create_event_and_report( room_id=self.room_id1, user_tok=self.other_user_tok, ) - for i in range(5): + for _ in range(5): self._create_event_and_report( room_id=self.room_id2, user_tok=self.other_user_tok, ) - for i in range(5): + for _ in range(5): self._create_event_and_report( room_id=self.room_id1, user_tok=self.admin_user_tok, ) - for i in range(5): + for _ in range(5): self._create_event_and_report( room_id=self.room_id2, user_tok=self.admin_user_tok, diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index 6bcd997085..6b84188120 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -615,7 +615,7 @@ class RoomTestCase(unittest.HomeserverTestCase): # Create 3 test rooms total_rooms = 3 room_ids = [] - for x in range(total_rooms): + for _ in range(total_rooms): room_id = self.helper.create_room_as( self.admin_user, tok=self.admin_user_tok ) @@ -679,7 +679,7 @@ class RoomTestCase(unittest.HomeserverTestCase): # Create 5 test rooms total_rooms = 5 room_ids = [] - for x in range(total_rooms): + for _ in range(total_rooms): room_id = self.helper.create_room_as( self.admin_user, tok=self.admin_user_tok ) @@ -1577,7 +1577,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): channel.json_body["event"]["event_id"], events[midway]["event_id"] ) - for i, found_event in enumerate(channel.json_body["events_before"]): + for found_event in channel.json_body["events_before"]: for j, posted_event in enumerate(events): if found_event["event_id"] == posted_event["event_id"]: self.assertTrue(j < midway) @@ -1585,7 +1585,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): else: self.fail("Event %s from events_before not found" % j) - for i, found_event in enumerate(channel.json_body["events_after"]): + for found_event in channel.json_body["events_after"]: for j, posted_event in enumerate(events): if found_event["event_id"] == posted_event["event_id"]: self.assertTrue(j > midway) diff --git a/tests/rest/admin/test_statistics.py b/tests/rest/admin/test_statistics.py index 363bdeeb2d..79cac4266b 100644 --- a/tests/rest/admin/test_statistics.py +++ b/tests/rest/admin/test_statistics.py @@ -467,7 +467,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): number_media: Number of media to be created for the user """ upload_resource = self.media_repo.children[b"upload"] - for i in range(number_media): + for _ in range(number_media): # file size is 67 Byte image_data = unhexlify( b"89504e470d0a1a0a0000000d4948445200000001000000010806" diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 2844c493fc..b3afd51522 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -1937,7 +1937,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): # Create rooms and join other_user_tok = self.login("user", "pass") number_rooms = 5 - for n in range(number_rooms): + for _ in range(number_rooms): self.helper.create_room_as(self.other_user, tok=other_user_tok) # Get rooms @@ -2517,7 +2517,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): user_token: Access token of the user number_media: Number of media to be created for the user """ - for i in range(number_media): + for _ in range(number_media): # file size is 67 Byte image_data = unhexlify( b"89504e470d0a1a0a0000000d4948445200000001000000010806" diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 92babf65e0..a3694f3d02 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -646,7 +646,7 @@ class RoomInviteRatelimitTestCase(RoomBase): def test_invites_by_users_ratelimit(self): """Tests that invites to a specific user are actually rate-limited.""" - for i in range(3): + for _ in range(3): room_id = self.helper.create_room_as(self.user_id) self.helper.invite(room_id, self.user_id, "@other-users:red") @@ -668,7 +668,7 @@ class RoomJoinRatelimitTestCase(RoomBase): ) def test_join_local_ratelimit(self): """Tests that local joins are actually rate-limited.""" - for i in range(3): + for _ in range(3): self.helper.create_room_as(self.user_id) self.helper.create_room_as(self.user_id, expect_code=429) @@ -733,7 +733,7 @@ class RoomJoinRatelimitTestCase(RoomBase): for path in paths_to_test: # Make sure we send more requests than the rate-limiting config would allow # if all of these requests ended up joining the user to a room. - for i in range(4): + for _ in range(4): channel = self.make_request("POST", path % room_id, {}) self.assertEquals(channel.code, 200) diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py index 397e68fe0a..088fbb247b 100644 --- a/tests/storage/test_event_metrics.py +++ b/tests/storage/test_event_metrics.py @@ -38,12 +38,12 @@ class ExtremStatisticsTestCase(HomeserverTestCase): last_event = None # Make a real event chain - for i in range(event_count): + for _ in range(event_count): ev = self.create_and_send_event(room_id, user, False, last_event) last_event = [ev] # Sprinkle in some extremities - for i in range(extrems): + for _ in range(extrems): ev = self.create_and_send_event(room_id, user, False, last_event) # Let it run for a while, then pull out the statistics from the diff --git a/tests/unittest.py b/tests/unittest.py index d890ad981f..ee22a53849 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -133,7 +133,7 @@ class TestCase(unittest.TestCase): def assertObjectHasAttributes(self, attrs, obj): """Asserts that the given object has each of the attributes given, and that the value of each matches according to assertEquals.""" - for (key, value) in attrs.items(): + for key in attrs.keys(): if not hasattr(obj, key): raise AssertionError("Expected obj to have a '.%s'" % key) try: diff --git a/tests/utils.py b/tests/utils.py index af6b32fc66..63d52b9140 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -303,7 +303,7 @@ def setup_test_homeserver( # database for a few more seconds due to flakiness, preventing # us from dropping it when the test is over. If we can't drop # it, warn and move on. - for x in range(5): + for _ in range(5): try: cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,)) db_conn.commit() -- cgit 1.5.1 From 177dae270420ee4b4c8fa5e2c74c5081d98da320 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 22 Apr 2021 17:49:11 +0100 Subject: Limit length of accepted email addresses (#9855) --- changelog.d/9855.misc | 1 + synapse/push/emailpusher.py | 9 ++++- synapse/rest/client/v2_alpha/account.py | 8 ++--- synapse/rest/client/v2_alpha/register.py | 8 +++-- synapse/util/threepids.py | 30 +++++++++++++++++ tests/rest/client/v2_alpha/test_register.py | 51 +++++++++++++++++++++++++++++ 6 files changed, 100 insertions(+), 7 deletions(-) create mode 100644 changelog.d/9855.misc (limited to 'tests/rest/client') diff --git a/changelog.d/9855.misc b/changelog.d/9855.misc new file mode 100644 index 0000000000..6a3d700fde --- /dev/null +++ b/changelog.d/9855.misc @@ -0,0 +1 @@ +Limit length of accepted email addresses. diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py index cd89b54305..99a18874d1 100644 --- a/synapse/push/emailpusher.py +++ b/synapse/push/emailpusher.py @@ -19,8 +19,9 @@ from twisted.internet.error import AlreadyCalled, AlreadyCancelled from twisted.internet.interfaces import IDelayedCall from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.push import Pusher, PusherConfig, ThrottleParams +from synapse.push import Pusher, PusherConfig, PusherConfigException, ThrottleParams from synapse.push.mailer import Mailer +from synapse.util.threepids import validate_email if TYPE_CHECKING: from synapse.server import HomeServer @@ -71,6 +72,12 @@ class EmailPusher(Pusher): self._is_processing = False + # Make sure that the email is valid. + try: + validate_email(self.email) + except ValueError: + raise PusherConfigException("Invalid email") + def on_started(self, should_check_for_notifs: bool) -> None: """Called when this pusher has been started. diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 3aad15132d..085561d3e9 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -39,7 +39,7 @@ from synapse.metrics import threepid_send_requests from synapse.push.mailer import Mailer from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.stringutils import assert_valid_client_secret, random_string -from synapse.util.threepids import canonicalise_email, check_3pid_allowed +from synapse.util.threepids import check_3pid_allowed, validate_email from ._base import client_patterns, interactive_auth_handler @@ -92,7 +92,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): # Stored in the database "foo@bar.com" # User requests with "FOO@bar.com" would raise a Not Found error try: - email = canonicalise_email(body["email"]) + email = validate_email(body["email"]) except ValueError as e: raise SynapseError(400, str(e)) send_attempt = body["send_attempt"] @@ -247,7 +247,7 @@ class PasswordRestServlet(RestServlet): # We store all email addresses canonicalised in the DB. # (See add_threepid in synapse/handlers/auth.py) try: - threepid["address"] = canonicalise_email(threepid["address"]) + threepid["address"] = validate_email(threepid["address"]) except ValueError as e: raise SynapseError(400, str(e)) # if using email, we must know about the email they're authing with! @@ -375,7 +375,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): # Otherwise the email will be sent to "FOO@bar.com" and stored as # "foo@bar.com" in database. try: - email = canonicalise_email(body["email"]) + email = validate_email(body["email"]) except ValueError as e: raise SynapseError(400, str(e)) send_attempt = body["send_attempt"] diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index c5a6800b8a..a30a5df1b1 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -49,7 +49,11 @@ from synapse.push.mailer import Mailer from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.stringutils import assert_valid_client_secret, random_string -from synapse.util.threepids import canonicalise_email, check_3pid_allowed +from synapse.util.threepids import ( + canonicalise_email, + check_3pid_allowed, + validate_email, +) from ._base import client_patterns, interactive_auth_handler @@ -111,7 +115,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): # (See on_POST in EmailThreepidRequestTokenRestServlet # in synapse/rest/client/v2_alpha/account.py) try: - email = canonicalise_email(body["email"]) + email = validate_email(body["email"]) except ValueError as e: raise SynapseError(400, str(e)) send_attempt = body["send_attempt"] diff --git a/synapse/util/threepids.py b/synapse/util/threepids.py index 281c5be4fb..a1cf1960b0 100644 --- a/synapse/util/threepids.py +++ b/synapse/util/threepids.py @@ -18,6 +18,16 @@ import re logger = logging.getLogger(__name__) +# it's unclear what the maximum length of an email address is. RFC3696 (as corrected +# by errata) says: +# the upper limit on address lengths should normally be considered to be 254. +# +# In practice, mail servers appear to be more tolerant and allow 400 characters +# or so. Let's allow 500, which should be plenty for everyone. +# +MAX_EMAIL_ADDRESS_LENGTH = 500 + + def check_3pid_allowed(hs, medium, address): """Checks whether a given format of 3PID is allowed to be used on this HS @@ -70,3 +80,23 @@ def canonicalise_email(address: str) -> str: raise ValueError("Unable to parse email address") return parts[0].casefold() + "@" + parts[1].lower() + + +def validate_email(address: str) -> str: + """Does some basic validation on an email address. + + Returns the canonicalised email, as returned by `canonicalise_email`. + + Raises a ValueError if the email is invalid. + """ + # First we try canonicalising in case that fails + address = canonicalise_email(address) + + # Email addresses have to be at least 3 characters. + if len(address) < 3: + raise ValueError("Unable to parse email address") + + if len(address) > MAX_EMAIL_ADDRESS_LENGTH: + raise ValueError("Unable to parse email address") + + return address diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index 98695b05d5..1cad5f00eb 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -310,6 +310,57 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertIsNotNone(channel.json_body.get("sid")) + @unittest.override_config( + { + "public_baseurl": "https://test_server", + "email": { + "smtp_host": "mail_server", + "smtp_port": 2525, + "notif_from": "sender@host", + }, + } + ) + def test_reject_invalid_email(self): + """Check that bad emails are rejected""" + + # Test for email with multiple @ + channel = self.make_request( + "POST", + b"register/email/requestToken", + {"client_secret": "foobar", "email": "email@@email", "send_attempt": 1}, + ) + self.assertEquals(400, channel.code, channel.result) + # Check error to ensure that we're not erroring due to a bug in the test. + self.assertEquals( + channel.json_body, + {"errcode": "M_UNKNOWN", "error": "Unable to parse email address"}, + ) + + # Test for email with no @ + channel = self.make_request( + "POST", + b"register/email/requestToken", + {"client_secret": "foobar", "email": "email", "send_attempt": 1}, + ) + self.assertEquals(400, channel.code, channel.result) + self.assertEquals( + channel.json_body, + {"errcode": "M_UNKNOWN", "error": "Unable to parse email address"}, + ) + + # Test for super long email + email = "a@" + "a" * 1000 + channel = self.make_request( + "POST", + b"register/email/requestToken", + {"client_secret": "foobar", "email": email, "send_attempt": 1}, + ) + self.assertEquals(400, channel.code, channel.result) + self.assertEquals( + channel.json_body, + {"errcode": "M_UNKNOWN", "error": "Unable to parse email address"}, + ) + class AccountValidityTestCase(unittest.HomeserverTestCase): -- cgit 1.5.1 From 9d25a0ae65ce8728d0fda1eebaf0b469316f84d7 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 23 Apr 2021 12:21:55 +0100 Subject: Split presence out of master (#9820) --- changelog.d/9820.feature | 1 + scripts/synapse_port_db | 7 +- synapse/app/generic_worker.py | 31 +------- synapse/config/workers.py | 27 ++++++- synapse/handlers/presence.py | 56 ++++++++----- synapse/replication/http/_base.py | 5 +- synapse/replication/slave/storage/presence.py | 50 ------------ synapse/replication/tcp/handler.py | 18 ++++- synapse/replication/tcp/streams/_base.py | 17 ++-- synapse/rest/client/v1/presence.py | 7 +- synapse/server.py | 6 +- synapse/storage/databases/main/__init__.py | 47 +---------- synapse/storage/databases/main/presence.py | 92 +++++++++++++++++++++- .../schema/delta/59/12presence_stream_instance.sql | 18 +++++ .../59/12presence_stream_instance_seq.sql.postgres | 20 +++++ tests/app/test_frontend_proxy.py | 83 ------------------- tests/rest/client/v1/test_presence.py | 5 +- 17 files changed, 245 insertions(+), 245 deletions(-) create mode 100644 changelog.d/9820.feature delete mode 100644 synapse/replication/slave/storage/presence.py create mode 100644 synapse/storage/databases/main/schema/delta/59/12presence_stream_instance.sql create mode 100644 synapse/storage/databases/main/schema/delta/59/12presence_stream_instance_seq.sql.postgres delete mode 100644 tests/app/test_frontend_proxy.py (limited to 'tests/rest/client') diff --git a/changelog.d/9820.feature b/changelog.d/9820.feature new file mode 100644 index 0000000000..f56b0bb3bd --- /dev/null +++ b/changelog.d/9820.feature @@ -0,0 +1 @@ +Add experimental support for handling presence on a worker. diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index b7c1ffc956..f0c93d5226 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -634,8 +634,11 @@ class Porter(object): "device_inbox_sequence", ("device_inbox", "device_federation_outbox") ) await self._setup_sequence( - "account_data_sequence", ("room_account_data", "room_tags_revisions", "account_data")) - await self._setup_sequence("receipts_sequence", ("receipts_linearized", )) + "account_data_sequence", + ("room_account_data", "room_tags_revisions", "account_data"), + ) + await self._setup_sequence("receipts_sequence", ("receipts_linearized",)) + await self._setup_sequence("presence_stream_sequence", ("presence_stream",)) await self._setup_auth_chain_sequence() # Step 3. Get tables. diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 26c458dbb6..7b2ac3ca64 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -55,7 +55,6 @@ from synapse.replication.slave.storage.events import SlavedEventStore from synapse.replication.slave.storage.filtering import SlavedFilteringStore from synapse.replication.slave.storage.groups import SlavedGroupServerStore from synapse.replication.slave.storage.keys import SlavedKeyStore -from synapse.replication.slave.storage.presence import SlavedPresenceStore from synapse.replication.slave.storage.profile import SlavedProfileStore from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore from synapse.replication.slave.storage.pushers import SlavedPusherStore @@ -64,7 +63,7 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationSto from synapse.replication.slave.storage.room import RoomStore from synapse.replication.slave.storage.transactions import SlavedTransactionStore from synapse.rest.admin import register_servlets_for_media_repo -from synapse.rest.client.v1 import events, login, room +from synapse.rest.client.v1 import events, login, presence, room from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet from synapse.rest.client.v1.profile import ( ProfileAvatarURLRestServlet, @@ -110,6 +109,7 @@ from synapse.storage.databases.main.metrics import ServerMetricsStore from synapse.storage.databases.main.monthly_active_users import ( MonthlyActiveUsersWorkerStore, ) +from synapse.storage.databases.main.presence import PresenceStore from synapse.storage.databases.main.search import SearchWorkerStore from synapse.storage.databases.main.stats import StatsStore from synapse.storage.databases.main.transactions import TransactionWorkerStore @@ -121,26 +121,6 @@ from synapse.util.versionstring import get_version_string logger = logging.getLogger("synapse.app.generic_worker") -class PresenceStatusStubServlet(RestServlet): - """If presence is disabled this servlet can be used to stub out setting - presence status. - """ - - PATTERNS = client_patterns("/presence/(?P[^/]*)/status") - - def __init__(self, hs): - super().__init__() - self.auth = hs.get_auth() - - async def on_GET(self, request, user_id): - await self.auth.get_user_by_req(request) - return 200, {"presence": "offline"} - - async def on_PUT(self, request, user_id): - await self.auth.get_user_by_req(request) - return 200, {} - - class KeyUploadServlet(RestServlet): """An implementation of the `KeyUploadServlet` that responds to read only requests, but otherwise proxies through to the master instance. @@ -241,6 +221,7 @@ class GenericWorkerSlavedStore( StatsStore, UIAuthWorkerStore, EndToEndRoomKeyStore, + PresenceStore, SlavedDeviceInboxStore, SlavedDeviceStore, SlavedReceiptsStore, @@ -259,7 +240,6 @@ class GenericWorkerSlavedStore( SlavedTransactionStore, SlavedProfileStore, SlavedClientIpStore, - SlavedPresenceStore, SlavedFilteringStore, MonthlyActiveUsersWorkerStore, MediaRepositoryStore, @@ -327,10 +307,7 @@ class GenericWorkerServer(HomeServer): user_directory.register_servlets(self, resource) - # If presence is disabled, use the stub servlet that does - # not allow sending presence - if not self.config.use_presence: - PresenceStatusStubServlet(self).register(resource) + presence.register_servlets(self, resource) groups.register_servlets(self, resource) diff --git a/synapse/config/workers.py b/synapse/config/workers.py index b2540163d1..462630201d 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -64,6 +64,14 @@ class WriterLocations: Attributes: events: The instances that write to the event and backfill streams. typing: The instance that writes to the typing stream. + to_device: The instances that write to the to_device stream. Currently + can only be a single instance. + account_data: The instances that write to the account data streams. Currently + can only be a single instance. + receipts: The instances that write to the receipts stream. Currently + can only be a single instance. + presence: The instances that write to the presence stream. Currently + can only be a single instance. """ events = attr.ib( @@ -85,6 +93,11 @@ class WriterLocations: type=List[str], converter=_instance_to_list_converter, ) + presence = attr.ib( + default=["master"], + type=List[str], + converter=_instance_to_list_converter, + ) class WorkerConfig(Config): @@ -188,7 +201,14 @@ class WorkerConfig(Config): # Check that the configured writers for events and typing also appears in # `instance_map`. - for stream in ("events", "typing", "to_device", "account_data", "receipts"): + for stream in ( + "events", + "typing", + "to_device", + "account_data", + "receipts", + "presence", + ): instances = _instance_to_list_converter(getattr(self.writers, stream)) for instance in instances: if instance != "master" and instance not in self.instance_map: @@ -215,6 +235,11 @@ class WorkerConfig(Config): if len(self.writers.events) == 0: raise ConfigError("Must specify at least one instance to handle `events`.") + if len(self.writers.presence) != 1: + raise ConfigError( + "Must only specify one instance to handle `presence` messages." + ) + self.events_shard_config = RoutableShardedWorkerHandlingConfig( self.writers.events ) diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 7fd28ffa54..9938be3821 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -122,7 +122,8 @@ assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER class BasePresenceHandler(abc.ABC): - """Parts of the PresenceHandler that are shared between workers and master""" + """Parts of the PresenceHandler that are shared between workers and presence + writer""" def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() @@ -309,8 +310,16 @@ class WorkerPresenceHandler(BasePresenceHandler): super().__init__(hs) self.hs = hs + self._presence_writer_instance = hs.config.worker.writers.presence[0] + self._presence_enabled = hs.config.use_presence + # Route presence EDUs to the right worker + hs.get_federation_registry().register_instances_for_edu( + "m.presence", + hs.config.worker.writers.presence, + ) + # The number of ongoing syncs on this process, by user id. # Empty if _presence_enabled is false. self._user_to_num_current_syncs = {} # type: Dict[str, int] @@ -318,8 +327,8 @@ class WorkerPresenceHandler(BasePresenceHandler): self.notifier = hs.get_notifier() self.instance_id = hs.get_instance_id() - # user_id -> last_sync_ms. Lists the users that have stopped syncing - # but we haven't notified the master of that yet + # user_id -> last_sync_ms. Lists the users that have stopped syncing but + # we haven't notified the presence writer of that yet self.users_going_offline = {} self._bump_active_client = ReplicationBumpPresenceActiveTime.make_client(hs) @@ -352,22 +361,23 @@ class WorkerPresenceHandler(BasePresenceHandler): ) def mark_as_coming_online(self, user_id): - """A user has started syncing. Send a UserSync to the master, unless they - had recently stopped syncing. + """A user has started syncing. Send a UserSync to the presence writer, + unless they had recently stopped syncing. Args: user_id (str) """ going_offline = self.users_going_offline.pop(user_id, None) if not going_offline: - # Safe to skip because we haven't yet told the master they were offline + # Safe to skip because we haven't yet told the presence writer they + # were offline self.send_user_sync(user_id, True, self.clock.time_msec()) def mark_as_going_offline(self, user_id): - """A user has stopped syncing. We wait before notifying the master as - its likely they'll come back soon. This allows us to avoid sending - a stopped syncing immediately followed by a started syncing notification - to the master + """A user has stopped syncing. We wait before notifying the presence + writer as its likely they'll come back soon. This allows us to avoid + sending a stopped syncing immediately followed by a started syncing + notification to the presence writer Args: user_id (str) @@ -375,8 +385,8 @@ class WorkerPresenceHandler(BasePresenceHandler): self.users_going_offline[user_id] = self.clock.time_msec() def send_stop_syncing(self): - """Check if there are any users who have stopped syncing a while ago - and haven't come back yet. If there are poke the master about them. + """Check if there are any users who have stopped syncing a while ago and + haven't come back yet. If there are poke the presence writer about them. """ now = self.clock.time_msec() for user_id, last_sync_ms in list(self.users_going_offline.items()): @@ -492,9 +502,12 @@ class WorkerPresenceHandler(BasePresenceHandler): if not self.hs.config.use_presence: return - # Proxy request to master + # Proxy request to instance that writes presence await self._set_state_client( - user_id=user_id, state=state, ignore_status_msg=ignore_status_msg + instance_name=self._presence_writer_instance, + user_id=user_id, + state=state, + ignore_status_msg=ignore_status_msg, ) async def bump_presence_active_time(self, user): @@ -505,9 +518,11 @@ class WorkerPresenceHandler(BasePresenceHandler): if not self.hs.config.use_presence: return - # Proxy request to master + # Proxy request to instance that writes presence user_id = user.to_string() - await self._bump_active_client(user_id=user_id) + await self._bump_active_client( + instance_name=self._presence_writer_instance, user_id=user_id + ) class PresenceHandler(BasePresenceHandler): @@ -1909,7 +1924,7 @@ class PresenceFederationQueue: self._queue_presence_updates = True # Whether this instance is a presence writer. - self._presence_writer = hs.config.worker.worker_app is None + self._presence_writer = self._instance_name in hs.config.worker.writers.presence # The FederationSender instance, if this process sends federation traffic directly. self._federation = None @@ -1957,7 +1972,7 @@ class PresenceFederationQueue: Will forward to the local federation sender (if there is one) and queue to send over replication (if there are other federation sender instances.). - Must only be called on the master process. + Must only be called on the presence writer process. """ # This should only be called on a presence writer. @@ -2003,10 +2018,11 @@ class PresenceFederationQueue: We return rows in the form of `(destination, user_id)` to keep the size of each row bounded (rather than returning the sets in a row). - On workers this will query the master process via HTTP replication. + On workers this will query the presence writer process via HTTP replication. """ if instance_name != self._instance_name: - # If not local we query over http replication from the master + # If not local we query over http replication from the presence + # writer result = await self._repl_client( instance_name=instance_name, stream_name=PresenceFederationStream.NAME, diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index ece03467b5..5685cf2121 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -158,7 +158,10 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): def make_client(cls, hs): """Create a client that makes requests. - Returns a callable that accepts the same parameters as `_serialize_payload`. + Returns a callable that accepts the same parameters as + `_serialize_payload`, and also accepts an optional `instance_name` + parameter to specify which instance to hit (the instance must be in + the `instance_map` config). """ clock = hs.get_clock() client = hs.get_simple_http_client() diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py deleted file mode 100644 index 57327d910d..0000000000 --- a/synapse/replication/slave/storage/presence.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from synapse.replication.tcp.streams import PresenceStream -from synapse.storage import DataStore -from synapse.storage.database import DatabasePool -from synapse.storage.databases.main.presence import PresenceStore -from synapse.util.caches.stream_change_cache import StreamChangeCache - -from ._base import BaseSlavedStore -from ._slaved_id_tracker import SlavedIdTracker - - -class SlavedPresenceStore(BaseSlavedStore): - def __init__(self, database: DatabasePool, db_conn, hs): - super().__init__(database, db_conn, hs) - self._presence_id_gen = SlavedIdTracker(db_conn, "presence_stream", "stream_id") - - self._presence_on_startup = self._get_active_presence(db_conn) # type: ignore - - self.presence_stream_cache = StreamChangeCache( - "PresenceStreamChangeCache", self._presence_id_gen.get_current_token() - ) - - _get_active_presence = DataStore._get_active_presence - take_presence_startup_info = DataStore.take_presence_startup_info - _get_presence_for_user = PresenceStore.__dict__["_get_presence_for_user"] - get_presence_for_users = PresenceStore.__dict__["get_presence_for_users"] - - def get_current_presence_token(self): - return self._presence_id_gen.get_current_token() - - def process_replication_rows(self, stream_name, instance_name, token, rows): - if stream_name == PresenceStream.NAME: - self._presence_id_gen.advance(instance_name, token) - for row in rows: - self.presence_stream_cache.entity_has_changed(row.user_id, token) - self._get_presence_for_user.invalidate((row.user_id,)) - return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 2ce1b9f222..7ced4c543c 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -55,6 +55,8 @@ from synapse.replication.tcp.streams import ( CachesStream, EventsStream, FederationStream, + PresenceFederationStream, + PresenceStream, ReceiptsStream, Stream, TagAccountDataStream, @@ -99,6 +101,10 @@ class ReplicationCommandHandler: self._instance_id = hs.get_instance_id() self._instance_name = hs.get_instance_name() + self._is_presence_writer = ( + hs.get_instance_name() in hs.config.worker.writers.presence + ) + self._streams = { stream.NAME: stream(hs) for stream in STREAMS_MAP.values() } # type: Dict[str, Stream] @@ -153,6 +159,14 @@ class ReplicationCommandHandler: continue + if isinstance(stream, (PresenceStream, PresenceFederationStream)): + # Only add PresenceStream as a source on the instance in charge + # of presence. + if self._is_presence_writer: + self._streams_to_replicate.append(stream) + + continue + # Only add any other streams if we're on master. if hs.config.worker_app is not None: continue @@ -350,7 +364,7 @@ class ReplicationCommandHandler: ) -> Optional[Awaitable[None]]: user_sync_counter.inc() - if self._is_master: + if self._is_presence_writer: return self._presence_handler.update_external_syncs_row( cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms ) @@ -360,7 +374,7 @@ class ReplicationCommandHandler: def on_CLEAR_USER_SYNC( self, conn: IReplicationConnection, cmd: ClearUserSyncsCommand ) -> Optional[Awaitable[None]]: - if self._is_master: + if self._is_presence_writer: return self._presence_handler.update_external_syncs_clear(cmd.instance_id) else: return None diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 9d75a89f1c..b03824925a 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -272,15 +272,22 @@ class PresenceStream(Stream): NAME = "presence" ROW_TYPE = PresenceStreamRow - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): store = hs.get_datastore() - if hs.config.worker_app is None: - # on the master, query the presence handler + if hs.get_instance_name() in hs.config.worker.writers.presence: + # on the presence writer, query the presence handler presence_handler = hs.get_presence_handler() - update_function = presence_handler.get_all_presence_updates + + from synapse.handlers.presence import PresenceHandler + + assert isinstance(presence_handler, PresenceHandler) + + update_function = ( + presence_handler.get_all_presence_updates + ) # type: UpdateFunction else: - # Query master process + # Query presence writer process update_function = make_http_update_function(hs, self.NAME) super().__init__( diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py index c232484f29..2b24fe5aa6 100644 --- a/synapse/rest/client/v1/presence.py +++ b/synapse/rest/client/v1/presence.py @@ -35,10 +35,15 @@ class PresenceStatusRestServlet(RestServlet): self.clock = hs.get_clock() self.auth = hs.get_auth() + self._use_presence = hs.config.server.use_presence + async def on_GET(self, request, user_id): requester = await self.auth.get_user_by_req(request) user = UserID.from_string(user_id) + if not self._use_presence: + return 200, {"presence": "offline"} + if requester.user != user: allowed = await self.presence_handler.is_visible( observed_user=user, observer_user=requester.user @@ -80,7 +85,7 @@ class PresenceStatusRestServlet(RestServlet): except Exception: raise SynapseError(400, "Unable to parse state") - if self.hs.config.use_presence: + if self._use_presence: await self.presence_handler.set_state(user, state) return 200, {} diff --git a/synapse/server.py b/synapse/server.py index 67598fffe3..8c147be2b3 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -418,10 +418,10 @@ class HomeServer(metaclass=abc.ABCMeta): @cache_in_self def get_presence_handler(self) -> BasePresenceHandler: - if self.config.worker_app: - return WorkerPresenceHandler(self) - else: + if self.get_instance_name() in self.config.worker.writers.presence: return PresenceHandler(self) + else: + return WorkerPresenceHandler(self) @cache_in_self def get_typing_writer_handler(self) -> TypingWriterHandler: diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 5c50f5f950..49c7606d51 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -17,7 +17,6 @@ import logging from typing import List, Optional, Tuple -from synapse.api.constants import PresenceState from synapse.config.homeserver import HomeServerConfig from synapse.storage.database import DatabasePool from synapse.storage.databases.main.stats import UserSortOrder @@ -51,7 +50,7 @@ from .media_repository import MediaRepositoryStore from .metrics import ServerMetricsStore from .monthly_active_users import MonthlyActiveUsersStore from .openid import OpenIdStore -from .presence import PresenceStore, UserPresenceState +from .presence import PresenceStore from .profile import ProfileStore from .purge_events import PurgeEventsStore from .push_rule import PushRuleStore @@ -126,9 +125,6 @@ class DataStore( self._clock = hs.get_clock() self.database_engine = database.engine - self._presence_id_gen = StreamIdGenerator( - db_conn, "presence_stream", "stream_id" - ) self._public_room_id_gen = StreamIdGenerator( db_conn, "public_room_list_stream", "stream_id" ) @@ -177,21 +173,6 @@ class DataStore( super().__init__(database, db_conn, hs) - self._presence_on_startup = self._get_active_presence(db_conn) - - presence_cache_prefill, min_presence_val = self.db_pool.get_cache_dict( - db_conn, - "presence_stream", - entity_column="user_id", - stream_column="stream_id", - max_value=self._presence_id_gen.get_current_token(), - ) - self.presence_stream_cache = StreamChangeCache( - "PresenceStreamChangeCache", - min_presence_val, - prefilled_cache=presence_cache_prefill, - ) - device_list_max = self._device_list_id_gen.get_current_token() self._device_list_stream_cache = StreamChangeCache( "DeviceListStreamChangeCache", device_list_max @@ -238,32 +219,6 @@ class DataStore( def get_device_stream_token(self) -> int: return self._device_list_id_gen.get_current_token() - def take_presence_startup_info(self): - active_on_startup = self._presence_on_startup - self._presence_on_startup = None - return active_on_startup - - def _get_active_presence(self, db_conn): - """Fetch non-offline presence from the database so that we can register - the appropriate time outs. - """ - - sql = ( - "SELECT user_id, state, last_active_ts, last_federation_update_ts," - " last_user_sync_ts, status_msg, currently_active FROM presence_stream" - " WHERE state != ?" - ) - - txn = db_conn.cursor() - txn.execute(sql, (PresenceState.OFFLINE,)) - rows = self.db_pool.cursor_to_dict(txn) - txn.close() - - for row in rows: - row["currently_active"] = bool(row["currently_active"]) - - return [UserPresenceState(**row) for row in rows] - async def get_users(self) -> List[JsonDict]: """Function to retrieve a list of users in users table. diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py index c207d917b1..db22fab23e 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py @@ -12,16 +12,69 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Tuple +from typing import TYPE_CHECKING, Dict, List, Tuple -from synapse.api.presence import UserPresenceState +from synapse.api.presence import PresenceState, UserPresenceState +from synapse.replication.tcp.streams import PresenceStream from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause +from synapse.storage.database import DatabasePool +from synapse.storage.engines import PostgresEngine +from synapse.storage.types import Connection +from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.util.caches.descriptors import cached, cachedList +from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.iterutils import batch_iter +if TYPE_CHECKING: + from synapse.server import HomeServer + class PresenceStore(SQLBaseStore): + def __init__( + self, + database: DatabasePool, + db_conn: Connection, + hs: "HomeServer", + ): + super().__init__(database, db_conn, hs) + + self._can_persist_presence = ( + hs.get_instance_name() in hs.config.worker.writers.presence + ) + + if isinstance(database.engine, PostgresEngine): + self._presence_id_gen = MultiWriterIdGenerator( + db_conn=db_conn, + db=database, + stream_name="presence_stream", + instance_name=self._instance_name, + tables=[("presence_stream", "instance_name", "stream_id")], + sequence_name="presence_stream_sequence", + writers=hs.config.worker.writers.to_device, + ) + else: + self._presence_id_gen = StreamIdGenerator( + db_conn, "presence_stream", "stream_id" + ) + + self._presence_on_startup = self._get_active_presence(db_conn) + + presence_cache_prefill, min_presence_val = self.db_pool.get_cache_dict( + db_conn, + "presence_stream", + entity_column="user_id", + stream_column="stream_id", + max_value=self._presence_id_gen.get_current_token(), + ) + self.presence_stream_cache = StreamChangeCache( + "PresenceStreamChangeCache", + min_presence_val, + prefilled_cache=presence_cache_prefill, + ) + async def update_presence(self, presence_states): + assert self._can_persist_presence + stream_ordering_manager = self._presence_id_gen.get_next_mult( len(presence_states) ) @@ -57,6 +110,7 @@ class PresenceStore(SQLBaseStore): "last_user_sync_ts": state.last_user_sync_ts, "status_msg": state.status_msg, "currently_active": state.currently_active, + "instance_name": self._instance_name, } for stream_id, state in zip(stream_orderings, presence_states) ], @@ -216,3 +270,37 @@ class PresenceStore(SQLBaseStore): def get_current_presence_token(self): return self._presence_id_gen.get_current_token() + + def _get_active_presence(self, db_conn: Connection): + """Fetch non-offline presence from the database so that we can register + the appropriate time outs. + """ + + sql = ( + "SELECT user_id, state, last_active_ts, last_federation_update_ts," + " last_user_sync_ts, status_msg, currently_active FROM presence_stream" + " WHERE state != ?" + ) + + txn = db_conn.cursor() + txn.execute(sql, (PresenceState.OFFLINE,)) + rows = self.db_pool.cursor_to_dict(txn) + txn.close() + + for row in rows: + row["currently_active"] = bool(row["currently_active"]) + + return [UserPresenceState(**row) for row in rows] + + def take_presence_startup_info(self): + active_on_startup = self._presence_on_startup + self._presence_on_startup = None + return active_on_startup + + def process_replication_rows(self, stream_name, instance_name, token, rows): + if stream_name == PresenceStream.NAME: + self._presence_id_gen.advance(instance_name, token) + for row in rows: + self.presence_stream_cache.entity_has_changed(row.user_id, token) + self._get_presence_for_user.invalidate((row.user_id,)) + return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/storage/databases/main/schema/delta/59/12presence_stream_instance.sql b/synapse/storage/databases/main/schema/delta/59/12presence_stream_instance.sql new file mode 100644 index 0000000000..b6ba0bda1a --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/59/12presence_stream_instance.sql @@ -0,0 +1,18 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Add a column to specify which instance wrote the row. Historic rows have +-- `NULL`, which indicates that the master instance wrote them. +ALTER TABLE presence_stream ADD COLUMN instance_name TEXT; diff --git a/synapse/storage/databases/main/schema/delta/59/12presence_stream_instance_seq.sql.postgres b/synapse/storage/databases/main/schema/delta/59/12presence_stream_instance_seq.sql.postgres new file mode 100644 index 0000000000..02b182adf9 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/59/12presence_stream_instance_seq.sql.postgres @@ -0,0 +1,20 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE SEQUENCE IF NOT EXISTS presence_stream_sequence; + +SELECT setval('presence_stream_sequence', ( + SELECT COALESCE(MAX(stream_id), 1) FROM presence_stream +)); diff --git a/tests/app/test_frontend_proxy.py b/tests/app/test_frontend_proxy.py deleted file mode 100644 index 3d45da38ab..0000000000 --- a/tests/app/test_frontend_proxy.py +++ /dev/null @@ -1,83 +0,0 @@ -# Copyright 2018 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from synapse.app.generic_worker import GenericWorkerServer - -from tests.server import make_request -from tests.unittest import HomeserverTestCase - - -class FrontendProxyTests(HomeserverTestCase): - def make_homeserver(self, reactor, clock): - - hs = self.setup_test_homeserver( - federation_http_client=None, homeserver_to_use=GenericWorkerServer - ) - - return hs - - def default_config(self): - c = super().default_config() - c["worker_app"] = "synapse.app.frontend_proxy" - - c["worker_listeners"] = [ - { - "type": "http", - "port": 8080, - "bind_addresses": ["0.0.0.0"], - "resources": [{"names": ["client"]}], - } - ] - - return c - - def test_listen_http_with_presence_enabled(self): - """ - When presence is on, the stub servlet will not register. - """ - # Presence is on - self.hs.config.use_presence = True - - # Listen with the config - self.hs._listen_http(self.hs.config.worker.worker_listeners[0]) - - # Grab the resource from the site that was told to listen - self.assertEqual(len(self.reactor.tcpServers), 1) - site = self.reactor.tcpServers[0][1] - - channel = make_request(self.reactor, site, "PUT", "presence/a/status") - - # 400 + unrecognised, because nothing is registered - self.assertEqual(channel.code, 400) - self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED") - - def test_listen_http_with_presence_disabled(self): - """ - When presence is off, the stub servlet will register. - """ - # Presence is off - self.hs.config.use_presence = False - - # Listen with the config - self.hs._listen_http(self.hs.config.worker.worker_listeners[0]) - - # Grab the resource from the site that was told to listen - self.assertEqual(len(self.reactor.tcpServers), 1) - site = self.reactor.tcpServers[0][1] - - channel = make_request(self.reactor, site, "PUT", "presence/a/status") - - # 401, because the stub servlet still checks authentication - self.assertEqual(channel.code, 401) - self.assertEqual(channel.json_body["errcode"], "M_MISSING_TOKEN") diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py index 3a050659ca..409f3949dc 100644 --- a/tests/rest/client/v1/test_presence.py +++ b/tests/rest/client/v1/test_presence.py @@ -16,6 +16,7 @@ from unittest.mock import Mock from twisted.internet import defer +from synapse.handlers.presence import PresenceHandler from synapse.rest.client.v1 import presence from synapse.types import UserID @@ -32,7 +33,7 @@ class PresenceTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): - presence_handler = Mock() + presence_handler = Mock(spec=PresenceHandler) presence_handler.set_state.return_value = defer.succeed(None) hs = self.setup_test_homeserver( @@ -59,12 +60,12 @@ class PresenceTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200) self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 1) + @unittest.override_config({"use_presence": False}) def test_put_presence_disabled(self): """ PUT to the status endpoint with use_presence disabled will NOT call set_state on the presence handler. """ - self.hs.config.use_presence = False body = {"presence": "here", "status_msg": "beep boop"} channel = self.make_request( -- cgit 1.5.1